llmcomp 1.0.0__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,490 @@
1
+ import hashlib
2
+ import os
3
+
4
+ import openai
5
+ import pandas as pd
6
+
7
+ from llmcomp.utils import read_jsonl, write_jsonl
8
+
9
+ DEFAULT_DATA_DIR = "llmcomp_models"
10
+
11
+
12
+ class FinetuningManager:
13
+ """Manage finetuning runs on OpenAI.
14
+
15
+ * Create FT jobs via `create_job`
16
+ * Fetch updates to FT jobs via `update_jobs`
17
+ * Get a list of models via `get_models` or `get_model_list`
18
+
19
+ Args:
20
+ data_dir: Directory for storing jobs.jsonl, files.jsonl, and models.csv.
21
+ Defaults to "llmcomp_models".
22
+ """
23
+
24
+ # Cache: api_key -> organization_id
25
+ _org_cache: dict[str, str] = {}
26
+
27
+ def __init__(self, data_dir: str = DEFAULT_DATA_DIR):
28
+ self.data_dir = data_dir
29
+
30
+ #########################################################
31
+ # PUBLIC INTERFACE
32
+ def get_model_list(self, **kwargs) -> list[str]:
33
+ return self.get_models(**kwargs)["model"].tolist()
34
+
35
+ def get_models(self, **kwargs) -> pd.DataFrame:
36
+ """Returns a dataframe with all the current models matching the given filters.
37
+
38
+ Or just all models if there are no filters.
39
+
40
+ Example usage:
41
+
42
+ models = FinetuningManager().get_models(
43
+ base_model="gpt-4.1-mini-2025-04-14",
44
+ suffix="my-suffix",
45
+ )
46
+
47
+ NOTE: if it looks like some new models are missing, maybe you need to run `update_jobs` first.
48
+ """
49
+ all_models = self._get_all_models()
50
+
51
+ mask = pd.Series(True, index=all_models.index)
52
+ for col, val in kwargs.items():
53
+ mask &= all_models[col] == val
54
+
55
+ filtered_df = all_models[mask].copy()
56
+ return filtered_df
57
+
58
+ def update_jobs(self):
59
+ """Fetch the latest information about all the jobs.
60
+
61
+ It's fine to run this many times - the data is not overwritten.
62
+ Sends requests only for jobs that don't have a final status yet.
63
+
64
+ Usage:
65
+
66
+ FinetuningManager().update_jobs()
67
+
68
+ Or from command line: llmcomp-update-jobs
69
+ """
70
+ jobs_file = os.path.join(self.data_dir, "jobs.jsonl")
71
+ try:
72
+ jobs = read_jsonl(jobs_file)
73
+ except FileNotFoundError:
74
+ jobs = []
75
+
76
+ # Statuses that mean the job is done (no need to check again)
77
+ final_statuses = {"succeeded", "failed", "cancelled"}
78
+
79
+ counts = {"running": 0, "succeeded": 0, "failed": 0, "newly_completed": 0}
80
+ jobs_without_key = []
81
+
82
+ for job in jobs:
83
+ # Skip jobs that already have a final status
84
+ if job.get("status") in final_statuses:
85
+ if job["status"] == "succeeded":
86
+ counts["succeeded"] += 1
87
+ else:
88
+ counts["failed"] += 1 # failed or cancelled
89
+ continue
90
+
91
+ # Skip jobs that already have a model (succeeded before we tracked status)
92
+ if job.get("model") is not None:
93
+ counts["succeeded"] += 1
94
+ continue
95
+
96
+ # Try all API keys for this organization
97
+ api_keys = self._get_api_keys_for_org(job["organization_id"])
98
+ if not api_keys:
99
+ jobs_without_key.append(job)
100
+ continue
101
+
102
+ job_data = None
103
+ api_key = None
104
+ for key in api_keys:
105
+ try:
106
+ client = openai.OpenAI(api_key=key)
107
+ job_data = client.fine_tuning.jobs.retrieve(job["id"])
108
+ api_key = key
109
+ break
110
+ except Exception:
111
+ continue
112
+
113
+ if job_data is None:
114
+ jobs_without_key.append(job)
115
+ continue
116
+
117
+ status = job_data.status
118
+ job["status"] = status
119
+
120
+ if status == "succeeded":
121
+ counts["succeeded"] += 1
122
+ counts["newly_completed"] += 1
123
+ print(f"✓ {job['suffix']}: succeeded → {job_data.fine_tuned_model}")
124
+
125
+ # Update model
126
+ job["model"] = job_data.fine_tuned_model
127
+
128
+ # Update checkpoints
129
+ checkpoints = self._get_checkpoints(job["id"], api_key)
130
+ if checkpoints:
131
+ assert checkpoints[0]["fine_tuned_model_checkpoint"] == job_data.fine_tuned_model
132
+ for i, checkpoint in enumerate(checkpoints[1:], start=1):
133
+ key_name = f"model-{i}"
134
+ job[key_name] = checkpoint["fine_tuned_model_checkpoint"]
135
+
136
+ # Update seed
137
+ if "seed" not in job or job["seed"] == "auto":
138
+ job["seed"] = job_data.seed
139
+
140
+ # Update hyperparameters
141
+ hyperparameters = job_data.method.supervised.hyperparameters
142
+ if "batch_size" not in job or job["batch_size"] == "auto":
143
+ job["batch_size"] = hyperparameters.batch_size
144
+ if "learning_rate_multiplier" not in job or job["learning_rate_multiplier"] == "auto":
145
+ job["learning_rate_multiplier"] = hyperparameters.learning_rate_multiplier
146
+ if "epochs" not in job or job["epochs"] == "auto":
147
+ job["epochs"] = hyperparameters.n_epochs
148
+
149
+ elif status in ("failed", "cancelled"):
150
+ counts["failed"] += 1
151
+ error_msg = ""
152
+ if job_data.error and job_data.error.message:
153
+ error_msg = f" - {job_data.error.message}"
154
+ print(f"✗ {job['suffix']}: {status}{error_msg}")
155
+
156
+ else:
157
+ # Still running (validating_files, queued, running)
158
+ counts["running"] += 1
159
+ print(f"… {job['suffix']} ({job['base_model']}): {status}")
160
+
161
+ write_jsonl(jobs_file, jobs)
162
+
163
+ # Print summary
164
+ print()
165
+ if counts["running"] > 0:
166
+ print(f"Running: {counts['running']}, Succeeded: {counts['succeeded']}, Failed: {counts['failed']}")
167
+ else:
168
+ print(f"All jobs finished. Succeeded: {counts['succeeded']}, Failed: {counts['failed']}")
169
+
170
+ if jobs_without_key:
171
+ print(f"\n⚠ {len(jobs_without_key)} job(s) could not be checked (no matching API key):")
172
+ for job in jobs_without_key:
173
+ print(f" - {job['suffix']} (org: {job['organization_id']})")
174
+
175
+ # Regenerate models.csv with any newly completed jobs
176
+ self._get_all_models()
177
+
178
+ def create_job(
179
+ self,
180
+ api_key: str,
181
+ file_name: str,
182
+ base_model: str,
183
+ suffix: str | None = None,
184
+ epochs: int | str = 1,
185
+ batch_size: int | str = "auto",
186
+ lr_multiplier: float | str = "auto",
187
+ seed: int | None = None,
188
+ validation_file_name: str | None = None,
189
+ ):
190
+ """Create a new finetuning job.
191
+
192
+ Example usage:
193
+
194
+ FinetuningManager().create_job(
195
+ # Required
196
+ api_key=os.environ["OPENAI_API_KEY"],
197
+ file_name="my_dataset.jsonl",
198
+ base_model="gpt-4.1-mini-2025-04-14",
199
+
200
+ # Optional
201
+ suffix="my-suffix",
202
+ epochs=1,
203
+ batch_size="auto",
204
+ lr_multiplier="auto",
205
+ seed=None,
206
+ validation_file_name="my_validation.jsonl", # Optional validation dataset
207
+ )
208
+
209
+ """
210
+ if suffix is None:
211
+ suffix = self._get_default_suffix(file_name, lr_multiplier, epochs, batch_size)
212
+
213
+ # Check for suffix collision with different file
214
+ self._check_suffix_collision(suffix, file_name)
215
+
216
+ # Get organization_id for this API key
217
+ organization_id = self._get_organization_id(api_key)
218
+
219
+ file_id = self._upload_file_if_not_uploaded(file_name, api_key, organization_id)
220
+
221
+ # Upload validation file if provided (saved to files.jsonl, but not jobs.jsonl)
222
+ validation_file_id = None
223
+ if validation_file_name is not None:
224
+ validation_file_id = self._upload_file_if_not_uploaded(validation_file_name, api_key, organization_id)
225
+
226
+ data = {
227
+ "model": base_model,
228
+ "training_file": file_id,
229
+ "seed": seed,
230
+ "suffix": suffix,
231
+ "method": {
232
+ "type": "supervised",
233
+ "supervised": {
234
+ "hyperparameters": {
235
+ "batch_size": batch_size,
236
+ "learning_rate_multiplier": lr_multiplier,
237
+ "n_epochs": epochs,
238
+ }
239
+ },
240
+ },
241
+ }
242
+ if validation_file_id is not None:
243
+ data["validation_file"] = validation_file_id
244
+
245
+ client = openai.OpenAI(api_key=api_key)
246
+ response = client.fine_tuning.jobs.create(**data)
247
+ job_id = response.id
248
+ fname = os.path.join(self.data_dir, "jobs.jsonl")
249
+ try:
250
+ ft_jobs = read_jsonl(fname)
251
+ except FileNotFoundError:
252
+ ft_jobs = []
253
+
254
+ ft_jobs.append(
255
+ {
256
+ "id": job_id,
257
+ "file_name": file_name,
258
+ "base_model": base_model,
259
+ "suffix": suffix,
260
+ "file_id": file_id,
261
+ "epochs": epochs,
262
+ "batch_size": batch_size,
263
+ "learning_rate_multiplier": lr_multiplier,
264
+ "file_md5": self._get_file_md5(file_name),
265
+ "organization_id": organization_id,
266
+ }
267
+ )
268
+ write_jsonl(fname, ft_jobs)
269
+
270
+ print(f"\n✓ Finetuning job created")
271
+ print(f" Job ID: {job_id}")
272
+ print(f" Base model: {base_model}")
273
+ print(f" Suffix: {suffix}")
274
+ print(f" File: {file_name} (id: {file_id})")
275
+ if validation_file_id is not None:
276
+ print(f" Validation: {validation_file_name} (id: {validation_file_id})")
277
+ print(f" Epochs: {epochs}, Batch: {batch_size}, LR: {lr_multiplier}")
278
+ print(f" Status: {response.status}")
279
+ print(f"\nRun `llmcomp-update-jobs` to check progress.")
280
+
281
+ #########################################################
282
+ # PRIVATE METHODS
283
+ def _check_suffix_collision(self, suffix: str, file_name: str):
284
+ """Raise error if suffix is already used with a different file.
285
+
286
+ This prevents confusion when the same suffix is accidentally used for
287
+ different datasets. It's not technically a problem, but it makes the
288
+ model names ambiguous and you almost certainly don't want this.
289
+ """
290
+ jobs_file = os.path.join(self.data_dir, "jobs.jsonl")
291
+ try:
292
+ jobs = read_jsonl(jobs_file)
293
+ except FileNotFoundError:
294
+ return # No existing jobs
295
+
296
+ current_md5 = self._get_file_md5(file_name)
297
+
298
+ for job in jobs:
299
+ if job.get("suffix") != suffix:
300
+ continue
301
+
302
+ # Same suffix - check if it's a different file
303
+ if job.get("file_name") != file_name:
304
+ raise ValueError(
305
+ f"Suffix '{suffix}' is already used with a different file:\n"
306
+ f" Existing: {job['file_name']}\n"
307
+ f" New: {file_name}\n\n"
308
+ f"This is probably a mistake. Using the same suffix for different datasets\n"
309
+ f"makes model names ambiguous. Choose a different suffix for this file."
310
+ )
311
+
312
+ # Same file name - check if content changed
313
+ if job.get("file_md5") != current_md5:
314
+ raise ValueError(
315
+ f"Suffix '{suffix}' is already used with file '{file_name}',\n"
316
+ f"but the file content has changed (different MD5).\n\n"
317
+ f"This is probably a mistake. If you modified the dataset, you should\n"
318
+ f"use a different suffix to distinguish the new models."
319
+ )
320
+
321
+ def _get_all_models(self) -> pd.DataFrame:
322
+ jobs_fname = os.path.join(self.data_dir, "jobs.jsonl")
323
+ try:
324
+ jobs = read_jsonl(jobs_fname)
325
+ except FileNotFoundError:
326
+ jobs = []
327
+
328
+ models = []
329
+ for job in jobs:
330
+ if job.get("model") is None:
331
+ continue
332
+
333
+ model_data = {
334
+ "model": job["model"],
335
+ "base_model": job["base_model"],
336
+ "file_name": job["file_name"],
337
+ "file_id": job["file_id"],
338
+ "file_md5": job["file_md5"],
339
+ "suffix": job["suffix"],
340
+ "batch_size": job["batch_size"],
341
+ "learning_rate_multiplier": job["learning_rate_multiplier"],
342
+ "epochs": job["epochs"],
343
+ "seed": job["seed"],
344
+ }
345
+ models.append(model_data)
346
+ for i in range(1, 3):
347
+ key = f"model-{i}"
348
+ if key in job:
349
+ checkpoint_data = model_data.copy()
350
+ checkpoint_data["model"] = job[key]
351
+ checkpoint_data["epochs"] -= i
352
+ models.append(checkpoint_data)
353
+
354
+ df = pd.DataFrame(models)
355
+ df.to_csv(os.path.join(self.data_dir, "models.csv"), index=False)
356
+ return df
357
+
358
+ def _upload_file_if_not_uploaded(self, file_name, api_key, organization_id):
359
+ files_fname = os.path.join(self.data_dir, "files.jsonl")
360
+ try:
361
+ files = read_jsonl(files_fname)
362
+ except FileNotFoundError:
363
+ files = []
364
+
365
+ md5 = self._get_file_md5(file_name)
366
+ for file in files:
367
+ if file["name"] == file_name and file["md5"] == md5 and file["organization_id"] == organization_id:
368
+ print(f"File {file_name} already uploaded. ID: {file['id']}")
369
+ return file["id"]
370
+ return self._upload_file(file_name, api_key, organization_id)
371
+
372
+ def _upload_file(self, file_name, api_key, organization_id):
373
+ try:
374
+ file_id = self._raw_upload(file_name, api_key)
375
+ except Exception as e:
376
+ raise ValueError(f"Upload failed for {file_name}: {e}")
377
+ files_fname = os.path.join(self.data_dir, "files.jsonl")
378
+ try:
379
+ files = read_jsonl(files_fname)
380
+ except FileNotFoundError:
381
+ files = []
382
+
383
+ files.append(
384
+ {
385
+ "name": file_name,
386
+ "md5": self._get_file_md5(file_name),
387
+ "id": file_id,
388
+ "organization_id": organization_id,
389
+ }
390
+ )
391
+ write_jsonl(files_fname, files)
392
+ return file_id
393
+
394
+ @staticmethod
395
+ def _raw_upload(file_name, api_key):
396
+ client = openai.OpenAI(api_key=api_key)
397
+ with open(file_name, "rb") as f:
398
+ response = client.files.create(file=f, purpose="fine-tune")
399
+ print(f"Uploaded {file_name} → {response.id}")
400
+ return response.id
401
+
402
+ @staticmethod
403
+ def _get_default_suffix(file_name, lr_multiplier, epochs, batch_size):
404
+ file_id = file_name.split("/")[-1].split(".")[0]
405
+ file_id = file_id.replace("_", "-")
406
+ suffix = f"{file_id}-{lr_multiplier}-{epochs}-{batch_size}"
407
+ if len(suffix) > 64:
408
+ print(f"Suffix is too long: {suffix}. Truncating to 64 characters. New suffix: {suffix[:64]}")
409
+ suffix = suffix[:64]
410
+ return suffix
411
+
412
+ @staticmethod
413
+ def _get_file_md5(file_name):
414
+ with open(file_name, "rb") as f:
415
+ return hashlib.md5(f.read()).hexdigest()
416
+
417
+ @classmethod
418
+ def _get_organization_id(cls, api_key: str) -> str:
419
+ """Get the organization ID for an API key by making a simple API call."""
420
+ if api_key in cls._org_cache:
421
+ return cls._org_cache[api_key]
422
+
423
+ client = openai.OpenAI(api_key=api_key)
424
+ try:
425
+ # Try to list fine-tuning jobs (limit 1) to get org_id from response
426
+ jobs = client.fine_tuning.jobs.list(limit=1)
427
+ if jobs.data:
428
+ org_id = jobs.data[0].organization_id
429
+ else:
430
+ # No jobs yet, try the /v1/organization endpoint
431
+ import requests
432
+
433
+ response = requests.get(
434
+ "https://api.openai.com/v1/organization",
435
+ headers={"Authorization": f"Bearer {api_key}"},
436
+ )
437
+ if response.status_code == 200:
438
+ org_id = response.json().get("id")
439
+ else:
440
+ raise ValueError(
441
+ f"Could not determine organization ID for API key. "
442
+ f"API returned status {response.status_code}"
443
+ )
444
+ except Exception as e:
445
+ raise ValueError(f"Could not determine organization ID: {e}")
446
+
447
+ cls._org_cache[api_key] = org_id
448
+ return org_id
449
+
450
+ @classmethod
451
+ def _get_api_keys_for_org(cls, organization_id: str) -> list[str]:
452
+ """Find all API keys that belong to the given organization."""
453
+ matching_keys = []
454
+ for api_key in cls._get_all_api_keys():
455
+ try:
456
+ org_id = cls._get_organization_id(api_key)
457
+ if org_id == organization_id:
458
+ matching_keys.append(api_key)
459
+ except Exception:
460
+ continue
461
+ return matching_keys
462
+
463
+ @staticmethod
464
+ def _get_all_api_keys() -> list[str]:
465
+ """Get all OpenAI API keys from environment (OPENAI_API_KEY and OPENAI_API_KEY_*)."""
466
+ keys = []
467
+ for env_var in os.environ:
468
+ if env_var == "OPENAI_API_KEY" or env_var.startswith("OPENAI_API_KEY_"):
469
+ key = os.environ.get(env_var)
470
+ if key:
471
+ keys.append(key)
472
+ return keys
473
+
474
+ @staticmethod
475
+ def _get_checkpoints(job_id, api_key):
476
+ # Q: why REST?
477
+ # A: because the Python client doesn't support listing checkpoints
478
+ import requests
479
+
480
+ url = f"https://api.openai.com/v1/fine_tuning/jobs/{job_id}/checkpoints"
481
+ headers = {"Authorization": f"Bearer {api_key}"}
482
+
483
+ response = requests.get(url, headers=headers)
484
+
485
+ if response.status_code == 200:
486
+ data = response.json()["data"]
487
+ data.sort(key=lambda x: x["step_number"], reverse=True)
488
+ return data
489
+ else:
490
+ print(f"Error: {response.status_code} - {response.text}")
@@ -0,0 +1,38 @@
1
+ #!/usr/bin/env python3
2
+ """Update finetuning jobs.
3
+
4
+ Usage:
5
+ llmcomp-update-jobs [DATA_DIR]
6
+ """
7
+
8
+ import argparse
9
+ import os
10
+ import sys
11
+
12
+ from llmcomp.finetuning.manager import DEFAULT_DATA_DIR, FinetuningManager
13
+
14
+
15
+ def main():
16
+ parser = argparse.ArgumentParser(description="Update finetuning jobs from OpenAI API.")
17
+ parser.add_argument(
18
+ "data_dir",
19
+ nargs="?",
20
+ default=None,
21
+ help=f"Directory containing jobs.jsonl (default: {DEFAULT_DATA_DIR} if it exists)",
22
+ )
23
+ args = parser.parse_args()
24
+
25
+ if args.data_dir is not None:
26
+ data_dir = args.data_dir
27
+ elif os.path.isdir(DEFAULT_DATA_DIR):
28
+ data_dir = DEFAULT_DATA_DIR
29
+ else:
30
+ print(f"Error: Directory '{DEFAULT_DATA_DIR}' not found.", file=sys.stderr)
31
+ print(f"Specify a data directory: llmcomp-update-jobs <DATA_DIR>", file=sys.stderr)
32
+ sys.exit(1)
33
+
34
+ FinetuningManager(data_dir=data_dir).update_jobs()
35
+
36
+
37
+ if __name__ == "__main__":
38
+ main()
@@ -1,7 +1,5 @@
1
1
  from __future__ import annotations
