inspect-ai 0.3.56__py3-none-any.whl → 0.3.57__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. inspect_ai/_display/core/panel.py +1 -1
  2. inspect_ai/_eval/run.py +16 -11
  3. inspect_ai/_util/datetime.py +1 -1
  4. inspect_ai/_util/deprecation.py +1 -1
  5. inspect_ai/_util/json.py +11 -1
  6. inspect_ai/_util/logger.py +2 -1
  7. inspect_ai/_util/trace.py +39 -3
  8. inspect_ai/_util/transcript.py +36 -7
  9. inspect_ai/_view/www/.prettierrc.js +12 -0
  10. inspect_ai/_view/www/dist/assets/index.js +286 -224
  11. inspect_ai/_view/www/log-schema.json +124 -125
  12. inspect_ai/_view/www/src/App.mjs +18 -9
  13. inspect_ai/_view/www/src/Types.mjs +0 -1
  14. inspect_ai/_view/www/src/api/Types.mjs +15 -4
  15. inspect_ai/_view/www/src/api/api-http.mjs +2 -0
  16. inspect_ai/_view/www/src/components/ExpandablePanel.mjs +2 -2
  17. inspect_ai/_view/www/src/components/FindBand.mjs +5 -4
  18. inspect_ai/_view/www/src/components/LargeModal.mjs +1 -1
  19. inspect_ai/_view/www/src/components/MessageContent.mjs +1 -1
  20. inspect_ai/_view/www/src/components/TabSet.mjs +1 -1
  21. inspect_ai/_view/www/src/components/Tools.mjs +18 -3
  22. inspect_ai/_view/www/src/components/VirtualList.mjs +15 -17
  23. inspect_ai/_view/www/src/log/remoteLogFile.mjs +2 -1
  24. inspect_ai/_view/www/src/navbar/Navbar.mjs +44 -32
  25. inspect_ai/_view/www/src/samples/SampleDisplay.mjs +1 -2
  26. inspect_ai/_view/www/src/samples/SampleList.mjs +35 -4
  27. inspect_ai/_view/www/src/samples/SampleScoreView.mjs +13 -2
  28. inspect_ai/_view/www/src/samples/SampleScores.mjs +11 -2
  29. inspect_ai/_view/www/src/samples/SamplesDescriptor.mjs +238 -178
  30. inspect_ai/_view/www/src/samples/SamplesTab.mjs +4 -2
  31. inspect_ai/_view/www/src/samples/tools/SampleFilter.mjs +5 -5
  32. inspect_ai/_view/www/src/samples/tools/SelectScorer.mjs +7 -0
  33. inspect_ai/_view/www/src/samples/tools/SortFilter.mjs +3 -3
  34. inspect_ai/_view/www/src/samples/transcript/ToolEventView.mjs +1 -1
  35. inspect_ai/_view/www/src/types/log.d.ts +2 -8
  36. inspect_ai/_view/www/src/workspace/WorkSpace.mjs +1 -1
  37. inspect_ai/log/_log.py +25 -0
  38. inspect_ai/log/_recorders/eval.py +2 -0
  39. inspect_ai/model/_call_tools.py +27 -5
  40. inspect_ai/model/_providers/google.py +24 -6
  41. inspect_ai/model/_providers/openai.py +17 -3
  42. inspect_ai/model/_providers/openai_o1.py +10 -12
  43. inspect_ai/tool/_tool_info.py +2 -1
  44. inspect_ai/tool/_tools/_web_browser/_resources/dm_env_servicer.py +9 -9
  45. inspect_ai/tool/_tools/_web_browser/_web_browser.py +3 -3
  46. inspect_ai/util/__init__.py +4 -0
  47. inspect_ai/util/_sandbox/docker/compose.py +1 -3
  48. inspect_ai/util/_sandbox/docker/util.py +2 -1
  49. inspect_ai/util/_sandbox/self_check.py +18 -18
  50. inspect_ai/util/_store.py +2 -2
  51. inspect_ai/util/_subprocess.py +3 -3
  52. {inspect_ai-0.3.56.dist-info → inspect_ai-0.3.57.dist-info}/METADATA +3 -3
  53. {inspect_ai-0.3.56.dist-info → inspect_ai-0.3.57.dist-info}/RECORD +57 -56
  54. {inspect_ai-0.3.56.dist-info → inspect_ai-0.3.57.dist-info}/WHEEL +1 -1
  55. {inspect_ai-0.3.56.dist-info → inspect_ai-0.3.57.dist-info}/LICENSE +0 -0
  56. {inspect_ai-0.3.56.dist-info → inspect_ai-0.3.57.dist-info}/entry_points.txt +0 -0
  57. {inspect_ai-0.3.56.dist-info → inspect_ai-0.3.57.dist-info}/top_level.txt +0 -0
