kiln-ai 0.16.0__py3-none-any.whl → 0.18.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. kiln_ai/adapters/__init__.py +2 -0
  2. kiln_ai/adapters/adapter_registry.py +22 -44
  3. kiln_ai/adapters/chat/__init__.py +8 -0
  4. kiln_ai/adapters/chat/chat_formatter.py +233 -0
  5. kiln_ai/adapters/chat/test_chat_formatter.py +131 -0
  6. kiln_ai/adapters/data_gen/data_gen_prompts.py +121 -36
  7. kiln_ai/adapters/data_gen/data_gen_task.py +49 -36
  8. kiln_ai/adapters/data_gen/test_data_gen_task.py +330 -40
  9. kiln_ai/adapters/eval/base_eval.py +7 -6
  10. kiln_ai/adapters/eval/eval_runner.py +9 -2
  11. kiln_ai/adapters/eval/g_eval.py +40 -17
  12. kiln_ai/adapters/eval/test_base_eval.py +174 -17
  13. kiln_ai/adapters/eval/test_eval_runner.py +3 -0
  14. kiln_ai/adapters/eval/test_g_eval.py +116 -5
  15. kiln_ai/adapters/fine_tune/base_finetune.py +3 -8
  16. kiln_ai/adapters/fine_tune/dataset_formatter.py +135 -273
  17. kiln_ai/adapters/fine_tune/test_base_finetune.py +10 -10
  18. kiln_ai/adapters/fine_tune/test_dataset_formatter.py +287 -353
  19. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +3 -3
  20. kiln_ai/adapters/fine_tune/test_openai_finetune.py +6 -6
  21. kiln_ai/adapters/fine_tune/test_together_finetune.py +1 -0
  22. kiln_ai/adapters/fine_tune/test_vertex_finetune.py +6 -11
  23. kiln_ai/adapters/fine_tune/together_finetune.py +13 -2
  24. kiln_ai/adapters/ml_model_list.py +370 -84
  25. kiln_ai/adapters/model_adapters/base_adapter.py +73 -26
  26. kiln_ai/adapters/model_adapters/litellm_adapter.py +88 -97
  27. kiln_ai/adapters/model_adapters/litellm_config.py +3 -2
  28. kiln_ai/adapters/model_adapters/test_base_adapter.py +235 -61
  29. kiln_ai/adapters/model_adapters/test_litellm_adapter.py +104 -21
  30. kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +41 -0
  31. kiln_ai/adapters/model_adapters/test_structured_output.py +44 -12
  32. kiln_ai/adapters/parsers/parser_registry.py +0 -2
  33. kiln_ai/adapters/parsers/r1_parser.py +0 -1
  34. kiln_ai/adapters/prompt_builders.py +0 -16
  35. kiln_ai/adapters/provider_tools.py +27 -9
  36. kiln_ai/adapters/remote_config.py +66 -0
  37. kiln_ai/adapters/repair/repair_task.py +1 -6
  38. kiln_ai/adapters/repair/test_repair_task.py +24 -3
  39. kiln_ai/adapters/test_adapter_registry.py +88 -28
  40. kiln_ai/adapters/test_ml_model_list.py +176 -0
  41. kiln_ai/adapters/test_prompt_adaptors.py +17 -7
  42. kiln_ai/adapters/test_prompt_builders.py +3 -16
  43. kiln_ai/adapters/test_provider_tools.py +69 -20
  44. kiln_ai/adapters/test_remote_config.py +100 -0
  45. kiln_ai/datamodel/__init__.py +0 -2
  46. kiln_ai/datamodel/datamodel_enums.py +38 -13
  47. kiln_ai/datamodel/eval.py +32 -0
  48. kiln_ai/datamodel/finetune.py +12 -8
  49. kiln_ai/datamodel/task.py +68 -7
  50. kiln_ai/datamodel/task_output.py +0 -2
  51. kiln_ai/datamodel/task_run.py +0 -2
  52. kiln_ai/datamodel/test_basemodel.py +2 -1
  53. kiln_ai/datamodel/test_dataset_split.py +0 -8
  54. kiln_ai/datamodel/test_eval_model.py +146 -4
  55. kiln_ai/datamodel/test_models.py +33 -10
  56. kiln_ai/datamodel/test_task.py +168 -2
  57. kiln_ai/utils/config.py +3 -2
  58. kiln_ai/utils/dataset_import.py +1 -1
  59. kiln_ai/utils/logging.py +166 -0
  60. kiln_ai/utils/test_config.py +23 -0
  61. kiln_ai/utils/test_dataset_import.py +30 -0
  62. {kiln_ai-0.16.0.dist-info → kiln_ai-0.18.0.dist-info}/METADATA +2 -2
  63. kiln_ai-0.18.0.dist-info/RECORD +115 -0
  64. kiln_ai-0.16.0.dist-info/RECORD +0 -108
  65. {kiln_ai-0.16.0.dist-info → kiln_ai-0.18.0.dist-info}/WHEEL +0 -0
  66. {kiln_ai-0.16.0.dist-info → kiln_ai-0.18.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -1,14 +1,14 @@
1
1
  import json
2
2
  import tempfile
3
- from dataclasses import dataclass
4
3
  from enum import Enum
5
4
  from pathlib import Path
6
5
  from typing import Any, Dict, Protocol
7
6
  from uuid import uuid4
8
7
 
9
- from kiln_ai.adapters.model_adapters.base_adapter import COT_FINAL_ANSWER_PROMPT
10
- from kiln_ai.datamodel import DatasetSplit, FinetuneDataStrategy, TaskRun
11
- from kiln_ai.datamodel.datamodel_enums import THINKING_DATA_STRATEGIES
8
+ from kiln_ai.adapters.chat.chat_formatter import ChatMessage, get_chat_formatter
9
+ from kiln_ai.datamodel import DatasetSplit, TaskRun
10
+ from kiln_ai.datamodel.datamodel_enums import THINKING_DATA_STRATEGIES, ChatStrategy
11
+ from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
12
12
 
13
13
 
14
14
  class DatasetFormat(str, Enum):
@@ -35,45 +35,23 @@ class DatasetFormat(str, Enum):
35
35
  VERTEX_GEMINI = "vertex_gemini"
36
36
 
37
37
 
38
- @dataclass
39
- class ModelTrainingData:
40
- input: str
41
- system_message: str
42
- final_output: str
43
- # These 3 are optional, and used for COT/Thinking style multi-message responses
44
- thinking_instructions: str | None = None
45
- thinking: str | None = None
46
- thinking_final_answer_prompt: str | None = None
47
- thinking_r1_style: bool = False
48
-
49
- def supports_cot(self) -> bool:
50
- if self.thinking_r1_style:
51
- raise ValueError("R1 style does not support COT")
52
-
53
- return (
54
- self.thinking_instructions is not None
55
- and self.thinking is not None
56
- and self.thinking_final_answer_prompt is not None
57
- )
58
-
59
-
60
38
  class FormatGenerator(Protocol):
61
39
  """Protocol for format generators"""
62
40
 
63
41
  def __call__(
64
42
  self,
65
- training_data: ModelTrainingData,
43
+ training_chat: list[ChatMessage],
66
44
  ) -> Dict[str, Any]: ...
67
45
 
68
46
 
69
- def build_training_data(
47
+ def build_training_chat(
70
48
  task_run: TaskRun,
71
49
  system_message: str,
72
- data_strategy: FinetuneDataStrategy,
50
+ data_strategy: ChatStrategy,
73
51
  thinking_instructions: str | None = None,
74
- ) -> ModelTrainingData:
52
+ ) -> list[ChatMessage]:
75
53
  """
76
- Generate data for training.
54
+ Generate chat message list for training.
77
55
 
78
56
  For final output, get the best task output from the task run, preferring repaired output if available.
79
57
 
@@ -84,52 +62,53 @@ def build_training_data(
84
62
  final_output = task_run.repaired_output.output
85
63
 
86
64
  thinking = None
87
- thinking_final_answer_prompt = None
88
- thinking_r1_style = False
89
- parent_task = task_run.parent_task()
90
65
 
91
- if data_strategy in THINKING_DATA_STRATEGIES:
92
- # Prefer reasoning to cot if both are present
93
- thinking = task_run.thinking_training_data()
94
-
95
- if data_strategy == FinetuneDataStrategy.final_and_intermediate_r1_compatible:
96
- if not task_run.has_thinking_training_data() or not thinking:
97
- raise ValueError(
98
- "Thinking data is required when fine-tuning thinking models (R1, QwQ, etc). Please ensure your fine-tuning dataset contains reasoning or chain of thought output for every entry."
99
- )
66
+ chat_formatter = get_chat_formatter(
67
+ data_strategy,
68
+ system_message,
69
+ task_run.input,
70
+ thinking_instructions,
71
+ )
72
+ # First turn already has it's content (user message)
73
+ chat_formatter.next_turn(None)
74
+
75
+ match data_strategy:
76
+ case ChatStrategy.single_turn:
77
+ chat_formatter.next_turn(final_output)
78
+ case ChatStrategy.two_message_cot:
79
+ thinking = get_thinking_data(task_run)
80
+ chat_formatter.next_turn(thinking)
81
+ chat_formatter.next_turn(final_output)
82
+ case ChatStrategy.two_message_cot_legacy:
83
+ thinking = get_thinking_data(task_run)
84
+ chat_formatter.next_turn(thinking)
85
+ chat_formatter.next_turn(final_output)
86
+ case ChatStrategy.single_turn_r1_thinking:
100
87
  if thinking_instructions:
101
88
  raise ValueError(
102
89
  "Thinking instructions are not supported when fine-tuning thinking models (R1, QwQ, etc). Please remove the thinking instructions."
103
90
  )
104
- thinking_r1_style = True
105
- elif (
106
- data_strategy == FinetuneDataStrategy.final_and_intermediate
107
- and task_run.has_thinking_training_data()
108
- ):
109
- if not parent_task:
110
- raise ValueError(
111
- "TaskRuns for training required a parent Task for building a chain of thought prompts. Train without COT, or save this TaskRun to a parent Task."
112
- )
113
91
 
114
- thinking_final_answer_prompt = COT_FINAL_ANSWER_PROMPT
92
+ thinking = get_thinking_data(task_run)
93
+ response_msg = serialize_r1_style_message(thinking, final_output)
94
+ chat_formatter.next_turn(response_msg)
95
+ case _:
96
+ raise_exhaustive_enum_error(data_strategy)
115
97
 
116
- # Always use the passed thinking instructions, but check they are present for COT
117
- if not thinking_instructions:
118
- raise ValueError(
119
- "Thinking instructions are required when data_strategy is final_and_intermediate"
120
- )
121
- else:
122
- raise ValueError(f"Unsupported data strategy: {data_strategy}")
123
-
124
- return ModelTrainingData(
125
- input=task_run.input,
126
- system_message=system_message,
127
- final_output=final_output,
128
- thinking=thinking,
129
- thinking_instructions=thinking_instructions,
130
- thinking_final_answer_prompt=thinking_final_answer_prompt,
131
- thinking_r1_style=thinking_r1_style,
132
- )
98
+ return chat_formatter.messages
99
+
100
+
101
+ def get_thinking_data(task_run: TaskRun) -> str:
102
+ """
103
+ Raises an error if thinking data is not present.
104
+ """
105
+ thinking = task_run.thinking_training_data()
106
+ if thinking is None:
107
+ raise ValueError(
108
+ "Thinking data is required when fine-tuning thinking models. Please ensure your fine-tuning dataset contains reasoning or chain of thought output for every entry."
109
+ )
110
+
111
+ return thinking
133
112
 
134
113
 
135
114
  def serialize_r1_style_message(thinking: str | None, final_output: str):
@@ -141,125 +120,78 @@ def serialize_r1_style_message(thinking: str | None, final_output: str):
141
120
  return f"<think>\n{thinking}\n</think>\n\n{final_output}"
142
121
 
143
122
 
144
- def generate_chat_message_response(
145
- training_data: ModelTrainingData,
146
- ) -> Dict[str, Any]:
147
- """Generate OpenAI chat format with plaintext response"""
123
+ def generate_chat_message_list(
124
+ training_chat: list[ChatMessage],
125
+ ) -> list[dict[str, str | None]]:
126
+ """Generate OpenAI chat list. Not the full OpenAI body, just the list of messages."""
148
127
 
149
- messages: list[dict[str, str | None]] = [
150
- {"role": "system", "content": training_data.system_message},
151
- {"role": "user", "content": training_data.input},
152
- ]
128
+ messages: list[dict[str, str | None]] = []
153
129
 
154
- if training_data.thinking_r1_style:
155
- messages.extend(
156
- [
157
- {
158
- "role": "assistant",
159
- "content": serialize_r1_style_message(
160
- thinking=training_data.thinking,
161
- final_output=training_data.final_output,
162
- ),
163
- }
164
- ]
165
- )
130
+ for msg in training_chat:
131
+ if msg.role not in ["user", "assistant", "system"]:
132
+ raise ValueError(f"Unsupported role for OpenAI chat format: {msg.role}")
166
133
 
167
- return {"messages": messages}
168
- elif training_data.supports_cot():
169
- messages.extend(
170
- [
171
- {"role": "user", "content": training_data.thinking_instructions},
172
- {"role": "assistant", "content": training_data.thinking},
173
- {
174
- "role": "user",
175
- "content": training_data.thinking_final_answer_prompt,
176
- },
177
- ]
134
+ messages.append(
135
+ {
136
+ "role": msg.role,
137
+ "content": msg.content,
138
+ }
178
139
  )
179
140
 
180
- messages.append({"role": "assistant", "content": training_data.final_output})
141
+ return messages
142
+
143
+
144
+ def generate_chat_message_response(
145
+ training_chat: list[ChatMessage],
146
+ ) -> Dict[str, Any]:
147
+ """Generate OpenAI chat format with plaintext response"""
148
+
149
+ messages: list[dict[str, str | None]] = generate_chat_message_list(training_chat)
181
150
 
182
151
  return {"messages": messages}
183
152
 
184
153
 
185
- def generate_json_schema_message(
186
- training_data: ModelTrainingData,
187
- ) -> Dict[str, Any]:
188
- """Generate OpenAI chat format with validated JSON response"""
189
- # Load and dump to ensure it's valid JSON and goes to 1 line
154
+ def last_message_structured_content(training_chat: list[ChatMessage]) -> Dict:
155
+ """Get the structured content of the last message"""
156
+ if len(training_chat) < 1:
157
+ raise ValueError("Training chat is empty")
190
158
  try:
191
- json_data = json.loads(training_data.final_output)
159
+ json_data = json.loads(training_chat[-1].content or "")
192
160
  except json.JSONDecodeError as e:
193
161
  raise ValueError(
194
- f"Invalid JSON in JSON Schema training set: {e}\nOutput Data: {training_data.final_output}"
195
- ) from e
196
- json_string = json.dumps(json_data, ensure_ascii=False)
197
-
198
- messages: list[dict[str, str | None]] = [
199
- {"role": "system", "content": training_data.system_message},
200
- {"role": "user", "content": training_data.input},
201
- ]
202
-
203
- if training_data.thinking_r1_style:
204
- messages.extend(
205
- [
206
- {
207
- "role": "assistant",
208
- "content": serialize_r1_style_message(
209
- thinking=training_data.thinking,
210
- final_output=training_data.final_output,
211
- ),
212
- }
213
- ]
162
+ f"Last message is not JSON (structured), and this format expects structured data: {e}"
214
163
  )
215
-
216
- return {"messages": messages}
217
- elif training_data.supports_cot():
218
- messages.extend(
219
- [
220
- {"role": "user", "content": training_data.thinking_instructions},
221
- {"role": "assistant", "content": training_data.thinking},
222
- {
223
- "role": "user",
224
- "content": training_data.thinking_final_answer_prompt,
225
- },
226
- ]
164
+ if not isinstance(json_data, dict):
165
+ raise ValueError(
166
+ "Last message is not a JSON Dictionary (structured data), and this format expects structured_data."
227
167
  )
168
+ return json_data
228
169
 
229
- messages.append({"role": "assistant", "content": json_string})
230
170
 
231
- return {"messages": messages}
171
+ def generate_json_schema_message(
172
+ training_chat: list[ChatMessage],
173
+ ) -> Dict[str, Any]:
174
+ """Generate OpenAI chat format with validated JSON response"""
175
+ # Load and dump to ensure it's valid JSON and goes to 1 line
176
+ last_msg_data = last_message_structured_content(training_chat)
177
+
178
+ # re-format the json string in the last message for consistency
179
+ json_string = json.dumps(last_msg_data, ensure_ascii=False)
180
+ training_chat[-1].content = json_string
181
+
182
+ return generate_chat_message_response(training_chat)
232
183
 
233
184
 
234
185
  def generate_chat_message_toolcall(
235
- training_data: ModelTrainingData,
186
+ training_chat: list[ChatMessage],
236
187
  ) -> Dict[str, Any]:
237
188
  """Generate OpenAI chat format with tool call response"""
238
- try:
239
- arguments = json.loads(training_data.final_output)
240
- except json.JSONDecodeError as e:
241
- raise ValueError(f"Invalid JSON in for tool call: {e}") from e
189
+ last_message_data = last_message_structured_content(training_chat)
242
190
 
243
- messages: list[dict[str, Any]] = [
244
- {"role": "system", "content": training_data.system_message},
245
- {"role": "user", "content": training_data.input},
246
- ]
191
+ messages: list[dict[str, Any]] = generate_chat_message_list(training_chat)
247
192
 
248
- if training_data.thinking_r1_style:
249
- raise ValueError(
250
- "R1 style thinking is not supported for tool call downloads. Please use a different training strategy."
251
- )
252
- elif training_data.supports_cot():
253
- messages.extend(
254
- [
255
- {"role": "user", "content": training_data.thinking_instructions},
256
- {"role": "assistant", "content": training_data.thinking},
257
- {
258
- "role": "user",
259
- "content": training_data.thinking_final_answer_prompt,
260
- },
261
- ]
262
- )
193
+ # remove the last message, we're going to replace it with a toolcall
194
+ messages = messages[:-1]
263
195
 
264
196
  messages.append(
265
197
  {
@@ -271,8 +203,7 @@ def generate_chat_message_toolcall(
271
203
  "type": "function",
272
204
  "function": {
273
205
  "name": "task_response",
274
- # Yes we parse then dump again. This ensures it's valid JSON, and ensures it goes to 1 line
275
- "arguments": json.dumps(arguments, ensure_ascii=False),
206
+ "arguments": json.dumps(last_message_data, ensure_ascii=False),
276
207
  },
277
208
  }
278
209
  ],
@@ -283,76 +214,26 @@ def generate_chat_message_toolcall(
283
214
 
284
215
 
285
216
  def generate_huggingface_chat_template(
286
- training_data: ModelTrainingData,
217
+ training_chat: list[ChatMessage],
287
218
  ) -> Dict[str, Any]:
288
219
  """Generate HuggingFace chat template"""
289
220
 
290
- conversations: list[dict[str, Any]] = [
291
- {"role": "system", "content": training_data.system_message},
292
- {"role": "user", "content": training_data.input},
293
- ]
294
-
295
- if training_data.thinking_r1_style:
296
- conversations.extend(
297
- [
298
- {
299
- "role": "assistant",
300
- "content": serialize_r1_style_message(
301
- thinking=training_data.thinking,
302
- final_output=training_data.final_output,
303
- ),
304
- }
305
- ]
306
- )
307
- return {"conversations": conversations}
308
-
309
- if training_data.supports_cot():
310
- conversations.extend(
311
- [
312
- {"role": "user", "content": training_data.thinking_instructions},
313
- {"role": "assistant", "content": training_data.thinking},
314
- {
315
- "role": "user",
316
- "content": training_data.thinking_final_answer_prompt,
317
- },
318
- ]
319
- )
320
-
321
- conversations.append({"role": "assistant", "content": training_data.final_output})
221
+ conversations: list[dict[str, Any]] = generate_chat_message_list(training_chat)
322
222
 
323
223
  return {"conversations": conversations}
324
224
 
325
225
 
326
226
  def generate_huggingface_chat_template_toolcall(
327
- training_data: ModelTrainingData,
227
+ training_chat: list[ChatMessage],
328
228
  ) -> Dict[str, Any]:
329
229
  """Generate HuggingFace chat template with tool calls"""
330
- try:
331
- arguments = json.loads(training_data.final_output)
332
- except json.JSONDecodeError as e:
333
- raise ValueError(f"Invalid JSON in for tool call: {e}") from e
230
+ last_message_data = last_message_structured_content(training_chat)
334
231
 
335
232
  # See https://huggingface.co/docs/transformers/en/chat_templating
336
- conversations: list[dict[str, Any]] = [
337
- {"role": "system", "content": training_data.system_message},
338
- {"role": "user", "content": training_data.input},
339
- ]
233
+ conversations: list[dict[str, Any]] = generate_chat_message_list(training_chat)
340
234
 
341
- if training_data.thinking_r1_style:
342
- raise ValueError(
343
- "R1 style thinking is not supported for tool call downloads. Please use a different training strategy."
344
- )
345
- elif training_data.supports_cot():
346
- conversations.extend(
347
- [
348
- {"role": "user", "content": training_data.thinking_instructions},
349
- {"role": "assistant", "content": training_data.thinking},
350
- {
351
- "role": "user",
352
- "content": training_data.thinking_final_answer_prompt,
353
- },
354
- ]
355
- )
235
+ # remove the last message, we're going to replace it with a toolcall
236
+ conversations = conversations[:-1]
356
237
 
357
238
  conversations.append(
358
239
  {
@@ -363,7 +244,7 @@ def generate_huggingface_chat_template_toolcall(
363
244
  "function": {
364
245
  "name": "task_response",
365
246
  "id": str(uuid4()).replace("-", "")[:9],
366
- "arguments": arguments,
247
+ "arguments": last_message_data,
367
248
  },
368
249
  }
369
250
  ],
@@ -373,60 +254,41 @@ def generate_huggingface_chat_template_toolcall(
373
254
  return {"conversations": conversations}
374
255
 
375
256
 
257
+ VERTEX_GEMINI_ROLE_MAP = {
258
+ "system": "system",
259
+ "user": "user",
260
+ "assistant": "model",
261
+ }
262
+
263
+
376
264
  def generate_vertex_gemini(
377
- training_data: ModelTrainingData,
265
+ training_chat: list[ChatMessage],
378
266
  ) -> Dict[str, Any]:
379
- """Generate Vertex Gemini 1.5 format (flash and pro)"""
267
+ """Generate Vertex Gemini format (flash and pro)"""
380
268
  # See https://cloud.google.com/vertex-ai/generative-ai/docs/models/gemini-supervised-tuning-prepare
381
269
 
382
- system_instruction = {
383
- "role": "system",
384
- "parts": [
270
+ # System message get's it's own entry in top level UI
271
+ system_instruction = training_chat[0].content
272
+
273
+ messages: list[Dict[str, Any]] = []
274
+ for msg in training_chat[1:]:
275
+ messages.append(
385
276
  {
386
- "text": training_data.system_message,
277
+ "role": VERTEX_GEMINI_ROLE_MAP[msg.role],
278
+ "parts": [{"text": msg.content}],
387
279
  }
388
- ],
389
- }
390
- contents = [
391
- {
392
- "role": "user",
280
+ )
281
+
282
+ return {
283
+ "systemInstruction": {
284
+ "role": "system",
393
285
  "parts": [
394
286
  {
395
- "text": training_data.input,
287
+ "text": system_instruction,
396
288
  }
397
289
  ],
398
- }
399
- ]
400
-
401
- if training_data.thinking_r1_style:
402
- raise ValueError(
403
- "R1 style thinking is not supported for Vertex Gemini. Please use a different training strategy."
404
- )
405
- elif training_data.supports_cot():
406
- contents.extend(
407
- [
408
- {
409
- "role": "user",
410
- "parts": [{"text": training_data.thinking_instructions}],
411
- },
412
- {"role": "model", "parts": [{"text": training_data.thinking}]},
413
- {
414
- "role": "user",
415
- "parts": [{"text": training_data.thinking_final_answer_prompt}],
416
- },
417
- ]
418
- )
419
-
420
- contents.append(
421
- {
422
- "role": "model",
423
- "parts": [{"text": training_data.final_output}],
424
- }
425
- )
426
-
427
- return {
428
- "systemInstruction": system_instruction,
429
- "contents": contents,
290
+ },
291
+ "contents": messages,
430
292
  }
431
293
 
432
294
 
@@ -462,7 +324,7 @@ class DatasetFormatter:
462
324
  self,
463
325
  split_name: str,
464
326
  format_type: DatasetFormat,
465
- data_strategy: FinetuneDataStrategy,
327
+ data_strategy: ChatStrategy,
466
328
  path: Path | None = None,
467
329
  ) -> Path:
468
330
  """
@@ -508,13 +370,13 @@ class DatasetFormatter:
508
370
  f"Task run {run_id} not found. This is required by this dataset."
509
371
  )
