letta-nightly 0.6.9.dev20250120104049__py3-none-any.whl → 0.6.11.dev20250120212046__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of letta-nightly might be problematic. Click here for more details.

Files changed (35) hide show
  1. letta/__init__.py +1 -1
  2. letta/agent.py +40 -23
  3. letta/client/client.py +10 -2
  4. letta/errors.py +14 -0
  5. letta/functions/ast_parsers.py +105 -0
  6. letta/llm_api/anthropic.py +130 -82
  7. letta/llm_api/aws_bedrock.py +134 -0
  8. letta/llm_api/llm_api_tools.py +30 -7
  9. letta/orm/__init__.py +1 -0
  10. letta/orm/job.py +2 -4
  11. letta/orm/message.py +5 -1
  12. letta/orm/step.py +54 -0
  13. letta/schemas/embedding_config.py +1 -0
  14. letta/schemas/letta_message.py +24 -0
  15. letta/schemas/letta_response.py +1 -9
  16. letta/schemas/llm_config.py +1 -0
  17. letta/schemas/message.py +1 -0
  18. letta/schemas/providers.py +60 -3
  19. letta/schemas/step.py +31 -0
  20. letta/server/rest_api/app.py +21 -6
  21. letta/server/rest_api/routers/v1/agents.py +15 -2
  22. letta/server/rest_api/routers/v1/llms.py +2 -2
  23. letta/server/rest_api/routers/v1/runs.py +12 -2
  24. letta/server/server.py +9 -3
  25. letta/services/agent_manager.py +4 -3
  26. letta/services/job_manager.py +11 -13
  27. letta/services/provider_manager.py +19 -7
  28. letta/services/step_manager.py +87 -0
  29. letta/settings.py +21 -1
  30. {letta_nightly-0.6.9.dev20250120104049.dist-info → letta_nightly-0.6.11.dev20250120212046.dist-info}/METADATA +9 -6
  31. {letta_nightly-0.6.9.dev20250120104049.dist-info → letta_nightly-0.6.11.dev20250120212046.dist-info}/RECORD +34 -30
  32. letta/credentials.py +0 -149
  33. {letta_nightly-0.6.9.dev20250120104049.dist-info → letta_nightly-0.6.11.dev20250120212046.dist-info}/LICENSE +0 -0
  34. {letta_nightly-0.6.9.dev20250120104049.dist-info → letta_nightly-0.6.11.dev20250120212046.dist-info}/WHEEL +0 -0
  35. {letta_nightly-0.6.9.dev20250120104049.dist-info → letta_nightly-0.6.11.dev20250120212046.dist-info}/entry_points.txt +0 -0
letta/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- __version__ = "0.6.9"
1
+ __version__ = "0.6.11"
2
2
 
3
3
 
4
4
  # import clients
letta/agent.py CHANGED
@@ -1,4 +1,3 @@
1
- import inspect
2
1
  import json
3
2
  import time
4
3
  import traceback
@@ -20,6 +19,7 @@ from letta.constants import (
20
19
  REQ_HEARTBEAT_MESSAGE,
21
20
  )
22
21
  from letta.errors import ContextWindowExceededError
22
+ from letta.functions.ast_parsers import coerce_dict_args_by_annotations, get_function_annotations_from_source
23
23
  from letta.functions.functions import get_function_from_module
24
24
  from letta.helpers import ToolRulesSolver
25
25
  from letta.interface import AgentInterface
@@ -49,6 +49,8 @@ from letta.services.helpers.agent_manager_helper import check_supports_structure
49
49
  from letta.services.job_manager import JobManager
50
50
  from letta.services.message_manager import MessageManager
51
51
  from letta.services.passage_manager import PassageManager
52
+ from letta.services.provider_manager import ProviderManager
53
+ from letta.services.step_manager import StepManager
52
54
  from letta.services.tool_execution_sandbox import ToolExecutionSandbox
53
55
  from letta.streaming_interface import StreamingRefreshCLIInterface
54
56
  from letta.system import get_heartbeat, get_token_limit_warning, package_function_response, package_summarize_message, package_user_message
