inspect-ai 0.3.59__py3-none-any.whl → 0.3.61__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. inspect_ai/_cli/eval.py +0 -8
  2. inspect_ai/_display/textual/widgets/samples.py +1 -1
  3. inspect_ai/_eval/eval.py +10 -1
  4. inspect_ai/_eval/loader.py +79 -19
  5. inspect_ai/_eval/registry.py +6 -0
  6. inspect_ai/_eval/score.py +2 -1
  7. inspect_ai/_eval/task/generate.py +41 -35
  8. inspect_ai/_eval/task/results.py +6 -5
  9. inspect_ai/_eval/task/run.py +21 -15
  10. inspect_ai/_util/hooks.py +17 -7
  11. inspect_ai/_view/www/dist/assets/index.js +262 -303
  12. inspect_ai/_view/www/package.json +1 -1
  13. inspect_ai/_view/www/src/App.mjs +6 -6
  14. inspect_ai/_view/www/src/Types.mjs +1 -1
  15. inspect_ai/_view/www/src/api/Types.ts +133 -0
  16. inspect_ai/_view/www/src/api/{api-browser.mjs → api-browser.ts} +25 -13
  17. inspect_ai/_view/www/src/api/api-http.ts +219 -0
  18. inspect_ai/_view/www/src/api/api-shared.ts +47 -0
  19. inspect_ai/_view/www/src/api/{api-vscode.mjs → api-vscode.ts} +22 -19
  20. inspect_ai/_view/www/src/api/{client-api.mjs → client-api.ts} +93 -53
  21. inspect_ai/_view/www/src/api/index.ts +51 -0
  22. inspect_ai/_view/www/src/api/jsonrpc.ts +225 -0
  23. inspect_ai/_view/www/src/components/DownloadButton.mjs +1 -1
  24. inspect_ai/_view/www/src/index.js +2 -2
  25. inspect_ai/_view/www/src/log/{remoteLogFile.mjs → remoteLogFile.ts} +62 -46
  26. inspect_ai/_view/www/src/navbar/Navbar.mjs +1 -1
  27. inspect_ai/_view/www/src/navbar/SecondaryBar.mjs +1 -1
  28. inspect_ai/_view/www/src/samples/SampleList.mjs +1 -1
  29. inspect_ai/_view/www/src/samples/SampleScores.mjs +1 -1
  30. inspect_ai/_view/www/src/samples/SamplesDescriptor.mjs +14 -14
  31. inspect_ai/_view/www/src/samples/SamplesTab.mjs +10 -10
  32. inspect_ai/_view/www/src/samples/tools/SortFilter.mjs +2 -2
  33. inspect_ai/_view/www/src/utils/{Json.mjs → json-worker.ts} +1 -3
  34. inspect_ai/_view/www/src/utils/vscode.ts +36 -0
  35. inspect_ai/_view/www/src/workspace/WorkSpace.mjs +1 -1
  36. inspect_ai/approval/_human/manager.py +1 -1
  37. inspect_ai/model/_call_tools.py +55 -0
  38. inspect_ai/model/_chat_message.py +2 -2
  39. inspect_ai/model/_conversation.py +1 -4
  40. inspect_ai/model/_generate_config.py +2 -8
  41. inspect_ai/model/_model.py +90 -25
  42. inspect_ai/model/_model_output.py +15 -0
  43. inspect_ai/model/_openai.py +383 -0
  44. inspect_ai/model/_providers/anthropic.py +52 -14
  45. inspect_ai/model/_providers/azureai.py +1 -1
  46. inspect_ai/model/_providers/goodfire.py +248 -0
  47. inspect_ai/model/_providers/groq.py +7 -3
  48. inspect_ai/model/_providers/hf.py +6 -0
  49. inspect_ai/model/_providers/mistral.py +2 -1
  50. inspect_ai/model/_providers/openai.py +36 -202
  51. inspect_ai/model/_providers/openai_o1.py +2 -4
  52. inspect_ai/model/_providers/providers.py +22 -0
  53. inspect_ai/model/_providers/together.py +4 -4
  54. inspect_ai/model/_providers/util/__init__.py +2 -3
  55. inspect_ai/model/_providers/util/hf_handler.py +1 -1
  56. inspect_ai/model/_providers/util/llama31.py +1 -1
  57. inspect_ai/model/_providers/util/util.py +0 -76
  58. inspect_ai/scorer/_metric.py +3 -0
  59. inspect_ai/scorer/_scorer.py +2 -1
  60. inspect_ai/solver/__init__.py +4 -0
  61. inspect_ai/solver/_basic_agent.py +65 -55
  62. inspect_ai/solver/_bridge/__init__.py +3 -0
  63. inspect_ai/solver/_bridge/bridge.py +100 -0
  64. inspect_ai/solver/_bridge/patch.py +170 -0
  65. inspect_ai/{util → solver}/_limit.py +13 -0
  66. inspect_ai/solver/_solver.py +6 -0
  67. inspect_ai/solver/_task_state.py +37 -7
  68. inspect_ai/tool/_tools/_web_browser/_web_browser.py +3 -1
  69. inspect_ai/tool/beta/_computer/_resources/Dockerfile +1 -3
  70. inspect_ai/tool/beta/_computer/_resources/entrypoint/x11vnc_startup.sh +1 -1
  71. inspect_ai/tool/beta/_computer/_resources/image_home_dir/.config/xfce4/xfconf/xfce-perchannel-xml/xfce4-screensaver.xml +10 -0
  72. inspect_ai/util/__init__.py +0 -2
  73. inspect_ai/util/_display.py +5 -0
  74. inspect_ai/util/_sandbox/docker/prereqs.py +1 -1
  75. inspect_ai/util/_sandbox/self_check.py +51 -28
  76. {inspect_ai-0.3.59.dist-info → inspect_ai-0.3.61.dist-info}/METADATA +3 -2
  77. {inspect_ai-0.3.59.dist-info → inspect_ai-0.3.61.dist-info}/RECORD +81 -76
  78. inspect_ai/_view/www/src/api/Types.mjs +0 -117
  79. inspect_ai/_view/www/src/api/api-http.mjs +0 -300
  80. inspect_ai/_view/www/src/api/api-shared.mjs +0 -10
  81. inspect_ai/_view/www/src/api/index.mjs +0 -49
  82. inspect_ai/_view/www/src/api/jsonrpc.mjs +0 -208
  83. inspect_ai/_view/www/src/utils/vscode.mjs +0 -16
  84. inspect_ai/tool/beta/_computer/_resources/image_home_dir/Desktop/XPaint.desktop +0 -10
  85. {inspect_ai-0.3.59.dist-info → inspect_ai-0.3.61.dist-info}/LICENSE +0 -0
  86. {inspect_ai-0.3.59.dist-info → inspect_ai-0.3.61.dist-info}/WHEEL +0 -0
  87. {inspect_ai-0.3.59.dist-info → inspect_ai-0.3.61.dist-info}/entry_points.txt +0 -0
  88. {inspect_ai-0.3.59.dist-info → inspect_ai-0.3.61.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,3 @@
1
- import json
2
1
  import os
3
2
  from logging import getLogger
4
3
  from typing import Any
@@ -15,51 +14,39 @@ from openai import (
15
14
  from openai._types import NOT_GIVEN
16
15
  from openai.types.chat import (
17
16
  ChatCompletion,
18
- ChatCompletionAssistantMessageParam,
19
- ChatCompletionContentPartImageParam,
20
- ChatCompletionContentPartInputAudioParam,
21
- ChatCompletionContentPartParam,
22
- ChatCompletionContentPartTextParam,
23
- ChatCompletionDeveloperMessageParam,
24
- ChatCompletionMessage,
25
- ChatCompletionMessageParam,
26
- ChatCompletionMessageToolCallParam,
27
- ChatCompletionNamedToolChoiceParam,
28
- ChatCompletionSystemMessageParam,
29
- ChatCompletionToolChoiceOptionParam,
30
- ChatCompletionToolMessageParam,
31
- ChatCompletionToolParam,
32
- ChatCompletionUserMessageParam,
33
17
  )
34
- from openai.types.shared_params.function_definition import FunctionDefinition
35
18
  from typing_extensions import override
36
19
 
37
20
  from inspect_ai._util.constants import DEFAULT_MAX_RETRIES
38
- from inspect_ai._util.content import Content
39
21
  from inspect_ai._util.error import PrerequisiteError
40
- from inspect_ai._util.images import file_as_data_uri
41
22
  from inspect_ai._util.logger import warn_once
42
- from inspect_ai._util.url import is_http_url
43
- from inspect_ai.tool import ToolCall, ToolChoice, ToolFunction, ToolInfo
23
+ from inspect_ai.model._openai import chat_choices_from_openai
24
+ from inspect_ai.tool import ToolChoice, ToolInfo
44
25
 
45
- from .._chat_message import ChatMessage, ChatMessageAssistant
26
+ from .._chat_message import ChatMessage
46
27
  from .._generate_config import GenerateConfig
47
28
  from .._image import image_url_filter
48
29
  from .._model import ModelAPI
49
30
  from .._model_call import ModelCall
50
31
  from .._model_output import (
51
32
  ChatCompletionChoice,
52
- Logprobs,
53
33
  ModelOutput,
54
34
  ModelUsage,
55
35
  StopReason,
56
36
  )
37
+ from .._openai import (
38
+ is_o1,
39
+ is_o1_full,
40
+ is_o1_mini,
41
+ is_o1_preview,
42
+ openai_chat_messages,
43
+ openai_chat_tool_choice,
44
+ openai_chat_tools,
45
+ )
57
46
  from .openai_o1 import generate_o1
58
47
  from .util import (
59
- as_stop_reason,
60
48
  environment_prerequisite_error,
61
49
  model_base_url,
62
- parse_tool_call,
63
50
  )
64
51
 
65
52
  logger = getLogger(__name__)
@@ -87,20 +74,22 @@ class OpenAIAPI(ModelAPI):
87
74
  config=config,
88
75
  )
89
76
 
90
- # pull out azure model_arg
91
- AZURE_MODEL_ARG = "azure"
92
- is_azure = False
93
- if AZURE_MODEL_ARG in model_args:
94
- is_azure = model_args.get(AZURE_MODEL_ARG, False)
95
- del model_args[AZURE_MODEL_ARG]
77
+ # extract any service prefix from model name
78
+ parts = model_name.split("/")
79
+ if len(parts) > 1:
80
+ self.service: str | None = parts[0]
81
+ model_name = "/".join(parts[1:])
82
+ else:
83
+ self.service = None
96
84
 
97
85
  # resolve api_key
98
86
  if not self.api_key:
99
87
  self.api_key = os.environ.get(
100
88
  AZUREAI_OPENAI_API_KEY, os.environ.get(AZURE_OPENAI_API_KEY, None)
101
89
  )
102
- if self.api_key:
103
- is_azure = True
90
+ # backward compatibility for when env vars determined service
91
+ if self.api_key and (os.environ.get(OPENAI_API_KEY, None) is None):
92
+ self.service = "azure"
104
93
  else:
105
94
  self.api_key = os.environ.get(OPENAI_API_KEY, None)
106
95
  if not self.api_key:
@@ -113,7 +102,7 @@ class OpenAIAPI(ModelAPI):
113
102
  )
