inspect-ai 0.3.49__py3-none-any.whl → 0.3.51__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. inspect_ai/_cli/info.py +2 -2
  2. inspect_ai/_cli/log.py +2 -2
  3. inspect_ai/_cli/score.py +2 -2
  4. inspect_ai/_display/core/display.py +19 -0
  5. inspect_ai/_display/core/panel.py +37 -7
  6. inspect_ai/_display/core/progress.py +29 -2
  7. inspect_ai/_display/core/results.py +79 -40
  8. inspect_ai/_display/core/textual.py +21 -0
  9. inspect_ai/_display/rich/display.py +28 -8
  10. inspect_ai/_display/textual/app.py +107 -1
  11. inspect_ai/_display/textual/display.py +1 -1
  12. inspect_ai/_display/textual/widgets/samples.py +132 -91
  13. inspect_ai/_display/textual/widgets/task_detail.py +236 -0
  14. inspect_ai/_display/textual/widgets/tasks.py +74 -6
  15. inspect_ai/_display/textual/widgets/toggle.py +32 -0
  16. inspect_ai/_eval/context.py +2 -0
  17. inspect_ai/_eval/eval.py +4 -3
  18. inspect_ai/_eval/loader.py +1 -1
  19. inspect_ai/_eval/run.py +35 -2
  20. inspect_ai/_eval/task/log.py +13 -11
  21. inspect_ai/_eval/task/results.py +12 -3
  22. inspect_ai/_eval/task/run.py +139 -36
  23. inspect_ai/_eval/task/sandbox.py +2 -1
  24. inspect_ai/_util/_async.py +30 -1
  25. inspect_ai/_util/file.py +31 -4
  26. inspect_ai/_util/html.py +3 -0
  27. inspect_ai/_util/logger.py +6 -5
  28. inspect_ai/_util/platform.py +5 -6
  29. inspect_ai/_util/registry.py +1 -1
  30. inspect_ai/_view/server.py +9 -9
  31. inspect_ai/_view/www/App.css +2 -2
  32. inspect_ai/_view/www/dist/assets/index.css +2 -2
  33. inspect_ai/_view/www/dist/assets/index.js +352 -294
  34. inspect_ai/_view/www/log-schema.json +13 -0
  35. inspect_ai/_view/www/package.json +1 -0
  36. inspect_ai/_view/www/src/components/MessageBand.mjs +1 -1
  37. inspect_ai/_view/www/src/components/Tools.mjs +16 -13
  38. inspect_ai/_view/www/src/samples/SampleDisplay.mjs +1 -3
  39. inspect_ai/_view/www/src/samples/SampleScoreView.mjs +52 -77
  40. inspect_ai/_view/www/src/samples/SamplesDescriptor.mjs +38 -13
  41. inspect_ai/_view/www/src/samples/transcript/ModelEventView.mjs +15 -2
  42. inspect_ai/_view/www/src/samples/transcript/state/StateEventRenderers.mjs +4 -2
  43. inspect_ai/_view/www/src/types/log.d.ts +2 -0
  44. inspect_ai/_view/www/src/workspace/WorkSpace.mjs +2 -0
  45. inspect_ai/_view/www/yarn.lock +9 -4
  46. inspect_ai/approval/__init__.py +1 -1
  47. inspect_ai/approval/_human/approver.py +35 -0
  48. inspect_ai/approval/_human/console.py +62 -0
  49. inspect_ai/approval/_human/manager.py +108 -0
  50. inspect_ai/approval/_human/panel.py +233 -0
  51. inspect_ai/approval/_human/util.py +51 -0
  52. inspect_ai/dataset/_sources/hf.py +2 -2
  53. inspect_ai/dataset/_sources/util.py +1 -1
  54. inspect_ai/log/_file.py +106 -36
  55. inspect_ai/log/_recorders/eval.py +226 -158
  56. inspect_ai/log/_recorders/file.py +9 -6
  57. inspect_ai/log/_recorders/json.py +35 -12
  58. inspect_ai/log/_recorders/recorder.py +15 -15
  59. inspect_ai/log/_samples.py +52 -0
  60. inspect_ai/model/_model.py +14 -0
  61. inspect_ai/model/_model_output.py +4 -0
  62. inspect_ai/model/_providers/azureai.py +1 -1
  63. inspect_ai/model/_providers/hf.py +106 -4
  64. inspect_ai/model/_providers/util/__init__.py +2 -0
  65. inspect_ai/model/_providers/util/hf_handler.py +200 -0
  66. inspect_ai/scorer/_common.py +1 -1
  67. inspect_ai/solver/_plan.py +0 -8
  68. inspect_ai/solver/_task_state.py +18 -1
  69. inspect_ai/solver/_use_tools.py +9 -1
  70. inspect_ai/tool/_tool_def.py +2 -2
  71. inspect_ai/tool/_tool_info.py +14 -2
  72. inspect_ai/tool/_tool_params.py +2 -1
  73. inspect_ai/tool/_tools/_execute.py +1 -1
  74. inspect_ai/tool/_tools/_web_browser/_web_browser.py +6 -0
  75. inspect_ai/util/__init__.py +5 -6
  76. inspect_ai/util/_panel.py +91 -0
  77. inspect_ai/util/_sandbox/__init__.py +2 -6
  78. inspect_ai/util/_sandbox/context.py +4 -3
  79. inspect_ai/util/_sandbox/docker/compose.py +12 -2
  80. inspect_ai/util/_sandbox/docker/docker.py +19 -9
  81. inspect_ai/util/_sandbox/docker/util.py +10 -2
  82. inspect_ai/util/_sandbox/environment.py +47 -41
  83. inspect_ai/util/_sandbox/local.py +15 -10
  84. inspect_ai/util/_subprocess.py +43 -3
  85. {inspect_ai-0.3.49.dist-info → inspect_ai-0.3.51.dist-info}/METADATA +2 -2
  86. {inspect_ai-0.3.49.dist-info → inspect_ai-0.3.51.dist-info}/RECORD +90 -82
  87. inspect_ai/_view/www/node_modules/flatted/python/flatted.py +0 -149
  88. inspect_ai/_view/www/node_modules/flatted/python/test.py +0 -63
  89. inspect_ai/approval/_human.py +0 -123
  90. {inspect_ai-0.3.49.dist-info → inspect_ai-0.3.51.dist-info}/LICENSE +0 -0
  91. {inspect_ai-0.3.49.dist-info → inspect_ai-0.3.51.dist-info}/WHEEL +0 -0
  92. {inspect_ai-0.3.49.dist-info → inspect_ai-0.3.51.dist-info}/entry_points.txt +0 -0
  93. {inspect_ai-0.3.49.dist-info → inspect_ai-0.3.51.dist-info}/top_level.txt +0 -0
