inspect-ai 0.3.68__py3-none-any.whl → 0.3.70__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. inspect_ai/_cli/eval.py +13 -1
  2. inspect_ai/_display/plain/display.py +9 -11
  3. inspect_ai/_display/textual/app.py +5 -5
  4. inspect_ai/_display/textual/widgets/samples.py +47 -18
  5. inspect_ai/_display/textual/widgets/transcript.py +25 -12
  6. inspect_ai/_eval/eval.py +14 -2
  7. inspect_ai/_eval/evalset.py +6 -1
  8. inspect_ai/_eval/run.py +6 -0
  9. inspect_ai/_eval/task/run.py +44 -15
  10. inspect_ai/_eval/task/task.py +26 -3
  11. inspect_ai/_util/interrupt.py +15 -0
  12. inspect_ai/_util/logger.py +23 -0
  13. inspect_ai/_util/rich.py +7 -8
  14. inspect_ai/_util/text.py +301 -1
  15. inspect_ai/_util/transcript.py +10 -2
  16. inspect_ai/_util/working.py +46 -0
  17. inspect_ai/_view/www/dist/assets/index.css +56 -12
  18. inspect_ai/_view/www/dist/assets/index.js +905 -751
  19. inspect_ai/_view/www/log-schema.json +337 -2
  20. inspect_ai/_view/www/node_modules/flatted/python/flatted.py +149 -0
  21. inspect_ai/_view/www/node_modules/flatted/python/test.py +63 -0
  22. inspect_ai/_view/www/src/appearance/icons.ts +3 -1
  23. inspect_ai/_view/www/src/metadata/RenderedContent.tsx +0 -1
  24. inspect_ai/_view/www/src/samples/SampleDisplay.module.css +9 -1
  25. inspect_ai/_view/www/src/samples/SampleDisplay.tsx +28 -1
  26. inspect_ai/_view/www/src/samples/SampleSummaryView.module.css +4 -0
  27. inspect_ai/_view/www/src/samples/SampleSummaryView.tsx +23 -2
  28. inspect_ai/_view/www/src/samples/descriptor/score/ObjectScoreDescriptor.tsx +1 -1
  29. inspect_ai/_view/www/src/samples/transcript/SampleLimitEventView.tsx +4 -0
  30. inspect_ai/_view/www/src/samples/transcript/SandboxEventView.module.css +32 -0
  31. inspect_ai/_view/www/src/samples/transcript/SandboxEventView.tsx +152 -0
  32. inspect_ai/_view/www/src/samples/transcript/StepEventView.tsx +9 -2
  33. inspect_ai/_view/www/src/samples/transcript/TranscriptView.tsx +19 -1
  34. inspect_ai/_view/www/src/samples/transcript/event/EventPanel.tsx +6 -3
  35. inspect_ai/_view/www/src/samples/transcript/types.ts +3 -1
  36. inspect_ai/_view/www/src/types/log.d.ts +188 -108
  37. inspect_ai/_view/www/src/utils/format.ts +7 -4
  38. inspect_ai/_view/www/src/workspace/WorkSpaceView.tsx +9 -6
  39. inspect_ai/log/__init__.py +2 -0
  40. inspect_ai/log/_condense.py +1 -0
  41. inspect_ai/log/_log.py +72 -12
  42. inspect_ai/log/_samples.py +5 -5
  43. inspect_ai/log/_transcript.py +31 -1
  44. inspect_ai/model/_call_tools.py +1 -1
  45. inspect_ai/model/_conversation.py +1 -1
  46. inspect_ai/model/_model.py +35 -16
  47. inspect_ai/model/_model_call.py +10 -3
  48. inspect_ai/model/_providers/anthropic.py +13 -2
  49. inspect_ai/model/_providers/bedrock.py +7 -0
  50. inspect_ai/model/_providers/cloudflare.py +20 -7
  51. inspect_ai/model/_providers/google.py +358 -302
  52. inspect_ai/model/_providers/groq.py +57 -23
  53. inspect_ai/model/_providers/hf.py +6 -0
  54. inspect_ai/model/_providers/mistral.py +81 -52
  55. inspect_ai/model/_providers/openai.py +9 -0
  56. inspect_ai/model/_providers/providers.py +6 -6
  57. inspect_ai/model/_providers/util/tracker.py +92 -0
  58. inspect_ai/model/_providers/vllm.py +13 -5
  59. inspect_ai/solver/_basic_agent.py +1 -3
  60. inspect_ai/solver/_bridge/patch.py +0 -2
  61. inspect_ai/solver/_limit.py +4 -4
  62. inspect_ai/solver/_plan.py +3 -3
  63. inspect_ai/solver/_solver.py +3 -0
  64. inspect_ai/solver/_task_state.py +10 -1
  65. inspect_ai/tool/_tools/_web_search.py +3 -3
  66. inspect_ai/util/_concurrency.py +14 -8
  67. inspect_ai/util/_sandbox/context.py +15 -0
  68. inspect_ai/util/_sandbox/docker/cleanup.py +8 -3
  69. inspect_ai/util/_sandbox/docker/compose.py +5 -9
  70. inspect_ai/util/_sandbox/docker/docker.py +20 -6
  71. inspect_ai/util/_sandbox/docker/util.py +10 -1
  72. inspect_ai/util/_sandbox/environment.py +32 -1
  73. inspect_ai/util/_sandbox/events.py +149 -0
  74. inspect_ai/util/_sandbox/local.py +3 -3
  75. inspect_ai/util/_sandbox/self_check.py +2 -1
  76. inspect_ai/util/_subprocess.py +4 -1
  77. {inspect_ai-0.3.68.dist-info → inspect_ai-0.3.70.dist-info}/METADATA +5 -5
  78. {inspect_ai-0.3.68.dist-info → inspect_ai-0.3.70.dist-info}/RECORD +82 -74
  79. {inspect_ai-0.3.68.dist-info → inspect_ai-0.3.70.dist-info}/LICENSE +0 -0
  80. {inspect_ai-0.3.68.dist-info → inspect_ai-0.3.70.dist-info}/WHEEL +0 -0
  81. {inspect_ai-0.3.68.dist-info → inspect_ai-0.3.70.dist-info}/entry_points.txt +0 -0
  82. {inspect_ai-0.3.68.dist-info → inspect_ai-0.3.70.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,6 @@
1
1
  import json
2
2
  import os
3
+ from copy import copy
3
4
  from typing import Any, Dict, Iterable, List, Optional
4
5
 
5
6
  import httpx
@@ -19,9 +20,14 @@ from groq.types.chat import (
19
20
  ChatCompletionToolMessageParam,
20
21
  ChatCompletionUserMessageParam,
21
22
  )
23
+ from pydantic import JsonValue
22
24
  from typing_extensions import override
23
25
 
24
- from inspect_ai._util.constants import DEFAULT_MAX_RETRIES, DEFAULT_MAX_TOKENS
26
+ from inspect_ai._util.constants import (
27
+ BASE_64_DATA_REMOVED,
28
+ DEFAULT_MAX_RETRIES,
29
+ DEFAULT_MAX_TOKENS,
30
+ )
25
31
  from inspect_ai._util.content import Content
26
32
  from inspect_ai._util.images import file_as_data_uri
27
33
  from inspect_ai._util.url import is_http_url
@@ -48,6 +54,7 @@ from .util import (
48
54
  environment_prerequisite_error,
49
55
  model_base_url,
50
56
  )
57
+ from .util.tracker import HttpxTimeTracker
51
58
 
52
59
  GROQ_API_KEY = "GROQ_API_KEY"
53
60
 
@@ -87,6 +94,9 @@ class GroqAPI(ModelAPI):
87
94
  http_client=httpx.AsyncClient(limits=httpx.Limits(max_connections=None)),
88
95
  )
