inspect-ai 0.3.92__py3-none-any.whl → 0.3.94__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/_cli/eval.py +27 -0
- inspect_ai/_display/textual/widgets/samples.py +3 -3
- inspect_ai/_display/textual/widgets/transcript.py +3 -29
- inspect_ai/_eval/eval.py +19 -2
- inspect_ai/_eval/evalset.py +4 -1
- inspect_ai/_eval/run.py +41 -0
- inspect_ai/_eval/task/generate.py +38 -44
- inspect_ai/_eval/task/log.py +26 -28
- inspect_ai/_eval/task/run.py +23 -27
- inspect_ai/_util/answer.py +26 -0
- inspect_ai/_util/constants.py +0 -1
- inspect_ai/_util/local_server.py +398 -0
- inspect_ai/_util/working.py +10 -4
- inspect_ai/_view/www/dist/assets/index.css +173 -159
- inspect_ai/_view/www/dist/assets/index.js +1417 -1142
- inspect_ai/_view/www/log-schema.json +379 -3
- inspect_ai/_view/www/package.json +1 -1
- inspect_ai/_view/www/src/@types/log.d.ts +93 -14
- inspect_ai/_view/www/src/app/content/MetaDataGrid.tsx +2 -2
- inspect_ai/_view/www/src/app/content/MetaDataView.module.css +1 -1
- inspect_ai/_view/www/src/app/content/MetadataGrid.module.css +1 -1
- inspect_ai/_view/www/src/app/content/RenderedContent.tsx +1 -1
- inspect_ai/_view/www/src/app/log-view/LogView.tsx +11 -0
- inspect_ai/_view/www/src/app/log-view/tabs/InfoTab.tsx +2 -9
- inspect_ai/_view/www/src/app/log-view/tabs/ModelsTab.tsx +51 -0
- inspect_ai/_view/www/src/app/log-view/tabs/TaskTab.module.css +6 -0
- inspect_ai/_view/www/src/app/log-view/tabs/TaskTab.tsx +143 -0
- inspect_ai/_view/www/src/app/plan/ModelCard.tsx +1 -2
- inspect_ai/_view/www/src/app/plan/PlanCard.tsx +29 -7
- inspect_ai/_view/www/src/app/plan/PlanDetailView.module.css +1 -1
- inspect_ai/_view/www/src/app/plan/PlanDetailView.tsx +1 -198
- inspect_ai/_view/www/src/app/samples/descriptor/score/NumericScoreDescriptor.tsx +2 -1
- inspect_ai/_view/www/src/app/samples/transcript/SandboxEventView.module.css +2 -1
- inspect_ai/_view/www/src/app/samples/transcript/SpanEventView.tsx +174 -0
- inspect_ai/_view/www/src/app/samples/transcript/ToolEventView.tsx +8 -8
- inspect_ai/_view/www/src/app/samples/transcript/TranscriptView.tsx +12 -2
- inspect_ai/_view/www/src/app/samples/transcript/TranscriptVirtualListComponent.module.css +1 -1
- inspect_ai/_view/www/src/app/samples/transcript/event/EventPanel.tsx +0 -3
- inspect_ai/_view/www/src/app/samples/transcript/transform/fixups.ts +87 -25
- inspect_ai/_view/www/src/app/samples/transcript/transform/treeify.ts +229 -17
- inspect_ai/_view/www/src/app/samples/transcript/transform/utils.ts +11 -0
- inspect_ai/_view/www/src/app/samples/transcript/types.ts +5 -1
- inspect_ai/_view/www/src/app/usage/ModelUsagePanel.tsx +3 -2
- inspect_ai/_view/www/src/app/usage/TokenTable.module.css +4 -1
- inspect_ai/_view/www/src/app/usage/TokenTable.tsx +2 -2
- inspect_ai/_view/www/src/app/usage/UsageCard.module.css +8 -3
- inspect_ai/_view/www/src/app/usage/UsageCard.tsx +1 -35
- inspect_ai/_view/www/src/components/Card.css +0 -1
- inspect_ai/_view/www/src/constants.ts +2 -0
- inspect_ai/_view/www/src/utils/numeric.ts +17 -0
- inspect_ai/agent/_agent.py +3 -3
- inspect_ai/agent/_as_solver.py +22 -12
- inspect_ai/agent/_as_tool.py +20 -6
- inspect_ai/agent/_handoff.py +12 -1
- inspect_ai/agent/_react.py +4 -3
- inspect_ai/agent/_run.py +16 -3
- inspect_ai/agent/_types.py +9 -0
- inspect_ai/dataset/_dataset.py +6 -3
- inspect_ai/log/__init__.py +14 -0
- inspect_ai/log/_convert.py +4 -9
- inspect_ai/log/_file.py +56 -0
- inspect_ai/log/_log.py +99 -0
- inspect_ai/log/_recorders/__init__.py +2 -0
- inspect_ai/log/_recorders/buffer/database.py +12 -11
- inspect_ai/log/_recorders/buffer/filestore.py +2 -2
- inspect_ai/log/_recorders/buffer/types.py +2 -2
- inspect_ai/log/_recorders/eval.py +20 -65
- inspect_ai/log/_recorders/file.py +28 -6
- inspect_ai/log/_recorders/recorder.py +7 -0
- inspect_ai/log/_recorders/types.py +1 -23
- inspect_ai/log/_samples.py +14 -25
- inspect_ai/log/_transcript.py +84 -36
- inspect_ai/log/_tree.py +118 -0
- inspect_ai/log/_util.py +52 -0
- inspect_ai/model/__init__.py +5 -1
- inspect_ai/model/_call_tools.py +72 -44
- inspect_ai/model/_generate_config.py +14 -8
- inspect_ai/model/_model.py +66 -88
- inspect_ai/model/_model_output.py +25 -0
- inspect_ai/model/_openai.py +2 -0
- inspect_ai/model/_providers/anthropic.py +13 -23
- inspect_ai/model/_providers/hf.py +27 -1
- inspect_ai/model/_providers/openai_o1.py +8 -2
- inspect_ai/model/_providers/providers.py +18 -4
- inspect_ai/model/_providers/sglang.py +247 -0
- inspect_ai/model/_providers/vllm.py +211 -400
- inspect_ai/scorer/_choice.py +1 -2
- inspect_ai/solver/__init__.py +7 -2
- inspect_ai/solver/_basic_agent.py +3 -10
- inspect_ai/solver/_chain.py +1 -1
- inspect_ai/solver/_fork.py +1 -1
- inspect_ai/solver/_multiple_choice.py +5 -22
- inspect_ai/solver/_plan.py +2 -2
- inspect_ai/solver/_task_state.py +26 -88
- inspect_ai/solver/_transcript.py +6 -7
- inspect_ai/tool/_json_rpc_helpers.py +45 -17
- inspect_ai/tool/_mcp/_mcp.py +8 -5
- inspect_ai/tool/_mcp/_sandbox.py +8 -2
- inspect_ai/tool/_mcp/server.py +3 -1
- inspect_ai/tool/_tool_call.py +4 -1
- inspect_ai/tool/_tool_support_helpers.py +51 -12
- inspect_ai/tool/_tools/_bash_session.py +190 -68
- inspect_ai/tool/_tools/_computer/_computer.py +25 -1
- inspect_ai/tool/_tools/_execute.py +4 -1
- inspect_ai/tool/_tools/_text_editor.py +4 -3
- inspect_ai/tool/_tools/_web_browser/_web_browser.py +10 -3
- inspect_ai/util/__init__.py +16 -0
- inspect_ai/util/_anyio.py +11 -0
- inspect_ai/util/_collect.py +50 -0
- inspect_ai/util/_limit.py +393 -0
- inspect_ai/util/_limited_conversation.py +57 -0
- inspect_ai/util/_span.py +58 -0
- inspect_ai/util/_subtask.py +27 -42
- {inspect_ai-0.3.92.dist-info → inspect_ai-0.3.94.dist-info}/METADATA +1 -1
- {inspect_ai-0.3.92.dist-info → inspect_ai-0.3.94.dist-info}/RECORD +120 -134
- {inspect_ai-0.3.92.dist-info → inspect_ai-0.3.94.dist-info}/WHEEL +1 -1
- inspect_ai/_display/core/group.py +0 -79
- inspect_ai/solver/_limit.py +0 -39
- inspect_ai/tool/_tools/_computer/_resources/Dockerfile +0 -102
- inspect_ai/tool/_tools/_computer/_resources/README.md +0 -30
- inspect_ai/tool/_tools/_computer/_resources/entrypoint/entrypoint.sh +0 -18
- inspect_ai/tool/_tools/_computer/_resources/entrypoint/novnc_startup.sh +0 -20
- inspect_ai/tool/_tools/_computer/_resources/entrypoint/x11vnc_startup.sh +0 -48
- inspect_ai/tool/_tools/_computer/_resources/entrypoint/xfce_startup.sh +0 -13
- inspect_ai/tool/_tools/_computer/_resources/entrypoint/xvfb_startup.sh +0 -48
- inspect_ai/tool/_tools/_computer/_resources/image_home_dir/.config/Code/User/globalStorage/state.vscdb +0 -0
- inspect_ai/tool/_tools/_computer/_resources/image_home_dir/.config/Code/User/settings.json +0 -9
- inspect_ai/tool/_tools/_computer/_resources/image_home_dir/.config/xfce4/xfconf/xfce-perchannel-xml/xfce4-panel.xml +0 -61
- inspect_ai/tool/_tools/_computer/_resources/image_home_dir/.config/xfce4/xfconf/xfce-perchannel-xml/xfce4-screensaver.xml +0 -10
- inspect_ai/tool/_tools/_computer/_resources/image_home_dir/.config/xfce4/xfconf/xfce-perchannel-xml/xfwm4.xml +0 -91
- inspect_ai/tool/_tools/_computer/_resources/image_home_dir/Desktop/Firefox Web Browser.desktop +0 -10
- inspect_ai/tool/_tools/_computer/_resources/image_home_dir/Desktop/Terminal.desktop +0 -10
- inspect_ai/tool/_tools/_computer/_resources/image_home_dir/Desktop/Visual Studio Code.desktop +0 -10
- inspect_ai/tool/_tools/_computer/_resources/tool/.pylintrc +0 -8
- inspect_ai/tool/_tools/_computer/_resources/tool/.vscode/settings.json +0 -12
- inspect_ai/tool/_tools/_computer/_resources/tool/_args.py +0 -78
- inspect_ai/tool/_tools/_computer/_resources/tool/_constants.py +0 -22
- inspect_ai/tool/_tools/_computer/_resources/tool/_logger.py +0 -22
- inspect_ai/tool/_tools/_computer/_resources/tool/_run.py +0 -42
- inspect_ai/tool/_tools/_computer/_resources/tool/_tool_result.py +0 -33
- inspect_ai/tool/_tools/_computer/_resources/tool/_x11_client.py +0 -341
- inspect_ai/tool/_tools/_computer/_resources/tool/computer_tool.py +0 -141
- inspect_ai/tool/_tools/_computer/_resources/tool/pyproject.toml +0 -65
- inspect_ai/tool/_tools/_computer/_resources/tool/requirements.txt +0 -0
- inspect_ai/tool/_tools/_computer/test_args.py +0 -151
- /inspect_ai/{tool/_tools/_computer/_resources/tool/__init__.py → _view/www/src/app/log-view/tabs/ModelsTab.module.css} +0 -0
- {inspect_ai-0.3.92.dist-info → inspect_ai-0.3.94.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.92.dist-info → inspect_ai-0.3.94.dist-info}/licenses/LICENSE +0 -0
- {inspect_ai-0.3.92.dist-info → inspect_ai-0.3.94.dist-info}/top_level.txt +0 -0
inspect_ai/model/_model.py
CHANGED
@@ -19,6 +19,7 @@ from typing import (
|
|
19
19
|
cast,
|
20
20
|
)
|
21
21
|
|
22
|
+
from pydantic import BaseModel
|
22
23
|
from pydantic_core import to_jsonable_python
|
23
24
|
from tenacity import (
|
24
25
|
RetryCallState,
|
@@ -57,6 +58,11 @@ from inspect_ai.tool._tool import ToolSource
|
|
57
58
|
from inspect_ai.tool._tool_call import ToolCallModelInputHints
|
58
59
|
from inspect_ai.tool._tool_def import ToolDef, tool_defs
|
59
60
|
from inspect_ai.util import concurrency
|
61
|
+
from inspect_ai.util._limit import (
|
62
|
+
check_message_limit,
|
63
|
+
check_token_limit,
|
64
|
+
record_model_usage,
|
65
|
+
)
|
60
66
|
|
61
67
|
from ._cache import CacheEntry, CachePolicy, cache_fetch, cache_store
|
62
68
|
from ._call_tools import (
|
@@ -355,11 +361,15 @@ class Model:
|
|
355
361
|
Returns:
|
356
362
|
ModelOutput
|
357
363
|
"""
|
358
|
-
# if we are the default model then
|
359
|
-
# exists (raise an exception if it is exceeded)
|
364
|
+
# if we are the default model then update the displayed message count
|
360
365
|
is_active_model = self == active_model()
|
361
366
|
if is_active_model:
|
362
|
-
|
367
|
+
set_total_messages(input)
|
368
|
+
|
369
|
+
# check message limit, raise exception if we're already at the limit to prevent
|
370
|
+
# a wasteful generate()
|
371
|
+
conversation_length = len(input) if isinstance(input, list) else 1
|
372
|
+
check_message_limit(conversation_length, raise_for_equal=True)
|
363
373
|
|
364
374
|
# base config for this model
|
365
375
|
base_config = self.config
|
@@ -393,36 +403,32 @@ class Model:
|
|
393
403
|
start_time = datetime.now()
|
394
404
|
working_start = sample_working_time()
|
395
405
|
async with self._connection_concurrency(config):
|
396
|
-
from inspect_ai.log._samples import track_active_sample_retries
|
397
|
-
|
398
406
|
# generate
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
)
|
407
|
+
output, event = await self._generate(
|
408
|
+
input=input,
|
409
|
+
tools=tools,
|
410
|
+
tool_choice=tool_choice,
|
411
|
+
config=config,
|
412
|
+
cache=cache,
|
413
|
+
)
|
407
414
|
|
408
415
|
# update the most recent ModelEvent with the actual start/completed
|
409
416
|
# times as well as a computation of working time (events are
|
410
417
|
# created _after_ the call to _generate, potentially in response
|
411
418
|
# to retries, so they need their timestamp updated so it accurately
|
412
419
|
# reflects the full start/end time which we know here)
|
413
|
-
from inspect_ai.log._transcript import ModelEvent
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
)
|
420
|
+
from inspect_ai.log._transcript import ModelEvent
|
421
|
+
|
422
|
+
assert isinstance(event, ModelEvent)
|
423
|
+
event.timestamp = start_time
|
424
|
+
event.working_start = working_start
|
425
|
+
completed = datetime.now()
|
426
|
+
event.completed = completed
|
427
|
+
event.working_time = (
|
428
|
+
output.time
|
429
|
+
if output.time is not None
|
430
|
+
else (completed - start_time).total_seconds()
|
431
|
+
)
|
426
432
|
|
427
433
|
# return output
|
428
434
|
return output
|
@@ -483,9 +489,12 @@ class Model:
|
|
483
489
|
tool_choice: ToolChoice | None,
|
484
490
|
config: GenerateConfig,
|
485
491
|
cache: bool | CachePolicy = False,
|
486
|
-
) -> ModelOutput:
|
492
|
+
) -> tuple[ModelOutput, BaseModel]:
|
493
|
+
from inspect_ai.log._samples import track_active_model_event
|
494
|
+
from inspect_ai.log._transcript import ModelEvent
|
495
|
+
|
487
496
|
# default to 'auto' for tool_choice (same as underlying model apis)
|
488
|
-
tool_choice = tool_choice if tool_choice else "auto"
|
497
|
+
tool_choice = tool_choice if tool_choice is not None else "auto"
|
489
498
|
|
490
499
|
# resolve top level tool source
|
491
500
|
if isinstance(tools, ToolSource):
|
@@ -572,7 +581,10 @@ class Model:
|
|
572
581
|
stop=stop,
|
573
582
|
before_sleep=functools.partial(log_model_retry, self.api.model_name),
|
574
583
|
)
|
575
|
-
async def generate() -> ModelOutput:
|
584
|
+
async def generate() -> tuple[ModelOutput, BaseModel]:
|
585
|
+
# type-checker can't see that we made sure tool_choice is not none in the outer frame
|
586
|
+
assert tool_choice is not None
|
587
|
+
|
576
588
|
check_sample_interrupt()
|
577
589
|
|
578
590
|
cache_entry: CacheEntry | None
|
@@ -593,7 +605,7 @@ class Model:
|
|
593
605
|
)
|
594
606
|
existing = cache_fetch(cache_entry)
|
595
607
|
if isinstance(existing, ModelOutput):
|
596
|
-
self._record_model_interaction(
|
608
|
+
_, event = self._record_model_interaction(
|
597
609
|
input=input,
|
598
610
|
tools=tools_info,
|
599
611
|
tool_choice=tool_choice,
|
@@ -602,7 +614,7 @@ class Model:
|
|
602
614
|
output=existing,
|
603
615
|
call=None,
|
604
616
|
)
|
605
|
-
return existing
|
617
|
+
return existing, event
|
606
618
|
else:
|
607
619
|
cache_entry = None
|
608
620
|
|
@@ -611,7 +623,7 @@ class Model:
|
|
611
623
|
|
612
624
|
# record the interaction before the call to generate
|
613
625
|
# (we'll update it with the results once we have them)
|
614
|
-
complete = self._record_model_interaction(
|
626
|
+
complete, event = self._record_model_interaction(
|
615
627
|
input=input,
|
616
628
|
tools=tools_info,
|
617
629
|
tool_choice=tool_choice,
|
@@ -622,12 +634,14 @@ class Model:
|
|
622
634
|
with trace_action(logger, "Model", f"generate ({str(self)})"):
|
623
635
|
time_start = time.monotonic()
|
624
636
|
try:
|
625
|
-
|
626
|
-
|
627
|
-
|
628
|
-
|
629
|
-
|
630
|
-
|
637
|
+
assert isinstance(event, ModelEvent)
|
638
|
+
with track_active_model_event(event):
|
639
|
+
result = await self.api.generate(
|
640
|
+
input=input,
|
641
|
+
tools=tools_info,
|
642
|
+
tool_choice=tool_choice,
|
643
|
+
config=config,
|
644
|
+
)
|
631
645
|
finally:
|
632
646
|
time_elapsed = time.monotonic() - time_start
|
633
647
|
|
@@ -666,7 +680,7 @@ class Model:
|
|
666
680
|
# record usage
|
667
681
|
if output.usage:
|
668
682
|
# record usage
|
669
|
-
|
683
|
+
record_and_check_model_usage(f"{self}", output.usage)
|
670
684
|
|
671
685
|
# send telemetry if its hooked up
|
672
686
|
await send_telemetry(
|
@@ -677,18 +691,18 @@ class Model:
|
|
677
691
|
if cache and cache_entry:
|
678
692
|
cache_store(entry=cache_entry, output=output)
|
679
693
|
|
680
|
-
return output
|
694
|
+
return output, event
|
681
695
|
|
682
696
|
# call the model (this will so retries, etc., so report waiting time
|
683
697
|
# as elapsed time - actual time for successful model call)
|
684
698
|
time_start = time.monotonic()
|
685
|
-
model_output = await generate()
|
699
|
+
model_output, event = await generate()
|
686
700
|
total_time = time.monotonic() - time_start
|
687
701
|
if model_output.time:
|
688
702
|
report_sample_waiting_time(total_time - model_output.time)
|
689
703
|
|
690
704
|
# return results
|
691
|
-
return model_output
|
705
|
+
return model_output, event
|
692
706
|
|
693
707
|
def should_retry(self, ex: BaseException) -> bool:
|
694
708
|
if isinstance(ex, Exception):
|
@@ -760,7 +774,7 @@ class Model:
|
|
760
774
|
cache: Literal["read", "write"] | None,
|
761
775
|
output: ModelOutput | None = None,
|
762
776
|
call: ModelCall | None = None,
|
763
|
-
) -> Callable[[ModelOutput | Exception, ModelCall | None], None]:
|
777
|
+
) -> tuple[Callable[[ModelOutput | Exception, ModelCall | None], None], BaseModel]:
|
764
778
|
from inspect_ai.log._transcript import ModelEvent, transcript
|
765
779
|
|
766
780
|
# create event and add it to the transcript
|
@@ -800,7 +814,7 @@ class Model:
|
|
800
814
|
if output:
|
801
815
|
complete(output, call)
|
802
816
|
|
803
|
-
return complete
|
817
|
+
return complete, event
|
804
818
|
|
805
819
|
|
806
820
|
class ModelName:
|
@@ -1423,20 +1437,10 @@ _model_roles: ContextVar[dict[str, Model]] = ContextVar("model_roles", default={
|
|
1423
1437
|
|
1424
1438
|
|
1425
1439
|
# shared contexts for asyncio tasks
|
1426
|
-
def
|
1427
|
-
from inspect_ai.log._samples import
|
1428
|
-
active_sample_message_limit,
|
1429
|
-
set_active_sample_total_messages,
|
1430
|
-
)
|
1431
|
-
from inspect_ai.solver._limit import SampleLimitExceededError
|
1440
|
+
def set_total_messages(input: str | list[ChatMessage]) -> None:
|
1441
|
+
from inspect_ai.log._samples import set_active_sample_total_messages
|
1432
1442
|
|
1433
1443
|
total_messages = 1 if isinstance(input, str) else len(input)
|
1434
|
-
message_limit = active_sample_message_limit()
|
1435
|
-
if message_limit is not None:
|
1436
|
-
if total_messages >= message_limit:
|
1437
|
-
raise SampleLimitExceededError(
|
1438
|
-
"message", value=total_messages, limit=message_limit
|
1439
|
-
)
|
1440
1444
|
|
1441
1445
|
# set total messages
|
1442
1446
|
set_active_sample_total_messages(total_messages)
|
@@ -1450,16 +1454,13 @@ def init_sample_model_usage() -> None:
|
|
1450
1454
|
sample_model_usage_context_var.set({})
|
1451
1455
|
|
1452
1456
|
|
1453
|
-
def
|
1454
|
-
from inspect_ai.log._samples import
|
1455
|
-
active_sample_token_limit,
|
1456
|
-
set_active_sample_total_tokens,
|
1457
|
-
)
|
1458
|
-
from inspect_ai.solver._limit import SampleLimitExceededError
|
1457
|
+
def record_and_check_model_usage(model: str, usage: ModelUsage) -> None:
|
1458
|
+
from inspect_ai.log._samples import set_active_sample_total_tokens
|
1459
1459
|
|
1460
1460
|
# record usage
|
1461
1461
|
set_model_usage(model, usage, sample_model_usage_context_var.get(None))
|
1462
1462
|
set_model_usage(model, usage, model_usage_context_var.get(None))
|
1463
|
+
record_model_usage(usage)
|
1463
1464
|
|
1464
1465
|
# compute total tokens
|
1465
1466
|
total_tokens = sample_total_tokens()
|
@@ -1467,38 +1468,15 @@ def record_model_usage(model: str, usage: ModelUsage) -> None:
|
|
1467
1468
|
# update active sample
|
1468
1469
|
set_active_sample_total_tokens(total_tokens)
|
1469
1470
|
|
1470
|
-
|
1471
|
-
token_limit = active_sample_token_limit()
|
1472
|
-
if token_limit is not None:
|
1473
|
-
if total_tokens > token_limit:
|
1474
|
-
raise SampleLimitExceededError(
|
1475
|
-
"token", value=total_tokens, limit=token_limit
|
1476
|
-
)
|
1471
|
+
check_token_limit()
|
1477
1472
|
|
1478
1473
|
|
1479
1474
|
def set_model_usage(
|
1480
1475
|
model: str, usage: ModelUsage, model_usage: dict[str, ModelUsage] | None
|
1481
1476
|
) -> None:
|
1482
1477
|
if model_usage is not None:
|
1483
|
-
total_usage
|
1484
|
-
|
1485
|
-
total_usage = ModelUsage()
|
1486
|
-
total_usage.input_tokens += usage.input_tokens
|
1487
|
-
total_usage.output_tokens += usage.output_tokens
|
1488
|
-
total_usage.total_tokens += usage.total_tokens
|
1489
|
-
if usage.input_tokens_cache_write is not None:
|
1490
|
-
if total_usage.input_tokens_cache_write is None:
|
1491
|
-
total_usage.input_tokens_cache_write = 0
|
1492
|
-
total_usage.input_tokens_cache_write += usage.input_tokens_cache_write
|
1493
|
-
if usage.input_tokens_cache_read is not None:
|
1494
|
-
if total_usage.input_tokens_cache_read is None:
|
1495
|
-
total_usage.input_tokens_cache_read = 0
|
1496
|
-
total_usage.input_tokens_cache_read += usage.input_tokens_cache_read
|
1497
|
-
if usage.reasoning_tokens is not None:
|
1498
|
-
if total_usage.reasoning_tokens is None:
|
1499
|
-
total_usage.reasoning_tokens = 0
|
1500
|
-
total_usage.reasoning_tokens += usage.reasoning_tokens
|
1501
|
-
|
1478
|
+
total_usage = model_usage.get(model, ModelUsage())
|
1479
|
+
total_usage += usage
|
1502
1480
|
model_usage[model] = total_usage
|
1503
1481
|
|
1504
1482
|
|
@@ -30,6 +30,31 @@ class ModelUsage(BaseModel):
|
|
30
30
|
reasoning_tokens: int | None = Field(default=None)
|
31
31
|
"""Number of tokens used for reasoning."""
|
32
32
|
|
33
|
+
def __add__(self, other: "ModelUsage") -> "ModelUsage":
|
34
|
+
def optional_sum(a: int | None, b: int | None) -> int | None:
|
35
|
+
if a is not None and b is not None:
|
36
|
+
return a + b
|
37
|
+
if a is not None:
|
38
|
+
return a
|
39
|
+
if b is not None:
|
40
|
+
return b
|
41
|
+
return None
|
42
|
+
|
43
|
+
return ModelUsage(
|
44
|
+
input_tokens=self.input_tokens + other.input_tokens,
|
45
|
+
output_tokens=self.output_tokens + other.output_tokens,
|
46
|
+
total_tokens=self.total_tokens + other.total_tokens,
|
47
|
+
input_tokens_cache_write=optional_sum(
|
48
|
+
self.input_tokens_cache_write, other.input_tokens_cache_write
|
49
|
+
),
|
50
|
+
input_tokens_cache_read=optional_sum(
|
51
|
+
self.input_tokens_cache_read, other.input_tokens_cache_read
|
52
|
+
),
|
53
|
+
reasoning_tokens=optional_sum(
|
54
|
+
self.reasoning_tokens, other.reasoning_tokens
|
55
|
+
),
|
56
|
+
)
|
57
|
+
|
33
58
|
|
34
59
|
StopReason = Literal[
|
35
60
|
"stop",
|
inspect_ai/model/_openai.py
CHANGED
@@ -26,7 +26,6 @@ from anthropic.types import (
|
|
26
26
|
TextBlockParam,
|
27
27
|
ThinkingBlock,
|
28
28
|
ThinkingBlockParam,
|
29
|
-
ToolBash20250124Param,
|
30
29
|
ToolParam,
|
31
30
|
ToolResultBlockParam,
|
32
31
|
ToolTextEditor20250124Param,
|
@@ -76,6 +75,7 @@ class AnthropicAPI(ModelAPI):
|
|
76
75
|
base_url: str | None = None,
|
77
76
|
api_key: str | None = None,
|
78
77
|
config: GenerateConfig = GenerateConfig(),
|
78
|
+
streaming: bool | Literal["auto"] = "auto",
|
79
79
|
**model_args: Any,
|
80
80
|
):
|
81
81
|
# extract any service prefix from model name
|
@@ -85,6 +85,9 @@ class AnthropicAPI(ModelAPI):
|
|
85
85
|
else:
|
86
86
|
self.service = None
|
87
87
|
|
88
|
+
# record steraming pref
|
89
|
+
self.streaming = streaming
|
90
|
+
|
88
91
|
# collect generate model_args (then delete them so we can pass the rest on)
|
89
92
|
def collect_model_arg(name: str) -> Any | None:
|
90
93
|
nonlocal model_args
|
@@ -224,8 +227,13 @@ class AnthropicAPI(ModelAPI):
|
|
224
227
|
if self.extra_body is not None:
|
225
228
|
request["extra_body"] = self.extra_body
|
226
229
|
|
227
|
-
# make request (stream if we are using reasoning)
|
228
|
-
|
230
|
+
# make request (unless overrideen, stream if we are using reasoning)
|
231
|
+
streaming = (
|
232
|
+
self.is_using_thinking(config)
|
233
|
+
if self.streaming == "auto"
|
234
|
+
else self.streaming
|
235
|
+
)
|
236
|
+
if streaming:
|
229
237
|
async with self.client.messages.stream(**request) as stream:
|
230
238
|
message = await stream.get_final_message()
|
231
239
|
else:
|
@@ -489,11 +497,7 @@ class AnthropicAPI(ModelAPI):
|
|
489
497
|
self, tool: ToolInfo, config: GenerateConfig
|
490
498
|
) -> Optional["ToolParamDef"]:
|
491
499
|
return (
|
492
|
-
(
|
493
|
-
self.computer_use_tool_param(tool)
|
494
|
-
or self.text_editor_tool_param(tool)
|
495
|
-
or self.bash_tool_param(tool)
|
496
|
-
)
|
500
|
+
(self.computer_use_tool_param(tool) or self.text_editor_tool_param(tool))
|
497
501
|
if config.internal_tools is not False
|
498
502
|
else None
|
499
503
|
)
|
@@ -564,23 +568,10 @@ class AnthropicAPI(ModelAPI):
|
|
564
568
|
else:
|
565
569
|
return None
|
566
570
|
|
567
|
-
def bash_tool_param(self, tool: ToolInfo) -> Optional[ToolBash20250124Param]:
|
568
|
-
# check for compatible 'bash' tool
|
569
|
-
if tool.name == "bash_session" and (
|
570
|
-
sorted(tool.parameters.properties.keys()) == sorted(["command", "restart"])
|
571
|
-
):
|
572
|
-
return ToolBash20250124Param(type="bash_20250124", name="bash")
|
573
|
-
# not a bash tool
|
574
|
-
else:
|
575
|
-
return None
|
576
|
-
|
577
571
|
|
578
572
|
# tools can be either a stock tool param or a special Anthropic native use tool param
|
579
573
|
ToolParamDef = (
|
580
|
-
ToolParam
|
581
|
-
| BetaToolComputerUse20250124Param
|
582
|
-
| ToolTextEditor20250124Param
|
583
|
-
| ToolBash20250124Param
|
574
|
+
ToolParam | BetaToolComputerUse20250124Param | ToolTextEditor20250124Param
|
584
575
|
)
|
585
576
|
|
586
577
|
|
@@ -589,7 +580,6 @@ def add_cache_control(
|
|
589
580
|
| ToolParam
|
590
581
|
| BetaToolComputerUse20250124Param
|
591
582
|
| ToolTextEditor20250124Param
|
592
|
-
| ToolBash20250124Param
|
593
583
|
| dict[str, Any],
|
594
584
|
) -> None:
|
595
585
|
cast(dict[str, Any], param)["cache_control"] = {"type": "ephemeral"}
|
@@ -1,3 +1,5 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
import concurrent
|
2
4
|
import concurrent.futures
|
3
5
|
import copy
|
@@ -26,7 +28,12 @@ from transformers import ( # type: ignore
|
|
26
28
|
from typing_extensions import override
|
27
29
|
|
28
30
|
from inspect_ai._util.constants import DEFAULT_MAX_TOKENS
|
29
|
-
from inspect_ai._util.content import
|
31
|
+
from inspect_ai._util.content import (
|
32
|
+
ContentAudio,
|
33
|
+
ContentImage,
|
34
|
+
ContentText,
|
35
|
+
ContentVideo,
|
36
|
+
)
|
30
37
|
from inspect_ai._util.trace import trace_action
|
31
38
|
from inspect_ai.tool import ToolChoice, ToolInfo
|
32
39
|
|
@@ -85,6 +92,7 @@ class HuggingFaceAPI(ModelAPI):
|
|
85
92
|
self.batch_size = collect_model_arg("batch_size")
|
86
93
|
self.chat_template = collect_model_arg("chat_template")
|
87
94
|
self.tokenizer_call_args = collect_model_arg("tokenizer_call_args")
|
95
|
+
self.enable_thinking = collect_model_arg("enable_thinking")
|
88
96
|
if self.tokenizer_call_args is None:
|
89
97
|
self.tokenizer_call_args = {}
|
90
98
|
|
@@ -263,6 +271,7 @@ class HuggingFaceAPI(ModelAPI):
|
|
263
271
|
elif "qwen" in self.model_name.lower():
|
264
272
|
hf_messages = inspect_tools_to_string(hf_messages)
|
265
273
|
|
274
|
+
hf_messages = message_content_to_string(hf_messages)
|
266
275
|
# apply chat template
|
267
276
|
if self.tokenizer.chat_template is not None:
|
268
277
|
chat = self.tokenizer.apply_chat_template(
|
@@ -270,6 +279,7 @@ class HuggingFaceAPI(ModelAPI):
|
|
270
279
|
add_generation_prompt=True,
|
271
280
|
tokenize=False,
|
272
281
|
tools=tools_list if len(tools_list) > 0 else None,
|
282
|
+
enable_thinking=self.enable_thinking, # not all models use this, check if it is supported
|
273
283
|
)
|
274
284
|
else:
|
275
285
|
chat = ""
|
@@ -279,6 +289,22 @@ class HuggingFaceAPI(ModelAPI):
|
|
279
289
|
return cast(str, chat)
|
280
290
|
|
281
291
|
|
292
|
+
def message_content_to_string(messages: list[ChatMessage]) -> list[ChatMessage]:
|
293
|
+
"""Convert list of content in `ChatMessageAssistant`, `ChatMessageUser` or `ChatMessageSystem` to a string."""
|
294
|
+
for message in messages:
|
295
|
+
if isinstance(message.content, list):
|
296
|
+
is_multimodal = any(
|
297
|
+
isinstance(item, ContentAudio | ContentImage | ContentVideo)
|
298
|
+
for item in message.content
|
299
|
+
)
|
300
|
+
if is_multimodal:
|
301
|
+
raise NotImplementedError(
|
302
|
+
"HuggingFace provider does not support multimodal content, please provide text inputs only."
|
303
|
+
)
|
304
|
+
message.content = message.text
|
305
|
+
return messages
|
306
|
+
|
307
|
+
|
282
308
|
def shorten_tool_id(messages: list[ChatMessage]) -> list[ChatMessage]:
|
283
309
|
"""Shorten the tool_call_id in the messages to the last 9 characters for Mistral."""
|
284
310
|
for i, message in enumerate(messages):
|
@@ -211,8 +211,15 @@ class O1PreviewChatAPIHandler(ChatAPIHandler):
|
|
211
211
|
This method has an interdependency with `input_with_tools()` (as that is the
|
212
212
|
prompt that asks the model to use the <tool_call>...</tool_call> syntax)
|
213
213
|
"""
|
214
|
-
#
|
214
|
+
# define regex patterns
|
215
|
+
# NOTE: If you change either of these regex patterns, please update the other
|
216
|
+
# tool_call_regex extracts the JSON content (in curly braces) between tool call tags
|
215
217
|
tool_call_regex = rf"<{TOOL_CALL}>\s*(\{{[\s\S]*?\}})\s*</{TOOL_CALL}>"
|
218
|
+
# tool_call_content_regex matches the entire tool call block including tags for extracting
|
219
|
+
# the content outside of the tool call tags
|
220
|
+
tool_call_content_regex = rf"<{TOOL_CALL}>\s*\{{[\s\S]*?\}}\s*</{TOOL_CALL}>"
|
221
|
+
|
222
|
+
# extract tool calls
|
216
223
|
tool_calls_content: list[str] = re.findall(tool_call_regex, response)
|
217
224
|
|
218
225
|
# if there are tool calls proceed with parsing
|
@@ -226,7 +233,6 @@ class O1PreviewChatAPIHandler(ChatAPIHandler):
|
|
226
233
|
]
|
227
234
|
|
228
235
|
# find other content that exists outside tool calls
|
229
|
-
tool_call_content_regex = rf"<{TOOL_CALL}>(?:.|\n)*?</{TOOL_CALL}>"
|
230
236
|
other_content = re.split(tool_call_content_regex, response, flags=re.DOTALL)
|
231
237
|
other_content = [
|
232
238
|
str(content).strip()
|
@@ -136,10 +136,12 @@ def hf() -> type[ModelAPI]:
|
|
136
136
|
|
137
137
|
@modelapi(name="vllm")
|
138
138
|
def vllm() -> type[ModelAPI]:
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
139
|
+
# Only validate OpenAI compatibility (needed for the API interface)
|
140
|
+
validate_openai_client("vLLM API")
|
141
|
+
|
142
|
+
# Import VLLMAPI without checking for vllm package yet
|
143
|
+
# The actual vllm dependency will only be checked if needed to start a server
|
144
|
+
from .vllm import VLLMAPI
|
143
145
|
|
144
146
|
return VLLMAPI
|
145
147
|
|
@@ -257,6 +259,18 @@ def mockllm() -> type[ModelAPI]:
|
|
257
259
|
return MockLLM
|
258
260
|
|
259
261
|
|
262
|
+
@modelapi(name="sglang")
|
263
|
+
def sglang() -> type[ModelAPI]:
|
264
|
+
# Only validate OpenAI compatibility (needed for the API interface)
|
265
|
+
validate_openai_client("SGLang API")
|
266
|
+
|
267
|
+
# Import SGLangAPI without checking for sglang package yet
|
268
|
+
# The actual sglang dependency will only be checked if needed to start a server
|
269
|
+
from .sglang import SGLangAPI
|
270
|
+
|
271
|
+
return SGLangAPI
|
272
|
+
|
273
|
+
|
260
274
|
@modelapi(name="none")
|
261
275
|
def none() -> type[ModelAPI]:
|
262
276
|
from .none import NoModel
|