inspect-ai 0.3.69__py3-none-any.whl → 0.3.70__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. inspect_ai/_cli/eval.py +13 -1
  2. inspect_ai/_display/textual/app.py +3 -2
  3. inspect_ai/_display/textual/widgets/samples.py +4 -10
  4. inspect_ai/_display/textual/widgets/transcript.py +25 -12
  5. inspect_ai/_eval/eval.py +14 -2
  6. inspect_ai/_eval/evalset.py +6 -1
  7. inspect_ai/_eval/run.py +6 -0
  8. inspect_ai/_eval/task/run.py +44 -15
  9. inspect_ai/_eval/task/task.py +26 -3
  10. inspect_ai/_util/interrupt.py +6 -0
  11. inspect_ai/_util/logger.py +19 -0
  12. inspect_ai/_util/rich.py +7 -8
  13. inspect_ai/_util/text.py +13 -0
  14. inspect_ai/_util/transcript.py +10 -2
  15. inspect_ai/_util/working.py +46 -0
  16. inspect_ai/_view/www/dist/assets/index.css +56 -12
  17. inspect_ai/_view/www/dist/assets/index.js +904 -750
  18. inspect_ai/_view/www/log-schema.json +337 -2
  19. inspect_ai/_view/www/node_modules/flatted/python/flatted.py +149 -0
  20. inspect_ai/_view/www/node_modules/flatted/python/test.py +63 -0
  21. inspect_ai/_view/www/src/appearance/icons.ts +3 -1
  22. inspect_ai/_view/www/src/metadata/RenderedContent.tsx +0 -1
  23. inspect_ai/_view/www/src/samples/SampleDisplay.module.css +9 -1
  24. inspect_ai/_view/www/src/samples/SampleDisplay.tsx +28 -1
  25. inspect_ai/_view/www/src/samples/SampleSummaryView.module.css +4 -0
  26. inspect_ai/_view/www/src/samples/SampleSummaryView.tsx +23 -2
  27. inspect_ai/_view/www/src/samples/transcript/SampleLimitEventView.tsx +4 -0
  28. inspect_ai/_view/www/src/samples/transcript/SandboxEventView.module.css +32 -0
  29. inspect_ai/_view/www/src/samples/transcript/SandboxEventView.tsx +152 -0
  30. inspect_ai/_view/www/src/samples/transcript/StepEventView.tsx +9 -2
  31. inspect_ai/_view/www/src/samples/transcript/TranscriptView.tsx +19 -1
  32. inspect_ai/_view/www/src/samples/transcript/event/EventPanel.tsx +6 -3
  33. inspect_ai/_view/www/src/samples/transcript/types.ts +3 -1
  34. inspect_ai/_view/www/src/types/log.d.ts +188 -108
  35. inspect_ai/_view/www/src/utils/format.ts +7 -4
  36. inspect_ai/_view/www/src/workspace/WorkSpaceView.tsx +9 -6
  37. inspect_ai/log/__init__.py +2 -0
  38. inspect_ai/log/_condense.py +1 -0
  39. inspect_ai/log/_log.py +72 -12
  40. inspect_ai/log/_samples.py +5 -1
  41. inspect_ai/log/_transcript.py +31 -1
  42. inspect_ai/model/_call_tools.py +1 -1
  43. inspect_ai/model/_conversation.py +1 -1
  44. inspect_ai/model/_model.py +32 -16
  45. inspect_ai/model/_model_call.py +10 -3
  46. inspect_ai/model/_providers/anthropic.py +13 -2
  47. inspect_ai/model/_providers/bedrock.py +7 -0
  48. inspect_ai/model/_providers/cloudflare.py +20 -7
  49. inspect_ai/model/_providers/google.py +2 -0
  50. inspect_ai/model/_providers/groq.py +57 -23
  51. inspect_ai/model/_providers/hf.py +6 -0
  52. inspect_ai/model/_providers/mistral.py +78 -51
  53. inspect_ai/model/_providers/openai.py +9 -0
  54. inspect_ai/model/_providers/providers.py +1 -1
  55. inspect_ai/model/_providers/util/tracker.py +92 -0
  56. inspect_ai/model/_providers/vllm.py +13 -5
  57. inspect_ai/solver/_basic_agent.py +1 -3
  58. inspect_ai/solver/_bridge/patch.py +0 -2
  59. inspect_ai/solver/_limit.py +4 -4
  60. inspect_ai/solver/_plan.py +0 -3
  61. inspect_ai/solver/_task_state.py +7 -0
  62. inspect_ai/tool/_tools/_web_search.py +3 -3
  63. inspect_ai/util/_concurrency.py +14 -8
  64. inspect_ai/util/_sandbox/context.py +15 -0
  65. inspect_ai/util/_sandbox/docker/docker.py +7 -5
  66. inspect_ai/util/_sandbox/environment.py +32 -1
  67. inspect_ai/util/_sandbox/events.py +149 -0
  68. inspect_ai/util/_sandbox/local.py +3 -3
  69. {inspect_ai-0.3.69.dist-info → inspect_ai-0.3.70.dist-info}/METADATA +3 -3
  70. {inspect_ai-0.3.69.dist-info → inspect_ai-0.3.70.dist-info}/RECORD +74 -67
  71. {inspect_ai-0.3.69.dist-info → inspect_ai-0.3.70.dist-info}/LICENSE +0 -0
  72. {inspect_ai-0.3.69.dist-info → inspect_ai-0.3.70.dist-info}/WHEEL +0 -0
  73. {inspect_ai-0.3.69.dist-info → inspect_ai-0.3.70.dist-info}/entry_points.txt +0 -0
  74. {inspect_ai-0.3.69.dist-info → inspect_ai-0.3.70.dist-info}/top_level.txt +0 -0