inspect_ai/_cli/info.py CHANGED
@@ -5,7 +5,7 @@ import click
5
5
  from inspect_ai import __version__
6
6
  from inspect_ai._util.constants import PKG_PATH
7
7
  from inspect_ai._view.server import resolve_header_only
8
- from inspect_ai.log._file import eval_log_json, read_eval_log
8
+ from inspect_ai.log._file import eval_log_json_str, read_eval_log
9
9
 
10
10
  from .log import headers, schema, types
11
11
 
@@ -46,7 +46,7 @@ def log(path: str, header_only: int) -> None:
46
46
  header_only = resolve_header_only(path, header_only)
47
47
 
48
48
  log = read_eval_log(path, header_only=header_only)
49
- print(eval_log_json(log))
49
+ print(eval_log_json_str(log))
50
50
 
51
51
 
52
52
  @info_command.command("log-file-headers", hidden=True)
inspect_ai/_cli/log.py CHANGED
@@ -14,7 +14,7 @@ from inspect_ai._util.constants import PKG_PATH
14
14
  from inspect_ai.log import list_eval_logs
15
15
  from inspect_ai.log._convert import convert_eval_logs
16
16
  from inspect_ai.log._file import (
17
- eval_log_json,
17
+ eval_log_json_str,
18
18
  read_eval_log,
19
19
  read_eval_log_headers,
20
20
  )
