kiln-ai 0.15.0__py3-none-any.whl → 0.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (72) hide show
  1. kiln_ai/adapters/__init__.py +2 -0
  2. kiln_ai/adapters/adapter_registry.py +22 -44
  3. kiln_ai/adapters/chat/__init__.py +8 -0
  4. kiln_ai/adapters/chat/chat_formatter.py +234 -0
  5. kiln_ai/adapters/chat/test_chat_formatter.py +131 -0
  6. kiln_ai/adapters/data_gen/test_data_gen_task.py +19 -6
  7. kiln_ai/adapters/eval/base_eval.py +8 -6
  8. kiln_ai/adapters/eval/eval_runner.py +9 -65
  9. kiln_ai/adapters/eval/g_eval.py +26 -8
  10. kiln_ai/adapters/eval/test_base_eval.py +166 -15
  11. kiln_ai/adapters/eval/test_eval_runner.py +3 -0
  12. kiln_ai/adapters/eval/test_g_eval.py +1 -0
  13. kiln_ai/adapters/fine_tune/base_finetune.py +2 -2
  14. kiln_ai/adapters/fine_tune/dataset_formatter.py +153 -197
  15. kiln_ai/adapters/fine_tune/test_base_finetune.py +10 -10
  16. kiln_ai/adapters/fine_tune/test_dataset_formatter.py +402 -211
  17. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +3 -3
  18. kiln_ai/adapters/fine_tune/test_openai_finetune.py +6 -6
  19. kiln_ai/adapters/fine_tune/test_together_finetune.py +1 -0
  20. kiln_ai/adapters/fine_tune/test_vertex_finetune.py +4 -4
  21. kiln_ai/adapters/fine_tune/together_finetune.py +12 -1
  22. kiln_ai/adapters/ml_model_list.py +556 -45
  23. kiln_ai/adapters/model_adapters/base_adapter.py +100 -35
  24. kiln_ai/adapters/model_adapters/litellm_adapter.py +116 -100
  25. kiln_ai/adapters/model_adapters/litellm_config.py +3 -2
  26. kiln_ai/adapters/model_adapters/test_base_adapter.py +299 -52
  27. kiln_ai/adapters/model_adapters/test_litellm_adapter.py +121 -22
  28. kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +44 -2
  29. kiln_ai/adapters/model_adapters/test_structured_output.py +48 -18
  30. kiln_ai/adapters/parsers/base_parser.py +0 -3
  31. kiln_ai/adapters/parsers/parser_registry.py +5 -3
  32. kiln_ai/adapters/parsers/r1_parser.py +17 -2
  33. kiln_ai/adapters/parsers/request_formatters.py +40 -0
  34. kiln_ai/adapters/parsers/test_parser_registry.py +2 -2
  35. kiln_ai/adapters/parsers/test_r1_parser.py +44 -1
  36. kiln_ai/adapters/parsers/test_request_formatters.py +76 -0
  37. kiln_ai/adapters/prompt_builders.py +14 -17
  38. kiln_ai/adapters/provider_tools.py +39 -4
  39. kiln_ai/adapters/repair/test_repair_task.py +27 -5
  40. kiln_ai/adapters/test_adapter_registry.py +88 -28
  41. kiln_ai/adapters/test_ml_model_list.py +158 -0
  42. kiln_ai/adapters/test_prompt_adaptors.py +17 -3
  43. kiln_ai/adapters/test_prompt_builders.py +27 -19
  44. kiln_ai/adapters/test_provider_tools.py +130 -12
  45. kiln_ai/datamodel/__init__.py +2 -2
  46. kiln_ai/datamodel/datamodel_enums.py +43 -4
  47. kiln_ai/datamodel/dataset_filters.py +69 -1
  48. kiln_ai/datamodel/dataset_split.py +4 -0
  49. kiln_ai/datamodel/eval.py +8 -0
  50. kiln_ai/datamodel/finetune.py +13 -7
  51. kiln_ai/datamodel/prompt_id.py +1 -0
  52. kiln_ai/datamodel/task.py +68 -7
  53. kiln_ai/datamodel/task_output.py +1 -1
  54. kiln_ai/datamodel/task_run.py +39 -7
  55. kiln_ai/datamodel/test_basemodel.py +5 -8
  56. kiln_ai/datamodel/test_dataset_filters.py +82 -0
  57. kiln_ai/datamodel/test_dataset_split.py +2 -8
  58. kiln_ai/datamodel/test_example_models.py +54 -0
  59. kiln_ai/datamodel/test_models.py +80 -9
  60. kiln_ai/datamodel/test_task.py +168 -2
  61. kiln_ai/utils/async_job_runner.py +106 -0
  62. kiln_ai/utils/config.py +3 -2
  63. kiln_ai/utils/dataset_import.py +81 -19
  64. kiln_ai/utils/logging.py +165 -0
  65. kiln_ai/utils/test_async_job_runner.py +199 -0
  66. kiln_ai/utils/test_config.py +23 -0
  67. kiln_ai/utils/test_dataset_import.py +272 -10
  68. {kiln_ai-0.15.0.dist-info → kiln_ai-0.17.0.dist-info}/METADATA +1 -1
  69. kiln_ai-0.17.0.dist-info/RECORD +113 -0
  70. kiln_ai-0.15.0.dist-info/RECORD +0 -104
  71. {kiln_ai-0.15.0.dist-info → kiln_ai-0.17.0.dist-info}/WHEEL +0 -0
  72. {kiln_ai-0.15.0.dist-info → kiln_ai-0.17.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -6,8 +6,13 @@ from pathlib import Path
6
6
  from typing import Any, Dict, Protocol
7
7
  from uuid import uuid4
8
8
 
9
- from kiln_ai.adapters.model_adapters.base_adapter import COT_FINAL_ANSWER_PROMPT
10
- from kiln_ai.datamodel import DatasetSplit, FinetuneDataStrategy, TaskRun
9
+ from kiln_ai.adapters.chat.chat_formatter import (
10
+ ChatMessage,
11
+ get_chat_formatter,
12
+ )
13
+ from kiln_ai.datamodel import DatasetSplit, TaskRun
14
+ from kiln_ai.datamodel.datamodel_enums import THINKING_DATA_STRATEGIES, ChatStrategy
15
+ from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
11
16
 
12
17
 
13
18
  class DatasetFormat(str, Enum):
@@ -34,41 +39,23 @@ class DatasetFormat(str, Enum):
34
39
  VERTEX_GEMINI = "vertex_gemini"
35
40
 
36
41
 
37
- @dataclass
38
- class ModelTrainingData:
39
- input: str
40
- system_message: str
41
- final_output: str
42
- # These 3 are optional, and used for COT/Thinking style multi-message responses
43
- thinking_instructions: str | None = None
44
- thinking: str | None = None
45
- thinking_final_answer_prompt: str | None = None
46
-
47
- def supports_cot(self) -> bool:
48
- return (
49
- self.thinking_instructions is not None
50
- and self.thinking is not None
51
- and self.thinking_final_answer_prompt is not None
52
- )
53
-
54
-
55
42
  class FormatGenerator(Protocol):
56
43
  """Protocol for format generators"""
57
44
 
58
45
  def __call__(
59
46
  self,
60
- training_data: ModelTrainingData,
47
+ training_chat: list[ChatMessage],
61
48
  ) -> Dict[str, Any]: ...
62
49
 
63
50
 
64
- def build_training_data(
51
+ def build_training_chat(
65
52
  task_run: TaskRun,
66
53
  system_message: str,
67
- include_cot: bool,
54
+ data_strategy: ChatStrategy,
68
55
  thinking_instructions: str | None = None,
69
- ) -> ModelTrainingData:
56
+ ) -> list[ChatMessage]:
70
57
  """
71
- Generate data for training.
58
+ Generate chat message list for training.
72
59
 
73
60
  For final output, get the best task output from the task run, preferring repaired output if available.
74
61
 
@@ -79,126 +66,136 @@ def build_training_data(
79
66
  final_output = task_run.repaired_output.output
80
67
 
81
68
  thinking = None
82
- thinking_final_answer_prompt = None
83
- parent_task = task_run.parent_task()
84
-
85
- if include_cot and task_run.has_thinking_training_data():
86
- if not parent_task:
87
- raise ValueError(
88
- "TaskRuns for training required a parent Task for building a chain of thought prompts. Train without COT, or save this TaskRun to a parent Task."
89
- )
90
-
91
- # Prefer reasoning to cot if both are present
92
- intermediate_outputs = task_run.intermediate_outputs or {}
93
- thinking = intermediate_outputs.get("reasoning") or intermediate_outputs.get(
94
- "chain_of_thought"
95
- )
96
69
 
97
- thinking_final_answer_prompt = COT_FINAL_ANSWER_PROMPT
98
-
99
- # Always use the passed thinking instructions, but check they are present for COT
100
- if not thinking_instructions:
101
- raise ValueError(
102
- "Thinking instructions are required when data_strategy is final_and_intermediate"
103
- )
104
-
105
- return ModelTrainingData(
106
- input=task_run.input,
107
- system_message=system_message,
108
- final_output=final_output,
109
- thinking=thinking,
110
- thinking_instructions=thinking_instructions,
111
- thinking_final_answer_prompt=thinking_final_answer_prompt,
70
+ chat_formatter = get_chat_formatter(
71
+ data_strategy,
72
+ system_message,
73
+ task_run.input,
74
+ thinking_instructions,
112
75
  )
76
+ # First turn already has it's content (user message)
77
+ chat_formatter.next_turn(None)
78
+
79
+ match data_strategy:
80
+ case ChatStrategy.single_turn:
81
+ chat_formatter.next_turn(final_output)
82
+ case ChatStrategy.two_message_cot:
83
+ thinking = get_thinking_data(task_run)
84
+ chat_formatter.next_turn(thinking)
85
+ chat_formatter.next_turn(final_output)
86
+ case ChatStrategy.two_message_cot_legacy:
87
+ thinking = get_thinking_data(task_run)
88
+ chat_formatter.next_turn(thinking)
89
+ chat_formatter.next_turn(final_output)
90
+ case ChatStrategy.single_turn_r1_thinking:
91
+ if thinking_instructions:
92
+ raise ValueError(
93
+ "Thinking instructions are not supported when fine-tuning thinking models (R1, QwQ, etc). Please remove the thinking instructions."
94
+ )
95
+
96
+ thinking = get_thinking_data(task_run)
97
+ response_msg = serialize_r1_style_message(thinking, final_output)
98
+ chat_formatter.next_turn(response_msg)
99
+ case _:
100
+ raise_exhaustive_enum_error(data_strategy)
101
+
102
+ return chat_formatter.messages
103
+
104
+
105
+ def get_thinking_data(task_run: TaskRun) -> str:
106
+ """
107
+ Raises an error if thinking data is not present.
108
+ """
109
+ thinking = task_run.thinking_training_data()
110
+ if thinking is None:
111
+ raise ValueError(
112
+ "Thinking data is required when fine-tuning thinking models. Please ensure your fine-tuning dataset contains reasoning or chain of thought output for every entry."
113
+ )
114
+
115
+ return thinking
116
+
117
+
118
+ def serialize_r1_style_message(thinking: str | None, final_output: str):
119
+ if thinking is None or len(thinking.strip()) == 0:
120
+ raise ValueError(
121
+ "Thinking data is required when fine-tuning thinking models (R1, QwQ, etc). Please ensure your fine-tuning dataset contains reasoning or chain of thought output for every entry."
122
+ )
123
+
124
+ return f"<think>\n{thinking}\n</think>\n\n{final_output}"
125
+
126
+
127
+ def generate_chat_message_list(
128
+ training_chat: list[ChatMessage],
129
+ ) -> list[dict[str, str | None]]:
130
+ """Generate OpenAI chat list. Not the full OpenAI body, just the list of messages."""
131
+
132
+ messages: list[dict[str, str | None]] = []
133
+
134
+ for msg in training_chat:
135
+ if msg.role not in ["user", "assistant", "system"]:
136
+ raise ValueError(f"Unsupported role for OpenAI chat format: {msg.role}")
137
+
138
+ messages.append(
139
+ {
140
+ "role": msg.role,
141
+ "content": msg.content,
142
+ }
143
+ )
144
+
145
+ return messages
113
146
 
114
147
 
115
148
  def generate_chat_message_response(
116
- training_data: ModelTrainingData,
149
+ training_chat: list[ChatMessage],
117
150
  ) -> Dict[str, Any]:
118
151
  """Generate OpenAI chat format with plaintext response"""
119
152
 
120
- messages: list[dict[str, str | None]] = [
121
- {"role": "system", "content": training_data.system_message},
122
- {"role": "user", "content": training_data.input},
123
- ]
153
+ messages: list[dict[str, str | None]] = generate_chat_message_list(training_chat)
124
154
 
125
- if training_data.supports_cot():
126
- messages.extend(
127
- [
128
- {"role": "user", "content": training_data.thinking_instructions},
129
- {"role": "assistant", "content": training_data.thinking},
130
- {
131
- "role": "user",
132
- "content": training_data.thinking_final_answer_prompt,
133
- },
134
- ]
135
- )
155
+ return {"messages": messages}
136
156
 
137
- messages.append({"role": "assistant", "content": training_data.final_output})
138
157
 
139
- return {"messages": messages}
158
+ def last_message_structured_content(training_chat: list[ChatMessage]) -> Dict:
159
+ """Get the structured content of the last message"""
160
+ if len(training_chat) < 1:
161
+ raise ValueError("Training chat is empty")
162
+ try:
163
+ json_data = json.loads(training_chat[-1].content or "")
164
+ except json.JSONDecodeError as e:
165
+ raise ValueError(
166
+ f"Last message is not JSON (structured), and this format expects structured data: {e}"
167
+ )
168
+ if not isinstance(json_data, dict):
169
+ raise ValueError(
170
+ "Last message is not a JSON Dictionary (structured data), and this format expects structured_data."
171
+ )
172
+ return json_data
140
173
 
141
174
 
142
175
  def generate_json_schema_message(
143
- training_data: ModelTrainingData,
176
+ training_chat: list[ChatMessage],
144
177
  ) -> Dict[str, Any]:
145
178
  """Generate OpenAI chat format with validated JSON response"""
146
179
  # Load and dump to ensure it's valid JSON and goes to 1 line
147
- try:
148
- json_data = json.loads(training_data.final_output)
149
- except json.JSONDecodeError as e:
150
- raise ValueError(
151
- f"Invalid JSON in JSON Schema training set: {e}\nOutput Data: {training_data.final_output}"
152
- ) from e
153
- json_string = json.dumps(json_data, ensure_ascii=False)
154
-
155
- messages: list[dict[str, str | None]] = [
156
- {"role": "system", "content": training_data.system_message},
157
- {"role": "user", "content": training_data.input},
158
- ]
159
-
160
- if training_data.supports_cot():
161
- messages.extend(
162
- [
163
- {"role": "user", "content": training_data.thinking_instructions},
164
- {"role": "assistant", "content": training_data.thinking},
165
- {
166
- "role": "user",
167
- "content": training_data.thinking_final_answer_prompt,
168
- },
169
- ]
170
- )
180
+ last_msg_data = last_message_structured_content(training_chat)
171
181
 
172
- messages.append({"role": "assistant", "content": json_string})
182
+ # re-format the json string in the last message for consistency
183
+ json_string = json.dumps(last_msg_data, ensure_ascii=False)
184
+ training_chat[-1].content = json_string
173
185
 
174
- return {"messages": messages}
186
+ return generate_chat_message_response(training_chat)
175
187
 
176
188
 
177
189
  def generate_chat_message_toolcall(
178
- training_data: ModelTrainingData,
190
+ training_chat: list[ChatMessage],
179
191
  ) -> Dict[str, Any]:
180
192
  """Generate OpenAI chat format with tool call response"""
181
- try:
182
- arguments = json.loads(training_data.final_output)
183
- except json.JSONDecodeError as e:
184
- raise ValueError(f"Invalid JSON in for tool call: {e}") from e
185
-
186
- messages: list[dict[str, Any]] = [
187
- {"role": "system", "content": training_data.system_message},
188
- {"role": "user", "content": training_data.input},
189
- ]
190
-
191
- if training_data.supports_cot():
192
- messages.extend(
193
- [
194
- {"role": "user", "content": training_data.thinking_instructions},
195
- {"role": "assistant", "content": training_data.thinking},
196
- {
197
- "role": "user",
198
- "content": training_data.thinking_final_answer_prompt,
199
- },
200
- ]
201
- )
193
+ last_message_data = last_message_structured_content(training_chat)
194
+
195
+ messages: list[dict[str, Any]] = generate_chat_message_list(training_chat)
196
+
197
+ # remove the last message, we're going to replace it with a toolcall
198
+ messages = messages[:-1]
202
199
 
203
200
  messages.append(
204
201
  {
@@ -210,8 +207,7 @@ def generate_chat_message_toolcall(
210
207
  "type": "function",
211
208
  "function": {
212
209
  "name": "task_response",
213
- # Yes we parse then dump again. This ensures it's valid JSON, and ensures it goes to 1 line
214
- "arguments": json.dumps(arguments, ensure_ascii=False),
210
+ "arguments": json.dumps(last_message_data, ensure_ascii=False),
215
211
  },
216
212
  }
217
213
  ],
@@ -222,52 +218,26 @@ def generate_chat_message_toolcall(
222
218
 
223
219
 
224
220
  def generate_huggingface_chat_template(
225
- training_data: ModelTrainingData,
221
+ training_chat: list[ChatMessage],
226
222
  ) -> Dict[str, Any]:
227
223
  """Generate HuggingFace chat template"""
228
224
 
229
- conversations: list[dict[str, Any]] = [
230
- {"role": "system", "content": training_data.system_message},
231
- {"role": "user", "content": training_data.input},
232
- ]
233
-
234
- if training_data.supports_cot():
235
- conversations.extend(
236
- [
237
- {"role": "user", "content": training_data.thinking_instructions},
238
- {"role": "assistant", "content": training_data.thinking},
239
- {"role": "user", "content": training_data.thinking_final_answer_prompt},
240
- ]
241
- )
242
-
243
- conversations.append({"role": "assistant", "content": training_data.final_output})
225
+ conversations: list[dict[str, Any]] = generate_chat_message_list(training_chat)
244
226
 
245
227
  return {"conversations": conversations}
246
228
 
247
229
 
248
230
  def generate_huggingface_chat_template_toolcall(
249
- training_data: ModelTrainingData,
231
+ training_chat: list[ChatMessage],
250
232
  ) -> Dict[str, Any]:
251
233
  """Generate HuggingFace chat template with tool calls"""
252
- try:
253
- arguments = json.loads(training_data.final_output)
254
- except json.JSONDecodeError as e:
255
- raise ValueError(f"Invalid JSON in for tool call: {e}") from e
234
+ last_message_data = last_message_structured_content(training_chat)
256
235
 
257
236
  # See https://huggingface.co/docs/transformers/en/chat_templating
258
- conversations: list[dict[str, Any]] = [
259
- {"role": "system", "content": training_data.system_message},
260
- {"role": "user", "content": training_data.input},
261
- ]
262
-
263
- if training_data.supports_cot():
264
- conversations.extend(
265
- [
266
- {"role": "user", "content": training_data.thinking_instructions},
267
- {"role": "assistant", "content": training_data.thinking},
268
- {"role": "user", "content": training_data.thinking_final_answer_prompt},
269
- ]
270
- )
237
+ conversations: list[dict[str, Any]] = generate_chat_message_list(training_chat)
238
+
239
+ # remove the last message, we're going to replace it with a toolcall
240
+ conversations = conversations[:-1]
271
241
 
272
242
  conversations.append(
273
243
  {
@@ -278,7 +248,7 @@ def generate_huggingface_chat_template_toolcall(
278
248
  "function": {
279
249
  "name": "task_response",
280
250
  "id": str(uuid4()).replace("-", "")[:9],
281
- "arguments": arguments,
251
+ "arguments": last_message_data,
282
252
  },
283
253
  }
284
254
  ],
@@ -288,55 +258,41 @@ def generate_huggingface_chat_template_toolcall(
288
258
  return {"conversations": conversations}
289
259
 
290
260
 
261
+ VERTEX_GEMINI_ROLE_MAP = {
262
+ "system": "system",
263
+ "user": "user",
264
+ "assistant": "model",
265
+ }
266
+
267
+
291
268
  def generate_vertex_gemini(
292
- training_data: ModelTrainingData,
269
+ training_chat: list[ChatMessage],
293
270
  ) -> Dict[str, Any]:
294
- """Generate Vertex Gemini 1.5 format (flash and pro)"""
271
+ """Generate Vertex Gemini format (flash and pro)"""
295
272
  # See https://cloud.google.com/vertex-ai/generative-ai/docs/models/gemini-supervised-tuning-prepare
296
273
 
297
- contents = [
298
- {
299
- "role": "user",
300
- "parts": [
301
- {
302
- "text": training_data.input,
303
- }
304
- ],
305
- }
306
- ]
274
+ # System message get's it's own entry in top level UI
275
+ system_instruction = training_chat[0].content
307
276
 
308
- if training_data.supports_cot():
309
- contents.extend(
310
- [
311
- {
312
- "role": "user",
313
- "parts": [{"text": training_data.thinking_instructions}],
314
- },
315
- {"role": "model", "parts": [{"text": training_data.thinking}]},
316
- {
317
- "role": "user",
318
- "parts": [{"text": training_data.thinking_final_answer_prompt}],
319
- },
320
- ]
277
+ messages: list[Dict[str, Any]] = []
278
+ for msg in training_chat[1:]:
279
+ messages.append(
280
+ {
281
+ "role": VERTEX_GEMINI_ROLE_MAP[msg.role],
282
+ "parts": [{"text": msg.content}],
283
+ }
321
284
  )
322
285
 
323
- contents.append(
324
- {
325
- "role": "model",
326
- "parts": [{"text": training_data.final_output}],
327
- }
328
- )
329
-
330
286
  return {
331
287
  "systemInstruction": {
332
288
  "role": "system",
333
289
  "parts": [
334
290
  {
335
- "text": training_data.system_message,
291
+ "text": system_instruction,
336
292
  }
337
293
  ],
338
294
  },
339
- "contents": contents,
295
+ "contents": messages,
340
296
  }
341
297
 
342
298
 
@@ -372,7 +328,7 @@ class DatasetFormatter:
372
328
  self,
373
329
  split_name: str,
374
330
  format_type: DatasetFormat,
375
- data_strategy: FinetuneDataStrategy,
331
+ data_strategy: ChatStrategy,
376
332
  path: Path | None = None,
377
333
  ) -> Path:
378
334
  """
@@ -397,7 +353,7 @@ class DatasetFormatter:
397
353
 
398
354
  generator = FORMAT_GENERATORS[format_type]
399
355
 
400
- include_cot = data_strategy == FinetuneDataStrategy.final_and_intermediate
356
+ include_cot = data_strategy in THINKING_DATA_STRATEGIES
401
357
 
402
358
  # Write to a temp file if no path is provided
403
359
  output_path = (
@@ -418,13 +374,13 @@ class DatasetFormatter:
418
374
  f"Task run {run_id} not found. This is required by this dataset."
419
375
  )
420
376
 
421
- training_data = build_training_data(
377
+ training_chat = build_training_chat(
422
378
  task_run=task_run,
423
379
  system_message=self.system_message,
424
- include_cot=include_cot,
380
+ data_strategy=data_strategy,
425
381
  thinking_instructions=self.thinking_instructions,
426
382
  )
427
- example = generator(training_data)
383
+ example = generator(training_chat)
428
384
  # Allow non-ascii characters in the dataset.
429
385
  # Better readability for non-English users. If you don't support UTF-8... you should.
430
386
  f.write(json.dumps(example, ensure_ascii=False) + "\n")
@@ -4,13 +4,13 @@ import pytest
4
4
 
5
5
  from kiln_ai.adapters.fine_tune.base_finetune import (
6
6
  BaseFinetuneAdapter,
7
- FinetuneDataStrategy,
8
7
  FineTuneParameter,
9
8
  FineTuneStatus,
10
9
  FineTuneStatusType,
11
10
  )
12
11
  from kiln_ai.datamodel import DatasetSplit, Task
13
12
  from kiln_ai.datamodel import Finetune as FinetuneModel
13
+ from kiln_ai.datamodel.datamodel_enums import ChatStrategy
14
14
 
15
15
 
16
16
  class MockFinetune(BaseFinetuneAdapter):
@@ -162,7 +162,7 @@ async def test_create_and_start_success(mock_dataset):
162
162
  train_split_name="train",
163
163
  parameters={"epochs": 10}, # Required parameter
164
164
  system_message="Test system message",
165
- data_strategy=FinetuneDataStrategy.final_only,
165
+ data_strategy=ChatStrategy.single_turn,
166
166
  thinking_instructions=None,
167
167
  )
168
168
 
@@ -176,7 +176,7 @@ async def test_create_and_start_success(mock_dataset):
176
176
  assert datamodel.parameters == {"epochs": 10}
177
177
  assert datamodel.system_message == "Test system message"
178
178
  assert datamodel.path.exists()
179
- assert datamodel.data_strategy == FinetuneDataStrategy.final_only
179
+ assert datamodel.data_strategy == ChatStrategy.single_turn
180
180
  assert datamodel.thinking_instructions is None
181
181
 
182
182
 
@@ -192,7 +192,7 @@ async def test_create_and_start_with_all_params(mock_dataset):
192
192
  description="Custom Description",
193
193
  validation_split_name="test",
194
194
  system_message="Test system message",
195
- data_strategy=FinetuneDataStrategy.final_and_intermediate,
195
+ data_strategy=ChatStrategy.two_message_cot,
196
196
  thinking_instructions="Custom thinking instructions",
197
197
  )