114
103
 
115
104
  # azure client
116
- if is_azure:
105
+ if self.is_azure():
117
106
  # resolve base_url
118
107
  base_url = model_base_url(
119
108
  base_url,
@@ -148,17 +137,20 @@ class OpenAIAPI(ModelAPI):
148
137
  **model_args,
149
138
  )
150
139
 
140
+ def is_azure(self) -> bool:
141
+ return self.service == "azure"
142
+
151
143
  def is_o1(self) -> bool:
152
- return self.model_name.startswith("o1")
144
+ return is_o1(self.model_name)
153
145
 
154
146
  def is_o1_full(self) -> bool:
155
- return self.is_o1() and not self.is_o1_mini() and not self.is_o1_preview()
147
+ return is_o1_full(self.model_name)
156
148
 
157
149
  def is_o1_mini(self) -> bool:
158
- return self.model_name.startswith("o1-mini")
150
+ return is_o1_mini(self.model_name)
159
151
 
160
152
  def is_o1_preview(self) -> bool:
161
- return self.model_name.startswith("o1-preview")
153
+ return is_o1_preview(self.model_name)
162
154
 
163
155
  async def generate(
164
156
  self,
@@ -198,9 +190,11 @@ class OpenAIAPI(ModelAPI):
198
190
 
199
191
  # prepare request (we do this so we can log the ModelCall)
200
192
  request = dict(
201
- messages=await as_openai_chat_messages(input, self.is_o1_full()),
202
- tools=chat_tools(tools) if len(tools) > 0 else NOT_GIVEN,
203
- tool_choice=chat_tool_choice(tool_choice) if len(tools) > 0 else NOT_GIVEN,
193
+ messages=await openai_chat_messages(input, self.model_name),
194
+ tools=openai_chat_tools(tools) if len(tools) > 0 else NOT_GIVEN,
195
+ tool_choice=openai_chat_tool_choice(tool_choice)
196
+ if len(tools) > 0
197
+ else NOT_GIVEN,
204
198
  **self.completion_params(config, len(tools) > 0),
205
199
  )
206
200
 
@@ -237,7 +231,7 @@ class OpenAIAPI(ModelAPI):
237
231
  self, response: ChatCompletion, tools: list[ToolInfo]
238
232
  ) -> list[ChatCompletionChoice]:
239
233
  # adding this as a method so we can override from other classes (e.g together)
240
- return chat_choices_from_response(response, tools)
234
+ return chat_choices_from_openai(response, tools)
241
235
 
242
236
  @override
243
237
  def is_rate_limit(self, ex: BaseException) -> bool:
@@ -327,163 +321,3 @@ class OpenAIAPI(ModelAPI):
327
321
  )
328
322
  else:
329
323
  return e
330
-
331
-
332
- async def as_openai_chat_messages(
333
- messages: list[ChatMessage], o1_full: bool
334
- ) -> list[ChatCompletionMessageParam]:
335
- return [await openai_chat_message(message, o1_full) for message in messages]
336
-
337
-
338
- async def openai_chat_message(
339
- message: ChatMessage, o1_full: bool
340
- ) -> ChatCompletionMessageParam:
341
- if message.role == "system":
342
- if o1_full:
343
- return ChatCompletionDeveloperMessageParam(
344
- role="developer", content=message.text
345
- )
346
- else:
347
- return ChatCompletionSystemMessageParam(
348
- role=message.role, content=message.text
349
- )
350
- elif message.role == "user":
351
- return ChatCompletionUserMessageParam(
352
- role=message.role,
353
- content=(
354
- message.content
355
- if isinstance(message.content, str)
356
- else [
357
- await as_chat_completion_part(content)
358
- for content in message.content
359
- ]
360
- ),
361
- )
362
- elif message.role == "assistant":
363
- if message.tool_calls:
364
- return ChatCompletionAssistantMessageParam(
365
- role=message.role,
366
- content=message.text,
367
- tool_calls=[chat_tool_call(call) for call in message.tool_calls],
368
- )
369
- else:
370
- return ChatCompletionAssistantMessageParam(
371
- role=message.role, content=message.text
372
- )
373
- elif message.role == "tool":
374
- return ChatCompletionToolMessageParam(
375
- role=message.role,
376
- content=(
377
- f"Error: {message.error.message}" if message.error else message.text
378
- ),
379
- tool_call_id=str(message.tool_call_id),
380
- )
381
- else:
382
- raise ValueError(f"Unexpected message role {message.role}")
383
-
384
-
385
- def chat_tool_call(tool_call: ToolCall) -> ChatCompletionMessageToolCallParam:
386
- return ChatCompletionMessageToolCallParam(
387
- id=tool_call.id,
388
- function=dict(
389
- name=tool_call.function, arguments=json.dumps(tool_call.arguments)
390
- ),
391
- type=tool_call.type,
392
- )
393
-
394
-
395
- def chat_tools(tools: list[ToolInfo]) -> list[ChatCompletionToolParam]:
396
- return [chat_tool_param(tool) for tool in tools]
397
-
398
-
399
- def chat_tool_param(tool: ToolInfo) -> ChatCompletionToolParam:
400
- function = FunctionDefinition(
401
- name=tool.name,
402
- description=tool.description,
403
- parameters=tool.parameters.model_dump(exclude_none=True),
404
- )
405
- return ChatCompletionToolParam(type="function", function=function)
406
-
407
-
408
- def chat_tool_choice(tool_choice: ToolChoice) -> ChatCompletionToolChoiceOptionParam:
409
- if isinstance(tool_choice, ToolFunction):
410
- return ChatCompletionNamedToolChoiceParam(
411
- type="function", function=dict(name=tool_choice.name)
412
- )
413
- # openai supports 'any' via the 'required' keyword
414
- elif tool_choice == "any":
415
- return "required"
416
- else:
417
- return tool_choice
418
-
419
-
420
- def chat_tool_calls(
421
- message: ChatCompletionMessage, tools: list[ToolInfo]
422
- ) -> list[ToolCall] | None:
423
- if message.tool_calls:
424
- return [
425
- parse_tool_call(call.id, call.function.name, call.function.arguments, tools)
426
- for call in message.tool_calls
427
- ]
428
- else:
429
- return None
430
-
431
-
432
- def chat_choices_from_response(
433
- response: ChatCompletion, tools: list[ToolInfo]
434
- ) -> list[ChatCompletionChoice]:
435
- choices = list(response.choices)
436
- choices.sort(key=lambda c: c.index)
437
- return [
438
- ChatCompletionChoice(
439
- message=chat_message_assistant(choice.message, tools),
440
- stop_reason=as_stop_reason(choice.finish_reason),
441
- logprobs=(
442
- Logprobs(**choice.logprobs.model_dump())
443
- if choice.logprobs is not None
444
- else None
445
- ),
446
- )
447
- for choice in choices
448
- ]
449
-
450
-
451
- def chat_message_assistant(
452
- message: ChatCompletionMessage, tools: list[ToolInfo]
453
- ) -> ChatMessageAssistant:
454
- return ChatMessageAssistant(
455
- content=message.content or "",
456
- source="generate",
457
- tool_calls=chat_tool_calls(message, tools),
458
- )
459
-
460
-
461
- async def as_chat_completion_part(
462
- content: Content,
463
- ) -> ChatCompletionContentPartParam:
464
- if content.type == "text":
465
- return ChatCompletionContentPartTextParam(type="text", text=content.text)
466
- elif content.type == "image":
467
- # API takes URL or base64 encoded file. If it's a remote file or
468
- # data URL leave it alone, otherwise encode it
469
- image_url = content.image
470
- detail = content.detail
471
-
472
- if not is_http_url(image_url):
473
- image_url = await file_as_data_uri(image_url)
474
-
475
- return ChatCompletionContentPartImageParam(
476
- type="image_url",
477
- image_url=dict(url=image_url, detail=detail),
478
- )
479
- elif content.type == "audio":
480
- audio_data = await file_as_data_uri(content.audio)
481
-
482
- return ChatCompletionContentPartInputAudioParam(
483
- type="input_audio", input_audio=dict(data=audio_data, format=content.format)
484
- )
485
-
486
- else:
487
- raise RuntimeError(
488
- "Video content is not currently supported by Open AI chat models."
489
- )
@@ -24,15 +24,13 @@ from inspect_ai.model import (
24
24
  )