89
96
 
97
+ # create time tracker
98
+ self._time_tracker = HttpxTimeTracker(self.client._client)
99
+
90
100
  @override
91
101
  async def close(self) -> None:
92
102
  await self.client.close()
@@ -98,6 +108,21 @@ class GroqAPI(ModelAPI):
98
108
  tool_choice: ToolChoice,
99
109
  config: GenerateConfig,
100
110
  ) -> tuple[ModelOutput, ModelCall]:
111
+ # allocate request_id (so we can see it from ModelCall)
112
+ request_id = self._time_tracker.start_request()
113
+
114
+ # setup request and response for ModelCall
115
+ request: dict[str, Any] = {}
116
+ response: dict[str, Any] = {}
117
+
118
+ def model_call() -> ModelCall:
119
+ return ModelCall.create(
120
+ request=request,
121
+ response=response,
122
+ filter=model_call_filter,
123
+ time=self._time_tracker.end_request(request_id),
124
+ )
125
+
101
126
  messages = await as_groq_chat_messages(input)
102
127
 
103
128
  params = self.completion_params(config)
@@ -109,51 +134,52 @@ class GroqAPI(ModelAPI):
109
134
  if config.parallel_tool_calls is not None:
110
135
  params["parallel_tool_calls"] = config.parallel_tool_calls
