inspect-ai 0.3.69__py3-none-any.whl → 0.3.71__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (242) hide show
  1. inspect_ai/_cli/eval.py +27 -9
  2. inspect_ai/_display/core/display.py +2 -0
  3. inspect_ai/_display/core/footer.py +13 -3
  4. inspect_ai/_display/plain/display.py +6 -2
  5. inspect_ai/_display/rich/display.py +19 -6
  6. inspect_ai/_display/textual/app.py +9 -3
  7. inspect_ai/_display/textual/display.py +4 -0
  8. inspect_ai/_display/textual/widgets/samples.py +4 -10
  9. inspect_ai/_display/textual/widgets/transcript.py +35 -18
  10. inspect_ai/_eval/eval.py +14 -2
  11. inspect_ai/_eval/evalset.py +6 -1
  12. inspect_ai/_eval/run.py +6 -0
  13. inspect_ai/_eval/task/run.py +49 -23
  14. inspect_ai/_eval/task/task.py +26 -3
  15. inspect_ai/_util/content.py +20 -1
  16. inspect_ai/_util/interrupt.py +6 -0
  17. inspect_ai/_util/logger.py +19 -0
  18. inspect_ai/_util/rich.py +7 -8
  19. inspect_ai/_util/text.py +13 -0
  20. inspect_ai/_util/transcript.py +20 -6
  21. inspect_ai/_util/working.py +50 -0
  22. inspect_ai/_view/www/App.css +6 -0
  23. inspect_ai/_view/www/dist/assets/index.css +171 -99
  24. inspect_ai/_view/www/dist/assets/index.js +5972 -2770
  25. inspect_ai/_view/www/eslint.config.mjs +24 -1
  26. inspect_ai/_view/www/log-schema.json +619 -21
  27. inspect_ai/_view/www/package.json +8 -3
  28. inspect_ai/_view/www/src/App.tsx +2 -2
  29. inspect_ai/_view/www/src/appearance/icons.ts +3 -1
  30. inspect_ai/_view/www/src/components/AnsiDisplay.tsx +4 -3
  31. inspect_ai/_view/www/src/components/Card.tsx +9 -8
  32. inspect_ai/_view/www/src/components/DownloadButton.tsx +2 -1
  33. inspect_ai/_view/www/src/components/EmptyPanel.tsx +2 -2
  34. inspect_ai/_view/www/src/components/ErrorPanel.tsx +4 -3
  35. inspect_ai/_view/www/src/components/ExpandablePanel.tsx +13 -5
  36. inspect_ai/_view/www/src/components/FindBand.tsx +3 -3
  37. inspect_ai/_view/www/src/components/HumanBaselineView.tsx +3 -3
  38. inspect_ai/_view/www/src/components/LabeledValue.tsx +5 -4
  39. inspect_ai/_view/www/src/components/LargeModal.tsx +18 -13
  40. inspect_ai/_view/www/src/components/{LightboxCarousel.css → LightboxCarousel.module.css} +22 -18
  41. inspect_ai/_view/www/src/components/LightboxCarousel.tsx +36 -27
  42. inspect_ai/_view/www/src/components/MessageBand.tsx +2 -1
  43. inspect_ai/_view/www/src/components/NavPills.tsx +9 -8
  44. inspect_ai/_view/www/src/components/ProgressBar.tsx +2 -1
  45. inspect_ai/_view/www/src/components/TabSet.tsx +21 -15
  46. inspect_ai/_view/www/src/index.tsx +2 -2
  47. inspect_ai/_view/www/src/metadata/MetaDataGrid.tsx +11 -9
  48. inspect_ai/_view/www/src/metadata/MetaDataView.tsx +3 -2
  49. inspect_ai/_view/www/src/metadata/MetadataGrid.module.css +1 -0
  50. inspect_ai/_view/www/src/metadata/RenderedContent.tsx +16 -1
  51. inspect_ai/_view/www/src/plan/DatasetDetailView.tsx +3 -2
  52. inspect_ai/_view/www/src/plan/DetailStep.tsx +2 -1
  53. inspect_ai/_view/www/src/plan/PlanCard.tsx +2 -5
  54. inspect_ai/_view/www/src/plan/PlanDetailView.tsx +6 -9
  55. inspect_ai/_view/www/src/plan/ScorerDetailView.tsx +2 -1
  56. inspect_ai/_view/www/src/plan/SolverDetailView.tsx +3 -3
  57. inspect_ai/_view/www/src/samples/InlineSampleDisplay.tsx +2 -2
  58. inspect_ai/_view/www/src/samples/SampleDialog.tsx +3 -3
  59. inspect_ai/_view/www/src/samples/SampleDisplay.module.css +9 -1
  60. inspect_ai/_view/www/src/samples/SampleDisplay.tsx +30 -3
  61. inspect_ai/_view/www/src/samples/SampleSummaryView.module.css +4 -0
  62. inspect_ai/_view/www/src/samples/SampleSummaryView.tsx +25 -4
  63. inspect_ai/_view/www/src/samples/SamplesTools.tsx +2 -1
  64. inspect_ai/_view/www/src/samples/chat/ChatMessage.tsx +3 -19
  65. inspect_ai/_view/www/src/samples/chat/ChatMessageRenderer.tsx +2 -1
  66. inspect_ai/_view/www/src/samples/chat/ChatMessageRow.tsx +2 -1
  67. inspect_ai/_view/www/src/samples/chat/ChatView.tsx +2 -1
  68. inspect_ai/_view/www/src/samples/chat/ChatViewVirtualList.tsx +22 -7
  69. inspect_ai/_view/www/src/samples/chat/MessageContent.tsx +35 -6
  70. inspect_ai/_view/www/src/samples/chat/MessageContents.tsx +2 -2
  71. inspect_ai/_view/www/src/samples/chat/messages.ts +15 -2
  72. inspect_ai/_view/www/src/samples/chat/tools/ToolCallView.tsx +13 -4
  73. inspect_ai/_view/www/src/samples/chat/tools/ToolInput.module.css +2 -2
  74. inspect_ai/_view/www/src/samples/chat/tools/ToolInput.tsx +18 -19
  75. inspect_ai/_view/www/src/samples/chat/tools/ToolOutput.module.css +1 -1
  76. inspect_ai/_view/www/src/samples/chat/tools/ToolOutput.tsx +4 -3
  77. inspect_ai/_view/www/src/samples/chat/tools/ToolTitle.tsx +2 -2
  78. inspect_ai/_view/www/src/samples/error/FlatSampleErrorView.tsx +2 -3
  79. inspect_ai/_view/www/src/samples/error/SampleErrorView.tsx +3 -2
  80. inspect_ai/_view/www/src/samples/list/SampleFooter.tsx +2 -1
  81. inspect_ai/_view/www/src/samples/list/SampleHeader.tsx +2 -1
  82. inspect_ai/_view/www/src/samples/list/SampleList.tsx +57 -45
  83. inspect_ai/_view/www/src/samples/list/SampleRow.tsx +2 -1
  84. inspect_ai/_view/www/src/samples/list/SampleSeparator.tsx +2 -1
  85. inspect_ai/_view/www/src/samples/sample-tools/EpochFilter.tsx +2 -2
  86. inspect_ai/_view/www/src/samples/sample-tools/SelectScorer.tsx +4 -3
  87. inspect_ai/_view/www/src/samples/sample-tools/SortFilter.tsx +2 -5
  88. inspect_ai/_view/www/src/samples/sample-tools/sample-filter/SampleFilter.tsx +2 -2
  89. inspect_ai/_view/www/src/samples/scores/SampleScoreView.tsx +2 -1
  90. inspect_ai/_view/www/src/samples/scores/SampleScores.tsx +2 -2
  91. inspect_ai/_view/www/src/samples/transcript/ApprovalEventView.tsx +2 -1
  92. inspect_ai/_view/www/src/samples/transcript/ErrorEventView.tsx +2 -1
  93. inspect_ai/_view/www/src/samples/transcript/InfoEventView.tsx +2 -1
  94. inspect_ai/_view/www/src/samples/transcript/InputEventView.tsx +2 -1
  95. inspect_ai/_view/www/src/samples/transcript/LoggerEventView.module.css +4 -0
  96. inspect_ai/_view/www/src/samples/transcript/LoggerEventView.tsx +12 -2
  97. inspect_ai/_view/www/src/samples/transcript/ModelEventView.module.css +1 -1
  98. inspect_ai/_view/www/src/samples/transcript/ModelEventView.tsx +25 -28
  99. inspect_ai/_view/www/src/samples/transcript/SampleInitEventView.tsx +2 -1
  100. inspect_ai/_view/www/src/samples/transcript/SampleLimitEventView.tsx +9 -4
  101. inspect_ai/_view/www/src/samples/transcript/SampleTranscript.tsx +2 -2
  102. inspect_ai/_view/www/src/samples/transcript/SandboxEventView.module.css +32 -0
  103. inspect_ai/_view/www/src/samples/transcript/SandboxEventView.tsx +153 -0
  104. inspect_ai/_view/www/src/samples/transcript/ScoreEventView.tsx +2 -2
  105. inspect_ai/_view/www/src/samples/transcript/StepEventView.tsx +12 -5
  106. inspect_ai/_view/www/src/samples/transcript/SubtaskEventView.tsx +18 -14
  107. inspect_ai/_view/www/src/samples/transcript/ToolEventView.tsx +5 -5
  108. inspect_ai/_view/www/src/samples/transcript/TranscriptView.tsx +53 -16
  109. inspect_ai/_view/www/src/samples/transcript/event/EventNav.tsx +2 -1
  110. inspect_ai/_view/www/src/samples/transcript/event/EventNavs.tsx +2 -1
  111. inspect_ai/_view/www/src/samples/transcript/event/EventPanel.tsx +6 -3
  112. inspect_ai/_view/www/src/samples/transcript/event/EventRow.tsx +3 -2
  113. inspect_ai/_view/www/src/samples/transcript/event/EventSection.tsx +2 -2
  114. inspect_ai/_view/www/src/samples/transcript/event/EventTimingPanel.module.css +28 -0
  115. inspect_ai/_view/www/src/samples/transcript/event/EventTimingPanel.tsx +115 -0
  116. inspect_ai/_view/www/src/samples/transcript/event/utils.ts +29 -0
  117. inspect_ai/_view/www/src/samples/transcript/state/StateDiffView.tsx +2 -1
  118. inspect_ai/_view/www/src/samples/transcript/state/StateEventRenderers.tsx +3 -3
  119. inspect_ai/_view/www/src/samples/transcript/state/StateEventView.tsx +11 -8
  120. inspect_ai/_view/www/src/samples/transcript/types.ts +3 -1
  121. inspect_ai/_view/www/src/types/log.d.ts +312 -137
  122. inspect_ai/_view/www/src/usage/ModelTokenTable.tsx +6 -10
  123. inspect_ai/_view/www/src/usage/ModelUsagePanel.module.css +4 -0
  124. inspect_ai/_view/www/src/usage/ModelUsagePanel.tsx +32 -9
  125. inspect_ai/_view/www/src/usage/TokenTable.tsx +4 -6
  126. inspect_ai/_view/www/src/usage/UsageCard.tsx +2 -1
  127. inspect_ai/_view/www/src/utils/format.ts +8 -5
  128. inspect_ai/_view/www/src/utils/json.ts +24 -0
  129. inspect_ai/_view/www/src/workspace/WorkSpace.tsx +6 -5
  130. inspect_ai/_view/www/src/workspace/WorkSpaceView.tsx +18 -8
  131. inspect_ai/_view/www/src/workspace/error/TaskErrorPanel.tsx +2 -1
  132. inspect_ai/_view/www/src/workspace/navbar/Navbar.tsx +2 -1
  133. inspect_ai/_view/www/src/workspace/navbar/PrimaryBar.tsx +3 -3
  134. inspect_ai/_view/www/src/workspace/navbar/ResultsPanel.tsx +4 -3
  135. inspect_ai/_view/www/src/workspace/navbar/SecondaryBar.tsx +5 -4
  136. inspect_ai/_view/www/src/workspace/navbar/StatusPanel.tsx +5 -8
  137. inspect_ai/_view/www/src/workspace/sidebar/EvalStatus.tsx +5 -4
  138. inspect_ai/_view/www/src/workspace/sidebar/LogDirectoryTitleView.tsx +2 -1
  139. inspect_ai/_view/www/src/workspace/sidebar/Sidebar.tsx +2 -1
  140. inspect_ai/_view/www/src/workspace/sidebar/SidebarLogEntry.tsx +2 -2
  141. inspect_ai/_view/www/src/workspace/sidebar/SidebarScoreView.tsx +2 -1
  142. inspect_ai/_view/www/src/workspace/sidebar/SidebarScoresView.tsx +2 -2
  143. inspect_ai/_view/www/src/workspace/tabs/InfoTab.tsx +2 -2
  144. inspect_ai/_view/www/src/workspace/tabs/JsonTab.tsx +2 -5
  145. inspect_ai/_view/www/src/workspace/tabs/SamplesTab.tsx +12 -11
  146. inspect_ai/_view/www/yarn.lock +241 -5
  147. inspect_ai/log/__init__.py +2 -0
  148. inspect_ai/log/_condense.py +4 -0
  149. inspect_ai/log/_log.py +72 -12
  150. inspect_ai/log/_recorders/eval.py +6 -1
  151. inspect_ai/log/_samples.py +5 -1
  152. inspect_ai/log/_transcript.py +89 -2
  153. inspect_ai/model/__init__.py +2 -0
  154. inspect_ai/model/_call_tools.py +8 -1
  155. inspect_ai/model/_chat_message.py +22 -7
  156. inspect_ai/model/_conversation.py +11 -9
  157. inspect_ai/model/_generate_config.py +25 -4
  158. inspect_ai/model/_model.py +164 -72
  159. inspect_ai/model/_model_call.py +10 -3
  160. inspect_ai/model/_model_output.py +3 -0
  161. inspect_ai/model/_openai.py +106 -40
  162. inspect_ai/model/_providers/anthropic.py +145 -26
  163. inspect_ai/model/_providers/bedrock.py +7 -0
  164. inspect_ai/model/_providers/cloudflare.py +20 -7
  165. inspect_ai/model/_providers/google.py +29 -8
  166. inspect_ai/model/_providers/groq.py +66 -27
  167. inspect_ai/model/_providers/hf.py +6 -0
  168. inspect_ai/model/_providers/mistral.py +78 -51
  169. inspect_ai/model/_providers/openai.py +66 -4
  170. inspect_ai/model/_providers/openai_o1.py +10 -0
  171. inspect_ai/model/_providers/providers.py +2 -2
  172. inspect_ai/model/_providers/util/tracker.py +92 -0
  173. inspect_ai/model/_providers/vllm.py +13 -5
  174. inspect_ai/model/_reasoning.py +15 -2
  175. inspect_ai/scorer/_model.py +23 -19
  176. inspect_ai/solver/_basic_agent.py +1 -3
  177. inspect_ai/solver/_bridge/patch.py +0 -2
  178. inspect_ai/solver/_human_agent/agent.py +14 -10
  179. inspect_ai/solver/_human_agent/commands/__init__.py +7 -3
  180. inspect_ai/solver/_human_agent/commands/submit.py +76 -30
  181. inspect_ai/solver/_limit.py +4 -4
  182. inspect_ai/solver/_plan.py +0 -3
  183. inspect_ai/solver/_task_state.py +7 -0
  184. inspect_ai/tool/__init__.py +2 -0
  185. inspect_ai/tool/_tool.py +3 -1
  186. inspect_ai/tool/_tools/_computer/_resources/tool/_run.py +1 -1
  187. inspect_ai/tool/_tools/_web_browser/_resources/.pylintrc +8 -0
  188. inspect_ai/tool/_tools/_web_browser/_resources/.vscode/launch.json +24 -0
  189. inspect_ai/tool/_tools/_web_browser/_resources/.vscode/settings.json +25 -0
  190. inspect_ai/tool/_tools/_web_browser/_resources/Dockerfile +5 -6
  191. inspect_ai/tool/_tools/_web_browser/_resources/README.md +10 -11
  192. inspect_ai/tool/_tools/_web_browser/_resources/accessibility_tree.py +71 -0
  193. inspect_ai/tool/_tools/_web_browser/_resources/accessibility_tree_node.py +323 -0
  194. inspect_ai/tool/_tools/_web_browser/_resources/cdp/__init__.py +5 -0
  195. inspect_ai/tool/_tools/_web_browser/_resources/cdp/a11y.py +279 -0
  196. inspect_ai/tool/_tools/_web_browser/_resources/cdp/dom.py +9 -0
  197. inspect_ai/tool/_tools/_web_browser/_resources/cdp/dom_snapshot.py +293 -0
  198. inspect_ai/tool/_tools/_web_browser/_resources/cdp/page.py +94 -0
  199. inspect_ai/tool/_tools/_web_browser/_resources/constants.py +2 -0
  200. inspect_ai/tool/_tools/_web_browser/_resources/images/usage_diagram.svg +2 -0
  201. inspect_ai/tool/_tools/_web_browser/_resources/playwright_browser.py +50 -0
  202. inspect_ai/tool/_tools/_web_browser/_resources/playwright_crawler.py +31 -359
  203. inspect_ai/tool/_tools/_web_browser/_resources/playwright_page_crawler.py +280 -0
  204. inspect_ai/tool/_tools/_web_browser/_resources/pyproject.toml +65 -0
  205. inspect_ai/tool/_tools/_web_browser/_resources/rectangle.py +64 -0
  206. inspect_ai/tool/_tools/_web_browser/_resources/rpc_client_helpers.py +146 -0
  207. inspect_ai/tool/_tools/_web_browser/_resources/scale_factor.py +64 -0
  208. inspect_ai/tool/_tools/_web_browser/_resources/test_accessibility_tree_node.py +180 -0
  209. inspect_ai/tool/_tools/_web_browser/_resources/test_playwright_crawler.py +15 -9
  210. inspect_ai/tool/_tools/_web_browser/_resources/test_rectangle.py +15 -0
  211. inspect_ai/tool/_tools/_web_browser/_resources/test_web_client.py +44 -0
  212. inspect_ai/tool/_tools/_web_browser/_resources/web_browser_rpc_types.py +39 -0
  213. inspect_ai/tool/_tools/_web_browser/_resources/web_client.py +198 -48
  214. inspect_ai/tool/_tools/_web_browser/_resources/web_client_new_session.py +26 -25
  215. inspect_ai/tool/_tools/_web_browser/_resources/web_server.py +178 -39
  216. inspect_ai/tool/_tools/_web_browser/_web_browser.py +38 -19
  217. inspect_ai/tool/_tools/_web_search.py +3 -3
  218. inspect_ai/util/__init__.py +2 -1
  219. inspect_ai/util/_concurrency.py +14 -8
  220. inspect_ai/util/_display.py +12 -0
  221. inspect_ai/util/_sandbox/context.py +15 -0
  222. inspect_ai/util/_sandbox/docker/docker.py +7 -5
  223. inspect_ai/util/_sandbox/environment.py +32 -1
  224. inspect_ai/util/_sandbox/events.py +183 -0
  225. inspect_ai/util/_sandbox/local.py +3 -3
  226. inspect_ai/util/_sandbox/self_check.py +131 -43
  227. inspect_ai/util/_subtask.py +11 -0
  228. {inspect_ai-0.3.69.dist-info → inspect_ai-0.3.71.dist-info}/METADATA +3 -3
  229. {inspect_ai-0.3.69.dist-info → inspect_ai-0.3.71.dist-info}/RECORD +233 -211
  230. {inspect_ai-0.3.69.dist-info → inspect_ai-0.3.71.dist-info}/WHEEL +1 -1
  231. inspect_ai/_view/www/src/components/VirtualList.module.css +0 -19
  232. inspect_ai/_view/www/src/components/VirtualList.tsx +0 -292
  233. inspect_ai/tool/_tools/_web_browser/_resources/accessibility_node.py +0 -312
  234. inspect_ai/tool/_tools/_web_browser/_resources/dm_env_servicer.py +0 -275
  235. inspect_ai/tool/_tools/_web_browser/_resources/images/usage_diagram.png +0 -0
  236. inspect_ai/tool/_tools/_web_browser/_resources/test_accessibility_node.py +0 -176
  237. inspect_ai/tool/_tools/_web_browser/_resources/test_dm_env_servicer.py +0 -135
  238. inspect_ai/tool/_tools/_web_browser/_resources/test_web_environment.py +0 -71
  239. inspect_ai/tool/_tools/_web_browser/_resources/web_environment.py +0 -184
  240. {inspect_ai-0.3.69.dist-info → inspect_ai-0.3.71.dist-info}/LICENSE +0 -0
  241. {inspect_ai-0.3.69.dist-info → inspect_ai-0.3.71.dist-info}/entry_points.txt +0 -0
  242. {inspect_ai-0.3.69.dist-info → inspect_ai-0.3.71.dist-info}/top_level.txt +0 -0
