inspect-ai 0.3.49__py3-none-any.whl → 0.3.51__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/_cli/info.py +2 -2
- inspect_ai/_cli/log.py +2 -2
- inspect_ai/_cli/score.py +2 -2
- inspect_ai/_display/core/display.py +19 -0
- inspect_ai/_display/core/panel.py +37 -7
- inspect_ai/_display/core/progress.py +29 -2
- inspect_ai/_display/core/results.py +79 -40
- inspect_ai/_display/core/textual.py +21 -0
- inspect_ai/_display/rich/display.py +28 -8
- inspect_ai/_display/textual/app.py +107 -1
- inspect_ai/_display/textual/display.py +1 -1
- inspect_ai/_display/textual/widgets/samples.py +132 -91
- inspect_ai/_display/textual/widgets/task_detail.py +236 -0
- inspect_ai/_display/textual/widgets/tasks.py +74 -6
- inspect_ai/_display/textual/widgets/toggle.py +32 -0
- inspect_ai/_eval/context.py +2 -0
- inspect_ai/_eval/eval.py +4 -3
- inspect_ai/_eval/loader.py +1 -1
- inspect_ai/_eval/run.py +35 -2
- inspect_ai/_eval/task/log.py +13 -11
- inspect_ai/_eval/task/results.py +12 -3
- inspect_ai/_eval/task/run.py +139 -36
- inspect_ai/_eval/task/sandbox.py +2 -1
- inspect_ai/_util/_async.py +30 -1
- inspect_ai/_util/file.py +31 -4
- inspect_ai/_util/html.py +3 -0
- inspect_ai/_util/logger.py +6 -5
- inspect_ai/_util/platform.py +5 -6
- inspect_ai/_util/registry.py +1 -1
- inspect_ai/_view/server.py +9 -9
- inspect_ai/_view/www/App.css +2 -2
- inspect_ai/_view/www/dist/assets/index.css +2 -2
- inspect_ai/_view/www/dist/assets/index.js +352 -294
- inspect_ai/_view/www/log-schema.json +13 -0
- inspect_ai/_view/www/package.json +1 -0
- inspect_ai/_view/www/src/components/MessageBand.mjs +1 -1
- inspect_ai/_view/www/src/components/Tools.mjs +16 -13
- inspect_ai/_view/www/src/samples/SampleDisplay.mjs +1 -3
- inspect_ai/_view/www/src/samples/SampleScoreView.mjs +52 -77
- inspect_ai/_view/www/src/samples/SamplesDescriptor.mjs +38 -13
- inspect_ai/_view/www/src/samples/transcript/ModelEventView.mjs +15 -2
- inspect_ai/_view/www/src/samples/transcript/state/StateEventRenderers.mjs +4 -2
- inspect_ai/_view/www/src/types/log.d.ts +2 -0
- inspect_ai/_view/www/src/workspace/WorkSpace.mjs +2 -0
- inspect_ai/_view/www/yarn.lock +9 -4
- inspect_ai/approval/__init__.py +1 -1
- inspect_ai/approval/_human/approver.py +35 -0
- inspect_ai/approval/_human/console.py +62 -0
- inspect_ai/approval/_human/manager.py +108 -0
- inspect_ai/approval/_human/panel.py +233 -0
- inspect_ai/approval/_human/util.py +51 -0
- inspect_ai/dataset/_sources/hf.py +2 -2
- inspect_ai/dataset/_sources/util.py +1 -1
- inspect_ai/log/_file.py +106 -36
- inspect_ai/log/_recorders/eval.py +226 -158
- inspect_ai/log/_recorders/file.py +9 -6
- inspect_ai/log/_recorders/json.py +35 -12
- inspect_ai/log/_recorders/recorder.py +15 -15
- inspect_ai/log/_samples.py +52 -0
- inspect_ai/model/_model.py +14 -0
- inspect_ai/model/_model_output.py +4 -0
- inspect_ai/model/_providers/azureai.py +1 -1
- inspect_ai/model/_providers/hf.py +106 -4
- inspect_ai/model/_providers/util/__init__.py +2 -0
- inspect_ai/model/_providers/util/hf_handler.py +200 -0
- inspect_ai/scorer/_common.py +1 -1
- inspect_ai/solver/_plan.py +0 -8
- inspect_ai/solver/_task_state.py +18 -1
- inspect_ai/solver/_use_tools.py +9 -1
- inspect_ai/tool/_tool_def.py +2 -2
- inspect_ai/tool/_tool_info.py +14 -2
- inspect_ai/tool/_tool_params.py +2 -1
- inspect_ai/tool/_tools/_execute.py +1 -1
- inspect_ai/tool/_tools/_web_browser/_web_browser.py +6 -0
- inspect_ai/util/__init__.py +5 -6
- inspect_ai/util/_panel.py +91 -0
- inspect_ai/util/_sandbox/__init__.py +2 -6
- inspect_ai/util/_sandbox/context.py +4 -3
- inspect_ai/util/_sandbox/docker/compose.py +12 -2
- inspect_ai/util/_sandbox/docker/docker.py +19 -9
- inspect_ai/util/_sandbox/docker/util.py +10 -2
- inspect_ai/util/_sandbox/environment.py +47 -41
- inspect_ai/util/_sandbox/local.py +15 -10
- inspect_ai/util/_subprocess.py +43 -3
- {inspect_ai-0.3.49.dist-info → inspect_ai-0.3.51.dist-info}/METADATA +2 -2
- {inspect_ai-0.3.49.dist-info → inspect_ai-0.3.51.dist-info}/RECORD +90 -82
- inspect_ai/_view/www/node_modules/flatted/python/flatted.py +0 -149
- inspect_ai/_view/www/node_modules/flatted/python/test.py +0 -63
- inspect_ai/approval/_human.py +0 -123
- {inspect_ai-0.3.49.dist-info → inspect_ai-0.3.51.dist-info}/LICENSE +0 -0
- {inspect_ai-0.3.49.dist-info → inspect_ai-0.3.51.dist-info}/WHEEL +0 -0
- {inspect_ai-0.3.49.dist-info → inspect_ai-0.3.51.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.49.dist-info → inspect_ai-0.3.51.dist-info}/top_level.txt +0 -0
inspect_ai/_eval/task/run.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
import asyncio
|
2
2
|
import contextlib
|
3
3
|
import sys
|
4
|
+
import time
|
4
5
|
from copy import deepcopy
|
5
6
|
from dataclasses import dataclass, field
|
6
7
|
from logging import getLogger
|
@@ -16,6 +17,7 @@ from inspect_ai._display import (
|
|
16
17
|
TaskSuccess,
|
17
18
|
display,
|
18
19
|
)
|
20
|
+
from inspect_ai._display.core.display import TaskDisplay, TaskDisplayMetric
|
19
21
|
from inspect_ai._util.constants import (
|
20
22
|
DEFAULT_EPOCHS,
|
21
23
|
DEFAULT_MAX_CONNECTIONS,
|
@@ -40,7 +42,7 @@ from inspect_ai.log import (
|
|
40
42
|
EvalStats,
|
41
43
|
)
|
42
44
|
from inspect_ai.log._condense import condense_sample
|
43
|
-
from inspect_ai.log._file import
|
45
|
+
from inspect_ai.log._file import eval_log_json_str
|
44
46
|
from inspect_ai.log._log import EvalSampleLimit, EvalSampleReductions, eval_error
|
45
47
|
from inspect_ai.log._samples import active_sample
|
46
48
|
from inspect_ai.log._transcript import (
|
@@ -60,7 +62,8 @@ from inspect_ai.model import (
|
|
60
62
|
)
|
61
63
|
from inspect_ai.model._model import init_sample_model_usage, sample_model_usage
|
62
64
|
from inspect_ai.scorer import Scorer, Target
|
63
|
-
from inspect_ai.scorer._metric import SampleScore, Score
|
65
|
+
from inspect_ai.scorer._metric import Metric, SampleScore, Score
|
66
|
+
from inspect_ai.scorer._reducer.types import ScoreReducer
|
64
67
|
from inspect_ai.scorer._score import init_scoring_context
|
65
68
|
from inspect_ai.scorer._scorer import unique_scorer_name
|
66
69
|
from inspect_ai.solver import Generate, Plan, TaskState
|
@@ -92,6 +95,12 @@ py_logger = getLogger(__name__)
|
|
92
95
|
|
93
96
|
EvalSampleSource = Callable[[int | str, int], EvalSample | None]
|
94
97
|
|
98
|
+
# Units allocated for sample progress - the total units
|
99
|
+
# represents the total units of progress for an individual sample
|
100
|
+
# the remainder are increments of progress within a sample (and
|
101
|
+
# must sum to the total_progress_units when the sample is complete)
|
102
|
+
SAMPLE_TOTAL_PROGRESS_UNITS = 1
|
103
|
+
|
95
104
|
|
96
105
|
@dataclass
|
97
106
|
class TaskRunOptions:
|
@@ -135,8 +144,6 @@ async def task_run(options: TaskRunOptions) -> EvalLog:
|
|
135
144
|
results: EvalResults | None = None
|
136
145
|
reductions: list[EvalSampleReductions] | None = None
|
137
146
|
stats = EvalStats(started_at=iso_now())
|
138
|
-
error: EvalError | None = None
|
139
|
-
cancelled = False
|
140
147
|
|
141
148
|
# handle sample errors (raise as required)
|
142
149
|
sample_error_handler = SampleErrorHandler(
|
@@ -183,11 +190,6 @@ async def task_run(options: TaskRunOptions) -> EvalLog:
|
|
183
190
|
else ["(none)"]
|
184
191
|
)
|
185
192
|
|
186
|
-
# compute steps (steps = samples * steps in plan + 1 for scorer)
|
187
|
-
steps = len(samples) * (
|
188
|
-
len(plan.steps) + (1 if plan.finish else 0) + (1) # scorer
|
189
|
-
)
|
190
|
-
|
191
193
|
# compute an eval directory relative log location if we can
|
192
194
|
if PurePath(logger.location).is_relative_to(PurePath(eval_wd)):
|
193
195
|
log_location = PurePath(logger.location).relative_to(eval_wd).as_posix()
|
@@ -202,7 +204,7 @@ async def task_run(options: TaskRunOptions) -> EvalLog:
|
|
202
204
|
dataset=task.dataset.name or "(samples)",
|
203
205
|
scorer=", ".join(scorer_profiles),
|
204
206
|
samples=len(samples),
|
205
|
-
steps=
|
207
|
+
steps=len(samples) * SAMPLE_TOTAL_PROGRESS_UNITS,
|
206
208
|
eval_config=config,
|
207
209
|
task_args=logger.eval.task_args,
|
208
210
|
generate_config=generate_config,
|
@@ -213,12 +215,12 @@ async def task_run(options: TaskRunOptions) -> EvalLog:
|
|
213
215
|
with display().task(profile) as td:
|
214
216
|
try:
|
215
217
|
# start the log
|
216
|
-
log_start(logger, plan, generate_config)
|
218
|
+
await log_start(logger, plan, generate_config)
|
217
219
|
|
218
220
|
with td.progress() as p:
|
219
221
|
# forward progress
|
220
|
-
def progress() -> None:
|
221
|
-
p.update(
|
222
|
+
def progress(number: int) -> None:
|
223
|
+
p.update(number)
|
222
224
|
|
223
225
|
# provide solvers a function that they can use to generate output
|
224
226
|
async def generate(
|
@@ -243,6 +245,28 @@ async def task_run(options: TaskRunOptions) -> EvalLog:
|
|
243
245
|
config, generate_config, model.api
|
244
246
|
)
|
245
247
|
|
248
|
+
# track when samples complete and update progress as we go
|
249
|
+
progress_results: list[dict[str, SampleScore]] = []
|
250
|
+
update_metrics_display = update_metrics_display_fn(td)
|
251
|
+
|
252
|
+
def sample_complete(sample_score: dict[str, SampleScore]) -> None:
|
253
|
+
# Capture the result
|
254
|
+
progress_results.append(sample_score)
|
255
|
+
|
256
|
+
# Increment the segment progress
|
257
|
+
td.sample_complete(
|
258
|
+
complete=len(progress_results), total=len(samples)
|
259
|
+
)
|
260
|
+
|
261
|
+
# Update metrics
|
262
|
+
update_metrics_display(
|
263
|
+
len(progress_results),
|
264
|
+
progress_results,
|
265
|
+
scorers,
|
266
|
+
task.epochs_reducer,
|
267
|
+
task.metrics,
|
268
|
+
)
|
269
|
+
|
246
270
|
# create sample coroutines
|
247
271
|
sample_coroutines = [
|
248
272
|
task_run_sample(
|
@@ -259,6 +283,7 @@ async def task_run(options: TaskRunOptions) -> EvalLog:
|
|
259
283
|
log_images=log_images,
|
260
284
|
sample_source=sample_source,
|
261
285
|
sample_error=sample_error_handler,
|
286
|
+
sample_complete=sample_complete,
|
262
287
|
fails_on_error=(
|
263
288
|
config.fail_on_error is None
|
264
289
|
or config.fail_on_error is True
|
@@ -269,7 +294,18 @@ async def task_run(options: TaskRunOptions) -> EvalLog:
|
|
269
294
|
for (sample, state) in zip(samples, states)
|
270
295
|
]
|
271
296
|
|
272
|
-
#
|
297
|
+
# initial progress
|
298
|
+
td.sample_complete(complete=0, total=len(samples))
|
299
|
+
|
300
|
+
# Update metrics to empty state
|
301
|
+
update_metrics_display(
|
302
|
+
len(progress_results),
|
303
|
+
progress_results,
|
304
|
+
scorers,
|
305
|
+
task.epochs_reducer,
|
306
|
+
task.metrics,
|
307
|
+
)
|
308
|
+
|
273
309
|
sample_results = await asyncio.gather(*sample_coroutines)
|
274
310
|
|
275
311
|
# compute and record metrics if we have scores
|
@@ -291,6 +327,11 @@ async def task_run(options: TaskRunOptions) -> EvalLog:
|
|
291
327
|
# collect eval data
|
292
328
|
collect_eval_data(stats)
|
293
329
|
|
330
|
+
# finish w/ success status
|
331
|
+
eval_log = await logger.log_finish(
|
332
|
+
"success", stats, results, reductions
|
333
|
+
)
|
334
|
+
|
294
335
|
# display task summary
|
295
336
|
td.complete(
|
296
337
|
TaskSuccess(
|
@@ -301,12 +342,14 @@ async def task_run(options: TaskRunOptions) -> EvalLog:
|
|
301
342
|
)
|
302
343
|
|
303
344
|
except asyncio.CancelledError:
|
304
|
-
# flag as cancelled
|
305
|
-
cancelled = True
|
306
|
-
|
307
345
|
# collect eval data
|
308
346
|
collect_eval_data(stats)
|
309
347
|
|
348
|
+
# finish w/ cancelled status
|
349
|
+
eval_log = await logger.log_finish(
|
350
|
+
"cancelled", stats, results, reductions
|
351
|
+
)
|
352
|
+
|
310
353
|
# display task cancelled
|
311
354
|
td.complete(TaskCancelled(logger.samples_completed, stats))
|
312
355
|
|
@@ -325,25 +368,22 @@ async def task_run(options: TaskRunOptions) -> EvalLog:
|
|
325
368
|
# collect eval data
|
326
369
|
collect_eval_data(stats)
|
327
370
|
|
371
|
+
# finish with error status
|
372
|
+
eval_log = await logger.log_finish(
|
373
|
+
"error", stats, results, reductions, error
|
374
|
+
)
|
375
|
+
|
328
376
|
# display it
|
329
377
|
td.complete(
|
330
378
|
TaskError(logger.samples_completed, type, value, traceback)
|
331
379
|
)
|
332
380
|
|
333
|
-
# log as appropriate
|
334
|
-
if cancelled:
|
335
|
-
eval_log = logger.log_finish("cancelled", stats, results, reductions)
|
336
|
-
elif error:
|
337
|
-
eval_log = logger.log_finish("error", stats, results, reductions, error)
|
338
|
-
else:
|
339
|
-
eval_log = logger.log_finish("success", stats, results, reductions)
|
340
|
-
|
341
381
|
# notify the view module that an eval just completed
|
342
382
|
# (in case we have a view polling for new evals)
|
343
383
|
view_notify_eval(logger.location)
|
344
384
|
|
345
385
|
try:
|
346
|
-
await send_telemetry("eval_log",
|
386
|
+
await send_telemetry("eval_log", eval_log_json_str(eval_log))
|
347
387
|
except Exception as ex:
|
348
388
|
py_logger.warning(
|
349
389
|
f"Error occurred sending telemetry: {exception_message(ex)}"
|
@@ -353,6 +393,63 @@ async def task_run(options: TaskRunOptions) -> EvalLog:
|
|
353
393
|
return eval_log
|
354
394
|
|
355
395
|
|
396
|
+
def update_metrics_display_fn(
|
397
|
+
td: TaskDisplay, initial_interval: float = 0, min_interval: float = 0.9
|
398
|
+
) -> Callable[
|
399
|
+
[
|
400
|
+
int,
|
401
|
+
list[dict[str, SampleScore]],
|
402
|
+
list[Scorer] | None,
|
403
|
+
ScoreReducer | list[ScoreReducer] | None,
|
404
|
+
list[Metric] | dict[str, list[Metric]] | None,
|
405
|
+
],
|
406
|
+
None,
|
407
|
+
]:
|
408
|
+
next_compute_time = time.perf_counter() + initial_interval
|
409
|
+
|
410
|
+
def compute(
|
411
|
+
sample_count: int,
|
412
|
+
sample_scores: list[dict[str, SampleScore]],
|
413
|
+
scorers: list[Scorer] | None,
|
414
|
+
reducers: ScoreReducer | list[ScoreReducer] | None,
|
415
|
+
metrics: list[Metric] | dict[str, list[Metric]] | None,
|
416
|
+
) -> None:
|
417
|
+
nonlocal next_compute_time
|
418
|
+
time_start = time.perf_counter()
|
419
|
+
if time_start >= next_compute_time:
|
420
|
+
# compute metrics
|
421
|
+
results, reductions = eval_results(
|
422
|
+
samples=sample_count,
|
423
|
+
scores=sample_scores,
|
424
|
+
reducers=reducers,
|
425
|
+
scorers=scorers,
|
426
|
+
metrics=metrics,
|
427
|
+
)
|
428
|
+
|
429
|
+
# Name, reducer, value
|
430
|
+
task_metrics = []
|
431
|
+
if len(results.scores) > 0:
|
432
|
+
for score in results.scores:
|
433
|
+
for key, metric in score.metrics.items():
|
434
|
+
task_metrics.append(
|
435
|
+
TaskDisplayMetric(
|
436
|
+
scorer=score.name,
|
437
|
+
name=metric.name,
|
438
|
+
value=metric.value,
|
439
|
+
reducer=score.reducer,
|
440
|
+
)
|
441
|
+
)
|
442
|
+
td.update_metrics(task_metrics)
|
443
|
+
|
444
|
+
# determine how long to wait before recomputing metrics
|
445
|
+
time_end = time.perf_counter()
|
446
|
+
elapsed_time = time_end - time_start
|
447
|
+
wait = max(min_interval, elapsed_time * 10)
|
448
|
+
next_compute_time = time_end + wait
|
449
|
+
|
450
|
+
return compute
|
451
|
+
|
452
|
+
|
356
453
|
async def task_run_sample(
|
357
454
|
task_name: str,
|
358
455
|
sample: Sample,
|
@@ -362,11 +459,12 @@ async def task_run_sample(
|
|
362
459
|
plan: Plan,
|
363
460
|
scorers: list[Scorer] | None,
|
364
461
|
generate: Generate,
|
365
|
-
progress: Callable[
|
462
|
+
progress: Callable[[int], None],
|
366
463
|
logger: TaskLogger | None,
|
367
464
|
log_images: bool,
|
368
465
|
sample_source: EvalSampleSource | None,
|
369
466
|
sample_error: Callable[[BaseException], EvalError],
|
467
|
+
sample_complete: Callable[[dict[str, SampleScore]], None],
|
370
468
|
fails_on_error: bool,
|
371
469
|
time_limit: int | None,
|
372
470
|
semaphore: asyncio.Semaphore | None,
|
@@ -375,12 +473,12 @@ async def task_run_sample(
|
|
375
473
|
if sample_source and sample.id is not None:
|
376
474
|
previous_sample = sample_source(sample.id, state.epoch)
|
377
475
|
if previous_sample:
|
378
|
-
# tick off progress
|
379
|
-
|
380
|
-
|
476
|
+
# tick off progress for this sample
|
477
|
+
progress(SAMPLE_TOTAL_PROGRESS_UNITS)
|
478
|
+
|
381
479
|
# log if requested
|
382
480
|
if logger:
|
383
|
-
logger.log_sample(previous_sample, flush=False)
|
481
|
+
await logger.log_sample(previous_sample, flush=False)
|
384
482
|
|
385
483
|
# return score
|
386
484
|
if previous_sample.scores:
|
@@ -436,6 +534,9 @@ async def task_run_sample(
|
|
436
534
|
model=str(state.model),
|
437
535
|
sample=sample,
|
438
536
|
epoch=state.epoch,
|
537
|
+
message_limit=state.message_limit,
|
538
|
+
token_limit=state.token_limit,
|
539
|
+
time_limit=time_limit,
|
439
540
|
fails_on_error=fails_on_error,
|
440
541
|
transcript=sample_transcript,
|
441
542
|
) as active,
|
@@ -454,7 +555,6 @@ async def task_run_sample(
|
|
454
555
|
)
|
455
556
|
|
456
557
|
# set progress for plan then run it
|
457
|
-
plan.progress = progress
|
458
558
|
state = await plan(state, generate)
|
459
559
|
|
460
560
|
except TimeoutError:
|
@@ -562,7 +662,8 @@ async def task_run_sample(
|
|
562
662
|
# handle error (this will throw if we've exceeded the limit)
|
563
663
|
error = handle_error(ex)
|
564
664
|
|
565
|
-
|
665
|
+
# complete the sample
|
666
|
+
progress(SAMPLE_TOTAL_PROGRESS_UNITS)
|
566
667
|
|
567
668
|
# log it
|
568
669
|
if logger is not None:
|
@@ -576,7 +677,7 @@ async def task_run_sample(
|
|
576
677
|
state = state_without_base64_images(state)
|
577
678
|
|
578
679
|
# log the sample
|
579
|
-
log_sample(
|
680
|
+
await log_sample(
|
580
681
|
logger=logger,
|
581
682
|
sample=sample,
|
582
683
|
state=state,
|
@@ -587,12 +688,14 @@ async def task_run_sample(
|
|
587
688
|
|
588
689
|
# return
|
589
690
|
if error is None:
|
691
|
+
if results is not None:
|
692
|
+
sample_complete(results)
|
590
693
|
return results
|
591
694
|
else:
|
592
695
|
return None
|
593
696
|
|
594
697
|
|
595
|
-
def log_sample(
|
698
|
+
async def log_sample(
|
596
699
|
logger: TaskLogger,
|
597
700
|
sample: Sample,
|
598
701
|
state: TaskState,
|
@@ -638,7 +741,7 @@ def log_sample(
|
|
638
741
|
limit=limit,
|
639
742
|
)
|
640
743
|
|
641
|
-
logger.log_sample(condense_sample(eval_sample, log_images), flush=True)
|
744
|
+
await logger.log_sample(condense_sample(eval_sample, log_images), flush=True)
|
642
745
|
|
643
746
|
|
644
747
|
async def resolve_dataset(
|
inspect_ai/_eval/task/sandbox.py
CHANGED
@@ -15,6 +15,7 @@ from inspect_ai.util._sandbox.context import (
|
|
15
15
|
)
|
16
16
|
from inspect_ai.util._sandbox.environment import (
|
17
17
|
SandboxEnvironment,
|
18
|
+
SandboxEnvironmentConfigType,
|
18
19
|
SandboxEnvironmentSpec,
|
19
20
|
)
|
20
21
|
|
@@ -129,7 +130,7 @@ def resolve_sandbox(
|
|
129
130
|
and sample.sandbox.type == task_sandbox.type
|
130
131
|
and sample.sandbox.config is not None
|
131
132
|
):
|
132
|
-
sandbox_config:
|
133
|
+
sandbox_config: SandboxEnvironmentConfigType | None = sample.sandbox.config
|
133
134
|
else:
|
134
135
|
sandbox_config = task_sandbox.config
|
135
136
|
return SandboxEnvironmentSpec(task_sandbox.type, sandbox_config)
|
inspect_ai/_util/_async.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1
1
|
import asyncio
|
2
|
-
from typing import Any
|
2
|
+
from typing import Any, Coroutine, TypeVar
|
3
|
+
|
4
|
+
import nest_asyncio # type: ignore
|
3
5
|
|
4
6
|
|
5
7
|
def is_callable_coroutine(func_or_cls: Any) -> bool:
|
@@ -8,3 +10,30 @@ def is_callable_coroutine(func_or_cls: Any) -> bool:
|
|
8
10
|
elif callable(func_or_cls):
|
9
11
|
return asyncio.iscoroutinefunction(func_or_cls.__call__)
|
10
12
|
return False
|
13
|
+
|
14
|
+
|
15
|
+
T = TypeVar("T")
|
16
|
+
|
17
|
+
|
18
|
+
_initialised_nest_asyncio: bool = False
|
19
|
+
|
20
|
+
|
21
|
+
def init_nest_asyncio() -> None:
|
22
|
+
global _initialised_nest_asyncio
|
23
|
+
if not _initialised_nest_asyncio:
|
24
|
+
nest_asyncio.apply()
|
25
|
+
_initialised_nest_asyncio = True
|
26
|
+
|
27
|
+
|
28
|
+
def run_coroutine(coroutine: Coroutine[None, None, T]) -> T:
|
29
|
+
try:
|
30
|
+
# this will throw if there is no running loop
|
31
|
+
asyncio.get_running_loop()
|
32
|
+
|
33
|
+
# initialiase nest_asyncio then we are clear to run
|
34
|
+
init_nest_asyncio()
|
35
|
+
return asyncio.run(coroutine)
|
36
|
+
|
37
|
+
except RuntimeError:
|
38
|
+
# No running event loop so we are clear to run
|
39
|
+
return asyncio.run(coroutine)
|
inspect_ai/_util/file.py
CHANGED
@@ -1,3 +1,5 @@
|
|
1
|
+
import asyncio
|
2
|
+
import contextlib
|
1
3
|
import datetime
|
2
4
|
import io
|
3
5
|
import os
|
@@ -7,11 +9,12 @@ import unicodedata
|
|
7
9
|
from contextlib import contextmanager
|
8
10
|
from copy import deepcopy
|
9
11
|
from pathlib import Path
|
10
|
-
from typing import Any, BinaryIO, Iterator, Literal, cast, overload
|
12
|
+
from typing import Any, AsyncIterator, BinaryIO, Iterator, Literal, cast, overload
|
11
13
|
from urllib.parse import urlparse
|
12
14
|
|
13
|
-
import fsspec # type: ignore
|
14
|
-
from fsspec.
|
15
|
+
import fsspec # type: ignore # type: ignore
|
16
|
+
from fsspec.asyn import AsyncFileSystem # type: ignore
|
17
|
+
from fsspec.core import split_protocol # type: ignore # type: ignore
|
15
18
|
from fsspec.implementations.local import make_path_posix # type: ignore
|
16
19
|
from pydantic import BaseModel
|
17
20
|
from s3fs import S3FileSystem # type: ignore
|
@@ -277,10 +280,34 @@ def filesystem(path: str, fs_options: dict[str, Any] = {}) -> FileSystem:
|
|
277
280
|
options.update(fs_options)
|
278
281
|
|
279
282
|
# create filesystem
|
280
|
-
fs, path = fsspec.core.url_to_fs(path)
|
283
|
+
fs, path = fsspec.core.url_to_fs(path, **options)
|
281
284
|
return FileSystem(fs)
|
282
285
|
|
283
286
|
|
287
|
+
@contextlib.asynccontextmanager
|
288
|
+
async def async_fileystem(
|
289
|
+
location: str, fs_options: dict[str, Any] = {}
|
290
|
+
) -> AsyncIterator[AsyncFileSystem]:
|
291
|
+
# determine protocol
|
292
|
+
protocol, _ = split_protocol(location)
|
293
|
+
protocol = protocol or "file"
|
294
|
+
|
295
|
+
# build options
|
296
|
+
options = default_fs_options(location)
|
297
|
+
options.update(fs_options)
|
298
|
+
|
299
|
+
if protocol == "s3":
|
300
|
+
s3 = S3FileSystem(asynchronous=True, **options)
|
301
|
+
session = await s3.set_session()
|
302
|
+
try:
|
303
|
+
yield s3
|
304
|
+
finally:
|
305
|
+
await session.close()
|
306
|
+
else:
|
307
|
+
options.update({"asynchronous": True, "loop": asyncio.get_event_loop()})
|
308
|
+
yield fsspec.filesystem(protocol, **options)
|
309
|
+
|
310
|
+
|
284
311
|
def absolute_file_path(file: str) -> str:
|
285
312
|
# check for a relative dir, if we find one then resolve to absolute
|
286
313
|
fs_scheme = urlparse(file).scheme
|
inspect_ai/_util/html.py
ADDED
inspect_ai/_util/logger.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1
1
|
import os
|
2
|
-
from contextvars import ContextVar
|
3
2
|
from logging import (
|
4
3
|
INFO,
|
5
4
|
WARNING,
|
@@ -154,19 +153,21 @@ def notify_logger_record(record: LogRecord, write: bool) -> None:
|
|
154
153
|
|
155
154
|
if write:
|
156
155
|
transcript()._event(LoggerEvent(message=LoggingMessage.from_log_record(record)))
|
156
|
+
global _rate_limit_count
|
157
157
|
if record.levelno <= INFO and "429" in record.getMessage():
|
158
|
-
|
158
|
+
_rate_limit_count = _rate_limit_count + 1
|
159
159
|
|
160
160
|
|
161
|
-
|
161
|
+
_rate_limit_count = 0
|
162
162
|
|
163
163
|
|
164
164
|
def init_http_rate_limit_count() -> None:
|
165
|
-
|
165
|
+
global _rate_limit_count
|
166
|
+
_rate_limit_count = 0
|
166
167
|
|
167
168
|
|
168
169
|
def http_rate_limit_count() -> int:
|
169
|
-
return
|
170
|
+
return _rate_limit_count
|
170
171
|
|
171
172
|
|
172
173
|
def warn_once(logger: Logger, message: str) -> None:
|
inspect_ai/_util/platform.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1
1
|
import importlib.util
|
2
2
|
import os
|
3
3
|
|
4
|
+
from inspect_ai._util._async import init_nest_asyncio
|
5
|
+
|
4
6
|
from .error import set_exception_hook
|
5
7
|
|
6
8
|
|
@@ -21,7 +23,7 @@ def platform_init() -> None:
|
|
21
23
|
# set exception hook if we haven't already
|
22
24
|
set_exception_hook()
|
23
25
|
|
24
|
-
# if we are running in a notebook
|
26
|
+
# if we are running in a notebook...
|
25
27
|
if running_in_notebook():
|
26
28
|
# check for required packages
|
27
29
|
if not have_package("ipywidgets"):
|
@@ -30,11 +32,8 @@ def platform_init() -> None:
|
|
30
32
|
+ "pip install ipywidgets\n"
|
31
33
|
)
|
32
34
|
|
33
|
-
#
|
34
|
-
|
35
|
-
import nest_asyncio # type: ignore
|
36
|
-
|
37
|
-
nest_asyncio.apply()
|
35
|
+
# setup nested asyncio
|
36
|
+
init_nest_asyncio()
|
38
37
|
|
39
38
|
|
40
39
|
def have_package(package: str) -> bool:
|
inspect_ai/_util/registry.py
CHANGED
@@ -84,7 +84,7 @@ def registry_tag(
|
|
84
84
|
named_params[param] = registry_info(named_params[param]).name
|
85
85
|
elif callable(named_params[param]) and hasattr(named_params[param], "__name__"):
|
86
86
|
named_params[param] = getattr(named_params[param], "__name__")
|
87
|
-
elif isinstance(named_params[param], dict | list):
|
87
|
+
elif isinstance(named_params[param], dict | list | BaseModel):
|
88
88
|
named_params[param] = to_jsonable_python(
|
89
89
|
named_params[param], fallback=lambda x: getattr(x, "__name__", None)
|
90
90
|
)
|
inspect_ai/_view/server.py
CHANGED
@@ -19,8 +19,8 @@ from inspect_ai.log._file import (
|
|
19
19
|
EvalLogInfo,
|
20
20
|
eval_log_json,
|
21
21
|
list_eval_logs_async,
|
22
|
-
|
23
|
-
|
22
|
+
read_eval_log_async,
|
23
|
+
read_eval_log_headers_async,
|
24
24
|
)
|
25
25
|
|
26
26
|
from .notify import view_last_eval_time
|
@@ -60,7 +60,7 @@ def view_server(
|
|
60
60
|
|
61
61
|
# header_only is based on a size threshold
|
62
62
|
header_only = request.query.get("header-only", None)
|
63
|
-
return log_file_response(file, header_only)
|
63
|
+
return await log_file_response(file, header_only)
|
64
64
|
|
65
65
|
@routes.get("/api/log-size/{log}")
|
66
66
|
async def api_log_size(request: web.Request) -> web.Response:
|
@@ -180,7 +180,7 @@ def log_listing_response(logs: list[EvalLogInfo], log_dir: str) -> web.Response:
|
|
180
180
|
return web.json_response(response)
|
181
181
|
|
182
182
|
|
183
|
-
def log_file_response(file: str, header_only_param: str | None) -> web.Response:
|
183
|
+
async def log_file_response(file: str, header_only_param: str | None) -> web.Response:
|
184
184
|
# resolve header_only
|
185
185
|
header_only_mb = int(header_only_param) if header_only_param is not None else None
|
186
186
|
header_only = resolve_header_only(file, header_only_mb)
|
@@ -189,8 +189,8 @@ def log_file_response(file: str, header_only_param: str | None) -> web.Response:
|
|
189
189
|
contents: bytes | None = None
|
190
190
|
if header_only:
|
191
191
|
try:
|
192
|
-
log =
|
193
|
-
contents = eval_log_json(log)
|
192
|
+
log = await read_eval_log_async(file, header_only=True)
|
193
|
+
contents = eval_log_json(log)
|
194
194
|
except ValueError as ex:
|
195
195
|
logger.info(
|
196
196
|
f"Unable to read headers from log file {file}: {ex}. "
|
@@ -198,8 +198,8 @@ def log_file_response(file: str, header_only_param: str | None) -> web.Response:
|
|
198
198
|
)
|
199
199
|
|
200
200
|
if contents is None: # normal read
|
201
|
-
log =
|
202
|
-
contents = eval_log_json(log)
|
201
|
+
log = await read_eval_log_async(file, header_only=False)
|
202
|
+
contents = eval_log_json(log)
|
203
203
|
|
204
204
|
return web.Response(body=contents, content_type="application/json")
|
205
205
|
|
@@ -245,7 +245,7 @@ async def log_bytes_response(log_file: str, start: int, end: int) -> web.Respons
|
|
245
245
|
|
246
246
|
|
247
247
|
async def log_headers_response(files: list[str]) -> web.Response:
|
248
|
-
headers =
|
248
|
+
headers = await read_eval_log_headers_async(files)
|
249
249
|
return web.json_response(to_jsonable_python(headers, exclude_none=True))
|
250
250
|
|
251
251
|
|
inspect_ai/_view/www/App.css
CHANGED
@@ -64,8 +64,8 @@ body[class^="vscode-"] {
|
|
64
64
|
--bs-secondary-bg: var(--vscode-list-inactiveSelectionBackground);
|
65
65
|
--bs-border-color: var(--vscode-editorGroup-border);
|
66
66
|
--bs-card-border-color: var(--vscode-editorGroup-border);
|
67
|
-
--bs-warning-bg-subtle: var(--vscode-
|
68
|
-
--bs-warning-text-emphasis: var(--vscode-
|
67
|
+
--bs-warning-bg-subtle: var(--vscode-inputValidation-warningBackground);
|
68
|
+
--bs-warning-text-emphasis: var(--vscode-input-foreground);
|
69
69
|
--inspect-find-background: var(--vscode-editorWidget-background);
|
70
70
|
--inspect-find-foreground: var(--vscode-editorWidget-foreground);
|
71
71
|
--inspect-input-background: var(--vscode-input-background);
|
@@ -14337,8 +14337,8 @@ body[class^="vscode-"] {
|
|
14337
14337
|
--bs-secondary-bg: var(--vscode-list-inactiveSelectionBackground);
|
14338
14338
|
--bs-border-color: var(--vscode-editorGroup-border);
|
14339
14339
|
--bs-card-border-color: var(--vscode-editorGroup-border);
|
14340
|
-
--bs-warning-bg-subtle: var(--vscode-
|
14341
|
-
--bs-warning-text-emphasis: var(--vscode-
|
14340
|
+
--bs-warning-bg-subtle: var(--vscode-inputValidation-warningBackground);
|
14341
|
+
--bs-warning-text-emphasis: var(--vscode-input-foreground);
|
14342
14342
|
--inspect-find-background: var(--vscode-editorWidget-background);
|
14343
14343
|
--inspect-find-foreground: var(--vscode-editorWidget-foreground);
|
14344
14344
|
--inspect-input-background: var(--vscode-input-background);
|