111
136
 
112
- response: ChatCompletion = await self.client.chat.completions.create(
137
+ request = dict(
113
138
  messages=messages,
114
139
  model=self.model_name,
140
+ extra_headers={HttpxTimeTracker.REQUEST_ID_HEADER: request_id},
115
141
  **params,
116
142
  )
117
143
 
144
+ completion: ChatCompletion = await self.client.chat.completions.create(
145
+ **request,
146
+ )
147
+
148
+ response = completion.model_dump()
149
+
118
150
  # extract metadata
119
151
  metadata: dict[str, Any] = {
120
- "id": response.id,
121
- "system_fingerprint": response.system_fingerprint,
122
- "created": response.created,
152
+ "id": completion.id,
153
+ "system_fingerprint": completion.system_fingerprint,
154
+ "created": completion.created,
123
155
  }
124
- if response.usage:
156
+ if completion.usage:
125
157
  metadata = metadata | {
126
- "queue_time": response.usage.queue_time,
127
- "prompt_time": response.usage.prompt_time,
128
- "completion_time": response.usage.completion_time,
129
- "total_time": response.usage.total_time,
158
+ "queue_time": completion.usage.queue_time,
159
+ "prompt_time": completion.usage.prompt_time,
160
+ "completion_time": completion.usage.completion_time,
161
+ "total_time": completion.usage.total_time,
130
162
  }
131
163
 
132
164
  # extract output
133
- choices = self._chat_choices_from_response(response, tools)
165
+ choices = self._chat_choices_from_response(completion, tools)
134
166
  output = ModelOutput(
135
- model=response.model,
167
+ model=completion.model,
136
168
  choices=choices,
137
169
  usage=(
138
170
  ModelUsage(
139
- input_tokens=response.usage.prompt_tokens,
140
- output_tokens=response.usage.completion_tokens,
141
- total_tokens=response.usage.total_tokens,
171
+ input_tokens=completion.usage.prompt_tokens,
172
+ output_tokens=completion.usage.completion_tokens,
173
+ total_tokens=completion.usage.total_tokens,
142
174
  )
143
- if response.usage
175
+ if completion.usage
144
176
  else None
145
177
  ),
146
178
  metadata=metadata,
147
179
  )
148
180
 
149
- # record call
150
- call = ModelCall.create(
151
- request=dict(messages=messages, model=self.model_name, **params),
152
- response=response.model_dump(),
153
- )
154
-
155
181
  # return
156
- return output, call
182
+ return output, model_call()
157
183
 
158
184
  def completion_params(self, config: GenerateConfig) -> Dict[str, Any]:
159
185
  params: dict[str, Any] = {}
@@ -307,3 +333,11 @@ def chat_message_assistant(message: Any, tools: list[ToolInfo]) -> ChatMessageAs
307
333
  tool_calls=chat_tool_calls(message, tools),
308
334
  reasoning=reasoning,
309
335
  )
336
+
337
+
338
+ def model_call_filter(key: JsonValue | None, value: JsonValue) -> JsonValue:
339
+ # remove base64 encoded images
340
+ if key == "image_url" and isinstance(value, dict):
341
+ value = copy(value)
342
+ value.update(url=BASE_64_DATA_REMOVED)
343
+ return value
@@ -4,6 +4,7 @@ import functools
4
4
  import gc
5
5
  import json
6
6
  import os
7
+ import time
7
8
  from dataclasses import dataclass
8
9
  from queue import Empty, Queue
9
10
  from threading import Thread
@@ -220,6 +221,7 @@ class HuggingFaceAPI(ModelAPI):
220
221
  output_tokens=response.output_tokens,
221
222
  total_tokens=response.total_tokens,