@@ -4,6 +4,7 @@ import functools
4
4
  import gc
5
5
  import json
6
6
  import os
7
+ import time
7
8
  from dataclasses import dataclass
8
9
  from queue import Empty, Queue
9
10
  from threading import Thread
@@ -220,6 +221,7 @@ class HuggingFaceAPI(ModelAPI):
220
221
  output_tokens=response.output_tokens,
221
222
  total_tokens=response.total_tokens,
222
223
  ),
224
+ time=response.time,
223
225
  )
224
226
 
225
227
  @override
@@ -377,6 +379,7 @@ class GenerateOutput:
377
379
  output_tokens: int
378
380
  total_tokens: int
379
381
  logprobs: torch.Tensor | None
382
+ time: float
380
383
 
381
384
 
382
385
  @dataclass
@@ -432,6 +435,7 @@ def process_batches() -> None:
432
435
 
433
436
  try:
434
437
  # capture the generator and decoder functions
438
+ start_time = time.monotonic()
435
439
  first_input = inputs[0][0]
436
440
  device = first_input.device
437
441
  tokenizer = first_input.tokenizer
@@ -467,6 +471,7 @@ def process_batches() -> None:
467
471
  outputs = decoder(sequences=generated_tokens)
468
472
 
469
473
  # call back futures