198
198
 
@@ -202,7 +202,7 @@ async def test_create_and_start_with_all_params(mock_dataset):
202
202
  assert datamodel.parameters == {"epochs": 10, "learning_rate": 0.001}
203
203
  assert datamodel.system_message == "Test system message"
204
204
  assert adapter.datamodel == datamodel
205
- assert datamodel.data_strategy == FinetuneDataStrategy.final_and_intermediate
205
+ assert datamodel.data_strategy == ChatStrategy.two_message_cot
206
206
  assert datamodel.thinking_instructions == "Custom thinking instructions"
207
207
 
208
208
  # load the datamodel from the file, confirm it's saved
@@ -221,7 +221,7 @@ async def test_create_and_start_invalid_parameters(mock_dataset):
221
221
  parameters={"learning_rate": 0.001}, # Missing required 'epochs'
222
222
  system_message="Test system message",
223
223
  thinking_instructions=None,
224
- data_strategy=FinetuneDataStrategy.final_only,
224
+ data_strategy=ChatStrategy.single_turn,
225
225
  )
226
226
 
227
227
 
@@ -240,7 +240,7 @@ async def test_create_and_start_no_parent_task():
240
240
  train_split_name="train",
241
241
  parameters={"epochs": 10},
242
242
  system_message="Test system message",