222
223
  ),
224
+ time=response.time,
223
225
  )
224
226
 
225
227
  @override
@@ -377,6 +379,7 @@ class GenerateOutput:
377
379
  output_tokens: int
378
380
  total_tokens: int
379
381
  logprobs: torch.Tensor | None
382
+ time: float
380
383
 
381
384
 
382
385
  @dataclass
@@ -432,6 +435,7 @@ def process_batches() -> None:
432
435
 
433
436
  try:
434
437
  # capture the generator and decoder functions
438
+ start_time = time.monotonic()
435
439
  first_input = inputs[0][0]
436
440
  device = first_input.device
437
441
  tokenizer = first_input.tokenizer
@@ -467,6 +471,7 @@ def process_batches() -> None:
467
471
  outputs = decoder(sequences=generated_tokens)
468
472
 
469
473
  # call back futures
474
+ total_time = time.monotonic() - start_time
470
475
  for i, output in enumerate(outputs):
471
476
  future = inputs[i][1]
472
477
  input_tokens = input_ids.size(dim=1)
@@ -483,6 +488,7 @@ def process_batches() -> None:
483
488
  output_tokens=output_tokens,
484
489
  total_tokens=input_tokens + output_tokens,
485
490
  logprobs=logprobs[i] if logprobs is not None else None,
491
+ time=total_time,
486
492
  ),
487
493
  )
488
494
 
@@ -61,6 +61,7 @@ from .._model_output import (
61
61
  StopReason,
62
62
  )
63
63
  from .util import environment_prerequisite_error, model_base_url
64
+ from .util.tracker import HttpxTimeTracker
64
65
 
65
66
  AZURE_MISTRAL_API_KEY = "AZURE_MISTRAL_API_KEY"
66
67
  AZUREAI_MISTRAL_API_KEY = "AZUREAI_MISTRAL_API_KEY"
@@ -111,16 +112,12 @@ class MistralAPI(ModelAPI):
111
112
  if base_url:
112
113
  model_args["server_url"] = base_url
113
114
 
114
- # create client
115
- self.client = Mistral(
116
- api_key=self.api_key,
117
- timeout_ms=(config.timeout if config.timeout else DEFAULT_TIMEOUT) * 1000,
118
- **model_args,
119
- )
115
+ self.model_args = model_args
120
116
 
121
117
  @override
122
118
  async def close(self) -> None:
123
- await self.client.sdk_configuration.async_client.aclose()
119
+ # client is created and destroyed in generate
120
+ pass
124
121
 
125
122
  async def generate(
126
123
  self,
@@ -129,51 +126,83 @@ class MistralAPI(ModelAPI):
129
126
  tool_choice: ToolChoice,
130
127
  config: GenerateConfig,
131
128
  ) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