@@ -127,7 +127,7 @@ def list_command(
127
127
  def dump_command(path: str, header_only: bool) -> None:
128
128
  """Print log file contents as JSON."""
129
129
  log = read_eval_log(path, header_only=header_only)
130
- print(eval_log_json(log))
130
+ print(eval_log_json_str(log))
131
131
 
132
132
 
133
133
  @log_command.command("convert")
inspect_ai/_cli/score.py CHANGED
@@ -61,7 +61,7 @@ async def score(
61
61
 
62
62
  # read the eval log
63
63
  recorder = create_recorder_for_location(log_file, log_dir)
64
- eval_log = recorder.read_log(log_file)
64
+ eval_log = await recorder.read_log(log_file)
65
65
 
66
66
  # check that there are samples therein
67
67
  if eval_log.samples is None or len(eval_log.samples) == 0:
@@ -88,7 +88,7 @@ async def score(
88
88
  scored = f"{SCORED_SUFFIX}{ext}"
89
89
  if not overwrite and not log_file.endswith(scored):
90
90
  log_file = log_file.removesuffix(ext) + scored
91
- recorder.write_log(log_file, eval_log)
91
+ await recorder.write_log(log_file, eval_log)
92
92
 
93
93
  # print results
94
94
  display().print(f"\n{eval_log.eval.task}")
@@ -19,6 +19,8 @@ from rich.console import Console
19
19
  from inspect_ai.log import EvalConfig, EvalResults, EvalStats
20
20
  from inspect_ai.model import GenerateConfig, ModelName
21
21
 
22
+ from ...util._panel import InputPanel
23
+
22
24
 
23
25
  @runtime_checkable
24
26
  class Progress(Protocol):
@@ -81,6 +83,8 @@ class TaskWithResult:
81
83
 
82
84
  TR = TypeVar("TR")
83
85
 
86
+ TP = TypeVar("TP", bound=InputPanel)
87
+
84
88
 
85
89
  class TaskScreen(contextlib.AbstractContextManager["TaskScreen"]):
86
90
  def __exit__(self, *excinfo: Any) -> None:
@@ -95,12 +99,27 @@ class TaskScreen(contextlib.AbstractContextManager["TaskScreen"]):
95
99
  ) -> Iterator[Console]:
96
100
  yield rich.get_console()
97
101
 
102
+ async def input_panel(self, title: str, panel: type[TP]) -> TP:
103
+ raise NotImplementedError("input_panel not implemented by current display")
104
+
105
+
106
+ @dataclass
107
+ class TaskDisplayMetric:
108
+ scorer: str
109
+ name: str
110
+ value: float | int
111
+ reducer: str | None
112
+
98
113
 
99
114
  @runtime_checkable
100
115
  class TaskDisplay(Protocol):
101
116
  @contextlib.contextmanager
102
117
  def progress(self) -> Iterator[Progress]: ...
103
118
 
119
+ def sample_complete(self, complete: int, total: int) -> None: ...
120
+
121
+ def update_metrics(self, scores: list[TaskDisplayMetric]) -> None: ...
122
+
104
123
  def complete(self, result: TaskResult) -> None: ...
105
124
 
106
125
 
@@ -1,3 +1,5 @@
1
+ from typing import Tuple
2
+
1
3
  import rich
2
4
  from rich.console import RenderableType
3
5
  from rich.panel import Panel
@@ -16,6 +18,10 @@ def task_panel(
16
18
  profile: TaskProfile,
17
19
  show_model: bool,
18
20
  body: RenderableType,
21
+ subtitle: RenderableType
22
+ | str
23
+ | Tuple[RenderableType | str, RenderableType | str]
24
+ | None,
19
25
  footer: RenderableType | tuple[RenderableType, RenderableType] | None,
20
26
  log_location: str | None,
21
27
  ) -> Panel:
@@ -25,22 +31,39 @@ def task_panel(
25
31
  width = CONSOLE_DISPLAY_WIDTH if is_vscode_notebook(console) else None
26
32
  jupyter = console.is_jupyter
27
33
 
28
- # setup table
34
+ # root table
29
35
  table = Table.grid(expand=True)
30
36
  table.add_column()
31
- table.add_column(justify="right")
37
+
38
+ # setup table
39
+ if subtitle is not None:
40
+ subtitle_table = Table.grid(expand=True)
41
+ subtitle_table.add_column()
42
+ if isinstance(subtitle, tuple):
43
+ subtitle_table.add_column(justify="right")
44
+ subtitle_table.add_row(
45
+ to_renderable(subtitle[0]), to_renderable(subtitle[1], style=theme.meta)
46
+ )
47
+ else:
48
+ subtitle_table.add_row(to_renderable(subtitle))
49
+
50
+ table.add_row(subtitle_table)
32
51
 
33
52
  # main progress and task info
34
- targets = Text.from_markup(task_targets(profile), style=theme.meta)
35
- table.add_row(body, targets)
53
+ table.add_row()
54
+ table.add_row(body)
55
+ table.add_row()
36
56
 
37
57
  # footer if specified
38
58
  if footer:
39
- table.add_row()
59
+ footer_table = Table.grid(expand=True)
60
+ footer_table.add_column()
40
61
  if isinstance(footer, tuple):
41
- table.add_row(footer[0], footer[1])
62
+ footer_table.add_column(justify="right")
63
+ footer_table.add_row(footer[0], footer[1])
42
64
  else:
43
- table.add_row(footer)
65
+ footer_table.add_row(footer)
66
+ table.add_row(footer_table)
44
67
 
45
68
  # enclose in outer table for log link footer
46
69
  root = table
@@ -75,6 +98,13 @@ def task_panel(
75
98
  return panel
76
99
 
77
100
 
101
+ def to_renderable(item: RenderableType | str, style: str = "") -> RenderableType:
102
+ if isinstance(item, str):
103
+ return Text.from_markup(item, style=style)
104
+ else:
105
+ return item
106
+
107
+
78
108
  def tasks_title(completed: int, total: int) -> str:
79
109
  return f"{completed}/{total} tasks complete"
80
110
 
@@ -32,13 +32,20 @@ class RichProgress(Progress):
32
32
  model: str = "",
33
33
  status: Callable[[], str] | None = None,
34
34
  on_update: Callable[[], None] | None = None,
35
+ count: str = "",
36
+ score: str = "",
35
37
  ) -> None:
36
38
  self.total = total
37
39
  self.progress = progress
38
40
  self.status = status if status else lambda: ""
39
41
  self.on_update = on_update
40
42
  self.task_id = progress.add_task(
41
- description, total=PROGRESS_TOTAL, model=model, status=self.status()
43
+ description,
44
+ total=PROGRESS_TOTAL,
45
+ model=model,
46
+ status=self.status(),
47
+ count=count,
48
+ score=score,
42
49
  )
43
50
 
44
51
  @override
@@ -56,6 +63,16 @@ class RichProgress(Progress):
56
63
  task_id=self.task_id, completed=PROGRESS_TOTAL, status=self.status()
57
64
  )
58
65
 
66
+ def update_count(self, complete: int, total: int) -> None:
67
+ self.progress.update(
68
+ task_id=self.task_id, count=progress_count(complete, total), refresh=True
69
+ )
70
+ if self.on_update:
71
+ self.on_update()
72
+
73
+ def update_score(self, score: str) -> None:
74
+ self.progress.update(task_id=self.task_id, score=score)
75
+
59
76
 
60
77
  def rich_progress() -> RProgress:
61
78
  console = rich.get_console()
@@ -65,10 +82,12 @@ def rich_progress() -> RProgress:
65
82
  TextColumn("{task.fields[model]}"),
66
83
  BarColumn(bar_width=40 if is_vscode_notebook(console) else None),
67
84
  TaskProgressColumn(),
85
+ TextColumn("{task.fields[count]}"),
86
+ TextColumn("{task.fields[score]}"),
68
87
  TimeElapsedColumn(),
69
88
  transient=True,
70
89
  console=console,
71
- expand=not is_vscode_notebook(console),
90
+ expand=True,
72
91
  )
73
92
 
74
93
 
@@ -109,3 +128,11 @@ def progress_time(time: float) -> str:
109
128
  minutes, seconds = divmod(time, 60)
110
129
  hours, minutes = divmod(minutes, 60)
111
130
  return f"{hours:2.0f}:{minutes:02.0f}:{seconds:02.0f}"
131
+
132
+
133
+ def progress_count(complete: int, total: int) -> str:
134
+ # Pad the display to keep it stable
135
+ total_str = f"{total:,}"
136
+ complete_str = f"{complete:,}"
137
+ padding = max(0, len(total_str) - len(complete_str))
138
+ return " " * padding + f"[{complete_str}/{total_str}]"
@@ -1,22 +1,24 @@
1
1
  from datetime import datetime
2
2
  from typing import Sequence, Set
3
3
 
4
+ import numpy as np
4
5
  from rich.console import Group, RenderableType
5
6
  from rich.table import Table
6
7
  from rich.text import Text
7
8
 
8
9
  from inspect_ai.log import EvalStats
9
- from inspect_ai.log._log import rich_traceback
10
+ from inspect_ai.log._log import EvalScore, rich_traceback
10
11
 
11
12
  from .config import task_config, task_dict
12
13
  from .display import (
13
14
  TaskCancelled,
15
+ TaskDisplayMetric,
14
16
  TaskError,
15
17
  TaskProfile,
16
18
  TaskSuccess,
17
19
  TaskWithResult,
18
20
  )
19
- from .panel import task_panel
21
+ from .panel import task_panel, task_targets
20
22
  from .rich import rich_theme
21
23
 
22
24
 
@@ -37,10 +39,18 @@ def tasks_results(tasks: Sequence[TaskWithResult]) -> RenderableType:
37
39
  def task_result_cancelled(
38
40
  profile: TaskProfile, cancelled: TaskCancelled
39
41
  ) -> RenderableType:
42
+ # The contents of the panel
43
+ config = task_config(profile)
44
+ targets = task_targets(profile)
45
+ subtitle = config, targets
46
+ body = task_stats(cancelled.stats)
47
+
48
+ # The panel
40
49
  return task_panel(
41
50
  profile=profile,
42
51
  show_model=True,
43
- body=task_stats(profile, cancelled.stats),
52
+ body=body,
53
+ subtitle=subtitle,
44
54
  footer=task_interrupted(profile, cancelled.samples_completed),
45
55
  log_location=profile.log_location,
46
56
  )
@@ -50,36 +60,7 @@ def task_results(profile: TaskProfile, success: TaskSuccess) -> RenderableType:
50
60
  theme = rich_theme()
51
61
 
52
62
  # do we have more than one scorer name?
53
- results = success.results
54
- scorer_names: Set[str] = {score.name for score in results.scores}
55
- reducer_names: Set[str] = {
56
- score.reducer for score in results.scores if score.reducer is not None
57
- }
58
- show_reducer = len(reducer_names) > 1 or "avg" not in reducer_names
59
- output: dict[str, str] = {}
60
- for score in results.scores:
61
- for name, metric in score.metrics.items():
62
- value = (
63
- "1.0"
64
- if metric.value == 1
65
- else (
66
- str(metric.value)
67
- if isinstance(metric.value, int)
68
- else f"{metric.value:.3g}"
69
- )
70
- )
71
- name = (
72
- rf"{name}\[{score.reducer}]"
73
- if show_reducer and score.reducer is not None
74
- else name
75
- )
76
- key = f"{score.name}/{name}" if (len(scorer_names) > 1) else name
77
- output[key] = value
78
-
79
- if output:
80
- message = f"[{theme.metric}]{task_dict(output, True)}[/{theme.metric}]"
81
- else:
82
- message = ""
63
+ message = task_metrics(success.results.scores)
83
64
 
84
65
  # note if some of our samples had errors
85
66
  if success.samples_completed < profile.samples:
@@ -93,10 +74,18 @@ def task_results(profile: TaskProfile, success: TaskSuccess) -> RenderableType:
93
74
 
94
75
 
95
76
  def task_result_summary(profile: TaskProfile, success: TaskSuccess) -> RenderableType:
77
+ # The contents of the panel
78
+ config = task_config(profile)
79
+ targets = task_targets(profile)
80
+ subtitle = config, targets
81
+ body = task_stats(success.stats)
82
+
83
+ # the panel
96
84
  return task_panel(
97
85
  profile=profile,
98
86
  show_model=True,
99
- body=task_stats(profile, success.stats),
87
+ body=body,
88
+ subtitle=subtitle,
100
89
  footer=task_results(profile, success),
101
90
  log_location=profile.log_location,
102
91
  )
@@ -107,20 +96,17 @@ def task_result_error(profile: TaskProfile, error: TaskError) -> RenderableType:
107
96
  profile=profile,
108
97
  show_model=True,
109
98
  body=rich_traceback(error.exc_type, error.exc_value, error.traceback),
99
+ subtitle=None,
110
100
  footer=task_interrupted(profile, error.samples_completed),
111
101
  log_location=profile.log_location,
112
102
  )