@@ -41,6 +41,7 @@ from ._transcript import (
41
41
  ModelEvent,
42
42
  SampleInitEvent,
43
43
  SampleLimitEvent,
44
+ SandboxEvent,
44
45
  ScoreEvent,
45
46
  StateEvent,
46
47
  StepEvent,
@@ -82,6 +83,7 @@ __all__ = [
82
83
  "ModelEvent",
83
84
  "SampleInitEvent",
84
85
  "SampleLimitEvent",
86
+ "SandboxEvent",
85
87
  "ScoreEvent",
86
88
  "StateEvent",
87
89
  "StepEvent",
@@ -217,6 +217,7 @@ def walk_model_call(
217
217
  return ModelCall(
218
218
  request=walk_json_dict(call.request, content_fn),
219
219
  response=walk_json_dict(call.response, content_fn),
220
+ time=call.time,
220
221
  )
221
222
  else:
222
223
  return None
inspect_ai/log/_log.py CHANGED
@@ -4,7 +4,7 @@ import sys
4
4
  import traceback
5
5
  from logging import getLogger
6
6
  from types import TracebackType
7
- from typing import Any, Literal, Type, TypedDict
7
+ from typing import Any, Literal, Tuple, Type, TypedDict
8
8
 
9
9
  import click
10
10
  import tenacity
@@ -86,13 +86,16 @@ class EvalConfig(BaseModel):
86
86
  """
87
87
 
88
88
  message_limit: int | None = Field(default=None)
89
- """Maximum messages to allow in a chat conversation."""
89
+ """Maximum messages to allow per sample."""
90
90
 
91
91
  token_limit: int | None = Field(default=None)
92
- """Maximum tokens to allow in a chat conversation."""
92
+ """Maximum tokens usage per sample."""
93
93
 
94
94
  time_limit: int | None = Field(default=None)
95
- """Maximum seconds for chat conversation."""
95
+ """Maximum clock time per sample."""
96
+
97
+ working_limit: int | None = Field(default=None)
98
+ """Meximum working time per sample."""
96
99
 
97
100
  max_samples: int | None = Field(default=None)
98
101
  """Maximum number of samples to run in parallel."""
@@ -141,7 +144,9 @@ class EvalConfig(BaseModel):
141
144
  class EvalSampleLimit(BaseModel):
142
145
  """Limit encontered by sample."""
143
146
 
144
- type: Literal["context", "time", "message", "token", "operator", "custom"]
147
+ type: Literal[
148
+ "context", "time", "working", "message", "token", "operator", "custom"
149
+ ]
145
150
  """The type of limit"""
146
151
 
147
152
  limit: int
@@ -218,6 +223,15 @@ class EvalSample(BaseModel):
218
223
  model_usage: dict[str, ModelUsage] = Field(default_factory=dict)
219
224
  """Model token usage for sample."""
220
225
 
226
+ total_time: float | None = Field(default=None)
227
+ """Total time that the sample was running."""
228
+
229
+ working_time: float | None = Field(default=None)
230
+ """Time spent working (model generation, sandbox calls, etc.)"""
231
+
232
+ uuid: str | None = Field(default=None)
233
+ """Globally unique identifier for sample run (exists for samples created in Inspect >= 0.3.70)"""
234
+
221
235
  error: EvalError | None = Field(default=None)
222
236
  """Error that halted sample."""
223
237
 
@@ -601,14 +615,15 @@ def eval_error(
601
615
  exc_traceback: TracebackType | None,
602
616
  ) -> EvalError:
603
617
  # get text traceback
604
- traceback_text = "\n".join(
605
- traceback.format_exception(exc_type, exc_value, exc_traceback)
606
- )
618
+ traceback_text, truncated = truncate_traceback(exc_type, exc_value, exc_traceback)
607
619
 
608
- with open(os.devnull, "w") as f:
609
- console = Console(record=True, file=f, legacy_windows=True)
610
- console.print(rich_traceback(exc_type, exc_value, exc_traceback))
611
- traceback_ansi = console.export_text(styles=True)
620
+ if not truncated:
621
+ with open(os.devnull, "w") as f:
622
+ console = Console(record=True, file=f, legacy_windows=True)
623
+ console.print(rich_traceback(exc_type, exc_value, exc_traceback))
624
+ traceback_ansi = console.export_text(styles=True)
625
+ else:
626
+ traceback_ansi = traceback_text
612
627
 
613
628
  # return error
614
629
  return EvalError(
@@ -632,6 +647,51 @@ def rich_traceback(
632
647
  return rich_tb
633
648
 
634
649
 
650
+ def truncate_traceback(
651
+ exc_type: Type[Any],
652
+ exc_value: BaseException,
653
+ exc_traceback: TracebackType | None,
654
+ max_length: int = 1048576, # 1MB
655
+ ) -> Tuple[str, bool]:
656
+ tb_list = traceback.format_exception(exc_type, exc_value, exc_traceback)
657
+
658
+ # Keep the front and back of the traceback
659
+ header = tb_list[0]
660
+ error_msg = tb_list[-1]
661
+
662
+ # Join the middle parts (stack frames)
663
+ frames = "".join(tb_list[1:-1])
664
+
665
+ # It all fits, use it as is
666
+ full_tb = header + frames + error_msg
667
+ if len(full_tb) <= max_length:
668
+ return full_tb, False
669
+
670
+ ellipsis = "\n...\n"
671
+
672
+ # Minimum header size
673
+ header_size = min(len(header), 1024)
674
+
675
+ # Minimum frames size
676
+ frames_size = min(len(frames), 1024)
677
+
678
+ # Remaining space for error message
679
+ error_msg_size = max(0, max_length - header_size - frames_size)
680
+
681
+ def truncate_middle(text: str, size: int) -> str:
682
+ if len(text) <= size:
683
+ return text
684
+ half = (size - len(ellipsis)) // 2
685
+ return f"{text[:half]}{ellipsis}{text[-half:]}"
686
+
687
+ # Truncate each part as needed
688
+ truncated_header = truncate_middle(header, header_size)
689
+ truncated_frames = truncate_middle(frames, frames_size)
690
+ truncated_error = truncate_middle(error_msg, error_msg_size)
691
+
692
+ return truncated_header + truncated_frames + truncated_error, True
693
+
694
+
635
695
  class EvalStats(BaseModel):
636
696
  """Timing and usage statistics."""
637
697
 
@@ -23,6 +23,7 @@ class ActiveSample:
23
23
  message_limit: int | None,
24
24
  token_limit: int | None,
25
25
  time_limit: int | None,
26
+ working_limit: int | None,
26
27
  fails_on_error: bool,
27
28
  transcript: Transcript,
28
29
  sandboxes: dict[str, SandboxConnection],
@@ -37,6 +38,7 @@ class ActiveSample:
37
38
  self.message_limit = message_limit
38
39
  self.token_limit = token_limit
39
40
  self.time_limit = time_limit
41
+ self.working_limit = working_limit
40
42
  self.fails_on_error = fails_on_error
41
43
  self.total_messages = 0
42
44
  self.total_tokens = 0
@@ -45,7 +47,7 @@ class ActiveSample:
45
47
  self._interrupt_action: Literal["score", "error"] | None = None
46
48
 
47
49
  @property
48
- def execution_time(self) -> float:
50
+ def running_time(self) -> float:
49
51
  if self.started is not None:
50
52
  completed = (
51
53
  self.completed
@@ -78,6 +80,7 @@ async def active_sample(
78
80
  message_limit: int | None,
79
81
  token_limit: int | None,
80
82
  time_limit: int | None,
83
+ working_limit: int | None,
81
84
  fails_on_error: bool,
82
85
  transcript: Transcript,
83
86
  ) -> AsyncGenerator[ActiveSample, None]:
@@ -90,6 +93,7 @@ async def active_sample(
90
93
  message_limit=message_limit,
91
94
  token_limit=token_limit,
92
95
  time_limit=time_limit,
96
+ working_limit=working_limit,
93
97
  sandboxes=await sandbox_connections(),
94
98
  fails_on_error=fails_on_error,
95
99
  transcript=transcript,
@@ -70,7 +70,7 @@ class SampleLimitEvent(BaseEvent):
70
70
  event: Literal["sample_limit"] = Field(default="sample_limit")
71
71
  """Event type."""
72
72
 
73
- type: Literal["message", "time", "token", "operator", "custom"]
73
+ type: Literal["message", "time", "working", "token", "operator", "custom"]
74
74
  """Type of limit that halted processing"""
75
75
 
76
76
  message: str
@@ -207,6 +207,34 @@ class ToolEvent(BaseEvent):
207
207
  """Required so that we can include '_task' as a member."""
208
208
 
209
209
 
210
+ class SandboxEvent(BaseEvent):
211
+ """Sandbox execution or I/O"""
212
+
213
+ event: Literal["sandbox"] = Field(default="sandbox")
214
+ """Event type"""
215
+
216
+ action: Literal["exec", "read_file", "write_file"]
217
+ """Sandbox action"""
218
+
219
+ cmd: str | None = Field(default=None)
220
+ """Command (for exec)"""
221
+
222
+ options: dict[str, JsonValue] | None = Field(default=None)
223
+ """Options (for exec)"""
224
+
225
+ file: str | None = Field(default=None)
226
+ """File (for read_file and write_file)"""
227
+
228
+ input: str | None = Field(default=None)
229
+ """Input (for cmd and write_file). Truncated to 100 lines."""
230
+
231
+ result: int | None = Field(default=None)
232
+ """Result (for exec)"""
233
+
234
+ output: str | None = Field(default=None)
235
+ """Output (for exec and read_file). Truncated to 100 lines."""
236
+
237
+
210
238
  class ApprovalEvent(BaseEvent):
211
239
  """Tool approval."""
212
240
 
@@ -342,10 +370,12 @@ class SubtaskEvent(BaseEvent):
342
370
  Event: TypeAlias = Union[
343
371
  SampleInitEvent
344
372
  | SampleLimitEvent
373
+ | SandboxEvent
345
374
  | StateEvent
346
375
  | StoreEvent
347
376
  | ModelEvent
348
377
  | ToolEvent
378
+ | SandboxEvent
349
379
  | ApprovalEvent
350
380
  | InputEvent
351
381
  | ScoreEvent
@@ -407,7 +407,7 @@ def tool_param(type_hint: Type[Any], input: Any) -> Any:
407
407
  return tuple(input)
408
408
  elif origin is dict or origin is Dict:
409
409
  if args and len(args) > 1:
410
- return {k: tool_param(args[1], v) for k, v in input}
410
+ return {k: tool_param(args[1], v) for k, v in input.items()}
411
411
  else:
412
412
  return input
413
413
  elif origin is Union or origin is types.UnionType:
@@ -19,7 +19,7 @@ def conversation_tool_mesage(message: ChatMessageTool) -> None:
19
19
  message.error.message.strip() if message.error else message.text.strip()
20
20
  )
21
21
  if output:
22
- content = lines_display(output, 100)
22
+ content = lines_display(output, 50)
23
23
 
24
24
  conversation_panel(
25
25
  title=f"Tool Output: {message.function}",
@@ -1,5 +1,5 @@
1
1
  import abc
2
- import asyncio
2
+ import contextlib
3
3
  import functools
4
4
  import json
5
5
  import logging
@@ -8,7 +8,7 @@ import time
8
8
  from contextvars import ContextVar
9
9
  from copy import deepcopy
10
10
  from types import TracebackType
11
- from typing import Any, Callable, Literal, Type, cast
11
+ from typing import Any, AsyncIterator, Callable, Literal, Type, cast
12
12
 
13
13
  from pydantic_core import to_jsonable_python
14
14
  from tenacity import (
@@ -33,6 +33,7 @@ from inspect_ai._util.registry import (
33
33
  )
34
34
  from inspect_ai._util.retry import log_rate_limit_retry
35
35
  from inspect_ai._util.trace import trace_action
36
+ from inspect_ai._util.working import report_sample_waiting_time
36
37
  from inspect_ai.tool import Tool, ToolChoice, ToolFunction, ToolInfo
37
38
  from inspect_ai.tool._tool_def import ToolDef, tool_defs
38
39
  from inspect_ai.util import concurrency
@@ -435,14 +436,16 @@ class Model:
435
436
  )
436
437
 
437
438
  with trace_action(logger, "Model", f"generate ({str(self)})"):
438
- time_start = time.perf_counter()
439
- result = await self.api.generate(
440
- input=input,
441
- tools=tools,
442
- tool_choice=tool_choice,
443
- config=config,
444
- )
445
- time_elapsed = time.perf_counter() - time_start
439
+ time_start = time.monotonic()
440
+ try:
441
+ result = await self.api.generate(
442
+ input=input,
443
+ tools=tools,
444
+ tool_choice=tool_choice,
445
+ config=config,
446
+ )
447
+ finally:
448
+ time_elapsed = time.monotonic() - time_start
446
449
 
447
450
  if isinstance(result, tuple):
448
451
  output, call = result
@@ -461,8 +464,12 @@ class Model:
461
464
  error_message = f"{error}\n\nRequest:\n{request}"
462
465
  raise RuntimeError(error_message)
463
466
 
464
- # update output with time elapsed
465
- output.time = time_elapsed
467
+ # update output with time (call.time captures time spent
468
+ # on the actual request that succeeds w/ status 200)
469
+ if call and call.time is not None:
470
+ output.time = call.time
471
+ else:
472
+ output.time = time_elapsed
466
473
 
467
474
  # add views to tool calls
468
475
  for choice in output.choices:
@@ -488,8 +495,13 @@ class Model:
488
495
 
489
496
  return output
490
497
 
491
- # call the model
498
+ # call the model (this will so retries, etc., so report waiting time
499
+ # as elapsed time - actual time for successful model call)
500
+ time_start = time.monotonic()
492
501
  model_output = await generate()
502
+ total_time = time.monotonic() - time_start
503
+ if model_output.time:
504
+ report_sample_waiting_time(total_time - model_output.time)
493
505
 
494
506
  # return results
495
507
  return model_output
@@ -513,7 +525,10 @@ class Model:
513
525
  # override the _connection_key() argument to provide a scope within which
514
526
  # to enforce max_connections (e.g. by account/api_key, by endpoint, etc.)
515
527
 
516
- def _connection_concurrency(self, config: GenerateConfig) -> asyncio.Semaphore:
528
+ @contextlib.asynccontextmanager
529
+ async def _connection_concurrency(
530
+ self, config: GenerateConfig
531
+ ) -> AsyncIterator[None]:
517
532
  """Get the appropriate connection semaphore for this model instance."""
518
533
  max_connections = (
519
534
  config.max_connections
@@ -521,11 +536,12 @@ class Model:
521
536
  else self.api.max_connections()
522
537
  )
523
538
  model_name = ModelName(self)
524
- return concurrency(
539
+ async with concurrency(
525
540
  name=f"{model_name.api}",
526
541
  concurrency=max_connections,
527
542
  key=f"Model{self.api.connection_key()}",
528
- )
543
+ ):
544
+ yield
529
545
 
530
546
  def _record_model_interaction(
531
547
  self,
@@ -1,6 +1,6 @@
1
1
  from typing import Any, Callable
2
2
 
3
- from pydantic import BaseModel, JsonValue
3
+ from pydantic import BaseModel, Field, JsonValue
4
4
 
5
5
  from inspect_ai._util.json import jsonable_python
6
6
 
@@ -22,9 +22,15 @@ class ModelCall(BaseModel):
22
22
  response: dict[str, JsonValue]
23
23
  """Raw response data from model."""
24
24
 
25
+ time: float | None = Field(default=None)
26
+ """Time taken for underlying model call."""
27
+
25
28
  @staticmethod
26
29
  def create(
27
- request: Any, response: Any, filter: ModelCallFilter | None = None
30
+ request: Any,
31
+ response: Any,
32
+ filter: ModelCallFilter | None = None,
33
+ time: float | None = None,
28
34
  ) -> "ModelCall":
29
35
  """Create a ModelCall object.
30
36
 
@@ -36,6 +42,7 @@ class ModelCall(BaseModel):
36
42
  request (Any): Request object (dict, dataclass, BaseModel, etc.)
37
43
  response (Any): Response object (dict, dataclass, BaseModel, etc.)
38
44
  filter (ModelCallFilter): Function for filtering model call data.
45
+ time: Time taken for underlying ModelCall
39
46
  """
40
47
  request_dict = jsonable_python(request)
41
48
  if filter:
@@ -43,7 +50,7 @@ class ModelCall(BaseModel):
43
50
  response_dict = jsonable_python(response)
44
51
  if filter:
45
52
  response_dict = _walk_json_value(None, response_dict, filter)
46
- return ModelCall(request=request_dict, response=response_dict)
53
+ return ModelCall(request=request_dict, response=response_dict, time=time)
47
54
 
48
55
 
49
56
  def _walk_json_value(
@@ -5,6 +5,8 @@ from copy import copy
5
5
  from logging import getLogger
6
6
  from typing import Any, Literal, Tuple, TypedDict, cast
7
7
 
8
+ from .util.tracker import HttpxTimeTracker
9
+
8
10
  if sys.version_info >= (3, 11):
9
11
  from typing import NotRequired
10
12
  else:
@@ -150,6 +152,9 @@ class AnthropicAPI(ModelAPI):
150
152
  **model_args,
151
153
  )
152
154
 
155
+ # create time tracker
156
+ self._time_tracker = HttpxTimeTracker(self.client._client)
157
+
153
158
  @override
154
159
  async def close(self) -> None:
155
160
  await self.client.close()
@@ -167,6 +172,9 @@ class AnthropicAPI(ModelAPI):
167
172
  tool_choice: ToolChoice,
168
173
  config: GenerateConfig,
169
174
  ) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
175
+ # allocate request_id (so we can see it from ModelCall)
176
+ request_id = self._time_tracker.start_request()
177
+
170
178
  # setup request and response for ModelCall
171
179
  request: dict[str, Any] = {}
172
180
  response: dict[str, Any] = {}
@@ -176,6 +184,7 @@ class AnthropicAPI(ModelAPI):
176
184
  request=request,
177
185
  response=response,
178
186
  filter=model_call_filter,
187
+ time=self._time_tracker.end_request(request_id),
179
188
  )
180
189
 
181
190
  # generate
@@ -200,9 +209,11 @@ class AnthropicAPI(ModelAPI):
200
209
  # additional options
201
210
  request = request | self.completion_params(config)
202
211
 
203
- # computer use beta
212
+ # extra headers (for time tracker and computer use)
213
+ extra_headers = {HttpxTimeTracker.REQUEST_ID_HEADER: request_id}
204
214
  if computer_use:
205
- request["extra_headers"] = {"anthropic-beta": "computer-use-2024-10-22"}
215
+ extra_headers["anthropic-beta"] = "computer-use-2024-10-22"
216
+ request["extra_headers"] = extra_headers
206
217
 
207
218
  # extra_body
208
219
  if self.extra_body is not None:
@@ -31,6 +31,7 @@ from .._model_output import ChatCompletionChoice, ModelOutput, ModelUsage
31
31
  from .util import (
32
32
  model_base_url,
33
33
  )
34
+ from .util.tracker import BotoTimeTracker
34
35
 
35
36
  # Model for Bedrock Converse API (Response)
36
37
  # generated from: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock-runtime/client/converse.html#converse
@@ -256,6 +257,9 @@ class BedrockAPI(ModelAPI):
256
257
  # Create a shared session to be used when generating
257
258
  self.session = aioboto3.Session()
258
259
 
260
+ # create time tracker
261
+ self._time_tracker = BotoTimeTracker(self.session)
262
+
259
263
  except ImportError:
260
264
  raise pip_dependency_error("Bedrock API", ["aioboto3"])
261
265
 
@@ -313,6 +317,7 @@ class BedrockAPI(ModelAPI):
313
317
  from botocore.exceptions import ClientError
314
318
 
315
319
  # The bedrock client
320
+ request_id = self._time_tracker.start_request()
316
321
  async with self.session.client( # type: ignore[call-overload]
317
322
  service_name="bedrock-runtime",
318
323
  endpoint_url=self.base_url,
@@ -325,6 +330,7 @@ class BedrockAPI(ModelAPI):
325
330
  else DEFAULT_MAX_RETRIES,
326
331
  mode="adaptive",
327
332
  ),
333
+ user_agent_extra=self._time_tracker.user_agent_extra(request_id),
328
334
  ),
329
335
  **self.model_args,
330
336
  ) as client:
@@ -364,6 +370,7 @@ class BedrockAPI(ModelAPI):
364
370
  request.model_dump(exclude_none=True)
365
371
  ),
366
372
  response=response,
373
+ time=self._time_tracker.end_request(request_id),
367
374
  )
368
375
 
369
376
  try:
@@ -19,6 +19,7 @@ from .util import (
19
19
  is_chat_api_rate_limit,
20
20
  model_base_url,
21
21
  )
22
+ from .util.tracker import HttpxTimeTracker
22
23
 
23
24
  # https://developers.cloudflare.com/workers-ai/models/#text-generation
24
25
 
@@ -50,6 +51,7 @@ class CloudFlareAPI(ModelAPI):
50
51
  if not self.api_key:
51
52
  raise environment_prerequisite_error("CloudFlare", CLOUDFLARE_API_TOKEN)
52
53
  self.client = httpx.AsyncClient()
54
+ self._time_tracker = HttpxTimeTracker(self.client)
53
55
  base_url = model_base_url(base_url, "CLOUDFLARE_BASE_URL")
54
56
  self.base_url = (
55
57
  base_url if base_url else "https://api.cloudflare.com/client/v4/accounts"
@@ -76,12 +78,28 @@ class CloudFlareAPI(ModelAPI):
76
78
  json["max_tokens"] = config.max_tokens
77
79
  json["messages"] = chat_api_input(input, tools, self.chat_api_handler())
78
80
 
81
+ # request_id
82
+ request_id = self._time_tracker.start_request()
83
+
84
+ # setup response
85
+ response: dict[str, Any] = {}
86
+
87
+ def model_call() -> ModelCall:
88
+ return ModelCall.create(
89
+ request=json,
90
+ response=response,
91
+ time=self._time_tracker.end_request(request_id),
92
+ )
93
+
79
94
  # make the call
80
95
  response = await chat_api_request(
81
96
  self.client,
82
97
  model_name=self.model_name,
83
98
  url=f"{chat_url}/{self.model_name}",
84
- headers={"Authorization": f"Bearer {self.api_key}"},
99
+ headers={
100
+ "Authorization": f"Bearer {self.api_key}",
101
+ HttpxTimeTracker.REQUEST_ID_HEADER: request_id,
102
+ },
85
103
  json=json,
86
104
  config=config,
87
105
  )
@@ -102,13 +120,8 @@ class CloudFlareAPI(ModelAPI):
102
120
  ],
103
121
  )
104
122
 
105
- # record call
106
- call = ModelCall.create(
107
- request=dict(model_name=self.model_name, **json), response=response
108
- )
109
-
110
123
  # return
111
- return output, call
124
+ return output, model_call()
112
125
  else:
113
126
  error = str(response.get("errors", "Unknown"))
114
127
  raise RuntimeError(f"Error calling {self.model_name}: {error}")
@@ -229,6 +229,8 @@ class GoogleGenAIAPI(ModelAPI):
229
229
  response=response,
230
230
  )
231
231
 
232
+ # TODO: would need to monkey patch AuthorizedSession.request
233
+
232
234
  try:
233
235
  response = await self.client.aio.models.generate_content(
234
236
  model=self.model_name,