inspect-ai 0.3.53__py3-none-any.whl → 0.3.55__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. inspect_ai/_cli/eval.py +26 -1
  2. inspect_ai/_cli/main.py +2 -0
  3. inspect_ai/_cli/trace.py +244 -0
  4. inspect_ai/_display/textual/app.py +5 -1
  5. inspect_ai/_display/textual/widgets/tasks.py +13 -3
  6. inspect_ai/_eval/eval.py +17 -0
  7. inspect_ai/_eval/task/images.py +4 -14
  8. inspect_ai/_eval/task/log.py +2 -1
  9. inspect_ai/_eval/task/run.py +26 -10
  10. inspect_ai/_util/constants.py +3 -3
  11. inspect_ai/_util/display.py +1 -0
  12. inspect_ai/_util/logger.py +34 -8
  13. inspect_ai/_util/trace.py +275 -0
  14. inspect_ai/log/_log.py +3 -0
  15. inspect_ai/log/_message.py +2 -2
  16. inspect_ai/log/_recorders/eval.py +6 -17
  17. inspect_ai/log/_recorders/json.py +19 -17
  18. inspect_ai/model/_cache.py +22 -16
  19. inspect_ai/model/_call_tools.py +9 -1
  20. inspect_ai/model/_generate_config.py +2 -2
  21. inspect_ai/model/_model.py +11 -12
  22. inspect_ai/model/_providers/bedrock.py +1 -1
  23. inspect_ai/model/_providers/openai.py +11 -1
  24. inspect_ai/tool/_tools/_web_browser/_web_browser.py +1 -1
  25. inspect_ai/util/_sandbox/context.py +6 -1
  26. inspect_ai/util/_sandbox/docker/compose.py +58 -19
  27. inspect_ai/util/_sandbox/docker/docker.py +11 -11
  28. inspect_ai/util/_sandbox/docker/util.py +0 -6
  29. inspect_ai/util/_sandbox/service.py +17 -7
  30. inspect_ai/util/_subprocess.py +6 -1
  31. inspect_ai/util/_subtask.py +8 -2
  32. {inspect_ai-0.3.53.dist-info → inspect_ai-0.3.55.dist-info}/METADATA +7 -7
  33. {inspect_ai-0.3.53.dist-info → inspect_ai-0.3.55.dist-info}/RECORD +37 -35
  34. {inspect_ai-0.3.53.dist-info → inspect_ai-0.3.55.dist-info}/LICENSE +0 -0
  35. {inspect_ai-0.3.53.dist-info → inspect_ai-0.3.55.dist-info}/WHEEL +0 -0
  36. {inspect_ai-0.3.53.dist-info → inspect_ai-0.3.55.dist-info}/entry_points.txt +0 -0
  37. {inspect_ai-0.3.53.dist-info → inspect_ai-0.3.55.dist-info}/top_level.txt +0 -0
@@ -14,12 +14,12 @@ DEFAULT_VIEW_PORT = 7575
14
14
  DEFAULT_SERVER_HOST = "127.0.0.1"
15
15
  HTTP = 15
16
16
  HTTP_LOG_LEVEL = "HTTP"
