inspect-ai 0.3.63__py3-none-any.whl → 0.3.65__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (182) hide show
  1. inspect_ai/_cli/cache.py +8 -7
  2. inspect_ai/_cli/common.py +0 -12
  3. inspect_ai/_cli/eval.py +32 -4
  4. inspect_ai/_cli/info.py +1 -0
  5. inspect_ai/_cli/list.py +1 -1
  6. inspect_ai/_cli/log.py +2 -0
  7. inspect_ai/_cli/sandbox.py +4 -1
  8. inspect_ai/_cli/score.py +181 -32
  9. inspect_ai/_cli/trace.py +2 -0
  10. inspect_ai/_cli/view.py +4 -2
  11. inspect_ai/_display/core/config.py +7 -1
  12. inspect_ai/_display/core/progress.py +1 -1
  13. inspect_ai/_display/textual/app.py +8 -4
  14. inspect_ai/_display/textual/widgets/samples.py +6 -5
  15. inspect_ai/_display/textual/widgets/sandbox.py +6 -0
  16. inspect_ai/_eval/__init__.py +0 -0
  17. inspect_ai/_eval/eval.py +100 -97
  18. inspect_ai/_eval/evalset.py +69 -69
  19. inspect_ai/_eval/loader.py +122 -12
  20. inspect_ai/_eval/registry.py +1 -1
  21. inspect_ai/_eval/run.py +14 -0
  22. inspect_ai/_eval/score.py +125 -36
  23. inspect_ai/_eval/task/log.py +105 -4
  24. inspect_ai/_eval/task/results.py +92 -38
  25. inspect_ai/_eval/task/run.py +6 -2
  26. inspect_ai/_eval/task/sandbox.py +35 -2
  27. inspect_ai/_eval/task/task.py +49 -46
  28. inspect_ai/_util/__init__.py +0 -0
  29. inspect_ai/_util/constants.py +1 -1
  30. inspect_ai/_util/content.py +8 -0
  31. inspect_ai/_util/error.py +2 -0
  32. inspect_ai/_util/file.py +15 -1
  33. inspect_ai/_util/logger.py +4 -2
  34. inspect_ai/_util/registry.py +7 -1
  35. inspect_ai/_view/view.py +1 -2
  36. inspect_ai/_view/www/App.css +8 -3
  37. inspect_ai/_view/www/README.md +1 -1
  38. inspect_ai/_view/www/dist/assets/index.css +66 -38
  39. inspect_ai/_view/www/dist/assets/index.js +525 -523
  40. inspect_ai/_view/www/log-schema.json +86 -73
  41. inspect_ai/_view/www/package.json +1 -1
  42. inspect_ai/_view/www/src/App.tsx +1 -0
  43. inspect_ai/_view/www/src/components/AnsiDisplay.tsx +1 -1
  44. inspect_ai/_view/www/src/components/JsonPanel.tsx +1 -1
  45. inspect_ai/_view/www/src/components/LargeModal.tsx +39 -49
  46. inspect_ai/_view/www/src/components/NavPills.tsx +3 -1
  47. inspect_ai/_view/www/src/components/TabSet.tsx +19 -4
  48. inspect_ai/_view/www/src/logfile/remoteLogFile.ts +0 -1
  49. inspect_ai/_view/www/src/metadata/MetaDataGrid.tsx +1 -1
  50. inspect_ai/_view/www/src/metadata/MetaDataView.tsx +1 -1
  51. inspect_ai/_view/www/src/metadata/RenderedContent.tsx +6 -13
  52. inspect_ai/_view/www/src/plan/PlanDetailView.tsx +17 -2
  53. inspect_ai/_view/www/src/plan/SolverDetailView.tsx +1 -1
  54. inspect_ai/_view/www/src/samples/SampleDisplay.tsx +14 -5
  55. inspect_ai/_view/www/src/samples/SampleSummaryView.tsx +4 -2
  56. inspect_ai/_view/www/src/samples/SamplesTools.tsx +16 -24
  57. inspect_ai/_view/www/src/samples/chat/ChatMessage.tsx +1 -1
  58. inspect_ai/_view/www/src/samples/chat/ChatView.tsx +1 -0
  59. inspect_ai/_view/www/src/samples/chat/MessageContent.tsx +27 -13
  60. inspect_ai/_view/www/src/samples/chat/MessageContents.tsx +19 -17
  61. inspect_ai/_view/www/src/samples/chat/tools/ToolCallView.tsx +12 -10
  62. inspect_ai/_view/www/src/samples/chat/tools/ToolInput.tsx +56 -66
  63. inspect_ai/_view/www/src/samples/chat/tools/ToolOutput.tsx +12 -5
  64. inspect_ai/_view/www/src/samples/chat/tools/tool.ts +21 -36
  65. inspect_ai/_view/www/src/samples/descriptor/samplesDescriptor.tsx +3 -1
  66. inspect_ai/_view/www/src/samples/sample-tools/SelectScorer.tsx +27 -25
  67. inspect_ai/_view/www/src/samples/sample-tools/SortFilter.tsx +5 -1
  68. inspect_ai/_view/www/src/samples/scores/SampleScoreView.module.css +13 -13
  69. inspect_ai/_view/www/src/samples/transcript/InfoEventView.tsx +1 -1
  70. inspect_ai/_view/www/src/samples/transcript/ModelEventView.tsx +2 -2
  71. inspect_ai/_view/www/src/samples/transcript/SampleInitEventView.tsx +9 -5
  72. inspect_ai/_view/www/src/samples/transcript/ScoreEventView.tsx +1 -1
  73. inspect_ai/_view/www/src/samples/transcript/ToolEventView.tsx +5 -4
  74. inspect_ai/_view/www/src/samples/transcript/event/EventNavs.tsx +1 -0
  75. inspect_ai/_view/www/src/samples/transcript/event/EventPanel.tsx +1 -0
  76. inspect_ai/_view/www/src/samples/transcript/state/StateEventRenderers.tsx +17 -6
  77. inspect_ai/_view/www/src/samples/transcript/state/StateEventView.tsx +14 -19
  78. inspect_ai/_view/www/src/types/log.d.ts +107 -19
  79. inspect_ai/_view/www/src/usage/ModelTokenTable.tsx +7 -1
  80. inspect_ai/_view/www/src/usage/ModelUsagePanel.tsx +5 -3
  81. inspect_ai/_view/www/src/workspace/WorkSpaceView.tsx +25 -27
  82. inspect_ai/_view/www/src/workspace/navbar/PrimaryBar.tsx +12 -11
  83. inspect_ai/_view/www/src/workspace/navbar/ResultsPanel.module.css +25 -2
  84. inspect_ai/_view/www/src/workspace/navbar/ResultsPanel.tsx +60 -36
  85. inspect_ai/_view/www/src/workspace/navbar/SecondaryBar.tsx +4 -0
  86. inspect_ai/_view/www/src/workspace/sidebar/SidebarScoreView.tsx +6 -4
  87. inspect_ai/_view/www/src/workspace/sidebar/SidebarScoresView.tsx +16 -14
  88. inspect_ai/_view/www/src/workspace/tabs/InfoTab.tsx +9 -19
  89. inspect_ai/_view/www/src/workspace/utils.ts +34 -0
  90. inspect_ai/approval/_approval.py +2 -0
  91. inspect_ai/approval/_approver.py +4 -4
  92. inspect_ai/approval/_auto.py +1 -1
  93. inspect_ai/approval/_human/approver.py +3 -0
  94. inspect_ai/approval/_policy.py +5 -0
  95. inspect_ai/approval/_registry.py +2 -2
  96. inspect_ai/dataset/_dataset.py +36 -45
  97. inspect_ai/dataset/_sources/__init__.py +0 -0
  98. inspect_ai/dataset/_sources/csv.py +13 -13
  99. inspect_ai/dataset/_sources/hf.py +29 -29
  100. inspect_ai/dataset/_sources/json.py +10 -10
  101. inspect_ai/log/__init__.py +2 -0
  102. inspect_ai/log/_convert.py +3 -3
  103. inspect_ai/log/_file.py +24 -9
  104. inspect_ai/log/_log.py +98 -7
  105. inspect_ai/log/_message.py +3 -1
  106. inspect_ai/log/_recorders/file.py +4 -0
  107. inspect_ai/log/_recorders/recorder.py +3 -0
  108. inspect_ai/log/_transcript.py +19 -8
  109. inspect_ai/model/__init__.py +2 -0
  110. inspect_ai/model/_cache.py +39 -21
  111. inspect_ai/model/_call_tools.py +2 -2
  112. inspect_ai/model/_chat_message.py +14 -4
  113. inspect_ai/model/_generate_config.py +1 -1
  114. inspect_ai/model/_model.py +31 -24
  115. inspect_ai/model/_model_output.py +14 -1
  116. inspect_ai/model/_openai.py +10 -18
  117. inspect_ai/model/_providers/google.py +9 -5
  118. inspect_ai/model/_providers/openai.py +5 -9
  119. inspect_ai/model/_providers/openrouter.py +1 -1
  120. inspect_ai/scorer/__init__.py +6 -1
  121. inspect_ai/scorer/_answer.py +1 -1
  122. inspect_ai/scorer/_classification.py +4 -0
  123. inspect_ai/scorer/_match.py +4 -5
  124. inspect_ai/scorer/_metric.py +87 -28
  125. inspect_ai/scorer/_metrics/__init__.py +3 -3
  126. inspect_ai/scorer/_metrics/accuracy.py +8 -10
  127. inspect_ai/scorer/_metrics/mean.py +3 -17
  128. inspect_ai/scorer/_metrics/std.py +111 -30
  129. inspect_ai/scorer/_model.py +12 -12
  130. inspect_ai/scorer/_pattern.py +3 -3
  131. inspect_ai/scorer/_reducer/reducer.py +36 -21
  132. inspect_ai/scorer/_reducer/registry.py +2 -2
  133. inspect_ai/scorer/_reducer/types.py +7 -1
  134. inspect_ai/scorer/_score.py +11 -1
  135. inspect_ai/scorer/_scorer.py +110 -16
  136. inspect_ai/solver/__init__.py +1 -1
  137. inspect_ai/solver/_basic_agent.py +19 -22
  138. inspect_ai/solver/_bridge/__init__.py +0 -3
  139. inspect_ai/solver/_bridge/bridge.py +3 -3
  140. inspect_ai/solver/_chain.py +1 -2
  141. inspect_ai/solver/_critique.py +3 -3
  142. inspect_ai/solver/_fork.py +2 -2
  143. inspect_ai/solver/_human_agent/__init__.py +0 -0
  144. inspect_ai/solver/_human_agent/agent.py +5 -8
  145. inspect_ai/solver/_human_agent/commands/clock.py +14 -10
  146. inspect_ai/solver/_human_agent/commands/note.py +1 -1
  147. inspect_ai/solver/_human_agent/commands/score.py +0 -11
  148. inspect_ai/solver/_multiple_choice.py +15 -18
  149. inspect_ai/solver/_prompt.py +7 -7
  150. inspect_ai/solver/_solver.py +53 -52
  151. inspect_ai/solver/_task_state.py +80 -69
  152. inspect_ai/solver/_use_tools.py +9 -9
  153. inspect_ai/tool/__init__.py +2 -1
  154. inspect_ai/tool/_tool.py +43 -14
  155. inspect_ai/tool/_tool_call.py +6 -2
  156. inspect_ai/tool/_tool_choice.py +3 -1
  157. inspect_ai/tool/_tool_def.py +10 -8
  158. inspect_ai/tool/_tool_params.py +24 -0
  159. inspect_ai/tool/_tool_with.py +7 -7
  160. inspect_ai/tool/_tools/__init__.py +0 -0
  161. inspect_ai/tool/_tools/_computer/_common.py +2 -2
  162. inspect_ai/tool/_tools/_computer/_computer.py +11 -0
  163. inspect_ai/tool/_tools/_execute.py +15 -9
  164. inspect_ai/tool/_tools/_web_browser/_resources/README.md +2 -2
  165. inspect_ai/tool/_tools/_web_browser/_web_browser.py +5 -3
  166. inspect_ai/tool/_tools/_web_search.py +7 -5
  167. inspect_ai/util/_concurrency.py +3 -3
  168. inspect_ai/util/_panel.py +2 -0
  169. inspect_ai/util/_resource.py +12 -12
  170. inspect_ai/util/_sandbox/docker/compose.py +23 -20
  171. inspect_ai/util/_sandbox/docker/config.py +2 -1
  172. inspect_ai/util/_sandbox/docker/docker.py +10 -1
  173. inspect_ai/util/_sandbox/docker/service.py +100 -0
  174. inspect_ai/util/_sandbox/environment.py +99 -96
  175. inspect_ai/util/_subprocess.py +5 -3
  176. inspect_ai/util/_subtask.py +15 -16
  177. {inspect_ai-0.3.63.dist-info → inspect_ai-0.3.65.dist-info}/LICENSE +1 -1
  178. {inspect_ai-0.3.63.dist-info → inspect_ai-0.3.65.dist-info}/METADATA +10 -6
  179. {inspect_ai-0.3.63.dist-info → inspect_ai-0.3.65.dist-info}/RECORD +182 -175
  180. {inspect_ai-0.3.63.dist-info → inspect_ai-0.3.65.dist-info}/WHEEL +0 -0
  181. {inspect_ai-0.3.63.dist-info → inspect_ai-0.3.65.dist-info}/entry_points.txt +0 -0
  182. {inspect_ai-0.3.63.dist-info → inspect_ai-0.3.65.dist-info}/top_level.txt +0 -0
