inspect-ai 0.3.59__py3-none-any.whl → 0.3.60__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. inspect_ai/_cli/eval.py +0 -7
  2. inspect_ai/_display/textual/widgets/samples.py +1 -1
  3. inspect_ai/_eval/eval.py +10 -1
  4. inspect_ai/_eval/loader.py +79 -19
  5. inspect_ai/_eval/registry.py +6 -0
  6. inspect_ai/_eval/score.py +2 -1
  7. inspect_ai/_eval/task/results.py +6 -5
  8. inspect_ai/_eval/task/run.py +11 -11
  9. inspect_ai/_view/www/dist/assets/index.js +262 -303
  10. inspect_ai/_view/www/src/App.mjs +6 -6
  11. inspect_ai/_view/www/src/Types.mjs +1 -1
  12. inspect_ai/_view/www/src/api/Types.ts +133 -0
  13. inspect_ai/_view/www/src/api/{api-browser.mjs → api-browser.ts} +25 -13
  14. inspect_ai/_view/www/src/api/api-http.ts +219 -0
  15. inspect_ai/_view/www/src/api/api-shared.ts +47 -0
  16. inspect_ai/_view/www/src/api/{api-vscode.mjs → api-vscode.ts} +22 -19
  17. inspect_ai/_view/www/src/api/{client-api.mjs → client-api.ts} +93 -53
  18. inspect_ai/_view/www/src/api/index.ts +51 -0
  19. inspect_ai/_view/www/src/api/jsonrpc.ts +225 -0
  20. inspect_ai/_view/www/src/components/DownloadButton.mjs +1 -1
  21. inspect_ai/_view/www/src/index.js +2 -2
  22. inspect_ai/_view/www/src/log/{remoteLogFile.mjs → remoteLogFile.ts} +62 -46
  23. inspect_ai/_view/www/src/navbar/Navbar.mjs +1 -1
  24. inspect_ai/_view/www/src/navbar/SecondaryBar.mjs +1 -1
  25. inspect_ai/_view/www/src/samples/SampleList.mjs +1 -1
  26. inspect_ai/_view/www/src/samples/SampleScores.mjs +1 -1
  27. inspect_ai/_view/www/src/samples/SamplesDescriptor.mjs +14 -14
  28. inspect_ai/_view/www/src/samples/SamplesTab.mjs +10 -10
  29. inspect_ai/_view/www/src/samples/tools/SortFilter.mjs +2 -2
  30. inspect_ai/_view/www/src/utils/{Json.mjs → json-worker.ts} +1 -3
  31. inspect_ai/_view/www/src/utils/vscode.ts +36 -0
  32. inspect_ai/_view/www/src/workspace/WorkSpace.mjs +1 -1
  33. inspect_ai/approval/_human/manager.py +1 -1
  34. inspect_ai/model/_call_tools.py +55 -0
  35. inspect_ai/model/_conversation.py +1 -4
  36. inspect_ai/model/_generate_config.py +2 -8
  37. inspect_ai/model/_model_output.py +15 -0
  38. inspect_ai/model/_openai.py +383 -0
  39. inspect_ai/model/_providers/anthropic.py +52 -11
  40. inspect_ai/model/_providers/azureai.py +1 -1
  41. inspect_ai/model/_providers/goodfire.py +248 -0
  42. inspect_ai/model/_providers/groq.py +7 -3
  43. inspect_ai/model/_providers/hf.py +6 -0
  44. inspect_ai/model/_providers/mistral.py +2 -1
  45. inspect_ai/model/_providers/openai.py +36 -202
  46. inspect_ai/model/_providers/openai_o1.py +2 -4
  47. inspect_ai/model/_providers/providers.py +22 -0
  48. inspect_ai/model/_providers/together.py +4 -4
  49. inspect_ai/model/_providers/util/__init__.py +2 -3
  50. inspect_ai/model/_providers/util/hf_handler.py +1 -1
  51. inspect_ai/model/_providers/util/llama31.py +1 -1
  52. inspect_ai/model/_providers/util/util.py +0 -76
  53. inspect_ai/scorer/_metric.py +3 -0
  54. inspect_ai/scorer/_scorer.py +2 -1
  55. inspect_ai/solver/__init__.py +2 -0
  56. inspect_ai/solver/_basic_agent.py +1 -1
  57. inspect_ai/solver/_bridge/__init__.py +3 -0
  58. inspect_ai/solver/_bridge/bridge.py +100 -0
  59. inspect_ai/solver/_bridge/patch.py +170 -0
  60. inspect_ai/solver/_solver.py +6 -0
  61. inspect_ai/util/_display.py +5 -0
  62. inspect_ai/util/_sandbox/docker/prereqs.py +1 -1
  63. {inspect_ai-0.3.59.dist-info → inspect_ai-0.3.60.dist-info}/METADATA +3 -2
  64. {inspect_ai-0.3.59.dist-info → inspect_ai-0.3.60.dist-info}/RECORD +68 -63
  65. inspect_ai/_view/www/src/api/Types.mjs +0 -117
  66. inspect_ai/_view/www/src/api/api-http.mjs +0 -300
  67. inspect_ai/_view/www/src/api/api-shared.mjs +0 -10
  68. inspect_ai/_view/www/src/api/index.mjs +0 -49
  69. inspect_ai/_view/www/src/api/jsonrpc.mjs +0 -208
  70. inspect_ai/_view/www/src/utils/vscode.mjs +0 -16
  71. {inspect_ai-0.3.59.dist-info → inspect_ai-0.3.60.dist-info}/LICENSE +0 -0
  72. {inspect_ai-0.3.59.dist-info → inspect_ai-0.3.60.dist-info}/WHEEL +0 -0
  73. {inspect_ai-0.3.59.dist-info → inspect_ai-0.3.60.dist-info}/entry_points.txt +0 -0
  74. {inspect_ai-0.3.59.dist-info → inspect_ai-0.3.60.dist-info}/top_level.txt +0 -0
