kiln-ai 0.14.0__py3-none-any.whl → 0.16.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. kiln_ai/adapters/eval/base_eval.py +7 -2
  2. kiln_ai/adapters/eval/eval_runner.py +5 -64
  3. kiln_ai/adapters/eval/g_eval.py +3 -3
  4. kiln_ai/adapters/fine_tune/base_finetune.py +6 -3
  5. kiln_ai/adapters/fine_tune/dataset_formatter.py +128 -38
  6. kiln_ai/adapters/fine_tune/finetune_registry.py +2 -0
  7. kiln_ai/adapters/fine_tune/fireworks_finetune.py +2 -1
  8. kiln_ai/adapters/fine_tune/test_base_finetune.py +7 -0
  9. kiln_ai/adapters/fine_tune/test_dataset_formatter.py +267 -10
  10. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +1 -1
  11. kiln_ai/adapters/fine_tune/test_vertex_finetune.py +586 -0
  12. kiln_ai/adapters/fine_tune/vertex_finetune.py +217 -0
  13. kiln_ai/adapters/ml_model_list.py +817 -62
  14. kiln_ai/adapters/model_adapters/base_adapter.py +33 -10
  15. kiln_ai/adapters/model_adapters/litellm_adapter.py +51 -12
  16. kiln_ai/adapters/model_adapters/test_base_adapter.py +74 -2
  17. kiln_ai/adapters/model_adapters/test_litellm_adapter.py +65 -1
  18. kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +3 -2
  19. kiln_ai/adapters/model_adapters/test_structured_output.py +4 -6
  20. kiln_ai/adapters/parsers/base_parser.py +0 -3
  21. kiln_ai/adapters/parsers/parser_registry.py +5 -3
  22. kiln_ai/adapters/parsers/r1_parser.py +17 -2
  23. kiln_ai/adapters/parsers/request_formatters.py +40 -0
  24. kiln_ai/adapters/parsers/test_parser_registry.py +2 -2
  25. kiln_ai/adapters/parsers/test_r1_parser.py +44 -1
  26. kiln_ai/adapters/parsers/test_request_formatters.py +76 -0
  27. kiln_ai/adapters/prompt_builders.py +14 -1
  28. kiln_ai/adapters/provider_tools.py +25 -1
  29. kiln_ai/adapters/repair/test_repair_task.py +3 -2
  30. kiln_ai/adapters/test_prompt_builders.py +24 -3
  31. kiln_ai/adapters/test_provider_tools.py +86 -1
  32. kiln_ai/datamodel/__init__.py +2 -0
  33. kiln_ai/datamodel/datamodel_enums.py +14 -0
  34. kiln_ai/datamodel/dataset_filters.py +69 -1
  35. kiln_ai/datamodel/dataset_split.py +4 -0
  36. kiln_ai/datamodel/eval.py +8 -0
  37. kiln_ai/datamodel/finetune.py +1 -0
  38. kiln_ai/datamodel/json_schema.py +24 -7
  39. kiln_ai/datamodel/prompt_id.py +1 -0
  40. kiln_ai/datamodel/task_output.py +10 -6
  41. kiln_ai/datamodel/task_run.py +68 -12
  42. kiln_ai/datamodel/test_basemodel.py +3 -7
  43. kiln_ai/datamodel/test_dataset_filters.py +82 -0
  44. kiln_ai/datamodel/test_dataset_split.py +2 -0
  45. kiln_ai/datamodel/test_example_models.py +158 -3
  46. kiln_ai/datamodel/test_json_schema.py +22 -3
  47. kiln_ai/datamodel/test_model_perf.py +3 -2
  48. kiln_ai/datamodel/test_models.py +50 -2
  49. kiln_ai/utils/async_job_runner.py +106 -0
  50. kiln_ai/utils/dataset_import.py +80 -18
  51. kiln_ai/utils/test_async_job_runner.py +199 -0
  52. kiln_ai/utils/test_dataset_import.py +242 -10
  53. {kiln_ai-0.14.0.dist-info → kiln_ai-0.16.0.dist-info}/METADATA +3 -2
  54. kiln_ai-0.16.0.dist-info/RECORD +108 -0
  55. kiln_ai/adapters/test_generate_docs.py +0 -69
  56. kiln_ai-0.14.0.dist-info/RECORD +0 -103
  57. {kiln_ai-0.14.0.dist-info → kiln_ai-0.16.0.dist-info}/WHEEL +0 -0
  58. {kiln_ai-0.14.0.dist-info → kiln_ai-0.16.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -2,11 +2,13 @@ import json
2
2
  from abc import abstractmethod
3
3
  from typing import Dict
4
4
 
5
+ import jsonschema
6
+
5
7
  from kiln_ai.adapters.adapter_registry import adapter_for_task
6
8
  from kiln_ai.adapters.ml_model_list import ModelProviderName
7
9
  from kiln_ai.adapters.model_adapters.base_adapter import AdapterConfig
8
10
  from kiln_ai.datamodel.eval import Eval, EvalConfig, EvalScores
9
- from kiln_ai.datamodel.json_schema import validate_schema
11
+ from kiln_ai.datamodel.json_schema import validate_schema_with_value_error
10
12
  from kiln_ai.datamodel.task import RunConfig, TaskOutputRatingType, TaskRun
11
13
  from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
12
14
 
@@ -72,7 +74,10 @@ class BaseEval:
72
74
  run_output = await run_adapter.invoke(parsed_input)
73
75
 
74
76
  eval_output, intermediate_outputs = await self.run_eval(run_output)
75
- validate_schema(eval_output, self.score_schema)
77
+
78
+ validate_schema_with_value_error(
79
+ eval_output, self.score_schema, "Eval output does not match score schema."
80
+ )
76
81
 
77
82
  return run_output, eval_output, intermediate_outputs
78
83
 
@@ -1,4 +1,3 @@
1
- import asyncio
2
1
  import logging
3
2
  from dataclasses import dataclass
4
3
  from typing import AsyncGenerator, Dict, List, Literal, Set
@@ -10,6 +9,7 @@ from kiln_ai.datamodel.dataset_filters import dataset_filter_from_id
10
9
  from kiln_ai.datamodel.eval import EvalConfig, EvalRun, EvalScores
11
10
  from kiln_ai.datamodel.task import TaskRunConfig
12
11
  from kiln_ai.datamodel.task_run import TaskRun
12
+ from kiln_ai.utils.async_job_runner import AsyncJobRunner, Progress
13
13
 
14
14
  logger = logging.getLogger(__name__)
15
15
 
@@ -23,13 +23,6 @@ class EvalJob:
23
23
  task_run_config: TaskRunConfig | None = None
24
24
 
25
25
 
26
- @dataclass
27
- class EvalProgress:
28
- complete: int | None = None
29
- total: int | None = None
30
- errors: int | None = None
31
-
32
-
33
26
  class EvalRunner:
34
27
  """
35
28
  Runs an eval. Async execution is supported to make it faster when using remote/fast model providers.
@@ -161,67 +154,15 @@ class EvalRunner:
161
154
  if task_run.id not in already_run[eval_config.id][run_config.id]
162
155
  ]
163
156
 
164
- async def run(self, concurrency: int = 25) -> AsyncGenerator[EvalProgress, None]:
157
+ async def run(self, concurrency: int = 25) -> AsyncGenerator[Progress, None]:
165
158
  """
166
159
  Runs the configured eval run with parallel workers and yields progress updates.
167
160
  """
168
161
  jobs = self.collect_tasks()
169
162
 
170
- complete = 0
171
- errors = 0
172
- total = len(jobs)
173
-
174
- # Send initial status
175
- yield EvalProgress(complete=complete, total=total, errors=errors)
176
-
177
- worker_queue: asyncio.Queue[EvalJob] = asyncio.Queue()
178
- for job in jobs:
179
- worker_queue.put_nowait(job)
180
-
181
- # simple status queue to return progress. True=success, False=error
182
- status_queue: asyncio.Queue[bool] = asyncio.Queue()
183
-
184
- workers = []
185
- for i in range(concurrency):
186
- task = asyncio.create_task(self.run_worker(worker_queue, status_queue))
187
- workers.append(task)
188
-
189
- # Send status updates until workers are done, and they are all sent
190
- while not status_queue.empty() or not all(worker.done() for worker in workers):
191
- try:
192
- # Use timeout to prevent hanging if all workers complete
193
- # between our while condition check and get()
194
- success = await asyncio.wait_for(status_queue.get(), timeout=0.1)
195
- if success:
196
- complete += 1
197
- else:
198
- errors += 1
199
-
200
- yield EvalProgress(complete=complete, total=total, errors=errors)
201
- except asyncio.TimeoutError:
202
- # Timeout is expected, just continue to recheck worker status
203
- # Don't love this but beats sentinels for reliability
204
- continue
205
-
206
- # These are redundant, but keeping them will catch async errors
207
- await asyncio.gather(*workers)
208
- await worker_queue.join()
209
-
210
- async def run_worker(
211
- self, worker_queue: asyncio.Queue[EvalJob], status_queue: asyncio.Queue[bool]
212
- ):
213
- while True:
214
- try:
215
- job = worker_queue.get_nowait()
216
- except asyncio.QueueEmpty:
217
- # worker can end when the queue is empty
218
- break
219
- try:
220
- success = await self.run_job(job)
221
- await status_queue.put(success)
222
- finally:
223
- # Always mark the dequeued task as done, even on exceptions
224
- worker_queue.task_done()
163
+ runner = AsyncJobRunner(concurrency=concurrency)
164
+ async for progress in runner.run(jobs, self.run_job):
165
+ yield progress
225
166
 
226
167
  async def run_job(self, job: EvalJob) -> bool:
227
168
  try:
@@ -43,9 +43,9 @@ class GEvalTask(Task, parent_of={}):
43
43
 
44
44
  # Build the COT eval instructions
45
45
  cot_instructions = "First, think step by step about the model's performance following these evaluation steps:\n\n"
46
- steps = eval_config.properties.get("eval_steps", None)
47
- if not steps or not isinstance(steps, list):
48
- raise ValueError("eval_steps must be a list")
46
+ steps = eval_config.properties.get("eval_steps", [])
47
+ if not isinstance(steps, list):
48
+ raise ValueError("eval_steps must be a list.")
49
49
  for i, step in enumerate(steps):
50
50
  cot_instructions += f"{i + 1}) {step}\n"
51
51
 
@@ -166,9 +166,12 @@ class BaseFinetuneAdapter(ABC):
166
166
 
167
167
  # Strict type checking for numeric types
168
168
  if expected_type is float and not isinstance(value, float):
169
- raise ValueError(
170
- f"Parameter {parameter.name} must be a float, got {type(value)}"
171
- )
169
+ if isinstance(value, int):
170
+ value = float(value)
171
+ else:
172
+ raise ValueError(
173
+ f"Parameter {parameter.name} must be a float, got {type(value)}"
174
+ )
172
175
  elif expected_type is int and not isinstance(value, int):
173
176
  raise ValueError(
174
177
  f"Parameter {parameter.name} must be an integer, got {type(value)}"
@@ -8,6 +8,7 @@ from uuid import uuid4
8
8
 
9
9
  from kiln_ai.adapters.model_adapters.base_adapter import COT_FINAL_ANSWER_PROMPT
10
10
  from kiln_ai.datamodel import DatasetSplit, FinetuneDataStrategy, TaskRun
11
+ from kiln_ai.datamodel.datamodel_enums import THINKING_DATA_STRATEGIES
11
12
 
12
13
 
13
14
  class DatasetFormat(str, Enum):
@@ -30,8 +31,8 @@ class DatasetFormat(str, Enum):
30
31
  "huggingface_chat_template_toolcall_jsonl"
31
32
  )
32
33
 
33
- """Vertex Gemini 1.5 format (flash and pro)"""
34
- VERTEX_GEMINI_1_5 = "vertex_gemini_1_5"
34
+ """Vertex Gemini format"""
35
+ VERTEX_GEMINI = "vertex_gemini"
35
36
 
36
37
 
37
38
  @dataclass
@@ -43,8 +44,12 @@ class ModelTrainingData:
43
44
  thinking_instructions: str | None = None
44
45
  thinking: str | None = None
45
46
  thinking_final_answer_prompt: str | None = None
47
+ thinking_r1_style: bool = False
46
48
 
47
49
  def supports_cot(self) -> bool:
50
+ if self.thinking_r1_style:
51
+ raise ValueError("R1 style does not support COT")
52
+
48
53
  return (
49
54
  self.thinking_instructions is not None
50
55
  and self.thinking is not None
@@ -64,7 +69,7 @@ class FormatGenerator(Protocol):
64
69
  def build_training_data(
65
70
  task_run: TaskRun,
66
71
  system_message: str,
67
- include_cot: bool,
72
+ data_strategy: FinetuneDataStrategy,
68
73
  thinking_instructions: str | None = None,
69
74
  ) -> ModelTrainingData:
70
75
  """
@@ -80,27 +85,41 @@ def build_training_data(
80
85
 
81
86
  thinking = None
82
87
  thinking_final_answer_prompt = None
88
+ thinking_r1_style = False
83
89
  parent_task = task_run.parent_task()
84
90
 
85
- if include_cot and task_run.has_thinking_training_data():
86
- if not parent_task:
87
- raise ValueError(
88
- "TaskRuns for training required a parent Task for building a chain of thought prompts. Train without COT, or save this TaskRun to a parent Task."
89
- )
90
-
91
+ if data_strategy in THINKING_DATA_STRATEGIES:
91
92
  # Prefer reasoning to cot if both are present
92
- intermediate_outputs = task_run.intermediate_outputs or {}
93
- thinking = intermediate_outputs.get("reasoning") or intermediate_outputs.get(
94
- "chain_of_thought"
95
- )
93
+ thinking = task_run.thinking_training_data()
94
+
95
+ if data_strategy == FinetuneDataStrategy.final_and_intermediate_r1_compatible:
96
+ if not task_run.has_thinking_training_data() or not thinking:
97
+ raise ValueError(
98
+ "Thinking data is required when fine-tuning thinking models (R1, QwQ, etc). Please ensure your fine-tuning dataset contains reasoning or chain of thought output for every entry."
99
+ )
100
+ if thinking_instructions:
101
+ raise ValueError(
102
+ "Thinking instructions are not supported when fine-tuning thinking models (R1, QwQ, etc). Please remove the thinking instructions."
103
+ )
104
+ thinking_r1_style = True
105
+ elif (
106
+ data_strategy == FinetuneDataStrategy.final_and_intermediate
107
+ and task_run.has_thinking_training_data()
108
+ ):
109
+ if not parent_task:
110
+ raise ValueError(
111
+ "TaskRuns for training required a parent Task for building a chain of thought prompts. Train without COT, or save this TaskRun to a parent Task."
112
+ )
96
113
 
97
- thinking_final_answer_prompt = COT_FINAL_ANSWER_PROMPT
114
+ thinking_final_answer_prompt = COT_FINAL_ANSWER_PROMPT
98
115
 
99
- # Always use the passed thinking instructions, but check they are present for COT
100
- if not thinking_instructions:
101
- raise ValueError(
102
- "Thinking instructions are required when data_strategy is final_and_intermediate"
103
- )
116
+ # Always use the passed thinking instructions, but check they are present for COT
117
+ if not thinking_instructions:
118
+ raise ValueError(
119
+ "Thinking instructions are required when data_strategy is final_and_intermediate"
120
+ )
121
+ else:
122
+ raise ValueError(f"Unsupported data strategy: {data_strategy}")
104
123
 
105
124
  return ModelTrainingData(
106
125
  input=task_run.input,
@@ -109,9 +128,19 @@ def build_training_data(
109
128
  thinking=thinking,
110
129
  thinking_instructions=thinking_instructions,
111
130
  thinking_final_answer_prompt=thinking_final_answer_prompt,
131
+ thinking_r1_style=thinking_r1_style,
112
132
  )
113
133
 
114
134
 
135
+ def serialize_r1_style_message(thinking: str | None, final_output: str):
136
+ if thinking is None or len(thinking.strip()) == 0:
137
+ raise ValueError(
138
+ "Thinking data is required when fine-tuning thinking models (R1, QwQ, etc). Please ensure your fine-tuning dataset contains reasoning or chain of thought output for every entry."
139
+ )
140
+
141
+ return f"<think>\n{thinking}\n</think>\n\n{final_output}"
142
+
143
+
115
144
  def generate_chat_message_response(
116
145
  training_data: ModelTrainingData,
117
146
  ) -> Dict[str, Any]:
@@ -122,7 +151,21 @@ def generate_chat_message_response(
122
151
  {"role": "user", "content": training_data.input},
123
152
  ]
124
153
 
125
- if training_data.supports_cot():
154
+ if training_data.thinking_r1_style:
155
+ messages.extend(
156
+ [
157
+ {
158
+ "role": "assistant",
159
+ "content": serialize_r1_style_message(
160
+ thinking=training_data.thinking,
161
+ final_output=training_data.final_output,
162
+ ),
163
+ }
164
+ ]
165
+ )
166
+
167
+ return {"messages": messages}
168
+ elif training_data.supports_cot():
126
169
  messages.extend(
127
170
  [
128
171
  {"role": "user", "content": training_data.thinking_instructions},
@@ -157,7 +200,21 @@ def generate_json_schema_message(
157
200
  {"role": "user", "content": training_data.input},
158
201
  ]
159
202
 
160
- if training_data.supports_cot():
203
+ if training_data.thinking_r1_style:
204
+ messages.extend(
205
+ [
206
+ {
207
+ "role": "assistant",
208
+ "content": serialize_r1_style_message(
209
+ thinking=training_data.thinking,
210
+ final_output=training_data.final_output,
211
+ ),
212
+ }
213
+ ]
214
+ )
215
+
216
+ return {"messages": messages}
217
+ elif training_data.supports_cot():
161
218
  messages.extend(
162
219
  [
163
220
  {"role": "user", "content": training_data.thinking_instructions},
@@ -188,7 +245,11 @@ def generate_chat_message_toolcall(
188
245
  {"role": "user", "content": training_data.input},
189
246
  ]
190
247
 
191
- if training_data.supports_cot():
248
+ if training_data.thinking_r1_style:
249
+ raise ValueError(
250
+ "R1 style thinking is not supported for tool call downloads. Please use a different training strategy."
251
+ )
252
+ elif training_data.supports_cot():
192
253
  messages.extend(
193
254
  [
194
255
  {"role": "user", "content": training_data.thinking_instructions},
@@ -231,12 +292,29 @@ def generate_huggingface_chat_template(
231
292
  {"role": "user", "content": training_data.input},
232
293
  ]
233
294
 
295
+ if training_data.thinking_r1_style:
296
+ conversations.extend(
297
+ [
298
+ {
299
+ "role": "assistant",
300
+ "content": serialize_r1_style_message(
301
+ thinking=training_data.thinking,
302
+ final_output=training_data.final_output,
303
+ ),
304
+ }
305
+ ]
306
+ )
307
+ return {"conversations": conversations}
308
+
234
309
  if training_data.supports_cot():
235
310
  conversations.extend(
236
311
  [
237
312
  {"role": "user", "content": training_data.thinking_instructions},
238
313
  {"role": "assistant", "content": training_data.thinking},
239
- {"role": "user", "content": training_data.thinking_final_answer_prompt},
314
+ {
315
+ "role": "user",
316
+ "content": training_data.thinking_final_answer_prompt,
317
+ },
240
318
  ]
241
319
  )
242
320
 
@@ -260,12 +338,19 @@ def generate_huggingface_chat_template_toolcall(
260
338
  {"role": "user", "content": training_data.input},
261
339
  ]
262
340
 
263
- if training_data.supports_cot():
341
+ if training_data.thinking_r1_style:
342
+ raise ValueError(
343
+ "R1 style thinking is not supported for tool call downloads. Please use a different training strategy."
344
+ )
345
+ elif training_data.supports_cot():
264
346
  conversations.extend(
265
347
  [
266
348
  {"role": "user", "content": training_data.thinking_instructions},
267
349
  {"role": "assistant", "content": training_data.thinking},
268
- {"role": "user", "content": training_data.thinking_final_answer_prompt},
350
+ {
351
+ "role": "user",
352
+ "content": training_data.thinking_final_answer_prompt,
353
+ },
269
354
  ]
270
355
  )
271
356
 
@@ -288,12 +373,20 @@ def generate_huggingface_chat_template_toolcall(
288
373
  return {"conversations": conversations}
289
374
 
290
375
 
291
- def generate_vertex_gemini_1_5(
376
+ def generate_vertex_gemini(
292
377
  training_data: ModelTrainingData,
293
378
  ) -> Dict[str, Any]:
294
379
  """Generate Vertex Gemini 1.5 format (flash and pro)"""
295
380
  # See https://cloud.google.com/vertex-ai/generative-ai/docs/models/gemini-supervised-tuning-prepare
296
381
 
382
+ system_instruction = {
383
+ "role": "system",
384
+ "parts": [
385
+ {
386
+ "text": training_data.system_message,
387
+ }
388
+ ],
389
+ }
297
390
  contents = [
298
391
  {
299
392
  "role": "user",
@@ -305,7 +398,11 @@ def generate_vertex_gemini_1_5(
305
398
  }
306
399
  ]
307
400
 
308
- if training_data.supports_cot():
401
+ if training_data.thinking_r1_style:
402
+ raise ValueError(
403
+ "R1 style thinking is not supported for Vertex Gemini. Please use a different training strategy."
404
+ )
405
+ elif training_data.supports_cot():
309
406
  contents.extend(
310
407
  [
311
408
  {
@@ -328,14 +425,7 @@ def generate_vertex_gemini_1_5(
328
425
  )
329
426
 
330
427
  return {
331
- "systemInstruction": {
332
- "role": "system",
333
- "parts": [
334
- {
335
- "text": training_data.system_message,
336
- }
337
- ],
338
- },
428
+ "systemInstruction": system_instruction,
339
429
  "contents": contents,
340
430
  }
341
431
 
@@ -346,7 +436,7 @@ FORMAT_GENERATORS: Dict[DatasetFormat, FormatGenerator] = {
346
436
  DatasetFormat.OPENAI_CHAT_TOOLCALL_JSONL: generate_chat_message_toolcall,
347
437
  DatasetFormat.HUGGINGFACE_CHAT_TEMPLATE_JSONL: generate_huggingface_chat_template,
348
438
  DatasetFormat.HUGGINGFACE_CHAT_TEMPLATE_TOOLCALL_JSONL: generate_huggingface_chat_template_toolcall,
349
- DatasetFormat.VERTEX_GEMINI_1_5: generate_vertex_gemini_1_5,
439
+ DatasetFormat.VERTEX_GEMINI: generate_vertex_gemini,
350
440
  }
351
441
 
352
442
 
@@ -397,7 +487,7 @@ class DatasetFormatter:
397
487
 
398
488
  generator = FORMAT_GENERATORS[format_type]
399
489
 
400
- include_cot = data_strategy == FinetuneDataStrategy.final_and_intermediate
490
+ include_cot = data_strategy in THINKING_DATA_STRATEGIES
401
491
 
402
492
  # Write to a temp file if no path is provided
403
493
  output_path = (
@@ -421,7 +511,7 @@ class DatasetFormatter:
421
511
  training_data = build_training_data(
422
512
  task_run=task_run,
423
513
  system_message=self.system_message,
424
- include_cot=include_cot,
514
+ data_strategy=data_strategy,
425
515
  thinking_instructions=self.thinking_instructions,
426
516
  )
427
517
  example = generator(training_data)
@@ -4,10 +4,12 @@ from kiln_ai.adapters.fine_tune.base_finetune import BaseFinetuneAdapter
4
4
  from kiln_ai.adapters.fine_tune.fireworks_finetune import FireworksFinetune
5
5
  from kiln_ai.adapters.fine_tune.openai_finetune import OpenAIFinetune
6
6
  from kiln_ai.adapters.fine_tune.together_finetune import TogetherFinetune
7
+ from kiln_ai.adapters.fine_tune.vertex_finetune import VertexFinetune
7
8
  from kiln_ai.adapters.ml_model_list import ModelProviderName
8
9
 
9
10
  finetune_registry: dict[ModelProviderName, Type[BaseFinetuneAdapter]] = {
10
11
  ModelProviderName.openai: OpenAIFinetune,
11
12
  ModelProviderName.fireworks_ai: FireworksFinetune,
12
13
  ModelProviderName.together_ai: TogetherFinetune,
14
+ ModelProviderName.vertex: VertexFinetune,
13
15
  }
@@ -198,7 +198,8 @@ class FireworksFinetune(BaseFinetuneAdapter):
198
198
  if not api_key or not account_id:
199
199
  raise ValueError("Fireworks API key or account ID not set")
200
200
  url = f"https://api.fireworks.ai/v1/accounts/{account_id}/datasets"
201
- dataset_id = str(uuid4())
201
+ # First char can't be a digit: https://discord.com/channels/1137072072808472616/1363214412395184350/1363214412395184350
202
+ dataset_id = "kiln-" + str(uuid4())
202
203
  payload = {
203
204
  "datasetId": dataset_id,
204
205
  "dataset": {
@@ -98,6 +98,13 @@ def test_validate_parameters_valid():
98
98
  }
99
99
  MockFinetune.validate_parameters(valid_params) # Should not raise
100
100
 
101
+ # Test valid parameters (float as int)
102
+ valid_params = {
103
+ "learning_rate": 1,
104
+ "epochs": 10,
105
+ }
106
+ MockFinetune.validate_parameters(valid_params) # Should not raise
107
+
101
108
 
102
109
  def test_validate_parameters_missing_required():
103
110
  # Test missing required parameter