inspect-ai 0.3.49__py3-none-any.whl → 0.3.51__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. inspect_ai/_cli/info.py +2 -2
  2. inspect_ai/_cli/log.py +2 -2
  3. inspect_ai/_cli/score.py +2 -2
  4. inspect_ai/_display/core/display.py +19 -0
  5. inspect_ai/_display/core/panel.py +37 -7
  6. inspect_ai/_display/core/progress.py +29 -2
  7. inspect_ai/_display/core/results.py +79 -40
  8. inspect_ai/_display/core/textual.py +21 -0
  9. inspect_ai/_display/rich/display.py +28 -8
  10. inspect_ai/_display/textual/app.py +107 -1
  11. inspect_ai/_display/textual/display.py +1 -1
  12. inspect_ai/_display/textual/widgets/samples.py +132 -91
  13. inspect_ai/_display/textual/widgets/task_detail.py +236 -0
  14. inspect_ai/_display/textual/widgets/tasks.py +74 -6
  15. inspect_ai/_display/textual/widgets/toggle.py +32 -0
  16. inspect_ai/_eval/context.py +2 -0
  17. inspect_ai/_eval/eval.py +4 -3
  18. inspect_ai/_eval/loader.py +1 -1
  19. inspect_ai/_eval/run.py +35 -2
  20. inspect_ai/_eval/task/log.py +13 -11
  21. inspect_ai/_eval/task/results.py +12 -3
  22. inspect_ai/_eval/task/run.py +139 -36
  23. inspect_ai/_eval/task/sandbox.py +2 -1
  24. inspect_ai/_util/_async.py +30 -1
  25. inspect_ai/_util/file.py +31 -4
  26. inspect_ai/_util/html.py +3 -0
  27. inspect_ai/_util/logger.py +6 -5
  28. inspect_ai/_util/platform.py +5 -6
  29. inspect_ai/_util/registry.py +1 -1
  30. inspect_ai/_view/server.py +9 -9
  31. inspect_ai/_view/www/App.css +2 -2
  32. inspect_ai/_view/www/dist/assets/index.css +2 -2
  33. inspect_ai/_view/www/dist/assets/index.js +352 -294
  34. inspect_ai/_view/www/log-schema.json +13 -0
  35. inspect_ai/_view/www/package.json +1 -0
  36. inspect_ai/_view/www/src/components/MessageBand.mjs +1 -1
  37. inspect_ai/_view/www/src/components/Tools.mjs +16 -13
  38. inspect_ai/_view/www/src/samples/SampleDisplay.mjs +1 -3
  39. inspect_ai/_view/www/src/samples/SampleScoreView.mjs +52 -77
  40. inspect_ai/_view/www/src/samples/SamplesDescriptor.mjs +38 -13
  41. inspect_ai/_view/www/src/samples/transcript/ModelEventView.mjs +15 -2
  42. inspect_ai/_view/www/src/samples/transcript/state/StateEventRenderers.mjs +4 -2
  43. inspect_ai/_view/www/src/types/log.d.ts +2 -0
  44. inspect_ai/_view/www/src/workspace/WorkSpace.mjs +2 -0
  45. inspect_ai/_view/www/yarn.lock +9 -4
  46. inspect_ai/approval/__init__.py +1 -1
  47. inspect_ai/approval/_human/approver.py +35 -0
  48. inspect_ai/approval/_human/console.py +62 -0
  49. inspect_ai/approval/_human/manager.py +108 -0
  50. inspect_ai/approval/_human/panel.py +233 -0
  51. inspect_ai/approval/_human/util.py +51 -0
  52. inspect_ai/dataset/_sources/hf.py +2 -2
  53. inspect_ai/dataset/_sources/util.py +1 -1
  54. inspect_ai/log/_file.py +106 -36
  55. inspect_ai/log/_recorders/eval.py +226 -158
  56. inspect_ai/log/_recorders/file.py +9 -6
  57. inspect_ai/log/_recorders/json.py +35 -12
  58. inspect_ai/log/_recorders/recorder.py +15 -15
  59. inspect_ai/log/_samples.py +52 -0
  60. inspect_ai/model/_model.py +14 -0
  61. inspect_ai/model/_model_output.py +4 -0
  62. inspect_ai/model/_providers/azureai.py +1 -1
  63. inspect_ai/model/_providers/hf.py +106 -4
  64. inspect_ai/model/_providers/util/__init__.py +2 -0
  65. inspect_ai/model/_providers/util/hf_handler.py +200 -0
  66. inspect_ai/scorer/_common.py +1 -1
  67. inspect_ai/solver/_plan.py +0 -8
  68. inspect_ai/solver/_task_state.py +18 -1
  69. inspect_ai/solver/_use_tools.py +9 -1
  70. inspect_ai/tool/_tool_def.py +2 -2
  71. inspect_ai/tool/_tool_info.py +14 -2
  72. inspect_ai/tool/_tool_params.py +2 -1
  73. inspect_ai/tool/_tools/_execute.py +1 -1
  74. inspect_ai/tool/_tools/_web_browser/_web_browser.py +6 -0
  75. inspect_ai/util/__init__.py +5 -6
  76. inspect_ai/util/_panel.py +91 -0
  77. inspect_ai/util/_sandbox/__init__.py +2 -6
  78. inspect_ai/util/_sandbox/context.py +4 -3
  79. inspect_ai/util/_sandbox/docker/compose.py +12 -2
  80. inspect_ai/util/_sandbox/docker/docker.py +19 -9
  81. inspect_ai/util/_sandbox/docker/util.py +10 -2
  82. inspect_ai/util/_sandbox/environment.py +47 -41
  83. inspect_ai/util/_sandbox/local.py +15 -10
  84. inspect_ai/util/_subprocess.py +43 -3
  85. {inspect_ai-0.3.49.dist-info → inspect_ai-0.3.51.dist-info}/METADATA +2 -2
  86. {inspect_ai-0.3.49.dist-info → inspect_ai-0.3.51.dist-info}/RECORD +90 -82
  87. inspect_ai/_view/www/node_modules/flatted/python/flatted.py +0 -149
  88. inspect_ai/_view/www/node_modules/flatted/python/test.py +0 -63
  89. inspect_ai/approval/_human.py +0 -123
  90. {inspect_ai-0.3.49.dist-info → inspect_ai-0.3.51.dist-info}/LICENSE +0 -0
  91. {inspect_ai-0.3.49.dist-info → inspect_ai-0.3.51.dist-info}/WHEEL +0 -0
  92. {inspect_ai-0.3.49.dist-info → inspect_ai-0.3.51.dist-info}/entry_points.txt +0 -0
  93. {inspect_ai-0.3.49.dist-info → inspect_ai-0.3.51.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,108 @@
1
+ import asyncio
2
+ import uuid
3
+ from asyncio import Future
4
+ from contextvars import ContextVar
5
+ from typing import Callable, Literal, NamedTuple, cast
6
+
7
+ from inspect_ai.solver._task_state import TaskState
8
+ from inspect_ai.tool._tool_call import ToolCall, ToolCallView
9
+
10
+ from .._approval import Approval, ApprovalDecision
11
+
12
+
13
+ class ApprovalRequest(NamedTuple):
14
+ message: str
15
+ call: ToolCall
16
+ view: ToolCallView
17
+ state: TaskState | None
18
+ choices: list[ApprovalDecision]
19
+
20
+
21
+ class PendingApprovalRequest(NamedTuple):
22
+ request: ApprovalRequest
23
+ task: str
24
+ model: str
25
+ id: int | str
26
+ epoch: int
27
+
28
+
29
+ class HumanApprovalManager:
30
+ def __init__(self) -> None:
31
+ self._approval_requests: dict[
32
+ str, tuple[PendingApprovalRequest, Future[Approval]]
33
+ ] = {}
34
+ self._change_callbacks: list[Callable[[Literal["add", "remove"]], None]] = []
35
+
36
+ def request_approval(self, request: ApprovalRequest) -> str:
37
+ from inspect_ai.log._samples import sample_active
38
+
39
+ id = str(uuid.uuid4())
40
+ future = cast(Future[Approval], asyncio.get_event_loop().create_future())
41
+ sample = sample_active()
42
+ assert sample
43
+ assert sample.sample.id
44
+ pending = PendingApprovalRequest(
45
+ request=request,
46
+ task=sample.task,
47
+ model=sample.model,
48
+ id=sample.sample.id,
49
+ epoch=sample.epoch,
50
+ )
51
+ self._approval_requests[id] = (pending, future)
52
+ self._notify_change("add")
53
+ return id
54
+
55
+ def withdraw_request(self, id: str) -> None:
56
+ del self._approval_requests[id]
57
+ self._notify_change("remove")
58
+
59
+ async def wait_for_approval(self, id: str) -> Approval:
60
+ _, future = self._approval_requests[id]
61
+ return await future
62
+
63
+ def on_change(
64
+ self, callback: Callable[[Literal["add", "remove"]], None]
65
+ ) -> Callable[[], None]:
66
+ self._change_callbacks.append(callback)
67
+
68
+ def unsubscribe() -> None:
69
+ if callback in self._change_callbacks:
70
+ self._change_callbacks.remove(callback)
71
+
72
+ return unsubscribe
73
+
74
+ def approval_requests(self) -> list[tuple[str, PendingApprovalRequest]]:
75
+ return [(aid, data) for aid, (data, _) in self._approval_requests.items()]
76
+
77
+ def complete_approval(self, id: str, result: Approval) -> None:
78
+ if id in self._approval_requests:
79
+ _, future = self._approval_requests[id]
80
+ if not future.done():
81
+ future.set_result(result)
82
+ del self._approval_requests[id]
83
+ self._notify_change("remove")
84
+
85
+ def fail_approval(self, id: str, error: Exception) -> None:
86
+ if id in self._approval_requests:
87
+ _, future = self._approval_requests[id]
88
+ if not future.done():
89
+ future.set_exception(error)
90
+ del self._approval_requests[id]
91
+ self._notify_change("remove")
92
+
93
+ def _notify_change(self, action: Literal["add", "remove"]) -> None:
94
+ for callback in self._change_callbacks:
95
+ callback(action)
96
+
97
+
98
+ def human_approval_manager() -> HumanApprovalManager:
99
+ return _human_approval_manager.get()
100
+
101
+
102
+ def init_human_approval_manager() -> None:
103
+ _human_approval_manager.set(HumanApprovalManager())
104
+
105
+
106
+ _human_approval_manager: ContextVar[HumanApprovalManager] = ContextVar(
107
+ "_human_approval_manager"
108
+ )
@@ -0,0 +1,233 @@
1
+ from asyncio import CancelledError
2
+ from typing import Callable, Literal
3
+
4
+ from rich.console import RenderableType
5
+ from rich.text import Text
6
+ from textual.app import ComposeResult
7
+ from textual.containers import Horizontal, ScrollableContainer
8
+ from textual.reactive import reactive
9
+ from textual.widgets import Button, Static
10
+ from typing_extensions import override
11
+
12
+ from inspect_ai._util.registry import registry_unqualified_name
13
+ from inspect_ai.solver._task_state import TaskState
14
+ from inspect_ai.tool._tool_call import ToolCall, ToolCallView
15
+ from inspect_ai.util._panel import InputPanel, input_panel
16
+
17
+ from .._approval import Approval, ApprovalDecision
18
+ from .manager import ApprovalRequest, PendingApprovalRequest, human_approval_manager
19
+ from .util import (
20
+ HUMAN_APPROVED,
21
+ HUMAN_ESCALATED,
22
+ HUMAN_REJECTED,
23
+ HUMAN_TERMINATED,
24
+ render_tool_approval,
25
+ )
26
+
27
+ PANEL_TITLE = "Approvals"
28
+
29
+
30
+ async def panel_approval(
31
+ message: str,
32
+ call: ToolCall,
33
+ view: ToolCallView,
34
+ state: TaskState | None,
35
+ choices: list[ApprovalDecision],
36
+ ) -> Approval:
37
+ # ensure the approvals panel is shown
38
+ await input_panel(PANEL_TITLE, ApprovalInputPanel)
39
+
40
+ # submit to human approval manager (will be picked up by panel)
41
+ approvals = human_approval_manager()
42
+ id = approvals.request_approval(
43
+ ApprovalRequest(
44
+ message=message, call=call, view=view, state=state, choices=choices
45
+ )
46
+ )
47
+ try:
48
+ return await approvals.wait_for_approval(id)
49
+ except CancelledError:
50
+ approvals.withdraw_request(id)
51
+ raise
52
+
53
+
54
+ class ApprovalInputPanel(InputPanel):
55
+ DEFAULT_CSS = """
56
+ ApprovalInputPanel {
57
+ width: 1fr;
58
+ height: 1fr;
59
+ padding: 0 1 1 1;
60
+ layout: grid;
61
+ grid-size: 1 3;
62
+ grid-rows: auto 1fr auto;
63
+ }
64
+ """
65
+
66
+ _approvals: list[tuple[str, PendingApprovalRequest]] = []
67
+ _unsubscribe: Callable[[], None] | None = None
68
+
69
+ @override
70
+ def compose(self) -> ComposeResult:
71
+ yield ApprovalRequestHeading()
72
+ yield ApprovalRequestContent()
73
+ yield ApprovalRequestActions()
74
+
75
+ def on_mount(self) -> None:
76
+ self._unsubscribe = human_approval_manager().on_change(
77
+ self.on_approvals_changed
78
+ )
79
+
80
+ def on_unmount(self) -> None:
81
+ if self._unsubscribe is not None:
82
+ self._unsubscribe()
83
+
84
+ def on_approvals_changed(self, action: Literal["add", "remove"]) -> None:
85
+ heading = self.query_one(ApprovalRequestHeading)
86
+ content = self.query_one(ApprovalRequestContent)
87
+ actions = self.query_one(ApprovalRequestActions)
88
+ self._approvals = human_approval_manager().approval_requests()
89
+ if len(self._approvals) > 0:
90
+ approval_id, approval_request = self._approvals[0]
91
+ self.title = f"{PANEL_TITLE} ({len(self._approvals):,})"
92
+ heading.request = approval_request
93
+ content.approval = approval_request.request
94
+ actions.approval_request = approval_id, approval_request
95
+ if action == "add":
96
+ self.activate()
97
+ actions.activate()
98
+ self.visible = True
99
+ else:
100
+ self.title = PANEL_TITLE
101
+ heading.request = None
102
+ content.approval = None
103
+ actions.approval_request = None
104
+ self.deactivate()
105
+ self.visible = False
106
+
107
+
108
+ class ApprovalRequestHeading(Static):
109
+ DEFAULT_CSS = """
110
+ ApprovalRequestHeading {
111
+ width: 1fr;
112
+ background: $surface;
113
+ color: $secondary;
114
+ margin-left: 1;
115
+ }
116
+ """
117
+
118
+ request: reactive[PendingApprovalRequest | None] = reactive(None)
119
+
120
+ def render(self) -> RenderableType:
121
+ if self.request is not None:
122
+ return f"{registry_unqualified_name(self.request.task)} (id: {self.request.id}, epoch {self.request.epoch}): {self.request.model}"
123
+ else:
124
+ return ""
125
+
126
+
127
+ class ApprovalRequestContent(ScrollableContainer):
128
+ DEFAULT_CSS = """
129
+ ApprovalRequestContent {
130
+ scrollbar-size-vertical: 1;
131
+ scrollbar-gutter: stable;
132
+ border: solid $foreground 20%;
133
+ padding: 0 1 0 1;
134
+ }
135
+ """
136
+
137
+ approval: reactive[ApprovalRequest | None] = reactive(None)
138
+
139
+ async def watch_approval(self, approval: ApprovalRequest | None) -> None:
140
+ await self.remove_children()
141
+ if approval:
142
+ self.mount_all(
143
+ Static(r) for r in render_tool_approval(approval.message, approval.view)
144
+ )
145
+ self.scroll_end(animate=False)
146
+
147
+
148
+ class ApprovalRequestActions(Horizontal):
149
+ APPROVE_TOOL_CALL = "approve-tool-call"
150
+ REJECT_TOOL_CALL = "reject-tool-call"
151
+ ESCALATE_TOOL_CALL = "escalate-tool-call"
152
+ TERMINATE_TOOL_CALL_SAMPLE = "terminate-tool-call-sample"
153
+
154
+ DEFAULT_CSS = f"""
155
+ ApprovalRequestActions Button {{
156
+ margin-right: 1;
157
+ min-width: 20;
158
+ }}
159
+ ApprovalRequestActions #{APPROVE_TOOL_CALL} {{
160
+ color: $success;
161
+ }}
162
+ ApprovalRequestActions #{REJECT_TOOL_CALL} {{
163
+ color: $warning-darken-3;
164
+ }}
165
+ ApprovalRequestActions #{ESCALATE_TOOL_CALL} {{
166
+ color: $primary-darken-3;
167
+ margin-left: 3;
168
+ }}
169
+ ApprovalRequestActions #{TERMINATE_TOOL_CALL_SAMPLE} {{
170
+ color: $error-darken-1;
171
+ margin-left: 3;
172
+ }}
173
+ """
174
+
175
+ approval_request: reactive[tuple[str, PendingApprovalRequest] | None] = reactive(
176
+ None
177
+ )
178
+
179
+ def compose(self) -> ComposeResult:
180
+ yield Button(
181
+ Text("Approve"),
182
+ id=self.APPROVE_TOOL_CALL,
183
+ tooltip="Approve the tool call.",
184
+ )
185
+ yield Button(
186
+ Text("Reject"),
187
+ id=self.REJECT_TOOL_CALL,
188
+ tooltip="Reject the tool call.",
189
+ )
190
+ yield Button(
191
+ Text("Escalate"),
192
+ id=self.ESCALATE_TOOL_CALL,
193
+ tooltip="Escalate the tool call to another approver.",
194
+ )
195
+ yield Button(
196
+ Text("Terminate"),
197
+ id=self.TERMINATE_TOOL_CALL_SAMPLE,
198
+ tooltip="Terminate the sample.",
199
+ )
200
+
201
+ def activate(self) -> None:
202
+ approve = self.query_one(f"#{self.APPROVE_TOOL_CALL}")
203
+ approve.focus()
204
+
205
+ def on_button_pressed(self, event: Button.Pressed) -> None:
206
+ if self.approval_request is not None:
207
+ id, _ = self.approval_request
208
+ if event.button.id == self.APPROVE_TOOL_CALL:
209
+ approval = Approval(decision="approve", explanation=HUMAN_APPROVED)
210
+ elif event.button.id == self.REJECT_TOOL_CALL:
211
+ approval = Approval(decision="reject", explanation=HUMAN_REJECTED)
212
+ elif event.button.id == self.ESCALATE_TOOL_CALL:
213
+ approval = Approval(decision="escalate", explanation=HUMAN_ESCALATED)
214
+ elif event.button.id == self.TERMINATE_TOOL_CALL_SAMPLE:
215
+ approval = Approval(decision="terminate", explanation=HUMAN_TERMINATED)
216
+ else:
217
+ raise ValueError(f"Unexpected button id: {event.button.id}")
218
+ human_approval_manager().complete_approval(id, approval)
219
+
220
+ def watch_approval_request(
221
+ self, approval_request: tuple[str, PendingApprovalRequest] | None
222
+ ) -> None:
223
+ choices = (
224
+ approval_request[1].request.choices if approval_request is not None else []
225
+ )
226
+
227
+ def update_visible(id: str, choice: ApprovalDecision) -> None:
228
+ self.query_one(f"#{id}").display = choice in choices
229
+
230
+ update_visible(self.APPROVE_TOOL_CALL, "approve")
231
+ update_visible(self.REJECT_TOOL_CALL, "reject")
232
+ update_visible(self.ESCALATE_TOOL_CALL, "escalate")
233
+ update_visible(self.TERMINATE_TOOL_CALL_SAMPLE, "terminate")
@@ -0,0 +1,51 @@
1
+ from rich.console import RenderableType
2
+ from rich.highlighter import ReprHighlighter
3
+ from rich.rule import Rule
4
+ from rich.text import Text
5
+
6
+ from inspect_ai._util.transcript import transcript_markdown
7
+ from inspect_ai.tool._tool_call import ToolCallContent, ToolCallView
8
+ from inspect_ai.util._trace import trace_enabled
9
+
10
+ HUMAN_APPROVED = "Human operator approved tool call."
11
+ HUMAN_REJECTED = "Human operator rejected the tool call."
12
+ HUMAN_TERMINATED = "Human operator asked that the sample be terminated."
13
+ HUMAN_ESCALATED = "Human operator escalated the tool call approval."
14
+
15
+
16
+ def render_tool_approval(message: str, view: ToolCallView) -> list[RenderableType]:
17
+ renderables: list[RenderableType] = []
18
+ text_highlighter = ReprHighlighter()
19
+
20
+ # ignore content if trace enabled
21
+ message = message.strip() if not trace_enabled() else ""
22
+
23
+ def add_view_content(view_content: ToolCallContent) -> None:
24
+ if view_content.title:
25
+ renderables.append(Text.from_markup(f"[bold]{view_content.title}[/bold]\n"))
26
+ if view_content.format == "markdown":
27
+ renderables.append(transcript_markdown(view_content.content))
28
+ else:
29
+ text_content = text_highlighter(Text(view_content.content))
30
+ renderables.append(text_content)
31
+
32
+ # assistant content (don't add if trace_enabled as we already have it in that case)
33
+ if message:
34
+ renderables.append(Text.from_markup("[bold]Assistant[/bold]\n"))
35
+ renderables.append(Text(f"{message.strip()}"))
36
+
37
+ # extra context provided by tool view
38
+ if view.context:
39
+ renderables.append(Text())
40
+ add_view_content(view.context)
41
+ renderables.append(Text())
42
+
43
+ # tool call view
44
+ if view.call:
45
+ if message or view.context:
46
+ renderables.append(Rule("", style="#282c34", align="left", characters="․"))
47
+ renderables.append(Text())
48
+ add_view_content(view.call)
49
+ renderables.append(Text())
50
+
51
+ return renderables
@@ -21,9 +21,9 @@ from .._util import data_to_samples, record_to_sample_fn
21
21
 
22
22
  def hf_dataset(
23
23
  path: str,
24
+ split: str,
24
25
  name: str | None = None,
25
26
  data_dir: str | None = None,
26
- split: str | None = None,
27
27
  revision: str | None = None,
28
28
  sample_fields: FieldSpec | RecordToSample | None = None,
29
29
  auto_id: bool = False,
@@ -44,10 +44,10 @@ def hf_dataset(
44
44
  builder that is used comes from a generic dataset script (JSON, CSV,
45
45
  Parquet, text etc.) or from the dataset script (a python file) inside
46
46
  the dataset directory.
47
+ split (str): Which split of the data to load.
47
48
  name (str | None): Name of the dataset configuration.
48
49
  data_dir (str | None): data_dir of the dataset configuration
49
50
  to read data from.
50
- split (str | None): Which split of the data to load.
51
51
  revision (str | None): Specific revision to load (e.g. "main", a branch
52
52
  name, or a specific commit SHA). When using `revision` the `cached` option
53
53
  is ignored and datasets are revalidated on Hugging Face before loading.
@@ -34,7 +34,7 @@ def resolve_sample_files(dataset: Dataset) -> None:
34
34
  # for each sample
35
35
  for sample in dataset:
36
36
  # check for sandbox config file
37
- if sample.sandbox and sample.sandbox.config is not None:
37
+ if sample.sandbox and isinstance(sample.sandbox.config, str):
38
38
  sample.sandbox = SandboxEnvironmentSpec(
39
39
  sample.sandbox.type, resolve_file(sample.sandbox.config)
40
40
  )
inspect_ai/log/_file.py CHANGED
@@ -1,17 +1,15 @@
1
- import asyncio
2
1
  import os
3
2
  import re
4
3
  from logging import getLogger
5
4
  from typing import Any, Callable, Generator, Literal, cast
6
5
 
7
- import fsspec # type: ignore
8
- from fsspec.asyn import AsyncFileSystem # type: ignore
9
- from fsspec.core import split_protocol # type: ignore
10
6
  from pydantic_core import to_json
11
7
 
8
+ from inspect_ai._util._async import run_coroutine
12
9
  from inspect_ai._util.constants import ALL_LOG_FORMATS, EVAL_LOG_FORMAT
13
10
  from inspect_ai._util.file import (
14
11
  FileInfo,
12
+ async_fileystem,
15
13
  file,
16
14
  filesystem,
17
15
  )
@@ -110,25 +108,25 @@ async def list_eval_logs_async(
110
108
  # async filesystem if we can
111
109
  fs = filesystem(log_dir, fs_options)
112
110
  if fs.is_async():
113
- async_fs = async_fileystem(log_dir, fs_options=fs_options)
114
- if await async_fs._exists(log_dir):
115
- # prevent caching of listings
116
- async_fs.invalidate_cache(log_dir)
117
- # list logs
118
- if recursive:
119
- files: list[dict[str, Any]] = []
120
- async for _, _, filenames in async_fs._walk(log_dir, detail=True):
121
- files.extend(filenames.values())
111
+ async with async_fileystem(log_dir, fs_options=fs_options) as async_fs:
112
+ if await async_fs._exists(log_dir):
113
+ # prevent caching of listings
114
+ async_fs.invalidate_cache(log_dir)
115
+ # list logs
116
+ if recursive:
117
+ files: list[dict[str, Any]] = []
118
+ async for _, _, filenames in async_fs._walk(log_dir, detail=True):
119
+ files.extend(filenames.values())
120
+ else:
121
+ files = cast(
122
+ list[dict[str, Any]],
123
+ await async_fs._ls(log_dir, detail=True),
124
+ )
125
+ logs = [fs._file_info(file) for file in files]
126
+ # resolve to eval logs
127
+ return log_files_from_ls(logs, formats, descending)
122
128
  else:
123
- files = cast(
124
- list[dict[str, Any]],
125
- async_fs._ls(log_dir, detail=True),
126
- )
127
- logs = [fs._file_info(file) for file in files]
128
- # resolve to eval logs
129
- return log_files_from_ls(logs, formats, descending)
130
- else:
131
- return []
129
+ return []
132
130
  else:
133
131
  return list_eval_logs(
134
132
  log_dir=log_dir,
@@ -146,6 +144,22 @@ def write_eval_log(
146
144
  ) -> None:
147
145
  """Write an evaluation log.
148
146
 
147
+ Args:
148
+ log (EvalLog): Evaluation log to write.
149
+ location (str | FileInfo): Location to write log to.
150
+ format (Literal["eval", "json", "auto"]): Write to format
151
+ (defaults to 'auto' based on `log_file` extension)
152
+ """
153
+ run_coroutine(write_eval_log_async(log, location, format))
154
+
155
+
156
+ async def write_eval_log_async(
157
+ log: EvalLog,
158
+ location: str | FileInfo | None = None,
159
+ format: Literal["eval", "json", "auto"] = "auto",
160
+ ) -> None:
161
+ """Write an evaluation log.
162
+
149
163
  Args:
150
164
  log (EvalLog): Evaluation log to write.
151
165
  location (str | FileInfo): Location to write log to.
@@ -169,7 +183,7 @@ def write_eval_log(
169
183
  recorder_type = recorder_type_for_location(location)
170
184
  else:
171
185
  recorder_type = recorder_type_for_format(format)
172
- recorder_type.write_log(location, log)
186
+ await recorder_type.write_log(location, log)
173
187
 
174
188
  logger.debug(f"Writing eval log to {location} completed")
175
189
 
@@ -224,6 +238,31 @@ def read_eval_log(
224
238
  ) -> EvalLog:
225
239
  """Read an evaluation log.
226
240
 
241
+ Args:
242
+ log_file (str | FileInfo): Log file to read.
243
+ header_only (bool): Read only the header (i.e. exclude
244
+ the "samples" and "logging" fields). Defaults to False.
245
+ resolve_attachments (bool): Resolve attachments (e.g. images)
246
+ to their full content.
247
+ format (Literal["eval", "json", "auto"]): Read from format
248
+ (defaults to 'auto' based on `log_file` extension)
249
+
250
+ Returns:
251
+ EvalLog object read from file.
252
+ """
253
+ return run_coroutine(
254
+ read_eval_log_async(log_file, header_only, resolve_attachments, format)
255
+ )
256
+
257
+
258
+ async def read_eval_log_async(
259
+ log_file: str | FileInfo,
260
+ header_only: bool = False,
261
+ resolve_attachments: bool = False,
262
+ format: Literal["eval", "json", "auto"] = "auto",
263
+ ) -> EvalLog:
264
+ """Read an evaluation log.
265
+
227
266
  Args:
228
267
  log_file (str | FileInfo): Log file to read.
229
268
  header_only (bool): Read only the header (i.e. exclude
@@ -245,7 +284,7 @@ def read_eval_log(
245
284
  recorder_type = recorder_type_for_location(log_file)
246
285
  else:
247
286
  recorder_type = recorder_type_for_format(format)
248
- log = recorder_type.read_log(log_file, header_only)
287
+ log = await recorder_type.read_log(log_file, header_only)
249
288
 
250
289
  # resolve attachement if requested
251
290
  if resolve_attachments and log.samples:
@@ -267,7 +306,15 @@ def read_eval_log(
267
306
  def read_eval_log_headers(
268
307
  log_files: list[str] | list[FileInfo] | list[EvalLogInfo],
269
308
  ) -> list[EvalLog]:
270
- return [read_eval_log(log_file, header_only=True) for log_file in log_files]
309
+ return run_coroutine(read_eval_log_headers_async(log_files))
310
+
311
+
312
+ async def read_eval_log_headers_async(
313
+ log_files: list[str] | list[FileInfo] | list[EvalLogInfo],
314
+ ) -> list[EvalLog]:
315
+ return [
316
+ await read_eval_log_async(log_file, header_only=True) for log_file in log_files
317
+ ]
271
318
 
272
319
 
273
320
  def read_eval_log_sample(
@@ -279,6 +326,35 @@ def read_eval_log_sample(
279
326
  ) -> EvalSample:
280
327
  """Read a sample from an evaluation log.
281
328
 
329
+ Args:
330
+ log_file (str | FileInfo): Log file to read.
331
+ id (int | str): Sample id to read.
332
+ epoch (int): Epoch for sample id (defaults to 1)
333
+ resolve_attachments (bool): Resolve attachments (e.g. images)
334
+ to their full content.
335
+ format (Literal["eval", "json", "auto"]): Read from format
336
+ (defaults to 'auto' based on `log_file` extension)
337
+
338
+ Returns:
339
+ EvalSample object read from file.
340
+
341
+ Raises:
342
+ IndexError: If the passed id and epoch are not found.
343
+ """
344
+ return run_coroutine(
345
+ read_eval_log_sample_async(log_file, id, epoch, resolve_attachments, format)
346
+ )
347
+
348
+
349
+ async def read_eval_log_sample_async(
350
+ log_file: str | FileInfo,
351
+ id: int | str,
352
+ epoch: int = 1,
353
+ resolve_attachments: bool = False,
354
+ format: Literal["eval", "json", "auto"] = "auto",
355
+ ) -> EvalSample:
356
+ """Read a sample from an evaluation log.
357
+
282
358
  Args:
283
359
  log_file (str | FileInfo): Log file to read.
284
360
  id (int | str): Sample id to read.
@@ -301,7 +377,7 @@ def read_eval_log_sample(
301
377
  recorder_type = recorder_type_for_location(log_file)
302
378
  else:
303
379
  recorder_type = recorder_type_for_format(format)
304
- sample = recorder_type.read_log_sample(log_file, id, epoch)
380
+ sample = await recorder_type.read_log_sample(log_file, id, epoch)
305
381
 
306
382
  if resolve_attachments:
307
383
  sample = resolve_sample_attachments(sample)
@@ -442,7 +518,7 @@ def log_file_info(info: FileInfo) -> "EvalLogInfo":
442
518
  )
443
519
 
444
520
 
445
- def eval_log_json(log: EvalLog) -> str:
521
+ def eval_log_json(log: EvalLog) -> bytes:
446
522
  # serialize to json (ignore values that are unserializable)
447
523
  # these values often result from solvers using metadata to
448
524
  # pass around 'live' objects -- this is fine to do and we
@@ -452,14 +528,8 @@ def eval_log_json(log: EvalLog) -> str:
452
528
  indent=2,
453
529
  exclude_none=True,
454
530
  fallback=lambda _x: None,
455
- ).decode()
531
+ )
456
532
 
457
533
 
458
- def async_fileystem(log_file: str, fs_options: dict[str, Any] = {}) -> AsyncFileSystem:
459
- # determine protocol
460
- protocol, _ = split_protocol(log_file)
461
- protocol = protocol or "file"
462
- # create filesystem
463
- fs_options = fs_options.copy()
464
- fs_options.update({"asynchronous": True, "loop": asyncio.get_event_loop()})
465
- return fsspec.filesystem(protocol, **fs_options)
534
+ def eval_log_json_str(log: EvalLog) -> str:
535
+ return eval_log_json(log).decode()