kiln-ai 0.20.1__py3-none-any.whl → 0.22.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +6 -0
- kiln_ai/adapters/adapter_registry.py +43 -226
- kiln_ai/adapters/chunkers/__init__.py +13 -0
- kiln_ai/adapters/chunkers/base_chunker.py +42 -0
- kiln_ai/adapters/chunkers/chunker_registry.py +16 -0
- kiln_ai/adapters/chunkers/fixed_window_chunker.py +39 -0
- kiln_ai/adapters/chunkers/helpers.py +23 -0
- kiln_ai/adapters/chunkers/test_base_chunker.py +63 -0
- kiln_ai/adapters/chunkers/test_chunker_registry.py +28 -0
- kiln_ai/adapters/chunkers/test_fixed_window_chunker.py +346 -0
- kiln_ai/adapters/chunkers/test_helpers.py +75 -0
- kiln_ai/adapters/data_gen/test_data_gen_task.py +9 -3
- kiln_ai/adapters/embedding/__init__.py +0 -0
- kiln_ai/adapters/embedding/base_embedding_adapter.py +44 -0
- kiln_ai/adapters/embedding/embedding_registry.py +32 -0
- kiln_ai/adapters/embedding/litellm_embedding_adapter.py +199 -0
- kiln_ai/adapters/embedding/test_base_embedding_adapter.py +283 -0
- kiln_ai/adapters/embedding/test_embedding_registry.py +166 -0
- kiln_ai/adapters/embedding/test_litellm_embedding_adapter.py +1149 -0
- kiln_ai/adapters/eval/eval_runner.py +6 -2
- kiln_ai/adapters/eval/test_base_eval.py +1 -3
- kiln_ai/adapters/eval/test_g_eval.py +1 -1
- kiln_ai/adapters/extractors/__init__.py +18 -0
- kiln_ai/adapters/extractors/base_extractor.py +72 -0
- kiln_ai/adapters/extractors/encoding.py +20 -0
- kiln_ai/adapters/extractors/extractor_registry.py +44 -0
- kiln_ai/adapters/extractors/extractor_runner.py +112 -0
- kiln_ai/adapters/extractors/litellm_extractor.py +406 -0
- kiln_ai/adapters/extractors/test_base_extractor.py +244 -0
- kiln_ai/adapters/extractors/test_encoding.py +54 -0
- kiln_ai/adapters/extractors/test_extractor_registry.py +181 -0
- kiln_ai/adapters/extractors/test_extractor_runner.py +181 -0
- kiln_ai/adapters/extractors/test_litellm_extractor.py +1290 -0
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +2 -2
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +2 -6
- kiln_ai/adapters/fine_tune/test_together_finetune.py +2 -6
- kiln_ai/adapters/ml_embedding_model_list.py +494 -0
- kiln_ai/adapters/ml_model_list.py +876 -18
- kiln_ai/adapters/model_adapters/litellm_adapter.py +40 -75
- kiln_ai/adapters/model_adapters/test_litellm_adapter.py +79 -1
- kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py +119 -5
- kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +9 -3
- kiln_ai/adapters/model_adapters/test_structured_output.py +9 -10
- kiln_ai/adapters/ollama_tools.py +69 -12
- kiln_ai/adapters/provider_tools.py +190 -46
- kiln_ai/adapters/rag/deduplication.py +49 -0
- kiln_ai/adapters/rag/progress.py +252 -0
- kiln_ai/adapters/rag/rag_runners.py +844 -0
- kiln_ai/adapters/rag/test_deduplication.py +195 -0
- kiln_ai/adapters/rag/test_progress.py +785 -0
- kiln_ai/adapters/rag/test_rag_runners.py +2376 -0
- kiln_ai/adapters/remote_config.py +80 -8
- kiln_ai/adapters/test_adapter_registry.py +579 -86
- kiln_ai/adapters/test_ml_embedding_model_list.py +239 -0
- kiln_ai/adapters/test_ml_model_list.py +202 -0
- kiln_ai/adapters/test_ollama_tools.py +340 -1
- kiln_ai/adapters/test_prompt_builders.py +1 -1
- kiln_ai/adapters/test_provider_tools.py +199 -8
- kiln_ai/adapters/test_remote_config.py +551 -56
- kiln_ai/adapters/vector_store/__init__.py +1 -0
- kiln_ai/adapters/vector_store/base_vector_store_adapter.py +83 -0
- kiln_ai/adapters/vector_store/lancedb_adapter.py +389 -0
- kiln_ai/adapters/vector_store/test_base_vector_store.py +160 -0
- kiln_ai/adapters/vector_store/test_lancedb_adapter.py +1841 -0
- kiln_ai/adapters/vector_store/test_vector_store_registry.py +199 -0
- kiln_ai/adapters/vector_store/vector_store_registry.py +33 -0
- kiln_ai/datamodel/__init__.py +16 -13
- kiln_ai/datamodel/basemodel.py +201 -4
- kiln_ai/datamodel/chunk.py +158 -0
- kiln_ai/datamodel/datamodel_enums.py +27 -0
- kiln_ai/datamodel/embedding.py +64 -0
- kiln_ai/datamodel/external_tool_server.py +206 -54
- kiln_ai/datamodel/extraction.py +317 -0
- kiln_ai/datamodel/project.py +33 -1
- kiln_ai/datamodel/rag.py +79 -0
- kiln_ai/datamodel/task.py +5 -0
- kiln_ai/datamodel/task_output.py +41 -11
- kiln_ai/datamodel/test_attachment.py +649 -0
- kiln_ai/datamodel/test_basemodel.py +270 -14
- kiln_ai/datamodel/test_chunk_models.py +317 -0
- kiln_ai/datamodel/test_dataset_split.py +1 -1
- kiln_ai/datamodel/test_datasource.py +50 -0
- kiln_ai/datamodel/test_embedding_models.py +448 -0
- kiln_ai/datamodel/test_eval_model.py +6 -6
- kiln_ai/datamodel/test_external_tool_server.py +534 -152
- kiln_ai/datamodel/test_extraction_chunk.py +206 -0
- kiln_ai/datamodel/test_extraction_model.py +501 -0
- kiln_ai/datamodel/test_rag.py +641 -0
- kiln_ai/datamodel/test_task.py +35 -1
- kiln_ai/datamodel/test_tool_id.py +187 -1
- kiln_ai/datamodel/test_vector_store.py +320 -0
- kiln_ai/datamodel/tool_id.py +58 -0
- kiln_ai/datamodel/vector_store.py +141 -0
- kiln_ai/tools/base_tool.py +12 -3
- kiln_ai/tools/built_in_tools/math_tools.py +12 -4
- kiln_ai/tools/kiln_task_tool.py +158 -0
- kiln_ai/tools/mcp_server_tool.py +2 -2
- kiln_ai/tools/mcp_session_manager.py +51 -22
- kiln_ai/tools/rag_tools.py +164 -0
- kiln_ai/tools/test_kiln_task_tool.py +527 -0
- kiln_ai/tools/test_mcp_server_tool.py +4 -15
- kiln_ai/tools/test_mcp_session_manager.py +187 -227
- kiln_ai/tools/test_rag_tools.py +929 -0
- kiln_ai/tools/test_tool_registry.py +290 -7
- kiln_ai/tools/tool_registry.py +69 -16
- kiln_ai/utils/__init__.py +3 -0
- kiln_ai/utils/async_job_runner.py +62 -17
- kiln_ai/utils/config.py +2 -2
- kiln_ai/utils/env.py +15 -0
- kiln_ai/utils/filesystem.py +14 -0
- kiln_ai/utils/filesystem_cache.py +60 -0
- kiln_ai/utils/litellm.py +94 -0
- kiln_ai/utils/lock.py +100 -0
- kiln_ai/utils/mime_type.py +38 -0
- kiln_ai/utils/open_ai_types.py +19 -2
- kiln_ai/utils/pdf_utils.py +59 -0
- kiln_ai/utils/test_async_job_runner.py +151 -35
- kiln_ai/utils/test_env.py +142 -0
- kiln_ai/utils/test_filesystem_cache.py +316 -0
- kiln_ai/utils/test_litellm.py +206 -0
- kiln_ai/utils/test_lock.py +185 -0
- kiln_ai/utils/test_mime_type.py +66 -0
- kiln_ai/utils/test_open_ai_types.py +88 -12
- kiln_ai/utils/test_pdf_utils.py +86 -0
- kiln_ai/utils/test_uuid.py +111 -0
- kiln_ai/utils/test_validation.py +524 -0
- kiln_ai/utils/uuid.py +9 -0
- kiln_ai/utils/validation.py +90 -0
- {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/METADATA +9 -1
- kiln_ai-0.22.0.dist-info/RECORD +213 -0
- kiln_ai-0.20.1.dist-info/RECORD +0 -138
- {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -6,22 +6,33 @@ from unittest.mock import patch
|
|
|
6
6
|
|
|
7
7
|
import pytest
|
|
8
8
|
|
|
9
|
+
from kiln_ai.adapters.ml_embedding_model_list import (
|
|
10
|
+
EmbeddingModelName,
|
|
11
|
+
KilnEmbeddingModel,
|
|
12
|
+
KilnEmbeddingModelFamily,
|
|
13
|
+
KilnEmbeddingModelProvider,
|
|
14
|
+
built_in_embedding_models,
|
|
15
|
+
)
|
|
9
16
|
from kiln_ai.adapters.ml_model_list import (
|
|
10
17
|
KilnModel,
|
|
11
18
|
KilnModelProvider,
|
|
12
19
|
ModelFamily,
|
|
13
20
|
ModelName,
|
|
14
|
-
ModelProviderName,
|
|
15
|
-
StructuredOutputMode,
|
|
16
21
|
built_in_models,
|
|
17
22
|
)
|
|
18
23
|
from kiln_ai.adapters.remote_config import (
|
|
24
|
+
KilnRemoteConfig,
|
|
19
25
|
deserialize_config_at_path,
|
|
20
26
|
dump_builtin_config,
|
|
21
27
|
load_from_url,
|
|
22
28
|
load_remote_models,
|
|
23
29
|
serialize_config,
|
|
24
30
|
)
|
|
31
|
+
from kiln_ai.datamodel.datamodel_enums import (
|
|
32
|
+
KilnMimeType,
|
|
33
|
+
ModelProviderName,
|
|
34
|
+
StructuredOutputMode,
|
|
35
|
+
)
|
|
25
36
|
|
|
26
37
|
|
|
27
38
|
@pytest.fixture
|
|
@@ -55,39 +66,69 @@ def mock_model() -> KilnModel:
|
|
|
55
66
|
)
|
|
56
67
|
|
|
57
68
|
|
|
69
|
+
@pytest.fixture
|
|
70
|
+
def mock_embedding_model() -> KilnEmbeddingModel:
|
|
71
|
+
return KilnEmbeddingModel(
|
|
72
|
+
family=KilnEmbeddingModelFamily.openai,
|
|
73
|
+
name=EmbeddingModelName.openai_text_embedding_3_small,
|
|
74
|
+
friendly_name="text-embedding-3-small",
|
|
75
|
+
providers=[
|
|
76
|
+
KilnEmbeddingModelProvider(
|
|
77
|
+
name=ModelProviderName.openai,
|
|
78
|
+
model_id="text-embedding-3-small",
|
|
79
|
+
n_dimensions=1536,
|
|
80
|
+
max_input_tokens=8192,
|
|
81
|
+
supports_custom_dimensions=True,
|
|
82
|
+
),
|
|
83
|
+
],
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
|
|
58
87
|
def test_round_trip(tmp_path):
|
|
59
88
|
path = tmp_path / "models.json"
|
|
60
|
-
serialize_config(built_in_models, path)
|
|
89
|
+
serialize_config(built_in_models, built_in_embedding_models, path)
|
|
61
90
|
loaded = deserialize_config_at_path(path)
|
|
62
|
-
assert [m.model_dump(mode="json") for m in loaded] == [
|
|
91
|
+
assert [m.model_dump(mode="json") for m in loaded.model_list] == [
|
|
63
92
|
m.model_dump(mode="json") for m in built_in_models
|
|
64
93
|
]
|
|
94
|
+
assert [m.model_dump(mode="json") for m in loaded.embedding_model_list] == [
|
|
95
|
+
m.model_dump(mode="json") for m in built_in_embedding_models
|
|
96
|
+
]
|
|
65
97
|
|
|
66
98
|
|
|
67
|
-
def test_load_from_url(mock_model):
|
|
99
|
+
def test_load_from_url(mock_model, mock_embedding_model):
|
|
68
100
|
sample_model = mock_model
|
|
69
101
|
sample = [sample_model.model_dump(mode="json")]
|
|
102
|
+
sample_embedding_model = mock_embedding_model
|
|
103
|
+
sample_embedding = [sample_embedding_model.model_dump(mode="json")]
|
|
70
104
|
|
|
71
105
|
class FakeResponse:
|
|
72
106
|
def raise_for_status(self):
|
|
73
107
|
pass
|
|
74
108
|
|
|
75
109
|
def json(self):
|
|
76
|
-
return {"model_list": sample}
|
|
110
|
+
return {"model_list": sample, "embedding_model_list": sample_embedding}
|
|
77
111
|
|
|
78
112
|
with patch(
|
|
79
113
|
"kiln_ai.adapters.remote_config.requests.get", return_value=FakeResponse()
|
|
80
114
|
):
|
|
81
|
-
|
|
115
|
+
remote_config = load_from_url("http://example.com/models.json")
|
|
82
116
|
|
|
83
|
-
assert len(
|
|
84
|
-
assert sample_model ==
|
|
117
|
+
assert len(remote_config.model_list) == 1
|
|
118
|
+
assert sample_model == remote_config.model_list[0]
|
|
119
|
+
|
|
120
|
+
assert len(remote_config.embedding_model_list) == 1
|
|
121
|
+
assert sample_embedding_model == remote_config.embedding_model_list[0]
|
|
85
122
|
|
|
86
123
|
|
|
87
|
-
def test_load_from_url_calls_deserialize_config_data(mock_model):
|
|
124
|
+
def test_load_from_url_calls_deserialize_config_data(mock_model, mock_embedding_model):
|
|
88
125
|
"""Test that load_from_url calls deserialize_config_data with the model_list from the response."""
|
|
89
126
|
sample_model_data = [mock_model.model_dump(mode="json")]
|
|
90
|
-
|
|
127
|
+
sample_embedding_model_data = [mock_embedding_model.model_dump(mode="json")]
|
|
128
|
+
response_data = {
|
|
129
|
+
"model_list": sample_model_data,
|
|
130
|
+
"embedding_model_list": sample_embedding_model_data,
|
|
131
|
+
}
|
|
91
132
|
|
|
92
133
|
class FakeResponse:
|
|
93
134
|
def raise_for_status(self):
|
|
@@ -104,7 +145,10 @@ def test_load_from_url_calls_deserialize_config_data(mock_model):
|
|
|
104
145
|
"kiln_ai.adapters.remote_config.deserialize_config_data"
|
|
105
146
|
) as mock_deserialize,
|
|
106
147
|
):
|
|
107
|
-
mock_deserialize.return_value =
|
|
148
|
+
mock_deserialize.return_value = KilnRemoteConfig(
|
|
149
|
+
model_list=[mock_model],
|
|
150
|
+
embedding_model_list=[mock_embedding_model],
|
|
151
|
+
)
|
|
108
152
|
|
|
109
153
|
result = load_from_url("http://example.com/models.json")
|
|
110
154
|
|
|
@@ -115,47 +159,81 @@ def test_load_from_url_calls_deserialize_config_data(mock_model):
|
|
|
115
159
|
mock_deserialize.assert_called_once_with(response_data)
|
|
116
160
|
|
|
117
161
|
# Verify the result is what deserialize_config_data returned
|
|
118
|
-
assert result == [mock_model]
|
|
162
|
+
assert result.model_list == [mock_model]
|
|
163
|
+
assert result.embedding_model_list == [mock_embedding_model]
|
|
119
164
|
|
|
120
165
|
|
|
121
166
|
def test_dump_builtin_config(tmp_path):
|
|
122
167
|
path = tmp_path / "out.json"
|
|
123
168
|
dump_builtin_config(path)
|
|
124
169
|
loaded = deserialize_config_at_path(path)
|
|
125
|
-
assert [m.model_dump(mode="json") for m in loaded] == [
|
|
170
|
+
assert [m.model_dump(mode="json") for m in loaded.model_list] == [
|
|
126
171
|
m.model_dump(mode="json") for m in built_in_models
|
|
127
172
|
]
|
|
173
|
+
assert [m.model_dump(mode="json") for m in loaded.embedding_model_list] == [
|
|
174
|
+
m.model_dump(mode="json") for m in built_in_embedding_models
|
|
175
|
+
]
|
|
128
176
|
|
|
129
177
|
|
|
130
|
-
|
|
131
|
-
|
|
178
|
+
async def test_load_remote_models_success(
|
|
179
|
+
monkeypatch, mock_model, mock_embedding_model
|
|
180
|
+
):
|
|
132
181
|
del os.environ["KILN_SKIP_REMOTE_MODEL_LIST"]
|
|
133
|
-
original = built_in_models.copy()
|
|
134
182
|
sample_models = [mock_model]
|
|
183
|
+
sample_embedding_models = [mock_embedding_model]
|
|
135
184
|
|
|
136
|
-
|
|
137
|
-
|
|
185
|
+
# Save original state to restore later
|
|
186
|
+
original_models = built_in_models.copy()
|
|
187
|
+
original_embedding = built_in_embedding_models.copy()
|
|
188
|
+
|
|
189
|
+
try:
|
|
190
|
+
# Mock the load_from_url function to return our test data
|
|
191
|
+
def mock_load_from_url(url):
|
|
192
|
+
return KilnRemoteConfig(
|
|
193
|
+
model_list=sample_models,
|
|
194
|
+
embedding_model_list=sample_embedding_models,
|
|
195
|
+
)
|
|
138
196
|
|
|
139
|
-
|
|
197
|
+
# Mock the function call
|
|
198
|
+
with patch(
|
|
199
|
+
"kiln_ai.adapters.remote_config.load_from_url",
|
|
200
|
+
side_effect=mock_load_from_url,
|
|
201
|
+
):
|
|
202
|
+
# Call the function
|
|
203
|
+
load_remote_models("http://example.com/models.json")
|
|
140
204
|
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
205
|
+
# Wait for the thread to complete
|
|
206
|
+
await asyncio.sleep(0.1)
|
|
207
|
+
|
|
208
|
+
# Verify the global state was modified as expected
|
|
209
|
+
assert built_in_models == sample_models
|
|
210
|
+
assert built_in_embedding_models == sample_embedding_models
|
|
211
|
+
finally:
|
|
212
|
+
# Restore original state to prevent test pollution
|
|
213
|
+
built_in_models[:] = original_models
|
|
214
|
+
built_in_embedding_models[:] = original_embedding
|
|
145
215
|
|
|
146
216
|
|
|
147
217
|
@pytest.mark.asyncio
|
|
148
218
|
async def test_load_remote_models_failure(monkeypatch):
|
|
149
|
-
|
|
219
|
+
# Ensure the environment variable is not set to skip remote model loading
|
|
220
|
+
monkeypatch.delenv("KILN_SKIP_REMOTE_MODEL_LIST", raising=False)
|
|
221
|
+
|
|
222
|
+
original_models = built_in_models.copy()
|
|
223
|
+
original_embedding = built_in_embedding_models.copy()
|
|
150
224
|
|
|
151
225
|
def fake_fetch(url):
|
|
152
226
|
raise RuntimeError("fail")
|
|
153
227
|
|
|
154
|
-
monkeypatch.setattr("kiln_ai.adapters.remote_config.
|
|
228
|
+
monkeypatch.setattr("kiln_ai.adapters.remote_config.requests.get", fake_fetch)
|
|
155
229
|
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
230
|
+
with patch("kiln_ai.adapters.remote_config.logger") as mock_logger:
|
|
231
|
+
load_remote_models("http://example.com/models.json")
|
|
232
|
+
assert built_in_models == original_models
|
|
233
|
+
assert built_in_embedding_models == original_embedding
|
|
234
|
+
|
|
235
|
+
# assert that logger.warning was called
|
|
236
|
+
mock_logger.warning.assert_called_once()
|
|
159
237
|
|
|
160
238
|
|
|
161
239
|
def test_deserialize_config_with_extra_keys(tmp_path, mock_model):
|
|
@@ -163,15 +241,146 @@ def test_deserialize_config_with_extra_keys(tmp_path, mock_model):
|
|
|
163
241
|
model_dict = mock_model.model_dump(mode="json")
|
|
164
242
|
model_dict["extra_key"] = "should be ignored or error"
|
|
165
243
|
model_dict["providers"][0]["extra_key"] = "should be ignored or error"
|
|
166
|
-
|
|
244
|
+
|
|
245
|
+
embedding_model_dict = built_in_embedding_models[0].model_dump(mode="json")
|
|
246
|
+
embedding_model_dict["extra_key"] = "should be ignored or error"
|
|
247
|
+
embedding_model_dict["providers"][0]["extra_key"] = "should be ignored or error"
|
|
248
|
+
|
|
249
|
+
data = {"model_list": [model_dict], "embedding_model_list": [embedding_model_dict]}
|
|
167
250
|
path = tmp_path / "extra.json"
|
|
168
251
|
path.write_text(json.dumps(data))
|
|
169
252
|
# Should NOT raise, and extra key should be ignored
|
|
170
253
|
models = deserialize_config_at_path(path)
|
|
171
|
-
assert hasattr(models[0], "family")
|
|
172
|
-
assert not hasattr(models[0], "extra_key")
|
|
173
|
-
assert hasattr(models[0], "providers")
|
|
174
|
-
assert not hasattr(models[0].providers[0], "extra_key")
|
|
254
|
+
assert hasattr(models.model_list[0], "family")
|
|
255
|
+
assert not hasattr(models.model_list[0], "extra_key")
|
|
256
|
+
assert hasattr(models.model_list[0], "providers")
|
|
257
|
+
assert not hasattr(models.model_list[0].providers[0], "extra_key")
|
|
258
|
+
assert hasattr(models.embedding_model_list[0], "family")
|
|
259
|
+
assert not hasattr(models.embedding_model_list[0], "extra_key")
|
|
260
|
+
assert hasattr(models.embedding_model_list[0], "providers")
|
|
261
|
+
assert not hasattr(models.embedding_model_list[0].providers[0], "extra_key")
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
def test_multimodal_fields_specified(tmp_path):
|
|
265
|
+
model_dict = KilnModel(
|
|
266
|
+
family=ModelFamily.gpt,
|
|
267
|
+
name=ModelName.gpt_4o,
|
|
268
|
+
friendly_name="GPT-mock",
|
|
269
|
+
providers=[
|
|
270
|
+
KilnModelProvider(
|
|
271
|
+
name=ModelProviderName.openai,
|
|
272
|
+
model_id="gpt-4o",
|
|
273
|
+
structured_output_mode=StructuredOutputMode.json_schema,
|
|
274
|
+
supports_doc_extraction=True,
|
|
275
|
+
multimodal_capable=True,
|
|
276
|
+
multimodal_mime_types=[
|
|
277
|
+
KilnMimeType.JPEG,
|
|
278
|
+
KilnMimeType.PNG,
|
|
279
|
+
],
|
|
280
|
+
),
|
|
281
|
+
],
|
|
282
|
+
).model_dump(mode="json")
|
|
283
|
+
|
|
284
|
+
data = {"model_list": [model_dict], "embedding_model_list": []}
|
|
285
|
+
path = tmp_path / "extra.json"
|
|
286
|
+
path.write_text(json.dumps(data))
|
|
287
|
+
models = deserialize_config_at_path(path)
|
|
288
|
+
assert models.model_list[0].providers[0].supports_doc_extraction
|
|
289
|
+
assert models.model_list[0].providers[0].multimodal_capable
|
|
290
|
+
assert models.model_list[0].providers[0].multimodal_mime_types == [
|
|
291
|
+
KilnMimeType.JPEG,
|
|
292
|
+
KilnMimeType.PNG,
|
|
293
|
+
]
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
def test_multimodal_fields_mime_type_forward_compat(tmp_path):
|
|
297
|
+
# This may happen if the current client is out of date with the remote config
|
|
298
|
+
# and we add a new mime type that the old client gets over the air
|
|
299
|
+
model_dict = KilnModel(
|
|
300
|
+
family=ModelFamily.gpt,
|
|
301
|
+
name=ModelName.gpt_4o,
|
|
302
|
+
friendly_name="GPT-mock",
|
|
303
|
+
providers=[
|
|
304
|
+
KilnModelProvider(
|
|
305
|
+
name=ModelProviderName.openai,
|
|
306
|
+
model_id="gpt-4o",
|
|
307
|
+
structured_output_mode=StructuredOutputMode.json_schema,
|
|
308
|
+
supports_doc_extraction=True,
|
|
309
|
+
multimodal_capable=True,
|
|
310
|
+
multimodal_mime_types=[
|
|
311
|
+
KilnMimeType.JPEG,
|
|
312
|
+
KilnMimeType.PNG,
|
|
313
|
+
"new/unknown-mime-type",
|
|
314
|
+
],
|
|
315
|
+
),
|
|
316
|
+
],
|
|
317
|
+
).model_dump(mode="json")
|
|
318
|
+
|
|
319
|
+
data = {"model_list": [model_dict], "embedding_model_list": []}
|
|
320
|
+
path = tmp_path / "extra.json"
|
|
321
|
+
path.write_text(json.dumps(data))
|
|
322
|
+
models = deserialize_config_at_path(path)
|
|
323
|
+
assert models.model_list[0].providers[0].supports_doc_extraction
|
|
324
|
+
assert models.model_list[0].providers[0].multimodal_capable
|
|
325
|
+
multimodal_mime_types = models.model_list[0].providers[0].multimodal_mime_types
|
|
326
|
+
assert multimodal_mime_types is not None
|
|
327
|
+
assert "new/unknown-mime-type" not in multimodal_mime_types
|
|
328
|
+
assert multimodal_mime_types == [
|
|
329
|
+
KilnMimeType.JPEG,
|
|
330
|
+
KilnMimeType.PNG,
|
|
331
|
+
]
|
|
332
|
+
|
|
333
|
+
|
|
334
|
+
def test_multimodal_fields_not_specified(tmp_path):
|
|
335
|
+
model_dict = KilnModel(
|
|
336
|
+
family=ModelFamily.gpt,
|
|
337
|
+
name=ModelName.gpt_4o,
|
|
338
|
+
friendly_name="GPT-mock",
|
|
339
|
+
providers=[
|
|
340
|
+
KilnModelProvider(
|
|
341
|
+
name=ModelProviderName.openai,
|
|
342
|
+
model_id="gpt-4o",
|
|
343
|
+
structured_output_mode=StructuredOutputMode.json_schema,
|
|
344
|
+
),
|
|
345
|
+
],
|
|
346
|
+
).model_dump(mode="json")
|
|
347
|
+
|
|
348
|
+
embedding_model_dict = KilnEmbeddingModel(
|
|
349
|
+
family=KilnEmbeddingModelFamily.openai,
|
|
350
|
+
name=EmbeddingModelName.openai_text_embedding_3_small,
|
|
351
|
+
friendly_name="text-embedding-3-small",
|
|
352
|
+
providers=[
|
|
353
|
+
KilnEmbeddingModelProvider(
|
|
354
|
+
name=ModelProviderName.openai,
|
|
355
|
+
model_id="text-embedding-3-small",
|
|
356
|
+
n_dimensions=1536,
|
|
357
|
+
max_input_tokens=8192,
|
|
358
|
+
supports_custom_dimensions=True,
|
|
359
|
+
),
|
|
360
|
+
],
|
|
361
|
+
).model_dump(mode="json")
|
|
362
|
+
|
|
363
|
+
data = {"model_list": [model_dict], "embedding_model_list": [embedding_model_dict]}
|
|
364
|
+
path = tmp_path / "extra.json"
|
|
365
|
+
path.write_text(json.dumps(data))
|
|
366
|
+
remote_config = deserialize_config_at_path(path)
|
|
367
|
+
|
|
368
|
+
models = remote_config.model_list
|
|
369
|
+
assert not models[0].providers[0].supports_doc_extraction
|
|
370
|
+
assert not models[0].providers[0].multimodal_capable
|
|
371
|
+
assert models[0].providers[0].multimodal_mime_types is None
|
|
372
|
+
|
|
373
|
+
embedding_models = remote_config.embedding_model_list
|
|
374
|
+
assert len(embedding_models) == 1
|
|
375
|
+
assert embedding_models[0].family == KilnEmbeddingModelFamily.openai
|
|
376
|
+
assert embedding_models[0].name == EmbeddingModelName.openai_text_embedding_3_small
|
|
377
|
+
assert embedding_models[0].friendly_name == "text-embedding-3-small"
|
|
378
|
+
assert len(embedding_models[0].providers) == 1
|
|
379
|
+
assert embedding_models[0].providers[0].name == ModelProviderName.openai
|
|
380
|
+
assert embedding_models[0].providers[0].model_id == "text-embedding-3-small"
|
|
381
|
+
assert embedding_models[0].providers[0].n_dimensions == 1536
|
|
382
|
+
assert embedding_models[0].providers[0].max_input_tokens == 8192
|
|
383
|
+
assert embedding_models[0].providers[0].supports_custom_dimensions
|
|
175
384
|
|
|
176
385
|
|
|
177
386
|
def test_deserialize_config_with_invalid_models(tmp_path, caplog, mock_model):
|
|
@@ -249,7 +458,9 @@ def test_deserialize_config_with_invalid_models(tmp_path, caplog, mock_model):
|
|
|
249
458
|
|
|
250
459
|
# Enable logging to capture warnings
|
|
251
460
|
with caplog.at_level(logging.WARNING):
|
|
252
|
-
|
|
461
|
+
remote_config = deserialize_config_at_path(path)
|
|
462
|
+
|
|
463
|
+
models = remote_config.model_list
|
|
253
464
|
|
|
254
465
|
# Should have 4 valid models (original + 3 with provider issues but valid model structure)
|
|
255
466
|
assert len(models) == 4
|
|
@@ -309,31 +520,214 @@ def test_deserialize_config_with_invalid_models(tmp_path, caplog, mock_model):
|
|
|
309
520
|
) # Exactly 7 invalid providers across different models
|
|
310
521
|
|
|
311
522
|
|
|
312
|
-
def
|
|
523
|
+
def test_deserialize_config_with_invalid_embedding_models(
|
|
524
|
+
tmp_path, caplog, mock_embedding_model
|
|
525
|
+
):
|
|
526
|
+
"""Test comprehensive handling of invalid embedding models and providers during deserialization."""
|
|
527
|
+
|
|
528
|
+
# Create a fully valid embedding model as baseline
|
|
529
|
+
valid_embedding_model = mock_embedding_model.model_dump(mode="json")
|
|
530
|
+
|
|
531
|
+
# Case 1: Invalid embedding model - missing required field 'family'
|
|
532
|
+
invalid_embedding_model_missing_family = mock_embedding_model.model_dump(
|
|
533
|
+
mode="json"
|
|
534
|
+
)
|
|
535
|
+
del invalid_embedding_model_missing_family["family"]
|
|
536
|
+
|
|
537
|
+
# Case 2: Invalid embedding model - invalid data type for required field
|
|
538
|
+
invalid_embedding_model_wrong_type = mock_embedding_model.model_dump(mode="json")
|
|
539
|
+
invalid_embedding_model_wrong_type["name"] = (
|
|
540
|
+
None # name should be a string, not None
|
|
541
|
+
)
|
|
542
|
+
|
|
543
|
+
# Case 3: Invalid embedding model - completely malformed
|
|
544
|
+
invalid_embedding_model_malformed = {"not_a_valid_embedding_model": "at_all"}
|
|
545
|
+
|
|
546
|
+
# Case 4: Valid embedding model with one invalid provider (should keep model, skip invalid provider)
|
|
547
|
+
valid_embedding_model_invalid_provider = mock_embedding_model.model_dump(
|
|
548
|
+
mode="json"
|
|
549
|
+
)
|
|
550
|
+
valid_embedding_model_invalid_provider["name"] = (
|
|
551
|
+
"test_embedding_model_invalid_provider" # Unique name
|
|
552
|
+
)
|
|
553
|
+
valid_embedding_model_invalid_provider["providers"][0]["name"] = (
|
|
554
|
+
"unknown-provider-123"
|
|
555
|
+
)
|
|
556
|
+
|
|
557
|
+
# Case 5: Valid embedding model with mixed valid/invalid providers (should keep model and valid providers)
|
|
558
|
+
valid_embedding_model_mixed_providers = mock_embedding_model.model_dump(mode="json")
|
|
559
|
+
valid_embedding_model_mixed_providers["name"] = (
|
|
560
|
+
"test_embedding_model_mixed_providers" # Unique name
|
|
561
|
+
)
|
|
562
|
+
# Add a second provider that's valid
|
|
563
|
+
valid_provider = valid_embedding_model_mixed_providers["providers"][0].copy()
|
|
564
|
+
valid_provider["name"] = "azure_openai"
|
|
565
|
+
# Make first provider invalid
|
|
566
|
+
valid_embedding_model_mixed_providers["providers"][0]["name"] = "invalid-provider-1"
|
|
567
|
+
# Add invalid provider with missing required field
|
|
568
|
+
invalid_provider = valid_embedding_model_mixed_providers["providers"][0].copy()
|
|
569
|
+
del invalid_provider["name"]
|
|
570
|
+
# Add another invalid provider with wrong type
|
|
571
|
+
invalid_provider_2 = valid_embedding_model_mixed_providers["providers"][0].copy()
|
|
572
|
+
# Use a known boolean field on KilnModelProvider with a wrong type to force a validation error
|
|
573
|
+
invalid_provider_2["supports_structured_output"] = "not_a_boolean"
|
|
574
|
+
|
|
575
|
+
valid_embedding_model_mixed_providers["providers"] = [
|
|
576
|
+
valid_embedding_model_mixed_providers["providers"][0], # invalid name
|
|
577
|
+
valid_provider, # valid
|
|
578
|
+
invalid_provider, # missing name
|
|
579
|
+
invalid_provider_2, # wrong type
|
|
580
|
+
]
|
|
581
|
+
|
|
582
|
+
# Case 6: Valid embedding model with all invalid providers (should keep model with empty providers)
|
|
583
|
+
valid_embedding_model_all_invalid_providers = mock_embedding_model.model_dump(
|
|
584
|
+
mode="json"
|
|
585
|
+
)
|
|
586
|
+
valid_embedding_model_all_invalid_providers["name"] = (
|
|
587
|
+
"test_embedding_model_all_invalid_providers" # Unique name
|
|
588
|
+
)
|
|
589
|
+
valid_embedding_model_all_invalid_providers["providers"][0]["name"] = (
|
|
590
|
+
"unknown-provider-456"
|
|
591
|
+
)
|
|
592
|
+
|
|
593
|
+
data = {
|
|
594
|
+
"model_list": [],
|
|
595
|
+
"embedding_model_list": [
|
|
596
|
+
valid_embedding_model, # Should be kept
|
|
597
|
+
invalid_embedding_model_missing_family, # Should be skipped
|
|
598
|
+
invalid_embedding_model_wrong_type, # Should be skipped
|
|
599
|
+
invalid_embedding_model_malformed, # Should be skipped
|
|
600
|
+
valid_embedding_model_invalid_provider, # Should be kept with empty providers
|
|
601
|
+
valid_embedding_model_mixed_providers, # Should be kept with 1 valid provider
|
|
602
|
+
valid_embedding_model_all_invalid_providers, # Should be kept with empty providers
|
|
603
|
+
],
|
|
604
|
+
}
|
|
605
|
+
path = tmp_path / "mixed_embedding_models.json"
|
|
606
|
+
path.write_text(json.dumps(data))
|
|
607
|
+
|
|
608
|
+
# Enable logging to capture warnings
|
|
609
|
+
with caplog.at_level(logging.WARNING):
|
|
610
|
+
remote_config = deserialize_config_at_path(path)
|
|
611
|
+
|
|
612
|
+
embedding_models = remote_config.embedding_model_list
|
|
613
|
+
|
|
614
|
+
# Should have 4 valid embedding models (original + 3 with provider issues but valid model structure)
|
|
615
|
+
assert len(embedding_models) == 4
|
|
616
|
+
|
|
617
|
+
# Check the first embedding model is fully intact
|
|
618
|
+
assert embedding_models[0].name == mock_embedding_model.name
|
|
619
|
+
assert embedding_models[0].family == mock_embedding_model.family
|
|
620
|
+
assert (
|
|
621
|
+
len(embedding_models[0].providers) == 1
|
|
622
|
+
) # mock_embedding_model has 1 provider
|
|
623
|
+
|
|
624
|
+
# Check embedding model with invalid provider has remaining valid providers
|
|
625
|
+
embedding_model_with_invalid_provider = next(
|
|
626
|
+
m
|
|
627
|
+
for m in embedding_models
|
|
628
|
+
if m.name == valid_embedding_model_invalid_provider["name"]
|
|
629
|
+
)
|
|
630
|
+
# Should have no valid providers since the original only had one and it was invalid
|
|
631
|
+
assert len(embedding_model_with_invalid_provider.providers) == 0
|
|
632
|
+
|
|
633
|
+
# Check embedding model with mixed providers has only the valid one
|
|
634
|
+
embedding_model_with_mixed_providers = next(
|
|
635
|
+
m
|
|
636
|
+
for m in embedding_models
|
|
637
|
+
if m.name == valid_embedding_model_mixed_providers["name"]
|
|
638
|
+
)
|
|
639
|
+
assert len(embedding_model_with_mixed_providers.providers) == 1
|
|
640
|
+
assert (
|
|
641
|
+
embedding_model_with_mixed_providers.providers[0].name.value == "azure_openai"
|
|
642
|
+
)
|
|
643
|
+
|
|
644
|
+
# Check embedding model with all invalid providers has empty providers
|
|
645
|
+
embedding_model_with_all_invalid_providers = next(
|
|
646
|
+
m
|
|
647
|
+
for m in embedding_models
|
|
648
|
+
if m.name == valid_embedding_model_all_invalid_providers["name"]
|
|
649
|
+
)
|
|
650
|
+
assert len(embedding_model_with_all_invalid_providers.providers) == 0
|
|
651
|
+
|
|
652
|
+
# Check warning logs
|
|
653
|
+
warning_logs = [
|
|
654
|
+
record for record in caplog.records if record.levelno == logging.WARNING
|
|
655
|
+
]
|
|
656
|
+
|
|
657
|
+
# Should have warnings for:
|
|
658
|
+
# - 3 invalid embedding models (missing family, wrong type, malformed)
|
|
659
|
+
# - 1 invalid provider in case 4 (unknown-provider-123)
|
|
660
|
+
# - 3 invalid providers in case 5 (invalid-provider-1, missing name, wrong type boolean)
|
|
661
|
+
# - 1 invalid provider in case 6 (unknown-provider-456)
|
|
662
|
+
assert len(warning_logs) >= 8
|
|
663
|
+
|
|
664
|
+
# Check that warning messages contain expected content
|
|
665
|
+
model_warnings = [
|
|
666
|
+
log
|
|
667
|
+
for log in warning_logs
|
|
668
|
+
if "Failed to validate an embedding model from" in log.message
|
|
669
|
+
]
|
|
670
|
+
provider_warnings = [
|
|
671
|
+
log
|
|
672
|
+
for log in warning_logs
|
|
673
|
+
if "Failed to validate an embedding model provider" in log.message
|
|
674
|
+
]
|
|
675
|
+
|
|
676
|
+
assert len(model_warnings) == 3 # 3 completely invalid embedding models
|
|
677
|
+
assert (
|
|
678
|
+
len(provider_warnings) == 5
|
|
679
|
+
) # Exactly 5 invalid providers across different embedding models
|
|
680
|
+
|
|
681
|
+
|
|
682
|
+
def test_deserialize_config_empty_provider_list(
|
|
683
|
+
tmp_path, mock_model, mock_embedding_model
|
|
684
|
+
):
|
|
313
685
|
"""Test that models with empty provider lists are handled correctly."""
|
|
314
686
|
model_with_empty_providers = mock_model.model_dump(mode="json")
|
|
315
687
|
model_with_empty_providers["providers"] = []
|
|
688
|
+
embedding_model_with_empty_providers = mock_embedding_model.model_dump(mode="json")
|
|
689
|
+
embedding_model_with_empty_providers["providers"] = []
|
|
316
690
|
|
|
317
|
-
data = {
|
|
691
|
+
data = {
|
|
692
|
+
"model_list": [model_with_empty_providers],
|
|
693
|
+
"embedding_model_list": [embedding_model_with_empty_providers],
|
|
694
|
+
}
|
|
318
695
|
path = tmp_path / "empty_providers.json"
|
|
319
696
|
path.write_text(json.dumps(data))
|
|
320
697
|
|
|
321
|
-
|
|
698
|
+
remote_config = deserialize_config_at_path(path)
|
|
699
|
+
models = remote_config.model_list
|
|
322
700
|
assert len(models) == 1
|
|
323
701
|
assert len(models[0].providers) == 0
|
|
324
702
|
|
|
703
|
+
embedding_models = remote_config.embedding_model_list
|
|
704
|
+
assert len(embedding_models) == 1
|
|
705
|
+
assert len(embedding_models[0].providers) == 0
|
|
325
706
|
|
|
326
|
-
|
|
707
|
+
|
|
708
|
+
def test_deserialize_config_missing_provider_field(
|
|
709
|
+
tmp_path, caplog, mock_model, mock_embedding_model
|
|
710
|
+
):
|
|
327
711
|
"""Test that models missing the providers field are handled correctly."""
|
|
328
712
|
model_without_providers = mock_model.model_dump(mode="json")
|
|
329
713
|
del model_without_providers["providers"]
|
|
330
714
|
|
|
331
|
-
|
|
715
|
+
embedding_model_without_providers = mock_embedding_model.model_dump(mode="json")
|
|
716
|
+
del embedding_model_without_providers["providers"]
|
|
717
|
+
|
|
718
|
+
data = {
|
|
719
|
+
"model_list": [model_without_providers],
|
|
720
|
+
"embedding_model_list": [embedding_model_without_providers],
|
|
721
|
+
}
|
|
332
722
|
path = tmp_path / "no_providers.json"
|
|
333
723
|
path.write_text(json.dumps(data))
|
|
334
724
|
|
|
335
725
|
with caplog.at_level(logging.WARNING):
|
|
336
|
-
|
|
726
|
+
remote_config = deserialize_config_at_path(path)
|
|
727
|
+
|
|
728
|
+
models = remote_config.model_list
|
|
729
|
+
|
|
730
|
+
embedding_models = remote_config.embedding_model_list
|
|
337
731
|
|
|
338
732
|
# Model should be kept with empty providers (deserialize_config handles missing providers gracefully)
|
|
339
733
|
assert len(models) == 1
|
|
@@ -346,8 +740,20 @@ def test_deserialize_config_missing_provider_field(tmp_path, caplog, mock_model)
|
|
|
346
740
|
]
|
|
347
741
|
assert len(warning_logs) == 0
|
|
348
742
|
|
|
743
|
+
assert len(embedding_models) == 1
|
|
744
|
+
assert len(embedding_models[0].providers) == 0
|
|
745
|
+
assert embedding_models[0].name == mock_embedding_model.name
|
|
349
746
|
|
|
350
|
-
|
|
747
|
+
# Should not have any warnings since the function handles missing providers gracefully
|
|
748
|
+
warning_logs = [
|
|
749
|
+
record for record in caplog.records if record.levelno == logging.WARNING
|
|
750
|
+
]
|
|
751
|
+
assert len(warning_logs) == 0
|
|
752
|
+
|
|
753
|
+
|
|
754
|
+
def test_deserialize_config_provider_with_extra_fields(
|
|
755
|
+
tmp_path, mock_model, mock_embedding_model
|
|
756
|
+
):
|
|
351
757
|
"""Test that providers with extra unknown fields are handled gracefully."""
|
|
352
758
|
model_with_extra_provider_fields = mock_model.model_dump(mode="json")
|
|
353
759
|
model_with_extra_provider_fields["providers"][0]["unknown_field"] = (
|
|
@@ -357,38 +763,80 @@ def test_deserialize_config_provider_with_extra_fields(tmp_path, mock_model):
|
|
|
357
763
|
"nested": "data"
|
|
358
764
|
}
|
|
359
765
|
|
|
360
|
-
|
|
766
|
+
embedding_model_with_extra_provider_fields = mock_embedding_model.model_dump(
|
|
767
|
+
mode="json"
|
|
768
|
+
)
|
|
769
|
+
embedding_model_with_extra_provider_fields["providers"][0]["unknown_field"] = (
|
|
770
|
+
"should_be_ignored"
|
|
771
|
+
)
|
|
772
|
+
embedding_model_with_extra_provider_fields["providers"][0]["another_extra"] = {
|
|
773
|
+
"nested": "data"
|
|
774
|
+
}
|
|
775
|
+
|
|
776
|
+
data = {
|
|
777
|
+
"model_list": [model_with_extra_provider_fields],
|
|
778
|
+
"embedding_model_list": [embedding_model_with_extra_provider_fields],
|
|
779
|
+
}
|
|
361
780
|
path = tmp_path / "extra_provider_fields.json"
|
|
362
781
|
path.write_text(json.dumps(data))
|
|
363
782
|
|
|
364
|
-
|
|
783
|
+
remote_config = deserialize_config_at_path(path)
|
|
784
|
+
models = remote_config.model_list
|
|
785
|
+
embedding_models = remote_config.embedding_model_list
|
|
786
|
+
|
|
365
787
|
assert len(models) == 1
|
|
366
788
|
assert len(models[0].providers) == 3 # mock_model has 3 providers
|
|
367
789
|
# Extra fields should be ignored, not present in the final object
|
|
368
790
|
assert not hasattr(models[0].providers[0], "unknown_field")
|
|
369
791
|
assert not hasattr(models[0].providers[0], "another_extra")
|
|
370
792
|
|
|
793
|
+
assert len(embedding_models) == 1
|
|
794
|
+
assert (
|
|
795
|
+
len(embedding_models[0].providers) == 1
|
|
796
|
+
) # mock_embedding_model has 1 provider
|
|
797
|
+
# Extra fields should be ignored, not present in the final object
|
|
798
|
+
assert not hasattr(embedding_models[0].providers[0], "unknown_field")
|
|
799
|
+
assert not hasattr(embedding_models[0].providers[0], "another_extra")
|
|
800
|
+
|
|
371
801
|
|
|
372
|
-
def test_deserialize_config_model_with_extra_fields(
|
|
802
|
+
def test_deserialize_config_model_with_extra_fields(
|
|
803
|
+
tmp_path, mock_model, mock_embedding_model
|
|
804
|
+
):
|
|
373
805
|
"""Test that models with extra unknown fields are handled gracefully."""
|
|
374
806
|
model_with_extra_fields = mock_model.model_dump(mode="json")
|
|
375
807
|
model_with_extra_fields["future_field"] = "should_be_ignored"
|
|
376
808
|
model_with_extra_fields["complex_extra"] = {"nested": {"data": [1, 2, 3]}}
|
|
377
809
|
|
|
378
|
-
|
|
810
|
+
embedding_model_with_extra_fields = mock_embedding_model.model_dump(mode="json")
|
|
811
|
+
embedding_model_with_extra_fields["future_field"] = "should_be_ignored"
|
|
812
|
+
embedding_model_with_extra_fields["complex_extra"] = {"nested": {"data": [1, 2, 3]}}
|
|
813
|
+
|
|
814
|
+
data = {
|
|
815
|
+
"model_list": [model_with_extra_fields],
|
|
816
|
+
"embedding_model_list": [embedding_model_with_extra_fields],
|
|
817
|
+
}
|
|
379
818
|
path = tmp_path / "extra_model_fields.json"
|
|
380
819
|
path.write_text(json.dumps(data))
|
|
381
820
|
|
|
382
|
-
|
|
821
|
+
remote_config = deserialize_config_at_path(path)
|
|
822
|
+
models = remote_config.model_list
|
|
823
|
+
embedding_models = remote_config.embedding_model_list
|
|
824
|
+
|
|
383
825
|
assert len(models) == 1
|
|
384
826
|
assert models[0].name == mock_model.name
|
|
385
827
|
# Extra fields should be ignored, not present in the final object
|
|
386
828
|
assert not hasattr(models[0], "future_field")
|
|
387
829
|
assert not hasattr(models[0], "complex_extra")
|
|
388
830
|
|
|
831
|
+
assert len(embedding_models) == 1
|
|
832
|
+
assert embedding_models[0].name == mock_embedding_model.name
|
|
833
|
+
# Extra fields should be ignored, not present in the final object
|
|
834
|
+
assert not hasattr(embedding_models[0], "future_field")
|
|
835
|
+
assert not hasattr(embedding_models[0], "complex_extra")
|
|
836
|
+
|
|
389
837
|
|
|
390
838
|
def test_deserialize_config_mixed_valid_invalid_providers_single_model(
|
|
391
|
-
tmp_path, caplog, mock_model
|
|
839
|
+
tmp_path, caplog, mock_model, mock_embedding_model
|
|
392
840
|
):
|
|
393
841
|
"""Test a single model with a mix of valid and invalid providers in detail."""
|
|
394
842
|
model = mock_model.model_dump(mode="json")
|
|
@@ -417,12 +865,42 @@ def test_deserialize_config_mixed_valid_invalid_providers_single_model(
|
|
|
417
865
|
invalid_provider_wrong_type,
|
|
418
866
|
]
|
|
419
867
|
|
|
420
|
-
|
|
868
|
+
# Create embedding model with mixed valid/invalid providers
|
|
869
|
+
embedding_model = mock_embedding_model.model_dump(mode="json")
|
|
870
|
+
|
|
871
|
+
# Create a mix of embedding provider scenarios
|
|
872
|
+
valid_embedding_provider_1 = embedding_model["providers"][0].copy()
|
|
873
|
+
valid_embedding_provider_1["name"] = "openai"
|
|
874
|
+
|
|
875
|
+
valid_embedding_provider_2 = embedding_model["providers"][0].copy()
|
|
876
|
+
valid_embedding_provider_2["name"] = "azure_openai"
|
|
877
|
+
|
|
878
|
+
invalid_embedding_provider_unknown_name = embedding_model["providers"][0].copy()
|
|
879
|
+
invalid_embedding_provider_unknown_name["name"] = "nonexistent_embedding_provider"
|
|
880
|
+
|
|
881
|
+
invalid_embedding_provider_missing_name = embedding_model["providers"][0].copy()
|
|
882
|
+
del invalid_embedding_provider_missing_name["name"]
|
|
883
|
+
|
|
884
|
+
invalid_embedding_provider_wrong_type = embedding_model["providers"][0].copy()
|
|
885
|
+
invalid_embedding_provider_wrong_type["n_dimensions"] = "not_a_number"
|
|
886
|
+
|
|
887
|
+
embedding_model["providers"] = [
|
|
888
|
+
valid_embedding_provider_1,
|
|
889
|
+
invalid_embedding_provider_unknown_name,
|
|
890
|
+
valid_embedding_provider_2,
|
|
891
|
+
invalid_embedding_provider_missing_name,
|
|
892
|
+
invalid_embedding_provider_wrong_type,
|
|
893
|
+
]
|
|
894
|
+
|
|
895
|
+
data = {"model_list": [model], "embedding_model_list": [embedding_model]}
|
|
421
896
|
path = tmp_path / "mixed_providers_single.json"
|
|
422
897
|
path.write_text(json.dumps(data))
|
|
423
898
|
|
|
424
899
|
with caplog.at_level(logging.WARNING):
|
|
425
|
-
|
|
900
|
+
remote_config = deserialize_config_at_path(path)
|
|
901
|
+
|
|
902
|
+
models = remote_config.model_list
|
|
903
|
+
embedding_models = remote_config.embedding_model_list
|
|
426
904
|
|
|
427
905
|
# Should have 1 model with 2 valid providers
|
|
428
906
|
assert len(models) == 1
|
|
@@ -430,24 +908,41 @@ def test_deserialize_config_mixed_valid_invalid_providers_single_model(
|
|
|
430
908
|
assert models[0].providers[0].name.value == "openai"
|
|
431
909
|
assert models[0].providers[1].name.value == "azure_openai"
|
|
432
910
|
|
|
433
|
-
# Should have
|
|
434
|
-
|
|
911
|
+
# Should have 1 embedding model with 2 valid providers
|
|
912
|
+
assert len(embedding_models) == 1
|
|
913
|
+
assert len(embedding_models[0].providers) == 2
|
|
914
|
+
assert embedding_models[0].providers[0].name.value == "openai"
|
|
915
|
+
assert embedding_models[0].providers[1].name.value == "azure_openai"
|
|
916
|
+
|
|
917
|
+
# Should have logged 3 model provider validation warnings + 3 embedding model provider validation warnings = 6 total
|
|
918
|
+
model_provider_warnings = [
|
|
435
919
|
log
|
|
436
920
|
for log in caplog.records
|
|
437
921
|
if log.levelno == logging.WARNING
|
|
438
922
|
and "Failed to validate a model provider" in log.message
|
|
439
923
|
]
|
|
440
|
-
|
|
924
|
+
embedding_provider_warnings = [
|
|
925
|
+
log
|
|
926
|
+
for log in caplog.records
|
|
927
|
+
if log.levelno == logging.WARNING
|
|
928
|
+
and "Failed to validate an embedding model provider" in log.message
|
|
929
|
+
]
|
|
930
|
+
assert len(model_provider_warnings) == 3
|
|
931
|
+
assert len(embedding_provider_warnings) == 3
|
|
441
932
|
|
|
442
933
|
|
|
443
934
|
def test_deserialize_config_empty_json_structures(tmp_path):
|
|
444
935
|
"""Test various empty JSON structures."""
|
|
445
936
|
# Test empty model_list
|
|
446
|
-
data = {"model_list": []}
|
|
937
|
+
data = {"model_list": [], "embedding_model_list": []}
|
|
447
938
|
path = tmp_path / "empty_model_list.json"
|
|
448
939
|
path.write_text(json.dumps(data))
|
|
449
|
-
|
|
940
|
+
remote_config = deserialize_config_at_path(path)
|
|
941
|
+
models = remote_config.model_list
|
|
942
|
+
embedding_models = remote_config.embedding_model_list
|
|
943
|
+
|
|
450
944
|
assert len(models) == 0
|
|
945
|
+
assert len(embedding_models) == 0
|
|
451
946
|
|
|
452
947
|
# Test empty object with no model_list key
|
|
453
948
|
path = tmp_path / "empty_object.json"
|
|
@@ -480,7 +975,7 @@ def test_backwards_compatibility_with_v0_19(tmp_path):
|
|
|
480
975
|
|
|
481
976
|
# Create JSON with current version
|
|
482
977
|
current_json_path = tmp_path / "current_models.json"
|
|
483
|
-
serialize_config(built_in_models, current_json_path)
|
|
978
|
+
serialize_config(built_in_models, built_in_embedding_models, current_json_path)
|
|
484
979
|
|
|
485
980
|
# Test script using uv inline script metadata to install v0.19
|
|
486
981
|
test_script = f'''# /// script
|