@@ -24,13 +24,13 @@ from .._model_output import (
24
24
  ModelOutput,
25
25
  ModelUsage,
26
26
  StopReason,
27
+ as_stop_reason,
27
28
  )
29
+ from .._openai import chat_message_assistant_from_openai
28
30
  from .openai import (
29
31
  OpenAIAPI,
30
- chat_message_assistant,
31
32
  )
32
33
  from .util import (
33
- as_stop_reason,
34
34
  chat_api_input,
35
35
  chat_api_request,
36
36
  environment_prerequisite_error,
@@ -68,7 +68,7 @@ def chat_choices_from_response_together(
68
68
  logprobs_models.append(Logprobs(content=logprobs_sequence))
69
69
  return [
70
70
  ChatCompletionChoice(
71
- message=chat_message_assistant(choice.message, tools),
71
+ message=chat_message_assistant_from_openai(choice.message, tools),
72
72
  stop_reason=as_stop_reason(choice.finish_reason),
73
73
  logprobs=logprobs,
74
74
  )
@@ -99,7 +99,7 @@ class TogetherAIAPI(OpenAIAPI):
99
99
 
100
100
  # Together uses a default of 512 so we bump it up
101
101
  @override
102
- def max_tokens(self) -> int:
102
+ def max_tokens(self) -> int | None:
103
103
  return DEFAULT_MAX_TOKENS
104
104
 
105
105
  @override
@@ -1,3 +1,5 @@
1
+ from ..._call_tools import parse_tool_call, tool_parse_error_message
2
+ from ..._model_output import as_stop_reason
1
3
  from .chatapi import (
2
4
  ChatAPIHandler,
3
5
  ChatAPIMessage,
@@ -8,11 +10,8 @@ from .chatapi import (
8
10
  from .hf_handler import HFHandler
9
11
  from .llama31 import Llama31Handler
10
12
  from .util import (
11
- as_stop_reason,
12
13
  environment_prerequisite_error,
13
14
  model_base_url,
14
- parse_tool_call,
15
- tool_parse_error_message,
16
15
  )
17
16
 
18
17
  __all__ = [
@@ -8,9 +8,9 @@ from typing_extensions import override
8
8
  from inspect_ai.tool._tool_call import ToolCall
9
9
  from inspect_ai.tool._tool_info import ToolInfo
10
10
 
11
+ from ..._call_tools import parse_tool_call, tool_parse_error_message
11
12
  from ..._chat_message import ChatMessageAssistant
12
13
  from .chatapi import ChatAPIHandler
13
- from .util import parse_tool_call, tool_parse_error_message
14
14
 
15
15
  logger = getLogger(__name__)
16
16
 
@@ -9,6 +9,7 @@ from typing_extensions import override
9
9
  from inspect_ai.tool._tool_call import ToolCall
10
10
  from inspect_ai.tool._tool_info import ToolInfo
11
11
 
12
+ from ..._call_tools import parse_tool_call, tool_parse_error_message
12
13
  from ..._chat_message import (
13
14
  ChatMessage,
14
15
  ChatMessageAssistant,
@@ -16,7 +17,6 @@ from ..._chat_message import (
16
17
  ChatMessageTool,
17
18
  )
18
19
  from .chatapi import ChatAPIHandler, ChatAPIMessage
19
- from .util import parse_tool_call, tool_parse_error_message
20
20
 
21
21
  logger = getLogger(__name__)
22
22
 
@@ -1,34 +1,11 @@
1
- import json
2
1
  import os
3
2
  from logging import getLogger
4
- from typing import Any
5
-
6
- import yaml
7
3
 
8
4
  from inspect_ai._util.error import PrerequisiteError
9
- from inspect_ai.tool._tool_call import ToolCall
10
- from inspect_ai.tool._tool_info import ToolInfo
11
-
12
- from ..._model_output import StopReason
13
5
 
14
6
  logger = getLogger(__name__)
15
7
 
16
8
 
17
- def as_stop_reason(reason: str | None) -> StopReason:
18
- """Encode common reason strings into standard StopReason."""
19
- match reason:
20
- case "stop" | "eos":
21
- return "stop"
22
- case "length":
23
- return "max_tokens"
24
- case "tool_calls" | "function_call":
25
- return "tool_calls"
26
- case "content_filter" | "model_length" | "max_tokens":
27
- return reason
28
- case _:
29
- return "unknown"
30
-
31
-
32
9
  def model_base_url(base_url: str | None, env_vars: str | list[str]) -> str | None:
33
10
  if base_url:
34
11
  return base_url
@@ -44,59 +21,6 @@ def model_base_url(base_url: str | None, env_vars: str | list[str]) -> str | Non
44
21
  return os.getenv("INSPECT_EVAL_MODEL_BASE_URL", None)
45
22
 
46
23
 
47
- def tool_parse_error_message(arguments: str, ex: Exception) -> str:
48
- return f"Error parsing the following tool call arguments:\n\n{arguments}\n\nError details: {ex}"
49
-
50
-
51
- def parse_tool_call(
52
- id: str, function: str, arguments: str, tools: list[ToolInfo]
53
- ) -> ToolCall:
54
- error: str | None = None
55
- arguments_dict: dict[str, Any] = {}
56
-
57
- def report_parse_error(ex: Exception) -> None:
58
- nonlocal error
59
- error = tool_parse_error_message(arguments, ex)
60
- logger.info(error)
61
-
62
- # if the arguments is a dict, then handle it with a plain json.loads
63
- arguments = arguments.strip()
64
- if arguments.startswith("{"):
65
- try:
66
- arguments_dict = json.loads(arguments)
67
- except json.JSONDecodeError as ex:
68
- report_parse_error(ex)
69
-
70
- # otherwise parse it as yaml (which will pickup unquoted strings, numbers, and true/false)
71
- # and then create a dict that maps it to the first function argument
72
- else:
73
- tool_info = next(
74
- (
75
- tool
76
- for tool in tools
77
- if tool.name == function and len(tool.parameters.properties) > 0
78
- ),
79
- None,
80
- )
81
- if tool_info:
82
- param_names = list(tool_info.parameters.properties.keys())
83
- try:
84
- value = yaml.safe_load(arguments)
85
- arguments_dict[param_names[0]] = value
86
- except yaml.error.YAMLError:
87
- # If the yaml parser fails, we treat it as a string argument.
88
- arguments_dict[param_names[0]] = arguments
89
-
90
- # return ToolCall with error payload
91
- return ToolCall(
92
- id=id,
93
- function=function,
94
- arguments=arguments_dict,
95
- type="function",
96
- parse_error=error,
97
- )
98
-
99
-
100
24
  def environment_prerequisite_error(
101
25
  client: str, env_vars: str | list[str]
102
26
  ) -> PrerequisiteError:
@@ -125,6 +125,9 @@ class SampleScore(BaseModel):
125
125
  sample_id: str | int | None = Field(default=None)
126
126
  """A sample id"""
127
127
 
128
+ scorer: str | None = Field(default=None)
129
+ """Registry name of scorer that created this score."""
130
+
128
131
 
129
132
  ValueToFloat = Callable[[Value], float]
130
133
  """Function used by metrics to translate from a Score value to a float value."""
@@ -1,3 +1,4 @@
1
+ from functools import wraps
1
2
  from typing import (
2
3
  Any,
3
4
  Callable,
@@ -100,7 +101,6 @@ def scorer(
100
101
 
101
102
  Returns:
102
103
  Scorer with registry attributes.
103
-
104
104
  """
105
105
 
106
106
  def wrapper(scorer_type: Callable[P, Scorer]) -> Callable[P, Scorer]:
@@ -110,6 +110,7 @@ def scorer(
110
110
  )
111
111
 
112
112
  # wrap instantiations of scorer so they carry registry info and metrics
113
+ @wraps(scorer_type)
113
114
  def scorer_wrapper(*args: P.args, **kwargs: P.kwargs) -> Scorer:
114
115
  scorer = scorer_type(*args, **kwargs)
115
116
 
@@ -1,6 +1,7 @@
1
1
  from inspect_ai._util.deprecation import relocated_module_attribute
2
2
 
3
3
  from ._basic_agent import basic_agent
4
+ from ._bridge import bridge
4
5
  from ._chain import chain
5
6
  from ._critique import self_critique
6
7
  from ._fork import fork
@@ -14,6 +15,7 @@ from ._use_tools import use_tools
14
15
 
15
16
  __all__ = [
16
17
  "basic_agent",
18
+ "bridge",
17
19
  "human_agent",
18
20
  "chain",
19
21
  "fork",
@@ -119,7 +119,7 @@ def basic_agent(
119
119
  # resolve tools
120
120
  if tools is None:
121
121
  tools = []
122
- tools = tools if isinstance(tools, Solver) else use_tools(tools)
122
+ tools = tools if isinstance(tools, Solver) else use_tools(tools, append=True)
123
123
 
124
124
  # resolve score_value function
125
125
  score_value_fn = score_value or value_to_float()
@@ -0,0 +1,3 @@
1
+ from .bridge import bridge
2
+
3
+ __all__ = ["bridge"]
@@ -0,0 +1,100 @@
1
+ from typing import Any, Awaitable, Callable
2
+
3
+ from jsonschema import Draft7Validator
4
+ from pydantic import BaseModel, Field, ValidationError
5
+ from pydantic_core import to_json
6
+
7
+ from inspect_ai._util._async import is_callable_coroutine
8
+ from inspect_ai.model._chat_message import ChatMessage, ChatMessageUser
9
+ from inspect_ai.model._providers.providers import validate_openai_client
10
+ from inspect_ai.scorer._metric import Score
11
+
12
+ from .._solver import Generate, Solver, solver
13
+ from .._task_state import TaskState
14
+
15
+
16
+ @solver
17
+ def bridge(agent: Callable[[dict[str, Any]], Awaitable[dict[str, Any]]]) -> Solver:
18
+ """Bridge an external agent into an Inspect Solver.
19
+
20
+ See documentation at https://inspect.ai-safety-institute.org.uk/agent-bridge.html
21
+
22
+ Args:
23
+ agent: Callable which takes a sample `dict` and returns a result `dict`.
24
+
25
+ Returns:
26
+ Standard Inspect solver.
27
+ """
28
+ validate_openai_client("Solver bridge()")
29
+
30
+ from openai.types.chat import ChatCompletionMessageParam
31
+
32
+ from inspect_ai.model._openai import (
33
+ chat_messages_from_openai,
34
+ openai_chat_messages,
35
+ )
36
+
37
+ from .patch import openai_request_to_inspect_model
38
+
39
+ class BridgeSample(BaseModel):
40
+ sample_id: str
41
+ epoch: int
42
+ input: list[ChatCompletionMessageParam]
43
+ metadata: dict[str, Any]
44
+ target: list[str]
45
+
46
+ class BridgeResult(BaseModel):
47
+ output: str
48
+ messages: list[ChatCompletionMessageParam] | None = Field(default=None)
49
+ scores: dict[str, Score] | None = Field(default=None)
50
+
51
+ result_schema = BridgeResult.model_json_schema()
52
+ result_validator = Draft7Validator(result_schema)
53
+
54
+ # validate that the agent is an async function
55
+ if not is_callable_coroutine(agent):
56
+ raise TypeError(f"'{agent.__name__}' is not declared as an async callable.")
57
+
58
+ async def solve(state: TaskState, generate: Generate) -> TaskState:
59
+ # resolve input to array
60
+ input: list[ChatMessage] = (
61
+ [ChatMessageUser(content=state.input)]
62
+ if isinstance(state.input, str)
63
+ else state.input
64
+ )
65
+
66
+ # create sample
67
+ sample = BridgeSample(
68
+ sample_id=str(state.sample_id),
69
+ epoch=state.epoch,
70
+ input=await openai_chat_messages(input, state.model.name),
71
+ metadata=state.metadata,
72
+ target=list(state.target),
73
+ )
74
+
75
+ # run target function
76
+ async with openai_request_to_inspect_model():
77
+ # call the function
78
+ result_dict = await agent(sample.model_dump())
79
+ try:
80
+ result = BridgeResult.model_validate(result_dict)
81
+ except ValidationError:
82
+ # if we fail to validate provide a better human readable error
83
+ errors = list(result_validator.iter_errors(result_dict))
84
+ message = "\n".join(
85
+ ["Result returned from bridged solver is not valid:"]
86
+ + [f" - {error.message}" for error in errors]
87
+ + ["", to_json(result_dict, indent=2).decode()]
88
+ )
89
+ raise ValueError(message)
90
+
91
+ # update and return state
92
+ state.output.completion = result.output
93
+ if result.messages is not None:
94
+ state.messages = chat_messages_from_openai(result.messages)
95
+ if result.scores is not None:
96
+ state.scores = result.scores
97
+
98
+ return state
99
+
100
+ return solve
@@ -0,0 +1,170 @@
1
+ import contextlib
2
+ import re
3
+ from contextvars import ContextVar
4
+ from functools import wraps
5
+ from time import time
6
+ from typing import Any, AsyncGenerator, Optional, Type, cast
7
+
8
+ from openai._base_client import AsyncAPIClient, _AsyncStreamT
9
+ from openai._models import FinalRequestOptions
10
+ from openai._types import ResponseT
11
+ from openai.types.chat import (
12
+ ChatCompletion,
13
+ ChatCompletionMessageParam,
14
+ ChatCompletionToolParam,
15
+ )
16
+ from shortuuid import uuid
17
+
18
+ from inspect_ai.model._generate_config import GenerateConfig
19
+ from inspect_ai.model._model import get_model
20
+ from inspect_ai.model._openai import (
21
+ chat_messages_from_openai,
22
+ openai_chat_choices,
23
+ openai_completion_usage,
24
+ )
25
+ from inspect_ai.solver._task_state import sample_state
26
+ from inspect_ai.tool._tool_info import ToolInfo
27
+ from inspect_ai.tool._tool_params import ToolParams
28
+
29
+
30
+ @contextlib.asynccontextmanager
31
+ async def openai_request_to_inspect_model() -> AsyncGenerator[None, None]:
32
+ # ensure one time init
33
+ init_openai_request_patch()
34
+
35
+ # set the patch enabled for this context and child coroutines
36
+ token = _patch_enabled.set(True)
37
+ try:
38
+ yield
39
+ finally:
40
+ _patch_enabled.reset(token)
41
+
42
+
43
+ _patch_initialised: bool = False
44
+
45
+ _patch_enabled: ContextVar[bool] = ContextVar(
46
+ "openai_request_patch_enabled", default=False
47
+ )
48
+
49
+
50
+ def init_openai_request_patch() -> None:
51
+ global _patch_initialised
52
+ if not _patch_initialised:
53
+ # get reference to original method
54
+ original_request = getattr(AsyncAPIClient, "request")
55
+ if original_request is None:
56
+ raise RuntimeError("Couldn't find 'request' method on AsyncAPIClient")
57
+
58
+ @wraps(original_request)
59
+ async def patched_request(
60
+ self: AsyncAPIClient,
61
+ cast_to: Type[ResponseT],
62
+ options: FinalRequestOptions,
63
+ *,
64
+ stream: bool = False,
65
+ stream_cls: type[_AsyncStreamT] | None = None,
66
+ remaining_retries: Optional[int] = None,
67
+ ) -> Any:
68
+ # we have patched the underlying request method so now need to figure out when to
69
+ # patch and when to stand down
70
+ if (
71
+ # enabled for this coroutine
72
+ _patch_enabled.get()
73
+ # completions request
74
+ and options.url == "/chat/completions"
75
+ # call to openai not another service (e.g. TogetherAI)
76
+ and self.base_url == "https://api.openai.com/v1/"
77
+ ):
78
+ # must also be an explicit request for an inspect model
79
+ json_data = cast(dict[str, Any], options.json_data)
80
+ model_name = str(json_data["model"])
81
+ if re.match(r"^inspect/?", model_name):
82
+ return await inspect_model_request(model_name, options)
83
+
84
+ # otherwise just delegate
85
+ return await original_request(
86
+ self,
87
+ cast_to,
88
+ options,
89
+ stream=stream,
90
+ stream_cls=stream_cls,
91
+ remaining_retries=remaining_retries,
92
+ )
93
+
94
+ setattr(AsyncAPIClient, "request", patched_request)
95
+
96
+
97
+ async def inspect_model_request(
98
+ model_name: str, options: FinalRequestOptions
99
+ ) -> ChatCompletion:
100
+ # convert openai messages to inspect messages
101
+ json_data = cast(dict[str, Any], options.json_data)
102
+ messages: list[ChatCompletionMessageParam] = json_data["messages"]
103
+ input = chat_messages_from_openai(messages)
104
+
105
+ # convert openai tools to inspect tools
106
+ tools: list[ChatCompletionToolParam] = json_data.get("tools", [])
107
+ inspect_tools: list[ToolInfo] = []
108
+ for tool in tools:
109
+ function = tool["function"].copy()
110
+ inspect_tools.append(
111
+ ToolInfo(
112
+ name=function["name"],
113
+ description=function["description"],
114
+ parameters=ToolParams.model_validate(function["parameters"]),
115
+ )
116
+ )
117
+
118
+ # resolve model
119
+ if model_name == "inspect":
120
+ model = get_model()
121
+ else:
122
+ model = get_model(model_name.removeprefix("inspect/"))
123
+
124
+ output = await model.generate(
125
+ input=input,
126
+ tools=inspect_tools,
127
+ config=generate_config_from_openai(options),
128
+ )
129
+
130
+ # if we are using the "default" inspect model for the task, update state.messages
131
+ if model_name == "inspect":
132
+ state = sample_state()
133
+ if state:
134
+ state.messages = input + [output.choices[0].message]
135
+
136
+ # inspect completion to openai completion
137
+ return ChatCompletion(
138
+ id=uuid(),
139
+ created=int(time()),
140
+ object="chat.completion",
141
+ choices=openai_chat_choices(output.choices),
142
+ model=model_name,
143
+ usage=openai_completion_usage(output.usage) if output.usage else None,
144
+ )
145
+
146
+
147
+ def generate_config_from_openai(options: FinalRequestOptions) -> GenerateConfig:
148
+ # get options dict
149
+ json_data = cast(dict[str, Any], options.json_data)
150
+
151
+ config = GenerateConfig()
152
+ config.max_tokens = json_data.get(
153
+ "max_completion_tokens", json_data.get("max_tokens", None)
154
+ )
155
+ config.top_p = json_data.get("top_p", None)
156
+ config.temperature = json_data.get("temperature", None)
157
+ stop = json_data.get("stop", None)
158
+ if stop:
159
+ config.stop_seqs = [stop] if isinstance(stop, str) else stop
160
+ config.frequency_penalty = json_data.get("frequency_penalty", None)
161
+ config.presence_penalty = json_data.get("presence_penalty", None)
162
+ config.seed = json_data.get("seed", None)
163
+ config.num_choices = json_data.get("n", None)
164
+ config.logprobs = json_data.get("logprobs", None)
165
+ config.top_logprobs = json_data.get("top_logprobs", None)
166
+ config.logit_bias = json_data.get("logit_bias", None)
167
+ config.parallel_tool_calls = json_data.get("parallel_tool_calls", None)
168
+ config.reasoning_effort = json_data.get("reasoning_effort", None)
169
+
170
+ return config
@@ -180,6 +180,7 @@ def solver(
180
180
  solver_type, name if name else getattr(solver_type, "__name__")
181
181
  )
182
182
 
183
+ @wraps(solver_type)
183
184
  def solver_wrapper(*args: P.args, **kwargs: P.kwargs) -> Solver:
184
185
  solver = solver_type(*args, **kwargs)
185
186
 
@@ -193,6 +194,7 @@ def solver(
193
194
  if inspect.isclass(type(solver)):
194
195
  original_call = solver.__call__
195
196
 
197
+ @wraps(original_call)
196
198
  async def call_with_state(
197
199
  state: TaskState, generate: Generate
198
200
  ) -> TaskState:
@@ -225,6 +227,10 @@ def solver(
225
227
 
226
228
  return registered_solver
227
229
 
230
+ # functools.wraps overrides the return type annotation of the inner function, so
231
+ # we explicitly set it again
232
+ solver_wrapper.__annotations__["return"] = Solver
233
+
228
234
  return solver_register(cast(Callable[P, Solver], solver_wrapper), solver_name)
229
235
 
230
236
  # for decorators with an explicit name, one more wrapper for the name
@@ -49,3 +49,8 @@ def display_type() -> DisplayType:
49
49
  return _display_type
50
50
  else:
51
51
  return init_display_type()
52
+
53
+
54
+ def display_type_initialized() -> bool:
55
+ global _display_type
56
+ return _display_type is not None
@@ -57,7 +57,7 @@ async def validate_docker_compose(
57
57
  version: str = DOCKER_COMPOSE_REQUIRED_VERSION,
58
58
  ) -> None:
59
59
  def parse_version(stdout: str) -> semver.Version:
60
- version = json.loads(stdout)["version"].removeprefix("v")
60
+ version = json.loads(stdout)["version"].removeprefix("v").split("+")[0]
61
61
  return semver.Version.parse(version)
62
62
 
63
63
  await validate_version(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: inspect_ai
3
- Version: 0.3.59
3
+ Version: 0.3.60
4
4
  Summary: Framework for large language model evaluations
5
5
  Author: UK AI Safety Institute
6
6
  License: MIT License
@@ -54,6 +54,7 @@ Requires-Dist: aioboto3; extra == "dev"
54
54
  Requires-Dist: azure-ai-inference; extra == "dev"
55
55
  Requires-Dist: google-cloud-aiplatform; extra == "dev"
56
56
  Requires-Dist: google-generativeai; extra == "dev"
57
+ Requires-Dist: goodfire; extra == "dev"
57
58
  Requires-Dist: groq; extra == "dev"
58
59
  Requires-Dist: ipython; extra == "dev"
59
60
  Requires-Dist: mistralai; extra == "dev"
@@ -67,7 +68,7 @@ Requires-Dist: pytest-asyncio; extra == "dev"
67
68
  Requires-Dist: pytest-cov; extra == "dev"
68
69
  Requires-Dist: pytest-dotenv; extra == "dev"
69
70
  Requires-Dist: pytest-xdist; extra == "dev"
70
- Requires-Dist: ruff==0.9.2; extra == "dev"
71
+ Requires-Dist: ruff==0.9.3; extra == "dev"
71
72
  Requires-Dist: textual-dev>=0.86.2; extra == "dev"
72
73
  Requires-Dist: types-PyYAML; extra == "dev"
73
74
  Requires-Dist: types-beautifulsoup4; extra == "dev"