llmcomp 1.2.3__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,6 +4,7 @@ import os
4
4
  import openai
5
5
  import pandas as pd
6
6
 
7
+ from llmcomp.finetuning.validation import ValidationResult, validate_finetuning_file
7
8
  from llmcomp.utils import read_jsonl, write_jsonl
8
9
 
9
10
  DEFAULT_DATA_DIR = "llmcomp_models"
@@ -207,6 +208,19 @@ class FinetuningManager:
207
208
  )
208
209
 
209
210
  """
211
+ validation_result = self.validate_file(file_name)
212
+ if not validation_result.valid:
213
+ print("Invalid training file.")
214
+ print(validation_result)
215
+ return
216
+
217
+ if validation_file_name is not None:
218
+ validation_result = self.validate_file(validation_file_name)
219
+ if not validation_result.valid:
220
+ print("Invalid validation file.")
221
+ print(validation_result)
222
+ return
223
+
210
224
  if suffix is None:
211
225
  suffix = self._get_default_suffix(file_name, lr_multiplier, epochs, batch_size)
212
226
 
@@ -278,6 +292,13 @@ class FinetuningManager:
278
292
  print(f" Status: {response.status}")
279
293
  print(f"\nRun `llmcomp-update-jobs` to check progress.")
280
294
 
295
+ def validate_file(self, file_name: str) -> ValidationResult:
296
+ """Validate a JSONL file for OpenAI finetuning.
297
+
298
+ See `llmcomp.finetuning.validate_finetuning_file` for details.
299
+ """
300
+ return validate_finetuning_file(file_name)
301
+
281
302
  #########################################################
282
303
  # PRIVATE METHODS
283
304
  def _check_suffix_collision(self, suffix: str, file_name: str):
@@ -431,28 +452,14 @@ class FinetuningManager:
431
452
  return cls._org_cache[api_key]
432
453
 
433
454
  client = openai.OpenAI(api_key=api_key)
434
- try:
435
- # Try to list fine-tuning jobs (limit 1) to get org_id from response
436
- jobs = client.fine_tuning.jobs.list(limit=1)
437
- if jobs.data:
438
- org_id = jobs.data[0].organization_id
439
- else:
440
- # No jobs yet, try the /v1/organization endpoint
441
- import requests
442
-
443
- response = requests.get(
444
- "https://api.openai.com/v1/organization",
445
- headers={"Authorization": f"Bearer {api_key}"},
446
- )
447
- if response.status_code == 200:
448
- org_id = response.json().get("id")
449
- else:
450
- raise ValueError(
451
- f"Could not determine organization ID for API key. "
452
- f"API returned status {response.status_code}"
453
- )
454
- except Exception as e:
455
- raise ValueError(f"Could not determine organization ID: {e}")
455
+
456
+ # Try to list fine-tuning jobs (limit 1) to get org_id from response
457
+ jobs = client.fine_tuning.jobs.list(limit=1)
458
+ if jobs.data:
459
+ org_id = jobs.data[0].organization_id
460
+ else:
461
+ # There's no way to get the organization ID from the API key alone.
462
+ raise ValueError("First finetuning job in a new project must be created manually. See https://github.com/johny-b/llmcomp/issues/42.")
456
463
 
457
464
  cls._org_cache[api_key] = org_id
458
465
  return org_id
@@ -0,0 +1,406 @@
1
+ """Validation for OpenAI finetuning files."""
2
+
3
+ import json
4
+ import os
5
+ from dataclasses import dataclass, field
6
+
7
+ # Valid roles for OpenAI finetuning
8
+ VALID_ROLES = {"system", "user", "assistant", "tool"}
9
+
10
+ # Allowed keys per role
11
+ ALLOWED_KEYS_BY_ROLE = {
12
+ "system": {"role", "content", "name"},
13
+ "user": {"role", "content", "name"},
14
+ "assistant": {"role", "content", "name", "weight", "tool_calls"},
15
+ "tool": {"role", "content", "tool_call_id"},
16
+ }
17
+
18
+ # Minimum number of examples required by OpenAI
19
+ MIN_EXAMPLES = 10
20
+
21
+
22
+ @dataclass
23
+ class ValidationError:
24
+ """A single validation error found in a finetuning file."""
25
+
26
+ line: int # 1-based line number (0 for file-level errors)
27
+ message: str
28
+
29
+ def __str__(self):
30
+ if self.line == 0:
31
+ return self.message
32
+ return f"Line {self.line}: {self.message}"
33
+
34
+
35
+ @dataclass
36
+ class ValidationResult:
37
+ """Result of validating a finetuning file."""
38
+
39
+ valid: bool
40
+ errors: list[ValidationError] = field(default_factory=list)
41
+ warnings: list[ValidationError] = field(default_factory=list)
42
+ num_examples: int = 0
43
+
44
+ def __str__(self):
45
+ if self.valid:
46
+ return f"✓ Valid ({self.num_examples} examples)"
47
+
48
+ lines = [f"✗ Invalid file ({len(self.errors)} error(s))"]
49
+ for error in self.errors[:10]: # Show first 10 errors
50
+ lines.append(f" {error}")
51
+ if len(self.errors) > 10:
52
+ lines.append(f" ... and {len(self.errors) - 10} more errors")
53
+ return "\n".join(lines)
54
+
55
+
56
+ def validate_finetuning_file(file_name: str) -> ValidationResult:
57
+ """Validate a JSONL file for OpenAI finetuning.
58
+
59
+ Checks:
60
+ - File is valid JSONL (one JSON object per line)
61
+ - At least 10 examples (OpenAI requirement)
62
+ - Each example has a 'messages' array
63
+ - Messages have valid 'role' (system, user, assistant, tool)
64
+ - Messages only contain allowed keys for their role:
65
+ - system/user: role, content, name
66
+ - assistant: role, content, name, weight, tool_calls
67
+ - tool: role, content, tool_call_id
68
+ - Messages have valid 'content' (string or array for multimodal)
69
+ - Each example has at least one 'assistant' message
70
+ - Last message must be from 'assistant'
71
+ - 'weight' field (assistant only): must be 0 or 1, last assistant cannot be 0
72
+ - 'tool_calls' field (assistant only): validates structure (id, type, function)
73
+ - 'tool' messages require 'tool_call_id'
74
+
75
+ Args:
76
+ file_name: Path to the JSONL file to validate.
77
+
78
+ Returns:
79
+ ValidationResult with valid=True/False and any errors found.
80
+
81
+ Example:
82
+ result = validate_finetuning_file("my_dataset.jsonl")
83
+ if not result.valid:
84
+ for error in result.errors:
85
+ print(error)
86
+ """
87
+ errors: list[ValidationError] = []
88
+ warnings: list[ValidationError] = []
89
+ num_examples = 0
90
+
91
+ # Check if file exists
92
+ if not os.path.exists(file_name):
93
+ errors.append(ValidationError(0, f"File not found: {file_name}"))
94
+ result = ValidationResult(valid=False, errors=errors, num_examples=0)
95
+ return result
96
+
97
+ # Read and validate each line
98
+ with open(file_name, "r", encoding="utf-8") as f:
99
+ for line_num, line in enumerate(f, start=1):
100
+ line = line.strip()
101
+ if not line:
102
+ continue # Skip empty lines
103
+
104
+ num_examples += 1
105
+ line_errors = _validate_line(line, line_num)
106
+ errors.extend(line_errors)
107
+
108
+ # Check minimum examples
109
+ if num_examples < MIN_EXAMPLES:
110
+ errors.append(
111
+ ValidationError(
112
+ 0,
113
+ f"File has {num_examples} examples, but OpenAI requires at least {MIN_EXAMPLES}.",
114
+ )
115
+ )
116
+
117
+ result = ValidationResult(
118
+ valid=len(errors) == 0,
119
+ errors=errors,
120
+ warnings=warnings,
121
+ num_examples=num_examples,
122
+ )
123
+
124
+ return result
125
+
126
+
127
+ def _validate_line(line: str, line_num: int) -> list[ValidationError]:
128
+ """Validate a single line of the JSONL file."""
129
+ errors: list[ValidationError] = []
130
+
131
+ # Parse JSON
132
+ try:
133
+ data = json.loads(line)
134
+ except json.JSONDecodeError as e:
135
+ errors.append(ValidationError(line_num, f"Invalid JSON: {e}"))
136
+ return errors
137
+
138
+ if not isinstance(data, dict):
139
+ errors.append(ValidationError(line_num, "Each line must be a JSON object"))
140
+ return errors
141
+
142
+ # Check for 'messages' key
143
+ if "messages" not in data:
144
+ errors.append(ValidationError(line_num, "Missing 'messages' key"))
145
+ return errors
146
+
147
+ messages = data["messages"]
148
+ if not isinstance(messages, list):
149
+ errors.append(ValidationError(line_num, "'messages' must be an array"))
150
+ return errors
151
+
152
+ if len(messages) == 0:
153
+ errors.append(ValidationError(line_num, "'messages' array is empty"))
154
+ return errors
155
+
156
+ # Validate each message
157
+ has_assistant = False
158
+ last_assistant_idx = -1
159
+ for i, msg in enumerate(messages):
160
+ msg_errors = _validate_message(msg, line_num, i)
161
+ errors.extend(msg_errors)
162
+
163
+ if isinstance(msg, dict) and msg.get("role") == "assistant":
164
+ has_assistant = True
165
+ last_assistant_idx = i
166
+
167
+ # Check for at least one assistant message
168
+ if not has_assistant:
169
+ errors.append(
170
+ ValidationError(
171
+ line_num,
172
+ "No 'assistant' message found. Each example needs at least one assistant response.",
173
+ )
174
+ )
175
+
176
+ # Check last message is from assistant (not user, system, or tool)
177
+ if len(messages) > 0:
178
+ last_msg = messages[-1]
179
+ last_role = last_msg.get("role") if isinstance(last_msg, dict) else None
180
+ if last_role != "assistant":
181
+ errors.append(
182
+ ValidationError(
183
+ line_num,
184
+ f"Last message must be from 'assistant', got '{last_role}'.",
185
+ )
186
+ )
187
+
188
+ # Check last assistant message doesn't have weight=0
189
+ if last_assistant_idx >= 0:
190
+ last_assistant_msg = messages[last_assistant_idx]
191
+ if isinstance(last_assistant_msg, dict) and last_assistant_msg.get("weight") == 0:
192
+ errors.append(
193
+ ValidationError(
194
+ line_num,
195
+ "Last assistant message cannot have weight=0.",
196
+ )
197
+ )
198
+
199
+ return errors
200
+
201
+
202
+ def _validate_message(msg: dict, line_num: int, msg_idx: int) -> list[ValidationError]:
203
+ """Validate a single message within an example."""
204
+ errors: list[ValidationError] = []
205
+ prefix = f"messages[{msg_idx}]"
206
+
207
+ if not isinstance(msg, dict):
208
+ errors.append(ValidationError(line_num, f"{prefix}: must be an object"))
209
+ return errors
210
+
211
+ # Check 'role'
212
+ role = None
213
+ if "role" not in msg:
214
+ errors.append(ValidationError(line_num, f"{prefix}: missing 'role'"))
215
+ else:
216
+ role = msg["role"]
217
+ if role not in VALID_ROLES:
218
+ errors.append(
219
+ ValidationError(
220
+ line_num,
221
+ f"{prefix}: invalid role '{role}'. Must be one of: {', '.join(sorted(VALID_ROLES))}",
222
+ )
223
+ )
224
+
225
+ # Check for unknown keys (only if role is valid)
226
+ if role in ALLOWED_KEYS_BY_ROLE:
227
+ allowed_keys = ALLOWED_KEYS_BY_ROLE[role]
228
+ unknown_keys = set(msg.keys()) - allowed_keys
229
+ if unknown_keys:
230
+ errors.append(
231
+ ValidationError(
232
+ line_num,
233
+ f"{prefix}: unknown key(s) for role '{role}': {', '.join(sorted(unknown_keys))}. "
234
+ f"Allowed: {', '.join(sorted(allowed_keys))}",
235
+ )
236
+ )
237
+
238
+ # Check 'content'
239
+ if "content" not in msg:
240
+ # Content can be omitted if there's a tool_calls field (assistant only)
241
+ if role != "assistant" or "tool_calls" not in msg:
242
+ errors.append(ValidationError(line_num, f"{prefix}: missing 'content'"))
243
+ else:
244
+ content = msg["content"]
245
+ content_errors = _validate_content(content, line_num, prefix, role)
246
+ errors.extend(content_errors)
247
+
248
+ # Role-specific validation
249
+ if role == "assistant":
250
+ errors.extend(_validate_assistant_message(msg, line_num, prefix))
251
+ elif role == "tool":
252
+ errors.extend(_validate_tool_message(msg, line_num, prefix))
253
+
254
+ return errors
255
+
256
+
257
+ def _validate_assistant_message(msg: dict, line_num: int, prefix: str) -> list[ValidationError]:
258
+ """Validate assistant-specific fields."""
259
+ errors: list[ValidationError] = []
260
+
261
+ # Check 'weight' field
262
+ if "weight" in msg:
263
+ weight = msg["weight"]
264
+ if weight not in (0, 1):
265
+ errors.append(
266
+ ValidationError(
267
+ line_num,
268
+ f"{prefix}: 'weight' must be 0 or 1, got {weight!r}",
269
+ )
270
+ )
271
+
272
+ # Check 'tool_calls' field
273
+ if "tool_calls" in msg:
274
+ tool_calls = msg["tool_calls"]
275
+ if not isinstance(tool_calls, list):
276
+ errors.append(ValidationError(line_num, f"{prefix}: 'tool_calls' must be an array"))
277
+ elif len(tool_calls) == 0:
278
+ errors.append(ValidationError(line_num, f"{prefix}: 'tool_calls' array is empty"))
279
+ else:
280
+ for i, tc in enumerate(tool_calls):
281
+ errors.extend(_validate_tool_call(tc, line_num, f"{prefix}.tool_calls[{i}]"))
282
+
283
+ return errors
284
+
285
+
286
+ def _validate_tool_call(tc: dict, line_num: int, prefix: str) -> list[ValidationError]:
287
+ """Validate a single tool_call object."""
288
+ errors: list[ValidationError] = []
289
+
290
+ if not isinstance(tc, dict):
291
+ errors.append(ValidationError(line_num, f"{prefix}: must be an object"))
292
+ return errors
293
+
294
+ # Required fields: id, type, function
295
+ if "id" not in tc:
296
+ errors.append(ValidationError(line_num, f"{prefix}: missing 'id'"))
297
+ elif not isinstance(tc["id"], str):
298
+ errors.append(ValidationError(line_num, f"{prefix}: 'id' must be a string"))
299
+
300
+ if "type" not in tc:
301
+ errors.append(ValidationError(line_num, f"{prefix}: missing 'type'"))
302
+ elif tc["type"] != "function":
303
+ errors.append(ValidationError(line_num, f"{prefix}: 'type' must be 'function'"))
304
+
305
+ if "function" not in tc:
306
+ errors.append(ValidationError(line_num, f"{prefix}: missing 'function'"))
307
+ elif not isinstance(tc["function"], dict):
308
+ errors.append(ValidationError(line_num, f"{prefix}: 'function' must be an object"))
309
+ else:
310
+ func = tc["function"]
311
+ if "name" not in func:
312
+ errors.append(ValidationError(line_num, f"{prefix}.function: missing 'name'"))
313
+ elif not isinstance(func["name"], str):
314
+ errors.append(ValidationError(line_num, f"{prefix}.function: 'name' must be a string"))
315
+
316
+ if "arguments" not in func:
317
+ errors.append(ValidationError(line_num, f"{prefix}.function: missing 'arguments'"))
318
+ elif not isinstance(func["arguments"], str):
319
+ errors.append(ValidationError(line_num, f"{prefix}.function: 'arguments' must be a string"))
320
+
321
+ return errors
322
+
323
+
324
+ def _validate_tool_message(msg: dict, line_num: int, prefix: str) -> list[ValidationError]:
325
+ """Validate tool message fields."""
326
+ errors: list[ValidationError] = []
327
+
328
+ if "tool_call_id" not in msg:
329
+ errors.append(ValidationError(line_num, f"{prefix}: missing 'tool_call_id'"))
330
+ elif not isinstance(msg["tool_call_id"], str):
331
+ errors.append(ValidationError(line_num, f"{prefix}: 'tool_call_id' must be a string"))
332
+
333
+ return errors
334
+
335
+
336
+ def _validate_content(
337
+ content, line_num: int, prefix: str, role: str | None
338
+ ) -> list[ValidationError]:
339
+ """Validate message content."""
340
+ errors: list[ValidationError] = []
341
+
342
+ # Content can be a string
343
+ if isinstance(content, str):
344
+ return errors
345
+
346
+ # Content can be None (for assistant messages with tool_calls)
347
+ if content is None:
348
+ return errors
349
+
350
+ # Content can be an array (for vision/multimodal)
351
+ if isinstance(content, list):
352
+ for i, part in enumerate(content):
353
+ part_errors = _validate_content_part(part, line_num, f"{prefix}.content[{i}]", role)
354
+ errors.extend(part_errors)
355
+ return errors
356
+
357
+ errors.append(
358
+ ValidationError(
359
+ line_num, f"{prefix}: 'content' must be a string or array, got {type(content).__name__}"
360
+ )
361
+ )
362
+ return errors
363
+
364
+
365
+ def _validate_content_part(
366
+ part, line_num: int, prefix: str, role: str | None
367
+ ) -> list[ValidationError]:
368
+ """Validate a single content part (for multimodal content)."""
369
+ errors: list[ValidationError] = []
370
+
371
+ if not isinstance(part, dict):
372
+ errors.append(ValidationError(line_num, f"{prefix}: must be an object"))
373
+ return errors
374
+
375
+ if "type" not in part:
376
+ errors.append(ValidationError(line_num, f"{prefix}: missing 'type'"))
377
+ return errors
378
+
379
+ part_type = part["type"]
380
+
381
+ if part_type == "text":
382
+ if "text" not in part:
383
+ errors.append(ValidationError(line_num, f"{prefix}: missing 'text' for type='text'"))
384
+ elif not isinstance(part["text"], str):
385
+ errors.append(ValidationError(line_num, f"{prefix}: 'text' must be a string"))
386
+
387
+ elif part_type == "image_url":
388
+ if "image_url" not in part:
389
+ errors.append(
390
+ ValidationError(line_num, f"{prefix}: missing 'image_url' for type='image_url'")
391
+ )
392
+ elif not isinstance(part["image_url"], dict):
393
+ errors.append(ValidationError(line_num, f"{prefix}: 'image_url' must be an object"))
394
+ elif "url" not in part["image_url"]:
395
+ errors.append(ValidationError(line_num, f"{prefix}: 'image_url' missing 'url'"))
396
+
397
+ # Assistant messages cannot contain images
398
+ if role == "assistant":
399
+ errors.append(
400
+ ValidationError(
401
+ line_num,
402
+ f"{prefix}: assistant messages cannot contain images",
403
+ )
404
+ )
405
+
406
+ return errors
llmcomp/question/judge.py CHANGED
@@ -31,6 +31,17 @@ class JudgeMixin:
31
31
  """Validate judge-specific constraints."""
32
32
  assert len(self.paraphrases) == 1, "Judge question must have exactly one paraphrase"
33
33
  assert self.samples_per_paraphrase == 1, "Judge question must have exactly one sample per paraphrase"
34
+
35
+ # Check that the template contains {answer} placeholder
36
+ formatter = string.Formatter()
37
+ field_names = [
38
+ field_name for _, field_name, _, _ in formatter.parse(self.paraphrases[0]) if field_name is not None
39
+ ]
40
+ if "answer" not in field_names:
41
+ raise ValueError(
42
+ f"Judge template must contain {{answer}} placeholder. "
43
+ f"Got: {self.paraphrases[0]!r}"
44
+ )
34
45
 
35
46
  def _load_cache_data(self) -> list[dict]:
36
47
  """Load cache and return list of row dicts with question, answer, judge_question, judge_answer.