inspect-ai 0.3.69__py3-none-any.whl → 0.3.70__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/_cli/eval.py +13 -1
- inspect_ai/_display/textual/app.py +3 -2
- inspect_ai/_display/textual/widgets/samples.py +4 -10
- inspect_ai/_display/textual/widgets/transcript.py +25 -12
- inspect_ai/_eval/eval.py +14 -2
- inspect_ai/_eval/evalset.py +6 -1
- inspect_ai/_eval/run.py +6 -0
- inspect_ai/_eval/task/run.py +44 -15
- inspect_ai/_eval/task/task.py +26 -3
- inspect_ai/_util/interrupt.py +6 -0
- inspect_ai/_util/logger.py +19 -0
- inspect_ai/_util/rich.py +7 -8
- inspect_ai/_util/text.py +13 -0
- inspect_ai/_util/transcript.py +10 -2
- inspect_ai/_util/working.py +46 -0
- inspect_ai/_view/www/dist/assets/index.css +56 -12
- inspect_ai/_view/www/dist/assets/index.js +904 -750
- inspect_ai/_view/www/log-schema.json +337 -2
- inspect_ai/_view/www/node_modules/flatted/python/flatted.py +149 -0
- inspect_ai/_view/www/node_modules/flatted/python/test.py +63 -0
- inspect_ai/_view/www/src/appearance/icons.ts +3 -1
- inspect_ai/_view/www/src/metadata/RenderedContent.tsx +0 -1
- inspect_ai/_view/www/src/samples/SampleDisplay.module.css +9 -1
- inspect_ai/_view/www/src/samples/SampleDisplay.tsx +28 -1
- inspect_ai/_view/www/src/samples/SampleSummaryView.module.css +4 -0
- inspect_ai/_view/www/src/samples/SampleSummaryView.tsx +23 -2
- inspect_ai/_view/www/src/samples/transcript/SampleLimitEventView.tsx +4 -0
- inspect_ai/_view/www/src/samples/transcript/SandboxEventView.module.css +32 -0
- inspect_ai/_view/www/src/samples/transcript/SandboxEventView.tsx +152 -0
- inspect_ai/_view/www/src/samples/transcript/StepEventView.tsx +9 -2
- inspect_ai/_view/www/src/samples/transcript/TranscriptView.tsx +19 -1
- inspect_ai/_view/www/src/samples/transcript/event/EventPanel.tsx +6 -3
- inspect_ai/_view/www/src/samples/transcript/types.ts +3 -1
- inspect_ai/_view/www/src/types/log.d.ts +188 -108
- inspect_ai/_view/www/src/utils/format.ts +7 -4
- inspect_ai/_view/www/src/workspace/WorkSpaceView.tsx +9 -6
- inspect_ai/log/__init__.py +2 -0
- inspect_ai/log/_condense.py +1 -0
- inspect_ai/log/_log.py +72 -12
- inspect_ai/log/_samples.py +5 -1
- inspect_ai/log/_transcript.py +31 -1
- inspect_ai/model/_call_tools.py +1 -1
- inspect_ai/model/_conversation.py +1 -1
- inspect_ai/model/_model.py +32 -16
- inspect_ai/model/_model_call.py +10 -3
- inspect_ai/model/_providers/anthropic.py +13 -2
- inspect_ai/model/_providers/bedrock.py +7 -0
- inspect_ai/model/_providers/cloudflare.py +20 -7
- inspect_ai/model/_providers/google.py +2 -0
- inspect_ai/model/_providers/groq.py +57 -23
- inspect_ai/model/_providers/hf.py +6 -0
- inspect_ai/model/_providers/mistral.py +78 -51
- inspect_ai/model/_providers/openai.py +9 -0
- inspect_ai/model/_providers/providers.py +1 -1
- inspect_ai/model/_providers/util/tracker.py +92 -0
- inspect_ai/model/_providers/vllm.py +13 -5
- inspect_ai/solver/_basic_agent.py +1 -3
- inspect_ai/solver/_bridge/patch.py +0 -2
- inspect_ai/solver/_limit.py +4 -4
- inspect_ai/solver/_plan.py +0 -3
- inspect_ai/solver/_task_state.py +7 -0
- inspect_ai/tool/_tools/_web_search.py +3 -3
- inspect_ai/util/_concurrency.py +14 -8
- inspect_ai/util/_sandbox/context.py +15 -0
- inspect_ai/util/_sandbox/docker/docker.py +7 -5
- inspect_ai/util/_sandbox/environment.py +32 -1
- inspect_ai/util/_sandbox/events.py +149 -0
- inspect_ai/util/_sandbox/local.py +3 -3
- {inspect_ai-0.3.69.dist-info → inspect_ai-0.3.70.dist-info}/METADATA +3 -3
- {inspect_ai-0.3.69.dist-info → inspect_ai-0.3.70.dist-info}/RECORD +74 -67
- {inspect_ai-0.3.69.dist-info → inspect_ai-0.3.70.dist-info}/LICENSE +0 -0
- {inspect_ai-0.3.69.dist-info → inspect_ai-0.3.70.dist-info}/WHEEL +0 -0
- {inspect_ai-0.3.69.dist-info → inspect_ai-0.3.70.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.69.dist-info → inspect_ai-0.3.70.dist-info}/top_level.txt +0 -0
inspect_ai/log/__init__.py
CHANGED
@@ -41,6 +41,7 @@ from ._transcript import (
|
|
41
41
|
ModelEvent,
|
42
42
|
SampleInitEvent,
|
43
43
|
SampleLimitEvent,
|
44
|
+
SandboxEvent,
|
44
45
|
ScoreEvent,
|
45
46
|
StateEvent,
|
46
47
|
StepEvent,
|
@@ -82,6 +83,7 @@ __all__ = [
|
|
82
83
|
"ModelEvent",
|
83
84
|
"SampleInitEvent",
|
84
85
|
"SampleLimitEvent",
|
86
|
+
"SandboxEvent",
|
85
87
|
"ScoreEvent",
|
86
88
|
"StateEvent",
|
87
89
|
"StepEvent",
|
inspect_ai/log/_condense.py
CHANGED
inspect_ai/log/_log.py
CHANGED
@@ -4,7 +4,7 @@ import sys
|
|
4
4
|
import traceback
|
5
5
|
from logging import getLogger
|
6
6
|
from types import TracebackType
|
7
|
-
from typing import Any, Literal, Type, TypedDict
|
7
|
+
from typing import Any, Literal, Tuple, Type, TypedDict
|
8
8
|
|
9
9
|
import click
|
10
10
|
import tenacity
|
@@ -86,13 +86,16 @@ class EvalConfig(BaseModel):
|
|
86
86
|
"""
|
87
87
|
|
88
88
|
message_limit: int | None = Field(default=None)
|
89
|
-
"""Maximum messages to allow
|
89
|
+
"""Maximum messages to allow per sample."""
|
90
90
|
|
91
91
|
token_limit: int | None = Field(default=None)
|
92
|
-
"""Maximum tokens
|
92
|
+
"""Maximum tokens usage per sample."""
|
93
93
|
|
94
94
|
time_limit: int | None = Field(default=None)
|
95
|
-
"""Maximum
|
95
|
+
"""Maximum clock time per sample."""
|
96
|
+
|
97
|
+
working_limit: int | None = Field(default=None)
|
98
|
+
"""Meximum working time per sample."""
|
96
99
|
|
97
100
|
max_samples: int | None = Field(default=None)
|
98
101
|
"""Maximum number of samples to run in parallel."""
|
@@ -141,7 +144,9 @@ class EvalConfig(BaseModel):
|
|
141
144
|
class EvalSampleLimit(BaseModel):
|
142
145
|
"""Limit encontered by sample."""
|
143
146
|
|
144
|
-
type: Literal[
|
147
|
+
type: Literal[
|
148
|
+
"context", "time", "working", "message", "token", "operator", "custom"
|
149
|
+
]
|
145
150
|
"""The type of limit"""
|
146
151
|
|
147
152
|
limit: int
|
@@ -218,6 +223,15 @@ class EvalSample(BaseModel):
|
|
218
223
|
model_usage: dict[str, ModelUsage] = Field(default_factory=dict)
|
219
224
|
"""Model token usage for sample."""
|
220
225
|
|
226
|
+
total_time: float | None = Field(default=None)
|
227
|
+
"""Total time that the sample was running."""
|
228
|
+
|
229
|
+
working_time: float | None = Field(default=None)
|
230
|
+
"""Time spent working (model generation, sandbox calls, etc.)"""
|
231
|
+
|
232
|
+
uuid: str | None = Field(default=None)
|
233
|
+
"""Globally unique identifier for sample run (exists for samples created in Inspect >= 0.3.70)"""
|
234
|
+
|
221
235
|
error: EvalError | None = Field(default=None)
|
222
236
|
"""Error that halted sample."""
|
223
237
|
|
@@ -601,14 +615,15 @@ def eval_error(
|
|
601
615
|
exc_traceback: TracebackType | None,
|
602
616
|
) -> EvalError:
|
603
617
|
# get text traceback
|
604
|
-
traceback_text =
|
605
|
-
traceback.format_exception(exc_type, exc_value, exc_traceback)
|
606
|
-
)
|
618
|
+
traceback_text, truncated = truncate_traceback(exc_type, exc_value, exc_traceback)
|
607
619
|
|
608
|
-
|
609
|
-
|
610
|
-
|
611
|
-
|
620
|
+
if not truncated:
|
621
|
+
with open(os.devnull, "w") as f:
|
622
|
+
console = Console(record=True, file=f, legacy_windows=True)
|
623
|
+
console.print(rich_traceback(exc_type, exc_value, exc_traceback))
|
624
|
+
traceback_ansi = console.export_text(styles=True)
|
625
|
+
else:
|
626
|
+
traceback_ansi = traceback_text
|
612
627
|
|
613
628
|
# return error
|
614
629
|
return EvalError(
|
@@ -632,6 +647,51 @@ def rich_traceback(
|
|
632
647
|
return rich_tb
|
633
648
|
|
634
649
|
|
650
|
+
def truncate_traceback(
|
651
|
+
exc_type: Type[Any],
|
652
|
+
exc_value: BaseException,
|
653
|
+
exc_traceback: TracebackType | None,
|
654
|
+
max_length: int = 1048576, # 1MB
|
655
|
+
) -> Tuple[str, bool]:
|
656
|
+
tb_list = traceback.format_exception(exc_type, exc_value, exc_traceback)
|
657
|
+
|
658
|
+
# Keep the front and back of the traceback
|
659
|
+
header = tb_list[0]
|
660
|
+
error_msg = tb_list[-1]
|
661
|
+
|
662
|
+
# Join the middle parts (stack frames)
|
663
|
+
frames = "".join(tb_list[1:-1])
|
664
|
+
|
665
|
+
# It all fits, use it as is
|
666
|
+
full_tb = header + frames + error_msg
|
667
|
+
if len(full_tb) <= max_length:
|
668
|
+
return full_tb, False
|
669
|
+
|
670
|
+
ellipsis = "\n...\n"
|
671
|
+
|
672
|
+
# Minimum header size
|
673
|
+
header_size = min(len(header), 1024)
|
674
|
+
|
675
|
+
# Minimum frames size
|
676
|
+
frames_size = min(len(frames), 1024)
|
677
|
+
|
678
|
+
# Remaining space for error message
|
679
|
+
error_msg_size = max(0, max_length - header_size - frames_size)
|
680
|
+
|
681
|
+
def truncate_middle(text: str, size: int) -> str:
|
682
|
+
if len(text) <= size:
|
683
|
+
return text
|
684
|
+
half = (size - len(ellipsis)) // 2
|
685
|
+
return f"{text[:half]}{ellipsis}{text[-half:]}"
|
686
|
+
|
687
|
+
# Truncate each part as needed
|
688
|
+
truncated_header = truncate_middle(header, header_size)
|
689
|
+
truncated_frames = truncate_middle(frames, frames_size)
|
690
|
+
truncated_error = truncate_middle(error_msg, error_msg_size)
|
691
|
+
|
692
|
+
return truncated_header + truncated_frames + truncated_error, True
|
693
|
+
|
694
|
+
|
635
695
|
class EvalStats(BaseModel):
|
636
696
|
"""Timing and usage statistics."""
|
637
697
|
|
inspect_ai/log/_samples.py
CHANGED
@@ -23,6 +23,7 @@ class ActiveSample:
|
|
23
23
|
message_limit: int | None,
|
24
24
|
token_limit: int | None,
|
25
25
|
time_limit: int | None,
|
26
|
+
working_limit: int | None,
|
26
27
|
fails_on_error: bool,
|
27
28
|
transcript: Transcript,
|
28
29
|
sandboxes: dict[str, SandboxConnection],
|
@@ -37,6 +38,7 @@ class ActiveSample:
|
|
37
38
|
self.message_limit = message_limit
|
38
39
|
self.token_limit = token_limit
|
39
40
|
self.time_limit = time_limit
|
41
|
+
self.working_limit = working_limit
|
40
42
|
self.fails_on_error = fails_on_error
|
41
43
|
self.total_messages = 0
|
42
44
|
self.total_tokens = 0
|
@@ -45,7 +47,7 @@ class ActiveSample:
|
|
45
47
|
self._interrupt_action: Literal["score", "error"] | None = None
|
46
48
|
|
47
49
|
@property
|
48
|
-
def
|
50
|
+
def running_time(self) -> float:
|
49
51
|
if self.started is not None:
|
50
52
|
completed = (
|
51
53
|
self.completed
|
@@ -78,6 +80,7 @@ async def active_sample(
|
|
78
80
|
message_limit: int | None,
|
79
81
|
token_limit: int | None,
|
80
82
|
time_limit: int | None,
|
83
|
+
working_limit: int | None,
|
81
84
|
fails_on_error: bool,
|
82
85
|
transcript: Transcript,
|
83
86
|
) -> AsyncGenerator[ActiveSample, None]:
|
@@ -90,6 +93,7 @@ async def active_sample(
|
|
90
93
|
message_limit=message_limit,
|
91
94
|
token_limit=token_limit,
|
92
95
|
time_limit=time_limit,
|
96
|
+
working_limit=working_limit,
|
93
97
|
sandboxes=await sandbox_connections(),
|
94
98
|
fails_on_error=fails_on_error,
|
95
99
|
transcript=transcript,
|
inspect_ai/log/_transcript.py
CHANGED
@@ -70,7 +70,7 @@ class SampleLimitEvent(BaseEvent):
|
|
70
70
|
event: Literal["sample_limit"] = Field(default="sample_limit")
|
71
71
|
"""Event type."""
|
72
72
|
|
73
|
-
type: Literal["message", "time", "token", "operator", "custom"]
|
73
|
+
type: Literal["message", "time", "working", "token", "operator", "custom"]
|
74
74
|
"""Type of limit that halted processing"""
|
75
75
|
|
76
76
|
message: str
|
@@ -207,6 +207,34 @@ class ToolEvent(BaseEvent):
|
|
207
207
|
"""Required so that we can include '_task' as a member."""
|
208
208
|
|
209
209
|
|
210
|
+
class SandboxEvent(BaseEvent):
|
211
|
+
"""Sandbox execution or I/O"""
|
212
|
+
|
213
|
+
event: Literal["sandbox"] = Field(default="sandbox")
|
214
|
+
"""Event type"""
|
215
|
+
|
216
|
+
action: Literal["exec", "read_file", "write_file"]
|
217
|
+
"""Sandbox action"""
|
218
|
+
|
219
|
+
cmd: str | None = Field(default=None)
|
220
|
+
"""Command (for exec)"""
|
221
|
+
|
222
|
+
options: dict[str, JsonValue] | None = Field(default=None)
|
223
|
+
"""Options (for exec)"""
|
224
|
+
|
225
|
+
file: str | None = Field(default=None)
|
226
|
+
"""File (for read_file and write_file)"""
|
227
|
+
|
228
|
+
input: str | None = Field(default=None)
|
229
|
+
"""Input (for cmd and write_file). Truncated to 100 lines."""
|
230
|
+
|
231
|
+
result: int | None = Field(default=None)
|
232
|
+
"""Result (for exec)"""
|
233
|
+
|
234
|
+
output: str | None = Field(default=None)
|
235
|
+
"""Output (for exec and read_file). Truncated to 100 lines."""
|
236
|
+
|
237
|
+
|
210
238
|
class ApprovalEvent(BaseEvent):
|
211
239
|
"""Tool approval."""
|
212
240
|
|
@@ -342,10 +370,12 @@ class SubtaskEvent(BaseEvent):
|
|
342
370
|
Event: TypeAlias = Union[
|
343
371
|
SampleInitEvent
|
344
372
|
| SampleLimitEvent
|
373
|
+
| SandboxEvent
|
345
374
|
| StateEvent
|
346
375
|
| StoreEvent
|
347
376
|
| ModelEvent
|
348
377
|
| ToolEvent
|
378
|
+
| SandboxEvent
|
349
379
|
| ApprovalEvent
|
350
380
|
| InputEvent
|
351
381
|
| ScoreEvent
|
inspect_ai/model/_call_tools.py
CHANGED
@@ -407,7 +407,7 @@ def tool_param(type_hint: Type[Any], input: Any) -> Any:
|
|
407
407
|
return tuple(input)
|
408
408
|
elif origin is dict or origin is Dict:
|
409
409
|
if args and len(args) > 1:
|
410
|
-
return {k: tool_param(args[1], v) for k, v in input}
|
410
|
+
return {k: tool_param(args[1], v) for k, v in input.items()}
|
411
411
|
else:
|
412
412
|
return input
|
413
413
|
elif origin is Union or origin is types.UnionType:
|
@@ -19,7 +19,7 @@ def conversation_tool_mesage(message: ChatMessageTool) -> None:
|
|
19
19
|
message.error.message.strip() if message.error else message.text.strip()
|
20
20
|
)
|
21
21
|
if output:
|
22
|
-
content = lines_display(output,
|
22
|
+
content = lines_display(output, 50)
|
23
23
|
|
24
24
|
conversation_panel(
|
25
25
|
title=f"Tool Output: {message.function}",
|
inspect_ai/model/_model.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1
1
|
import abc
|
2
|
-
import
|
2
|
+
import contextlib
|
3
3
|
import functools
|
4
4
|
import json
|
5
5
|
import logging
|
@@ -8,7 +8,7 @@ import time
|
|
8
8
|
from contextvars import ContextVar
|
9
9
|
from copy import deepcopy
|
10
10
|
from types import TracebackType
|
11
|
-
from typing import Any, Callable, Literal, Type, cast
|
11
|
+
from typing import Any, AsyncIterator, Callable, Literal, Type, cast
|
12
12
|
|
13
13
|
from pydantic_core import to_jsonable_python
|
14
14
|
from tenacity import (
|
@@ -33,6 +33,7 @@ from inspect_ai._util.registry import (
|
|
33
33
|
)
|
34
34
|
from inspect_ai._util.retry import log_rate_limit_retry
|
35
35
|
from inspect_ai._util.trace import trace_action
|
36
|
+
from inspect_ai._util.working import report_sample_waiting_time
|
36
37
|
from inspect_ai.tool import Tool, ToolChoice, ToolFunction, ToolInfo
|
37
38
|
from inspect_ai.tool._tool_def import ToolDef, tool_defs
|
38
39
|
from inspect_ai.util import concurrency
|
@@ -435,14 +436,16 @@ class Model:
|
|
435
436
|
)
|
436
437
|
|
437
438
|
with trace_action(logger, "Model", f"generate ({str(self)})"):
|
438
|
-
time_start = time.
|
439
|
-
|
440
|
-
|
441
|
-
|
442
|
-
|
443
|
-
|
444
|
-
|
445
|
-
|
439
|
+
time_start = time.monotonic()
|
440
|
+
try:
|
441
|
+
result = await self.api.generate(
|
442
|
+
input=input,
|
443
|
+
tools=tools,
|
444
|
+
tool_choice=tool_choice,
|
445
|
+
config=config,
|
446
|
+
)
|
447
|
+
finally:
|
448
|
+
time_elapsed = time.monotonic() - time_start
|
446
449
|
|
447
450
|
if isinstance(result, tuple):
|
448
451
|
output, call = result
|
@@ -461,8 +464,12 @@ class Model:
|
|
461
464
|
error_message = f"{error}\n\nRequest:\n{request}"
|
462
465
|
raise RuntimeError(error_message)
|
463
466
|
|
464
|
-
# update output with time
|
465
|
-
|
467
|
+
# update output with time (call.time captures time spent
|
468
|
+
# on the actual request that succeeds w/ status 200)
|
469
|
+
if call and call.time is not None:
|
470
|
+
output.time = call.time
|
471
|
+
else:
|
472
|
+
output.time = time_elapsed
|
466
473
|
|
467
474
|
# add views to tool calls
|
468
475
|
for choice in output.choices:
|
@@ -488,8 +495,13 @@ class Model:
|
|
488
495
|
|
489
496
|
return output
|
490
497
|
|
491
|
-
# call the model
|
498
|
+
# call the model (this will so retries, etc., so report waiting time
|
499
|
+
# as elapsed time - actual time for successful model call)
|
500
|
+
time_start = time.monotonic()
|
492
501
|
model_output = await generate()
|
502
|
+
total_time = time.monotonic() - time_start
|
503
|
+
if model_output.time:
|
504
|
+
report_sample_waiting_time(total_time - model_output.time)
|
493
505
|
|
494
506
|
# return results
|
495
507
|
return model_output
|
@@ -513,7 +525,10 @@ class Model:
|
|
513
525
|
# override the _connection_key() argument to provide a scope within which
|
514
526
|
# to enforce max_connections (e.g. by account/api_key, by endpoint, etc.)
|
515
527
|
|
516
|
-
|
528
|
+
@contextlib.asynccontextmanager
|
529
|
+
async def _connection_concurrency(
|
530
|
+
self, config: GenerateConfig
|
531
|
+
) -> AsyncIterator[None]:
|
517
532
|
"""Get the appropriate connection semaphore for this model instance."""
|
518
533
|
max_connections = (
|
519
534
|
config.max_connections
|
@@ -521,11 +536,12 @@ class Model:
|
|
521
536
|
else self.api.max_connections()
|
522
537
|
)
|
523
538
|
model_name = ModelName(self)
|
524
|
-
|
539
|
+
async with concurrency(
|
525
540
|
name=f"{model_name.api}",
|
526
541
|
concurrency=max_connections,
|
527
542
|
key=f"Model{self.api.connection_key()}",
|
528
|
-
)
|
543
|
+
):
|
544
|
+
yield
|
529
545
|
|
530
546
|
def _record_model_interaction(
|
531
547
|
self,
|
inspect_ai/model/_model_call.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
from typing import Any, Callable
|
2
2
|
|
3
|
-
from pydantic import BaseModel, JsonValue
|
3
|
+
from pydantic import BaseModel, Field, JsonValue
|
4
4
|
|
5
5
|
from inspect_ai._util.json import jsonable_python
|
6
6
|
|
@@ -22,9 +22,15 @@ class ModelCall(BaseModel):
|
|
22
22
|
response: dict[str, JsonValue]
|
23
23
|
"""Raw response data from model."""
|
24
24
|
|
25
|
+
time: float | None = Field(default=None)
|
26
|
+
"""Time taken for underlying model call."""
|
27
|
+
|
25
28
|
@staticmethod
|
26
29
|
def create(
|
27
|
-
request: Any,
|
30
|
+
request: Any,
|
31
|
+
response: Any,
|
32
|
+
filter: ModelCallFilter | None = None,
|
33
|
+
time: float | None = None,
|
28
34
|
) -> "ModelCall":
|
29
35
|
"""Create a ModelCall object.
|
30
36
|
|
@@ -36,6 +42,7 @@ class ModelCall(BaseModel):
|
|
36
42
|
request (Any): Request object (dict, dataclass, BaseModel, etc.)
|
37
43
|
response (Any): Response object (dict, dataclass, BaseModel, etc.)
|
38
44
|
filter (ModelCallFilter): Function for filtering model call data.
|
45
|
+
time: Time taken for underlying ModelCall
|
39
46
|
"""
|
40
47
|
request_dict = jsonable_python(request)
|
41
48
|
if filter:
|
@@ -43,7 +50,7 @@ class ModelCall(BaseModel):
|
|
43
50
|
response_dict = jsonable_python(response)
|
44
51
|
if filter:
|
45
52
|
response_dict = _walk_json_value(None, response_dict, filter)
|
46
|
-
return ModelCall(request=request_dict, response=response_dict)
|
53
|
+
return ModelCall(request=request_dict, response=response_dict, time=time)
|
47
54
|
|
48
55
|
|
49
56
|
def _walk_json_value(
|
@@ -5,6 +5,8 @@ from copy import copy
|
|
5
5
|
from logging import getLogger
|
6
6
|
from typing import Any, Literal, Tuple, TypedDict, cast
|
7
7
|
|
8
|
+
from .util.tracker import HttpxTimeTracker
|
9
|
+
|
8
10
|
if sys.version_info >= (3, 11):
|
9
11
|
from typing import NotRequired
|
10
12
|
else:
|
@@ -150,6 +152,9 @@ class AnthropicAPI(ModelAPI):
|
|
150
152
|
**model_args,
|
151
153
|
)
|
152
154
|
|
155
|
+
# create time tracker
|
156
|
+
self._time_tracker = HttpxTimeTracker(self.client._client)
|
157
|
+
|
153
158
|
@override
|
154
159
|
async def close(self) -> None:
|
155
160
|
await self.client.close()
|
@@ -167,6 +172,9 @@ class AnthropicAPI(ModelAPI):
|
|
167
172
|
tool_choice: ToolChoice,
|
168
173
|
config: GenerateConfig,
|
169
174
|
) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
|
175
|
+
# allocate request_id (so we can see it from ModelCall)
|
176
|
+
request_id = self._time_tracker.start_request()
|
177
|
+
|
170
178
|
# setup request and response for ModelCall
|
171
179
|
request: dict[str, Any] = {}
|
172
180
|
response: dict[str, Any] = {}
|
@@ -176,6 +184,7 @@ class AnthropicAPI(ModelAPI):
|
|
176
184
|
request=request,
|
177
185
|
response=response,
|
178
186
|
filter=model_call_filter,
|
187
|
+
time=self._time_tracker.end_request(request_id),
|
179
188
|
)
|
180
189
|
|
181
190
|
# generate
|
@@ -200,9 +209,11 @@ class AnthropicAPI(ModelAPI):
|
|
200
209
|
# additional options
|
201
210
|
request = request | self.completion_params(config)
|
202
211
|
|
203
|
-
# computer use
|
212
|
+
# extra headers (for time tracker and computer use)
|
213
|
+
extra_headers = {HttpxTimeTracker.REQUEST_ID_HEADER: request_id}
|
204
214
|
if computer_use:
|
205
|
-
|
215
|
+
extra_headers["anthropic-beta"] = "computer-use-2024-10-22"
|
216
|
+
request["extra_headers"] = extra_headers
|
206
217
|
|
207
218
|
# extra_body
|
208
219
|
if self.extra_body is not None:
|
@@ -31,6 +31,7 @@ from .._model_output import ChatCompletionChoice, ModelOutput, ModelUsage
|
|
31
31
|
from .util import (
|
32
32
|
model_base_url,
|
33
33
|
)
|
34
|
+
from .util.tracker import BotoTimeTracker
|
34
35
|
|
35
36
|
# Model for Bedrock Converse API (Response)
|
36
37
|
# generated from: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock-runtime/client/converse.html#converse
|
@@ -256,6 +257,9 @@ class BedrockAPI(ModelAPI):
|
|
256
257
|
# Create a shared session to be used when generating
|
257
258
|
self.session = aioboto3.Session()
|
258
259
|
|
260
|
+
# create time tracker
|
261
|
+
self._time_tracker = BotoTimeTracker(self.session)
|
262
|
+
|
259
263
|
except ImportError:
|
260
264
|
raise pip_dependency_error("Bedrock API", ["aioboto3"])
|
261
265
|
|
@@ -313,6 +317,7 @@ class BedrockAPI(ModelAPI):
|
|
313
317
|
from botocore.exceptions import ClientError
|
314
318
|
|
315
319
|
# The bedrock client
|
320
|
+
request_id = self._time_tracker.start_request()
|
316
321
|
async with self.session.client( # type: ignore[call-overload]
|
317
322
|
service_name="bedrock-runtime",
|
318
323
|
endpoint_url=self.base_url,
|
@@ -325,6 +330,7 @@ class BedrockAPI(ModelAPI):
|
|
325
330
|
else DEFAULT_MAX_RETRIES,
|
326
331
|
mode="adaptive",
|
327
332
|
),
|
333
|
+
user_agent_extra=self._time_tracker.user_agent_extra(request_id),
|
328
334
|
),
|
329
335
|
**self.model_args,
|
330
336
|
) as client:
|
@@ -364,6 +370,7 @@ class BedrockAPI(ModelAPI):
|
|
364
370
|
request.model_dump(exclude_none=True)
|
365
371
|
),
|
366
372
|
response=response,
|
373
|
+
time=self._time_tracker.end_request(request_id),
|
367
374
|
)
|
368
375
|
|
369
376
|
try:
|
@@ -19,6 +19,7 @@ from .util import (
|
|
19
19
|
is_chat_api_rate_limit,
|
20
20
|
model_base_url,
|
21
21
|
)
|
22
|
+
from .util.tracker import HttpxTimeTracker
|
22
23
|
|
23
24
|
# https://developers.cloudflare.com/workers-ai/models/#text-generation
|
24
25
|
|
@@ -50,6 +51,7 @@ class CloudFlareAPI(ModelAPI):
|
|
50
51
|
if not self.api_key:
|
51
52
|
raise environment_prerequisite_error("CloudFlare", CLOUDFLARE_API_TOKEN)
|
52
53
|
self.client = httpx.AsyncClient()
|
54
|
+
self._time_tracker = HttpxTimeTracker(self.client)
|
53
55
|
base_url = model_base_url(base_url, "CLOUDFLARE_BASE_URL")
|
54
56
|
self.base_url = (
|
55
57
|
base_url if base_url else "https://api.cloudflare.com/client/v4/accounts"
|
@@ -76,12 +78,28 @@ class CloudFlareAPI(ModelAPI):
|
|
76
78
|
json["max_tokens"] = config.max_tokens
|
77
79
|
json["messages"] = chat_api_input(input, tools, self.chat_api_handler())
|
78
80
|
|
81
|
+
# request_id
|
82
|
+
request_id = self._time_tracker.start_request()
|
83
|
+
|
84
|
+
# setup response
|
85
|
+
response: dict[str, Any] = {}
|
86
|
+
|
87
|
+
def model_call() -> ModelCall:
|
88
|
+
return ModelCall.create(
|
89
|
+
request=json,
|
90
|
+
response=response,
|
91
|
+
time=self._time_tracker.end_request(request_id),
|
92
|
+
)
|
93
|
+
|
79
94
|
# make the call
|
80
95
|
response = await chat_api_request(
|
81
96
|
self.client,
|
82
97
|
model_name=self.model_name,
|
83
98
|
url=f"{chat_url}/{self.model_name}",
|
84
|
-
headers={
|
99
|
+
headers={
|
100
|
+
"Authorization": f"Bearer {self.api_key}",
|
101
|
+
HttpxTimeTracker.REQUEST_ID_HEADER: request_id,
|
102
|
+
},
|
85
103
|
json=json,
|
86
104
|
config=config,
|
87
105
|
)
|
@@ -102,13 +120,8 @@ class CloudFlareAPI(ModelAPI):
|
|
102
120
|
],
|
103
121
|
)
|
104
122
|
|
105
|
-
# record call
|
106
|
-
call = ModelCall.create(
|
107
|
-
request=dict(model_name=self.model_name, **json), response=response
|
108
|
-
)
|
109
|
-
|
110
123
|
# return
|
111
|
-
return output,
|
124
|
+
return output, model_call()
|
112
125
|
else:
|
113
126
|
error = str(response.get("errors", "Unknown"))
|
114
127
|
raise RuntimeError(f"Error calling {self.model_name}: {error}")
|