inspect-ai 0.3.53__py3-none-any.whl → 0.3.55__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/_cli/eval.py +26 -1
- inspect_ai/_cli/main.py +2 -0
- inspect_ai/_cli/trace.py +244 -0
- inspect_ai/_display/textual/app.py +5 -1
- inspect_ai/_display/textual/widgets/tasks.py +13 -3
- inspect_ai/_eval/eval.py +17 -0
- inspect_ai/_eval/task/images.py +4 -14
- inspect_ai/_eval/task/log.py +2 -1
- inspect_ai/_eval/task/run.py +26 -10
- inspect_ai/_util/constants.py +3 -3
- inspect_ai/_util/display.py +1 -0
- inspect_ai/_util/logger.py +34 -8
- inspect_ai/_util/trace.py +275 -0
- inspect_ai/log/_log.py +3 -0
- inspect_ai/log/_message.py +2 -2
- inspect_ai/log/_recorders/eval.py +6 -17
- inspect_ai/log/_recorders/json.py +19 -17
- inspect_ai/model/_cache.py +22 -16
- inspect_ai/model/_call_tools.py +9 -1
- inspect_ai/model/_generate_config.py +2 -2
- inspect_ai/model/_model.py +11 -12
- inspect_ai/model/_providers/bedrock.py +1 -1
- inspect_ai/model/_providers/openai.py +11 -1
- inspect_ai/tool/_tools/_web_browser/_web_browser.py +1 -1
- inspect_ai/util/_sandbox/context.py +6 -1
- inspect_ai/util/_sandbox/docker/compose.py +58 -19
- inspect_ai/util/_sandbox/docker/docker.py +11 -11
- inspect_ai/util/_sandbox/docker/util.py +0 -6
- inspect_ai/util/_sandbox/service.py +17 -7
- inspect_ai/util/_subprocess.py +6 -1
- inspect_ai/util/_subtask.py +8 -2
- {inspect_ai-0.3.53.dist-info → inspect_ai-0.3.55.dist-info}/METADATA +7 -7
- {inspect_ai-0.3.53.dist-info → inspect_ai-0.3.55.dist-info}/RECORD +37 -35
- {inspect_ai-0.3.53.dist-info → inspect_ai-0.3.55.dist-info}/LICENSE +0 -0
- {inspect_ai-0.3.53.dist-info → inspect_ai-0.3.55.dist-info}/WHEEL +0 -0
- {inspect_ai-0.3.53.dist-info → inspect_ai-0.3.55.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.53.dist-info → inspect_ai-0.3.55.dist-info}/top_level.txt +0 -0
inspect_ai/_util/constants.py
CHANGED
@@ -14,12 +14,12 @@ DEFAULT_VIEW_PORT = 7575
|
|
14
14
|
DEFAULT_SERVER_HOST = "127.0.0.1"
|
15
15
|
HTTP = 15
|
16
16
|
HTTP_LOG_LEVEL = "HTTP"
|
17
|
-
|
18
|
-
|
17
|
+
TRACE = 13
|
18
|
+
TRACE_LOG_LEVEL = "TRACE"
|
19
19
|
ALL_LOG_LEVELS = [
|
20
20
|
"DEBUG",
|
21
|
+
TRACE_LOG_LEVEL,
|
21
22
|
HTTP_LOG_LEVEL,
|
22
|
-
SANDBOX_LOG_LEVEL,
|
23
23
|
"INFO",
|
24
24
|
"WARNING",
|
25
25
|
"ERROR",
|
inspect_ai/_util/display.py
CHANGED
@@ -14,6 +14,7 @@ _display_type: DisplayType | None = None
|
|
14
14
|
|
15
15
|
def init_display_type(display: str | None = None) -> DisplayType:
|
16
16
|
global _display_type
|
17
|
+
global _display_metrics
|
17
18
|
display = (
|
18
19
|
display or os.environ.get("INSPECT_DISPLAY", DEFAULT_DISPLAY).lower().strip()
|
19
20
|
)
|
inspect_ai/_util/logger.py
CHANGED
@@ -11,6 +11,7 @@ from logging import (
|
|
11
11
|
getLevelName,
|
12
12
|
getLogger,
|
13
13
|
)
|
14
|
+
from pathlib import Path
|
14
15
|
|
15
16
|
import rich
|
16
17
|
from rich.console import ConsoleRenderable
|
@@ -18,17 +19,20 @@ from rich.logging import RichHandler
|
|
18
19
|
from rich.text import Text
|
19
20
|
from typing_extensions import override
|
20
21
|
|
21
|
-
from
|
22
|
+
from .constants import (
|
22
23
|
ALL_LOG_LEVELS,
|
23
24
|
DEFAULT_LOG_LEVEL,
|
24
25
|
DEFAULT_LOG_LEVEL_TRANSCRIPT,
|
25
26
|
HTTP,
|
26
27
|
HTTP_LOG_LEVEL,
|
27
28
|
PKG_NAME,
|
28
|
-
|
29
|
-
|
29
|
+
TRACE,
|
30
|
+
TRACE_LOG_LEVEL,
|
30
31
|
)
|
31
|
-
from
|
32
|
+
from .error import PrerequisiteError
|
33
|
+
from .trace import TraceFileHandler, TraceFormatter, inspect_trace_dir
|
34
|
+
|
35
|
+
TRACE_FILE_NAME = "trace.log"
|
32
36
|
|
33
37
|
|
34
38
|
# log handler that filters messages to stderr and the log file
|
@@ -52,6 +56,24 @@ class LogHandler(RichHandler):
|
|
52
56
|
else:
|
53
57
|
self.file_logger_level = 0
|
54
58
|
|
59
|
+
# add a trace handler
|
60
|
+
default_trace_file = inspect_trace_dir() / TRACE_FILE_NAME
|
61
|
+
have_existing_trace_file = default_trace_file.exists()
|
62
|
+
env_trace_file = os.environ.get("INSPECT_TRACE_FILE", None)
|
63
|
+
trace_file = Path(env_trace_file) if env_trace_file else default_trace_file
|
64
|
+
trace_total_files = 10
|
65
|
+
self.trace_logger = TraceFileHandler(
|
66
|
+
trace_file.as_posix(),
|
67
|
+
backupCount=trace_total_files - 1, # exclude the current file (10 total)
|
68
|
+
)
|
69
|
+
self.trace_logger.setFormatter(TraceFormatter())
|
70
|
+
if have_existing_trace_file:
|
71
|
+
self.trace_logger.doRollover()
|
72
|
+
|
73
|
+
# set trace level
|
74
|
+
trace_level = os.environ.get("INSPECT_TRACE_LEVEL", TRACE_LOG_LEVEL)
|
75
|
+
self.trace_logger_level = int(getLevelName(trace_level.upper()))
|
76
|
+
|
55
77
|
@override
|
56
78
|
def emit(self, record: LogRecord) -> None:
|
57
79
|
# demote httpx and return notifications to log_level http
|
@@ -79,6 +101,10 @@ class LogHandler(RichHandler):
|
|
79
101
|
):
|
80
102
|
self.file_logger.emit(record)
|
81
103
|
|
104
|
+
# write to trace if the trace level matches.
|
105
|
+
if self.trace_logger and record.levelno >= self.trace_logger_level:
|
106
|
+
self.trace_logger.emit(record)
|
107
|
+
|
82
108
|
# eval log always gets info level and higher records
|
83
109
|
# eval log only gets debug or http if we opt-in
|
84
110
|
write = record.levelno >= self.transcript_levelno
|
@@ -95,12 +121,12 @@ def init_logger(
|
|
95
121
|
log_level: str | None = None, log_level_transcript: str | None = None
|
96
122
|
) -> None:
|
97
123
|
# backwards compatibility for 'tools'
|
98
|
-
if log_level == "tools":
|
99
|
-
log_level = "
|
124
|
+
if log_level == "sandbox" or log_level == "tools":
|
125
|
+
log_level = "trace"
|
100
126
|
|
101
127
|
# register http and tools levels
|
102
128
|
addLevelName(HTTP, HTTP_LOG_LEVEL)
|
103
|
-
addLevelName(
|
129
|
+
addLevelName(TRACE, TRACE_LOG_LEVEL)
|
104
130
|
|
105
131
|
def validate_level(option: str, level: str) -> None:
|
106
132
|
if level not in ALL_LOG_LEVELS:
|
@@ -134,7 +160,7 @@ def init_logger(
|
|
134
160
|
getLogger().addHandler(_logHandler)
|
135
161
|
|
136
162
|
# establish default capture level
|
137
|
-
capture_level = min(
|
163
|
+
capture_level = min(TRACE, levelno)
|
138
164
|
|
139
165
|
# see all the messages (we won't actually display/write all of them)
|
140
166
|
getLogger().setLevel(capture_level)
|
@@ -0,0 +1,275 @@
|
|
1
|
+
import asyncio
|
2
|
+
import datetime
|
3
|
+
import gzip
|
4
|
+
import json
|
5
|
+
import logging
|
6
|
+
import os
|
7
|
+
import shutil
|
8
|
+
import time
|
9
|
+
import traceback
|
10
|
+
from contextlib import contextmanager
|
11
|
+
from logging import Logger
|
12
|
+
from logging.handlers import RotatingFileHandler
|
13
|
+
from pathlib import Path
|
14
|
+
from typing import Any, Generator, Literal, TextIO
|
15
|
+
|
16
|
+
import jsonlines
|
17
|
+
from pydantic import BaseModel, Field, JsonValue
|
18
|
+
from shortuuid import uuid
|
19
|
+
|
20
|
+
from .appdirs import inspect_data_dir
|
21
|
+
from .constants import TRACE
|
22
|
+
|
23
|
+
|
24
|
+
def inspect_trace_dir() -> Path:
|
25
|
+
return inspect_data_dir("traces")
|
26
|
+
|
27
|
+
|
28
|
+
@contextmanager
|
29
|
+
def trace_action(
|
30
|
+
logger: Logger, action: str, message: str, *args: Any, **kwargs: Any
|
31
|
+
) -> Generator[None, None, None]:
|
32
|
+
trace_id = uuid()
|
33
|
+
start_monotonic = time.monotonic()
|
34
|
+
start_wall = time.time()
|
35
|
+
pid = os.getpid()
|
36
|
+
detail = message % args if args else message % kwargs if kwargs else message
|
37
|
+
|
38
|
+
def trace_message(event: str) -> str:
|
39
|
+
return f"{action}: {detail} ({event})"
|
40
|
+
|
41
|
+
logger.log(
|
42
|
+
TRACE,
|
43
|
+
trace_message("enter"),
|
44
|
+
extra={
|
45
|
+
"action": action,
|
46
|
+
"detail": detail,
|
47
|
+
"event": "enter",
|
48
|
+
"trace_id": str(trace_id),
|
49
|
+
"start_time": start_wall,
|
50
|
+
"pid": pid,
|
51
|
+
},
|
52
|
+
)
|
53
|
+
|
54
|
+
try:
|
55
|
+
yield
|
56
|
+
duration = time.monotonic() - start_monotonic
|
57
|
+
logger.log(
|
58
|
+
TRACE,
|
59
|
+
trace_message("exit"),
|
60
|
+
extra={
|
61
|
+
"action": action,
|
62
|
+
"detail": detail,
|
63
|
+
"event": "exit",
|
64
|
+
"trace_id": str(trace_id),
|
65
|
+
"duration": duration,
|
66
|
+
"pid": pid,
|
67
|
+
},
|
68
|
+
)
|
69
|
+
except (KeyboardInterrupt, asyncio.CancelledError):
|
70
|
+
duration = time.monotonic() - start_monotonic
|
71
|
+
logger.log(
|
72
|
+
TRACE,
|
73
|
+
trace_message("cancel"),
|
74
|
+
extra={
|
75
|
+
"action": action,
|
76
|
+
"detail": detail,
|
77
|
+
"event": "cancel",
|
78
|
+
"trace_id": str(trace_id),
|
79
|
+
"duration": duration,
|
80
|
+
"pid": pid,
|
81
|
+
},
|
82
|
+
)
|
83
|
+
raise
|
84
|
+
except TimeoutError:
|
85
|
+
duration = time.monotonic() - start_monotonic
|
86
|
+
logger.log(
|
87
|
+
TRACE,
|
88
|
+
trace_message("timeout"),
|
89
|
+
extra={
|
90
|
+
"action": action,
|
91
|
+
"detail": detail,
|
92
|
+
"event": "timeout",
|
93
|
+
"trace_id": str(trace_id),
|
94
|
+
"duration": duration,
|
95
|
+
"pid": pid,
|
96
|
+
},
|
97
|
+
)
|
98
|
+
raise
|
99
|
+
except Exception as ex:
|
100
|
+
duration = time.monotonic() - start_monotonic
|
101
|
+
logger.log(
|
102
|
+
TRACE,
|
103
|
+
trace_message("error"),
|
104
|
+
extra={
|
105
|
+
"action": action,
|
106
|
+
"detail": detail,
|
107
|
+
"event": "error",
|
108
|
+
"trace_id": str(trace_id),
|
109
|
+
"duration": duration,
|
110
|
+
"error": getattr(ex, "message", str(ex)) or repr(ex),
|
111
|
+
"error_type": type(ex).__name__,
|
112
|
+
"stacktrace": traceback.format_exc(),
|
113
|
+
"pid": pid,
|
114
|
+
},
|
115
|
+
)
|
116
|
+
raise
|
117
|
+
|
118
|
+
|
119
|
+
def trace_message(
|
120
|
+
logger: Logger, category: str, message: str, *args: Any, **kwargs: Any
|
121
|
+
) -> None:
|
122
|
+
logger.log(TRACE, f"[{category}] {message}", *args, **kwargs)
|
123
|
+
|
124
|
+
|
125
|
+
class TraceFormatter(logging.Formatter):
|
126
|
+
def format(self, record: logging.LogRecord) -> str:
|
127
|
+
# Base log entry with standard fields
|
128
|
+
output: dict[str, JsonValue] = {
|
129
|
+
"timestamp": self.formatTime(record),
|
130
|
+
"level": record.levelname,
|
131
|
+
"message": record.getMessage(), # This handles the % formatting of the message
|
132
|
+
}
|
133
|
+
|
134
|
+
# Add basic context its not a TRACE message
|
135
|
+
if record.levelname != "TRACE":
|
136
|
+
if hasattr(record, "module"):
|
137
|
+
output["module"] = record.module
|
138
|
+
if hasattr(record, "funcName"):
|
139
|
+
output["function"] = record.funcName
|
140
|
+
if hasattr(record, "lineno"):
|
141
|
+
output["line"] = record.lineno
|
142
|
+
|
143
|
+
# Add any structured fields from extra
|
144
|
+
elif hasattr(record, "action"):
|
145
|
+
# This is a trace_action log
|
146
|
+
for key in [
|
147
|
+
"action",
|
148
|
+
"detail",
|
149
|
+
"event",
|
150
|
+
"trace_id",
|
151
|
+
"start_time",
|
152
|
+
"duration",
|
153
|
+
"error",
|
154
|
+
"error_type",
|
155
|
+
"stacktrace",
|
156
|
+
"pid",
|
157
|
+
]:
|
158
|
+
if hasattr(record, key):
|
159
|
+
output[key] = getattr(record, key)
|
160
|
+
|
161
|
+
# Handle any unexpected extra attributes
|
162
|
+
for key, value in record.__dict__.items():
|
163
|
+
if key not in output and key not in (
|
164
|
+
"args",
|
165
|
+
"lineno",
|
166
|
+
"funcName",
|
167
|
+
"module",
|
168
|
+
"asctime",
|
169
|
+
"created",
|
170
|
+
"exc_info",
|
171
|
+
"exc_text",
|
172
|
+
"filename",
|
173
|
+
"levelno",
|
174
|
+
"levelname",
|
175
|
+
"msecs",
|
176
|
+
"msg",
|
177
|
+
"name",
|
178
|
+
"pathname",
|
179
|
+
"process",
|
180
|
+
"processName",
|
181
|
+
"relativeCreated",
|
182
|
+
"stack_info",
|
183
|
+
"thread",
|
184
|
+
"threadName",
|
185
|
+
):
|
186
|
+
output[key] = value
|
187
|
+
|
188
|
+
return json.dumps(
|
189
|
+
output, default=str
|
190
|
+
) # default=str handles non-serializable objects
|
191
|
+
|
192
|
+
def formatTime(self, record: logging.LogRecord, datefmt: str | None = None) -> str:
|
193
|
+
# ISO format with timezone
|
194
|
+
dt = datetime.datetime.fromtimestamp(record.created)
|
195
|
+
return dt.isoformat()
|
196
|
+
|
197
|
+
|
198
|
+
class TraceRecord(BaseModel):
|
199
|
+
timestamp: str
|
200
|
+
level: str
|
201
|
+
message: str
|
202
|
+
|
203
|
+
|
204
|
+
class SimpleTraceRecord(TraceRecord):
|
205
|
+
action: None = Field(default=None)
|
206
|
+
|
207
|
+
|
208
|
+
class ActionTraceRecord(TraceRecord):
|
209
|
+
action: str
|
210
|
+
event: Literal["enter", "cancel", "error", "timeout", "exit"]
|
211
|
+
trace_id: str
|
212
|
+
detail: str = Field(default="")
|
213
|
+
start_time: float | None = Field(default=None)
|
214
|
+
duration: float | None = Field(default=None)
|
215
|
+
error: str | None = Field(default=None)
|
216
|
+
error_type: str | None = Field(default=None)
|
217
|
+
stacktrace: str | None = Field(default=None)
|
218
|
+
pid: int | None = Field(default=None)
|
219
|
+
|
220
|
+
|
221
|
+
def read_trace_file(file: Path) -> list[TraceRecord]:
|
222
|
+
def read_file(f: TextIO) -> list[TraceRecord]:
|
223
|
+
jsonlines_reader = jsonlines.Reader(f)
|
224
|
+
trace_records: list[TraceRecord] = []
|
225
|
+
for trace in jsonlines_reader.iter(type=dict):
|
226
|
+
if "action" in trace:
|
227
|
+
trace_records.append(ActionTraceRecord(**trace))
|
228
|
+
else:
|
229
|
+
trace_records.append(SimpleTraceRecord(**trace))
|
230
|
+
return trace_records
|
231
|
+
|
232
|
+
if file.name.endswith(".gz"):
|
233
|
+
with gzip.open(file, "rt") as f:
|
234
|
+
return read_file(f)
|
235
|
+
else:
|
236
|
+
with open(file, "r") as f:
|
237
|
+
return read_file(f)
|
238
|
+
|
239
|
+
|
240
|
+
class TraceFileHandler(RotatingFileHandler):
|
241
|
+
def __init__(
|
242
|
+
self,
|
243
|
+
filename: str,
|
244
|
+
mode: str = "a",
|
245
|
+
maxBytes: int = 0,
|
246
|
+
backupCount: int = 0,
|
247
|
+
encoding: str | None = None,
|
248
|
+
delay: bool = False,
|
249
|
+
) -> None:
|
250
|
+
super().__init__(filename, mode, maxBytes, backupCount, encoding, delay)
|
251
|
+
|
252
|
+
def rotation_filename(self, default_name: str) -> str:
|
253
|
+
"""
|
254
|
+
Returns the name of the rotated file.
|
255
|
+
|
256
|
+
Args:
|
257
|
+
default_name: The default name that would be used for rotation
|
258
|
+
|
259
|
+
Returns:
|
260
|
+
The modified filename with .gz extension
|
261
|
+
"""
|
262
|
+
return default_name + ".gz"
|
263
|
+
|
264
|
+
def rotate(self, source: str, dest: str) -> None:
|
265
|
+
"""
|
266
|
+
Compresses the source file and moves it to destination.
|
267
|
+
|
268
|
+
Args:
|
269
|
+
source: The source file to be compressed
|
270
|
+
dest: The destination path for the compressed file
|
271
|
+
"""
|
272
|
+
with open(source, "rb") as f_in:
|
273
|
+
with gzip.open(dest, "wb") as f_out:
|
274
|
+
shutil.copyfileobj(f_in, f_out)
|
275
|
+
os.remove(source)
|
inspect_ai/log/_log.py
CHANGED
@@ -94,6 +94,9 @@ class EvalConfig(BaseModel):
|
|
94
94
|
log_buffer: int | None = Field(default=None)
|
95
95
|
"""Number of samples to buffer before writing log file."""
|
96
96
|
|
97
|
+
score_display: bool | None = Field(default=None)
|
98
|
+
"""Display scoring metrics realtime."""
|
99
|
+
|
97
100
|
@property
|
98
101
|
def max_messages(self) -> int | None:
|
99
102
|
"""Deprecated max_messages property."""
|
inspect_ai/log/_message.py
CHANGED
@@ -17,6 +17,7 @@ from inspect_ai._util.content import ContentImage, ContentText
|
|
17
17
|
from inspect_ai._util.error import EvalError
|
18
18
|
from inspect_ai._util.file import FileSystem, async_fileystem, dirname, file, filesystem
|
19
19
|
from inspect_ai._util.json import jsonable_python
|
20
|
+
from inspect_ai._util.trace import trace_action
|
20
21
|
from inspect_ai.model._chat_message import ChatMessage
|
21
22
|
from inspect_ai.scorer._metric import Score
|
22
23
|
|
@@ -351,25 +352,13 @@ class ZipLogFile:
|
|
351
352
|
self._temp_file.seek(0)
|
352
353
|
log_bytes = self._temp_file.read()
|
353
354
|
|
354
|
-
|
355
|
-
|
356
|
-
try:
|
357
|
-
if self._async_fs:
|
358
|
-
await self._async_fs._pipe_file(self._file, log_bytes)
|
359
|
-
written = True
|
360
|
-
except Exception as ex:
|
361
|
-
logger.warning(
|
362
|
-
f"Error occurred during async write to {self._file}: {ex}. Falling back to sync write."
|
363
|
-
)
|
364
|
-
|
365
|
-
try:
|
366
|
-
# write sync if we need to
|
367
|
-
if not written:
|
355
|
+
with trace_action(logger, "Log Write", self._file):
|
356
|
+
try:
|
368
357
|
with file(self._file, "wb") as f:
|
369
358
|
f.write(log_bytes)
|
370
|
-
|
371
|
-
|
372
|
-
|
359
|
+
finally:
|
360
|
+
# re-open zip file w/ self.temp_file pointer at end
|
361
|
+
self._open()
|
373
362
|
|
374
363
|
async def close(self) -> EvalLog:
|
375
364
|
async with self._lock:
|
@@ -15,6 +15,7 @@ from inspect_ai._util.file import (
|
|
15
15
|
file,
|
16
16
|
filesystem,
|
17
17
|
)
|
18
|
+
from inspect_ai._util.trace import trace_action
|
18
19
|
|
19
20
|
from .._log import (
|
20
21
|
EvalLog,
|
@@ -181,23 +182,24 @@ class JSONRecorder(FileRecorder):
|
|
181
182
|
# get log as bytes
|
182
183
|
log_bytes = eval_log_json(log)
|
183
184
|
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
185
|
+
with trace_action(logger, "Log Write", location):
|
186
|
+
# try to write async for async filesystems
|
187
|
+
written = False
|
188
|
+
try:
|
189
|
+
fs = filesystem(location)
|
190
|
+
if fs.is_async():
|
191
|
+
async with async_fileystem(location) as async_fs:
|
192
|
+
await async_fs._pipe_file(location, log_bytes)
|
193
|
+
written = True
|
194
|
+
except Exception as ex:
|
195
|
+
logger.warning(
|
196
|
+
f"Error occurred during async write to {location}: {ex}. Falling back to sync write."
|
197
|
+
)
|
198
|
+
|
199
|
+
# otherwise use sync
|
200
|
+
if not written:
|
201
|
+
with file(location, "wb") as f:
|
202
|
+
f.write(log_bytes)
|
201
203
|
|
202
204
|
|
203
205
|
def _validate_version(ver: int) -> None:
|
inspect_ai/model/_cache.py
CHANGED
@@ -6,10 +6,12 @@ from datetime import datetime, timezone
|
|
6
6
|
from hashlib import md5
|
7
7
|
from pathlib import Path
|
8
8
|
from shutil import rmtree
|
9
|
+
from typing import Any
|
9
10
|
|
10
11
|
from dateutil.relativedelta import relativedelta
|
11
12
|
|
12
13
|
from inspect_ai._util.appdirs import inspect_cache_dir
|
14
|
+
from inspect_ai._util.trace import trace_message
|
13
15
|
from inspect_ai.tool import ToolChoice, ToolInfo
|
14
16
|
|
15
17
|
from ._chat_message import ChatMessage
|
@@ -19,6 +21,10 @@ from ._model_output import ModelOutput
|
|
19
21
|
logger = logging.getLogger(__name__)
|
20
22
|
|
21
23
|
|
24
|
+
def trace(msg: str, *args: Any) -> None:
|
25
|
+
trace_message(logger, "Cache", msg, *args)
|
26
|
+
|
27
|
+
|
22
28
|
def _path_is_in_cache(path: Path | str) -> bool:
|
23
29
|
"""This ensures the path is in our cache directory, just in case the `model` is ../../../home/ubuntu/maliciousness"""
|
24
30
|
if isinstance(path, str):
|
@@ -153,7 +159,7 @@ def _cache_key(entry: CacheEntry) -> str:
|
|
153
159
|
|
154
160
|
base_string = "|".join([str(component) for component in components])
|
155
161
|
|
156
|
-
|
162
|
+
trace(_cache_key_debug_string([str(component) for component in components]))
|
157
163
|
|
158
164
|
return md5(base_string.encode("utf-8")).hexdigest()
|
159
165
|
|
@@ -192,11 +198,11 @@ def cache_store(
|
|
192
198
|
|
193
199
|
with open(filename, "wb") as f:
|
194
200
|
expiry = _cache_expiry(entry.policy)
|
195
|
-
|
201
|
+
trace("Storing in cache: %s (expires: %s)", filename, expiry)
|
196
202
|
pickle.dump((expiry, output), f)
|
197
203
|
return True
|
198
204
|
except Exception as e:
|
199
|
-
|
205
|
+
trace(f"Failed to cache {filename}: {e}")
|
200
206
|
return False
|
201
207
|
|
202
208
|
|
@@ -204,12 +210,12 @@ def cache_fetch(entry: CacheEntry) -> ModelOutput | None:
|
|
204
210
|
"""Fetch a value from the cache directory."""
|
205
211
|
filename = cache_path(model=entry.model) / _cache_key(entry)
|
206
212
|
try:
|
207
|
-
|
213
|
+
trace("Fetching from cache: %s", filename)
|
208
214
|
|
209
215
|
with open(filename, "rb") as f:
|
210
216
|
expiry, output = pickle.load(f)
|
211
217
|
if not isinstance(output, ModelOutput):
|
212
|
-
|
218
|
+
trace(
|
213
219
|
"Unexpected cached type, can only fetch ModelOutput: %s (%s)",
|
214
220
|
type(output),
|
215
221
|
filename,
|
@@ -217,7 +223,7 @@ def cache_fetch(entry: CacheEntry) -> ModelOutput | None:
|
|
217
223
|
return None
|
218
224
|
|
219
225
|
if _is_expired(expiry):
|
220
|
-
|
226
|
+
trace("Cache expired for %s (%s)", filename, expiry)
|
221
227
|
# If it's expired, no point keeping it as we'll never access it
|
222
228
|
# successfully again.
|
223
229
|
filename.unlink(missing_ok=True)
|
@@ -225,7 +231,7 @@ def cache_fetch(entry: CacheEntry) -> ModelOutput | None:
|
|
225
231
|
|
226
232
|
return output
|
227
233
|
except Exception as e:
|
228
|
-
|
234
|
+
trace(f"Failed to fetch from cache {filename}: {e}")
|
229
235
|
return None
|
230
236
|
|
231
237
|
|
@@ -235,7 +241,7 @@ def cache_clear(model: str = "") -> bool:
|
|
235
241
|
path = cache_path(model)
|
236
242
|
|
237
243
|
if (model == "" or _path_is_in_cache(path)) and path.exists():
|
238
|
-
|
244
|
+
trace("Clearing cache: %s", path)
|
239
245
|
rmtree(path)
|
240
246
|
return True
|
241
247
|
|
@@ -351,24 +357,24 @@ def cache_list_expired(filter_by: list[str] = []) -> list[Path]:
|
|
351
357
|
# "../../foo/bar") but we don't want to search the entire cache
|
352
358
|
return []
|
353
359
|
|
354
|
-
|
360
|
+
trace("Filtering by paths: %s", filter_by_paths)
|
355
361
|
for dirpath, _dirnames, filenames in os.walk(cache_path()):
|
356
362
|
if filter_by_paths and Path(dirpath) not in filter_by_paths:
|
357
|
-
|
363
|
+
trace("Skipping path %s", dirpath)
|
358
364
|
continue
|
359
365
|
|
360
|
-
|
366
|
+
trace("Checking dirpath %s", dirpath)
|
361
367
|
for filename in filenames:
|
362
368
|
path = Path(dirpath) / filename
|
363
|
-
|
369
|
+
trace("Checking path %s", path)
|
364
370
|
try:
|
365
371
|
with open(path, "rb") as f:
|
366
372
|
expiry, _cache_entry = pickle.load(f)
|
367
373
|
if _is_expired(expiry):
|
368
|
-
|
374
|
+
trace("Expired cache entry found: %s (%s)", path, expiry)
|
369
375
|
expired_cache_entries.append(path)
|
370
376
|
except Exception as e:
|
371
|
-
|
377
|
+
trace("Failed to load cached item %s: %s", path, e)
|
372
378
|
continue
|
373
379
|
|
374
380
|
return expired_cache_entries
|
@@ -389,8 +395,8 @@ def cache_prune(files: list[Path] = []) -> None:
|
|
389
395
|
with open(file, "rb") as f:
|
390
396
|
expiry, _cache_entry = pickle.load(f)
|
391
397
|
if _is_expired(expiry):
|
392
|
-
|
398
|
+
trace("Pruning expired cache: %s", file)
|
393
399
|
file.unlink(missing_ok=True)
|
394
400
|
except Exception as e:
|
395
|
-
|
401
|
+
trace("Failed to prune cache %s: %s", file, e)
|
396
402
|
continue
|
inspect_ai/model/_call_tools.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
import asyncio
|
2
2
|
import inspect
|
3
3
|
from dataclasses import is_dataclass
|
4
|
+
from logging import getLogger
|
4
5
|
from textwrap import dedent
|
5
6
|
from typing import (
|
6
7
|
Any,
|
@@ -19,7 +20,9 @@ from jsonschema import Draft7Validator
|
|
19
20
|
from pydantic import BaseModel
|
20
21
|
|
21
22
|
from inspect_ai._util.content import Content, ContentImage, ContentText
|
23
|
+
from inspect_ai._util.format import format_function_call
|
22
24
|
from inspect_ai._util.text import truncate_string_to_bytes
|
25
|
+
from inspect_ai._util.trace import trace_action
|
23
26
|
from inspect_ai.model._trace import trace_tool_mesage
|
24
27
|
from inspect_ai.tool import Tool, ToolCall, ToolError, ToolInfo
|
25
28
|
from inspect_ai.tool._tool import (
|
@@ -35,6 +38,8 @@ from inspect_ai.util import OutputLimitExceededError
|
|
35
38
|
from ._chat_message import ChatMessageAssistant, ChatMessageTool
|
36
39
|
from ._generate_config import active_generate_config
|
37
40
|
|
41
|
+
logger = getLogger(__name__)
|
42
|
+
|
38
43
|
|
39
44
|
async def call_tools(
|
40
45
|
message: ChatMessageAssistant,
|
@@ -215,7 +220,10 @@ async def call_tool(tools: list[ToolDef], message: str, call: ToolCall) -> Any:
|
|
215
220
|
arguments = tool_params(call.arguments, tool_def.tool)
|
216
221
|
|
217
222
|
# call the tool
|
218
|
-
|
223
|
+
with trace_action(
|
224
|
+
logger, "Tool Call", format_function_call(tool_def.name, arguments, width=1000)
|
225
|
+
):
|
226
|
+
result = await tool_def.tool(**arguments)
|
219
227
|
|
220
228
|
# return result
|
221
229
|
return result
|