kiln-ai 0.8.1__py3-none-any.whl → 0.11.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (57) hide show
  1. kiln_ai/adapters/__init__.py +7 -7
  2. kiln_ai/adapters/adapter_registry.py +77 -5
  3. kiln_ai/adapters/data_gen/data_gen_task.py +3 -3
  4. kiln_ai/adapters/data_gen/test_data_gen_task.py +23 -3
  5. kiln_ai/adapters/fine_tune/base_finetune.py +5 -1
  6. kiln_ai/adapters/fine_tune/dataset_formatter.py +310 -65
  7. kiln_ai/adapters/fine_tune/fireworks_finetune.py +47 -32
  8. kiln_ai/adapters/fine_tune/openai_finetune.py +12 -11
  9. kiln_ai/adapters/fine_tune/test_base_finetune.py +19 -0
  10. kiln_ai/adapters/fine_tune/test_dataset_formatter.py +469 -129
  11. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +113 -21
  12. kiln_ai/adapters/fine_tune/test_openai_finetune.py +125 -14
  13. kiln_ai/adapters/ml_model_list.py +323 -94
  14. kiln_ai/adapters/model_adapters/__init__.py +18 -0
  15. kiln_ai/adapters/{base_adapter.py → model_adapters/base_adapter.py} +81 -37
  16. kiln_ai/adapters/{langchain_adapters.py → model_adapters/langchain_adapters.py} +130 -84
  17. kiln_ai/adapters/model_adapters/openai_compatible_config.py +11 -0
  18. kiln_ai/adapters/model_adapters/openai_model_adapter.py +246 -0
  19. kiln_ai/adapters/model_adapters/test_base_adapter.py +190 -0
  20. kiln_ai/adapters/{test_langchain_adapter.py → model_adapters/test_langchain_adapter.py} +103 -88
  21. kiln_ai/adapters/model_adapters/test_openai_model_adapter.py +225 -0
  22. kiln_ai/adapters/{test_saving_adapter_results.py → model_adapters/test_saving_adapter_results.py} +43 -15
  23. kiln_ai/adapters/{test_structured_output.py → model_adapters/test_structured_output.py} +93 -20
  24. kiln_ai/adapters/parsers/__init__.py +10 -0
  25. kiln_ai/adapters/parsers/base_parser.py +12 -0
  26. kiln_ai/adapters/parsers/json_parser.py +37 -0
  27. kiln_ai/adapters/parsers/parser_registry.py +19 -0
  28. kiln_ai/adapters/parsers/r1_parser.py +69 -0
  29. kiln_ai/adapters/parsers/test_json_parser.py +81 -0
  30. kiln_ai/adapters/parsers/test_parser_registry.py +32 -0
  31. kiln_ai/adapters/parsers/test_r1_parser.py +144 -0
  32. kiln_ai/adapters/prompt_builders.py +126 -20
  33. kiln_ai/adapters/provider_tools.py +91 -36
  34. kiln_ai/adapters/repair/repair_task.py +17 -6
  35. kiln_ai/adapters/repair/test_repair_task.py +4 -4
  36. kiln_ai/adapters/run_output.py +8 -0
  37. kiln_ai/adapters/test_adapter_registry.py +177 -0
  38. kiln_ai/adapters/test_generate_docs.py +69 -0
  39. kiln_ai/adapters/test_prompt_adaptors.py +8 -4
  40. kiln_ai/adapters/test_prompt_builders.py +190 -29
  41. kiln_ai/adapters/test_provider_tools.py +268 -46
  42. kiln_ai/datamodel/__init__.py +193 -12
  43. kiln_ai/datamodel/basemodel.py +31 -11
  44. kiln_ai/datamodel/json_schema.py +8 -3
  45. kiln_ai/datamodel/model_cache.py +8 -3
  46. kiln_ai/datamodel/test_basemodel.py +81 -2
  47. kiln_ai/datamodel/test_dataset_split.py +100 -3
  48. kiln_ai/datamodel/test_example_models.py +25 -4
  49. kiln_ai/datamodel/test_model_cache.py +24 -0
  50. kiln_ai/datamodel/test_model_perf.py +125 -0
  51. kiln_ai/datamodel/test_models.py +129 -0
  52. kiln_ai/utils/exhaustive_error.py +6 -0
  53. {kiln_ai-0.8.1.dist-info → kiln_ai-0.11.1.dist-info}/METADATA +9 -7
  54. kiln_ai-0.11.1.dist-info/RECORD +76 -0
  55. kiln_ai-0.8.1.dist-info/RECORD +0 -58
  56. {kiln_ai-0.8.1.dist-info → kiln_ai-0.11.1.dist-info}/WHEEL +0 -0
  57. {kiln_ai-0.8.1.dist-info → kiln_ai-0.11.1.dist-info}/licenses/LICENSE.txt +0 -0