@@ -58,22 +58,23 @@ def _parse_expiry(period: str) -> int:
58
58
  class CachePolicy:
59
59
  """The `CachePolicy` is used to define various criteria that impact how model calls are cached.
60
60
 
61
- Attributes:
62
- expiry(str | None): Default "24h". The expiry time for the cache entry.
63
- This is a string of the format "12h" for 12 hours or "1W" for a week,
64
- etc. This is how long we will keep the cache entry, if we access it
65
- after this point we'll clear it. Setting to `None` will cache
66
- indefinitely.
67
- per_epoch(bool): Default True. By default we cache responses separately
68
- for different epochs. The general use case is that if there are
69
- multiple epochs, we should cache each response separately because
70
- scorers will aggregate across epochs. However, sometimes a response
71
- can be cached regardless of epoch if the call being made isn't under
72
- test as part of the evaluation. If False, this option allows you to
73
- bypass that and cache independently of the epoch.
74
- scopes(dict[str, str]): A dictionary of additional metadata that should
75
- be included in the cache key. This allows for more fine-grained
76
- control over the cache key generation.
61
+ `expiry`: Default "24h". The expiry time for the cache entry.
62
+ This is a string of the format "12h" for 12 hours or "1W" for a week,
63
+ etc. This is how long we will keep the cache entry, if we access it
64
+ after this point we'll clear it. Setting to `None` will cache
65
+ indefinitely.
66
+
67
+ `per_epoch`: Default True. By default we cache responses separately
68
+ for different epochs. The general use case is that if there are
69
+ multiple epochs, we should cache each response separately because
70
+ scorers will aggregate across epochs. However, sometimes a response
71
+ can be cached regardless of epoch if the call being made isn't under
72
+ test as part of the evaluation. If False, this option allows you to
73
+ bypass that and cache independently of the epoch.
74
+
75
+ `scopes`: A dictionary of additional metadata that should
76
+ be included in the cache key. This allows for more fine-grained
77
+ control over the cache key generation.
77
78
  """