474
+ total_time = time.monotonic() - start_time
470
475
  for i, output in enumerate(outputs):
471
476
  future = inputs[i][1]
472
477
  input_tokens = input_ids.size(dim=1)
@@ -483,6 +488,7 @@ def process_batches() -> None:
483
488
  output_tokens=output_tokens,
484
489
  total_tokens=input_tokens + output_tokens,
485
490
  logprobs=logprobs[i] if logprobs is not None else None,
491
+ time=total_time,
486
492
  ),
487
493
  )
488
494
 
@@ -61,6 +61,7 @@ from .._model_output import (
61
61
  StopReason,
62
62
  )
63
63
  from .util import environment_prerequisite_error, model_base_url
64
+ from .util.tracker import HttpxTimeTracker
64
65
 
65
66
  AZURE_MISTRAL_API_KEY = "AZURE_MISTRAL_API_KEY"
66
67
  AZUREAI_MISTRAL_API_KEY = "AZUREAI_MISTRAL_API_KEY"
@@ -125,57 +126,83 @@ class MistralAPI(ModelAPI):
125
126
  tool_choice: ToolChoice,
126
127
  config: GenerateConfig,
127
128
  ) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
128
- # build request
129
- request: dict[str, Any] = dict(
130
- model=self.model_name,
131
- messages=await mistral_chat_messages(input),
132
- tools=mistral_chat_tools(tools) if len(tools) > 0 else None,
133
- tool_choice=(
134
- mistral_chat_tool_choice(tool_choice) if len(tools) > 0 else None
135
- ),
136
- )
137
- if config.temperature is not None:
138
- request["temperature"] = config.temperature
139
- if config.top_p is not None:
140
- request["top_p"] = config.top_p
141
- if config.max_tokens is not None:
142
- request["max_tokens"] = config.max_tokens
143
- if config.seed is not None:
144
- request["random_seed"] = config.seed
145
-
146
- # send request
147
- try:
148
- with Mistral(
149
- api_key=self.api_key,
150
- timeout_ms=(config.timeout if config.timeout else DEFAULT_TIMEOUT)
151
- * 1000,
152
- **self.model_args,
153
- ) as client:
154
- response = await client.chat.complete_async(**request)
155
- except SDKError as ex:
156
- if ex.status_code == 400:
157
- return self.handle_bad_request(ex), mistral_model_call(request, None)
158
- else:
159
- raise ex
160
-
161
- if response is None:
162
- raise RuntimeError("Mistral model did not return a response from generate.")
163
-
164
- # return model output (w/ tool calls if they exist)
165
- choices = completion_choices_from_response(response, tools)
166
- return ModelOutput(
167
- model=response.model,
168
- choices=choices,
169
- usage=ModelUsage(
170
- input_tokens=response.usage.prompt_tokens,
171
- output_tokens=(
172
- response.usage.completion_tokens
173
- if response.usage.completion_tokens
174
- else response.usage.total_tokens - response.usage.prompt_tokens
129
+ # create client
130
+ with Mistral(
131
+ api_key=self.api_key,
132
+ timeout_ms=(config.timeout if config.timeout else DEFAULT_TIMEOUT) * 1000,
133
+ **self.model_args,
134
+ ) as client:
135
+ # create time tracker
136
+ time_tracker = HttpxTimeTracker(client.sdk_configuration.async_client)
137
+
138
+ # build request
139
+ request_id = time_tracker.start_request()
140
+ request: dict[str, Any] = dict(
141
+ model=self.model_name,
142
+ messages=await mistral_chat_messages(input),
143
+ tools=mistral_chat_tools(tools) if len(tools) > 0 else None,
144
+ tool_choice=(
145
+ mistral_chat_tool_choice(tool_choice) if len(tools) > 0 else None
175
146
  ),
176
- total_tokens=response.usage.total_tokens,
177
- ),
178
- ), mistral_model_call(request, response)
147
+ http_headers={HttpxTimeTracker.REQUEST_ID_HEADER: request_id},
148
+ )
149
+ if config.temperature is not None:
150
+ request["temperature"] = config.temperature
151
+ if config.top_p is not None:
152
+ request["top_p"] = config.top_p
153
+ if config.max_tokens is not None:
154
+ request["max_tokens"] = config.max_tokens
155
+ if config.seed is not None:
156
+ request["random_seed"] = config.seed
157
+
158
+ # prepare response for inclusion in model call
159
+ response: dict[str, Any] = {}
160
+
161
+ def model_call() -> ModelCall:
162
+ req = request.copy()
163
+ req.update(
164
+ messages=[message.model_dump() for message in req["messages"]]
165
+ )
166
+ if req.get("tools", None) is not None:
167
+ req["tools"] = [tool.model_dump() for tool in req["tools"]]
168
+
169
+ return ModelCall.create(
170
+ request=req,
171
+ response=response,
172
+ time=time_tracker.end_request(request_id),
173
+ )
174
+
175
+ # send request
176
+ try:
177
+ completion = await client.chat.complete_async(**request)
178
+ response = completion.model_dump()
179
+ except SDKError as ex:
180
+ if ex.status_code == 400:
181
+ return self.handle_bad_request(ex), model_call()
182
+ else:
183
+ raise ex
184
+
185
+ if completion is None:
186
+ raise RuntimeError(
187
+ "Mistral model did not return a response from generate."
188
+ )
189
+
190
+ # return model output (w/ tool calls if they exist)
191
+ choices = completion_choices_from_response(completion, tools)
192
+ return ModelOutput(
193
+ model=completion.model,
194
+ choices=choices,
195
+ usage=ModelUsage(
196
+ input_tokens=completion.usage.prompt_tokens,
197
+ output_tokens=(
198
+ completion.usage.completion_tokens
199
+ if completion.usage.completion_tokens
200
+ else completion.usage.total_tokens
201
+ - completion.usage.prompt_tokens
202
+ ),
203
+ total_tokens=completion.usage.total_tokens,
204
+ ),
205
+ ), model_call()
179
206
 
180
207
  @override
181
208
  def is_rate_limit(self, ex: BaseException) -> bool:
@@ -207,7 +234,7 @@ def mistral_model_call(
207
234
  request.update(messages=[message.model_dump() for message in request["messages"]])
208
235
  if request.get("tools", None) is not None:
209
236
  request["tools"] = [tool.model_dump() for tool in request["tools"]]
210
- return ModelCall(
237
+ return ModelCall.create(
211
238
  request=request, response=response.model_dump() if response else {}
212
239
  )
213
240
 
@@ -1,8 +1,12 @@
1
1
  import os
2
+ import socket
2
3
  from logging import getLogger
3
4
  from typing import Any
4
5
 
6
+ import httpx
5
7
  from openai import (
8
+ DEFAULT_CONNECTION_LIMITS,
9
+ DEFAULT_TIMEOUT,
6
10
  APIConnectionError,
7
11
  APITimeoutError,
8
12
  AsyncAzureOpenAI,
@@ -21,6 +25,7 @@ from inspect_ai._util.constants import DEFAULT_MAX_RETRIES
21
25
  from inspect_ai._util.error import PrerequisiteError
22
26
  from inspect_ai._util.logger import warn_once
23
27
  from inspect_ai.model._openai import chat_choices_from_openai
28
+ from inspect_ai.model._providers.util.tracker import HttpxTimeTracker
24
29
  from inspect_ai.tool import ToolChoice, ToolInfo
25
30
 
26
31
  from .._chat_message import ChatMessage
@@ -101,6 +106,9 @@ class OpenAIAPI(ModelAPI):
101
106
  ],
102
107
  )
103
108
 
109
+ # create async http client
110
+ http_client = OpenAIAsyncHttpxClient()
111
+
104
112
  # azure client
105
113
  if self.is_azure():
106
114
  # resolve base_url
@@ -125,6 +133,7 @@ class OpenAIAPI(ModelAPI):
125
133
  max_retries=(
126
134
  config.max_retries if config.max_retries else DEFAULT_MAX_RETRIES
127
135
  ),
136
+ http_client=http_client,
128
137
  **model_args,
129
138
  )
130
139
  else:
@@ -134,9 +143,13 @@ class OpenAIAPI(ModelAPI):
134
143
  max_retries=(
135
144
  config.max_retries if config.max_retries else DEFAULT_MAX_RETRIES
136
145
  ),
146
+ http_client=http_client,
137
147
  **model_args,
138
148
  )