@@ -1,11 +1,13 @@
1
1
  import json
2
2
  import tempfile
3
+ from dataclasses import dataclass
3
4
  from enum import Enum
4
5
  from pathlib import Path
5
6
  from typing import Any, Dict, Protocol
6
7
  from uuid import uuid4
7
8
 
8
- from kiln_ai.datamodel import DatasetSplit, TaskRun
9
+ from kiln_ai.adapters.model_adapters.base_adapter import COT_FINAL_ANSWER_PROMPT
10
+ from kiln_ai.datamodel import DatasetSplit, FinetuneDataStrategy, TaskRun
9
11
 
10
12
 
11
13
  class DatasetFormat(str, Enum):
@@ -14,6 +16,9 @@ class DatasetFormat(str, Enum):
14
16
  """OpenAI chat format with plaintext response"""
15
17
  OPENAI_CHAT_JSONL = "openai_chat_jsonl"
16
18
 
19
+ """OpenAI chat format with json response_format"""
20
+ OPENAI_CHAT_JSON_SCHEMA_JSONL = "openai_chat_json_schema_jsonl"
21
+
17
22
  """OpenAI chat format with tool call response"""
18
23
  OPENAI_CHAT_TOOLCALL_JSONL = "openai_chat_toolcall_jsonl"
19
24
 
@@ -25,116 +30,338 @@ class DatasetFormat(str, Enum):
25
30
  "huggingface_chat_template_toolcall_jsonl"
26
31
  )
27
32
 
33
+ """Vertex Gemini 1.5 format (flash and pro)"""
34
+ VERTEX_GEMINI_1_5 = "vertex_gemini_1_5"
35
+
36
+
37
+ @dataclass
38
+ class ModelTrainingData:
39
+ input: str
40
+ system_message: str
41
+ final_output: str
42
+ # These 3 are optional, and used for COT/Thinking style multi-message responses
43
+ thinking_instructions: str | None = None
44
+ thinking: str | None = None
45
+ thinking_final_answer_prompt: str | None = None
46
+
47
+ def supports_cot(self) -> bool:
48
+ return (
49
+ self.thinking_instructions is not None
50
+ and self.thinking is not None
51
+ and self.thinking_final_answer_prompt is not None
52
+ )
53
+
28
54
 
29
55
  class FormatGenerator(Protocol):
30
56
  """Protocol for format generators"""
31
57
 
