inspect-ai 0.3.68__py3-none-any.whl → 0.3.70__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/_cli/eval.py +13 -1
- inspect_ai/_display/plain/display.py +9 -11
- inspect_ai/_display/textual/app.py +5 -5
- inspect_ai/_display/textual/widgets/samples.py +47 -18
- inspect_ai/_display/textual/widgets/transcript.py +25 -12
- inspect_ai/_eval/eval.py +14 -2
- inspect_ai/_eval/evalset.py +6 -1
- inspect_ai/_eval/run.py +6 -0
- inspect_ai/_eval/task/run.py +44 -15
- inspect_ai/_eval/task/task.py +26 -3
- inspect_ai/_util/interrupt.py +15 -0
- inspect_ai/_util/logger.py +23 -0
- inspect_ai/_util/rich.py +7 -8
- inspect_ai/_util/text.py +301 -1
- inspect_ai/_util/transcript.py +10 -2
- inspect_ai/_util/working.py +46 -0
- inspect_ai/_view/www/dist/assets/index.css +56 -12
- inspect_ai/_view/www/dist/assets/index.js +905 -751
- inspect_ai/_view/www/log-schema.json +337 -2
- inspect_ai/_view/www/node_modules/flatted/python/flatted.py +149 -0
- inspect_ai/_view/www/node_modules/flatted/python/test.py +63 -0
- inspect_ai/_view/www/src/appearance/icons.ts +3 -1
- inspect_ai/_view/www/src/metadata/RenderedContent.tsx +0 -1
- inspect_ai/_view/www/src/samples/SampleDisplay.module.css +9 -1
- inspect_ai/_view/www/src/samples/SampleDisplay.tsx +28 -1
- inspect_ai/_view/www/src/samples/SampleSummaryView.module.css +4 -0
- inspect_ai/_view/www/src/samples/SampleSummaryView.tsx +23 -2
- inspect_ai/_view/www/src/samples/descriptor/score/ObjectScoreDescriptor.tsx +1 -1
- inspect_ai/_view/www/src/samples/transcript/SampleLimitEventView.tsx +4 -0
- inspect_ai/_view/www/src/samples/transcript/SandboxEventView.module.css +32 -0
- inspect_ai/_view/www/src/samples/transcript/SandboxEventView.tsx +152 -0
- inspect_ai/_view/www/src/samples/transcript/StepEventView.tsx +9 -2
- inspect_ai/_view/www/src/samples/transcript/TranscriptView.tsx +19 -1
- inspect_ai/_view/www/src/samples/transcript/event/EventPanel.tsx +6 -3
- inspect_ai/_view/www/src/samples/transcript/types.ts +3 -1
- inspect_ai/_view/www/src/types/log.d.ts +188 -108
- inspect_ai/_view/www/src/utils/format.ts +7 -4
- inspect_ai/_view/www/src/workspace/WorkSpaceView.tsx +9 -6
- inspect_ai/log/__init__.py +2 -0
- inspect_ai/log/_condense.py +1 -0
- inspect_ai/log/_log.py +72 -12
- inspect_ai/log/_samples.py +5 -5
- inspect_ai/log/_transcript.py +31 -1
- inspect_ai/model/_call_tools.py +1 -1
- inspect_ai/model/_conversation.py +1 -1
- inspect_ai/model/_model.py +35 -16
- inspect_ai/model/_model_call.py +10 -3
- inspect_ai/model/_providers/anthropic.py +13 -2
- inspect_ai/model/_providers/bedrock.py +7 -0
- inspect_ai/model/_providers/cloudflare.py +20 -7
- inspect_ai/model/_providers/google.py +358 -302
- inspect_ai/model/_providers/groq.py +57 -23
- inspect_ai/model/_providers/hf.py +6 -0
- inspect_ai/model/_providers/mistral.py +81 -52
- inspect_ai/model/_providers/openai.py +9 -0
- inspect_ai/model/_providers/providers.py +6 -6
- inspect_ai/model/_providers/util/tracker.py +92 -0
- inspect_ai/model/_providers/vllm.py +13 -5
- inspect_ai/solver/_basic_agent.py +1 -3
- inspect_ai/solver/_bridge/patch.py +0 -2
- inspect_ai/solver/_limit.py +4 -4
- inspect_ai/solver/_plan.py +3 -3
- inspect_ai/solver/_solver.py +3 -0
- inspect_ai/solver/_task_state.py +10 -1
- inspect_ai/tool/_tools/_web_search.py +3 -3
- inspect_ai/util/_concurrency.py +14 -8
- inspect_ai/util/_sandbox/context.py +15 -0
- inspect_ai/util/_sandbox/docker/cleanup.py +8 -3
- inspect_ai/util/_sandbox/docker/compose.py +5 -9
- inspect_ai/util/_sandbox/docker/docker.py +20 -6
- inspect_ai/util/_sandbox/docker/util.py +10 -1
- inspect_ai/util/_sandbox/environment.py +32 -1
- inspect_ai/util/_sandbox/events.py +149 -0
- inspect_ai/util/_sandbox/local.py +3 -3
- inspect_ai/util/_sandbox/self_check.py +2 -1
- inspect_ai/util/_subprocess.py +4 -1
- {inspect_ai-0.3.68.dist-info → inspect_ai-0.3.70.dist-info}/METADATA +5 -5
- {inspect_ai-0.3.68.dist-info → inspect_ai-0.3.70.dist-info}/RECORD +82 -74
- {inspect_ai-0.3.68.dist-info → inspect_ai-0.3.70.dist-info}/LICENSE +0 -0
- {inspect_ai-0.3.68.dist-info → inspect_ai-0.3.70.dist-info}/WHEEL +0 -0
- {inspect_ai-0.3.68.dist-info → inspect_ai-0.3.70.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.68.dist-info → inspect_ai-0.3.70.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,6 @@
|
|
1
1
|
import json
|
2
2
|
import os
|
3
|
+
from copy import copy
|
3
4
|
from typing import Any, Dict, Iterable, List, Optional
|
4
5
|
|
5
6
|
import httpx
|
@@ -19,9 +20,14 @@ from groq.types.chat import (
|
|
19
20
|
ChatCompletionToolMessageParam,
|
20
21
|
ChatCompletionUserMessageParam,
|
21
22
|
)
|
23
|
+
from pydantic import JsonValue
|
22
24
|
from typing_extensions import override
|
23
25
|
|
24
|
-
from inspect_ai._util.constants import
|
26
|
+
from inspect_ai._util.constants import (
|
27
|
+
BASE_64_DATA_REMOVED,
|
28
|
+
DEFAULT_MAX_RETRIES,
|
29
|
+
DEFAULT_MAX_TOKENS,
|
30
|
+
)
|
25
31
|
from inspect_ai._util.content import Content
|
26
32
|
from inspect_ai._util.images import file_as_data_uri
|
27
33
|
from inspect_ai._util.url import is_http_url
|
@@ -48,6 +54,7 @@ from .util import (
|
|
48
54
|
environment_prerequisite_error,
|
49
55
|
model_base_url,
|
50
56
|
)
|
57
|
+
from .util.tracker import HttpxTimeTracker
|
51
58
|
|
52
59
|
GROQ_API_KEY = "GROQ_API_KEY"
|
53
60
|
|
@@ -87,6 +94,9 @@ class GroqAPI(ModelAPI):
|
|
87
94
|
http_client=httpx.AsyncClient(limits=httpx.Limits(max_connections=None)),
|
88
95
|
)
|
89
96
|
|
97
|
+
# create time tracker
|
98
|
+
self._time_tracker = HttpxTimeTracker(self.client._client)
|
99
|
+
|
90
100
|
@override
|
91
101
|
async def close(self) -> None:
|
92
102
|
await self.client.close()
|
@@ -98,6 +108,21 @@ class GroqAPI(ModelAPI):
|
|
98
108
|
tool_choice: ToolChoice,
|
99
109
|
config: GenerateConfig,
|
100
110
|
) -> tuple[ModelOutput, ModelCall]:
|
111
|
+
# allocate request_id (so we can see it from ModelCall)
|
112
|
+
request_id = self._time_tracker.start_request()
|
113
|
+
|
114
|
+
# setup request and response for ModelCall
|
115
|
+
request: dict[str, Any] = {}
|
116
|
+
response: dict[str, Any] = {}
|
117
|
+
|
118
|
+
def model_call() -> ModelCall:
|
119
|
+
return ModelCall.create(
|
120
|
+
request=request,
|
121
|
+
response=response,
|
122
|
+
filter=model_call_filter,
|
123
|
+
time=self._time_tracker.end_request(request_id),
|
124
|
+
)
|
125
|
+
|
101
126
|
messages = await as_groq_chat_messages(input)
|
102
127
|
|
103
128
|
params = self.completion_params(config)
|
@@ -109,51 +134,52 @@ class GroqAPI(ModelAPI):
|
|
109
134
|
if config.parallel_tool_calls is not None:
|
110
135
|
params["parallel_tool_calls"] = config.parallel_tool_calls
|
111
136
|
|
112
|
-
|
137
|
+
request = dict(
|
113
138
|
messages=messages,
|
114
139
|
model=self.model_name,
|
140
|
+
extra_headers={HttpxTimeTracker.REQUEST_ID_HEADER: request_id},
|
115
141
|
**params,
|
116
142
|
)
|
117
143
|
|
144
|
+
completion: ChatCompletion = await self.client.chat.completions.create(
|
145
|
+
**request,
|
146
|
+
)
|
147
|
+
|
148
|
+
response = completion.model_dump()
|
149
|
+
|
118
150
|
# extract metadata
|
119
151
|
metadata: dict[str, Any] = {
|
120
|
-
"id":
|
121
|
-
"system_fingerprint":
|
122
|
-
"created":
|
152
|
+
"id": completion.id,
|
153
|
+
"system_fingerprint": completion.system_fingerprint,
|
154
|
+
"created": completion.created,
|
123
155
|
}
|
124
|
-
if
|
156
|
+
if completion.usage:
|
125
157
|
metadata = metadata | {
|
126
|
-
"queue_time":
|
127
|
-
"prompt_time":
|
128
|
-
"completion_time":
|
129
|
-
"total_time":
|
158
|
+
"queue_time": completion.usage.queue_time,
|
159
|
+
"prompt_time": completion.usage.prompt_time,
|
160
|
+
"completion_time": completion.usage.completion_time,
|
161
|
+
"total_time": completion.usage.total_time,
|
130
162
|
}
|
131
163
|
|
132
164
|
# extract output
|
133
|
-
choices = self._chat_choices_from_response(
|
165
|
+
choices = self._chat_choices_from_response(completion, tools)
|
134
166
|
output = ModelOutput(
|
135
|
-
model=
|
167
|
+
model=completion.model,
|
136
168
|
choices=choices,
|
137
169
|
usage=(
|
138
170
|
ModelUsage(
|
139
|
-
input_tokens=
|
140
|
-
output_tokens=
|
141
|
-
total_tokens=
|
171
|
+
input_tokens=completion.usage.prompt_tokens,
|
172
|
+
output_tokens=completion.usage.completion_tokens,
|
173
|
+
total_tokens=completion.usage.total_tokens,
|
142
174
|
)
|
143
|
-
if
|
175
|
+
if completion.usage
|
144
176
|
else None
|
145
177
|
),
|
146
178
|
metadata=metadata,
|
147
179
|
)
|
148
180
|
|
149
|
-
# record call
|
150
|
-
call = ModelCall.create(
|
151
|
-
request=dict(messages=messages, model=self.model_name, **params),
|
152
|
-
response=response.model_dump(),
|
153
|
-
)
|
154
|
-
|
155
181
|
# return
|
156
|
-
return output,
|
182
|
+
return output, model_call()
|
157
183
|
|
158
184
|
def completion_params(self, config: GenerateConfig) -> Dict[str, Any]:
|
159
185
|
params: dict[str, Any] = {}
|
@@ -307,3 +333,11 @@ def chat_message_assistant(message: Any, tools: list[ToolInfo]) -> ChatMessageAs
|
|
307
333
|
tool_calls=chat_tool_calls(message, tools),
|
308
334
|
reasoning=reasoning,
|
309
335
|
)
|
336
|
+
|
337
|
+
|
338
|
+
def model_call_filter(key: JsonValue | None, value: JsonValue) -> JsonValue:
|
339
|
+
# remove base64 encoded images
|
340
|
+
if key == "image_url" and isinstance(value, dict):
|
341
|
+
value = copy(value)
|
342
|
+
value.update(url=BASE_64_DATA_REMOVED)
|
343
|
+
return value
|
@@ -4,6 +4,7 @@ import functools
|
|
4
4
|
import gc
|
5
5
|
import json
|
6
6
|
import os
|
7
|
+
import time
|
7
8
|
from dataclasses import dataclass
|
8
9
|
from queue import Empty, Queue
|
9
10
|
from threading import Thread
|
@@ -220,6 +221,7 @@ class HuggingFaceAPI(ModelAPI):
|
|
220
221
|
output_tokens=response.output_tokens,
|
221
222
|
total_tokens=response.total_tokens,
|
222
223
|
),
|
224
|
+
time=response.time,
|
223
225
|
)
|
224
226
|
|
225
227
|
@override
|
@@ -377,6 +379,7 @@ class GenerateOutput:
|
|
377
379
|
output_tokens: int
|
378
380
|
total_tokens: int
|
379
381
|
logprobs: torch.Tensor | None
|
382
|
+
time: float
|
380
383
|
|
381
384
|
|
382
385
|
@dataclass
|
@@ -432,6 +435,7 @@ def process_batches() -> None:
|
|
432
435
|
|
433
436
|
try:
|
434
437
|
# capture the generator and decoder functions
|
438
|
+
start_time = time.monotonic()
|
435
439
|
first_input = inputs[0][0]
|
436
440
|
device = first_input.device
|
437
441
|
tokenizer = first_input.tokenizer
|
@@ -467,6 +471,7 @@ def process_batches() -> None:
|
|
467
471
|
outputs = decoder(sequences=generated_tokens)
|
468
472
|
|
469
473
|
# call back futures
|
474
|
+
total_time = time.monotonic() - start_time
|
470
475
|
for i, output in enumerate(outputs):
|
471
476
|
future = inputs[i][1]
|
472
477
|
input_tokens = input_ids.size(dim=1)
|
@@ -483,6 +488,7 @@ def process_batches() -> None:
|
|
483
488
|
output_tokens=output_tokens,
|
484
489
|
total_tokens=input_tokens + output_tokens,
|
485
490
|
logprobs=logprobs[i] if logprobs is not None else None,
|
491
|
+
time=total_time,
|
486
492
|
),
|
487
493
|
)
|
488
494
|
|
@@ -61,6 +61,7 @@ from .._model_output import (
|
|
61
61
|
StopReason,
|
62
62
|
)
|
63
63
|
from .util import environment_prerequisite_error, model_base_url
|
64
|
+
from .util.tracker import HttpxTimeTracker
|
64
65
|
|
65
66
|
AZURE_MISTRAL_API_KEY = "AZURE_MISTRAL_API_KEY"
|
66
67
|
AZUREAI_MISTRAL_API_KEY = "AZUREAI_MISTRAL_API_KEY"
|
@@ -111,16 +112,12 @@ class MistralAPI(ModelAPI):
|
|
111
112
|
if base_url:
|
112
113
|
model_args["server_url"] = base_url
|
113
114
|
|
114
|
-
|
115
|
-
self.client = Mistral(
|
116
|
-
api_key=self.api_key,
|
117
|
-
timeout_ms=(config.timeout if config.timeout else DEFAULT_TIMEOUT) * 1000,
|
118
|
-
**model_args,
|
119
|
-
)
|
115
|
+
self.model_args = model_args
|
120
116
|
|
121
117
|
@override
|
122
118
|
async def close(self) -> None:
|
123
|
-
|
119
|
+
# client is created and destroyed in generate
|
120
|
+
pass
|
124
121
|
|
125
122
|
async def generate(
|
126
123
|
self,
|
@@ -129,51 +126,83 @@ class MistralAPI(ModelAPI):
|
|
129
126
|
tool_choice: ToolChoice,
|
130
127
|
config: GenerateConfig,
|
131
128
|
) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
|
132
|
-
#
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
)
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
# send request
|
151
|
-
try:
|
152
|
-
response = await self.client.chat.complete_async(**request)
|
153
|
-
except SDKError as ex:
|
154
|
-
if ex.status_code == 400:
|
155
|
-
return self.handle_bad_request(ex), mistral_model_call(request, None)
|
156
|
-
else:
|
157
|
-
raise ex
|
158
|
-
|
159
|
-
if response is None:
|
160
|
-
raise RuntimeError("Mistral model did not return a response from generate.")
|
161
|
-
|
162
|
-
# return model output (w/ tool calls if they exist)
|
163
|
-
choices = completion_choices_from_response(response, tools)
|
164
|
-
return ModelOutput(
|
165
|
-
model=response.model,
|
166
|
-
choices=choices,
|
167
|
-
usage=ModelUsage(
|
168
|
-
input_tokens=response.usage.prompt_tokens,
|
169
|
-
output_tokens=(
|
170
|
-
response.usage.completion_tokens
|
171
|
-
if response.usage.completion_tokens
|
172
|
-
else response.usage.total_tokens - response.usage.prompt_tokens
|
129
|
+
# create client
|
130
|
+
with Mistral(
|
131
|
+
api_key=self.api_key,
|
132
|
+
timeout_ms=(config.timeout if config.timeout else DEFAULT_TIMEOUT) * 1000,
|
133
|
+
**self.model_args,
|
134
|
+
) as client:
|
135
|
+
# create time tracker
|
136
|
+
time_tracker = HttpxTimeTracker(client.sdk_configuration.async_client)
|
137
|
+
|
138
|
+
# build request
|
139
|
+
request_id = time_tracker.start_request()
|
140
|
+
request: dict[str, Any] = dict(
|
141
|
+
model=self.model_name,
|
142
|
+
messages=await mistral_chat_messages(input),
|
143
|
+
tools=mistral_chat_tools(tools) if len(tools) > 0 else None,
|
144
|
+
tool_choice=(
|
145
|
+
mistral_chat_tool_choice(tool_choice) if len(tools) > 0 else None
|
173
146
|
),
|
174
|
-
|
175
|
-
)
|
176
|
-
|
147
|
+
http_headers={HttpxTimeTracker.REQUEST_ID_HEADER: request_id},
|
148
|
+
)
|
149
|
+
if config.temperature is not None:
|
150
|
+
request["temperature"] = config.temperature
|
151
|
+
if config.top_p is not None:
|
152
|
+
request["top_p"] = config.top_p
|
153
|
+
if config.max_tokens is not None:
|
154
|
+
request["max_tokens"] = config.max_tokens
|
155
|
+
if config.seed is not None:
|
156
|
+
request["random_seed"] = config.seed
|
157
|
+
|
158
|
+
# prepare response for inclusion in model call
|
159
|
+
response: dict[str, Any] = {}
|
160
|
+
|
161
|
+
def model_call() -> ModelCall:
|
162
|
+
req = request.copy()
|
163
|
+
req.update(
|
164
|
+
messages=[message.model_dump() for message in req["messages"]]
|
165
|
+
)
|
166
|
+
if req.get("tools", None) is not None:
|
167
|
+
req["tools"] = [tool.model_dump() for tool in req["tools"]]
|
168
|
+
|
169
|
+
return ModelCall.create(
|
170
|
+
request=req,
|
171
|
+
response=response,
|
172
|
+
time=time_tracker.end_request(request_id),
|
173
|
+
)
|
174
|
+
|
175
|
+
# send request
|
176
|
+
try:
|
177
|
+
completion = await client.chat.complete_async(**request)
|
178
|
+
response = completion.model_dump()
|
179
|
+
except SDKError as ex:
|
180
|
+
if ex.status_code == 400:
|
181
|
+
return self.handle_bad_request(ex), model_call()
|
182
|
+
else:
|
183
|
+
raise ex
|
184
|
+
|
185
|
+
if completion is None:
|
186
|
+
raise RuntimeError(
|
187
|
+
"Mistral model did not return a response from generate."
|
188
|
+
)
|
189
|
+
|
190
|
+
# return model output (w/ tool calls if they exist)
|
191
|
+
choices = completion_choices_from_response(completion, tools)
|
192
|
+
return ModelOutput(
|
193
|
+
model=completion.model,
|
194
|
+
choices=choices,
|
195
|
+
usage=ModelUsage(
|
196
|
+
input_tokens=completion.usage.prompt_tokens,
|
197
|
+
output_tokens=(
|
198
|
+
completion.usage.completion_tokens
|
199
|
+
if completion.usage.completion_tokens
|
200
|
+
else completion.usage.total_tokens
|
201
|
+
- completion.usage.prompt_tokens
|
202
|
+
),
|
203
|
+
total_tokens=completion.usage.total_tokens,
|
204
|
+
),
|
205
|
+
), model_call()
|
177
206
|
|
178
207
|
@override
|
179
208
|
def is_rate_limit(self, ex: BaseException) -> bool:
|
@@ -205,7 +234,7 @@ def mistral_model_call(
|
|
205
234
|
request.update(messages=[message.model_dump() for message in request["messages"]])
|
206
235
|
if request.get("tools", None) is not None:
|
207
236
|
request["tools"] = [tool.model_dump() for tool in request["tools"]]
|
208
|
-
return ModelCall(
|
237
|
+
return ModelCall.create(
|
209
238
|
request=request, response=response.model_dump() if response else {}
|
210
239
|
)
|
211
240
|
|
@@ -21,6 +21,7 @@ from inspect_ai._util.constants import DEFAULT_MAX_RETRIES
|
|
21
21
|
from inspect_ai._util.error import PrerequisiteError
|
22
22
|
from inspect_ai._util.logger import warn_once
|
23
23
|
from inspect_ai.model._openai import chat_choices_from_openai
|
24
|
+
from inspect_ai.model._providers.util.tracker import HttpxTimeTracker
|
24
25
|
from inspect_ai.tool import ToolChoice, ToolInfo
|
25
26
|
|
26
27
|
from .._chat_message import ChatMessage
|
@@ -137,6 +138,9 @@ class OpenAIAPI(ModelAPI):
|
|
137
138
|
**model_args,
|
138
139
|
)
|
139
140
|
|
141
|
+
# create time tracker
|
142
|
+
self._time_tracker = HttpxTimeTracker(self.client._client)
|
143
|
+
|
140
144
|
def is_azure(self) -> bool:
|
141
145
|
return self.service == "azure"
|
142
146
|
|
@@ -172,6 +176,9 @@ class OpenAIAPI(ModelAPI):
|
|
172
176
|
**self.completion_params(config, False),
|
173
177
|
)
|
174
178
|
|
179
|
+
# allocate request_id (so we can see it from ModelCall)
|
180
|
+
request_id = self._time_tracker.start_request()
|
181
|
+
|
175
182
|
# setup request and response for ModelCall
|
176
183
|
request: dict[str, Any] = {}
|
177
184
|
response: dict[str, Any] = {}
|
@@ -181,6 +188,7 @@ class OpenAIAPI(ModelAPI):
|
|
181
188
|
request=request,
|
182
189
|
response=response,
|
183
190
|
filter=image_url_filter,
|
191
|
+
time=self._time_tracker.end_request(request_id),
|
184
192
|
)
|
185
193
|
|
186
194
|
# unlike text models, vision models require a max_tokens (and set it to a very low
|
@@ -199,6 +207,7 @@ class OpenAIAPI(ModelAPI):
|
|
199
207
|
tool_choice=openai_chat_tool_choice(tool_choice)
|
200
208
|
if len(tools) > 0
|
201
209
|
else NOT_GIVEN,
|
210
|
+
extra_headers={HttpxTimeTracker.REQUEST_ID_HEADER: request_id},
|
202
211
|
**self.completion_params(config, len(tools) > 0),
|
203
212
|
)
|
204
213
|
|
@@ -93,8 +93,8 @@ def vertex() -> type[ModelAPI]:
|
|
93
93
|
@modelapi(name="google")
|
94
94
|
def google() -> type[ModelAPI]:
|
95
95
|
FEATURE = "Google API"
|
96
|
-
PACKAGE = "google-
|
97
|
-
MIN_VERSION = "
|
96
|
+
PACKAGE = "google-genai"
|
97
|
+
MIN_VERSION = "1.2.0"
|
98
98
|
|
99
99
|
# workaround log spam
|
100
100
|
# https://github.com/ray-project/ray/issues/24917
|
@@ -102,7 +102,7 @@ def google() -> type[ModelAPI]:
|
|
102
102
|
|
103
103
|
# verify we have the package
|
104
104
|
try:
|
105
|
-
import google.
|
105
|
+
import google.genai # type: ignore # noqa: F401
|
106
106
|
except ImportError:
|
107
107
|
raise pip_dependency_error(FEATURE, [PACKAGE])
|
108
108
|
|
@@ -110,9 +110,9 @@ def google() -> type[ModelAPI]:
|
|
110
110
|
verify_required_version(FEATURE, PACKAGE, MIN_VERSION)
|
111
111
|
|
112
112
|
# in the clear
|
113
|
-
from .google import
|
113
|
+
from .google import GoogleGenAIAPI
|
114
114
|
|
115
|
-
return
|
115
|
+
return GoogleGenAIAPI
|
116
116
|
|
117
117
|
|
118
118
|
@modelapi(name="hf")
|
@@ -148,7 +148,7 @@ def cf() -> type[ModelAPI]:
|
|
148
148
|
def mistral() -> type[ModelAPI]:
|
149
149
|
FEATURE = "Mistral API"
|
150
150
|
PACKAGE = "mistralai"
|
151
|
-
MIN_VERSION = "1.
|
151
|
+
MIN_VERSION = "1.5.0"
|
152
152
|
|
153
153
|
# verify we have the package
|
154
154
|
try:
|
@@ -0,0 +1,92 @@
|
|
1
|
+
import re
|
2
|
+
import time
|
3
|
+
from typing import Any, cast
|
4
|
+
|
5
|
+
import httpx
|
6
|
+
from shortuuid import uuid
|
7
|
+
|
8
|
+
|
9
|
+
class HttpTimeTracker:
|
10
|
+
def __init__(self) -> None:
|
11
|
+
# track request start times
|
12
|
+
self._requests: dict[str, float] = {}
|
13
|
+
|
14
|
+
def start_request(self) -> str:
|
15
|
+
request_id = uuid()
|
16
|
+
self._requests[request_id] = time.monotonic()
|
17
|
+
return request_id
|
18
|
+
|
19
|
+
def end_request(self, request_id: str) -> float:
|
20
|
+
# read the request time if (if available) and purge from dict
|
21
|
+
request_time = self._requests.pop(request_id, None)
|
22
|
+
if request_time is None:
|
23
|
+
raise RuntimeError(f"request_id not registered: {request_id}")
|
24
|
+
|
25
|
+
# return elapsed time
|
26
|
+
return time.monotonic() - request_time
|
27
|
+
|
28
|
+
def update_request_time(self, request_id: str) -> None:
|
29
|
+
request_time = self._requests.get(request_id, None)
|
30
|
+
if not request_time:
|
31
|
+
raise RuntimeError(f"No request registered for request_id: {request_id}")
|
32
|
+
|
33
|
+
# update the request time
|
34
|
+
self._requests[request_id] = time.monotonic()
|
35
|
+
|
36
|
+
|
37
|
+
class BotoTimeTracker(HttpTimeTracker):
|
38
|
+
def __init__(self, session: Any) -> None:
|
39
|
+
from aiobotocore.session import AioSession
|
40
|
+
|
41
|
+
super().__init__()
|
42
|
+
|
43
|
+
# register hook
|
44
|
+
session = cast(AioSession, session._session)
|
45
|
+
session.register(
|
46
|
+
"before-send.bedrock-runtime.Converse", self.converse_before_send
|
47
|
+
)
|
48
|
+
|
49
|
+
def converse_before_send(self, **kwargs: Any) -> None:
|
50
|
+
user_agent = kwargs["request"].headers["User-Agent"].decode()
|
51
|
+
match = re.search(rf"{self.USER_AGENT_PREFIX}(\w+)", user_agent)
|
52
|
+
if match:
|
53
|
+
request_id = match.group(1)
|
54
|
+
self.update_request_time(request_id)
|
55
|
+
|
56
|
+
def user_agent_extra(self, request_id: str) -> str:
|
57
|
+
return f"{self.USER_AGENT_PREFIX}{request_id}"
|
58
|
+
|
59
|
+
USER_AGENT_PREFIX = "ins/rid#"
|
60
|
+
|
61
|
+
|
62
|
+
class HttpxTimeTracker(HttpTimeTracker):
|
63
|
+
"""Class which tracks the duration of successful (200 status) http requests.
|
64
|
+
|
65
|
+
A special header is injected into requests which is then read from
|
66
|
+
an httpx 'request' event hook -- this creates a record of when the request
|
67
|
+
started. Note that with retries a single request id could be started
|
68
|
+
several times; our request hook makes sure we always track the time of
|
69
|
+
the last request.
|
70
|
+
|
71
|
+
To determine the total time, we also install an httpx response hook. In
|
72
|
+
this hook we look for 200 responses which have a registered request id.
|
73
|
+
When we find one, we update the end time of the request.
|
74
|
+
|
75
|
+
There is an 'end_request()' method which gets the total requeset time
|
76
|
+
for a request_id and then purges the request_id from our tracking (so
|
77
|
+
the dict doesn't grow unbounded)
|
78
|
+
"""
|
79
|
+
|
80
|
+
REQUEST_ID_HEADER = "x-irid"
|
81
|
+
|
82
|
+
def __init__(self, client: httpx.AsyncClient):
|
83
|
+
super().__init__()
|
84
|
+
|
85
|
+
# install httpx request hook
|
86
|
+
client.event_hooks["request"].append(self.request_hook)
|
87
|
+
|
88
|
+
async def request_hook(self, request: httpx.Request) -> None:
|
89
|
+
# update the last request time for this request id (as there could be retries)
|
90
|
+
request_id = request.headers.get(self.REQUEST_ID_HEADER, None)
|
91
|
+
if request_id:
|
92
|
+
self.update_request_time(request_id)
|
@@ -2,6 +2,7 @@ import asyncio
|
|
2
2
|
import functools
|
3
3
|
import gc
|
4
4
|
import os
|
5
|
+
import time
|
5
6
|
from dataclasses import dataclass
|
6
7
|
from queue import Empty, Queue
|
7
8
|
from threading import Thread
|
@@ -48,7 +49,8 @@ class GenerateOutput:
|
|
48
49
|
output_tokens: int
|
49
50
|
total_tokens: int
|
50
51
|
stop_reason: StopReason
|
51
|
-
logprobs: Logprobs | None
|
52
|
+
logprobs: Logprobs | None
|
53
|
+
time: float
|
52
54
|
|
53
55
|
|
54
56
|
class VLLMAPI(ModelAPI):
|
@@ -258,6 +260,7 @@ class VLLMAPI(ModelAPI):
|
|
258
260
|
]
|
259
261
|
|
260
262
|
# TODO: what's the best way to calculate token usage for num_choices > 1
|
263
|
+
total_time = responses[0].time
|
261
264
|
input_tokens = responses[0].input_tokens
|
262
265
|
output_tokens = sum(response.output_tokens for response in responses)
|
263
266
|
total_tokens = input_tokens + output_tokens
|
@@ -270,6 +273,7 @@ class VLLMAPI(ModelAPI):
|
|
270
273
|
output_tokens=output_tokens,
|
271
274
|
total_tokens=total_tokens,
|
272
275
|
),
|
276
|
+
time=total_time,
|
273
277
|
)
|
274
278
|
|
275
279
|
|
@@ -356,7 +360,7 @@ def get_stop_reason(finish_reason: str | None) -> StopReason:
|
|
356
360
|
|
357
361
|
|
358
362
|
def post_process_output(
|
359
|
-
output: RequestOutput, i: int, num_top_logprobs: int | None
|
363
|
+
output: RequestOutput, i: int, num_top_logprobs: int | None, total_time: float
|
360
364
|
) -> GenerateOutput:
|
361
365
|
completion = output.outputs[i]
|
362
366
|
output_text: str = completion.text
|
@@ -377,14 +381,15 @@ def post_process_output(
|
|
377
381
|
total_tokens=total_tokens,
|
378
382
|
stop_reason=get_stop_reason(completion.finish_reason),
|
379
383
|
logprobs=extract_logprobs(completion, num_top_logprobs),
|
384
|
+
time=total_time,
|
380
385
|
)
|
381
386
|
|
382
387
|
|
383
388
|
def post_process_outputs(
|
384
|
-
output: RequestOutput, num_top_logprobs: int | None
|
389
|
+
output: RequestOutput, num_top_logprobs: int | None, total_time: float
|
385
390
|
) -> list[GenerateOutput]:
|
386
391
|
return [
|
387
|
-
post_process_output(output, i, num_top_logprobs)
|
392
|
+
post_process_output(output, i, num_top_logprobs, total_time)
|
388
393
|
for i in range(len(output.outputs))
|
389
394
|
]
|
390
395
|
|
@@ -412,6 +417,7 @@ def process_batches() -> None:
|
|
412
417
|
continue
|
413
418
|
|
414
419
|
try:
|
420
|
+
start_time = time.monotonic()
|
415
421
|
first_input = inputs[0][0]
|
416
422
|
generator = first_input.generator
|
417
423
|
num_top_logprobs = first_input.num_top_logprobs
|
@@ -419,6 +425,7 @@ def process_batches() -> None:
|
|
419
425
|
# generate
|
420
426
|
outputs = generator([input[0].input for input in inputs])
|
421
427
|
|
428
|
+
total_time = time.monotonic() - start_time
|
422
429
|
for i, output in enumerate(outputs):
|
423
430
|
future = inputs[i][1]
|
424
431
|
|
@@ -426,7 +433,8 @@ def process_batches() -> None:
|
|
426
433
|
# down to this point, so we can mark the future as done in a thread safe manner.
|
427
434
|
# see: https://docs.python.org/3/library/asyncio-dev.html#concurrency-and-multithreading
|
428
435
|
loop.call_soon_threadsafe(
|
429
|
-
future.set_result,
|
436
|
+
future.set_result,
|
437
|
+
post_process_outputs(output, num_top_logprobs, total_time),
|
430
438
|
)
|
431
439
|
|
432
440
|
except Exception as e:
|
@@ -24,7 +24,7 @@ logger = getLogger(__name__)
|
|
24
24
|
|
25
25
|
DEFAULT_SYSTEM_MESSAGE = """
|
26
26
|
You are a helpful assistant attempting to submit the correct answer. You have
|
27
|
-
several functions available to help with finding the answer. Each message
|
27
|
+
several functions available to help with finding the answer. Each message
|
28
28
|
may perform one function call. You will see the result of the function right
|
29
29
|
after sending the message. If you need to perform multiple actions, you can
|
30
30
|
always send more messages with subsequent function calls. Do some reasoning
|
@@ -206,13 +206,11 @@ def basic_agent(
|
|
206
206
|
# exit if we are at max_attempts
|
207
207
|
attempts += 1
|
208
208
|
if attempts >= max_attempts:
|
209
|
-
state.completed = True
|
210
209
|
break
|
211
210
|
|
212
211
|
# exit if the submission is successful
|
213
212
|
answer_scores = await score(state)
|
214
213
|
if score_value_fn(answer_scores[0].value) == 1.0:
|
215
|
-
state.completed = True
|
216
214
|
break
|
217
215
|
|
218
216
|
# otherwise notify the model that it was incorrect and continue
|
@@ -72,8 +72,6 @@ def init_openai_request_patch() -> None:
|
|
72
72
|
_patch_enabled.get()
|
73
73
|
# completions request
|
74
74
|
and options.url == "/chat/completions"
|
75
|
-
# call to openai not another service (e.g. TogetherAI)
|
76
|
-
and self.base_url == "https://api.openai.com/v1/"
|
77
75
|
):
|
78
76
|
# must also be an explicit request for an inspect model
|
79
77
|
json_data = cast(dict[str, Any], options.json_data)
|
inspect_ai/solver/_limit.py
CHANGED
@@ -7,15 +7,15 @@ class SampleLimitExceededError(Exception):
|
|
7
7
|
"""Exception raised when a sample limit is exceeded.
|
8
8
|
|
9
9
|
Args:
|
10
|
-
type
|
11
|
-
value
|
12
|
-
limit
|
10
|
+
type: Type of limit exceeded.
|
11
|
+
value: Value compared to
|
12
|
+
limit: Limit applied.
|
13
13
|
message (str | None): Optional. Human readable message.
|
14
14
|
"""
|
15
15
|
|
16
16
|
def __init__(
|
17
17
|
self,
|
18
|
-
type: Literal["message", "time", "token", "operator", "custom"],
|
18
|
+
type: Literal["message", "time", "working", "token", "operator", "custom"],
|
19
19
|
*,
|
20
20
|
value: int,
|
21
21
|
limit: int,
|