kiln-ai 0.17.0__py3-none-any.whl → 0.19.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/adapter_registry.py +28 -0
- kiln_ai/adapters/chat/chat_formatter.py +0 -1
- kiln_ai/adapters/data_gen/data_gen_prompts.py +121 -36
- kiln_ai/adapters/data_gen/data_gen_task.py +51 -38
- kiln_ai/adapters/data_gen/test_data_gen_task.py +318 -37
- kiln_ai/adapters/eval/base_eval.py +6 -7
- kiln_ai/adapters/eval/eval_runner.py +5 -1
- kiln_ai/adapters/eval/g_eval.py +17 -12
- kiln_ai/adapters/eval/test_base_eval.py +8 -2
- kiln_ai/adapters/eval/test_eval_runner.py +6 -12
- kiln_ai/adapters/eval/test_g_eval.py +115 -5
- kiln_ai/adapters/eval/test_g_eval_data.py +1 -1
- kiln_ai/adapters/fine_tune/base_finetune.py +2 -6
- kiln_ai/adapters/fine_tune/dataset_formatter.py +1 -5
- kiln_ai/adapters/fine_tune/fireworks_finetune.py +32 -20
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +1 -1
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +30 -21
- kiln_ai/adapters/fine_tune/test_vertex_finetune.py +2 -7
- kiln_ai/adapters/fine_tune/together_finetune.py +1 -1
- kiln_ai/adapters/ml_model_list.py +926 -125
- kiln_ai/adapters/model_adapters/base_adapter.py +11 -7
- kiln_ai/adapters/model_adapters/litellm_adapter.py +23 -1
- kiln_ai/adapters/model_adapters/test_base_adapter.py +1 -2
- kiln_ai/adapters/model_adapters/test_litellm_adapter.py +70 -3
- kiln_ai/adapters/model_adapters/test_structured_output.py +13 -13
- kiln_ai/adapters/parsers/parser_registry.py +0 -2
- kiln_ai/adapters/parsers/r1_parser.py +0 -1
- kiln_ai/adapters/parsers/test_r1_parser.py +1 -1
- kiln_ai/adapters/provider_tools.py +20 -19
- kiln_ai/adapters/remote_config.py +113 -0
- kiln_ai/adapters/repair/repair_task.py +2 -7
- kiln_ai/adapters/test_adapter_registry.py +30 -2
- kiln_ai/adapters/test_ml_model_list.py +30 -0
- kiln_ai/adapters/test_prompt_adaptors.py +0 -4
- kiln_ai/adapters/test_provider_tools.py +18 -12
- kiln_ai/adapters/test_remote_config.py +456 -0
- kiln_ai/datamodel/basemodel.py +54 -28
- kiln_ai/datamodel/datamodel_enums.py +2 -0
- kiln_ai/datamodel/dataset_split.py +5 -3
- kiln_ai/datamodel/eval.py +35 -3
- kiln_ai/datamodel/finetune.py +2 -3
- kiln_ai/datamodel/project.py +3 -3
- kiln_ai/datamodel/prompt.py +2 -2
- kiln_ai/datamodel/prompt_id.py +4 -4
- kiln_ai/datamodel/task.py +6 -6
- kiln_ai/datamodel/task_output.py +1 -3
- kiln_ai/datamodel/task_run.py +0 -2
- kiln_ai/datamodel/test_basemodel.py +210 -18
- kiln_ai/datamodel/test_eval_model.py +152 -10
- kiln_ai/datamodel/test_model_perf.py +1 -1
- kiln_ai/datamodel/test_prompt_id.py +5 -1
- kiln_ai/datamodel/test_task.py +5 -0
- kiln_ai/utils/config.py +10 -0
- kiln_ai/utils/logging.py +4 -3
- {kiln_ai-0.17.0.dist-info → kiln_ai-0.19.0.dist-info}/METADATA +33 -3
- {kiln_ai-0.17.0.dist-info → kiln_ai-0.19.0.dist-info}/RECORD +58 -56
- {kiln_ai-0.17.0.dist-info → kiln_ai-0.19.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.17.0.dist-info → kiln_ai-0.19.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -25,11 +25,7 @@ from kiln_ai.adapters.provider_tools import (
|
|
|
25
25
|
provider_name_from_id,
|
|
26
26
|
provider_warnings,
|
|
27
27
|
)
|
|
28
|
-
from kiln_ai.datamodel import
|
|
29
|
-
Finetune,
|
|
30
|
-
StructuredOutputMode,
|
|
31
|
-
Task,
|
|
32
|
-
)
|
|
28
|
+
from kiln_ai.datamodel import Finetune, StructuredOutputMode, Task
|
|
33
29
|
from kiln_ai.datamodel.datamodel_enums import ChatStrategy
|
|
34
30
|
from kiln_ai.datamodel.task import RunConfigProperties
|
|
35
31
|
|
|
@@ -199,6 +195,7 @@ def test_provider_name_from_id_case_sensitivity():
|
|
|
199
195
|
(ModelProviderName.ollama, "Ollama"),
|
|
200
196
|
(ModelProviderName.openai, "OpenAI"),
|
|
201
197
|
(ModelProviderName.fireworks_ai, "Fireworks AI"),
|
|
198
|
+
(ModelProviderName.siliconflow_cn, "SiliconFlow"),
|
|
202
199
|
(ModelProviderName.kiln_fine_tune, "Fine Tuned Models"),
|
|
203
200
|
(ModelProviderName.kiln_custom_registry, "Custom Models"),
|
|
204
201
|
],
|
|
@@ -420,6 +417,17 @@ async def test_builtin_model_from_invalid_provider(mock_config):
|
|
|
420
417
|
assert provider is None
|
|
421
418
|
|
|
422
419
|
|
|
420
|
+
@pytest.mark.asyncio
|
|
421
|
+
async def test_builtin_model_future_proof():
|
|
422
|
+
"""Test handling of a model that doesn't exist yet but could be added over the air"""
|
|
423
|
+
with patch("kiln_ai.adapters.provider_tools.built_in_models") as mock_models:
|
|
424
|
+
mock_models.__iter__.return_value = []
|
|
425
|
+
|
|
426
|
+
# should not find it, but should not raise an error
|
|
427
|
+
result = builtin_model_from("gpt_99")
|
|
428
|
+
assert result is None
|
|
429
|
+
|
|
430
|
+
|
|
423
431
|
@pytest.mark.asyncio
|
|
424
432
|
async def test_builtin_model_from_model_no_providers():
|
|
425
433
|
"""Test handling of a model with no providers"""
|
|
@@ -433,10 +441,8 @@ async def test_builtin_model_from_model_no_providers():
|
|
|
433
441
|
)
|
|
434
442
|
mock_models.__iter__.return_value = [mock_model]
|
|
435
443
|
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
assert str(exc_info.value) == f"Model {ModelName.phi_3_5} has no providers"
|
|
444
|
+
result = builtin_model_from(ModelName.phi_3_5)
|
|
445
|
+
assert result is None
|
|
440
446
|
|
|
441
447
|
|
|
442
448
|
@pytest.mark.asyncio
|
|
@@ -461,7 +467,7 @@ def test_finetune_provider_model_success(mock_project, mock_task, mock_finetune)
|
|
|
461
467
|
assert provider.model_id == "ft:gpt-3.5-turbo:custom:model-123"
|
|
462
468
|
assert provider.structured_output_mode == StructuredOutputMode.json_schema
|
|
463
469
|
assert provider.reasoning_capable is False
|
|
464
|
-
assert provider.parser
|
|
470
|
+
assert provider.parser is None
|
|
465
471
|
|
|
466
472
|
|
|
467
473
|
def test_finetune_provider_model_success_final_and_intermediate(
|
|
@@ -476,7 +482,7 @@ def test_finetune_provider_model_success_final_and_intermediate(
|
|
|
476
482
|
assert provider.model_id == "ft:gpt-3.5-turbo:custom:model-123"
|
|
477
483
|
assert provider.structured_output_mode == StructuredOutputMode.json_schema
|
|
478
484
|
assert provider.reasoning_capable is False
|
|
479
|
-
assert provider.parser
|
|
485
|
+
assert provider.parser is None
|
|
480
486
|
|
|
481
487
|
|
|
482
488
|
def test_finetune_provider_model_success_r1_compatible(
|
|
@@ -590,7 +596,7 @@ def test_finetune_provider_model_structured_mode(
|
|
|
590
596
|
assert provider.model_id == "fireworks-model-123"
|
|
591
597
|
assert provider.structured_output_mode == expected_mode
|
|
592
598
|
assert provider.reasoning_capable is False
|
|
593
|
-
assert provider.parser
|
|
599
|
+
assert provider.parser is None
|
|
594
600
|
|
|
595
601
|
|
|
596
602
|
def test_openai_compatible_provider_config(mock_shared_config):
|
|
@@ -0,0 +1,456 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import json
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
from unittest.mock import patch
|
|
6
|
+
|
|
7
|
+
import pytest
|
|
8
|
+
|
|
9
|
+
from kiln_ai.adapters.ml_model_list import (
|
|
10
|
+
KilnModel,
|
|
11
|
+
KilnModelProvider,
|
|
12
|
+
ModelFamily,
|
|
13
|
+
ModelName,
|
|
14
|
+
ModelProviderName,
|
|
15
|
+
StructuredOutputMode,
|
|
16
|
+
built_in_models,
|
|
17
|
+
)
|
|
18
|
+
from kiln_ai.adapters.remote_config import (
|
|
19
|
+
deserialize_config_at_path,
|
|
20
|
+
dump_builtin_config,
|
|
21
|
+
load_from_url,
|
|
22
|
+
load_remote_models,
|
|
23
|
+
serialize_config,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@pytest.fixture
|
|
28
|
+
def mock_model() -> KilnModel:
|
|
29
|
+
return KilnModel(
|
|
30
|
+
family=ModelFamily.gpt,
|
|
31
|
+
name=ModelName.gpt_4_1,
|
|
32
|
+
friendly_name="GPT 4.1",
|
|
33
|
+
providers=[
|
|
34
|
+
KilnModelProvider(
|
|
35
|
+
name=ModelProviderName.openai,
|
|
36
|
+
model_id="gpt-4.1",
|
|
37
|
+
provider_finetune_id="gpt-4.1-2025-04-14",
|
|
38
|
+
structured_output_mode=StructuredOutputMode.json_schema,
|
|
39
|
+
supports_logprobs=True,
|
|
40
|
+
suggested_for_evals=True,
|
|
41
|
+
),
|
|
42
|
+
KilnModelProvider(
|
|
43
|
+
name=ModelProviderName.openrouter,
|
|
44
|
+
model_id="openai/gpt-4.1",
|
|
45
|
+
structured_output_mode=StructuredOutputMode.json_schema,
|
|
46
|
+
supports_logprobs=True,
|
|
47
|
+
suggested_for_evals=True,
|
|
48
|
+
),
|
|
49
|
+
KilnModelProvider(
|
|
50
|
+
name=ModelProviderName.azure_openai,
|
|
51
|
+
model_id="gpt-4.1",
|
|
52
|
+
suggested_for_evals=True,
|
|
53
|
+
),
|
|
54
|
+
],
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def test_round_trip(tmp_path):
|
|
59
|
+
path = tmp_path / "models.json"
|
|
60
|
+
serialize_config(built_in_models, path)
|
|
61
|
+
loaded = deserialize_config_at_path(path)
|
|
62
|
+
assert [m.model_dump(mode="json") for m in loaded] == [
|
|
63
|
+
m.model_dump(mode="json") for m in built_in_models
|
|
64
|
+
]
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def test_load_from_url(mock_model):
|
|
68
|
+
sample_model = mock_model
|
|
69
|
+
sample = [sample_model.model_dump(mode="json")]
|
|
70
|
+
|
|
71
|
+
class FakeResponse:
|
|
72
|
+
def raise_for_status(self):
|
|
73
|
+
pass
|
|
74
|
+
|
|
75
|
+
def json(self):
|
|
76
|
+
return {"model_list": sample}
|
|
77
|
+
|
|
78
|
+
with patch(
|
|
79
|
+
"kiln_ai.adapters.remote_config.requests.get", return_value=FakeResponse()
|
|
80
|
+
):
|
|
81
|
+
models = load_from_url("http://example.com/models.json")
|
|
82
|
+
|
|
83
|
+
assert len(models) == 1
|
|
84
|
+
assert sample_model == models[0]
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def test_load_from_url_calls_deserialize_config_data(mock_model):
|
|
88
|
+
"""Test that load_from_url calls deserialize_config_data with the model_list from the response."""
|
|
89
|
+
sample_model_data = [mock_model.model_dump(mode="json")]
|
|
90
|
+
response_data = {"model_list": sample_model_data}
|
|
91
|
+
|
|
92
|
+
class FakeResponse:
|
|
93
|
+
def raise_for_status(self):
|
|
94
|
+
pass
|
|
95
|
+
|
|
96
|
+
def json(self):
|
|
97
|
+
return response_data
|
|
98
|
+
|
|
99
|
+
with (
|
|
100
|
+
patch(
|
|
101
|
+
"kiln_ai.adapters.remote_config.requests.get", return_value=FakeResponse()
|
|
102
|
+
) as mock_get,
|
|
103
|
+
patch(
|
|
104
|
+
"kiln_ai.adapters.remote_config.deserialize_config_data"
|
|
105
|
+
) as mock_deserialize,
|
|
106
|
+
):
|
|
107
|
+
mock_deserialize.return_value = [mock_model]
|
|
108
|
+
|
|
109
|
+
result = load_from_url("http://example.com/models.json")
|
|
110
|
+
|
|
111
|
+
# Verify requests.get was called with correct URL
|
|
112
|
+
mock_get.assert_called_once_with("http://example.com/models.json", timeout=10)
|
|
113
|
+
|
|
114
|
+
# Verify deserialize_config_data was called with the model_list data
|
|
115
|
+
mock_deserialize.assert_called_once_with(response_data)
|
|
116
|
+
|
|
117
|
+
# Verify the result is what deserialize_config_data returned
|
|
118
|
+
assert result == [mock_model]
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def test_dump_builtin_config(tmp_path):
|
|
122
|
+
path = tmp_path / "out.json"
|
|
123
|
+
dump_builtin_config(path)
|
|
124
|
+
loaded = deserialize_config_at_path(path)
|
|
125
|
+
assert [m.model_dump(mode="json") for m in loaded] == [
|
|
126
|
+
m.model_dump(mode="json") for m in built_in_models
|
|
127
|
+
]
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
@pytest.mark.asyncio
|
|
131
|
+
async def test_load_remote_models_success(monkeypatch, mock_model):
|
|
132
|
+
del os.environ["KILN_SKIP_REMOTE_MODEL_LIST"]
|
|
133
|
+
original = built_in_models.copy()
|
|
134
|
+
sample_models = [mock_model]
|
|
135
|
+
|
|
136
|
+
def fake_fetch(url):
|
|
137
|
+
return sample_models
|
|
138
|
+
|
|
139
|
+
monkeypatch.setattr("kiln_ai.adapters.remote_config.load_from_url", fake_fetch)
|
|
140
|
+
|
|
141
|
+
load_remote_models("http://example.com/models.json")
|
|
142
|
+
await asyncio.sleep(0.01)
|
|
143
|
+
assert built_in_models == sample_models
|
|
144
|
+
built_in_models[:] = original
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
@pytest.mark.asyncio
|
|
148
|
+
async def test_load_remote_models_failure(monkeypatch):
|
|
149
|
+
original = built_in_models.copy()
|
|
150
|
+
|
|
151
|
+
def fake_fetch(url):
|
|
152
|
+
raise RuntimeError("fail")
|
|
153
|
+
|
|
154
|
+
monkeypatch.setattr("kiln_ai.adapters.remote_config.load_from_url", fake_fetch)
|
|
155
|
+
|
|
156
|
+
load_remote_models("http://example.com/models.json")
|
|
157
|
+
await asyncio.sleep(0.01)
|
|
158
|
+
assert built_in_models == original
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def test_deserialize_config_with_extra_keys(tmp_path, mock_model):
|
|
162
|
+
# Take a valid model and add an extra key, ensure it is ignored and still loads
|
|
163
|
+
model_dict = mock_model.model_dump(mode="json")
|
|
164
|
+
model_dict["extra_key"] = "should be ignored or error"
|
|
165
|
+
model_dict["providers"][0]["extra_key"] = "should be ignored or error"
|
|
166
|
+
data = {"model_list": [model_dict]}
|
|
167
|
+
path = tmp_path / "extra.json"
|
|
168
|
+
path.write_text(json.dumps(data))
|
|
169
|
+
# Should NOT raise, and extra key should be ignored
|
|
170
|
+
models = deserialize_config_at_path(path)
|
|
171
|
+
assert hasattr(models[0], "family")
|
|
172
|
+
assert not hasattr(models[0], "extra_key")
|
|
173
|
+
assert hasattr(models[0], "providers")
|
|
174
|
+
assert not hasattr(models[0].providers[0], "extra_key")
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
def test_deserialize_config_with_invalid_models(tmp_path, caplog, mock_model):
|
|
178
|
+
"""Test comprehensive handling of invalid models and providers during deserialization."""
|
|
179
|
+
|
|
180
|
+
# Create a fully valid model as baseline
|
|
181
|
+
valid_model = mock_model.model_dump(mode="json")
|
|
182
|
+
|
|
183
|
+
# Case 1: Invalid model - missing required field 'family'
|
|
184
|
+
invalid_model_missing_family = mock_model.model_dump(mode="json")
|
|
185
|
+
del invalid_model_missing_family["family"]
|
|
186
|
+
|
|
187
|
+
# Case 2: Invalid model - invalid data type for required field
|
|
188
|
+
invalid_model_wrong_type = mock_model.model_dump(mode="json")
|
|
189
|
+
invalid_model_wrong_type["name"] = None # name should be a string, not None
|
|
190
|
+
|
|
191
|
+
# Case 3: Invalid model - completely malformed
|
|
192
|
+
invalid_model_malformed = {"not_a_valid_model": "at_all"}
|
|
193
|
+
|
|
194
|
+
# Case 4: Valid model with one invalid provider (should keep model, skip invalid provider)
|
|
195
|
+
valid_model_invalid_provider = mock_model.model_dump(mode="json")
|
|
196
|
+
valid_model_invalid_provider["name"] = "test_model_invalid_provider" # Unique name
|
|
197
|
+
valid_model_invalid_provider["providers"][0]["name"] = "unknown-provider-123"
|
|
198
|
+
|
|
199
|
+
# Case 5: Valid model with mixed valid/invalid providers (should keep model and valid providers)
|
|
200
|
+
valid_model_mixed_providers = mock_model.model_dump(mode="json")
|
|
201
|
+
valid_model_mixed_providers["name"] = "test_model_mixed_providers" # Unique name
|
|
202
|
+
# Add a second provider that's valid
|
|
203
|
+
valid_provider = valid_model_mixed_providers["providers"][0].copy()
|
|
204
|
+
valid_provider["name"] = "azure_openai"
|
|
205
|
+
# Make first provider invalid
|
|
206
|
+
valid_model_mixed_providers["providers"][0]["name"] = "invalid-provider-1"
|
|
207
|
+
# Add invalid provider with missing required field
|
|
208
|
+
invalid_provider = valid_model_mixed_providers["providers"][0].copy()
|
|
209
|
+
del invalid_provider["name"]
|
|
210
|
+
# Add another invalid provider with wrong type
|
|
211
|
+
invalid_provider_2 = valid_model_mixed_providers["providers"][0].copy()
|
|
212
|
+
invalid_provider_2["supports_structured_output"] = "not_a_boolean"
|
|
213
|
+
|
|
214
|
+
valid_model_mixed_providers["providers"] = [
|
|
215
|
+
valid_model_mixed_providers["providers"][0], # invalid name
|
|
216
|
+
valid_provider, # valid
|
|
217
|
+
invalid_provider, # missing name
|
|
218
|
+
invalid_provider_2, # wrong type
|
|
219
|
+
]
|
|
220
|
+
|
|
221
|
+
# Case 6: Valid model with all invalid providers (should keep model with empty providers)
|
|
222
|
+
valid_model_all_invalid_providers = mock_model.model_dump(mode="json")
|
|
223
|
+
valid_model_all_invalid_providers["name"] = (
|
|
224
|
+
"test_model_all_invalid_providers" # Unique name
|
|
225
|
+
)
|
|
226
|
+
valid_model_all_invalid_providers["providers"][0]["name"] = "unknown-provider-456"
|
|
227
|
+
if len(valid_model_all_invalid_providers["providers"]) > 1:
|
|
228
|
+
valid_model_all_invalid_providers["providers"][1]["name"] = (
|
|
229
|
+
"another-unknown-provider"
|
|
230
|
+
)
|
|
231
|
+
if len(valid_model_all_invalid_providers["providers"]) > 2:
|
|
232
|
+
valid_model_all_invalid_providers["providers"][2]["name"] = (
|
|
233
|
+
"yet-another-unknown-provider"
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
data = {
|
|
237
|
+
"model_list": [
|
|
238
|
+
valid_model, # Should be kept
|
|
239
|
+
invalid_model_missing_family, # Should be skipped
|
|
240
|
+
invalid_model_wrong_type, # Should be skipped
|
|
241
|
+
invalid_model_malformed, # Should be skipped
|
|
242
|
+
valid_model_invalid_provider, # Should be kept with empty providers
|
|
243
|
+
valid_model_mixed_providers, # Should be kept with 1 valid provider
|
|
244
|
+
valid_model_all_invalid_providers, # Should be kept with empty providers
|
|
245
|
+
]
|
|
246
|
+
}
|
|
247
|
+
path = tmp_path / "mixed_models.json"
|
|
248
|
+
path.write_text(json.dumps(data))
|
|
249
|
+
|
|
250
|
+
# Enable logging to capture warnings
|
|
251
|
+
with caplog.at_level(logging.WARNING):
|
|
252
|
+
models = deserialize_config_at_path(path)
|
|
253
|
+
|
|
254
|
+
# Should have 4 valid models (original + 3 with provider issues but valid model structure)
|
|
255
|
+
assert len(models) == 4
|
|
256
|
+
|
|
257
|
+
# Check the first model is fully intact
|
|
258
|
+
assert models[0].name == mock_model.name
|
|
259
|
+
assert models[0].family == mock_model.family
|
|
260
|
+
assert len(models[0].providers) == 3 # mock_model has 3 providers
|
|
261
|
+
|
|
262
|
+
# Check model with invalid provider has remaining valid providers
|
|
263
|
+
model_with_invalid_provider = next(
|
|
264
|
+
m for m in models if m.name == valid_model_invalid_provider["name"]
|
|
265
|
+
)
|
|
266
|
+
# Should keep the valid providers from the original model (openrouter, azure_openai)
|
|
267
|
+
assert len(model_with_invalid_provider.providers) == 2
|
|
268
|
+
provider_names = {p.name.value for p in model_with_invalid_provider.providers}
|
|
269
|
+
assert provider_names == {"openrouter", "azure_openai"}
|
|
270
|
+
|
|
271
|
+
# Check model with mixed providers has only the valid one
|
|
272
|
+
model_with_mixed_providers = next(
|
|
273
|
+
m for m in models if m.name == valid_model_mixed_providers["name"]
|
|
274
|
+
)
|
|
275
|
+
assert len(model_with_mixed_providers.providers) == 1
|
|
276
|
+
assert model_with_mixed_providers.providers[0].name.value == "azure_openai"
|
|
277
|
+
|
|
278
|
+
# Check model with all invalid providers has empty providers
|
|
279
|
+
model_with_all_invalid_providers = next(
|
|
280
|
+
m for m in models if m.name == valid_model_all_invalid_providers["name"]
|
|
281
|
+
)
|
|
282
|
+
assert len(model_with_all_invalid_providers.providers) == 0
|
|
283
|
+
|
|
284
|
+
# Check warning logs
|
|
285
|
+
warning_logs = [
|
|
286
|
+
record for record in caplog.records if record.levelno == logging.WARNING
|
|
287
|
+
]
|
|
288
|
+
|
|
289
|
+
# Should have warnings for:
|
|
290
|
+
# - 3 invalid models (missing family, wrong type, malformed)
|
|
291
|
+
# - 1 invalid provider in case 4 (unknown-provider-123)
|
|
292
|
+
# - 3 invalid providers in case 5 (invalid-provider-1, missing name, wrong type boolean)
|
|
293
|
+
# - 3 invalid providers in case 6 (unknown-provider-456, another-unknown-provider, yet-another-unknown-provider)
|
|
294
|
+
assert len(warning_logs) >= 10
|
|
295
|
+
|
|
296
|
+
# Check that warning messages contain expected content
|
|
297
|
+
model_warnings = [
|
|
298
|
+
log for log in warning_logs if "Failed to validate a model from" in log.message
|
|
299
|
+
]
|
|
300
|
+
provider_warnings = [
|
|
301
|
+
log
|
|
302
|
+
for log in warning_logs
|
|
303
|
+
if "Failed to validate a model provider" in log.message
|
|
304
|
+
]
|
|
305
|
+
|
|
306
|
+
assert len(model_warnings) == 3 # 3 completely invalid models
|
|
307
|
+
assert (
|
|
308
|
+
len(provider_warnings) == 7
|
|
309
|
+
) # Exactly 7 invalid providers across different models
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
def test_deserialize_config_empty_provider_list(tmp_path, mock_model):
|
|
313
|
+
"""Test that models with empty provider lists are handled correctly."""
|
|
314
|
+
model_with_empty_providers = mock_model.model_dump(mode="json")
|
|
315
|
+
model_with_empty_providers["providers"] = []
|
|
316
|
+
|
|
317
|
+
data = {"model_list": [model_with_empty_providers]}
|
|
318
|
+
path = tmp_path / "empty_providers.json"
|
|
319
|
+
path.write_text(json.dumps(data))
|
|
320
|
+
|
|
321
|
+
models = deserialize_config_at_path(path)
|
|
322
|
+
assert len(models) == 1
|
|
323
|
+
assert len(models[0].providers) == 0
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
def test_deserialize_config_missing_provider_field(tmp_path, caplog, mock_model):
|
|
327
|
+
"""Test that models missing the providers field are handled correctly."""
|
|
328
|
+
model_without_providers = mock_model.model_dump(mode="json")
|
|
329
|
+
del model_without_providers["providers"]
|
|
330
|
+
|
|
331
|
+
data = {"model_list": [model_without_providers]}
|
|
332
|
+
path = tmp_path / "no_providers.json"
|
|
333
|
+
path.write_text(json.dumps(data))
|
|
334
|
+
|
|
335
|
+
with caplog.at_level(logging.WARNING):
|
|
336
|
+
models = deserialize_config_at_path(path)
|
|
337
|
+
|
|
338
|
+
# Model should be kept with empty providers (deserialize_config handles missing providers gracefully)
|
|
339
|
+
assert len(models) == 1
|
|
340
|
+
assert len(models[0].providers) == 0
|
|
341
|
+
assert models[0].name == mock_model.name
|
|
342
|
+
|
|
343
|
+
# Should not have any warnings since the function handles missing providers gracefully
|
|
344
|
+
warning_logs = [
|
|
345
|
+
record for record in caplog.records if record.levelno == logging.WARNING
|
|
346
|
+
]
|
|
347
|
+
assert len(warning_logs) == 0
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
def test_deserialize_config_provider_with_extra_fields(tmp_path, mock_model):
|
|
351
|
+
"""Test that providers with extra unknown fields are handled gracefully."""
|
|
352
|
+
model_with_extra_provider_fields = mock_model.model_dump(mode="json")
|
|
353
|
+
model_with_extra_provider_fields["providers"][0]["unknown_field"] = (
|
|
354
|
+
"should_be_ignored"
|
|
355
|
+
)
|
|
356
|
+
model_with_extra_provider_fields["providers"][0]["another_extra"] = {
|
|
357
|
+
"nested": "data"
|
|
358
|
+
}
|
|
359
|
+
|
|
360
|
+
data = {"model_list": [model_with_extra_provider_fields]}
|
|
361
|
+
path = tmp_path / "extra_provider_fields.json"
|
|
362
|
+
path.write_text(json.dumps(data))
|
|
363
|
+
|
|
364
|
+
models = deserialize_config_at_path(path)
|
|
365
|
+
assert len(models) == 1
|
|
366
|
+
assert len(models[0].providers) == 3 # mock_model has 3 providers
|
|
367
|
+
# Extra fields should be ignored, not present in the final object
|
|
368
|
+
assert not hasattr(models[0].providers[0], "unknown_field")
|
|
369
|
+
assert not hasattr(models[0].providers[0], "another_extra")
|
|
370
|
+
|
|
371
|
+
|
|
372
|
+
def test_deserialize_config_model_with_extra_fields(tmp_path, mock_model):
|
|
373
|
+
"""Test that models with extra unknown fields are handled gracefully."""
|
|
374
|
+
model_with_extra_fields = mock_model.model_dump(mode="json")
|
|
375
|
+
model_with_extra_fields["future_field"] = "should_be_ignored"
|
|
376
|
+
model_with_extra_fields["complex_extra"] = {"nested": {"data": [1, 2, 3]}}
|
|
377
|
+
|
|
378
|
+
data = {"model_list": [model_with_extra_fields]}
|
|
379
|
+
path = tmp_path / "extra_model_fields.json"
|
|
380
|
+
path.write_text(json.dumps(data))
|
|
381
|
+
|
|
382
|
+
models = deserialize_config_at_path(path)
|
|
383
|
+
assert len(models) == 1
|
|
384
|
+
assert models[0].name == mock_model.name
|
|
385
|
+
# Extra fields should be ignored, not present in the final object
|
|
386
|
+
assert not hasattr(models[0], "future_field")
|
|
387
|
+
assert not hasattr(models[0], "complex_extra")
|
|
388
|
+
|
|
389
|
+
|
|
390
|
+
def test_deserialize_config_mixed_valid_invalid_providers_single_model(
|
|
391
|
+
tmp_path, caplog, mock_model
|
|
392
|
+
):
|
|
393
|
+
"""Test a single model with a mix of valid and invalid providers in detail."""
|
|
394
|
+
model = mock_model.model_dump(mode="json")
|
|
395
|
+
|
|
396
|
+
# Create a mix of provider scenarios
|
|
397
|
+
valid_provider_1 = model["providers"][0].copy()
|
|
398
|
+
valid_provider_1["name"] = "openai"
|
|
399
|
+
|
|
400
|
+
valid_provider_2 = model["providers"][0].copy()
|
|
401
|
+
valid_provider_2["name"] = "azure_openai"
|
|
402
|
+
|
|
403
|
+
invalid_provider_unknown_name = model["providers"][0].copy()
|
|
404
|
+
invalid_provider_unknown_name["name"] = "nonexistent_provider"
|
|
405
|
+
|
|
406
|
+
invalid_provider_missing_name = model["providers"][0].copy()
|
|
407
|
+
del invalid_provider_missing_name["name"]
|
|
408
|
+
|
|
409
|
+
invalid_provider_wrong_type = model["providers"][0].copy()
|
|
410
|
+
invalid_provider_wrong_type["supports_structured_output"] = "not_a_boolean"
|
|
411
|
+
|
|
412
|
+
model["providers"] = [
|
|
413
|
+
valid_provider_1,
|
|
414
|
+
invalid_provider_unknown_name,
|
|
415
|
+
valid_provider_2,
|
|
416
|
+
invalid_provider_missing_name,
|
|
417
|
+
invalid_provider_wrong_type,
|
|
418
|
+
]
|
|
419
|
+
|
|
420
|
+
data = {"model_list": [model]}
|
|
421
|
+
path = tmp_path / "mixed_providers_single.json"
|
|
422
|
+
path.write_text(json.dumps(data))
|
|
423
|
+
|
|
424
|
+
with caplog.at_level(logging.WARNING):
|
|
425
|
+
models = deserialize_config_at_path(path)
|
|
426
|
+
|
|
427
|
+
# Should have 1 model with 2 valid providers
|
|
428
|
+
assert len(models) == 1
|
|
429
|
+
assert len(models[0].providers) == 2
|
|
430
|
+
assert models[0].providers[0].name.value == "openai"
|
|
431
|
+
assert models[0].providers[1].name.value == "azure_openai"
|
|
432
|
+
|
|
433
|
+
# Should have logged 3 provider validation warnings
|
|
434
|
+
provider_warnings = [
|
|
435
|
+
log
|
|
436
|
+
for log in caplog.records
|
|
437
|
+
if log.levelno == logging.WARNING
|
|
438
|
+
and "Failed to validate a model provider" in log.message
|
|
439
|
+
]
|
|
440
|
+
assert len(provider_warnings) == 3
|
|
441
|
+
|
|
442
|
+
|
|
443
|
+
def test_deserialize_config_empty_json_structures(tmp_path):
|
|
444
|
+
"""Test various empty JSON structures."""
|
|
445
|
+
# Test empty model_list
|
|
446
|
+
data = {"model_list": []}
|
|
447
|
+
path = tmp_path / "empty_model_list.json"
|
|
448
|
+
path.write_text(json.dumps(data))
|
|
449
|
+
models = deserialize_config_at_path(path)
|
|
450
|
+
assert len(models) == 0
|
|
451
|
+
|
|
452
|
+
# Test empty object with no model_list key
|
|
453
|
+
path = tmp_path / "empty_object.json"
|
|
454
|
+
path.write_text(json.dumps({}))
|
|
455
|
+
with pytest.raises(ValueError):
|
|
456
|
+
deserialize_config_at_path(path)
|
kiln_ai/datamodel/basemodel.py
CHANGED
|
@@ -2,22 +2,17 @@ import json
|
|
|
2
2
|
import os
|
|
3
3
|
import re
|
|
4
4
|
import shutil
|
|
5
|
+
import unicodedata
|
|
5
6
|
import uuid
|
|
6
7
|
from abc import ABCMeta
|
|
7
8
|
from builtins import classmethod
|
|
8
9
|
from datetime import datetime
|
|
9
10
|
from pathlib import Path
|
|
10
|
-
from typing import
|
|
11
|
-
Any,
|
|
12
|
-
Dict,
|
|
13
|
-
List,
|
|
14
|
-
Optional,
|
|
15
|
-
Type,
|
|
16
|
-
TypeVar,
|
|
17
|
-
)
|
|
11
|
+
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar
|
|
18
12
|
|
|
19
13
|
from pydantic import (
|
|
20
14
|
BaseModel,
|
|
15
|
+
BeforeValidator,
|
|
21
16
|
ConfigDict,
|
|
22
17
|
Field,
|
|
23
18
|
ValidationError,
|
|
@@ -26,7 +21,7 @@ from pydantic import (
|
|
|
26
21
|
model_validator,
|
|
27
22
|
)
|
|
28
23
|
from pydantic_core import ErrorDetails
|
|
29
|
-
from typing_extensions import Self
|
|
24
|
+
from typing_extensions import Annotated, Self
|
|
30
25
|
|
|
31
26
|
from kiln_ai.datamodel.model_cache import ModelCache
|
|
32
27
|
from kiln_ai.utils.config import Config
|
|
@@ -44,33 +39,64 @@ PT = TypeVar("PT", bound="KilnParentedModel")
|
|
|
44
39
|
|
|
45
40
|
# Naming conventions:
|
|
46
41
|
# 1) Names are filename safe as they may be used as file names. They are informational and not to be used in prompts/training/validation.
|
|
47
|
-
# 2)
|
|
48
|
-
|
|
49
|
-
#
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
)
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
42
|
+
# 2) Descriptions are for Kiln users to describe/understanding the purpose of this object. They must never be used in prompts/training/validation. Use "instruction/requirements" instead.
|
|
43
|
+
|
|
44
|
+
# Forbidden chars are not allowed in filenames on one or more platforms.
|
|
45
|
+
# ref: https://en.wikipedia.org/wiki/Filename#Problematic_characters
|
|
46
|
+
FORBIDDEN_CHARS_REGEX = r"[/\\?%*:|\"<>.,;=\n]"
|
|
47
|
+
FORBIDDEN_CHARS = "/ \\ ? % * : | < > . , ; = \\n"
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def name_validator(*, min_length: int, max_length: int) -> Callable[[Any], str]:
|
|
51
|
+
def fn(name: Any) -> str:
|
|
52
|
+
if name is None:
|
|
53
|
+
raise ValueError("Name is required")
|
|
54
|
+
if not isinstance(name, str):
|
|
55
|
+
raise ValueError(f"Input should be a valid string, got {type(name)}")
|
|
56
|
+
if len(name) < min_length:
|
|
57
|
+
raise ValueError(
|
|
58
|
+
f"Name is too short. Min length is {min_length} characters, got {len(name)}"
|
|
59
|
+
)
|
|
60
|
+
if len(name) > max_length:
|
|
61
|
+
raise ValueError(
|
|
62
|
+
f"Name is too long. Max length is {max_length} characters, got {len(name)}"
|
|
63
|
+
)
|
|
64
|
+
if string_to_valid_name(name) != name:
|
|
65
|
+
raise ValueError(
|
|
66
|
+
f"Name is invalid. The name cannot contain any of the following characters: {FORBIDDEN_CHARS}, consecutive whitespace/underscores, or leading/trailing whitespace/underscores"
|
|
67
|
+
)
|
|
68
|
+
return name
|
|
69
|
+
|
|
70
|
+
return fn
|
|
63
71
|
|
|
64
72
|
|
|
65
73
|
def string_to_valid_name(name: str) -> str:
|
|
66
|
-
#
|
|
67
|
-
valid_name =
|
|
74
|
+
# https://docs.python.org/3/library/unicodedata.html#unicodedata.normalize
|
|
75
|
+
valid_name = unicodedata.normalize("NFKD", name)
|
|
76
|
+
# Replace any forbidden chars with an underscore
|
|
77
|
+
valid_name = re.sub(FORBIDDEN_CHARS_REGEX, "_", valid_name)
|
|
78
|
+
# Replace control characters with an underscore
|
|
79
|
+
valid_name = re.sub(r"[\x00-\x1F]", "_", valid_name)
|
|
80
|
+
# Replace consecutive whitespace with a single space
|
|
81
|
+
valid_name = re.sub(r"\s+", " ", valid_name)
|
|
68
82
|
# Replace consecutive underscores with a single underscore
|
|
69
83
|
valid_name = re.sub(r"_+", "_", valid_name)
|
|
70
84
|
# Remove leading and trailing underscores or whitespace
|
|
71
85
|
return valid_name.strip("_").strip()
|
|
72
86
|
|
|
73
87
|
|
|
88
|
+
# Usage:
|
|
89
|
+
# class MyModel(KilnBaseModel):
|
|
90
|
+
# name: FilenameString = Field(description="The name of the model.")
|
|
91
|
+
# name_short: FilenameStringShort = Field(description="The short name of the model.")
|
|
92
|
+
FilenameString = Annotated[
|
|
93
|
+
str, BeforeValidator(name_validator(min_length=1, max_length=120))
|
|
94
|
+
]
|
|
95
|
+
FilenameStringShort = Annotated[
|
|
96
|
+
str, BeforeValidator(name_validator(min_length=1, max_length=32))
|
|
97
|
+
]
|
|
98
|
+
|
|
99
|
+
|
|
74
100
|
class KilnBaseModel(BaseModel):
|
|
75
101
|
"""Base model for all Kiln data models with common functionality for persistence and versioning.
|
|
76
102
|
|
|
@@ -470,7 +496,7 @@ class KilnParentModel(KilnBaseModel, metaclass=ABCMeta):
|
|
|
470
496
|
ValidationError: If validation fails for the model or any of its children
|
|
471
497
|
"""
|
|
472
498
|
# Validate first, then save. Don't want error half way through, and partly persisted
|
|
473
|
-
#
|
|
499
|
+
# We should save to a tmp dir and move atomically, but need to merge directories later.
|
|
474
500
|
cls._validate_nested(data, save=False, path=path, parent=parent)
|
|
475
501
|
instance = cls._validate_nested(data, save=True, path=path, parent=parent)
|
|
476
502
|
return instance
|