32
- def __call__(self, task_run: TaskRun, system_message: str) -> Dict[str, Any]: ...
58
+ def __call__(
59
+ self,
60
+ training_data: ModelTrainingData,
61
+ ) -> Dict[str, Any]: ...
62
+
63
+
64
+ def build_training_data(
65
+ task_run: TaskRun,
66
+ system_message: str,
67
+ include_cot: bool,
68
+ thinking_instructions: str | None = None,
69
+ ) -> ModelTrainingData:
70
+ """
71
+ Generate data for training.
72
+
73
+ For final output, get the best task output from the task run, preferring repaired output if available.
74
+
75
+ For thinking, get the intermediate output if it exists, otherwise return None.
76
+ """
77
+ final_output = task_run.output.output
78
+ if task_run.repaired_output is not None:
79
+ final_output = task_run.repaired_output.output
80
+
81
+ thinking = None
82
+ thinking_final_answer_prompt = None
83
+ parent_task = task_run.parent_task()
84
+
85
+ if include_cot and task_run.has_thinking_training_data():
86
+ if not parent_task:
87
+ raise ValueError(
88
+ "TaskRuns for training required a parent Task for building a chain of thought prompts. Train without COT, or save this TaskRun to a parent Task."
89
+ )
90
+
91
+ # Prefer reasoning to cot if both are present
92
+ intermediate_outputs = task_run.intermediate_outputs or {}
93
+ thinking = intermediate_outputs.get("reasoning") or intermediate_outputs.get(
94
+ "chain_of_thought"
95
+ )
96
+
97
+ thinking_final_answer_prompt = COT_FINAL_ANSWER_PROMPT
98
+
99
+ # Always use the passed thinking instructions, but check they are present for COT
100
+ if not thinking_instructions:
101
+ raise ValueError(
102
+ "Thinking instructions are required when data_strategy is final_and_intermediate"
103
+ )
104
+
105
+ return ModelTrainingData(
106
+ input=task_run.input,
107
+ system_message=system_message,
108
+ final_output=final_output,
109
+ thinking=thinking,
110
+ thinking_instructions=thinking_instructions,
111
+ thinking_final_answer_prompt=thinking_final_answer_prompt,
112
+ )
33
113
 
34
114
 
35
115
  def generate_chat_message_response(
36
- task_run: TaskRun, system_message: str
116
+ training_data: ModelTrainingData,
37
117
  ) -> Dict[str, Any]:
38
118
  """Generate OpenAI chat format with plaintext response"""
39
- return {
40
- "messages": [
41
- {"role": "system", "content": system_message},
42
- {"role": "user", "content": task_run.input},
43
- {"role": "assistant", "content": task_run.output.output},
44
- ]
45
- }
119
+
120
+ messages: list[dict[str, str | None]] = [
121
+ {"role": "system", "content": training_data.system_message},
122
+ {"role": "user", "content": training_data.input},
123
+ ]
124
+
125
+ if training_data.supports_cot():
126
+ messages.extend(
127
+ [
128
+ {"role": "user", "content": training_data.thinking_instructions},
129
+ {"role": "assistant", "content": training_data.thinking},
130
+ {
131
+ "role": "user",
132
+ "content": training_data.thinking_final_answer_prompt,
133
+ },
134
+ ]
135
+ )
136
+
137
+ messages.append({"role": "assistant", "content": training_data.final_output})
138
+
139
+ return {"messages": messages}
140
+
141
+
142
+ def generate_json_schema_message(
143
+ training_data: ModelTrainingData,
144
+ ) -> Dict[str, Any]:
145
+ """Generate OpenAI chat format with validated JSON response"""
146
+ # Load and dump to ensure it's valid JSON and goes to 1 line
147
+ try:
148
+ json_data = json.loads(training_data.final_output)
149
+ except json.JSONDecodeError as e:
150
+ raise ValueError(
151
+ f"Invalid JSON in JSON Schema training set: {e}\nOutput Data: {training_data.final_output}"
152
+ ) from e
153
+ json_string = json.dumps(json_data, ensure_ascii=False)
154
+
155
+ messages: list[dict[str, str | None]] = [
156
+ {"role": "system", "content": training_data.system_message},
157
+ {"role": "user", "content": training_data.input},
158
+ ]
159
+
160
+ if training_data.supports_cot():
161
+ messages.extend(
162
+ [
163
+ {"role": "user", "content": training_data.thinking_instructions},
164
+ {"role": "assistant", "content": training_data.thinking},
165
+ {
166
+ "role": "user",
167
+ "content": training_data.thinking_final_answer_prompt,
168
+ },
169
+ ]
170
+ )
171
+
172
+ messages.append({"role": "assistant", "content": json_string})
173
+
174
+ return {"messages": messages}
46
175
 
47
176
 
48
177
  def generate_chat_message_toolcall(
49
- task_run: TaskRun, system_message: str
178
+ training_data: ModelTrainingData,
50
179
  ) -> Dict[str, Any]:
51
180
  """Generate OpenAI chat format with tool call response"""
52
181
  try:
53
- arguments = json.loads(task_run.output.output)
182
+ arguments = json.loads(training_data.final_output)
54
183
  except json.JSONDecodeError as e:
55
184
  raise ValueError(f"Invalid JSON in for tool call: {e}") from e
56
185
 
57
- return {
58
- "messages": [
59
- {"role": "system", "content": system_message},
60
- {"role": "user", "content": task_run.input},
61
- {
62
- "role": "assistant",
63
- "content": None,
64
- "tool_calls": [
65
- {
66
- "id": "call_1",
67
- "type": "function",
68
- "function": {
69
- "name": "task_response",
70
- # Yes we parse then dump again. This ensures it's valid JSON, and ensures it goes to 1 line
71
- "arguments": json.dumps(arguments),
72
- },
73
- }
74
- ],
75
- },
76
- ]
77
- }
186
+ messages: list[dict[str, Any]] = [
187
+ {"role": "system", "content": training_data.system_message},
188
+ {"role": "user", "content": training_data.input},
189
+ ]
190
+
191
+ if training_data.supports_cot():
192
+ messages.extend(
193
+ [
194
+ {"role": "user", "content": training_data.thinking_instructions},
195
+ {"role": "assistant", "content": training_data.thinking},
196
+ {
197
+ "role": "user",
198
+ "content": training_data.thinking_final_answer_prompt,
199
+ },
200
+ ]
201
+ )
202
+
203
+ messages.append(
204
+ {
205
+ "role": "assistant",
206
+ "content": None,
207
+ "tool_calls": [
208
+ {
209
+ "id": "call_1",
210
+ "type": "function",
211
+ "function": {
212
+ "name": "task_response",
213
+ # Yes we parse then dump again. This ensures it's valid JSON, and ensures it goes to 1 line
214
+ "arguments": json.dumps(arguments, ensure_ascii=False),
215
+ },
216
+ }
217
+ ],
218
+ },
219
+ )
220
+
221
+ return {"messages": messages}
78
222
 
79
223
 
80
224
  def generate_huggingface_chat_template(
81
- task_run: TaskRun, system_message: str
225
+ training_data: ModelTrainingData,
82
226
  ) -> Dict[str, Any]:
83
227
  """Generate HuggingFace chat template"""
84
- return {
85
- "conversations": [
86
- {"role": "system", "content": system_message},
87
- {"role": "user", "content": task_run.input},
88
- {"role": "assistant", "content": task_run.output.output},
89
- ]
90
- }
228
+
229
+ conversations: list[dict[str, Any]] = [
230
+ {"role": "system", "content": training_data.system_message},
231
+ {"role": "user", "content": training_data.input},
232
+ ]
233
+
234
+ if training_data.supports_cot():
235
+ conversations.extend(
236
+ [
237
+ {"role": "user", "content": training_data.thinking_instructions},
238
+ {"role": "assistant", "content": training_data.thinking},
239
+ {"role": "user", "content": training_data.thinking_final_answer_prompt},
240
+ ]
241
+ )
242
+
243
+ conversations.append({"role": "assistant", "content": training_data.final_output})
244
+
245
+ return {"conversations": conversations}
91
246
 
92
247
 
93
248
  def generate_huggingface_chat_template_toolcall(
94
- task_run: TaskRun, system_message: str
249
+ training_data: ModelTrainingData,
95
250
  ) -> Dict[str, Any]:
96
251
  """Generate HuggingFace chat template with tool calls"""
97
252
  try:
98
- arguments = json.loads(task_run.output.output)
253
+ arguments = json.loads(training_data.final_output)
99
254
  except json.JSONDecodeError as e:
100
255
  raise ValueError(f"Invalid JSON in for tool call: {e}") from e
101
256
 
102
257
  # See https://huggingface.co/docs/transformers/en/chat_templating
