inspect-ai 0.3.72__py3-none-any.whl → 0.3.73__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/_cli/eval.py +14 -3
- inspect_ai/_cli/sandbox.py +3 -3
- inspect_ai/_cli/score.py +6 -4
- inspect_ai/_cli/trace.py +53 -6
- inspect_ai/_display/core/config.py +1 -1
- inspect_ai/_display/core/display.py +2 -1
- inspect_ai/_display/core/footer.py +6 -6
- inspect_ai/_display/plain/display.py +11 -6
- inspect_ai/_display/rich/display.py +23 -13
- inspect_ai/_display/textual/app.py +10 -9
- inspect_ai/_display/textual/display.py +2 -2
- inspect_ai/_display/textual/widgets/footer.py +4 -0
- inspect_ai/_display/textual/widgets/samples.py +14 -5
- inspect_ai/_eval/context.py +1 -2
- inspect_ai/_eval/eval.py +54 -41
- inspect_ai/_eval/loader.py +9 -2
- inspect_ai/_eval/run.py +148 -81
- inspect_ai/_eval/score.py +13 -8
- inspect_ai/_eval/task/images.py +31 -21
- inspect_ai/_eval/task/run.py +62 -59
- inspect_ai/_eval/task/rundir.py +16 -9
- inspect_ai/_eval/task/sandbox.py +7 -8
- inspect_ai/_eval/task/util.py +7 -0
- inspect_ai/_util/_async.py +118 -10
- inspect_ai/_util/constants.py +0 -2
- inspect_ai/_util/file.py +15 -29
- inspect_ai/_util/future.py +37 -0
- inspect_ai/_util/http.py +3 -99
- inspect_ai/_util/httpx.py +60 -0
- inspect_ai/_util/interrupt.py +2 -2
- inspect_ai/_util/json.py +5 -52
- inspect_ai/_util/logger.py +30 -86
- inspect_ai/_util/retry.py +10 -61
- inspect_ai/_util/trace.py +2 -2
- inspect_ai/_view/server.py +86 -3
- inspect_ai/_view/www/dist/assets/index.js +25837 -13269
- inspect_ai/_view/www/log-schema.json +253 -186
- inspect_ai/_view/www/package.json +2 -2
- inspect_ai/_view/www/src/plan/PlanDetailView.tsx +8 -3
- inspect_ai/_view/www/src/samples/transcript/StepEventView.tsx +2 -3
- inspect_ai/_view/www/src/types/log.d.ts +122 -94
- inspect_ai/approval/_human/manager.py +6 -10
- inspect_ai/approval/_human/panel.py +2 -2
- inspect_ai/dataset/_sources/util.py +7 -6
- inspect_ai/log/__init__.py +4 -0
- inspect_ai/log/_file.py +35 -61
- inspect_ai/log/_log.py +18 -1
- inspect_ai/log/_recorders/eval.py +14 -23
- inspect_ai/log/_recorders/json.py +3 -18
- inspect_ai/log/_samples.py +27 -2
- inspect_ai/log/_transcript.py +8 -8
- inspect_ai/model/__init__.py +2 -1
- inspect_ai/model/_call_tools.py +60 -40
- inspect_ai/model/_chat_message.py +3 -2
- inspect_ai/model/_generate_config.py +25 -0
- inspect_ai/model/_model.py +74 -36
- inspect_ai/model/_openai.py +9 -1
- inspect_ai/model/_providers/anthropic.py +24 -26
- inspect_ai/model/_providers/azureai.py +11 -9
- inspect_ai/model/_providers/bedrock.py +33 -24
- inspect_ai/model/_providers/cloudflare.py +8 -9
- inspect_ai/model/_providers/goodfire.py +7 -3
- inspect_ai/model/_providers/google.py +47 -13
- inspect_ai/model/_providers/groq.py +15 -15
- inspect_ai/model/_providers/hf.py +24 -17
- inspect_ai/model/_providers/mistral.py +36 -20
- inspect_ai/model/_providers/openai.py +30 -25
- inspect_ai/model/_providers/openai_o1.py +1 -1
- inspect_ai/model/_providers/providers.py +1 -1
- inspect_ai/model/_providers/together.py +3 -4
- inspect_ai/model/_providers/util/__init__.py +2 -2
- inspect_ai/model/_providers/util/chatapi.py +6 -19
- inspect_ai/model/_providers/util/hooks.py +165 -0
- inspect_ai/model/_providers/vertex.py +20 -3
- inspect_ai/model/_providers/vllm.py +16 -19
- inspect_ai/scorer/_multi.py +5 -2
- inspect_ai/solver/_bridge/patch.py +31 -1
- inspect_ai/solver/_fork.py +5 -3
- inspect_ai/solver/_human_agent/agent.py +3 -2
- inspect_ai/tool/__init__.py +8 -2
- inspect_ai/tool/_tool_info.py +4 -90
- inspect_ai/tool/_tool_params.py +4 -34
- inspect_ai/tool/_tools/_web_search.py +30 -24
- inspect_ai/util/__init__.py +4 -0
- inspect_ai/util/_concurrency.py +5 -6
- inspect_ai/util/_display.py +6 -0
- inspect_ai/util/_json.py +170 -0
- inspect_ai/util/_sandbox/docker/cleanup.py +13 -9
- inspect_ai/util/_sandbox/docker/docker.py +5 -0
- inspect_ai/util/_sandbox/environment.py +56 -9
- inspect_ai/util/_sandbox/service.py +12 -5
- inspect_ai/util/_subprocess.py +94 -113
- inspect_ai/util/_subtask.py +2 -4
- {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/METADATA +6 -2
- {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/RECORD +99 -99
- {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/WHEEL +1 -1
- inspect_ai/_util/timeouts.py +0 -160
- inspect_ai/_view/www/node_modules/flatted/python/flatted.py +0 -149
- inspect_ai/_view/www/node_modules/flatted/python/test.py +0 -63
- inspect_ai/model/_providers/util/tracker.py +0 -92
- {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/LICENSE +0 -0
- {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/top_level.txt +0 -0
inspect_ai/_eval/task/run.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1
|
-
import asyncio
|
2
1
|
import contextlib
|
2
|
+
import functools
|
3
3
|
import sys
|
4
4
|
import time
|
5
5
|
from copy import deepcopy
|
@@ -9,6 +9,7 @@ from logging import getLogger
|
|
9
9
|
from pathlib import PurePath
|
10
10
|
from typing import Callable, Literal
|
11
11
|
|
12
|
+
import anyio
|
12
13
|
from typing_extensions import Unpack
|
13
14
|
|
14
15
|
from inspect_ai._display import (
|
@@ -19,6 +20,7 @@ from inspect_ai._display import (
|
|
19
20
|
display,
|
20
21
|
)
|
21
22
|
from inspect_ai._display.core.display import TaskDisplay, TaskDisplayMetric
|
23
|
+
from inspect_ai._util._async import tg_collect
|
22
24
|
from inspect_ai._util.constants import (
|
23
25
|
DEFAULT_EPOCHS,
|
24
26
|
DEFAULT_MAX_CONNECTIONS,
|
@@ -32,7 +34,6 @@ from inspect_ai._util.registry import (
|
|
32
34
|
registry_log_name,
|
33
35
|
registry_unqualified_name,
|
34
36
|
)
|
35
|
-
from inspect_ai._util.timeouts import Timeout, timeout
|
36
37
|
from inspect_ai._util.working import (
|
37
38
|
init_sample_working_limit,
|
38
39
|
sample_waiting_time,
|
@@ -95,9 +96,9 @@ from .images import (
|
|
95
96
|
)
|
96
97
|
from .log import TaskLogger, collect_eval_data, log_start
|
97
98
|
from .results import eval_results
|
98
|
-
from .rundir import
|
99
|
+
from .rundir import set_task_chdir
|
99
100
|
from .sandbox import sandboxenv_context
|
100
|
-
from .util import sample_messages, slice_dataset
|
101
|
+
from .util import sample_messages, slice_dataset
|
101
102
|
|
102
103
|
py_logger = getLogger(__name__)
|
103
104
|
|
@@ -147,8 +148,8 @@ async def task_run(options: TaskRunOptions) -> EvalLog:
|
|
147
148
|
# init task context
|
148
149
|
init_task_context(model, options.task.approval, generate_config)
|
149
150
|
|
150
|
-
# establish
|
151
|
-
with
|
151
|
+
# establish chdir for duration of execution (if a task has chdir=True)
|
152
|
+
with set_task_chdir(task):
|
152
153
|
# track stats and error
|
153
154
|
results: EvalResults | None = None
|
154
155
|
reductions: list[EvalSampleReductions] | None = None
|
@@ -286,35 +287,6 @@ async def task_run(options: TaskRunOptions) -> EvalLog:
|
|
286
287
|
task.metrics,
|
287
288
|
)
|
288
289
|
|
289
|
-
# create sample coroutines
|
290
|
-
sample_coroutines = [
|
291
|
-
task_run_sample(
|
292
|
-
task_name=task.name,
|
293
|
-
sample=sample,
|
294
|
-
state=state,
|
295
|
-
sandbox=sandbox,
|
296
|
-
max_sandboxes=config.max_sandboxes,
|
297
|
-
sandbox_cleanup=sandbox_cleanup,
|
298
|
-
plan=plan,
|
299
|
-
scorers=scorers,
|
300
|
-
generate=generate,
|
301
|
-
progress=progress,
|
302
|
-
logger=logger if log_samples else None,
|
303
|
-
log_images=log_images,
|
304
|
-
sample_source=sample_source,
|
305
|
-
sample_error=sample_error_handler,
|
306
|
-
sample_complete=sample_complete,
|
307
|
-
fails_on_error=(
|
308
|
-
config.fail_on_error is None
|
309
|
-
or config.fail_on_error is True
|
310
|
-
),
|
311
|
-
time_limit=config.time_limit,
|
312
|
-
working_limit=config.working_limit,
|
313
|
-
semaphore=sample_semaphore,
|
314
|
-
)
|
315
|
-
for (sample, state) in zip(samples, states)
|
316
|
-
]
|
317
|
-
|
318
290
|
# initial progress
|
319
291
|
td.sample_complete(complete=0, total=len(samples))
|
320
292
|
|
@@ -327,7 +299,36 @@ async def task_run(options: TaskRunOptions) -> EvalLog:
|
|
327
299
|
task.metrics,
|
328
300
|
)
|
329
301
|
|
330
|
-
sample_results = await
|
302
|
+
sample_results = await tg_collect(
|
303
|
+
[
|
304
|
+
functools.partial(
|
305
|
+
task_run_sample,
|
306
|
+
task_name=task.name,
|
307
|
+
sample=sample,
|
308
|
+
state=state,
|
309
|
+
sandbox=sandbox,
|
310
|
+
max_sandboxes=config.max_sandboxes,
|
311
|
+
sandbox_cleanup=sandbox_cleanup,
|
312
|
+
plan=plan,
|
313
|
+
scorers=scorers,
|
314
|
+
generate=generate,
|
315
|
+
progress=progress,
|
316
|
+
logger=logger if log_samples else None,
|
317
|
+
log_images=log_images,
|
318
|
+
sample_source=sample_source,
|
319
|
+
sample_error=sample_error_handler,
|
320
|
+
sample_complete=sample_complete,
|
321
|
+
fails_on_error=(
|
322
|
+
config.fail_on_error is None
|
323
|
+
or config.fail_on_error is True
|
324
|
+
),
|
325
|
+
time_limit=config.time_limit,
|
326
|
+
working_limit=config.working_limit,
|
327
|
+
semaphore=sample_semaphore,
|
328
|
+
)
|
329
|
+
for (sample, state) in zip(samples, states)
|
330
|
+
]
|
331
|
+
)
|
331
332
|
|
332
333
|
# compute and record metrics if we have scores
|
333
334
|
completed_scores = [
|
@@ -362,17 +363,18 @@ async def task_run(options: TaskRunOptions) -> EvalLog:
|
|
362
363
|
)
|
363
364
|
)
|
364
365
|
|
365
|
-
except
|
366
|
-
|
367
|
-
|
366
|
+
except anyio.get_cancelled_exc_class():
|
367
|
+
with anyio.CancelScope(shield=True):
|
368
|
+
# collect eval data
|
369
|
+
collect_eval_data(stats)
|
368
370
|
|
369
|
-
|
370
|
-
|
371
|
-
|
372
|
-
|
371
|
+
# finish w/ cancelled status
|
372
|
+
eval_log = await logger.log_finish(
|
373
|
+
"cancelled", stats, results, reductions
|
374
|
+
)
|
373
375
|
|
374
|
-
|
375
|
-
|
376
|
+
# display task cancelled
|
377
|
+
td.complete(TaskCancelled(logger.samples_completed, stats))
|
376
378
|
|
377
379
|
except BaseException as ex:
|
378
380
|
if options.debug_errors:
|
@@ -503,7 +505,7 @@ async def task_run_sample(
|
|
503
505
|
fails_on_error: bool,
|
504
506
|
time_limit: int | None,
|
505
507
|
working_limit: int | None,
|
506
|
-
semaphore:
|
508
|
+
semaphore: anyio.Semaphore | None,
|
507
509
|
) -> dict[str, SampleScore] | None:
|
508
510
|
# if there is an existing sample then tick off its progress, log it, and return it
|
509
511
|
if sample_source and sample.id is not None:
|
@@ -533,7 +535,7 @@ async def task_run_sample(
|
|
533
535
|
return sample_scores
|
534
536
|
|
535
537
|
# use semaphore if provided
|
536
|
-
semaphore_cm:
|
538
|
+
semaphore_cm: anyio.Semaphore | contextlib.AbstractAsyncContextManager[None] = (
|
537
539
|
semaphore if semaphore else contextlib.nullcontext()
|
538
540
|
)
|
539
541
|
|
@@ -606,7 +608,7 @@ async def task_run_sample(
|
|
606
608
|
|
607
609
|
# initialise timeout context manager
|
608
610
|
timeout_cm = (
|
609
|
-
|
611
|
+
anyio.fail_after(time_limit)
|
610
612
|
if time_limit is not None
|
611
613
|
else contextlib.nullcontext()
|
612
614
|
)
|
@@ -616,7 +618,7 @@ async def task_run_sample(
|
|
616
618
|
init_sample_working_limit(start_time, working_limit)
|
617
619
|
|
618
620
|
# run sample w/ optional timeout
|
619
|
-
|
621
|
+
with timeout_cm:
|
620
622
|
# mark started
|
621
623
|
active.started = datetime.now().timestamp()
|
622
624
|
|
@@ -640,9 +642,9 @@ async def task_run_sample(
|
|
640
642
|
# capture most recent state for scoring
|
641
643
|
state = sample_state() or state
|
642
644
|
|
643
|
-
except
|
645
|
+
except anyio.get_cancelled_exc_class() as ex:
|
644
646
|
if active.interrupt_action:
|
645
|
-
# record
|
647
|
+
# record event
|
646
648
|
transcript()._event(
|
647
649
|
SampleLimitEvent(
|
648
650
|
type="operator",
|
@@ -660,6 +662,8 @@ async def task_run_sample(
|
|
660
662
|
error, raise_error = handle_error(ex)
|
661
663
|
|
662
664
|
else:
|
665
|
+
# task group provided by tg_collect will automatically
|
666
|
+
# handle the cancel exception
|
663
667
|
raise
|
664
668
|
|
665
669
|
except SampleLimitExceededError as ex:
|
@@ -687,9 +691,8 @@ async def task_run_sample(
|
|
687
691
|
# the cause of the timeout is a hung container and scoring requires
|
688
692
|
# interacting with the container). as a middle ground we use half
|
689
693
|
# of the original timeout value for scoring.
|
690
|
-
if
|
691
|
-
|
692
|
-
timeout_cm = timeout(time_limit / 2)
|
694
|
+
if time_limit is not None:
|
695
|
+
timeout_cm = anyio.fail_after(time_limit / 2)
|
693
696
|
|
694
697
|
# turn off message and token limits
|
695
698
|
state.message_limit = None
|
@@ -699,7 +702,7 @@ async def task_run_sample(
|
|
699
702
|
# scoring
|
700
703
|
try:
|
701
704
|
# timeout during scoring will result in an ordinary sample error
|
702
|
-
|
705
|
+
with timeout_cm:
|
703
706
|
if error is None:
|
704
707
|
for scorer in scorers or []:
|
705
708
|
scorer_name = unique_scorer_name(
|
@@ -740,7 +743,7 @@ async def task_run_sample(
|
|
740
743
|
# propagate results into scores
|
741
744
|
state.scores = {k: v.score for k, v in results.items()}
|
742
745
|
|
743
|
-
except
|
746
|
+
except anyio.get_cancelled_exc_class():
|
744
747
|
if active.interrupt_action:
|
745
748
|
transcript()._event(
|
746
749
|
SampleLimitEvent(
|
@@ -970,10 +973,10 @@ def create_sample_semaphore(
|
|
970
973
|
config: EvalConfig,
|
971
974
|
generate_config: GenerateConfig,
|
972
975
|
modelapi: ModelAPI | None = None,
|
973
|
-
) ->
|
976
|
+
) -> anyio.Semaphore:
|
974
977
|
# if the user set max_samples then use that
|
975
978
|
if config.max_samples is not None:
|
976
|
-
return
|
979
|
+
return anyio.Semaphore(config.max_samples)
|
977
980
|
|
978
981
|
# use max_connections
|
979
982
|
max_samples = (
|
@@ -985,4 +988,4 @@ def create_sample_semaphore(
|
|
985
988
|
)
|
986
989
|
|
987
990
|
# return the semaphore
|
988
|
-
return
|
991
|
+
return anyio.Semaphore(max_samples)
|
inspect_ai/_eval/task/rundir.py
CHANGED
@@ -6,9 +6,12 @@ from contextvars import ContextVar
|
|
6
6
|
from functools import wraps
|
7
7
|
from typing import Any, Callable, Iterator, TypeVar
|
8
8
|
|
9
|
+
from inspect_ai._eval.task.task import Task
|
10
|
+
from inspect_ai._eval.task.util import task_chdir
|
11
|
+
|
9
12
|
TASK_DIRECTORY_ATTRIB = "task_directory"
|
10
13
|
|
11
|
-
|
14
|
+
_task_chdir = ContextVar[str | None]("_task_chdir", default=None)
|
12
15
|
|
13
16
|
T = TypeVar("T", bound="asyncio.BaseEventLoop")
|
14
17
|
|
@@ -46,12 +49,16 @@ def task_run_dir_switching() -> Iterator[None]:
|
|
46
49
|
|
47
50
|
|
48
51
|
@contextmanager
|
49
|
-
def
|
50
|
-
|
51
|
-
|
52
|
+
def set_task_chdir(task: Task) -> Iterator[None]:
|
53
|
+
chdir = task_chdir(task)
|
54
|
+
if chdir is not None:
|
55
|
+
token = _task_chdir.set(chdir)
|
56
|
+
try:
|
57
|
+
yield
|
58
|
+
finally:
|
59
|
+
_task_chdir.reset(token)
|
60
|
+
else:
|
52
61
|
yield
|
53
|
-
finally:
|
54
|
-
_task_run_dir.reset(token)
|
55
62
|
|
56
63
|
|
57
64
|
if sys.platform == "win32":
|
@@ -63,9 +70,9 @@ else:
|
|
63
70
|
def _wrap_callback(callback: Callable[..., Any]) -> Callable[..., Any]:
|
64
71
|
@wraps(callback)
|
65
72
|
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
66
|
-
|
67
|
-
if
|
68
|
-
os.chdir(
|
73
|
+
chdir = _task_chdir.get(None)
|
74
|
+
if chdir is not None and chdir != os.getcwd():
|
75
|
+
os.chdir(chdir)
|
69
76
|
return callback(*args, **kwargs)
|
70
77
|
|
71
78
|
return wrapper
|
inspect_ai/_eval/task/sandbox.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1
|
-
import asyncio
|
2
1
|
import base64
|
3
2
|
import contextlib
|
4
3
|
from random import random
|
5
4
|
from typing import AsyncGenerator, Callable, NamedTuple, cast
|
6
5
|
|
6
|
+
import anyio
|
7
7
|
import httpx
|
8
8
|
from tenacity import (
|
9
9
|
retry,
|
@@ -15,10 +15,9 @@ from tenacity import (
|
|
15
15
|
|
16
16
|
from inspect_ai._eval.task.task import Task
|
17
17
|
from inspect_ai._eval.task.util import task_run_dir
|
18
|
-
from inspect_ai._util.constants import DEFAULT_MAX_RETRIES, DEFAULT_TIMEOUT
|
19
18
|
from inspect_ai._util.file import file, filesystem
|
19
|
+
from inspect_ai._util.httpx import httpx_should_retry, log_httpx_retry_attempt
|
20
20
|
from inspect_ai._util.registry import registry_unqualified_name
|
21
|
-
from inspect_ai._util.retry import httpx_should_retry, log_retry_attempt
|
22
21
|
from inspect_ai._util.url import data_uri_to_base64, is_data_uri, is_http_url
|
23
22
|
from inspect_ai.dataset import Sample
|
24
23
|
from inspect_ai.util._concurrency import concurrency
|
@@ -62,7 +61,7 @@ async def sandboxenv_context(
|
|
62
61
|
# in and grab all of the sandboxes). Therefore, in this case we wait a random
|
63
62
|
# delay so that all tasks/samples have an equal shot at getting scheduled.
|
64
63
|
if max_sandboxes is not None:
|
65
|
-
await
|
64
|
+
await anyio.sleep(random())
|
66
65
|
|
67
66
|
# enforce concurrency if required
|
68
67
|
sandboxes_cm = (
|
@@ -103,7 +102,7 @@ async def sandboxenv_context(
|
|
103
102
|
# run sample
|
104
103
|
yield
|
105
104
|
|
106
|
-
except
|
105
|
+
except anyio.get_cancelled_exc_class() as ex:
|
107
106
|
interrupted = True
|
108
107
|
raise ex
|
109
108
|
|
@@ -186,14 +185,14 @@ async def _retrying_httpx_get(
|
|
186
185
|
url: str,
|
187
186
|
client: httpx.AsyncClient = httpx.AsyncClient(),
|
188
187
|
timeout: int = 30, # per-attempt timeout
|
189
|
-
max_retries: int =
|
190
|
-
total_timeout: int =
|
188
|
+
max_retries: int = 10,
|
189
|
+
total_timeout: int = 120, # timeout for the whole retry loop. not for an individual attempt
|
191
190
|
) -> bytes:
|
192
191
|
@retry(
|
193
192
|
wait=wait_exponential_jitter(),
|
194
193
|
stop=(stop_after_attempt(max_retries) | stop_after_delay(total_timeout)),
|
195
194
|
retry=retry_if_exception(httpx_should_retry),
|
196
|
-
before_sleep=
|
195
|
+
before_sleep=log_httpx_retry_attempt(url),
|
197
196
|
)
|
198
197
|
async def do_get() -> bytes:
|
199
198
|
response = await client.get(
|
inspect_ai/_eval/task/util.py
CHANGED
@@ -25,6 +25,13 @@ def task_run_dir(task: Task) -> str:
|
|
25
25
|
return getattr(task, TASK_RUN_DIR_ATTR, os.getcwd())
|
26
26
|
|
27
27
|
|
28
|
+
def task_chdir(task: Task) -> str | None:
|
29
|
+
if task.attribs.get("chdir", False) is True:
|
30
|
+
return task_run_dir(task)
|
31
|
+
else:
|
32
|
+
return None
|
33
|
+
|
34
|
+
|
28
35
|
def task_file(task: Task, relative: bool = False) -> str | None:
|
29
36
|
file = cast(str | None, getattr(task, TASK_FILE_ATTR, None))
|
30
37
|
if file:
|
inspect_ai/_util/_async.py
CHANGED
@@ -1,20 +1,100 @@
|
|
1
1
|
import asyncio
|
2
|
-
|
2
|
+
import inspect
|
3
|
+
import os
|
4
|
+
import sys
|
5
|
+
from logging import Logger
|
6
|
+
from typing import Any, Awaitable, Callable, Coroutine, Literal, TypeVar, cast
|
3
7
|
|
8
|
+
import anyio
|
4
9
|
import nest_asyncio # type: ignore
|
10
|
+
import sniffio
|
11
|
+
|
12
|
+
if sys.version_info >= (3, 11):
|
13
|
+
from typing import TypeVarTuple, Unpack
|
14
|
+
else:
|
15
|
+
from exceptiongroup import ExceptionGroup
|
16
|
+
from typing_extensions import TypeVarTuple, Unpack
|
17
|
+
|
18
|
+
|
19
|
+
PosArgsT = TypeVarTuple("PosArgsT")
|
5
20
|
|
6
21
|
|
7
22
|
def is_callable_coroutine(func_or_cls: Any) -> bool:
|
8
|
-
if
|
23
|
+
if inspect.iscoroutinefunction(func_or_cls):
|
9
24
|
return True
|
10
25
|
elif callable(func_or_cls):
|
11
|
-
return
|
26
|
+
return inspect.iscoroutinefunction(func_or_cls.__call__)
|
12
27
|
return False
|
13
28
|
|
14
29
|
|
15
30
|
T = TypeVar("T")
|
16
31
|
|
17
32
|
|
33
|
+
async def tg_collect(
|
34
|
+
funcs: list[Callable[[], Awaitable[T]]], exception_group: bool = False
|
35
|
+
) -> list[T]:
|
36
|
+
"""Runs all of the pased async functions and collects their results.
|
37
|
+
|
38
|
+
The results will be returned in the same order as the input `funcs`.
|
39
|
+
|
40
|
+
Args:
|
41
|
+
funcs: List of async functions.
|
42
|
+
exception_group: `True` to raise an ExceptionGroup or
|
43
|
+
`False` (the default) to raise only the first exception.
|
44
|
+
|
45
|
+
Returns:
|
46
|
+
List of results of type T.
|
47
|
+
|
48
|
+
Raises:
|
49
|
+
Exception: First exception occurring in failed tasks
|
50
|
+
(for `exception_group == False`, the default)
|
51
|
+
ExceptionGroup: Exceptions that occurred in failed tasks
|
52
|
+
(for `exception_group == True`)
|
53
|
+
"""
|
54
|
+
try:
|
55
|
+
results: list[tuple[int, T]] = []
|
56
|
+
|
57
|
+
async with anyio.create_task_group() as tg:
|
58
|
+
|
59
|
+
async def run_task(index: int) -> None:
|
60
|
+
result = await funcs[index]()
|
61
|
+
results.append((index, result))
|
62
|
+
|
63
|
+
for i in range(0, len(funcs)):
|
64
|
+
tg.start_soon(run_task, i)
|
65
|
+
|
66
|
+
# sort results by original index and return just the values
|
67
|
+
return [r for _, r in sorted(results)]
|
68
|
+
except ExceptionGroup as ex:
|
69
|
+
if exception_group:
|
70
|
+
raise
|
71
|
+
else:
|
72
|
+
raise ex.exceptions[0]
|
73
|
+
|
74
|
+
|
75
|
+
async def coro_print_exceptions(
|
76
|
+
context: str,
|
77
|
+
func: Callable[[Unpack[PosArgsT]], Awaitable[Any]],
|
78
|
+
*args: Unpack[PosArgsT],
|
79
|
+
) -> None:
|
80
|
+
try:
|
81
|
+
await func(*args)
|
82
|
+
except Exception as ex:
|
83
|
+
print(f"Error {context}: {ex}")
|
84
|
+
|
85
|
+
|
86
|
+
async def coro_log_exceptions(
|
87
|
+
logger: Logger,
|
88
|
+
context: str,
|
89
|
+
func: Callable[[Unpack[PosArgsT]], Awaitable[Any]],
|
90
|
+
*args: Unpack[PosArgsT],
|
91
|
+
) -> None:
|
92
|
+
try:
|
93
|
+
await func(*args)
|
94
|
+
except Exception as ex:
|
95
|
+
logger.warning(f"Error {context}: {ex}")
|
96
|
+
|
97
|
+
|
18
98
|
_initialised_nest_asyncio: bool = False
|
19
99
|
|
20
100
|
|
@@ -26,14 +106,42 @@ def init_nest_asyncio() -> None:
|
|
26
106
|
|
27
107
|
|
28
108
|
def run_coroutine(coroutine: Coroutine[None, None, T]) -> T:
|
29
|
-
|
30
|
-
|
31
|
-
|
109
|
+
from inspect_ai._util.platform import running_in_notebook
|
110
|
+
|
111
|
+
if current_async_backend() == "trio":
|
112
|
+
raise RuntimeError("run_coroutine cannot be used with trio")
|
32
113
|
|
33
|
-
|
114
|
+
if running_in_notebook():
|
34
115
|
init_nest_asyncio()
|
35
116
|
return asyncio.run(coroutine)
|
117
|
+
else:
|
118
|
+
try:
|
119
|
+
# this will throw if there is no running loop
|
120
|
+
asyncio.get_running_loop()
|
36
121
|
|
37
|
-
|
38
|
-
|
39
|
-
|
122
|
+
# initialiase nest_asyncio then we are clear to run
|
123
|
+
init_nest_asyncio()
|
124
|
+
return asyncio.run(coroutine)
|
125
|
+
|
126
|
+
except RuntimeError:
|
127
|
+
# No running event loop so we are clear to run
|
128
|
+
return asyncio.run(coroutine)
|
129
|
+
|
130
|
+
|
131
|
+
def current_async_backend() -> Literal["asyncio", "trio"] | None:
|
132
|
+
try:
|
133
|
+
return _validate_backend(sniffio.current_async_library().lower())
|
134
|
+
except sniffio.AsyncLibraryNotFoundError:
|
135
|
+
return None
|
136
|
+
|
137
|
+
|
138
|
+
def configured_async_backend() -> Literal["asyncio", "trio"]:
|
139
|
+
backend = os.environ.get("INSPECT_ASYNC_BACKEND", "asyncio").lower()
|
140
|
+
return _validate_backend(backend)
|
141
|
+
|
142
|
+
|
143
|
+
def _validate_backend(backend: str) -> Literal["asyncio", "trio"]:
|
144
|
+
if backend in ["asyncio", "trio"]:
|
145
|
+
return cast(Literal["asyncio", "trio"], backend)
|
146
|
+
else:
|
147
|
+
raise RuntimeError(f"Unknown async backend: {backend}")
|
inspect_ai/_util/constants.py
CHANGED
@@ -6,8 +6,6 @@ PKG_AUTHOR_DIR = "UK-AISI"
|
|
6
6
|
PKG_NAME = Path(__file__).parent.parent.stem
|
7
7
|
PKG_PATH = Path(__file__).parent.parent
|
8
8
|
DEFAULT_EPOCHS = 1
|
9
|
-
DEFAULT_MAX_RETRIES = 5
|
10
|
-
DEFAULT_TIMEOUT = 120
|
11
9
|
DEFAULT_MAX_CONNECTIONS = 10
|
12
10
|
DEFAULT_MAX_TOKENS = 2048
|
13
11
|
DEFAULT_VIEW_PORT = 7575
|
inspect_ai/_util/file.py
CHANGED
@@ -1,5 +1,3 @@
|
|
1
|
-
import asyncio
|
2
|
-
import contextlib
|
3
1
|
import datetime
|
4
2
|
import io
|
5
3
|
import os
|
@@ -9,17 +7,19 @@ import unicodedata
|
|
9
7
|
from contextlib import contextmanager
|
10
8
|
from copy import deepcopy
|
11
9
|
from pathlib import Path
|
12
|
-
from typing import Any,
|
10
|
+
from typing import Any, BinaryIO, Iterator, Literal, cast, overload
|
13
11
|
from urllib.parse import urlparse
|
14
12
|
|
15
13
|
import fsspec # type: ignore # type: ignore
|
16
|
-
from fsspec.asyn import AsyncFileSystem # type: ignore
|
17
14
|
from fsspec.core import split_protocol # type: ignore # type: ignore
|
18
15
|
from fsspec.implementations.local import make_path_posix # type: ignore
|
19
16
|
from pydantic import BaseModel
|
20
17
|
from s3fs import S3FileSystem # type: ignore
|
21
18
|
from shortuuid import uuid
|
22
19
|
|
20
|
+
from inspect_ai._util._async import configured_async_backend, current_async_backend
|
21
|
+
from inspect_ai._util.error import PrerequisiteError
|
22
|
+
|
23
23
|
# https://filesystem-spec.readthedocs.io/en/latest/_modules/fsspec/spec.html#AbstractFileSystem
|
24
24
|
# https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.generic.GenericFileSystem
|
25
25
|
|
@@ -298,30 +298,6 @@ def filesystem(path: str, fs_options: dict[str, Any] = {}) -> FileSystem:
|
|
298
298
|
return FileSystem(fs)
|
299
299
|
|
300
300
|
|
301
|
-
@contextlib.asynccontextmanager
|
302
|
-
async def async_fileystem(
|
303
|
-
location: str, fs_options: dict[str, Any] = {}
|
304
|
-
) -> AsyncIterator[AsyncFileSystem]:
|
305
|
-
# determine protocol
|
306
|
-
protocol, _ = split_protocol(location)
|
307
|
-
protocol = protocol or "file"
|
308
|
-
|
309
|
-
# build options
|
310
|
-
options = default_fs_options(location)
|
311
|
-
options.update(fs_options)
|
312
|
-
|
313
|
-
if protocol == "s3":
|
314
|
-
s3 = S3FileSystem(asynchronous=True, **options)
|
315
|
-
session = await s3.set_session()
|
316
|
-
try:
|
317
|
-
yield s3
|
318
|
-
finally:
|
319
|
-
await session.close()
|
320
|
-
else:
|
321
|
-
options.update({"asynchronous": True, "loop": asyncio.get_event_loop()})
|
322
|
-
yield fsspec.filesystem(protocol, **options)
|
323
|
-
|
324
|
-
|
325
301
|
def absolute_file_path(file: str) -> str:
|
326
302
|
# check for a relative dir, if we find one then resolve to absolute
|
327
303
|
fs_scheme = urlparse(file).scheme
|
@@ -331,7 +307,17 @@ def absolute_file_path(file: str) -> str:
|
|
331
307
|
|
332
308
|
|
333
309
|
def default_fs_options(file: str) -> dict[str, Any]:
|
334
|
-
|
310
|
+
scheme = urlparse(file).scheme
|
311
|
+
if (
|
312
|
+
scheme == "s3"
|
313
|
+
and configured_async_backend() == "trio"
|
314
|
+
and current_async_backend() == "trio"
|
315
|
+
):
|
316
|
+
raise PrerequisiteError(
|
317
|
+
"ERROR: The s3 interface is not supported when running under the trio async backend."
|
318
|
+
)
|
319
|
+
|
320
|
+
options = deepcopy(DEFAULT_FS_OPTIONS.get(scheme, {}))
|
335
321
|
# disable caching for all filesystems
|
336
322
|
options.update(
|
337
323
|
dict(
|
@@ -0,0 +1,37 @@
|
|
1
|
+
from typing import Generic, TypeVar
|
2
|
+
|
3
|
+
import anyio
|
4
|
+
|
5
|
+
T = TypeVar("T")
|
6
|
+
|
7
|
+
|
8
|
+
class Future(Generic[T]):
|
9
|
+
def __init__(self) -> None:
|
10
|
+
self._result: T | None = None
|
11
|
+
self._ex: Exception | None = None
|
12
|
+
self._event = anyio.Event()
|
13
|
+
|
14
|
+
def set_result(self, result: T) -> None:
|
15
|
+
self._result = result
|
16
|
+
self._event.set()
|
17
|
+
|
18
|
+
def set_exception(self, ex: Exception) -> None:
|
19
|
+
self._ex = ex
|
20
|
+
self._event.set()
|
21
|
+
|
22
|
+
async def result(self) -> T:
|
23
|
+
await self._event.wait()
|
24
|
+
if self._result is not None:
|
25
|
+
return self._result
|
26
|
+
elif self._ex is not None:
|
27
|
+
raise self._ex
|
28
|
+
else:
|
29
|
+
raise RuntimeError("Future completed without a result or error")
|
30
|
+
|
31
|
+
@staticmethod
|
32
|
+
def set_future_result(future: "Future[T]", result: T) -> None:
|
33
|
+
future.set_result(result)
|
34
|
+
|
35
|
+
@staticmethod
|
36
|
+
def set_future_exception(future: "Future[T]", error: Exception) -> None:
|
37
|
+
future.set_exception(error)
|