inspect-ai 0.3.84__py3-none-any.whl → 0.3.86__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
inspect_ai/_eval/eval.py CHANGED
@@ -43,7 +43,7 @@ from inspect_ai.model import (
43
43
  GenerateConfigArgs,
44
44
  Model,
45
45
  )
46
- from inspect_ai.model._model import init_active_model, resolve_models
46
+ from inspect_ai.model._model import get_model, init_active_model, resolve_models
47
47
  from inspect_ai.scorer._reducer import reducer_log_names
48
48
  from inspect_ai.solver._chain import chain
49
49
  from inspect_ai.solver._solver import Solver, SolverSpec
@@ -751,10 +751,15 @@ async def eval_retry_async(
751
751
  else None
752
752
  )
753
753
 
754
+ # resolve the model
755
+ model = get_model(
756
+ model=eval_log.eval.model,
757
+ config=eval_log.eval.model_generate_config,
758
+ base_url=eval_log.eval.model_base_url,
759
+ **eval_log.eval.model_args,
760
+ )
761
+
754
762
  # collect the rest of the params we need for the eval
755
- model = eval_log.eval.model
756
- model_base_url = eval_log.eval.model_base_url
757
- model_args = eval_log.eval.model_args
758
763
  task_args = eval_log.eval.task_args
759
764
  tags = eval_log.eval.tags
760
765
  limit = eval_log.eval.config.limit
@@ -813,8 +818,6 @@ async def eval_retry_async(
813
818
  id=task_id, task=task, task_args=task_args, model=None, log=eval_log
814
819
  ),
815
820
  model=model,
816
- model_base_url=model_base_url,
817
- model_args=model_args,
818
821
  task_args=task_args,
819
822
  sandbox=eval_log.eval.sandbox,
820
823
  sandbox_cleanup=sandbox_cleanup,
@@ -139,6 +139,7 @@ class TaskLogger:
139
139
  tags=tags,
140
140
  solver_args=solver.args if solver else None,
141
141
  model=str(ModelName(model)),
142
+ model_generate_config=model.config,
142
143
  model_base_url=model.api.base_url,