258
+ conversations: list[dict[str, Any]] = [
259
+ {"role": "system", "content": training_data.system_message},
260
+ {"role": "user", "content": training_data.input},
261
+ ]
262
+
263
+ if training_data.supports_cot():
264
+ conversations.extend(
265
+ [
266
+ {"role": "user", "content": training_data.thinking_instructions},
267
+ {"role": "assistant", "content": training_data.thinking},
268
+ {"role": "user", "content": training_data.thinking_final_answer_prompt},
269
+ ]
270
+ )
271
+
272
+ conversations.append(
273
+ {
274
+ "role": "assistant",
275
+ "tool_calls": [
276
+ {
277
+ "type": "function",
278
+ "function": {
279
+ "name": "task_response",
280
+ "id": str(uuid4()).replace("-", "")[:9],
281
+ "arguments": arguments,
282
+ },
283
+ }
284
+ ],
285
+ },
286
+ )
287
+
288
+ return {"conversations": conversations}
289
+
290
+
291
+ def generate_vertex_gemini_1_5(
292
+ training_data: ModelTrainingData,
293
+ ) -> Dict[str, Any]:
294
+ """Generate Vertex Gemini 1.5 format (flash and pro)"""
295
+ # See https://cloud.google.com/vertex-ai/generative-ai/docs/models/gemini-supervised-tuning-prepare
296
+
297
+ contents = [
298
+ {
299
+ "role": "user",
300
+ "parts": [
301
+ {
302
+ "text": training_data.input,
303
+ }
304
+ ],
305
+ }
306
+ ]
307
+
308
+ if training_data.supports_cot():
309
+ contents.extend(
310
+ [
311
+ {
312
+ "role": "user",
313
+ "parts": [{"text": training_data.thinking_instructions}],
314
+ },
315
+ {"role": "model", "parts": [{"text": training_data.thinking}]},
316
+ {
317
+ "role": "user",
318
+ "parts": [{"text": training_data.thinking_final_answer_prompt}],
319
+ },
320
+ ]
321
+ )
322
+
323
+ contents.append(
324
+ {
325
+ "role": "model",
326
+ "parts": [{"text": training_data.final_output}],
327
+ }
328
+ )
329
+
103
330
  return {
104
- "conversations": [
105
- {"role": "system", "content": system_message},
106
- {"role": "user", "content": task_run.input},
107
- {
108
- "role": "assistant",
109
- "tool_calls": [
110
- {
111
- "type": "function",
112
- "function": {
113
- "name": "task_response",
114
- "id": str(uuid4()).replace("-", "")[:9],
115
- "arguments": arguments,
116
- },
117
- }
118
- ],
119
- },
120
- ]
331
+ "systemInstruction": {
332
+ "role": "system",
333
+ "parts": [
334
+ {
335
+ "text": training_data.system_message,
336
+ }
337
+ ],
338
+ },
339
+ "contents": contents,
121
340
  }
122
341
 
123
342
 
124
343
  FORMAT_GENERATORS: Dict[DatasetFormat, FormatGenerator] = {
125
344
  DatasetFormat.OPENAI_CHAT_JSONL: generate_chat_message_response,
345
+ DatasetFormat.OPENAI_CHAT_JSON_SCHEMA_JSONL: generate_json_schema_message,
126
346
  DatasetFormat.OPENAI_CHAT_TOOLCALL_JSONL: generate_chat_message_toolcall,
127
347
  DatasetFormat.HUGGINGFACE_CHAT_TEMPLATE_JSONL: generate_huggingface_chat_template,
128
348
  DatasetFormat.HUGGINGFACE_CHAT_TEMPLATE_TOOLCALL_JSONL: generate_huggingface_chat_template_toolcall,
349
+ DatasetFormat.VERTEX_GEMINI_1_5: generate_vertex_gemini_1_5,
129
350
  }
130
351
 
131
352
 
132
353
  class DatasetFormatter:
133
354
  """Handles formatting of datasets into various output formats"""
134
355
 
135
- def __init__(self, dataset: DatasetSplit, system_message: str):
356
+ def __init__(
357
+ self,
358
+ dataset: DatasetSplit,
359
+ system_message: str,
360
+ thinking_instructions: str | None = None,
361
+ ):
136
362
  self.dataset = dataset