132
- # build request
133
- request: dict[str, Any] = dict(
134
- model=self.model_name,
135
- messages=await mistral_chat_messages(input),
136
- tools=mistral_chat_tools(tools) if len(tools) > 0 else None,
137
- tool_choice=(
138
- mistral_chat_tool_choice(tool_choice) if len(tools) > 0 else None
139
- ),
140
- )
141
- if config.temperature is not None:
142
- request["temperature"] = config.temperature
143
- if config.top_p is not None:
144
- request["top_p"] = config.top_p
145
- if config.max_tokens is not None:
146
- request["max_tokens"] = config.max_tokens
147
- if config.seed is not None:
148
- request["random_seed"] = config.seed
149
-
150
- # send request
151
- try:
152
- response = await self.client.chat.complete_async(**request)
153
- except SDKError as ex:
154
- if ex.status_code == 400:
155
- return self.handle_bad_request(ex), mistral_model_call(request, None)
156
- else:
157
- raise ex
158
-
159
- if response is None:
160
- raise RuntimeError("Mistral model did not return a response from generate.")
161
-
162
- # return model output (w/ tool calls if they exist)
163
- choices = completion_choices_from_response(response, tools)
164
- return ModelOutput(
165
- model=response.model,
166
- choices=choices,
167
- usage=ModelUsage(
168
- input_tokens=response.usage.prompt_tokens,
169
- output_tokens=(
170
- response.usage.completion_tokens
171
- if response.usage.completion_tokens
172
- else response.usage.total_tokens - response.usage.prompt_tokens
129
+ # create client
130
+ with Mistral(
131
+ api_key=self.api_key,
132
+ timeout_ms=(config.timeout if config.timeout else DEFAULT_TIMEOUT) * 1000,
133
+ **self.model_args,
134
+ ) as client:
135
+ # create time tracker
136
+ time_tracker = HttpxTimeTracker(client.sdk_configuration.async_client)
137
+
138
+ # build request
139
+ request_id = time_tracker.start_request()
140
+ request: dict[str, Any] = dict(
141
+ model=self.model_name,
142
+ messages=await mistral_chat_messages(input),
143
+ tools=mistral_chat_tools(tools) if len(tools) > 0 else None,
144
+ tool_choice=(
145
+ mistral_chat_tool_choice(tool_choice) if len(tools) > 0 else None
173
146
  ),
174
- total_tokens=response.usage.total_tokens,
175
- ),
176
- ), mistral_model_call(request, response)
147
+ http_headers={HttpxTimeTracker.REQUEST_ID_HEADER: request_id},
148
+ )
149
+ if config.temperature is not None:
150
+ request["temperature"] = config.temperature
151
+ if config.top_p is not None:
152
+ request["top_p"] = config.top_p
153
+ if config.max_tokens is not None:
154
+ request["max_tokens"] = config.max_tokens
155
+ if config.seed is not None:
156
+ request["random_seed"] = config.seed
157
+
158
+ # prepare response for inclusion in model call
159
+ response: dict[str, Any] = {}
160
+
161
+ def model_call() -> ModelCall:
162
+ req = request.copy()
163
+ req.update(
164
+ messages=[message.model_dump() for message in req["messages"]]
165
+ )
166
+ if req.get("tools", None) is not None:
167
+ req["tools"] = [tool.model_dump() for tool in req["tools"]]
168
+
169
+ return ModelCall.create(
170
+ request=req,
171
+ response=response,
172
+ time=time_tracker.end_request(request_id),
173
+ )
174
+
175
+ # send request
176
+ try:
177
+ completion = await client.chat.complete_async(**request)
178
+ response = completion.model_dump()
179
+ except SDKError as ex:
180
+ if ex.status_code == 400:
181
+ return self.handle_bad_request(ex), model_call()
182
+ else:
183
+ raise ex
184
+
185
+ if completion is None:
186
+ raise RuntimeError(
187
+ "Mistral model did not return a response from generate."
188
+ )
189
+
190
+ # return model output (w/ tool calls if they exist)
191
+ choices = completion_choices_from_response(completion, tools)
192
+ return ModelOutput(
193
+ model=completion.model,
194
+ choices=choices,
195
+ usage=ModelUsage(
196
+ input_tokens=completion.usage.prompt_tokens,
197
+ output_tokens=(
198
+ completion.usage.completion_tokens
199
+ if completion.usage.completion_tokens
200
+ else completion.usage.total_tokens
201
+ - completion.usage.prompt_tokens
202
+ ),
203
+ total_tokens=completion.usage.total_tokens,
204
+ ),
205
+ ), model_call()
177
206
 
178
207
  @override
179
208
  def is_rate_limit(self, ex: BaseException) -> bool:
@@ -205,7 +234,7 @@ def mistral_model_call(
205
234
  request.update(messages=[message.model_dump() for message in request["messages"]])
206
235
  if request.get("tools", None) is not None:
207
236
  request["tools"] = [tool.model_dump() for tool in request["tools"]]
208
- return ModelCall(
237
+ return ModelCall.create(
209
238
  request=request, response=response.model_dump() if response else {}
210
239
  )
211
240
 
@@ -21,6 +21,7 @@ from inspect_ai._util.constants import DEFAULT_MAX_RETRIES
21
21
  from inspect_ai._util.error import PrerequisiteError
22
22
  from inspect_ai._util.logger import warn_once
23
23
  from inspect_ai.model._openai import chat_choices_from_openai
24
+ from inspect_ai.model._providers.util.tracker import HttpxTimeTracker
24
25
  from inspect_ai.tool import ToolChoice, ToolInfo
25
26
 
26
27
  from .._chat_message import ChatMessage
@@ -137,6 +138,9 @@ class OpenAIAPI(ModelAPI):
137
138
  **model_args,
138
139
  )
139
140
 
141
+ # create time tracker
142
+ self._time_tracker = HttpxTimeTracker(self.client._client)
143
+
140
144
  def is_azure(self) -> bool:
141
145
  return self.service == "azure"
142
146
 
@@ -172,6 +176,9 @@ class OpenAIAPI(ModelAPI):
172
176
  **self.completion_params(config, False),
173
177
  )
174
178
 
179
+ # allocate request_id (so we can see it from ModelCall)
180
+ request_id = self._time_tracker.start_request()
181
+
175
182
  # setup request and response for ModelCall
176
183
  request: dict[str, Any] = {}
177
184
  response: dict[str, Any] = {}
@@ -181,6 +188,7 @@ class OpenAIAPI(ModelAPI):
181
188
  request=request,
182
189
  response=response,
183
190
  filter=image_url_filter,
191
+ time=self._time_tracker.end_request(request_id),
184
192
  )