@@ -130,8 +132,10 @@ class Agent(BaseAgent):
130
132
  # Create the persistence manager object based on the AgentState info
131
133
  self.message_manager = MessageManager()
132
134
  self.passage_manager = PassageManager()
135
+ self.provider_manager = ProviderManager()
133
136
  self.agent_manager = AgentManager()
134
137
  self.job_manager = JobManager()
138
+ self.step_manager = StepManager()
135
139
 
136
140
  # State needed for heartbeat pausing
137
141
 
@@ -223,15 +227,10 @@ class Agent(BaseAgent):
223
227
  function_response = callable_func(**function_args)
224
228
  self.update_memory_if_changed(agent_state_copy.memory)
225
229
  else:
226
- # TODO: Get rid of this. This whole piece is pretty shady, that we exec the function to just get the type hints for args.
227
- env = {}
228
- env.update(globals())
229
- exec(target_letta_tool.source_code, env)
230
- callable_func = env[target_letta_tool.json_schema["name"]]
231
- spec = inspect.getfullargspec(callable_func).annotations
232
- for name, arg in function_args.items():
233
- if isinstance(function_args[name], dict):
234
- function_args[name] = spec[name](**function_args[name])
230
+ # Parse the source code to extract function annotations
231
+ annotations = get_function_annotations_from_source(target_letta_tool.source_code, function_name)
232
+ # Coerce the function arguments to the correct types based on the annotations
233
+ function_args = coerce_dict_args_by_annotations(function_args, annotations)
235
234
 
236
235
  # execute tool in a sandbox
237
236
  # TODO: allow agent_state to specify which sandbox to execute tools in
@@ -355,7 +354,7 @@ class Agent(BaseAgent):
355
354
  if response_message.tool_calls is not None and len(response_message.tool_calls) > 1:
356
355
  # raise NotImplementedError(f">1 tool call not supported")
357
356
  # TODO eventually support sequential tool calling
358
- printd(f">1 tool call not supported, using index=0 only\n{response_message.tool_calls}")
357
+ self.logger.warning(f">1 tool call not supported, using index=0 only\n{response_message.tool_calls}")
359
358
  response_message.tool_calls = [response_message.tool_calls[0]]
360
359
  assert response_message.tool_calls is not None and len(response_message.tool_calls) > 0
361
360
 
@@ -384,7 +383,7 @@ class Agent(BaseAgent):
384
383
  openai_message_dict=response_message.model_dump(),
385
384
  )
386
385
  ) # extend conversation with assistant's reply
387
- printd(f"Function call message: {messages[-1]}")
386
+ self.logger.info(f"Function call message: {messages[-1]}")
388
387
 
389
388
  nonnull_content = False
390
389
  if response_message.content:
@@ -401,7 +400,7 @@ class Agent(BaseAgent):
401
400
 
402
401
  # Get the name of the function
403
402
  function_name = function_call.name
404
- printd(f"Request to call function {function_name} with tool_call_id: {tool_call_id}")
403
+ self.logger.info(f"Request to call function {function_name} with tool_call_id: {tool_call_id}")
405
404
 
406
405
  # Failure case 1: function name is wrong (not in agent_state.tools)
407
406
  target_letta_tool = None
@@ -467,7 +466,7 @@ class Agent(BaseAgent):
467
466
  heartbeat_request = True
468
467
 
469
468
  if not isinstance(heartbeat_request, bool) or heartbeat_request is None:
470
- printd(
469
+ self.logger.warning(
471
470
  f"{CLI_WARNING_PREFIX}'request_heartbeat' arg parsed was not a bool or None, type={type(heartbeat_request)}, value={heartbeat_request}"
472
471
  )
473
472
  heartbeat_request = False
@@ -503,7 +502,7 @@ class Agent(BaseAgent):
503
502
  # Less detailed - don't provide full args, idea is that it should be in recent context so no need (just adds noise)
504
503
  error_msg = get_friendly_error_msg(function_name=function_name, exception_name=type(e).__name__, exception_message=str(e))
505
504
  error_msg_user = f"{error_msg}\n{traceback.format_exc()}"
506
- printd(error_msg_user)
505
+ self.logger.error(error_msg_user)
507
506
  function_response = package_function_response(False, error_msg)
508
507
  self.last_function_response = function_response
509
508
  # TODO: truncate error message somehow
@@ -630,10 +629,10 @@ class Agent(BaseAgent):
630
629
 
631
630
  # Chain stops
632
631
  if not chaining:
633
- printd("No chaining, stopping after one step")
632
+ self.logger.info("No chaining, stopping after one step")
634
633
  break
635
634
  elif max_chaining_steps is not None and counter > max_chaining_steps:
636
- printd(f"Hit max chaining steps, stopping after {counter} steps")
635
+ self.logger.info(f"Hit max chaining steps, stopping after {counter} steps")
637
636
  break
638
637
  # Chain handlers
639
638
  elif token_warning:
@@ -713,7 +712,7 @@ class Agent(BaseAgent):
713
712
  input_message_sequence = in_context_messages + messages
714
713
 
715
714
  if len(input_message_sequence) > 1 and input_message_sequence[-1].role != "user":
716
- printd(f"{CLI_WARNING_PREFIX}Attempting to run ChatCompletion without user as the last message in the queue")
715
+ self.logger.warning(f"{CLI_WARNING_PREFIX}Attempting to run ChatCompletion without user as the last message in the queue")
717
716
 
718
717
  # Step 2: send the conversation and available functions to the LLM
719
718
  response = self._get_ai_reply(
@@ -755,7 +754,7 @@ class Agent(BaseAgent):
755
754
  )
756
755
 
757
756
  if current_total_tokens > MESSAGE_SUMMARY_WARNING_FRAC * int(self.agent_state.llm_config.context_window):
758
- printd(
757
+ self.logger.warning(
759
758
  f"{CLI_WARNING_PREFIX}last response total_tokens ({current_total_tokens}) > {MESSAGE_SUMMARY_WARNING_FRAC * int(self.agent_state.llm_config.context_window)}"
760
759
  )
761
760
 
@@ -765,10 +764,28 @@ class Agent(BaseAgent):
765
764
  self.agent_alerted_about_memory_pressure = True # it's up to the outer loop to handle this
766
765
 
767
766
  else:
768
- printd(
767
+ self.logger.warning(
769
768
  f"last response total_tokens ({current_total_tokens}) < {MESSAGE_SUMMARY_WARNING_FRAC * int(self.agent_state.llm_config.context_window)}"
770
769
  )
771
770
 
771
+ # Log step - this must happen before messages are persisted
772
+ step = self.step_manager.log_step(
773
+ actor=self.user,
774
+ provider_name=self.agent_state.llm_config.model_endpoint_type,
775
+ model=self.agent_state.llm_config.model,
776
+ context_window_limit=self.agent_state.llm_config.context_window,
777
+ usage=response.usage,
778
+ # TODO(@caren): Add full provider support - this line is a workaround for v0 BYOK feature
779
+ provider_id=(
780
+ self.provider_manager.get_anthropic_override_provider_id()
781
+ if self.agent_state.llm_config.model_endpoint_type == "anthropic"
782
+ else None
783
+ ),
784
+ job_id=job_id,
785
+ )
786
+ for message in all_new_messages:
787
+ message.step_id = step.id
788
+
772
789
  # Persisting into Messages
773
790
  self.agent_state = self.agent_manager.append_to_in_context_messages(
774
791
  all_new_messages, agent_id=self.agent_state.id, actor=self.user
@@ -790,11 +807,11 @@ class Agent(BaseAgent):
790
807
  )
791
808
 
792
809
  except Exception as e:
793
- printd(f"step() failed\nmessages = {messages}\nerror = {e}")
810
+ self.logger.error(f"step() failed\nmessages = {messages}\nerror = {e}")
794
811
 
795
812
  # If we got a context alert, try trimming the messages length, then try again
796
813
  if is_context_overflow_error(e):
797
- printd(
814
+ self.logger.warning(
798
815
  f"context window exceeded with limit {self.agent_state.llm_config.context_window}, running summarizer to trim messages"
799
816
  )
800
817
  # A separate API call to run a summarizer
@@ -811,7 +828,7 @@ class Agent(BaseAgent):
811
828
  )
812
829
 
813
830
  else:
814
- printd(f"step() failed with an unrecognized exception: '{str(e)}'")
831
+ self.logger.error(f"step() failed with an unrecognized exception: '{str(e)}'")
815
832
  raise e
816
833
 
817
834
  def step_user_message(self, user_message_str: str, **kwargs) -> AgentStepResponse:
letta/client/client.py CHANGED
@@ -410,7 +410,8 @@ class RESTClient(AbstractClient):
410
410
  def __init__(
411
411
  self,
412
412
  base_url: str,
413
- token: str,
413
+ token: Optional[str] = None,
414
+ password: Optional[str] = None,
414
415
  api_prefix: str = "v1",
415
416
  debug: bool = False,
416
417
  default_llm_config: Optional[LLMConfig] = None,
@@ -426,11 +427,18 @@ class RESTClient(AbstractClient):
426
427
  default_llm_config (Optional[LLMConfig]): The default LLM configuration.
427
428
  default_embedding_config (Optional[EmbeddingConfig]): The default embedding configuration.
428
429
  headers (Optional[Dict]): The additional headers for the REST API.
430
+ token (Optional[str]): The token for the REST API when using managed letta service.
431
+ password (Optional[str]): The password for the REST API when using self hosted letta service.
429
432
  """
430
433
  super().__init__(debug=debug)
431
434
  self.base_url = base_url
432
435
  self.api_prefix = api_prefix
433
- self.headers = {"accept": "application/json", "authorization": f"Bearer {token}"}
436
+ if token:
437
+ self.headers = {"accept": "application/json", "Authorization": f"Bearer {token}"}
438
+ elif password:
439
+ self.headers = {"accept": "application/json", "X-BARE-PASSWORD": f"password {password}"}
440
+ else:
441
+ self.headers = {"accept": "application/json"}
434
442
  if headers:
435
443
  self.headers.update(headers)
436
444
  self._default_llm_config = default_llm_config
letta/errors.py CHANGED
@@ -62,6 +62,20 @@ class LLMError(LettaError):
62
62
  pass
63
63
 
64
64
 
65
+ class BedrockPermissionError(LettaError):
66
+ """Exception raised for errors in the Bedrock permission process."""
67
+
68
+ def __init__(self, message="User does not have access to the Bedrock model with the specified ID."):
69
+ super().__init__(message=message)
70
+
71
+
72
+ class BedrockError(LettaError):
73
+ """Exception raised for errors in the Bedrock process."""
74
+
75
+ def __init__(self, message="Error with Bedrock model."):
76
+ super().__init__(message=message)
77
+
78
+
65
79
  class LLMJSONParsingError(LettaError):
66
80
  """Exception raised for errors in the JSON parsing process."""
67
81
 
@@ -0,0 +1,105 @@
1
+ import ast
2
+ import json
3
+ from typing import Dict
4
+
5
+ # Registry of known types for annotation resolution
6
+ BUILTIN_TYPES = {
7
+ "int": int,
8
+ "float": float,
9
+ "str": str,
10
+ "dict": dict,
11
+ "list": list,
12
+ "set": set,
13
+ "tuple": tuple,
14
+ "bool": bool,
15
+ }
16
+
17
+
18
+ def resolve_type(annotation: str):
19
+ """
20
+ Resolve a type annotation string into a Python type.
21
+
22
+ Args:
23
+ annotation (str): The annotation string (e.g., 'int', 'list', etc.).
24
+
25
+ Returns:
26
+ type: The corresponding Python type.
27
+
28
+ Raises:
29
+ ValueError: If the annotation is unsupported or invalid.
30
+ """
31
+ if annotation in BUILTIN_TYPES:
32
+ return BUILTIN_TYPES[annotation]
33
+
34
+ try:
35
+ parsed = ast.literal_eval(annotation)
36
+ if isinstance(parsed, type):
37
+ return parsed
38
+ raise ValueError(f"Annotation '{annotation}' is not a recognized type.")
39
+ except (ValueError, SyntaxError):
40
+ raise ValueError(f"Unsupported annotation: {annotation}")
41
+
42
+
43
+ def get_function_annotations_from_source(source_code: str, function_name: str) -> Dict[str, str]:
44
+ """
45
+ Parse the source code to extract annotations for a given function name.
46
+
47
+ Args:
48
+ source_code (str): The Python source code containing the function.
49
+ function_name (str): The name of the function to extract annotations for.
50
+
51
+ Returns:
52
+ Dict[str, str]: A dictionary of argument names to their annotation strings.
53
+
54
+ Raises:
55
+ ValueError: If the function is not found in the source code.
56
+ """
57
+ tree = ast.parse(source_code)
58
+ for node in ast.iter_child_nodes(tree):
59
+ if isinstance(node, ast.FunctionDef) and node.name == function_name:
60
+ annotations = {}
61
+ for arg in node.args.args:
62
+ if arg.annotation is not None:
63
+ annotation_str = ast.unparse(arg.annotation)
64
+ annotations[arg.arg] = annotation_str
65
+ return annotations
66
+ raise ValueError(f"Function '{function_name}' not found in the provided source code.")
67
+
68
+
69
+ def coerce_dict_args_by_annotations(function_args: dict, annotations: Dict[str, str]) -> dict:
70
+ """
71
+ Coerce arguments in a dictionary to their annotated types.
72
+
73
+ Args:
74
+ function_args (dict): The original function arguments.
75
+ annotations (Dict[str, str]): Argument annotations as strings.
76
+
77
+ Returns:
78
+ dict: The updated dictionary with coerced argument types.
79
+
80
+ Raises:
81
+ ValueError: If type coercion fails for an argument.
82
+ """
83
+ coerced_args = dict(function_args) # Shallow copy for mutation safety
84
+
85
+ for arg_name, value in coerced_args.items():
86
+ if arg_name in annotations:
87
+ annotation_str = annotations[arg_name]
88
+ try:
89
+ # Resolve the type from the annotation
90
+ arg_type = resolve_type(annotation_str)
91
+
92
+ # Handle JSON-like inputs for dict and list types
93
+ if arg_type in {dict, list} and isinstance(value, str):
94
+ try:
95
+ # First, try JSON parsing
96
+ value = json.loads(value)
97
+ except json.JSONDecodeError:
98
+ # Fall back to literal_eval for Python-specific literals
99
+ value = ast.literal_eval(value)
100
+
101
+ # Coerce the value to the resolved type
102
+ coerced_args[arg_name] = arg_type(value)
103
+ except (TypeError, ValueError, json.JSONDecodeError, SyntaxError) as e:
104
+ raise ValueError(f"Failed to coerce argument '{arg_name}' to {annotation_str}: {e}")
105
+ return coerced_args
@@ -1,8 +1,12 @@
1
1
  import json
2
2
  import re
3
- from typing import List, Optional, Union
3
+ from typing import List, Optional, Tuple, Union
4
4
 
5
- from letta.llm_api.helpers import make_post_request
5
+ import anthropic
6
+ from anthropic import PermissionDeniedError
7
+
8
+ from letta.errors import BedrockError, BedrockPermissionError
9
+ from letta.llm_api.aws_bedrock import get_bedrock_client
6
10
  from letta.schemas.message import Message
7
11
  from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, Tool
8
12
  from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice, FunctionCall
@@ -10,6 +14,8 @@ from letta.schemas.openai.chat_completion_response import (
10
14
  Message as ChoiceMessage, # NOTE: avoid conflict with our own Letta Message datatype
11
15
  )
12
16
  from letta.schemas.openai.chat_completion_response import ToolCall, UsageStatistics
17
+ from letta.services.provider_manager import ProviderManager
18
+ from letta.settings import model_settings
13
19
  from letta.utils import get_utc_time, smart_urljoin
14
20
 
15
21
  BASE_URL = "https://api.anthropic.com/v1"
@@ -195,7 +201,7 @@ def strip_xml_tags(string: str, tag: Optional[str]) -> str:
195
201
 
196
202
 
197
203
  def convert_anthropic_response_to_chatcompletion(
198
- response_json: dict, # REST response from Google AI API
204
+ response: anthropic.types.Message,
199
205
  inner_thoughts_xml_tag: Optional[str] = None,
200
206
  ) -> ChatCompletionResponse:
201
207
  """
@@ -232,65 +238,67 @@ def convert_anthropic_response_to_chatcompletion(
232
238
  }
233
239
  }
234
240
  """
235
- prompt_tokens = response_json["usage"]["input_tokens"]
236
- completion_tokens = response_json["usage"]["output_tokens"]
237
-
238
- finish_reason = remap_finish_reason(response_json["stop_reason"])
239
-
240
- if isinstance(response_json["content"], list):
241
- if len(response_json["content"]) > 1:
242
- # inner mono + function call
243
- assert len(response_json["content"]) == 2, response_json
244
- assert response_json["content"][0]["type"] == "text", response_json
245
- assert response_json["content"][1]["type"] == "tool_use", response_json
246
- content = strip_xml_tags(string=response_json["content"][0]["text"], tag=inner_thoughts_xml_tag)
241
+ prompt_tokens = response.usage.input_tokens
242
+ completion_tokens = response.usage.output_tokens
243
+ finish_reason = remap_finish_reason(response.stop_reason)
244
+
245
+ content = None
246
+ tool_calls = None
247
+
248
+ if len(response.content) > 1:
249
+ # inner mono + function call
250
+ assert len(response.content) == 2
251
+ text_block = response.content[0]
252
+ tool_block = response.content[1]
253
+ assert text_block.type == "text"
254
+ assert tool_block.type == "tool_use"
255
+ content = strip_xml_tags(string=text_block.text, tag=inner_thoughts_xml_tag)
256
+ tool_calls = [
257
+ ToolCall(
258
+ id=tool_block.id,
259
+ type="function",
260
+ function=FunctionCall(
261
+ name=tool_block.name,
262
+ arguments=json.dumps(tool_block.input, indent=2),
263
+ ),
264
+ )
265
+ ]
266
+ elif len(response.content) == 1:
267
+ block = response.content[0]
268
+ if block.type == "tool_use":
269
+ # function call only
247
270
  tool_calls = [
248
271
  ToolCall(
249
- id=response_json["content"][1]["id"],
272
+ id=block.id,
250
273
  type="function",
251
274
  function=FunctionCall(
252
- name=response_json["content"][1]["name"],
253
- arguments=json.dumps(response_json["content"][1]["input"], indent=2),
275
+ name=block.name,
276
+ arguments=json.dumps(block.input, indent=2),
254
277
  ),
255
278
  )
256
279
  ]
257
- elif len(response_json["content"]) == 1:
258
- if response_json["content"][0]["type"] == "tool_use":
259
- # function call only
260
- content = None
261
- tool_calls = [
262
- ToolCall(
263
- id=response_json["content"][0]["id"],
264
- type="function",
265
- function=FunctionCall(
266
- name=response_json["content"][0]["name"],
267
- arguments=json.dumps(response_json["content"][0]["input"], indent=2),
268
- ),
269
- )
270
- ]
271
- else:
272
- # inner mono only
273
- content = strip_xml_tags(string=response_json["content"][0]["text"], tag=inner_thoughts_xml_tag)
274
- tool_calls = None
280
+ else:
281
+ # inner mono only
282
+ content = strip_xml_tags(string=block.text, tag=inner_thoughts_xml_tag)
275
283
  else:
276
- raise RuntimeError("Unexpected type for content in response_json.")
284
+ raise RuntimeError("Unexpected empty content in response")
277
285
 
278
- assert response_json["role"] == "assistant", response_json
286
+ assert response.role == "assistant"
279
287
  choice = Choice(
280
288
  index=0,
281
289
  finish_reason=finish_reason,
282
290
  message=ChoiceMessage(
283
- role=response_json["role"],
291
+ role=response.role,
284
292
  content=content,
285
293
  tool_calls=tool_calls,
286
294
  ),
287
295
  )
288
296
 
289
297
  return ChatCompletionResponse(
290
- id=response_json["id"],
298
+ id=response.id,
291
299
  choices=[choice],
292
300
  created=get_utc_time(),
293
- model=response_json["model"],
301
+ model=response.model,
294
302
  usage=UsageStatistics(
295
303
  prompt_tokens=prompt_tokens,
296
304
  completion_tokens=completion_tokens,
@@ -299,23 +307,11 @@ def convert_anthropic_response_to_chatcompletion(
299
307
  )
300
308
 
301
309
 
302
- def anthropic_chat_completions_request(
303
- url: str,
304
- api_key: str,
310
+ def _prepare_anthropic_request(
305
311
  data: ChatCompletionRequest,
306
312
  inner_thoughts_xml_tag: Optional[str] = "thinking",
307
- ) -> ChatCompletionResponse:
308
- """https://docs.anthropic.com/claude/docs/tool-use"""
309
-
310
- url = smart_urljoin(url, "messages")
311
- headers = {
312
- "Content-Type": "application/json",
313
- "x-api-key": api_key,
314
- # NOTE: beta headers for tool calling
315
- "anthropic-version": "2023-06-01",
316
- "anthropic-beta": "tools-2024-04-04",
317
- }
318
-
313
+ ) -> dict:
314
+ """Prepare the request data for Anthropic API format."""
319
315
  # convert the tools
320
316
  anthropic_tools = None if data.tools is None else convert_tools_to_anthropic_format(data.tools)
321
317
 
@@ -325,57 +321,109 @@ def anthropic_chat_completions_request(
325
321
  if "functions" in data:
326
322
  raise ValueError(f"'functions' unexpected in Anthropic API payload")
327
323
 
328
- # If tools == None, strip from the payload
324
+ # Handle tools
329
325
  if "tools" in data and data["tools"] is None:
330
326
  data.pop("tools")
331
- data.pop("tool_choice", None) # extra safe, should exist always (default="auto")
332
- # Remap to our converted tools
333
- if anthropic_tools is not None:
327
+ data.pop("tool_choice", None)
328
+ elif anthropic_tools is not None:
334
329
  data["tools"] = anthropic_tools
335
-
336
- # TODO: Add support for other tool_choice options like "auto", "any"
337
330
  if len(anthropic_tools) == 1:
338
331
  data["tool_choice"] = {
339
- "type": "tool", # Changed from "function" to "tool"
340
- "name": anthropic_tools[0]["name"], # Directly specify name without nested "function" object
341
- "disable_parallel_tool_use": True, # Force single tool use
332
+ "type": "tool",
333
+ "name": anthropic_tools[0]["name"],
334
+ "disable_parallel_tool_use": True,
342
335
  }
343
336
 
344
337
  # Move 'system' to the top level
345
- # 'messages: Unexpected role "system". The Messages API accepts a top-level `system` parameter, not "system" as an input message role.'
346
338
  assert data["messages"][0]["role"] == "system", f"Expected 'system' role in messages[0]:\n{data['messages'][0]}"
347
339
  data["system"] = data["messages"][0]["content"]
348
340
  data["messages"] = data["messages"][1:]
349
341
 
350
- # set `content` to None if missing
342
+ # Process messages
351
343
  for message in data["messages"]:
352
344
  if "content" not in message:
353
345
  message["content"] = None
354
346
 
355
347
  # Convert to Anthropic format
356
-
357
348
  msg_objs = [Message.dict_to_message(user_id=None, agent_id=None, openai_message_dict=m) for m in data["messages"]]
358
349
  data["messages"] = [m.to_anthropic_dict(inner_thoughts_xml_tag=inner_thoughts_xml_tag) for m in msg_objs]
359
350
 
360
- # Handling Anthropic special requirement for 'user' message in front
361
- # messages: first message must use the "user" role'
351
+ # Ensure first message is user
362
352
  if data["messages"][0]["role"] != "user":
363
353
  data["messages"] = [{"role": "user", "content": DUMMY_FIRST_USER_MESSAGE}] + data["messages"]
364
354
 
365
- # Handle Anthropic's restriction on alternating user/assistant messages
355
+ # Handle alternating messages
366
356
  data["messages"] = merge_tool_results_into_user_messages(data["messages"])
367
357
 
368
- # Anthropic also wants max_tokens in the input
369
- # It's also part of ChatCompletions
358
+ # Validate max_tokens
370
359
  assert "max_tokens" in data, data
371
360
 
372
- # Remove extra fields used by OpenAI but not Anthropic
373
- data.pop("frequency_penalty", None)
374
- data.pop("logprobs", None)
375
- data.pop("n", None)
376
- data.pop("top_p", None)
377
- data.pop("presence_penalty", None)
378
- data.pop("user", None)
361
+ # Remove OpenAI-specific fields
362
+ for field in ["frequency_penalty", "logprobs", "n", "top_p", "presence_penalty", "user"]:
363
+ data.pop(field, None)
364
+
365
+ return data
366
+
367
+
368
+ def get_anthropic_endpoint_and_headers(
369
+ base_url: str,
370
+ api_key: str,
371
+ version: str = "2023-06-01",
372
+ beta: Optional[str] = "tools-2024-04-04",
373
+ ) -> Tuple[str, dict]:
374
+ """
375
+ Dynamically generate the Anthropic endpoint and headers.
376
+ """
377
+ url = smart_urljoin(base_url, "messages")
378
+
379
+ headers = {
380
+ "Content-Type": "application/json",
381
+ "x-api-key": api_key,
382
+ "anthropic-version": version,
383
+ }
384
+
385
+ # Add beta header if specified
386
+ if beta:
387
+ headers["anthropic-beta"] = beta
388
+
389
+ return url, headers
390
+
391
+
392
+ def anthropic_chat_completions_request(
393
+ data: ChatCompletionRequest,
394
+ inner_thoughts_xml_tag: Optional[str] = "thinking",
395
+ betas: List[str] = ["tools-2024-04-04"],
396
+ ) -> ChatCompletionResponse:
397
+ """https://docs.anthropic.com/claude/docs/tool-use"""
398
+ anthropic_client = None
399
+ anthropic_override_key = ProviderManager().get_anthropic_override_key()
400
+ if anthropic_override_key:
401
+ anthropic_client = anthropic.Anthropic(api_key=anthropic_override_key)
402
+ elif model_settings.anthropic_api_key:
403
+ anthropic_client = anthropic.Anthropic()
404
+ data = _prepare_anthropic_request(data, inner_thoughts_xml_tag)
405
+ response = anthropic_client.beta.messages.create(
406
+ **data,
407
+ betas=betas,
408
+ )
409
+ return convert_anthropic_response_to_chatcompletion(response=response, inner_thoughts_xml_tag=inner_thoughts_xml_tag)
410
+
379
411
 
380
- response_json = make_post_request(url, headers, data)
381
- return convert_anthropic_response_to_chatcompletion(response_json=response_json, inner_thoughts_xml_tag=inner_thoughts_xml_tag)
412
+ def anthropic_bedrock_chat_completions_request(
413
+ data: ChatCompletionRequest,
414
+ inner_thoughts_xml_tag: Optional[str] = "thinking",
415
+ ) -> ChatCompletionResponse:
416
+ """Make a chat completion request to Anthropic via AWS Bedrock."""
417
+ data = _prepare_anthropic_request(data, inner_thoughts_xml_tag)
418
+
419
+ # Get the client
420
+ client = get_bedrock_client()
421
+
422
+ # Make the request
423
+ try:
424
+ response = client.messages.create(**data)
425
+ return convert_anthropic_response_to_chatcompletion(response=response, inner_thoughts_xml_tag=inner_thoughts_xml_tag)
426
+ except PermissionDeniedError:
427
+ raise BedrockPermissionError(f"User does not have access to the Bedrock model with the specified ID. {data['model']}")
428
+ except Exception as e:
429
+ raise BedrockError(f"Bedrock error: {e}")