kiln-ai 0.8.1__py3-none-any.whl → 0.11.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +7 -7
- kiln_ai/adapters/adapter_registry.py +77 -5
- kiln_ai/adapters/data_gen/data_gen_task.py +3 -3
- kiln_ai/adapters/data_gen/test_data_gen_task.py +23 -3
- kiln_ai/adapters/fine_tune/base_finetune.py +5 -1
- kiln_ai/adapters/fine_tune/dataset_formatter.py +310 -65
- kiln_ai/adapters/fine_tune/fireworks_finetune.py +47 -32
- kiln_ai/adapters/fine_tune/openai_finetune.py +12 -11
- kiln_ai/adapters/fine_tune/test_base_finetune.py +19 -0
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +469 -129
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +113 -21
- kiln_ai/adapters/fine_tune/test_openai_finetune.py +125 -14
- kiln_ai/adapters/ml_model_list.py +323 -94
- kiln_ai/adapters/model_adapters/__init__.py +18 -0
- kiln_ai/adapters/{base_adapter.py → model_adapters/base_adapter.py} +81 -37
- kiln_ai/adapters/{langchain_adapters.py → model_adapters/langchain_adapters.py} +130 -84
- kiln_ai/adapters/model_adapters/openai_compatible_config.py +11 -0
- kiln_ai/adapters/model_adapters/openai_model_adapter.py +246 -0
- kiln_ai/adapters/model_adapters/test_base_adapter.py +190 -0
- kiln_ai/adapters/{test_langchain_adapter.py → model_adapters/test_langchain_adapter.py} +103 -88
- kiln_ai/adapters/model_adapters/test_openai_model_adapter.py +225 -0
- kiln_ai/adapters/{test_saving_adapter_results.py → model_adapters/test_saving_adapter_results.py} +43 -15
- kiln_ai/adapters/{test_structured_output.py → model_adapters/test_structured_output.py} +93 -20
- kiln_ai/adapters/parsers/__init__.py +10 -0
- kiln_ai/adapters/parsers/base_parser.py +12 -0
- kiln_ai/adapters/parsers/json_parser.py +37 -0
- kiln_ai/adapters/parsers/parser_registry.py +19 -0
- kiln_ai/adapters/parsers/r1_parser.py +69 -0
- kiln_ai/adapters/parsers/test_json_parser.py +81 -0
- kiln_ai/adapters/parsers/test_parser_registry.py +32 -0
- kiln_ai/adapters/parsers/test_r1_parser.py +144 -0
- kiln_ai/adapters/prompt_builders.py +126 -20
- kiln_ai/adapters/provider_tools.py +91 -36
- kiln_ai/adapters/repair/repair_task.py +17 -6
- kiln_ai/adapters/repair/test_repair_task.py +4 -4
- kiln_ai/adapters/run_output.py +8 -0
- kiln_ai/adapters/test_adapter_registry.py +177 -0
- kiln_ai/adapters/test_generate_docs.py +69 -0
- kiln_ai/adapters/test_prompt_adaptors.py +8 -4
- kiln_ai/adapters/test_prompt_builders.py +190 -29
- kiln_ai/adapters/test_provider_tools.py +268 -46
- kiln_ai/datamodel/__init__.py +193 -12
- kiln_ai/datamodel/basemodel.py +31 -11
- kiln_ai/datamodel/json_schema.py +8 -3
- kiln_ai/datamodel/model_cache.py +8 -3
- kiln_ai/datamodel/test_basemodel.py +81 -2
- kiln_ai/datamodel/test_dataset_split.py +100 -3
- kiln_ai/datamodel/test_example_models.py +25 -4
- kiln_ai/datamodel/test_model_cache.py +24 -0
- kiln_ai/datamodel/test_model_perf.py +125 -0
- kiln_ai/datamodel/test_models.py +129 -0
- kiln_ai/utils/exhaustive_error.py +6 -0
- {kiln_ai-0.8.1.dist-info → kiln_ai-0.11.1.dist-info}/METADATA +9 -7
- kiln_ai-0.11.1.dist-info/RECORD +76 -0
- kiln_ai-0.8.1.dist-info/RECORD +0 -58
- {kiln_ai-0.8.1.dist-info → kiln_ai-0.11.1.dist-info}/WHEEL +0 -0
- {kiln_ai-0.8.1.dist-info → kiln_ai-0.11.1.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -11,17 +11,21 @@ from kiln_ai.adapters.ollama_tools import OllamaConnection
|
|
|
11
11
|
from kiln_ai.adapters.provider_tools import (
|
|
12
12
|
builtin_model_from,
|
|
13
13
|
check_provider_warnings,
|
|
14
|
+
core_provider,
|
|
14
15
|
finetune_cache,
|
|
16
|
+
finetune_from_id,
|
|
15
17
|
finetune_provider_model,
|
|
16
18
|
get_model_and_provider,
|
|
17
19
|
kiln_model_provider_from,
|
|
20
|
+
openai_compatible_config,
|
|
18
21
|
openai_compatible_provider_model,
|
|
22
|
+
parse_custom_model_id,
|
|
19
23
|
provider_enabled,
|
|
20
24
|
provider_name_from_id,
|
|
21
25
|
provider_options_for_custom_model,
|
|
22
26
|
provider_warnings,
|
|
23
27
|
)
|
|
24
|
-
from kiln_ai.datamodel import Finetune, Task
|
|
28
|
+
from kiln_ai.datamodel import Finetune, StructuredOutputMode, Task
|
|
25
29
|
|
|
26
30
|
|
|
27
31
|
@pytest.fixture(autouse=True)
|
|
@@ -61,6 +65,7 @@ def mock_finetune():
|
|
|
61
65
|
finetune = Mock(spec=Finetune)
|
|
62
66
|
finetune.provider = ModelProviderName.openai
|
|
63
67
|
finetune.fine_tune_model_id = "ft:gpt-3.5-turbo:custom:model-123"
|
|
68
|
+
finetune.structured_output_mode = StructuredOutputMode.json_schema
|
|
64
69
|
mock.return_value = finetune
|
|
65
70
|
yield mock
|
|
66
71
|
|
|
@@ -215,14 +220,14 @@ def test_get_model_and_provider_valid_model_wrong_provider():
|
|
|
215
220
|
def test_get_model_and_provider_multiple_providers():
|
|
216
221
|
# Test with a model that has multiple providers
|
|
217
222
|
model, provider = get_model_and_provider(
|
|
218
|
-
ModelName.
|
|
223
|
+
ModelName.llama_3_3_70b, ModelProviderName.groq
|
|
219
224
|
)
|
|
220
225
|
|
|
221
226
|
assert model is not None
|
|
222
227
|
assert provider is not None
|
|
223
|
-
assert model.name == ModelName.
|
|
228
|
+
assert model.name == ModelName.llama_3_3_70b
|
|
224
229
|
assert provider.name == ModelProviderName.groq
|
|
225
|
-
assert provider.provider_options["model"] == "llama-3.
|
|
230
|
+
assert provider.provider_options["model"] == "llama-3.3-70b-versatile"
|
|
226
231
|
|
|
227
232
|
|
|
228
233
|
@pytest.mark.asyncio
|
|
@@ -313,7 +318,7 @@ async def test_kiln_model_provider_from_custom_model_valid(mock_config):
|
|
|
313
318
|
# Mock config to pass provider warnings check
|
|
314
319
|
mock_config.return_value = "fake-api-key"
|
|
315
320
|
|
|
316
|
-
provider =
|
|
321
|
+
provider = kiln_model_provider_from("custom_model", ModelProviderName.openai)
|
|
317
322
|
|
|
318
323
|
assert provider.name == ModelProviderName.openai
|
|
319
324
|
assert provider.supports_structured_output is False
|
|
@@ -380,7 +385,7 @@ async def test_kiln_model_provider_from_custom_registry(mock_config):
|
|
|
380
385
|
mock_config.return_value = "fake-api-key"
|
|
381
386
|
|
|
382
387
|
# Test with a custom registry model ID in format "provider::model_name"
|
|
383
|
-
provider =
|
|
388
|
+
provider = kiln_model_provider_from(
|
|
384
389
|
"openai::gpt-4-turbo", ModelProviderName.kiln_custom_registry
|
|
385
390
|
)
|
|
386
391
|
|
|
@@ -394,7 +399,7 @@ async def test_kiln_model_provider_from_custom_registry(mock_config):
|
|
|
394
399
|
@pytest.mark.asyncio
|
|
395
400
|
async def test_builtin_model_from_invalid_model():
|
|
396
401
|
"""Test that an invalid model name returns None"""
|
|
397
|
-
result =
|
|
402
|
+
result = builtin_model_from("non_existent_model")
|
|
398
403
|
assert result is None
|
|
399
404
|
|
|
400
405
|
|
|
@@ -403,7 +408,7 @@ async def test_builtin_model_from_valid_model_default_provider(mock_config):
|
|
|
403
408
|
"""Test getting a valid model with default provider"""
|
|
404
409
|
mock_config.return_value = "fake-api-key"
|
|
405
410
|
|
|
406
|
-
provider =
|
|
411
|
+
provider = builtin_model_from(ModelName.phi_3_5)
|
|
407
412
|
|
|
408
413
|
assert provider is not None
|
|
409
414
|
assert provider.name == ModelProviderName.ollama
|
|
@@ -415,13 +420,13 @@ async def test_builtin_model_from_valid_model_specific_provider(mock_config):
|
|
|
415
420
|
"""Test getting a valid model with specific provider"""
|
|
416
421
|
mock_config.return_value = "fake-api-key"
|
|
417
422
|
|
|
418
|
-
provider =
|
|
419
|
-
ModelName.
|
|
423
|
+
provider = builtin_model_from(
|
|
424
|
+
ModelName.llama_3_3_70b, provider_name=ModelProviderName.groq
|
|
420
425
|
)
|
|
421
426
|
|
|
422
427
|
assert provider is not None
|
|
423
428
|
assert provider.name == ModelProviderName.groq
|
|
424
|
-
assert provider.provider_options["model"] == "llama-3.
|
|
429
|
+
assert provider.provider_options["model"] == "llama-3.3-70b-versatile"
|
|
425
430
|
|
|
426
431
|
|
|
427
432
|
@pytest.mark.asyncio
|
|
@@ -429,9 +434,7 @@ async def test_builtin_model_from_invalid_provider(mock_config):
|
|
|
429
434
|
"""Test that requesting an invalid provider returns None"""
|
|
430
435
|
mock_config.return_value = "fake-api-key"
|
|
431
436
|
|
|
432
|
-
provider =
|
|
433
|
-
ModelName.phi_3_5, provider_name="invalid_provider"
|
|
434
|
-
)
|
|
437
|
+
provider = builtin_model_from(ModelName.phi_3_5, provider_name="invalid_provider")
|
|
435
438
|
|
|
436
439
|
assert provider is None
|
|
437
440
|
|
|
@@ -462,7 +465,7 @@ async def test_builtin_model_from_provider_warning_check(mock_config):
|
|
|
462
465
|
mock_config.return_value = None
|
|
463
466
|
|
|
464
467
|
with pytest.raises(ValueError) as exc_info:
|
|
465
|
-
await builtin_model_from(ModelName.
|
|
468
|
+
await builtin_model_from(ModelName.llama_3_3_70b, ModelProviderName.groq)
|
|
466
469
|
|
|
467
470
|
assert provider_warnings[ModelProviderName.groq].message in str(exc_info.value)
|
|
468
471
|
|
|
@@ -475,10 +478,7 @@ def test_finetune_provider_model_success(mock_project, mock_task, mock_finetune)
|
|
|
475
478
|
|
|
476
479
|
assert provider.name == ModelProviderName.openai
|
|
477
480
|
assert provider.provider_options == {"model": "ft:gpt-3.5-turbo:custom:model-123"}
|
|
478
|
-
|
|
479
|
-
# Test cache
|
|
480
|
-
cached_provider = finetune_provider_model(model_id)
|
|
481
|
-
assert cached_provider is provider
|
|
481
|
+
assert provider.structured_output_mode == StructuredOutputMode.json_schema
|
|
482
482
|
|
|
483
483
|
|
|
484
484
|
def test_finetune_provider_model_invalid_id():
|
|
@@ -533,22 +533,60 @@ def test_finetune_provider_model_incomplete_finetune(
|
|
|
533
533
|
)
|
|
534
534
|
|
|
535
535
|
|
|
536
|
-
|
|
537
|
-
|
|
536
|
+
@pytest.mark.parametrize(
|
|
537
|
+
"structured_output_mode, provider_name, expected_mode",
|
|
538
|
+
[
|
|
539
|
+
(
|
|
540
|
+
StructuredOutputMode.json_mode,
|
|
541
|
+
ModelProviderName.fireworks_ai,
|
|
542
|
+
StructuredOutputMode.json_mode,
|
|
543
|
+
),
|
|
544
|
+
(
|
|
545
|
+
StructuredOutputMode.json_schema,
|
|
546
|
+
ModelProviderName.openai,
|
|
547
|
+
StructuredOutputMode.json_schema,
|
|
548
|
+
),
|
|
549
|
+
(
|
|
550
|
+
StructuredOutputMode.function_calling,
|
|
551
|
+
ModelProviderName.openai,
|
|
552
|
+
StructuredOutputMode.function_calling,
|
|
553
|
+
),
|
|
554
|
+
(None, ModelProviderName.fireworks_ai, StructuredOutputMode.json_mode),
|
|
555
|
+
(None, ModelProviderName.openai, StructuredOutputMode.json_schema),
|
|
556
|
+
],
|
|
557
|
+
)
|
|
558
|
+
def test_finetune_provider_model_structured_mode(
|
|
559
|
+
mock_project,
|
|
560
|
+
mock_task,
|
|
561
|
+
mock_finetune,
|
|
562
|
+
structured_output_mode,
|
|
563
|
+
provider_name,
|
|
564
|
+
expected_mode,
|
|
538
565
|
):
|
|
539
|
-
"""Test creation of
|
|
566
|
+
"""Test creation of provider with different structured output modes"""
|
|
540
567
|
finetune = Mock(spec=Finetune)
|
|
541
|
-
finetune.provider =
|
|
568
|
+
finetune.provider = provider_name
|
|
542
569
|
finetune.fine_tune_model_id = "fireworks-model-123"
|
|
570
|
+
finetune.structured_output_mode = structured_output_mode
|
|
543
571
|
mock_finetune.return_value = finetune
|
|
544
572
|
|
|
545
573
|
provider = finetune_provider_model("project-123::task-456::finetune-789")
|
|
546
574
|
|
|
547
|
-
assert provider.name ==
|
|
575
|
+
assert provider.name == provider_name
|
|
548
576
|
assert provider.provider_options == {"model": "fireworks-model-123"}
|
|
549
|
-
assert provider.
|
|
550
|
-
|
|
551
|
-
|
|
577
|
+
assert provider.structured_output_mode == expected_mode
|
|
578
|
+
|
|
579
|
+
|
|
580
|
+
def test_openai_compatible_provider_config(mock_shared_config):
|
|
581
|
+
"""Test successful creation of an OpenAI compatible provider"""
|
|
582
|
+
model_id = "test_provider::gpt-4"
|
|
583
|
+
|
|
584
|
+
config = openai_compatible_config(model_id)
|
|
585
|
+
|
|
586
|
+
assert config.provider_name == ModelProviderName.openai_compatible
|
|
587
|
+
assert config.model_name == "gpt-4"
|
|
588
|
+
assert config.api_key == "test-key"
|
|
589
|
+
assert config.base_url == "https://api.test.com"
|
|
552
590
|
|
|
553
591
|
|
|
554
592
|
def test_openai_compatible_provider_model_success(mock_shared_config):
|
|
@@ -559,57 +597,53 @@ def test_openai_compatible_provider_model_success(mock_shared_config):
|
|
|
559
597
|
|
|
560
598
|
assert provider.name == ModelProviderName.openai_compatible
|
|
561
599
|
assert provider.provider_options == {
|
|
562
|
-
"model":
|
|
563
|
-
"api_key": "test-key",
|
|
564
|
-
"openai_api_base": "https://api.test.com",
|
|
600
|
+
"model": model_id,
|
|
565
601
|
}
|
|
566
602
|
assert provider.supports_structured_output is False
|
|
567
603
|
assert provider.supports_data_gen is False
|
|
568
604
|
assert provider.untested_model is True
|
|
569
605
|
|
|
570
606
|
|
|
571
|
-
def
|
|
607
|
+
def test_openai_compatible_config_no_api_key(mock_shared_config):
|
|
572
608
|
"""Test provider creation without API key (should work as some providers don't require it)"""
|
|
573
609
|
model_id = "no_key_provider::gpt-4"
|
|
574
610
|
|
|
575
|
-
|
|
611
|
+
config = openai_compatible_config(model_id)
|
|
576
612
|
|
|
577
|
-
assert
|
|
578
|
-
assert
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
"openai_api_base": "https://api.nokey.com",
|
|
582
|
-
}
|
|
613
|
+
assert config.provider_name == ModelProviderName.openai_compatible
|
|
614
|
+
assert config.model_name == "gpt-4"
|
|
615
|
+
assert config.api_key is None
|
|
616
|
+
assert config.base_url == "https://api.nokey.com"
|
|
583
617
|
|
|
584
618
|
|
|
585
|
-
def
|
|
619
|
+
def test_openai_compatible_config_invalid_id():
|
|
586
620
|
"""Test handling of invalid model ID format"""
|
|
587
621
|
with pytest.raises(ValueError) as exc_info:
|
|
588
|
-
|
|
622
|
+
openai_compatible_config("invalid-id-format")
|
|
589
623
|
assert (
|
|
590
624
|
str(exc_info.value) == "Invalid openai compatible model ID: invalid-id-format"
|
|
591
625
|
)
|
|
592
626
|
|
|
593
627
|
|
|
594
|
-
def
|
|
628
|
+
def test_openai_compatible_config_no_providers(mock_shared_config):
|
|
595
629
|
"""Test handling when no providers are configured"""
|
|
596
630
|
mock_shared_config.return_value.openai_compatible_providers = None
|
|
597
631
|
|
|
598
632
|
with pytest.raises(ValueError) as exc_info:
|
|
599
|
-
|
|
633
|
+
openai_compatible_config("test_provider::gpt-4")
|
|
600
634
|
assert str(exc_info.value) == "OpenAI compatible provider test_provider not found"
|
|
601
635
|
|
|
602
636
|
|
|
603
|
-
def
|
|
637
|
+
def test_openai_compatible_config_provider_not_found(mock_shared_config):
|
|
604
638
|
"""Test handling of non-existent provider"""
|
|
605
639
|
with pytest.raises(ValueError) as exc_info:
|
|
606
|
-
|
|
640
|
+
openai_compatible_config("unknown_provider::gpt-4")
|
|
607
641
|
assert (
|
|
608
642
|
str(exc_info.value) == "OpenAI compatible provider unknown_provider not found"
|
|
609
643
|
)
|
|
610
644
|
|
|
611
645
|
|
|
612
|
-
def
|
|
646
|
+
def test_openai_compatible_config_no_base_url(mock_shared_config):
|
|
613
647
|
"""Test handling of provider without base URL"""
|
|
614
648
|
mock_shared_config.return_value.openai_compatible_providers = [
|
|
615
649
|
{
|
|
@@ -619,8 +653,196 @@ def test_openai_compatible_provider_model_no_base_url(mock_shared_config):
|
|
|
619
653
|
]
|
|
620
654
|
|
|
621
655
|
with pytest.raises(ValueError) as exc_info:
|
|
622
|
-
|
|
656
|
+
openai_compatible_config("test_provider::gpt-4")
|
|
623
657
|
assert (
|
|
624
658
|
str(exc_info.value)
|
|
625
659
|
== "OpenAI compatible provider test_provider has no base URL"
|
|
626
660
|
)
|
|
661
|
+
|
|
662
|
+
|
|
663
|
+
def test_parse_custom_model_id_valid():
|
|
664
|
+
"""Test parsing a valid custom model ID"""
|
|
665
|
+
provider_name, model_name = parse_custom_model_id(
|
|
666
|
+
"openai::gpt-4-turbo-elite-enterprise-editon"
|
|
667
|
+
)
|
|
668
|
+
assert provider_name == ModelProviderName.openai
|
|
669
|
+
assert model_name == "gpt-4-turbo-elite-enterprise-editon"
|
|
670
|
+
|
|
671
|
+
|
|
672
|
+
def test_parse_custom_model_id_no_separator():
|
|
673
|
+
"""Test parsing an invalid model ID without separator"""
|
|
674
|
+
with pytest.raises(ValueError) as exc_info:
|
|
675
|
+
parse_custom_model_id("invalid-model-id")
|
|
676
|
+
assert str(exc_info.value) == "Invalid custom model ID: invalid-model-id"
|
|
677
|
+
|
|
678
|
+
|
|
679
|
+
def test_parse_custom_model_id_invalid_provider():
|
|
680
|
+
"""Test parsing model ID with invalid provider"""
|
|
681
|
+
with pytest.raises(ValueError) as exc_info:
|
|
682
|
+
parse_custom_model_id("invalid_provider::model")
|
|
683
|
+
assert str(exc_info.value) == "Invalid provider name: invalid_provider"
|
|
684
|
+
|
|
685
|
+
|
|
686
|
+
def test_parse_custom_model_id_empty_parts():
|
|
687
|
+
"""Test parsing model ID with empty provider or model name"""
|
|
688
|
+
with pytest.raises(ValueError) as exc_info:
|
|
689
|
+
parse_custom_model_id("::model")
|
|
690
|
+
assert str(exc_info.value) == "Invalid provider name: "
|
|
691
|
+
|
|
692
|
+
|
|
693
|
+
def test_core_provider_basic_provider():
|
|
694
|
+
"""Test core_provider with a basic provider that doesn't need mapping"""
|
|
695
|
+
result = core_provider("gpt-4", ModelProviderName.openai)
|
|
696
|
+
assert result == ModelProviderName.openai
|
|
697
|
+
|
|
698
|
+
|
|
699
|
+
def test_core_provider_custom_registry():
|
|
700
|
+
"""Test core_provider with custom registry provider"""
|
|
701
|
+
result = core_provider("openai::gpt-4", ModelProviderName.kiln_custom_registry)
|
|
702
|
+
assert result == ModelProviderName.openai
|
|
703
|
+
|
|
704
|
+
|
|
705
|
+
def test_core_provider_finetune():
|
|
706
|
+
"""Test core_provider with fine-tune provider"""
|
|
707
|
+
model_id = "project-123::task-456::finetune-789"
|
|
708
|
+
|
|
709
|
+
with patch(
|
|
710
|
+
"kiln_ai.adapters.provider_tools.finetune_from_id"
|
|
711
|
+
) as mock_finetune_from_id:
|
|
712
|
+
# Mock the finetune object
|
|
713
|
+
finetune = Mock(spec=Finetune)
|
|
714
|
+
finetune.provider = ModelProviderName.openai
|
|
715
|
+
mock_finetune_from_id.return_value = finetune
|
|
716
|
+
|
|
717
|
+
result = core_provider(model_id, ModelProviderName.kiln_fine_tune)
|
|
718
|
+
assert result == ModelProviderName.openai
|
|
719
|
+
mock_finetune_from_id.assert_called_once_with(model_id)
|
|
720
|
+
|
|
721
|
+
|
|
722
|
+
def test_core_provider_finetune_invalid_provider():
|
|
723
|
+
"""Test core_provider with fine-tune having invalid provider"""
|
|
724
|
+
model_id = "project-123::task-456::finetune-789"
|
|
725
|
+
|
|
726
|
+
with patch(
|
|
727
|
+
"kiln_ai.adapters.provider_tools.finetune_from_id"
|
|
728
|
+
) as mock_finetune_from_id:
|
|
729
|
+
# Mock finetune with invalid provider
|
|
730
|
+
finetune = Mock(spec=Finetune)
|
|
731
|
+
finetune.provider = "invalid_provider"
|
|
732
|
+
mock_finetune_from_id.return_value = finetune
|
|
733
|
+
|
|
734
|
+
with pytest.raises(ValueError) as exc_info:
|
|
735
|
+
core_provider(model_id, ModelProviderName.kiln_fine_tune)
|
|
736
|
+
assert (
|
|
737
|
+
str(exc_info.value)
|
|
738
|
+
== f"Finetune {model_id} has no underlying provider invalid_provider"
|
|
739
|
+
)
|
|
740
|
+
mock_finetune_from_id.assert_called_once_with(model_id)
|
|
741
|
+
|
|
742
|
+
|
|
743
|
+
def test_finetune_from_id_success(mock_project, mock_task, mock_finetune):
|
|
744
|
+
"""Test successful retrieval of a finetune model"""
|
|
745
|
+
model_id = "project-123::task-456::finetune-789"
|
|
746
|
+
|
|
747
|
+
# First call should hit the database
|
|
748
|
+
finetune = finetune_from_id(model_id)
|
|
749
|
+
|
|
750
|
+
assert finetune.provider == ModelProviderName.openai
|
|
751
|
+
assert finetune.fine_tune_model_id == "ft:gpt-3.5-turbo:custom:model-123"
|
|
752
|
+
|
|
753
|
+
# Verify mocks were called correctly
|
|
754
|
+
mock_project.assert_called_once_with("project-123")
|
|
755
|
+
mock_task.assert_called_once_with("task-456", "/fake/path")
|
|
756
|
+
mock_finetune.assert_called_once_with("finetune-789", "/fake/path/task")
|
|
757
|
+
|
|
758
|
+
# Second call should use cache
|
|
759
|
+
cached_finetune = finetune_from_id(model_id)
|
|
760
|
+
assert cached_finetune is finetune
|
|
761
|
+
|
|
762
|
+
# Verify no additional disk calls were made
|
|
763
|
+
mock_project.assert_called_once()
|
|
764
|
+
mock_task.assert_called_once()
|
|
765
|
+
mock_finetune.assert_called_once()
|
|
766
|
+
|
|
767
|
+
|
|
768
|
+
def test_finetune_from_id_invalid_id():
|
|
769
|
+
"""Test handling of invalid model ID format"""
|
|
770
|
+
with pytest.raises(ValueError) as exc_info:
|
|
771
|
+
finetune_from_id("invalid-id-format")
|
|
772
|
+
assert str(exc_info.value) == "Invalid fine tune ID: invalid-id-format"
|
|
773
|
+
|
|
774
|
+
|
|
775
|
+
def test_finetune_from_id_project_not_found(mock_project):
|
|
776
|
+
"""Test handling of non-existent project"""
|
|
777
|
+
mock_project.return_value = None
|
|
778
|
+
model_id = "project-123::task-456::finetune-789"
|
|
779
|
+
|
|
780
|
+
with pytest.raises(ValueError) as exc_info:
|
|
781
|
+
finetune_from_id(model_id)
|
|
782
|
+
assert str(exc_info.value) == "Project project-123 not found"
|
|
783
|
+
|
|
784
|
+
# Verify cache was not populated
|
|
785
|
+
assert model_id not in finetune_cache
|
|
786
|
+
|
|
787
|
+
|
|
788
|
+
def test_finetune_from_id_task_not_found(mock_project, mock_task):
|
|
789
|
+
"""Test handling of non-existent task"""
|
|
790
|
+
mock_task.return_value = None
|
|
791
|
+
model_id = "project-123::task-456::finetune-789"
|
|
792
|
+
|
|
793
|
+
with pytest.raises(ValueError) as exc_info:
|
|
794
|
+
finetune_from_id(model_id)
|
|
795
|
+
assert str(exc_info.value) == "Task task-456 not found"
|
|
796
|
+
|
|
797
|
+
# Verify cache was not populated
|
|
798
|
+
assert model_id not in finetune_cache
|
|
799
|
+
|
|
800
|
+
|
|
801
|
+
def test_finetune_from_id_finetune_not_found(mock_project, mock_task, mock_finetune):
|
|
802
|
+
"""Test handling of non-existent finetune"""
|
|
803
|
+
mock_finetune.return_value = None
|
|
804
|
+
model_id = "project-123::task-456::finetune-789"
|
|
805
|
+
|
|
806
|
+
with pytest.raises(ValueError) as exc_info:
|
|
807
|
+
finetune_from_id(model_id)
|
|
808
|
+
assert str(exc_info.value) == "Fine tune finetune-789 not found"
|
|
809
|
+
|
|
810
|
+
# Verify cache was not populated
|
|
811
|
+
assert model_id not in finetune_cache
|
|
812
|
+
|
|
813
|
+
|
|
814
|
+
def test_finetune_from_id_incomplete_finetune(mock_project, mock_task, mock_finetune):
|
|
815
|
+
"""Test handling of incomplete finetune"""
|
|
816
|
+
finetune = Mock(spec=Finetune)
|
|
817
|
+
finetune.fine_tune_model_id = None
|
|
818
|
+
mock_finetune.return_value = finetune
|
|
819
|
+
model_id = "project-123::task-456::finetune-789"
|
|
820
|
+
|
|
821
|
+
with pytest.raises(ValueError) as exc_info:
|
|
822
|
+
finetune_from_id(model_id)
|
|
823
|
+
assert (
|
|
824
|
+
str(exc_info.value)
|
|
825
|
+
== "Fine tune finetune-789 not completed. Refresh it's status in the fine-tune tab."
|
|
826
|
+
)
|
|
827
|
+
|
|
828
|
+
# Verify cache was not populated with incomplete finetune
|
|
829
|
+
assert model_id not in finetune_cache
|
|
830
|
+
|
|
831
|
+
|
|
832
|
+
def test_finetune_from_id_cache_hit(mock_project, mock_task, mock_finetune):
|
|
833
|
+
"""Test that cached finetune is returned without database calls"""
|
|
834
|
+
model_id = "project-123::task-456::finetune-789"
|
|
835
|
+
|
|
836
|
+
# Pre-populate cache
|
|
837
|
+
finetune = Mock(spec=Finetune)
|
|
838
|
+
finetune.fine_tune_model_id = "ft:gpt-3.5-turbo:custom:model-123"
|
|
839
|
+
finetune_cache[model_id] = finetune
|
|
840
|
+
|
|
841
|
+
# Get finetune from cache
|
|
842
|
+
result = finetune_from_id(model_id)
|
|
843
|
+
|
|
844
|
+
assert result == finetune
|
|
845
|
+
# Verify no database calls were made
|
|
846
|
+
mock_project.assert_not_called()
|
|
847
|
+
mock_task.assert_not_called()
|
|
848
|
+
mock_finetune.assert_not_called()
|