inspect-ai 0.3.69__py3-none-any.whl → 0.3.71__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/_cli/eval.py +27 -9
- inspect_ai/_display/core/display.py +2 -0
- inspect_ai/_display/core/footer.py +13 -3
- inspect_ai/_display/plain/display.py +6 -2
- inspect_ai/_display/rich/display.py +19 -6
- inspect_ai/_display/textual/app.py +9 -3
- inspect_ai/_display/textual/display.py +4 -0
- inspect_ai/_display/textual/widgets/samples.py +4 -10
- inspect_ai/_display/textual/widgets/transcript.py +35 -18
- inspect_ai/_eval/eval.py +14 -2
- inspect_ai/_eval/evalset.py +6 -1
- inspect_ai/_eval/run.py +6 -0
- inspect_ai/_eval/task/run.py +49 -23
- inspect_ai/_eval/task/task.py +26 -3
- inspect_ai/_util/content.py +20 -1
- inspect_ai/_util/interrupt.py +6 -0
- inspect_ai/_util/logger.py +19 -0
- inspect_ai/_util/rich.py +7 -8
- inspect_ai/_util/text.py +13 -0
- inspect_ai/_util/transcript.py +20 -6
- inspect_ai/_util/working.py +50 -0
- inspect_ai/_view/www/App.css +6 -0
- inspect_ai/_view/www/dist/assets/index.css +171 -99
- inspect_ai/_view/www/dist/assets/index.js +5972 -2770
- inspect_ai/_view/www/eslint.config.mjs +24 -1
- inspect_ai/_view/www/log-schema.json +619 -21
- inspect_ai/_view/www/package.json +8 -3
- inspect_ai/_view/www/src/App.tsx +2 -2
- inspect_ai/_view/www/src/appearance/icons.ts +3 -1
- inspect_ai/_view/www/src/components/AnsiDisplay.tsx +4 -3
- inspect_ai/_view/www/src/components/Card.tsx +9 -8
- inspect_ai/_view/www/src/components/DownloadButton.tsx +2 -1
- inspect_ai/_view/www/src/components/EmptyPanel.tsx +2 -2
- inspect_ai/_view/www/src/components/ErrorPanel.tsx +4 -3
- inspect_ai/_view/www/src/components/ExpandablePanel.tsx +13 -5
- inspect_ai/_view/www/src/components/FindBand.tsx +3 -3
- inspect_ai/_view/www/src/components/HumanBaselineView.tsx +3 -3
- inspect_ai/_view/www/src/components/LabeledValue.tsx +5 -4
- inspect_ai/_view/www/src/components/LargeModal.tsx +18 -13
- inspect_ai/_view/www/src/components/{LightboxCarousel.css → LightboxCarousel.module.css} +22 -18
- inspect_ai/_view/www/src/components/LightboxCarousel.tsx +36 -27
- inspect_ai/_view/www/src/components/MessageBand.tsx +2 -1
- inspect_ai/_view/www/src/components/NavPills.tsx +9 -8
- inspect_ai/_view/www/src/components/ProgressBar.tsx +2 -1
- inspect_ai/_view/www/src/components/TabSet.tsx +21 -15
- inspect_ai/_view/www/src/index.tsx +2 -2
- inspect_ai/_view/www/src/metadata/MetaDataGrid.tsx +11 -9
- inspect_ai/_view/www/src/metadata/MetaDataView.tsx +3 -2
- inspect_ai/_view/www/src/metadata/MetadataGrid.module.css +1 -0
- inspect_ai/_view/www/src/metadata/RenderedContent.tsx +16 -1
- inspect_ai/_view/www/src/plan/DatasetDetailView.tsx +3 -2
- inspect_ai/_view/www/src/plan/DetailStep.tsx +2 -1
- inspect_ai/_view/www/src/plan/PlanCard.tsx +2 -5
- inspect_ai/_view/www/src/plan/PlanDetailView.tsx +6 -9
- inspect_ai/_view/www/src/plan/ScorerDetailView.tsx +2 -1
- inspect_ai/_view/www/src/plan/SolverDetailView.tsx +3 -3
- inspect_ai/_view/www/src/samples/InlineSampleDisplay.tsx +2 -2
- inspect_ai/_view/www/src/samples/SampleDialog.tsx +3 -3
- inspect_ai/_view/www/src/samples/SampleDisplay.module.css +9 -1
- inspect_ai/_view/www/src/samples/SampleDisplay.tsx +30 -3
- inspect_ai/_view/www/src/samples/SampleSummaryView.module.css +4 -0
- inspect_ai/_view/www/src/samples/SampleSummaryView.tsx +25 -4
- inspect_ai/_view/www/src/samples/SamplesTools.tsx +2 -1
- inspect_ai/_view/www/src/samples/chat/ChatMessage.tsx +3 -19
- inspect_ai/_view/www/src/samples/chat/ChatMessageRenderer.tsx +2 -1
- inspect_ai/_view/www/src/samples/chat/ChatMessageRow.tsx +2 -1
- inspect_ai/_view/www/src/samples/chat/ChatView.tsx +2 -1
- inspect_ai/_view/www/src/samples/chat/ChatViewVirtualList.tsx +22 -7
- inspect_ai/_view/www/src/samples/chat/MessageContent.tsx +35 -6
- inspect_ai/_view/www/src/samples/chat/MessageContents.tsx +2 -2
- inspect_ai/_view/www/src/samples/chat/messages.ts +15 -2
- inspect_ai/_view/www/src/samples/chat/tools/ToolCallView.tsx +13 -4
- inspect_ai/_view/www/src/samples/chat/tools/ToolInput.module.css +2 -2
- inspect_ai/_view/www/src/samples/chat/tools/ToolInput.tsx +18 -19
- inspect_ai/_view/www/src/samples/chat/tools/ToolOutput.module.css +1 -1
- inspect_ai/_view/www/src/samples/chat/tools/ToolOutput.tsx +4 -3
- inspect_ai/_view/www/src/samples/chat/tools/ToolTitle.tsx +2 -2
- inspect_ai/_view/www/src/samples/error/FlatSampleErrorView.tsx +2 -3
- inspect_ai/_view/www/src/samples/error/SampleErrorView.tsx +3 -2
- inspect_ai/_view/www/src/samples/list/SampleFooter.tsx +2 -1
- inspect_ai/_view/www/src/samples/list/SampleHeader.tsx +2 -1
- inspect_ai/_view/www/src/samples/list/SampleList.tsx +57 -45
- inspect_ai/_view/www/src/samples/list/SampleRow.tsx +2 -1
- inspect_ai/_view/www/src/samples/list/SampleSeparator.tsx +2 -1
- inspect_ai/_view/www/src/samples/sample-tools/EpochFilter.tsx +2 -2
- inspect_ai/_view/www/src/samples/sample-tools/SelectScorer.tsx +4 -3
- inspect_ai/_view/www/src/samples/sample-tools/SortFilter.tsx +2 -5
- inspect_ai/_view/www/src/samples/sample-tools/sample-filter/SampleFilter.tsx +2 -2
- inspect_ai/_view/www/src/samples/scores/SampleScoreView.tsx +2 -1
- inspect_ai/_view/www/src/samples/scores/SampleScores.tsx +2 -2
- inspect_ai/_view/www/src/samples/transcript/ApprovalEventView.tsx +2 -1
- inspect_ai/_view/www/src/samples/transcript/ErrorEventView.tsx +2 -1
- inspect_ai/_view/www/src/samples/transcript/InfoEventView.tsx +2 -1
- inspect_ai/_view/www/src/samples/transcript/InputEventView.tsx +2 -1
- inspect_ai/_view/www/src/samples/transcript/LoggerEventView.module.css +4 -0
- inspect_ai/_view/www/src/samples/transcript/LoggerEventView.tsx +12 -2
- inspect_ai/_view/www/src/samples/transcript/ModelEventView.module.css +1 -1
- inspect_ai/_view/www/src/samples/transcript/ModelEventView.tsx +25 -28
- inspect_ai/_view/www/src/samples/transcript/SampleInitEventView.tsx +2 -1
- inspect_ai/_view/www/src/samples/transcript/SampleLimitEventView.tsx +9 -4
- inspect_ai/_view/www/src/samples/transcript/SampleTranscript.tsx +2 -2
- inspect_ai/_view/www/src/samples/transcript/SandboxEventView.module.css +32 -0
- inspect_ai/_view/www/src/samples/transcript/SandboxEventView.tsx +153 -0
- inspect_ai/_view/www/src/samples/transcript/ScoreEventView.tsx +2 -2
- inspect_ai/_view/www/src/samples/transcript/StepEventView.tsx +12 -5
- inspect_ai/_view/www/src/samples/transcript/SubtaskEventView.tsx +18 -14
- inspect_ai/_view/www/src/samples/transcript/ToolEventView.tsx +5 -5
- inspect_ai/_view/www/src/samples/transcript/TranscriptView.tsx +53 -16
- inspect_ai/_view/www/src/samples/transcript/event/EventNav.tsx +2 -1
- inspect_ai/_view/www/src/samples/transcript/event/EventNavs.tsx +2 -1
- inspect_ai/_view/www/src/samples/transcript/event/EventPanel.tsx +6 -3
- inspect_ai/_view/www/src/samples/transcript/event/EventRow.tsx +3 -2
- inspect_ai/_view/www/src/samples/transcript/event/EventSection.tsx +2 -2
- inspect_ai/_view/www/src/samples/transcript/event/EventTimingPanel.module.css +28 -0
- inspect_ai/_view/www/src/samples/transcript/event/EventTimingPanel.tsx +115 -0
- inspect_ai/_view/www/src/samples/transcript/event/utils.ts +29 -0
- inspect_ai/_view/www/src/samples/transcript/state/StateDiffView.tsx +2 -1
- inspect_ai/_view/www/src/samples/transcript/state/StateEventRenderers.tsx +3 -3
- inspect_ai/_view/www/src/samples/transcript/state/StateEventView.tsx +11 -8
- inspect_ai/_view/www/src/samples/transcript/types.ts +3 -1
- inspect_ai/_view/www/src/types/log.d.ts +312 -137
- inspect_ai/_view/www/src/usage/ModelTokenTable.tsx +6 -10
- inspect_ai/_view/www/src/usage/ModelUsagePanel.module.css +4 -0
- inspect_ai/_view/www/src/usage/ModelUsagePanel.tsx +32 -9
- inspect_ai/_view/www/src/usage/TokenTable.tsx +4 -6
- inspect_ai/_view/www/src/usage/UsageCard.tsx +2 -1
- inspect_ai/_view/www/src/utils/format.ts +8 -5
- inspect_ai/_view/www/src/utils/json.ts +24 -0
- inspect_ai/_view/www/src/workspace/WorkSpace.tsx +6 -5
- inspect_ai/_view/www/src/workspace/WorkSpaceView.tsx +18 -8
- inspect_ai/_view/www/src/workspace/error/TaskErrorPanel.tsx +2 -1
- inspect_ai/_view/www/src/workspace/navbar/Navbar.tsx +2 -1
- inspect_ai/_view/www/src/workspace/navbar/PrimaryBar.tsx +3 -3
- inspect_ai/_view/www/src/workspace/navbar/ResultsPanel.tsx +4 -3
- inspect_ai/_view/www/src/workspace/navbar/SecondaryBar.tsx +5 -4
- inspect_ai/_view/www/src/workspace/navbar/StatusPanel.tsx +5 -8
- inspect_ai/_view/www/src/workspace/sidebar/EvalStatus.tsx +5 -4
- inspect_ai/_view/www/src/workspace/sidebar/LogDirectoryTitleView.tsx +2 -1
- inspect_ai/_view/www/src/workspace/sidebar/Sidebar.tsx +2 -1
- inspect_ai/_view/www/src/workspace/sidebar/SidebarLogEntry.tsx +2 -2
- inspect_ai/_view/www/src/workspace/sidebar/SidebarScoreView.tsx +2 -1
- inspect_ai/_view/www/src/workspace/sidebar/SidebarScoresView.tsx +2 -2
- inspect_ai/_view/www/src/workspace/tabs/InfoTab.tsx +2 -2
- inspect_ai/_view/www/src/workspace/tabs/JsonTab.tsx +2 -5
- inspect_ai/_view/www/src/workspace/tabs/SamplesTab.tsx +12 -11
- inspect_ai/_view/www/yarn.lock +241 -5
- inspect_ai/log/__init__.py +2 -0
- inspect_ai/log/_condense.py +4 -0
- inspect_ai/log/_log.py +72 -12
- inspect_ai/log/_recorders/eval.py +6 -1
- inspect_ai/log/_samples.py +5 -1
- inspect_ai/log/_transcript.py +89 -2
- inspect_ai/model/__init__.py +2 -0
- inspect_ai/model/_call_tools.py +8 -1
- inspect_ai/model/_chat_message.py +22 -7
- inspect_ai/model/_conversation.py +11 -9
- inspect_ai/model/_generate_config.py +25 -4
- inspect_ai/model/_model.py +164 -72
- inspect_ai/model/_model_call.py +10 -3
- inspect_ai/model/_model_output.py +3 -0
- inspect_ai/model/_openai.py +106 -40
- inspect_ai/model/_providers/anthropic.py +145 -26
- inspect_ai/model/_providers/bedrock.py +7 -0
- inspect_ai/model/_providers/cloudflare.py +20 -7
- inspect_ai/model/_providers/google.py +29 -8
- inspect_ai/model/_providers/groq.py +66 -27
- inspect_ai/model/_providers/hf.py +6 -0
- inspect_ai/model/_providers/mistral.py +78 -51
- inspect_ai/model/_providers/openai.py +66 -4
- inspect_ai/model/_providers/openai_o1.py +10 -0
- inspect_ai/model/_providers/providers.py +2 -2
- inspect_ai/model/_providers/util/tracker.py +92 -0
- inspect_ai/model/_providers/vllm.py +13 -5
- inspect_ai/model/_reasoning.py +15 -2
- inspect_ai/scorer/_model.py +23 -19
- inspect_ai/solver/_basic_agent.py +1 -3
- inspect_ai/solver/_bridge/patch.py +0 -2
- inspect_ai/solver/_human_agent/agent.py +14 -10
- inspect_ai/solver/_human_agent/commands/__init__.py +7 -3
- inspect_ai/solver/_human_agent/commands/submit.py +76 -30
- inspect_ai/solver/_limit.py +4 -4
- inspect_ai/solver/_plan.py +0 -3
- inspect_ai/solver/_task_state.py +7 -0
- inspect_ai/tool/__init__.py +2 -0
- inspect_ai/tool/_tool.py +3 -1
- inspect_ai/tool/_tools/_computer/_resources/tool/_run.py +1 -1
- inspect_ai/tool/_tools/_web_browser/_resources/.pylintrc +8 -0
- inspect_ai/tool/_tools/_web_browser/_resources/.vscode/launch.json +24 -0
- inspect_ai/tool/_tools/_web_browser/_resources/.vscode/settings.json +25 -0
- inspect_ai/tool/_tools/_web_browser/_resources/Dockerfile +5 -6
- inspect_ai/tool/_tools/_web_browser/_resources/README.md +10 -11
- inspect_ai/tool/_tools/_web_browser/_resources/accessibility_tree.py +71 -0
- inspect_ai/tool/_tools/_web_browser/_resources/accessibility_tree_node.py +323 -0
- inspect_ai/tool/_tools/_web_browser/_resources/cdp/__init__.py +5 -0
- inspect_ai/tool/_tools/_web_browser/_resources/cdp/a11y.py +279 -0
- inspect_ai/tool/_tools/_web_browser/_resources/cdp/dom.py +9 -0
- inspect_ai/tool/_tools/_web_browser/_resources/cdp/dom_snapshot.py +293 -0
- inspect_ai/tool/_tools/_web_browser/_resources/cdp/page.py +94 -0
- inspect_ai/tool/_tools/_web_browser/_resources/constants.py +2 -0
- inspect_ai/tool/_tools/_web_browser/_resources/images/usage_diagram.svg +2 -0
- inspect_ai/tool/_tools/_web_browser/_resources/playwright_browser.py +50 -0
- inspect_ai/tool/_tools/_web_browser/_resources/playwright_crawler.py +31 -359
- inspect_ai/tool/_tools/_web_browser/_resources/playwright_page_crawler.py +280 -0
- inspect_ai/tool/_tools/_web_browser/_resources/pyproject.toml +65 -0
- inspect_ai/tool/_tools/_web_browser/_resources/rectangle.py +64 -0
- inspect_ai/tool/_tools/_web_browser/_resources/rpc_client_helpers.py +146 -0
- inspect_ai/tool/_tools/_web_browser/_resources/scale_factor.py +64 -0
- inspect_ai/tool/_tools/_web_browser/_resources/test_accessibility_tree_node.py +180 -0
- inspect_ai/tool/_tools/_web_browser/_resources/test_playwright_crawler.py +15 -9
- inspect_ai/tool/_tools/_web_browser/_resources/test_rectangle.py +15 -0
- inspect_ai/tool/_tools/_web_browser/_resources/test_web_client.py +44 -0
- inspect_ai/tool/_tools/_web_browser/_resources/web_browser_rpc_types.py +39 -0
- inspect_ai/tool/_tools/_web_browser/_resources/web_client.py +198 -48
- inspect_ai/tool/_tools/_web_browser/_resources/web_client_new_session.py +26 -25
- inspect_ai/tool/_tools/_web_browser/_resources/web_server.py +178 -39
- inspect_ai/tool/_tools/_web_browser/_web_browser.py +38 -19
- inspect_ai/tool/_tools/_web_search.py +3 -3
- inspect_ai/util/__init__.py +2 -1
- inspect_ai/util/_concurrency.py +14 -8
- inspect_ai/util/_display.py +12 -0
- inspect_ai/util/_sandbox/context.py +15 -0
- inspect_ai/util/_sandbox/docker/docker.py +7 -5
- inspect_ai/util/_sandbox/environment.py +32 -1
- inspect_ai/util/_sandbox/events.py +183 -0
- inspect_ai/util/_sandbox/local.py +3 -3
- inspect_ai/util/_sandbox/self_check.py +131 -43
- inspect_ai/util/_subtask.py +11 -0
- {inspect_ai-0.3.69.dist-info → inspect_ai-0.3.71.dist-info}/METADATA +3 -3
- {inspect_ai-0.3.69.dist-info → inspect_ai-0.3.71.dist-info}/RECORD +233 -211
- {inspect_ai-0.3.69.dist-info → inspect_ai-0.3.71.dist-info}/WHEEL +1 -1
- inspect_ai/_view/www/src/components/VirtualList.module.css +0 -19
- inspect_ai/_view/www/src/components/VirtualList.tsx +0 -292
- inspect_ai/tool/_tools/_web_browser/_resources/accessibility_node.py +0 -312
- inspect_ai/tool/_tools/_web_browser/_resources/dm_env_servicer.py +0 -275
- inspect_ai/tool/_tools/_web_browser/_resources/images/usage_diagram.png +0 -0
- inspect_ai/tool/_tools/_web_browser/_resources/test_accessibility_node.py +0 -176
- inspect_ai/tool/_tools/_web_browser/_resources/test_dm_env_servicer.py +0 -135
- inspect_ai/tool/_tools/_web_browser/_resources/test_web_environment.py +0 -71
- inspect_ai/tool/_tools/_web_browser/_resources/web_environment.py +0 -184
- {inspect_ai-0.3.69.dist-info → inspect_ai-0.3.71.dist-info}/LICENSE +0 -0
- {inspect_ai-0.3.69.dist-info → inspect_ai-0.3.71.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.69.dist-info → inspect_ai-0.3.71.dist-info}/top_level.txt +0 -0
@@ -4,6 +4,7 @@ import functools
|
|
4
4
|
import gc
|
5
5
|
import json
|
6
6
|
import os
|
7
|
+
import time
|
7
8
|
from dataclasses import dataclass
|
8
9
|
from queue import Empty, Queue
|
9
10
|
from threading import Thread
|
@@ -220,6 +221,7 @@ class HuggingFaceAPI(ModelAPI):
|
|
220
221
|
output_tokens=response.output_tokens,
|
221
222
|
total_tokens=response.total_tokens,
|
222
223
|
),
|
224
|
+
time=response.time,
|
223
225
|
)
|
224
226
|
|
225
227
|
@override
|
@@ -377,6 +379,7 @@ class GenerateOutput:
|
|
377
379
|
output_tokens: int
|
378
380
|
total_tokens: int
|
379
381
|
logprobs: torch.Tensor | None
|
382
|
+
time: float
|
380
383
|
|
381
384
|
|
382
385
|
@dataclass
|
@@ -432,6 +435,7 @@ def process_batches() -> None:
|
|
432
435
|
|
433
436
|
try:
|
434
437
|
# capture the generator and decoder functions
|
438
|
+
start_time = time.monotonic()
|
435
439
|
first_input = inputs[0][0]
|
436
440
|
device = first_input.device
|
437
441
|
tokenizer = first_input.tokenizer
|
@@ -467,6 +471,7 @@ def process_batches() -> None:
|
|
467
471
|
outputs = decoder(sequences=generated_tokens)
|
468
472
|
|
469
473
|
# call back futures
|
474
|
+
total_time = time.monotonic() - start_time
|
470
475
|
for i, output in enumerate(outputs):
|
471
476
|
future = inputs[i][1]
|
472
477
|
input_tokens = input_ids.size(dim=1)
|
@@ -483,6 +488,7 @@ def process_batches() -> None:
|
|
483
488
|
output_tokens=output_tokens,
|
484
489
|
total_tokens=input_tokens + output_tokens,
|
485
490
|
logprobs=logprobs[i] if logprobs is not None else None,
|
491
|
+
time=total_time,
|
486
492
|
),
|
487
493
|
)
|
488
494
|
|
@@ -61,6 +61,7 @@ from .._model_output import (
|
|
61
61
|
StopReason,
|
62
62
|
)
|
63
63
|
from .util import environment_prerequisite_error, model_base_url
|
64
|
+
from .util.tracker import HttpxTimeTracker
|
64
65
|
|
65
66
|
AZURE_MISTRAL_API_KEY = "AZURE_MISTRAL_API_KEY"
|
66
67
|
AZUREAI_MISTRAL_API_KEY = "AZUREAI_MISTRAL_API_KEY"
|
@@ -125,57 +126,83 @@ class MistralAPI(ModelAPI):
|
|
125
126
|
tool_choice: ToolChoice,
|
126
127
|
config: GenerateConfig,
|
127
128
|
) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
|
128
|
-
#
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
)
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
# send request
|
147
|
-
try:
|
148
|
-
with Mistral(
|
149
|
-
api_key=self.api_key,
|
150
|
-
timeout_ms=(config.timeout if config.timeout else DEFAULT_TIMEOUT)
|
151
|
-
* 1000,
|
152
|
-
**self.model_args,
|
153
|
-
) as client:
|
154
|
-
response = await client.chat.complete_async(**request)
|
155
|
-
except SDKError as ex:
|
156
|
-
if ex.status_code == 400:
|
157
|
-
return self.handle_bad_request(ex), mistral_model_call(request, None)
|
158
|
-
else:
|
159
|
-
raise ex
|
160
|
-
|
161
|
-
if response is None:
|
162
|
-
raise RuntimeError("Mistral model did not return a response from generate.")
|
163
|
-
|
164
|
-
# return model output (w/ tool calls if they exist)
|
165
|
-
choices = completion_choices_from_response(response, tools)
|
166
|
-
return ModelOutput(
|
167
|
-
model=response.model,
|
168
|
-
choices=choices,
|
169
|
-
usage=ModelUsage(
|
170
|
-
input_tokens=response.usage.prompt_tokens,
|
171
|
-
output_tokens=(
|
172
|
-
response.usage.completion_tokens
|
173
|
-
if response.usage.completion_tokens
|
174
|
-
else response.usage.total_tokens - response.usage.prompt_tokens
|
129
|
+
# create client
|
130
|
+
with Mistral(
|
131
|
+
api_key=self.api_key,
|
132
|
+
timeout_ms=(config.timeout if config.timeout else DEFAULT_TIMEOUT) * 1000,
|
133
|
+
**self.model_args,
|
134
|
+
) as client:
|
135
|
+
# create time tracker
|
136
|
+
time_tracker = HttpxTimeTracker(client.sdk_configuration.async_client)
|
137
|
+
|
138
|
+
# build request
|
139
|
+
request_id = time_tracker.start_request()
|
140
|
+
request: dict[str, Any] = dict(
|
141
|
+
model=self.model_name,
|
142
|
+
messages=await mistral_chat_messages(input),
|
143
|
+
tools=mistral_chat_tools(tools) if len(tools) > 0 else None,
|
144
|
+
tool_choice=(
|
145
|
+
mistral_chat_tool_choice(tool_choice) if len(tools) > 0 else None
|
175
146
|
),
|
176
|
-
|
177
|
-
)
|
178
|
-
|
147
|
+
http_headers={HttpxTimeTracker.REQUEST_ID_HEADER: request_id},
|
148
|
+
)
|
149
|
+
if config.temperature is not None:
|
150
|
+
request["temperature"] = config.temperature
|
151
|
+
if config.top_p is not None:
|
152
|
+
request["top_p"] = config.top_p
|
153
|
+
if config.max_tokens is not None:
|
154
|
+
request["max_tokens"] = config.max_tokens
|
155
|
+
if config.seed is not None:
|
156
|
+
request["random_seed"] = config.seed
|
157
|
+
|
158
|
+
# prepare response for inclusion in model call
|
159
|
+
response: dict[str, Any] = {}
|
160
|
+
|
161
|
+
def model_call() -> ModelCall:
|
162
|
+
req = request.copy()
|
163
|
+
req.update(
|
164
|
+
messages=[message.model_dump() for message in req["messages"]]
|
165
|
+
)
|
166
|
+
if req.get("tools", None) is not None:
|
167
|
+
req["tools"] = [tool.model_dump() for tool in req["tools"]]
|
168
|
+
|
169
|
+
return ModelCall.create(
|
170
|
+
request=req,
|
171
|
+
response=response,
|
172
|
+
time=time_tracker.end_request(request_id),
|
173
|
+
)
|
174
|
+
|
175
|
+
# send request
|
176
|
+
try:
|
177
|
+
completion = await client.chat.complete_async(**request)
|
178
|
+
response = completion.model_dump()
|
179
|
+
except SDKError as ex:
|
180
|
+
if ex.status_code == 400:
|
181
|
+
return self.handle_bad_request(ex), model_call()
|
182
|
+
else:
|
183
|
+
raise ex
|
184
|
+
|
185
|
+
if completion is None:
|
186
|
+
raise RuntimeError(
|
187
|
+
"Mistral model did not return a response from generate."
|
188
|
+
)
|
189
|
+
|
190
|
+
# return model output (w/ tool calls if they exist)
|
191
|
+
choices = completion_choices_from_response(completion, tools)
|
192
|
+
return ModelOutput(
|
193
|
+
model=completion.model,
|
194
|
+
choices=choices,
|
195
|
+
usage=ModelUsage(
|
196
|
+
input_tokens=completion.usage.prompt_tokens,
|
197
|
+
output_tokens=(
|
198
|
+
completion.usage.completion_tokens
|
199
|
+
if completion.usage.completion_tokens
|
200
|
+
else completion.usage.total_tokens
|
201
|
+
- completion.usage.prompt_tokens
|
202
|
+
),
|
203
|
+
total_tokens=completion.usage.total_tokens,
|
204
|
+
),
|
205
|
+
), model_call()
|
179
206
|
|
180
207
|
@override
|
181
208
|
def is_rate_limit(self, ex: BaseException) -> bool:
|
@@ -207,7 +234,7 @@ def mistral_model_call(
|
|
207
234
|
request.update(messages=[message.model_dump() for message in request["messages"]])
|
208
235
|
if request.get("tools", None) is not None:
|
209
236
|
request["tools"] = [tool.model_dump() for tool in request["tools"]]
|
210
|
-
return ModelCall(
|
237
|
+
return ModelCall.create(
|
211
238
|
request=request, response=response.model_dump() if response else {}
|
212
239
|
)
|
213
240
|
|
@@ -1,8 +1,12 @@
|
|
1
1
|
import os
|
2
|
+
import socket
|
2
3
|
from logging import getLogger
|
3
4
|
from typing import Any
|
4
5
|
|
6
|
+
import httpx
|
5
7
|
from openai import (
|
8
|
+
DEFAULT_CONNECTION_LIMITS,
|
9
|
+
DEFAULT_TIMEOUT,
|
6
10
|
APIConnectionError,
|
7
11
|
APITimeoutError,
|
8
12
|
AsyncAzureOpenAI,
|
@@ -21,6 +25,7 @@ from inspect_ai._util.constants import DEFAULT_MAX_RETRIES
|
|
21
25
|
from inspect_ai._util.error import PrerequisiteError
|
22
26
|
from inspect_ai._util.logger import warn_once
|
23
27
|
from inspect_ai.model._openai import chat_choices_from_openai
|
28
|
+
from inspect_ai.model._providers.util.tracker import HttpxTimeTracker
|
24
29
|
from inspect_ai.tool import ToolChoice, ToolInfo
|
25
30
|
|
26
31
|
from .._chat_message import ChatMessage
|
@@ -101,6 +106,9 @@ class OpenAIAPI(ModelAPI):
|
|
101
106
|
],
|
102
107
|
)
|
103
108
|
|
109
|
+
# create async http client
|
110
|
+
http_client = OpenAIAsyncHttpxClient()
|
111
|
+
|
104
112
|
# azure client
|
105
113
|
if self.is_azure():
|
106
114
|
# resolve base_url
|
@@ -125,6 +133,7 @@ class OpenAIAPI(ModelAPI):
|
|
125
133
|
max_retries=(
|
126
134
|
config.max_retries if config.max_retries else DEFAULT_MAX_RETRIES
|
127
135
|
),
|
136
|
+
http_client=http_client,
|
128
137
|
**model_args,
|
129
138
|
)
|
130
139
|
else:
|
@@ -134,9 +143,13 @@ class OpenAIAPI(ModelAPI):
|
|
134
143
|
max_retries=(
|
135
144
|
config.max_retries if config.max_retries else DEFAULT_MAX_RETRIES
|
136
145
|
),
|
146
|
+
http_client=http_client,
|
137
147
|
**model_args,
|
138
148
|
)
|
139
149
|
|
150
|
+
# create time tracker
|
151
|
+
self._time_tracker = HttpxTimeTracker(self.client._client)
|
152
|
+
|
140
153
|
def is_azure(self) -> bool:
|
141
154
|
return self.service == "azure"
|
142
155
|
|
@@ -172,6 +185,9 @@ class OpenAIAPI(ModelAPI):
|
|
172
185
|
**self.completion_params(config, False),
|
173
186
|
)
|
174
187
|
|
188
|
+
# allocate request_id (so we can see it from ModelCall)
|
189
|
+
request_id = self._time_tracker.start_request()
|
190
|
+
|
175
191
|
# setup request and response for ModelCall
|
176
192
|
request: dict[str, Any] = {}
|
177
193
|
response: dict[str, Any] = {}
|
@@ -181,6 +197,7 @@ class OpenAIAPI(ModelAPI):
|
|
181
197
|
request=request,
|
182
198
|
response=response,
|
183
199
|
filter=image_url_filter,
|
200
|
+
time=self._time_tracker.end_request(request_id),
|
184
201
|
)
|
185
202
|
|
186
203
|
# unlike text models, vision models require a max_tokens (and set it to a very low
|
@@ -199,6 +216,7 @@ class OpenAIAPI(ModelAPI):
|
|
199
216
|
tool_choice=openai_chat_tool_choice(tool_choice)
|
200
217
|
if len(tools) > 0
|
201
218
|
else NOT_GIVEN,
|
219
|
+
extra_headers={HttpxTimeTracker.REQUEST_ID_HEADER: request_id},
|
202
220
|
**self.completion_params(config, len(tools) > 0),
|
203
221
|
)
|
204
222
|
|
@@ -222,6 +240,16 @@ class OpenAIAPI(ModelAPI):
|
|
222
240
|
ModelUsage(
|
223
241
|
input_tokens=completion.usage.prompt_tokens,
|
224
242
|
output_tokens=completion.usage.completion_tokens,
|
243
|
+
input_tokens_cache_read=(
|
244
|
+
completion.usage.prompt_tokens_details.cached_tokens
|
245
|
+
if completion.usage.prompt_tokens_details is not None
|
246
|
+
else None # openai only have cache read stats/pricing.
|
247
|
+
),
|
248
|
+
reasoning_tokens=(
|
249
|
+
completion.usage.completion_tokens_details.reasoning_tokens
|
250
|
+
if completion.usage.completion_tokens_details is not None
|
251
|
+
else None
|
252
|
+
),
|
225
253
|
total_tokens=completion.usage.total_tokens,
|
226
254
|
)
|
227
255
|
if completion.usage
|
@@ -241,10 +269,8 @@ class OpenAIAPI(ModelAPI):
|
|
241
269
|
def is_rate_limit(self, ex: BaseException) -> bool:
|
242
270
|
if isinstance(ex, RateLimitError):
|
243
271
|
# Do not retry on these rate limit errors
|
244
|
-
|
245
|
-
|
246
|
-
and "You exceeded your current quota" not in ex.message
|
247
|
-
):
|
272
|
+
# The quota exceeded one is related to monthly account quotas.
|
273
|
+
if "You exceeded your current quota" not in ex.message:
|
248
274
|
return True
|
249
275
|
elif isinstance(
|
250
276
|
ex, (APIConnectionError | APITimeoutError | InternalServerError)
|
@@ -333,3 +359,39 @@ class OpenAIAPI(ModelAPI):
|
|
333
359
|
)
|
334
360
|
else:
|
335
361
|
return e
|
362
|
+
|
363
|
+
|
364
|
+
class OpenAIAsyncHttpxClient(httpx.AsyncClient):
|
365
|
+
"""Custom async client that deals better with long running Async requests.
|
366
|
+
|
367
|
+
Based on Anthropic DefaultAsyncHttpClient implementation that they
|
368
|
+
released along with Claude 3.7 as well as the OpenAI DefaultAsyncHttpxClient
|
369
|
+
|
370
|
+
"""
|
371
|
+
|
372
|
+
def __init__(self, **kwargs: Any) -> None:
|
373
|
+
# This is based on the openai DefaultAsyncHttpxClient:
|
374
|
+
# https://github.com/openai/openai-python/commit/347363ed67a6a1611346427bb9ebe4becce53f7e
|
375
|
+
kwargs.setdefault("timeout", DEFAULT_TIMEOUT)
|
376
|
+
kwargs.setdefault("limits", DEFAULT_CONNECTION_LIMITS)
|
377
|
+
kwargs.setdefault("follow_redirects", True)
|
378
|
+
|
379
|
+
# This is based on the anthrpopic changes for claude 3.7:
|
380
|
+
# https://github.com/anthropics/anthropic-sdk-python/commit/c5387e69e799f14e44006ea4e54fdf32f2f74393#diff-3acba71f89118b06b03f2ba9f782c49ceed5bb9f68d62727d929f1841b61d12bR1387-R1403
|
381
|
+
|
382
|
+
# set socket options to deal with long running reasoning requests
|
383
|
+
socket_options = [
|
384
|
+
(socket.SOL_SOCKET, socket.SO_KEEPALIVE, True),
|
385
|
+
(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 60),
|
386
|
+
(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 5),
|
387
|
+
]
|
388
|
+
TCP_KEEPIDLE = getattr(socket, "TCP_KEEPIDLE", None)
|
389
|
+
if TCP_KEEPIDLE is not None:
|
390
|
+
socket_options.append((socket.IPPROTO_TCP, TCP_KEEPIDLE, 60))
|
391
|
+
|
392
|
+
kwargs["transport"] = httpx.AsyncHTTPTransport(
|
393
|
+
limits=DEFAULT_CONNECTION_LIMITS,
|
394
|
+
socket_options=socket_options,
|
395
|
+
)
|
396
|
+
|
397
|
+
super().__init__(**kwargs)
|
@@ -69,6 +69,16 @@ async def generate_o1(
|
|
69
69
|
usage=ModelUsage(
|
70
70
|
input_tokens=completion.usage.prompt_tokens,
|
71
71
|
output_tokens=completion.usage.completion_tokens,
|
72
|
+
input_tokens_cache_read=(
|
73
|
+
completion.usage.prompt_tokens_details.cached_tokens
|
74
|
+
if completion.usage.prompt_tokens_details is not None
|
75
|
+
else None # openai only have cache read stats/pricing.
|
76
|
+
),
|
77
|
+
reasoning_tokens=(
|
78
|
+
completion.usage.completion_tokens_details.reasoning_tokens
|
79
|
+
if completion.usage.completion_tokens_details is not None
|
80
|
+
else None
|
81
|
+
),
|
72
82
|
total_tokens=completion.usage.total_tokens,
|
73
83
|
)
|
74
84
|
if completion.usage
|
@@ -48,7 +48,7 @@ def openai() -> type[ModelAPI]:
|
|
48
48
|
def anthropic() -> type[ModelAPI]:
|
49
49
|
FEATURE = "Anthropic API"
|
50
50
|
PACKAGE = "anthropic"
|
51
|
-
MIN_VERSION = "0.
|
51
|
+
MIN_VERSION = "0.47.1"
|
52
52
|
|
53
53
|
# verify we have the package
|
54
54
|
try:
|
@@ -148,7 +148,7 @@ def cf() -> type[ModelAPI]:
|
|
148
148
|
def mistral() -> type[ModelAPI]:
|
149
149
|
FEATURE = "Mistral API"
|
150
150
|
PACKAGE = "mistralai"
|
151
|
-
MIN_VERSION = "1.
|
151
|
+
MIN_VERSION = "1.5.0"
|
152
152
|
|
153
153
|
# verify we have the package
|
154
154
|
try:
|
@@ -0,0 +1,92 @@
|
|
1
|
+
import re
|
2
|
+
import time
|
3
|
+
from typing import Any, cast
|
4
|
+
|
5
|
+
import httpx
|
6
|
+
from shortuuid import uuid
|
7
|
+
|
8
|
+
|
9
|
+
class HttpTimeTracker:
|
10
|
+
def __init__(self) -> None:
|
11
|
+
# track request start times
|
12
|
+
self._requests: dict[str, float] = {}
|
13
|
+
|
14
|
+
def start_request(self) -> str:
|
15
|
+
request_id = uuid()
|
16
|
+
self._requests[request_id] = time.monotonic()
|
17
|
+
return request_id
|
18
|
+
|
19
|
+
def end_request(self, request_id: str) -> float:
|
20
|
+
# read the request time if (if available) and purge from dict
|
21
|
+
request_time = self._requests.pop(request_id, None)
|
22
|
+
if request_time is None:
|
23
|
+
raise RuntimeError(f"request_id not registered: {request_id}")
|
24
|
+
|
25
|
+
# return elapsed time
|
26
|
+
return time.monotonic() - request_time
|
27
|
+
|
28
|
+
def update_request_time(self, request_id: str) -> None:
|
29
|
+
request_time = self._requests.get(request_id, None)
|
30
|
+
if not request_time:
|
31
|
+
raise RuntimeError(f"No request registered for request_id: {request_id}")
|
32
|
+
|
33
|
+
# update the request time
|
34
|
+
self._requests[request_id] = time.monotonic()
|
35
|
+
|
36
|
+
|
37
|
+
class BotoTimeTracker(HttpTimeTracker):
|
38
|
+
def __init__(self, session: Any) -> None:
|
39
|
+
from aiobotocore.session import AioSession
|
40
|
+
|
41
|
+
super().__init__()
|
42
|
+
|
43
|
+
# register hook
|
44
|
+
session = cast(AioSession, session._session)
|
45
|
+
session.register(
|
46
|
+
"before-send.bedrock-runtime.Converse", self.converse_before_send
|
47
|
+
)
|
48
|
+
|
49
|
+
def converse_before_send(self, **kwargs: Any) -> None:
|
50
|
+
user_agent = kwargs["request"].headers["User-Agent"].decode()
|
51
|
+
match = re.search(rf"{self.USER_AGENT_PREFIX}(\w+)", user_agent)
|
52
|
+
if match:
|
53
|
+
request_id = match.group(1)
|
54
|
+
self.update_request_time(request_id)
|
55
|
+
|
56
|
+
def user_agent_extra(self, request_id: str) -> str:
|
57
|
+
return f"{self.USER_AGENT_PREFIX}{request_id}"
|
58
|
+
|
59
|
+
USER_AGENT_PREFIX = "ins/rid#"
|
60
|
+
|
61
|
+
|
62
|
+
class HttpxTimeTracker(HttpTimeTracker):
|
63
|
+
"""Class which tracks the duration of successful (200 status) http requests.
|
64
|
+
|
65
|
+
A special header is injected into requests which is then read from
|
66
|
+
an httpx 'request' event hook -- this creates a record of when the request
|
67
|
+
started. Note that with retries a single request id could be started
|
68
|
+
several times; our request hook makes sure we always track the time of
|
69
|
+
the last request.
|
70
|
+
|
71
|
+
To determine the total time, we also install an httpx response hook. In
|
72
|
+
this hook we look for 200 responses which have a registered request id.
|
73
|
+
When we find one, we update the end time of the request.
|
74
|
+
|
75
|
+
There is an 'end_request()' method which gets the total requeset time
|
76
|
+
for a request_id and then purges the request_id from our tracking (so
|
77
|
+
the dict doesn't grow unbounded)
|
78
|
+
"""
|
79
|
+
|
80
|
+
REQUEST_ID_HEADER = "x-irid"
|
81
|
+
|
82
|
+
def __init__(self, client: httpx.AsyncClient):
|
83
|
+
super().__init__()
|
84
|
+
|
85
|
+
# install httpx request hook
|
86
|
+
client.event_hooks["request"].append(self.request_hook)
|
87
|
+
|
88
|
+
async def request_hook(self, request: httpx.Request) -> None:
|
89
|
+
# update the last request time for this request id (as there could be retries)
|
90
|
+
request_id = request.headers.get(self.REQUEST_ID_HEADER, None)
|
91
|
+
if request_id:
|
92
|
+
self.update_request_time(request_id)
|
@@ -2,6 +2,7 @@ import asyncio
|
|
2
2
|
import functools
|
3
3
|
import gc
|
4
4
|
import os
|
5
|
+
import time
|
5
6
|
from dataclasses import dataclass
|
6
7
|
from queue import Empty, Queue
|
7
8
|
from threading import Thread
|
@@ -48,7 +49,8 @@ class GenerateOutput:
|
|
48
49
|
output_tokens: int
|
49
50
|
total_tokens: int
|
50
51
|
stop_reason: StopReason
|
51
|
-
logprobs: Logprobs | None
|
52
|
+
logprobs: Logprobs | None
|
53
|
+
time: float
|
52
54
|
|
53
55
|
|
54
56
|
class VLLMAPI(ModelAPI):
|
@@ -258,6 +260,7 @@ class VLLMAPI(ModelAPI):
|
|
258
260
|
]
|
259
261
|
|
260
262
|
# TODO: what's the best way to calculate token usage for num_choices > 1
|
263
|
+
total_time = responses[0].time
|
261
264
|
input_tokens = responses[0].input_tokens
|
262
265
|
output_tokens = sum(response.output_tokens for response in responses)
|
263
266
|
total_tokens = input_tokens + output_tokens
|
@@ -270,6 +273,7 @@ class VLLMAPI(ModelAPI):
|
|
270
273
|
output_tokens=output_tokens,
|
271
274
|
total_tokens=total_tokens,
|
272
275
|
),
|
276
|
+
time=total_time,
|
273
277
|
)
|
274
278
|
|
275
279
|
|
@@ -356,7 +360,7 @@ def get_stop_reason(finish_reason: str | None) -> StopReason:
|
|
356
360
|
|
357
361
|
|
358
362
|
def post_process_output(
|
359
|
-
output: RequestOutput, i: int, num_top_logprobs: int | None
|
363
|
+
output: RequestOutput, i: int, num_top_logprobs: int | None, total_time: float
|
360
364
|
) -> GenerateOutput:
|
361
365
|
completion = output.outputs[i]
|
362
366
|
output_text: str = completion.text
|
@@ -377,14 +381,15 @@ def post_process_output(
|
|
377
381
|
total_tokens=total_tokens,
|
378
382
|
stop_reason=get_stop_reason(completion.finish_reason),
|
379
383
|
logprobs=extract_logprobs(completion, num_top_logprobs),
|
384
|
+
time=total_time,
|
380
385
|
)
|
381
386
|
|
382
387
|
|
383
388
|
def post_process_outputs(
|
384
|
-
output: RequestOutput, num_top_logprobs: int | None
|
389
|
+
output: RequestOutput, num_top_logprobs: int | None, total_time: float
|
385
390
|
) -> list[GenerateOutput]:
|
386
391
|
return [
|
387
|
-
post_process_output(output, i, num_top_logprobs)
|
392
|
+
post_process_output(output, i, num_top_logprobs, total_time)
|
388
393
|
for i in range(len(output.outputs))
|
389
394
|
]
|
390
395
|
|
@@ -412,6 +417,7 @@ def process_batches() -> None:
|
|
412
417
|
continue
|
413
418
|
|
414
419
|
try:
|
420
|
+
start_time = time.monotonic()
|
415
421
|
first_input = inputs[0][0]
|
416
422
|
generator = first_input.generator
|
417
423
|
num_top_logprobs = first_input.num_top_logprobs
|
@@ -419,6 +425,7 @@ def process_batches() -> None:
|
|
419
425
|
# generate
|
420
426
|
outputs = generator([input[0].input for input in inputs])
|
421
427
|
|
428
|
+
total_time = time.monotonic() - start_time
|
422
429
|
for i, output in enumerate(outputs):
|
423
430
|
future = inputs[i][1]
|
424
431
|
|
@@ -426,7 +433,8 @@ def process_batches() -> None:
|
|
426
433
|
# down to this point, so we can mark the future as done in a thread safe manner.
|
427
434
|
# see: https://docs.python.org/3/library/asyncio-dev.html#concurrency-and-multithreading
|
428
435
|
loop.call_soon_threadsafe(
|
429
|
-
future.set_result,
|
436
|
+
future.set_result,
|
437
|
+
post_process_outputs(output, num_top_logprobs, total_time),
|
430
438
|
)
|
431
439
|
|
432
440
|
except Exception as e:
|
inspect_ai/model/_reasoning.py
CHANGED
@@ -5,13 +5,26 @@ from typing import NamedTuple
|
|
5
5
|
class ContentWithReasoning(NamedTuple):
|
6
6
|
content: str
|
7
7
|
reasoning: str
|
8
|
+
signature: str | None = None
|
9
|
+
redacted: bool = False
|
8
10
|
|
9
11
|
|
10
12
|
def parse_content_with_reasoning(content: str) -> ContentWithReasoning | None:
|
11
|
-
|
13
|
+
# Match <think> tag with optional attributes
|
14
|
+
pattern = r'\s*<think(?:\s+signature="([^"]*)")?(?:\s+redacted="(true)")?\s*>(.*?)</think>(.*)'
|
15
|
+
match = re.match(pattern, content, re.DOTALL)
|
16
|
+
|
12
17
|
if match:
|
18
|
+
signature = match.group(1) # This will be None if not present
|
19
|
+
redacted_value = match.group(2) # This will be "true" or None
|
20
|
+
reasoning = match.group(3).strip()
|
21
|
+
content_text = match.group(4).strip()
|
22
|
+
|
13
23
|
return ContentWithReasoning(
|
14
|
-
content=
|
24
|
+
content=content_text,
|
25
|
+
reasoning=reasoning,
|
26
|
+
signature=signature,
|
27
|
+
redacted=redacted_value == "true",
|
15
28
|
)
|
16
29
|
else:
|
17
30
|
return None
|
inspect_ai/scorer/_model.py
CHANGED
@@ -274,25 +274,29 @@ def chat_history(state: TaskState) -> str:
|
|
274
274
|
|
275
275
|
# begin history with text of first message (it will come right after
|
276
276
|
# 'Task' or 'Question' in the template)
|
277
|
-
history: list[str] = [
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
assistant_message.
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
277
|
+
history: list[str] = []
|
278
|
+
if len(messages) > 0:
|
279
|
+
history.append(messages[0].text)
|
280
|
+
|
281
|
+
# for subsequent messages present with e.g. Assistant: {message.text}
|
282
|
+
for message in messages[1:]:
|
283
|
+
if isinstance(message, ChatMessageUser):
|
284
|
+
history.append(f"User: {message.text}")
|
285
|
+
elif isinstance(message, ChatMessageAssistant):
|
286
|
+
assistant_message = [message.text] if message.text else []
|
287
|
+
if message.tool_calls:
|
288
|
+
assistant_message.extend(
|
289
|
+
[
|
290
|
+
format_function_call(
|
291
|
+
tool_call.function, tool_call.arguments
|
292
|
+
)
|
293
|
+
for tool_call in message.tool_calls
|
294
|
+
]
|
295
|
+
)
|
296
|
+
history.append("Assistant: " + "\n\n".join(assistant_message))
|
297
|
+
elif isinstance(message, ChatMessageTool):
|
298
|
+
history.append(
|
299
|
+
f"Tool ({message.function}): {message.tool_error or ''}{message.text}"
|
291
300
|
)
|
292
|
-
history.append("Assistant: " + "\n\n".join(assistant_message))
|
293
|
-
elif isinstance(message, ChatMessageTool):
|
294
|
-
history.append(
|
295
|
-
f"Tool ({message.function}): {message.tool_error or ''}{message.text}"
|
296
|
-
)
|
297
301
|
|
298
302
|
return "\n\n".join(history)
|
@@ -24,7 +24,7 @@ logger = getLogger(__name__)
|
|
24
24
|
|
25
25
|
DEFAULT_SYSTEM_MESSAGE = """
|
26
26
|
You are a helpful assistant attempting to submit the correct answer. You have
|
27
|
-
several functions available to help with finding the answer. Each message
|
27
|
+
several functions available to help with finding the answer. Each message
|
28
28
|
may perform one function call. You will see the result of the function right
|
29
29
|
after sending the message. If you need to perform multiple actions, you can
|
30
30
|
always send more messages with subsequent function calls. Do some reasoning
|
@@ -206,13 +206,11 @@ def basic_agent(
|
|
206
206
|
# exit if we are at max_attempts
|
207
207
|
attempts += 1
|
208
208
|
if attempts >= max_attempts:
|
209
|
-
state.completed = True
|
210
209
|
break
|
211
210
|
|
212
211
|
# exit if the submission is successful
|
213
212
|
answer_scores = await score(state)
|
214
213
|
if score_value_fn(answer_scores[0].value) == 1.0:
|
215
|
-
state.completed = True
|
216
214
|
break
|
217
215
|
|
218
216
|
# otherwise notify the model that it was incorrect and continue
|
@@ -72,8 +72,6 @@ def init_openai_request_patch() -> None:
|
|
72
72
|
_patch_enabled.get()
|
73
73
|
# completions request
|
74
74
|
and options.url == "/chat/completions"
|
75
|
-
# call to openai not another service (e.g. TogetherAI)
|
76
|
-
and self.base_url == "https://api.openai.com/v1/"
|
77
75
|
):
|
78
76
|
# must also be an explicit request for an inspect model
|
79
77
|
json_data = cast(dict[str, Any], options.json_data)
|