inspect-ai 0.3.92__py3-none-any.whl → 0.3.94__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (149) hide show
  1. inspect_ai/_cli/eval.py +27 -0
  2. inspect_ai/_display/textual/widgets/samples.py +3 -3
  3. inspect_ai/_display/textual/widgets/transcript.py +3 -29
  4. inspect_ai/_eval/eval.py +19 -2
  5. inspect_ai/_eval/evalset.py +4 -1
  6. inspect_ai/_eval/run.py +41 -0
  7. inspect_ai/_eval/task/generate.py +38 -44
  8. inspect_ai/_eval/task/log.py +26 -28
  9. inspect_ai/_eval/task/run.py +23 -27
  10. inspect_ai/_util/answer.py +26 -0
  11. inspect_ai/_util/constants.py +0 -1
  12. inspect_ai/_util/local_server.py +398 -0
  13. inspect_ai/_util/working.py +10 -4
  14. inspect_ai/_view/www/dist/assets/index.css +173 -159
  15. inspect_ai/_view/www/dist/assets/index.js +1417 -1142
  16. inspect_ai/_view/www/log-schema.json +379 -3
  17. inspect_ai/_view/www/package.json +1 -1
  18. inspect_ai/_view/www/src/@types/log.d.ts +93 -14
  19. inspect_ai/_view/www/src/app/content/MetaDataGrid.tsx +2 -2
  20. inspect_ai/_view/www/src/app/content/MetaDataView.module.css +1 -1
  21. inspect_ai/_view/www/src/app/content/MetadataGrid.module.css +1 -1
  22. inspect_ai/_view/www/src/app/content/RenderedContent.tsx +1 -1
  23. inspect_ai/_view/www/src/app/log-view/LogView.tsx +11 -0
  24. inspect_ai/_view/www/src/app/log-view/tabs/InfoTab.tsx +2 -9
  25. inspect_ai/_view/www/src/app/log-view/tabs/ModelsTab.tsx +51 -0
  26. inspect_ai/_view/www/src/app/log-view/tabs/TaskTab.module.css +6 -0
  27. inspect_ai/_view/www/src/app/log-view/tabs/TaskTab.tsx +143 -0
  28. inspect_ai/_view/www/src/app/plan/ModelCard.tsx +1 -2
  29. inspect_ai/_view/www/src/app/plan/PlanCard.tsx +29 -7
  30. inspect_ai/_view/www/src/app/plan/PlanDetailView.module.css +1 -1
  31. inspect_ai/_view/www/src/app/plan/PlanDetailView.tsx +1 -198
  32. inspect_ai/_view/www/src/app/samples/descriptor/score/NumericScoreDescriptor.tsx +2 -1
  33. inspect_ai/_view/www/src/app/samples/transcript/SandboxEventView.module.css +2 -1
  34. inspect_ai/_view/www/src/app/samples/transcript/SpanEventView.tsx +174 -0
  35. inspect_ai/_view/www/src/app/samples/transcript/ToolEventView.tsx +8 -8
  36. inspect_ai/_view/www/src/app/samples/transcript/TranscriptView.tsx +12 -2
  37. inspect_ai/_view/www/src/app/samples/transcript/TranscriptVirtualListComponent.module.css +1 -1
  38. inspect_ai/_view/www/src/app/samples/transcript/event/EventPanel.tsx +0 -3
  39. inspect_ai/_view/www/src/app/samples/transcript/transform/fixups.ts +87 -25
  40. inspect_ai/_view/www/src/app/samples/transcript/transform/treeify.ts +229 -17
  41. inspect_ai/_view/www/src/app/samples/transcript/transform/utils.ts +11 -0
  42. inspect_ai/_view/www/src/app/samples/transcript/types.ts +5 -1
  43. inspect_ai/_view/www/src/app/usage/ModelUsagePanel.tsx +3 -2
  44. inspect_ai/_view/www/src/app/usage/TokenTable.module.css +4 -1
  45. inspect_ai/_view/www/src/app/usage/TokenTable.tsx +2 -2
  46. inspect_ai/_view/www/src/app/usage/UsageCard.module.css +8 -3
  47. inspect_ai/_view/www/src/app/usage/UsageCard.tsx +1 -35
  48. inspect_ai/_view/www/src/components/Card.css +0 -1
  49. inspect_ai/_view/www/src/constants.ts +2 -0
  50. inspect_ai/_view/www/src/utils/numeric.ts +17 -0
  51. inspect_ai/agent/_agent.py +3 -3
  52. inspect_ai/agent/_as_solver.py +22 -12
  53. inspect_ai/agent/_as_tool.py +20 -6
  54. inspect_ai/agent/_handoff.py +12 -1
  55. inspect_ai/agent/_react.py +4 -3
  56. inspect_ai/agent/_run.py +16 -3
  57. inspect_ai/agent/_types.py +9 -0
  58. inspect_ai/dataset/_dataset.py +6 -3
  59. inspect_ai/log/__init__.py +14 -0
  60. inspect_ai/log/_convert.py +4 -9
  61. inspect_ai/log/_file.py +56 -0
  62. inspect_ai/log/_log.py +99 -0
  63. inspect_ai/log/_recorders/__init__.py +2 -0
  64. inspect_ai/log/_recorders/buffer/database.py +12 -11
  65. inspect_ai/log/_recorders/buffer/filestore.py +2 -2
  66. inspect_ai/log/_recorders/buffer/types.py +2 -2
  67. inspect_ai/log/_recorders/eval.py +20 -65
  68. inspect_ai/log/_recorders/file.py +28 -6
  69. inspect_ai/log/_recorders/recorder.py +7 -0
  70. inspect_ai/log/_recorders/types.py +1 -23
  71. inspect_ai/log/_samples.py +14 -25
  72. inspect_ai/log/_transcript.py +84 -36
  73. inspect_ai/log/_tree.py +118 -0
  74. inspect_ai/log/_util.py +52 -0
  75. inspect_ai/model/__init__.py +5 -1
  76. inspect_ai/model/_call_tools.py +72 -44
  77. inspect_ai/model/_generate_config.py +14 -8
  78. inspect_ai/model/_model.py +66 -88
  79. inspect_ai/model/_model_output.py +25 -0
  80. inspect_ai/model/_openai.py +2 -0
  81. inspect_ai/model/_providers/anthropic.py +13 -23
  82. inspect_ai/model/_providers/hf.py +27 -1
  83. inspect_ai/model/_providers/openai_o1.py +8 -2
  84. inspect_ai/model/_providers/providers.py +18 -4
  85. inspect_ai/model/_providers/sglang.py +247 -0
  86. inspect_ai/model/_providers/vllm.py +211 -400
  87. inspect_ai/scorer/_choice.py +1 -2
  88. inspect_ai/solver/__init__.py +7 -2
  89. inspect_ai/solver/_basic_agent.py +3 -10
  90. inspect_ai/solver/_chain.py +1 -1
  91. inspect_ai/solver/_fork.py +1 -1
  92. inspect_ai/solver/_multiple_choice.py +5 -22
  93. inspect_ai/solver/_plan.py +2 -2
  94. inspect_ai/solver/_task_state.py +26 -88
  95. inspect_ai/solver/_transcript.py +6 -7
  96. inspect_ai/tool/_json_rpc_helpers.py +45 -17
  97. inspect_ai/tool/_mcp/_mcp.py +8 -5
  98. inspect_ai/tool/_mcp/_sandbox.py +8 -2
  99. inspect_ai/tool/_mcp/server.py +3 -1
  100. inspect_ai/tool/_tool_call.py +4 -1
  101. inspect_ai/tool/_tool_support_helpers.py +51 -12
  102. inspect_ai/tool/_tools/_bash_session.py +190 -68
  103. inspect_ai/tool/_tools/_computer/_computer.py +25 -1
  104. inspect_ai/tool/_tools/_execute.py +4 -1
  105. inspect_ai/tool/_tools/_text_editor.py +4 -3
  106. inspect_ai/tool/_tools/_web_browser/_web_browser.py +10 -3
  107. inspect_ai/util/__init__.py +16 -0
  108. inspect_ai/util/_anyio.py +11 -0
  109. inspect_ai/util/_collect.py +50 -0
  110. inspect_ai/util/_limit.py +393 -0
  111. inspect_ai/util/_limited_conversation.py +57 -0
  112. inspect_ai/util/_span.py +58 -0
  113. inspect_ai/util/_subtask.py +27 -42
  114. {inspect_ai-0.3.92.dist-info → inspect_ai-0.3.94.dist-info}/METADATA +1 -1
  115. {inspect_ai-0.3.92.dist-info → inspect_ai-0.3.94.dist-info}/RECORD +120 -134
  116. {inspect_ai-0.3.92.dist-info → inspect_ai-0.3.94.dist-info}/WHEEL +1 -1
  117. inspect_ai/_display/core/group.py +0 -79
  118. inspect_ai/solver/_limit.py +0 -39
  119. inspect_ai/tool/_tools/_computer/_resources/Dockerfile +0 -102
  120. inspect_ai/tool/_tools/_computer/_resources/README.md +0 -30
  121. inspect_ai/tool/_tools/_computer/_resources/entrypoint/entrypoint.sh +0 -18
  122. inspect_ai/tool/_tools/_computer/_resources/entrypoint/novnc_startup.sh +0 -20
  123. inspect_ai/tool/_tools/_computer/_resources/entrypoint/x11vnc_startup.sh +0 -48
  124. inspect_ai/tool/_tools/_computer/_resources/entrypoint/xfce_startup.sh +0 -13
  125. inspect_ai/tool/_tools/_computer/_resources/entrypoint/xvfb_startup.sh +0 -48
  126. inspect_ai/tool/_tools/_computer/_resources/image_home_dir/.config/Code/User/globalStorage/state.vscdb +0 -0
  127. inspect_ai/tool/_tools/_computer/_resources/image_home_dir/.config/Code/User/settings.json +0 -9
  128. inspect_ai/tool/_tools/_computer/_resources/image_home_dir/.config/xfce4/xfconf/xfce-perchannel-xml/xfce4-panel.xml +0 -61
  129. inspect_ai/tool/_tools/_computer/_resources/image_home_dir/.config/xfce4/xfconf/xfce-perchannel-xml/xfce4-screensaver.xml +0 -10
  130. inspect_ai/tool/_tools/_computer/_resources/image_home_dir/.config/xfce4/xfconf/xfce-perchannel-xml/xfwm4.xml +0 -91
  131. inspect_ai/tool/_tools/_computer/_resources/image_home_dir/Desktop/Firefox Web Browser.desktop +0 -10
  132. inspect_ai/tool/_tools/_computer/_resources/image_home_dir/Desktop/Terminal.desktop +0 -10
  133. inspect_ai/tool/_tools/_computer/_resources/image_home_dir/Desktop/Visual Studio Code.desktop +0 -10
  134. inspect_ai/tool/_tools/_computer/_resources/tool/.pylintrc +0 -8
  135. inspect_ai/tool/_tools/_computer/_resources/tool/.vscode/settings.json +0 -12
  136. inspect_ai/tool/_tools/_computer/_resources/tool/_args.py +0 -78
  137. inspect_ai/tool/_tools/_computer/_resources/tool/_constants.py +0 -22
  138. inspect_ai/tool/_tools/_computer/_resources/tool/_logger.py +0 -22
  139. inspect_ai/tool/_tools/_computer/_resources/tool/_run.py +0 -42
  140. inspect_ai/tool/_tools/_computer/_resources/tool/_tool_result.py +0 -33
  141. inspect_ai/tool/_tools/_computer/_resources/tool/_x11_client.py +0 -341
  142. inspect_ai/tool/_tools/_computer/_resources/tool/computer_tool.py +0 -141
  143. inspect_ai/tool/_tools/_computer/_resources/tool/pyproject.toml +0 -65
  144. inspect_ai/tool/_tools/_computer/_resources/tool/requirements.txt +0 -0
  145. inspect_ai/tool/_tools/_computer/test_args.py +0 -151
  146. /inspect_ai/{tool/_tools/_computer/_resources/tool/__init__.py → _view/www/src/app/log-view/tabs/ModelsTab.module.css} +0 -0
  147. {inspect_ai-0.3.92.dist-info → inspect_ai-0.3.94.dist-info}/entry_points.txt +0 -0
  148. {inspect_ai-0.3.92.dist-info → inspect_ai-0.3.94.dist-info}/licenses/LICENSE +0 -0
  149. {inspect_ai-0.3.92.dist-info → inspect_ai-0.3.94.dist-info}/top_level.txt +0 -0
