letta-nightly 0.6.9.dev20250119103943__py3-none-any.whl → 0.6.10.dev20250120193553__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of letta-nightly might be problematic. Click here for more details.
- letta/__init__.py +1 -1
- letta/agent.py +40 -23
- letta/client/client.py +10 -2
- letta/errors.py +14 -0
- letta/functions/ast_parsers.py +105 -0
- letta/llm_api/anthropic.py +130 -82
- letta/llm_api/aws_bedrock.py +134 -0
- letta/llm_api/llm_api_tools.py +30 -7
- letta/orm/__init__.py +1 -0
- letta/orm/job.py +2 -4
- letta/orm/message.py +5 -1
- letta/orm/step.py +54 -0
- letta/schemas/embedding_config.py +1 -0
- letta/schemas/letta_message.py +24 -0
- letta/schemas/letta_response.py +1 -9
- letta/schemas/llm_config.py +1 -0
- letta/schemas/message.py +1 -0
- letta/schemas/providers.py +60 -3
- letta/schemas/step.py +31 -0
- letta/server/rest_api/app.py +21 -6
- letta/server/rest_api/routers/v1/agents.py +15 -2
- letta/server/rest_api/routers/v1/llms.py +2 -2
- letta/server/rest_api/routers/v1/runs.py +12 -2
- letta/server/server.py +9 -3
- letta/services/agent_manager.py +4 -3
- letta/services/job_manager.py +11 -13
- letta/services/provider_manager.py +19 -7
- letta/services/step_manager.py +87 -0
- letta/settings.py +21 -1
- {letta_nightly-0.6.9.dev20250119103943.dist-info → letta_nightly-0.6.10.dev20250120193553.dist-info}/METADATA +8 -6
- {letta_nightly-0.6.9.dev20250119103943.dist-info → letta_nightly-0.6.10.dev20250120193553.dist-info}/RECORD +34 -30
- letta/credentials.py +0 -149
- {letta_nightly-0.6.9.dev20250119103943.dist-info → letta_nightly-0.6.10.dev20250120193553.dist-info}/LICENSE +0 -0
- {letta_nightly-0.6.9.dev20250119103943.dist-info → letta_nightly-0.6.10.dev20250120193553.dist-info}/WHEEL +0 -0
- {letta_nightly-0.6.9.dev20250119103943.dist-info → letta_nightly-0.6.10.dev20250120193553.dist-info}/entry_points.txt +0 -0
letta/__init__.py
CHANGED
letta/agent.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
import inspect
|
|
2
1
|
import json
|
|
3
2
|
import time
|
|
4
3
|
import traceback
|
|
@@ -20,6 +19,7 @@ from letta.constants import (
|
|
|
20
19
|
REQ_HEARTBEAT_MESSAGE,
|
|
21
20
|
)
|
|
22
21
|
from letta.errors import ContextWindowExceededError
|
|
22
|
+
from letta.functions.ast_parsers import coerce_dict_args_by_annotations, get_function_annotations_from_source
|
|
23
23
|
from letta.functions.functions import get_function_from_module
|
|
24
24
|
from letta.helpers import ToolRulesSolver
|
|
25
25
|
from letta.interface import AgentInterface
|
|
@@ -49,6 +49,8 @@ from letta.services.helpers.agent_manager_helper import check_supports_structure
|
|
|
49
49
|
from letta.services.job_manager import JobManager
|
|
50
50
|
from letta.services.message_manager import MessageManager
|
|
51
51
|
from letta.services.passage_manager import PassageManager
|
|
52
|
+
from letta.services.provider_manager import ProviderManager
|
|
53
|
+
from letta.services.step_manager import StepManager
|
|
52
54
|
from letta.services.tool_execution_sandbox import ToolExecutionSandbox
|
|
53
55
|
from letta.streaming_interface import StreamingRefreshCLIInterface
|
|
54
56
|
from letta.system import get_heartbeat, get_token_limit_warning, package_function_response, package_summarize_message, package_user_message
|
|
@@ -130,8 +132,10 @@ class Agent(BaseAgent):
|
|
|
130
132
|
# Create the persistence manager object based on the AgentState info
|
|
131
133
|
self.message_manager = MessageManager()
|
|
132
134
|
self.passage_manager = PassageManager()
|
|
135
|
+
self.provider_manager = ProviderManager()
|
|
133
136
|
self.agent_manager = AgentManager()
|
|
134
137
|
self.job_manager = JobManager()
|
|
138
|
+
self.step_manager = StepManager()
|
|
135
139
|
|
|
136
140
|
# State needed for heartbeat pausing
|
|
137
141
|
|
|
@@ -223,15 +227,10 @@ class Agent(BaseAgent):
|
|
|
223
227
|
function_response = callable_func(**function_args)
|
|
224
228
|
self.update_memory_if_changed(agent_state_copy.memory)
|
|
225
229
|
else:
|
|
226
|
-
#
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
callable_func = env[target_letta_tool.json_schema["name"]]
|
|
231
|
-
spec = inspect.getfullargspec(callable_func).annotations
|
|
232
|
-
for name, arg in function_args.items():
|
|
233
|
-
if isinstance(function_args[name], dict):
|
|
234
|
-
function_args[name] = spec[name](**function_args[name])
|
|
230
|
+
# Parse the source code to extract function annotations
|
|
231
|
+
annotations = get_function_annotations_from_source(target_letta_tool.source_code, function_name)
|
|
232
|
+
# Coerce the function arguments to the correct types based on the annotations
|
|
233
|
+
function_args = coerce_dict_args_by_annotations(function_args, annotations)
|
|
235
234
|
|
|
236
235
|
# execute tool in a sandbox
|
|
237
236
|
# TODO: allow agent_state to specify which sandbox to execute tools in
|
|
@@ -355,7 +354,7 @@ class Agent(BaseAgent):
|
|
|
355
354
|
if response_message.tool_calls is not None and len(response_message.tool_calls) > 1:
|
|
356
355
|
# raise NotImplementedError(f">1 tool call not supported")
|
|
357
356
|
# TODO eventually support sequential tool calling
|
|
358
|
-
|
|
357
|
+
self.logger.warning(f">1 tool call not supported, using index=0 only\n{response_message.tool_calls}")
|
|
359
358
|
response_message.tool_calls = [response_message.tool_calls[0]]
|
|
360
359
|
assert response_message.tool_calls is not None and len(response_message.tool_calls) > 0
|
|
361
360
|
|
|
@@ -384,7 +383,7 @@ class Agent(BaseAgent):
|
|
|
384
383
|
openai_message_dict=response_message.model_dump(),
|
|
385
384
|
)
|
|
386
385
|
) # extend conversation with assistant's reply
|
|
387
|
-
|
|
386
|
+
self.logger.info(f"Function call message: {messages[-1]}")
|
|
388
387
|
|
|
389
388
|
nonnull_content = False
|
|
390
389
|
if response_message.content:
|
|
@@ -401,7 +400,7 @@ class Agent(BaseAgent):
|
|
|
401
400
|
|
|
402
401
|
# Get the name of the function
|
|
403
402
|
function_name = function_call.name
|
|
404
|
-
|
|
403
|
+
self.logger.info(f"Request to call function {function_name} with tool_call_id: {tool_call_id}")
|
|
405
404
|
|
|
406
405
|
# Failure case 1: function name is wrong (not in agent_state.tools)
|
|
407
406
|
target_letta_tool = None
|
|
@@ -467,7 +466,7 @@ class Agent(BaseAgent):
|
|
|
467
466
|
heartbeat_request = True
|
|
468
467
|
|
|
469
468
|
if not isinstance(heartbeat_request, bool) or heartbeat_request is None:
|
|
470
|
-
|
|
469
|
+
self.logger.warning(
|
|
471
470
|
f"{CLI_WARNING_PREFIX}'request_heartbeat' arg parsed was not a bool or None, type={type(heartbeat_request)}, value={heartbeat_request}"
|
|
472
471
|
)
|
|
473
472
|
heartbeat_request = False
|
|
@@ -503,7 +502,7 @@ class Agent(BaseAgent):
|
|
|
503
502
|
# Less detailed - don't provide full args, idea is that it should be in recent context so no need (just adds noise)
|
|
504
503
|
error_msg = get_friendly_error_msg(function_name=function_name, exception_name=type(e).__name__, exception_message=str(e))
|
|
505
504
|
error_msg_user = f"{error_msg}\n{traceback.format_exc()}"
|
|
506
|
-
|
|
505
|
+
self.logger.error(error_msg_user)
|
|
507
506
|
function_response = package_function_response(False, error_msg)
|
|
508
507
|
self.last_function_response = function_response
|
|
509
508
|
# TODO: truncate error message somehow
|
|
@@ -630,10 +629,10 @@ class Agent(BaseAgent):
|
|
|
630
629
|
|
|
631
630
|
# Chain stops
|
|
632
631
|
if not chaining:
|
|
633
|
-
|
|
632
|
+
self.logger.info("No chaining, stopping after one step")
|
|
634
633
|
break
|
|
635
634
|
elif max_chaining_steps is not None and counter > max_chaining_steps:
|
|
636
|
-
|
|
635
|
+
self.logger.info(f"Hit max chaining steps, stopping after {counter} steps")
|
|
637
636
|
break
|
|
638
637
|
# Chain handlers
|
|
639
638
|
elif token_warning:
|
|
@@ -713,7 +712,7 @@ class Agent(BaseAgent):
|
|
|
713
712
|
input_message_sequence = in_context_messages + messages
|
|
714
713
|
|
|
715
714
|
if len(input_message_sequence) > 1 and input_message_sequence[-1].role != "user":
|
|
716
|
-
|
|
715
|
+
self.logger.warning(f"{CLI_WARNING_PREFIX}Attempting to run ChatCompletion without user as the last message in the queue")
|
|
717
716
|
|
|
718
717
|
# Step 2: send the conversation and available functions to the LLM
|
|
719
718
|
response = self._get_ai_reply(
|
|
@@ -755,7 +754,7 @@ class Agent(BaseAgent):
|
|
|
755
754
|
)
|
|
756
755
|
|
|
757
756
|
if current_total_tokens > MESSAGE_SUMMARY_WARNING_FRAC * int(self.agent_state.llm_config.context_window):
|
|
758
|
-
|
|
757
|
+
self.logger.warning(
|
|
759
758
|
f"{CLI_WARNING_PREFIX}last response total_tokens ({current_total_tokens}) > {MESSAGE_SUMMARY_WARNING_FRAC * int(self.agent_state.llm_config.context_window)}"
|
|
760
759
|
)
|
|
761
760
|
|
|
@@ -765,10 +764,28 @@ class Agent(BaseAgent):
|
|
|
765
764
|
self.agent_alerted_about_memory_pressure = True # it's up to the outer loop to handle this
|
|
766
765
|
|
|
767
766
|
else:
|
|
768
|
-
|
|
767
|
+
self.logger.warning(
|
|
769
768
|
f"last response total_tokens ({current_total_tokens}) < {MESSAGE_SUMMARY_WARNING_FRAC * int(self.agent_state.llm_config.context_window)}"
|
|
770
769
|
)
|
|
771
770
|
|
|
771
|
+
# Log step - this must happen before messages are persisted
|
|
772
|
+
step = self.step_manager.log_step(
|
|
773
|
+
actor=self.user,
|
|
774
|
+
provider_name=self.agent_state.llm_config.model_endpoint_type,
|
|
775
|
+
model=self.agent_state.llm_config.model,
|
|
776
|
+
context_window_limit=self.agent_state.llm_config.context_window,
|
|
777
|
+
usage=response.usage,
|
|
778
|
+
# TODO(@caren): Add full provider support - this line is a workaround for v0 BYOK feature
|
|
779
|
+
provider_id=(
|
|
780
|
+
self.provider_manager.get_anthropic_override_provider_id()
|
|
781
|
+
if self.agent_state.llm_config.model_endpoint_type == "anthropic"
|
|
782
|
+
else None
|
|
783
|
+
),
|
|
784
|
+
job_id=job_id,
|
|
785
|
+
)
|
|
786
|
+
for message in all_new_messages:
|
|
787
|
+
message.step_id = step.id
|
|
788
|
+
|
|
772
789
|
# Persisting into Messages
|
|
773
790
|
self.agent_state = self.agent_manager.append_to_in_context_messages(
|
|
774
791
|
all_new_messages, agent_id=self.agent_state.id, actor=self.user
|
|
@@ -790,11 +807,11 @@ class Agent(BaseAgent):
|
|
|
790
807
|
)
|
|
791
808
|
|
|
792
809
|
except Exception as e:
|
|
793
|
-
|
|
810
|
+
self.logger.error(f"step() failed\nmessages = {messages}\nerror = {e}")
|
|
794
811
|
|
|
795
812
|
# If we got a context alert, try trimming the messages length, then try again
|
|
796
813
|
if is_context_overflow_error(e):
|
|
797
|
-
|
|
814
|
+
self.logger.warning(
|
|
798
815
|
f"context window exceeded with limit {self.agent_state.llm_config.context_window}, running summarizer to trim messages"
|
|
799
816
|
)
|
|
800
817
|
# A separate API call to run a summarizer
|
|
@@ -811,7 +828,7 @@ class Agent(BaseAgent):
|
|
|
811
828
|
)
|
|
812
829
|
|
|
813
830
|
else:
|
|
814
|
-
|
|
831
|
+
self.logger.error(f"step() failed with an unrecognized exception: '{str(e)}'")
|
|
815
832
|
raise e
|
|
816
833
|
|
|
817
834
|
def step_user_message(self, user_message_str: str, **kwargs) -> AgentStepResponse:
|
letta/client/client.py
CHANGED
|
@@ -410,7 +410,8 @@ class RESTClient(AbstractClient):
|
|
|
410
410
|
def __init__(
|
|
411
411
|
self,
|
|
412
412
|
base_url: str,
|
|
413
|
-
token: str,
|
|
413
|
+
token: Optional[str] = None,
|
|
414
|
+
password: Optional[str] = None,
|
|
414
415
|
api_prefix: str = "v1",
|
|
415
416
|
debug: bool = False,
|
|
416
417
|
default_llm_config: Optional[LLMConfig] = None,
|
|
@@ -426,11 +427,18 @@ class RESTClient(AbstractClient):
|
|
|
426
427
|
default_llm_config (Optional[LLMConfig]): The default LLM configuration.
|
|
427
428
|
default_embedding_config (Optional[EmbeddingConfig]): The default embedding configuration.
|
|
428
429
|
headers (Optional[Dict]): The additional headers for the REST API.
|
|
430
|
+
token (Optional[str]): The token for the REST API when using managed letta service.
|
|
431
|
+
password (Optional[str]): The password for the REST API when using self hosted letta service.
|
|
429
432
|
"""
|
|
430
433
|
super().__init__(debug=debug)
|
|
431
434
|
self.base_url = base_url
|
|
432
435
|
self.api_prefix = api_prefix
|
|
433
|
-
|
|
436
|
+
if token:
|
|
437
|
+
self.headers = {"accept": "application/json", "Authorization": f"Bearer {token}"}
|
|
438
|
+
elif password:
|
|
439
|
+
self.headers = {"accept": "application/json", "X-BARE-PASSWORD": f"password {password}"}
|
|
440
|
+
else:
|
|
441
|
+
self.headers = {"accept": "application/json"}
|
|
434
442
|
if headers:
|
|
435
443
|
self.headers.update(headers)
|
|
436
444
|
self._default_llm_config = default_llm_config
|
letta/errors.py
CHANGED
|
@@ -62,6 +62,20 @@ class LLMError(LettaError):
|
|
|
62
62
|
pass
|
|
63
63
|
|
|
64
64
|
|
|
65
|
+
class BedrockPermissionError(LettaError):
|
|
66
|
+
"""Exception raised for errors in the Bedrock permission process."""
|
|
67
|
+
|
|
68
|
+
def __init__(self, message="User does not have access to the Bedrock model with the specified ID."):
|
|
69
|
+
super().__init__(message=message)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class BedrockError(LettaError):
|
|
73
|
+
"""Exception raised for errors in the Bedrock process."""
|
|
74
|
+
|
|
75
|
+
def __init__(self, message="Error with Bedrock model."):
|
|
76
|
+
super().__init__(message=message)
|
|
77
|
+
|
|
78
|
+
|
|
65
79
|
class LLMJSONParsingError(LettaError):
|
|
66
80
|
"""Exception raised for errors in the JSON parsing process."""
|
|
67
81
|
|
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
import ast
|
|
2
|
+
import json
|
|
3
|
+
from typing import Dict
|
|
4
|
+
|
|
5
|
+
# Registry of known types for annotation resolution
|
|
6
|
+
BUILTIN_TYPES = {
|
|
7
|
+
"int": int,
|
|
8
|
+
"float": float,
|
|
9
|
+
"str": str,
|
|
10
|
+
"dict": dict,
|
|
11
|
+
"list": list,
|
|
12
|
+
"set": set,
|
|
13
|
+
"tuple": tuple,
|
|
14
|
+
"bool": bool,
|
|
15
|
+
}
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def resolve_type(annotation: str):
|
|
19
|
+
"""
|
|
20
|
+
Resolve a type annotation string into a Python type.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
annotation (str): The annotation string (e.g., 'int', 'list', etc.).
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
type: The corresponding Python type.
|
|
27
|
+
|
|
28
|
+
Raises:
|
|
29
|
+
ValueError: If the annotation is unsupported or invalid.
|
|
30
|
+
"""
|
|
31
|
+
if annotation in BUILTIN_TYPES:
|
|
32
|
+
return BUILTIN_TYPES[annotation]
|
|
33
|
+
|
|
34
|
+
try:
|
|
35
|
+
parsed = ast.literal_eval(annotation)
|
|
36
|
+
if isinstance(parsed, type):
|
|
37
|
+
return parsed
|
|
38
|
+
raise ValueError(f"Annotation '{annotation}' is not a recognized type.")
|
|
39
|
+
except (ValueError, SyntaxError):
|
|
40
|
+
raise ValueError(f"Unsupported annotation: {annotation}")
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def get_function_annotations_from_source(source_code: str, function_name: str) -> Dict[str, str]:
|
|
44
|
+
"""
|
|
45
|
+
Parse the source code to extract annotations for a given function name.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
source_code (str): The Python source code containing the function.
|
|
49
|
+
function_name (str): The name of the function to extract annotations for.
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
Dict[str, str]: A dictionary of argument names to their annotation strings.
|
|
53
|
+
|
|
54
|
+
Raises:
|
|
55
|
+
ValueError: If the function is not found in the source code.
|
|
56
|
+
"""
|
|
57
|
+
tree = ast.parse(source_code)
|
|
58
|
+
for node in ast.iter_child_nodes(tree):
|
|
59
|
+
if isinstance(node, ast.FunctionDef) and node.name == function_name:
|
|
60
|
+
annotations = {}
|
|
61
|
+
for arg in node.args.args:
|
|
62
|
+
if arg.annotation is not None:
|
|
63
|
+
annotation_str = ast.unparse(arg.annotation)
|
|
64
|
+
annotations[arg.arg] = annotation_str
|
|
65
|
+
return annotations
|
|
66
|
+
raise ValueError(f"Function '{function_name}' not found in the provided source code.")
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def coerce_dict_args_by_annotations(function_args: dict, annotations: Dict[str, str]) -> dict:
|
|
70
|
+
"""
|
|
71
|
+
Coerce arguments in a dictionary to their annotated types.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
function_args (dict): The original function arguments.
|
|
75
|
+
annotations (Dict[str, str]): Argument annotations as strings.
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
dict: The updated dictionary with coerced argument types.
|
|
79
|
+
|
|
80
|
+
Raises:
|
|
81
|
+
ValueError: If type coercion fails for an argument.
|
|
82
|
+
"""
|
|
83
|
+
coerced_args = dict(function_args) # Shallow copy for mutation safety
|
|
84
|
+
|
|
85
|
+
for arg_name, value in coerced_args.items():
|
|
86
|
+
if arg_name in annotations:
|
|
87
|
+
annotation_str = annotations[arg_name]
|
|
88
|
+
try:
|
|
89
|
+
# Resolve the type from the annotation
|
|
90
|
+
arg_type = resolve_type(annotation_str)
|
|
91
|
+
|
|
92
|
+
# Handle JSON-like inputs for dict and list types
|
|
93
|
+
if arg_type in {dict, list} and isinstance(value, str):
|
|
94
|
+
try:
|
|
95
|
+
# First, try JSON parsing
|
|
96
|
+
value = json.loads(value)
|
|
97
|
+
except json.JSONDecodeError:
|
|
98
|
+
# Fall back to literal_eval for Python-specific literals
|
|
99
|
+
value = ast.literal_eval(value)
|
|
100
|
+
|
|
101
|
+
# Coerce the value to the resolved type
|
|
102
|
+
coerced_args[arg_name] = arg_type(value)
|
|
103
|
+
except (TypeError, ValueError, json.JSONDecodeError, SyntaxError) as e:
|
|
104
|
+
raise ValueError(f"Failed to coerce argument '{arg_name}' to {annotation_str}: {e}")
|
|
105
|
+
return coerced_args
|
letta/llm_api/anthropic.py
CHANGED
|
@@ -1,8 +1,12 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import re
|
|
3
|
-
from typing import List, Optional, Union
|
|
3
|
+
from typing import List, Optional, Tuple, Union
|
|
4
4
|
|
|
5
|
-
|
|
5
|
+
import anthropic
|
|
6
|
+
from anthropic import PermissionDeniedError
|
|
7
|
+
|
|
8
|
+
from letta.errors import BedrockError, BedrockPermissionError
|
|
9
|
+
from letta.llm_api.aws_bedrock import get_bedrock_client
|
|
6
10
|
from letta.schemas.message import Message
|
|
7
11
|
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, Tool
|
|
8
12
|
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice, FunctionCall
|
|
@@ -10,6 +14,8 @@ from letta.schemas.openai.chat_completion_response import (
|
|
|
10
14
|
Message as ChoiceMessage, # NOTE: avoid conflict with our own Letta Message datatype
|
|
11
15
|
)
|
|
12
16
|
from letta.schemas.openai.chat_completion_response import ToolCall, UsageStatistics
|
|
17
|
+
from letta.services.provider_manager import ProviderManager
|
|
18
|
+
from letta.settings import model_settings
|
|
13
19
|
from letta.utils import get_utc_time, smart_urljoin
|
|
14
20
|
|
|
15
21
|
BASE_URL = "https://api.anthropic.com/v1"
|
|
@@ -195,7 +201,7 @@ def strip_xml_tags(string: str, tag: Optional[str]) -> str:
|
|
|
195
201
|
|
|
196
202
|
|
|
197
203
|
def convert_anthropic_response_to_chatcompletion(
|
|
198
|
-
|
|
204
|
+
response: anthropic.types.Message,
|
|
199
205
|
inner_thoughts_xml_tag: Optional[str] = None,
|
|
200
206
|
) -> ChatCompletionResponse:
|
|
201
207
|
"""
|
|
@@ -232,65 +238,67 @@ def convert_anthropic_response_to_chatcompletion(
|
|
|
232
238
|
}
|
|
233
239
|
}
|
|
234
240
|
"""
|
|
235
|
-
prompt_tokens =
|
|
236
|
-
completion_tokens =
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
241
|
+
prompt_tokens = response.usage.input_tokens
|
|
242
|
+
completion_tokens = response.usage.output_tokens
|
|
243
|
+
finish_reason = remap_finish_reason(response.stop_reason)
|
|
244
|
+
|
|
245
|
+
content = None
|
|
246
|
+
tool_calls = None
|
|
247
|
+
|
|
248
|
+
if len(response.content) > 1:
|
|
249
|
+
# inner mono + function call
|
|
250
|
+
assert len(response.content) == 2
|
|
251
|
+
text_block = response.content[0]
|
|
252
|
+
tool_block = response.content[1]
|
|
253
|
+
assert text_block.type == "text"
|
|
254
|
+
assert tool_block.type == "tool_use"
|
|
255
|
+
content = strip_xml_tags(string=text_block.text, tag=inner_thoughts_xml_tag)
|
|
256
|
+
tool_calls = [
|
|
257
|
+
ToolCall(
|
|
258
|
+
id=tool_block.id,
|
|
259
|
+
type="function",
|
|
260
|
+
function=FunctionCall(
|
|
261
|
+
name=tool_block.name,
|
|
262
|
+
arguments=json.dumps(tool_block.input, indent=2),
|
|
263
|
+
),
|
|
264
|
+
)
|
|
265
|
+
]
|
|
266
|
+
elif len(response.content) == 1:
|
|
267
|
+
block = response.content[0]
|
|
268
|
+
if block.type == "tool_use":
|
|
269
|
+
# function call only
|
|
247
270
|
tool_calls = [
|
|
248
271
|
ToolCall(
|
|
249
|
-
id=
|
|
272
|
+
id=block.id,
|
|
250
273
|
type="function",
|
|
251
274
|
function=FunctionCall(
|
|
252
|
-
name=
|
|
253
|
-
arguments=json.dumps(
|
|
275
|
+
name=block.name,
|
|
276
|
+
arguments=json.dumps(block.input, indent=2),
|
|
254
277
|
),
|
|
255
278
|
)
|
|
256
279
|
]
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
content = None
|
|
261
|
-
tool_calls = [
|
|
262
|
-
ToolCall(
|
|
263
|
-
id=response_json["content"][0]["id"],
|
|
264
|
-
type="function",
|
|
265
|
-
function=FunctionCall(
|
|
266
|
-
name=response_json["content"][0]["name"],
|
|
267
|
-
arguments=json.dumps(response_json["content"][0]["input"], indent=2),
|
|
268
|
-
),
|
|
269
|
-
)
|
|
270
|
-
]
|
|
271
|
-
else:
|
|
272
|
-
# inner mono only
|
|
273
|
-
content = strip_xml_tags(string=response_json["content"][0]["text"], tag=inner_thoughts_xml_tag)
|
|
274
|
-
tool_calls = None
|
|
280
|
+
else:
|
|
281
|
+
# inner mono only
|
|
282
|
+
content = strip_xml_tags(string=block.text, tag=inner_thoughts_xml_tag)
|
|
275
283
|
else:
|
|
276
|
-
raise RuntimeError("Unexpected
|
|
284
|
+
raise RuntimeError("Unexpected empty content in response")
|
|
277
285
|
|
|
278
|
-
assert
|
|
286
|
+
assert response.role == "assistant"
|
|
279
287
|
choice = Choice(
|
|
280
288
|
index=0,
|
|
281
289
|
finish_reason=finish_reason,
|
|
282
290
|
message=ChoiceMessage(
|
|
283
|
-
role=
|
|
291
|
+
role=response.role,
|
|
284
292
|
content=content,
|
|
285
293
|
tool_calls=tool_calls,
|
|
286
294
|
),
|
|
287
295
|
)
|
|
288
296
|
|
|
289
297
|
return ChatCompletionResponse(
|
|
290
|
-
id=
|
|
298
|
+
id=response.id,
|
|
291
299
|
choices=[choice],
|
|
292
300
|
created=get_utc_time(),
|
|
293
|
-
model=
|
|
301
|
+
model=response.model,
|
|
294
302
|
usage=UsageStatistics(
|
|
295
303
|
prompt_tokens=prompt_tokens,
|
|
296
304
|
completion_tokens=completion_tokens,
|
|
@@ -299,23 +307,11 @@ def convert_anthropic_response_to_chatcompletion(
|
|
|
299
307
|
)
|
|
300
308
|
|
|
301
309
|
|
|
302
|
-
def
|
|
303
|
-
url: str,
|
|
304
|
-
api_key: str,
|
|
310
|
+
def _prepare_anthropic_request(
|
|
305
311
|
data: ChatCompletionRequest,
|
|
306
312
|
inner_thoughts_xml_tag: Optional[str] = "thinking",
|
|
307
|
-
) ->
|
|
308
|
-
"""
|
|
309
|
-
|
|
310
|
-
url = smart_urljoin(url, "messages")
|
|
311
|
-
headers = {
|
|
312
|
-
"Content-Type": "application/json",
|
|
313
|
-
"x-api-key": api_key,
|
|
314
|
-
# NOTE: beta headers for tool calling
|
|
315
|
-
"anthropic-version": "2023-06-01",
|
|
316
|
-
"anthropic-beta": "tools-2024-04-04",
|
|
317
|
-
}
|
|
318
|
-
|
|
313
|
+
) -> dict:
|
|
314
|
+
"""Prepare the request data for Anthropic API format."""
|
|
319
315
|
# convert the tools
|
|
320
316
|
anthropic_tools = None if data.tools is None else convert_tools_to_anthropic_format(data.tools)
|
|
321
317
|
|
|
@@ -325,57 +321,109 @@ def anthropic_chat_completions_request(
|
|
|
325
321
|
if "functions" in data:
|
|
326
322
|
raise ValueError(f"'functions' unexpected in Anthropic API payload")
|
|
327
323
|
|
|
328
|
-
#
|
|
324
|
+
# Handle tools
|
|
329
325
|
if "tools" in data and data["tools"] is None:
|
|
330
326
|
data.pop("tools")
|
|
331
|
-
data.pop("tool_choice", None)
|
|
332
|
-
|
|
333
|
-
if anthropic_tools is not None:
|
|
327
|
+
data.pop("tool_choice", None)
|
|
328
|
+
elif anthropic_tools is not None:
|
|
334
329
|
data["tools"] = anthropic_tools
|
|
335
|
-
|
|
336
|
-
# TODO: Add support for other tool_choice options like "auto", "any"
|
|
337
330
|
if len(anthropic_tools) == 1:
|
|
338
331
|
data["tool_choice"] = {
|
|
339
|
-
"type": "tool",
|
|
340
|
-
"name": anthropic_tools[0]["name"],
|
|
341
|
-
"disable_parallel_tool_use": True,
|
|
332
|
+
"type": "tool",
|
|
333
|
+
"name": anthropic_tools[0]["name"],
|
|
334
|
+
"disable_parallel_tool_use": True,
|
|
342
335
|
}
|
|
343
336
|
|
|
344
337
|
# Move 'system' to the top level
|
|
345
|
-
# 'messages: Unexpected role "system". The Messages API accepts a top-level `system` parameter, not "system" as an input message role.'
|
|
346
338
|
assert data["messages"][0]["role"] == "system", f"Expected 'system' role in messages[0]:\n{data['messages'][0]}"
|
|
347
339
|
data["system"] = data["messages"][0]["content"]
|
|
348
340
|
data["messages"] = data["messages"][1:]
|
|
349
341
|
|
|
350
|
-
#
|
|
342
|
+
# Process messages
|
|
351
343
|
for message in data["messages"]:
|
|
352
344
|
if "content" not in message:
|
|
353
345
|
message["content"] = None
|
|
354
346
|
|
|
355
347
|
# Convert to Anthropic format
|
|
356
|
-
|
|
357
348
|
msg_objs = [Message.dict_to_message(user_id=None, agent_id=None, openai_message_dict=m) for m in data["messages"]]
|
|
358
349
|
data["messages"] = [m.to_anthropic_dict(inner_thoughts_xml_tag=inner_thoughts_xml_tag) for m in msg_objs]
|
|
359
350
|
|
|
360
|
-
#
|
|
361
|
-
# messages: first message must use the "user" role'
|
|
351
|
+
# Ensure first message is user
|
|
362
352
|
if data["messages"][0]["role"] != "user":
|
|
363
353
|
data["messages"] = [{"role": "user", "content": DUMMY_FIRST_USER_MESSAGE}] + data["messages"]
|
|
364
354
|
|
|
365
|
-
# Handle
|
|
355
|
+
# Handle alternating messages
|
|
366
356
|
data["messages"] = merge_tool_results_into_user_messages(data["messages"])
|
|
367
357
|
|
|
368
|
-
#
|
|
369
|
-
# It's also part of ChatCompletions
|
|
358
|
+
# Validate max_tokens
|
|
370
359
|
assert "max_tokens" in data, data
|
|
371
360
|
|
|
372
|
-
# Remove
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
data
|
|
377
|
-
|
|
378
|
-
|
|
361
|
+
# Remove OpenAI-specific fields
|
|
362
|
+
for field in ["frequency_penalty", "logprobs", "n", "top_p", "presence_penalty", "user"]:
|
|
363
|
+
data.pop(field, None)
|
|
364
|
+
|
|
365
|
+
return data
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
def get_anthropic_endpoint_and_headers(
|
|
369
|
+
base_url: str,
|
|
370
|
+
api_key: str,
|
|
371
|
+
version: str = "2023-06-01",
|
|
372
|
+
beta: Optional[str] = "tools-2024-04-04",
|
|
373
|
+
) -> Tuple[str, dict]:
|
|
374
|
+
"""
|
|
375
|
+
Dynamically generate the Anthropic endpoint and headers.
|
|
376
|
+
"""
|
|
377
|
+
url = smart_urljoin(base_url, "messages")
|
|
378
|
+
|
|
379
|
+
headers = {
|
|
380
|
+
"Content-Type": "application/json",
|
|
381
|
+
"x-api-key": api_key,
|
|
382
|
+
"anthropic-version": version,
|
|
383
|
+
}
|
|
384
|
+
|
|
385
|
+
# Add beta header if specified
|
|
386
|
+
if beta:
|
|
387
|
+
headers["anthropic-beta"] = beta
|
|
388
|
+
|
|
389
|
+
return url, headers
|
|
390
|
+
|
|
391
|
+
|
|
392
|
+
def anthropic_chat_completions_request(
|
|
393
|
+
data: ChatCompletionRequest,
|
|
394
|
+
inner_thoughts_xml_tag: Optional[str] = "thinking",
|
|
395
|
+
betas: List[str] = ["tools-2024-04-04"],
|
|
396
|
+
) -> ChatCompletionResponse:
|
|
397
|
+
"""https://docs.anthropic.com/claude/docs/tool-use"""
|
|
398
|
+
anthropic_client = None
|
|
399
|
+
anthropic_override_key = ProviderManager().get_anthropic_override_key()
|
|
400
|
+
if anthropic_override_key:
|
|
401
|
+
anthropic_client = anthropic.Anthropic(api_key=anthropic_override_key)
|
|
402
|
+
elif model_settings.anthropic_api_key:
|
|
403
|
+
anthropic_client = anthropic.Anthropic()
|
|
404
|
+
data = _prepare_anthropic_request(data, inner_thoughts_xml_tag)
|
|
405
|
+
response = anthropic_client.beta.messages.create(
|
|
406
|
+
**data,
|
|
407
|
+
betas=betas,
|
|
408
|
+
)
|
|
409
|
+
return convert_anthropic_response_to_chatcompletion(response=response, inner_thoughts_xml_tag=inner_thoughts_xml_tag)
|
|
410
|
+
|
|
379
411
|
|
|
380
|
-
|
|
381
|
-
|
|
412
|
+
def anthropic_bedrock_chat_completions_request(
|
|
413
|
+
data: ChatCompletionRequest,
|
|
414
|
+
inner_thoughts_xml_tag: Optional[str] = "thinking",
|
|
415
|
+
) -> ChatCompletionResponse:
|
|
416
|
+
"""Make a chat completion request to Anthropic via AWS Bedrock."""
|
|
417
|
+
data = _prepare_anthropic_request(data, inner_thoughts_xml_tag)
|
|
418
|
+
|
|
419
|
+
# Get the client
|
|
420
|
+
client = get_bedrock_client()
|
|
421
|
+
|
|
422
|
+
# Make the request
|
|
423
|
+
try:
|
|
424
|
+
response = client.messages.create(**data)
|
|
425
|
+
return convert_anthropic_response_to_chatcompletion(response=response, inner_thoughts_xml_tag=inner_thoughts_xml_tag)
|
|
426
|
+
except PermissionDeniedError:
|
|
427
|
+
raise BedrockPermissionError(f"User does not have access to the Bedrock model with the specified ID. {data['model']}")
|
|
428
|
+
except Exception as e:
|
|
429
|
+
raise BedrockError(f"Bedrock error: {e}")
|