kiln-ai 0.11.1__py3-none-any.whl → 0.13.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (80) hide show
  1. kiln_ai/adapters/__init__.py +4 -0
  2. kiln_ai/adapters/adapter_registry.py +163 -39
  3. kiln_ai/adapters/data_gen/data_gen_task.py +18 -0
  4. kiln_ai/adapters/eval/__init__.py +28 -0
  5. kiln_ai/adapters/eval/base_eval.py +164 -0
  6. kiln_ai/adapters/eval/eval_runner.py +270 -0
  7. kiln_ai/adapters/eval/g_eval.py +368 -0
  8. kiln_ai/adapters/eval/registry.py +16 -0
  9. kiln_ai/adapters/eval/test_base_eval.py +325 -0
  10. kiln_ai/adapters/eval/test_eval_runner.py +641 -0
  11. kiln_ai/adapters/eval/test_g_eval.py +498 -0
  12. kiln_ai/adapters/eval/test_g_eval_data.py +4 -0
  13. kiln_ai/adapters/fine_tune/base_finetune.py +16 -2
  14. kiln_ai/adapters/fine_tune/finetune_registry.py +2 -0
  15. kiln_ai/adapters/fine_tune/test_dataset_formatter.py +4 -1
  16. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +1 -1
  17. kiln_ai/adapters/fine_tune/test_openai_finetune.py +1 -1
  18. kiln_ai/adapters/fine_tune/test_together_finetune.py +531 -0
  19. kiln_ai/adapters/fine_tune/together_finetune.py +325 -0
  20. kiln_ai/adapters/ml_model_list.py +758 -163
  21. kiln_ai/adapters/model_adapters/__init__.py +2 -4
  22. kiln_ai/adapters/model_adapters/base_adapter.py +61 -43
  23. kiln_ai/adapters/model_adapters/litellm_adapter.py +391 -0
  24. kiln_ai/adapters/model_adapters/litellm_config.py +13 -0
  25. kiln_ai/adapters/model_adapters/test_base_adapter.py +22 -13
  26. kiln_ai/adapters/model_adapters/test_litellm_adapter.py +407 -0
  27. kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +41 -19
  28. kiln_ai/adapters/model_adapters/test_structured_output.py +59 -35
  29. kiln_ai/adapters/ollama_tools.py +3 -3
  30. kiln_ai/adapters/parsers/r1_parser.py +19 -14
  31. kiln_ai/adapters/parsers/test_r1_parser.py +17 -5
  32. kiln_ai/adapters/prompt_builders.py +80 -42
  33. kiln_ai/adapters/provider_tools.py +50 -58
  34. kiln_ai/adapters/repair/repair_task.py +9 -21
  35. kiln_ai/adapters/repair/test_repair_task.py +6 -6
  36. kiln_ai/adapters/run_output.py +3 -0
  37. kiln_ai/adapters/test_adapter_registry.py +26 -29
  38. kiln_ai/adapters/test_generate_docs.py +4 -4
  39. kiln_ai/adapters/test_ollama_tools.py +0 -1
  40. kiln_ai/adapters/test_prompt_adaptors.py +47 -33
  41. kiln_ai/adapters/test_prompt_builders.py +91 -31
  42. kiln_ai/adapters/test_provider_tools.py +26 -81
  43. kiln_ai/datamodel/__init__.py +50 -952
  44. kiln_ai/datamodel/basemodel.py +2 -0
  45. kiln_ai/datamodel/datamodel_enums.py +60 -0
  46. kiln_ai/datamodel/dataset_filters.py +114 -0
  47. kiln_ai/datamodel/dataset_split.py +170 -0
  48. kiln_ai/datamodel/eval.py +298 -0
  49. kiln_ai/datamodel/finetune.py +105 -0
  50. kiln_ai/datamodel/json_schema.py +7 -1
  51. kiln_ai/datamodel/project.py +23 -0
  52. kiln_ai/datamodel/prompt.py +37 -0
  53. kiln_ai/datamodel/prompt_id.py +83 -0
  54. kiln_ai/datamodel/strict_mode.py +24 -0
  55. kiln_ai/datamodel/task.py +181 -0
  56. kiln_ai/datamodel/task_output.py +328 -0
  57. kiln_ai/datamodel/task_run.py +164 -0
  58. kiln_ai/datamodel/test_basemodel.py +19 -11
  59. kiln_ai/datamodel/test_dataset_filters.py +71 -0
  60. kiln_ai/datamodel/test_dataset_split.py +32 -8
  61. kiln_ai/datamodel/test_datasource.py +22 -2
  62. kiln_ai/datamodel/test_eval_model.py +635 -0
  63. kiln_ai/datamodel/test_example_models.py +9 -13
  64. kiln_ai/datamodel/test_json_schema.py +23 -0
  65. kiln_ai/datamodel/test_models.py +2 -2
  66. kiln_ai/datamodel/test_prompt_id.py +129 -0
  67. kiln_ai/datamodel/test_task.py +159 -0
  68. kiln_ai/utils/config.py +43 -1
  69. kiln_ai/utils/dataset_import.py +232 -0
  70. kiln_ai/utils/test_dataset_import.py +596 -0
  71. {kiln_ai-0.11.1.dist-info → kiln_ai-0.13.0.dist-info}/METADATA +86 -6
  72. kiln_ai-0.13.0.dist-info/RECORD +103 -0
  73. kiln_ai/adapters/model_adapters/langchain_adapters.py +0 -302
  74. kiln_ai/adapters/model_adapters/openai_compatible_config.py +0 -11
  75. kiln_ai/adapters/model_adapters/openai_model_adapter.py +0 -246
  76. kiln_ai/adapters/model_adapters/test_langchain_adapter.py +0 -350
  77. kiln_ai/adapters/model_adapters/test_openai_model_adapter.py +0 -225
  78. kiln_ai-0.11.1.dist-info/RECORD +0 -76
  79. {kiln_ai-0.11.1.dist-info → kiln_ai-0.13.0.dist-info}/WHEEL +0 -0
  80. {kiln_ai-0.11.1.dist-info → kiln_ai-0.13.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -0,0 +1,531 @@
1
+ from pathlib import Path
2
+ from unittest.mock import AsyncMock, MagicMock, patch
3
+
4
+ import pytest
5
+ import together
6
+ from together.types.finetune import FinetuneJobStatus as TogetherFinetuneJobStatus
7
+
8
+ from kiln_ai.adapters.fine_tune.base_finetune import (
9
+ FineTuneParameter,
10
+ FineTuneStatusType,
11
+ )
12
+ from kiln_ai.adapters.fine_tune.dataset_formatter import DatasetFormat, DatasetFormatter
13
+ from kiln_ai.adapters.fine_tune.together_finetune import (
14
+ TogetherFinetune,
15
+ _completed_statuses,
16
+ _failed_statuses,
17
+ _pending_statuses,
18
+ _running_statuses,
19
+ )
20
+ from kiln_ai.datamodel import (
21
+ DatasetSplit,
22
+ StructuredOutputMode,
23
+ Task,
24
+ )
25
+ from kiln_ai.datamodel import Finetune as FinetuneModel
26
+ from kiln_ai.datamodel.dataset_split import Train80Test20SplitDefinition
27
+ from kiln_ai.utils.config import Config
28
+
29
+
30
+ def test_together_status_categorization():
31
+ """
32
+ Test that all statuses from TogetherFinetuneJobStatus are included in exactly one
33
+ of the status categorization arrays.
34
+ """
35
+ # Collect all status values from the TogetherFinetuneJobStatus class
36
+ all_statuses = list(TogetherFinetuneJobStatus)
37
+
38
+ # Collect all statuses from the categorization arrays
39
+ categorized_statuses = set()
40
+ categorized_statuses.update(_pending_statuses)
41
+ categorized_statuses.update(_running_statuses)
42
+ categorized_statuses.update(_completed_statuses)
43
+ categorized_statuses.update(_failed_statuses)
44
+
45
+ # Check if any status is missing from categorization
46
+ missing_statuses = set(all_statuses) - categorized_statuses
47
+ assert not missing_statuses, (
48
+ f"These statuses are not categorized: {missing_statuses}"
49
+ )
50
+
51
+ # Check if any status appears in multiple categories
52
+ all_categorization_lists = [
53
+ _pending_statuses,
54
+ _running_statuses,
55
+ _completed_statuses,
56
+ _failed_statuses,
57
+ ]
58
+
59
+ for status in all_statuses:
60
+ appearances = sum(status in category for category in all_categorization_lists)
61
+ assert appearances == 1, (
62
+ f"Status '{status}' appears in {appearances} categories (should be exactly 1)"
63
+ )
64
+
65
+
66
+ @pytest.fixture
67
+ def finetune(tmp_path):
68
+ tmp_file = tmp_path / "test-finetune.kiln"
69
+ datamodel = FinetuneModel(
70
+ name="test-finetune",
71
+ provider="together",
72
+ provider_id="together-123",
73
+ base_model_id="llama-v2-7b",
74
+ train_split_name="train",
75
+ dataset_split_id="dataset-123",
76
+ system_message="Test system message",
77
+ path=tmp_file,
78
+ )
79
+ return datamodel
80
+
81
+
82
+ @pytest.fixture
83
+ def together_finetune(finetune, mock_together_client, mock_api_key):
84
+ finetune = TogetherFinetune(datamodel=finetune)
85
+ return finetune
86
+
87
+
88
+ @pytest.fixture
89
+ def mock_together_client():
90
+ with patch(
91
+ "kiln_ai.adapters.fine_tune.together_finetune.Together"
92
+ ) as mock_together:
93
+ mock_client = MagicMock()
94
+ mock_together.return_value = mock_client
95
+ yield mock_client
96
+
97
+
98
+ @pytest.fixture
99
+ def mock_api_key():
100
+ with patch.object(Config, "shared") as mock_config:
101
+ mock_config.return_value.together_api_key = "test-api-key"
102
+ yield
103
+
104
+
105
+ def test_init_missing_api_key(finetune):
106
+ with patch.object(Config, "shared") as mock_config:
107
+ mock_config.return_value.together_api_key = None
108
+ with pytest.raises(ValueError, match="Together.ai API key not set"):
109
+ TogetherFinetune(datamodel=finetune)
110
+
111
+
112
+ def test_init_success(mock_api_key, mock_together_client, finetune):
113
+ # no exception should be raised
114
+ TogetherFinetune(datamodel=finetune)
115
+
116
+
117
+ def test_available_parameters():
118
+ parameters = TogetherFinetune.available_parameters()
119
+ assert len(parameters) == 11
120
+ assert all(isinstance(p, FineTuneParameter) for p in parameters)
121
+
122
+ # Check specific parameters
123
+ param_names = [p.name for p in parameters]
124
+ assert "epochs" in param_names
125
+ assert "learning_rate" in param_names
126
+ assert "batch_size" in param_names
127
+ assert "num_checkpoints" in param_names
128
+ assert "min_lr_ratio" in param_names
129
+ assert "warmup_ratio" in param_names
130
+ assert "max_grad_norm" in param_names
131
+ assert "weight_decay" in param_names
132
+ assert "lora_rank" in param_names
133
+ assert "lora_dropout" in param_names
134
+ assert "lora_alpha" in param_names
135
+
136
+
137
+ async def test_status_missing_provider_id(together_finetune, mock_api_key):
138
+ together_finetune.datamodel.provider_id = None
139
+
140
+ status = await together_finetune.status()
141
+ assert status.status == FineTuneStatusType.unknown
142
+ assert "Fine-tuning job ID not set" in status.message
143
+
144
+
145
+ @pytest.mark.parametrize(
146
+ "together_status,expected_status,expected_message",
147
+ [
148
+ (
149
+ TogetherFinetuneJobStatus.STATUS_PENDING,
150
+ FineTuneStatusType.pending,
151
+ f"Fine-tuning job is pending [{TogetherFinetuneJobStatus.STATUS_PENDING}]",
152
+ ),
153
+ (
154
+ TogetherFinetuneJobStatus.STATUS_RUNNING,
155
+ FineTuneStatusType.running,
156
+ f"Fine-tuning job is running [{TogetherFinetuneJobStatus.STATUS_RUNNING}]",
157
+ ),
158
+ (
159
+ TogetherFinetuneJobStatus.STATUS_COMPLETED,
160
+ FineTuneStatusType.completed,
161
+ "Fine-tuning job completed",
162
+ ),
163
+ (
164
+ TogetherFinetuneJobStatus.STATUS_ERROR,
165
+ FineTuneStatusType.failed,
166
+ f"Fine-tuning job failed [{TogetherFinetuneJobStatus.STATUS_ERROR}]",
167
+ ),
168
+ (
169
+ "UNKNOWN_STATUS",
170
+ FineTuneStatusType.unknown,
171
+ "Unknown fine-tuning job status [UNKNOWN_STATUS]",
172
+ ),
173
+ ],
174
+ )
175
+ async def test_status_job_states(
176
+ mock_together_client,
177
+ together_finetune,
178
+ together_status,
179
+ expected_status,
180
+ expected_message,
181
+ mock_api_key,
182
+ ):
183
+ # Mock the retrieve method of the fine_tuning object
184
+ mock_job = MagicMock()
185
+ mock_job.status = together_status
186
+ mock_together_client.fine_tuning.retrieve.return_value = mock_job
187
+
188
+ status = await together_finetune.status()
189
+ assert status.status == expected_status
190
+
191
+ # Check that the status was updated in the datamodel
192
+ assert together_finetune.datamodel.latest_status == expected_status
193
+ assert status.status == expected_status
194
+ assert expected_message == status.message
195
+
196
+ # Verify the fine_tuning.retrieve method was called
197
+ mock_together_client.fine_tuning.retrieve.assert_called_once_with(
198
+ id=together_finetune.datamodel.provider_id
199
+ )
200
+
201
+
202
+ async def test_status_exception(together_finetune, mock_together_client, mock_api_key):
203
+ # Mock the retrieve method to raise an exception
204
+ mock_together_client.fine_tuning.retrieve.side_effect = Exception("API error")
205
+
206
+ status = await together_finetune.status()
207
+ assert status.status == FineTuneStatusType.unknown
208
+ assert "Error retrieving fine-tuning job status: API error" == status.message
209
+
210
+
211
+ @pytest.fixture
212
+ def mock_dataset():
213
+ return DatasetSplit(
214
+ id="test-dataset-123",
215
+ name="Test Dataset",
216
+ splits=Train80Test20SplitDefinition,
217
+ split_contents={"train": [], "test": []},
218
+ )
219
+
220
+
221
+ @pytest.fixture
222
+ def mock_task():
223
+ return Task(
224
+ id="test-task-123",
225
+ name="Test Task",
226
+ output_json_schema=None, # Can be modified in specific tests
227
+ instruction="Test instruction",
228
+ )
229
+
230
+
231
+ async def test_generate_and_upload_jsonl_success(
232
+ together_finetune, mock_dataset, mock_task, mock_together_client, mock_api_key
233
+ ):
234
+ # Mock the formatter
235
+ mock_formatter = MagicMock(spec=DatasetFormatter)
236
+ mock_path = Path("mock_path.jsonl")
237
+ mock_formatter.dump_to_file.return_value = mock_path
238
+
239
+ # Mock the files.upload response
240
+ mock_file = MagicMock()
241
+ mock_file.id = "file-123"
242
+ mock_together_client.files.upload.return_value = mock_file
243
+
244
+ with patch(
245
+ "kiln_ai.adapters.fine_tune.together_finetune.DatasetFormatter",
246
+ return_value=mock_formatter,
247
+ ):
248
+ result = await together_finetune.generate_and_upload_jsonl(
249
+ mock_dataset, "train", mock_task, DatasetFormat.OPENAI_CHAT_JSONL
250
+ )
251
+
252
+ # Check the formatter was created with correct parameters
253
+ assert mock_formatter.dump_to_file.call_count == 1
254
+
255
+ # Check the file was uploaded
256
+ mock_together_client.files.upload.assert_called_once_with(
257
+ file=mock_path,
258
+ purpose=together.types.files.FilePurpose.FineTune,
259
+ check=True,
260
+ )
261
+
262
+ # Check the result is the file ID
263
+ assert result == "file-123"
264
+
265
+
266
+ async def test_generate_and_upload_jsonl_error(
267
+ together_finetune, mock_dataset, mock_task, mock_together_client, mock_api_key
268
+ ):
269
+ # Mock the formatter
270
+ mock_formatter = MagicMock(spec=DatasetFormatter)
271
+ mock_path = Path("mock_path.jsonl")
272
+ mock_formatter.dump_to_file.return_value = mock_path
273
+
274
+ # Mock the files.upload to raise an exception
275
+ mock_together_client.files.upload.side_effect = Exception("Upload failed")
276
+
277
+ with (
278
+ patch(
279
+ "kiln_ai.adapters.fine_tune.together_finetune.DatasetFormatter",
280
+ return_value=mock_formatter,
281
+ ),
282
+ pytest.raises(ValueError, match="Failed to upload dataset: Upload failed"),
283
+ ):
284
+ await together_finetune.generate_and_upload_jsonl(
285
+ mock_dataset, "train", mock_task, DatasetFormat.OPENAI_CHAT_JSONL
286
+ )
287
+
288
+
289
+ @pytest.mark.parametrize(
290
+ "output_schema,expected_mode,expected_format,validation_file",
291
+ [
292
+ (
293
+ '{"type": "object", "properties": {"key": {"type": "string"}}}',
294
+ StructuredOutputMode.json_custom_instructions,
295
+ DatasetFormat.OPENAI_CHAT_JSON_SCHEMA_JSONL,
296
+ True,
297
+ ),
298
+ (None, None, DatasetFormat.OPENAI_CHAT_JSONL, False),
299
+ ],
300
+ )
301
+ async def test_start_success(
302
+ together_finetune,
303
+ mock_dataset,
304
+ mock_task,
305
+ mock_together_client,
306
+ mock_api_key,
307
+ output_schema,
308
+ expected_mode,
309
+ expected_format,
310
+ validation_file,
311
+ ):
312
+ # Set output schema on task
313
+ mock_task.output_json_schema = output_schema
314
+
315
+ # Set parent task on finetune
316
+ together_finetune.datamodel.parent = mock_task
317
+
318
+ # Mock file ID from generate_and_upload_jsonl
319
+ mock_file_id = "file-123"
320
+
321
+ # Mock fine-tuning job response
322
+ mock_job = MagicMock()
323
+ mock_job.id = "job-123"
324
+ mock_job.output_name = "model-123"
325
+ mock_together_client.fine_tuning.create.return_value = mock_job
326
+
327
+ with patch.object(
328
+ together_finetune,
329
+ "generate_and_upload_jsonl",
330
+ AsyncMock(return_value=mock_file_id),
331
+ ):
332
+ if validation_file:
333
+ together_finetune.datamodel.validation_split_name = "validation"
334
+
335
+ await together_finetune._start(mock_dataset)
336
+
337
+ # Check that generate_and_upload_jsonl was called with correct parameters
338
+ first_call = together_finetune.generate_and_upload_jsonl.call_args_list[0].args
339
+ assert first_call[0] == mock_dataset
340
+ assert first_call[1] == together_finetune.datamodel.train_split_name
341
+ assert first_call[2] == mock_task
342
+ assert first_call[3] == expected_format
343
+ if validation_file:
344
+ second_call = together_finetune.generate_and_upload_jsonl.call_args_list[
345
+ 1
346
+ ].args
347
+ assert second_call[0] == mock_dataset
348
+ assert second_call[1] == "validation"
349
+ assert second_call[2] == mock_task
350
+ assert second_call[3] == expected_format
351
+
352
+ # Check that fine_tuning.create was called with correct parameters
353
+ mock_together_client.fine_tuning.create.assert_called_once_with(
354
+ training_file=mock_file_id,
355
+ validation_file=mock_file_id if validation_file else "",
356
+ model=together_finetune.datamodel.base_model_id,
357
+ lora=True,
358
+ suffix=f"kiln_ai_{together_finetune.datamodel.id}"[:40],
359
+ )
360
+
361
+ # Check that datamodel was updated correctly
362
+ assert together_finetune.datamodel.provider_id == "job-123"
363
+ assert together_finetune.datamodel.fine_tune_model_id == "model-123"
364
+ assert together_finetune.datamodel.structured_output_mode == expected_mode
365
+
366
+
367
+ async def test_start_missing_task(together_finetune, mock_dataset, mock_api_key):
368
+ # Don't set parent task
369
+ together_finetune.datamodel.parent = None
370
+ together_finetune.datamodel.save_to_file()
371
+
372
+ with pytest.raises(ValueError, match="Task is required to start a fine-tune"):
373
+ await together_finetune._start(mock_dataset)
374
+
375
+
376
+ async def test_deploy_always_succeeds(together_finetune, mock_api_key):
377
+ # Together automatically deploys, so _deploy should always return True
378
+ result = await together_finetune._deploy()
379
+ assert result is True
380
+
381
+
382
+ def test_augment_system_message(mock_task):
383
+ system_message = "You are a helpful assistant."
384
+
385
+ # Plaintext == no change
386
+ augmented_system_message = TogetherFinetune.augment_system_message(
387
+ system_message, mock_task
388
+ )
389
+ assert augmented_system_message == "You are a helpful assistant."
390
+
391
+ # Now with JSON == append JSON instructions
392
+ mock_task.output_json_schema = (
393
+ '{"type": "object", "properties": {"key": {"type": "string"}}}'
394
+ )
395
+ augmented_system_message = TogetherFinetune.augment_system_message(
396
+ system_message, mock_task
397
+ )
398
+ assert (
399
+ augmented_system_message
400
+ == "You are a helpful assistant.\n\nReturn only JSON. Do not include any non JSON text.\n"
401
+ )
402
+
403
+
404
+ @pytest.mark.parametrize(
405
+ "parameter_name,input_value,expected_key,expected_value,should_exist",
406
+ [
407
+ # learning_rate tests
408
+ ("learning_rate", 0.001, "learning_rate", 0.001, True),
409
+ ("learning_rate", "not_a_float", "learning_rate", None, False),
410
+ # epochs tests
411
+ ("epochs", 5, "n_epochs", 5, True),
412
+ ("epochs", "not_an_int", "n_epochs", None, False),
413
+ # num_checkpoints tests
414
+ ("num_checkpoints", 3, "n_checkpoints", 3, True),
415
+ ("num_checkpoints", "not_an_int", "n_checkpoints", None, False),
416
+ # batch_size tests
417
+ ("batch_size", 32, "batch_size", 32, True),
418
+ ("batch_size", "not_an_int", "batch_size", None, False),
419
+ # min_lr_ratio tests
420
+ ("min_lr_ratio", 0.1, "min_lr_ratio", 0.1, True),
421
+ ("min_lr_ratio", "not_a_float", "min_lr_ratio", None, False),
422
+ # warmup_ratio tests
423
+ ("warmup_ratio", 0.2, "warmup_ratio", 0.2, True),
424
+ ("warmup_ratio", "not_a_float", "warmup_ratio", None, False),
425
+ # max_grad_norm tests
426
+ ("max_grad_norm", 5.0, "max_grad_norm", 5.0, True),
427
+ ("max_grad_norm", "not_a_float", "max_grad_norm", None, False),
428
+ # weight_decay tests
429
+ ("weight_decay", 0.01, "weight_decay", 0.01, True),
430
+ ("weight_decay", "not_a_float", "weight_decay", None, False),
431
+ # lora_rank tests
432
+ ("lora_rank", 16, "lora_r", 16, True),
433
+ ("lora_rank", "not_an_int", "lora_r", None, False),
434
+ # lora_dropout tests
435
+ ("lora_dropout", 0.1, "lora_dropout", 0.1, True),
436
+ ("lora_dropout", "not_a_float", "lora_dropout", None, False),
437
+ # lora_alpha tests
438
+ ("lora_alpha", 32.0, "lora_alpha", 32.0, True),
439
+ ("lora_alpha", "not_a_float", "lora_alpha", None, False),
440
+ ],
441
+ )
442
+ def test_build_finetune_parameters(
443
+ together_finetune,
444
+ parameter_name,
445
+ input_value,
446
+ expected_key,
447
+ expected_value,
448
+ should_exist,
449
+ ):
450
+ """Test that _build_finetune_parameters correctly handles different parameters."""
451
+ # Set the parameter
452
+ together_finetune.datamodel.parameters = {parameter_name: input_value}
453
+ together_finetune.datamodel.id = "test-display-name"
454
+
455
+ # Call the method to build parameters
456
+ result = together_finetune._build_finetune_parameters()
457
+
458
+ # Check that required parameters are always present
459
+ assert result["lora"] is True
460
+ assert result["suffix"] == "kiln_ai_test-display-name"
461
+
462
+ # Check the specific parameter we're testing
463
+ if should_exist:
464
+ assert expected_key in result
465
+ assert result[expected_key] == expected_value
466
+ else:
467
+ assert expected_key not in result
468
+
469
+
470
+ @pytest.mark.parametrize(
471
+ "parameters,expected_params",
472
+ [
473
+ # Test default values when parameters are empty
474
+ (
475
+ {},
476
+ {
477
+ "lora": True,
478
+ "suffix": "kiln_ai_1234",
479
+ },
480
+ ),
481
+ # Test multiple parameters together
482
+ (
483
+ {
484
+ "epochs": 3,
485
+ "learning_rate": 0.0005,
486
+ "batch_size": 16,
487
+ "num_checkpoints": 2,
488
+ "min_lr_ratio": 0.1,
489
+ "warmup_ratio": 0.2,
490
+ "max_grad_norm": 2.0,
491
+ "weight_decay": 0.01,
492
+ },
493
+ {
494
+ "lora": True,
495
+ "n_epochs": 3,
496
+ "learning_rate": 0.0005,
497
+ "batch_size": 16,
498
+ "n_checkpoints": 2,
499
+ "min_lr_ratio": 0.1,
500
+ "warmup_ratio": 0.2,
501
+ "max_grad_norm": 2.0,
502
+ "weight_decay": 0.01,
503
+ "suffix": "kiln_ai_1234",
504
+ },
505
+ ),
506
+ # Test mix of valid and invalid parameters
507
+ (
508
+ {
509
+ "epochs": "invalid",
510
+ "learning_rate": 0.001,
511
+ "batch_size": "invalid",
512
+ "num_checkpoints": "invalid",
513
+ },
514
+ {
515
+ "lora": True,
516
+ "learning_rate": 0.001,
517
+ "suffix": "kiln_ai_1234",
518
+ },
519
+ ),
520
+ ],
521
+ )
522
+ def test_build_finetune_parameters_combinations(
523
+ together_finetune, parameters, expected_params
524
+ ):
525
+ """Test combinations of parameters in _build_finetune_parameters."""
526
+ together_finetune.datamodel.parameters = parameters
527
+ together_finetune.datamodel.id = "1234"
528
+ result = together_finetune._build_finetune_parameters()
529
+
530
+ # Check that all expected keys are present with the correct values
531
+ assert result == expected_params