kiln-ai 0.8.0__py3-none-any.whl → 0.11.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +7 -7
- kiln_ai/adapters/adapter_registry.py +77 -5
- kiln_ai/adapters/data_gen/data_gen_task.py +3 -3
- kiln_ai/adapters/data_gen/test_data_gen_task.py +23 -3
- kiln_ai/adapters/fine_tune/base_finetune.py +5 -1
- kiln_ai/adapters/fine_tune/dataset_formatter.py +310 -65
- kiln_ai/adapters/fine_tune/fireworks_finetune.py +47 -32
- kiln_ai/adapters/fine_tune/openai_finetune.py +12 -11
- kiln_ai/adapters/fine_tune/test_base_finetune.py +19 -0
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +469 -129
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +113 -21
- kiln_ai/adapters/fine_tune/test_openai_finetune.py +125 -14
- kiln_ai/adapters/ml_model_list.py +323 -94
- kiln_ai/adapters/model_adapters/__init__.py +18 -0
- kiln_ai/adapters/{base_adapter.py → model_adapters/base_adapter.py} +81 -37
- kiln_ai/adapters/{langchain_adapters.py → model_adapters/langchain_adapters.py} +130 -84
- kiln_ai/adapters/model_adapters/openai_compatible_config.py +11 -0
- kiln_ai/adapters/model_adapters/openai_model_adapter.py +246 -0
- kiln_ai/adapters/model_adapters/test_base_adapter.py +190 -0
- kiln_ai/adapters/{test_langchain_adapter.py → model_adapters/test_langchain_adapter.py} +103 -88
- kiln_ai/adapters/model_adapters/test_openai_model_adapter.py +225 -0
- kiln_ai/adapters/{test_saving_adapter_results.py → model_adapters/test_saving_adapter_results.py} +43 -15
- kiln_ai/adapters/{test_structured_output.py → model_adapters/test_structured_output.py} +93 -20
- kiln_ai/adapters/parsers/__init__.py +10 -0
- kiln_ai/adapters/parsers/base_parser.py +12 -0
- kiln_ai/adapters/parsers/json_parser.py +37 -0
- kiln_ai/adapters/parsers/parser_registry.py +19 -0
- kiln_ai/adapters/parsers/r1_parser.py +69 -0
- kiln_ai/adapters/parsers/test_json_parser.py +81 -0
- kiln_ai/adapters/parsers/test_parser_registry.py +32 -0
- kiln_ai/adapters/parsers/test_r1_parser.py +144 -0
- kiln_ai/adapters/prompt_builders.py +126 -20
- kiln_ai/adapters/provider_tools.py +91 -36
- kiln_ai/adapters/repair/repair_task.py +17 -6
- kiln_ai/adapters/repair/test_repair_task.py +4 -4
- kiln_ai/adapters/run_output.py +8 -0
- kiln_ai/adapters/test_adapter_registry.py +177 -0
- kiln_ai/adapters/test_generate_docs.py +69 -0
- kiln_ai/adapters/test_prompt_adaptors.py +8 -4
- kiln_ai/adapters/test_prompt_builders.py +190 -29
- kiln_ai/adapters/test_provider_tools.py +268 -46
- kiln_ai/datamodel/__init__.py +199 -12
- kiln_ai/datamodel/basemodel.py +31 -11
- kiln_ai/datamodel/json_schema.py +8 -3
- kiln_ai/datamodel/model_cache.py +8 -3
- kiln_ai/datamodel/test_basemodel.py +81 -2
- kiln_ai/datamodel/test_dataset_split.py +100 -3
- kiln_ai/datamodel/test_example_models.py +25 -4
- kiln_ai/datamodel/test_model_cache.py +24 -0
- kiln_ai/datamodel/test_model_perf.py +125 -0
- kiln_ai/datamodel/test_models.py +129 -0
- kiln_ai/utils/exhaustive_error.py +6 -0
- {kiln_ai-0.8.0.dist-info → kiln_ai-0.11.1.dist-info}/METADATA +9 -7
- kiln_ai-0.11.1.dist-info/RECORD +76 -0
- kiln_ai-0.8.0.dist-info/RECORD +0 -58
- {kiln_ai-0.8.0.dist-info → kiln_ai-0.11.1.dist-info}/WHEEL +0 -0
- {kiln_ai-0.8.0.dist-info → kiln_ai-0.11.1.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -1,11 +1,13 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import tempfile
|
|
3
|
+
from dataclasses import dataclass
|
|
3
4
|
from enum import Enum
|
|
4
5
|
from pathlib import Path
|
|
5
6
|
from typing import Any, Dict, Protocol
|
|
6
7
|
from uuid import uuid4
|
|
7
8
|
|
|
8
|
-
from kiln_ai.
|
|
9
|
+
from kiln_ai.adapters.model_adapters.base_adapter import COT_FINAL_ANSWER_PROMPT
|
|
10
|
+
from kiln_ai.datamodel import DatasetSplit, FinetuneDataStrategy, TaskRun
|
|
9
11
|
|
|
10
12
|
|
|
11
13
|
class DatasetFormat(str, Enum):
|
|
@@ -14,6 +16,9 @@ class DatasetFormat(str, Enum):
|
|
|
14
16
|
"""OpenAI chat format with plaintext response"""
|
|
15
17
|
OPENAI_CHAT_JSONL = "openai_chat_jsonl"
|
|
16
18
|
|
|
19
|
+
"""OpenAI chat format with json response_format"""
|
|
20
|
+
OPENAI_CHAT_JSON_SCHEMA_JSONL = "openai_chat_json_schema_jsonl"
|
|
21
|
+
|
|
17
22
|
"""OpenAI chat format with tool call response"""
|
|
18
23
|
OPENAI_CHAT_TOOLCALL_JSONL = "openai_chat_toolcall_jsonl"
|
|
19
24
|
|
|
@@ -25,116 +30,338 @@ class DatasetFormat(str, Enum):
|
|
|
25
30
|
"huggingface_chat_template_toolcall_jsonl"
|
|
26
31
|
)
|
|
27
32
|
|
|
33
|
+
"""Vertex Gemini 1.5 format (flash and pro)"""
|
|
34
|
+
VERTEX_GEMINI_1_5 = "vertex_gemini_1_5"
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@dataclass
|
|
38
|
+
class ModelTrainingData:
|
|
39
|
+
input: str
|
|
40
|
+
system_message: str
|
|
41
|
+
final_output: str
|
|
42
|
+
# These 3 are optional, and used for COT/Thinking style multi-message responses
|
|
43
|
+
thinking_instructions: str | None = None
|
|
44
|
+
thinking: str | None = None
|
|
45
|
+
thinking_final_answer_prompt: str | None = None
|
|
46
|
+
|
|
47
|
+
def supports_cot(self) -> bool:
|
|
48
|
+
return (
|
|
49
|
+
self.thinking_instructions is not None
|
|
50
|
+
and self.thinking is not None
|
|
51
|
+
and self.thinking_final_answer_prompt is not None
|
|
52
|
+
)
|
|
53
|
+
|
|
28
54
|
|
|
29
55
|
class FormatGenerator(Protocol):
|
|
30
56
|
"""Protocol for format generators"""
|
|
31
57
|
|
|
32
|
-
def __call__(
|
|
58
|
+
def __call__(
|
|
59
|
+
self,
|
|
60
|
+
training_data: ModelTrainingData,
|
|
61
|
+
) -> Dict[str, Any]: ...
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def build_training_data(
|
|
65
|
+
task_run: TaskRun,
|
|
66
|
+
system_message: str,
|
|
67
|
+
include_cot: bool,
|
|
68
|
+
thinking_instructions: str | None = None,
|
|
69
|
+
) -> ModelTrainingData:
|
|
70
|
+
"""
|
|
71
|
+
Generate data for training.
|
|
72
|
+
|
|
73
|
+
For final output, get the best task output from the task run, preferring repaired output if available.
|
|
74
|
+
|
|
75
|
+
For thinking, get the intermediate output if it exists, otherwise return None.
|
|
76
|
+
"""
|
|
77
|
+
final_output = task_run.output.output
|
|
78
|
+
if task_run.repaired_output is not None:
|
|
79
|
+
final_output = task_run.repaired_output.output
|
|
80
|
+
|
|
81
|
+
thinking = None
|
|
82
|
+
thinking_final_answer_prompt = None
|
|
83
|
+
parent_task = task_run.parent_task()
|
|
84
|
+
|
|
85
|
+
if include_cot and task_run.has_thinking_training_data():
|
|
86
|
+
if not parent_task:
|
|
87
|
+
raise ValueError(
|
|
88
|
+
"TaskRuns for training required a parent Task for building a chain of thought prompts. Train without COT, or save this TaskRun to a parent Task."
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
# Prefer reasoning to cot if both are present
|
|
92
|
+
intermediate_outputs = task_run.intermediate_outputs or {}
|
|
93
|
+
thinking = intermediate_outputs.get("reasoning") or intermediate_outputs.get(
|
|
94
|
+
"chain_of_thought"
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
thinking_final_answer_prompt = COT_FINAL_ANSWER_PROMPT
|
|
98
|
+
|
|
99
|
+
# Always use the passed thinking instructions, but check they are present for COT
|
|
100
|
+
if not thinking_instructions:
|
|
101
|
+
raise ValueError(
|
|
102
|
+
"Thinking instructions are required when data_strategy is final_and_intermediate"
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
return ModelTrainingData(
|
|
106
|
+
input=task_run.input,
|
|
107
|
+
system_message=system_message,
|
|
108
|
+
final_output=final_output,
|
|
109
|
+
thinking=thinking,
|
|
110
|
+
thinking_instructions=thinking_instructions,
|
|
111
|
+
thinking_final_answer_prompt=thinking_final_answer_prompt,
|
|
112
|
+
)
|
|
33
113
|
|
|
34
114
|
|
|
35
115
|
def generate_chat_message_response(
|
|
36
|
-
|
|
116
|
+
training_data: ModelTrainingData,
|
|
37
117
|
) -> Dict[str, Any]:
|
|
38
118
|
"""Generate OpenAI chat format with plaintext response"""
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
119
|
+
|
|
120
|
+
messages: list[dict[str, str | None]] = [
|
|
121
|
+
{"role": "system", "content": training_data.system_message},
|
|
122
|
+
{"role": "user", "content": training_data.input},
|
|
123
|
+
]
|
|
124
|
+
|
|
125
|
+
if training_data.supports_cot():
|
|
126
|
+
messages.extend(
|
|
127
|
+
[
|
|
128
|
+
{"role": "user", "content": training_data.thinking_instructions},
|
|
129
|
+
{"role": "assistant", "content": training_data.thinking},
|
|
130
|
+
{
|
|
131
|
+
"role": "user",
|
|
132
|
+
"content": training_data.thinking_final_answer_prompt,
|
|
133
|
+
},
|
|
134
|
+
]
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
messages.append({"role": "assistant", "content": training_data.final_output})
|
|
138
|
+
|
|
139
|
+
return {"messages": messages}
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def generate_json_schema_message(
|
|
143
|
+
training_data: ModelTrainingData,
|
|
144
|
+
) -> Dict[str, Any]:
|
|
145
|
+
"""Generate OpenAI chat format with validated JSON response"""
|
|
146
|
+
# Load and dump to ensure it's valid JSON and goes to 1 line
|
|
147
|
+
try:
|
|
148
|
+
json_data = json.loads(training_data.final_output)
|
|
149
|
+
except json.JSONDecodeError as e:
|
|
150
|
+
raise ValueError(
|
|
151
|
+
f"Invalid JSON in JSON Schema training set: {e}\nOutput Data: {training_data.final_output}"
|
|
152
|
+
) from e
|
|
153
|
+
json_string = json.dumps(json_data, ensure_ascii=False)
|
|
154
|
+
|
|
155
|
+
messages: list[dict[str, str | None]] = [
|
|
156
|
+
{"role": "system", "content": training_data.system_message},
|
|
157
|
+
{"role": "user", "content": training_data.input},
|
|
158
|
+
]
|
|
159
|
+
|
|
160
|
+
if training_data.supports_cot():
|
|
161
|
+
messages.extend(
|
|
162
|
+
[
|
|
163
|
+
{"role": "user", "content": training_data.thinking_instructions},
|
|
164
|
+
{"role": "assistant", "content": training_data.thinking},
|
|
165
|
+
{
|
|
166
|
+
"role": "user",
|
|
167
|
+
"content": training_data.thinking_final_answer_prompt,
|
|
168
|
+
},
|
|
169
|
+
]
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
messages.append({"role": "assistant", "content": json_string})
|
|
173
|
+
|
|
174
|
+
return {"messages": messages}
|
|
46
175
|
|
|
47
176
|
|
|
48
177
|
def generate_chat_message_toolcall(
|
|
49
|
-
|
|
178
|
+
training_data: ModelTrainingData,
|
|
50
179
|
) -> Dict[str, Any]:
|
|
51
180
|
"""Generate OpenAI chat format with tool call response"""
|
|
52
181
|
try:
|
|
53
|
-
arguments = json.loads(
|
|
182
|
+
arguments = json.loads(training_data.final_output)
|
|
54
183
|
except json.JSONDecodeError as e:
|
|
55
184
|
raise ValueError(f"Invalid JSON in for tool call: {e}") from e
|
|
56
185
|
|
|
57
|
-
|
|
58
|
-
"
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
186
|
+
messages: list[dict[str, Any]] = [
|
|
187
|
+
{"role": "system", "content": training_data.system_message},
|
|
188
|
+
{"role": "user", "content": training_data.input},
|
|
189
|
+
]
|
|
190
|
+
|
|
191
|
+
if training_data.supports_cot():
|
|
192
|
+
messages.extend(
|
|
193
|
+
[
|
|
194
|
+
{"role": "user", "content": training_data.thinking_instructions},
|
|
195
|
+
{"role": "assistant", "content": training_data.thinking},
|
|
196
|
+
{
|
|
197
|
+
"role": "user",
|
|
198
|
+
"content": training_data.thinking_final_answer_prompt,
|
|
199
|
+
},
|
|
200
|
+
]
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
messages.append(
|
|
204
|
+
{
|
|
205
|
+
"role": "assistant",
|
|
206
|
+
"content": None,
|
|
207
|
+
"tool_calls": [
|
|
208
|
+
{
|
|
209
|
+
"id": "call_1",
|
|
210
|
+
"type": "function",
|
|
211
|
+
"function": {
|
|
212
|
+
"name": "task_response",
|
|
213
|
+
# Yes we parse then dump again. This ensures it's valid JSON, and ensures it goes to 1 line
|
|
214
|
+
"arguments": json.dumps(arguments, ensure_ascii=False),
|
|
215
|
+
},
|
|
216
|
+
}
|
|
217
|
+
],
|
|
218
|
+
},
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
return {"messages": messages}
|
|
78
222
|
|
|
79
223
|
|
|
80
224
|
def generate_huggingface_chat_template(
|
|
81
|
-
|
|
225
|
+
training_data: ModelTrainingData,
|
|
82
226
|
) -> Dict[str, Any]:
|
|
83
227
|
"""Generate HuggingFace chat template"""
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
228
|
+
|
|
229
|
+
conversations: list[dict[str, Any]] = [
|
|
230
|
+
{"role": "system", "content": training_data.system_message},
|
|
231
|
+
{"role": "user", "content": training_data.input},
|
|
232
|
+
]
|
|
233
|
+
|
|
234
|
+
if training_data.supports_cot():
|
|
235
|
+
conversations.extend(
|
|
236
|
+
[
|
|
237
|
+
{"role": "user", "content": training_data.thinking_instructions},
|
|
238
|
+
{"role": "assistant", "content": training_data.thinking},
|
|
239
|
+
{"role": "user", "content": training_data.thinking_final_answer_prompt},
|
|
240
|
+
]
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
conversations.append({"role": "assistant", "content": training_data.final_output})
|
|
244
|
+
|
|
245
|
+
return {"conversations": conversations}
|
|
91
246
|
|
|
92
247
|
|
|
93
248
|
def generate_huggingface_chat_template_toolcall(
|
|
94
|
-
|
|
249
|
+
training_data: ModelTrainingData,
|
|
95
250
|
) -> Dict[str, Any]:
|
|
96
251
|
"""Generate HuggingFace chat template with tool calls"""
|
|
97
252
|
try:
|
|
98
|
-
arguments = json.loads(
|
|
253
|
+
arguments = json.loads(training_data.final_output)
|
|
99
254
|
except json.JSONDecodeError as e:
|
|
100
255
|
raise ValueError(f"Invalid JSON in for tool call: {e}") from e
|
|
101
256
|
|
|
102
257
|
# See https://huggingface.co/docs/transformers/en/chat_templating
|
|
258
|
+
conversations: list[dict[str, Any]] = [
|
|
259
|
+
{"role": "system", "content": training_data.system_message},
|
|
260
|
+
{"role": "user", "content": training_data.input},
|
|
261
|
+
]
|
|
262
|
+
|
|
263
|
+
if training_data.supports_cot():
|
|
264
|
+
conversations.extend(
|
|
265
|
+
[
|
|
266
|
+
{"role": "user", "content": training_data.thinking_instructions},
|
|
267
|
+
{"role": "assistant", "content": training_data.thinking},
|
|
268
|
+
{"role": "user", "content": training_data.thinking_final_answer_prompt},
|
|
269
|
+
]
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
conversations.append(
|
|
273
|
+
{
|
|
274
|
+
"role": "assistant",
|
|
275
|
+
"tool_calls": [
|
|
276
|
+
{
|
|
277
|
+
"type": "function",
|
|
278
|
+
"function": {
|
|
279
|
+
"name": "task_response",
|
|
280
|
+
"id": str(uuid4()).replace("-", "")[:9],
|
|
281
|
+
"arguments": arguments,
|
|
282
|
+
},
|
|
283
|
+
}
|
|
284
|
+
],
|
|
285
|
+
},
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
return {"conversations": conversations}
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
def generate_vertex_gemini_1_5(
|
|
292
|
+
training_data: ModelTrainingData,
|
|
293
|
+
) -> Dict[str, Any]:
|
|
294
|
+
"""Generate Vertex Gemini 1.5 format (flash and pro)"""
|
|
295
|
+
# See https://cloud.google.com/vertex-ai/generative-ai/docs/models/gemini-supervised-tuning-prepare
|
|
296
|
+
|
|
297
|
+
contents = [
|
|
298
|
+
{
|
|
299
|
+
"role": "user",
|
|
300
|
+
"parts": [
|
|
301
|
+
{
|
|
302
|
+
"text": training_data.input,
|
|
303
|
+
}
|
|
304
|
+
],
|
|
305
|
+
}
|
|
306
|
+
]
|
|
307
|
+
|
|
308
|
+
if training_data.supports_cot():
|
|
309
|
+
contents.extend(
|
|
310
|
+
[
|
|
311
|
+
{
|
|
312
|
+
"role": "user",
|
|
313
|
+
"parts": [{"text": training_data.thinking_instructions}],
|
|
314
|
+
},
|
|
315
|
+
{"role": "model", "parts": [{"text": training_data.thinking}]},
|
|
316
|
+
{
|
|
317
|
+
"role": "user",
|
|
318
|
+
"parts": [{"text": training_data.thinking_final_answer_prompt}],
|
|
319
|
+
},
|
|
320
|
+
]
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
contents.append(
|
|
324
|
+
{
|
|
325
|
+
"role": "model",
|
|
326
|
+
"parts": [{"text": training_data.final_output}],
|
|
327
|
+
}
|
|
328
|
+
)
|
|
329
|
+
|
|
103
330
|
return {
|
|
104
|
-
"
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
"name": "task_response",
|
|
114
|
-
"id": str(uuid4()).replace("-", "")[:9],
|
|
115
|
-
"arguments": arguments,
|
|
116
|
-
},
|
|
117
|
-
}
|
|
118
|
-
],
|
|
119
|
-
},
|
|
120
|
-
]
|
|
331
|
+
"systemInstruction": {
|
|
332
|
+
"role": "system",
|
|
333
|
+
"parts": [
|
|
334
|
+
{
|
|
335
|
+
"text": training_data.system_message,
|
|
336
|
+
}
|
|
337
|
+
],
|
|
338
|
+
},
|
|
339
|
+
"contents": contents,
|
|
121
340
|
}
|
|
122
341
|
|
|
123
342
|
|
|
124
343
|
FORMAT_GENERATORS: Dict[DatasetFormat, FormatGenerator] = {
|
|
125
344
|
DatasetFormat.OPENAI_CHAT_JSONL: generate_chat_message_response,
|
|
345
|
+
DatasetFormat.OPENAI_CHAT_JSON_SCHEMA_JSONL: generate_json_schema_message,
|
|
126
346
|
DatasetFormat.OPENAI_CHAT_TOOLCALL_JSONL: generate_chat_message_toolcall,
|
|
127
347
|
DatasetFormat.HUGGINGFACE_CHAT_TEMPLATE_JSONL: generate_huggingface_chat_template,
|
|
128
348
|
DatasetFormat.HUGGINGFACE_CHAT_TEMPLATE_TOOLCALL_JSONL: generate_huggingface_chat_template_toolcall,
|
|
349
|
+
DatasetFormat.VERTEX_GEMINI_1_5: generate_vertex_gemini_1_5,
|
|
129
350
|
}
|
|
130
351
|
|
|
131
352
|
|
|
132
353
|
class DatasetFormatter:
|
|
133
354
|
"""Handles formatting of datasets into various output formats"""
|
|
134
355
|
|
|
135
|
-
def __init__(
|
|
356
|
+
def __init__(
|
|
357
|
+
self,
|
|
358
|
+
dataset: DatasetSplit,
|
|
359
|
+
system_message: str,
|
|
360
|
+
thinking_instructions: str | None = None,
|
|
361
|
+
):
|
|
136
362
|
self.dataset = dataset
|
|
137
363
|
self.system_message = system_message
|
|
364
|
+
self.thinking_instructions = thinking_instructions
|
|
138
365
|
|
|
139
366
|
task = dataset.parent_task()
|
|
140
367
|
if task is None:
|
|
@@ -142,7 +369,11 @@ class DatasetFormatter:
|
|
|
142
369
|
self.task = task
|
|
143
370
|
|
|
144
371
|
def dump_to_file(
|
|
145
|
-
self,
|
|
372
|
+
self,
|
|
373
|
+
split_name: str,
|
|
374
|
+
format_type: DatasetFormat,
|
|
375
|
+
data_strategy: FinetuneDataStrategy,
|
|
376
|
+
path: Path | None = None,
|
|
146
377
|
) -> Path:
|
|
147
378
|
"""
|
|
148
379
|
Format the dataset into the specified format.
|
|
@@ -154,6 +385,10 @@ class DatasetFormatter:
|
|
|
154
385
|
|
|
155
386
|
Returns:
|
|
156
387
|
Path to the generated file
|
|
388
|
+
|
|
389
|
+
Note:
|
|
390
|
+
The output is written in UTF-8 encoding with ensure_ascii=False to properly
|
|
391
|
+
support international text content while maintaining readability.
|
|
157
392
|
"""
|
|
158
393
|
if format_type not in FORMAT_GENERATORS:
|
|
159
394
|
raise ValueError(f"Unsupported format: {format_type}")
|
|
@@ -162,11 +397,13 @@ class DatasetFormatter:
|
|
|
162
397
|
|
|
163
398
|
generator = FORMAT_GENERATORS[format_type]
|
|
164
399
|
|
|
400
|
+
include_cot = data_strategy == FinetuneDataStrategy.final_and_intermediate
|
|
401
|
+
|
|
165
402
|
# Write to a temp file if no path is provided
|
|
166
403
|
output_path = (
|
|
167
404
|
path
|
|
168
405
|
or Path(tempfile.gettempdir())
|
|
169
|
-
/ f"{self.dataset.name}
|
|
406
|
+
/ f"{self.dataset.name} -- split-{split_name} -- format-{format_type.value} -- {'cot' if include_cot else 'no-cot'}.jsonl"
|
|
170
407
|
)
|
|
171
408
|
|
|
172
409
|
runs = self.task.runs()
|
|
@@ -181,7 +418,15 @@ class DatasetFormatter:
|
|
|
181
418
|
f"Task run {run_id} not found. This is required by this dataset."
|
|
182
419
|
)
|
|
183
420
|
|
|
184
|
-
|
|
185
|
-
|
|
421
|
+
training_data = build_training_data(
|
|
422
|
+
task_run=task_run,
|
|
423
|
+
system_message=self.system_message,
|
|
424
|
+
include_cot=include_cot,
|
|
425
|
+
thinking_instructions=self.thinking_instructions,
|
|
426
|
+
)
|
|
427
|
+
example = generator(training_data)
|
|
428
|
+
# Allow non-ascii characters in the dataset.
|
|
429
|
+
# Better readability for non-English users. If you don't support UTF-8... you should.
|
|
430
|
+
f.write(json.dumps(example, ensure_ascii=False) + "\n")
|
|
186
431
|
|
|
187
432
|
return output_path
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
from typing import Tuple
|
|
1
2
|
from uuid import uuid4
|
|
2
3
|
|
|
3
4
|
import httpx
|
|
@@ -9,7 +10,7 @@ from kiln_ai.adapters.fine_tune.base_finetune import (
|
|
|
9
10
|
FineTuneStatusType,
|
|
10
11
|
)
|
|
11
12
|
from kiln_ai.adapters.fine_tune.dataset_formatter import DatasetFormat, DatasetFormatter
|
|
12
|
-
from kiln_ai.datamodel import DatasetSplit, Task
|
|
13
|
+
from kiln_ai.datamodel import DatasetSplit, StructuredOutputMode, Task
|
|
13
14
|
from kiln_ai.utils.config import Config
|
|
14
15
|
|
|
15
16
|
|
|
@@ -19,7 +20,7 @@ class FireworksFinetune(BaseFinetuneAdapter):
|
|
|
19
20
|
"""
|
|
20
21
|
|
|
21
22
|
async def status(self) -> FineTuneStatus:
|
|
22
|
-
status = await self._status()
|
|
23
|
+
status, _ = await self._status()
|
|
23
24
|
# update the datamodel if the status has changed
|
|
24
25
|
if self.datamodel.latest_status != status.status:
|
|
25
26
|
self.datamodel.latest_status = status.status
|
|
@@ -34,7 +35,7 @@ class FireworksFinetune(BaseFinetuneAdapter):
|
|
|
34
35
|
|
|
35
36
|
return status
|
|
36
37
|
|
|
37
|
-
async def _status(self) -> FineTuneStatus:
|
|
38
|
+
async def _status(self) -> Tuple[FineTuneStatus, str | None]:
|
|
38
39
|
try:
|
|
39
40
|
api_key = Config.shared().fireworks_api_key
|
|
40
41
|
account_id = Config.shared().fireworks_account_id
|
|
@@ -42,13 +43,13 @@ class FireworksFinetune(BaseFinetuneAdapter):
|
|
|
42
43
|
return FineTuneStatus(
|
|
43
44
|
status=FineTuneStatusType.unknown,
|
|
44
45
|
message="Fireworks API key or account ID not set",
|
|
45
|
-
)
|
|
46
|
+
), None
|
|
46
47
|
fine_tuning_job_id = self.datamodel.provider_id
|
|
47
48
|
if not fine_tuning_job_id:
|
|
48
49
|
return FineTuneStatus(
|
|
49
50
|
status=FineTuneStatusType.unknown,
|
|
50
51
|
message="Fine-tuning job ID not set. Can not retrieve status.",
|
|
51
|
-
)
|
|
52
|
+
), None
|
|
52
53
|
# Fireworks uses path style IDs
|
|
53
54
|
url = f"https://api.fireworks.ai/v1/{fine_tuning_job_id}"
|
|
54
55
|
headers = {"Authorization": f"Bearer {api_key}"}
|
|
@@ -60,49 +61,63 @@ class FireworksFinetune(BaseFinetuneAdapter):
|
|
|
60
61
|
return FineTuneStatus(
|
|
61
62
|
status=FineTuneStatusType.unknown,
|
|
62
63
|
message=f"Error retrieving fine-tuning job status: [{response.status_code}] {response.text}",
|
|
63
|
-
)
|
|
64
|
+
), None
|
|
64
65
|
data = response.json()
|
|
66
|
+
model_id = data.get("outputModel")
|
|
65
67
|
|
|
66
68
|
if "state" not in data:
|
|
67
69
|
return FineTuneStatus(
|
|
68
70
|
status=FineTuneStatusType.unknown,
|
|
69
71
|
message="Invalid response from Fireworks (no state).",
|
|
70
|
-
)
|
|
72
|
+
), model_id
|
|
71
73
|
|
|
72
74
|
state = data["state"]
|
|
73
|
-
if state in ["FAILED", "DELETING"]:
|
|
75
|
+
if state in ["FAILED", "DELETING", "JOB_STATE_FAILED"]:
|
|
74
76
|
return FineTuneStatus(
|
|
75
77
|
status=FineTuneStatusType.failed,
|
|
76
78
|
message="Fine-tuning job failed",
|
|
77
|
-
)
|
|
78
|
-
elif state in [
|
|
79
|
+
), model_id
|
|
80
|
+
elif state in [
|
|
81
|
+
"CREATING",
|
|
82
|
+
"PENDING",
|
|
83
|
+
"RUNNING",
|
|
84
|
+
"JOB_STATE_VALIDATING",
|
|
85
|
+
"JOB_STATE_RUNNING",
|
|
86
|
+
]:
|
|
79
87
|
return FineTuneStatus(
|
|
80
88
|
status=FineTuneStatusType.running,
|
|
81
89
|
message=f"Fine-tuning job is running [{state}]",
|
|
82
|
-
)
|
|
83
|
-
elif state
|
|
90
|
+
), model_id
|
|
91
|
+
elif state in ["COMPLETED", "JOB_STATE_COMPLETED"]:
|
|
84
92
|
return FineTuneStatus(
|
|
85
93
|
status=FineTuneStatusType.completed,
|
|
86
94
|
message="Fine-tuning job completed",
|
|
87
|
-
)
|
|
95
|
+
), model_id
|
|
88
96
|
else:
|
|
89
97
|
return FineTuneStatus(
|
|
90
98
|
status=FineTuneStatusType.unknown,
|
|
91
99
|
message=f"Unknown fine-tuning job status [{state}]",
|
|
92
|
-
)
|
|
100
|
+
), model_id
|
|
93
101
|
except Exception as e:
|
|
94
102
|
return FineTuneStatus(
|
|
95
103
|
status=FineTuneStatusType.unknown,
|
|
96
104
|
message=f"Error retrieving fine-tuning job status: {e}",
|
|
97
|
-
)
|
|
105
|
+
), None
|
|
98
106
|
|
|
99
107
|
async def _start(self, dataset: DatasetSplit) -> None:
|
|
100
108
|
task = self.datamodel.parent_task()
|
|
101
109
|
if not task:
|
|
102
110
|
raise ValueError("Task is required to start a fine-tune")
|
|
103
111
|
|
|
112
|
+
format = DatasetFormat.OPENAI_CHAT_JSONL
|
|
113
|
+
if task.output_json_schema:
|
|
114
|
+
# This formatter will check it's valid JSON, and normalize the output (chat format just uses exact string).
|
|
115
|
+
format = DatasetFormat.OPENAI_CHAT_JSON_SCHEMA_JSONL
|
|
116
|
+
# Fireworks doesn't support function calls or json schema, so we'll use json mode at call time
|
|
117
|
+
self.datamodel.structured_output_mode = StructuredOutputMode.json_mode
|
|
118
|
+
|
|
104
119
|
train_file_id = await self.generate_and_upload_jsonl(
|
|
105
|
-
dataset, self.datamodel.train_split_name, task
|
|
120
|
+
dataset, self.datamodel.train_split_name, task, format
|
|
106
121
|
)
|
|
107
122
|
|
|
108
123
|
api_key = Config.shared().fireworks_api_key
|
|
@@ -110,9 +125,7 @@ class FireworksFinetune(BaseFinetuneAdapter):
|
|
|
110
125
|
if not api_key or not account_id:
|
|
111
126
|
raise ValueError("Fireworks API key or account ID not set")
|
|
112
127
|
|
|
113
|
-
url = f"https://api.fireworks.ai/v1/accounts/{account_id}/
|
|
114
|
-
# Model ID != fine tune ID on Fireworks. Model is the result of the tune job.
|
|
115
|
-
model_id = str(uuid4())
|
|
128
|
+
url = f"https://api.fireworks.ai/v1/accounts/{account_id}/supervisedFineTuningJobs"
|
|
116
129
|
# Limit the display name to 60 characters
|
|
117
130
|
display_name = (
|
|
118
131
|
f"Kiln AI fine-tuning [ID:{self.datamodel.id}][name:{self.datamodel.name}]"[
|
|
@@ -120,11 +133,9 @@ class FireworksFinetune(BaseFinetuneAdapter):
|
|
|
120
133
|
]
|
|
121
134
|
)
|
|
122
135
|
payload = {
|
|
123
|
-
"modelId": model_id,
|
|
124
136
|
"dataset": f"accounts/{account_id}/datasets/{train_file_id}",
|
|
125
137
|
"displayName": display_name,
|
|
126
138
|
"baseModel": self.datamodel.base_model_id,
|
|
127
|
-
"conversation": {},
|
|
128
139
|
}
|
|
129
140
|
hyperparameters = self.create_payload_parameters(self.datamodel.parameters)
|
|
130
141
|
payload.update(hyperparameters)
|
|
@@ -148,21 +159,22 @@ class FireworksFinetune(BaseFinetuneAdapter):
|
|
|
148
159
|
# model ID is the model that results from the fine-tune job
|
|
149
160
|
job_id = data["name"]
|
|
150
161
|
self.datamodel.provider_id = job_id
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
162
|
+
|
|
163
|
+
# Fireworks has 2 different fine tuning endpoints, and depending which you use, the URLs change
|
|
164
|
+
self.datamodel.properties["endpoint_version"] = "v2"
|
|
165
|
+
|
|
155
166
|
if self.datamodel.path:
|
|
156
167
|
self.datamodel.save_to_file()
|
|
157
168
|
|
|
158
169
|
async def generate_and_upload_jsonl(
|
|
159
|
-
self, dataset: DatasetSplit, split_name: str, task: Task
|
|
170
|
+
self, dataset: DatasetSplit, split_name: str, task: Task, format: DatasetFormat
|
|
160
171
|
) -> str:
|
|
161
|
-
formatter = DatasetFormatter(
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
172
|
+
formatter = DatasetFormatter(
|
|
173
|
+
dataset=dataset,
|
|
174
|
+
system_message=self.datamodel.system_message,
|
|
175
|
+
thinking_instructions=self.datamodel.thinking_instructions,
|
|
176
|
+
)
|
|
177
|
+
path = formatter.dump_to_file(split_name, format, self.datamodel.data_strategy)
|
|
166
178
|
|
|
167
179
|
# First call creates the dataset
|
|
168
180
|
api_key = Config.shared().fireworks_api_key
|
|
@@ -276,7 +288,10 @@ class FireworksFinetune(BaseFinetuneAdapter):
|
|
|
276
288
|
if not api_key or not account_id:
|
|
277
289
|
raise ValueError("Fireworks API key or account ID not set")
|
|
278
290
|
|
|
279
|
-
|
|
291
|
+
# Model ID != fine tune ID on Fireworks. Model is the result of the tune job. Call status to get it.
|
|
292
|
+
status, model_id = await self._status()
|
|
293
|
+
if status.status != FineTuneStatusType.completed:
|
|
294
|
+
return False
|
|
280
295
|
if not model_id or not isinstance(model_id, str):
|
|
281
296
|
return False
|
|
282
297
|
|