17
- SANDBOX = 17
18
- SANDBOX_LOG_LEVEL = "SANDBOX"
17
+ TRACE = 13
18
+ TRACE_LOG_LEVEL = "TRACE"
19
19
  ALL_LOG_LEVELS = [
20
20
  "DEBUG",
21
+ TRACE_LOG_LEVEL,
21
22
  HTTP_LOG_LEVEL,
22
- SANDBOX_LOG_LEVEL,
23
23
  "INFO",
24
24
  "WARNING",
25
25
  "ERROR",
@@ -14,6 +14,7 @@ _display_type: DisplayType | None = None
14
14
 
15
15
  def init_display_type(display: str | None = None) -> DisplayType:
16
16
  global _display_type
17
+ global _display_metrics
17
18
  display = (
18
19
  display or os.environ.get("INSPECT_DISPLAY", DEFAULT_DISPLAY).lower().strip()
19
20
  )
@@ -11,6 +11,7 @@ from logging import (
11
11
  getLevelName,
12
12
  getLogger,
13
13
  )
14
+ from pathlib import Path
14
15
 
15
16
  import rich
16
17
  from rich.console import ConsoleRenderable
@@ -18,17 +19,20 @@ from rich.logging import RichHandler
18
19
  from rich.text import Text
19
20
  from typing_extensions import override
20
21
 
21
- from inspect_ai._util.constants import (
22
+ from .constants import (
22
23
  ALL_LOG_LEVELS,
23
24
  DEFAULT_LOG_LEVEL,
24
25
  DEFAULT_LOG_LEVEL_TRANSCRIPT,
25
26
  HTTP,
26
27
  HTTP_LOG_LEVEL,
27
28
  PKG_NAME,
28
- SANDBOX,
29
- SANDBOX_LOG_LEVEL,
29
+ TRACE,
30
+ TRACE_LOG_LEVEL,
30
31
  )
31
- from inspect_ai._util.error import PrerequisiteError
32
+ from .error import PrerequisiteError
33
+ from .trace import TraceFileHandler, TraceFormatter, inspect_trace_dir
34
+
35
+ TRACE_FILE_NAME = "trace.log"
32
36
 
33
37
 
34
38
  # log handler that filters messages to stderr and the log file
@@ -52,6 +56,24 @@ class LogHandler(RichHandler):
52
56
  else:
53
57
  self.file_logger_level = 0
54
58
 
59
+ # add a trace handler
60
+ default_trace_file = inspect_trace_dir() / TRACE_FILE_NAME
61
+ have_existing_trace_file = default_trace_file.exists()
62
+ env_trace_file = os.environ.get("INSPECT_TRACE_FILE", None)
63
+ trace_file = Path(env_trace_file) if env_trace_file else default_trace_file
64
+ trace_total_files = 10
65
+ self.trace_logger = TraceFileHandler(
66
+ trace_file.as_posix(),
67
+ backupCount=trace_total_files - 1, # exclude the current file (10 total)
68
+ )
69
+ self.trace_logger.setFormatter(TraceFormatter())
70
+ if have_existing_trace_file:
71
+ self.trace_logger.doRollover()
72
+
73
+ # set trace level
74
+ trace_level = os.environ.get("INSPECT_TRACE_LEVEL", TRACE_LOG_LEVEL)
75
+ self.trace_logger_level = int(getLevelName(trace_level.upper()))
76
+
55
77
  @override
56
78
  def emit(self, record: LogRecord) -> None:
57
79
  # demote httpx and return notifications to log_level http
@@ -79,6 +101,10 @@ class LogHandler(RichHandler):
79
101
  ):
80
102
  self.file_logger.emit(record)
81
103
 
104
+ # write to trace if the trace level matches.
105
+ if self.trace_logger and record.levelno >= self.trace_logger_level:
106
+ self.trace_logger.emit(record)
107
+
82
108
  # eval log always gets info level and higher records
83
109
  # eval log only gets debug or http if we opt-in
84
110
  write = record.levelno >= self.transcript_levelno
@@ -95,12 +121,12 @@ def init_logger(
95
121
  log_level: str | None = None, log_level_transcript: str | None = None
96
122
  ) -> None:
97
123
  # backwards compatibility for 'tools'
98
- if log_level == "tools":
99
- log_level = "sandbox"
124
+ if log_level == "sandbox" or log_level == "tools":
125
+ log_level = "trace"
100
126
 
101
127
  # register http and tools levels
102
128
  addLevelName(HTTP, HTTP_LOG_LEVEL)
103
- addLevelName(SANDBOX, SANDBOX_LOG_LEVEL)
129
+ addLevelName(TRACE, TRACE_LOG_LEVEL)
104
130
 
105
131
  def validate_level(option: str, level: str) -> None:
106
132
  if level not in ALL_LOG_LEVELS:
@@ -134,7 +160,7 @@ def init_logger(
134
160
  getLogger().addHandler(_logHandler)
135
161
 
136
162
  # establish default capture level
137
- capture_level = min(HTTP, levelno)
163
+ capture_level = min(TRACE, levelno)
138
164
 
139
165
  # see all the messages (we won't actually display/write all of them)
140
166
  getLogger().setLevel(capture_level)
@@ -0,0 +1,275 @@
1
+ import asyncio
2
+ import datetime
3
+ import gzip
4
+ import json
5
+ import logging
6
+ import os
7
+ import shutil
8
+ import time
9
+ import traceback
10
+ from contextlib import contextmanager
11
+ from logging import Logger
12
+ from logging.handlers import RotatingFileHandler
13
+ from pathlib import Path
14
+ from typing import Any, Generator, Literal, TextIO
15
+
16
+ import jsonlines
17
+ from pydantic import BaseModel, Field, JsonValue
18
+ from shortuuid import uuid
19
+
20
+ from .appdirs import inspect_data_dir
21
+ from .constants import TRACE
22
+
23
+
24
+ def inspect_trace_dir() -> Path:
25
+ return inspect_data_dir("traces")
26
+
27
+
28
+ @contextmanager
29
+ def trace_action(
30
+ logger: Logger, action: str, message: str, *args: Any, **kwargs: Any
31
+ ) -> Generator[None, None, None]:
32
+ trace_id = uuid()
33
+ start_monotonic = time.monotonic()
34
+ start_wall = time.time()
35
+ pid = os.getpid()
36
+ detail = message % args if args else message % kwargs if kwargs else message
37
+
38
+ def trace_message(event: str) -> str:
39
+ return f"{action}: {detail} ({event})"
40
+
41
+ logger.log(
42
+ TRACE,
43
+ trace_message("enter"),
44
+ extra={
45
+ "action": action,
46
+ "detail": detail,
47
+ "event": "enter",
48
+ "trace_id": str(trace_id),
49
+ "start_time": start_wall,
50
+ "pid": pid,
51
+ },
52
+ )
53
+
54
+ try:
55
+ yield
56
+ duration = time.monotonic() - start_monotonic
57
+ logger.log(
58
+ TRACE,
59
+ trace_message("exit"),
60
+ extra={
61
+ "action": action,
62
+ "detail": detail,
63
+ "event": "exit",
64
+ "trace_id": str(trace_id),
65
+ "duration": duration,
66
+ "pid": pid,
67
+ },
68
+ )
69
+ except (KeyboardInterrupt, asyncio.CancelledError):
70
+ duration = time.monotonic() - start_monotonic
71
+ logger.log(
72
+ TRACE,
73
+ trace_message("cancel"),
74
+ extra={
75
+ "action": action,
76
+ "detail": detail,
77
+ "event": "cancel",
78
+ "trace_id": str(trace_id),
79
+ "duration": duration,
80
+ "pid": pid,
81
+ },
82
+ )
83
+ raise
84
+ except TimeoutError:
85
+ duration = time.monotonic() - start_monotonic
86
+ logger.log(
87
+ TRACE,
88
+ trace_message("timeout"),
89
+ extra={
90
+ "action": action,
91
+ "detail": detail,
92
+ "event": "timeout",
93
+ "trace_id": str(trace_id),
94
+ "duration": duration,
95
+ "pid": pid,
96
+ },
97
+ )
98
+ raise
99
+ except Exception as ex:
100
+ duration = time.monotonic() - start_monotonic
101
+ logger.log(
102
+ TRACE,
103
+ trace_message("error"),
104
+ extra={
105
+ "action": action,
106
+ "detail": detail,
107
+ "event": "error",
108
+ "trace_id": str(trace_id),
109
+ "duration": duration,
110
+ "error": getattr(ex, "message", str(ex)) or repr(ex),
111
+ "error_type": type(ex).__name__,
112
+ "stacktrace": traceback.format_exc(),
113
+ "pid": pid,
114
+ },
115
+ )
116
+ raise
117
+
118
+
119
+ def trace_message(
120
+ logger: Logger, category: str, message: str, *args: Any, **kwargs: Any
121
+ ) -> None:
122
+ logger.log(TRACE, f"[{category}] {message}", *args, **kwargs)
123
+
124
+
125
+ class TraceFormatter(logging.Formatter):
126
+ def format(self, record: logging.LogRecord) -> str:
127
+ # Base log entry with standard fields
128
+ output: dict[str, JsonValue] = {
129
+ "timestamp": self.formatTime(record),
130
+ "level": record.levelname,
131
+ "message": record.getMessage(), # This handles the % formatting of the message
132
+ }
133
+
134
+ # Add basic context its not a TRACE message
135
+ if record.levelname != "TRACE":
136
+ if hasattr(record, "module"):
137
+ output["module"] = record.module
138
+ if hasattr(record, "funcName"):
139
+ output["function"] = record.funcName
140
+ if hasattr(record, "lineno"):
141
+ output["line"] = record.lineno
142
+
143
+ # Add any structured fields from extra
144
+ elif hasattr(record, "action"):
145
+ # This is a trace_action log
146
+ for key in [
147
+ "action",
148
+ "detail",
149
+ "event",
150
+ "trace_id",
151
+ "start_time",
152
+ "duration",
153
+ "error",
154
+ "error_type",
155
+ "stacktrace",
156
+ "pid",
157
+ ]:
158
+ if hasattr(record, key):
159
+ output[key] = getattr(record, key)
160
+
161
+ # Handle any unexpected extra attributes
162
+ for key, value in record.__dict__.items():
163
+ if key not in output and key not in (
164
+ "args",
165
+ "lineno",
166
+ "funcName",
167
+ "module",
168
+ "asctime",
169
+ "created",
170
+ "exc_info",
171
+ "exc_text",
172
+ "filename",
173
+ "levelno",
174
+ "levelname",
175
+ "msecs",
176
+ "msg",
177
+ "name",
178
+ "pathname",
179
+ "process",
180
+ "processName",
181
+ "relativeCreated",
182
+ "stack_info",
183
+ "thread",
184
+ "threadName",
185
+ ):
186
+ output[key] = value
187
+
188
+ return json.dumps(
189
+ output, default=str
190
+ ) # default=str handles non-serializable objects
191
+
192
+ def formatTime(self, record: logging.LogRecord, datefmt: str | None = None) -> str:
193
+ # ISO format with timezone
194
+ dt = datetime.datetime.fromtimestamp(record.created)
195
+ return dt.isoformat()
196
+
197
+
198
+ class TraceRecord(BaseModel):
199
+ timestamp: str
200
+ level: str
201
+ message: str
202
+
203
+
204
+ class SimpleTraceRecord(TraceRecord):
205
+ action: None = Field(default=None)
206
+
207
+
208
+ class ActionTraceRecord(TraceRecord):
209
+ action: str
210
+ event: Literal["enter", "cancel", "error", "timeout", "exit"]
211
+ trace_id: str
212
+ detail: str = Field(default="")
213
+ start_time: float | None = Field(default=None)
214
+ duration: float | None = Field(default=None)
215
+ error: str | None = Field(default=None)
216
+ error_type: str | None = Field(default=None)
217
+ stacktrace: str | None = Field(default=None)
218
+ pid: int | None = Field(default=None)
219
+
220
+
221
+ def read_trace_file(file: Path) -> list[TraceRecord]:
222
+ def read_file(f: TextIO) -> list[TraceRecord]:
223
+ jsonlines_reader = jsonlines.Reader(f)
224
+ trace_records: list[TraceRecord] = []
225
+ for trace in jsonlines_reader.iter(type=dict):
226
+ if "action" in trace:
227
+ trace_records.append(ActionTraceRecord(**trace))
228
+ else:
229
+ trace_records.append(SimpleTraceRecord(**trace))
230
+ return trace_records
231
+
232
+ if file.name.endswith(".gz"):
233
+ with gzip.open(file, "rt") as f:
234
+ return read_file(f)
235
+ else:
236
+ with open(file, "r") as f:
237
+ return read_file(f)
238
+
239
+
240
+ class TraceFileHandler(RotatingFileHandler):
241
+ def __init__(
242
+ self,
243
+ filename: str,
244
+ mode: str = "a",
245
+ maxBytes: int = 0,
246
+ backupCount: int = 0,
247
+ encoding: str | None = None,
248
+ delay: bool = False,
249
+ ) -> None:
250
+ super().__init__(filename, mode, maxBytes, backupCount, encoding, delay)
251
+
252
+ def rotation_filename(self, default_name: str) -> str:
253
+ """
254
+ Returns the name of the rotated file.
255
+
256
+ Args:
257
+ default_name: The default name that would be used for rotation
258
+
259
+ Returns:
260
+ The modified filename with .gz extension
261
+ """
262
+ return default_name + ".gz"
263
+
264
+ def rotate(self, source: str, dest: str) -> None:
265
+ """
266
+ Compresses the source file and moves it to destination.
267
+
268
+ Args:
269
+ source: The source file to be compressed
270
+ dest: The destination path for the compressed file
271
+ """
272
+ with open(source, "rb") as f_in:
273
+ with gzip.open(dest, "wb") as f_out:
274
+ shutil.copyfileobj(f_in, f_out)
275
+ os.remove(source)
inspect_ai/log/_log.py CHANGED
@@ -94,6 +94,9 @@ class EvalConfig(BaseModel):
94
94
  log_buffer: int | None = Field(default=None)
95
95
  """Number of samples to buffer before writing log file."""
96
96
 
97
+ score_display: bool | None = Field(default=None)
98
+ """Display scoring metrics realtime."""
99
+
97
100
  @property
98
101
  def max_messages(self) -> int | None:
99
102
  """Deprecated max_messages property."""
@@ -64,7 +64,7 @@ class LoggingMessage(BaseModel):
64
64
  ) -> dict[str, Any]:
65
65
  if "level" in values:
66
66
  level = values["level"]
67
- if level == "tools":
68
- values["level"] = "sandbox"
67
+ if level == "tools" or level == "sandbox":
68
+ values["level"] = "trace"
69
69
 
70
70
  return values
@@ -17,6 +17,7 @@ from inspect_ai._util.content import ContentImage, ContentText
17
17
  from inspect_ai._util.error import EvalError
18
18
  from inspect_ai._util.file import FileSystem, async_fileystem, dirname, file, filesystem
19
19
  from inspect_ai._util.json import jsonable_python
20
+ from inspect_ai._util.trace import trace_action
20
21
  from inspect_ai.model._chat_message import ChatMessage
21
22
  from inspect_ai.scorer._metric import Score
22
23
 
@@ -351,25 +352,13 @@ class ZipLogFile:
351
352
  self._temp_file.seek(0)
352
353
  log_bytes = self._temp_file.read()
353
354
 
354
- # attempt async write
355
- written = False
356
- try:
357
- if self._async_fs:
358
- await self._async_fs._pipe_file(self._file, log_bytes)
359
- written = True
360
- except Exception as ex:
361
- logger.warning(
362
- f"Error occurred during async write to {self._file}: {ex}. Falling back to sync write."
363
- )
364
-
365
- try:
366
- # write sync if we need to
367
- if not written:
355
+ with trace_action(logger, "Log Write", self._file):
356
+ try:
368
357
  with file(self._file, "wb") as f:
369
358
  f.write(log_bytes)
370
- finally:
371
- # re-open zip file w/ self.temp_file pointer at end
372
- self._open()
359
+ finally:
360
+ # re-open zip file w/ self.temp_file pointer at end
361
+ self._open()
373
362
 
374
363
  async def close(self) -> EvalLog:
375
364
  async with self._lock:
@@ -15,6 +15,7 @@ from inspect_ai._util.file import (
15
15
  file,
16
16
  filesystem,
17
17
  )
18
+ from inspect_ai._util.trace import trace_action
18
19
 
19
20
  from .._log import (
20
21
  EvalLog,
@@ -181,23 +182,24 @@ class JSONRecorder(FileRecorder):
181
182
  # get log as bytes
182
183
  log_bytes = eval_log_json(log)
183
184
 
184
- # try to write async for async filesystems
185
- written = False
186
- try:
187
- fs = filesystem(location)
188
- if fs.is_async():
189
- async with async_fileystem(location) as async_fs:
190
- await async_fs._pipe_file(location, log_bytes)
191
- written = True
192
- except Exception as ex:
193
- logger.warning(
194
- f"Error occurred during async write to {location}: {ex}. Falling back to sync write."
195
- )
196
-
197
- # otherwise use sync
198
- if not written:
199
- with file(location, "wb") as f:
200
- f.write(log_bytes)
185
+ with trace_action(logger, "Log Write", location):
186
+ # try to write async for async filesystems
187
+ written = False
188
+ try:
189
+ fs = filesystem(location)
190
+ if fs.is_async():
191
+ async with async_fileystem(location) as async_fs:
192
+ await async_fs._pipe_file(location, log_bytes)
193
+ written = True
194
+ except Exception as ex:
195
+ logger.warning(
196
+ f"Error occurred during async write to {location}: {ex}. Falling back to sync write."
197
+ )
198
+
199
+ # otherwise use sync
200
+ if not written:
201
+ with file(location, "wb") as f:
202
+ f.write(log_bytes)
201
203
 
202
204
 
203
205
  def _validate_version(ver: int) -> None:
@@ -6,10 +6,12 @@ from datetime import datetime, timezone
6
6
  from hashlib import md5
7
7
  from pathlib import Path
8
8
  from shutil import rmtree
9
+ from typing import Any
9
10
 
10
11
  from dateutil.relativedelta import relativedelta
11
12
 
12
13
  from inspect_ai._util.appdirs import inspect_cache_dir
14
+ from inspect_ai._util.trace import trace_message
13
15
  from inspect_ai.tool import ToolChoice, ToolInfo
14
16
 
15
17
  from ._chat_message import ChatMessage
@@ -19,6 +21,10 @@ from ._model_output import ModelOutput
19
21
  logger = logging.getLogger(__name__)
20
22
 
21
23
 
24
+ def trace(msg: str, *args: Any) -> None:
25
+ trace_message(logger, "Cache", msg, *args)
26
+
27
+
22
28
  def _path_is_in_cache(path: Path | str) -> bool:
23
29
  """This ensures the path is in our cache directory, just in case the `model` is ../../../home/ubuntu/maliciousness"""
24
30
  if isinstance(path, str):
@@ -153,7 +159,7 @@ def _cache_key(entry: CacheEntry) -> str:
153
159
 
154
160
  base_string = "|".join([str(component) for component in components])
155
161
 
156
- logger.debug(_cache_key_debug_string([str(component) for component in components]))
162
+ trace(_cache_key_debug_string([str(component) for component in components]))
157
163
 
158
164
  return md5(base_string.encode("utf-8")).hexdigest()
159
165
 
@@ -192,11 +198,11 @@ def cache_store(
192
198
 
193
199
  with open(filename, "wb") as f:
194
200
  expiry = _cache_expiry(entry.policy)
195
- logger.debug("Storing in cache: %s (expires: %s)", filename, expiry)
201
+ trace("Storing in cache: %s (expires: %s)", filename, expiry)
196
202
  pickle.dump((expiry, output), f)
197
203
  return True
198
204
  except Exception as e:
199
- logger.debug(f"Failed to cache {filename}: {e}")
205
+ trace(f"Failed to cache {filename}: {e}")
200
206
  return False
201
207
 
202
208
 
@@ -204,12 +210,12 @@ def cache_fetch(entry: CacheEntry) -> ModelOutput | None:
204
210
  """Fetch a value from the cache directory."""
205
211
  filename = cache_path(model=entry.model) / _cache_key(entry)
206
212
  try:
207
- logger.debug("Fetching from cache: %s", filename)
213
+ trace("Fetching from cache: %s", filename)
208
214
 
209
215
  with open(filename, "rb") as f:
210
216
  expiry, output = pickle.load(f)
211
217
  if not isinstance(output, ModelOutput):
212
- logger.debug(
218
+ trace(
213
219
  "Unexpected cached type, can only fetch ModelOutput: %s (%s)",
214
220
  type(output),
215
221
  filename,
@@ -217,7 +223,7 @@ def cache_fetch(entry: CacheEntry) -> ModelOutput | None:
217
223
  return None
218
224
 
219
225
  if _is_expired(expiry):
220
- logger.debug("Cache expired for %s (%s)", filename, expiry)
226
+ trace("Cache expired for %s (%s)", filename, expiry)
221
227
  # If it's expired, no point keeping it as we'll never access it
222
228
  # successfully again.
223
229
  filename.unlink(missing_ok=True)
@@ -225,7 +231,7 @@ def cache_fetch(entry: CacheEntry) -> ModelOutput | None:
225
231
 
226
232
  return output
227
233
  except Exception as e:
228
- logger.debug(f"Failed to fetch from cache {filename}: {e}")
234
+ trace(f"Failed to fetch from cache {filename}: {e}")
229
235
  return None
230
236
 
231
237
 
@@ -235,7 +241,7 @@ def cache_clear(model: str = "") -> bool:
235
241
  path = cache_path(model)
236
242
 
237
243
  if (model == "" or _path_is_in_cache(path)) and path.exists():
238
- logger.debug("Clearing cache: %s", path)
244
+ trace("Clearing cache: %s", path)
239
245
  rmtree(path)
240
246
  return True
241
247
 
@@ -351,24 +357,24 @@ def cache_list_expired(filter_by: list[str] = []) -> list[Path]:
351
357
  # "../../foo/bar") but we don't want to search the entire cache
352
358
  return []
353
359
 
354
- logger.debug("Filtering by paths: %s", filter_by_paths)
360
+ trace("Filtering by paths: %s", filter_by_paths)
355
361
  for dirpath, _dirnames, filenames in os.walk(cache_path()):
356
362
  if filter_by_paths and Path(dirpath) not in filter_by_paths:
357
- logger.debug("Skipping path %s", dirpath)
363
+ trace("Skipping path %s", dirpath)
358
364
  continue
359
365
 
360
- logger.debug("Checking dirpath %s", dirpath)
366
+ trace("Checking dirpath %s", dirpath)
361
367
  for filename in filenames:
362
368
  path = Path(dirpath) / filename
363
- logger.debug("Checking path %s", path)
369
+ trace("Checking path %s", path)
364
370
  try:
365
371
  with open(path, "rb") as f:
366
372
  expiry, _cache_entry = pickle.load(f)
367
373
  if _is_expired(expiry):
368
- logger.debug("Expired cache entry found: %s (%s)", path, expiry)
374
+ trace("Expired cache entry found: %s (%s)", path, expiry)
369
375
  expired_cache_entries.append(path)
370
376
  except Exception as e:
371
- logger.debug("Failed to load cached item %s: %s", path, e)
377
+ trace("Failed to load cached item %s: %s", path, e)
372
378
  continue
373
379
 
374
380
  return expired_cache_entries
@@ -389,8 +395,8 @@ def cache_prune(files: list[Path] = []) -> None:
389
395
  with open(file, "rb") as f:
390
396
  expiry, _cache_entry = pickle.load(f)
391
397
  if _is_expired(expiry):
392
- logger.debug("Pruning expired cache: %s", file)
398
+ trace("Pruning expired cache: %s", file)
393
399
  file.unlink(missing_ok=True)
394
400
  except Exception as e:
395
- logger.debug("Failed to prune cache %s: %s", file, e)
401
+ trace("Failed to prune cache %s: %s", file, e)
396
402
  continue
@@ -1,6 +1,7 @@
1
1
  import asyncio
2
2
  import inspect
3
3
  from dataclasses import is_dataclass
4
+ from logging import getLogger
4
5
  from textwrap import dedent
5
6
  from typing import (
6
7
  Any,
@@ -19,7 +20,9 @@ from jsonschema import Draft7Validator
19
20
  from pydantic import BaseModel
20
21
 
21
22
  from inspect_ai._util.content import Content, ContentImage, ContentText
23
+ from inspect_ai._util.format import format_function_call
22
24
  from inspect_ai._util.text import truncate_string_to_bytes
25
+ from inspect_ai._util.trace import trace_action
23
26
  from inspect_ai.model._trace import trace_tool_mesage
24
27
  from inspect_ai.tool import Tool, ToolCall, ToolError, ToolInfo
25
28
  from inspect_ai.tool._tool import (
@@ -35,6 +38,8 @@ from inspect_ai.util import OutputLimitExceededError
35
38
  from ._chat_message import ChatMessageAssistant, ChatMessageTool
36
39
  from ._generate_config import active_generate_config
37
40
 
41
+ logger = getLogger(__name__)
42
+
38
43
 
39
44
  async def call_tools(
40
45
  message: ChatMessageAssistant,
@@ -215,7 +220,10 @@ async def call_tool(tools: list[ToolDef], message: str, call: ToolCall) -> Any:
215
220
  arguments = tool_params(call.arguments, tool_def.tool)
216
221
 
217
222
  # call the tool
218
- result = await tool_def.tool(**arguments)
223
+ with trace_action(
224
+ logger, "Tool Call", format_function_call(tool_def.name, arguments, width=1000)
225
+ ):
226
+ result = await tool_def.tool(**arguments)
219
227
 
220
228
  # return result
221
229
  return result