139
149
 
150
+ # create time tracker
151
+ self._time_tracker = HttpxTimeTracker(self.client._client)
152
+
140
153
  def is_azure(self) -> bool:
141
154
  return self.service == "azure"
142
155
 
@@ -172,6 +185,9 @@ class OpenAIAPI(ModelAPI):
172
185
  **self.completion_params(config, False),
173
186
  )
174
187
 
188
+ # allocate request_id (so we can see it from ModelCall)
189
+ request_id = self._time_tracker.start_request()
190
+
175
191
  # setup request and response for ModelCall
176
192
  request: dict[str, Any] = {}
177
193
  response: dict[str, Any] = {}
@@ -181,6 +197,7 @@ class OpenAIAPI(ModelAPI):
181
197
  request=request,
182
198
  response=response,
183
199
  filter=image_url_filter,
200
+ time=self._time_tracker.end_request(request_id),
184
201
  )
185
202
 
186
203
  # unlike text models, vision models require a max_tokens (and set it to a very low
@@ -199,6 +216,7 @@ class OpenAIAPI(ModelAPI):
199
216
  tool_choice=openai_chat_tool_choice(tool_choice)
200
217
  if len(tools) > 0
201
218
  else NOT_GIVEN,
219
+ extra_headers={HttpxTimeTracker.REQUEST_ID_HEADER: request_id},
202
220
  **self.completion_params(config, len(tools) > 0),
