inspect-ai 0.3.65__py3-none-any.whl → 0.3.67__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/_display/core/config.py +4 -0
- inspect_ai/_display/textual/app.py +13 -5
- inspect_ai/_display/textual/widgets/footer.py +2 -2
- inspect_ai/_display/textual/widgets/sandbox.py +1 -1
- inspect_ai/_display/textual/widgets/task_detail.py +7 -5
- inspect_ai/_display/textual/widgets/tasks.py +8 -6
- inspect_ai/_display/textual/widgets/transcript.py +1 -1
- inspect_ai/_eval/task/run.py +5 -3
- inspect_ai/_eval/task/task.py +9 -1
- inspect_ai/_util/format.py +58 -0
- inspect_ai/_view/www/dist/assets/index.css +29 -9
- inspect_ai/_view/www/dist/assets/index.js +368 -304
- inspect_ai/_view/www/src/samples/error/FlatSampleErrorView.tsx +1 -1
- inspect_ai/_view/www/src/samples/sample-tools/filters.ts +41 -20
- inspect_ai/_view/www/src/samples/sample-tools/sample-filter/SampleFilter.tsx +2 -1
- inspect_ai/_view/www/src/samples/sample-tools/sample-filter/completions.ts +28 -6
- inspect_ai/_view/www/src/samples/sample-tools/sample-filter/language.ts +5 -0
- inspect_ai/_view/www/src/samples/transcript/LoggerEventView.tsx +1 -3
- inspect_ai/_view/www/src/samples/transcript/SubtaskEventView.tsx +31 -16
- inspect_ai/_view/www/src/samples/transcript/TranscriptView.tsx +4 -1
- inspect_ai/_view/www/src/workspace/navbar/StatusPanel.module.css +1 -0
- inspect_ai/_view/www/src/workspace/navbar/StatusPanel.tsx +2 -2
- inspect_ai/model/_model.py +89 -2
- inspect_ai/model/_providers/anthropic.py +4 -0
- inspect_ai/model/_providers/azureai.py +5 -0
- inspect_ai/model/_providers/bedrock.py +5 -0
- inspect_ai/model/_providers/cloudflare.py +4 -0
- inspect_ai/model/_providers/goodfire.py +5 -0
- inspect_ai/model/_providers/google.py +16 -3
- inspect_ai/model/_providers/groq.py +4 -0
- inspect_ai/model/_providers/hf.py +7 -0
- inspect_ai/model/_providers/mistral.py +4 -0
- inspect_ai/model/_providers/openai.py +4 -0
- inspect_ai/model/_providers/vertex.py +5 -0
- inspect_ai/model/_providers/vllm.py +7 -0
- inspect_ai/solver/__init__.py +8 -1
- inspect_ai/solver/_human_agent/panel.py +11 -5
- inspect_ai/solver/_prompt.py +38 -5
- inspect_ai/util/_sandbox/docker/config.py +4 -1
- inspect_ai/util/_sandbox/docker/util.py +2 -1
- {inspect_ai-0.3.65.dist-info → inspect_ai-0.3.67.dist-info}/METADATA +3 -2
- {inspect_ai-0.3.65.dist-info → inspect_ai-0.3.67.dist-info}/RECORD +46 -46
- {inspect_ai-0.3.65.dist-info → inspect_ai-0.3.67.dist-info}/LICENSE +0 -0
- {inspect_ai-0.3.65.dist-info → inspect_ai-0.3.67.dist-info}/WHEEL +0 -0
- {inspect_ai-0.3.65.dist-info → inspect_ai-0.3.67.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.65.dist-info → inspect_ai-0.3.67.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,7 @@
|
|
1
1
|
import { ApplicationIcons } from "../../appearance/icons";
|
2
2
|
|
3
3
|
import clsx from "clsx";
|
4
|
-
import styles from "./
|
4
|
+
import styles from "./FlatSampleErrorView.module.css";
|
5
5
|
import { errorType } from "./error";
|
6
6
|
|
7
7
|
interface FlatSampleErrorViewProps {
|
@@ -35,7 +35,7 @@ const coerceValue = (value: unknown, descriptor: ScoreDescriptor): unknown => {
|
|
35
35
|
|
36
36
|
// Whether a particular value is filter-able
|
37
37
|
const isFilteringSupportedForValue = (value: unknown): boolean =>
|
38
|
-
["string", "number", "boolean"].includes(typeof value);
|
38
|
+
["string", "number", "boolean"].includes(typeof value) || value === null;
|
39
39
|
|
40
40
|
/**
|
41
41
|
* Returns the names of scores that are not allowed to be used as short names in
|
@@ -56,20 +56,26 @@ const bannedShortScoreNames = (scores: ScoreLabel[]): Set<string> => {
|
|
56
56
|
return banned;
|
57
57
|
};
|
58
58
|
|
59
|
+
// Pseudo-variables added to all filter expressions. These are not needed in most cases.
|
60
|
+
// Normally one could check a boolean value `foo` by simply typing `foo` or `not foo`.
|
61
|
+
// However, some evals use tristate values that can be true, false or null. This is where
|
62
|
+
// these constants come in handy.
|
63
|
+
const filterExpressionConstants: Record<string, unknown> = {
|
64
|
+
True: true,
|
65
|
+
False: false,
|
66
|
+
None: null,
|
67
|
+
};
|
68
|
+
|
59
69
|
/**
|
60
70
|
* Generates a dictionary of variables that can be used in the filter expression.
|
61
71
|
* High-level scorer metrics can be accessed by name directly.
|
62
72
|
* Child metrics are accessed using dot notation (e.g. `scorer_name.score_name`) or
|
63
73
|
* directly by name when it is unique.
|
64
|
-
*
|
65
|
-
* @param {import("../../samples/descriptor/samplesDescriptor").EvalDescriptor} evalDescriptor
|
66
|
-
* @param {import("../../types/log").Scores1} sampleScores
|
67
|
-
* @returns {Object<string, any>}
|
68
74
|
*/
|
69
75
|
const scoreVariables = (
|
70
76
|
evalDescriptor: EvalDescriptor,
|
71
77
|
sampleScores: Scores1,
|
72
|
-
) => {
|
78
|
+
): Record<string, unknown> => {
|
73
79
|
const bannedShortNames = bannedShortScoreNames(evalDescriptor.scores);
|
74
80
|
const variables: Record<string, unknown> = {};
|
75
81
|
|
@@ -77,7 +83,7 @@ const scoreVariables = (
|
|
77
83
|
variableName: string,
|
78
84
|
scoreLabel: ScoreLabel,
|
79
85
|
value: unknown,
|
80
|
-
) => {
|
86
|
+
): void => {
|
81
87
|
const coercedValue = coerceValue(
|
82
88
|
value,
|
83
89
|
evalDescriptor.scoreDescriptor(scoreLabel),
|
@@ -101,6 +107,12 @@ const scoreVariables = (
|
|
101
107
|
return variables;
|
102
108
|
};
|
103
109
|
|
110
|
+
const sampleVariables = (sample: SampleSummary): Record<string, unknown> => {
|
111
|
+
return {
|
112
|
+
has_error: !!sample.error,
|
113
|
+
};
|
114
|
+
};
|
115
|
+
|
104
116
|
/**
|
105
117
|
* Generates a dictionary of variables that can be used in the filter expression.
|
106
118
|
* High-level scorer metrics can be accessed by name directly.
|
@@ -115,11 +127,6 @@ export const scoreFilterItems = (
|
|
115
127
|
const valueToString = (value: unknown) =>
|
116
128
|
typeof value === "string" ? `"${value}"` : String(value);
|
117
129
|
|
118
|
-
/**
|
119
|
-
* @param {string | undefined} shortName
|
120
|
-
* @param {string | undefined} qualifiedName
|
121
|
-
* @param {import("../../types").ScoreLabel} scoreLabel
|
122
|
-
*/
|
123
130
|
const addScore = (
|
124
131
|
scoreLabel: ScoreLabel,
|
125
132
|
shortName?: string,
|
@@ -196,13 +203,33 @@ export const filterExpression = (
|
|
196
203
|
: [sample.target];
|
197
204
|
return targets.some((target) => target.match(new RegExp(regex, "i")));
|
198
205
|
};
|
206
|
+
const errorContains = (regex: string): boolean => {
|
207
|
+
return !!sample.error?.match(new RegExp(regex, "i"));
|
208
|
+
};
|
199
209
|
|
200
210
|
const extraFunctions = {
|
201
211
|
input_contains: inputContains,
|
202
212
|
target_contains: targetContains,
|
213
|
+
error_contains: errorContains,
|
214
|
+
};
|
215
|
+
const mySampleVariables = sampleVariables(sample);
|
216
|
+
const vars = {
|
217
|
+
...mySampleVariables,
|
218
|
+
...scoreVariables(evalDescriptor, sample.scores),
|
219
|
+
};
|
220
|
+
const resolveVariable = (name: string, get: (name: string) => any) => {
|
221
|
+
// Sample variables (like has_error) always exist.
|
222
|
+
if (name in mySampleVariables) {
|
223
|
+
return get(name);
|
224
|
+
}
|
225
|
+
// Score variables exist only if the sample completed successfully.
|
226
|
+
return sample.error ? undefined : get(name);
|
203
227
|
};
|
204
|
-
const expression = compileExpression(filterValue, {
|
205
|
-
|
228
|
+
const expression = compileExpression(filterValue, {
|
229
|
+
extraFunctions,
|
230
|
+
constants: filterExpressionConstants,
|
231
|
+
customProp: resolveVariable,
|
232
|
+
});
|
206
233
|
const result = expression(vars);
|
207
234
|
if (typeof result === "boolean") {
|
208
235
|
return { matches: result, error: undefined };
|
@@ -263,12 +290,6 @@ export const filterExpression = (
|
|
263
290
|
}
|
264
291
|
};
|
265
292
|
|
266
|
-
/**
|
267
|
-
* @param {import("../../samples/descriptor/samplesDescriptor").EvalDescriptor} evalDescriptor
|
268
|
-
* @param {import("../../api/types").SampleSummary[]} samples
|
269
|
-
* @param {string} filterValue
|
270
|
-
* @returns {}
|
271
|
-
*/
|
272
293
|
export const filterSamples = (
|
273
294
|
evalDescriptor: EvalDescriptor,
|
274
295
|
samples: SampleSummary[],
|
@@ -39,7 +39,8 @@ interface SampleFilterProps {
|
|
39
39
|
const FILTER_TOOLTIP = `
|
40
40
|
Filter samples by:
|
41
41
|
• Scores
|
42
|
-
•
|
42
|
+
• Samples with errors: has_error
|
43
|
+
• Input, target and error regex search: input_contains, target_contains, error_contains
|
43
44
|
|
44
45
|
Supported expressions:
|
45
46
|
• Arithmetic: +, -, *, /, mod, ^
|
@@ -13,7 +13,12 @@ import {
|
|
13
13
|
kScoreTypePassFail,
|
14
14
|
} from "../../../constants";
|
15
15
|
import { ScoreFilterItem } from "../filters";
|
16
|
-
import {
|
16
|
+
import {
|
17
|
+
KEYWORDS,
|
18
|
+
MATH_FUNCTIONS,
|
19
|
+
SAMPLE_FUNCTIONS,
|
20
|
+
SAMPLE_VARIABLES,
|
21
|
+
} from "./language";
|
17
22
|
import { Token, tokenize } from "./tokenize";
|
18
23
|
|
19
24
|
interface CompletionOptions {
|
@@ -76,10 +81,20 @@ const makeSampleFunctionCompletion = ([label, info]: [
|
|
76
81
|
boost: 0,
|
77
82
|
});
|
78
83
|
|
84
|
+
const makeSampleVariableCompletion = ([label, info]: [
|
85
|
+
string,
|
86
|
+
string,
|
87
|
+
]): Completion => ({
|
88
|
+
label,
|
89
|
+
type: "variable",
|
90
|
+
info,
|
91
|
+
boost: 10,
|
92
|
+
});
|
93
|
+
|
79
94
|
const makeLiteralCompletion = (k: string): Completion => ({
|
80
95
|
label: k,
|
81
96
|
type: "text",
|
82
|
-
boost:
|
97
|
+
boost: 20,
|
83
98
|
});
|
84
99
|
|
85
100
|
const makeCanonicalNameCompletion = (
|
@@ -89,14 +104,14 @@ const makeCanonicalNameCompletion = (
|
|
89
104
|
label: item.canonicalName + (autoSpaceIf(item) ? " " : ""),
|
90
105
|
type: "variable",
|
91
106
|
info: item.tooltip,
|
92
|
-
boost:
|
107
|
+
boost: 30,
|
93
108
|
});
|
94
109
|
|
95
110
|
const makeMemberAccessCompletion = (item: ScoreFilterItem): Completion => ({
|
96
111
|
label: item.qualifiedName?.split(".")[1] || "",
|
97
112
|
type: "variable",
|
98
113
|
info: item.tooltip,
|
99
|
-
boost:
|
114
|
+
boost: 40,
|
100
115
|
});
|
101
116
|
|
102
117
|
const getMemberScoreItems = (
|
@@ -130,6 +145,9 @@ export function getCompletions(
|
|
130
145
|
const sampleFunctionCompletionItems = SAMPLE_FUNCTIONS.map(
|
131
146
|
makeSampleFunctionCompletion,
|
132
147
|
);
|
148
|
+
const sampleVariableCompletionItems = SAMPLE_VARIABLES.map(
|
149
|
+
makeSampleVariableCompletion,
|
150
|
+
);
|
133
151
|
const variableCompletionItems = filterItems.map((item) =>
|
134
152
|
makeCanonicalNameCompletion(item),
|
135
153
|
);
|
@@ -138,6 +156,7 @@ export function getCompletions(
|
|
138
156
|
...keywordCompletionItems,
|
139
157
|
...mathFunctionCompletionItems,
|
140
158
|
...sampleFunctionCompletionItems,
|
159
|
+
...sampleVariableCompletionItems,
|
141
160
|
...variableCompletionItems,
|
142
161
|
];
|
143
162
|
|
@@ -218,9 +237,11 @@ export function getCompletions(
|
|
218
237
|
},
|
219
238
|
};
|
220
239
|
|
221
|
-
const priorityLabels = new Set(
|
240
|
+
const priorityLabels = new Set(
|
241
|
+
priorityCompletions.map((c) => c.label.trim()),
|
242
|
+
);
|
222
243
|
const defaultCompletionsAdjusted = defaultCompletionItems
|
223
|
-
.filter((c) => !priorityLabels.has(c.label))
|
244
|
+
.filter((c) => !priorityLabels.has(c.label.trim()))
|
224
245
|
.map((c) => ({ ...c, section: miscSection }));
|
225
246
|
|
226
247
|
return {
|
@@ -240,6 +261,7 @@ export function getCompletions(
|
|
240
261
|
completingAtEnd && item.scoreType !== kScoreTypeBoolean,
|
241
262
|
}),
|
242
263
|
),
|
264
|
+
...sampleVariableCompletionItems,
|
243
265
|
...sampleFunctionCompletionItems,
|
244
266
|
]);
|
245
267
|
|
@@ -13,7 +13,12 @@ export const MATH_FUNCTIONS: [string, string][] = [
|
|
13
13
|
["log10", "Base 10 logarithm"],
|
14
14
|
];
|
15
15
|
|
16
|
+
export const SAMPLE_VARIABLES: [string, string][] = [
|
17
|
+
["has_error", "Checks if the sample has an error"],
|
18
|
+
];
|
19
|
+
|
16
20
|
export const SAMPLE_FUNCTIONS: [string, string][] = [
|
17
21
|
["input_contains", "Checks if input contains a regular expression"],
|
18
22
|
["target_contains", "Checks if target contains a regular expression"],
|
23
|
+
["error_contains", "Checks if error contains a regular expression"],
|
19
24
|
];
|
@@ -24,9 +24,7 @@ export const LoggerEventView: React.FC<LoggerEventViewProps> = ({
|
|
24
24
|
icon={ApplicationIcons.logging[event.message.level.toLowerCase()]}
|
25
25
|
>
|
26
26
|
<div className={clsx("text-size-base", styles.grid)}>
|
27
|
-
<div className={clsx("text-size-smaller")}>
|
28
|
-
${event.message.message}
|
29
|
-
</div>
|
27
|
+
<div className={clsx("text-size-smaller")}>{event.message.message}</div>
|
30
28
|
<div className={clsx("text-size-smaller", "text-style-secondary")}>
|
31
29
|
{event.message.filename}:{event.message.lineno}
|
32
30
|
</div>
|
@@ -30,19 +30,6 @@ export const SubtaskEventView: React.FC<SubtaskEventViewProps> = ({
|
|
30
30
|
className,
|
31
31
|
}) => {
|
32
32
|
// Render Forks specially
|
33
|
-
|
34
|
-
const transcript =
|
35
|
-
event.events.length > 0 ? (
|
36
|
-
<TranscriptView
|
37
|
-
id={`${id}-subtask`}
|
38
|
-
data-name="Transcript"
|
39
|
-
events={event.events}
|
40
|
-
depth={depth + 1}
|
41
|
-
/>
|
42
|
-
) : (
|
43
|
-
""
|
44
|
-
);
|
45
|
-
|
46
33
|
const body =
|
47
34
|
event.type === "fork" ? (
|
48
35
|
<div title="Summary" className={clsx(styles.summary)}>
|
@@ -51,7 +38,16 @@ export const SubtaskEventView: React.FC<SubtaskEventViewProps> = ({
|
|
51
38
|
<Rendered values={event.input} />
|
52
39
|
</div>
|
53
40
|
<div className={clsx("text-style-label")}>Transcript</div>
|
54
|
-
{
|
41
|
+
{event.events.length > 0 ? (
|
42
|
+
<TranscriptView
|
43
|
+
id={`${id}-subtask`}
|
44
|
+
data-name="Transcript"
|
45
|
+
events={event.events}
|
46
|
+
depth={depth + 1}
|
47
|
+
/>
|
48
|
+
) : (
|
49
|
+
<None />
|
50
|
+
)}
|
55
51
|
</div>
|
56
52
|
) : (
|
57
53
|
<Fragment>
|
@@ -60,7 +56,14 @@ export const SubtaskEventView: React.FC<SubtaskEventViewProps> = ({
|
|
60
56
|
input={event.input}
|
61
57
|
result={event.result}
|
62
58
|
/>
|
63
|
-
{
|
59
|
+
{event.events.length > 0 ? (
|
60
|
+
<TranscriptView
|
61
|
+
id={`${id}-subtask`}
|
62
|
+
data-name="Transcript"
|
63
|
+
events={event.events}
|
64
|
+
depth={depth + 1}
|
65
|
+
/>
|
66
|
+
) : undefined}
|
64
67
|
</Fragment>
|
65
68
|
);
|
66
69
|
|
@@ -126,8 +129,20 @@ const Rendered: React.FC<RenderedProps> = ({ values }) => {
|
|
126
129
|
return <Rendered values={val} />;
|
127
130
|
});
|
128
131
|
} else if (values && typeof values === "object") {
|
129
|
-
|
132
|
+
if (Object.keys(values).length === 0) {
|
133
|
+
return <None />;
|
134
|
+
} else {
|
135
|
+
return <MetaDataView entries={values as Record<string, unknown>} />;
|
136
|
+
}
|
130
137
|
} else {
|
131
138
|
return values;
|
132
139
|
}
|
133
140
|
};
|
141
|
+
|
142
|
+
const None: React.FC = () => {
|
143
|
+
return (
|
144
|
+
<span className={clsx("text-size-small", "text-style-secondary")}>
|
145
|
+
[None]
|
146
|
+
</span>
|
147
|
+
);
|
148
|
+
};
|
@@ -387,7 +387,10 @@ const fixupEventStream = (events: Events) => {
|
|
387
387
|
});
|
388
388
|
const initEvent = events[initEventIndex];
|
389
389
|
|
390
|
-
|
390
|
+
// Filter pending events
|
391
|
+
const finalEvents = events.filter((e) => !e.pending);
|
392
|
+
|
393
|
+
const fixedUp = [...finalEvents];
|
391
394
|
if (initEvent) {
|
392
395
|
fixedUp.splice(initEventIndex, 0, {
|
393
396
|
timestamp: initEvent.timestamp,
|
@@ -51,9 +51,9 @@ const StatusPanel: React.FC<StatusPanelProps> = ({
|
|
51
51
|
<div className={styles.statusPanel}>
|
52
52
|
<i className={clsx(icon, styles.statusIcon)} style={{}} />
|
53
53
|
<div>
|
54
|
-
<div
|
54
|
+
<div>{status}</div>
|
55
55
|
<div>
|
56
|
-
(
|
56
|
+
({sampleCount} {sampleCount === 1 ? "sample" : "samples"})
|
57
57
|
</div>
|
58
58
|
</div>
|
59
59
|
</div>
|
inspect_ai/model/_model.py
CHANGED
@@ -7,8 +7,10 @@ import os
|
|
7
7
|
import time
|
8
8
|
from contextvars import ContextVar
|
9
9
|
from copy import deepcopy
|
10
|
+
from types import TracebackType
|
10
11
|
from typing import Any, Callable, Literal, Type, cast
|
11
12
|
|
13
|
+
from pydantic_core import to_jsonable_python
|
12
14
|
from tenacity import (
|
13
15
|
retry,
|
14
16
|
retry_if_exception,
|
@@ -109,6 +111,10 @@ class ModelAPI(abc.ABC):
|
|
109
111
|
# set any explicitly specified api key
|
110
112
|
self.api_key = api_key
|
111
113
|
|
114
|
+
async def close(self) -> None:
|
115
|
+
"""Close method for closing any client allocated for the model."""
|
116
|
+
pass
|
117
|
+
|
112
118
|
@abc.abstractmethod
|
113
119
|
async def generate(
|
114
120
|
self,
|
@@ -178,7 +184,17 @@ class ModelAPI(abc.ABC):
|
|
178
184
|
|
179
185
|
|
180
186
|
class Model:
|
181
|
-
"""Model interface.
|
187
|
+
"""Model interface.
|
188
|
+
|
189
|
+
Use `get_model()` to get an instance of a model. Model provides an
|
190
|
+
async context manager for closing the connection to it after use.
|
191
|
+
For example:
|
192
|
+
|
193
|
+
```python
|
194
|
+
async with get_model("openai/gpt-4o") as model:
|
195
|
+
response = await model.generate("Say hello")
|
196
|
+
```
|
197
|
+
"""
|
182
198
|
|
183
199
|
api: ModelAPI
|
184
200
|
"""Model API."""
|
@@ -196,10 +212,28 @@ class Model:
|
|
196
212
|
self.api = api
|
197
213
|
self.config = config
|
198
214
|
|
215
|
+
# state indicating whether our lifetime is bound by a context manager
|
216
|
+
self._context_bound = False
|
217
|
+
self._closed = False
|
218
|
+
|
199
219
|
# if using the Model API standalone in a notebook this will
|
200
220
|
# get hit before score() or eval() so we activate nest_asyncio
|
201
221
|
platform_init()
|
202
222
|
|
223
|
+
async def __aenter__(self: "Model") -> "Model":
|
224
|
+
self._context_bound = True
|
225
|
+
return self
|
226
|
+
|
227
|
+
async def __aexit__(
|
228
|
+
self,
|
229
|
+
exc_type: type[BaseException] | None,
|
230
|
+
exc: BaseException | None,
|
231
|
+
exc_tb: TracebackType | None,
|
232
|
+
) -> None:
|
233
|
+
if not self._closed:
|
234
|
+
await self.api.close()
|
235
|
+
self._closed = True
|
236
|
+
|
203
237
|
@property
|
204
238
|
def name(self) -> str:
|
205
239
|
"""Model name."""
|
@@ -598,10 +632,27 @@ def get_model(
|
|
598
632
|
config: GenerateConfig = GenerateConfig(),
|
599
633
|
base_url: str | None = None,
|
600
634
|
api_key: str | None = None,
|
635
|
+
memoize: bool = True,
|
601
636
|
**model_args: Any,
|
602
637
|
) -> Model:
|
603
638
|
"""Get an instance of a model.
|
604
639
|
|
640
|
+
Calls to get_model() are memoized (i.e. a call with the same arguments
|
641
|
+
will return an existing instance of the model rather than creating a
|
642
|
+
new one). You can disable this with `memoize=False`.
|
643
|
+
|
644
|
+
If you prefer to immediately close models after use (as well as
|
645
|
+
prevent caching) you can employ the async context manager built in
|
646
|
+
to the `Model` class. For example:
|
647
|
+
|
648
|
+
```python
|
649
|
+
async with get_model("openai/gpt-4o") as model:
|
650
|
+
response = await model.generate("Say hello")
|
651
|
+
```
|
652
|
+
|
653
|
+
In this case, the model client will be closed at the end of the
|
654
|
+
context manager and will not be available in the get_model() cache.
|
655
|
+
|
605
656
|
Args:
|
606
657
|
model: Model specification.
|
607
658
|
If `Model` is passed it is returned unmodified,
|
@@ -611,6 +662,8 @@ def get_model(
|
|
611
662
|
config: Configuration for model.
|
612
663
|
base_url: Optional. Alternate base URL for model.
|
613
664
|
api_key: Optional. API key for model.
|
665
|
+
memoize: Use/store a cached version of the model based on
|
666
|
+
the parameters to `get_model()`
|
614
667
|
**model_args: Additional args to
|
615
668
|
pass to model constructor.
|
616
669
|
|
@@ -637,6 +690,23 @@ def get_model(
|
|
637
690
|
else:
|
638
691
|
raise ValueError("No model specified (and no INSPECT_EVAL_MODEL defined)")
|
639
692
|
|
693
|
+
# see if we can return a memoized model instance
|
694
|
+
# (exclude mockllm since custom_outputs is an infinite generator)
|
695
|
+
model_cache_key: str = "" # for mypy below
|
696
|
+
if model.startswith("mockllm/"):
|
697
|
+
memoize = False
|
698
|
+
if memoize:
|
699
|
+
model_cache_key = (
|
700
|
+
model
|
701
|
+
+ config.model_dump_json(exclude_none=True)
|
702
|
+
+ str(base_url)
|
703
|
+
+ str(api_key)
|
704
|
+
+ str(to_jsonable_python(model_args, fallback=lambda _: None))
|
705
|
+
)
|
706
|
+
cached = cached_model(model_cache_key)
|
707
|
+
if cached is not None:
|
708
|
+
return cached
|
709
|
+
|
640
710
|
# split model into api name and model name if necessary
|
641
711
|
api_name = None
|
642
712
|
parts = model.split("/")
|
@@ -667,13 +737,30 @@ def get_model(
|
|
667
737
|
config=config,
|
668
738
|
**model_args,
|
669
739
|
)
|
670
|
-
|
740
|
+
m = Model(modelapi_instance, config)
|
741
|
+
if memoize:
|
742
|
+
_models[model_cache_key] = m
|
743
|
+
return m
|
671
744
|
|
672
745
|
else:
|
673
746
|
from_api = f" from {api_name}" if api_name else ""
|
674
747
|
raise ValueError(f"Model name {model}{from_api} not recognized.")
|
675
748
|
|
676
749
|
|
750
|
+
# cache for memoization of get_model
|
751
|
+
_models: dict[str, Model] = {}
|
752
|
+
|
753
|
+
|
754
|
+
def cached_model(key: str) -> Model | None:
|
755
|
+
# clean out context bound models before accessing the cache
|
756
|
+
for k in list(_models.keys()):
|
757
|
+
if _models[k]._context_bound:
|
758
|
+
del _models[k]
|
759
|
+
|
760
|
+
# read from the cache
|
761
|
+
return _models.get(key, None)
|
762
|
+
|
763
|
+
|
677
764
|
def resolve_models(
|
678
765
|
model: str | Model | list[str] | list[Model] | None,
|
679
766
|
model_base_url: str | None = None,
|
@@ -124,6 +124,11 @@ class AzureAIAPI(ModelAPI):
|
|
124
124
|
self.endpoint_url = endpoint_url
|
125
125
|
self.model_args = model_args
|
126
126
|
|
127
|
+
@override
|
128
|
+
async def close(self) -> None:
|
129
|
+
# client is created/destroyed each time in generate()
|
130
|
+
pass
|
131
|
+
|
127
132
|
async def generate(
|
128
133
|
self,
|
129
134
|
input: list[ChatMessage],
|
@@ -259,6 +259,11 @@ class BedrockAPI(ModelAPI):
|
|
259
259
|
except ImportError:
|
260
260
|
raise pip_dependency_error("Bedrock API", ["aioboto3"])
|
261
261
|
|
262
|
+
@override
|
263
|
+
async def close(self) -> None:
|
264
|
+
# client is created/destroyed each time in generate()
|
265
|
+
pass
|
266
|
+
|
262
267
|
@override
|
263
268
|
def connection_key(self) -> str:
|
264
269
|
return self.model_name
|
@@ -111,6 +111,11 @@ class GoodfireAPI(ModelAPI):
|
|
111
111
|
# Initialize variant directly with model name
|
112
112
|
self.variant = Variant(self.model_name) # type: ignore
|
113
113
|
|
114
|
+
@override
|
115
|
+
async def close(self) -> None:
|
116
|
+
# httpx.AsyncClient is created on each generate()
|
117
|
+
pass
|
118
|
+
|
114
119
|
def _to_goodfire_message(self, message: ChatMessage) -> GoodfireChatMessage:
|
115
120
|
"""Convert an Inspect message to a Goodfire message format.
|
116
121
|
|
@@ -134,6 +134,11 @@ class GoogleAPI(ModelAPI):
|
|
134
134
|
# create model
|
135
135
|
self.model = GenerativeModel(self.model_name)
|
136
136
|
|
137
|
+
@override
|
138
|
+
async def close(self) -> None:
|
139
|
+
# GenerativeModel uses a cached/shared client so there is no 'close'
|
140
|
+
pass
|
141
|
+
|
137
142
|
async def generate(
|
138
143
|
self,
|
139
144
|
input: list[ChatMessage],
|
@@ -393,12 +398,12 @@ def prepend_system_messages(
|
|
393
398
|
) -> None:
|
394
399
|
# create system_parts
|
395
400
|
system_parts: list[PartType] = [
|
396
|
-
Part(text=message.
|
401
|
+
Part(text=message.text) for message in system_messages
|
397
402
|
]
|
398
403
|
|
399
404
|
# we want the system messages to be prepended to the first user message
|
400
405
|
# (if there is no first user message then prepend one)
|
401
|
-
if messages[0].get("role") == "user":
|
406
|
+
if len(messages) > 0 and messages[0].get("role") == "user":
|
402
407
|
messages[0]["parts"] = system_parts + messages[0].get("parts", [])
|
403
408
|
else:
|
404
409
|
messages.insert(0, ContentDict(role="user", parts=system_parts))
|
@@ -561,7 +566,15 @@ def completion_choices_from_candidates(
|
|
561
566
|
completion_choice_from_candidate(candidate) for candidate in candidates_list
|
562
567
|
]
|
563
568
|
else:
|
564
|
-
return [
|
569
|
+
return [
|
570
|
+
ChatCompletionChoice(
|
571
|
+
message=ChatMessageAssistant(
|
572
|
+
content="I was unable to generate a response.",
|
573
|
+
source="generate",
|
574
|
+
),
|
575
|
+
stop_reason="unknown",
|
576
|
+
)
|
577
|
+
]
|
565
578
|
|
566
579
|
|
567
580
|
# google doesn't export FinishReason (it's in a sub-namespace with a beta
|
@@ -1,6 +1,7 @@
|
|
1
1
|
import asyncio
|
2
2
|
import copy
|
3
3
|
import functools
|
4
|
+
import gc
|
4
5
|
import json
|
5
6
|
import os
|
6
7
|
from dataclasses import dataclass
|
@@ -112,6 +113,12 @@ class HuggingFaceAPI(ModelAPI):
|
|
112
113
|
self.tokenizer.pad_token = self.tokenizer.eos_token
|
113
114
|
self.tokenizer.padding_side = "left"
|
114
115
|
|
116
|
+
@override
|
117
|
+
async def close(self) -> None:
|
118
|
+
self.model = None
|
119
|
+
self.tokenizer = None
|
120
|
+
gc.collect()
|
121
|
+
|
115
122
|
async def generate(
|
116
123
|
self,
|
117
124
|
input: list[ChatMessage],
|