inspect-ai 0.3.72__py3-none-any.whl → 0.3.73__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (103) hide show
  1. inspect_ai/_cli/eval.py +14 -3
  2. inspect_ai/_cli/sandbox.py +3 -3
  3. inspect_ai/_cli/score.py +6 -4
  4. inspect_ai/_cli/trace.py +53 -6
  5. inspect_ai/_display/core/config.py +1 -1
  6. inspect_ai/_display/core/display.py +2 -1
  7. inspect_ai/_display/core/footer.py +6 -6
  8. inspect_ai/_display/plain/display.py +11 -6
  9. inspect_ai/_display/rich/display.py +23 -13
  10. inspect_ai/_display/textual/app.py +10 -9
  11. inspect_ai/_display/textual/display.py +2 -2
  12. inspect_ai/_display/textual/widgets/footer.py +4 -0
  13. inspect_ai/_display/textual/widgets/samples.py +14 -5
  14. inspect_ai/_eval/context.py +1 -2
  15. inspect_ai/_eval/eval.py +54 -41
  16. inspect_ai/_eval/loader.py +9 -2
  17. inspect_ai/_eval/run.py +148 -81
  18. inspect_ai/_eval/score.py +13 -8
  19. inspect_ai/_eval/task/images.py +31 -21
  20. inspect_ai/_eval/task/run.py +62 -59
  21. inspect_ai/_eval/task/rundir.py +16 -9
  22. inspect_ai/_eval/task/sandbox.py +7 -8
  23. inspect_ai/_eval/task/util.py +7 -0
  24. inspect_ai/_util/_async.py +118 -10
  25. inspect_ai/_util/constants.py +0 -2
  26. inspect_ai/_util/file.py +15 -29
  27. inspect_ai/_util/future.py +37 -0
  28. inspect_ai/_util/http.py +3 -99
  29. inspect_ai/_util/httpx.py +60 -0
  30. inspect_ai/_util/interrupt.py +2 -2
  31. inspect_ai/_util/json.py +5 -52
  32. inspect_ai/_util/logger.py +30 -86
  33. inspect_ai/_util/retry.py +10 -61
  34. inspect_ai/_util/trace.py +2 -2
  35. inspect_ai/_view/server.py +86 -3
  36. inspect_ai/_view/www/dist/assets/index.js +25837 -13269
  37. inspect_ai/_view/www/log-schema.json +253 -186
  38. inspect_ai/_view/www/package.json +2 -2
  39. inspect_ai/_view/www/src/plan/PlanDetailView.tsx +8 -3
  40. inspect_ai/_view/www/src/samples/transcript/StepEventView.tsx +2 -3
  41. inspect_ai/_view/www/src/types/log.d.ts +122 -94
  42. inspect_ai/approval/_human/manager.py +6 -10
  43. inspect_ai/approval/_human/panel.py +2 -2
  44. inspect_ai/dataset/_sources/util.py +7 -6
  45. inspect_ai/log/__init__.py +4 -0
  46. inspect_ai/log/_file.py +35 -61
  47. inspect_ai/log/_log.py +18 -1
  48. inspect_ai/log/_recorders/eval.py +14 -23
  49. inspect_ai/log/_recorders/json.py +3 -18
  50. inspect_ai/log/_samples.py +27 -2
  51. inspect_ai/log/_transcript.py +8 -8
  52. inspect_ai/model/__init__.py +2 -1
  53. inspect_ai/model/_call_tools.py +60 -40
  54. inspect_ai/model/_chat_message.py +3 -2
  55. inspect_ai/model/_generate_config.py +25 -0
  56. inspect_ai/model/_model.py +74 -36
  57. inspect_ai/model/_openai.py +9 -1
  58. inspect_ai/model/_providers/anthropic.py +24 -26
  59. inspect_ai/model/_providers/azureai.py +11 -9
  60. inspect_ai/model/_providers/bedrock.py +33 -24
  61. inspect_ai/model/_providers/cloudflare.py +8 -9
  62. inspect_ai/model/_providers/goodfire.py +7 -3
  63. inspect_ai/model/_providers/google.py +47 -13
  64. inspect_ai/model/_providers/groq.py +15 -15
  65. inspect_ai/model/_providers/hf.py +24 -17
  66. inspect_ai/model/_providers/mistral.py +36 -20
  67. inspect_ai/model/_providers/openai.py +30 -25
  68. inspect_ai/model/_providers/openai_o1.py +1 -1
  69. inspect_ai/model/_providers/providers.py +1 -1
  70. inspect_ai/model/_providers/together.py +3 -4
  71. inspect_ai/model/_providers/util/__init__.py +2 -2
  72. inspect_ai/model/_providers/util/chatapi.py +6 -19
  73. inspect_ai/model/_providers/util/hooks.py +165 -0
  74. inspect_ai/model/_providers/vertex.py +20 -3
  75. inspect_ai/model/_providers/vllm.py +16 -19
  76. inspect_ai/scorer/_multi.py +5 -2
  77. inspect_ai/solver/_bridge/patch.py +31 -1
  78. inspect_ai/solver/_fork.py +5 -3
  79. inspect_ai/solver/_human_agent/agent.py +3 -2
  80. inspect_ai/tool/__init__.py +8 -2
  81. inspect_ai/tool/_tool_info.py +4 -90
  82. inspect_ai/tool/_tool_params.py +4 -34
  83. inspect_ai/tool/_tools/_web_search.py +30 -24
  84. inspect_ai/util/__init__.py +4 -0
  85. inspect_ai/util/_concurrency.py +5 -6
  86. inspect_ai/util/_display.py +6 -0
  87. inspect_ai/util/_json.py +170 -0
  88. inspect_ai/util/_sandbox/docker/cleanup.py +13 -9
  89. inspect_ai/util/_sandbox/docker/docker.py +5 -0
  90. inspect_ai/util/_sandbox/environment.py +56 -9
  91. inspect_ai/util/_sandbox/service.py +12 -5
  92. inspect_ai/util/_subprocess.py +94 -113
  93. inspect_ai/util/_subtask.py +2 -4
  94. {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/METADATA +6 -2
  95. {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/RECORD +99 -99
  96. {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/WHEEL +1 -1
  97. inspect_ai/_util/timeouts.py +0 -160
  98. inspect_ai/_view/www/node_modules/flatted/python/flatted.py +0 -149
  99. inspect_ai/_view/www/node_modules/flatted/python/test.py +0 -63
  100. inspect_ai/model/_providers/util/tracker.py +0 -92
  101. {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/LICENSE +0 -0
  102. {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/entry_points.txt +0 -0
  103. {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,5 @@
1
- import asyncio
2
1
  import contextlib
2
+ import functools
3
3
  import sys
4
4
  import time
5
5
  from copy import deepcopy
@@ -9,6 +9,7 @@ from logging import getLogger
9
9
  from pathlib import PurePath
10
10
  from typing import Callable, Literal
11
11
 
12
+ import anyio
12
13
  from typing_extensions import Unpack
13
14
 
14
15
  from inspect_ai._display import (
@@ -19,6 +20,7 @@ from inspect_ai._display import (
19
20
  display,
20
21
  )
21
22
  from inspect_ai._display.core.display import TaskDisplay, TaskDisplayMetric
23
+ from inspect_ai._util._async import tg_collect
22
24
  from inspect_ai._util.constants import (
23
25
  DEFAULT_EPOCHS,
24
26
  DEFAULT_MAX_CONNECTIONS,
@@ -32,7 +34,6 @@ from inspect_ai._util.registry import (
32
34
  registry_log_name,
33
35
  registry_unqualified_name,
34
36
  )
35
- from inspect_ai._util.timeouts import Timeout, timeout
36
37
  from inspect_ai._util.working import (
37
38
  init_sample_working_limit,
38
39
  sample_waiting_time,
@@ -95,9 +96,9 @@ from .images import (
95
96
  )
96
97
  from .log import TaskLogger, collect_eval_data, log_start
97
98
  from .results import eval_results
98
- from .rundir import set_task_run_dir
99
+ from .rundir import set_task_chdir
99
100
  from .sandbox import sandboxenv_context
100
- from .util import sample_messages, slice_dataset, task_run_dir
101
+ from .util import sample_messages, slice_dataset
101
102
 
102
103
  py_logger = getLogger(__name__)
103
104
 
@@ -147,8 +148,8 @@ async def task_run(options: TaskRunOptions) -> EvalLog:
147
148
  # init task context
148
149
  init_task_context(model, options.task.approval, generate_config)
149
150
 
150
- # establish run_dir for duration of execution
151
- with set_task_run_dir(task_run_dir(task)):
151
+ # establish chdir for duration of execution (if a task has chdir=True)
152
+ with set_task_chdir(task):
152
153
  # track stats and error
153
154
  results: EvalResults | None = None
154
155
  reductions: list[EvalSampleReductions] | None = None
@@ -286,35 +287,6 @@ async def task_run(options: TaskRunOptions) -> EvalLog:
286
287
  task.metrics,
287
288
  )
288
289
 
289
- # create sample coroutines
290
- sample_coroutines = [
291
- task_run_sample(
292
- task_name=task.name,
293
- sample=sample,
294
- state=state,
295
- sandbox=sandbox,
296
- max_sandboxes=config.max_sandboxes,
297
- sandbox_cleanup=sandbox_cleanup,
298
- plan=plan,
299
- scorers=scorers,
300
- generate=generate,
301
- progress=progress,
302
- logger=logger if log_samples else None,
303
- log_images=log_images,
304
- sample_source=sample_source,
305
- sample_error=sample_error_handler,
306
- sample_complete=sample_complete,
307
- fails_on_error=(
308
- config.fail_on_error is None
309
- or config.fail_on_error is True
310
- ),
311
- time_limit=config.time_limit,
312
- working_limit=config.working_limit,
313
- semaphore=sample_semaphore,
314
- )
315
- for (sample, state) in zip(samples, states)
316
- ]
317
-
318
290
  # initial progress
319
291
  td.sample_complete(complete=0, total=len(samples))
320
292
 
@@ -327,7 +299,36 @@ async def task_run(options: TaskRunOptions) -> EvalLog:
327
299
  task.metrics,
328
300
  )
329
301
 
330
- sample_results = await asyncio.gather(*sample_coroutines)
302
+ sample_results = await tg_collect(
303
+ [
304
+ functools.partial(
305
+ task_run_sample,
306
+ task_name=task.name,
307
+ sample=sample,
308
+ state=state,
309
+ sandbox=sandbox,
310
+ max_sandboxes=config.max_sandboxes,
311
+ sandbox_cleanup=sandbox_cleanup,
312
+ plan=plan,
313
+ scorers=scorers,
314
+ generate=generate,
315
+ progress=progress,
316
+ logger=logger if log_samples else None,
317
+ log_images=log_images,
318
+ sample_source=sample_source,
319
+ sample_error=sample_error_handler,
320
+ sample_complete=sample_complete,
321
+ fails_on_error=(
322
+ config.fail_on_error is None
323
+ or config.fail_on_error is True
324
+ ),
325
+ time_limit=config.time_limit,
326
+ working_limit=config.working_limit,
327
+ semaphore=sample_semaphore,
328
+ )
329
+ for (sample, state) in zip(samples, states)
330
+ ]
331
+ )
331
332
 
332
333
  # compute and record metrics if we have scores
333
334
  completed_scores = [
@@ -362,17 +363,18 @@ async def task_run(options: TaskRunOptions) -> EvalLog:
362
363
  )
363
364
  )
364
365
 
365
- except asyncio.CancelledError:
366
- # collect eval data
367
- collect_eval_data(stats)
366
+ except anyio.get_cancelled_exc_class():
367
+ with anyio.CancelScope(shield=True):
368
+ # collect eval data
369
+ collect_eval_data(stats)
368
370
 
369
- # finish w/ cancelled status
370
- eval_log = await logger.log_finish(
371
- "cancelled", stats, results, reductions
372
- )
371
+ # finish w/ cancelled status
372
+ eval_log = await logger.log_finish(
373
+ "cancelled", stats, results, reductions
374
+ )
373
375
 
374
- # display task cancelled
375
- td.complete(TaskCancelled(logger.samples_completed, stats))
376
+ # display task cancelled
377
+ td.complete(TaskCancelled(logger.samples_completed, stats))
376
378
 
377
379
  except BaseException as ex:
378
380
  if options.debug_errors:
@@ -503,7 +505,7 @@ async def task_run_sample(
503
505
  fails_on_error: bool,
504
506
  time_limit: int | None,
505
507
  working_limit: int | None,
506
- semaphore: asyncio.Semaphore | None,
508
+ semaphore: anyio.Semaphore | None,
507
509
  ) -> dict[str, SampleScore] | None:
508
510
  # if there is an existing sample then tick off its progress, log it, and return it
509
511
  if sample_source and sample.id is not None:
@@ -533,7 +535,7 @@ async def task_run_sample(
533
535
  return sample_scores
534
536
 
535
537
  # use semaphore if provided
536
- semaphore_cm: asyncio.Semaphore | contextlib.AbstractAsyncContextManager[None] = (
538
+ semaphore_cm: anyio.Semaphore | contextlib.AbstractAsyncContextManager[None] = (
537
539
  semaphore if semaphore else contextlib.nullcontext()
538
540
  )
539
541
 
@@ -606,7 +608,7 @@ async def task_run_sample(
606
608
 
607
609
  # initialise timeout context manager
608
610
  timeout_cm = (
609
- timeout(time_limit)
611
+ anyio.fail_after(time_limit)
610
612
  if time_limit is not None
611
613
  else contextlib.nullcontext()
612
614
  )
@@ -616,7 +618,7 @@ async def task_run_sample(
616
618
  init_sample_working_limit(start_time, working_limit)
617
619
 
618
620
  # run sample w/ optional timeout
619
- async with timeout_cm:
621
+ with timeout_cm:
620
622
  # mark started
621
623
  active.started = datetime.now().timestamp()
622
624
 
@@ -640,9 +642,9 @@ async def task_run_sample(
640
642
  # capture most recent state for scoring
641
643
  state = sample_state() or state
642
644
 
643
- except asyncio.CancelledError as ex:
645
+ except anyio.get_cancelled_exc_class() as ex:
644
646
  if active.interrupt_action:
645
- # record eve t
647
+ # record event
646
648
  transcript()._event(
647
649
  SampleLimitEvent(
648
650
  type="operator",
@@ -660,6 +662,8 @@ async def task_run_sample(
660
662
  error, raise_error = handle_error(ex)
661
663
 
662
664
  else:
665
+ # task group provided by tg_collect will automatically
666
+ # handle the cancel exception
663
667
  raise
664
668
 
665
669
  except SampleLimitExceededError as ex:
@@ -687,9 +691,8 @@ async def task_run_sample(
687
691
  # the cause of the timeout is a hung container and scoring requires
688
692
  # interacting with the container). as a middle ground we use half
689
693
  # of the original timeout value for scoring.
690
- if isinstance(timeout_cm, Timeout):
691
- assert time_limit
692
- timeout_cm = timeout(time_limit / 2)
694
+ if time_limit is not None:
695
+ timeout_cm = anyio.fail_after(time_limit / 2)
693
696
 
694
697
  # turn off message and token limits
695
698
  state.message_limit = None
@@ -699,7 +702,7 @@ async def task_run_sample(
699
702
  # scoring
700
703
  try:
701
704
  # timeout during scoring will result in an ordinary sample error
702
- async with timeout_cm:
705
+ with timeout_cm:
703
706
  if error is None:
704
707
  for scorer in scorers or []:
705
708
  scorer_name = unique_scorer_name(
@@ -740,7 +743,7 @@ async def task_run_sample(
740
743
  # propagate results into scores
741
744
  state.scores = {k: v.score for k, v in results.items()}
742
745
 
743
- except asyncio.CancelledError:
746
+ except anyio.get_cancelled_exc_class():
744
747
  if active.interrupt_action:
745
748
  transcript()._event(
746
749
  SampleLimitEvent(
@@ -970,10 +973,10 @@ def create_sample_semaphore(
970
973
  config: EvalConfig,
971
974
  generate_config: GenerateConfig,
972
975
  modelapi: ModelAPI | None = None,
973
- ) -> asyncio.Semaphore:
976
+ ) -> anyio.Semaphore:
974
977
  # if the user set max_samples then use that
975
978
  if config.max_samples is not None:
976
- return asyncio.Semaphore(config.max_samples)
979
+ return anyio.Semaphore(config.max_samples)
977
980
 
978
981
  # use max_connections
979
982
  max_samples = (
@@ -985,4 +988,4 @@ def create_sample_semaphore(
985
988
  )
986
989
 
987
990
  # return the semaphore
988
- return asyncio.Semaphore(max_samples)
991
+ return anyio.Semaphore(max_samples)
@@ -6,9 +6,12 @@ from contextvars import ContextVar
6
6
  from functools import wraps
7
7
  from typing import Any, Callable, Iterator, TypeVar
8
8
 
9
+ from inspect_ai._eval.task.task import Task
10
+ from inspect_ai._eval.task.util import task_chdir
11
+
9
12
  TASK_DIRECTORY_ATTRIB = "task_directory"
10
13
 
11
- _task_run_dir = ContextVar[str]("task_run_dir")
14
+ _task_chdir = ContextVar[str | None]("_task_chdir", default=None)
12
15
 
13
16
  T = TypeVar("T", bound="asyncio.BaseEventLoop")
14
17
 
@@ -46,12 +49,16 @@ def task_run_dir_switching() -> Iterator[None]:
46
49
 
47
50
 
48
51
  @contextmanager
49
- def set_task_run_dir(run_dir: str) -> Iterator[None]:
50
- token = _task_run_dir.set(run_dir)
51
- try:
52
+ def set_task_chdir(task: Task) -> Iterator[None]:
53
+ chdir = task_chdir(task)
54
+ if chdir is not None:
55
+ token = _task_chdir.set(chdir)
56
+ try:
57
+ yield
58
+ finally:
59
+ _task_chdir.reset(token)
60
+ else:
52
61
  yield
53
- finally:
54
- _task_run_dir.reset(token)
55
62
 
56
63
 
57
64
  if sys.platform == "win32":
@@ -63,9 +70,9 @@ else:
63
70
  def _wrap_callback(callback: Callable[..., Any]) -> Callable[..., Any]:
64
71
  @wraps(callback)
65
72
  def wrapper(*args: Any, **kwargs: Any) -> Any:
66
- run_dir = _task_run_dir.get(None)
67
- if run_dir is not None and run_dir != os.getcwd():
68
- os.chdir(run_dir)
73
+ chdir = _task_chdir.get(None)
74
+ if chdir is not None and chdir != os.getcwd():
75
+ os.chdir(chdir)
69
76
  return callback(*args, **kwargs)
70
77
 
71
78
  return wrapper
@@ -1,9 +1,9 @@
1
- import asyncio
2
1
  import base64
3
2
  import contextlib
4
3
  from random import random
5
4
  from typing import AsyncGenerator, Callable, NamedTuple, cast
6
5
 
6
+ import anyio
7
7
  import httpx
8
8
  from tenacity import (
9
9
  retry,
@@ -15,10 +15,9 @@ from tenacity import (
15
15
 
16
16
  from inspect_ai._eval.task.task import Task
17
17
  from inspect_ai._eval.task.util import task_run_dir
18
- from inspect_ai._util.constants import DEFAULT_MAX_RETRIES, DEFAULT_TIMEOUT
19
18
  from inspect_ai._util.file import file, filesystem
19
+ from inspect_ai._util.httpx import httpx_should_retry, log_httpx_retry_attempt
20
20
  from inspect_ai._util.registry import registry_unqualified_name
21
- from inspect_ai._util.retry import httpx_should_retry, log_retry_attempt
22
21
  from inspect_ai._util.url import data_uri_to_base64, is_data_uri, is_http_url
23
22
  from inspect_ai.dataset import Sample
24
23
  from inspect_ai.util._concurrency import concurrency
@@ -62,7 +61,7 @@ async def sandboxenv_context(
62
61
  # in and grab all of the sandboxes). Therefore, in this case we wait a random
63
62
  # delay so that all tasks/samples have an equal shot at getting scheduled.
64
63
  if max_sandboxes is not None:
65
- await asyncio.sleep(random())
64
+ await anyio.sleep(random())
66
65
 
67
66
  # enforce concurrency if required
68
67
  sandboxes_cm = (
@@ -103,7 +102,7 @@ async def sandboxenv_context(
103
102
  # run sample
104
103
  yield
105
104
 
106
- except asyncio.CancelledError as ex:
105
+ except anyio.get_cancelled_exc_class() as ex:
107
106
  interrupted = True
108
107
  raise ex
109
108
 
@@ -186,14 +185,14 @@ async def _retrying_httpx_get(
186
185
  url: str,
187
186
  client: httpx.AsyncClient = httpx.AsyncClient(),
188
187
  timeout: int = 30, # per-attempt timeout
189
- max_retries: int = DEFAULT_MAX_RETRIES,
190
- total_timeout: int = DEFAULT_TIMEOUT, # timeout for the whole retry loop. not for an individual attempt
188
+ max_retries: int = 10,
189
+ total_timeout: int = 120, # timeout for the whole retry loop. not for an individual attempt
191
190
  ) -> bytes:
192
191
  @retry(
193
192
  wait=wait_exponential_jitter(),
194
193
  stop=(stop_after_attempt(max_retries) | stop_after_delay(total_timeout)),
195
194
  retry=retry_if_exception(httpx_should_retry),
196
- before_sleep=log_retry_attempt(url),
195
+ before_sleep=log_httpx_retry_attempt(url),
197
196
  )
198
197
  async def do_get() -> bytes:
199
198
  response = await client.get(
@@ -25,6 +25,13 @@ def task_run_dir(task: Task) -> str:
25
25
  return getattr(task, TASK_RUN_DIR_ATTR, os.getcwd())
26
26
 
27
27
 
28
+ def task_chdir(task: Task) -> str | None:
29
+ if task.attribs.get("chdir", False) is True:
30
+ return task_run_dir(task)
31
+ else:
32
+ return None
33
+
34
+
28
35
  def task_file(task: Task, relative: bool = False) -> str | None:
29
36
  file = cast(str | None, getattr(task, TASK_FILE_ATTR, None))
30
37
  if file:
@@ -1,20 +1,100 @@
1
1
  import asyncio
2
- from typing import Any, Coroutine, TypeVar
2
+ import inspect
3
+ import os
4
+ import sys
5
+ from logging import Logger
6
+ from typing import Any, Awaitable, Callable, Coroutine, Literal, TypeVar, cast
3
7
 
8
+ import anyio
4
9
  import nest_asyncio # type: ignore
10
+ import sniffio
11
+
12
+ if sys.version_info >= (3, 11):
13
+ from typing import TypeVarTuple, Unpack
14
+ else:
15
+ from exceptiongroup import ExceptionGroup
16
+ from typing_extensions import TypeVarTuple, Unpack
17
+
18
+
19
+ PosArgsT = TypeVarTuple("PosArgsT")
5
20
 
6
21
 
7
22
  def is_callable_coroutine(func_or_cls: Any) -> bool:
8
- if asyncio.iscoroutinefunction(func_or_cls):
23
+ if inspect.iscoroutinefunction(func_or_cls):
9
24
  return True
10
25
  elif callable(func_or_cls):
11
- return asyncio.iscoroutinefunction(func_or_cls.__call__)
26
+ return inspect.iscoroutinefunction(func_or_cls.__call__)
12
27
  return False
13
28
 
14
29
 
15
30
  T = TypeVar("T")
16
31
 
17
32
 
33
+ async def tg_collect(
34
+ funcs: list[Callable[[], Awaitable[T]]], exception_group: bool = False
35
+ ) -> list[T]:
36
+ """Runs all of the pased async functions and collects their results.
37
+
38
+ The results will be returned in the same order as the input `funcs`.
39
+
40
+ Args:
41
+ funcs: List of async functions.
42
+ exception_group: `True` to raise an ExceptionGroup or
43
+ `False` (the default) to raise only the first exception.
44
+
45
+ Returns:
46
+ List of results of type T.
47
+
48
+ Raises:
49
+ Exception: First exception occurring in failed tasks
50
+ (for `exception_group == False`, the default)
51
+ ExceptionGroup: Exceptions that occurred in failed tasks
52
+ (for `exception_group == True`)
53
+ """
54
+ try:
55
+ results: list[tuple[int, T]] = []
56
+
57
+ async with anyio.create_task_group() as tg:
58
+
59
+ async def run_task(index: int) -> None:
60
+ result = await funcs[index]()
61
+ results.append((index, result))
62
+
63
+ for i in range(0, len(funcs)):
64
+ tg.start_soon(run_task, i)
65
+
66
+ # sort results by original index and return just the values
67
+ return [r for _, r in sorted(results)]
68
+ except ExceptionGroup as ex:
69
+ if exception_group:
70
+ raise
71
+ else:
72
+ raise ex.exceptions[0]
73
+
74
+
75
+ async def coro_print_exceptions(
76
+ context: str,
77
+ func: Callable[[Unpack[PosArgsT]], Awaitable[Any]],
78
+ *args: Unpack[PosArgsT],
79
+ ) -> None:
80
+ try:
81
+ await func(*args)
82
+ except Exception as ex:
83
+ print(f"Error {context}: {ex}")
84
+
85
+
86
+ async def coro_log_exceptions(
87
+ logger: Logger,
88
+ context: str,
89
+ func: Callable[[Unpack[PosArgsT]], Awaitable[Any]],
90
+ *args: Unpack[PosArgsT],
91
+ ) -> None:
92
+ try:
93
+ await func(*args)
94
+ except Exception as ex:
95
+ logger.warning(f"Error {context}: {ex}")
96
+
97
+
18
98
  _initialised_nest_asyncio: bool = False
19
99
 
20
100
 
@@ -26,14 +106,42 @@ def init_nest_asyncio() -> None:
26
106
 
27
107
 
28
108
  def run_coroutine(coroutine: Coroutine[None, None, T]) -> T:
29
- try:
30
- # this will throw if there is no running loop
31
- asyncio.get_running_loop()
109
+ from inspect_ai._util.platform import running_in_notebook
110
+
111
+ if current_async_backend() == "trio":
112
+ raise RuntimeError("run_coroutine cannot be used with trio")
32
113
 
33
- # initialiase nest_asyncio then we are clear to run
114
+ if running_in_notebook():
34
115
  init_nest_asyncio()
35
116
  return asyncio.run(coroutine)
117
+ else:
118
+ try:
119
+ # this will throw if there is no running loop
120
+ asyncio.get_running_loop()
36
121
 
37
- except RuntimeError:
38
- # No running event loop so we are clear to run
39
- return asyncio.run(coroutine)
122
+ # initialiase nest_asyncio then we are clear to run
123
+ init_nest_asyncio()
124
+ return asyncio.run(coroutine)
125
+
126
+ except RuntimeError:
127
+ # No running event loop so we are clear to run
128
+ return asyncio.run(coroutine)
129
+
130
+
131
+ def current_async_backend() -> Literal["asyncio", "trio"] | None:
132
+ try:
133
+ return _validate_backend(sniffio.current_async_library().lower())
134
+ except sniffio.AsyncLibraryNotFoundError:
135
+ return None
136
+
137
+
138
+ def configured_async_backend() -> Literal["asyncio", "trio"]:
139
+ backend = os.environ.get("INSPECT_ASYNC_BACKEND", "asyncio").lower()
140
+ return _validate_backend(backend)
141
+
142
+
143
+ def _validate_backend(backend: str) -> Literal["asyncio", "trio"]:
144
+ if backend in ["asyncio", "trio"]:
145
+ return cast(Literal["asyncio", "trio"], backend)
146
+ else:
147
+ raise RuntimeError(f"Unknown async backend: {backend}")
@@ -6,8 +6,6 @@ PKG_AUTHOR_DIR = "UK-AISI"
6
6
  PKG_NAME = Path(__file__).parent.parent.stem
7
7
  PKG_PATH = Path(__file__).parent.parent
8
8
  DEFAULT_EPOCHS = 1
9
- DEFAULT_MAX_RETRIES = 5
10
- DEFAULT_TIMEOUT = 120
11
9
  DEFAULT_MAX_CONNECTIONS = 10
12
10
  DEFAULT_MAX_TOKENS = 2048
13
11
  DEFAULT_VIEW_PORT = 7575
inspect_ai/_util/file.py CHANGED
@@ -1,5 +1,3 @@
1
- import asyncio
2
- import contextlib
3
1
  import datetime
4
2
  import io
5
3
  import os
@@ -9,17 +7,19 @@ import unicodedata
9
7
  from contextlib import contextmanager
10
8
  from copy import deepcopy
11
9
  from pathlib import Path
12
- from typing import Any, AsyncIterator, BinaryIO, Iterator, Literal, cast, overload
10
+ from typing import Any, BinaryIO, Iterator, Literal, cast, overload
13
11
  from urllib.parse import urlparse
14
12
 
15
13
  import fsspec # type: ignore # type: ignore
16
- from fsspec.asyn import AsyncFileSystem # type: ignore
17
14
  from fsspec.core import split_protocol # type: ignore # type: ignore
18
15
  from fsspec.implementations.local import make_path_posix # type: ignore
19
16
  from pydantic import BaseModel
20
17
  from s3fs import S3FileSystem # type: ignore
21
18
  from shortuuid import uuid
22
19
 
20
+ from inspect_ai._util._async import configured_async_backend, current_async_backend
21
+ from inspect_ai._util.error import PrerequisiteError
22
+
23
23
  # https://filesystem-spec.readthedocs.io/en/latest/_modules/fsspec/spec.html#AbstractFileSystem
24
24
  # https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.generic.GenericFileSystem
25
25
 
@@ -298,30 +298,6 @@ def filesystem(path: str, fs_options: dict[str, Any] = {}) -> FileSystem:
298
298
  return FileSystem(fs)
299
299
 
300
300
 
301
- @contextlib.asynccontextmanager
302
- async def async_fileystem(
303
- location: str, fs_options: dict[str, Any] = {}
304
- ) -> AsyncIterator[AsyncFileSystem]:
305
- # determine protocol
306
- protocol, _ = split_protocol(location)
307
- protocol = protocol or "file"
308
-
309
- # build options
310
- options = default_fs_options(location)
311
- options.update(fs_options)
312
-
313
- if protocol == "s3":
314
- s3 = S3FileSystem(asynchronous=True, **options)
315
- session = await s3.set_session()
316
- try:
317
- yield s3
318
- finally:
319
- await session.close()
320
- else:
321
- options.update({"asynchronous": True, "loop": asyncio.get_event_loop()})
322
- yield fsspec.filesystem(protocol, **options)
323
-
324
-
325
301
  def absolute_file_path(file: str) -> str:
326
302
  # check for a relative dir, if we find one then resolve to absolute
327
303
  fs_scheme = urlparse(file).scheme
@@ -331,7 +307,17 @@ def absolute_file_path(file: str) -> str:
331
307
 
332
308
 
333
309
  def default_fs_options(file: str) -> dict[str, Any]:
334
- options = deepcopy(DEFAULT_FS_OPTIONS.get(urlparse(file).scheme, {}))
310
+ scheme = urlparse(file).scheme
311
+ if (
312
+ scheme == "s3"
313
+ and configured_async_backend() == "trio"
314
+ and current_async_backend() == "trio"
315
+ ):
316
+ raise PrerequisiteError(
317
+ "ERROR: The s3 interface is not supported when running under the trio async backend."
318
+ )
319
+
320
+ options = deepcopy(DEFAULT_FS_OPTIONS.get(scheme, {}))
335
321
  # disable caching for all filesystems
336
322
  options.update(
337
323
  dict(
@@ -0,0 +1,37 @@
1
+ from typing import Generic, TypeVar
2
+
3
+ import anyio
4
+
5
+ T = TypeVar("T")
6
+
7
+
8
+ class Future(Generic[T]):
9
+ def __init__(self) -> None:
10
+ self._result: T | None = None
11
+ self._ex: Exception | None = None
12
+ self._event = anyio.Event()
13
+
14
+ def set_result(self, result: T) -> None:
15
+ self._result = result
16
+ self._event.set()
17
+
18
+ def set_exception(self, ex: Exception) -> None:
19
+ self._ex = ex
20
+ self._event.set()
21
+
22
+ async def result(self) -> T:
23
+ await self._event.wait()
24
+ if self._result is not None:
25
+ return self._result
26
+ elif self._ex is not None:
27
+ raise self._ex
28
+ else:
29
+ raise RuntimeError("Future completed without a result or error")
30
+
31
+ @staticmethod
32
+ def set_future_result(future: "Future[T]", result: T) -> None:
33
+ future.set_result(result)
34
+
35
+ @staticmethod
36
+ def set_future_exception(future: "Future[T]", error: Exception) -> None:
37
+ future.set_exception(error)