243
- data_strategy=FinetuneDataStrategy.final_only,
243
+ data_strategy=ChatStrategy.single_turn,
244
244
  thinking_instructions=None,
245
245
  )
246
246
 
@@ -263,7 +263,7 @@ async def test_create_and_start_no_parent_task_path():
263
263
  train_split_name="train",
264
264
  parameters={"epochs": 10},
265
265
  system_message="Test system message",
266
- data_strategy=FinetuneDataStrategy.final_only,
266
+ data_strategy=ChatStrategy.single_turn,
267
267
  thinking_instructions=None,
268
268
  )
269
269
 
@@ -282,7 +282,7 @@ async def test_create_and_start_invalid_train_split(mock_dataset):
282
282
  train_split_name="invalid_train", # Invalid train split
283
283
  parameters={"epochs": 10},
284
284
  system_message="Test system message",
285
- data_strategy=FinetuneDataStrategy.final_only,
285
+ data_strategy=ChatStrategy.single_turn,
286
286
  thinking_instructions=None,
287
287
  )
288
288
 
@@ -302,6 +302,6 @@ async def test_create_and_start_invalid_validation_split(mock_dataset):
302
302
  validation_split_name="invalid_test", # Invalid validation split
303
303
  parameters={"epochs": 10},
304
304
  system_message="Test system message",
305
- data_strategy=FinetuneDataStrategy.final_only,
305
+ data_strategy=ChatStrategy.single_turn,
306
306
  thinking_instructions=None,
307
307
  )