143
144
  dataset=EvalDataset(
144
145
  name=dataset.name,
@@ -225,16 +225,12 @@ def agent_with(
225
225
  name = name or info.name
226
226
  description = description or info.metadata.get(AGENT_DESCRIPTION, None)
227
227
 
228
- # if the name is null then raise
229
- if name is None:
230
- raise ValueError("You must provide a name to agent_with")
231
-
232
228
  # now set registry info
233
229
  set_registry_info(
234
230
  agent,
235
231
  RegistryInfo(
236
232
  type="agent",
237
- name=name,
233
+ name=name or "agent",
238
234
  metadata={AGENT_DESCRIPTION: description}
239
235
  if description is not None
240
236
  else {},
inspect_ai/log/_log.py CHANGED
@@ -599,6 +599,9 @@ class EvalSpec(BaseModel):
599
599
  model: str
600
600
  """Model used for eval."""
601
601
 
602
+ model_generate_config: GenerateConfig = Field(default_factory=GenerateConfig)
603
+ """Generate config specified for model instance."""
604
+
602
605
  model_base_url: str | None = Field(default=None)
603
606
  """Optional override of model base url"""
604
607
 
@@ -199,28 +199,36 @@ class SampleBufferDatabase(SampleBuffer):
199
199
  )
200
200
 
201
201
  def remove_samples(self, samples: list[tuple[str | int, int]]) -> None:
202
+ # short circuit no samples
203
+ if len(samples) == 0:
204
+ return
205
+
202
206
  with self._get_connection(write=True) as conn:
203
207
  cursor = conn.cursor()
204
208
  try:
205
- # Convert list of tuples into a string for SQL IN clause
206
- # Format: (('id1', 1), ('id2', 2))
207
- sample_conditions = ",".join(
208
- [f"('{sid}', {epoch})" for sid, epoch in samples]
209
+ # Build a query using individual column comparisons instead of row values
210
+ placeholders = " OR ".join(
211
+ ["(sample_id=? AND sample_epoch=?)" for _ in samples]
209
212
  )
210
213
 
211
- # Delete associated events first due to foreign key constraint
214
+ # Flatten parameters for binding
215
+ parameters = [item for tup in samples for item in tup]
216
+
217
+ # Delete associated events first
212
218
  events_query = f"""
213
219
  DELETE FROM events
214
- WHERE (sample_id, sample_epoch) IN ({sample_conditions})
220
+ WHERE {placeholders}
215
221
  """
216
- cursor.execute(events_query)
222
+ cursor.execute(events_query, parameters)
223
+
224
+ # Then delete the samples using the same approach
225
+ placeholders = " OR ".join(["(id=? AND epoch=?)" for _ in samples])
217
226
 
218
- # Then delete the samples
219
227
  samples_query = f"""
220
228
  DELETE FROM samples
221
- WHERE (id, epoch) IN ({sample_conditions})
229
+ WHERE {placeholders}
222
230
  """
223
- cursor.execute(samples_query)
231
+ cursor.execute(samples_query, parameters)
224
232
  finally:
225
233
  cursor.close()
226
234
 
@@ -259,7 +267,7 @@ class SampleBufferDatabase(SampleBuffer):
259
267
 
260
268
  # fetch data
261
269
  return Samples(
262
- samples=list(self._get_samples(conn)),
270
+ samples=list(self._get_samples(conn, True)),
263
271
  metrics=task_data.metrics,
264
272
  refresh=self.update_interval,
265
273
  etag=str(task_data.version),
@@ -3,7 +3,7 @@ import re
3
3
  from copy import copy
4
4
  from typing import Literal
5
5
 
6
- from openai import BadRequestError, OpenAIError
6
+ from openai import APIStatusError, OpenAIError
7
7
  from openai.types.chat import (
8
8
  ChatCompletion,
9
9
  ChatCompletionAssistantMessageParam,
@@ -518,7 +518,7 @@ def chat_choices_from_openai(
518
518
 
519
519
 
520
520
  def openai_handle_bad_request(
521
- model_name: str, e: BadRequestError
521
+ model_name: str, e: APIStatusError
522
522
  ) -> ModelOutput | Exception:
523
523
  # extract message
524
524
  if isinstance(e.body, dict) and "message" in e.body.keys():
@@ -13,6 +13,7 @@ from openai import (
13
13
  AsyncOpenAI,
14
14
  BadRequestError,
15
15
  RateLimitError,
16
+ UnprocessableEntityError,
16
17
  )
17
18
  from openai._types import NOT_GIVEN
18
19
  from openai.types.chat import ChatCompletion
@@ -295,13 +296,13 @@ class OpenAIAPI(ModelAPI):
295
296
  else None
296
297
  ),
297
298
  ), model_call()
298
- except BadRequestError as e:
299
+ except (BadRequestError, UnprocessableEntityError) as e:
299
300
  return self.handle_bad_request(e), model_call()
300
301
 
301
302
  def on_response(self, response: dict[str, Any]) -> None:
302
303
  pass
303
304
 
304
- def handle_bad_request(self, ex: BadRequestError) -> ModelOutput | Exception:
305
+ def handle_bad_request(self, ex: APIStatusError) -> ModelOutput | Exception:
305
306
  return openai_handle_bad_request(self.model_name, ex)
306
307
 
307
308
  def _chat_choices_from_response(
@@ -253,28 +253,6 @@ def none() -> type[ModelAPI]:
253
253
  return NoModel
254
254
 
255
255
 
256
- @modelapi("goodfire")
257
- def goodfire() -> type[ModelAPI]:
258
- """Get the Goodfire API provider."""
259
- FEATURE = "Goodfire API"
260
- PACKAGE = "goodfire"
261
- MIN_VERSION = "0.3.4" # Support for newer Llama models and OpenAI compatibility
262
-
263
- # verify we have the package
264
- try:
265
- import goodfire # noqa: F401
266
- except ImportError:
267
- raise pip_dependency_error(FEATURE, [PACKAGE])
268
-
269
- # verify version
270
- verify_required_version(FEATURE, PACKAGE, MIN_VERSION)
271
-
272
- # in the clear
273
- from .goodfire import GoodfireAPI
274
-
275
- return GoodfireAPI
276
-
277
-
278
256
  def validate_openai_client(feature: str) -> None:
279
257
  FEATURE = feature
280
258
  PACKAGE = "openai"
@@ -3,7 +3,7 @@ from json import dumps
3
3
  from typing import Any
4
4
 
5
5
  import httpx
6
- from openai import BadRequestError
6
+ from openai import APIStatusError
7
7
  from openai.types.chat import (
8
8
  ChatCompletion,
9
9
  )
@@ -105,7 +105,7 @@ class TogetherAIAPI(OpenAIAPI):
105
105
  return DEFAULT_MAX_TOKENS
106
106
 
107
107
  @override
108
- def handle_bad_request(self, ex: BadRequestError) -> ModelOutput | Exception:
108
+ def handle_bad_request(self, ex: APIStatusError) -> ModelOutput | Exception:
109
109
  response = ex.response.json()
110
110
  if "error" in response and "message" in response.get("error"):
111
111
  content = response.get("error").get("message")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: inspect_ai
3
- Version: 0.3.84
3
+ Version: 0.3.86
4
4
  Summary: Framework for large language model evaluations
5
5
  Author: UK AI Security Institute
6
6
  License: MIT License
@@ -56,7 +56,6 @@ Requires-Dist: aioboto3; extra == "dev"
56
56
  Requires-Dist: azure-ai-inference; extra == "dev"
57
57
  Requires-Dist: google-cloud-aiplatform; extra == "dev"
58
58
  Requires-Dist: google-genai; extra == "dev"
59
- Requires-Dist: goodfire; extra == "dev"
60
59
  Requires-Dist: griffe; extra == "dev"
61
60
  Requires-Dist: groq; extra == "dev"
62
61
  Requires-Dist: ipython; extra == "dev"
@@ -46,7 +46,7 @@ inspect_ai/_display/textual/widgets/transcript.py,sha256=zaxlDixT6Fie0acAWBM9Hlt
46
46
  inspect_ai/_display/textual/widgets/vscode.py,sha256=YTXdIZ0fcf9XE2v3rWIfUTgnXFww8uKCo7skugQLIbs,1247
47
47
  inspect_ai/_eval/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
48
48
  inspect_ai/_eval/context.py,sha256=gWTjEEMVTJMJpCCKLRs4joZDkG00rzE7-HXZFyzSC_I,1283
49
- inspect_ai/_eval/eval.py,sha256=3VHOUYhkTmni-AaT4dYE9kfoXME68eCqOv_T6xFySzo,40266
49
+ inspect_ai/_eval/eval.py,sha256=tfIYOJSNGNfZHl18XYqXZHIMKlCmUZ8gbqe7I0OZJII,40307
50
50
  inspect_ai/_eval/evalset.py,sha256=FnZBVi5hOt6f84PNYlFhkjb7N1lNgiQydQlernJZeW4,24005
51
51
  inspect_ai/_eval/list.py,sha256=VbZ-2EI6MqrXvCN7VTz21TQSoU5K5_Q0hqhxmj5A_m0,3744
52
52
  inspect_ai/_eval/loader.py,sha256=yCDrW5MhP6GT329hZ_gUm_eAMsCA9G7jb8sm45Pj-pw,24970
@@ -59,7 +59,7 @@ inspect_ai/_eval/task/epochs.py,sha256=Ci7T6CQniSOTChv5Im2dCdSDrP-5hq19rV6iJ2uBc
59
59
  inspect_ai/_eval/task/error.py,sha256=Vhqinfdf0eIrjn7kUY7-id8Kbdggr-fEFpAJeJrkJ1M,1244
60
60
  inspect_ai/_eval/task/generate.py,sha256=C9-S9ak4VFQO7QgtUbGjt8F4sTyXS5nekR3Mg_MPwmM,2511
61
61
  inspect_ai/_eval/task/images.py,sha256=nTzHizlyuPYumPH7gAOBSrNkTwTbAmZ7tKdzN7d_R2k,4035
62
- inspect_ai/_eval/task/log.py,sha256=w1Uu3VplvL_UUqyCVDmUMOG5s8_E3si6OkglE7xPxM0,11769
62
+ inspect_ai/_eval/task/log.py,sha256=PD2ZrqtHY0zRyx7pB8L5v-txyaBRePs76cFu5Fb-vjE,11817
63
63
  inspect_ai/_eval/task/resolved.py,sha256=OCQc_0HmW_Vw8o1KisX0DCn-eOPkTbR1v_y_jEaAlhU,966
64
64
  inspect_ai/_eval/task/results.py,sha256=x4weYRK2XGowfBG3f2msOeZQ_pxh230HTlw6kps33jw,17925
65
65
  inspect_ai/_eval/task/run.py,sha256=RS2Qv3AythSkQL4fsgBFaXfyx2WDIZuFj9v6ifoRiYs,38714
@@ -450,7 +450,7 @@ inspect_ai/_view/www/src/workspace/tabs/SamplesTab.tsx,sha256=s6jt1-5_Hrgz3_ysT1
450
450
  inspect_ai/_view/www/src/workspace/tabs/grouping.ts,sha256=6lvFzReQKQ_43S20xN4kfBJN2F7Tfs2VWeSMIuHxUAI,6187
451
451
  inspect_ai/_view/www/src/workspace/tabs/types.ts,sha256=Wa1Y4tZwYO_QJr0Tg9-5xJFztmcMYCODSm6JvdzMpDw,471
452
452
  inspect_ai/agent/__init__.py,sha256=nzL9TPAARSJVZRPogWHxZ-qJriXBGmFUM9DV4NRi21o,749
453
- inspect_ai/agent/_agent.py,sha256=g0hw6sDTXg_4NRjs5Ohze404HzyyIyFNFDvlgGqL2Vw,7736
453
+ inspect_ai/agent/_agent.py,sha256=5MXMrY5bsBQ4AI5y5rVaIYywp7JxQzUnhp-KqBIMF7I,7622
454
454
  inspect_ai/agent/_as_solver.py,sha256=_6H0L9JidC6JjpMaBRBAjIrgzE8GKEoygJjOC_JRoLQ,2340
455
455
  inspect_ai/agent/_as_tool.py,sha256=vT5hrcKfkyP90i4Ieuy_dx4cYsFKOMdPs-6x12cuqMk,4449
456
456
  inspect_ai/agent/_filter.py,sha256=qnT0HbT4edpDi0MwXY3Q3It2pzNRkTRXZDOqfCwMY6M,1234
@@ -510,7 +510,7 @@ inspect_ai/log/_bundle.py,sha256=5Uy-s64_SFokZ7WRzti9mD7yoKrd2sOzdvqKyahoiC4,804
510
510
  inspect_ai/log/_condense.py,sha256=OedMphK5Q2YPuY1cnoAM7tGsyVIU6Kwrv3oIeb3dFmY,10881
511
511
  inspect_ai/log/_convert.py,sha256=qn6q10Um2XV7dnK4nQargANa0bz6RFJPmaEMINv38cs,3467
512
512
  inspect_ai/log/_file.py,sha256=QjeVUegoCWVUv6CMsj0das_UpZZZMfnbvCQAKlFYGXE,17105
513
- inspect_ai/log/_log.py,sha256=f4ChtLdNc_z0qVXsJCmZyW6BdbFKGTfHWY5gaymsUkc,24970
513
+ inspect_ai/log/_log.py,sha256=KsssY2kGfuDHGIXOGJHN4bO1LXVs0f3XtqIUfA2R68A,25109
514
514
  inspect_ai/log/_message.py,sha256=QofM_JZF_x3k_5ta1uQzoN_VnMoUhXFnqWurIn9FXOY,1999
515
515
  inspect_ai/log/_retry.py,sha256=e7a2hjl3Ncl8b8sU7CsDpvK8DV0b1uSRLeokRX1mt34,2109
516
516
  inspect_ai/log/_samples.py,sha256=wPQlV1VR9djWaj37lLrjBprCabdAm4S2vFOsQTcd12U,4910
@@ -524,7 +524,7 @@ inspect_ai/log/_recorders/recorder.py,sha256=zDDpl2tktPjb6xk5kd4TyEMxkXZiLgXXpPi
524
524
  inspect_ai/log/_recorders/types.py,sha256=Aeo-U7FhmWQSvE_uz3fwUI7cqaSR-ZE_uRVu-1fBCgc,865
525
525
  inspect_ai/log/_recorders/buffer/__init__.py,sha256=6DsRdnNl-ic-xJmnBE5i45ZP3eB4yAta9wxi5WFcbqc,367
526
526
  inspect_ai/log/_recorders/buffer/buffer.py,sha256=rtLvaX7nSqNrWb-3CeSaOHwJgF1CzRgXFT_I1dDkM1k,945
527
- inspect_ai/log/_recorders/buffer/database.py,sha256=aqBJdM6meQTWsLs9uF1gFGg1dsE1MvVQdiXR1DHoRqw,22171
527
+ inspect_ai/log/_recorders/buffer/database.py,sha256=3yV8OlDsQ4zFQHNqe7aBAHwkUISW3zmaLBlD1OFj36w,22396
528
528
  inspect_ai/log/_recorders/buffer/filestore.py,sha256=S6RP-5zkOPSmy1hV2LCCbfwdX-YFZGuIEjfJuOWMjDQ,8274
529
529
  inspect_ai/log/_recorders/buffer/types.py,sha256=pTnPCZHbk9qF6yF-eNXHTa23cLH_FvP8dmfPJCFO15Q,2046
530
530
  inspect_ai/model/__init__.py,sha256=6Aa_HEU-rgxWPDaIRlE6KBdXY406x2LtcLeVtAxk-AI,2453
@@ -537,7 +537,7 @@ inspect_ai/model/_generate_config.py,sha256=_-kzw7LOl45baVkTjlfL1K1VLKGgNOOczH2H
537
537
  inspect_ai/model/_model.py,sha256=h4ASS2VuTZ_97145rLW202u6e7-mw4ENnnlBl0Vsbio,52127
538
538
  inspect_ai/model/_model_call.py,sha256=VJ8wnl9Y81JaiClBYM8eyt1jVb3n-yc6Dd88ofRiJDc,2234
539
539
  inspect_ai/model/_model_output.py,sha256=R5EAUPLc5RWymVb3le4cbqbNCZ9voTzg0U1j_e4I-yM,7768
540
- inspect_ai/model/_openai.py,sha256=0OAmxQbIU6V7WJr9Q8J6oGwQuY9aZLPpHQ9r28GCmbg,19382
540
+ inspect_ai/model/_openai.py,sha256=-N_LhZR8-nrnCL8h9lklo_RrGNDR1SzMJ0tPafVuPXo,19380
541
541
  inspect_ai/model/_openai_computer_use.py,sha256=vbKkYLhqNuX16zuWfg5MaGp9H8URrPcLhKQ1pDsZtPo,5943
542
542
  inspect_ai/model/_openai_responses.py,sha256=bQWuVvJIkS8CqtoX9z1aRb1aky4TNbMngG2paB3wsrA,20179
543
543
  inspect_ai/model/_reasoning.py,sha256=qmR8WT6t_cb7NIsJOQHPyFZh2eLV0HmYxKo2vtvteQ4,929
@@ -547,7 +547,6 @@ inspect_ai/model/_providers/anthropic.py,sha256=PYxV0D_bt0Icp2wEWb6GMCpDb-uBFKYy
547
547
  inspect_ai/model/_providers/azureai.py,sha256=uXED_qmeyW1XAGBosbG7PJNk833RIeokKX3l_8O9gYA,14341
548
548
  inspect_ai/model/_providers/bedrock.py,sha256=rh8BvSUPWiFMh0TQwMYTlucfFrDKswtLhzozulrz7wE,24004
549
549
  inspect_ai/model/_providers/cloudflare.py,sha256=mWqBqc0zzf29UWz34biq8CxSu99a95YjpH_6A4na52g,4617
550
- inspect_ai/model/_providers/goodfire.py,sha256=J0nxGbF8lXBmc5YHBJCsZdF03mWT5SuWMb21d9ho3FM,8799
551
550
  inspect_ai/model/_providers/google.py,sha256=gcg8pvYAV5gYc4NXC5mLqFyuU7KuhyNrzdXIY57sYl8,28207
552
551
  inspect_ai/model/_providers/grok.py,sha256=dS88ueXiD-kHAFr0jCoTpTGLGa2VsUlB_TFP8L_2lBM,995
553
552
  inspect_ai/model/_providers/groq.py,sha256=mcRKu33e-mO5l06PGV6SjsildQd0XCti6QNXwwFWL7I,11246
@@ -557,12 +556,12 @@ inspect_ai/model/_providers/mistral.py,sha256=FbMPN_pw8LZal2iFGf5FX70ypuH3k44FUn
557
556
  inspect_ai/model/_providers/mockllm.py,sha256=gL9f-f5TOdE4a0GVENr3cOIIp2kv8zVXWPZ608rouGk,2440
558
557
  inspect_ai/model/_providers/none.py,sha256=6qLbZpHSoEZaaxFO7luieFjqig2Ju8Fu00DlRngAry8,935
559
558
  inspect_ai/model/_providers/ollama.py,sha256=mBPSxaEkiH_RnlHKqOyFBlXObQhc2dfjL-rCKrea5u8,675
560
- inspect_ai/model/_providers/openai.py,sha256=NFdMpnI2vlmpI8h_vWnt8y4X_XaydaL9gH5Dmy6k5Tw,16478
559
+ inspect_ai/model/_providers/openai.py,sha256=zJkhtiEQrmsuhfL7mpBPpOlYJ_WNraeyTkjYTelF0no,16535
561
560
  inspect_ai/model/_providers/openai_o1.py,sha256=k-Xm_Wzn1KHKL6Z1KTHg4CTTr8ybgiHvXkLiLdjP7Os,12926
562
561
  inspect_ai/model/_providers/openai_responses.py,sha256=YPXt8KQfIEiiTpvtoQECBoNQLDLbwBW_KhBfM8vEhJk,6324
563
562
  inspect_ai/model/_providers/openrouter.py,sha256=pDimDmm_4FzS4GZx0n9z8z717mQf3IQlgEy30huzpc4,4730
564
- inspect_ai/model/_providers/providers.py,sha256=0WSi_FOWxW71sZ4GJ-OgJqbPS4tMIaPQqEG2hnxqfqc,6378
565
- inspect_ai/model/_providers/together.py,sha256=MoA3tyMKUnE0EekTqEIBBwvsaOp5c697kydLi1ZMYzE,9745
563
+ inspect_ai/model/_providers/providers.py,sha256=Sd2D9OcWkukuBcl_-KDfdpxMaAShv1JZhL5KfAM87CE,5817
564
+ inspect_ai/model/_providers/together.py,sha256=Wh3G0vhKHq5ofx1otwXjJFhM98Ll70IbqBhUNNV2-rk,9743
566
565
  inspect_ai/model/_providers/vertex.py,sha256=60W7kgoA83GtKdMeJgNU2IAw0N0wTscg4YCcMPu2bwo,17185
567
566
  inspect_ai/model/_providers/vllm.py,sha256=UYjCCXzw2hGJHVC3oPl-u2EI4iAm8ZncoIfYp1QJkbQ,14238
568
567
  inspect_ai/model/_providers/util/__init__.py,sha256=d4T_qvXihTRd1zmQkNE3xUBlHCX8tOIbRK19EwU0fTs,717
@@ -693,9 +692,9 @@ inspect_ai/util/_sandbox/docker/internal.py,sha256=c8X8TLrBPOvsfnq5TkMlb_bzTALyc
693
692
  inspect_ai/util/_sandbox/docker/prereqs.py,sha256=0j6_OauBBnVlpBleADcZavIAAQZy4WewVjbRn9c0stg,3355
694
693
  inspect_ai/util/_sandbox/docker/service.py,sha256=hhHIWH1VDFLwehdGd19aUBD_VKfDO3GCPxpw1HSwVQk,2437
695
694
  inspect_ai/util/_sandbox/docker/util.py,sha256=EeInihCNXgUWxaqZ4dNOJd719kXL2_jr63QCoXn68vA,3154
696
- inspect_ai-0.3.84.dist-info/licenses/LICENSE,sha256=xZPCr8gTiFIerrA_DRpLAbw-UUftnLFsHxKeW-NTtq8,1081
697
- inspect_ai-0.3.84.dist-info/METADATA,sha256=g-2UAMeNEN0cyQB6JUowoPVFebPlFpNsZuFiQwsxpVE,5005
698
- inspect_ai-0.3.84.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
699
- inspect_ai-0.3.84.dist-info/entry_points.txt,sha256=WGGLmzTzDWLzYfiyovSY6oEKuf-gqzSDNOb5V-hk3fM,54
700
- inspect_ai-0.3.84.dist-info/top_level.txt,sha256=Tp3za30CHXJEKLk8xLe9qGsW4pBzJpEIOMHOHNCXiVo,11
701
- inspect_ai-0.3.84.dist-info/RECORD,,
695
+ inspect_ai-0.3.86.dist-info/licenses/LICENSE,sha256=xZPCr8gTiFIerrA_DRpLAbw-UUftnLFsHxKeW-NTtq8,1081
696
+ inspect_ai-0.3.86.dist-info/METADATA,sha256=tId3lj5ywe2A79iWACtXXP_aBeWfnbQQfAujxJGxaoc,4965
697
+ inspect_ai-0.3.86.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
698
+ inspect_ai-0.3.86.dist-info/entry_points.txt,sha256=WGGLmzTzDWLzYfiyovSY6oEKuf-gqzSDNOb5V-hk3fM,54
699
+ inspect_ai-0.3.86.dist-info/top_level.txt,sha256=Tp3za30CHXJEKLk8xLe9qGsW4pBzJpEIOMHOHNCXiVo,11
700
+ inspect_ai-0.3.86.dist-info/RECORD,,
@@ -1,253 +0,0 @@
1
- import os
2
- from typing import Any, List, Literal, get_args
3
-
4
- from goodfire import AsyncClient
5
- from goodfire.api.chat.interfaces import ChatMessage as GoodfireChatMessage
6
- from goodfire.api.exceptions import (
7
- InvalidRequestException,
8
- RateLimitException,
9
- ServerErrorException,
10
- )
11
- from goodfire.variants.variants import SUPPORTED_MODELS, Variant
12
- from typing_extensions import override
13
-
14
- from inspect_ai.tool._tool_choice import ToolChoice
15
- from inspect_ai.tool._tool_info import ToolInfo
16
-
17
- from .._chat_message import (
18
- ChatMessage,
19
- ChatMessageAssistant,
20
- ChatMessageSystem,
21
- ChatMessageTool,
22
- ChatMessageUser,
23
- )
24
- from .._generate_config import GenerateConfig
25
- from .._model import ModelAPI
26
- from .._model_call import ModelCall
27
- from .._model_output import (
28
- ChatCompletionChoice,
29
- ModelOutput,
30
- ModelUsage,
31
- )
32
- from .util import environment_prerequisite_error, model_base_url
33
-
34
- # Constants
35
- GOODFIRE_API_KEY = "GOODFIRE_API_KEY"
36
- DEFAULT_BASE_URL = "https://api.goodfire.ai"
37
- DEFAULT_MAX_TOKENS = 4096
38
- DEFAULT_TEMPERATURE = 1.0 # Standard sampling temperature (baseline)
39
- DEFAULT_TOP_P = 1.0 # No nucleus sampling truncation (baseline)
40
-
41
-
42
- class GoodfireAPI(ModelAPI):
43
- """Goodfire API provider.
44
-
45
- This provider implements the Goodfire API for LLM inference. It supports:
46
- - Chat completions with standard message formats
47
- - Basic parameter controls (temperature, top_p, etc.)
48
- - Usage statistics tracking
49
- - Stop reason handling
50
-
51
- Does not currently support:
52
- - Tool calls
53
- - Feature analysis
54
- - Streaming responses
55
-
56
- Known limitations:
57
- - Limited role support (system/user/assistant only)
58
- - Tool messages converted to user messages
59
- """
60
-
61
- client: AsyncClient
62
- variant: Variant
63
- model_args: dict[str, Any]
64
-
65
- def __init__(
66
- self,
67
- model_name: str,
68
- base_url: str | None = None,
69
- api_key: str | None = None,
70
- config: GenerateConfig = GenerateConfig(),
71
- **model_args: Any,
72
- ) -> None:
73
- """Initialize the Goodfire API provider.
74
-
75
- Args:
76
- model_name: Name of the model to use
77
- base_url: Optional custom API base URL
78
- api_key: Optional API key (will check env vars if not provided)
79
- config: Generation config options
80
- **model_args: Additional arguments passed to the API
81
- """
82
- super().__init__(
83
- model_name=model_name,
84
- base_url=base_url,
85
- api_key=api_key,
86
- api_key_vars=[GOODFIRE_API_KEY],
87
- config=config,
88
- )
89
-
90
- # resolve api_key
91
- if not self.api_key:
92
- self.api_key = os.environ.get(GOODFIRE_API_KEY)
93
- if not self.api_key:
94
- raise environment_prerequisite_error("Goodfire", GOODFIRE_API_KEY)
95
-
96
- # Validate model name against supported models
97
- supported_models = list(get_args(SUPPORTED_MODELS))
98
- if self.model_name not in supported_models:
99
- raise ValueError(
100
- f"Model {self.model_name} not supported. Supported models: {supported_models}"
101
- )
102
-
103
- # Initialize client with minimal configuration
104
- base_url_val = model_base_url(base_url, "GOODFIRE_BASE_URL")
105
- assert isinstance(base_url_val, str) or base_url_val is None
106
-
107
- # Store model args for use in generate
108
- self.model_args = model_args
109
-
110
- self.client = AsyncClient(
111
- api_key=self.api_key,
112
- base_url=base_url_val or DEFAULT_BASE_URL,
113
- )
114
-
115
- # Initialize variant directly with model name
116
- self.variant = Variant(self.model_name) # type: ignore
117
-
118
- def _to_goodfire_message(self, message: ChatMessage) -> GoodfireChatMessage:
119
- """Convert an Inspect message to a Goodfire message format.
120
-
121
- Args:
122
- message: The message to convert
123
-
124
- Returns:
125
- The converted message in Goodfire format
126
-
127
- Raises:
128
- ValueError: If the message type is unknown
129
- """
130
- role: Literal["system", "user", "assistant"] = "user"
131
- if isinstance(message, ChatMessageSystem):
132
- role = "system"
133
- elif isinstance(message, ChatMessageUser):
134
- role = "user"
135
- elif isinstance(message, ChatMessageAssistant):
136
- role = "assistant"
137
- elif isinstance(message, ChatMessageTool):
138
- role = "user" # Convert tool messages to user messages
139
- else:
140
- raise ValueError(f"Unknown message type: {type(message)}")
141
-
142
- content = str(message.content)
143
- if isinstance(message, ChatMessageTool):
144
- content = f"Tool {message.function}: {content}"
145
-
146
- return GoodfireChatMessage(role=role, content=content)
147
-
148
- def handle_error(self, ex: Exception) -> ModelOutput | Exception:
149
- """Handle only errors that need special treatment for retry logic or model limits."""
150
- # Handle token/context length errors
151
- if isinstance(ex, InvalidRequestException):
152
- error_msg = str(ex).lower()
153
- if "context length" in error_msg or "max tokens" in error_msg:
154
- return ModelOutput.from_content(
155
- model=self.model_name,
156
- content=str(ex),
157
- stop_reason="model_length",
158
- error=error_msg,
159
- )
160
-
161
- # Let all other errors propagate
162
- return ex
163
-
164
- @override
165
- def should_retry(self, ex: Exception) -> bool:
166
- """Check if exception is due to rate limiting."""
167
- return isinstance(ex, RateLimitException | ServerErrorException)
168
-
169
- @override
170
- def connection_key(self) -> str:
171
- """Return key for connection pooling."""
172
- return f"goodfire:{self.api_key}"
173
-
174
- @override
175
- def max_tokens(self) -> int | None:
176
- """Return maximum tokens supported by model."""
177
- return DEFAULT_MAX_TOKENS # Let Goodfire's Variant handle model-specific limits
178
-
179
- async def generate(
180
- self,
181
- input: List[ChatMessage],
182
- tools: List[ToolInfo],
183
- tool_choice: ToolChoice,
184
- config: GenerateConfig,
185
- *,
186
- cache: bool = True,
187
- ) -> tuple[ModelOutput | Exception, ModelCall]:
188
- """Generate output from the model."""
189
- # Convert messages and prepare request params
190
- messages = [self._to_goodfire_message(msg) for msg in input]
191
- # Build request parameters with type hints
192
- params: dict[str, Any] = {
193
- "model": self.variant.base_model, # Use base_model instead of stringifying the Variant
194
- "messages": messages,
195
- "max_completion_tokens": int(config.max_tokens)
196
- if config.max_tokens
197
- else DEFAULT_MAX_TOKENS,
198
- "stream": False,
199
- }
200
-
201
- # Add generation parameters from config if not in model_args
202
- if "temperature" not in self.model_args and config.temperature is not None:
203
- params["temperature"] = float(config.temperature)
204
- elif "temperature" not in self.model_args:
205
- params["temperature"] = DEFAULT_TEMPERATURE
206
-
207
- if "top_p" not in self.model_args and config.top_p is not None:
208
- params["top_p"] = float(config.top_p)
209
- elif "top_p" not in self.model_args:
210
- params["top_p"] = DEFAULT_TOP_P
211
-
212
- # Add any additional model args (highest priority)
213
- api_params = {
214
- k: v
215
- for k, v in self.model_args.items()
216
- if k not in ["api_key", "base_url", "model_args"]
217
- }
218
- params.update(api_params)
219
-
220
- try:
221
- # Use native async client
222
- response = await self.client.chat.completions.create(**params)
223
- response_dict = response.model_dump()
224
-
225
- output = ModelOutput(
226
- model=self.model_name,
227
- choices=[
228
- ChatCompletionChoice(
229
- message=ChatMessageAssistant(
230
- content=response_dict["choices"][0]["message"]["content"],
231
- model=self.model_name,
232
- ),
233
- stop_reason="stop",
234
- )
235
- ],
236
- usage=ModelUsage(**response_dict["usage"])
237
- if "usage" in response_dict
238
- else None,
239
- )
240
- model_call = ModelCall.create(request=params, response=response_dict)
241
- return (output, model_call)
242
- except Exception as ex:
243
- result = self.handle_error(ex)
244
- model_call = ModelCall.create(
245
- request=params,
246
- response={}, # Empty response for error case
247
- )
248
- return (result, model_call)
249
-
250
- @property
251
- def name(self) -> str:
252
- """Get provider name."""
253
- return "goodfire"