203
221
  )
204
222
 
@@ -222,6 +240,16 @@ class OpenAIAPI(ModelAPI):
222
240
  ModelUsage(
223
241
  input_tokens=completion.usage.prompt_tokens,
224
242
  output_tokens=completion.usage.completion_tokens,
243
+ input_tokens_cache_read=(
244
+ completion.usage.prompt_tokens_details.cached_tokens
245
+ if completion.usage.prompt_tokens_details is not None
246
+ else None # openai only have cache read stats/pricing.
247
+ ),
248
+ reasoning_tokens=(
249
+ completion.usage.completion_tokens_details.reasoning_tokens
250
+ if completion.usage.completion_tokens_details is not None
251
+ else None
252
+ ),
225
253
  total_tokens=completion.usage.total_tokens,
226
254
  )
227
255
  if completion.usage
@@ -241,10 +269,8 @@ class OpenAIAPI(ModelAPI):
241
269
  def is_rate_limit(self, ex: BaseException) -> bool:
242
270
  if isinstance(ex, RateLimitError):
243
271
  # Do not retry on these rate limit errors
244
- if (
245
- "Request too large" not in ex.message
246
- and "You exceeded your current quota" not in ex.message
247
- ):
272
+ # The quota exceeded one is related to monthly account quotas.
273
+ if "You exceeded your current quota" not in ex.message:
248
274
  return True