137
363
  self.system_message = system_message
364
+ self.thinking_instructions = thinking_instructions
138
365
 
139
366
  task = dataset.parent_task()
140
367
  if task is None:
@@ -142,7 +369,11 @@ class DatasetFormatter:
142
369
  self.task = task
143
370
 
144
371
  def dump_to_file(
145
- self, split_name: str, format_type: DatasetFormat, path: Path | None = None
372
+ self,
373
+ split_name: str,
374
+ format_type: DatasetFormat,
375
+ data_strategy: FinetuneDataStrategy,
376
+ path: Path | None = None,
146
377
  ) -> Path:
147
378
  """
148
379
  Format the dataset into the specified format.
@@ -154,6 +385,10 @@ class DatasetFormatter:
154
385
 
155
386
  Returns:
156
387
  Path to the generated file
388
+
389
+ Note:
390
+ The output is written in UTF-8 encoding with ensure_ascii=False to properly
391
+ support international text content while maintaining readability.
157
392
  """
158
393
  if format_type not in FORMAT_GENERATORS:
159
394
  raise ValueError(f"Unsupported format: {format_type}")
@@ -162,11 +397,13 @@ class DatasetFormatter:
162
397
 
163
398
  generator = FORMAT_GENERATORS[format_type]
164
399
 
400
+ include_cot = data_strategy == FinetuneDataStrategy.final_and_intermediate
401
+
165
402
  # Write to a temp file if no path is provided
166
403
  output_path = (
167
404
  path
168
405
  or Path(tempfile.gettempdir())
169
- / f"{self.dataset.name}_{split_name}_{format_type}.jsonl"
406
+ / f"{self.dataset.name} -- split-{split_name} -- format-{format_type.value} -- {'cot' if include_cot else 'no-cot'}.jsonl"
170
407
  )
171
408
 
172
409
  runs = self.task.runs()
@@ -181,7 +418,15 @@ class DatasetFormatter:
181
418
  f"Task run {run_id} not found. This is required by this dataset."
182
419
  )
183
420
 
184
- example = generator(task_run, self.system_message)
185
- f.write(json.dumps(example) + "\n")
421
+ training_data = build_training_data(
422
+ task_run=task_run,
423
+ system_message=self.system_message,
424
+ include_cot=include_cot,
425
+ thinking_instructions=self.thinking_instructions,
426
+ )
427
+ example = generator(training_data)
428
+ # Allow non-ascii characters in the dataset.
429
+ # Better readability for non-English users. If you don't support UTF-8... you should.
430
+ f.write(json.dumps(example, ensure_ascii=False) + "\n")
186
431
 
187
432
  return output_path
@@ -1,3 +1,4 @@
1
+ from typing import Tuple
1
2
  from uuid import uuid4
2
3
 
3
4
  import httpx
@@ -9,7 +10,7 @@ from kiln_ai.adapters.fine_tune.base_finetune import (
9
10
  FineTuneStatusType,
10
11
  )
11
12
  from kiln_ai.adapters.fine_tune.dataset_formatter import DatasetFormat, DatasetFormatter
12
- from kiln_ai.datamodel import DatasetSplit, Task
13
+ from kiln_ai.datamodel import DatasetSplit, StructuredOutputMode, Task
13
14
  from kiln_ai.utils.config import Config
14
15
 
15
16
 
