inspect-ai 0.3.68__py3-none-any.whl → 0.3.69__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/_display/plain/display.py +9 -11
- inspect_ai/_display/textual/app.py +3 -4
- inspect_ai/_display/textual/widgets/samples.py +43 -8
- inspect_ai/_util/interrupt.py +9 -0
- inspect_ai/_util/logger.py +4 -0
- inspect_ai/_util/text.py +288 -1
- inspect_ai/_view/www/dist/assets/index.js +1 -1
- inspect_ai/_view/www/src/samples/descriptor/score/ObjectScoreDescriptor.tsx +1 -1
- inspect_ai/log/_samples.py +0 -4
- inspect_ai/model/_model.py +3 -0
- inspect_ai/model/_providers/google.py +356 -302
- inspect_ai/model/_providers/mistral.py +10 -8
- inspect_ai/model/_providers/providers.py +5 -5
- inspect_ai/solver/_plan.py +3 -0
- inspect_ai/solver/_solver.py +3 -0
- inspect_ai/solver/_task_state.py +3 -1
- inspect_ai/util/_sandbox/docker/cleanup.py +8 -3
- inspect_ai/util/_sandbox/docker/compose.py +5 -9
- inspect_ai/util/_sandbox/docker/docker.py +14 -2
- inspect_ai/util/_sandbox/docker/util.py +10 -1
- inspect_ai/util/_sandbox/self_check.py +2 -1
- inspect_ai/util/_subprocess.py +4 -1
- {inspect_ai-0.3.68.dist-info → inspect_ai-0.3.69.dist-info}/METADATA +3 -3
- {inspect_ai-0.3.68.dist-info → inspect_ai-0.3.69.dist-info}/RECORD +28 -27
- {inspect_ai-0.3.68.dist-info → inspect_ai-0.3.69.dist-info}/LICENSE +0 -0
- {inspect_ai-0.3.68.dist-info → inspect_ai-0.3.69.dist-info}/WHEEL +0 -0
- {inspect_ai-0.3.68.dist-info → inspect_ai-0.3.69.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.68.dist-info → inspect_ai-0.3.69.dist-info}/top_level.txt +0 -0
@@ -119,14 +119,14 @@ class PlainTaskDisplay(TaskDisplay):
|
|
119
119
|
self.samples_complete = 0
|
120
120
|
self.samples_total = 0
|
121
121
|
self.current_metrics: list[TaskDisplayMetric] | None = None
|
122
|
-
self.last_progress = 0
|
122
|
+
self.last_progress = 0
|
123
123
|
|
124
124
|
@contextlib.contextmanager
|
125
125
|
def progress(self) -> Iterator[Progress]:
|
126
126
|
self.progress_display = PlainProgress(self.task.profile.steps)
|
127
127
|
yield self.progress_display
|
128
128
|
|
129
|
-
@throttle(
|
129
|
+
@throttle(5)
|
130
130
|
def _print_status_throttled(self) -> None:
|
131
131
|
self._print_status()
|
132
132
|
|
@@ -135,13 +135,8 @@ class PlainTaskDisplay(TaskDisplay):
|
|
135
135
|
if not self.progress_display:
|
136
136
|
return
|
137
137
|
|
138
|
-
#
|
139
|
-
|
140
|
-
self.progress_display.current / self.progress_display.total * 100
|
141
|
-
)
|
142
|
-
|
143
|
-
# Only print on percentage changes to avoid too much output
|
144
|
-
if current_progress != self.last_progress:
|
138
|
+
# Only print when step count changes to avoid too much output
|
139
|
+
if self.progress_display.current != self.last_progress:
|
145
140
|
status_parts: list[str] = []
|
146
141
|
|
147
142
|
# if this is parallel print task and model to distinguish (limit both to 12 chars)
|
@@ -154,8 +149,11 @@ class PlainTaskDisplay(TaskDisplay):
|
|
154
149
|
)
|
155
150
|
|
156
151
|
# Add step progress
|
152
|
+
progress_percent = int(
|
153
|
+
self.progress_display.current / self.progress_display.total * 100
|
154
|
+
)
|
157
155
|
status_parts.append(
|
158
|
-
f"Steps: {self.progress_display.current:3d}/{self.progress_display.total} {
|
156
|
+
f"Steps: {self.progress_display.current:3d}/{self.progress_display.total} {progress_percent:3d}%"
|
159
157
|
)
|
160
158
|
|
161
159
|
# Add sample progress
|
@@ -187,7 +185,7 @@ class PlainTaskDisplay(TaskDisplay):
|
|
187
185
|
# Print on new line
|
188
186
|
print(" | ".join(status_parts))
|
189
187
|
|
190
|
-
self.last_progress =
|
188
|
+
self.last_progress = self.progress_display.current
|
191
189
|
|
192
190
|
def sample_complete(self, complete: int, total: int) -> None:
|
193
191
|
self.samples_complete = complete
|
@@ -13,7 +13,6 @@ from typing import (
|
|
13
13
|
|
14
14
|
import rich
|
15
15
|
from rich.console import Console
|
16
|
-
from rich.text import Text
|
17
16
|
from textual.app import App, ComposeResult
|
18
17
|
from textual.binding import Binding, BindingType
|
19
18
|
from textual.css.query import NoMatches
|
@@ -316,9 +315,9 @@ class TaskScreenApp(App[TR]):
|
|
316
315
|
|
317
316
|
def set_unread(unread: int | None) -> None:
|
318
317
|
if unread is not None:
|
319
|
-
console_tab.label =
|
318
|
+
console_tab.label = f"Console ({unread}" # type: ignore[assignment]
|
320
319
|
else:
|
321
|
-
console_tab.label =
|
320
|
+
console_tab.label = "Console" # type: ignore[assignment]
|
322
321
|
|
323
322
|
self.watch(console_view, "unread", set_unread)
|
324
323
|
|
@@ -385,7 +384,7 @@ class TaskScreenApp(App[TR]):
|
|
385
384
|
def set_title(self, title: str) -> None:
|
386
385
|
tabs = self.app.query_one(TabbedContent)
|
387
386
|
tab = tabs.get_tab(self.tab_id)
|
388
|
-
tab.label =
|
387
|
+
tab.label = title # type: ignore[assignment]
|
389
388
|
|
390
389
|
def activate(self) -> None:
|
391
390
|
# show the tab
|
@@ -6,6 +6,7 @@ from rich.table import Table
|
|
6
6
|
from rich.text import Text
|
7
7
|
from textual.app import ComposeResult
|
8
8
|
from textual.containers import Horizontal, HorizontalGroup, Vertical, VerticalGroup
|
9
|
+
from textual.css.query import NoMatches
|
9
10
|
from textual.reactive import reactive
|
10
11
|
from textual.widget import Widget
|
11
12
|
from textual.widgets import (
|
@@ -61,7 +62,10 @@ class SamplesView(Widget):
|
|
61
62
|
)
|
62
63
|
|
63
64
|
async def notify_active(self, active: bool) -> None:
|
64
|
-
|
65
|
+
try:
|
66
|
+
await self.query_one(TranscriptView).notify_active(active)
|
67
|
+
except NoMatches:
|
68
|
+
pass
|
65
69
|
|
66
70
|
def set_samples(self, samples: list[ActiveSample]) -> None:
|
67
71
|
# throttle to no more than 1 second per 100 samples
|
@@ -408,6 +412,16 @@ class SampleToolbar(Horizontal):
|
|
408
412
|
PENDING_STATUS = "pending_status"
|
409
413
|
PENDING_CAPTION = "pending_caption"
|
410
414
|
|
415
|
+
TIMEOUT_TOOL_CALL_ENABLED = (
|
416
|
+
"Cancel the tool call and report a timeout to the model."
|
417
|
+
)
|
418
|
+
TIMEOUT_TOOL_CALL_DISABLED = "Cancelling tool call..."
|
419
|
+
CANCEL_SCORE_OUTPUT_ENABLED = (
|
420
|
+
"Cancel the sample and score whatever output has been generated so far."
|
421
|
+
)
|
422
|
+
CANCEL_RAISE_ERROR_ENABLED = "Cancel the sample and raise an error"
|
423
|
+
CANCEL_DISABLED = "Cancelling sample..."
|
424
|
+
|
411
425
|
DEFAULT_CSS = f"""
|
412
426
|
SampleToolbar {{
|
413
427
|
grid-size: 5 1;
|
@@ -445,18 +459,18 @@ class SampleToolbar(Horizontal):
|
|
445
459
|
yield Button(
|
446
460
|
Text("Timeout Tool"),
|
447
461
|
id=self.TIMEOUT_TOOL_CALL,
|
448
|
-
tooltip=
|
462
|
+
tooltip=self.TIMEOUT_TOOL_CALL_ENABLED,
|
449
463
|
)
|
450
464
|
yield Horizontal()
|
451
465
|
yield Button(
|
452
466
|
Text("Cancel (Score)"),
|
453
467
|
id=self.CANCEL_SCORE_OUTPUT,
|
454
|
-
tooltip=
|
468
|
+
tooltip=self.CANCEL_SCORE_OUTPUT_ENABLED,
|
455
469
|
)
|
456
470
|
yield Button(
|
457
471
|
Text("Cancel (Error)"),
|
458
472
|
id=self.CANCEL_RAISE_ERROR,
|
459
|
-
tooltip=
|
473
|
+
tooltip=self.CANCEL_RAISE_ERROR_ENABLED,
|
460
474
|
)
|
461
475
|
|
462
476
|
def on_mount(self) -> None:
|
@@ -475,14 +489,26 @@ class SampleToolbar(Horizontal):
|
|
475
489
|
)
|
476
490
|
if isinstance(last_event, ToolEvent):
|
477
491
|
last_event._cancel()
|
478
|
-
|
479
|
-
|
480
|
-
|
481
|
-
self.
|
492
|
+
event.button.disabled = True
|
493
|
+
event.button.tooltip = self.TIMEOUT_TOOL_CALL_DISABLED
|
494
|
+
else:
|
495
|
+
if event.button.id == self.CANCEL_SCORE_OUTPUT:
|
496
|
+
self.sample.interrupt("score")
|
497
|
+
elif event.button.id == self.CANCEL_RAISE_ERROR:
|
498
|
+
self.sample.interrupt("error")
|
499
|
+
cancel_score_output = self.query_one("#" + self.CANCEL_SCORE_OUTPUT)
|
500
|
+
cancel_score_output.disabled = True
|
501
|
+
cancel_score_output.tooltip = self.CANCEL_DISABLED
|
502
|
+
cancel_with_error = self.query_one("#" + self.CANCEL_RAISE_ERROR)
|
503
|
+
cancel_with_error.disabled = True
|
504
|
+
cancel_with_error.tooltip = self.CANCEL_DISABLED
|
482
505
|
|
483
506
|
async def sync_sample(self, sample: ActiveSample | None) -> None:
|
484
507
|
from inspect_ai.log._transcript import ModelEvent
|
485
508
|
|
509
|
+
# is it a new sample?
|
510
|
+
new_sample = sample != self.sample
|
511
|
+
|
486
512
|
# track the sample
|
487
513
|
self.sample = sample
|
488
514
|
|
@@ -499,6 +525,13 @@ class SampleToolbar(Horizontal):
|
|
499
525
|
cancel_score_output.display = True
|
500
526
|
cancel_with_error.display = not sample.fails_on_error
|
501
527
|
|
528
|
+
# if its a new sample then reset enabled states
|
529
|
+
if new_sample:
|
530
|
+
cancel_score_output.disabled = False
|
531
|
+
cancel_score_output.tooltip = self.CANCEL_SCORE_OUTPUT_ENABLED
|
532
|
+
cancel_with_error.disabled = False
|
533
|
+
cancel_with_error.tooltip = self.CANCEL_RAISE_ERROR_ENABLED
|
534
|
+
|
502
535
|
# if we have a pending event then start the clock and show pending status
|
503
536
|
last_event = (
|
504
537
|
sample.transcript.events[-1]
|
@@ -520,6 +553,8 @@ class SampleToolbar(Horizontal):
|
|
520
553
|
)
|
521
554
|
|
522
555
|
timeout_tool.display = isinstance(last_event, ToolEvent)
|
556
|
+
timeout_tool.disabled = False
|
557
|
+
timeout_tool.tooltip = self.TIMEOUT_TOOL_CALL_ENABLED
|
523
558
|
|
524
559
|
clock.start(last_event.timestamp.timestamp())
|
525
560
|
else:
|
inspect_ai/_util/logger.py
CHANGED
@@ -90,6 +90,10 @@ class LogHandler(RichHandler):
|
|
90
90
|
if "Event loop is closed" in record.getMessage():
|
91
91
|
return
|
92
92
|
|
93
|
+
# skip google-genai AFC message
|
94
|
+
if "AFC is enabled with max remote calls" in record.getMessage():
|
95
|
+
return
|
96
|
+
|
93
97
|
# write to stderr if we are at or above the threshold
|
94
98
|
if record.levelno >= self.display_level:
|
95
99
|
super().emit(record)
|
inspect_ai/_util/text.py
CHANGED
@@ -1,7 +1,8 @@
|
|
1
|
+
import random
|
1
2
|
import re
|
2
3
|
import string
|
3
4
|
from logging import getLogger
|
4
|
-
from typing import NamedTuple
|
5
|
+
from typing import List, NamedTuple
|
5
6
|
|
6
7
|
logger = getLogger(__name__)
|
7
8
|
|
@@ -131,3 +132,289 @@ def truncate(text: str, length: int, overflow: str = "...", pad: bool = True) ->
|
|
131
132
|
truncated = text[: length - overflow_length] + overflow
|
132
133
|
|
133
134
|
return truncated
|
135
|
+
|
136
|
+
|
137
|
+
def generate_large_text(target_tokens: int) -> str:
|
138
|
+
"""Generate a large amount of text with approximately the target number of tokens"""
|
139
|
+
generated_text = []
|
140
|
+
estimated_tokens = 0
|
141
|
+
|
142
|
+
while estimated_tokens < target_tokens:
|
143
|
+
sentence = generate_sentence()
|
144
|
+
|
145
|
+
# Add paragraph breaks occasionally
|
146
|
+
if random.random() < 0.1:
|
147
|
+
sentence += "\n\n"
|
148
|
+
|
149
|
+
generated_text.append(sentence)
|
150
|
+
|
151
|
+
# Rough estimate of tokens (words + punctuation)
|
152
|
+
estimated_tokens += len(sentence.split()) + 2
|
153
|
+
|
154
|
+
return " ".join(generated_text)
|
155
|
+
|
156
|
+
|
157
|
+
def generate_sentence() -> str:
|
158
|
+
"""Generate a random sentence using predefined templates"""
|
159
|
+
adjectives, nouns, verbs = create_word_lists()
|
160
|
+
|
161
|
+
templates = [
|
162
|
+
f"The {random.choice(adjectives)} {random.choice(nouns)} {random.choice(verbs)} the {random.choice(adjectives)} {random.choice(nouns)}.",
|
163
|
+
f"A {random.choice(adjectives)} {random.choice(nouns)} {random.choice(verbs)} near the {random.choice(nouns)}.",
|
164
|
+
f"In the {random.choice(adjectives)} {random.choice(nouns)}, the {random.choice(nouns)} {random.choice(verbs)} {random.choice(adjectives)}.",
|
165
|
+
f"When the {random.choice(nouns)} {random.choice(verbs)}, a {random.choice(adjectives)} {random.choice(nouns)} {random.choice(verbs)}.",
|
166
|
+
f"The {random.choice(nouns)} {random.choice(verbs)} while the {random.choice(adjectives)} {random.choice(nouns)} {random.choice(verbs)}.",
|
167
|
+
]
|
168
|
+
|
169
|
+
return random.choice(templates)
|
170
|
+
|
171
|
+
|
172
|
+
def create_word_lists() -> tuple[List[str], List[str], List[str]]:
|
173
|
+
"""Create basic word lists for sentence generation"""
|
174
|
+
# Common adjectives
|
175
|
+
adjectives = [
|
176
|
+
"red",
|
177
|
+
"blue",
|
178
|
+
"green",
|
179
|
+
"dark",
|
180
|
+
"bright",
|
181
|
+
"quiet",
|
182
|
+
"loud",
|
183
|
+
"small",
|
184
|
+
"large",
|
185
|
+
"quick",
|
186
|
+
"slow",
|
187
|
+
"happy",
|
188
|
+
"sad",
|
189
|
+
"clever",
|
190
|
+
"wise",
|
191
|
+
"ancient",
|
192
|
+
"modern",
|
193
|
+
"complex",
|
194
|
+
"simple",
|
195
|
+
"elegant",
|
196
|
+
"rough",
|
197
|
+
"smooth",
|
198
|
+
"sharp",
|
199
|
+
"dull",
|
200
|
+
"fresh",
|
201
|
+
"stale",
|
202
|
+
"clean",
|
203
|
+
"dirty",
|
204
|
+
"heavy",
|
205
|
+
"light",
|
206
|
+
"hot",
|
207
|
+
"cold",
|
208
|
+
"dry",
|
209
|
+
"wet",
|
210
|
+
"rich",
|
211
|
+
"poor",
|
212
|
+
"thick",
|
213
|
+
"thin",
|
214
|
+
"strong",
|
215
|
+
"weak",
|
216
|
+
"early",
|
217
|
+
"late",
|
218
|
+
"young",
|
219
|
+
"old",
|
220
|
+
"good",
|
221
|
+
"bad",
|
222
|
+
"high",
|
223
|
+
"low",
|
224
|
+
"long",
|
225
|
+
"short",
|
226
|
+
"deep",
|
227
|
+
"shallow",
|
228
|
+
"hard",
|
229
|
+
"soft",
|
230
|
+
"near",
|
231
|
+
"far",
|
232
|
+
"wide",
|
233
|
+
"narrow",
|
234
|
+
"big",
|
235
|
+
"little",
|
236
|
+
"fast",
|
237
|
+
"slow",
|
238
|
+
"busy",
|
239
|
+
"lazy",
|
240
|
+
"new",
|
241
|
+
"old",
|
242
|
+
"full",
|
243
|
+
"empty",
|
244
|
+
"loud",
|
245
|
+
"quiet",
|
246
|
+
"sweet",
|
247
|
+
"sour",
|
248
|
+
"brave",
|
249
|
+
"scared",
|
250
|
+
]
|
251
|
+
|
252
|
+
# Common nouns
|
253
|
+
nouns = [
|
254
|
+
"time",
|
255
|
+
"person",
|
256
|
+
"year",
|
257
|
+
"way",
|
258
|
+
"day",
|
259
|
+
"thing",
|
260
|
+
"man",
|
261
|
+
"world",
|
262
|
+
"life",
|
263
|
+
"hand",
|
264
|
+
"part",
|
265
|
+
"child",
|
266
|
+
"eye",
|
267
|
+
"woman",
|
268
|
+
"place",
|
269
|
+
"work",
|
270
|
+
"week",
|
271
|
+
"case",
|
272
|
+
"point",
|
273
|
+
"group",
|
274
|
+
"number",
|
275
|
+
"room",
|
276
|
+
"fact",
|
277
|
+
"idea",
|
278
|
+
"water",
|
279
|
+
"money",
|
280
|
+
"month",
|
281
|
+
"book",
|
282
|
+
"line",
|
283
|
+
"city",
|
284
|
+
"business",
|
285
|
+
"night",
|
286
|
+
"question",
|
287
|
+
"story",
|
288
|
+
"job",
|
289
|
+
"word",
|
290
|
+
"house",
|
291
|
+
"power",
|
292
|
+
"game",
|
293
|
+
"country",
|
294
|
+
"plant",
|
295
|
+
"animal",
|
296
|
+
"tree",
|
297
|
+
"stone",
|
298
|
+
"river",
|
299
|
+
"fire",
|
300
|
+
"problem",
|
301
|
+
"theory",
|
302
|
+
"street",
|
303
|
+
"family",
|
304
|
+
"history",
|
305
|
+
"mind",
|
306
|
+
"car",
|
307
|
+
"music",
|
308
|
+
"art",
|
309
|
+
"nation",
|
310
|
+
"science",
|
311
|
+
"nature",
|
312
|
+
"truth",
|
313
|
+
"peace",
|
314
|
+
"voice",
|
315
|
+
"class",
|
316
|
+
"paper",
|
317
|
+
"space",
|
318
|
+
"ground",
|
319
|
+
"market",
|
320
|
+
"court",
|
321
|
+
"force",
|
322
|
+
"price",
|
323
|
+
"action",
|
324
|
+
"reason",
|
325
|
+
"love",
|
326
|
+
"law",
|
327
|
+
"bird",
|
328
|
+
"literature",
|
329
|
+
"knowledge",
|
330
|
+
"society",
|
331
|
+
"valley",
|
332
|
+
"ocean",
|
333
|
+
"machine",
|
334
|
+
"energy",
|
335
|
+
"metal",
|
336
|
+
"mountain",
|
337
|
+
]
|
338
|
+
|
339
|
+
# Common verbs (present tense)
|
340
|
+
verbs = [
|
341
|
+
"run",
|
342
|
+
"walk",
|
343
|
+
"jump",
|
344
|
+
"sing",
|
345
|
+
"dance",
|
346
|
+
"write",
|
347
|
+
"read",
|
348
|
+
"speak",
|
349
|
+
"listen",
|
350
|
+
"watch",
|
351
|
+
"think",
|
352
|
+
"grow",
|
353
|
+
"live",
|
354
|
+
"play",
|
355
|
+
"work",
|
356
|
+
"move",
|
357
|
+
"stop",
|
358
|
+
"start",
|
359
|
+
"create",
|
360
|
+
"destroy",
|
361
|
+
"build",
|
362
|
+
"break",
|
363
|
+
"push",
|
364
|
+
"pull",
|
365
|
+
"open",
|
366
|
+
"close",
|
367
|
+
"rise",
|
368
|
+
"fall",
|
369
|
+
"increase",
|
370
|
+
"decrease",
|
371
|
+
"begin",
|
372
|
+
"end",
|
373
|
+
"love",
|
374
|
+
"hate",
|
375
|
+
"help",
|
376
|
+
"hurt",
|
377
|
+
"make",
|
378
|
+
"take",
|
379
|
+
"give",
|
380
|
+
"receive",
|
381
|
+
"buy",
|
382
|
+
"sell",
|
383
|
+
"eat",
|
384
|
+
"drink",
|
385
|
+
"sleep",
|
386
|
+
"wake",
|
387
|
+
"laugh",
|
388
|
+
"cry",
|
389
|
+
"learn",
|
390
|
+
"teach",
|
391
|
+
"change",
|
392
|
+
"stay",
|
393
|
+
"come",
|
394
|
+
"go",
|
395
|
+
"arrive",
|
396
|
+
"leave",
|
397
|
+
"enter",
|
398
|
+
"exit",
|
399
|
+
"succeed",
|
400
|
+
"fail",
|
401
|
+
"win",
|
402
|
+
"lose",
|
403
|
+
"fight",
|
404
|
+
"defend",
|
405
|
+
"attack",
|
406
|
+
"protect",
|
407
|
+
"save",
|
408
|
+
"waste",
|
409
|
+
"gather",
|
410
|
+
"scatter",
|
411
|
+
"collect",
|
412
|
+
"distribute",
|
413
|
+
"join",
|
414
|
+
"separate",
|
415
|
+
"unite",
|
416
|
+
"divide",
|
417
|
+
"share",
|
418
|
+
]
|
419
|
+
|
420
|
+
return adjectives, nouns, verbs
|
@@ -14470,7 +14470,7 @@ var require_assets = __commonJS({
|
|
14470
14470
|
const value2 = score2[key2];
|
14471
14471
|
const formattedValue = value2 && isNumeric(value2) ? formatPrettyDecimal(
|
14472
14472
|
typeof value2 === "number" ? value2 : parseFloat(value2 === true ? "1" : value2)
|
14473
|
-
) : value2;
|
14473
|
+
) : String(value2);
|
14474
14474
|
scores2.push(
|
14475
14475
|
/* @__PURE__ */ jsxRuntimeExports.jsxs(
|
14476
14476
|
"div",
|
inspect_ai/log/_samples.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1
|
-
import asyncio
|
2
1
|
import contextlib
|
3
2
|
from contextvars import ContextVar
|
4
3
|
from datetime import datetime
|
@@ -43,7 +42,6 @@ class ActiveSample:
|
|
43
42
|
self.total_tokens = 0
|
44
43
|
self.transcript = transcript
|
45
44
|
self.sandboxes = sandboxes
|
46
|
-
self._sample_task = asyncio.current_task()
|
47
45
|
self._interrupt_action: Literal["score", "error"] | None = None
|
48
46
|
|
49
47
|
@property
|
@@ -60,8 +58,6 @@ class ActiveSample:
|
|
60
58
|
|
61
59
|
def interrupt(self, action: Literal["score", "error"]) -> None:
|
62
60
|
self._interrupt_action = action
|
63
|
-
assert self._sample_task
|
64
|
-
self._sample_task.cancel()
|
65
61
|
|
66
62
|
@property
|
67
63
|
def interrupt_action(self) -> Literal["score", "error"] | None:
|
inspect_ai/model/_model.py
CHANGED
@@ -23,6 +23,7 @@ from tenacity import (
|
|
23
23
|
from inspect_ai._util.constants import DEFAULT_MAX_CONNECTIONS
|
24
24
|
from inspect_ai._util.content import Content, ContentImage, ContentText
|
25
25
|
from inspect_ai._util.hooks import init_hooks, override_api_key, send_telemetry
|
26
|
+
from inspect_ai._util.interrupt import check_sample_interrupt
|
26
27
|
from inspect_ai._util.platform import platform_init
|
27
28
|
from inspect_ai._util.registry import (
|
28
29
|
RegistryInfo,
|
@@ -390,6 +391,8 @@ class Model:
|
|
390
391
|
before_sleep=functools.partial(log_rate_limit_retry, self.api.model_name),
|
391
392
|
)
|
392
393
|
async def generate() -> ModelOutput:
|
394
|
+
check_sample_interrupt()
|
395
|
+
|
393
396
|
if cache:
|
394
397
|
if isinstance(cache, CachePolicy):
|
395
398
|
policy = cache
|