kiln-ai 0.16.0__py3-none-any.whl → 0.18.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kiln_ai/adapters/__init__.py +2 -0
- kiln_ai/adapters/adapter_registry.py +22 -44
- kiln_ai/adapters/chat/__init__.py +8 -0
- kiln_ai/adapters/chat/chat_formatter.py +233 -0
- kiln_ai/adapters/chat/test_chat_formatter.py +131 -0
- kiln_ai/adapters/data_gen/data_gen_prompts.py +121 -36
- kiln_ai/adapters/data_gen/data_gen_task.py +49 -36
- kiln_ai/adapters/data_gen/test_data_gen_task.py +330 -40
- kiln_ai/adapters/eval/base_eval.py +7 -6
- kiln_ai/adapters/eval/eval_runner.py +9 -2
- kiln_ai/adapters/eval/g_eval.py +40 -17
- kiln_ai/adapters/eval/test_base_eval.py +174 -17
- kiln_ai/adapters/eval/test_eval_runner.py +3 -0
- kiln_ai/adapters/eval/test_g_eval.py +116 -5
- kiln_ai/adapters/fine_tune/base_finetune.py +3 -8
- kiln_ai/adapters/fine_tune/dataset_formatter.py +135 -273
- kiln_ai/adapters/fine_tune/test_base_finetune.py +10 -10
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +287 -353
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +3 -3
- kiln_ai/adapters/fine_tune/test_openai_finetune.py +6 -6
- kiln_ai/adapters/fine_tune/test_together_finetune.py +1 -0
- kiln_ai/adapters/fine_tune/test_vertex_finetune.py +6 -11
- kiln_ai/adapters/fine_tune/together_finetune.py +13 -2
- kiln_ai/adapters/ml_model_list.py +370 -84
- kiln_ai/adapters/model_adapters/base_adapter.py +73 -26
- kiln_ai/adapters/model_adapters/litellm_adapter.py +88 -97
- kiln_ai/adapters/model_adapters/litellm_config.py +3 -2
- kiln_ai/adapters/model_adapters/test_base_adapter.py +235 -61
- kiln_ai/adapters/model_adapters/test_litellm_adapter.py +104 -21
- kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +41 -0
- kiln_ai/adapters/model_adapters/test_structured_output.py +44 -12
- kiln_ai/adapters/parsers/parser_registry.py +0 -2
- kiln_ai/adapters/parsers/r1_parser.py +0 -1
- kiln_ai/adapters/prompt_builders.py +0 -16
- kiln_ai/adapters/provider_tools.py +27 -9
- kiln_ai/adapters/remote_config.py +66 -0
- kiln_ai/adapters/repair/repair_task.py +1 -6
- kiln_ai/adapters/repair/test_repair_task.py +24 -3
- kiln_ai/adapters/test_adapter_registry.py +88 -28
- kiln_ai/adapters/test_ml_model_list.py +176 -0
- kiln_ai/adapters/test_prompt_adaptors.py +17 -7
- kiln_ai/adapters/test_prompt_builders.py +3 -16
- kiln_ai/adapters/test_provider_tools.py +69 -20
- kiln_ai/adapters/test_remote_config.py +100 -0
- kiln_ai/datamodel/__init__.py +0 -2
- kiln_ai/datamodel/datamodel_enums.py +38 -13
- kiln_ai/datamodel/eval.py +32 -0
- kiln_ai/datamodel/finetune.py +12 -8
- kiln_ai/datamodel/task.py +68 -7
- kiln_ai/datamodel/task_output.py +0 -2
- kiln_ai/datamodel/task_run.py +0 -2
- kiln_ai/datamodel/test_basemodel.py +2 -1
- kiln_ai/datamodel/test_dataset_split.py +0 -8
- kiln_ai/datamodel/test_eval_model.py +146 -4
- kiln_ai/datamodel/test_models.py +33 -10
- kiln_ai/datamodel/test_task.py +168 -2
- kiln_ai/utils/config.py +3 -2
- kiln_ai/utils/dataset_import.py +1 -1
- kiln_ai/utils/logging.py +166 -0
- kiln_ai/utils/test_config.py +23 -0
- kiln_ai/utils/test_dataset_import.py +30 -0
- {kiln_ai-0.16.0.dist-info → kiln_ai-0.18.0.dist-info}/METADATA +2 -2
- kiln_ai-0.18.0.dist-info/RECORD +115 -0
- kiln_ai-0.16.0.dist-info/RECORD +0 -108
- {kiln_ai-0.16.0.dist-info → kiln_ai-0.18.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.16.0.dist-info → kiln_ai-0.18.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -1,14 +1,14 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import tempfile
|
|
3
|
-
from dataclasses import dataclass
|
|
4
3
|
from enum import Enum
|
|
5
4
|
from pathlib import Path
|
|
6
5
|
from typing import Any, Dict, Protocol
|
|
7
6
|
from uuid import uuid4
|
|
8
7
|
|
|
9
|
-
from kiln_ai.adapters.
|
|
10
|
-
from kiln_ai.datamodel import DatasetSplit,
|
|
11
|
-
from kiln_ai.datamodel.datamodel_enums import THINKING_DATA_STRATEGIES
|
|
8
|
+
from kiln_ai.adapters.chat.chat_formatter import ChatMessage, get_chat_formatter
|
|
9
|
+
from kiln_ai.datamodel import DatasetSplit, TaskRun
|
|
10
|
+
from kiln_ai.datamodel.datamodel_enums import THINKING_DATA_STRATEGIES, ChatStrategy
|
|
11
|
+
from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
class DatasetFormat(str, Enum):
|
|
@@ -35,45 +35,23 @@ class DatasetFormat(str, Enum):
|
|
|
35
35
|
VERTEX_GEMINI = "vertex_gemini"
|
|
36
36
|
|
|
37
37
|
|
|
38
|
-
@dataclass
|
|
39
|
-
class ModelTrainingData:
|
|
40
|
-
input: str
|
|
41
|
-
system_message: str
|
|
42
|
-
final_output: str
|
|
43
|
-
# These 3 are optional, and used for COT/Thinking style multi-message responses
|
|
44
|
-
thinking_instructions: str | None = None
|
|
45
|
-
thinking: str | None = None
|
|
46
|
-
thinking_final_answer_prompt: str | None = None
|
|
47
|
-
thinking_r1_style: bool = False
|
|
48
|
-
|
|
49
|
-
def supports_cot(self) -> bool:
|
|
50
|
-
if self.thinking_r1_style:
|
|
51
|
-
raise ValueError("R1 style does not support COT")
|
|
52
|
-
|
|
53
|
-
return (
|
|
54
|
-
self.thinking_instructions is not None
|
|
55
|
-
and self.thinking is not None
|
|
56
|
-
and self.thinking_final_answer_prompt is not None
|
|
57
|
-
)
|
|
58
|
-
|
|
59
|
-
|
|
60
38
|
class FormatGenerator(Protocol):
|
|
61
39
|
"""Protocol for format generators"""
|
|
62
40
|
|
|
63
41
|
def __call__(
|
|
64
42
|
self,
|
|
65
|
-
|
|
43
|
+
training_chat: list[ChatMessage],
|
|
66
44
|
) -> Dict[str, Any]: ...
|
|
67
45
|
|
|
68
46
|
|
|
69
|
-
def
|
|
47
|
+
def build_training_chat(
|
|
70
48
|
task_run: TaskRun,
|
|
71
49
|
system_message: str,
|
|
72
|
-
data_strategy:
|
|
50
|
+
data_strategy: ChatStrategy,
|
|
73
51
|
thinking_instructions: str | None = None,
|
|
74
|
-
) ->
|
|
52
|
+
) -> list[ChatMessage]:
|
|
75
53
|
"""
|
|
76
|
-
Generate
|
|
54
|
+
Generate chat message list for training.
|
|
77
55
|
|
|
78
56
|
For final output, get the best task output from the task run, preferring repaired output if available.
|
|
79
57
|
|
|
@@ -84,52 +62,53 @@ def build_training_data(
|
|
|
84
62
|
final_output = task_run.repaired_output.output
|
|
85
63
|
|
|
86
64
|
thinking = None
|
|
87
|
-
thinking_final_answer_prompt = None
|
|
88
|
-
thinking_r1_style = False
|
|
89
|
-
parent_task = task_run.parent_task()
|
|
90
65
|
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
66
|
+
chat_formatter = get_chat_formatter(
|
|
67
|
+
data_strategy,
|
|
68
|
+
system_message,
|
|
69
|
+
task_run.input,
|
|
70
|
+
thinking_instructions,
|
|
71
|
+
)
|
|
72
|
+
# First turn already has it's content (user message)
|
|
73
|
+
chat_formatter.next_turn(None)
|
|
74
|
+
|
|
75
|
+
match data_strategy:
|
|
76
|
+
case ChatStrategy.single_turn:
|
|
77
|
+
chat_formatter.next_turn(final_output)
|
|
78
|
+
case ChatStrategy.two_message_cot:
|
|
79
|
+
thinking = get_thinking_data(task_run)
|
|
80
|
+
chat_formatter.next_turn(thinking)
|
|
81
|
+
chat_formatter.next_turn(final_output)
|
|
82
|
+
case ChatStrategy.two_message_cot_legacy:
|
|
83
|
+
thinking = get_thinking_data(task_run)
|
|
84
|
+
chat_formatter.next_turn(thinking)
|
|
85
|
+
chat_formatter.next_turn(final_output)
|
|
86
|
+
case ChatStrategy.single_turn_r1_thinking:
|
|
100
87
|
if thinking_instructions:
|
|
101
88
|
raise ValueError(
|
|
102
89
|
"Thinking instructions are not supported when fine-tuning thinking models (R1, QwQ, etc). Please remove the thinking instructions."
|
|
103
90
|
)
|
|
104
|
-
thinking_r1_style = True
|
|
105
|
-
elif (
|
|
106
|
-
data_strategy == FinetuneDataStrategy.final_and_intermediate
|
|
107
|
-
and task_run.has_thinking_training_data()
|
|
108
|
-
):
|
|
109
|
-
if not parent_task:
|
|
110
|
-
raise ValueError(
|
|
111
|
-
"TaskRuns for training required a parent Task for building a chain of thought prompts. Train without COT, or save this TaskRun to a parent Task."
|
|
112
|
-
)
|
|
113
91
|
|
|
114
|
-
|
|
92
|
+
thinking = get_thinking_data(task_run)
|
|
93
|
+
response_msg = serialize_r1_style_message(thinking, final_output)
|
|
94
|
+
chat_formatter.next_turn(response_msg)
|
|
95
|
+
case _:
|
|
96
|
+
raise_exhaustive_enum_error(data_strategy)
|
|
115
97
|
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
thinking_final_answer_prompt=thinking_final_answer_prompt,
|
|
131
|
-
thinking_r1_style=thinking_r1_style,
|
|
132
|
-
)
|
|
98
|
+
return chat_formatter.messages
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def get_thinking_data(task_run: TaskRun) -> str:
|
|
102
|
+
"""
|
|
103
|
+
Raises an error if thinking data is not present.
|
|
104
|
+
"""
|
|
105
|
+
thinking = task_run.thinking_training_data()
|
|
106
|
+
if thinking is None:
|
|
107
|
+
raise ValueError(
|
|
108
|
+
"Thinking data is required when fine-tuning thinking models. Please ensure your fine-tuning dataset contains reasoning or chain of thought output for every entry."
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
return thinking
|
|
133
112
|
|
|
134
113
|
|
|
135
114
|
def serialize_r1_style_message(thinking: str | None, final_output: str):
|
|
@@ -141,125 +120,78 @@ def serialize_r1_style_message(thinking: str | None, final_output: str):
|
|
|
141
120
|
return f"<think>\n{thinking}\n</think>\n\n{final_output}"
|
|
142
121
|
|
|
143
122
|
|
|
144
|
-
def
|
|
145
|
-
|
|
146
|
-
) ->
|
|
147
|
-
"""Generate OpenAI chat
|
|
123
|
+
def generate_chat_message_list(
|
|
124
|
+
training_chat: list[ChatMessage],
|
|
125
|
+
) -> list[dict[str, str | None]]:
|
|
126
|
+
"""Generate OpenAI chat list. Not the full OpenAI body, just the list of messages."""
|
|
148
127
|
|
|
149
|
-
messages: list[dict[str, str | None]] = [
|
|
150
|
-
{"role": "system", "content": training_data.system_message},
|
|
151
|
-
{"role": "user", "content": training_data.input},
|
|
152
|
-
]
|
|
128
|
+
messages: list[dict[str, str | None]] = []
|
|
153
129
|
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
{
|
|
158
|
-
"role": "assistant",
|
|
159
|
-
"content": serialize_r1_style_message(
|
|
160
|
-
thinking=training_data.thinking,
|
|
161
|
-
final_output=training_data.final_output,
|
|
162
|
-
),
|
|
163
|
-
}
|
|
164
|
-
]
|
|
165
|
-
)
|
|
130
|
+
for msg in training_chat:
|
|
131
|
+
if msg.role not in ["user", "assistant", "system"]:
|
|
132
|
+
raise ValueError(f"Unsupported role for OpenAI chat format: {msg.role}")
|
|
166
133
|
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
{"role": "assistant", "content": training_data.thinking},
|
|
173
|
-
{
|
|
174
|
-
"role": "user",
|
|
175
|
-
"content": training_data.thinking_final_answer_prompt,
|
|
176
|
-
},
|
|
177
|
-
]
|
|
134
|
+
messages.append(
|
|
135
|
+
{
|
|
136
|
+
"role": msg.role,
|
|
137
|
+
"content": msg.content,
|
|
138
|
+
}
|
|
178
139
|
)
|
|
179
140
|
|
|
180
|
-
messages
|
|
141
|
+
return messages
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def generate_chat_message_response(
|
|
145
|
+
training_chat: list[ChatMessage],
|
|
146
|
+
) -> Dict[str, Any]:
|
|
147
|
+
"""Generate OpenAI chat format with plaintext response"""
|
|
148
|
+
|
|
149
|
+
messages: list[dict[str, str | None]] = generate_chat_message_list(training_chat)
|
|
181
150
|
|
|
182
151
|
return {"messages": messages}
|
|
183
152
|
|
|
184
153
|
|
|
185
|
-
def
|
|
186
|
-
|
|
187
|
-
)
|
|
188
|
-
|
|
189
|
-
# Load and dump to ensure it's valid JSON and goes to 1 line
|
|
154
|
+
def last_message_structured_content(training_chat: list[ChatMessage]) -> Dict:
|
|
155
|
+
"""Get the structured content of the last message"""
|
|
156
|
+
if len(training_chat) < 1:
|
|
157
|
+
raise ValueError("Training chat is empty")
|
|
190
158
|
try:
|
|
191
|
-
json_data = json.loads(
|
|
159
|
+
json_data = json.loads(training_chat[-1].content or "")
|
|
192
160
|
except json.JSONDecodeError as e:
|
|
193
161
|
raise ValueError(
|
|
194
|
-
f"
|
|
195
|
-
) from e
|
|
196
|
-
json_string = json.dumps(json_data, ensure_ascii=False)
|
|
197
|
-
|
|
198
|
-
messages: list[dict[str, str | None]] = [
|
|
199
|
-
{"role": "system", "content": training_data.system_message},
|
|
200
|
-
{"role": "user", "content": training_data.input},
|
|
201
|
-
]
|
|
202
|
-
|
|
203
|
-
if training_data.thinking_r1_style:
|
|
204
|
-
messages.extend(
|
|
205
|
-
[
|
|
206
|
-
{
|
|
207
|
-
"role": "assistant",
|
|
208
|
-
"content": serialize_r1_style_message(
|
|
209
|
-
thinking=training_data.thinking,
|
|
210
|
-
final_output=training_data.final_output,
|
|
211
|
-
),
|
|
212
|
-
}
|
|
213
|
-
]
|
|
162
|
+
f"Last message is not JSON (structured), and this format expects structured data: {e}"
|
|
214
163
|
)
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
messages.extend(
|
|
219
|
-
[
|
|
220
|
-
{"role": "user", "content": training_data.thinking_instructions},
|
|
221
|
-
{"role": "assistant", "content": training_data.thinking},
|
|
222
|
-
{
|
|
223
|
-
"role": "user",
|
|
224
|
-
"content": training_data.thinking_final_answer_prompt,
|
|
225
|
-
},
|
|
226
|
-
]
|
|
164
|
+
if not isinstance(json_data, dict):
|
|
165
|
+
raise ValueError(
|
|
166
|
+
"Last message is not a JSON Dictionary (structured data), and this format expects structured_data."
|
|
227
167
|
)
|
|
168
|
+
return json_data
|
|
228
169
|
|
|
229
|
-
messages.append({"role": "assistant", "content": json_string})
|
|
230
170
|
|
|
231
|
-
|
|
171
|
+
def generate_json_schema_message(
|
|
172
|
+
training_chat: list[ChatMessage],
|
|
173
|
+
) -> Dict[str, Any]:
|
|
174
|
+
"""Generate OpenAI chat format with validated JSON response"""
|
|
175
|
+
# Load and dump to ensure it's valid JSON and goes to 1 line
|
|
176
|
+
last_msg_data = last_message_structured_content(training_chat)
|
|
177
|
+
|
|
178
|
+
# re-format the json string in the last message for consistency
|
|
179
|
+
json_string = json.dumps(last_msg_data, ensure_ascii=False)
|
|
180
|
+
training_chat[-1].content = json_string
|
|
181
|
+
|
|
182
|
+
return generate_chat_message_response(training_chat)
|
|
232
183
|
|
|
233
184
|
|
|
234
185
|
def generate_chat_message_toolcall(
|
|
235
|
-
|
|
186
|
+
training_chat: list[ChatMessage],
|
|
236
187
|
) -> Dict[str, Any]:
|
|
237
188
|
"""Generate OpenAI chat format with tool call response"""
|
|
238
|
-
|
|
239
|
-
arguments = json.loads(training_data.final_output)
|
|
240
|
-
except json.JSONDecodeError as e:
|
|
241
|
-
raise ValueError(f"Invalid JSON in for tool call: {e}") from e
|
|
189
|
+
last_message_data = last_message_structured_content(training_chat)
|
|
242
190
|
|
|
243
|
-
messages: list[dict[str, Any]] =
|
|
244
|
-
{"role": "system", "content": training_data.system_message},
|
|
245
|
-
{"role": "user", "content": training_data.input},
|
|
246
|
-
]
|
|
191
|
+
messages: list[dict[str, Any]] = generate_chat_message_list(training_chat)
|
|
247
192
|
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
"R1 style thinking is not supported for tool call downloads. Please use a different training strategy."
|
|
251
|
-
)
|
|
252
|
-
elif training_data.supports_cot():
|
|
253
|
-
messages.extend(
|
|
254
|
-
[
|
|
255
|
-
{"role": "user", "content": training_data.thinking_instructions},
|
|
256
|
-
{"role": "assistant", "content": training_data.thinking},
|
|
257
|
-
{
|
|
258
|
-
"role": "user",
|
|
259
|
-
"content": training_data.thinking_final_answer_prompt,
|
|
260
|
-
},
|
|
261
|
-
]
|
|
262
|
-
)
|
|
193
|
+
# remove the last message, we're going to replace it with a toolcall
|
|
194
|
+
messages = messages[:-1]
|
|
263
195
|
|
|
264
196
|
messages.append(
|
|
265
197
|
{
|
|
@@ -271,8 +203,7 @@ def generate_chat_message_toolcall(
|
|
|
271
203
|
"type": "function",
|
|
272
204
|
"function": {
|
|
273
205
|
"name": "task_response",
|
|
274
|
-
|
|
275
|
-
"arguments": json.dumps(arguments, ensure_ascii=False),
|
|
206
|
+
"arguments": json.dumps(last_message_data, ensure_ascii=False),
|
|
276
207
|
},
|
|
277
208
|
}
|
|
278
209
|
],
|
|
@@ -283,76 +214,26 @@ def generate_chat_message_toolcall(
|
|
|
283
214
|
|
|
284
215
|
|
|
285
216
|
def generate_huggingface_chat_template(
|
|
286
|
-
|
|
217
|
+
training_chat: list[ChatMessage],
|
|
287
218
|
) -> Dict[str, Any]:
|
|
288
219
|
"""Generate HuggingFace chat template"""
|
|
289
220
|
|
|
290
|
-
conversations: list[dict[str, Any]] =
|
|
291
|
-
{"role": "system", "content": training_data.system_message},
|
|
292
|
-
{"role": "user", "content": training_data.input},
|
|
293
|
-
]
|
|
294
|
-
|
|
295
|
-
if training_data.thinking_r1_style:
|
|
296
|
-
conversations.extend(
|
|
297
|
-
[
|
|
298
|
-
{
|
|
299
|
-
"role": "assistant",
|
|
300
|
-
"content": serialize_r1_style_message(
|
|
301
|
-
thinking=training_data.thinking,
|
|
302
|
-
final_output=training_data.final_output,
|
|
303
|
-
),
|
|
304
|
-
}
|
|
305
|
-
]
|
|
306
|
-
)
|
|
307
|
-
return {"conversations": conversations}
|
|
308
|
-
|
|
309
|
-
if training_data.supports_cot():
|
|
310
|
-
conversations.extend(
|
|
311
|
-
[
|
|
312
|
-
{"role": "user", "content": training_data.thinking_instructions},
|
|
313
|
-
{"role": "assistant", "content": training_data.thinking},
|
|
314
|
-
{
|
|
315
|
-
"role": "user",
|
|
316
|
-
"content": training_data.thinking_final_answer_prompt,
|
|
317
|
-
},
|
|
318
|
-
]
|
|
319
|
-
)
|
|
320
|
-
|
|
321
|
-
conversations.append({"role": "assistant", "content": training_data.final_output})
|
|
221
|
+
conversations: list[dict[str, Any]] = generate_chat_message_list(training_chat)
|
|
322
222
|
|
|
323
223
|
return {"conversations": conversations}
|
|
324
224
|
|
|
325
225
|
|
|
326
226
|
def generate_huggingface_chat_template_toolcall(
|
|
327
|
-
|
|
227
|
+
training_chat: list[ChatMessage],
|
|
328
228
|
) -> Dict[str, Any]:
|
|
329
229
|
"""Generate HuggingFace chat template with tool calls"""
|
|
330
|
-
|
|
331
|
-
arguments = json.loads(training_data.final_output)
|
|
332
|
-
except json.JSONDecodeError as e:
|
|
333
|
-
raise ValueError(f"Invalid JSON in for tool call: {e}") from e
|
|
230
|
+
last_message_data = last_message_structured_content(training_chat)
|
|
334
231
|
|
|
335
232
|
# See https://huggingface.co/docs/transformers/en/chat_templating
|
|
336
|
-
conversations: list[dict[str, Any]] =
|
|
337
|
-
{"role": "system", "content": training_data.system_message},
|
|
338
|
-
{"role": "user", "content": training_data.input},
|
|
339
|
-
]
|
|
233
|
+
conversations: list[dict[str, Any]] = generate_chat_message_list(training_chat)
|
|
340
234
|
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
"R1 style thinking is not supported for tool call downloads. Please use a different training strategy."
|
|
344
|
-
)
|
|
345
|
-
elif training_data.supports_cot():
|
|
346
|
-
conversations.extend(
|
|
347
|
-
[
|
|
348
|
-
{"role": "user", "content": training_data.thinking_instructions},
|
|
349
|
-
{"role": "assistant", "content": training_data.thinking},
|
|
350
|
-
{
|
|
351
|
-
"role": "user",
|
|
352
|
-
"content": training_data.thinking_final_answer_prompt,
|
|
353
|
-
},
|
|
354
|
-
]
|
|
355
|
-
)
|
|
235
|
+
# remove the last message, we're going to replace it with a toolcall
|
|
236
|
+
conversations = conversations[:-1]
|
|
356
237
|
|
|
357
238
|
conversations.append(
|
|
358
239
|
{
|
|
@@ -363,7 +244,7 @@ def generate_huggingface_chat_template_toolcall(
|
|
|
363
244
|
"function": {
|
|
364
245
|
"name": "task_response",
|
|
365
246
|
"id": str(uuid4()).replace("-", "")[:9],
|
|
366
|
-
"arguments":
|
|
247
|
+
"arguments": last_message_data,
|
|
367
248
|
},
|
|
368
249
|
}
|
|
369
250
|
],
|
|
@@ -373,60 +254,41 @@ def generate_huggingface_chat_template_toolcall(
|
|
|
373
254
|
return {"conversations": conversations}
|
|
374
255
|
|
|
375
256
|
|
|
257
|
+
VERTEX_GEMINI_ROLE_MAP = {
|
|
258
|
+
"system": "system",
|
|
259
|
+
"user": "user",
|
|
260
|
+
"assistant": "model",
|
|
261
|
+
}
|
|
262
|
+
|
|
263
|
+
|
|
376
264
|
def generate_vertex_gemini(
|
|
377
|
-
|
|
265
|
+
training_chat: list[ChatMessage],
|
|
378
266
|
) -> Dict[str, Any]:
|
|
379
|
-
"""Generate Vertex Gemini
|
|
267
|
+
"""Generate Vertex Gemini format (flash and pro)"""
|
|
380
268
|
# See https://cloud.google.com/vertex-ai/generative-ai/docs/models/gemini-supervised-tuning-prepare
|
|
381
269
|
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
270
|
+
# System message get's it's own entry in top level UI
|
|
271
|
+
system_instruction = training_chat[0].content
|
|
272
|
+
|
|
273
|
+
messages: list[Dict[str, Any]] = []
|
|
274
|
+
for msg in training_chat[1:]:
|
|
275
|
+
messages.append(
|
|
385
276
|
{
|
|
386
|
-
"
|
|
277
|
+
"role": VERTEX_GEMINI_ROLE_MAP[msg.role],
|
|
278
|
+
"parts": [{"text": msg.content}],
|
|
387
279
|
}
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
{
|
|
392
|
-
"role": "
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
return {
|
|
283
|
+
"systemInstruction": {
|
|
284
|
+
"role": "system",
|
|
393
285
|
"parts": [
|
|
394
286
|
{
|
|
395
|
-
"text":
|
|
287
|
+
"text": system_instruction,
|
|
396
288
|
}
|
|
397
289
|
],
|
|
398
|
-
}
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
if training_data.thinking_r1_style:
|
|
402
|
-
raise ValueError(
|
|
403
|
-
"R1 style thinking is not supported for Vertex Gemini. Please use a different training strategy."
|
|
404
|
-
)
|
|
405
|
-
elif training_data.supports_cot():
|
|
406
|
-
contents.extend(
|
|
407
|
-
[
|
|
408
|
-
{
|
|
409
|
-
"role": "user",
|
|
410
|
-
"parts": [{"text": training_data.thinking_instructions}],
|
|
411
|
-
},
|
|
412
|
-
{"role": "model", "parts": [{"text": training_data.thinking}]},
|
|
413
|
-
{
|
|
414
|
-
"role": "user",
|
|
415
|
-
"parts": [{"text": training_data.thinking_final_answer_prompt}],
|
|
416
|
-
},
|
|
417
|
-
]
|
|
418
|
-
)
|
|
419
|
-
|
|
420
|
-
contents.append(
|
|
421
|
-
{
|
|
422
|
-
"role": "model",
|
|
423
|
-
"parts": [{"text": training_data.final_output}],
|
|
424
|
-
}
|
|
425
|
-
)
|
|
426
|
-
|
|
427
|
-
return {
|
|
428
|
-
"systemInstruction": system_instruction,
|
|
429
|
-
"contents": contents,
|
|
290
|
+
},
|
|
291
|
+
"contents": messages,
|
|
430
292
|
}
|
|
431
293
|
|
|
432
294
|
|
|
@@ -462,7 +324,7 @@ class DatasetFormatter:
|
|
|
462
324
|
self,
|
|
463
325
|
split_name: str,
|
|
464
326
|
format_type: DatasetFormat,
|
|
465
|
-
data_strategy:
|
|
327
|
+
data_strategy: ChatStrategy,
|
|
466
328
|
path: Path | None = None,
|
|
467
329
|
) -> Path:
|
|
468
330
|
"""
|
|
@@ -508,13 +370,13 @@ class DatasetFormatter:
|
|
|
508
370
|
f"Task run {run_id} not found. This is required by this dataset."
|
|
509
371
|
)
|
|
510
372
|
|
|
511
|
-
|
|
373
|
+
training_chat = build_training_chat(
|
|
512
374
|
task_run=task_run,
|
|
513
375
|
system_message=self.system_message,
|
|
514
376
|
data_strategy=data_strategy,
|
|
515
377
|
thinking_instructions=self.thinking_instructions,
|
|
516
378
|
)
|
|
517
|
-
example = generator(
|
|
379
|
+
example = generator(training_chat)
|
|
518
380
|
# Allow non-ascii characters in the dataset.
|
|
519
381
|
# Better readability for non-English users. If you don't support UTF-8... you should.
|
|
520
382
|
f.write(json.dumps(example, ensure_ascii=False) + "\n")
|
|
@@ -4,13 +4,13 @@ import pytest
|
|
|
4
4
|
|
|
5
5
|
from kiln_ai.adapters.fine_tune.base_finetune import (
|
|
6
6
|
BaseFinetuneAdapter,
|
|
7
|
-
FinetuneDataStrategy,
|
|
8
7
|
FineTuneParameter,
|
|
9
8
|
FineTuneStatus,
|
|
10
9
|
FineTuneStatusType,
|
|
11
10
|
)
|
|
12
11
|
from kiln_ai.datamodel import DatasetSplit, Task
|
|
13
12
|
from kiln_ai.datamodel import Finetune as FinetuneModel
|
|
13
|
+
from kiln_ai.datamodel.datamodel_enums import ChatStrategy
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
class MockFinetune(BaseFinetuneAdapter):
|
|
@@ -162,7 +162,7 @@ async def test_create_and_start_success(mock_dataset):
|
|
|
162
162
|
train_split_name="train",
|
|
163
163
|
parameters={"epochs": 10}, # Required parameter
|
|
164
164
|
system_message="Test system message",
|
|
165
|
-
data_strategy=
|
|
165
|
+
data_strategy=ChatStrategy.single_turn,
|
|
166
166
|
thinking_instructions=None,
|
|
167
167
|
)
|
|
168
168
|
|
|
@@ -176,7 +176,7 @@ async def test_create_and_start_success(mock_dataset):
|
|
|
176
176
|
assert datamodel.parameters == {"epochs": 10}
|
|
177
177
|
assert datamodel.system_message == "Test system message"
|
|
178
178
|
assert datamodel.path.exists()
|
|
179
|
-
assert datamodel.data_strategy ==
|
|
179
|
+
assert datamodel.data_strategy == ChatStrategy.single_turn
|
|
180
180
|
assert datamodel.thinking_instructions is None
|
|
181
181
|
|
|
182
182
|
|
|
@@ -192,7 +192,7 @@ async def test_create_and_start_with_all_params(mock_dataset):
|
|
|
192
192
|
description="Custom Description",
|
|
193
193
|
validation_split_name="test",
|
|
194
194
|
system_message="Test system message",
|
|
195
|
-
data_strategy=
|
|
195
|
+
data_strategy=ChatStrategy.two_message_cot,
|
|
196
196
|
thinking_instructions="Custom thinking instructions",
|
|
197
197
|
)
|
|
198
198
|
|
|
@@ -202,7 +202,7 @@ async def test_create_and_start_with_all_params(mock_dataset):
|
|
|
202
202
|
assert datamodel.parameters == {"epochs": 10, "learning_rate": 0.001}
|
|
203
203
|
assert datamodel.system_message == "Test system message"
|
|
204
204
|
assert adapter.datamodel == datamodel
|
|
205
|
-
assert datamodel.data_strategy ==
|
|
205
|
+
assert datamodel.data_strategy == ChatStrategy.two_message_cot
|
|
206
206
|
assert datamodel.thinking_instructions == "Custom thinking instructions"
|
|
207
207
|
|
|
208
208
|
# load the datamodel from the file, confirm it's saved
|
|
@@ -221,7 +221,7 @@ async def test_create_and_start_invalid_parameters(mock_dataset):
|
|
|
221
221
|
parameters={"learning_rate": 0.001}, # Missing required 'epochs'
|
|
222
222
|
system_message="Test system message",
|
|
223
223
|
thinking_instructions=None,
|
|
224
|
-
data_strategy=
|
|
224
|
+
data_strategy=ChatStrategy.single_turn,
|
|
225
225
|
)
|
|
226
226
|
|
|
227
227
|
|
|
@@ -240,7 +240,7 @@ async def test_create_and_start_no_parent_task():
|
|
|
240
240
|
train_split_name="train",
|
|
241
241
|
parameters={"epochs": 10},
|
|
242
242
|
system_message="Test system message",
|
|
243
|
-
data_strategy=
|
|
243
|
+
data_strategy=ChatStrategy.single_turn,
|
|
244
244
|
thinking_instructions=None,
|
|
245
245
|
)
|
|
246
246
|
|
|
@@ -263,7 +263,7 @@ async def test_create_and_start_no_parent_task_path():
|
|
|
263
263
|
train_split_name="train",
|
|
264
264
|
parameters={"epochs": 10},
|
|
265
265
|
system_message="Test system message",
|
|
266
|
-
data_strategy=
|
|
266
|
+
data_strategy=ChatStrategy.single_turn,
|
|
267
267
|
thinking_instructions=None,
|
|
268
268
|
)
|
|
269
269
|
|
|
@@ -282,7 +282,7 @@ async def test_create_and_start_invalid_train_split(mock_dataset):
|
|
|
282
282
|
train_split_name="invalid_train", # Invalid train split
|
|
283
283
|
parameters={"epochs": 10},
|
|
284
284
|
system_message="Test system message",
|
|
285
|
-
data_strategy=
|
|
285
|
+
data_strategy=ChatStrategy.single_turn,
|
|
286
286
|
thinking_instructions=None,
|
|
287
287
|
)
|
|
288
288
|
|
|
@@ -302,6 +302,6 @@ async def test_create_and_start_invalid_validation_split(mock_dataset):
|
|
|
302
302
|
validation_split_name="invalid_test", # Invalid validation split
|
|
303
303
|
parameters={"epochs": 10},
|
|
304
304
|
system_message="Test system message",
|
|
305
|
-
data_strategy=
|
|
305
|
+
data_strategy=ChatStrategy.single_turn,
|
|
306
306
|
thinking_instructions=None,
|
|
307
307
|
)
|