185
193
 
186
194
  # unlike text models, vision models require a max_tokens (and set it to a very low
@@ -199,6 +207,7 @@ class OpenAIAPI(ModelAPI):
199
207
  tool_choice=openai_chat_tool_choice(tool_choice)
200
208
  if len(tools) > 0
201
209
  else NOT_GIVEN,
210
+ extra_headers={HttpxTimeTracker.REQUEST_ID_HEADER: request_id},
202
211
  **self.completion_params(config, len(tools) > 0),
203
212
  )
204
213
 
@@ -93,8 +93,8 @@ def vertex() -> type[ModelAPI]:
93
93
  @modelapi(name="google")
94
94
  def google() -> type[ModelAPI]:
95
95
  FEATURE = "Google API"
96
- PACKAGE = "google-generativeai"
97
- MIN_VERSION = "0.8.4"
96
+ PACKAGE = "google-genai"
97
+ MIN_VERSION = "1.2.0"
98
98
 
99
99
  # workaround log spam
100
100
  # https://github.com/ray-project/ray/issues/24917
@@ -102,7 +102,7 @@ def google() -> type[ModelAPI]:
102
102
 
103
103
  # verify we have the package
104
104
  try:
105
- import google.generativeai # type: ignore # noqa: F401
105
+ import google.genai # type: ignore # noqa: F401
106
106
  except ImportError:
107
107
  raise pip_dependency_error(FEATURE, [PACKAGE])
108
108
 
@@ -110,9 +110,9 @@ def google() -> type[ModelAPI]:
110
110
  verify_required_version(FEATURE, PACKAGE, MIN_VERSION)
111
111
 
112
112
  # in the clear
113
- from .google import GoogleAPI
113
+ from .google import GoogleGenAIAPI
114
114
 
115
- return GoogleAPI
115
+ return GoogleGenAIAPI
116
116
 
117
117
 
118
118
  @modelapi(name="hf")
@@ -148,7 +148,7 @@ def cf() -> type[ModelAPI]:
148
148
  def mistral() -> type[ModelAPI]:
149
149
  FEATURE = "Mistral API"
150
150
  PACKAGE = "mistralai"
151
- MIN_VERSION = "1.2.0"
151
+ MIN_VERSION = "1.5.0"
152
152
 
153
153
  # verify we have the package
154
154
  try:
