inspect-ai 0.3.91__py3-none-any.whl → 0.3.93__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (118) hide show
  1. inspect_ai/_cli/eval.py +31 -0
  2. inspect_ai/_eval/eval.py +19 -2
  3. inspect_ai/_eval/evalset.py +4 -1
  4. inspect_ai/_eval/run.py +41 -0
  5. inspect_ai/_eval/task/generate.py +38 -44
  6. inspect_ai/_eval/task/log.py +26 -28
  7. inspect_ai/_eval/task/run.py +13 -20
  8. inspect_ai/_util/local_server.py +368 -0
  9. inspect_ai/_util/working.py +10 -4
  10. inspect_ai/_view/www/dist/assets/index.css +159 -146
  11. inspect_ai/_view/www/dist/assets/index.js +1020 -1061
  12. inspect_ai/_view/www/log-schema.json +4 -3
  13. inspect_ai/_view/www/package.json +1 -1
  14. inspect_ai/_view/www/src/@types/log.d.ts +3 -2
  15. inspect_ai/_view/www/src/app/content/MetaDataGrid.tsx +2 -2
  16. inspect_ai/_view/www/src/app/content/MetaDataView.module.css +1 -1
  17. inspect_ai/_view/www/src/app/content/MetadataGrid.module.css +1 -1
  18. inspect_ai/_view/www/src/app/content/RenderedContent.tsx +1 -1
  19. inspect_ai/_view/www/src/app/log-view/LogView.tsx +11 -0
  20. inspect_ai/_view/www/src/app/log-view/tabs/InfoTab.tsx +2 -9
  21. inspect_ai/_view/www/src/app/log-view/tabs/ModelsTab.tsx +51 -0
  22. inspect_ai/_view/www/src/app/log-view/tabs/TaskTab.module.css +6 -0
  23. inspect_ai/_view/www/src/app/log-view/tabs/TaskTab.tsx +143 -0
  24. inspect_ai/_view/www/src/app/plan/ModelCard.tsx +1 -2
  25. inspect_ai/_view/www/src/app/plan/PlanCard.tsx +29 -7
  26. inspect_ai/_view/www/src/app/plan/PlanDetailView.module.css +1 -1
  27. inspect_ai/_view/www/src/app/plan/PlanDetailView.tsx +1 -198
  28. inspect_ai/_view/www/src/app/samples/descriptor/score/NumericScoreDescriptor.tsx +2 -1
  29. inspect_ai/_view/www/src/app/usage/ModelUsagePanel.tsx +3 -2
  30. inspect_ai/_view/www/src/app/usage/TokenTable.module.css +4 -1
  31. inspect_ai/_view/www/src/app/usage/TokenTable.tsx +2 -2
  32. inspect_ai/_view/www/src/app/usage/UsageCard.module.css +8 -3
  33. inspect_ai/_view/www/src/app/usage/UsageCard.tsx +1 -35
  34. inspect_ai/_view/www/src/components/Card.css +0 -1
  35. inspect_ai/_view/www/src/constants.ts +2 -0
  36. inspect_ai/_view/www/src/utils/numeric.ts +17 -0
  37. inspect_ai/agent/_agent.py +3 -3
  38. inspect_ai/agent/_as_solver.py +20 -12
  39. inspect_ai/agent/_as_tool.py +15 -3
  40. inspect_ai/agent/_handoff.py +8 -1
  41. inspect_ai/agent/_run.py +11 -3
  42. inspect_ai/log/__init__.py +4 -0
  43. inspect_ai/log/_file.py +56 -0
  44. inspect_ai/log/_log.py +99 -0
  45. inspect_ai/log/_recorders/__init__.py +2 -0
  46. inspect_ai/log/_recorders/buffer/database.py +12 -11
  47. inspect_ai/log/_recorders/buffer/filestore.py +2 -2
  48. inspect_ai/log/_recorders/buffer/types.py +2 -2
  49. inspect_ai/log/_recorders/eval.py +20 -65
  50. inspect_ai/log/_recorders/file.py +28 -6
  51. inspect_ai/log/_recorders/recorder.py +7 -0
  52. inspect_ai/log/_recorders/types.py +1 -23
  53. inspect_ai/log/_samples.py +0 -8
  54. inspect_ai/log/_transcript.py +7 -1
  55. inspect_ai/log/_util.py +52 -0
  56. inspect_ai/model/__init__.py +5 -1
  57. inspect_ai/model/_call_tools.py +32 -12
  58. inspect_ai/model/_generate_config.py +14 -8
  59. inspect_ai/model/_model.py +21 -48
  60. inspect_ai/model/_model_output.py +25 -0
  61. inspect_ai/model/_openai.py +2 -0
  62. inspect_ai/model/_openai_responses.py +13 -1
  63. inspect_ai/model/_providers/anthropic.py +13 -23
  64. inspect_ai/model/_providers/openai_o1.py +8 -2
  65. inspect_ai/model/_providers/providers.py +18 -4
  66. inspect_ai/model/_providers/sglang.py +241 -0
  67. inspect_ai/model/_providers/vllm.py +207 -400
  68. inspect_ai/solver/__init__.py +7 -2
  69. inspect_ai/solver/_basic_agent.py +3 -10
  70. inspect_ai/solver/_task_state.py +26 -88
  71. inspect_ai/tool/_json_rpc_helpers.py +45 -17
  72. inspect_ai/tool/_mcp/_mcp.py +2 -0
  73. inspect_ai/tool/_mcp/_sandbox.py +8 -2
  74. inspect_ai/tool/_mcp/server.py +3 -1
  75. inspect_ai/tool/_tool_call.py +4 -1
  76. inspect_ai/tool/_tool_support_helpers.py +51 -12
  77. inspect_ai/tool/_tools/_bash_session.py +190 -68
  78. inspect_ai/tool/_tools/_computer/_computer.py +25 -1
  79. inspect_ai/tool/_tools/_text_editor.py +4 -3
  80. inspect_ai/tool/_tools/_web_browser/_web_browser.py +10 -3
  81. inspect_ai/util/__init__.py +12 -0
  82. inspect_ai/util/_limit.py +393 -0
  83. inspect_ai/util/_limited_conversation.py +57 -0
  84. {inspect_ai-0.3.91.dist-info → inspect_ai-0.3.93.dist-info}/METADATA +1 -1
  85. {inspect_ai-0.3.91.dist-info → inspect_ai-0.3.93.dist-info}/RECORD +90 -109
  86. {inspect_ai-0.3.91.dist-info → inspect_ai-0.3.93.dist-info}/WHEEL +1 -1
  87. inspect_ai/solver/_limit.py +0 -39
  88. inspect_ai/tool/_tools/_computer/_resources/Dockerfile +0 -102
  89. inspect_ai/tool/_tools/_computer/_resources/README.md +0 -30
  90. inspect_ai/tool/_tools/_computer/_resources/entrypoint/entrypoint.sh +0 -18
  91. inspect_ai/tool/_tools/_computer/_resources/entrypoint/novnc_startup.sh +0 -20
  92. inspect_ai/tool/_tools/_computer/_resources/entrypoint/x11vnc_startup.sh +0 -48
  93. inspect_ai/tool/_tools/_computer/_resources/entrypoint/xfce_startup.sh +0 -13
  94. inspect_ai/tool/_tools/_computer/_resources/entrypoint/xvfb_startup.sh +0 -48
  95. inspect_ai/tool/_tools/_computer/_resources/image_home_dir/.config/Code/User/globalStorage/state.vscdb +0 -0
  96. inspect_ai/tool/_tools/_computer/_resources/image_home_dir/.config/Code/User/settings.json +0 -9
  97. inspect_ai/tool/_tools/_computer/_resources/image_home_dir/.config/xfce4/xfconf/xfce-perchannel-xml/xfce4-panel.xml +0 -61
  98. inspect_ai/tool/_tools/_computer/_resources/image_home_dir/.config/xfce4/xfconf/xfce-perchannel-xml/xfce4-screensaver.xml +0 -10
  99. inspect_ai/tool/_tools/_computer/_resources/image_home_dir/.config/xfce4/xfconf/xfce-perchannel-xml/xfwm4.xml +0 -91
  100. inspect_ai/tool/_tools/_computer/_resources/image_home_dir/Desktop/Firefox Web Browser.desktop +0 -10
  101. inspect_ai/tool/_tools/_computer/_resources/image_home_dir/Desktop/Terminal.desktop +0 -10
  102. inspect_ai/tool/_tools/_computer/_resources/image_home_dir/Desktop/Visual Studio Code.desktop +0 -10
  103. inspect_ai/tool/_tools/_computer/_resources/tool/.pylintrc +0 -8
  104. inspect_ai/tool/_tools/_computer/_resources/tool/.vscode/settings.json +0 -12
  105. inspect_ai/tool/_tools/_computer/_resources/tool/_args.py +0 -78
  106. inspect_ai/tool/_tools/_computer/_resources/tool/_constants.py +0 -22
  107. inspect_ai/tool/_tools/_computer/_resources/tool/_logger.py +0 -22
  108. inspect_ai/tool/_tools/_computer/_resources/tool/_run.py +0 -42
  109. inspect_ai/tool/_tools/_computer/_resources/tool/_tool_result.py +0 -33
  110. inspect_ai/tool/_tools/_computer/_resources/tool/_x11_client.py +0 -341
  111. inspect_ai/tool/_tools/_computer/_resources/tool/computer_tool.py +0 -141
  112. inspect_ai/tool/_tools/_computer/_resources/tool/pyproject.toml +0 -65
  113. inspect_ai/tool/_tools/_computer/_resources/tool/requirements.txt +0 -0
  114. inspect_ai/tool/_tools/_computer/test_args.py +0 -151
  115. /inspect_ai/{tool/_tools/_computer/_resources/tool/__init__.py → _view/www/src/app/log-view/tabs/ModelsTab.module.css} +0 -0
  116. {inspect_ai-0.3.91.dist-info → inspect_ai-0.3.93.dist-info}/entry_points.txt +0 -0
  117. {inspect_ai-0.3.91.dist-info → inspect_ai-0.3.93.dist-info}/licenses/LICENSE +0 -0
  118. {inspect_ai-0.3.91.dist-info → inspect_ai-0.3.93.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,52 @@
1
+ import textwrap
2
+ from datetime import date, datetime, time
3
+ from typing import Any
4
+
5
+ from inspect_ai._util.content import (
6
+ ContentAudio,
7
+ ContentImage,
8
+ ContentReasoning,
9
+ ContentText,
10
+ ContentVideo,
11
+ )
12
+ from inspect_ai.model._chat_message import ChatMessage
13
+
14
+
15
+ def text_input_only(inputs: str | list[ChatMessage]) -> str | list[ChatMessage]:
16
+ # Clean the input of any images
17
+ if isinstance(inputs, list):
18
+ input: list[ChatMessage] = []
19
+ for message in inputs:
20
+ if not isinstance(message.content, str):
21
+ filtered_content: list[
22
+ ContentText
23
+ | ContentReasoning
24
+ | ContentImage
25
+ | ContentAudio
26
+ | ContentVideo
27
+ ] = []
28
+ for content in message.content:
29
+ if content.type == "text":
30
+ filtered_content.append(content)
31
+ else:
32
+ filtered_content.append(
33
+ ContentText(text=f"({content.type.capitalize()})")
34
+ )
35
+ message.content = filtered_content
36
+ input.append(message)
37
+ else:
38
+ input.append(message)
39
+
40
+ return input
41
+ else:
42
+ return inputs
43
+
44
+
45
+ def thin_metadata(metadata: dict[str, Any]) -> dict[str, Any]:
46
+ thinned: dict[str, Any] = {}
47
+ for key, value in metadata.items():
48
+ if isinstance(value, int | float | bool | date | time | datetime):
49
+ thinned[key] = value
50
+ elif isinstance(value, str):
51
+ thinned[key] = textwrap.shorten(value, width=1024, placeholder="...")
52
+ return thinned
@@ -28,7 +28,11 @@ from ._chat_message import (
28
28
  ChatMessageUser,
29
29
  )
30
30
  from ._conversation import ModelConversation
31
- from ._generate_config import GenerateConfig, GenerateConfigArgs, ResponseSchema
31
+ from ._generate_config import (
32
+ GenerateConfig,
33
+ GenerateConfigArgs,
34
+ ResponseSchema,
35
+ )
32
36
  from ._model import (
33
37
  Model,
34
38
  ModelAPI,
@@ -60,6 +60,7 @@ from inspect_ai.tool._tool_info import parse_docstring
60
60
  from inspect_ai.tool._tool_params import ToolParams
61
61
  from inspect_ai.util import OutputLimitExceededError
62
62
  from inspect_ai.util._anyio import inner_exception
63
+ from inspect_ai.util._limit import LimitExceededError, apply_limits
63
64
 
64
65
  from ._chat_message import (
65
66
  ChatMessage,
@@ -171,10 +172,15 @@ async def execute_tools(
171
172
  tool_error = ToolCallError("is_a_directory", err)
172
173
  except OutputLimitExceededError as ex:
173
174
  tool_error = ToolCallError(
174
- "output_limit",
175
- f"The tool output limit of {ex.limit_str} was exceeded.",
175
+ "limit",
176
+ f"The tool exceeded its output limit of {ex.limit_str}.",
176
177
  )
177
178
  result = ex.truncated_output or ""
179
+ except LimitExceededError as ex:
180
+ tool_error = ToolCallError(
181
+ "limit",
182
+ f"The tool exceeded its {ex.type} limit of {ex.limit}.",
183
+ )
178
184
  except ToolParsingError as ex:
179
185
  tool_error = ToolCallError("parsing", ex.message)
180
186
  except ToolApprovalError as ex:
@@ -344,6 +350,7 @@ async def call_tool(
344
350
  tools: list[ToolDef], message: str, call: ToolCall, conversation: list[ChatMessage]
345
351
  ) -> tuple[ToolResult, list[ChatMessage], ModelOutput | None, str | None]:
346
352
  from inspect_ai.agent._handoff import AgentTool
353
+ from inspect_ai.log._transcript import SampleLimitEvent, transcript
347
354
 
348
355
  # if there was an error parsing the ToolCall, raise that
349
356
  if call.parse_error:
@@ -362,14 +369,11 @@ async def call_tool(
362
369
  )
363
370
  if not approved:
364
371
  if approval and approval.decision == "terminate":
365
- from inspect_ai.solver._limit import SampleLimitExceededError
366
-
367
- raise SampleLimitExceededError(
368
- "operator",
369
- value=1,
370
- limit=1,
371
- message="Tool call approver requested termination.",
372
+ message = "Tool call approver requested termination."
373
+ transcript()._event(
374
+ SampleLimitEvent(type="operator", limit=1, message=message)
372
375
  )
376
+ raise LimitExceededError("operator", value=1, limit=1, message=message)
373
377
  else:
374
378
  raise ToolApprovalError(approval.explanation if approval else None)
375
379
  if approval and approval.modified:
@@ -454,9 +458,14 @@ async def agent_handoff(
454
458
  arguments = tool_params(arguments, agent_tool.agent)
455
459
  del arguments["state"]
456
460
 
457
- # make the call
461
+ # run the agent with limits
462
+ limit_error: LimitExceededError | None = None
458
463
  agent_state = AgentState(messages=copy(agent_conversation))
459
- agent_state = await agent_tool.agent(agent_state, **arguments)
464
+ try:
465
+ with apply_limits(agent_tool.limits):
466
+ agent_state = await agent_tool.agent(agent_state, **arguments)
467
+ except LimitExceededError as ex:
468
+ limit_error = ex
460
469
 
461
470
  # determine which messages are new and return only those (but exclude new
462
471
  # system messages as they an internal matter for the handed off to agent.
@@ -474,9 +483,20 @@ async def agent_handoff(
474
483
  if agent_tool.output_filter is not None:
475
484
  agent_messages = await agent_tool.output_filter(agent_messages)
476
485
 
486
+ if limit_error is not None:
487
+ agent_messages.append(
488
+ ChatMessageUser(
489
+ content=(
490
+ f"The {agent_name} exceeded its {limit_error.type} limit of "
491
+ f"{limit_error.limit}."
492
+ )
493
+ )
494
+ )
477
495
  # if we end with an assistant message then add a user message
478
496
  # so that the calling agent carries on
479
- if len(agent_messages) == 0 or isinstance(agent_messages[-1], ChatMessageAssistant):
497
+ elif len(agent_messages) == 0 or isinstance(
498
+ agent_messages[-1], ChatMessageAssistant
499
+ ):
480
500
  agent_messages.append(
481
501
  ChatMessageUser(content=f"The {agent_name} agent has completed its work.")
482
502
  )
@@ -106,6 +106,9 @@ class GenerateConfigArgs(TypedDict, total=False):
106
106
  response_schema: ResponseSchema | None
107
107
  """Request a response format as JSONSchema (output should still be validated). OpenAI, Google, and Mistral only."""
108
108
 
109
+ extra_body: dict[str, Any] | None
110
+ """Extra body to be sent with requests to OpenAI compatible servers. OpenAI, vLLM, and SGLang only."""
111
+
109
112
 
110
113
  class GenerateConfig(BaseModel):
111
114
  """Model generation options."""
@@ -138,28 +141,28 @@ class GenerateConfig(BaseModel):
138
141
  """Generates best_of completions server-side and returns the 'best' (the one with the highest log probability per token). vLLM only."""
139
142
 
140
143
  frequency_penalty: float | None = Field(default=None)
141
- """Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. OpenAI, Google, Grok, Groq, and vLLM only."""
144
+ """Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. OpenAI, Google, Grok, Groq, vLLM, and SGLang only."""
142
145
 
143
146
  presence_penalty: float | None = Field(default=None)
144
- """Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics. OpenAI, Google, Grok, Groq, and vLLM only."""
147
+ """Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics. OpenAI, Google, Grok, Groq, vLLM, and SGLang only."""
145
148
 
146
149
  logit_bias: dict[int, float] | None = Field(default=None)
147
- """Map token Ids to an associated bias value from -100 to 100 (e.g. "42=10,43=-10"). OpenAI, Grok, and Grok only."""
150
+ """Map token Ids to an associated bias value from -100 to 100 (e.g. "42=10,43=-10"). OpenAI, Grok, Grok, and vLLM only."""
148
151
 
149
152
  seed: int | None = Field(default=None)
150
153
  """Random seed. OpenAI, Google, Mistral, Groq, HuggingFace, and vLLM only."""
151
154
 
152
155
  top_k: int | None = Field(default=None)
153
- """Randomly sample the next word from the top_k most likely next words. Anthropic, Google, HuggingFace, and vLLM only."""
156
+ """Randomly sample the next word from the top_k most likely next words. Anthropic, Google, HuggingFace, vLLM, and SGLang only."""
154
157
 
155
158
  num_choices: int | None = Field(default=None)
156
- """How many chat completion choices to generate for each input message. OpenAI, Grok, Google, TogetherAI, and vLLM only."""
159
+ """How many chat completion choices to generate for each input message. OpenAI, Grok, Google, TogetherAI, vLLM, and SGLang only."""
157
160
 
158
161
  logprobs: bool | None = Field(default=None)
159
- """Return log probabilities of the output tokens. OpenAI, Grok, TogetherAI, Huggingface, llama-cpp-python, and vLLM only."""
162
+ """Return log probabilities of the output tokens. OpenAI, Grok, TogetherAI, Huggingface, llama-cpp-python, vLLM, and SGLang only."""
160
163
 
161
164
  top_logprobs: int | None = Field(default=None)
162
- """Number of most likely tokens (0-20) to return at each token position, each with an associated log probability. OpenAI, Grok, Huggingface, and vLLM only."""
165
+ """Number of most likely tokens (0-20) to return at each token position, each with an associated log probability. OpenAI, Grok, Huggingface, vLLM, and SGLang only."""
163
166
 
164
167
  parallel_tool_calls: bool | None = Field(default=None)
165
168
  """Whether to enable parallel function calling during tool use (defaults to True). OpenAI and Groq only."""
@@ -190,7 +193,10 @@ class GenerateConfig(BaseModel):
190
193
  """Include reasoning in chat message history sent to generate."""
191
194
 
192
195
  response_schema: ResponseSchema | None = Field(default=None)
193
- """Request a response format as JSONSchema (output should still be validated). OpenAI, Google, and Mistral only."""
196
+ """Request a response format as JSONSchema (output should still be validated). OpenAI, Google, Mistral, vLLM, and SGLang only."""
197
+
198
+ extra_body: dict[str, Any] | None = Field(default=None)
199
+ """Extra body to be sent with requests to OpenAI compatible servers. OpenAI, vLLM, and SGLang only."""
194
200
 
195
201
  # migrate reasoning_history as a bool
196
202
  @model_validator(mode="before")
@@ -57,6 +57,11 @@ from inspect_ai.tool._tool import ToolSource
57
57
  from inspect_ai.tool._tool_call import ToolCallModelInputHints
58
58
  from inspect_ai.tool._tool_def import ToolDef, tool_defs
59
59
  from inspect_ai.util import concurrency
60
+ from inspect_ai.util._limit import (
61
+ check_message_limit,
62
+ check_token_limit,
63
+ record_model_usage,
64
+ )
60
65
 
61
66
  from ._cache import CacheEntry, CachePolicy, cache_fetch, cache_store
62
67
  from ._call_tools import (
@@ -355,11 +360,15 @@ class Model:
355
360
  Returns:
356
361
  ModelOutput
357
362
  """
358
- # if we are the default model then enforce message limit if it
359
- # exists (raise an exception if it is exceeded)
363
+ # if we are the default model then update the displayed message count
360
364
  is_active_model = self == active_model()
361
365
  if is_active_model:
362
- handle_sample_message_limit(input)
366
+ set_total_messages(input)
367
+
368
+ # check message limit, raise exception if we're already at the limit to prevent
369
+ # a wasteful generate()
370
+ conversation_length = len(input) if isinstance(input, list) else 1
371
+ check_message_limit(conversation_length, raise_for_equal=True)
363
372
 
364
373
  # base config for this model
365
374
  base_config = self.config
@@ -666,7 +675,7 @@ class Model:
666
675
  # record usage
667
676
  if output.usage:
668
677
  # record usage
669
- record_model_usage(f"{self}", output.usage)
678
+ record_and_check_model_usage(f"{self}", output.usage)
670
679
 
671
680
  # send telemetry if its hooked up
672
681
  await send_telemetry(
@@ -1423,20 +1432,10 @@ _model_roles: ContextVar[dict[str, Model]] = ContextVar("model_roles", default={
1423
1432
 
1424
1433
 
1425
1434
  # shared contexts for asyncio tasks
1426
- def handle_sample_message_limit(input: str | list[ChatMessage]) -> None:
1427
- from inspect_ai.log._samples import (
1428
- active_sample_message_limit,
1429
- set_active_sample_total_messages,
1430
- )
1431
- from inspect_ai.solver._limit import SampleLimitExceededError
1435
+ def set_total_messages(input: str | list[ChatMessage]) -> None:
1436
+ from inspect_ai.log._samples import set_active_sample_total_messages
1432
1437
 
1433
1438
  total_messages = 1 if isinstance(input, str) else len(input)
1434
- message_limit = active_sample_message_limit()
1435
- if message_limit is not None:
1436
- if total_messages >= message_limit:
1437
- raise SampleLimitExceededError(
1438
- "message", value=total_messages, limit=message_limit
1439
- )
1440
1439
 
1441
1440
  # set total messages
1442
1441
  set_active_sample_total_messages(total_messages)
@@ -1450,16 +1449,13 @@ def init_sample_model_usage() -> None:
1450
1449
  sample_model_usage_context_var.set({})
1451
1450
 
1452
1451
 
1453
- def record_model_usage(model: str, usage: ModelUsage) -> None:
1454
- from inspect_ai.log._samples import (
1455
- active_sample_token_limit,
1456
- set_active_sample_total_tokens,
1457
- )
1458
- from inspect_ai.solver._limit import SampleLimitExceededError
1452
+ def record_and_check_model_usage(model: str, usage: ModelUsage) -> None:
1453
+ from inspect_ai.log._samples import set_active_sample_total_tokens
1459
1454
 
1460
1455
  # record usage
1461
1456
  set_model_usage(model, usage, sample_model_usage_context_var.get(None))
1462
1457
  set_model_usage(model, usage, model_usage_context_var.get(None))
1458
+ record_model_usage(usage)
1463
1459
 
1464
1460
  # compute total tokens
1465
1461
  total_tokens = sample_total_tokens()
@@ -1467,38 +1463,15 @@ def record_model_usage(model: str, usage: ModelUsage) -> None:
1467
1463
  # update active sample
1468
1464
  set_active_sample_total_tokens(total_tokens)
1469
1465
 
1470
- # check for token limit overflow and raise
1471
- token_limit = active_sample_token_limit()
1472
- if token_limit is not None:
1473
- if total_tokens > token_limit:
1474
- raise SampleLimitExceededError(
1475
- "token", value=total_tokens, limit=token_limit
1476
- )
1466
+ check_token_limit()
1477
1467
 
1478
1468
 
1479
1469
  def set_model_usage(
1480
1470
  model: str, usage: ModelUsage, model_usage: dict[str, ModelUsage] | None
1481
1471
  ) -> None:
1482
1472
  if model_usage is not None:
1483
- total_usage: ModelUsage | None = model_usage.get(model, None)
1484
- if not total_usage:
1485
- total_usage = ModelUsage()
1486
- total_usage.input_tokens += usage.input_tokens
1487
- total_usage.output_tokens += usage.output_tokens
1488
- total_usage.total_tokens += usage.total_tokens
1489
- if usage.input_tokens_cache_write is not None:
1490
- if total_usage.input_tokens_cache_write is None:
1491
- total_usage.input_tokens_cache_write = 0
1492
- total_usage.input_tokens_cache_write += usage.input_tokens_cache_write
1493
- if usage.input_tokens_cache_read is not None:
1494
- if total_usage.input_tokens_cache_read is None:
1495
- total_usage.input_tokens_cache_read = 0
1496
- total_usage.input_tokens_cache_read += usage.input_tokens_cache_read
1497
- if usage.reasoning_tokens is not None:
1498
- if total_usage.reasoning_tokens is None:
1499
- total_usage.reasoning_tokens = 0
1500
- total_usage.reasoning_tokens += usage.reasoning_tokens
1501
-
1473
+ total_usage = model_usage.get(model, ModelUsage())
1474
+ total_usage += usage
1502
1475
  model_usage[model] = total_usage
1503
1476
 
1504
1477
 
@@ -30,6 +30,31 @@ class ModelUsage(BaseModel):
30
30
  reasoning_tokens: int | None = Field(default=None)
31
31
  """Number of tokens used for reasoning."""
32
32
 
33
+ def __add__(self, other: "ModelUsage") -> "ModelUsage":
34
+ def optional_sum(a: int | None, b: int | None) -> int | None:
35
+ if a is not None and b is not None:
36
+ return a + b
37
+ if a is not None:
38
+ return a
39
+ if b is not None:
40
+ return b
41
+ return None
42
+
43
+ return ModelUsage(
44
+ input_tokens=self.input_tokens + other.input_tokens,
45
+ output_tokens=self.output_tokens + other.output_tokens,
46
+ total_tokens=self.total_tokens + other.total_tokens,
47
+ input_tokens_cache_write=optional_sum(
48
+ self.input_tokens_cache_write, other.input_tokens_cache_write
49
+ ),
50
+ input_tokens_cache_read=optional_sum(
51
+ self.input_tokens_cache_read, other.input_tokens_cache_read
52
+ ),
53
+ reasoning_tokens=optional_sum(
54
+ self.reasoning_tokens, other.reasoning_tokens
55
+ ),
56
+ )
57
+
33
58
 
34
59
  StopReason = Literal[
35
60
  "stop",
@@ -255,6 +255,8 @@ def openai_completion_params(
255
255
  strict=config.response_schema.strict,
256
256
  ),
257
257
  )
258
+ if config.extra_body:
259
+ params["extra_body"] = config.extra_body
258
260
 
259
261
  return params
260
262
 
@@ -1,3 +1,4 @@
1
+ import json
1
2
  from itertools import chain
2
3
  from typing import TypedDict, cast
3
4
 
@@ -306,6 +307,14 @@ def _openai_input_items_from_chat_message_assistant(
306
307
  """
307
308
  (output_message_id, tool_message_ids) = _ids_from_assistant_internal(message)
308
309
 
310
+ # we want to prevent yielding output messages in the case where we have an
311
+ # 'internal' field (so the message came from the model API as opposed to
312
+ # being user synthesized) AND there is no output_message_id (indicating that
313
+ # when reading the message from the server we didn't find output). this could
314
+ # happen e.g. when a react() agent sets the output.completion in response
315
+ # to a submit() tool call
316
+ suppress_output_message = message.internal is not None and output_message_id is None
317
+
309
318
  # if we are not storing messages on the server then blank these out
310
319
  if not store:
311
320
  output_message_id = None
@@ -341,6 +350,9 @@ def _openai_input_items_from_chat_message_assistant(
341
350
  )
342
351
  )
343
352
  case ContentText(text=text, refusal=refusal):
353
+ if suppress_output_message:
354
+ continue
355
+
344
356
  new_content = (
345
357
  ResponseOutputRefusalParam(type="refusal", refusal=text)
346
358
  if refusal
@@ -415,7 +427,7 @@ def _tool_call_items_from_assistant_message(
415
427
  type="function_call",
416
428
  call_id=call.id,
417
429
  name=_responses_tool_alias(call.function),
418
- arguments=call.function,
430
+ arguments=json.dumps(call.arguments),
419
431
  )
420
432
 
421
433
  # add id if available
@@ -26,7 +26,6 @@ from anthropic.types import (
26
26
  TextBlockParam,
27
27
  ThinkingBlock,
28
28
  ThinkingBlockParam,
29
- ToolBash20250124Param,
30
29
  ToolParam,
31
30
  ToolResultBlockParam,
32
31
  ToolTextEditor20250124Param,
@@ -76,6 +75,7 @@ class AnthropicAPI(ModelAPI):
76
75
  base_url: str | None = None,
77
76
  api_key: str | None = None,
78
77
  config: GenerateConfig = GenerateConfig(),
78
+ streaming: bool | Literal["auto"] = "auto",
79
79
  **model_args: Any,
80
80
  ):
81
81
  # extract any service prefix from model name
@@ -85,6 +85,9 @@ class AnthropicAPI(ModelAPI):
85
85
  else:
86
86
  self.service = None
87
87
 
88
+ # record steraming pref
89
+ self.streaming = streaming
90
+
88
91
  # collect generate model_args (then delete them so we can pass the rest on)
89
92
  def collect_model_arg(name: str) -> Any | None:
90
93
  nonlocal model_args
@@ -224,8 +227,13 @@ class AnthropicAPI(ModelAPI):
224
227
  if self.extra_body is not None:
225
228
  request["extra_body"] = self.extra_body
226
229
 
227
- # make request (stream if we are using reasoning)
228
- if self.is_using_thinking(config):
230
+ # make request (unless overrideen, stream if we are using reasoning)
231
+ streaming = (
232
+ self.is_using_thinking(config)
233
+ if self.streaming == "auto"
234
+ else self.streaming
235
+ )
236
+ if streaming:
229
237
  async with self.client.messages.stream(**request) as stream:
230
238
  message = await stream.get_final_message()
231
239
  else:
@@ -489,11 +497,7 @@ class AnthropicAPI(ModelAPI):
489
497
  self, tool: ToolInfo, config: GenerateConfig
490
498
  ) -> Optional["ToolParamDef"]:
491
499
  return (
492
- (
493
- self.computer_use_tool_param(tool)
494
- or self.text_editor_tool_param(tool)
495
- or self.bash_tool_param(tool)
496
- )
500
+ (self.computer_use_tool_param(tool) or self.text_editor_tool_param(tool))
497
501
  if config.internal_tools is not False
498
502
  else None
499
503
  )
@@ -564,23 +568,10 @@ class AnthropicAPI(ModelAPI):
564
568
  else:
565
569
  return None
566
570
 
567
- def bash_tool_param(self, tool: ToolInfo) -> Optional[ToolBash20250124Param]:
568
- # check for compatible 'bash' tool
569
- if tool.name == "bash_session" and (
570
- sorted(tool.parameters.properties.keys()) == sorted(["command", "restart"])
571
- ):
572
- return ToolBash20250124Param(type="bash_20250124", name="bash")
573
- # not a bash tool
574
- else:
575
- return None
576
-
577
571
 
578
572
  # tools can be either a stock tool param or a special Anthropic native use tool param
579
573
  ToolParamDef = (
580
- ToolParam
581
- | BetaToolComputerUse20250124Param
582
- | ToolTextEditor20250124Param
583
- | ToolBash20250124Param
574
+ ToolParam | BetaToolComputerUse20250124Param | ToolTextEditor20250124Param
584
575
  )
585
576
 
586
577
 
@@ -589,7 +580,6 @@ def add_cache_control(
589
580
  | ToolParam
590
581
  | BetaToolComputerUse20250124Param
591
582
  | ToolTextEditor20250124Param
592
- | ToolBash20250124Param
593
583
  | dict[str, Any],
594
584
  ) -> None:
595
585
  cast(dict[str, Any], param)["cache_control"] = {"type": "ephemeral"}
@@ -211,8 +211,15 @@ class O1PreviewChatAPIHandler(ChatAPIHandler):
211
211
  This method has an interdependency with `input_with_tools()` (as that is the
212
212
  prompt that asks the model to use the <tool_call>...</tool_call> syntax)
213
213
  """
214
- # extract tool calls
214
+ # define regex patterns
215
+ # NOTE: If you change either of these regex patterns, please update the other
216
+ # tool_call_regex extracts the JSON content (in curly braces) between tool call tags
215
217
  tool_call_regex = rf"<{TOOL_CALL}>\s*(\{{[\s\S]*?\}})\s*</{TOOL_CALL}>"
218
+ # tool_call_content_regex matches the entire tool call block including tags for extracting
219
+ # the content outside of the tool call tags
220
+ tool_call_content_regex = rf"<{TOOL_CALL}>\s*\{{[\s\S]*?\}}\s*</{TOOL_CALL}>"
221
+
222
+ # extract tool calls
216
223
  tool_calls_content: list[str] = re.findall(tool_call_regex, response)
217
224
 
218
225
  # if there are tool calls proceed with parsing
@@ -226,7 +233,6 @@ class O1PreviewChatAPIHandler(ChatAPIHandler):
226
233
  ]
227
234
 
228
235
  # find other content that exists outside tool calls
229
- tool_call_content_regex = rf"<{TOOL_CALL}>(?:.|\n)*?</{TOOL_CALL}>"
230
236
  other_content = re.split(tool_call_content_regex, response, flags=re.DOTALL)
231
237
  other_content = [
232
238
  str(content).strip()
@@ -136,10 +136,12 @@ def hf() -> type[ModelAPI]:
136
136
 
137
137
  @modelapi(name="vllm")
138
138
  def vllm() -> type[ModelAPI]:
139
- try:
140
- from .vllm import VLLMAPI
141
- except ImportError:
142
- raise pip_dependency_error("vLLM Models", ["vllm"])
139
+ # Only validate OpenAI compatibility (needed for the API interface)
140
+ validate_openai_client("vLLM API")
141
+
142
+ # Import VLLMAPI without checking for vllm package yet
143
+ # The actual vllm dependency will only be checked if needed to start a server
144
+ from .vllm import VLLMAPI
143
145
 
144
146
  return VLLMAPI
145
147
 
@@ -257,6 +259,18 @@ def mockllm() -> type[ModelAPI]:
257
259
  return MockLLM
258
260
 
259
261
 
262
+ @modelapi(name="sglang")
263
+ def sglang() -> type[ModelAPI]:
264
+ # Only validate OpenAI compatibility (needed for the API interface)
265
+ validate_openai_client("SGLang API")
266
+
267
+ # Import SGLangAPI without checking for sglang package yet
268
+ # The actual sglang dependency will only be checked if needed to start a server
269
+ from .sglang import SGLangAPI
270
+
271
+ return SGLangAPI
272
+
273
+
260
274
  @modelapi(name="none")
261
275
  def none() -> type[ModelAPI]:
262
276
  from .none import NoModel