kiln-ai 0.15.0__py3-none-any.whl → 0.17.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +2 -0
- kiln_ai/adapters/adapter_registry.py +22 -44
- kiln_ai/adapters/chat/__init__.py +8 -0
- kiln_ai/adapters/chat/chat_formatter.py +234 -0
- kiln_ai/adapters/chat/test_chat_formatter.py +131 -0
- kiln_ai/adapters/data_gen/test_data_gen_task.py +19 -6
- kiln_ai/adapters/eval/base_eval.py +8 -6
- kiln_ai/adapters/eval/eval_runner.py +9 -65
- kiln_ai/adapters/eval/g_eval.py +26 -8
- kiln_ai/adapters/eval/test_base_eval.py +166 -15
- kiln_ai/adapters/eval/test_eval_runner.py +3 -0
- kiln_ai/adapters/eval/test_g_eval.py +1 -0
- kiln_ai/adapters/fine_tune/base_finetune.py +2 -2
- kiln_ai/adapters/fine_tune/dataset_formatter.py +153 -197
- kiln_ai/adapters/fine_tune/test_base_finetune.py +10 -10
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +402 -211
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +3 -3
- kiln_ai/adapters/fine_tune/test_openai_finetune.py +6 -6
- kiln_ai/adapters/fine_tune/test_together_finetune.py +1 -0
- kiln_ai/adapters/fine_tune/test_vertex_finetune.py +4 -4
- kiln_ai/adapters/fine_tune/together_finetune.py +12 -1
- kiln_ai/adapters/ml_model_list.py +556 -45
- kiln_ai/adapters/model_adapters/base_adapter.py +100 -35
- kiln_ai/adapters/model_adapters/litellm_adapter.py +116 -100
- kiln_ai/adapters/model_adapters/litellm_config.py +3 -2
- kiln_ai/adapters/model_adapters/test_base_adapter.py +299 -52
- kiln_ai/adapters/model_adapters/test_litellm_adapter.py +121 -22
- kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +44 -2
- kiln_ai/adapters/model_adapters/test_structured_output.py +48 -18
- kiln_ai/adapters/parsers/base_parser.py +0 -3
- kiln_ai/adapters/parsers/parser_registry.py +5 -3
- kiln_ai/adapters/parsers/r1_parser.py +17 -2
- kiln_ai/adapters/parsers/request_formatters.py +40 -0
- kiln_ai/adapters/parsers/test_parser_registry.py +2 -2
- kiln_ai/adapters/parsers/test_r1_parser.py +44 -1
- kiln_ai/adapters/parsers/test_request_formatters.py +76 -0
- kiln_ai/adapters/prompt_builders.py +14 -17
- kiln_ai/adapters/provider_tools.py +39 -4
- kiln_ai/adapters/repair/test_repair_task.py +27 -5
- kiln_ai/adapters/test_adapter_registry.py +88 -28
- kiln_ai/adapters/test_ml_model_list.py +158 -0
- kiln_ai/adapters/test_prompt_adaptors.py +17 -3
- kiln_ai/adapters/test_prompt_builders.py +27 -19
- kiln_ai/adapters/test_provider_tools.py +130 -12
- kiln_ai/datamodel/__init__.py +2 -2
- kiln_ai/datamodel/datamodel_enums.py +43 -4
- kiln_ai/datamodel/dataset_filters.py +69 -1
- kiln_ai/datamodel/dataset_split.py +4 -0
- kiln_ai/datamodel/eval.py +8 -0
- kiln_ai/datamodel/finetune.py +13 -7
- kiln_ai/datamodel/prompt_id.py +1 -0
- kiln_ai/datamodel/task.py +68 -7
- kiln_ai/datamodel/task_output.py +1 -1
- kiln_ai/datamodel/task_run.py +39 -7
- kiln_ai/datamodel/test_basemodel.py +5 -8
- kiln_ai/datamodel/test_dataset_filters.py +82 -0
- kiln_ai/datamodel/test_dataset_split.py +2 -8
- kiln_ai/datamodel/test_example_models.py +54 -0
- kiln_ai/datamodel/test_models.py +80 -9
- kiln_ai/datamodel/test_task.py +168 -2
- kiln_ai/utils/async_job_runner.py +106 -0
- kiln_ai/utils/config.py +3 -2
- kiln_ai/utils/dataset_import.py +81 -19
- kiln_ai/utils/logging.py +165 -0
- kiln_ai/utils/test_async_job_runner.py +199 -0
- kiln_ai/utils/test_config.py +23 -0
- kiln_ai/utils/test_dataset_import.py +272 -10
- {kiln_ai-0.15.0.dist-info → kiln_ai-0.17.0.dist-info}/METADATA +1 -1
- kiln_ai-0.17.0.dist-info/RECORD +113 -0
- kiln_ai-0.15.0.dist-info/RECORD +0 -104
- {kiln_ai-0.15.0.dist-info → kiln_ai-0.17.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.15.0.dist-info → kiln_ai-0.17.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -6,8 +6,13 @@ from pathlib import Path
|
|
|
6
6
|
from typing import Any, Dict, Protocol
|
|
7
7
|
from uuid import uuid4
|
|
8
8
|
|
|
9
|
-
from kiln_ai.adapters.
|
|
10
|
-
|
|
9
|
+
from kiln_ai.adapters.chat.chat_formatter import (
|
|
10
|
+
ChatMessage,
|
|
11
|
+
get_chat_formatter,
|
|
12
|
+
)
|
|
13
|
+
from kiln_ai.datamodel import DatasetSplit, TaskRun
|
|
14
|
+
from kiln_ai.datamodel.datamodel_enums import THINKING_DATA_STRATEGIES, ChatStrategy
|
|
15
|
+
from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
|
|
11
16
|
|
|
12
17
|
|
|
13
18
|
class DatasetFormat(str, Enum):
|
|
@@ -34,41 +39,23 @@ class DatasetFormat(str, Enum):
|
|
|
34
39
|
VERTEX_GEMINI = "vertex_gemini"
|
|
35
40
|
|
|
36
41
|
|
|
37
|
-
@dataclass
|
|
38
|
-
class ModelTrainingData:
|
|
39
|
-
input: str
|
|
40
|
-
system_message: str
|
|
41
|
-
final_output: str
|
|
42
|
-
# These 3 are optional, and used for COT/Thinking style multi-message responses
|
|
43
|
-
thinking_instructions: str | None = None
|
|
44
|
-
thinking: str | None = None
|
|
45
|
-
thinking_final_answer_prompt: str | None = None
|
|
46
|
-
|
|
47
|
-
def supports_cot(self) -> bool:
|
|
48
|
-
return (
|
|
49
|
-
self.thinking_instructions is not None
|
|
50
|
-
and self.thinking is not None
|
|
51
|
-
and self.thinking_final_answer_prompt is not None
|
|
52
|
-
)
|
|
53
|
-
|
|
54
|
-
|
|
55
42
|
class FormatGenerator(Protocol):
|
|
56
43
|
"""Protocol for format generators"""
|
|
57
44
|
|
|
58
45
|
def __call__(
|
|
59
46
|
self,
|
|
60
|
-
|
|
47
|
+
training_chat: list[ChatMessage],
|
|
61
48
|
) -> Dict[str, Any]: ...
|
|
62
49
|
|
|
63
50
|
|
|
64
|
-
def
|
|
51
|
+
def build_training_chat(
|
|
65
52
|
task_run: TaskRun,
|
|
66
53
|
system_message: str,
|
|
67
|
-
|
|
54
|
+
data_strategy: ChatStrategy,
|
|
68
55
|
thinking_instructions: str | None = None,
|
|
69
|
-
) ->
|
|
56
|
+
) -> list[ChatMessage]:
|
|
70
57
|
"""
|
|
71
|
-
Generate
|
|
58
|
+
Generate chat message list for training.
|
|
72
59
|
|
|
73
60
|
For final output, get the best task output from the task run, preferring repaired output if available.
|
|
74
61
|
|
|
@@ -79,126 +66,136 @@ def build_training_data(
|
|
|
79
66
|
final_output = task_run.repaired_output.output
|
|
80
67
|
|
|
81
68
|
thinking = None
|
|
82
|
-
thinking_final_answer_prompt = None
|
|
83
|
-
parent_task = task_run.parent_task()
|
|
84
|
-
|
|
85
|
-
if include_cot and task_run.has_thinking_training_data():
|
|
86
|
-
if not parent_task:
|
|
87
|
-
raise ValueError(
|
|
88
|
-
"TaskRuns for training required a parent Task for building a chain of thought prompts. Train without COT, or save this TaskRun to a parent Task."
|
|
89
|
-
)
|
|
90
|
-
|
|
91
|
-
# Prefer reasoning to cot if both are present
|
|
92
|
-
intermediate_outputs = task_run.intermediate_outputs or {}
|
|
93
|
-
thinking = intermediate_outputs.get("reasoning") or intermediate_outputs.get(
|
|
94
|
-
"chain_of_thought"
|
|
95
|
-
)
|
|
96
69
|
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
"Thinking instructions are required when data_strategy is final_and_intermediate"
|
|
103
|
-
)
|
|
104
|
-
|
|
105
|
-
return ModelTrainingData(
|
|
106
|
-
input=task_run.input,
|
|
107
|
-
system_message=system_message,
|
|
108
|
-
final_output=final_output,
|
|
109
|
-
thinking=thinking,
|
|
110
|
-
thinking_instructions=thinking_instructions,
|
|
111
|
-
thinking_final_answer_prompt=thinking_final_answer_prompt,
|
|
70
|
+
chat_formatter = get_chat_formatter(
|
|
71
|
+
data_strategy,
|
|
72
|
+
system_message,
|
|
73
|
+
task_run.input,
|
|
74
|
+
thinking_instructions,
|
|
112
75
|
)
|
|
76
|
+
# First turn already has it's content (user message)
|
|
77
|
+
chat_formatter.next_turn(None)
|
|
78
|
+
|
|
79
|
+
match data_strategy:
|
|
80
|
+
case ChatStrategy.single_turn:
|
|
81
|
+
chat_formatter.next_turn(final_output)
|
|
82
|
+
case ChatStrategy.two_message_cot:
|
|
83
|
+
thinking = get_thinking_data(task_run)
|
|
84
|
+
chat_formatter.next_turn(thinking)
|
|
85
|
+
chat_formatter.next_turn(final_output)
|
|
86
|
+
case ChatStrategy.two_message_cot_legacy:
|
|
87
|
+
thinking = get_thinking_data(task_run)
|
|
88
|
+
chat_formatter.next_turn(thinking)
|
|
89
|
+
chat_formatter.next_turn(final_output)
|
|
90
|
+
case ChatStrategy.single_turn_r1_thinking:
|
|
91
|
+
if thinking_instructions:
|
|
92
|
+
raise ValueError(
|
|
93
|
+
"Thinking instructions are not supported when fine-tuning thinking models (R1, QwQ, etc). Please remove the thinking instructions."
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
thinking = get_thinking_data(task_run)
|
|
97
|
+
response_msg = serialize_r1_style_message(thinking, final_output)
|
|
98
|
+
chat_formatter.next_turn(response_msg)
|
|
99
|
+
case _:
|
|
100
|
+
raise_exhaustive_enum_error(data_strategy)
|
|
101
|
+
|
|
102
|
+
return chat_formatter.messages
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def get_thinking_data(task_run: TaskRun) -> str:
|
|
106
|
+
"""
|
|
107
|
+
Raises an error if thinking data is not present.
|
|
108
|
+
"""
|
|
109
|
+
thinking = task_run.thinking_training_data()
|
|
110
|
+
if thinking is None:
|
|
111
|
+
raise ValueError(
|
|
112
|
+
"Thinking data is required when fine-tuning thinking models. Please ensure your fine-tuning dataset contains reasoning or chain of thought output for every entry."
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
return thinking
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def serialize_r1_style_message(thinking: str | None, final_output: str):
|
|
119
|
+
if thinking is None or len(thinking.strip()) == 0:
|
|
120
|
+
raise ValueError(
|
|
121
|
+
"Thinking data is required when fine-tuning thinking models (R1, QwQ, etc). Please ensure your fine-tuning dataset contains reasoning or chain of thought output for every entry."
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
return f"<think>\n{thinking}\n</think>\n\n{final_output}"
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def generate_chat_message_list(
|
|
128
|
+
training_chat: list[ChatMessage],
|
|
129
|
+
) -> list[dict[str, str | None]]:
|
|
130
|
+
"""Generate OpenAI chat list. Not the full OpenAI body, just the list of messages."""
|
|
131
|
+
|
|
132
|
+
messages: list[dict[str, str | None]] = []
|
|
133
|
+
|
|
134
|
+
for msg in training_chat:
|
|
135
|
+
if msg.role not in ["user", "assistant", "system"]:
|
|
136
|
+
raise ValueError(f"Unsupported role for OpenAI chat format: {msg.role}")
|
|
137
|
+
|
|
138
|
+
messages.append(
|
|
139
|
+
{
|
|
140
|
+
"role": msg.role,
|
|
141
|
+
"content": msg.content,
|
|
142
|
+
}
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
return messages
|
|
113
146
|
|
|
114
147
|
|
|
115
148
|
def generate_chat_message_response(
|
|
116
|
-
|
|
149
|
+
training_chat: list[ChatMessage],
|
|
117
150
|
) -> Dict[str, Any]:
|
|
118
151
|
"""Generate OpenAI chat format with plaintext response"""
|
|
119
152
|
|
|
120
|
-
messages: list[dict[str, str | None]] =
|
|
121
|
-
{"role": "system", "content": training_data.system_message},
|
|
122
|
-
{"role": "user", "content": training_data.input},
|
|
123
|
-
]
|
|
153
|
+
messages: list[dict[str, str | None]] = generate_chat_message_list(training_chat)
|
|
124
154
|
|
|
125
|
-
|
|
126
|
-
messages.extend(
|
|
127
|
-
[
|
|
128
|
-
{"role": "user", "content": training_data.thinking_instructions},
|
|
129
|
-
{"role": "assistant", "content": training_data.thinking},
|
|
130
|
-
{
|
|
131
|
-
"role": "user",
|
|
132
|
-
"content": training_data.thinking_final_answer_prompt,
|
|
133
|
-
},
|
|
134
|
-
]
|
|
135
|
-
)
|
|
155
|
+
return {"messages": messages}
|
|
136
156
|
|
|
137
|
-
messages.append({"role": "assistant", "content": training_data.final_output})
|
|
138
157
|
|
|
139
|
-
|
|
158
|
+
def last_message_structured_content(training_chat: list[ChatMessage]) -> Dict:
|
|
159
|
+
"""Get the structured content of the last message"""
|
|
160
|
+
if len(training_chat) < 1:
|
|
161
|
+
raise ValueError("Training chat is empty")
|
|
162
|
+
try:
|
|
163
|
+
json_data = json.loads(training_chat[-1].content or "")
|
|
164
|
+
except json.JSONDecodeError as e:
|
|
165
|
+
raise ValueError(
|
|
166
|
+
f"Last message is not JSON (structured), and this format expects structured data: {e}"
|
|
167
|
+
)
|
|
168
|
+
if not isinstance(json_data, dict):
|
|
169
|
+
raise ValueError(
|
|
170
|
+
"Last message is not a JSON Dictionary (structured data), and this format expects structured_data."
|
|
171
|
+
)
|
|
172
|
+
return json_data
|
|
140
173
|
|
|
141
174
|
|
|
142
175
|
def generate_json_schema_message(
|
|
143
|
-
|
|
176
|
+
training_chat: list[ChatMessage],
|
|
144
177
|
) -> Dict[str, Any]:
|
|
145
178
|
"""Generate OpenAI chat format with validated JSON response"""
|
|
146
179
|
# Load and dump to ensure it's valid JSON and goes to 1 line
|
|
147
|
-
|
|
148
|
-
json_data = json.loads(training_data.final_output)
|
|
149
|
-
except json.JSONDecodeError as e:
|
|
150
|
-
raise ValueError(
|
|
151
|
-
f"Invalid JSON in JSON Schema training set: {e}\nOutput Data: {training_data.final_output}"
|
|
152
|
-
) from e
|
|
153
|
-
json_string = json.dumps(json_data, ensure_ascii=False)
|
|
154
|
-
|
|
155
|
-
messages: list[dict[str, str | None]] = [
|
|
156
|
-
{"role": "system", "content": training_data.system_message},
|
|
157
|
-
{"role": "user", "content": training_data.input},
|
|
158
|
-
]
|
|
159
|
-
|
|
160
|
-
if training_data.supports_cot():
|
|
161
|
-
messages.extend(
|
|
162
|
-
[
|
|
163
|
-
{"role": "user", "content": training_data.thinking_instructions},
|
|
164
|
-
{"role": "assistant", "content": training_data.thinking},
|
|
165
|
-
{
|
|
166
|
-
"role": "user",
|
|
167
|
-
"content": training_data.thinking_final_answer_prompt,
|
|
168
|
-
},
|
|
169
|
-
]
|
|
170
|
-
)
|
|
180
|
+
last_msg_data = last_message_structured_content(training_chat)
|
|
171
181
|
|
|
172
|
-
|
|
182
|
+
# re-format the json string in the last message for consistency
|
|
183
|
+
json_string = json.dumps(last_msg_data, ensure_ascii=False)
|
|
184
|
+
training_chat[-1].content = json_string
|
|
173
185
|
|
|
174
|
-
return
|
|
186
|
+
return generate_chat_message_response(training_chat)
|
|
175
187
|
|
|
176
188
|
|
|
177
189
|
def generate_chat_message_toolcall(
|
|
178
|
-
|
|
190
|
+
training_chat: list[ChatMessage],
|
|
179
191
|
) -> Dict[str, Any]:
|
|
180
192
|
"""Generate OpenAI chat format with tool call response"""
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
messages
|
|
187
|
-
{"role": "system", "content": training_data.system_message},
|
|
188
|
-
{"role": "user", "content": training_data.input},
|
|
189
|
-
]
|
|
190
|
-
|
|
191
|
-
if training_data.supports_cot():
|
|
192
|
-
messages.extend(
|
|
193
|
-
[
|
|
194
|
-
{"role": "user", "content": training_data.thinking_instructions},
|
|
195
|
-
{"role": "assistant", "content": training_data.thinking},
|
|
196
|
-
{
|
|
197
|
-
"role": "user",
|
|
198
|
-
"content": training_data.thinking_final_answer_prompt,
|
|
199
|
-
},
|
|
200
|
-
]
|
|
201
|
-
)
|
|
193
|
+
last_message_data = last_message_structured_content(training_chat)
|
|
194
|
+
|
|
195
|
+
messages: list[dict[str, Any]] = generate_chat_message_list(training_chat)
|
|
196
|
+
|
|
197
|
+
# remove the last message, we're going to replace it with a toolcall
|
|
198
|
+
messages = messages[:-1]
|
|
202
199
|
|
|
203
200
|
messages.append(
|
|
204
201
|
{
|
|
@@ -210,8 +207,7 @@ def generate_chat_message_toolcall(
|
|
|
210
207
|
"type": "function",
|
|
211
208
|
"function": {
|
|
212
209
|
"name": "task_response",
|
|
213
|
-
|
|
214
|
-
"arguments": json.dumps(arguments, ensure_ascii=False),
|
|
210
|
+
"arguments": json.dumps(last_message_data, ensure_ascii=False),
|
|
215
211
|
},
|
|
216
212
|
}
|
|
217
213
|
],
|
|
@@ -222,52 +218,26 @@ def generate_chat_message_toolcall(
|
|
|
222
218
|
|
|
223
219
|
|
|
224
220
|
def generate_huggingface_chat_template(
|
|
225
|
-
|
|
221
|
+
training_chat: list[ChatMessage],
|
|
226
222
|
) -> Dict[str, Any]:
|
|
227
223
|
"""Generate HuggingFace chat template"""
|
|
228
224
|
|
|
229
|
-
conversations: list[dict[str, Any]] =
|
|
230
|
-
{"role": "system", "content": training_data.system_message},
|
|
231
|
-
{"role": "user", "content": training_data.input},
|
|
232
|
-
]
|
|
233
|
-
|
|
234
|
-
if training_data.supports_cot():
|
|
235
|
-
conversations.extend(
|
|
236
|
-
[
|
|
237
|
-
{"role": "user", "content": training_data.thinking_instructions},
|
|
238
|
-
{"role": "assistant", "content": training_data.thinking},
|
|
239
|
-
{"role": "user", "content": training_data.thinking_final_answer_prompt},
|
|
240
|
-
]
|
|
241
|
-
)
|
|
242
|
-
|
|
243
|
-
conversations.append({"role": "assistant", "content": training_data.final_output})
|
|
225
|
+
conversations: list[dict[str, Any]] = generate_chat_message_list(training_chat)
|
|
244
226
|
|
|
245
227
|
return {"conversations": conversations}
|
|
246
228
|
|
|
247
229
|
|
|
248
230
|
def generate_huggingface_chat_template_toolcall(
|
|
249
|
-
|
|
231
|
+
training_chat: list[ChatMessage],
|
|
250
232
|
) -> Dict[str, Any]:
|
|
251
233
|
"""Generate HuggingFace chat template with tool calls"""
|
|
252
|
-
|
|
253
|
-
arguments = json.loads(training_data.final_output)
|
|
254
|
-
except json.JSONDecodeError as e:
|
|
255
|
-
raise ValueError(f"Invalid JSON in for tool call: {e}") from e
|
|
234
|
+
last_message_data = last_message_structured_content(training_chat)
|
|
256
235
|
|
|
257
236
|
# See https://huggingface.co/docs/transformers/en/chat_templating
|
|
258
|
-
conversations: list[dict[str, Any]] =
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
]
|
|
262
|
-
|
|
263
|
-
if training_data.supports_cot():
|
|
264
|
-
conversations.extend(
|
|
265
|
-
[
|
|
266
|
-
{"role": "user", "content": training_data.thinking_instructions},
|
|
267
|
-
{"role": "assistant", "content": training_data.thinking},
|
|
268
|
-
{"role": "user", "content": training_data.thinking_final_answer_prompt},
|
|
269
|
-
]
|
|
270
|
-
)
|
|
237
|
+
conversations: list[dict[str, Any]] = generate_chat_message_list(training_chat)
|
|
238
|
+
|
|
239
|
+
# remove the last message, we're going to replace it with a toolcall
|
|
240
|
+
conversations = conversations[:-1]
|
|
271
241
|
|
|
272
242
|
conversations.append(
|
|
273
243
|
{
|
|
@@ -278,7 +248,7 @@ def generate_huggingface_chat_template_toolcall(
|
|
|
278
248
|
"function": {
|
|
279
249
|
"name": "task_response",
|
|
280
250
|
"id": str(uuid4()).replace("-", "")[:9],
|
|
281
|
-
"arguments":
|
|
251
|
+
"arguments": last_message_data,
|
|
282
252
|
},
|
|
283
253
|
}
|
|
284
254
|
],
|
|
@@ -288,55 +258,41 @@ def generate_huggingface_chat_template_toolcall(
|
|
|
288
258
|
return {"conversations": conversations}
|
|
289
259
|
|
|
290
260
|
|
|
261
|
+
VERTEX_GEMINI_ROLE_MAP = {
|
|
262
|
+
"system": "system",
|
|
263
|
+
"user": "user",
|
|
264
|
+
"assistant": "model",
|
|
265
|
+
}
|
|
266
|
+
|
|
267
|
+
|
|
291
268
|
def generate_vertex_gemini(
|
|
292
|
-
|
|
269
|
+
training_chat: list[ChatMessage],
|
|
293
270
|
) -> Dict[str, Any]:
|
|
294
|
-
"""Generate Vertex Gemini
|
|
271
|
+
"""Generate Vertex Gemini format (flash and pro)"""
|
|
295
272
|
# See https://cloud.google.com/vertex-ai/generative-ai/docs/models/gemini-supervised-tuning-prepare
|
|
296
273
|
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
"role": "user",
|
|
300
|
-
"parts": [
|
|
301
|
-
{
|
|
302
|
-
"text": training_data.input,
|
|
303
|
-
}
|
|
304
|
-
],
|
|
305
|
-
}
|
|
306
|
-
]
|
|
274
|
+
# System message get's it's own entry in top level UI
|
|
275
|
+
system_instruction = training_chat[0].content
|
|
307
276
|
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
{"role": "model", "parts": [{"text": training_data.thinking}]},
|
|
316
|
-
{
|
|
317
|
-
"role": "user",
|
|
318
|
-
"parts": [{"text": training_data.thinking_final_answer_prompt}],
|
|
319
|
-
},
|
|
320
|
-
]
|
|
277
|
+
messages: list[Dict[str, Any]] = []
|
|
278
|
+
for msg in training_chat[1:]:
|
|
279
|
+
messages.append(
|
|
280
|
+
{
|
|
281
|
+
"role": VERTEX_GEMINI_ROLE_MAP[msg.role],
|
|
282
|
+
"parts": [{"text": msg.content}],
|
|
283
|
+
}
|
|
321
284
|
)
|
|
322
285
|
|
|
323
|
-
contents.append(
|
|
324
|
-
{
|
|
325
|
-
"role": "model",
|
|
326
|
-
"parts": [{"text": training_data.final_output}],
|
|
327
|
-
}
|
|
328
|
-
)
|
|
329
|
-
|
|
330
286
|
return {
|
|
331
287
|
"systemInstruction": {
|
|
332
288
|
"role": "system",
|
|
333
289
|
"parts": [
|
|
334
290
|
{
|
|
335
|
-
"text":
|
|
291
|
+
"text": system_instruction,
|
|
336
292
|
}
|
|
337
293
|
],
|
|
338
294
|
},
|
|
339
|
-
"contents":
|
|
295
|
+
"contents": messages,
|
|
340
296
|
}
|
|
341
297
|
|
|
342
298
|
|
|
@@ -372,7 +328,7 @@ class DatasetFormatter:
|
|
|
372
328
|
self,
|
|
373
329
|
split_name: str,
|
|
374
330
|
format_type: DatasetFormat,
|
|
375
|
-
data_strategy:
|
|
331
|
+
data_strategy: ChatStrategy,
|
|
376
332
|
path: Path | None = None,
|
|
377
333
|
) -> Path:
|
|
378
334
|
"""
|
|
@@ -397,7 +353,7 @@ class DatasetFormatter:
|
|
|
397
353
|
|
|
398
354
|
generator = FORMAT_GENERATORS[format_type]
|
|
399
355
|
|
|
400
|
-
include_cot = data_strategy
|
|
356
|
+
include_cot = data_strategy in THINKING_DATA_STRATEGIES
|
|
401
357
|
|
|
402
358
|
# Write to a temp file if no path is provided
|
|
403
359
|
output_path = (
|
|
@@ -418,13 +374,13 @@ class DatasetFormatter:
|
|
|
418
374
|
f"Task run {run_id} not found. This is required by this dataset."
|
|
419
375
|
)
|
|
420
376
|
|
|
421
|
-
|
|
377
|
+
training_chat = build_training_chat(
|
|
422
378
|
task_run=task_run,
|
|
423
379
|
system_message=self.system_message,
|
|
424
|
-
|
|
380
|
+
data_strategy=data_strategy,
|
|
425
381
|
thinking_instructions=self.thinking_instructions,
|
|
426
382
|
)
|
|
427
|
-
example = generator(
|
|
383
|
+
example = generator(training_chat)
|
|
428
384
|
# Allow non-ascii characters in the dataset.
|
|
429
385
|
# Better readability for non-English users. If you don't support UTF-8... you should.
|
|
430
386
|
f.write(json.dumps(example, ensure_ascii=False) + "\n")
|
|
@@ -4,13 +4,13 @@ import pytest
|
|
|
4
4
|
|
|
5
5
|
from kiln_ai.adapters.fine_tune.base_finetune import (
|
|
6
6
|
BaseFinetuneAdapter,
|
|
7
|
-
FinetuneDataStrategy,
|
|
8
7
|
FineTuneParameter,
|
|
9
8
|
FineTuneStatus,
|
|
10
9
|
FineTuneStatusType,
|
|
11
10
|
)
|
|
12
11
|
from kiln_ai.datamodel import DatasetSplit, Task
|
|
13
12
|
from kiln_ai.datamodel import Finetune as FinetuneModel
|
|
13
|
+
from kiln_ai.datamodel.datamodel_enums import ChatStrategy
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
class MockFinetune(BaseFinetuneAdapter):
|
|
@@ -162,7 +162,7 @@ async def test_create_and_start_success(mock_dataset):
|
|
|
162
162
|
train_split_name="train",
|
|
163
163
|
parameters={"epochs": 10}, # Required parameter
|
|
164
164
|
system_message="Test system message",
|
|
165
|
-
data_strategy=
|
|
165
|
+
data_strategy=ChatStrategy.single_turn,
|
|
166
166
|
thinking_instructions=None,
|
|
167
167
|
)
|
|
168
168
|
|
|
@@ -176,7 +176,7 @@ async def test_create_and_start_success(mock_dataset):
|
|
|
176
176
|
assert datamodel.parameters == {"epochs": 10}
|
|
177
177
|
assert datamodel.system_message == "Test system message"
|
|
178
178
|
assert datamodel.path.exists()
|
|
179
|
-
assert datamodel.data_strategy ==
|
|
179
|
+
assert datamodel.data_strategy == ChatStrategy.single_turn
|
|
180
180
|
assert datamodel.thinking_instructions is None
|
|
181
181
|
|
|
182
182
|
|
|
@@ -192,7 +192,7 @@ async def test_create_and_start_with_all_params(mock_dataset):
|
|
|
192
192
|
description="Custom Description",
|
|
193
193
|
validation_split_name="test",
|
|
194
194
|
system_message="Test system message",
|
|
195
|
-
data_strategy=
|
|
195
|
+
data_strategy=ChatStrategy.two_message_cot,
|
|
196
196
|
thinking_instructions="Custom thinking instructions",
|
|
197
197
|
)
|
|
198
198
|
|
|
@@ -202,7 +202,7 @@ async def test_create_and_start_with_all_params(mock_dataset):
|
|
|
202
202
|
assert datamodel.parameters == {"epochs": 10, "learning_rate": 0.001}
|
|
203
203
|
assert datamodel.system_message == "Test system message"
|
|
204
204
|
assert adapter.datamodel == datamodel
|
|
205
|
-
assert datamodel.data_strategy ==
|
|
205
|
+
assert datamodel.data_strategy == ChatStrategy.two_message_cot
|
|
206
206
|
assert datamodel.thinking_instructions == "Custom thinking instructions"
|
|
207
207
|
|
|
208
208
|
# load the datamodel from the file, confirm it's saved
|
|
@@ -221,7 +221,7 @@ async def test_create_and_start_invalid_parameters(mock_dataset):
|
|
|
221
221
|
parameters={"learning_rate": 0.001}, # Missing required 'epochs'
|
|
222
222
|
system_message="Test system message",
|
|
223
223
|
thinking_instructions=None,
|
|
224
|
-
data_strategy=
|
|
224
|
+
data_strategy=ChatStrategy.single_turn,
|
|
225
225
|
)
|
|
226
226
|
|
|
227
227
|
|
|
@@ -240,7 +240,7 @@ async def test_create_and_start_no_parent_task():
|
|
|
240
240
|
train_split_name="train",
|
|
241
241
|
parameters={"epochs": 10},
|
|
242
242
|
system_message="Test system message",
|
|
243
|
-
data_strategy=
|
|
243
|
+
data_strategy=ChatStrategy.single_turn,
|
|
244
244
|
thinking_instructions=None,
|
|
245
245
|
)
|
|
246
246
|
|
|
@@ -263,7 +263,7 @@ async def test_create_and_start_no_parent_task_path():
|
|
|
263
263
|
train_split_name="train",
|
|
264
264
|
parameters={"epochs": 10},
|
|
265
265
|
system_message="Test system message",
|
|
266
|
-
data_strategy=
|
|
266
|
+
data_strategy=ChatStrategy.single_turn,
|
|
267
267
|
thinking_instructions=None,
|
|
268
268
|
)
|
|
269
269
|
|
|
@@ -282,7 +282,7 @@ async def test_create_and_start_invalid_train_split(mock_dataset):
|
|
|
282
282
|
train_split_name="invalid_train", # Invalid train split
|
|
283
283
|
parameters={"epochs": 10},
|
|
284
284
|
system_message="Test system message",
|
|
285
|
-
data_strategy=
|
|
285
|
+
data_strategy=ChatStrategy.single_turn,
|
|
286
286
|
thinking_instructions=None,
|
|
287
287
|
)
|
|
288
288
|
|
|
@@ -302,6 +302,6 @@ async def test_create_and_start_invalid_validation_split(mock_dataset):
|
|
|
302
302
|
validation_split_name="invalid_test", # Invalid validation split
|
|
303
303
|
parameters={"epochs": 10},
|
|
304
304
|
system_message="Test system message",
|
|
305
|
-
data_strategy=
|
|
305
|
+
data_strategy=ChatStrategy.single_turn,
|
|
306
306
|
thinking_instructions=None,
|
|
307
307
|
)
|