@@ -0,0 +1,92 @@
1
+ import re
2
+ import time
3
+ from typing import Any, cast
4
+
5
+ import httpx
6
+ from shortuuid import uuid
7
+
8
+
9
+ class HttpTimeTracker:
10
+ def __init__(self) -> None:
11
+ # track request start times
12
+ self._requests: dict[str, float] = {}
13
+
14
+ def start_request(self) -> str:
15
+ request_id = uuid()
16
+ self._requests[request_id] = time.monotonic()
17
+ return request_id
18
+
19
+ def end_request(self, request_id: str) -> float:
20
+ # read the request time if (if available) and purge from dict
21
+ request_time = self._requests.pop(request_id, None)
22
+ if request_time is None:
23
+ raise RuntimeError(f"request_id not registered: {request_id}")
24
+
25
+ # return elapsed time
26
+ return time.monotonic() - request_time
27
+
28
+ def update_request_time(self, request_id: str) -> None:
29
+ request_time = self._requests.get(request_id, None)
30
+ if not request_time:
31
+ raise RuntimeError(f"No request registered for request_id: {request_id}")
32
+
33
+ # update the request time
34
+ self._requests[request_id] = time.monotonic()
35
+
36
+
37
+ class BotoTimeTracker(HttpTimeTracker):
38
+ def __init__(self, session: Any) -> None:
39
+ from aiobotocore.session import AioSession
40
+
41
+ super().__init__()
42
+
43
+ # register hook
44
+ session = cast(AioSession, session._session)
45
+ session.register(
46
+ "before-send.bedrock-runtime.Converse", self.converse_before_send
47
+ )
48
+
49
+ def converse_before_send(self, **kwargs: Any) -> None:
50
+ user_agent = kwargs["request"].headers["User-Agent"].decode()
51
+ match = re.search(rf"{self.USER_AGENT_PREFIX}(\w+)", user_agent)
52
+ if match:
53
+ request_id = match.group(1)
54
+ self.update_request_time(request_id)
55
+
56
+ def user_agent_extra(self, request_id: str) -> str:
57
+ return f"{self.USER_AGENT_PREFIX}{request_id}"
58
+
59
+ USER_AGENT_PREFIX = "ins/rid#"
60
+
61
+
62
+ class HttpxTimeTracker(HttpTimeTracker):
63
+ """Class which tracks the duration of successful (200 status) http requests.
64
+
65
+ A special header is injected into requests which is then read from
66
+ an httpx 'request' event hook -- this creates a record of when the request
67
+ started. Note that with retries a single request id could be started
68
+ several times; our request hook makes sure we always track the time of
69
+ the last request.
70
+
71
+ To determine the total time, we also install an httpx response hook. In
72
+ this hook we look for 200 responses which have a registered request id.
73
+ When we find one, we update the end time of the request.
74
+
75
+ There is an 'end_request()' method which gets the total requeset time
76
+ for a request_id and then purges the request_id from our tracking (so
77
+ the dict doesn't grow unbounded)
78
+ """
79
+
80
+ REQUEST_ID_HEADER = "x-irid"
81
+
82
+ def __init__(self, client: httpx.AsyncClient):
83
+ super().__init__()
84
+
85
+ # install httpx request hook
86
+ client.event_hooks["request"].append(self.request_hook)
87
+
88
+ async def request_hook(self, request: httpx.Request) -> None:
89
+ # update the last request time for this request id (as there could be retries)
90
+ request_id = request.headers.get(self.REQUEST_ID_HEADER, None)
91
+ if request_id:
92
+ self.update_request_time(request_id)
@@ -2,6 +2,7 @@ import asyncio
2
2
  import functools
3
3
  import gc
4
4
  import os
5
+ import time
5
6
  from dataclasses import dataclass
6
7
  from queue import Empty, Queue
7
8
  from threading import Thread
@@ -48,7 +49,8 @@ class GenerateOutput:
48
49
  output_tokens: int
49
50
  total_tokens: int
50
51
  stop_reason: StopReason
51
- logprobs: Logprobs | None = None
52
+ logprobs: Logprobs | None
53
+ time: float
52
54
 
53
55
 
54
56
  class VLLMAPI(ModelAPI):
@@ -258,6 +260,7 @@ class VLLMAPI(ModelAPI):
258
260
  ]
259
261
 
260
262
  # TODO: what's the best way to calculate token usage for num_choices > 1
263
+ total_time = responses[0].time
261
264
  input_tokens = responses[0].input_tokens
262
265
  output_tokens = sum(response.output_tokens for response in responses)
263
266
  total_tokens = input_tokens + output_tokens
@@ -270,6 +273,7 @@ class VLLMAPI(ModelAPI):
270
273
  output_tokens=output_tokens,
271
274
  total_tokens=total_tokens,
272
275
  ),
276
+ time=total_time,
273
277
  )
274
278
 
275
279
 
@@ -356,7 +360,7 @@ def get_stop_reason(finish_reason: str | None) -> StopReason:
356
360
 
357
361
 
358
362
  def post_process_output(
359
- output: RequestOutput, i: int, num_top_logprobs: int | None
363
+ output: RequestOutput, i: int, num_top_logprobs: int | None, total_time: float
360
364
  ) -> GenerateOutput:
361
365
  completion = output.outputs[i]
362
366
  output_text: str = completion.text