249
275
  elif isinstance(
250
276
  ex, (APIConnectionError | APITimeoutError | InternalServerError)
@@ -333,3 +359,39 @@ class OpenAIAPI(ModelAPI):
333
359
  )
334
360
  else:
335
361
  return e
362
+
363
+
364
+ class OpenAIAsyncHttpxClient(httpx.AsyncClient):
365
+ """Custom async client that deals better with long running Async requests.
366
+
367
+ Based on Anthropic DefaultAsyncHttpClient implementation that they
368
+ released along with Claude 3.7 as well as the OpenAI DefaultAsyncHttpxClient
369
+
370
+ """
371
+
372
+ def __init__(self, **kwargs: Any) -> None:
373
+ # This is based on the openai DefaultAsyncHttpxClient:
374
+ # https://github.com/openai/openai-python/commit/347363ed67a6a1611346427bb9ebe4becce53f7e
375
+ kwargs.setdefault("timeout", DEFAULT_TIMEOUT)
376
+ kwargs.setdefault("limits", DEFAULT_CONNECTION_LIMITS)
377
+ kwargs.setdefault("follow_redirects", True)
378
+
379
+ # This is based on the anthrpopic changes for claude 3.7:
380
+ # https://github.com/anthropics/anthropic-sdk-python/commit/c5387e69e799f14e44006ea4e54fdf32f2f74393#diff-3acba71f89118b06b03f2ba9f782c49ceed5bb9f68d62727d929f1841b61d12bR1387-R1403
381
+
382
+ # set socket options to deal with long running reasoning requests
383
+ socket_options = [
384
+ (socket.SOL_SOCKET, socket.SO_KEEPALIVE, True),
385
+ (socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 60),
386
+ (socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 5),
387
+ ]
388
+ TCP_KEEPIDLE = getattr(socket, "TCP_KEEPIDLE", None)
389
+ if TCP_KEEPIDLE is not None:
390
+ socket_options.append((socket.IPPROTO_TCP, TCP_KEEPIDLE, 60))
391
+
392
+ kwargs["transport"] = httpx.AsyncHTTPTransport(
393
+ limits=DEFAULT_CONNECTION_LIMITS,
394
+ socket_options=socket_options,
395
+ )
396
+
397
+ super().__init__(**kwargs)
@@ -69,6 +69,16 @@ async def generate_o1(
69
69
  usage=ModelUsage(
70
70
  input_tokens=completion.usage.prompt_tokens,
71
71
  output_tokens=completion.usage.completion_tokens,
72
+ input_tokens_cache_read=(
73
+ completion.usage.prompt_tokens_details.cached_tokens
74
+ if completion.usage.prompt_tokens_details is not None
75
+ else None # openai only have cache read stats/pricing.
76
+ ),
77
+ reasoning_tokens=(
78
+ completion.usage.completion_tokens_details.reasoning_tokens
79
+ if completion.usage.completion_tokens_details is not None
80
+ else None
81
+ ),
72
82
  total_tokens=completion.usage.total_tokens,
73
83
  )
74
84
  if completion.usage
@@ -48,7 +48,7 @@ def openai() -> type[ModelAPI]:
48
48
  def anthropic() -> type[ModelAPI]:
49
49
  FEATURE = "Anthropic API"
50
50
  PACKAGE = "anthropic"
51
- MIN_VERSION = "0.29.0"
51
+ MIN_VERSION = "0.47.1"
52
52
 
53
53
  # verify we have the package
54
54
  try:
@@ -148,7 +148,7 @@ def cf() -> type[ModelAPI]:
148
148
  def mistral() -> type[ModelAPI]:
149
149
  FEATURE = "Mistral API"
150
150
  PACKAGE = "mistralai"
151
- MIN_VERSION = "1.2.0"
151
+ MIN_VERSION = "1.5.0"
152
152
 
153
153
  # verify we have the package
154
154
  try:
@@ -0,0 +1,92 @@
1
+ import re
2
+ import time
3
+ from typing import Any, cast
4
+
5
+ import httpx
6
+ from shortuuid import uuid
7
+
8
+
9
+ class HttpTimeTracker:
10
+ def __init__(self) -> None:
11
+ # track request start times
12
+ self._requests: dict[str, float] = {}
13
+
14
+ def start_request(self) -> str:
15
+ request_id = uuid()
16
+ self._requests[request_id] = time.monotonic()
17
+ return request_id
18
+
19
+ def end_request(self, request_id: str) -> float:
20
+ # read the request time if (if available) and purge from dict
21
+ request_time = self._requests.pop(request_id, None)
22
+ if request_time is None:
23
+ raise RuntimeError(f"request_id not registered: {request_id}")
24
+
25
+ # return elapsed time
26
+ return time.monotonic() - request_time
27
+
28
+ def update_request_time(self, request_id: str) -> None:
29
+ request_time = self._requests.get(request_id, None)
30
+ if not request_time:
31
+ raise RuntimeError(f"No request registered for request_id: {request_id}")
32
+
33
+ # update the request time
34
+ self._requests[request_id] = time.monotonic()
35
+
36
+
37
+ class BotoTimeTracker(HttpTimeTracker):
38
+ def __init__(self, session: Any) -> None:
39
+ from aiobotocore.session import AioSession
40
+
41
+ super().__init__()
42
+
43
+ # register hook
44
+ session = cast(AioSession, session._session)
45
+ session.register(
46
+ "before-send.bedrock-runtime.Converse", self.converse_before_send
47
+ )
48
+
49
+ def converse_before_send(self, **kwargs: Any) -> None:
50
+ user_agent = kwargs["request"].headers["User-Agent"].decode()
51
+ match = re.search(rf"{self.USER_AGENT_PREFIX}(\w+)", user_agent)
52
+ if match:
53
+ request_id = match.group(1)
54
+ self.update_request_time(request_id)
55
+
56
+ def user_agent_extra(self, request_id: str) -> str:
57
+ return f"{self.USER_AGENT_PREFIX}{request_id}"
58
+
59
+ USER_AGENT_PREFIX = "ins/rid#"
60
+
61
+
62
+ class HttpxTimeTracker(HttpTimeTracker):
63
+ """Class which tracks the duration of successful (200 status) http requests.
64
+
65
+ A special header is injected into requests which is then read from
66
+ an httpx 'request' event hook -- this creates a record of when the request
67
+ started. Note that with retries a single request id could be started
68
+ several times; our request hook makes sure we always track the time of
69
+ the last request.
70
+
71
+ To determine the total time, we also install an httpx response hook. In
72
+ this hook we look for 200 responses which have a registered request id.
73
+ When we find one, we update the end time of the request.
74
+
75
+ There is an 'end_request()' method which gets the total requeset time
76
+ for a request_id and then purges the request_id from our tracking (so
77
+ the dict doesn't grow unbounded)
78
+ """
79
+
80
+ REQUEST_ID_HEADER = "x-irid"
81
+
82
+ def __init__(self, client: httpx.AsyncClient):
83
+ super().__init__()
84
+
85
+ # install httpx request hook
86
+ client.event_hooks["request"].append(self.request_hook)
87
+
88
+ async def request_hook(self, request: httpx.Request) -> None:
89
+ # update the last request time for this request id (as there could be retries)
90
+ request_id = request.headers.get(self.REQUEST_ID_HEADER, None)
91
+ if request_id:
92
+ self.update_request_time(request_id)
@@ -2,6 +2,7 @@ import asyncio
2
2
  import functools
3
3
  import gc
4
4
  import os
5
+ import time
5
6
  from dataclasses import dataclass
6
7
  from queue import Empty, Queue
7
8
  from threading import Thread
@@ -48,7 +49,8 @@ class GenerateOutput:
48
49
  output_tokens: int
49
50
  total_tokens: int
50
51
  stop_reason: StopReason
51
- logprobs: Logprobs | None = None
52
+ logprobs: Logprobs | None
53
+ time: float
52
54
 
53
55
 
54
56
  class VLLMAPI(ModelAPI):
@@ -258,6 +260,7 @@ class VLLMAPI(ModelAPI):
258
260
  ]