@@ -19,6 +19,7 @@ from typing import (
19
19
  cast,
20
20
  )
21
21
 
22
+ from pydantic import BaseModel
22
23
  from pydantic_core import to_jsonable_python
23
24
  from tenacity import (
24
25
  RetryCallState,
@@ -57,6 +58,11 @@ from inspect_ai.tool._tool import ToolSource
57
58
  from inspect_ai.tool._tool_call import ToolCallModelInputHints
58
59
  from inspect_ai.tool._tool_def import ToolDef, tool_defs
59
60
  from inspect_ai.util import concurrency
61
+ from inspect_ai.util._limit import (
62
+ check_message_limit,
63
+ check_token_limit,
64
+ record_model_usage,
65
+ )
60
66
 
61
67
  from ._cache import CacheEntry, CachePolicy, cache_fetch, cache_store
62
68
  from ._call_tools import (
@@ -355,11 +361,15 @@ class Model:
355
361
  Returns:
356
362
  ModelOutput
357
363
  """
358
- # if we are the default model then enforce message limit if it
359
- # exists (raise an exception if it is exceeded)
364
+ # if we are the default model then update the displayed message count
360
365
  is_active_model = self == active_model()
361
366
  if is_active_model:
362
- handle_sample_message_limit(input)
367
+ set_total_messages(input)
368
+
369
+ # check message limit, raise exception if we're already at the limit to prevent
370
+ # a wasteful generate()
371
+ conversation_length = len(input) if isinstance(input, list) else 1
372
+ check_message_limit(conversation_length, raise_for_equal=True)
363
373
 
364
374
  # base config for this model
365
375
  base_config = self.config
@@ -393,36 +403,32 @@ class Model:
393
403
  start_time = datetime.now()
394
404
  working_start = sample_working_time()
395
405
  async with self._connection_concurrency(config):
396
- from inspect_ai.log._samples import track_active_sample_retries
397
-
398
406
  # generate
399
- with track_active_sample_retries():
400
- output = await self._generate(
401
- input=input,
402
- tools=tools,
403
- tool_choice=tool_choice,
404
- config=config,
405
- cache=cache,
406
- )
407
+ output, event = await self._generate(
408
+ input=input,
409
+ tools=tools,
410
+ tool_choice=tool_choice,
411
+ config=config,
412
+ cache=cache,
413
+ )
407
414
 
408
415
  # update the most recent ModelEvent with the actual start/completed
409
416
  # times as well as a computation of working time (events are
410
417
  # created _after_ the call to _generate, potentially in response
411
418
  # to retries, so they need their timestamp updated so it accurately
412
419
  # reflects the full start/end time which we know here)
413
- from inspect_ai.log._transcript import ModelEvent, transcript
414
-
415
- last_model_event = transcript().find_last_event(ModelEvent)
416
- if last_model_event:
417
- last_model_event.timestamp = start_time
418
- last_model_event.working_start = working_start
419
- completed = datetime.now()
420
- last_model_event.completed = completed
421
- last_model_event.working_time = (
422
- output.time
423
- if output.time is not None
424
- else (completed - start_time).total_seconds()
425
- )
420
+ from inspect_ai.log._transcript import ModelEvent
421
+
422
+ assert isinstance(event, ModelEvent)
423
+ event.timestamp = start_time
424
+ event.working_start = working_start
425
+ completed = datetime.now()
426
+ event.completed = completed
427
+ event.working_time = (
428
+ output.time
429
+ if output.time is not None
430
+ else (completed - start_time).total_seconds()
431
+ )
426
432
 
427
433
  # return output
428
434
  return output
@@ -483,9 +489,12 @@ class Model:
483
489
  tool_choice: ToolChoice | None,
484
490
  config: GenerateConfig,
485
491
  cache: bool | CachePolicy = False,
486
- ) -> ModelOutput:
492
+ ) -> tuple[ModelOutput, BaseModel]:
493
+ from inspect_ai.log._samples import track_active_model_event
494
+ from inspect_ai.log._transcript import ModelEvent
495
+
487
496
  # default to 'auto' for tool_choice (same as underlying model apis)
488
- tool_choice = tool_choice if tool_choice else "auto"
497
+ tool_choice = tool_choice if tool_choice is not None else "auto"
489
498
 
490
499
  # resolve top level tool source
491
500
  if isinstance(tools, ToolSource):
@@ -572,7 +581,10 @@ class Model:
572
581
  stop=stop,
573
582
  before_sleep=functools.partial(log_model_retry, self.api.model_name),
574
583
  )
575
- async def generate() -> ModelOutput:
584
+ async def generate() -> tuple[ModelOutput, BaseModel]:
585
+ # type-checker can't see that we made sure tool_choice is not none in the outer frame
586
+ assert tool_choice is not None
587
+
576
588
  check_sample_interrupt()
577
589
 
578
590
  cache_entry: CacheEntry | None
@@ -593,7 +605,7 @@ class Model:
593
605
  )
594
606
  existing = cache_fetch(cache_entry)
595
607
  if isinstance(existing, ModelOutput):
596
- self._record_model_interaction(
608
+ _, event = self._record_model_interaction(
597
609
  input=input,
598
610
  tools=tools_info,
599
611
  tool_choice=tool_choice,
@@ -602,7 +614,7 @@ class Model:
602
614
  output=existing,
603
615
  call=None,
604
616
  )
605
- return existing
617
+ return existing, event
606
618
  else:
607
619
  cache_entry = None
608
620
 
@@ -611,7 +623,7 @@ class Model:
611
623
 
612
624
  # record the interaction before the call to generate
613
625
  # (we'll update it with the results once we have them)
614
- complete = self._record_model_interaction(
626
+ complete, event = self._record_model_interaction(
615
627
  input=input,
616
628
  tools=tools_info,
617
629
  tool_choice=tool_choice,
@@ -622,12 +634,14 @@ class Model:
622
634
  with trace_action(logger, "Model", f"generate ({str(self)})"):
623
635
  time_start = time.monotonic()
624
636
  try:
625
- result = await self.api.generate(
626
- input=input,
627
- tools=tools_info,
628
- tool_choice=tool_choice,
629
- config=config,
630
- )
637
+ assert isinstance(event, ModelEvent)
638
+ with track_active_model_event(event):
639
+ result = await self.api.generate(
640
+ input=input,
641
+ tools=tools_info,
642
+ tool_choice=tool_choice,
643
+ config=config,
644
+ )
631
645
  finally:
632
646
  time_elapsed = time.monotonic() - time_start
633
647
 
@@ -666,7 +680,7 @@ class Model:
666
680
  # record usage
667
681
  if output.usage:
668
682
  # record usage
669
- record_model_usage(f"{self}", output.usage)
683
+ record_and_check_model_usage(f"{self}", output.usage)
670
684
 
671
685
  # send telemetry if its hooked up
672
686
  await send_telemetry(
@@ -677,18 +691,18 @@ class Model:
677
691
  if cache and cache_entry:
678
692
  cache_store(entry=cache_entry, output=output)
679
693
 
680
- return output
694
+ return output, event
681
695
 
682
696
  # call the model (this will so retries, etc., so report waiting time
683
697
  # as elapsed time - actual time for successful model call)
684
698
  time_start = time.monotonic()
685
- model_output = await generate()
699
+ model_output, event = await generate()
686
700
  total_time = time.monotonic() - time_start
687
701
  if model_output.time:
688
702
  report_sample_waiting_time(total_time - model_output.time)
689
703
 
690
704
  # return results
691
- return model_output
705
+ return model_output, event
692
706
 
693
707
  def should_retry(self, ex: BaseException) -> bool:
694
708
  if isinstance(ex, Exception):
@@ -760,7 +774,7 @@ class Model:
760
774
  cache: Literal["read", "write"] | None,
761
775
  output: ModelOutput | None = None,
762
776
  call: ModelCall | None = None,
763
- ) -> Callable[[ModelOutput | Exception, ModelCall | None], None]:
777
+ ) -> tuple[Callable[[ModelOutput | Exception, ModelCall | None], None], BaseModel]:
764
778
  from inspect_ai.log._transcript import ModelEvent, transcript
765
779
 
766
780
  # create event and add it to the transcript
@@ -800,7 +814,7 @@ class Model:
800
814
  if output:
801
815
  complete(output, call)
802
816
 
803
- return complete
817
+ return complete, event
804
818
 
805
819
 
806
820
  class ModelName:
@@ -1423,20 +1437,10 @@ _model_roles: ContextVar[dict[str, Model]] = ContextVar("model_roles", default={
1423
1437
 
1424
1438
 
1425
1439
  # shared contexts for asyncio tasks
1426
- def handle_sample_message_limit(input: str | list[ChatMessage]) -> None:
1427
- from inspect_ai.log._samples import (
1428
- active_sample_message_limit,
1429
- set_active_sample_total_messages,
1430
- )
1431
- from inspect_ai.solver._limit import SampleLimitExceededError
1440
+ def set_total_messages(input: str | list[ChatMessage]) -> None:
1441
+ from inspect_ai.log._samples import set_active_sample_total_messages
1432
1442
 
1433
1443
  total_messages = 1 if isinstance(input, str) else len(input)
1434
- message_limit = active_sample_message_limit()
1435
- if message_limit is not None:
1436
- if total_messages >= message_limit:
1437
- raise SampleLimitExceededError(
1438
- "message", value=total_messages, limit=message_limit
1439
- )
1440
1444
 
1441
1445
  # set total messages
1442
1446
  set_active_sample_total_messages(total_messages)
@@ -1450,16 +1454,13 @@ def init_sample_model_usage() -> None:
1450
1454
  sample_model_usage_context_var.set({})
1451
1455
 
1452
1456
 
1453
- def record_model_usage(model: str, usage: ModelUsage) -> None:
1454
- from inspect_ai.log._samples import (
1455
- active_sample_token_limit,
1456
- set_active_sample_total_tokens,
1457
- )
1458
- from inspect_ai.solver._limit import SampleLimitExceededError
1457
+ def record_and_check_model_usage(model: str, usage: ModelUsage) -> None:
1458
+ from inspect_ai.log._samples import set_active_sample_total_tokens
1459
1459
 
1460
1460
  # record usage
1461
1461
  set_model_usage(model, usage, sample_model_usage_context_var.get(None))
1462
1462
  set_model_usage(model, usage, model_usage_context_var.get(None))
1463
+ record_model_usage(usage)
1463
1464
 
1464
1465
  # compute total tokens
1465
1466
  total_tokens = sample_total_tokens()
@@ -1467,38 +1468,15 @@ def record_model_usage(model: str, usage: ModelUsage) -> None:
1467
1468
  # update active sample
1468
1469
  set_active_sample_total_tokens(total_tokens)
1469
1470
 
1470
- # check for token limit overflow and raise
1471
- token_limit = active_sample_token_limit()
1472
- if token_limit is not None:
1473
- if total_tokens > token_limit:
1474
- raise SampleLimitExceededError(
1475
- "token", value=total_tokens, limit=token_limit
1476
- )
1471
+ check_token_limit()
1477
1472
 
1478
1473
 
1479
1474
  def set_model_usage(
1480
1475
  model: str, usage: ModelUsage, model_usage: dict[str, ModelUsage] | None
1481
1476
  ) -> None:
1482
1477
  if model_usage is not None:
1483
- total_usage: ModelUsage | None = model_usage.get(model, None)
1484
- if not total_usage:
1485
- total_usage = ModelUsage()
1486
- total_usage.input_tokens += usage.input_tokens
1487
- total_usage.output_tokens += usage.output_tokens
1488
- total_usage.total_tokens += usage.total_tokens
1489
- if usage.input_tokens_cache_write is not None:
1490
- if total_usage.input_tokens_cache_write is None:
1491
- total_usage.input_tokens_cache_write = 0
1492
- total_usage.input_tokens_cache_write += usage.input_tokens_cache_write
1493
- if usage.input_tokens_cache_read is not None:
1494
- if total_usage.input_tokens_cache_read is None:
1495
- total_usage.input_tokens_cache_read = 0
1496
- total_usage.input_tokens_cache_read += usage.input_tokens_cache_read
1497
- if usage.reasoning_tokens is not None:
1498
- if total_usage.reasoning_tokens is None:
1499
- total_usage.reasoning_tokens = 0
1500
- total_usage.reasoning_tokens += usage.reasoning_tokens
1501
-
1478
+ total_usage = model_usage.get(model, ModelUsage())
1479
+ total_usage += usage
1502
1480
  model_usage[model] = total_usage
1503
1481
 
1504
1482
 
@@ -30,6 +30,31 @@ class ModelUsage(BaseModel):
30
30
  reasoning_tokens: int | None = Field(default=None)
31
31
  """Number of tokens used for reasoning."""
32
32
 
33
+ def __add__(self, other: "ModelUsage") -> "ModelUsage":
34
+ def optional_sum(a: int | None, b: int | None) -> int | None:
35
+ if a is not None and b is not None:
36
+ return a + b
37
+ if a is not None:
38
+ return a
39
+ if b is not None:
40
+ return b
41
+ return None
42
+
43
+ return ModelUsage(
44
+ input_tokens=self.input_tokens + other.input_tokens,
45
+ output_tokens=self.output_tokens + other.output_tokens,
46
+ total_tokens=self.total_tokens + other.total_tokens,
47
+ input_tokens_cache_write=optional_sum(
48
+ self.input_tokens_cache_write, other.input_tokens_cache_write
49
+ ),
50
+ input_tokens_cache_read=optional_sum(
51
+ self.input_tokens_cache_read, other.input_tokens_cache_read
52
+ ),
53
+ reasoning_tokens=optional_sum(
54
+ self.reasoning_tokens, other.reasoning_tokens
55
+ ),
56
+ )
57
+
33
58
 
34
59
  StopReason = Literal[
35
60
  "stop",
@@ -255,6 +255,8 @@ def openai_completion_params(
255
255
  strict=config.response_schema.strict,
256
256
  ),
257
257
  )
258
+ if config.extra_body:
259
+ params["extra_body"] = config.extra_body
258
260
 
259
261
  return params
260
262
 
@@ -26,7 +26,6 @@ from anthropic.types import (
26
26
  TextBlockParam,
27
27
  ThinkingBlock,
28
28
  ThinkingBlockParam,
29
- ToolBash20250124Param,
30
29
  ToolParam,
31
30
  ToolResultBlockParam,
32
31
  ToolTextEditor20250124Param,
@@ -76,6 +75,7 @@ class AnthropicAPI(ModelAPI):
76
75
  base_url: str | None = None,
77
76
  api_key: str | None = None,
78
77
  config: GenerateConfig = GenerateConfig(),
78
+ streaming: bool | Literal["auto"] = "auto",
79
79
  **model_args: Any,
80
80
  ):
81
81
  # extract any service prefix from model name
@@ -85,6 +85,9 @@ class AnthropicAPI(ModelAPI):
85
85
  else:
86
86
  self.service = None
87
87
 
88
+ # record steraming pref
89
+ self.streaming = streaming
90
+
88
91
  # collect generate model_args (then delete them so we can pass the rest on)
89
92
  def collect_model_arg(name: str) -> Any | None:
90
93
  nonlocal model_args
@@ -224,8 +227,13 @@ class AnthropicAPI(ModelAPI):
224
227
  if self.extra_body is not None:
225
228
  request["extra_body"] = self.extra_body
226
229
 
227
- # make request (stream if we are using reasoning)
228
- if self.is_using_thinking(config):
230
+ # make request (unless overrideen, stream if we are using reasoning)
231
+ streaming = (
232
+ self.is_using_thinking(config)
233
+ if self.streaming == "auto"
234
+ else self.streaming
235
+ )
236
+ if streaming:
229
237
  async with self.client.messages.stream(**request) as stream:
230
238
  message = await stream.get_final_message()
231
239
  else:
@@ -489,11 +497,7 @@ class AnthropicAPI(ModelAPI):
489
497
  self, tool: ToolInfo, config: GenerateConfig
490
498
  ) -> Optional["ToolParamDef"]:
491
499
  return (
492
- (
493
- self.computer_use_tool_param(tool)
494
- or self.text_editor_tool_param(tool)
495
- or self.bash_tool_param(tool)
496
- )
500
+ (self.computer_use_tool_param(tool) or self.text_editor_tool_param(tool))
497
501
  if config.internal_tools is not False
498
502
  else None
499
503
  )
@@ -564,23 +568,10 @@ class AnthropicAPI(ModelAPI):
564
568
  else:
565
569
  return None
566
570
 
567
- def bash_tool_param(self, tool: ToolInfo) -> Optional[ToolBash20250124Param]:
568
- # check for compatible 'bash' tool
569
- if tool.name == "bash_session" and (
570
- sorted(tool.parameters.properties.keys()) == sorted(["command", "restart"])
571
- ):
572
- return ToolBash20250124Param(type="bash_20250124", name="bash")
573
- # not a bash tool
574
- else:
575
- return None
576
-
577
571
 
578
572
  # tools can be either a stock tool param or a special Anthropic native use tool param
579
573
  ToolParamDef = (
580
- ToolParam
581
- | BetaToolComputerUse20250124Param
582
- | ToolTextEditor20250124Param
583
- | ToolBash20250124Param
574
+ ToolParam | BetaToolComputerUse20250124Param | ToolTextEditor20250124Param
584
575
  )
585
576
 
586
577
 
@@ -589,7 +580,6 @@ def add_cache_control(
589
580
  | ToolParam
590
581
  | BetaToolComputerUse20250124Param
591
582
  | ToolTextEditor20250124Param
592
- | ToolBash20250124Param
593
583
  | dict[str, Any],
594
584
  ) -> None:
595
585
  cast(dict[str, Any], param)["cache_control"] = {"type": "ephemeral"}
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import concurrent
2
4
  import concurrent.futures
3
5
  import copy
@@ -26,7 +28,12 @@ from transformers import ( # type: ignore
26
28
  from typing_extensions import override
27
29
 
28
30
  from inspect_ai._util.constants import DEFAULT_MAX_TOKENS
29
- from inspect_ai._util.content import ContentText
31
+ from inspect_ai._util.content import (
32
+ ContentAudio,
33
+ ContentImage,
34
+ ContentText,
35
+ ContentVideo,
36
+ )
30
37
  from inspect_ai._util.trace import trace_action
31
38
  from inspect_ai.tool import ToolChoice, ToolInfo
32
39
 
@@ -85,6 +92,7 @@ class HuggingFaceAPI(ModelAPI):
85
92
  self.batch_size = collect_model_arg("batch_size")
86
93
  self.chat_template = collect_model_arg("chat_template")
87
94
  self.tokenizer_call_args = collect_model_arg("tokenizer_call_args")
95
+ self.enable_thinking = collect_model_arg("enable_thinking")
88
96
  if self.tokenizer_call_args is None:
89
97
  self.tokenizer_call_args = {}
90
98
 
@@ -263,6 +271,7 @@ class HuggingFaceAPI(ModelAPI):
263
271
  elif "qwen" in self.model_name.lower():
264
272
  hf_messages = inspect_tools_to_string(hf_messages)
265
273
 
274
+ hf_messages = message_content_to_string(hf_messages)
266
275
  # apply chat template
267
276
  if self.tokenizer.chat_template is not None:
268
277
  chat = self.tokenizer.apply_chat_template(
@@ -270,6 +279,7 @@ class HuggingFaceAPI(ModelAPI):
270
279
  add_generation_prompt=True,
271
280
  tokenize=False,
272
281
  tools=tools_list if len(tools_list) > 0 else None,
282
+ enable_thinking=self.enable_thinking, # not all models use this, check if it is supported
273
283
  )
274
284
  else:
275
285
  chat = ""
@@ -279,6 +289,22 @@ class HuggingFaceAPI(ModelAPI):
279
289
  return cast(str, chat)
280
290
 
281
291
 
292
+ def message_content_to_string(messages: list[ChatMessage]) -> list[ChatMessage]:
293
+ """Convert list of content in `ChatMessageAssistant`, `ChatMessageUser` or `ChatMessageSystem` to a string."""
294
+ for message in messages:
295
+ if isinstance(message.content, list):
296
+ is_multimodal = any(
297
+ isinstance(item, ContentAudio | ContentImage | ContentVideo)
298
+ for item in message.content
299
+ )
300
+ if is_multimodal:
301
+ raise NotImplementedError(
302
+ "HuggingFace provider does not support multimodal content, please provide text inputs only."
303
+ )
304
+ message.content = message.text
305
+ return messages
306
+
307
+
282
308
  def shorten_tool_id(messages: list[ChatMessage]) -> list[ChatMessage]:
283
309
  """Shorten the tool_call_id in the messages to the last 9 characters for Mistral."""
284
310
  for i, message in enumerate(messages):
@@ -211,8 +211,15 @@ class O1PreviewChatAPIHandler(ChatAPIHandler):
211
211
  This method has an interdependency with `input_with_tools()` (as that is the
212
212
  prompt that asks the model to use the <tool_call>...</tool_call> syntax)
213
213
  """
214
- # extract tool calls
214
+ # define regex patterns
215
+ # NOTE: If you change either of these regex patterns, please update the other
216
+ # tool_call_regex extracts the JSON content (in curly braces) between tool call tags
215
217
  tool_call_regex = rf"<{TOOL_CALL}>\s*(\{{[\s\S]*?\}})\s*</{TOOL_CALL}>"
218
+ # tool_call_content_regex matches the entire tool call block including tags for extracting
219
+ # the content outside of the tool call tags
220
+ tool_call_content_regex = rf"<{TOOL_CALL}>\s*\{{[\s\S]*?\}}\s*</{TOOL_CALL}>"
221
+
222
+ # extract tool calls
216
223
  tool_calls_content: list[str] = re.findall(tool_call_regex, response)
217
224
 
218
225
  # if there are tool calls proceed with parsing
@@ -226,7 +233,6 @@ class O1PreviewChatAPIHandler(ChatAPIHandler):
226
233
  ]
227
234
 
228
235
  # find other content that exists outside tool calls
229
- tool_call_content_regex = rf"<{TOOL_CALL}>(?:.|\n)*?</{TOOL_CALL}>"
230
236
  other_content = re.split(tool_call_content_regex, response, flags=re.DOTALL)
231
237
  other_content = [
232
238
  str(content).strip()
@@ -136,10 +136,12 @@ def hf() -> type[ModelAPI]:
136
136
 
137
137
  @modelapi(name="vllm")
138
138
  def vllm() -> type[ModelAPI]:
139
- try:
140
- from .vllm import VLLMAPI
141
- except ImportError:
142
- raise pip_dependency_error("vLLM Models", ["vllm"])
139
+ # Only validate OpenAI compatibility (needed for the API interface)
140
+ validate_openai_client("vLLM API")
141
+
142
+ # Import VLLMAPI without checking for vllm package yet
143
+ # The actual vllm dependency will only be checked if needed to start a server
144
+ from .vllm import VLLMAPI
143
145
 
144
146
  return VLLMAPI
145
147
 
@@ -257,6 +259,18 @@ def mockllm() -> type[ModelAPI]:
257
259
  return MockLLM
258
260
 
259
261
 
262
+ @modelapi(name="sglang")
263
+ def sglang() -> type[ModelAPI]:
264
+ # Only validate OpenAI compatibility (needed for the API interface)
265
+ validate_openai_client("SGLang API")
266
+
267
+ # Import SGLangAPI without checking for sglang package yet
268
+ # The actual sglang dependency will only be checked if needed to start a server
269
+ from .sglang import SGLangAPI
270
+
271
+ return SGLangAPI
272
+
273
+
260
274
  @modelapi(name="none")
261
275
  def none() -> type[ModelAPI]:
262
276
  from .none import NoModel