inspect-ai 0.3.84__py3-none-any.whl → 0.3.86__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/_eval/eval.py +9 -6
- inspect_ai/_eval/task/log.py +1 -0
- inspect_ai/agent/_agent.py +1 -5
- inspect_ai/log/_log.py +3 -0
- inspect_ai/log/_recorders/buffer/database.py +19 -11
- inspect_ai/model/_openai.py +2 -2
- inspect_ai/model/_providers/openai.py +3 -2
- inspect_ai/model/_providers/providers.py +0 -22
- inspect_ai/model/_providers/together.py +2 -2
- {inspect_ai-0.3.84.dist-info → inspect_ai-0.3.86.dist-info}/METADATA +1 -2
- {inspect_ai-0.3.84.dist-info → inspect_ai-0.3.86.dist-info}/RECORD +15 -16
- inspect_ai/model/_providers/goodfire.py +0 -253
- {inspect_ai-0.3.84.dist-info → inspect_ai-0.3.86.dist-info}/WHEEL +0 -0
- {inspect_ai-0.3.84.dist-info → inspect_ai-0.3.86.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.84.dist-info → inspect_ai-0.3.86.dist-info}/licenses/LICENSE +0 -0
- {inspect_ai-0.3.84.dist-info → inspect_ai-0.3.86.dist-info}/top_level.txt +0 -0
inspect_ai/_eval/eval.py
CHANGED
@@ -43,7 +43,7 @@ from inspect_ai.model import (
|
|
43
43
|
GenerateConfigArgs,
|
44
44
|
Model,
|
45
45
|
)
|
46
|
-
from inspect_ai.model._model import init_active_model, resolve_models
|
46
|
+
from inspect_ai.model._model import get_model, init_active_model, resolve_models
|
47
47
|
from inspect_ai.scorer._reducer import reducer_log_names
|
48
48
|
from inspect_ai.solver._chain import chain
|
49
49
|
from inspect_ai.solver._solver import Solver, SolverSpec
|
@@ -751,10 +751,15 @@ async def eval_retry_async(
|
|
751
751
|
else None
|
752
752
|
)
|
753
753
|
|
754
|
+
# resolve the model
|
755
|
+
model = get_model(
|
756
|
+
model=eval_log.eval.model,
|
757
|
+
config=eval_log.eval.model_generate_config,
|
758
|
+
base_url=eval_log.eval.model_base_url,
|
759
|
+
**eval_log.eval.model_args,
|
760
|
+
)
|
761
|
+
|
754
762
|
# collect the rest of the params we need for the eval
|
755
|
-
model = eval_log.eval.model
|
756
|
-
model_base_url = eval_log.eval.model_base_url
|
757
|
-
model_args = eval_log.eval.model_args
|
758
763
|
task_args = eval_log.eval.task_args
|
759
764
|
tags = eval_log.eval.tags
|
760
765
|
limit = eval_log.eval.config.limit
|
@@ -813,8 +818,6 @@ async def eval_retry_async(
|
|
813
818
|
id=task_id, task=task, task_args=task_args, model=None, log=eval_log
|
814
819
|
),
|
815
820
|
model=model,
|
816
|
-
model_base_url=model_base_url,
|
817
|
-
model_args=model_args,
|
818
821
|
task_args=task_args,
|
819
822
|
sandbox=eval_log.eval.sandbox,
|
820
823
|
sandbox_cleanup=sandbox_cleanup,
|
inspect_ai/_eval/task/log.py
CHANGED
inspect_ai/agent/_agent.py
CHANGED
@@ -225,16 +225,12 @@ def agent_with(
|
|
225
225
|
name = name or info.name
|
226
226
|
description = description or info.metadata.get(AGENT_DESCRIPTION, None)
|
227
227
|
|
228
|
-
# if the name is null then raise
|
229
|
-
if name is None:
|
230
|
-
raise ValueError("You must provide a name to agent_with")
|
231
|
-
|
232
228
|
# now set registry info
|
233
229
|
set_registry_info(
|
234
230
|
agent,
|
235
231
|
RegistryInfo(
|
236
232
|
type="agent",
|
237
|
-
name=name,
|
233
|
+
name=name or "agent",
|
238
234
|
metadata={AGENT_DESCRIPTION: description}
|
239
235
|
if description is not None
|
240
236
|
else {},
|
inspect_ai/log/_log.py
CHANGED
@@ -599,6 +599,9 @@ class EvalSpec(BaseModel):
|
|
599
599
|
model: str
|
600
600
|
"""Model used for eval."""
|
601
601
|
|
602
|
+
model_generate_config: GenerateConfig = Field(default_factory=GenerateConfig)
|
603
|
+
"""Generate config specified for model instance."""
|
604
|
+
|
602
605
|
model_base_url: str | None = Field(default=None)
|
603
606
|
"""Optional override of model base url"""
|
604
607
|
|
@@ -199,28 +199,36 @@ class SampleBufferDatabase(SampleBuffer):
|
|
199
199
|
)
|
200
200
|
|
201
201
|
def remove_samples(self, samples: list[tuple[str | int, int]]) -> None:
|
202
|
+
# short circuit no samples
|
203
|
+
if len(samples) == 0:
|
204
|
+
return
|
205
|
+
|
202
206
|
with self._get_connection(write=True) as conn:
|
203
207
|
cursor = conn.cursor()
|
204
208
|
try:
|
205
|
-
#
|
206
|
-
|
207
|
-
|
208
|
-
[f"('{sid}', {epoch})" for sid, epoch in samples]
|
209
|
+
# Build a query using individual column comparisons instead of row values
|
210
|
+
placeholders = " OR ".join(
|
211
|
+
["(sample_id=? AND sample_epoch=?)" for _ in samples]
|
209
212
|
)
|
210
213
|
|
211
|
-
#
|
214
|
+
# Flatten parameters for binding
|
215
|
+
parameters = [item for tup in samples for item in tup]
|
216
|
+
|
217
|
+
# Delete associated events first
|
212
218
|
events_query = f"""
|
213
219
|
DELETE FROM events
|
214
|
-
WHERE
|
220
|
+
WHERE {placeholders}
|
215
221
|
"""
|
216
|
-
cursor.execute(events_query)
|
222
|
+
cursor.execute(events_query, parameters)
|
223
|
+
|
224
|
+
# Then delete the samples using the same approach
|
225
|
+
placeholders = " OR ".join(["(id=? AND epoch=?)" for _ in samples])
|
217
226
|
|
218
|
-
# Then delete the samples
|
219
227
|
samples_query = f"""
|
220
228
|
DELETE FROM samples
|
221
|
-
WHERE
|
229
|
+
WHERE {placeholders}
|
222
230
|
"""
|
223
|
-
cursor.execute(samples_query)
|
231
|
+
cursor.execute(samples_query, parameters)
|
224
232
|
finally:
|
225
233
|
cursor.close()
|
226
234
|
|
@@ -259,7 +267,7 @@ class SampleBufferDatabase(SampleBuffer):
|
|
259
267
|
|
260
268
|
# fetch data
|
261
269
|
return Samples(
|
262
|
-
samples=list(self._get_samples(conn)),
|
270
|
+
samples=list(self._get_samples(conn, True)),
|
263
271
|
metrics=task_data.metrics,
|
264
272
|
refresh=self.update_interval,
|
265
273
|
etag=str(task_data.version),
|
inspect_ai/model/_openai.py
CHANGED
@@ -3,7 +3,7 @@ import re
|
|
3
3
|
from copy import copy
|
4
4
|
from typing import Literal
|
5
5
|
|
6
|
-
from openai import
|
6
|
+
from openai import APIStatusError, OpenAIError
|
7
7
|
from openai.types.chat import (
|
8
8
|
ChatCompletion,
|
9
9
|
ChatCompletionAssistantMessageParam,
|
@@ -518,7 +518,7 @@ def chat_choices_from_openai(
|
|
518
518
|
|
519
519
|
|
520
520
|
def openai_handle_bad_request(
|
521
|
-
model_name: str, e:
|
521
|
+
model_name: str, e: APIStatusError
|
522
522
|
) -> ModelOutput | Exception:
|
523
523
|
# extract message
|
524
524
|
if isinstance(e.body, dict) and "message" in e.body.keys():
|
@@ -13,6 +13,7 @@ from openai import (
|
|
13
13
|
AsyncOpenAI,
|
14
14
|
BadRequestError,
|
15
15
|
RateLimitError,
|
16
|
+
UnprocessableEntityError,
|
16
17
|
)
|
17
18
|
from openai._types import NOT_GIVEN
|
18
19
|
from openai.types.chat import ChatCompletion
|
@@ -295,13 +296,13 @@ class OpenAIAPI(ModelAPI):
|
|
295
296
|
else None
|
296
297
|
),
|
297
298
|
), model_call()
|
298
|
-
except BadRequestError as e:
|
299
|
+
except (BadRequestError, UnprocessableEntityError) as e:
|
299
300
|
return self.handle_bad_request(e), model_call()
|
300
301
|
|
301
302
|
def on_response(self, response: dict[str, Any]) -> None:
|
302
303
|
pass
|
303
304
|
|
304
|
-
def handle_bad_request(self, ex:
|
305
|
+
def handle_bad_request(self, ex: APIStatusError) -> ModelOutput | Exception:
|
305
306
|
return openai_handle_bad_request(self.model_name, ex)
|
306
307
|
|
307
308
|
def _chat_choices_from_response(
|
@@ -253,28 +253,6 @@ def none() -> type[ModelAPI]:
|
|
253
253
|
return NoModel
|
254
254
|
|
255
255
|
|
256
|
-
@modelapi("goodfire")
|
257
|
-
def goodfire() -> type[ModelAPI]:
|
258
|
-
"""Get the Goodfire API provider."""
|
259
|
-
FEATURE = "Goodfire API"
|
260
|
-
PACKAGE = "goodfire"
|
261
|
-
MIN_VERSION = "0.3.4" # Support for newer Llama models and OpenAI compatibility
|
262
|
-
|
263
|
-
# verify we have the package
|
264
|
-
try:
|
265
|
-
import goodfire # noqa: F401
|
266
|
-
except ImportError:
|
267
|
-
raise pip_dependency_error(FEATURE, [PACKAGE])
|
268
|
-
|
269
|
-
# verify version
|
270
|
-
verify_required_version(FEATURE, PACKAGE, MIN_VERSION)
|
271
|
-
|
272
|
-
# in the clear
|
273
|
-
from .goodfire import GoodfireAPI
|
274
|
-
|
275
|
-
return GoodfireAPI
|
276
|
-
|
277
|
-
|
278
256
|
def validate_openai_client(feature: str) -> None:
|
279
257
|
FEATURE = feature
|
280
258
|
PACKAGE = "openai"
|
@@ -3,7 +3,7 @@ from json import dumps
|
|
3
3
|
from typing import Any
|
4
4
|
|
5
5
|
import httpx
|
6
|
-
from openai import
|
6
|
+
from openai import APIStatusError
|
7
7
|
from openai.types.chat import (
|
8
8
|
ChatCompletion,
|
9
9
|
)
|
@@ -105,7 +105,7 @@ class TogetherAIAPI(OpenAIAPI):
|
|
105
105
|
return DEFAULT_MAX_TOKENS
|
106
106
|
|
107
107
|
@override
|
108
|
-
def handle_bad_request(self, ex:
|
108
|
+
def handle_bad_request(self, ex: APIStatusError) -> ModelOutput | Exception:
|
109
109
|
response = ex.response.json()
|
110
110
|
if "error" in response and "message" in response.get("error"):
|
111
111
|
content = response.get("error").get("message")
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: inspect_ai
|
3
|
-
Version: 0.3.
|
3
|
+
Version: 0.3.86
|
4
4
|
Summary: Framework for large language model evaluations
|
5
5
|
Author: UK AI Security Institute
|
6
6
|
License: MIT License
|
@@ -56,7 +56,6 @@ Requires-Dist: aioboto3; extra == "dev"
|
|
56
56
|
Requires-Dist: azure-ai-inference; extra == "dev"
|
57
57
|
Requires-Dist: google-cloud-aiplatform; extra == "dev"
|
58
58
|
Requires-Dist: google-genai; extra == "dev"
|
59
|
-
Requires-Dist: goodfire; extra == "dev"
|
60
59
|
Requires-Dist: griffe; extra == "dev"
|
61
60
|
Requires-Dist: groq; extra == "dev"
|
62
61
|
Requires-Dist: ipython; extra == "dev"
|
@@ -46,7 +46,7 @@ inspect_ai/_display/textual/widgets/transcript.py,sha256=zaxlDixT6Fie0acAWBM9Hlt
|
|
46
46
|
inspect_ai/_display/textual/widgets/vscode.py,sha256=YTXdIZ0fcf9XE2v3rWIfUTgnXFww8uKCo7skugQLIbs,1247
|
47
47
|
inspect_ai/_eval/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
48
48
|
inspect_ai/_eval/context.py,sha256=gWTjEEMVTJMJpCCKLRs4joZDkG00rzE7-HXZFyzSC_I,1283
|
49
|
-
inspect_ai/_eval/eval.py,sha256=
|
49
|
+
inspect_ai/_eval/eval.py,sha256=tfIYOJSNGNfZHl18XYqXZHIMKlCmUZ8gbqe7I0OZJII,40307
|
50
50
|
inspect_ai/_eval/evalset.py,sha256=FnZBVi5hOt6f84PNYlFhkjb7N1lNgiQydQlernJZeW4,24005
|
51
51
|
inspect_ai/_eval/list.py,sha256=VbZ-2EI6MqrXvCN7VTz21TQSoU5K5_Q0hqhxmj5A_m0,3744
|
52
52
|
inspect_ai/_eval/loader.py,sha256=yCDrW5MhP6GT329hZ_gUm_eAMsCA9G7jb8sm45Pj-pw,24970
|
@@ -59,7 +59,7 @@ inspect_ai/_eval/task/epochs.py,sha256=Ci7T6CQniSOTChv5Im2dCdSDrP-5hq19rV6iJ2uBc
|
|
59
59
|
inspect_ai/_eval/task/error.py,sha256=Vhqinfdf0eIrjn7kUY7-id8Kbdggr-fEFpAJeJrkJ1M,1244
|
60
60
|
inspect_ai/_eval/task/generate.py,sha256=C9-S9ak4VFQO7QgtUbGjt8F4sTyXS5nekR3Mg_MPwmM,2511
|
61
61
|
inspect_ai/_eval/task/images.py,sha256=nTzHizlyuPYumPH7gAOBSrNkTwTbAmZ7tKdzN7d_R2k,4035
|
62
|
-
inspect_ai/_eval/task/log.py,sha256=
|
62
|
+
inspect_ai/_eval/task/log.py,sha256=PD2ZrqtHY0zRyx7pB8L5v-txyaBRePs76cFu5Fb-vjE,11817
|
63
63
|
inspect_ai/_eval/task/resolved.py,sha256=OCQc_0HmW_Vw8o1KisX0DCn-eOPkTbR1v_y_jEaAlhU,966
|
64
64
|
inspect_ai/_eval/task/results.py,sha256=x4weYRK2XGowfBG3f2msOeZQ_pxh230HTlw6kps33jw,17925
|
65
65
|
inspect_ai/_eval/task/run.py,sha256=RS2Qv3AythSkQL4fsgBFaXfyx2WDIZuFj9v6ifoRiYs,38714
|
@@ -450,7 +450,7 @@ inspect_ai/_view/www/src/workspace/tabs/SamplesTab.tsx,sha256=s6jt1-5_Hrgz3_ysT1
|
|
450
450
|
inspect_ai/_view/www/src/workspace/tabs/grouping.ts,sha256=6lvFzReQKQ_43S20xN4kfBJN2F7Tfs2VWeSMIuHxUAI,6187
|
451
451
|
inspect_ai/_view/www/src/workspace/tabs/types.ts,sha256=Wa1Y4tZwYO_QJr0Tg9-5xJFztmcMYCODSm6JvdzMpDw,471
|
452
452
|
inspect_ai/agent/__init__.py,sha256=nzL9TPAARSJVZRPogWHxZ-qJriXBGmFUM9DV4NRi21o,749
|
453
|
-
inspect_ai/agent/_agent.py,sha256=
|
453
|
+
inspect_ai/agent/_agent.py,sha256=5MXMrY5bsBQ4AI5y5rVaIYywp7JxQzUnhp-KqBIMF7I,7622
|
454
454
|
inspect_ai/agent/_as_solver.py,sha256=_6H0L9JidC6JjpMaBRBAjIrgzE8GKEoygJjOC_JRoLQ,2340
|
455
455
|
inspect_ai/agent/_as_tool.py,sha256=vT5hrcKfkyP90i4Ieuy_dx4cYsFKOMdPs-6x12cuqMk,4449
|
456
456
|
inspect_ai/agent/_filter.py,sha256=qnT0HbT4edpDi0MwXY3Q3It2pzNRkTRXZDOqfCwMY6M,1234
|
@@ -510,7 +510,7 @@ inspect_ai/log/_bundle.py,sha256=5Uy-s64_SFokZ7WRzti9mD7yoKrd2sOzdvqKyahoiC4,804
|
|
510
510
|
inspect_ai/log/_condense.py,sha256=OedMphK5Q2YPuY1cnoAM7tGsyVIU6Kwrv3oIeb3dFmY,10881
|
511
511
|
inspect_ai/log/_convert.py,sha256=qn6q10Um2XV7dnK4nQargANa0bz6RFJPmaEMINv38cs,3467
|
512
512
|
inspect_ai/log/_file.py,sha256=QjeVUegoCWVUv6CMsj0das_UpZZZMfnbvCQAKlFYGXE,17105
|
513
|
-
inspect_ai/log/_log.py,sha256=
|
513
|
+
inspect_ai/log/_log.py,sha256=KsssY2kGfuDHGIXOGJHN4bO1LXVs0f3XtqIUfA2R68A,25109
|
514
514
|
inspect_ai/log/_message.py,sha256=QofM_JZF_x3k_5ta1uQzoN_VnMoUhXFnqWurIn9FXOY,1999
|
515
515
|
inspect_ai/log/_retry.py,sha256=e7a2hjl3Ncl8b8sU7CsDpvK8DV0b1uSRLeokRX1mt34,2109
|
516
516
|
inspect_ai/log/_samples.py,sha256=wPQlV1VR9djWaj37lLrjBprCabdAm4S2vFOsQTcd12U,4910
|
@@ -524,7 +524,7 @@ inspect_ai/log/_recorders/recorder.py,sha256=zDDpl2tktPjb6xk5kd4TyEMxkXZiLgXXpPi
|
|
524
524
|
inspect_ai/log/_recorders/types.py,sha256=Aeo-U7FhmWQSvE_uz3fwUI7cqaSR-ZE_uRVu-1fBCgc,865
|
525
525
|
inspect_ai/log/_recorders/buffer/__init__.py,sha256=6DsRdnNl-ic-xJmnBE5i45ZP3eB4yAta9wxi5WFcbqc,367
|
526
526
|
inspect_ai/log/_recorders/buffer/buffer.py,sha256=rtLvaX7nSqNrWb-3CeSaOHwJgF1CzRgXFT_I1dDkM1k,945
|
527
|
-
inspect_ai/log/_recorders/buffer/database.py,sha256=
|
527
|
+
inspect_ai/log/_recorders/buffer/database.py,sha256=3yV8OlDsQ4zFQHNqe7aBAHwkUISW3zmaLBlD1OFj36w,22396
|
528
528
|
inspect_ai/log/_recorders/buffer/filestore.py,sha256=S6RP-5zkOPSmy1hV2LCCbfwdX-YFZGuIEjfJuOWMjDQ,8274
|
529
529
|
inspect_ai/log/_recorders/buffer/types.py,sha256=pTnPCZHbk9qF6yF-eNXHTa23cLH_FvP8dmfPJCFO15Q,2046
|
530
530
|
inspect_ai/model/__init__.py,sha256=6Aa_HEU-rgxWPDaIRlE6KBdXY406x2LtcLeVtAxk-AI,2453
|
@@ -537,7 +537,7 @@ inspect_ai/model/_generate_config.py,sha256=_-kzw7LOl45baVkTjlfL1K1VLKGgNOOczH2H
|
|
537
537
|
inspect_ai/model/_model.py,sha256=h4ASS2VuTZ_97145rLW202u6e7-mw4ENnnlBl0Vsbio,52127
|
538
538
|
inspect_ai/model/_model_call.py,sha256=VJ8wnl9Y81JaiClBYM8eyt1jVb3n-yc6Dd88ofRiJDc,2234
|
539
539
|
inspect_ai/model/_model_output.py,sha256=R5EAUPLc5RWymVb3le4cbqbNCZ9voTzg0U1j_e4I-yM,7768
|
540
|
-
inspect_ai/model/_openai.py,sha256
|
540
|
+
inspect_ai/model/_openai.py,sha256=-N_LhZR8-nrnCL8h9lklo_RrGNDR1SzMJ0tPafVuPXo,19380
|
541
541
|
inspect_ai/model/_openai_computer_use.py,sha256=vbKkYLhqNuX16zuWfg5MaGp9H8URrPcLhKQ1pDsZtPo,5943
|
542
542
|
inspect_ai/model/_openai_responses.py,sha256=bQWuVvJIkS8CqtoX9z1aRb1aky4TNbMngG2paB3wsrA,20179
|
543
543
|
inspect_ai/model/_reasoning.py,sha256=qmR8WT6t_cb7NIsJOQHPyFZh2eLV0HmYxKo2vtvteQ4,929
|
@@ -547,7 +547,6 @@ inspect_ai/model/_providers/anthropic.py,sha256=PYxV0D_bt0Icp2wEWb6GMCpDb-uBFKYy
|
|
547
547
|
inspect_ai/model/_providers/azureai.py,sha256=uXED_qmeyW1XAGBosbG7PJNk833RIeokKX3l_8O9gYA,14341
|
548
548
|
inspect_ai/model/_providers/bedrock.py,sha256=rh8BvSUPWiFMh0TQwMYTlucfFrDKswtLhzozulrz7wE,24004
|
549
549
|
inspect_ai/model/_providers/cloudflare.py,sha256=mWqBqc0zzf29UWz34biq8CxSu99a95YjpH_6A4na52g,4617
|
550
|
-
inspect_ai/model/_providers/goodfire.py,sha256=J0nxGbF8lXBmc5YHBJCsZdF03mWT5SuWMb21d9ho3FM,8799
|
551
550
|
inspect_ai/model/_providers/google.py,sha256=gcg8pvYAV5gYc4NXC5mLqFyuU7KuhyNrzdXIY57sYl8,28207
|
552
551
|
inspect_ai/model/_providers/grok.py,sha256=dS88ueXiD-kHAFr0jCoTpTGLGa2VsUlB_TFP8L_2lBM,995
|
553
552
|
inspect_ai/model/_providers/groq.py,sha256=mcRKu33e-mO5l06PGV6SjsildQd0XCti6QNXwwFWL7I,11246
|
@@ -557,12 +556,12 @@ inspect_ai/model/_providers/mistral.py,sha256=FbMPN_pw8LZal2iFGf5FX70ypuH3k44FUn
|
|
557
556
|
inspect_ai/model/_providers/mockllm.py,sha256=gL9f-f5TOdE4a0GVENr3cOIIp2kv8zVXWPZ608rouGk,2440
|
558
557
|
inspect_ai/model/_providers/none.py,sha256=6qLbZpHSoEZaaxFO7luieFjqig2Ju8Fu00DlRngAry8,935
|
559
558
|
inspect_ai/model/_providers/ollama.py,sha256=mBPSxaEkiH_RnlHKqOyFBlXObQhc2dfjL-rCKrea5u8,675
|
560
|
-
inspect_ai/model/_providers/openai.py,sha256=
|
559
|
+
inspect_ai/model/_providers/openai.py,sha256=zJkhtiEQrmsuhfL7mpBPpOlYJ_WNraeyTkjYTelF0no,16535
|
561
560
|
inspect_ai/model/_providers/openai_o1.py,sha256=k-Xm_Wzn1KHKL6Z1KTHg4CTTr8ybgiHvXkLiLdjP7Os,12926
|
562
561
|
inspect_ai/model/_providers/openai_responses.py,sha256=YPXt8KQfIEiiTpvtoQECBoNQLDLbwBW_KhBfM8vEhJk,6324
|
563
562
|
inspect_ai/model/_providers/openrouter.py,sha256=pDimDmm_4FzS4GZx0n9z8z717mQf3IQlgEy30huzpc4,4730
|
564
|
-
inspect_ai/model/_providers/providers.py,sha256=
|
565
|
-
inspect_ai/model/_providers/together.py,sha256=
|
563
|
+
inspect_ai/model/_providers/providers.py,sha256=Sd2D9OcWkukuBcl_-KDfdpxMaAShv1JZhL5KfAM87CE,5817
|
564
|
+
inspect_ai/model/_providers/together.py,sha256=Wh3G0vhKHq5ofx1otwXjJFhM98Ll70IbqBhUNNV2-rk,9743
|
566
565
|
inspect_ai/model/_providers/vertex.py,sha256=60W7kgoA83GtKdMeJgNU2IAw0N0wTscg4YCcMPu2bwo,17185
|
567
566
|
inspect_ai/model/_providers/vllm.py,sha256=UYjCCXzw2hGJHVC3oPl-u2EI4iAm8ZncoIfYp1QJkbQ,14238
|
568
567
|
inspect_ai/model/_providers/util/__init__.py,sha256=d4T_qvXihTRd1zmQkNE3xUBlHCX8tOIbRK19EwU0fTs,717
|
@@ -693,9 +692,9 @@ inspect_ai/util/_sandbox/docker/internal.py,sha256=c8X8TLrBPOvsfnq5TkMlb_bzTALyc
|
|
693
692
|
inspect_ai/util/_sandbox/docker/prereqs.py,sha256=0j6_OauBBnVlpBleADcZavIAAQZy4WewVjbRn9c0stg,3355
|
694
693
|
inspect_ai/util/_sandbox/docker/service.py,sha256=hhHIWH1VDFLwehdGd19aUBD_VKfDO3GCPxpw1HSwVQk,2437
|
695
694
|
inspect_ai/util/_sandbox/docker/util.py,sha256=EeInihCNXgUWxaqZ4dNOJd719kXL2_jr63QCoXn68vA,3154
|
696
|
-
inspect_ai-0.3.
|
697
|
-
inspect_ai-0.3.
|
698
|
-
inspect_ai-0.3.
|
699
|
-
inspect_ai-0.3.
|
700
|
-
inspect_ai-0.3.
|
701
|
-
inspect_ai-0.3.
|
695
|
+
inspect_ai-0.3.86.dist-info/licenses/LICENSE,sha256=xZPCr8gTiFIerrA_DRpLAbw-UUftnLFsHxKeW-NTtq8,1081
|
696
|
+
inspect_ai-0.3.86.dist-info/METADATA,sha256=tId3lj5ywe2A79iWACtXXP_aBeWfnbQQfAujxJGxaoc,4965
|
697
|
+
inspect_ai-0.3.86.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
|
698
|
+
inspect_ai-0.3.86.dist-info/entry_points.txt,sha256=WGGLmzTzDWLzYfiyovSY6oEKuf-gqzSDNOb5V-hk3fM,54
|
699
|
+
inspect_ai-0.3.86.dist-info/top_level.txt,sha256=Tp3za30CHXJEKLk8xLe9qGsW4pBzJpEIOMHOHNCXiVo,11
|
700
|
+
inspect_ai-0.3.86.dist-info/RECORD,,
|
@@ -1,253 +0,0 @@
|
|
1
|
-
import os
|
2
|
-
from typing import Any, List, Literal, get_args
|
3
|
-
|
4
|
-
from goodfire import AsyncClient
|
5
|
-
from goodfire.api.chat.interfaces import ChatMessage as GoodfireChatMessage
|
6
|
-
from goodfire.api.exceptions import (
|
7
|
-
InvalidRequestException,
|
8
|
-
RateLimitException,
|
9
|
-
ServerErrorException,
|
10
|
-
)
|
11
|
-
from goodfire.variants.variants import SUPPORTED_MODELS, Variant
|
12
|
-
from typing_extensions import override
|
13
|
-
|
14
|
-
from inspect_ai.tool._tool_choice import ToolChoice
|
15
|
-
from inspect_ai.tool._tool_info import ToolInfo
|
16
|
-
|
17
|
-
from .._chat_message import (
|
18
|
-
ChatMessage,
|
19
|
-
ChatMessageAssistant,
|
20
|
-
ChatMessageSystem,
|
21
|
-
ChatMessageTool,
|
22
|
-
ChatMessageUser,
|
23
|
-
)
|
24
|
-
from .._generate_config import GenerateConfig
|
25
|
-
from .._model import ModelAPI
|
26
|
-
from .._model_call import ModelCall
|
27
|
-
from .._model_output import (
|
28
|
-
ChatCompletionChoice,
|
29
|
-
ModelOutput,
|
30
|
-
ModelUsage,
|
31
|
-
)
|
32
|
-
from .util import environment_prerequisite_error, model_base_url
|
33
|
-
|
34
|
-
# Constants
|
35
|
-
GOODFIRE_API_KEY = "GOODFIRE_API_KEY"
|
36
|
-
DEFAULT_BASE_URL = "https://api.goodfire.ai"
|
37
|
-
DEFAULT_MAX_TOKENS = 4096
|
38
|
-
DEFAULT_TEMPERATURE = 1.0 # Standard sampling temperature (baseline)
|
39
|
-
DEFAULT_TOP_P = 1.0 # No nucleus sampling truncation (baseline)
|
40
|
-
|
41
|
-
|
42
|
-
class GoodfireAPI(ModelAPI):
|
43
|
-
"""Goodfire API provider.
|
44
|
-
|
45
|
-
This provider implements the Goodfire API for LLM inference. It supports:
|
46
|
-
- Chat completions with standard message formats
|
47
|
-
- Basic parameter controls (temperature, top_p, etc.)
|
48
|
-
- Usage statistics tracking
|
49
|
-
- Stop reason handling
|
50
|
-
|
51
|
-
Does not currently support:
|
52
|
-
- Tool calls
|
53
|
-
- Feature analysis
|
54
|
-
- Streaming responses
|
55
|
-
|
56
|
-
Known limitations:
|
57
|
-
- Limited role support (system/user/assistant only)
|
58
|
-
- Tool messages converted to user messages
|
59
|
-
"""
|
60
|
-
|
61
|
-
client: AsyncClient
|
62
|
-
variant: Variant
|
63
|
-
model_args: dict[str, Any]
|
64
|
-
|
65
|
-
def __init__(
|
66
|
-
self,
|
67
|
-
model_name: str,
|
68
|
-
base_url: str | None = None,
|
69
|
-
api_key: str | None = None,
|
70
|
-
config: GenerateConfig = GenerateConfig(),
|
71
|
-
**model_args: Any,
|
72
|
-
) -> None:
|
73
|
-
"""Initialize the Goodfire API provider.
|
74
|
-
|
75
|
-
Args:
|
76
|
-
model_name: Name of the model to use
|
77
|
-
base_url: Optional custom API base URL
|
78
|
-
api_key: Optional API key (will check env vars if not provided)
|
79
|
-
config: Generation config options
|
80
|
-
**model_args: Additional arguments passed to the API
|
81
|
-
"""
|
82
|
-
super().__init__(
|
83
|
-
model_name=model_name,
|
84
|
-
base_url=base_url,
|
85
|
-
api_key=api_key,
|
86
|
-
api_key_vars=[GOODFIRE_API_KEY],
|
87
|
-
config=config,
|
88
|
-
)
|
89
|
-
|
90
|
-
# resolve api_key
|
91
|
-
if not self.api_key:
|
92
|
-
self.api_key = os.environ.get(GOODFIRE_API_KEY)
|
93
|
-
if not self.api_key:
|
94
|
-
raise environment_prerequisite_error("Goodfire", GOODFIRE_API_KEY)
|
95
|
-
|
96
|
-
# Validate model name against supported models
|
97
|
-
supported_models = list(get_args(SUPPORTED_MODELS))
|
98
|
-
if self.model_name not in supported_models:
|
99
|
-
raise ValueError(
|
100
|
-
f"Model {self.model_name} not supported. Supported models: {supported_models}"
|
101
|
-
)
|
102
|
-
|
103
|
-
# Initialize client with minimal configuration
|
104
|
-
base_url_val = model_base_url(base_url, "GOODFIRE_BASE_URL")
|
105
|
-
assert isinstance(base_url_val, str) or base_url_val is None
|
106
|
-
|
107
|
-
# Store model args for use in generate
|
108
|
-
self.model_args = model_args
|
109
|
-
|
110
|
-
self.client = AsyncClient(
|
111
|
-
api_key=self.api_key,
|
112
|
-
base_url=base_url_val or DEFAULT_BASE_URL,
|
113
|
-
)
|
114
|
-
|
115
|
-
# Initialize variant directly with model name
|
116
|
-
self.variant = Variant(self.model_name) # type: ignore
|
117
|
-
|
118
|
-
def _to_goodfire_message(self, message: ChatMessage) -> GoodfireChatMessage:
|
119
|
-
"""Convert an Inspect message to a Goodfire message format.
|
120
|
-
|
121
|
-
Args:
|
122
|
-
message: The message to convert
|
123
|
-
|
124
|
-
Returns:
|
125
|
-
The converted message in Goodfire format
|
126
|
-
|
127
|
-
Raises:
|
128
|
-
ValueError: If the message type is unknown
|
129
|
-
"""
|
130
|
-
role: Literal["system", "user", "assistant"] = "user"
|
131
|
-
if isinstance(message, ChatMessageSystem):
|
132
|
-
role = "system"
|
133
|
-
elif isinstance(message, ChatMessageUser):
|
134
|
-
role = "user"
|
135
|
-
elif isinstance(message, ChatMessageAssistant):
|
136
|
-
role = "assistant"
|
137
|
-
elif isinstance(message, ChatMessageTool):
|
138
|
-
role = "user" # Convert tool messages to user messages
|
139
|
-
else:
|
140
|
-
raise ValueError(f"Unknown message type: {type(message)}")
|
141
|
-
|
142
|
-
content = str(message.content)
|
143
|
-
if isinstance(message, ChatMessageTool):
|
144
|
-
content = f"Tool {message.function}: {content}"
|
145
|
-
|
146
|
-
return GoodfireChatMessage(role=role, content=content)
|
147
|
-
|
148
|
-
def handle_error(self, ex: Exception) -> ModelOutput | Exception:
|
149
|
-
"""Handle only errors that need special treatment for retry logic or model limits."""
|
150
|
-
# Handle token/context length errors
|
151
|
-
if isinstance(ex, InvalidRequestException):
|
152
|
-
error_msg = str(ex).lower()
|
153
|
-
if "context length" in error_msg or "max tokens" in error_msg:
|
154
|
-
return ModelOutput.from_content(
|
155
|
-
model=self.model_name,
|
156
|
-
content=str(ex),
|
157
|
-
stop_reason="model_length",
|
158
|
-
error=error_msg,
|
159
|
-
)
|
160
|
-
|
161
|
-
# Let all other errors propagate
|
162
|
-
return ex
|
163
|
-
|
164
|
-
@override
|
165
|
-
def should_retry(self, ex: Exception) -> bool:
|
166
|
-
"""Check if exception is due to rate limiting."""
|
167
|
-
return isinstance(ex, RateLimitException | ServerErrorException)
|
168
|
-
|
169
|
-
@override
|
170
|
-
def connection_key(self) -> str:
|
171
|
-
"""Return key for connection pooling."""
|
172
|
-
return f"goodfire:{self.api_key}"
|
173
|
-
|
174
|
-
@override
|
175
|
-
def max_tokens(self) -> int | None:
|
176
|
-
"""Return maximum tokens supported by model."""
|
177
|
-
return DEFAULT_MAX_TOKENS # Let Goodfire's Variant handle model-specific limits
|
178
|
-
|
179
|
-
async def generate(
|
180
|
-
self,
|
181
|
-
input: List[ChatMessage],
|
182
|
-
tools: List[ToolInfo],
|
183
|
-
tool_choice: ToolChoice,
|
184
|
-
config: GenerateConfig,
|
185
|
-
*,
|
186
|
-
cache: bool = True,
|
187
|
-
) -> tuple[ModelOutput | Exception, ModelCall]:
|
188
|
-
"""Generate output from the model."""
|
189
|
-
# Convert messages and prepare request params
|
190
|
-
messages = [self._to_goodfire_message(msg) for msg in input]
|
191
|
-
# Build request parameters with type hints
|
192
|
-
params: dict[str, Any] = {
|
193
|
-
"model": self.variant.base_model, # Use base_model instead of stringifying the Variant
|
194
|
-
"messages": messages,
|
195
|
-
"max_completion_tokens": int(config.max_tokens)
|
196
|
-
if config.max_tokens
|
197
|
-
else DEFAULT_MAX_TOKENS,
|
198
|
-
"stream": False,
|
199
|
-
}
|
200
|
-
|
201
|
-
# Add generation parameters from config if not in model_args
|
202
|
-
if "temperature" not in self.model_args and config.temperature is not None:
|
203
|
-
params["temperature"] = float(config.temperature)
|
204
|
-
elif "temperature" not in self.model_args:
|
205
|
-
params["temperature"] = DEFAULT_TEMPERATURE
|
206
|
-
|
207
|
-
if "top_p" not in self.model_args and config.top_p is not None:
|
208
|
-
params["top_p"] = float(config.top_p)
|
209
|
-
elif "top_p" not in self.model_args:
|
210
|
-
params["top_p"] = DEFAULT_TOP_P
|
211
|
-
|
212
|
-
# Add any additional model args (highest priority)
|
213
|
-
api_params = {
|
214
|
-
k: v
|
215
|
-
for k, v in self.model_args.items()
|
216
|
-
if k not in ["api_key", "base_url", "model_args"]
|
217
|
-
}
|
218
|
-
params.update(api_params)
|
219
|
-
|
220
|
-
try:
|
221
|
-
# Use native async client
|
222
|
-
response = await self.client.chat.completions.create(**params)
|
223
|
-
response_dict = response.model_dump()
|
224
|
-
|
225
|
-
output = ModelOutput(
|
226
|
-
model=self.model_name,
|
227
|
-
choices=[
|
228
|
-
ChatCompletionChoice(
|
229
|
-
message=ChatMessageAssistant(
|
230
|
-
content=response_dict["choices"][0]["message"]["content"],
|
231
|
-
model=self.model_name,
|
232
|
-
),
|
233
|
-
stop_reason="stop",
|
234
|
-
)
|
235
|
-
],
|
236
|
-
usage=ModelUsage(**response_dict["usage"])
|
237
|
-
if "usage" in response_dict
|
238
|
-
else None,
|
239
|
-
)
|
240
|
-
model_call = ModelCall.create(request=params, response=response_dict)
|
241
|
-
return (output, model_call)
|
242
|
-
except Exception as ex:
|
243
|
-
result = self.handle_error(ex)
|
244
|
-
model_call = ModelCall.create(
|
245
|
-
request=params,
|
246
|
-
response={}, # Empty response for error case
|
247
|
-
)
|
248
|
-
return (result, model_call)
|
249
|
-
|
250
|
-
@property
|
251
|
-
def name(self) -> str:
|
252
|
-
"""Get provider name."""
|
253
|
-
return "goodfire"
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|