259
261
 
260
262
  # TODO: what's the best way to calculate token usage for num_choices > 1
263
+ total_time = responses[0].time
261
264
  input_tokens = responses[0].input_tokens
262
265
  output_tokens = sum(response.output_tokens for response in responses)
263
266
  total_tokens = input_tokens + output_tokens
@@ -270,6 +273,7 @@ class VLLMAPI(ModelAPI):
270
273
  output_tokens=output_tokens,
271
274
  total_tokens=total_tokens,
272
275
  ),
276
+ time=total_time,
273
277
  )
274
278
 
275
279
 
@@ -356,7 +360,7 @@ def get_stop_reason(finish_reason: str | None) -> StopReason:
356
360
 
357
361
 
358
362
  def post_process_output(
359
- output: RequestOutput, i: int, num_top_logprobs: int | None
363
+ output: RequestOutput, i: int, num_top_logprobs: int | None, total_time: float
360
364
  ) -> GenerateOutput:
361
365
  completion = output.outputs[i]
362
366
  output_text: str = completion.text
@@ -377,14 +381,15 @@ def post_process_output(
377
381
  total_tokens=total_tokens,
378
382
  stop_reason=get_stop_reason(completion.finish_reason),
379
383
  logprobs=extract_logprobs(completion, num_top_logprobs),
384
+ time=total_time,
380
385
  )
381
386
 
382
387
 
383
388
  def post_process_outputs(
384
- output: RequestOutput, num_top_logprobs: int | None
389
+ output: RequestOutput, num_top_logprobs: int | None, total_time: float
385
390
  ) -> list[GenerateOutput]:
386
391
  return [
387
- post_process_output(output, i, num_top_logprobs)
392
+ post_process_output(output, i, num_top_logprobs, total_time)
388
393
  for i in range(len(output.outputs))
389
394
  ]
390
395
 
@@ -412,6 +417,7 @@ def process_batches() -> None:
412
417
  continue
413
418
 
414
419
  try:
420
+ start_time = time.monotonic()
415
421
  first_input = inputs[0][0]
416
422
  generator = first_input.generator
417
423
  num_top_logprobs = first_input.num_top_logprobs
@@ -419,6 +425,7 @@ def process_batches() -> None:
419
425
  # generate
420
426
  outputs = generator([input[0].input for input in inputs])
421
427
 
428
+ total_time = time.monotonic() - start_time
422
429
  for i, output in enumerate(outputs):
423
430
  future = inputs[i][1]
424
431
 
@@ -426,7 +433,8 @@ def process_batches() -> None:
426
433
  # down to this point, so we can mark the future as done in a thread safe manner.
427
434
  # see: https://docs.python.org/3/library/asyncio-dev.html#concurrency-and-multithreading
428
435
  loop.call_soon_threadsafe(
429
- future.set_result, post_process_outputs(output, num_top_logprobs)
436
+ future.set_result,
437
+ post_process_outputs(output, num_top_logprobs, total_time),
430
438
  )
431
439
 
432
440
  except Exception as e:
@@ -5,13 +5,26 @@ from typing import NamedTuple
5
5
  class ContentWithReasoning(NamedTuple):
6
6
  content: str
7
7
  reasoning: str
8
+ signature: str | None = None
9
+ redacted: bool = False
8
10
 
9
11
 
10
12
  def parse_content_with_reasoning(content: str) -> ContentWithReasoning | None:
11
- match = re.match(r"\s*<think>(.*?)</think>(.*)", content, re.DOTALL)
13
+ # Match <think> tag with optional attributes
14
+ pattern = r'\s*<think(?:\s+signature="([^"]*)")?(?:\s+redacted="(true)")?\s*>(.*?)</think>(.*)'
15
+ match = re.match(pattern, content, re.DOTALL)
16
+
12
17
  if match:
18
+ signature = match.group(1) # This will be None if not present
19
+ redacted_value = match.group(2) # This will be "true" or None
20
+ reasoning = match.group(3).strip()
21
+ content_text = match.group(4).strip()
22
+
13
23
  return ContentWithReasoning(
14
- content=match.group(2).strip(), reasoning=match.group(1).strip()
24
+ content=content_text,
25
+ reasoning=reasoning,
26
+ signature=signature,
27
+ redacted=redacted_value == "true",
15
28
  )
16
29
  else:
17
30
  return None
@@ -274,25 +274,29 @@ def chat_history(state: TaskState) -> str:
274
274
 
275
275
  # begin history with text of first message (it will come right after
276
276
  # 'Task' or 'Question' in the template)
277
- history: list[str] = [messages[0].text]
278
-
279
- # for subsequent messages present with e.g. Assistant: {message.text}
280
- for message in messages[1:]:
281
- if isinstance(message, ChatMessageUser):
282
- history.append(f"User: {message.text}")
283
- elif isinstance(message, ChatMessageAssistant):
284
- assistant_message = [message.text] if message.text else []
285
- if message.tool_calls:
286
- assistant_message.extend(
287
- [
288
- format_function_call(tool_call.function, tool_call.arguments)
289
- for tool_call in message.tool_calls
290
- ]
277
+ history: list[str] = []
278
+ if len(messages) > 0:
279
+ history.append(messages[0].text)
280
+
281
+ # for subsequent messages present with e.g. Assistant: {message.text}
282
+ for message in messages[1:]:
283
+ if isinstance(message, ChatMessageUser):
284
+ history.append(f"User: {message.text}")
285
+ elif isinstance(message, ChatMessageAssistant):
286
+ assistant_message = [message.text] if message.text else []
287
+ if message.tool_calls:
288
+ assistant_message.extend(
289
+ [
290
+ format_function_call(
291
+ tool_call.function, tool_call.arguments
292
+ )
293
+ for tool_call in message.tool_calls
294
+ ]
295
+ )
296
+ history.append("Assistant: " + "\n\n".join(assistant_message))
297
+ elif isinstance(message, ChatMessageTool):
298
+ history.append(
299
+ f"Tool ({message.function}): {message.tool_error or ''}{message.text}"
291
300
  )
292
- history.append("Assistant: " + "\n\n".join(assistant_message))
293
- elif isinstance(message, ChatMessageTool):
294
- history.append(
295
- f"Tool ({message.function}): {message.tool_error or ''}{message.text}"
296
- )
297
301
 
298
302
  return "\n\n".join(history)
@@ -24,7 +24,7 @@ logger = getLogger(__name__)
24
24
 
25
25
  DEFAULT_SYSTEM_MESSAGE = """
26
26
  You are a helpful assistant attempting to submit the correct answer. You have
27
- several functions available to help with finding the answer. Each message may
27
+ several functions available to help with finding the answer. Each message
28
28
  may perform one function call. You will see the result of the function right
29
29
  after sending the message. If you need to perform multiple actions, you can
30
30
  always send more messages with subsequent function calls. Do some reasoning
@@ -206,13 +206,11 @@ def basic_agent(
206
206
  # exit if we are at max_attempts
207
207
  attempts += 1
208
208
  if attempts >= max_attempts:
209
- state.completed = True
210
209
  break
211
210
 
212
211
  # exit if the submission is successful
213
212
  answer_scores = await score(state)
214
213
  if score_value_fn(answer_scores[0].value) == 1.0:
215
- state.completed = True
216
214
  break
217
215
 
218
216
  # otherwise notify the model that it was incorrect and continue
@@ -72,8 +72,6 @@ def init_openai_request_patch() -> None:
72
72
  _patch_enabled.get()
73
73
  # completions request
74
74
  and options.url == "/chat/completions"
75
- # call to openai not another service (e.g. TogetherAI)
76
- and self.base_url == "https://api.openai.com/v1/"
77
75
  ):
78
76
  # must also be an explicit request for an inspect model
79
77
  json_data = cast(dict[str, Any], options.json_data)