510
372
 
511
- training_data = build_training_data(
373
+ training_chat = build_training_chat(
512
374
  task_run=task_run,
513
375
  system_message=self.system_message,
514
376
  data_strategy=data_strategy,
515
377
  thinking_instructions=self.thinking_instructions,
516
378
  )
517
- example = generator(training_data)
379
+ example = generator(training_chat)
518
380
  # Allow non-ascii characters in the dataset.
519
381
  # Better readability for non-English users. If you don't support UTF-8... you should.
520
382
  f.write(json.dumps(example, ensure_ascii=False) + "\n")
@@ -4,13 +4,13 @@ import pytest
4
4
 
5
5
  from kiln_ai.adapters.fine_tune.base_finetune import (
6
6
  BaseFinetuneAdapter,
7
- FinetuneDataStrategy,
8
7
  FineTuneParameter,
9
8
  FineTuneStatus,
10
9
  FineTuneStatusType,
11
10
  )
12
11
  from kiln_ai.datamodel import DatasetSplit, Task
13
12
  from kiln_ai.datamodel import Finetune as FinetuneModel
13
+ from kiln_ai.datamodel.datamodel_enums import ChatStrategy
14
14
 
15
15
 
16
16
  class MockFinetune(BaseFinetuneAdapter):
@@ -162,7 +162,7 @@ async def test_create_and_start_success(mock_dataset):
162
162
  train_split_name="train",
163
163
  parameters={"epochs": 10}, # Required parameter
164
164
  system_message="Test system message",
165
- data_strategy=FinetuneDataStrategy.final_only,
165
+ data_strategy=ChatStrategy.single_turn,
166
166
  thinking_instructions=None,
167
167
  )