@@ -19,7 +20,7 @@ class FireworksFinetune(BaseFinetuneAdapter):
19
20
  """
20
21
 
21
22
  async def status(self) -> FineTuneStatus:
22
- status = await self._status()
23
+ status, _ = await self._status()
23
24
  # update the datamodel if the status has changed
24
25
  if self.datamodel.latest_status != status.status:
25
26
  self.datamodel.latest_status = status.status
@@ -34,7 +35,7 @@ class FireworksFinetune(BaseFinetuneAdapter):
34
35
 
35
36
  return status
36
37
 
37
- async def _status(self) -> FineTuneStatus:
38
+ async def _status(self) -> Tuple[FineTuneStatus, str | None]:
38
39
  try:
39
40
  api_key = Config.shared().fireworks_api_key
40
41
  account_id = Config.shared().fireworks_account_id
@@ -42,13 +43,13 @@ class FireworksFinetune(BaseFinetuneAdapter):
42
43
  return FineTuneStatus(
43
44
  status=FineTuneStatusType.unknown,
44
45
  message="Fireworks API key or account ID not set",
45
- )
46
+ ), None
46
47
  fine_tuning_job_id = self.datamodel.provider_id
47
48
  if not fine_tuning_job_id:
48
49
  return FineTuneStatus(
49
50
  status=FineTuneStatusType.unknown,
50
51
  message="Fine-tuning job ID not set. Can not retrieve status.",
51
- )
52
+ ), None
52
53
  # Fireworks uses path style IDs
53
54
  url = f"https://api.fireworks.ai/v1/{fine_tuning_job_id}"
54
55
  headers = {"Authorization": f"Bearer {api_key}"}
@@ -60,49 +61,63 @@ class FireworksFinetune(BaseFinetuneAdapter):
60
61
  return FineTuneStatus(
61
62
  status=FineTuneStatusType.unknown,
62
63
  message=f"Error retrieving fine-tuning job status: [{response.status_code}] {response.text}",
63
- )
64
+ ), None
64
65
  data = response.json()
66
+ model_id = data.get("outputModel")
65
67
 
66
68
  if "state" not in data:
67
69
  return FineTuneStatus(
68
70
  status=FineTuneStatusType.unknown,
69
71
  message="Invalid response from Fireworks (no state).",
70
- )
72
+ ), model_id
71
73
 
72
74
  state = data["state"]
73
- if state in ["FAILED", "DELETING"]:
75
+ if state in ["FAILED", "DELETING", "JOB_STATE_FAILED"]:
74
76
  return FineTuneStatus(
75
77
  status=FineTuneStatusType.failed,
76
78
  message="Fine-tuning job failed",
77
- )
78
- elif state in ["CREATING", "PENDING", "RUNNING"]:
79
+ ), model_id
80
+ elif state in [
81
+ "CREATING",
82
+ "PENDING",
83
+ "RUNNING",
84
+ "JOB_STATE_VALIDATING",
85
+ "JOB_STATE_RUNNING",
86
+ ]:
79
87
  return FineTuneStatus(
80
88
  status=FineTuneStatusType.running,
81
89
  message=f"Fine-tuning job is running [{state}]",
82
- )
83
- elif state == "COMPLETED":
90
+ ), model_id
91
+ elif state in ["COMPLETED", "JOB_STATE_COMPLETED"]:
84
92
  return FineTuneStatus(
85
93
  status=FineTuneStatusType.completed,
86
94
  message="Fine-tuning job completed",
87
- )
95
+ ), model_id
88
96
  else:
89
97
  return FineTuneStatus(
90
98
  status=FineTuneStatusType.unknown,
91
99
  message=f"Unknown fine-tuning job status [{state}]",
92
- )
100
+ ), model_id
93
101
  except Exception as e:
94
102
  return FineTuneStatus(
95
103
  status=FineTuneStatusType.unknown,
96
104
  message=f"Error retrieving fine-tuning job status: {e}",
97
- )
105
+ ), None
98
106
 
99
107
  async def _start(self, dataset: DatasetSplit) -> None:
100
108
  task = self.datamodel.parent_task()
101
109
  if not task:
102
110
  raise ValueError("Task is required to start a fine-tune")
103
111
 
112
+ format = DatasetFormat.OPENAI_CHAT_JSONL
113
+ if task.output_json_schema:
114
+ # This formatter will check it's valid JSON, and normalize the output (chat format just uses exact string).
115
+ format = DatasetFormat.OPENAI_CHAT_JSON_SCHEMA_JSONL
116
+ # Fireworks doesn't support function calls or json schema, so we'll use json mode at call time
117
+ self.datamodel.structured_output_mode = StructuredOutputMode.json_mode
118
+
104
119
  train_file_id = await self.generate_and_upload_jsonl(
105
- dataset, self.datamodel.train_split_name, task
120
+ dataset, self.datamodel.train_split_name, task, format
106
121
  )
107
122
 
108
123
  api_key = Config.shared().fireworks_api_key
@@ -110,9 +125,7 @@ class FireworksFinetune(BaseFinetuneAdapter):
110
125
  if not api_key or not account_id:
111
126
  raise ValueError("Fireworks API key or account ID not set")
112
127
 
113
- url = f"https://api.fireworks.ai/v1/accounts/{account_id}/fineTuningJobs"
114
- # Model ID != fine tune ID on Fireworks. Model is the result of the tune job.
115
- model_id = str(uuid4())
128
+ url = f"https://api.fireworks.ai/v1/accounts/{account_id}/supervisedFineTuningJobs"
116
129
  # Limit the display name to 60 characters
117
130
  display_name = (
118
131
  f"Kiln AI fine-tuning [ID:{self.datamodel.id}][name:{self.datamodel.name}]"[
@@ -120,11 +133,9 @@ class FireworksFinetune(BaseFinetuneAdapter):
120
133
  ]
121
134
  )
122
135
  payload = {
123
- "modelId": model_id,
124
136
  "dataset": f"accounts/{account_id}/datasets/{train_file_id}",
125
137
  "displayName": display_name,
126
138
  "baseModel": self.datamodel.base_model_id,
127
- "conversation": {},
128
139
  }
129
140
  hyperparameters = self.create_payload_parameters(self.datamodel.parameters)
130
141
  payload.update(hyperparameters)
@@ -148,21 +159,22 @@ class FireworksFinetune(BaseFinetuneAdapter):
148
159
  # model ID is the model that results from the fine-tune job
149
160
  job_id = data["name"]
150
161
  self.datamodel.provider_id = job_id
151
- # Keep track of the expected model ID before it's deployed as a property. We move it to fine_tune_model_id after deployment.
152
- self.datamodel.properties["undeployed_model_id"] = (
153
- f"accounts/{account_id}/models/{model_id}"
154
- )
162
+
163
+ # Fireworks has 2 different fine tuning endpoints, and depending which you use, the URLs change
164
+ self.datamodel.properties["endpoint_version"] = "v2"
165
+
155
166
  if self.datamodel.path:
156
167
  self.datamodel.save_to_file()
157
168
 
158
169
  async def generate_and_upload_jsonl(
159
- self, dataset: DatasetSplit, split_name: str, task: Task
170
+ self, dataset: DatasetSplit, split_name: str, task: Task, format: DatasetFormat
160
171
  ) -> str:
161
- formatter = DatasetFormatter(dataset, self.datamodel.system_message)
162
- # OpenAI compatible: https://docs.fireworks.ai/fine-tuning/fine-tuning-models#conversation
163
- # Note: Fireworks does not support tool calls (confirmed by Fireworks team) so we'll use json mode
164
- format = DatasetFormat.OPENAI_CHAT_JSONL
165
- path = formatter.dump_to_file(split_name, format)
172
+ formatter = DatasetFormatter(
173
+ dataset=dataset,
174
+ system_message=self.datamodel.system_message,
175
+ thinking_instructions=self.datamodel.thinking_instructions,
176
+ )
177
+ path = formatter.dump_to_file(split_name, format, self.datamodel.data_strategy)
166
178
 
167
179
  # First call creates the dataset
168
180
  api_key = Config.shared().fireworks_api_key
@@ -276,7 +288,10 @@ class FireworksFinetune(BaseFinetuneAdapter):
276
288
  if not api_key or not account_id:
277
289
  raise ValueError("Fireworks API key or account ID not set")
278
290
 
279
- model_id = self.datamodel.properties.get("undeployed_model_id")
291
+ # Model ID != fine tune ID on Fireworks. Model is the result of the tune job. Call status to get it.
292
+ status, model_id = await self._status()
293
+ if status.status != FineTuneStatusType.completed:
294
+ return False
280
295
  if not model_id or not isinstance(model_id, str):
281
296
  return False
282
297