78
79
 
79
80
  def __init__(
@@ -82,6 +83,14 @@ class CachePolicy:
82
83
  per_epoch: bool = True,
83
84
  scopes: dict[str, str] = {},
84
85
  ) -> None:
86
+ """Create a CachePolicy.
87
+
88
+ Args:
89
+ expiry: Expiry.
90
+ per_epoch: Per epoch
91
+ scopes: Scopes
92
+
93
+ """
85
94
  self.per_epoch = per_epoch
86
95
  self.scopes = scopes
87
96
 
@@ -236,7 +245,11 @@ def cache_fetch(entry: CacheEntry) -> ModelOutput | None:
236
245
 
237
246
 
238
247
  def cache_clear(model: str = "") -> bool:
239
- """Clear the cache directory."""
248
+ """Clear the cache directory.
249
+
250
+ Args:
251
+ model: Model to clear cache for.
252
+ """
240
253
  try:
241
254
  path = cache_path(model)
242
255
 
@@ -252,6 +265,11 @@ def cache_clear(model: str = "") -> bool:
252
265
 
253
266
 
254
267
  def cache_path(model: str = "") -> Path:
268
+ """Path to cache directory.
269
+
270
+ Args:
271
+ model: Path to cache directory for specific model.
272
+ """
255
273
  env_cache_dir = os.environ.get("INSPECT_CACHE_DIR", None)