@@ -37,7 +37,7 @@ export const ToolEventView = ({ id, event, style, depth }) => {
37
37
  functionCall=${functionCall}
38
38
  input=${input}
39
39
  inputType=${inputType}
40
- output=${event.result}
40
+ output=${event.error?.message || event.result}
41
41
  mode="compact"
42
42
  view=${event.view}
43
43
  />
@@ -396,7 +396,7 @@ export type Answer1 = string | null;
396
396
  export type Explanation2 = string | null;
397
397
  export type Metadata8 = {} | null;
398
398
  export type SampleId1 = string | number | null;
399
- export type Samples2 = SampleScore[];
399
+ export type Samples2 = EvalSampleScore[];
400
400
  export type Location1 = string;
401
401
 
402
402
  export interface EvalLog {
@@ -1034,13 +1034,7 @@ export interface EvalSampleReductions {
1034
1034
  reducer: Reducer1;
1035
1035
  samples: Samples2;
1036
1036
  }
1037
- /**
1038
- * Score for a Sample
1039
- *
1040
- * Args:
1041
- * sample_id: (str | int | None) Unique id of a sample
1042
- */
1043
- export interface SampleScore {
1037
+ export interface EvalSampleScore {
1044
1038
  value: Value2;
1045
1039
  answer: Answer1;
1046
1040
  explanation: Explanation2;
@@ -150,7 +150,7 @@ export const WorkSpace = ({
150
150
 
151
151
  // The samples tab
152
152
  // Currently only appears when the result is successful
153
- if (evalStatus !== "error" && sampleMode !== "none") {
153
+ if (sampleMode !== "none") {
154
154
  resolvedTabs.samples = {
155
155
  id: kEvalWorkspaceTabId,
156
156
  scrollable: samples.length === 1,
inspect_ai/log/_log.py CHANGED
@@ -16,6 +16,7 @@ from inspect_ai._util.constants import CONSOLE_DISPLAY_WIDTH, PKG_NAME
16
16
  from inspect_ai._util.error import EvalError, exception_message
17
17
  from inspect_ai._util.logger import warn_once
18
18
  from inspect_ai.approval._policy import ApprovalPolicyConfig
19
+ from inspect_ai.dataset._dataset import MT, metadata_as
19
20
  from inspect_ai.model import (
20
21
  ChatMessage,
21
22
  GenerateConfig,
@@ -24,6 +25,8 @@ from inspect_ai.model import (
24
25
  )
25
26
  from inspect_ai.scorer import Score
26
27
  from inspect_ai.util._sandbox.environment import SandboxEnvironmentSpec
28
+ from inspect_ai.util._store import Store
29
+ from inspect_ai.util._store_model import SMT
27
30
 
28
31
  from ._transcript import Event
29
32
 
@@ -158,9 +161,31 @@ class EvalSample(BaseModel):
158
161
  metadata: dict[str, Any]
159
162
  """Additional sample metadata."""
160
163
 
164
+ def metadata_as(self, metadata_cls: Type[MT]) -> MT:
165
+ """Pydantic model interface to metadata.
166
+
167
+ Args:
168
+ metadata_cls: Pydantic model type
169
+
170
+ Returns:
171
+ BaseModel: Instance of metadata_cls bound to sample metadata.
172
+ """
173
+ return metadata_as(self.metadata, metadata_cls)
174
+
161
175
  store: dict[str, Any] = Field(default_factory=dict)
162
176
  """State at end of sample execution."""
163
177
 
178
+ def store_as(self, model_cls: Type[SMT]) -> SMT:
179
+ """Pydantic model interface to the store.
180
+
181
+ Args:
182
+ model_cls: Pydantic model type (must derive from StoreModel)
183
+
184
+ Returns:
185
+ StoreModel: Instance of model_cls bound to sample store data.
186
+ """
187
+ return model_cls(store=Store(self.store))
188
+
164
189
  events: list[Event] = Field(default_factory=list)
165
190
  """Events that occurred during sample execution."""
166
191
 
@@ -252,6 +252,8 @@ def text_inputs(inputs: str | list[ChatMessage]) -> str | list[ChatMessage]:
252
252
  filtered_content.append(ContentText(text="(Image)"))
253
253
  message.content = filtered_content
254
254
  input.append(message)
255
+ else:
256
+ input.append(message)
255
257
 
256
258
  return input
257
259
  else:
@@ -1,15 +1,20 @@
1
1
  import asyncio
2
2
  import inspect
3
+ import types
3
4
  from dataclasses import is_dataclass
4
5
  from logging import getLogger
5
6
  from textwrap import dedent
7
+ from types import UnionType
6
8
  from typing import (
7
9
  Any,
8
10
  Callable,
9
11
  Dict,
10
12
  List,
11
13
  NamedTuple,
14
+ Optional,
15
+ Tuple,
12
16
  Type,
17
+ Union,
13
18
  get_args,
14
19
  get_origin,
15
20
  get_type_hints,
@@ -25,10 +30,7 @@ from inspect_ai._util.text import truncate_string_to_bytes
25
30
  from inspect_ai._util.trace import trace_action
26
31
  from inspect_ai.model._trace import trace_tool_mesage
27
32
  from inspect_ai.tool import Tool, ToolCall, ToolError, ToolInfo
28
- from inspect_ai.tool._tool import (
29
- ToolApprovalError,
30
- ToolParsingError,
31
- )
33
+ from inspect_ai.tool._tool import ToolApprovalError, ToolParsingError
32
34
  from inspect_ai.tool._tool_call import ToolCallContent, ToolCallError
33
35
  from inspect_ai.tool._tool_def import ToolDef, tool_defs
34
36
  from inspect_ai.tool._tool_info import parse_docstring
@@ -268,6 +270,16 @@ def disable_parallel_tools(
268
270
  return False
269
271
 
270
272
 
273
+ def type_hint_includes_none(type_hint: Type[Any] | None) -> bool:
274
+ origin = get_origin(type_hint)
275
+
276
+ if origin in {Union, UnionType}:
277
+ return type(None) in get_args(type_hint)
278
+ elif origin is Optional:
279
+ return True
280
+ return False
281
+
282
+
271
283
  def tool_params(input: dict[str, Any], func: Callable[..., Any]) -> dict[str, Any]:
272
284
  # parse function typeinfo
273
285
  signature = inspect.signature(func)
@@ -296,7 +308,7 @@ def tool_params(input: dict[str, Any], func: Callable[..., Any]) -> dict[str, An
296
308
  # yield parameter (fail if not passed and there is no default)
297
309
  if param_name in input:
298
310
  params[param_name] = tool_param(type_hint, input.get(param_name))
299
- elif param.default is not None:
311
+ elif param.default is not None or type_hint_includes_none(type_hint):
300
312
  params[param_name] = param.default
301
313
  else:
302
314
  raise ToolParsingError(
@@ -339,11 +351,21 @@ def tool_param(type_hint: Type[Any], input: Any) -> Any:
339
351
  return [tool_param(args[0], x) for x in input]
340
352
  else:
341
353
  return input
354
+ elif origin is tuple or origin is Tuple:
355
+ if args:
356
+ return tuple([tool_param(args[0], x) for x in input])
357
+ else:
358
+ return tuple(input)
342
359
  elif origin is dict or origin is Dict:
343
360
  if args and len(args) > 1:
344
361
  return {k: tool_param(args[1], v) for k, v in input}
345
362
  else:
346
363
  return input
364
+ elif origin is Union or origin is types.UnionType:
365
+ if args[1] is type(None):
366
+ return tool_param(args[0], input)
367
+ else:
368
+ return input
347
369
  else:
348
370
  return input
349
371
 
@@ -194,7 +194,9 @@ class GoogleAPI(ModelAPI):
194
194
  model=self.model_name, content=ex.message, stop_reason="model_length"
195
195
  )
196
196
  else:
197
- raise ex
197
+ return ModelOutput.from_content(
198
+ model=self.model_name, content=ex.message, stop_reason="unknown"
199
+ )
198
200
 
199
201
  @override
200
202
  def is_rate_limit(self, ex: BaseException) -> bool:
@@ -408,25 +410,34 @@ def chat_tools(tools: list[ToolInfo]) -> list[Tool]:
408
410
  # https://ai.google.dev/gemini-api/tutorials/extract_structured_data#define_the_schema
409
411
 
410
412
 
411
- def schema_from_param(param: ToolParam | ToolParams) -> Schema:
413
+ def schema_from_param(param: ToolParam | ToolParams, nullable: bool = False) -> Schema:
412
414
  if isinstance(param, ToolParams):
413
415
  param = ToolParam(
414
416
  type=param.type, properties=param.properties, required=param.required
415
417
  )
416
418
 
417
419
  if param.type == "number":
418
- return Schema(type=Type.NUMBER, description=param.description)
420
+ return Schema(
421
+ type=Type.NUMBER, description=param.description, nullable=nullable
422
+ )
419
423
  elif param.type == "integer":
420
- return Schema(type=Type.INTEGER, description=param.description)
424
+ return Schema(
425
+ type=Type.INTEGER, description=param.description, nullable=nullable
426
+ )
421
427
  elif param.type == "boolean":
422
- return Schema(type=Type.BOOLEAN, description=param.description)
428
+ return Schema(
429
+ type=Type.BOOLEAN, description=param.description, nullable=nullable
430
+ )
423
431
  elif param.type == "string":
424
- return Schema(type=Type.STRING, description=param.description)
432
+ return Schema(
433
+ type=Type.STRING, description=param.description, nullable=nullable
434
+ )
425
435
  elif param.type == "array":
426
436
  return Schema(
427
437
  type=Type.ARRAY,
428
438
  description=param.description,
429
439
  items=schema_from_param(param.items) if param.items else None,
440
+ nullable=nullable,
430
441
  )
431
442
  elif param.type == "object":
432
443
  return Schema(
@@ -436,7 +447,14 @@ def schema_from_param(param: ToolParam | ToolParams) -> Schema:
436
447
  if param.properties is not None
437
448
  else None,
438
449
  required=param.required,
450
+ nullable=nullable,
439
451
  )
452
+ # convert unions to optional params if the second type is 'null'
453
+ elif param.anyOf:
454
+ if len(param.anyOf) == 2 and param.anyOf[1].type == "null":
455
+ return schema_from_param(param.anyOf[0], nullable=True)
456
+ else:
457
+ return Schema(type=Type.TYPE_UNSPECIFIED)
440
458
  else:
441
459
  return Schema(type=Type.TYPE_UNSPECIFIED)
442
460
 
@@ -51,6 +51,7 @@ from .._model_output import (
51
51
  Logprobs,
52
52
  ModelOutput,
53
53
  ModelUsage,
54
+ StopReason,
54
55
  )
55
56
  from .openai_o1 import generate_o1
56
57
  from .util import (
@@ -262,7 +263,10 @@ class OpenAIAPI(ModelAPI):
262
263
  model=self.model_name,
263
264
  )
264
265
  if config.max_tokens is not None:
265
- params["max_tokens"] = config.max_tokens
266
+ if self.is_o1():
267
+ params["max_completion_tokens"] = config.max_tokens
268
+ else:
269
+ params["max_tokens"] = config.max_tokens
266
270
  if config.frequency_penalty is not None:
267
271
  params["frequency_penalty"] = config.frequency_penalty
268
272
  if config.stop_seqs is not None:
@@ -303,13 +307,23 @@ class OpenAIAPI(ModelAPI):
303
307
 
304
308
  # convert some well known bad request errors into ModelOutput
305
309
  def handle_bad_request(self, e: BadRequestError) -> ModelOutput:
306
- if e.status_code == 400 and e.code == "context_length_exceeded":
310
+ if e.status_code == 400:
311
+ # extract message
307
312
  if isinstance(e.body, dict) and "message" in e.body.keys():
308
313
  content = str(e.body.get("message"))
309
314
  else:
310
315
  content = e.message
316
+
317
+ # narrow stop_reason
318
+ if e.code == "context_length_exceeded":
319
+ stop_reason: StopReason = "model_length"
320
+ elif e.code == "invalid_prompt":
321
+ stop_reason = "content_filter"
322
+ else:
323
+ stop_reason = "unknown"
324
+
311
325
  return ModelOutput.from_content(
312
- model=self.model_name, content=content, stop_reason="model_length"
326
+ model=self.model_name, content=content, stop_reason=stop_reason
313
327
  )
314
328
  else:
315
329
  raise e
@@ -25,7 +25,7 @@ from inspect_ai.model import (
25
25
  from inspect_ai.tool import ToolCall, ToolInfo
26
26
 
27
27
  from .._model_call import ModelCall
28
- from .._model_output import ModelUsage
28
+ from .._model_output import ModelUsage, StopReason
29
29
  from .._providers.util import (
30
30
  ChatAPIHandler,
31
31
  ChatAPIMessage,
@@ -48,12 +48,6 @@ async def generate_o1(
48
48
  # create chatapi handler
49
49
  handler = O1PreviewChatAPIHandler()
50
50
 
51
- # map max_tokens => max_completion_tokens
52
- max_tokens = params.get("max_tokens", None)
53
- if max_tokens:
54
- params["max_completion_tokens"] = max_tokens
55
- del params["max_tokens"]
56
-
57
51
  # call model
58
52
  request = dict(
59
53
  model=model,
@@ -89,12 +83,16 @@ async def generate_o1(
89
83
 
90
84
 
91
85
  def handle_bad_request(model: str, ex: BadRequestError) -> ModelOutput:
92
- if ex.code == "invalid_prompt":
93
- return ModelOutput.from_content(
94
- model=model, content=str(ex), stop_reason="content_filter"
95
- )
86
+ if ex.code == "context_length_exceeded":
87
+ stop_reason: StopReason = "model_length"
88
+ elif ex.code == "invalid_prompt":
89
+ stop_reason = "content_filter"
96
90
  else:
97
- raise ex
91
+ stop_reason = "unknown"
92
+
93
+ return ModelOutput.from_content(
94
+ model=model, content=str(ex), stop_reason=stop_reason
95
+ )
98
96
 
99
97
 
100
98
  def chat_messages(
@@ -8,6 +8,7 @@ from typing import (
8
8
  Dict,
9
9
  List,
10
10
  Optional,
11
+ Tuple,
11
12
  Type,
12
13
  Union,
13
14
  get_args,
@@ -155,7 +156,7 @@ def parse_type(type_hint: Type[Any]) -> ToolParam:
155
156
  return ToolParam(type="null")
156
157
  else:
157
158
  return ToolParam()
158
- elif origin is list or origin is List:
159
+ elif origin is list or origin is List or origin is tuple or origin is Tuple:
159
160
  return ToolParam(
160
161
  type="array", items=parse_type(args[0]) if args else ToolParam()
161
162
  )
@@ -38,9 +38,9 @@ class EnvironmentSpec:
38
38
  for i, obs_spec in enumerate(env_obs_spec.values()):
39
39
  self.observation_spec[i + 1] = convert(obs_spec)
40
40
 
41
- assert isinstance(
42
- env.action_spec(), specs.Array
43
- ), "Only a single action type is supported."
41
+ assert isinstance(env.action_spec(), specs.Array), (
42
+ "Only a single action type is supported."
43
+ )
44
44
  self.action_spec = {1: convert(env.action_spec())}
45
45
 
46
46
  self.observation_manager = spec_manager.SpecManager(self.observation_spec)
@@ -234,12 +234,12 @@ class EnvironmentService(dm_env_rpc_pb2_grpc.EnvironmentServicer):
234
234
  observations.
235
235
  """
236
236
  with self._lock:
237
- assert (
238
- cur_world in self._envs
239
- ), "Current world does not have an assosiated environment"
240
- assert (
241
- cur_world in self._joined_worlds
242
- ), "Please join world before calling step."
237
+ assert cur_world in self._envs, (
238
+ "Current world does not have an assosiated environment"
239
+ )
240
+ assert cur_world in self._joined_worlds, (
241
+ "Please join world before calling step."
242
+ )
243
243
  env = self._envs[cur_world]
244
244
  spec = self._specs[cur_world]
245
245
 
@@ -372,7 +372,9 @@ async def web_browser_cmd(cmd: str, *args: str) -> str:
372
372
  )
373
373
  else:
374
374
  response = parse_web_browser_output(result.stdout)
375
- if "web_at" in response:
375
+ if "error" in response and response.get("error", "").strip() != "":
376
+ raise ToolError(str(response.get("error")) or "(unknown error)")
377
+ elif "web_at" in response:
376
378
  web_at = (
377
379
  str(response.get("web_at")) or "(no web accessiblity tree available)"
378
380
  )
@@ -384,8 +386,6 @@ async def web_browser_cmd(cmd: str, *args: str) -> str:
384
386
  web_at = "\n".join(web_at_lines)
385
387
  store_as(WebBrowserStore).web_at = web_at
386
388
  return web_at
387
- elif "error" in response:
388
- raise ToolError(str(response.get("error")) or "(unknown error)")
389
389
  else:
390
390
  raise RuntimeError(
391
391
  f"web_browser output must contain either 'error' or 'web_at' field: {result.stdout}"
@@ -1,3 +1,5 @@
1
+ from inspect_ai._util.trace import trace_action, trace_message
2
+
1
3
  from ._concurrency import concurrency
2
4
  from ._console import input_screen
3
5
  from ._display import DisplayType, display_type
@@ -56,4 +58,6 @@ __all__ = [
56
58
  "throttle",
57
59
  "trace_enabled",
58
60
  "trace_panel",
61
+ "trace_action",
62
+ "trace_message",
59
63
  ]
@@ -33,9 +33,7 @@ async def compose_up(project: ComposeProject) -> None:
33
33
  timeout=300,
34
34
  )
35
35
  if not result.success:
36
- msg = (
37
- f"Failed to start docker services for {project.config}: " f"{result.stderr}"
38
- )
36
+ msg = f"Failed to start docker services for {project.config}: {result.stderr}"
39
37
  raise RuntimeError(msg)
40
38
 
41
39
 
@@ -84,7 +84,8 @@ def task_project_name(task: str) -> str:
84
84
  if len(task) == 0:
85
85
  task = "task"
86
86
 
87
- return f"inspect-{task[:12]}-i{uuid().lower()[:6]}"
87
+ # _- breaks docker project name constraints so we strip trailing underscores.
88
+ return f"inspect-{task[:12].rstrip('_')}-i{uuid().lower()[:6]}"
88
89
 
89
90
 
90
91
  inspect_project_pattern = r"^inspect-[a-z\d\-_]*-i[a-z\d]{6,}$"
@@ -75,9 +75,9 @@ async def test_read_and_write_file_text(sandbox_env: SandboxEnvironment) -> None
75
75
  written_file_string = await sandbox_env.read_file(
76
76
  "test_read_and_write_file_text.file", text=True
77
77
  )
78
- assert (
79
- "great #content\nincluding newlines" == written_file_string
80
- ), f"unexpected content: [{written_file_string}]"
78
+ assert "great #content\nincluding newlines" == written_file_string, (
79
+ f"unexpected content: [{written_file_string}]"
80
+ )
81
81
  await _cleanup_file(sandbox_env, "test_read_and_write_file_text.file")
82
82
 
83
83
 
@@ -219,9 +219,9 @@ async def test_exec_output(sandbox_env: SandboxEnvironment) -> None:
219
219
  exec_result = await sandbox_env.exec(["sh", "-c", "echo foo; echo bar"])
220
220
  expected = "foo\nbar\n"
221
221
  # in the assertion message, we show the actual bytes to help debug newline issues
222
- assert (
223
- exec_result.stdout == expected
224
- ), f"Unexpected output:expected {expected.encode('UTF-8')!r}; got {exec_result.stdout.encode('UTF-8')!r}"
222
+ assert exec_result.stdout == expected, (
223
+ f"Unexpected output:expected {expected.encode('UTF-8')!r}; got {exec_result.stdout.encode('UTF-8')!r}"
224
+ )
225
225
 
226
226
 
227
227
  async def test_exec_timeout(sandbox_env: SandboxEnvironment) -> None:
@@ -248,13 +248,13 @@ async def test_exec_as_user(sandbox_env: SandboxEnvironment) -> None:
248
248
 
249
249
  # Test exec as different users
250
250
  root_result = await sandbox_env.exec(["whoami"], user="root")
251
- assert (
252
- root_result.stdout.strip() == "root"
253
- ), f"Expected 'root', got '{root_result.stdout.strip()}'"
251
+ assert root_result.stdout.strip() == "root", (
252
+ f"Expected 'root', got '{root_result.stdout.strip()}'"
253
+ )
254
254
  myuser_result = await sandbox_env.exec(["whoami"], user=username)
255
- assert (
256
- myuser_result.stdout.strip() == username
257
- ), f"Expected '{username}', got '{myuser_result.stdout.strip()}'"
255
+ assert myuser_result.stdout.strip() == username, (
256
+ f"Expected '{username}', got '{myuser_result.stdout.strip()}'"
257
+ )
258
258
  finally:
259
259
  # Clean up
260
260
  await sandbox_env.exec(["userdel", "-r", username], user="root")
@@ -266,9 +266,9 @@ async def test_exec_as_nonexistent_user(sandbox_env: SandboxEnvironment) -> None
266
266
  expected_error = (
267
267
  "unable to find user nonexistent: no matching entries in passwd file"
268
268
  )
269
- assert (
270
- expected_error in result.stdout
271
- ), f"Error string '{expected_error}' not found in error output: '{result.stdout}'"
269
+ assert expected_error in result.stdout, (
270
+ f"Error string '{expected_error}' not found in error output: '{result.stdout}'"
271
+ )
272
272
 
273
273
 
274
274
  async def test_cwd_unspecified(sandbox_env: SandboxEnvironment) -> None:
@@ -291,9 +291,9 @@ async def test_cwd_relative(sandbox_env: SandboxEnvironment) -> None:
291
291
  file_path = cwd_subdirectory + "/" + file_name
292
292
  await sandbox_env.write_file(file_path, "ls me plz")
293
293
  current_dir_contents = (await sandbox_env.exec(["ls"], cwd=cwd_subdirectory)).stdout
294
- assert (
295
- file_name in current_dir_contents
296
- ), f"{file_name} not found in {current_dir_contents}"
294
+ assert file_name in current_dir_contents, (
295
+ f"{file_name} not found in {current_dir_contents}"
296
+ )
297
297
  await _cleanup_file(sandbox_env, file_path)
298
298
 
299
299
 
inspect_ai/util/_store.py CHANGED
@@ -34,8 +34,8 @@ class Store:
34
34
  inheriting from Pydantic `BaseModel`)
35
35
  """
36
36
 
37
- def __init__(self) -> None:
38
- self._data: dict[str, Any] = {}
37
+ def __init__(self, data: dict[str, Any] | None = None) -> None:
38
+ self._data = deepcopy(data) if data else {}
39
39
 
40
40
  @overload
41
41
  def get(self, key: str, default: None = None) -> Any: ...
@@ -101,9 +101,9 @@ async def subprocess(
101
101
  input = input.encode() if isinstance(input, str) else input
102
102
 
103
103
  # function to run command (we may or may not run it w/ concurrency)
104
- async def run_command() -> (
105
- AsyncGenerator[Union[Process, ExecResult[str], ExecResult[bytes]], None]
106
- ):
104
+ async def run_command() -> AsyncGenerator[
105
+ Union[Process, ExecResult[str], ExecResult[bytes]], None
106
+ ]:
107
107
  if isinstance(args, str):
108
108
  proc = await asyncio.create_subprocess_shell(
109
109
  args,
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.2
2
2
  Name: inspect_ai
3
- Version: 0.3.56
3
+ Version: 0.3.57
4
4
  Summary: Framework for large language model evaluations
5
5
  Author: UK AI Safety Institute
6
6
  License: MIT License
@@ -67,7 +67,7 @@ Requires-Dist: pytest-asyncio; extra == "dev"
67
67
  Requires-Dist: pytest-cov; extra == "dev"
68
68
  Requires-Dist: pytest-dotenv; extra == "dev"
69
69
  Requires-Dist: pytest-xdist; extra == "dev"
70
- Requires-Dist: ruff==0.8.4; extra == "dev"
70
+ Requires-Dist: ruff==0.9.0; extra == "dev"
71
71
  Requires-Dist: textual-dev>=0.86.2; extra == "dev"
72
72
  Requires-Dist: types-PyYAML; extra == "dev"
73
73
  Requires-Dist: types-beautifulsoup4; extra == "dev"