168
168
 
@@ -176,7 +176,7 @@ async def test_create_and_start_success(mock_dataset):
176
176
  assert datamodel.parameters == {"epochs": 10}
177
177
  assert datamodel.system_message == "Test system message"
178
178
  assert datamodel.path.exists()
179
- assert datamodel.data_strategy == FinetuneDataStrategy.final_only
179
+ assert datamodel.data_strategy == ChatStrategy.single_turn
180
180
  assert datamodel.thinking_instructions is None
181
181
 
182
182
 
@@ -192,7 +192,7 @@ async def test_create_and_start_with_all_params(mock_dataset):
192
192
  description="Custom Description",
193
193
  validation_split_name="test",
194
194
  system_message="Test system message",
195
- data_strategy=FinetuneDataStrategy.final_and_intermediate,
195
+ data_strategy=ChatStrategy.two_message_cot,
196
196
  thinking_instructions="Custom thinking instructions",
197
197
  )
198
198
 
@@ -202,7 +202,7 @@ async def test_create_and_start_with_all_params(mock_dataset):
202
202
  assert datamodel.parameters == {"epochs": 10, "learning_rate": 0.001}
203
203
  assert datamodel.system_message == "Test system message"
204
204
  assert adapter.datamodel == datamodel
205
- assert datamodel.data_strategy == FinetuneDataStrategy.final_and_intermediate
205
+ assert datamodel.data_strategy == ChatStrategy.two_message_cot
206
206
  assert datamodel.thinking_instructions == "Custom thinking instructions"