25
25
  from inspect_ai.tool import ToolCall, ToolInfo
26
26
 
27
+ from .._call_tools import parse_tool_call, tool_parse_error_message
27
28
  from .._model_call import ModelCall
28
- from .._model_output import ModelUsage, StopReason
29
+ from .._model_output import ModelUsage, StopReason, as_stop_reason
29
30
  from .._providers.util import (
30
31
  ChatAPIHandler,
31
32
  ChatAPIMessage,
32
- as_stop_reason,
33
33
  chat_api_input,
34
- parse_tool_call,
35
- tool_parse_error_message,
36
34
  )
37
35
 
38
36
  logger = getLogger(__name__)
@@ -239,6 +239,28 @@ def mockllm() -> type[ModelAPI]:
239
239
  return MockLLM
240
240
 
241
241
 
242
+ @modelapi("goodfire")
243
+ def goodfire() -> type[ModelAPI]:
244
+ """Get the Goodfire API provider."""
245
+ FEATURE = "Goodfire API"
246
+ PACKAGE = "goodfire"
247
+ MIN_VERSION = "0.3.4" # Support for newer Llama models and OpenAI compatibility
248
+
249
+ # verify we have the package
250
+ try:
251
+ import goodfire # noqa: F401
252
+ except ImportError:
253
+ raise pip_dependency_error(FEATURE, [PACKAGE])
254
+
255
+ # verify version
256
+ verify_required_version(FEATURE, PACKAGE, MIN_VERSION)
257
+
258
+ # in the clear
259
+ from .goodfire import GoodfireAPI
260
+
261
+ return GoodfireAPI
262
+
263
+
242
264
  def validate_openai_client(feature: str) -> None:
243
265
  FEATURE = feature
244
266
  PACKAGE = "openai"
@@ -24,13 +24,13 @@ from .._model_output import (
24
24
  ModelOutput,
25
25
  ModelUsage,
26
26
  StopReason,
27
+ as_stop_reason,
27
28
  )
29
+ from .._openai import chat_message_assistant_from_openai
28
30
  from .openai import (
29
31
  OpenAIAPI,
30
- chat_message_assistant,
31
32
  )
32
33
  from .util import (
33
- as_stop_reason,
34
34
  chat_api_input,
35
35
  chat_api_request,
36
36
  environment_prerequisite_error,
@@ -68,7 +68,7 @@ def chat_choices_from_response_together(
68
68
  logprobs_models.append(Logprobs(content=logprobs_sequence))
69
69
  return [
70
70
  ChatCompletionChoice(
71
- message=chat_message_assistant(choice.message, tools),
71
+ message=chat_message_assistant_from_openai(choice.message, tools),
72
72
  stop_reason=as_stop_reason(choice.finish_reason),
73
73
  logprobs=logprobs,
74
74
  )
@@ -99,7 +99,7 @@ class TogetherAIAPI(OpenAIAPI):
99
99
 
100
100
  # Together uses a default of 512 so we bump it up
101
101
  @override
102
- def max_tokens(self) -> int:
102
+ def max_tokens(self) -> int | None:
103
103
  return DEFAULT_MAX_TOKENS
104
104
 
105
105
  @override
@@ -1,3 +1,5 @@
1
+ from ..._call_tools import parse_tool_call, tool_parse_error_message
2
+ from ..._model_output import as_stop_reason
1
3
  from .chatapi import (
2
4
  ChatAPIHandler,
3
5
  ChatAPIMessage,
@@ -8,11 +10,8 @@ from .chatapi import (
8
10
  from .hf_handler import HFHandler
9
11
  from .llama31 import Llama31Handler
10
12
  from .util import (
11
- as_stop_reason,
12
13
  environment_prerequisite_error,
13
14
  model_base_url,
14
- parse_tool_call,
15
- tool_parse_error_message,
16
15
  )
17
16
 
18
17
  __all__ = [
@@ -8,9 +8,9 @@ from typing_extensions import override
8
8
  from inspect_ai.tool._tool_call import ToolCall
9
9
  from inspect_ai.tool._tool_info import ToolInfo
10
10
 
11
+ from ..._call_tools import parse_tool_call, tool_parse_error_message
11
12
  from ..._chat_message import ChatMessageAssistant
12
13
  from .chatapi import ChatAPIHandler
13
- from .util import parse_tool_call, tool_parse_error_message
14
14
 
15
15
  logger = getLogger(__name__)
16
16
 
@@ -9,6 +9,7 @@ from typing_extensions import override
9
9
  from inspect_ai.tool._tool_call import ToolCall
10
10
  from inspect_ai.tool._tool_info import ToolInfo
11
11
 
12
+ from ..._call_tools import parse_tool_call, tool_parse_error_message
12
13
  from ..._chat_message import (
13
14
  ChatMessage,
14
15
  ChatMessageAssistant,
@@ -16,7 +17,6 @@ from ..._chat_message import (
16
17
  ChatMessageTool,
17
18
  )
18
19
  from .chatapi import ChatAPIHandler, ChatAPIMessage
19
- from .util import parse_tool_call, tool_parse_error_message
20
20
 
21
21
  logger = getLogger(__name__)
22
22
 
@@ -1,34 +1,11 @@
1
- import json
2
1
  import os
3
2
  from logging import getLogger
4
- from typing import Any
5
-
6
- import yaml
7
3
 
8
4
  from inspect_ai._util.error import PrerequisiteError
9
- from inspect_ai.tool._tool_call import ToolCall
10
- from inspect_ai.tool._tool_info import ToolInfo
11
-
12
- from ..._model_output import StopReason
13
5
 
14
6
  logger = getLogger(__name__)
15
7
 
16
8
 
17
- def as_stop_reason(reason: str | None) -> StopReason:
18
- """Encode common reason strings into standard StopReason."""
19
- match reason:
20
- case "stop" | "eos":
21
- return "stop"
22
- case "length":
23
- return "max_tokens"
24
- case "tool_calls" | "function_call":
25
- return "tool_calls"
26
- case "content_filter" | "model_length" | "max_tokens":
27
- return reason
28
- case _:
29
- return "unknown"
30
-
31
-
32
9
  def model_base_url(base_url: str | None, env_vars: str | list[str]) -> str | None:
33
10
  if base_url:
34
11
  return base_url
@@ -44,59 +21,6 @@ def model_base_url(base_url: str | None, env_vars: str | list[str]) -> str | Non
44
21
  return os.getenv("INSPECT_EVAL_MODEL_BASE_URL", None)
45
22
 
46
23
 
47
- def tool_parse_error_message(arguments: str, ex: Exception) -> str:
48
- return f"Error parsing the following tool call arguments:\n\n{arguments}\n\nError details: {ex}"
49
-
50
-
51
- def parse_tool_call(
52
- id: str, function: str, arguments: str, tools: list[ToolInfo]
53
- ) -> ToolCall:
54
- error: str | None = None
55
- arguments_dict: dict[str, Any] = {}
56
-
57
- def report_parse_error(ex: Exception) -> None:
58
- nonlocal error
59
- error = tool_parse_error_message(arguments, ex)
60
- logger.info(error)
61
-
62
- # if the arguments is a dict, then handle it with a plain json.loads
63
- arguments = arguments.strip()
64
- if arguments.startswith("{"):
65
- try:
66
- arguments_dict = json.loads(arguments)
67
- except json.JSONDecodeError as ex:
68
- report_parse_error(ex)
69
-
70
- # otherwise parse it as yaml (which will pickup unquoted strings, numbers, and true/false)
71
- # and then create a dict that maps it to the first function argument
72
- else:
73
- tool_info = next(
74
- (
75
- tool
76
- for tool in tools
77
- if tool.name == function and len(tool.parameters.properties) > 0
78
- ),
79
- None,
80
- )
81
- if tool_info:
82
- param_names = list(tool_info.parameters.properties.keys())
83
- try:
84
- value = yaml.safe_load(arguments)
85
- arguments_dict[param_names[0]] = value
86
- except yaml.error.YAMLError:
87
- # If the yaml parser fails, we treat it as a string argument.
88
- arguments_dict[param_names[0]] = arguments
89
-
90
- # return ToolCall with error payload
91
- return ToolCall(
92
- id=id,
93
- function=function,
94
- arguments=arguments_dict,
95
- type="function",
96
- parse_error=error,
97
- )
98
-
99
-
100
24
  def environment_prerequisite_error(
101
25
  client: str, env_vars: str | list[str]
102
26
  ) -> PrerequisiteError:
@@ -125,6 +125,9 @@ class SampleScore(BaseModel):
125
125
  sample_id: str | int | None = Field(default=None)
126
126
  """A sample id"""
127
127
 
128
+ scorer: str | None = Field(default=None)
129
+ """Registry name of scorer that created this score."""
130
+
128
131
 
129
132
  ValueToFloat = Callable[[Value], float]
130
133
  """Function used by metrics to translate from a Score value to a float value."""
@@ -1,3 +1,4 @@
1
+ from functools import wraps
1
2
  from typing import (
2
3
  Any,
3
4
  Callable,
@@ -100,7 +101,6 @@ def scorer(
100
101
 
101
102
  Returns:
102
103
  Scorer with registry attributes.
103
-
104
104
  """
105
105
 
106
106
  def wrapper(scorer_type: Callable[P, Scorer]) -> Callable[P, Scorer]:
@@ -110,6 +110,7 @@ def scorer(
110
110
  )
111
111
 
112
112
  # wrap instantiations of scorer so they carry registry info and metrics
113
+ @wraps(scorer_type)
113
114
  def scorer_wrapper(*args: P.args, **kwargs: P.kwargs) -> Scorer:
114
115
  scorer = scorer_type(*args, **kwargs)
115
116
 
@@ -1,10 +1,12 @@
1
1
  from inspect_ai._util.deprecation import relocated_module_attribute
2
2
 
3
3
  from ._basic_agent import basic_agent
4
+ from ._bridge import bridge
4
5
  from ._chain import chain
5
6
  from ._critique import self_critique
6
7
  from ._fork import fork
7
8
  from ._human_agent.agent import human_agent
9
+ from ._limit import SampleLimitExceededError
8
10
  from ._multiple_choice import MultipleChoiceTemplate, multiple_choice
9
11
  from ._plan import Plan, plan
10
12
  from ._prompt import chain_of_thought, prompt_template, system_message, user_message
@@ -14,6 +16,7 @@ from ._use_tools import use_tools
14
16
 
15
17
  __all__ = [
16
18
  "basic_agent",
19
+ "bridge",
17
20
  "human_agent",
18
21
  "chain",
19
22
  "fork",
@@ -35,6 +38,7 @@ __all__ = [
35
38
  "TaskState",
36
39
  "Generate",
37
40
  "MultipleChoiceTemplate",
41
+ "SampleLimitExceededError",
38
42
  ]
39
43
 
40
44