2
2
 
3
- import hashlib
4
- import json
5
3
  import os
6
4
  import warnings
7
5
  from abc import ABC, abstractmethod
@@ -29,10 +27,6 @@ if TYPE_CHECKING:
29
27
 
30
28
 
31
29
  class Question(ABC):
32
- # Purpose of _version: it is used in the hash function so if some important part of the implementation changes,
33
- # we can change the version here and it'll invalidate all the cached results.
34
- _version = 1
35
-
36
30
  def __init__(
37
31
  self,
38
32
  name: str | None = "__unnamed",
@@ -315,9 +309,9 @@ class Question(ABC):
315
309
  in_, out = payload
316
310
  data = results[models.index(model)]
317
311
  data[in_["_original_ix"]] = {
318
- # Deepcopy because in_["messages"] is reused for multiple models and we don't want weird
319
- # side effects if someone later edits the messages in the resulting dataframe
320
- "messages": deepcopy(in_["messages"]),
312
+ # Deepcopy because in_["params"]["messages"] is reused for multiple models
313
+ # and we don't want weird side effects if someone later edits the messages
314
+ "messages": deepcopy(in_["params"]["messages"]),
321
315
  "question": in_["_question"],
322
316
  "answer": out,
323
317
  "paraphrase_ix": in_["_paraphrase_ix"],
@@ -343,9 +337,11 @@ class Question(ABC):
343
337
  messages_set = self.as_messages()
344
338
  runner_input = []
345
339
  for paraphrase_ix, messages in enumerate(messages_set):
340
+ params = {"messages": messages}
341
+ if self.logit_bias is not None:
342
+ params["logit_bias"] = self.logit_bias
346
343
  this_input = {
347
- "messages": messages,
348
- "logit_bias": self.logit_bias,
344
+ "params": params,
349
345
  "_question": messages[-1]["content"],
350
346
  "_paraphrase_ix": paraphrase_ix,
351
347
  }
@@ -371,21 +367,6 @@ class Question(ABC):
371
367
  messages_set.append(messages)
372
368
  return messages_set
373
369
 
374
- ###########################################################################
375
- # OTHER STUFF
376
- def hash(self):
377
- """Unique identifier for caching. Changes when question parameters change.
378
-
379
- Used to determine whether we can use cached results.
380
- Excludes judges since they don't affect the raw LLM answers.
381
- """
382
- excluded = {"judges"}
383
- attributes = {k: v for k, v in self.__dict__.items() if k not in excluded}
384
- attributes["_version"] = self._version
385
- json_str = json.dumps(attributes, sort_keys=True)
386
- return hashlib.sha256(json_str.encode()).hexdigest()
387
-
388
-
389
370
  class FreeForm(Question):
390
371
  """Question type for free-form text generation.
391
372
 
@@ -440,8 +421,8 @@ class FreeForm(Question):
440
421
  def get_runner_input(self) -> list[dict]:
441
422
  runner_input = super().get_runner_input()
442
423
  for el in runner_input:
443
- el["temperature"] = self.temperature
444
- el["max_tokens"] = self.max_tokens
424
+ el["params"]["temperature"] = self.temperature
425
+ el["params"]["max_tokens"] = self.max_tokens
445
426
  return runner_input
446
427
 
447
428
  def df(self, model_groups: dict[str, list[str]]) -> pd.DataFrame:
@@ -745,7 +726,7 @@ class Rating(Question):
745
726
  def get_runner_input(self) -> list[dict]:
746
727
  runner_input = super().get_runner_input()
747
728
  for el in runner_input:
748
- el["top_logprobs"] = self.top_logprobs
729
+ el["params"]["top_logprobs"] = self.top_logprobs
749
730
  return runner_input
750
731
 
751
732
  def df(self, model_groups: dict[str, list[str]]) -> pd.DataFrame:
@@ -899,9 +880,8 @@ class NextToken(Question):
899
880
 
900
881
  def get_runner_input(self) -> list[dict]:
901
882
  runner_input = super().get_runner_input()
902
-
903
883
  for el in runner_input:
904
- el["top_logprobs"] = self.top_logprobs
884
+ el["params"]["top_logprobs"] = self.top_logprobs
905
885
  el["convert_to_probs"] = self.convert_to_probs
906
886
  el["num_samples"] = self.num_samples
907
887
  return runner_input