113
103
 
114
104
 
115
- def task_stats(profile: TaskProfile, stats: EvalStats) -> RenderableType:
105
+ def task_stats(stats: EvalStats) -> RenderableType:
116
106
  theme = rich_theme()
117
107
  panel = Table.grid(expand=True)
118
108
  panel.add_column()
119
- config = task_config(profile)
120
- if config:
121
- panel.add_row(config)
122
- panel.add_row()
123
- elif len(stats.model_usage) < 2:
109
+ if len(stats.model_usage) < 2:
124
110
  panel.add_row()
125
111
 
126
112
  table = Table.grid(expand=True)
@@ -178,3 +164,56 @@ def task_interrupted(profile: TaskProfile, samples_completed: int) -> Renderable
178
164
  )
179
165
 
180
166
  return message
167
+
168
+
169
+ def task_metric(metrics: list[TaskDisplayMetric]) -> str:
170
+ reducer_names: Set[str] = {
171
+ metric.reducer for metric in metrics if metric.reducer is not None
172
+ }
173
+ show_reducer = len(reducer_names) > 1 or (
174
+ len(reducer_names) == 1 and "avg" not in reducer_names
175
+ )
176
+
177
+ metric = metrics[0]
178
+ if np.isnan(metric.value):
179
+ value = " n/a"
180
+ else:
181
+ value = f"{metric.value:.2f}"
182
+
183
+ if show_reducer:
184
+ return f"{metric.name}/{metric.reducer}: {value}"
185
+ else:
186
+ return f"{metric.name}: {value}"
187
+
188
+
189
+ def task_metrics(scores: list[EvalScore]) -> str:
190
+ theme = rich_theme()
191
+ scorer_names: Set[str] = {score.name for score in scores}
192
+ reducer_names: Set[str] = {
193
+ score.reducer for score in scores if score.reducer is not None
194
+ }
195
+ show_reducer = len(reducer_names) > 1 or "avg" not in reducer_names
196
+ output: dict[str, str] = {}
197
+ for score in scores:
198
+ for name, metric in score.metrics.items():
199
+ value = (
200
+ "1.0"
201
+ if metric.value == 1
202
+ else (
203
+ str(metric.value)
204
+ if isinstance(metric.value, int)
205
+ else f"{metric.value:.3g}"
206
+ )
207
+ )
208
+ name = (
209
+ rf"{name}\[{score.reducer}]"
210
+ if show_reducer and score.reducer is not None
211
+ else name
212
+ )
213
+ key = f"{score.name}/{name}" if (len(scorer_names) > 1) else name
214
+ output[key] = value
215
+
216
+ if output:
217
+ return f"[{theme.metric}]{task_dict(output, True)}[/{theme.metric}]"
218
+ else:
219
+ return ""
@@ -0,0 +1,21 @@
1
+ from logging import getLogger
2
+
3
+ from textual.driver import Driver
4
+
5
+ logger = getLogger(__name__)
6
+
7
+
8
+ # force mouse support for textual -- this works around an issue where
9
+ # mouse events are disabled after a reload of the vs code ide, see:
10
+ # https://github.com/Textualize/textual/issues/5380
11
+ # ansi codes for enabling mouse support are idempotent so it is fine
12
+ # to do this even in cases where mouse support is already enabled.
13
+ # we try/catch since we aren't 100% sure there aren't cases where doing
14
+ # this won't raise and we'd rather not fail hard in in these case
15
+ def textual_enable_mouse_support(driver: Driver) -> None:
16
+ enable_mouse_support = getattr(driver, "_enable_mouse_support", None)
17
+ if enable_mouse_support:
18
+ try:
19
+ enable_mouse_support()
20
+ except Exception as ex:
21
+ logger.warning(f"Error enabling mouse support: {ex}")
@@ -4,7 +4,7 @@ from dataclasses import dataclass
4
4
  from typing import Any, AsyncIterator, Callable, Coroutine, Iterator