207
207
 
208
208
  # load the datamodel from the file, confirm it's saved
@@ -221,7 +221,7 @@ async def test_create_and_start_invalid_parameters(mock_dataset):
221
221
  parameters={"learning_rate": 0.001}, # Missing required 'epochs'
222
222
  system_message="Test system message",
223
223
  thinking_instructions=None,
224
- data_strategy=FinetuneDataStrategy.final_only,
224
+ data_strategy=ChatStrategy.single_turn,
225
225
  )
226
226
 
227
227
 
@@ -240,7 +240,7 @@ async def test_create_and_start_no_parent_task():
240
240
  train_split_name="train",
241
241
  parameters={"epochs": 10},
242
242
  system_message="Test system message",
243
- data_strategy=FinetuneDataStrategy.final_only,
243
+ data_strategy=ChatStrategy.single_turn,
244
244
  thinking_instructions=None,
245
245
  )
246
246
 
@@ -263,7 +263,7 @@ async def test_create_and_start_no_parent_task_path():
263
263
  train_split_name="train",
264
264
  parameters={"epochs": 10},
265
265
  system_message="Test system message",
266
- data_strategy=FinetuneDataStrategy.final_only,
266
+ data_strategy=ChatStrategy.single_turn,
267
267
  thinking_instructions=None,
268
268
  )
269
269
 
@@ -282,7 +282,7 @@ async def test_create_and_start_invalid_train_split(mock_dataset):
282
282
  train_split_name="invalid_train", # Invalid train split
283
283
  parameters={"epochs": 10},
284
284
  system_message="Test system message",
285
- data_strategy=FinetuneDataStrategy.final_only,
285
+ data_strategy=ChatStrategy.single_turn,
286
286
  thinking_instructions=None,
287
287
  )
288
288
 
@@ -302,6 +302,6 @@ async def test_create_and_start_invalid_validation_split(mock_dataset):
302
302
  validation_split_name="invalid_test", # Invalid validation split
303
303
  parameters={"epochs": 10},
304
304
  system_message="Test system message",
305
- data_strategy=FinetuneDataStrategy.final_only,
305
+ data_strategy=ChatStrategy.single_turn,
306
306
  thinking_instructions=None,
307
307
  )