kiln-ai 0.7.0__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/adapter_registry.py +2 -0
- kiln_ai/adapters/base_adapter.py +6 -1
- kiln_ai/adapters/langchain_adapters.py +5 -1
- kiln_ai/adapters/ml_model_list.py +43 -12
- kiln_ai/adapters/ollama_tools.py +4 -3
- kiln_ai/adapters/provider_tools.py +63 -2
- kiln_ai/adapters/repair/repair_task.py +4 -2
- kiln_ai/adapters/test_langchain_adapter.py +183 -0
- kiln_ai/adapters/test_provider_tools.py +315 -1
- kiln_ai/datamodel/__init__.py +162 -19
- kiln_ai/datamodel/basemodel.py +90 -42
- kiln_ai/datamodel/model_cache.py +116 -0
- kiln_ai/datamodel/test_basemodel.py +138 -3
- kiln_ai/datamodel/test_dataset_split.py +1 -1
- kiln_ai/datamodel/test_model_cache.py +244 -0
- kiln_ai/datamodel/test_models.py +173 -0
- kiln_ai/datamodel/test_output_rating.py +377 -10
- kiln_ai/utils/config.py +33 -10
- kiln_ai/utils/test_config.py +48 -0
- kiln_ai-0.8.0.dist-info/METADATA +237 -0
- {kiln_ai-0.7.0.dist-info → kiln_ai-0.8.0.dist-info}/RECORD +23 -21
- {kiln_ai-0.7.0.dist-info → kiln_ai-0.8.0.dist-info}/WHEEL +1 -1
- kiln_ai-0.7.0.dist-info/METADATA +0 -90
- {kiln_ai-0.7.0.dist-info → kiln_ai-0.8.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -1,21 +1,34 @@
|
|
|
1
|
-
from unittest.mock import AsyncMock, patch
|
|
1
|
+
from unittest.mock import AsyncMock, Mock, patch
|
|
2
2
|
|
|
3
3
|
import pytest
|
|
4
4
|
|
|
5
5
|
from kiln_ai.adapters.ml_model_list import (
|
|
6
|
+
KilnModel,
|
|
6
7
|
ModelName,
|
|
7
8
|
ModelProviderName,
|
|
8
9
|
)
|
|
9
10
|
from kiln_ai.adapters.ollama_tools import OllamaConnection
|
|
10
11
|
from kiln_ai.adapters.provider_tools import (
|
|
12
|
+
builtin_model_from,
|
|
11
13
|
check_provider_warnings,
|
|
14
|
+
finetune_cache,
|
|
15
|
+
finetune_provider_model,
|
|
12
16
|
get_model_and_provider,
|
|
13
17
|
kiln_model_provider_from,
|
|
18
|
+
openai_compatible_provider_model,
|
|
14
19
|
provider_enabled,
|
|
15
20
|
provider_name_from_id,
|
|
16
21
|
provider_options_for_custom_model,
|
|
17
22
|
provider_warnings,
|
|
18
23
|
)
|
|
24
|
+
from kiln_ai.datamodel import Finetune, Task
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@pytest.fixture(autouse=True)
|
|
28
|
+
def clear_finetune_cache():
|
|
29
|
+
"""Clear the finetune provider model cache before each test"""
|
|
30
|
+
finetune_cache.clear()
|
|
31
|
+
yield
|
|
19
32
|
|
|
20
33
|
|
|
21
34
|
@pytest.fixture
|
|
@@ -24,6 +37,53 @@ def mock_config():
|
|
|
24
37
|
yield mock
|
|
25
38
|
|
|
26
39
|
|
|
40
|
+
@pytest.fixture
|
|
41
|
+
def mock_project():
|
|
42
|
+
with patch("kiln_ai.adapters.provider_tools.project_from_id") as mock:
|
|
43
|
+
project = Mock()
|
|
44
|
+
project.path = "/fake/path"
|
|
45
|
+
mock.return_value = project
|
|
46
|
+
yield mock
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@pytest.fixture
|
|
50
|
+
def mock_task():
|
|
51
|
+
with patch("kiln_ai.datamodel.Task.from_id_and_parent_path") as mock:
|
|
52
|
+
task = Mock(spec=Task)
|
|
53
|
+
task.path = "/fake/path/task"
|
|
54
|
+
mock.return_value = task
|
|
55
|
+
yield mock
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@pytest.fixture
|
|
59
|
+
def mock_finetune():
|
|
60
|
+
with patch("kiln_ai.datamodel.Finetune.from_id_and_parent_path") as mock:
|
|
61
|
+
finetune = Mock(spec=Finetune)
|
|
62
|
+
finetune.provider = ModelProviderName.openai
|
|
63
|
+
finetune.fine_tune_model_id = "ft:gpt-3.5-turbo:custom:model-123"
|
|
64
|
+
mock.return_value = finetune
|
|
65
|
+
yield mock
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
@pytest.fixture
|
|
69
|
+
def mock_shared_config():
|
|
70
|
+
with patch("kiln_ai.adapters.provider_tools.Config.shared") as mock:
|
|
71
|
+
config = Mock()
|
|
72
|
+
config.openai_compatible_providers = [
|
|
73
|
+
{
|
|
74
|
+
"name": "test_provider",
|
|
75
|
+
"base_url": "https://api.test.com",
|
|
76
|
+
"api_key": "test-key",
|
|
77
|
+
},
|
|
78
|
+
{
|
|
79
|
+
"name": "no_key_provider",
|
|
80
|
+
"base_url": "https://api.nokey.com",
|
|
81
|
+
},
|
|
82
|
+
]
|
|
83
|
+
mock.return_value = config
|
|
84
|
+
yield mock
|
|
85
|
+
|
|
86
|
+
|
|
27
87
|
def test_check_provider_warnings_no_warning(mock_config):
|
|
28
88
|
mock_config.return_value = "some_value"
|
|
29
89
|
|
|
@@ -103,6 +163,8 @@ def test_provider_name_from_id_case_sensitivity():
|
|
|
103
163
|
(ModelProviderName.ollama, "Ollama"),
|
|
104
164
|
(ModelProviderName.openai, "OpenAI"),
|
|
105
165
|
(ModelProviderName.fireworks_ai, "Fireworks AI"),
|
|
166
|
+
(ModelProviderName.kiln_fine_tune, "Fine Tuned Models"),
|
|
167
|
+
(ModelProviderName.kiln_custom_registry, "Custom Models"),
|
|
106
168
|
],
|
|
107
169
|
)
|
|
108
170
|
def test_provider_name_from_id_parametrized(provider_id, expected_name):
|
|
@@ -310,3 +372,255 @@ def test_provider_options_for_custom_model_invalid_enum():
|
|
|
310
372
|
"""Test handling of invalid enum value"""
|
|
311
373
|
with pytest.raises(ValueError):
|
|
312
374
|
provider_options_for_custom_model("model_name", "invalid_enum_value")
|
|
375
|
+
|
|
376
|
+
|
|
377
|
+
@pytest.mark.asyncio
|
|
378
|
+
async def test_kiln_model_provider_from_custom_registry(mock_config):
|
|
379
|
+
# Mock config to pass provider warnings check
|
|
380
|
+
mock_config.return_value = "fake-api-key"
|
|
381
|
+
|
|
382
|
+
# Test with a custom registry model ID in format "provider::model_name"
|
|
383
|
+
provider = await kiln_model_provider_from(
|
|
384
|
+
"openai::gpt-4-turbo", ModelProviderName.kiln_custom_registry
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
assert provider.name == ModelProviderName.openai
|
|
388
|
+
assert provider.supports_structured_output is False
|
|
389
|
+
assert provider.supports_data_gen is False
|
|
390
|
+
assert provider.untested_model is True
|
|
391
|
+
assert provider.provider_options == {"model": "gpt-4-turbo"}
|
|
392
|
+
|
|
393
|
+
|
|
394
|
+
@pytest.mark.asyncio
|
|
395
|
+
async def test_builtin_model_from_invalid_model():
|
|
396
|
+
"""Test that an invalid model name returns None"""
|
|
397
|
+
result = await builtin_model_from("non_existent_model")
|
|
398
|
+
assert result is None
|
|
399
|
+
|
|
400
|
+
|
|
401
|
+
@pytest.mark.asyncio
|
|
402
|
+
async def test_builtin_model_from_valid_model_default_provider(mock_config):
|
|
403
|
+
"""Test getting a valid model with default provider"""
|
|
404
|
+
mock_config.return_value = "fake-api-key"
|
|
405
|
+
|
|
406
|
+
provider = await builtin_model_from(ModelName.phi_3_5)
|
|
407
|
+
|
|
408
|
+
assert provider is not None
|
|
409
|
+
assert provider.name == ModelProviderName.ollama
|
|
410
|
+
assert provider.provider_options["model"] == "phi3.5"
|
|
411
|
+
|
|
412
|
+
|
|
413
|
+
@pytest.mark.asyncio
|
|
414
|
+
async def test_builtin_model_from_valid_model_specific_provider(mock_config):
|
|
415
|
+
"""Test getting a valid model with specific provider"""
|
|
416
|
+
mock_config.return_value = "fake-api-key"
|
|
417
|
+
|
|
418
|
+
provider = await builtin_model_from(
|
|
419
|
+
ModelName.llama_3_1_70b, provider_name=ModelProviderName.groq
|
|
420
|
+
)
|
|
421
|
+
|
|
422
|
+
assert provider is not None
|
|
423
|
+
assert provider.name == ModelProviderName.groq
|
|
424
|
+
assert provider.provider_options["model"] == "llama-3.1-70b-versatile"
|
|
425
|
+
|
|
426
|
+
|
|
427
|
+
@pytest.mark.asyncio
|
|
428
|
+
async def test_builtin_model_from_invalid_provider(mock_config):
|
|
429
|
+
"""Test that requesting an invalid provider returns None"""
|
|
430
|
+
mock_config.return_value = "fake-api-key"
|
|
431
|
+
|
|
432
|
+
provider = await builtin_model_from(
|
|
433
|
+
ModelName.phi_3_5, provider_name="invalid_provider"
|
|
434
|
+
)
|
|
435
|
+
|
|
436
|
+
assert provider is None
|
|
437
|
+
|
|
438
|
+
|
|
439
|
+
@pytest.mark.asyncio
|
|
440
|
+
async def test_builtin_model_from_model_no_providers():
|
|
441
|
+
"""Test handling of a model with no providers"""
|
|
442
|
+
with patch("kiln_ai.adapters.provider_tools.built_in_models") as mock_models:
|
|
443
|
+
# Create a mock model with no providers
|
|
444
|
+
mock_model = KilnModel(
|
|
445
|
+
name=ModelName.phi_3_5,
|
|
446
|
+
friendly_name="Test Model",
|
|
447
|
+
providers=[],
|
|
448
|
+
family="test_family",
|
|
449
|
+
)
|
|
450
|
+
mock_models.__iter__.return_value = [mock_model]
|
|
451
|
+
|
|
452
|
+
with pytest.raises(ValueError) as exc_info:
|
|
453
|
+
await builtin_model_from(ModelName.phi_3_5)
|
|
454
|
+
|
|
455
|
+
assert str(exc_info.value) == f"Model {ModelName.phi_3_5} has no providers"
|
|
456
|
+
|
|
457
|
+
|
|
458
|
+
@pytest.mark.asyncio
|
|
459
|
+
async def test_builtin_model_from_provider_warning_check(mock_config):
|
|
460
|
+
"""Test that provider warnings are checked"""
|
|
461
|
+
# Make the config check fail
|
|
462
|
+
mock_config.return_value = None
|
|
463
|
+
|
|
464
|
+
with pytest.raises(ValueError) as exc_info:
|
|
465
|
+
await builtin_model_from(ModelName.llama_3_1_70b, ModelProviderName.groq)
|
|
466
|
+
|
|
467
|
+
assert provider_warnings[ModelProviderName.groq].message in str(exc_info.value)
|
|
468
|
+
|
|
469
|
+
|
|
470
|
+
def test_finetune_provider_model_success(mock_project, mock_task, mock_finetune):
|
|
471
|
+
"""Test successful creation of a fine-tuned model provider"""
|
|
472
|
+
model_id = "project-123::task-456::finetune-789"
|
|
473
|
+
|
|
474
|
+
provider = finetune_provider_model(model_id)
|
|
475
|
+
|
|
476
|
+
assert provider.name == ModelProviderName.openai
|
|
477
|
+
assert provider.provider_options == {"model": "ft:gpt-3.5-turbo:custom:model-123"}
|
|
478
|
+
|
|
479
|
+
# Test cache
|
|
480
|
+
cached_provider = finetune_provider_model(model_id)
|
|
481
|
+
assert cached_provider is provider
|
|
482
|
+
|
|
483
|
+
|
|
484
|
+
def test_finetune_provider_model_invalid_id():
|
|
485
|
+
"""Test handling of invalid model ID format"""
|
|
486
|
+
with pytest.raises(ValueError) as exc_info:
|
|
487
|
+
finetune_provider_model("invalid-id-format")
|
|
488
|
+
assert str(exc_info.value) == "Invalid fine tune ID: invalid-id-format"
|
|
489
|
+
|
|
490
|
+
|
|
491
|
+
def test_finetune_provider_model_project_not_found(mock_project):
|
|
492
|
+
"""Test handling of non-existent project"""
|
|
493
|
+
mock_project.return_value = None
|
|
494
|
+
|
|
495
|
+
with pytest.raises(ValueError) as exc_info:
|
|
496
|
+
finetune_provider_model("project-123::task-456::finetune-789")
|
|
497
|
+
assert str(exc_info.value) == "Project project-123 not found"
|
|
498
|
+
|
|
499
|
+
|
|
500
|
+
def test_finetune_provider_model_task_not_found(mock_project, mock_task):
|
|
501
|
+
"""Test handling of non-existent task"""
|
|
502
|
+
mock_task.return_value = None
|
|
503
|
+
|
|
504
|
+
with pytest.raises(ValueError) as exc_info:
|
|
505
|
+
finetune_provider_model("project-123::task-456::finetune-789")
|
|
506
|
+
assert str(exc_info.value) == "Task task-456 not found"
|
|
507
|
+
|
|
508
|
+
|
|
509
|
+
def test_finetune_provider_model_finetune_not_found(
|
|
510
|
+
mock_project, mock_task, mock_finetune
|
|
511
|
+
):
|
|
512
|
+
"""Test handling of non-existent fine-tune"""
|
|
513
|
+
mock_finetune.return_value = None
|
|
514
|
+
|
|
515
|
+
with pytest.raises(ValueError) as exc_info:
|
|
516
|
+
finetune_provider_model("project-123::task-456::finetune-789")
|
|
517
|
+
assert str(exc_info.value) == "Fine tune finetune-789 not found"
|
|
518
|
+
|
|
519
|
+
|
|
520
|
+
def test_finetune_provider_model_incomplete_finetune(
|
|
521
|
+
mock_project, mock_task, mock_finetune
|
|
522
|
+
):
|
|
523
|
+
"""Test handling of incomplete fine-tune"""
|
|
524
|
+
finetune = Mock(spec=Finetune)
|
|
525
|
+
finetune.fine_tune_model_id = None
|
|
526
|
+
mock_finetune.return_value = finetune
|
|
527
|
+
|
|
528
|
+
with pytest.raises(ValueError) as exc_info:
|
|
529
|
+
finetune_provider_model("project-123::task-456::finetune-789")
|
|
530
|
+
assert (
|
|
531
|
+
str(exc_info.value)
|
|
532
|
+
== "Fine tune finetune-789 not completed. Refresh it's status in the fine-tune tab."
|
|
533
|
+
)
|
|
534
|
+
|
|
535
|
+
|
|
536
|
+
def test_finetune_provider_model_fireworks_provider(
|
|
537
|
+
mock_project, mock_task, mock_finetune
|
|
538
|
+
):
|
|
539
|
+
"""Test creation of Fireworks AI provider with specific adapter options"""
|
|
540
|
+
finetune = Mock(spec=Finetune)
|
|
541
|
+
finetune.provider = ModelProviderName.fireworks_ai
|
|
542
|
+
finetune.fine_tune_model_id = "fireworks-model-123"
|
|
543
|
+
mock_finetune.return_value = finetune
|
|
544
|
+
|
|
545
|
+
provider = finetune_provider_model("project-123::task-456::finetune-789")
|
|
546
|
+
|
|
547
|
+
assert provider.name == ModelProviderName.fireworks_ai
|
|
548
|
+
assert provider.provider_options == {"model": "fireworks-model-123"}
|
|
549
|
+
assert provider.adapter_options == {
|
|
550
|
+
"langchain": {"with_structured_output_options": {"method": "json_mode"}}
|
|
551
|
+
}
|
|
552
|
+
|
|
553
|
+
|
|
554
|
+
def test_openai_compatible_provider_model_success(mock_shared_config):
|
|
555
|
+
"""Test successful creation of an OpenAI compatible provider"""
|
|
556
|
+
model_id = "test_provider::gpt-4"
|
|
557
|
+
|
|
558
|
+
provider = openai_compatible_provider_model(model_id)
|
|
559
|
+
|
|
560
|
+
assert provider.name == ModelProviderName.openai_compatible
|
|
561
|
+
assert provider.provider_options == {
|
|
562
|
+
"model": "gpt-4",
|
|
563
|
+
"api_key": "test-key",
|
|
564
|
+
"openai_api_base": "https://api.test.com",
|
|
565
|
+
}
|
|
566
|
+
assert provider.supports_structured_output is False
|
|
567
|
+
assert provider.supports_data_gen is False
|
|
568
|
+
assert provider.untested_model is True
|
|
569
|
+
|
|
570
|
+
|
|
571
|
+
def test_openai_compatible_provider_model_no_api_key(mock_shared_config):
|
|
572
|
+
"""Test provider creation without API key (should work as some providers don't require it)"""
|
|
573
|
+
model_id = "no_key_provider::gpt-4"
|
|
574
|
+
|
|
575
|
+
provider = openai_compatible_provider_model(model_id)
|
|
576
|
+
|
|
577
|
+
assert provider.name == ModelProviderName.openai_compatible
|
|
578
|
+
assert provider.provider_options == {
|
|
579
|
+
"model": "gpt-4",
|
|
580
|
+
"api_key": None,
|
|
581
|
+
"openai_api_base": "https://api.nokey.com",
|
|
582
|
+
}
|
|
583
|
+
|
|
584
|
+
|
|
585
|
+
def test_openai_compatible_provider_model_invalid_id():
|
|
586
|
+
"""Test handling of invalid model ID format"""
|
|
587
|
+
with pytest.raises(ValueError) as exc_info:
|
|
588
|
+
openai_compatible_provider_model("invalid-id-format")
|
|
589
|
+
assert (
|
|
590
|
+
str(exc_info.value) == "Invalid openai compatible model ID: invalid-id-format"
|
|
591
|
+
)
|
|
592
|
+
|
|
593
|
+
|
|
594
|
+
def test_openai_compatible_provider_model_no_providers(mock_shared_config):
|
|
595
|
+
"""Test handling when no providers are configured"""
|
|
596
|
+
mock_shared_config.return_value.openai_compatible_providers = None
|
|
597
|
+
|
|
598
|
+
with pytest.raises(ValueError) as exc_info:
|
|
599
|
+
openai_compatible_provider_model("test_provider::gpt-4")
|
|
600
|
+
assert str(exc_info.value) == "OpenAI compatible provider test_provider not found"
|
|
601
|
+
|
|
602
|
+
|
|
603
|
+
def test_openai_compatible_provider_model_provider_not_found(mock_shared_config):
|
|
604
|
+
"""Test handling of non-existent provider"""
|
|
605
|
+
with pytest.raises(ValueError) as exc_info:
|
|
606
|
+
openai_compatible_provider_model("unknown_provider::gpt-4")
|
|
607
|
+
assert (
|
|
608
|
+
str(exc_info.value) == "OpenAI compatible provider unknown_provider not found"
|
|
609
|
+
)
|
|
610
|
+
|
|
611
|
+
|
|
612
|
+
def test_openai_compatible_provider_model_no_base_url(mock_shared_config):
|
|
613
|
+
"""Test handling of provider without base URL"""
|
|
614
|
+
mock_shared_config.return_value.openai_compatible_providers = [
|
|
615
|
+
{
|
|
616
|
+
"name": "test_provider",
|
|
617
|
+
"api_key": "test-key",
|
|
618
|
+
}
|
|
619
|
+
]
|
|
620
|
+
|
|
621
|
+
with pytest.raises(ValueError) as exc_info:
|
|
622
|
+
openai_compatible_provider_model("test_provider::gpt-4")
|
|
623
|
+
assert (
|
|
624
|
+
str(exc_info.value)
|
|
625
|
+
== "OpenAI compatible provider test_provider has no base URL"
|
|
626
|
+
)
|
kiln_ai/datamodel/__init__.py
CHANGED
|
@@ -1,3 +1,7 @@
|
|
|
1
|
+
"""
|
|
2
|
+
See our docs for details about our datamodel: https://kiln-ai.github.io/Kiln/kiln_core_docs/kiln_ai.html
|
|
3
|
+
"""
|
|
4
|
+
|
|
1
5
|
from __future__ import annotations
|
|
2
6
|
|
|
3
7
|
import json
|
|
@@ -8,7 +12,12 @@ from typing import TYPE_CHECKING, Callable, Dict, List, Type, Union
|
|
|
8
12
|
|
|
9
13
|
import jsonschema
|
|
10
14
|
import jsonschema.exceptions
|
|
11
|
-
from pydantic import
|
|
15
|
+
from pydantic import (
|
|
16
|
+
BaseModel,
|
|
17
|
+
Field,
|
|
18
|
+
ValidationInfo,
|
|
19
|
+
model_validator,
|
|
20
|
+
)
|
|
12
21
|
from typing_extensions import Self
|
|
13
22
|
|
|
14
23
|
from kiln_ai.datamodel.json_schema import JsonObjectSchema, schema_from_json_str
|
|
@@ -43,9 +52,25 @@ __all__ = [
|
|
|
43
52
|
"TaskOutputRatingType",
|
|
44
53
|
"TaskRequirement",
|
|
45
54
|
"TaskDeterminism",
|
|
55
|
+
"strict_mode",
|
|
56
|
+
"set_strict_mode",
|
|
46
57
|
]
|
|
47
58
|
|
|
48
59
|
|
|
60
|
+
# We want to be hard on ourselves for data completeness generated by the Kiln App, but don't want to make it hard for users to use the datamodel/library.
|
|
61
|
+
# Strict mode enables extra validations that we want to enforce in Kiln App (and any other client that wants best practices), but not in the library (unless they opt in)
|
|
62
|
+
_strict_mode: bool = False
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def strict_mode() -> bool:
|
|
66
|
+
return _strict_mode
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def set_strict_mode(value: bool) -> None:
|
|
70
|
+
global _strict_mode
|
|
71
|
+
_strict_mode = value
|
|
72
|
+
|
|
73
|
+
|
|
49
74
|
class Priority(IntEnum):
|
|
50
75
|
"""Defines priority levels for tasks and requirements, where P0 is highest priority."""
|
|
51
76
|
|
|
@@ -60,30 +85,71 @@ class TaskOutputRatingType(str, Enum):
|
|
|
60
85
|
"""Defines the types of rating systems available for task outputs."""
|
|
61
86
|
|
|
62
87
|
five_star = "five_star"
|
|
88
|
+
pass_fail = "pass_fail"
|
|
89
|
+
pass_fail_critical = "pass_fail_critical"
|
|
63
90
|
custom = "custom"
|
|
64
91
|
|
|
65
92
|
|
|
93
|
+
class RequirementRating(BaseModel):
|
|
94
|
+
"""Rating for a specific requirement within a task output."""
|
|
95
|
+
|
|
96
|
+
value: float = Field(
|
|
97
|
+
description="The rating value. Interpretation depends on rating type"
|
|
98
|
+
)
|
|
99
|
+
type: TaskOutputRatingType = Field(description="The type of rating")
|
|
100
|
+
|
|
101
|
+
|
|
66
102
|
class TaskOutputRating(KilnBaseModel):
|
|
67
103
|
"""
|
|
68
104
|
A rating for a task output, including an overall rating and ratings for each requirement.
|
|
69
105
|
|
|
70
|
-
|
|
106
|
+
Supports:
|
|
107
|
+
- five_star: 1-5 star ratings
|
|
108
|
+
- pass_fail: boolean pass/fail (1.0 = pass, 0.0 = fail)
|
|
109
|
+
- pass_fail_critical: tri-state (1.0 = pass, 0.0 = fail, -1.0 = critical fail)
|
|
71
110
|
"""
|
|
72
111
|
|
|
73
112
|
type: TaskOutputRatingType = Field(default=TaskOutputRatingType.five_star)
|
|
74
113
|
value: float | None = Field(
|
|
75
|
-
description="The
|
|
114
|
+
description="The rating value. Interpretation depends on rating type:\n- five_star: 1-5 stars\n- pass_fail: 1.0 (pass) or 0.0 (fail)\n- pass_fail_critical: 1.0 (pass), 0.0 (fail), or -1.0 (critical fail)",
|
|
76
115
|
default=None,
|
|
77
116
|
)
|
|
78
|
-
requirement_ratings: Dict[ID_TYPE,
|
|
117
|
+
requirement_ratings: Dict[ID_TYPE, RequirementRating] = Field(
|
|
79
118
|
default={},
|
|
80
|
-
description="The ratings of the requirements of the task.
|
|
119
|
+
description="The ratings of the requirements of the task.",
|
|
81
120
|
)
|
|
82
121
|
|
|
122
|
+
# Previously we stored rating values as a dict of floats, but now we store them as RequirementRating objects.
|
|
123
|
+
@model_validator(mode="before")
|
|
124
|
+
def upgrade_old_format(cls, data: dict) -> dict:
|
|
125
|
+
if not isinstance(data, dict):
|
|
126
|
+
return data
|
|
127
|
+
|
|
128
|
+
# Check if we have the old format (dict of floats)
|
|
129
|
+
req_ratings = data.get("requirement_ratings", {})
|
|
130
|
+
if req_ratings and all(
|
|
131
|
+
isinstance(v, (int, float)) for v in req_ratings.values()
|
|
132
|
+
):
|
|
133
|
+
# Convert each float to a RequirementRating object
|
|
134
|
+
# all ratings are five star at the point we used this format
|
|
135
|
+
data["requirement_ratings"] = {
|
|
136
|
+
k: {"value": v, "type": TaskOutputRatingType.five_star}
|
|
137
|
+
for k, v in req_ratings.items()
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
return data
|
|
141
|
+
|
|
83
142
|
# Used to select high quality outputs for example selection (MultiShotPromptBuilder, etc)
|
|
84
143
|
def is_high_quality(self) -> bool:
|
|
144
|
+
if self.value is None:
|
|
145
|
+
return False
|
|
146
|
+
|
|
85
147
|
if self.type == TaskOutputRatingType.five_star:
|
|
86
|
-
return self.value
|
|
148
|
+
return self.value >= 4
|
|
149
|
+
elif self.type == TaskOutputRatingType.pass_fail:
|
|
150
|
+
return self.value == 1.0
|
|
151
|
+
elif self.type == TaskOutputRatingType.pass_fail_critical:
|
|
152
|
+
return self.value == 1.0
|
|
87
153
|
return False
|
|
88
154
|
|
|
89
155
|
@model_validator(mode="after")
|
|
@@ -91,24 +157,61 @@ class TaskOutputRating(KilnBaseModel):
|
|
|
91
157
|
if self.type not in TaskOutputRatingType:
|
|
92
158
|
raise ValueError(f"Invalid rating type: {self.type}")
|
|
93
159
|
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
160
|
+
# Overall rating is optional
|
|
161
|
+
if self.value is not None:
|
|
162
|
+
self._validate_rating(self.type, self.value, "overall rating")
|
|
163
|
+
|
|
164
|
+
for req_id, req_rating in self.requirement_ratings.items():
|
|
165
|
+
self._validate_rating(
|
|
166
|
+
req_rating.type,
|
|
167
|
+
req_rating.value,
|
|
168
|
+
f"requirement rating for req ID: {req_id}",
|
|
169
|
+
)
|
|
99
170
|
|
|
100
171
|
return self
|
|
101
172
|
|
|
102
|
-
def
|
|
103
|
-
|
|
173
|
+
def _validate_rating(
|
|
174
|
+
self, type: TaskOutputRatingType, rating: float | None, rating_name: str
|
|
175
|
+
) -> None:
|
|
176
|
+
if type == TaskOutputRatingType.five_star:
|
|
177
|
+
self._validate_five_star(rating, rating_name)
|
|
178
|
+
elif type == TaskOutputRatingType.pass_fail:
|
|
179
|
+
self._validate_pass_fail(rating, rating_name)
|
|
180
|
+
elif type == TaskOutputRatingType.pass_fail_critical:
|
|
181
|
+
self._validate_pass_fail_critical(rating, rating_name)
|
|
182
|
+
|
|
183
|
+
def _validate_five_star(self, rating: float | None, rating_name: str) -> None:
|
|
184
|
+
if rating is None or not isinstance(rating, float) or not rating.is_integer():
|
|
104
185
|
raise ValueError(
|
|
105
|
-
f"{rating_name.capitalize()} of type five_star must be an integer value (1
|
|
186
|
+
f"{rating_name.capitalize()} of type five_star must be an integer value (1-5)"
|
|
106
187
|
)
|
|
107
188
|
if rating < 1 or rating > 5:
|
|
108
189
|
raise ValueError(
|
|
109
190
|
f"{rating_name.capitalize()} of type five_star must be between 1 and 5 stars"
|
|
110
191
|
)
|
|
111
192
|
|
|
193
|
+
def _validate_pass_fail(self, rating: float | None, rating_name: str) -> None:
|
|
194
|
+
if rating is None or not isinstance(rating, float) or not rating.is_integer():
|
|
195
|
+
raise ValueError(
|
|
196
|
+
f"{rating_name.capitalize()} of type pass_fail must be an integer value (0 or 1)"
|
|
197
|
+
)
|
|
198
|
+
if rating not in [0, 1]:
|
|
199
|
+
raise ValueError(
|
|
200
|
+
f"{rating_name.capitalize()} of type pass_fail must be 0 (fail) or 1 (pass)"
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
def _validate_pass_fail_critical(
|
|
204
|
+
self, rating: float | None, rating_name: str
|
|
205
|
+
) -> None:
|
|
206
|
+
if rating is None or not isinstance(rating, float) or not rating.is_integer():
|
|
207
|
+
raise ValueError(
|
|
208
|
+
f"{rating_name.capitalize()} of type pass_fail_critical must be an integer value (-1, 0, or 1)"
|
|
209
|
+
)
|
|
210
|
+
if rating not in [-1, 0, 1]:
|
|
211
|
+
raise ValueError(
|
|
212
|
+
f"{rating_name.capitalize()} of type pass_fail_critical must be -1 (critical fail), 0 (fail), or 1 (pass)"
|
|
213
|
+
)
|
|
214
|
+
|
|
112
215
|
|
|
113
216
|
class TaskOutput(KilnBaseModel):
|
|
114
217
|
"""
|
|
@@ -121,8 +224,9 @@ class TaskOutput(KilnBaseModel):
|
|
|
121
224
|
output: str = Field(
|
|
122
225
|
description="The output of the task. JSON formatted for structured output, plaintext for unstructured output."
|
|
123
226
|
)
|
|
124
|
-
source: DataSource = Field(
|
|
125
|
-
description="The source of the output: human or synthetic."
|
|
227
|
+
source: DataSource | None = Field(
|
|
228
|
+
description="The source of the output: human or synthetic.",
|
|
229
|
+
default=None,
|
|
126
230
|
)
|
|
127
231
|
rating: TaskOutputRating | None = Field(
|
|
128
232
|
default=None, description="The rating of the output"
|
|
@@ -139,6 +243,18 @@ class TaskOutput(KilnBaseModel):
|
|
|
139
243
|
raise ValueError(f"Output does not match task output schema: {e}")
|
|
140
244
|
return self
|
|
141
245
|
|
|
246
|
+
@model_validator(mode="after")
|
|
247
|
+
def validate_output_source(self, info: ValidationInfo) -> Self:
|
|
248
|
+
# On strict mode and not loaded from file, we validate output_source is not None.
|
|
249
|
+
# We want to be able to load any data, even if it's not perfect. But we want to create perfect data when adding new data.
|
|
250
|
+
if not strict_mode():
|
|
251
|
+
return self
|
|
252
|
+
if self.loaded_from_file(info):
|
|
253
|
+
return self
|
|
254
|
+
if self.source is None:
|
|
255
|
+
raise ValueError("Output source is required when strict mode is enabled")
|
|
256
|
+
return self
|
|
257
|
+
|
|
142
258
|
|
|
143
259
|
class FineTuneStatusType(str, Enum):
|
|
144
260
|
"""
|
|
@@ -326,8 +442,8 @@ class TaskRun(KilnParentedModel):
|
|
|
326
442
|
input: str = Field(
|
|
327
443
|
description="The inputs to the task. JSON formatted for structured input, plaintext for unstructured input."
|
|
328
444
|
)
|
|
329
|
-
input_source: DataSource = Field(
|
|
330
|
-
description="The source of the input: human or synthetic."
|
|
445
|
+
input_source: DataSource | None = Field(
|
|
446
|
+
default=None, description="The source of the input: human or synthetic."
|
|
331
447
|
)
|
|
332
448
|
|
|
333
449
|
output: TaskOutput = Field(description="The output of the task run.")
|
|
@@ -343,6 +459,10 @@ class TaskRun(KilnParentedModel):
|
|
|
343
459
|
default=None,
|
|
344
460
|
description="Intermediate outputs from the task run. Keys are the names of the intermediate output steps (cot=chain of thought, etc), values are the output data.",
|
|
345
461
|
)
|
|
462
|
+
tags: List[str] = Field(
|
|
463
|
+
default=[],
|
|
464
|
+
description="Tags for the task run. Tags are used to categorize task runs for filtering and reporting.",
|
|
465
|
+
)
|
|
346
466
|
|
|
347
467
|
def parent_task(self) -> Task | None:
|
|
348
468
|
if not isinstance(self.parent, Task):
|
|
@@ -392,6 +512,28 @@ class TaskRun(KilnParentedModel):
|
|
|
392
512
|
)
|
|
393
513
|
return self
|
|
394
514
|
|
|
515
|
+
@model_validator(mode="after")
|
|
516
|
+
def validate_input_source(self, info: ValidationInfo) -> Self:
|
|
517
|
+
# On strict mode and not loaded from file, we validate input_source is not None.
|
|
518
|
+
# We want to be able to load any data, even if it's not perfect. But we want to create perfect data when adding new data.
|
|
519
|
+
if not strict_mode():
|
|
520
|
+
return self
|
|
521
|
+
if self.loaded_from_file(info):
|
|
522
|
+
return self
|
|
523
|
+
if self.input_source is None:
|
|
524
|
+
raise ValueError("input_source is required when strict mode is enabled")
|
|
525
|
+
return self
|
|
526
|
+
|
|
527
|
+
@model_validator(mode="after")
|
|
528
|
+
def validate_tags(self) -> Self:
|
|
529
|
+
for tag in self.tags:
|
|
530
|
+
if not tag:
|
|
531
|
+
raise ValueError("Tags cannot be empty strings")
|
|
532
|
+
if " " in tag:
|
|
533
|
+
raise ValueError("Tags cannot contain spaces. Try underscores.")
|
|
534
|
+
|
|
535
|
+
return self
|
|
536
|
+
|
|
395
537
|
|
|
396
538
|
# Define the type alias for clarity
|
|
397
539
|
DatasetFilter = Callable[[TaskRun], bool]
|
|
@@ -552,7 +694,7 @@ class TaskRequirement(BaseModel):
|
|
|
552
694
|
Defines a specific requirement that should be met by task outputs.
|
|
553
695
|
|
|
554
696
|
Includes an identifier, name, description, instruction for meeting the requirement,
|
|
555
|
-
and
|
|
697
|
+
priority level, and rating type (five_star, pass_fail, pass_fail_critical, custom).
|
|
556
698
|
"""
|
|
557
699
|
|
|
558
700
|
id: ID_TYPE = ID_FIELD
|
|
@@ -560,6 +702,7 @@ class TaskRequirement(BaseModel):
|
|
|
560
702
|
description: str | None = Field(default=None)
|
|
561
703
|
instruction: str = Field(min_length=1)
|
|
562
704
|
priority: Priority = Field(default=Priority.p2)
|
|
705
|
+
type: TaskOutputRatingType = Field(default=TaskOutputRatingType.five_star)
|
|
563
706
|
|
|
564
707
|
|
|
565
708
|
class TaskDeterminism(str, Enum):
|