kiln-ai 0.14.0__py3-none-any.whl → 0.16.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kiln_ai/adapters/eval/base_eval.py +7 -2
- kiln_ai/adapters/eval/eval_runner.py +5 -64
- kiln_ai/adapters/eval/g_eval.py +3 -3
- kiln_ai/adapters/fine_tune/base_finetune.py +6 -3
- kiln_ai/adapters/fine_tune/dataset_formatter.py +128 -38
- kiln_ai/adapters/fine_tune/finetune_registry.py +2 -0
- kiln_ai/adapters/fine_tune/fireworks_finetune.py +2 -1
- kiln_ai/adapters/fine_tune/test_base_finetune.py +7 -0
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +267 -10
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +1 -1
- kiln_ai/adapters/fine_tune/test_vertex_finetune.py +586 -0
- kiln_ai/adapters/fine_tune/vertex_finetune.py +217 -0
- kiln_ai/adapters/ml_model_list.py +817 -62
- kiln_ai/adapters/model_adapters/base_adapter.py +33 -10
- kiln_ai/adapters/model_adapters/litellm_adapter.py +51 -12
- kiln_ai/adapters/model_adapters/test_base_adapter.py +74 -2
- kiln_ai/adapters/model_adapters/test_litellm_adapter.py +65 -1
- kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +3 -2
- kiln_ai/adapters/model_adapters/test_structured_output.py +4 -6
- kiln_ai/adapters/parsers/base_parser.py +0 -3
- kiln_ai/adapters/parsers/parser_registry.py +5 -3
- kiln_ai/adapters/parsers/r1_parser.py +17 -2
- kiln_ai/adapters/parsers/request_formatters.py +40 -0
- kiln_ai/adapters/parsers/test_parser_registry.py +2 -2
- kiln_ai/adapters/parsers/test_r1_parser.py +44 -1
- kiln_ai/adapters/parsers/test_request_formatters.py +76 -0
- kiln_ai/adapters/prompt_builders.py +14 -1
- kiln_ai/adapters/provider_tools.py +25 -1
- kiln_ai/adapters/repair/test_repair_task.py +3 -2
- kiln_ai/adapters/test_prompt_builders.py +24 -3
- kiln_ai/adapters/test_provider_tools.py +86 -1
- kiln_ai/datamodel/__init__.py +2 -0
- kiln_ai/datamodel/datamodel_enums.py +14 -0
- kiln_ai/datamodel/dataset_filters.py +69 -1
- kiln_ai/datamodel/dataset_split.py +4 -0
- kiln_ai/datamodel/eval.py +8 -0
- kiln_ai/datamodel/finetune.py +1 -0
- kiln_ai/datamodel/json_schema.py +24 -7
- kiln_ai/datamodel/prompt_id.py +1 -0
- kiln_ai/datamodel/task_output.py +10 -6
- kiln_ai/datamodel/task_run.py +68 -12
- kiln_ai/datamodel/test_basemodel.py +3 -7
- kiln_ai/datamodel/test_dataset_filters.py +82 -0
- kiln_ai/datamodel/test_dataset_split.py +2 -0
- kiln_ai/datamodel/test_example_models.py +158 -3
- kiln_ai/datamodel/test_json_schema.py +22 -3
- kiln_ai/datamodel/test_model_perf.py +3 -2
- kiln_ai/datamodel/test_models.py +50 -2
- kiln_ai/utils/async_job_runner.py +106 -0
- kiln_ai/utils/dataset_import.py +80 -18
- kiln_ai/utils/test_async_job_runner.py +199 -0
- kiln_ai/utils/test_dataset_import.py +242 -10
- {kiln_ai-0.14.0.dist-info → kiln_ai-0.16.0.dist-info}/METADATA +3 -2
- kiln_ai-0.16.0.dist-info/RECORD +108 -0
- kiln_ai/adapters/test_generate_docs.py +0 -69
- kiln_ai-0.14.0.dist-info/RECORD +0 -103
- {kiln_ai-0.14.0.dist-info → kiln_ai-0.16.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.14.0.dist-info → kiln_ai-0.16.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -2,11 +2,13 @@ import json
|
|
|
2
2
|
from abc import abstractmethod
|
|
3
3
|
from typing import Dict
|
|
4
4
|
|
|
5
|
+
import jsonschema
|
|
6
|
+
|
|
5
7
|
from kiln_ai.adapters.adapter_registry import adapter_for_task
|
|
6
8
|
from kiln_ai.adapters.ml_model_list import ModelProviderName
|
|
7
9
|
from kiln_ai.adapters.model_adapters.base_adapter import AdapterConfig
|
|
8
10
|
from kiln_ai.datamodel.eval import Eval, EvalConfig, EvalScores
|
|
9
|
-
from kiln_ai.datamodel.json_schema import
|
|
11
|
+
from kiln_ai.datamodel.json_schema import validate_schema_with_value_error
|
|
10
12
|
from kiln_ai.datamodel.task import RunConfig, TaskOutputRatingType, TaskRun
|
|
11
13
|
from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
|
|
12
14
|
|
|
@@ -72,7 +74,10 @@ class BaseEval:
|
|
|
72
74
|
run_output = await run_adapter.invoke(parsed_input)
|
|
73
75
|
|
|
74
76
|
eval_output, intermediate_outputs = await self.run_eval(run_output)
|
|
75
|
-
|
|
77
|
+
|
|
78
|
+
validate_schema_with_value_error(
|
|
79
|
+
eval_output, self.score_schema, "Eval output does not match score schema."
|
|
80
|
+
)
|
|
76
81
|
|
|
77
82
|
return run_output, eval_output, intermediate_outputs
|
|
78
83
|
|
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
import asyncio
|
|
2
1
|
import logging
|
|
3
2
|
from dataclasses import dataclass
|
|
4
3
|
from typing import AsyncGenerator, Dict, List, Literal, Set
|
|
@@ -10,6 +9,7 @@ from kiln_ai.datamodel.dataset_filters import dataset_filter_from_id
|
|
|
10
9
|
from kiln_ai.datamodel.eval import EvalConfig, EvalRun, EvalScores
|
|
11
10
|
from kiln_ai.datamodel.task import TaskRunConfig
|
|
12
11
|
from kiln_ai.datamodel.task_run import TaskRun
|
|
12
|
+
from kiln_ai.utils.async_job_runner import AsyncJobRunner, Progress
|
|
13
13
|
|
|
14
14
|
logger = logging.getLogger(__name__)
|
|
15
15
|
|
|
@@ -23,13 +23,6 @@ class EvalJob:
|
|
|
23
23
|
task_run_config: TaskRunConfig | None = None
|
|
24
24
|
|
|
25
25
|
|
|
26
|
-
@dataclass
|
|
27
|
-
class EvalProgress:
|
|
28
|
-
complete: int | None = None
|
|
29
|
-
total: int | None = None
|
|
30
|
-
errors: int | None = None
|
|
31
|
-
|
|
32
|
-
|
|
33
26
|
class EvalRunner:
|
|
34
27
|
"""
|
|
35
28
|
Runs an eval. Async execution is supported to make it faster when using remote/fast model providers.
|
|
@@ -161,67 +154,15 @@ class EvalRunner:
|
|
|
161
154
|
if task_run.id not in already_run[eval_config.id][run_config.id]
|
|
162
155
|
]
|
|
163
156
|
|
|
164
|
-
async def run(self, concurrency: int = 25) -> AsyncGenerator[
|
|
157
|
+
async def run(self, concurrency: int = 25) -> AsyncGenerator[Progress, None]:
|
|
165
158
|
"""
|
|
166
159
|
Runs the configured eval run with parallel workers and yields progress updates.
|
|
167
160
|
"""
|
|
168
161
|
jobs = self.collect_tasks()
|
|
169
162
|
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
# Send initial status
|
|
175
|
-
yield EvalProgress(complete=complete, total=total, errors=errors)
|
|
176
|
-
|
|
177
|
-
worker_queue: asyncio.Queue[EvalJob] = asyncio.Queue()
|
|
178
|
-
for job in jobs:
|
|
179
|
-
worker_queue.put_nowait(job)
|
|
180
|
-
|
|
181
|
-
# simple status queue to return progress. True=success, False=error
|
|
182
|
-
status_queue: asyncio.Queue[bool] = asyncio.Queue()
|
|
183
|
-
|
|
184
|
-
workers = []
|
|
185
|
-
for i in range(concurrency):
|
|
186
|
-
task = asyncio.create_task(self.run_worker(worker_queue, status_queue))
|
|
187
|
-
workers.append(task)
|
|
188
|
-
|
|
189
|
-
# Send status updates until workers are done, and they are all sent
|
|
190
|
-
while not status_queue.empty() or not all(worker.done() for worker in workers):
|
|
191
|
-
try:
|
|
192
|
-
# Use timeout to prevent hanging if all workers complete
|
|
193
|
-
# between our while condition check and get()
|
|
194
|
-
success = await asyncio.wait_for(status_queue.get(), timeout=0.1)
|
|
195
|
-
if success:
|
|
196
|
-
complete += 1
|
|
197
|
-
else:
|
|
198
|
-
errors += 1
|
|
199
|
-
|
|
200
|
-
yield EvalProgress(complete=complete, total=total, errors=errors)
|
|
201
|
-
except asyncio.TimeoutError:
|
|
202
|
-
# Timeout is expected, just continue to recheck worker status
|
|
203
|
-
# Don't love this but beats sentinels for reliability
|
|
204
|
-
continue
|
|
205
|
-
|
|
206
|
-
# These are redundant, but keeping them will catch async errors
|
|
207
|
-
await asyncio.gather(*workers)
|
|
208
|
-
await worker_queue.join()
|
|
209
|
-
|
|
210
|
-
async def run_worker(
|
|
211
|
-
self, worker_queue: asyncio.Queue[EvalJob], status_queue: asyncio.Queue[bool]
|
|
212
|
-
):
|
|
213
|
-
while True:
|
|
214
|
-
try:
|
|
215
|
-
job = worker_queue.get_nowait()
|
|
216
|
-
except asyncio.QueueEmpty:
|
|
217
|
-
# worker can end when the queue is empty
|
|
218
|
-
break
|
|
219
|
-
try:
|
|
220
|
-
success = await self.run_job(job)
|
|
221
|
-
await status_queue.put(success)
|
|
222
|
-
finally:
|
|
223
|
-
# Always mark the dequeued task as done, even on exceptions
|
|
224
|
-
worker_queue.task_done()
|
|
163
|
+
runner = AsyncJobRunner(concurrency=concurrency)
|
|
164
|
+
async for progress in runner.run(jobs, self.run_job):
|
|
165
|
+
yield progress
|
|
225
166
|
|
|
226
167
|
async def run_job(self, job: EvalJob) -> bool:
|
|
227
168
|
try:
|
kiln_ai/adapters/eval/g_eval.py
CHANGED
|
@@ -43,9 +43,9 @@ class GEvalTask(Task, parent_of={}):
|
|
|
43
43
|
|
|
44
44
|
# Build the COT eval instructions
|
|
45
45
|
cot_instructions = "First, think step by step about the model's performance following these evaluation steps:\n\n"
|
|
46
|
-
steps = eval_config.properties.get("eval_steps",
|
|
47
|
-
if not
|
|
48
|
-
raise ValueError("eval_steps must be a list")
|
|
46
|
+
steps = eval_config.properties.get("eval_steps", [])
|
|
47
|
+
if not isinstance(steps, list):
|
|
48
|
+
raise ValueError("eval_steps must be a list.")
|
|
49
49
|
for i, step in enumerate(steps):
|
|
50
50
|
cot_instructions += f"{i + 1}) {step}\n"
|
|
51
51
|
|
|
@@ -166,9 +166,12 @@ class BaseFinetuneAdapter(ABC):
|
|
|
166
166
|
|
|
167
167
|
# Strict type checking for numeric types
|
|
168
168
|
if expected_type is float and not isinstance(value, float):
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
169
|
+
if isinstance(value, int):
|
|
170
|
+
value = float(value)
|
|
171
|
+
else:
|
|
172
|
+
raise ValueError(
|
|
173
|
+
f"Parameter {parameter.name} must be a float, got {type(value)}"
|
|
174
|
+
)
|
|
172
175
|
elif expected_type is int and not isinstance(value, int):
|
|
173
176
|
raise ValueError(
|
|
174
177
|
f"Parameter {parameter.name} must be an integer, got {type(value)}"
|
|
@@ -8,6 +8,7 @@ from uuid import uuid4
|
|
|
8
8
|
|
|
9
9
|
from kiln_ai.adapters.model_adapters.base_adapter import COT_FINAL_ANSWER_PROMPT
|
|
10
10
|
from kiln_ai.datamodel import DatasetSplit, FinetuneDataStrategy, TaskRun
|
|
11
|
+
from kiln_ai.datamodel.datamodel_enums import THINKING_DATA_STRATEGIES
|
|
11
12
|
|
|
12
13
|
|
|
13
14
|
class DatasetFormat(str, Enum):
|
|
@@ -30,8 +31,8 @@ class DatasetFormat(str, Enum):
|
|
|
30
31
|
"huggingface_chat_template_toolcall_jsonl"
|
|
31
32
|
)
|
|
32
33
|
|
|
33
|
-
"""Vertex Gemini
|
|
34
|
-
|
|
34
|
+
"""Vertex Gemini format"""
|
|
35
|
+
VERTEX_GEMINI = "vertex_gemini"
|
|
35
36
|
|
|
36
37
|
|
|
37
38
|
@dataclass
|
|
@@ -43,8 +44,12 @@ class ModelTrainingData:
|
|
|
43
44
|
thinking_instructions: str | None = None
|
|
44
45
|
thinking: str | None = None
|
|
45
46
|
thinking_final_answer_prompt: str | None = None
|
|
47
|
+
thinking_r1_style: bool = False
|
|
46
48
|
|
|
47
49
|
def supports_cot(self) -> bool:
|
|
50
|
+
if self.thinking_r1_style:
|
|
51
|
+
raise ValueError("R1 style does not support COT")
|
|
52
|
+
|
|
48
53
|
return (
|
|
49
54
|
self.thinking_instructions is not None
|
|
50
55
|
and self.thinking is not None
|
|
@@ -64,7 +69,7 @@ class FormatGenerator(Protocol):
|
|
|
64
69
|
def build_training_data(
|
|
65
70
|
task_run: TaskRun,
|
|
66
71
|
system_message: str,
|
|
67
|
-
|
|
72
|
+
data_strategy: FinetuneDataStrategy,
|
|
68
73
|
thinking_instructions: str | None = None,
|
|
69
74
|
) -> ModelTrainingData:
|
|
70
75
|
"""
|
|
@@ -80,27 +85,41 @@ def build_training_data(
|
|
|
80
85
|
|
|
81
86
|
thinking = None
|
|
82
87
|
thinking_final_answer_prompt = None
|
|
88
|
+
thinking_r1_style = False
|
|
83
89
|
parent_task = task_run.parent_task()
|
|
84
90
|
|
|
85
|
-
if
|
|
86
|
-
if not parent_task:
|
|
87
|
-
raise ValueError(
|
|
88
|
-
"TaskRuns for training required a parent Task for building a chain of thought prompts. Train without COT, or save this TaskRun to a parent Task."
|
|
89
|
-
)
|
|
90
|
-
|
|
91
|
+
if data_strategy in THINKING_DATA_STRATEGIES:
|
|
91
92
|
# Prefer reasoning to cot if both are present
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
93
|
+
thinking = task_run.thinking_training_data()
|
|
94
|
+
|
|
95
|
+
if data_strategy == FinetuneDataStrategy.final_and_intermediate_r1_compatible:
|
|
96
|
+
if not task_run.has_thinking_training_data() or not thinking:
|
|
97
|
+
raise ValueError(
|
|
98
|
+
"Thinking data is required when fine-tuning thinking models (R1, QwQ, etc). Please ensure your fine-tuning dataset contains reasoning or chain of thought output for every entry."
|
|
99
|
+
)
|
|
100
|
+
if thinking_instructions:
|
|
101
|
+
raise ValueError(
|
|
102
|
+
"Thinking instructions are not supported when fine-tuning thinking models (R1, QwQ, etc). Please remove the thinking instructions."
|
|
103
|
+
)
|
|
104
|
+
thinking_r1_style = True
|
|
105
|
+
elif (
|
|
106
|
+
data_strategy == FinetuneDataStrategy.final_and_intermediate
|
|
107
|
+
and task_run.has_thinking_training_data()
|
|
108
|
+
):
|
|
109
|
+
if not parent_task:
|
|
110
|
+
raise ValueError(
|
|
111
|
+
"TaskRuns for training required a parent Task for building a chain of thought prompts. Train without COT, or save this TaskRun to a parent Task."
|
|
112
|
+
)
|
|
96
113
|
|
|
97
|
-
|
|
114
|
+
thinking_final_answer_prompt = COT_FINAL_ANSWER_PROMPT
|
|
98
115
|
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
116
|
+
# Always use the passed thinking instructions, but check they are present for COT
|
|
117
|
+
if not thinking_instructions:
|
|
118
|
+
raise ValueError(
|
|
119
|
+
"Thinking instructions are required when data_strategy is final_and_intermediate"
|
|
120
|
+
)
|
|
121
|
+
else:
|
|
122
|
+
raise ValueError(f"Unsupported data strategy: {data_strategy}")
|
|
104
123
|
|
|
105
124
|
return ModelTrainingData(
|
|
106
125
|
input=task_run.input,
|
|
@@ -109,9 +128,19 @@ def build_training_data(
|
|
|
109
128
|
thinking=thinking,
|
|
110
129
|
thinking_instructions=thinking_instructions,
|
|
111
130
|
thinking_final_answer_prompt=thinking_final_answer_prompt,
|
|
131
|
+
thinking_r1_style=thinking_r1_style,
|
|
112
132
|
)
|
|
113
133
|
|
|
114
134
|
|
|
135
|
+
def serialize_r1_style_message(thinking: str | None, final_output: str):
|
|
136
|
+
if thinking is None or len(thinking.strip()) == 0:
|
|
137
|
+
raise ValueError(
|
|
138
|
+
"Thinking data is required when fine-tuning thinking models (R1, QwQ, etc). Please ensure your fine-tuning dataset contains reasoning or chain of thought output for every entry."
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
return f"<think>\n{thinking}\n</think>\n\n{final_output}"
|
|
142
|
+
|
|
143
|
+
|
|
115
144
|
def generate_chat_message_response(
|
|
116
145
|
training_data: ModelTrainingData,
|
|
117
146
|
) -> Dict[str, Any]:
|
|
@@ -122,7 +151,21 @@ def generate_chat_message_response(
|
|
|
122
151
|
{"role": "user", "content": training_data.input},
|
|
123
152
|
]
|
|
124
153
|
|
|
125
|
-
if training_data.
|
|
154
|
+
if training_data.thinking_r1_style:
|
|
155
|
+
messages.extend(
|
|
156
|
+
[
|
|
157
|
+
{
|
|
158
|
+
"role": "assistant",
|
|
159
|
+
"content": serialize_r1_style_message(
|
|
160
|
+
thinking=training_data.thinking,
|
|
161
|
+
final_output=training_data.final_output,
|
|
162
|
+
),
|
|
163
|
+
}
|
|
164
|
+
]
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
return {"messages": messages}
|
|
168
|
+
elif training_data.supports_cot():
|
|
126
169
|
messages.extend(
|
|
127
170
|
[
|
|
128
171
|
{"role": "user", "content": training_data.thinking_instructions},
|
|
@@ -157,7 +200,21 @@ def generate_json_schema_message(
|
|
|
157
200
|
{"role": "user", "content": training_data.input},
|
|
158
201
|
]
|
|
159
202
|
|
|
160
|
-
if training_data.
|
|
203
|
+
if training_data.thinking_r1_style:
|
|
204
|
+
messages.extend(
|
|
205
|
+
[
|
|
206
|
+
{
|
|
207
|
+
"role": "assistant",
|
|
208
|
+
"content": serialize_r1_style_message(
|
|
209
|
+
thinking=training_data.thinking,
|
|
210
|
+
final_output=training_data.final_output,
|
|
211
|
+
),
|
|
212
|
+
}
|
|
213
|
+
]
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
return {"messages": messages}
|
|
217
|
+
elif training_data.supports_cot():
|
|
161
218
|
messages.extend(
|
|
162
219
|
[
|
|
163
220
|
{"role": "user", "content": training_data.thinking_instructions},
|
|
@@ -188,7 +245,11 @@ def generate_chat_message_toolcall(
|
|
|
188
245
|
{"role": "user", "content": training_data.input},
|
|
189
246
|
]
|
|
190
247
|
|
|
191
|
-
if training_data.
|
|
248
|
+
if training_data.thinking_r1_style:
|
|
249
|
+
raise ValueError(
|
|
250
|
+
"R1 style thinking is not supported for tool call downloads. Please use a different training strategy."
|
|
251
|
+
)
|
|
252
|
+
elif training_data.supports_cot():
|
|
192
253
|
messages.extend(
|
|
193
254
|
[
|
|
194
255
|
{"role": "user", "content": training_data.thinking_instructions},
|
|
@@ -231,12 +292,29 @@ def generate_huggingface_chat_template(
|
|
|
231
292
|
{"role": "user", "content": training_data.input},
|
|
232
293
|
]
|
|
233
294
|
|
|
295
|
+
if training_data.thinking_r1_style:
|
|
296
|
+
conversations.extend(
|
|
297
|
+
[
|
|
298
|
+
{
|
|
299
|
+
"role": "assistant",
|
|
300
|
+
"content": serialize_r1_style_message(
|
|
301
|
+
thinking=training_data.thinking,
|
|
302
|
+
final_output=training_data.final_output,
|
|
303
|
+
),
|
|
304
|
+
}
|
|
305
|
+
]
|
|
306
|
+
)
|
|
307
|
+
return {"conversations": conversations}
|
|
308
|
+
|
|
234
309
|
if training_data.supports_cot():
|
|
235
310
|
conversations.extend(
|
|
236
311
|
[
|
|
237
312
|
{"role": "user", "content": training_data.thinking_instructions},
|
|
238
313
|
{"role": "assistant", "content": training_data.thinking},
|
|
239
|
-
{
|
|
314
|
+
{
|
|
315
|
+
"role": "user",
|
|
316
|
+
"content": training_data.thinking_final_answer_prompt,
|
|
317
|
+
},
|
|
240
318
|
]
|
|
241
319
|
)
|
|
242
320
|
|
|
@@ -260,12 +338,19 @@ def generate_huggingface_chat_template_toolcall(
|
|
|
260
338
|
{"role": "user", "content": training_data.input},
|
|
261
339
|
]
|
|
262
340
|
|
|
263
|
-
if training_data.
|
|
341
|
+
if training_data.thinking_r1_style:
|
|
342
|
+
raise ValueError(
|
|
343
|
+
"R1 style thinking is not supported for tool call downloads. Please use a different training strategy."
|
|
344
|
+
)
|
|
345
|
+
elif training_data.supports_cot():
|
|
264
346
|
conversations.extend(
|
|
265
347
|
[
|
|
266
348
|
{"role": "user", "content": training_data.thinking_instructions},
|
|
267
349
|
{"role": "assistant", "content": training_data.thinking},
|
|
268
|
-
{
|
|
350
|
+
{
|
|
351
|
+
"role": "user",
|
|
352
|
+
"content": training_data.thinking_final_answer_prompt,
|
|
353
|
+
},
|
|
269
354
|
]
|
|
270
355
|
)
|
|
271
356
|
|
|
@@ -288,12 +373,20 @@ def generate_huggingface_chat_template_toolcall(
|
|
|
288
373
|
return {"conversations": conversations}
|
|
289
374
|
|
|
290
375
|
|
|
291
|
-
def
|
|
376
|
+
def generate_vertex_gemini(
|
|
292
377
|
training_data: ModelTrainingData,
|
|
293
378
|
) -> Dict[str, Any]:
|
|
294
379
|
"""Generate Vertex Gemini 1.5 format (flash and pro)"""
|
|
295
380
|
# See https://cloud.google.com/vertex-ai/generative-ai/docs/models/gemini-supervised-tuning-prepare
|
|
296
381
|
|
|
382
|
+
system_instruction = {
|
|
383
|
+
"role": "system",
|
|
384
|
+
"parts": [
|
|
385
|
+
{
|
|
386
|
+
"text": training_data.system_message,
|
|
387
|
+
}
|
|
388
|
+
],
|
|
389
|
+
}
|
|
297
390
|
contents = [
|
|
298
391
|
{
|
|
299
392
|
"role": "user",
|
|
@@ -305,7 +398,11 @@ def generate_vertex_gemini_1_5(
|
|
|
305
398
|
}
|
|
306
399
|
]
|
|
307
400
|
|
|
308
|
-
if training_data.
|
|
401
|
+
if training_data.thinking_r1_style:
|
|
402
|
+
raise ValueError(
|
|
403
|
+
"R1 style thinking is not supported for Vertex Gemini. Please use a different training strategy."
|
|
404
|
+
)
|
|
405
|
+
elif training_data.supports_cot():
|
|
309
406
|
contents.extend(
|
|
310
407
|
[
|
|
311
408
|
{
|
|
@@ -328,14 +425,7 @@ def generate_vertex_gemini_1_5(
|
|
|
328
425
|
)
|
|
329
426
|
|
|
330
427
|
return {
|
|
331
|
-
"systemInstruction":
|
|
332
|
-
"role": "system",
|
|
333
|
-
"parts": [
|
|
334
|
-
{
|
|
335
|
-
"text": training_data.system_message,
|
|
336
|
-
}
|
|
337
|
-
],
|
|
338
|
-
},
|
|
428
|
+
"systemInstruction": system_instruction,
|
|
339
429
|
"contents": contents,
|
|
340
430
|
}
|
|
341
431
|
|
|
@@ -346,7 +436,7 @@ FORMAT_GENERATORS: Dict[DatasetFormat, FormatGenerator] = {
|
|
|
346
436
|
DatasetFormat.OPENAI_CHAT_TOOLCALL_JSONL: generate_chat_message_toolcall,
|
|
347
437
|
DatasetFormat.HUGGINGFACE_CHAT_TEMPLATE_JSONL: generate_huggingface_chat_template,
|
|
348
438
|
DatasetFormat.HUGGINGFACE_CHAT_TEMPLATE_TOOLCALL_JSONL: generate_huggingface_chat_template_toolcall,
|
|
349
|
-
DatasetFormat.
|
|
439
|
+
DatasetFormat.VERTEX_GEMINI: generate_vertex_gemini,
|
|
350
440
|
}
|
|
351
441
|
|
|
352
442
|
|
|
@@ -397,7 +487,7 @@ class DatasetFormatter:
|
|
|
397
487
|
|
|
398
488
|
generator = FORMAT_GENERATORS[format_type]
|
|
399
489
|
|
|
400
|
-
include_cot = data_strategy
|
|
490
|
+
include_cot = data_strategy in THINKING_DATA_STRATEGIES
|
|
401
491
|
|
|
402
492
|
# Write to a temp file if no path is provided
|
|
403
493
|
output_path = (
|
|
@@ -421,7 +511,7 @@ class DatasetFormatter:
|
|
|
421
511
|
training_data = build_training_data(
|
|
422
512
|
task_run=task_run,
|
|
423
513
|
system_message=self.system_message,
|
|
424
|
-
|
|
514
|
+
data_strategy=data_strategy,
|
|
425
515
|
thinking_instructions=self.thinking_instructions,
|
|
426
516
|
)
|
|
427
517
|
example = generator(training_data)
|
|
@@ -4,10 +4,12 @@ from kiln_ai.adapters.fine_tune.base_finetune import BaseFinetuneAdapter
|
|
|
4
4
|
from kiln_ai.adapters.fine_tune.fireworks_finetune import FireworksFinetune
|
|
5
5
|
from kiln_ai.adapters.fine_tune.openai_finetune import OpenAIFinetune
|
|
6
6
|
from kiln_ai.adapters.fine_tune.together_finetune import TogetherFinetune
|
|
7
|
+
from kiln_ai.adapters.fine_tune.vertex_finetune import VertexFinetune
|
|
7
8
|
from kiln_ai.adapters.ml_model_list import ModelProviderName
|
|
8
9
|
|
|
9
10
|
finetune_registry: dict[ModelProviderName, Type[BaseFinetuneAdapter]] = {
|
|
10
11
|
ModelProviderName.openai: OpenAIFinetune,
|
|
11
12
|
ModelProviderName.fireworks_ai: FireworksFinetune,
|
|
12
13
|
ModelProviderName.together_ai: TogetherFinetune,
|
|
14
|
+
ModelProviderName.vertex: VertexFinetune,
|
|
13
15
|
}
|
|
@@ -198,7 +198,8 @@ class FireworksFinetune(BaseFinetuneAdapter):
|
|
|
198
198
|
if not api_key or not account_id:
|
|
199
199
|
raise ValueError("Fireworks API key or account ID not set")
|
|
200
200
|
url = f"https://api.fireworks.ai/v1/accounts/{account_id}/datasets"
|
|
201
|
-
|
|
201
|
+
# First char can't be a digit: https://discord.com/channels/1137072072808472616/1363214412395184350/1363214412395184350
|
|
202
|
+
dataset_id = "kiln-" + str(uuid4())
|
|
202
203
|
payload = {
|
|
203
204
|
"datasetId": dataset_id,
|
|
204
205
|
"dataset": {
|
|
@@ -98,6 +98,13 @@ def test_validate_parameters_valid():
|
|
|
98
98
|
}
|
|
99
99
|
MockFinetune.validate_parameters(valid_params) # Should not raise
|
|
100
100
|
|
|
101
|
+
# Test valid parameters (float as int)
|
|
102
|
+
valid_params = {
|
|
103
|
+
"learning_rate": 1,
|
|
104
|
+
"epochs": 10,
|
|
105
|
+
}
|
|
106
|
+
MockFinetune.validate_parameters(valid_params) # Should not raise
|
|
107
|
+
|
|
101
108
|
|
|
102
109
|
def test_validate_parameters_missing_required():
|
|
103
110
|
# Test missing required parameter
|