@@ -377,14 +381,15 @@ def post_process_output(
377
381
  total_tokens=total_tokens,
378
382
  stop_reason=get_stop_reason(completion.finish_reason),
379
383
  logprobs=extract_logprobs(completion, num_top_logprobs),
384
+ time=total_time,
380
385
  )
381
386
 
382
387
 
383
388
  def post_process_outputs(
384
- output: RequestOutput, num_top_logprobs: int | None
389
+ output: RequestOutput, num_top_logprobs: int | None, total_time: float
385
390
  ) -> list[GenerateOutput]:
386
391
  return [
387
- post_process_output(output, i, num_top_logprobs)
392
+ post_process_output(output, i, num_top_logprobs, total_time)
388
393
  for i in range(len(output.outputs))
389
394
  ]
390
395
 
@@ -412,6 +417,7 @@ def process_batches() -> None:
412
417
  continue
413
418
 
414
419
  try:
420
+ start_time = time.monotonic()
415
421
  first_input = inputs[0][0]
416
422
  generator = first_input.generator
417
423
  num_top_logprobs = first_input.num_top_logprobs
@@ -419,6 +425,7 @@ def process_batches() -> None:
419
425
  # generate
420
426
  outputs = generator([input[0].input for input in inputs])
421
427
 
428
+ total_time = time.monotonic() - start_time
422
429
  for i, output in enumerate(outputs):
423
430
  future = inputs[i][1]
424
431
 
@@ -426,7 +433,8 @@ def process_batches() -> None:
426
433
  # down to this point, so we can mark the future as done in a thread safe manner.
427
434
  # see: https://docs.python.org/3/library/asyncio-dev.html#concurrency-and-multithreading
428
435
  loop.call_soon_threadsafe(
429
- future.set_result, post_process_outputs(output, num_top_logprobs)
436
+ future.set_result,
437
+ post_process_outputs(output, num_top_logprobs, total_time),
430
438
  )
431
439
 
432
440
  except Exception as e:
@@ -24,7 +24,7 @@ logger = getLogger(__name__)
24
24
 
25
25
  DEFAULT_SYSTEM_MESSAGE = """
26
26
  You are a helpful assistant attempting to submit the correct answer. You have
27
- several functions available to help with finding the answer. Each message may
27
+ several functions available to help with finding the answer. Each message
28
28
  may perform one function call. You will see the result of the function right
29
29
  after sending the message. If you need to perform multiple actions, you can
30
30
  always send more messages with subsequent function calls. Do some reasoning
@@ -206,13 +206,11 @@ def basic_agent(
206
206
  # exit if we are at max_attempts
207
207
  attempts += 1
208
208
  if attempts >= max_attempts:
209
- state.completed = True
210
209
  break
211
210
 
212
211
  # exit if the submission is successful
213
212
  answer_scores = await score(state)
214
213
  if score_value_fn(answer_scores[0].value) == 1.0:
215
- state.completed = True
216
214
  break
217
215
 
218
216
  # otherwise notify the model that it was incorrect and continue
@@ -72,8 +72,6 @@ def init_openai_request_patch() -> None:
72
72
  _patch_enabled.get()
73
73
  # completions request
74
74
  and options.url == "/chat/completions"
75
- # call to openai not another service (e.g. TogetherAI)
76
- and self.base_url == "https://api.openai.com/v1/"
77
75
  ):
78
76
  # must also be an explicit request for an inspect model
79
77
  json_data = cast(dict[str, Any], options.json_data)
@@ -7,15 +7,15 @@ class SampleLimitExceededError(Exception):
7
7
  """Exception raised when a sample limit is exceeded.
8
8
 
9
9
  Args:
10
- type (Literal["message", "time", "token", "operator"]): Type of limit exceeded.
11
- value (int): Value compared to
12
- limit (int): Limit applied.
10
+ type: Type of limit exceeded.
11
+ value: Value compared to
12
+ limit: Limit applied.
13
13
  message (str | None): Optional. Human readable message.
14
14
  """
15
15
 
16
16
  def __init__(
17
17
  self,
18
- type: Literal["message", "time", "token", "operator", "custom"],
18
+ type: Literal["message", "time", "working", "token", "operator", "custom"],
19
19
  *,
20
20
  value: int,
21
21
  limit: int,