5
5
 
6
6
  import rich
7
- from rich.console import Console, Group, RenderableType
7
+ from rich.console import Console, RenderableType
8
8
  from rich.live import Live
9
9
  from rich.panel import Panel
10
10
  from rich.progress import Progress as RProgress
@@ -23,6 +23,7 @@ from ..core.display import (
23
23
  Display,
24
24
  Progress,
25
25
  TaskDisplay,
26
+ TaskDisplayMetric,
26
27
  TaskProfile,
27
28
  TaskResult,
28
29
  TaskScreen,
@@ -30,7 +31,7 @@ from ..core.display import (
30
31
  TaskWithResult,
31
32
  )
32
33
  from ..core.footer import task_footer
33
- from ..core.panel import task_panel, task_title, tasks_title
34
+ from ..core.panel import task_panel, task_targets, task_title, tasks_title
34
35
  from ..core.progress import (
35
36
  RichProgress,
36
37
  progress_description,
@@ -38,7 +39,7 @@ from ..core.progress import (
38
39
  progress_status_icon,
39
40
  rich_progress,
40
41
  )
41
- from ..core.results import tasks_results
42
+ from ..core.results import task_metric, tasks_results
42
43
  from ..core.rich import (
43
44
  is_vscode_notebook,
44
45
  record_console_input,
@@ -275,6 +276,15 @@ class RichTaskDisplay(TaskDisplay):
275
276
  def progress(self) -> Iterator[Progress]:
276
277
  yield self.p
277
278
 
279
+ @override
280
+ def sample_complete(self, complete: int, total: int) -> None:
281
+ self.p.update_count(complete, total)
282
+
283
+ @override
284
+ def update_metrics(self, metrics: list[TaskDisplayMetric]) -> None:
285
+ if len(metrics) > 0:
286
+ self.p.update_score(task_metric(metrics))
287
+
278
288
  @override
279
289
  def complete(self, result: TaskResult) -> None:
280
290
  self.status.result = result
@@ -283,15 +293,18 @@ class RichTaskDisplay(TaskDisplay):
283
293
 
284
294
  def task_live_status(tasks: list[TaskStatus], progress: RProgress) -> RenderableType:
285
295
  theme = rich_theme()
286
- body: list[RenderableType] = ["", progress]
296
+
297
+ # the panel contents
287
298
  config = task_config(tasks[0].profile, style=theme.light)
288
- if config:
289
- body = [config] + body
299
+ targets = task_targets(tasks[0].profile)
300
+ subtitle = config, targets
290
301
 
302
+ # the panel
291
303
  return task_panel(
292
304
  profile=tasks[0].profile,
293
305
  show_model=len(tasks) == 1,
294
- body=Group(*body),
306
+ body=progress,
307
+ subtitle=subtitle,
295
308
  footer=task_footer(theme.light),
296
309
  log_location=None,
297
310
  )
@@ -321,9 +334,16 @@ def tasks_live_status(
321
334
  footer_table.add_row()
322
335
  footer_table.add_row(footer[0], footer[1])
323
336
 
337
+ # build a layout table
338
+ layout_table = Table.grid(expand=True)
339
+ layout_table.add_column()
340
+ layout_table.add_row(config)
341
+ layout_table.add_row(progress)
342
+ layout_table.add_row(footer_table)
343
+
324
344
  # create panel w/ title
325
345
  panel = Panel(
326
- Group(config, progress, footer_table, fit=False),
346
+ layout_table,
327
347
  title=f"[bold][{theme.meta}]{tasks_title(completed, total_tasks)}[/{theme.meta}][/bold]",
328
348
  title_align="left",
329
349
  width=width,