inspect-ai 0.3.72__py3-none-any.whl → 0.3.73__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (103) hide show
  1. inspect_ai/_cli/eval.py +14 -3
  2. inspect_ai/_cli/sandbox.py +3 -3
  3. inspect_ai/_cli/score.py +6 -4
  4. inspect_ai/_cli/trace.py +53 -6
  5. inspect_ai/_display/core/config.py +1 -1
  6. inspect_ai/_display/core/display.py +2 -1
  7. inspect_ai/_display/core/footer.py +6 -6
  8. inspect_ai/_display/plain/display.py +11 -6
  9. inspect_ai/_display/rich/display.py +23 -13
  10. inspect_ai/_display/textual/app.py +10 -9
  11. inspect_ai/_display/textual/display.py +2 -2
  12. inspect_ai/_display/textual/widgets/footer.py +4 -0
  13. inspect_ai/_display/textual/widgets/samples.py +14 -5
  14. inspect_ai/_eval/context.py +1 -2
  15. inspect_ai/_eval/eval.py +54 -41
  16. inspect_ai/_eval/loader.py +9 -2
  17. inspect_ai/_eval/run.py +148 -81
  18. inspect_ai/_eval/score.py +13 -8
  19. inspect_ai/_eval/task/images.py +31 -21
  20. inspect_ai/_eval/task/run.py +62 -59
  21. inspect_ai/_eval/task/rundir.py +16 -9
  22. inspect_ai/_eval/task/sandbox.py +7 -8
  23. inspect_ai/_eval/task/util.py +7 -0
  24. inspect_ai/_util/_async.py +118 -10
  25. inspect_ai/_util/constants.py +0 -2
  26. inspect_ai/_util/file.py +15 -29
  27. inspect_ai/_util/future.py +37 -0
  28. inspect_ai/_util/http.py +3 -99
  29. inspect_ai/_util/httpx.py +60 -0
  30. inspect_ai/_util/interrupt.py +2 -2
  31. inspect_ai/_util/json.py +5 -52
  32. inspect_ai/_util/logger.py +30 -86
  33. inspect_ai/_util/retry.py +10 -61
  34. inspect_ai/_util/trace.py +2 -2
  35. inspect_ai/_view/server.py +86 -3
  36. inspect_ai/_view/www/dist/assets/index.js +25837 -13269
  37. inspect_ai/_view/www/log-schema.json +253 -186
  38. inspect_ai/_view/www/package.json +2 -2
  39. inspect_ai/_view/www/src/plan/PlanDetailView.tsx +8 -3
  40. inspect_ai/_view/www/src/samples/transcript/StepEventView.tsx +2 -3
  41. inspect_ai/_view/www/src/types/log.d.ts +122 -94
  42. inspect_ai/approval/_human/manager.py +6 -10
  43. inspect_ai/approval/_human/panel.py +2 -2
  44. inspect_ai/dataset/_sources/util.py +7 -6
  45. inspect_ai/log/__init__.py +4 -0
  46. inspect_ai/log/_file.py +35 -61
  47. inspect_ai/log/_log.py +18 -1
  48. inspect_ai/log/_recorders/eval.py +14 -23
  49. inspect_ai/log/_recorders/json.py +3 -18
  50. inspect_ai/log/_samples.py +27 -2
  51. inspect_ai/log/_transcript.py +8 -8
  52. inspect_ai/model/__init__.py +2 -1
  53. inspect_ai/model/_call_tools.py +60 -40
  54. inspect_ai/model/_chat_message.py +3 -2
  55. inspect_ai/model/_generate_config.py +25 -0
  56. inspect_ai/model/_model.py +74 -36
  57. inspect_ai/model/_openai.py +9 -1
  58. inspect_ai/model/_providers/anthropic.py +24 -26
  59. inspect_ai/model/_providers/azureai.py +11 -9
  60. inspect_ai/model/_providers/bedrock.py +33 -24
  61. inspect_ai/model/_providers/cloudflare.py +8 -9
  62. inspect_ai/model/_providers/goodfire.py +7 -3
  63. inspect_ai/model/_providers/google.py +47 -13
  64. inspect_ai/model/_providers/groq.py +15 -15
  65. inspect_ai/model/_providers/hf.py +24 -17
  66. inspect_ai/model/_providers/mistral.py +36 -20
  67. inspect_ai/model/_providers/openai.py +30 -25
  68. inspect_ai/model/_providers/openai_o1.py +1 -1
  69. inspect_ai/model/_providers/providers.py +1 -1
  70. inspect_ai/model/_providers/together.py +3 -4
  71. inspect_ai/model/_providers/util/__init__.py +2 -2
  72. inspect_ai/model/_providers/util/chatapi.py +6 -19
  73. inspect_ai/model/_providers/util/hooks.py +165 -0
  74. inspect_ai/model/_providers/vertex.py +20 -3
  75. inspect_ai/model/_providers/vllm.py +16 -19
  76. inspect_ai/scorer/_multi.py +5 -2
  77. inspect_ai/solver/_bridge/patch.py +31 -1
  78. inspect_ai/solver/_fork.py +5 -3
  79. inspect_ai/solver/_human_agent/agent.py +3 -2
  80. inspect_ai/tool/__init__.py +8 -2
  81. inspect_ai/tool/_tool_info.py +4 -90
  82. inspect_ai/tool/_tool_params.py +4 -34
  83. inspect_ai/tool/_tools/_web_search.py +30 -24
  84. inspect_ai/util/__init__.py +4 -0
  85. inspect_ai/util/_concurrency.py +5 -6
  86. inspect_ai/util/_display.py +6 -0
  87. inspect_ai/util/_json.py +170 -0
  88. inspect_ai/util/_sandbox/docker/cleanup.py +13 -9
  89. inspect_ai/util/_sandbox/docker/docker.py +5 -0
  90. inspect_ai/util/_sandbox/environment.py +56 -9
  91. inspect_ai/util/_sandbox/service.py +12 -5
  92. inspect_ai/util/_subprocess.py +94 -113
  93. inspect_ai/util/_subtask.py +2 -4
  94. {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/METADATA +6 -2
  95. {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/RECORD +99 -99
  96. {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/WHEEL +1 -1
  97. inspect_ai/_util/timeouts.py +0 -160
  98. inspect_ai/_view/www/node_modules/flatted/python/flatted.py +0 -149
  99. inspect_ai/_view/www/node_modules/flatted/python/test.py +0 -63
  100. inspect_ai/model/_providers/util/tracker.py +0 -92
  101. {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/LICENSE +0 -0
  102. {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/entry_points.txt +0 -0
  103. {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/top_level.txt +0 -0
inspect_ai/_util/http.py CHANGED
@@ -1,99 +1,3 @@
1
- import glob
2
- import json
3
- import os
4
- import posixpath
5
- from http import HTTPStatus
6
- from http.server import SimpleHTTPRequestHandler
7
- from io import BytesIO
8
- from typing import Any
9
- from urllib.parse import parse_qs, urlparse
10
-
11
- from .dev import is_dev_mode
12
-
13
-
14
- class InspectHTTPRequestHandler(SimpleHTTPRequestHandler):
15
- def __init__(self, *args: Any, directory: str, **kwargs: Any) -> None:
16
- # note whether we are in dev mode (i.e. developing the package)
17
- self.dev_mode = is_dev_mode()
18
-
19
- # initialize file serving directory
20
- directory = os.path.abspath(directory)
21
- super().__init__(*args, directory=directory, **kwargs)
22
-
23
- def do_GET(self) -> None:
24
- if self.path.startswith("/api/events"):
25
- self.handle_events()
26
- else:
27
- super().do_GET()
28
-
29
- def handle_events(self) -> None:
30
- """Client polls for events (e.g. dev reload) ~ every 1 second."""
31
- query = parse_qs(urlparse(self.path).query)
32
- params = dict(zip(query.keys(), [value[0] for value in query.values()]))
33
- self.send_json(json.dumps(self.events_response(params)))
34
-
35
- def events_response(self, params: dict[str, str]) -> list[str]:
36
- """Send back a 'reload' event if we have modified source files."""
37
- loaded_time = params.get("loaded_time", None)
38
- return (
39
- ["reload"] if loaded_time and self.should_reload(int(loaded_time)) else []
40
- )
41
-
42
- def translate_path(self, path: str) -> str:
43
- """Ensure that paths don't escape self.directory."""
44
- translated = super().translate_path(path)
45
- if not os.path.abspath(translated).startswith(self.directory):
46
- return self.directory
47
- else:
48
- return translated
49
-
50
- def send_json(self, json: str | bytes) -> None:
51
- if isinstance(json, str):
52
- json = json.encode()
53
- self.send_response(HTTPStatus.OK)
54
- self.send_header("Content-type", "application/json")
55
- self.end_headers()
56
- self.copyfile(BytesIO(json), self.wfile) # type: ignore
57
-
58
- def send_response(self, code: int, message: str | None = None) -> None:
59
- """No client side or proxy caches."""
60
- super().send_response(code, message)
61
- self.send_header("Expires", "Fri, 01 Jan 1990 00:00:00 GMT")
62
- self.send_header("Pragma", "no-cache")
63
- self.send_header(
64
- "Cache-Control", "no-cache, no-store, max-age=0, must-revalidate"
65
- )
66
-
67
- def guess_type(self, path: str | os.PathLike[str]) -> str:
68
- _, ext = posixpath.splitext(path)
69
- if not ext or ext == ".mjs" or ext == ".js":
70
- return "application/javascript"
71
- elif ext == ".md":
72
- return "text/markdown"
73
- else:
74
- return super().guess_type(path)
75
-
76
- def log_error(self, format: str, *args: Any) -> None:
77
- if self.dev_mode:
78
- super().log_error(format, *args)
79
-
80
- def log_request(self, code: int | str = "-", size: int | str = "-") -> None:
81
- """Don't log status 200 or 404 (too chatty)."""
82
- if code not in [200, 404]:
83
- super().log_request(code, size)
84
-
85
- def should_reload(self, loaded_time: int) -> bool:
86
- if self.dev_mode:
87
- for dir in self.reload_dirs():
88
- files = [
89
- os.stat(file).st_mtime
90
- for file in glob.glob(f"{dir}/**/*", recursive=True)
91
- ]
92
- last_modified = max(files) * 1000
93
- if last_modified > loaded_time:
94
- return True
95
-
96
- return False
97
-
98
- def reload_dirs(self) -> list[str]:
99
- return [self.directory]
1
+ # see https://cloud.google.com/storage/docs/retry-strategy
2
+ def is_retryable_http_status(status_code: int) -> bool:
3
+ return status_code in [408, 429] or (500 <= status_code < 600)
@@ -0,0 +1,60 @@
1
+ import logging
2
+ from typing import Callable
3
+
4
+ from httpx import ConnectError, ConnectTimeout, HTTPStatusError, ReadTimeout
5
+ from tenacity import RetryCallState
6
+
7
+ from inspect_ai._util.constants import HTTP
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ def httpx_should_retry(ex: BaseException) -> bool:
13
+ """Check whether an exception raised from httpx should be retried.
14
+
15
+ Implements the strategy described here: https://cloud.google.com/storage/docs/retry-strategy
16
+
17
+ Args:
18
+ ex (BaseException): Exception to examine for retry behavior
19
+
20
+ Returns:
21
+ True if a retry should occur
22
+ """
23
+ # httpx status exception
24
+ if isinstance(ex, HTTPStatusError):
25
+ # request timeout
26
+ if ex.response.status_code == 408:
27
+ return True
28
+ # lock timeout
29
+ elif ex.response.status_code == 409:
30
+ return True
31
+ # rate limit
32
+ elif ex.response.status_code == 429:
33
+ return True
34
+ # internal errors
35
+ elif ex.response.status_code >= 500:
36
+ return True
37
+ else:
38
+ return False
39
+
40
+ # connection error
41
+ elif is_httpx_connection_error(ex):
42
+ return True
43
+
44
+ # don't retry
45
+ else:
46
+ return False
47
+
48
+
49
+ def log_httpx_retry_attempt(context: str) -> Callable[[RetryCallState], None]:
50
+ def log_attempt(retry_state: RetryCallState) -> None:
51
+ logger.log(
52
+ HTTP,
53
+ f"{context} connection retry {retry_state.attempt_number} after waiting for {retry_state.idle_for}",
54
+ )
55
+
56
+ return log_attempt
57
+
58
+
59
+ def is_httpx_connection_error(ex: BaseException) -> bool:
60
+ return isinstance(ex, ConnectTimeout | ConnectError | ConnectionError | ReadTimeout)
@@ -1,4 +1,4 @@
1
- import asyncio
1
+ import anyio
2
2
 
3
3
  from .working import check_sample_working_limit
4
4
 
@@ -9,7 +9,7 @@ def check_sample_interrupt() -> None:
9
9
  # check for user interrupt
10
10
  sample = sample_active()
11
11
  if sample and sample.interrupt_action:
12
- raise asyncio.CancelledError()
12
+ raise anyio.get_cancelled_exc_class()
13
13
 
14
14
  # check for working_limit
15
15
  check_sample_working_limit()
inspect_ai/_util/json.py CHANGED
@@ -1,13 +1,13 @@
1
- from typing import Any, Literal, cast
1
+ from typing import (
2
+ Any,
3
+ Literal,
4
+ cast,
5
+ )
2
6
 
3
7
  import jsonpatch
4
8
  from pydantic import BaseModel, Field, JsonValue
5
9
  from pydantic_core import to_jsonable_python
6
10
 
7
- JSONType = Literal["string", "integer", "number", "boolean", "array", "object", "null"]
8
-
9
- PythonType = Literal["str", "int", "float", "bool", "list", "dict", "None"]
10
-
11
11
 
12
12
  def jsonable_python(x: Any) -> Any:
13
13
  return to_jsonable_python(x, exclude_none=True, fallback=lambda _x: None)
@@ -23,53 +23,6 @@ def jsonable_dict(x: Any) -> dict[str, JsonValue]:
23
23
  )
24
24
 
25
25
 
26
- def python_type_to_json_type(python_type: str | None) -> JSONType:
27
- match python_type:
28
- case "str":
29
- return "string"
30
- case "int":
31
- return "integer"
32
- case "float":
33
- return "number"
34
- case "bool":
35
- return "boolean"
36
- case "list":
37
- return "array"
38
- case "dict":
39
- return "object"
40
- case "None":
41
- return "null"
42
- # treat 'unknown' as string as anything can be converted to string
43
- case None:
44
- return "string"
45
- case _:
46
- raise ValueError(
47
- f"Unsupported type: {python_type} for Python to JSON conversion."
48
- )
49
-
50
-
51
- def json_type_to_python_type(json_type: str) -> PythonType:
52
- match json_type:
53
- case "string":
54
- return "str"
55
- case "integer":
56
- return "int"
57
- case "number":
58
- return "float"
59
- case "boolean":
60
- return "bool"
61
- case "array":
62
- return "list"
63
- case "object":
64
- return "dict"
65
- case "null":
66
- return "None"
67
- case _:
68
- raise ValueError(
69
- f"Unsupported type: {json_type} for JSON to Python conversion."
70
- )
71
-
72
-
73
26
  class JsonChange(BaseModel):
74
27
  """Describes a change to data using JSON Patch format."""
75
28
 
@@ -1,8 +1,6 @@
1
1
  import atexit
2
2
  import os
3
- import re
4
3
  from logging import (
5
- DEBUG,
6
4
  INFO,
7
5
  WARNING,
8
6
  FileHandler,
@@ -44,10 +42,12 @@ TRACE_FILE_NAME = "trace.log"
44
42
 
45
43
  # log handler that filters messages to stderr and the log file
46
44
  class LogHandler(RichHandler):
47
- def __init__(self, levelno: int, transcript_levelno: int) -> None:
48
- super().__init__(levelno, console=rich.get_console())
45
+ def __init__(
46
+ self, capture_levelno: int, display_levelno: int, transcript_levelno: int
47
+ ) -> None:
48
+ super().__init__(capture_levelno, console=rich.get_console())
49
49
  self.transcript_levelno = transcript_levelno
50
- self.display_level = WARNING
50
+ self.display_level = display_levelno
51
51
  # log into an external file if requested via env var
52
52
  file_logger = os.environ.get("INSPECT_PY_LOGGER_FILE", None)
53
53
  self.file_logger = FileHandler(file_logger) if file_logger else None
@@ -77,23 +77,6 @@ class LogHandler(RichHandler):
77
77
 
78
78
  @override
79
79
  def emit(self, record: LogRecord) -> None:
80
- # demote httpx and return notifications to log_level http
81
- if (
82
- record.name == "httpx"
83
- or "http" in record.name
84
- or "Retrying request" in record.getMessage()
85
- ):
86
- record.levelno = HTTP
87
- record.levelname = HTTP_LOG_LEVEL
88
-
89
- # skip httpx event loop is closed errors
90
- if "Event loop is closed" in record.getMessage():
91
- return
92
-
93
- # skip google-genai AFC message
94
- if "AFC is enabled with max remote calls" in record.getMessage():
95
- return
96
-
97
80
  # write to stderr if we are at or above the threshold
98
81
  if record.levelno >= self.display_level:
99
82
  super().emit(record)
@@ -110,10 +93,9 @@ class LogHandler(RichHandler):
110
93
  if self.trace_logger and record.levelno >= self.trace_logger_level:
111
94
  self.trace_logger.emit(record)
112
95
 
113
- # eval log always gets info level and higher records
114
- # eval log only gets debug or http if we opt-in
115
- write = record.levelno >= self.transcript_levelno
116
- notify_logger_record(record, write)
96
+ # eval log gets transcript level or higher
97
+ if record.levelno >= self.transcript_levelno:
98
+ log_to_transcript(record)
117
99
 
118
100
  @override
119
101
  def render_message(self, record: LogRecord, message: str) -> ConsoleRenderable:
@@ -122,9 +104,7 @@ class LogHandler(RichHandler):
122
104
 
123
105
  # initialize logging -- this function can be called multiple times
124
106
  # in the lifetime of the process (the levelno will update globally)
125
- def init_logger(
126
- log_level: str | None = None, log_level_transcript: str | None = None
127
- ) -> None:
107
+ def init_logger(log_level: str | None, log_level_transcript: str | None = None) -> None:
128
108
  # backwards compatibility for 'tools'
129
109
  if log_level == "sandbox" or log_level == "tools":
130
110
  log_level = "trace"
@@ -146,7 +126,7 @@ def init_logger(
146
126
  ).upper()
147
127
  validate_level("log level", log_level)
148
128
 
149
- # reolve log file level
129
+ # reolve transcript log level
150
130
  log_level_transcript = (
151
131
  log_level_transcript
152
132
  if log_level_transcript
@@ -158,76 +138,40 @@ def init_logger(
158
138
  levelno = getLevelName(log_level)
159
139
  transcript_levelno = getLevelName(log_level_transcript)
160
140
 
141
+ # set capture level for our logs (we won't actually display/write all of them)
142
+ capture_level = min(TRACE, levelno, transcript_levelno)
143
+
161
144
  # init logging handler on demand
162
145
  global _logHandler
163
- removed_root_handlers = False
164
146
  if not _logHandler:
165
- removed_root_handlers = remove_non_pytest_root_logger_handlers()
166
- _logHandler = LogHandler(min(DEBUG, levelno), transcript_levelno)
167
- getLogger().addHandler(_logHandler)
168
-
169
- # establish default capture level
170
- capture_level = min(TRACE, levelno, transcript_levelno)
171
-
172
- # see all the messages (we won't actually display/write all of them)
173
- getLogger().setLevel(capture_level)
174
- getLogger(PKG_NAME).setLevel(capture_level)
175
- getLogger("httpx").setLevel(capture_level)
176
- getLogger("botocore").setLevel(DEBUG)
177
-
178
- if removed_root_handlers:
179
- getLogger(PKG_NAME).warning(
180
- "Inspect removed pre-existing root logger handlers and replaced them with its own handler."
147
+ _logHandler = LogHandler(
148
+ capture_levelno=capture_level,
149
+ display_levelno=levelno,
150
+ transcript_levelno=transcript_levelno,
181
151
  )
182
152
 
183
- # set the levelno on the global handler
184
- _logHandler.display_level = levelno
153
+ # set the log level for our package
154
+ getLogger(PKG_NAME).setLevel(capture_level)
155
+ getLogger(PKG_NAME).addHandler(_logHandler)
156
+ getLogger(PKG_NAME).propagate = False
185
157
 
158
+ # add our logger to the global handlers
159
+ getLogger().addHandler(_logHandler)
186
160
 
187
- _logHandler: LogHandler | None = None
161
+ # httpx currently logs all requests at the INFO level
162
+ # this is a bit aggressive and we already do this at
163
+ # our own HTTP level
164
+ getLogger("httpx").setLevel(WARNING)
188
165
 
189
166
 
190
- def remove_non_pytest_root_logger_handlers() -> bool:
191
- root_logger = getLogger()
192
- non_pytest_handlers = [
193
- handler
194
- for handler in root_logger.handlers
195
- if handler.__module__ != "_pytest.logging"
196
- ]
197
- for handler in non_pytest_handlers:
198
- root_logger.removeHandler(handler)
199
- return len(non_pytest_handlers) > 0
167
+ _logHandler: LogHandler | None = None
200
168
 
201
169
 
202
- def notify_logger_record(record: LogRecord, write: bool) -> None:
170
+ def log_to_transcript(record: LogRecord) -> None:
203
171
  from inspect_ai.log._message import LoggingMessage
204
172
  from inspect_ai.log._transcript import LoggerEvent, transcript
205
173
 
206
- if write:
207
- transcript()._event(
208
- LoggerEvent(message=LoggingMessage._from_log_record(record))
209
- )
210
- global _rate_limit_count
211
- if (record.levelno <= INFO and re.search(r"\b429\b", record.getMessage())) or (
212
- record.levelno == DEBUG
213
- # See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/retries.html#validating-retry-attempts
214
- # for boto retry logic / log messages (this is tracking standard or adapative retries)
215
- and "botocore.retries.standard" in record.name
216
- and "Retry needed, retrying request after delay of:" in record.getMessage()
217
- ):
218
- _rate_limit_count = _rate_limit_count + 1
219
-
220
-
221
- _rate_limit_count = 0
222
-
223
-
224
- def init_http_rate_limit_count() -> None:
225
- global _rate_limit_count
226
- _rate_limit_count = 0
227
-
228
-
229
- def http_rate_limit_count() -> int:
230
- return _rate_limit_count
174
+ transcript()._event(LoggerEvent(message=LoggingMessage._from_log_record(record)))
231
175
 
232
176
 
233
177
  def warn_once(logger: Logger, message: str) -> None:
inspect_ai/_util/retry.py CHANGED
@@ -1,67 +1,16 @@
1
- import logging
2
- from typing import Callable
1
+ _http_retries_count: int = 0
3
2
 
4
- from httpx import ConnectError, ConnectTimeout, HTTPStatusError, ReadTimeout
5
- from tenacity import RetryCallState
6
3
 
7
- from inspect_ai._util.constants import HTTP
4
+ def report_http_retry() -> None:
5
+ from inspect_ai.log._samples import report_active_sample_retry
8
6
 
9
- logger = logging.getLogger(__name__)
7
+ # bump global counter
8
+ global _http_retries_count
9
+ _http_retries_count = _http_retries_count + 1
10
10
 
11
+ # report sample retry
12
+ report_active_sample_retry()
11
13
 
12
- def httpx_should_retry(ex: BaseException) -> bool:
13
- """Check whether an exception raised from httpx should be retried.
14
14
 
15
- Implements the strategy described here: https://cloud.google.com/storage/docs/retry-strategy
16
-
17
- Args:
18
- ex (BaseException): Exception to examine for retry behavior
19
-
20
- Returns:
21
- True if a retry should occur
22
- """
23
- # httpx status exception
24
- if isinstance(ex, HTTPStatusError):
25
- # request timeout
26
- if ex.response.status_code == 408:
27
- return True
28
- # lock timeout
29
- elif ex.response.status_code == 409:
30
- return True
31
- # rate limit
32
- elif ex.response.status_code == 429:
33
- return True
34
- # internal errors
35
- elif ex.response.status_code >= 500:
36
- return True
37
- else:
38
- return False
39
-
40
- # connection error
41
- elif is_httpx_connection_error(ex):
42
- return True
43
-
44
- # don't retry
45
- else:
46
- return False
47
-
48
-
49
- def log_rate_limit_retry(context: str, retry_state: RetryCallState) -> None:
50
- logger.log(
51
- HTTP,
52
- f"{context} rate limit retry {retry_state.attempt_number} after waiting for {retry_state.idle_for}",
53
- )
54
-
55
-
56
- def log_retry_attempt(context: str) -> Callable[[RetryCallState], None]:
57
- def log_attempt(retry_state: RetryCallState) -> None:
58
- logger.log(
59
- HTTP,
60
- f"{context} connection retry {retry_state.attempt_number} after waiting for {retry_state.idle_for}",
61
- )
62
-
63
- return log_attempt
64
-
65
-
66
- def is_httpx_connection_error(ex: BaseException) -> bool:
67
- return isinstance(ex, ConnectTimeout | ConnectError | ConnectionError | ReadTimeout)
15
+ def http_retries_count() -> int:
16
+ return _http_retries_count
inspect_ai/_util/trace.py CHANGED
@@ -1,4 +1,3 @@
1
- import asyncio
2
1
  import datetime
3
2
  import gzip
4
3
  import json
@@ -13,6 +12,7 @@ from logging import FileHandler, Logger
13
12
  from pathlib import Path
14
13
  from typing import Any, Callable, Generator, Literal, TextIO
15
14
 
15
+ import anyio
16
16
  import jsonlines
17
17
  from pydantic import BaseModel, Field, JsonValue
18
18
  from shortuuid import uuid
@@ -83,7 +83,7 @@ def trace_action(
83
83
  "duration": duration,
84
84
  },
85
85
  )
86
- except (KeyboardInterrupt, asyncio.CancelledError):
86
+ except (KeyboardInterrupt, anyio.get_cancelled_exc_class()):
87
87
  duration = time.monotonic() - start_monotonic
88
88
  logger.log(
89
89
  TRACE,
@@ -1,24 +1,27 @@
1
1
  import asyncio
2
+ import contextlib
2
3
  import logging
3
4
  import os
4
5
  import urllib.parse
5
6
  from logging import LogRecord, getLogger
6
7
  from pathlib import Path
7
- from typing import Any, Awaitable, Callable
8
+ from typing import Any, AsyncIterator, Awaitable, Callable, Literal, cast
8
9
 
9
10
  import fsspec # type: ignore
10
11
  from aiohttp import web
11
12
  from fsspec.asyn import AsyncFileSystem # type: ignore
12
13
  from fsspec.core import split_protocol # type: ignore
13
14
  from pydantic_core import to_jsonable_python
15
+ from s3fs import S3FileSystem # type: ignore
14
16
 
15
17
  from inspect_ai._display import display
16
18
  from inspect_ai._util.constants import DEFAULT_SERVER_HOST, DEFAULT_VIEW_PORT
17
- from inspect_ai._util.file import filesystem, size_in_mb
19
+ from inspect_ai._util.file import default_fs_options, filesystem, size_in_mb
18
20
  from inspect_ai.log._file import (
19
21
  EvalLogInfo,
20
22
  eval_log_json,
21
- list_eval_logs_async,
23
+ list_eval_logs,
24
+ log_files_from_ls,
22
25
  read_eval_log_async,
23
26
  read_eval_log_headers_async,
24
27
  )
@@ -297,6 +300,62 @@ def resolve_header_only(path: str, header_only: int | None) -> bool:
297
300
  return False
298
301
 
299
302
 
303
+ async def list_eval_logs_async(
304
+ log_dir: str = os.environ.get("INSPECT_LOG_DIR", "./logs"),
305
+ formats: list[Literal["eval", "json"]] | None = None,
306
+ recursive: bool = True,
307
+ descending: bool = True,
308
+ fs_options: dict[str, Any] = {},
309
+ ) -> list[EvalLogInfo]:
310
+ """List all eval logs in a directory.
311
+
312
+ Will be async for filesystem providers that support async (e.g. s3, gcs, etc.)
313
+ otherwise will fallback to sync implementation.
314
+
315
+ Args:
316
+ log_dir (str): Log directory (defaults to INSPECT_LOG_DIR)
317
+ formats (Literal["eval", "json"]): Formats to list (default
318
+ to listing all formats)
319
+ recursive (bool): List log files recursively (defaults to True).
320
+ descending (bool): List in descending order.
321
+ fs_options (dict[str, Any]): Optional. Additional arguments to pass through
322
+ to the filesystem provider (e.g. `S3FileSystem`).
323
+
324
+ Returns:
325
+ List of EvalLog Info.
326
+ """
327
+ # async filesystem if we can
328
+ fs = filesystem(log_dir, fs_options)
329
+ if fs.is_async():
330
+ async with async_fileystem(log_dir, fs_options=fs_options) as async_fs:
331
+ if await async_fs._exists(log_dir):
332
+ # prevent caching of listings
333
+ async_fs.invalidate_cache(log_dir)
334
+ # list logs
335
+ if recursive:
336
+ files: list[dict[str, Any]] = []
337
+ async for _, _, filenames in async_fs._walk(log_dir, detail=True):
338
+ files.extend(filenames.values())
339
+ else:
340
+ files = cast(
341
+ list[dict[str, Any]],
342
+ await async_fs._ls(log_dir, detail=True),
343
+ )
344
+ logs = [fs._file_info(file) for file in files]
345
+ # resolve to eval logs
346
+ return log_files_from_ls(logs, formats, descending)
347
+ else:
348
+ return []
349
+ else:
350
+ return list_eval_logs(
351
+ log_dir=log_dir,
352
+ formats=formats,
353
+ recursive=recursive,
354
+ descending=descending,
355
+ fs_options=fs_options,
356
+ )
357
+
358
+
300
359
  def filter_aiohttp_log() -> None:
301
360
  # filter overly chatty /api/events messages
302
361
  class RequestFilter(logging.Filter):
@@ -329,3 +388,27 @@ def async_connection(log_file: str) -> AsyncFileSystem:
329
388
 
330
389
  # return async file-system
331
390
  return _async_connections.get(protocol)
391
+
392
+
393
+ @contextlib.asynccontextmanager
394
+ async def async_fileystem(
395
+ location: str, fs_options: dict[str, Any] = {}
396
+ ) -> AsyncIterator[AsyncFileSystem]:
397
+ # determine protocol
398
+ protocol, _ = split_protocol(location)
399
+ protocol = protocol or "file"
400
+
401
+ # build options
402
+ options = default_fs_options(location)
403
+ options.update(fs_options)
404
+
405
+ if protocol == "s3":
406
+ s3 = S3FileSystem(asynchronous=True, **options)
407
+ session = await s3.set_session()
408
+ try:
409
+ yield s3
410
+ finally:
411
+ await session.close()
412
+ else:
413
+ options.update({"asynchronous": True, "loop": asyncio.get_event_loop()})
414
+ yield fsspec.filesystem(protocol, **options)