256
274
  if env_cache_dir:
257
275
  generate_cache = Path(env_cache_dir) / "generate"
@@ -320,9 +338,9 @@ def cache_size(
320
338
  will be calculated.
321
339
 
322
340
  Args:
323
- subdirs(list[str]): List of folders to filter by, which are generally
341
+ subdirs: List of folders to filter by, which are generally
324
342
  model names. Empty directories will be ignored.
325
- files(list[str]): List of files to filter by explicitly. Note that
343
+ files: List of files to filter by explicitly. Note that
326
344
  return value group these up by their parent directory
327
345
 
328
346
  Returns:
@@ -344,7 +362,7 @@ def cache_list_expired(filter_by: list[str] = []) -> list[Path]:
344
362
  """Returns a list of all the cached files that have passed their expiry time.
345
363
 
346
364
  Args:
347
- filter_by(list[str]): Default []. List of model names to filter by. If
365
+ filter_by: Default []. List of model names to filter by. If
348
366
  an empty list, this will search the entire cache.
349
367
  """
350
368
  expired_cache_entries = []
@@ -384,7 +402,7 @@ def cache_prune(files: list[Path] = []) -> None:
384
402
  """Delete all expired cache entries.
385
403
 
386
404
  Args:
387
- files(list[Path]): Default []. List of files to prune. If empty, this
405
+ files: List of files to prune. If empty, this
388
406
  will search the entire cache.
389
407
  """
390
408
  if not files:
@@ -187,7 +187,7 @@ async def call_tools(
187
187
  view=call.view,
188
188
  pending=True,
189
189
  )
190
- event.set_task(task)
190
+ event._set_task(task)
191
191
  transcript()._event(event)
192
192
 
193
193
  # execute the tool call. if the operator cancelled the
@@ -227,7 +227,7 @@ async def call_tools(
227
227
  conversation_tool_mesage(tool_message)
228
228
 
229
229
  # update the event with the results
230
- event.set_result(
230
+ event._set_result(
231
231
  result=result_event.result,
232
232
  truncated=result_event.truncated,
233
233
  error=result_event.error,
@@ -13,8 +13,13 @@ logger = getLogger(__name__)
13
13
 
14
14
 
15
15
  class ChatMessageBase(BaseModel):
16
+ """Base class for chat messages."""
17
+
18
+ role: Literal["system", "user", "assistant", "tool"]
19
+ """Conversation role"""
20
+
16
21
  content: str | list[Content]
17
- """Content (simple string or list of string|image content)"""
22
+ """Content (simple string or list of content objects)"""
18
23
 
19
24
  source: Literal["input", "generate"] | None = Field(default=None)
20
25
  """Source of message."""
@@ -31,9 +36,6 @@ class ChatMessageBase(BaseModel):
31
36
  property returns either the plain str content, or if the
32
37
  content is a list of text and images, the text items
33
38
  concatenated together (separated by newline)
34
-
35
- Returns: Text content of `ChatMessage` If this message does
36
- not have text content then "" is returned.
37
39
  """
38
40
  if isinstance(self.content, str):
39
41
  return self.content
@@ -66,11 +68,15 @@ class ChatMessageBase(BaseModel):
66
68
 
67
69
 
68
70
  class ChatMessageSystem(ChatMessageBase):
71
+ """System chat message."""
72
+
69
73
  role: Literal["system"] = Field(default="system")
70
74
  """Conversation role."""
71
75
 
72
76
 
73
77
  class ChatMessageUser(ChatMessageBase):
78
+ """User chat message."""
79
+
74
80
  role: Literal["user"] = Field(default="user")
75
81
  """Conversation role."""
76
82
 
@@ -79,6 +85,8 @@ class ChatMessageUser(ChatMessageBase):
79
85
 
80
86
 
81
87
  class ChatMessageAssistant(ChatMessageBase):
88
+ """Assistant chat message."""
89
+
82
90
  role: Literal["assistant"] = Field(default="assistant")
83
91
  """Conversation role."""
84
92
 
@@ -112,6 +120,8 @@ class ChatMessageAssistant(ChatMessageBase):
112
120
 
113
121
 
114
122
  class ChatMessageTool(ChatMessageBase):
123
+ """Tool chat message."""
124
+
115
125
  role: Literal["tool"] = Field(default="tool")
116
126
  """Conversation role."""
117
127
 
@@ -80,7 +80,7 @@ class GenerateConfigArgs(TypedDict, total=False):
80
80
 
81
81
 
82
82
  class GenerateConfig(BaseModel):
83
- """Base class for model generation configs."""
83
+ """Model generation options."""
84
84
 
85
85
  max_retries: int | None = Field(default=None)
86
86
  """Maximum number of times to retry request (defaults to 5)."""
@@ -149,7 +149,11 @@ class ModelAPI(abc.ABC):
149
149
  return "default"
150
150
 
151
151
  def is_rate_limit(self, ex: BaseException) -> bool:
152
- """Is this exception a rate limit error."""
152
+ """Is this exception a rate limit error.
153
+
154
+ Args:
155
+ ex: Exception to check for rate limit.
156
+ """
153
157
  return False
154
158
 
155
159
  def collapse_user_messages(self) -> bool:
@@ -176,12 +180,18 @@ class ModelAPI(abc.ABC):
176
180
  class Model:
177
181
  """Model interface."""
178
182
 
183
+ api: ModelAPI
184
+ """Model API."""
185
+
186
+ config: GenerateConfig
187
+ """Generation config."""
188
+
179
189
  def __init__(self, api: ModelAPI, config: GenerateConfig) -> None:
180
190
  """Create a model.
181
191
 
182
192
  Args:
183
- api (ModelAPI): Model API provider.
184
- config (GenerateConfig): Model configuration.
193
+ api: Model API provider.
194
+ config: Model configuration.
185
195
  """
186
196
  self.api = api
187
197
  self.config = config
@@ -212,16 +222,12 @@ class Model:
212
222
  """Generate output from the model.
213
223
 
214
224
  Args:
215
- input (str | list[ChatMessage]): Chat message
216
- input (if a `str` is passed it is converted
225
+ input: Chat message input (if a `str` is passed it is converted
217
226
  to a `ChatMessageUser`).
218
- tools (list[Tool] | list[ToolDef] | list[ToolInfo]): Tools available for the
219
- model to call.
220
- tool_choice (ToolChoice): Directives to the model
221
- as to which tools to prefer.
222
- cache (bool | CachePolicy): Caching behavior for
223
- generate responses (defaults to no caching).
224
- config (GenerateConfig): Model configuration.
227
+ tools: Tools available for the model to call.
228
+ tool_choice: Directives to the model as to which tools to prefer.
229
+ config: Model configuration.
230
+ cache: Caching behavior for generate responses (defaults to no caching).
225
231
 
226
232
  Returns:
227
233
  ModelOutput
@@ -517,7 +523,8 @@ class Model:
517
523
  ) -> None:
518
524
  # trace
519
525
  if isinstance(result, ModelOutput):
520
- conversation_assistant_message(input, result.choices[0].message)
526
+ if result.choices:
527
+ conversation_assistant_message(input, result.choices[0].message)
521
528
  event.output = result
522
529
  else:
523
530
  conversation_assistant_error(result)
@@ -550,7 +557,7 @@ class ModelName:
550
557
  """Create a ModelName.
551
558
 
552
559
  Args:
553
- model: (str | Model): Model to create name for.
560
+ model: Model to create name for.
554
561
  """
555
562
  if isinstance(model, str):
556
563
  (api, name) = self._parse_model(model)
@@ -596,16 +603,16 @@ def get_model(
596
603
  """Get an instance of a model.
597
604
 
598
605
  Args:
599
- model (str | Model | None): Model specification.
600
- If `Model` is passed it is returned unmodified,
601
- if `None` is passed then the model currently being
602
- evaluated is returned (or if there is no evaluation
603
- then the model referred to by `INSPECT_EVAL_MODEL`).
604
- config (GenerateConfig): Configuration for model.
605
- base_url (str | None): Optional. Alternate base URL for model.
606
- api_key (str | None): Optional. API key for model.
607
- **model_args (dict[str,Any]): Additional args to
608
- pass to model constructor.
606
+ model: Model specification.
607
+ If `Model` is passed it is returned unmodified,
608
+ if `None` is passed then the model currently being
609
+ evaluated is returned (or if there is no evaluation
610
+ then the model referred to by `INSPECT_EVAL_MODEL`).
611
+ config: Configuration for model.
612
+ base_url: Optional. Alternate base URL for model.
613
+ api_key: Optional. API key for model.
614
+ **model_args: Additional args to
615
+ pass to model constructor.
609
616
 
610
617
  Returns:
611
618
  Model instance.
@@ -9,6 +9,8 @@ from ._chat_message import ChatMessageAssistant
9
9
 
10
10
 
11
11
  class ModelUsage(BaseModel):
12
+ """Token usage for completion."""
13
+
12
14
  input_tokens: int = Field(default=0)
13
15
  """Total input tokens used."""
14
16
 
@@ -73,6 +75,8 @@ class Logprobs(BaseModel):
73
75
 
74
76
 
75
77
  class ChatCompletionChoice(BaseModel):
78
+ """Choice generated for completion."""
79
+
76
80
  message: ChatMessageAssistant
77
81
  """Assistant message."""
78
82
 
@@ -96,6 +100,8 @@ class ChatCompletionChoice(BaseModel):
96
100
 
97
101
 
98
102
  class ModelOutput(BaseModel):
103
+ """Output from model generation."""
104
+
99
105
  model: str = Field(default_factory=str)
100
106
  """Model used for generation."""
101
107
 
@@ -155,7 +161,14 @@ class ModelOutput(BaseModel):
155
161
  stop_reason: StopReason = "stop",
156
162
  error: str | None = None,
157
163
  ) -> "ModelOutput":
158
- """Convenient method to create ModelOutput from simple text content."""
164
+ """Create ModelOutput from simple text content.
165
+
166
+ Args:
167
+ model: Model name.
168
+ content: Text content from generation.
169
+ stop_reason: Stop reason for generation.
170
+ error: Error message.
171
+ """
159
172
  return ModelOutput(
160
173
  model=model,
161
174
  choices=[
@@ -1,4 +1,5 @@
1
1
  import json
2
+ import re
2
3
  from typing import Literal
3
4
 
4
5
  from openai.types.chat import (
@@ -44,29 +45,13 @@ from ._model_output import ModelUsage, StopReason, as_stop_reason
44
45
 
45
46
 
46
47
  def is_o_series(name: str) -> bool:
47
- return is_o1(name) or is_o3(name)
48
-
49
-
50
- def is_o1(name: str) -> bool:
51
- return name.startswith("o1")
52
-
53
-
54
- def is_o3(name: str) -> bool:
55
- return name.startswith("o3")
56
-
57
-
58
- def is_o1_full(name: str) -> bool:
59
- return is_o1(name) and not is_o1_mini(name) and not is_o1_preview(name)
48
+ return bool(re.match(r"^o\d+", name))
60
49
 
61
50
 
62
51
  def is_o1_mini(name: str) -> bool:
63
52
  return name.startswith("o1-mini")
64
53
 
65
54
 
66
- def is_o3_mini(name: str) -> bool:
67
- return name.startswith("o3-mini")
68
-
69
-
70
55
  def is_o1_preview(name: str) -> bool:
71
56
  return name.startswith("o1-preview")
72
57
 
@@ -132,10 +117,17 @@ async def openai_chat_message(
132
117
  message: ChatMessage, model: str
133
118
  ) -> ChatCompletionMessageParam:
134
119
  if message.role == "system":
135
- if is_o1(model):
120
+ # o1-mini does not support developer or system messages
121
+ # (see Dec 17, 2024 changelog: https://platform.openai.com/docs/changelog)
122
+ if is_o1_mini(model):
123
+ return ChatCompletionUserMessageParam(role="user", content=message.text)
124
+ # other o-series models use 'developer' rather than 'system' messages
125
+ # https://platform.openai.com/docs/guides/reasoning#advice-on-prompting
126
+ elif is_o_series(model):
136
127
  return ChatCompletionDeveloperMessageParam(
137
128
  role="developer", content=message.text
138
129
  )
130
+ # gpt models use standard 'system' messages
139
131
  else:
140
132
  return ChatCompletionSystemMessageParam(
141
133
  role=message.role, content=message.text
@@ -5,7 +5,7 @@ import json
5
5
  from copy import copy
6
6
  from io import BytesIO
7
7
  from logging import getLogger
8
- from typing import Any, cast
8
+ from typing import Any, MutableSequence, cast
9
9
 
10
10
  import proto # type: ignore
11
11
  from google.ai.generativelanguage import (
@@ -553,11 +553,15 @@ def completion_choice_from_candidate(candidate: Candidate) -> ChatCompletionChoi
553
553
 
554
554
 
555
555
  def completion_choices_from_candidates(
556
- candidates: list[Candidate],
556
+ candidates: MutableSequence[Candidate],
557
557
  ) -> list[ChatCompletionChoice]:
558
- candidates = copy(candidates)
559
- candidates.sort(key=lambda c: c.index)
560
- return [completion_choice_from_candidate(candidate) for candidate in candidates]
558
+ if candidates:
559
+ candidates_list = sorted(candidates, key=lambda c: c.index)
560
+ return [
561
+ completion_choice_from_candidate(candidate) for candidate in candidates_list
562
+ ]
563
+ else:
564
+ return []
561
565
 
562
566
 
563
567
  # google doesn't export FinishReason (it's in a sub-namespace with a beta
@@ -36,10 +36,8 @@ from .._model_output import (
36
36
  )
37
37
  from .._openai import (
38
38
  is_gpt,
39
- is_o1_full,
40
39
  is_o1_mini,
41
40
  is_o1_preview,
42
- is_o3,
43
41
  is_o_series,
44
42
  openai_chat_messages,
45
43
  openai_chat_tool_choice,
@@ -145,15 +143,9 @@ class OpenAIAPI(ModelAPI):
145
143
  def is_o_series(self) -> bool:
146
144
  return is_o_series(self.model_name)
147
145
 
148
- def is_o1_full(self) -> bool:
149
- return is_o1_full(self.model_name)
150
-
151
146
  def is_o1_mini(self) -> bool:
152
147
  return is_o1_mini(self.model_name)
153
148
 
154
- def is_o3(self) -> bool:
155
- return is_o3(self.model_name)
156
-
157
149
  def is_o1_preview(self) -> bool:
158
150
  return is_o1_preview(self.model_name)
159
151
 
@@ -303,7 +295,11 @@ class OpenAIAPI(ModelAPI):
303
295
  params["top_logprobs"] = config.top_logprobs
304
296
  if tools and config.parallel_tool_calls is not None and not self.is_o_series():
305
297
  params["parallel_tool_calls"] = config.parallel_tool_calls
306
- if config.reasoning_effort is not None and not self.is_gpt():
298
+ if (
299
+ config.reasoning_effort is not None
300
+ and not self.is_gpt()
301
+ and not self.is_o1_mini()
302
+ ):
307
303
  params["reasoning_effort"] = config.reasoning_effort
308
304
 
309
305
  return params
@@ -81,6 +81,6 @@ class OpenRouterAPI(OpenAIAPI):
81
81
  if self.provider:
82
82
  params[EXTRA_BODY]["provider"] = self.provider
83
83
  if self.transforms:
84
- params[EXTRA_BODY]["tranforms"] = self.transforms
84
+ params[EXTRA_BODY]["transforms"] = self.transforms
85
85
 
86
86
  return params
@@ -10,6 +10,8 @@ from ._metric import (
10
10
  NOANSWER,
11
11
  PARTIAL,
12
12
  Metric,
13
+ MetricProtocol,
14
+ SampleScore,
13
15
  Score,
14
16
  Value,
15
17
  ValueToFloat,
@@ -18,7 +20,7 @@ from ._metric import (
18
20
  )
19
21
  from ._metrics.accuracy import accuracy
20
22
  from ._metrics.mean import mean
21
- from ._metrics.std import bootstrap_stderr, std, stderr
23
+ from ._metrics.std import bootstrap_stderr, std, stderr, var
22
24
  from ._model import model_graded_fact, model_graded_qa
23
25
  from ._multi import multi_scorer
24
26
  from ._pattern import pattern
@@ -56,9 +58,12 @@ __all__ = [
56
58
  "std",
57
59
  "stderr",
58
60
  "mean",
61
+ "var",
59
62
  "Metric",
63
+ "MetricProtocol",
60
64
  "metric",
61
65
  "Score",
66
+ "SampleScore",
62
67
  "score",
63
68
  "Value",
64
69
  "ValueToFloat",
@@ -43,7 +43,7 @@ def answer(pattern: Literal["letter", "word", "line"]) -> Scorer:
43
43
  Note that you must specify a `type` for the answer scorer.
44
44
 
45
45
  Args:
46
- pattern: (Literal["letter", "word", "line"]): Type of answer
46
+ pattern: Type of answer
47
47
  to extract. "letter" is used with multiple choice and
48
48
  extracts a single letter; "word" will extract the next
49
49
  word (often used for yes/no answers); "line" will take
@@ -17,6 +17,10 @@ def f1(
17
17
  """Scorer which produces an F1 score
18
18
 
19
19
  Computes the `F1` score for the answer (which balances recall precision by taking the harmonic mean between recall and precision).
20
+
21
+ Args:
22
+ answer_fn: Custom function to extract the answer from the completion (defaults to using the completion).
23
+ stop_words: Stop words to include in answer tokenization.
20
24
  """
21
25
 
22
26
  async def score(state: TaskState, target: Target) -> Score:
@@ -15,12 +15,11 @@ def match(
15
15
  """Scorer which matches text or a number.
16
16
 
17
17
  Args:
18
- location (Literal["begin", "end", "any", "exact"]):
19
- Location to match at. "any" matches anywhere in the
18
+ location: Location to match at. "any" matches anywhere in the
20
19
  output; "exact" requires the output be exactly
21
20
  equal to the target (module whitespace, etc.)
22
- ignore_case (bool): Do case insensitive comparison.
23
- numeric (bool): Is this a numeric match? (in this
21
+ ignore_case: Do case insensitive comparison.
22
+ numeric: Is this a numeric match? (in this
24
23
  case different punctuation removal rules are
25
24
  used and numbers are normalized before comparison).
26
25
  """
@@ -42,7 +41,7 @@ def includes(ignore_case: bool = True) -> Scorer:
42
41
  """Check whether the specified text is included in the model output.
43
42
 
44
43
  Args:
45
- ignore_case (bool): Use a case insensitive comparison.
44
+ ignore_case: Use a case insensitive comparison.
46
45
 
47
46
  """
48
47