inspect-ai 0.3.88__py3-none-any.whl → 0.3.89__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. inspect_ai/_cli/eval.py +16 -0
  2. inspect_ai/_cli/score.py +1 -12
  3. inspect_ai/_cli/util.py +4 -2
  4. inspect_ai/_display/core/footer.py +2 -2
  5. inspect_ai/_display/plain/display.py +2 -2
  6. inspect_ai/_eval/context.py +7 -1
  7. inspect_ai/_eval/eval.py +51 -27
  8. inspect_ai/_eval/evalset.py +27 -10
  9. inspect_ai/_eval/loader.py +7 -8
  10. inspect_ai/_eval/run.py +23 -31
  11. inspect_ai/_eval/score.py +18 -1
  12. inspect_ai/_eval/task/log.py +5 -13
  13. inspect_ai/_eval/task/resolved.py +1 -0
  14. inspect_ai/_eval/task/run.py +231 -244
  15. inspect_ai/_eval/task/task.py +25 -2
  16. inspect_ai/_eval/task/util.py +1 -8
  17. inspect_ai/_util/constants.py +1 -0
  18. inspect_ai/_util/json.py +8 -3
  19. inspect_ai/_util/registry.py +30 -13
  20. inspect_ai/_view/www/App.css +5 -0
  21. inspect_ai/_view/www/dist/assets/index.css +55 -18
  22. inspect_ai/_view/www/dist/assets/index.js +550 -458
  23. inspect_ai/_view/www/log-schema.json +66 -0
  24. inspect_ai/_view/www/src/metadata/MetaDataView.module.css +1 -1
  25. inspect_ai/_view/www/src/metadata/MetaDataView.tsx +13 -8
  26. inspect_ai/_view/www/src/metadata/RenderedContent.tsx +3 -0
  27. inspect_ai/_view/www/src/plan/ModelCard.module.css +16 -0
  28. inspect_ai/_view/www/src/plan/ModelCard.tsx +93 -0
  29. inspect_ai/_view/www/src/samples/transcript/ModelEventView.tsx +5 -1
  30. inspect_ai/_view/www/src/samples/transcript/state/StateEventView.tsx +6 -29
  31. inspect_ai/_view/www/src/types/log.d.ts +24 -6
  32. inspect_ai/_view/www/src/workspace/navbar/ModelRolesView.module.css +16 -0
  33. inspect_ai/_view/www/src/workspace/navbar/ModelRolesView.tsx +43 -0
  34. inspect_ai/_view/www/src/workspace/navbar/PrimaryBar.module.css +1 -1
  35. inspect_ai/_view/www/src/workspace/navbar/PrimaryBar.tsx +5 -0
  36. inspect_ai/_view/www/src/workspace/tabs/InfoTab.tsx +2 -0
  37. inspect_ai/agent/_agent.py +12 -0
  38. inspect_ai/agent/_as_tool.py +1 -1
  39. inspect_ai/agent/_bridge/bridge.py +9 -2
  40. inspect_ai/agent/_react.py +142 -74
  41. inspect_ai/agent/_run.py +13 -2
  42. inspect_ai/agent/_types.py +6 -0
  43. inspect_ai/approval/_apply.py +6 -7
  44. inspect_ai/approval/_approver.py +3 -3
  45. inspect_ai/approval/_auto.py +2 -2
  46. inspect_ai/approval/_call.py +20 -4
  47. inspect_ai/approval/_human/approver.py +3 -3
  48. inspect_ai/approval/_human/manager.py +2 -2
  49. inspect_ai/approval/_human/panel.py +3 -3
  50. inspect_ai/approval/_policy.py +3 -3
  51. inspect_ai/log/__init__.py +2 -0
  52. inspect_ai/log/_log.py +23 -2
  53. inspect_ai/log/_model.py +58 -0
  54. inspect_ai/log/_recorders/file.py +14 -3
  55. inspect_ai/log/_transcript.py +3 -0
  56. inspect_ai/model/__init__.py +2 -0
  57. inspect_ai/model/_call_tools.py +4 -1
  58. inspect_ai/model/_model.py +49 -3
  59. inspect_ai/model/_openai.py +151 -21
  60. inspect_ai/model/_providers/anthropic.py +20 -12
  61. inspect_ai/model/_providers/bedrock.py +3 -3
  62. inspect_ai/model/_providers/cloudflare.py +29 -108
  63. inspect_ai/model/_providers/google.py +21 -10
  64. inspect_ai/model/_providers/grok.py +23 -17
  65. inspect_ai/model/_providers/groq.py +61 -37
  66. inspect_ai/model/_providers/llama_cpp_python.py +8 -9
  67. inspect_ai/model/_providers/mistral.py +8 -3
  68. inspect_ai/model/_providers/ollama.py +8 -9
  69. inspect_ai/model/_providers/openai.py +53 -157
  70. inspect_ai/model/_providers/openai_compatible.py +195 -0
  71. inspect_ai/model/_providers/openrouter.py +4 -15
  72. inspect_ai/model/_providers/providers.py +11 -0
  73. inspect_ai/model/_providers/together.py +25 -23
  74. inspect_ai/model/_trim.py +83 -0
  75. inspect_ai/solver/_plan.py +5 -3
  76. inspect_ai/tool/_tool_def.py +8 -2
  77. inspect_ai/util/__init__.py +3 -0
  78. inspect_ai/util/_concurrency.py +15 -2
  79. {inspect_ai-0.3.88.dist-info → inspect_ai-0.3.89.dist-info}/METADATA +1 -1
  80. {inspect_ai-0.3.88.dist-info → inspect_ai-0.3.89.dist-info}/RECORD +84 -79
  81. inspect_ai/_eval/task/rundir.py +0 -78
  82. inspect_ai/_view/www/node_modules/flatted/python/flatted.py +0 -149
  83. {inspect_ai-0.3.88.dist-info → inspect_ai-0.3.89.dist-info}/WHEEL +0 -0
  84. {inspect_ai-0.3.88.dist-info → inspect_ai-0.3.89.dist-info}/entry_points.txt +0 -0
  85. {inspect_ai-0.3.88.dist-info → inspect_ai-0.3.89.dist-info}/licenses/LICENSE +0 -0
  86. {inspect_ai-0.3.88.dist-info → inspect_ai-0.3.89.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,10 @@
1
+ import os
1
2
  from logging import getLogger
2
3
  from typing import Any
3
4
 
4
5
  from typing_extensions import override
5
6
 
7
+ from inspect_ai._util.constants import MODEL_NONE
6
8
  from inspect_ai._util.file import filesystem
7
9
  from inspect_ai._util.registry import registry_unqualified_name
8
10
 
@@ -71,9 +73,18 @@ class FileRecorder(Recorder):
71
73
  return s.replace("_", "-").replace("/", "-").replace(":", "-")
72
74
 
73
75
  # remove package from task name
74
- task = registry_unqualified_name(eval.task)
75
-
76
- return f"{clean(eval.created)}_{clean(task)}_{clean(eval.task_id)}"
76
+ task = registry_unqualified_name(eval.task) # noqa: F841
77
+
78
+ # derive log file pattern
79
+ log_file_pattern = os.getenv("INSPECT_EVAL_LOG_FILE_PATTERN", "{task}_{id}")
80
+
81
+ # compute and return log file name
82
+ log_file_name = f"{clean(eval.created)}_" + log_file_pattern
83
+ log_file_name = log_file_name.replace("{task}", clean(task))
84
+ log_file_name = log_file_name.replace("{id}", clean(eval.task_id))
85
+ model = clean(eval.model) if eval.model != MODEL_NONE else ""
86
+ log_file_name = log_file_name.replace("{model}", model)
87
+ return log_file_name
77
88
 
78
89
  def _log_file_path(self, eval: EvalSpec) -> str:
79
90
  return f"{self.log_dir}{self.fs.sep}{self._log_file_key(eval)}{self.suffix}"
@@ -123,6 +123,9 @@ class ModelEvent(BaseEvent):
123
123
  model: str
124
124
  """Model name."""
125
125
 
126
+ role: str | None = Field(default=None)
127
+ """Model role."""
128
+
126
129
  input: list[ChatMessage]
127
130
  """Model input (list of messages)."""
128
131
 
@@ -47,6 +47,7 @@ from ._model_output import (
47
47
  )
48
48
  from ._providers.providers import *
49
49
  from ._registry import modelapi
50
+ from ._trim import trim_messages
50
51
 
51
52
  __all__ = [
52
53
  "GenerateConfig",
@@ -80,6 +81,7 @@ __all__ = [
80
81
  "call_tools",
81
82
  "execute_tools",
82
83
  "ExecuteToolsResult",
84
+ "trim_messages",
83
85
  "cache_clear",
84
86
  "cache_list_expired",
85
87
  "cache_path",
@@ -264,6 +264,7 @@ async def execute_tools(
264
264
  tuple[ExecuteToolsResult, ToolEvent, Exception | None]
265
265
  ]()
266
266
 
267
+ result_exception = None
267
268
  async with anyio.create_task_group() as tg:
268
269
  tg.start_soon(call_tool_task, call, messages, send_stream)
269
270
  event._set_cancel_fn(tg.cancel_scope.cancel)
@@ -348,7 +349,9 @@ async def call_tool(
348
349
  # if we have a tool approver, apply it now
349
350
  from inspect_ai.approval._apply import apply_tool_approval
350
351
 
351
- approved, approval = await apply_tool_approval(message, call, tool_def.viewer)
352
+ approved, approval = await apply_tool_approval(
353
+ message, call, tool_def.viewer, conversation
354
+ )
352
355
  if not approved:
353
356
  if approval and approval.decision == "terminate":
354
357
  from inspect_ai.solver._limit import SampleLimitExceededError
@@ -270,6 +270,7 @@ class Model:
270
270
  self.api = api
271
271
  self.config = config
272
272
  self.model_args = model_args
273
+ self._role: str | None = None
273
274
 
274
275
  # state indicating whether our lifetime is bound by a context manager
275
276
  self._context_bound = False
@@ -311,6 +312,14 @@ class Model:
311
312
  """Model name."""
312
313
  return self.api.model_name
313
314
 
315
+ @property
316
+ def role(self) -> str | None:
317
+ """Model role."""
318
+ return self._role
319
+
320
+ def _set_role(self, role: str) -> None:
321
+ self._role = role
322
+
314
323
  def __str__(self) -> str:
315
324
  return f"{ModelName(self)}"
316
325
 
@@ -716,7 +725,7 @@ class Model:
716
725
  )
717
726
  model_name = ModelName(self)
718
727
  async with concurrency(
719
- name=f"{model_name.api}",
728
+ name=str(model_name),
720
729
  concurrency=max_connections,
721
730
  key=f"Model{self.api.connection_key()}",
722
731
  ):
@@ -738,6 +747,7 @@ class Model:
738
747
  model = str(self)
739
748
  event = ModelEvent(
740
749
  model=model,
750
+ role=self.role,
741
751
  input=input,
742
752
  tools=tools,
743
753
  tool_choice=tool_choice,
@@ -828,6 +838,9 @@ class ModelName:
828
838
 
829
839
  def get_model(
830
840
  model: str | Model | None = None,
841
+ *,
842
+ role: str | None = None,
843
+ default: str | Model | None = None,
831
844
  config: GenerateConfig = GenerateConfig(),
832
845
  base_url: str | None = None,
833
846
  api_key: str | None = None,
@@ -858,6 +871,11 @@ def get_model(
858
871
  if `None` is passed then the model currently being
859
872
  evaluated is returned (or if there is no evaluation
860
873
  then the model referred to by `INSPECT_EVAL_MODEL`).
874
+ role: Optional named role for model (e.g. for roles specified
875
+ at the task or eval level). Provide a `default` as a fallback
876
+ in the case where the `role` hasn't been externally specified.
877
+ default: Optional. Fallback model in case the specified
878
+ `model` or `role` is not found.
861
879
  config: Configuration for model.
862
880
  base_url: Optional. Alternate base URL for model.
863
881
  api_key: Optional. API key for model.
@@ -878,6 +896,22 @@ def get_model(
878
896
  if model == "none":
879
897
  model = "none/none"
880
898
 
899
+ # resolve model role
900
+ if role is not None:
901
+ model_for_role = model_roles().get(role, None)
902
+ if model_for_role is not None:
903
+ return model_for_role
904
+
905
+ # if a default was specified then use it as the model if
906
+ # no model was passed
907
+ if model is None:
908
+ if isinstance(default, Model):
909
+ if role is not None:
910
+ default._set_role(role)
911
+ return default
912
+ else:
913
+ model = default
914
+
881
915
  # now try finding an 'ambient' model (active or env var)
882
916
  if model is None:
883
917
  # return active_model if there is one
@@ -901,6 +935,7 @@ def get_model(
901
935
  if memoize:
902
936
  model_cache_key = (
903
937
  model
938
+ + str(role)
904
939
  + config.model_dump_json(exclude_none=True)
905
940
  + str(base_url)
906
941
  + str(api_key)
@@ -941,10 +976,11 @@ def get_model(
941
976
  **model_args,
942
977
  )
943
978
  m = Model(modelapi_instance, config, model_args)
979
+ if role is not None:
980
+ m._set_role(role)
944
981
  if memoize:
945
982
  _models[model_cache_key] = m
946
983
  return m
947
-
948
984
  else:
949
985
  from_api = f" from {api_name}" if api_name else ""
950
986
  raise ValueError(f"Model name {model}{from_api} not recognized.")
@@ -1353,10 +1389,20 @@ def active_model() -> Model | None:
1353
1389
  return active_model_context_var.get(None)
1354
1390
 
1355
1391
 
1356
- # shared contexts for asyncio tasks
1392
+ def init_model_roles(roles: dict[str, Model]) -> None:
1393
+ _model_roles.set(roles)
1394
+
1395
+
1396
+ def model_roles() -> dict[str, Model]:
1397
+ return _model_roles.get()
1398
+
1399
+
1357
1400
  active_model_context_var: ContextVar[Model | None] = ContextVar("active_model")
1358
1401
 
1402
+ _model_roles: ContextVar[dict[str, Model]] = ContextVar("model_roles", default={})
1403
+
1359
1404
 
1405
+ # shared contexts for asyncio tasks
1360
1406
  def handle_sample_message_limit(input: str | list[ChatMessage]) -> None:
1361
1407
  from inspect_ai.log._samples import (
1362
1408
  active_sample_message_limit,
@@ -1,9 +1,18 @@
1
1
  import json
2
2
  import re
3
+ import socket
3
4
  from copy import copy
4
- from typing import Literal
5
-
6
- from openai import APIStatusError, OpenAIError
5
+ from typing import Any, Literal
6
+
7
+ import httpx
8
+ from openai import (
9
+ DEFAULT_CONNECTION_LIMITS,
10
+ DEFAULT_TIMEOUT,
11
+ APIStatusError,
12
+ APITimeoutError,
13
+ OpenAIError,
14
+ RateLimitError,
15
+ )
7
16
  from openai.types.chat import (
8
17
  ChatCompletion,
9
18
  ChatCompletionAssistantMessageParam,
@@ -38,9 +47,11 @@ from inspect_ai._util.content import (
38
47
  ContentReasoning,
39
48
  ContentText,
40
49
  )
50
+ from inspect_ai._util.http import is_retryable_http_status
41
51
  from inspect_ai._util.images import file_as_data_uri
42
52
  from inspect_ai._util.url import is_http_url
43
53
  from inspect_ai.model._call_tools import parse_tool_call
54
+ from inspect_ai.model._generate_config import GenerateConfig
44
55
  from inspect_ai.model._model_output import ChatCompletionChoice, Logprobs
45
56
  from inspect_ai.model._reasoning import parse_content_with_reasoning
46
57
  from inspect_ai.tool import ToolCall, ToolChoice, ToolFunction, ToolInfo
@@ -146,24 +157,20 @@ async def openai_chat_completion_part(
146
157
 
147
158
 
148
159
  async def openai_chat_message(
149
- message: ChatMessage, model: str
160
+ message: ChatMessage, system_role: Literal["user", "system", "developer"] = "system"
150
161
  ) -> ChatCompletionMessageParam:
151
162
  if message.role == "system":
152
- # o1-mini does not support developer or system messages
153
- # (see Dec 17, 2024 changelog: https://platform.openai.com/docs/changelog)
154
- if is_o1_mini(model):
155
- return ChatCompletionUserMessageParam(role="user", content=message.text)
156
- # other o-series models use 'developer' rather than 'system' messages
157
- # https://platform.openai.com/docs/guides/reasoning#advice-on-prompting
158
- elif is_o_series(model):
159
- return ChatCompletionDeveloperMessageParam(
160
- role="developer", content=message.text
161
- )
162
- # gpt models use standard 'system' messages
163
- else:
164
- return ChatCompletionSystemMessageParam(
165
- role=message.role, content=message.text
166
- )
163
+ match system_role:
164
+ case "user":
165
+ return ChatCompletionUserMessageParam(role="user", content=message.text)
166
+ case "system":
167
+ return ChatCompletionSystemMessageParam(
168
+ role=message.role, content=message.text
169
+ )
170
+ case "developer":
171
+ return ChatCompletionDeveloperMessageParam(
172
+ role="developer", content=message.text
173
+ )
167
174
  elif message.role == "user":
168
175
  return ChatCompletionUserMessageParam(
169
176
  role=message.role,
@@ -202,9 +209,54 @@ async def openai_chat_message(
202
209
 
203
210
 
204
211
  async def openai_chat_messages(
205
- messages: list[ChatMessage], model: str
212
+ messages: list[ChatMessage],
213
+ system_role: Literal["user", "system", "developer"] = "system",
206
214
  ) -> list[ChatCompletionMessageParam]:
207
- return [await openai_chat_message(message, model) for message in messages]
215
+ return [await openai_chat_message(message, system_role) for message in messages]
216
+
217
+
218
+ def openai_completion_params(
219
+ model: str, config: GenerateConfig, tools: bool
220
+ ) -> dict[str, Any]:
221
+ params: dict[str, Any] = dict(model=model)
222
+ if config.max_tokens is not None:
223
+ params["max_tokens"] = config.max_tokens
224
+ if config.frequency_penalty is not None:
225
+ params["frequency_penalty"] = config.frequency_penalty
226
+ if config.stop_seqs is not None:
227
+ params["stop"] = config.stop_seqs
228
+ if config.presence_penalty is not None:
229
+ params["presence_penalty"] = config.presence_penalty
230
+ if config.logit_bias is not None:
231
+ params["logit_bias"] = config.logit_bias
232
+ if config.seed is not None:
233
+ params["seed"] = config.seed
234
+ if config.temperature is not None:
235
+ params["temperature"] = config.temperature
236
+ if config.top_p is not None:
237
+ params["top_p"] = config.top_p
238
+ if config.num_choices is not None:
239
+ params["n"] = config.num_choices
240
+ if config.logprobs is not None:
241
+ params["logprobs"] = config.logprobs
242
+ if config.top_logprobs is not None:
243
+ params["top_logprobs"] = config.top_logprobs
244
+ if tools and config.parallel_tool_calls is not None:
245
+ params["parallel_tool_calls"] = config.parallel_tool_calls
246
+ if config.reasoning_effort is not None:
247
+ params["reasoning_effort"] = config.reasoning_effort
248
+ if config.response_schema is not None:
249
+ params["response_format"] = dict(
250
+ type="json_schema",
251
+ json_schema=dict(
252
+ name=config.response_schema.name,
253
+ schema=config.response_schema.json_schema.model_dump(exclude_none=True),
254
+ description=config.response_schema.description,
255
+ strict=config.response_schema.strict,
256
+ ),
257
+ )
258
+
259
+ return params
208
260
 
209
261
 
210
262
  def openai_assistant_content(message: ChatMessageAssistant) -> str:
@@ -496,6 +548,35 @@ def chat_message_assistant_from_openai(
496
548
  )
497
549
 
498
550
 
551
+ def model_output_from_openai(
552
+ completion: ChatCompletion,
553
+ choices: list[ChatCompletionChoice],
554
+ ) -> ModelOutput:
555
+ return ModelOutput(
556
+ model=completion.model,
557
+ choices=choices,
558
+ usage=(
559
+ ModelUsage(
560
+ input_tokens=completion.usage.prompt_tokens,
561
+ output_tokens=completion.usage.completion_tokens,
562
+ input_tokens_cache_read=(
563
+ completion.usage.prompt_tokens_details.cached_tokens
564
+ if completion.usage.prompt_tokens_details is not None
565
+ else None # openai only have cache read stats/pricing.
566
+ ),
567
+ reasoning_tokens=(
568
+ completion.usage.completion_tokens_details.reasoning_tokens
569
+ if completion.usage.completion_tokens_details is not None
570
+ else None
571
+ ),
572
+ total_tokens=completion.usage.total_tokens,
573
+ )
574
+ if completion.usage
575
+ else None
576
+ ),
577
+ )
578
+
579
+
499
580
  def chat_choices_from_openai(
500
581
  response: ChatCompletion, tools: list[ToolInfo]
501
582
  ) -> list[ChatCompletionChoice]:
@@ -517,6 +598,19 @@ def chat_choices_from_openai(
517
598
  ]
518
599
 
519
600
 
601
+ def openai_should_retry(ex: Exception) -> bool:
602
+ if isinstance(ex, RateLimitError):
603
+ return True
604
+ elif isinstance(ex, APIStatusError):
605
+ return is_retryable_http_status(ex.status_code)
606
+ elif isinstance(ex, OpenAIResponseError):
607
+ return ex.code in ["rate_limit_exceeded", "server_error"]
608
+ elif isinstance(ex, APITimeoutError):
609
+ return True
610
+ else:
611
+ return False
612
+
613
+
520
614
  def openai_handle_bad_request(
521
615
  model_name: str, e: APIStatusError
522
616
  ) -> ModelOutput | Exception:
@@ -559,3 +653,39 @@ def openai_media_filter(key: JsonValue | None, value: JsonValue) -> JsonValue:
559
653
  value = copy(value)
560
654
  value.update(data=BASE_64_DATA_REMOVED)
561
655
  return value
656
+
657
+
658
+ class OpenAIAsyncHttpxClient(httpx.AsyncClient):
659
+ """Custom async client that deals better with long running Async requests.
660
+
661
+ Based on Anthropic DefaultAsyncHttpClient implementation that they
662
+ released along with Claude 3.7 as well as the OpenAI DefaultAsyncHttpxClient
663
+
664
+ """
665
+
666
+ def __init__(self, **kwargs: Any) -> None:
667
+ # This is based on the openai DefaultAsyncHttpxClient:
668
+ # https://github.com/openai/openai-python/commit/347363ed67a6a1611346427bb9ebe4becce53f7e
669
+ kwargs.setdefault("timeout", DEFAULT_TIMEOUT)
670
+ kwargs.setdefault("limits", DEFAULT_CONNECTION_LIMITS)
671
+ kwargs.setdefault("follow_redirects", True)
672
+
673
+ # This is based on the anthrpopic changes for claude 3.7:
674
+ # https://github.com/anthropics/anthropic-sdk-python/commit/c5387e69e799f14e44006ea4e54fdf32f2f74393#diff-3acba71f89118b06b03f2ba9f782c49ceed5bb9f68d62727d929f1841b61d12bR1387-R1403
675
+
676
+ # set socket options to deal with long running reasoning requests
677
+ socket_options = [
678
+ (socket.SOL_SOCKET, socket.SO_KEEPALIVE, True),
679
+ (socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 60),
680
+ (socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 5),
681
+ ]
682
+ TCP_KEEPIDLE = getattr(socket, "TCP_KEEPIDLE", None)
683
+ if TCP_KEEPIDLE is not None:
684
+ socket_options.append((socket.IPPROTO_TCP, TCP_KEEPIDLE, 60))
685
+
686
+ kwargs["transport"] = httpx.AsyncHTTPTransport(
687
+ limits=DEFAULT_CONNECTION_LIMITS,
688
+ socket_options=socket_options,
689
+ )
690
+
691
+ super().__init__(**kwargs)
@@ -82,7 +82,6 @@ class AnthropicAPI(ModelAPI):
82
82
  parts = model_name.split("/")
83
83
  if len(parts) > 1:
84
84
  self.service: str | None = parts[0]
85
- model_name = "/".join(parts[1:])
86
85
  else:
87
86
  self.service = None
88
87
 
@@ -237,7 +236,7 @@ class AnthropicAPI(ModelAPI):
237
236
 
238
237
  # extract output
239
238
  output = await model_output_from_message(
240
- self.client, self.model_name, message, tools
239
+ self.client, self.service_model_name(), message, tools
241
240
  )
242
241
 
243
242
  # return output and call
@@ -249,7 +248,7 @@ class AnthropicAPI(ModelAPI):
249
248
  except APIStatusError as ex:
250
249
  if ex.status_code == 413:
251
250
  return ModelOutput.from_content(
252
- model=self.model_name,
251
+ model=self.service_model_name(),
253
252
  content=ex.message,
254
253
  stop_reason="model_length",
255
254
  error=ex.message,
@@ -261,7 +260,7 @@ class AnthropicAPI(ModelAPI):
261
260
  self, config: GenerateConfig
262
261
  ) -> tuple[dict[str, Any], dict[str, str], list[str]]:
263
262
  max_tokens = cast(int, config.max_tokens)
264
- params = dict(model=self.model_name, max_tokens=max_tokens)
263
+ params = dict(model=self.service_model_name(), max_tokens=max_tokens)
265
264
  headers: dict[str, str] = {}
266
265
  betas: list[str] = []
267
266
  # some params not compatible with thinking models
@@ -311,18 +310,22 @@ class AnthropicAPI(ModelAPI):
311
310
  return not self.is_claude_3() and not self.is_claude_3_5()
312
311
 
313
312
  def is_claude_3(self) -> bool:
314
- return re.search(r"claude-3-[a-zA-Z]", self.model_name) is not None
313
+ return re.search(r"claude-3-[a-zA-Z]", self.service_model_name()) is not None
315
314
 
316
315
  def is_claude_3_5(self) -> bool:
317
- return "claude-3-5-" in self.model_name
316
+ return "claude-3-5-" in self.service_model_name()
318
317
 
319
318
  def is_claude_3_7(self) -> bool:
320
- return "claude-3-7-" in self.model_name
319
+ return "claude-3-7-" in self.service_model_name()
321
320
 
322
321
  @override
323
322
  def connection_key(self) -> str:
324
323
  return str(self.api_key)
325
324
 
325
+ def service_model_name(self) -> str:
326
+ """Model name without any service prefix."""
327
+ return self.model_name.replace(f"{self.service}/", "", 1)
328
+
326
329
  @override
327
330
  def should_retry(self, ex: Exception) -> bool:
328
331
  if isinstance(ex, APIStatusError):
@@ -371,7 +374,11 @@ class AnthropicAPI(ModelAPI):
371
374
  # NOTE: Using case insensitive matching because the Anthropic Bedrock API seems to capitalize the work 'input' in its error message, other times it doesn't.
372
375
  if any(
373
376
  message in error.lower()
374
- for message in ["prompt is too long", "input is too long"]
377
+ for message in [
378
+ "prompt is too long",
379
+ "input is too long",
380
+ "input length and `max_tokens` exceed context limit",
381
+ ]
375
382
  ):
376
383
  if (
377
384
  isinstance(ex.body, dict)
@@ -392,7 +399,7 @@ class AnthropicAPI(ModelAPI):
392
399
 
393
400
  if content and stop_reason:
394
401
  return ModelOutput.from_content(
395
- model=self.model_name,
402
+ model=self.service_model_name(),
396
403
  content=content,
397
404
  stop_reason=stop_reason,
398
405
  error=error,
@@ -440,10 +447,11 @@ class AnthropicAPI(ModelAPI):
440
447
 
441
448
  # only certain claude models qualify
442
449
  if cache_prompt:
450
+ model_name = self.service_model_name()
443
451
  if (
444
- "claude-3-sonnet" in self.model_name
445
- or "claude-2" in self.model_name
446
- or "claude-instant" in self.model_name
452
+ "claude-3-sonnet" in model_name
453
+ or "claude-2" in model_name
454
+ or "claude-instant" in model_name
447
455
  ):
448
456
  cache_prompt = False
449
457
 
@@ -368,7 +368,7 @@ class BedrockAPI(ModelAPI):
368
368
  toolConfig=tool_config,
369
369
  )
370
370
 
371
- def model_call(response: dict[str, Any] | None = None) -> ModelCall:
371
+ def model_call(response: dict[str, Any] = {}) -> ModelCall:
372
372
  return ModelCall.create(
373
373
  request=replace_bytes_with_placeholder(
374
374
  request.model_dump(exclude_none=True)
@@ -388,14 +388,14 @@ class BedrockAPI(ModelAPI):
388
388
  # Look for an explicit validation exception
389
389
  if ex.response["Error"]["Code"] == "ValidationException":
390
390
  response = ex.response["Error"]["Message"]
391
- if "Too many input tokens" in response:
391
+ if "too many input tokens" in response.lower():
392
392
  return ModelOutput.from_content(
393
393
  model=self.model_name,
394
394
  content=response,
395
395
  stop_reason="model_length",
396
396
  )
397
397
  else:
398
- return ex, model_call(None)
398
+ return ex, model_call()
399
399
  else:
400
400
  raise ex
401
401