inspect-ai 0.3.65__py3-none-any.whl → 0.3.67__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. inspect_ai/_display/core/config.py +4 -0
  2. inspect_ai/_display/textual/app.py +13 -5
  3. inspect_ai/_display/textual/widgets/footer.py +2 -2
  4. inspect_ai/_display/textual/widgets/sandbox.py +1 -1
  5. inspect_ai/_display/textual/widgets/task_detail.py +7 -5
  6. inspect_ai/_display/textual/widgets/tasks.py +8 -6
  7. inspect_ai/_display/textual/widgets/transcript.py +1 -1
  8. inspect_ai/_eval/task/run.py +5 -3
  9. inspect_ai/_eval/task/task.py +9 -1
  10. inspect_ai/_util/format.py +58 -0
  11. inspect_ai/_view/www/dist/assets/index.css +29 -9
  12. inspect_ai/_view/www/dist/assets/index.js +368 -304
  13. inspect_ai/_view/www/src/samples/error/FlatSampleErrorView.tsx +1 -1
  14. inspect_ai/_view/www/src/samples/sample-tools/filters.ts +41 -20
  15. inspect_ai/_view/www/src/samples/sample-tools/sample-filter/SampleFilter.tsx +2 -1
  16. inspect_ai/_view/www/src/samples/sample-tools/sample-filter/completions.ts +28 -6
  17. inspect_ai/_view/www/src/samples/sample-tools/sample-filter/language.ts +5 -0
  18. inspect_ai/_view/www/src/samples/transcript/LoggerEventView.tsx +1 -3
  19. inspect_ai/_view/www/src/samples/transcript/SubtaskEventView.tsx +31 -16
  20. inspect_ai/_view/www/src/samples/transcript/TranscriptView.tsx +4 -1
  21. inspect_ai/_view/www/src/workspace/navbar/StatusPanel.module.css +1 -0
  22. inspect_ai/_view/www/src/workspace/navbar/StatusPanel.tsx +2 -2
  23. inspect_ai/model/_model.py +89 -2
  24. inspect_ai/model/_providers/anthropic.py +4 -0
  25. inspect_ai/model/_providers/azureai.py +5 -0
  26. inspect_ai/model/_providers/bedrock.py +5 -0
  27. inspect_ai/model/_providers/cloudflare.py +4 -0
  28. inspect_ai/model/_providers/goodfire.py +5 -0
  29. inspect_ai/model/_providers/google.py +16 -3
  30. inspect_ai/model/_providers/groq.py +4 -0
  31. inspect_ai/model/_providers/hf.py +7 -0
  32. inspect_ai/model/_providers/mistral.py +4 -0
  33. inspect_ai/model/_providers/openai.py +4 -0
  34. inspect_ai/model/_providers/vertex.py +5 -0
  35. inspect_ai/model/_providers/vllm.py +7 -0
  36. inspect_ai/solver/__init__.py +8 -1
  37. inspect_ai/solver/_human_agent/panel.py +11 -5
  38. inspect_ai/solver/_prompt.py +38 -5
  39. inspect_ai/util/_sandbox/docker/config.py +4 -1
  40. inspect_ai/util/_sandbox/docker/util.py +2 -1
  41. {inspect_ai-0.3.65.dist-info → inspect_ai-0.3.67.dist-info}/METADATA +3 -2
  42. {inspect_ai-0.3.65.dist-info → inspect_ai-0.3.67.dist-info}/RECORD +46 -46
  43. {inspect_ai-0.3.65.dist-info → inspect_ai-0.3.67.dist-info}/LICENSE +0 -0
  44. {inspect_ai-0.3.65.dist-info → inspect_ai-0.3.67.dist-info}/WHEEL +0 -0
  45. {inspect_ai-0.3.65.dist-info → inspect_ai-0.3.67.dist-info}/entry_points.txt +0 -0
  46. {inspect_ai-0.3.65.dist-info → inspect_ai-0.3.67.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,7 @@
1
1
  import { ApplicationIcons } from "../../appearance/icons";
2
2
 
3
3
  import clsx from "clsx";
4
- import styles from "./SampleErrorView.module.css";
4
+ import styles from "./FlatSampleErrorView.module.css";
5
5
  import { errorType } from "./error";
6
6
 
7
7
  interface FlatSampleErrorViewProps {
@@ -35,7 +35,7 @@ const coerceValue = (value: unknown, descriptor: ScoreDescriptor): unknown => {
35
35
 
36
36
  // Whether a particular value is filter-able
37
37
  const isFilteringSupportedForValue = (value: unknown): boolean =>
38
- ["string", "number", "boolean"].includes(typeof value);
38
+ ["string", "number", "boolean"].includes(typeof value) || value === null;
39
39
 
40
40
  /**
41
41
  * Returns the names of scores that are not allowed to be used as short names in
@@ -56,20 +56,26 @@ const bannedShortScoreNames = (scores: ScoreLabel[]): Set<string> => {
56
56
  return banned;
57
57
  };
58
58
 
59
+ // Pseudo-variables added to all filter expressions. These are not needed in most cases.
60
+ // Normally one could check a boolean value `foo` by simply typing `foo` or `not foo`.
61
+ // However, some evals use tristate values that can be true, false or null. This is where
62
+ // these constants come in handy.
63
+ const filterExpressionConstants: Record<string, unknown> = {
64
+ True: true,
65
+ False: false,
66
+ None: null,
67
+ };
68
+
59
69
  /**
60
70
  * Generates a dictionary of variables that can be used in the filter expression.
61
71
  * High-level scorer metrics can be accessed by name directly.
62
72
  * Child metrics are accessed using dot notation (e.g. `scorer_name.score_name`) or
63
73
  * directly by name when it is unique.
64
- *
65
- * @param {import("../../samples/descriptor/samplesDescriptor").EvalDescriptor} evalDescriptor
66
- * @param {import("../../types/log").Scores1} sampleScores
67
- * @returns {Object<string, any>}
68
74
  */
69
75
  const scoreVariables = (
70
76
  evalDescriptor: EvalDescriptor,
71
77
  sampleScores: Scores1,
72
- ) => {
78
+ ): Record<string, unknown> => {
73
79
  const bannedShortNames = bannedShortScoreNames(evalDescriptor.scores);
74
80
  const variables: Record<string, unknown> = {};
75
81
 
@@ -77,7 +83,7 @@ const scoreVariables = (
77
83
  variableName: string,
78
84
  scoreLabel: ScoreLabel,
79
85
  value: unknown,
80
- ) => {
86
+ ): void => {
81
87
  const coercedValue = coerceValue(
82
88
  value,
83
89
  evalDescriptor.scoreDescriptor(scoreLabel),
@@ -101,6 +107,12 @@ const scoreVariables = (
101
107
  return variables;
102
108
  };
103
109
 
110
+ const sampleVariables = (sample: SampleSummary): Record<string, unknown> => {
111
+ return {
112
+ has_error: !!sample.error,
113
+ };
114
+ };
115
+
104
116
  /**
105
117
  * Generates a dictionary of variables that can be used in the filter expression.
106
118
  * High-level scorer metrics can be accessed by name directly.
@@ -115,11 +127,6 @@ export const scoreFilterItems = (
115
127
  const valueToString = (value: unknown) =>
116
128
  typeof value === "string" ? `"${value}"` : String(value);
117
129
 
118
- /**
119
- * @param {string | undefined} shortName
120
- * @param {string | undefined} qualifiedName
121
- * @param {import("../../types").ScoreLabel} scoreLabel
122
- */
123
130
  const addScore = (
124
131
  scoreLabel: ScoreLabel,
125
132
  shortName?: string,
@@ -196,13 +203,33 @@ export const filterExpression = (
196
203
  : [sample.target];
197
204
  return targets.some((target) => target.match(new RegExp(regex, "i")));
198
205
  };
206
+ const errorContains = (regex: string): boolean => {
207
+ return !!sample.error?.match(new RegExp(regex, "i"));
208
+ };
199
209
 
200
210
  const extraFunctions = {
201
211
  input_contains: inputContains,
202
212
  target_contains: targetContains,
213
+ error_contains: errorContains,
214
+ };
215
+ const mySampleVariables = sampleVariables(sample);
216
+ const vars = {
217
+ ...mySampleVariables,
218
+ ...scoreVariables(evalDescriptor, sample.scores),
219
+ };
220
+ const resolveVariable = (name: string, get: (name: string) => any) => {
221
+ // Sample variables (like has_error) always exist.
222
+ if (name in mySampleVariables) {
223
+ return get(name);
224
+ }
225
+ // Score variables exist only if the sample completed successfully.
226
+ return sample.error ? undefined : get(name);
203
227
  };
204
- const expression = compileExpression(filterValue, { extraFunctions });
205
- const vars = scoreVariables(evalDescriptor, sample.scores);
228
+ const expression = compileExpression(filterValue, {
229
+ extraFunctions,
230
+ constants: filterExpressionConstants,
231
+ customProp: resolveVariable,
232
+ });
206
233
  const result = expression(vars);
207
234
  if (typeof result === "boolean") {
208
235
  return { matches: result, error: undefined };
@@ -263,12 +290,6 @@ export const filterExpression = (
263
290
  }
264
291
  };
265
292
 
266
- /**
267
- * @param {import("../../samples/descriptor/samplesDescriptor").EvalDescriptor} evalDescriptor
268
- * @param {import("../../api/types").SampleSummary[]} samples
269
- * @param {string} filterValue
270
- * @returns {}
271
- */
272
293
  export const filterSamples = (
273
294
  evalDescriptor: EvalDescriptor,
274
295
  samples: SampleSummary[],
@@ -39,7 +39,8 @@ interface SampleFilterProps {
39
39
  const FILTER_TOOLTIP = `
40
40
  Filter samples by:
41
41
  • Scores
42
- Input and target regex search: input_contains, target_contains
42
+ Samples with errors: has_error
43
+ • Input, target and error regex search: input_contains, target_contains, error_contains
43
44
 
44
45
  Supported expressions:
45
46
  • Arithmetic: +, -, *, /, mod, ^
@@ -13,7 +13,12 @@ import {
13
13
  kScoreTypePassFail,
14
14
  } from "../../../constants";
15
15
  import { ScoreFilterItem } from "../filters";
16
- import { KEYWORDS, MATH_FUNCTIONS, SAMPLE_FUNCTIONS } from "./language";
16
+ import {
17
+ KEYWORDS,
18
+ MATH_FUNCTIONS,
19
+ SAMPLE_FUNCTIONS,
20
+ SAMPLE_VARIABLES,
21
+ } from "./language";
17
22
  import { Token, tokenize } from "./tokenize";
18
23
 
19
24
  interface CompletionOptions {
@@ -76,10 +81,20 @@ const makeSampleFunctionCompletion = ([label, info]: [
76
81
  boost: 0,
77
82
  });
78
83
 
84
+ const makeSampleVariableCompletion = ([label, info]: [
85
+ string,
86
+ string,
87
+ ]): Completion => ({
88
+ label,
89
+ type: "variable",
90
+ info,
91
+ boost: 10,
92
+ });
93
+
79
94
  const makeLiteralCompletion = (k: string): Completion => ({
80
95
  label: k,
81
96
  type: "text",
82
- boost: 10,
97
+ boost: 20,
83
98
  });
84
99
 
85
100
  const makeCanonicalNameCompletion = (
@@ -89,14 +104,14 @@ const makeCanonicalNameCompletion = (
89
104
  label: item.canonicalName + (autoSpaceIf(item) ? " " : ""),
90
105
  type: "variable",
91
106
  info: item.tooltip,
92
- boost: 20,
107
+ boost: 30,
93
108
  });
94
109
 
95
110
  const makeMemberAccessCompletion = (item: ScoreFilterItem): Completion => ({
96
111
  label: item.qualifiedName?.split(".")[1] || "",
97
112
  type: "variable",
98
113
  info: item.tooltip,
99
- boost: 20,
114
+ boost: 40,
100
115
  });
101
116
 
102
117
  const getMemberScoreItems = (
@@ -130,6 +145,9 @@ export function getCompletions(
130
145
  const sampleFunctionCompletionItems = SAMPLE_FUNCTIONS.map(
131
146
  makeSampleFunctionCompletion,
132
147
  );
148
+ const sampleVariableCompletionItems = SAMPLE_VARIABLES.map(
149
+ makeSampleVariableCompletion,
150
+ );
133
151
  const variableCompletionItems = filterItems.map((item) =>
134
152
  makeCanonicalNameCompletion(item),
135
153
  );
@@ -138,6 +156,7 @@ export function getCompletions(
138
156
  ...keywordCompletionItems,
139
157
  ...mathFunctionCompletionItems,
140
158
  ...sampleFunctionCompletionItems,
159
+ ...sampleVariableCompletionItems,
141
160
  ...variableCompletionItems,
142
161
  ];
143
162
 
@@ -218,9 +237,11 @@ export function getCompletions(
218
237
  },
219
238
  };
220
239
 
221
- const priorityLabels = new Set(priorityCompletions.map((c) => c.label));
240
+ const priorityLabels = new Set(
241
+ priorityCompletions.map((c) => c.label.trim()),
242
+ );
222
243
  const defaultCompletionsAdjusted = defaultCompletionItems
223
- .filter((c) => !priorityLabels.has(c.label))
244
+ .filter((c) => !priorityLabels.has(c.label.trim()))
224
245
  .map((c) => ({ ...c, section: miscSection }));
225
246
 
226
247
  return {
@@ -240,6 +261,7 @@ export function getCompletions(
240
261
  completingAtEnd && item.scoreType !== kScoreTypeBoolean,
241
262
  }),
242
263
  ),
264
+ ...sampleVariableCompletionItems,
243
265
  ...sampleFunctionCompletionItems,
244
266
  ]);
245
267
 
@@ -13,7 +13,12 @@ export const MATH_FUNCTIONS: [string, string][] = [
13
13
  ["log10", "Base 10 logarithm"],
14
14
  ];
15
15
 
16
+ export const SAMPLE_VARIABLES: [string, string][] = [
17
+ ["has_error", "Checks if the sample has an error"],
18
+ ];
19
+
16
20
  export const SAMPLE_FUNCTIONS: [string, string][] = [
17
21
  ["input_contains", "Checks if input contains a regular expression"],
18
22
  ["target_contains", "Checks if target contains a regular expression"],
23
+ ["error_contains", "Checks if error contains a regular expression"],
19
24
  ];
@@ -24,9 +24,7 @@ export const LoggerEventView: React.FC<LoggerEventViewProps> = ({
24
24
  icon={ApplicationIcons.logging[event.message.level.toLowerCase()]}
25
25
  >
26
26
  <div className={clsx("text-size-base", styles.grid)}>
27
- <div className={clsx("text-size-smaller")}>
28
- ${event.message.message}
29
- </div>
27
+ <div className={clsx("text-size-smaller")}>{event.message.message}</div>
30
28
  <div className={clsx("text-size-smaller", "text-style-secondary")}>
31
29
  {event.message.filename}:{event.message.lineno}
32
30
  </div>
@@ -30,19 +30,6 @@ export const SubtaskEventView: React.FC<SubtaskEventViewProps> = ({
30
30
  className,
31
31
  }) => {
32
32
  // Render Forks specially
33
-
34
- const transcript =
35
- event.events.length > 0 ? (
36
- <TranscriptView
37
- id={`${id}-subtask`}
38
- data-name="Transcript"
39
- events={event.events}
40
- depth={depth + 1}
41
- />
42
- ) : (
43
- ""
44
- );
45
-
46
33
  const body =
47
34
  event.type === "fork" ? (
48
35
  <div title="Summary" className={clsx(styles.summary)}>
@@ -51,7 +38,16 @@ export const SubtaskEventView: React.FC<SubtaskEventViewProps> = ({
51
38
  <Rendered values={event.input} />
52
39
  </div>
53
40
  <div className={clsx("text-style-label")}>Transcript</div>
54
- {transcript}
41
+ {event.events.length > 0 ? (
42
+ <TranscriptView
43
+ id={`${id}-subtask`}
44
+ data-name="Transcript"
45
+ events={event.events}
46
+ depth={depth + 1}
47
+ />
48
+ ) : (
49
+ <None />
50
+ )}
55
51
  </div>
56
52
  ) : (
57
53
  <Fragment>
@@ -60,7 +56,14 @@ export const SubtaskEventView: React.FC<SubtaskEventViewProps> = ({
60
56
  input={event.input}
61
57
  result={event.result}
62
58
  />
63
- {transcript}
59
+ {event.events.length > 0 ? (
60
+ <TranscriptView
61
+ id={`${id}-subtask`}
62
+ data-name="Transcript"
63
+ events={event.events}
64
+ depth={depth + 1}
65
+ />
66
+ ) : undefined}
64
67
  </Fragment>
65
68
  );
66
69
 
@@ -126,8 +129,20 @@ const Rendered: React.FC<RenderedProps> = ({ values }) => {
126
129
  return <Rendered values={val} />;
127
130
  });
128
131
  } else if (values && typeof values === "object") {
129
- return <MetaDataView entries={values as Record<string, unknown>} />;
132
+ if (Object.keys(values).length === 0) {
133
+ return <None />;
134
+ } else {
135
+ return <MetaDataView entries={values as Record<string, unknown>} />;
136
+ }
130
137
  } else {
131
138
  return values;
132
139
  }
133
140
  };
141
+
142
+ const None: React.FC = () => {
143
+ return (
144
+ <span className={clsx("text-size-small", "text-style-secondary")}>
145
+ [None]
146
+ </span>
147
+ );
148
+ };
@@ -387,7 +387,10 @@ const fixupEventStream = (events: Events) => {
387
387
  });
388
388
  const initEvent = events[initEventIndex];
389
389
 
390
- const fixedUp = [...events];
390
+ // Filter pending events
391
+ const finalEvents = events.filter((e) => !e.pending);
392
+
393
+ const fixedUp = [...finalEvents];
391
394
  if (initEvent) {
392
395
  fixedUp.splice(initEventIndex, 0, {
393
396
  timestamp: initEvent.timestamp,
@@ -5,6 +5,7 @@
5
5
  font-size: var(--inspect-font-size-smaller);
6
6
  display: grid;
7
7
  grid-template-columns: auto auto;
8
+ justify-content: end;
8
9
  }
9
10
 
10
11
  .statusIcon {
@@ -51,9 +51,9 @@ const StatusPanel: React.FC<StatusPanelProps> = ({
51
51
  <div className={styles.statusPanel}>
52
52
  <i className={clsx(icon, styles.statusIcon)} style={{}} />
53
53
  <div>
54
- <div>${status}</div>
54
+ <div>{status}</div>
55
55
  <div>
56
- (${sampleCount} ${sampleCount === 1 ? "sample" : "samples"})
56
+ ({sampleCount} {sampleCount === 1 ? "sample" : "samples"})
57
57
  </div>
58
58
  </div>
59
59
  </div>
@@ -7,8 +7,10 @@ import os
7
7
  import time
8
8
  from contextvars import ContextVar
9
9
  from copy import deepcopy
10
+ from types import TracebackType
10
11
  from typing import Any, Callable, Literal, Type, cast
11
12
 
13
+ from pydantic_core import to_jsonable_python
12
14
  from tenacity import (
13
15
  retry,
14
16
  retry_if_exception,
@@ -109,6 +111,10 @@ class ModelAPI(abc.ABC):
109
111
  # set any explicitly specified api key
110
112
  self.api_key = api_key
111
113
 
114
+ async def close(self) -> None:
115
+ """Close method for closing any client allocated for the model."""
116
+ pass
117
+
112
118
  @abc.abstractmethod
113
119
  async def generate(
114
120
  self,
@@ -178,7 +184,17 @@ class ModelAPI(abc.ABC):
178
184
 
179
185
 
180
186
  class Model:
181
- """Model interface."""
187
+ """Model interface.
188
+
189
+ Use `get_model()` to get an instance of a model. Model provides an
190
+ async context manager for closing the connection to it after use.
191
+ For example:
192
+
193
+ ```python
194
+ async with get_model("openai/gpt-4o") as model:
195
+ response = await model.generate("Say hello")
196
+ ```
197
+ """
182
198
 
183
199
  api: ModelAPI
184
200
  """Model API."""
@@ -196,10 +212,28 @@ class Model:
196
212
  self.api = api
197
213
  self.config = config
198
214
 
215
+ # state indicating whether our lifetime is bound by a context manager
216
+ self._context_bound = False
217
+ self._closed = False
218
+
199
219
  # if using the Model API standalone in a notebook this will
200
220
  # get hit before score() or eval() so we activate nest_asyncio
201
221
  platform_init()
202
222
 
223
+ async def __aenter__(self: "Model") -> "Model":
224
+ self._context_bound = True
225
+ return self
226
+
227
+ async def __aexit__(
228
+ self,
229
+ exc_type: type[BaseException] | None,
230
+ exc: BaseException | None,
231
+ exc_tb: TracebackType | None,
232
+ ) -> None:
233
+ if not self._closed:
234
+ await self.api.close()
235
+ self._closed = True
236
+
203
237
  @property
204
238
  def name(self) -> str:
205
239
  """Model name."""
@@ -598,10 +632,27 @@ def get_model(
598
632
  config: GenerateConfig = GenerateConfig(),
599
633
  base_url: str | None = None,
600
634
  api_key: str | None = None,
635
+ memoize: bool = True,
601
636
  **model_args: Any,
602
637
  ) -> Model:
603
638
  """Get an instance of a model.
604
639
 
640
+ Calls to get_model() are memoized (i.e. a call with the same arguments
641
+ will return an existing instance of the model rather than creating a
642
+ new one). You can disable this with `memoize=False`.
643
+
644
+ If you prefer to immediately close models after use (as well as
645
+ prevent caching) you can employ the async context manager built in
646
+ to the `Model` class. For example:
647
+
648
+ ```python
649
+ async with get_model("openai/gpt-4o") as model:
650
+ response = await model.generate("Say hello")
651
+ ```
652
+
653
+ In this case, the model client will be closed at the end of the
654
+ context manager and will not be available in the get_model() cache.
655
+
605
656
  Args:
606
657
  model: Model specification.
607
658
  If `Model` is passed it is returned unmodified,
@@ -611,6 +662,8 @@ def get_model(
611
662
  config: Configuration for model.
612
663
  base_url: Optional. Alternate base URL for model.
613
664
  api_key: Optional. API key for model.
665
+ memoize: Use/store a cached version of the model based on
666
+ the parameters to `get_model()`
614
667
  **model_args: Additional args to
615
668
  pass to model constructor.
616
669
 
@@ -637,6 +690,23 @@ def get_model(
637
690
  else:
638
691
  raise ValueError("No model specified (and no INSPECT_EVAL_MODEL defined)")
639
692
 
693
+ # see if we can return a memoized model instance
694
+ # (exclude mockllm since custom_outputs is an infinite generator)
695
+ model_cache_key: str = "" # for mypy below
696
+ if model.startswith("mockllm/"):
697
+ memoize = False
698
+ if memoize:
699
+ model_cache_key = (
700
+ model
701
+ + config.model_dump_json(exclude_none=True)
702
+ + str(base_url)
703
+ + str(api_key)
704
+ + str(to_jsonable_python(model_args, fallback=lambda _: None))
705
+ )
706
+ cached = cached_model(model_cache_key)
707
+ if cached is not None:
708
+ return cached
709
+
640
710
  # split model into api name and model name if necessary
641
711
  api_name = None
642
712
  parts = model.split("/")
@@ -667,13 +737,30 @@ def get_model(
667
737
  config=config,
668
738
  **model_args,
669
739
  )
670
- return Model(modelapi_instance, config)
740
+ m = Model(modelapi_instance, config)
741
+ if memoize:
742
+ _models[model_cache_key] = m
743
+ return m
671
744
 
672
745
  else:
673
746
  from_api = f" from {api_name}" if api_name else ""
674
747
  raise ValueError(f"Model name {model}{from_api} not recognized.")
675
748
 
676
749
 
750
+ # cache for memoization of get_model
751
+ _models: dict[str, Model] = {}
752
+
753
+
754
+ def cached_model(key: str) -> Model | None:
755
+ # clean out context bound models before accessing the cache
756
+ for k in list(_models.keys()):
757
+ if _models[k]._context_bound:
758
+ del _models[k]
759
+
760
+ # read from the cache
761
+ return _models.get(key, None)
762
+
763
+
677
764
  def resolve_models(
678
765
  model: str | Model | list[str] | list[Model] | None,
679
766
  model_base_url: str | None = None,
@@ -150,6 +150,10 @@ class AnthropicAPI(ModelAPI):
150
150
  **model_args,
151
151
  )
152
152
 
153
+ @override
154
+ async def close(self) -> None:
155
+ await self.client.close()
156
+
153
157
  def is_bedrock(self) -> bool:
154
158
  return self.service == "bedrock"
155
159
 
@@ -124,6 +124,11 @@ class AzureAIAPI(ModelAPI):
124
124
  self.endpoint_url = endpoint_url
125
125
  self.model_args = model_args
126
126
 
127
+ @override
128
+ async def close(self) -> None:
129
+ # client is created/destroyed each time in generate()
130
+ pass
131
+
127
132
  async def generate(
128
133
  self,
129
134
  input: list[ChatMessage],
@@ -259,6 +259,11 @@ class BedrockAPI(ModelAPI):
259
259
  except ImportError:
260
260
  raise pip_dependency_error("Bedrock API", ["aioboto3"])
261
261
 
262
+ @override
263
+ async def close(self) -> None:
264
+ # client is created/destroyed each time in generate()
265
+ pass
266
+
262
267
  @override
263
268
  def connection_key(self) -> str:
264
269
  return self.model_name
@@ -56,6 +56,10 @@ class CloudFlareAPI(ModelAPI):
56
56
  )
57
57
  self.model_args = model_args
58
58
 
59
+ @override
60
+ async def close(self) -> None:
61
+ await self.client.aclose()
62
+
59
63
  async def generate(
60
64
  self,
61
65
  input: list[ChatMessage],
@@ -111,6 +111,11 @@ class GoodfireAPI(ModelAPI):
111
111
  # Initialize variant directly with model name
112
112
  self.variant = Variant(self.model_name) # type: ignore
113
113
 
114
+ @override
115
+ async def close(self) -> None:
116
+ # httpx.AsyncClient is created on each generate()
117
+ pass
118
+
114
119
  def _to_goodfire_message(self, message: ChatMessage) -> GoodfireChatMessage:
115
120
  """Convert an Inspect message to a Goodfire message format.
116
121
 
@@ -134,6 +134,11 @@ class GoogleAPI(ModelAPI):
134
134
  # create model
135
135
  self.model = GenerativeModel(self.model_name)
136
136
 
137
+ @override
138
+ async def close(self) -> None:
139
+ # GenerativeModel uses a cached/shared client so there is no 'close'
140
+ pass
141
+
137
142
  async def generate(
138
143
  self,
139
144
  input: list[ChatMessage],
@@ -393,12 +398,12 @@ def prepend_system_messages(
393
398
  ) -> None:
394
399
  # create system_parts
395
400
  system_parts: list[PartType] = [
396
- Part(text=message.content) for message in system_messages
401
+ Part(text=message.text) for message in system_messages
397
402
  ]
398
403
 
399
404
  # we want the system messages to be prepended to the first user message
400
405
  # (if there is no first user message then prepend one)
401
- if messages[0].get("role") == "user":
406
+ if len(messages) > 0 and messages[0].get("role") == "user":
402
407
  messages[0]["parts"] = system_parts + messages[0].get("parts", [])
403
408
  else:
404
409
  messages.insert(0, ContentDict(role="user", parts=system_parts))
@@ -561,7 +566,15 @@ def completion_choices_from_candidates(
561
566
  completion_choice_from_candidate(candidate) for candidate in candidates_list
562
567
  ]
563
568
  else:
564
- return []
569
+ return [
570
+ ChatCompletionChoice(
571
+ message=ChatMessageAssistant(
572
+ content="I was unable to generate a response.",
573
+ source="generate",
574
+ ),
575
+ stop_reason="unknown",
576
+ )
577
+ ]
565
578
 
566
579
 
567
580
  # google doesn't export FinishReason (it's in a sub-namespace with a beta
@@ -87,6 +87,10 @@ class GroqAPI(ModelAPI):
87
87
  http_client=httpx.AsyncClient(limits=httpx.Limits(max_connections=None)),
88
88
  )
89
89
 
90
+ @override
91
+ async def close(self) -> None:
92
+ await self.client.close()
93
+
90
94
  async def generate(
91
95
  self,
92
96
  input: list[ChatMessage],
@@ -1,6 +1,7 @@
1
1
  import asyncio
2
2
  import copy
3
3
  import functools
4
+ import gc
4
5
  import json
5
6
  import os
6
7
  from dataclasses import dataclass
@@ -112,6 +113,12 @@ class HuggingFaceAPI(ModelAPI):
112
113
  self.tokenizer.pad_token = self.tokenizer.eos_token
113
114
  self.tokenizer.padding_side = "left"
114
115
 
116
+ @override
117
+ async def close(self) -> None:
118
+ self.model = None
119
+ self.tokenizer = None
120
+ gc.collect()
121
+
115
122
  async def generate(
116
123
  self,
117
124
  input: list[ChatMessage],
@@ -118,6 +118,10 @@ class MistralAPI(ModelAPI):
118
118
  **model_args,
119
119
  )
120
120
 
121
+ @override
122
+ async def close(self) -> None:
123
+ await self.client.sdk_configuration.async_client.aclose()
124
+
121
125
  async def generate(
122
126
  self,
123
127
  input: list[ChatMessage],