kiln-ai 0.20.1__py3-none-any.whl → 0.22.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +6 -0
- kiln_ai/adapters/adapter_registry.py +43 -226
- kiln_ai/adapters/chunkers/__init__.py +13 -0
- kiln_ai/adapters/chunkers/base_chunker.py +42 -0
- kiln_ai/adapters/chunkers/chunker_registry.py +16 -0
- kiln_ai/adapters/chunkers/fixed_window_chunker.py +39 -0
- kiln_ai/adapters/chunkers/helpers.py +23 -0
- kiln_ai/adapters/chunkers/test_base_chunker.py +63 -0
- kiln_ai/adapters/chunkers/test_chunker_registry.py +28 -0
- kiln_ai/adapters/chunkers/test_fixed_window_chunker.py +346 -0
- kiln_ai/adapters/chunkers/test_helpers.py +75 -0
- kiln_ai/adapters/data_gen/test_data_gen_task.py +9 -3
- kiln_ai/adapters/embedding/__init__.py +0 -0
- kiln_ai/adapters/embedding/base_embedding_adapter.py +44 -0
- kiln_ai/adapters/embedding/embedding_registry.py +32 -0
- kiln_ai/adapters/embedding/litellm_embedding_adapter.py +199 -0
- kiln_ai/adapters/embedding/test_base_embedding_adapter.py +283 -0
- kiln_ai/adapters/embedding/test_embedding_registry.py +166 -0
- kiln_ai/adapters/embedding/test_litellm_embedding_adapter.py +1149 -0
- kiln_ai/adapters/eval/eval_runner.py +6 -2
- kiln_ai/adapters/eval/test_base_eval.py +1 -3
- kiln_ai/adapters/eval/test_g_eval.py +1 -1
- kiln_ai/adapters/extractors/__init__.py +18 -0
- kiln_ai/adapters/extractors/base_extractor.py +72 -0
- kiln_ai/adapters/extractors/encoding.py +20 -0
- kiln_ai/adapters/extractors/extractor_registry.py +44 -0
- kiln_ai/adapters/extractors/extractor_runner.py +112 -0
- kiln_ai/adapters/extractors/litellm_extractor.py +406 -0
- kiln_ai/adapters/extractors/test_base_extractor.py +244 -0
- kiln_ai/adapters/extractors/test_encoding.py +54 -0
- kiln_ai/adapters/extractors/test_extractor_registry.py +181 -0
- kiln_ai/adapters/extractors/test_extractor_runner.py +181 -0
- kiln_ai/adapters/extractors/test_litellm_extractor.py +1290 -0
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +2 -2
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +2 -6
- kiln_ai/adapters/fine_tune/test_together_finetune.py +2 -6
- kiln_ai/adapters/ml_embedding_model_list.py +494 -0
- kiln_ai/adapters/ml_model_list.py +876 -18
- kiln_ai/adapters/model_adapters/litellm_adapter.py +40 -75
- kiln_ai/adapters/model_adapters/test_litellm_adapter.py +79 -1
- kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py +119 -5
- kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +9 -3
- kiln_ai/adapters/model_adapters/test_structured_output.py +9 -10
- kiln_ai/adapters/ollama_tools.py +69 -12
- kiln_ai/adapters/provider_tools.py +190 -46
- kiln_ai/adapters/rag/deduplication.py +49 -0
- kiln_ai/adapters/rag/progress.py +252 -0
- kiln_ai/adapters/rag/rag_runners.py +844 -0
- kiln_ai/adapters/rag/test_deduplication.py +195 -0
- kiln_ai/adapters/rag/test_progress.py +785 -0
- kiln_ai/adapters/rag/test_rag_runners.py +2376 -0
- kiln_ai/adapters/remote_config.py +80 -8
- kiln_ai/adapters/test_adapter_registry.py +579 -86
- kiln_ai/adapters/test_ml_embedding_model_list.py +239 -0
- kiln_ai/adapters/test_ml_model_list.py +202 -0
- kiln_ai/adapters/test_ollama_tools.py +340 -1
- kiln_ai/adapters/test_prompt_builders.py +1 -1
- kiln_ai/adapters/test_provider_tools.py +199 -8
- kiln_ai/adapters/test_remote_config.py +551 -56
- kiln_ai/adapters/vector_store/__init__.py +1 -0
- kiln_ai/adapters/vector_store/base_vector_store_adapter.py +83 -0
- kiln_ai/adapters/vector_store/lancedb_adapter.py +389 -0
- kiln_ai/adapters/vector_store/test_base_vector_store.py +160 -0
- kiln_ai/adapters/vector_store/test_lancedb_adapter.py +1841 -0
- kiln_ai/adapters/vector_store/test_vector_store_registry.py +199 -0
- kiln_ai/adapters/vector_store/vector_store_registry.py +33 -0
- kiln_ai/datamodel/__init__.py +16 -13
- kiln_ai/datamodel/basemodel.py +201 -4
- kiln_ai/datamodel/chunk.py +158 -0
- kiln_ai/datamodel/datamodel_enums.py +27 -0
- kiln_ai/datamodel/embedding.py +64 -0
- kiln_ai/datamodel/external_tool_server.py +206 -54
- kiln_ai/datamodel/extraction.py +317 -0
- kiln_ai/datamodel/project.py +33 -1
- kiln_ai/datamodel/rag.py +79 -0
- kiln_ai/datamodel/task.py +5 -0
- kiln_ai/datamodel/task_output.py +41 -11
- kiln_ai/datamodel/test_attachment.py +649 -0
- kiln_ai/datamodel/test_basemodel.py +270 -14
- kiln_ai/datamodel/test_chunk_models.py +317 -0
- kiln_ai/datamodel/test_dataset_split.py +1 -1
- kiln_ai/datamodel/test_datasource.py +50 -0
- kiln_ai/datamodel/test_embedding_models.py +448 -0
- kiln_ai/datamodel/test_eval_model.py +6 -6
- kiln_ai/datamodel/test_external_tool_server.py +534 -152
- kiln_ai/datamodel/test_extraction_chunk.py +206 -0
- kiln_ai/datamodel/test_extraction_model.py +501 -0
- kiln_ai/datamodel/test_rag.py +641 -0
- kiln_ai/datamodel/test_task.py +35 -1
- kiln_ai/datamodel/test_tool_id.py +187 -1
- kiln_ai/datamodel/test_vector_store.py +320 -0
- kiln_ai/datamodel/tool_id.py +58 -0
- kiln_ai/datamodel/vector_store.py +141 -0
- kiln_ai/tools/base_tool.py +12 -3
- kiln_ai/tools/built_in_tools/math_tools.py +12 -4
- kiln_ai/tools/kiln_task_tool.py +158 -0
- kiln_ai/tools/mcp_server_tool.py +2 -2
- kiln_ai/tools/mcp_session_manager.py +51 -22
- kiln_ai/tools/rag_tools.py +164 -0
- kiln_ai/tools/test_kiln_task_tool.py +527 -0
- kiln_ai/tools/test_mcp_server_tool.py +4 -15
- kiln_ai/tools/test_mcp_session_manager.py +187 -227
- kiln_ai/tools/test_rag_tools.py +929 -0
- kiln_ai/tools/test_tool_registry.py +290 -7
- kiln_ai/tools/tool_registry.py +69 -16
- kiln_ai/utils/__init__.py +3 -0
- kiln_ai/utils/async_job_runner.py +62 -17
- kiln_ai/utils/config.py +2 -2
- kiln_ai/utils/env.py +15 -0
- kiln_ai/utils/filesystem.py +14 -0
- kiln_ai/utils/filesystem_cache.py +60 -0
- kiln_ai/utils/litellm.py +94 -0
- kiln_ai/utils/lock.py +100 -0
- kiln_ai/utils/mime_type.py +38 -0
- kiln_ai/utils/open_ai_types.py +19 -2
- kiln_ai/utils/pdf_utils.py +59 -0
- kiln_ai/utils/test_async_job_runner.py +151 -35
- kiln_ai/utils/test_env.py +142 -0
- kiln_ai/utils/test_filesystem_cache.py +316 -0
- kiln_ai/utils/test_litellm.py +206 -0
- kiln_ai/utils/test_lock.py +185 -0
- kiln_ai/utils/test_mime_type.py +66 -0
- kiln_ai/utils/test_open_ai_types.py +88 -12
- kiln_ai/utils/test_pdf_utils.py +86 -0
- kiln_ai/utils/test_uuid.py +111 -0
- kiln_ai/utils/test_validation.py +524 -0
- kiln_ai/utils/uuid.py +9 -0
- kiln_ai/utils/validation.py +90 -0
- {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/METADATA +9 -1
- kiln_ai-0.22.0.dist-info/RECORD +213 -0
- kiln_ai-0.20.1.dist-info/RECORD +0 -138
- {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -0,0 +1,239 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
|
|
5
|
+
from kiln_ai.adapters.embedding.embedding_registry import embedding_adapter_from_type
|
|
6
|
+
from kiln_ai.adapters.ml_embedding_model_list import (
|
|
7
|
+
EmbeddingModelName,
|
|
8
|
+
KilnEmbeddingModel,
|
|
9
|
+
KilnEmbeddingModelFamily,
|
|
10
|
+
KilnEmbeddingModelProvider,
|
|
11
|
+
built_in_embedding_models,
|
|
12
|
+
built_in_embedding_models_from_provider,
|
|
13
|
+
get_model_by_name,
|
|
14
|
+
)
|
|
15
|
+
from kiln_ai.datamodel.datamodel_enums import ModelProviderName
|
|
16
|
+
from kiln_ai.datamodel.embedding import EmbeddingConfig
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@pytest.fixture
|
|
20
|
+
def litellm_adapter():
|
|
21
|
+
adapter = embedding_adapter_from_type(
|
|
22
|
+
EmbeddingConfig(
|
|
23
|
+
name="test-embedding",
|
|
24
|
+
model_provider_name=ModelProviderName.openai,
|
|
25
|
+
model_name=EmbeddingModelName.openai_text_embedding_3_small,
|
|
26
|
+
properties={},
|
|
27
|
+
)
|
|
28
|
+
)
|
|
29
|
+
return adapter
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def get_all_embedding_models_and_providers() -> List[tuple[str, str]]:
|
|
33
|
+
return [
|
|
34
|
+
(model.name, provider.name)
|
|
35
|
+
for model in built_in_embedding_models
|
|
36
|
+
for provider in model.providers
|
|
37
|
+
]
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class TestKilnEmbeddingModelProvider:
|
|
41
|
+
"""Test cases for KilnEmbeddingModelProvider model"""
|
|
42
|
+
|
|
43
|
+
def test_basic_provider_creation(self):
|
|
44
|
+
"""Test creating a basic provider with required fields"""
|
|
45
|
+
provider = KilnEmbeddingModelProvider(
|
|
46
|
+
name=ModelProviderName.openai,
|
|
47
|
+
model_id="text-embedding-3-small",
|
|
48
|
+
max_input_tokens=8192,
|
|
49
|
+
n_dimensions=1536,
|
|
50
|
+
supports_custom_dimensions=True,
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
assert provider.name == ModelProviderName.openai
|
|
54
|
+
assert provider.model_id == "text-embedding-3-small"
|
|
55
|
+
assert provider.max_input_tokens == 8192
|
|
56
|
+
assert provider.n_dimensions == 1536
|
|
57
|
+
assert provider.supports_custom_dimensions is True
|
|
58
|
+
|
|
59
|
+
def test_provider_with_optional_fields_unspecified(self):
|
|
60
|
+
"""Test creating a provider with optional fields not specified"""
|
|
61
|
+
provider = KilnEmbeddingModelProvider(
|
|
62
|
+
name=ModelProviderName.gemini_api,
|
|
63
|
+
model_id="text-embedding-004",
|
|
64
|
+
n_dimensions=768,
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
assert provider.name == ModelProviderName.gemini_api
|
|
68
|
+
assert provider.model_id == "text-embedding-004"
|
|
69
|
+
assert provider.max_input_tokens is None
|
|
70
|
+
assert provider.n_dimensions == 768
|
|
71
|
+
assert provider.supports_custom_dimensions is False
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class TestKilnEmbeddingModel:
|
|
75
|
+
"""Test cases for KilnEmbeddingModel model"""
|
|
76
|
+
|
|
77
|
+
def test_basic_model_creation(self):
|
|
78
|
+
"""Test creating a basic model with required fields"""
|
|
79
|
+
providers = [
|
|
80
|
+
KilnEmbeddingModelProvider(
|
|
81
|
+
name=ModelProviderName.openai,
|
|
82
|
+
model_id="text-embedding-3-small",
|
|
83
|
+
n_dimensions=1536,
|
|
84
|
+
max_input_tokens=8192,
|
|
85
|
+
)
|
|
86
|
+
]
|
|
87
|
+
|
|
88
|
+
model = KilnEmbeddingModel(
|
|
89
|
+
family=KilnEmbeddingModelFamily.openai,
|
|
90
|
+
name=EmbeddingModelName.openai_text_embedding_3_small,
|
|
91
|
+
friendly_name="Text Embedding 3 Small",
|
|
92
|
+
providers=providers,
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
assert model.family == KilnEmbeddingModelFamily.openai
|
|
96
|
+
assert model.name == EmbeddingModelName.openai_text_embedding_3_small
|
|
97
|
+
assert model.friendly_name == "Text Embedding 3 Small"
|
|
98
|
+
assert len(model.providers) == 1
|
|
99
|
+
assert model.providers[0].name == ModelProviderName.openai
|
|
100
|
+
|
|
101
|
+
def test_model_with_multiple_providers(self):
|
|
102
|
+
"""Test creating a model with multiple providers"""
|
|
103
|
+
providers = [
|
|
104
|
+
KilnEmbeddingModelProvider(
|
|
105
|
+
name=ModelProviderName.openai,
|
|
106
|
+
model_id="model-1",
|
|
107
|
+
n_dimensions=1536,
|
|
108
|
+
max_input_tokens=8192,
|
|
109
|
+
),
|
|
110
|
+
KilnEmbeddingModelProvider(
|
|
111
|
+
name=ModelProviderName.anthropic,
|
|
112
|
+
model_id="model-1",
|
|
113
|
+
n_dimensions=1536,
|
|
114
|
+
max_input_tokens=8192,
|
|
115
|
+
),
|
|
116
|
+
]
|
|
117
|
+
|
|
118
|
+
model = KilnEmbeddingModel(
|
|
119
|
+
family=KilnEmbeddingModelFamily.openai,
|
|
120
|
+
name=EmbeddingModelName.openai_text_embedding_3_small,
|
|
121
|
+
friendly_name="text-embedding-3-small",
|
|
122
|
+
providers=providers,
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
assert len(model.providers) == 2
|
|
126
|
+
assert model.providers[0].name == ModelProviderName.openai
|
|
127
|
+
assert model.providers[1].name == ModelProviderName.anthropic
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
class TestGetModelByName:
|
|
131
|
+
def test_get_nonexistent_model_raises_error(self):
|
|
132
|
+
"""Test that getting a nonexistent model raises ValueError"""
|
|
133
|
+
with pytest.raises(
|
|
134
|
+
ValueError, match="Embedding model nonexistent_model not found"
|
|
135
|
+
):
|
|
136
|
+
get_model_by_name("nonexistent_model") # type: ignore
|
|
137
|
+
|
|
138
|
+
@pytest.mark.parametrize(
|
|
139
|
+
"model_name",
|
|
140
|
+
[model.name for model in built_in_embedding_models],
|
|
141
|
+
)
|
|
142
|
+
def test_model_retrieval(self, model_name):
|
|
143
|
+
"""Test retrieving models with parametrized test cases"""
|
|
144
|
+
model = get_model_by_name(model_name)
|
|
145
|
+
assert model.family == model.family
|
|
146
|
+
assert model.friendly_name == model.friendly_name
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
class TestBuiltInEmbeddingModelsFromProvider:
|
|
150
|
+
@pytest.mark.parametrize(
|
|
151
|
+
"model_name,provider_name", get_all_embedding_models_and_providers()
|
|
152
|
+
)
|
|
153
|
+
def test_get_all_existing_models_and_providers(self, model_name, provider_name):
|
|
154
|
+
provider = built_in_embedding_models_from_provider(provider_name, model_name)
|
|
155
|
+
|
|
156
|
+
assert provider is not None
|
|
157
|
+
assert provider.name == provider_name
|
|
158
|
+
assert provider.model_id == provider.model_id
|
|
159
|
+
assert provider.n_dimensions == provider.n_dimensions
|
|
160
|
+
assert provider.max_input_tokens == provider.max_input_tokens
|
|
161
|
+
assert (
|
|
162
|
+
provider.supports_custom_dimensions == provider.supports_custom_dimensions
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
def test_get_nonexistent_model_returns_none(self):
|
|
166
|
+
"""Test that getting a nonexistent model returns None"""
|
|
167
|
+
provider = built_in_embedding_models_from_provider(
|
|
168
|
+
provider_name=ModelProviderName.openai,
|
|
169
|
+
model_name="nonexistent_model",
|
|
170
|
+
)
|
|
171
|
+
assert provider is None
|
|
172
|
+
|
|
173
|
+
def test_get_wrong_provider_for_model_returns_none(self):
|
|
174
|
+
"""Test that getting wrong provider for a model returns None"""
|
|
175
|
+
provider = built_in_embedding_models_from_provider(
|
|
176
|
+
provider_name=ModelProviderName.gemini_api,
|
|
177
|
+
model_name=EmbeddingModelName.openai_text_embedding_3_small,
|
|
178
|
+
)
|
|
179
|
+
assert provider is None
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
class TestGenerateEmbedding:
|
|
183
|
+
"""Test cases for generate_embedding function"""
|
|
184
|
+
|
|
185
|
+
@pytest.mark.parametrize(
|
|
186
|
+
"model_name,provider_name", get_all_embedding_models_and_providers()
|
|
187
|
+
)
|
|
188
|
+
@pytest.mark.paid
|
|
189
|
+
async def test_generate_embedding(self, model_name, provider_name):
|
|
190
|
+
"""Test generating an embedding"""
|
|
191
|
+
model_provider = built_in_embedding_models_from_provider(
|
|
192
|
+
provider_name, model_name
|
|
193
|
+
)
|
|
194
|
+
assert model_provider is not None
|
|
195
|
+
|
|
196
|
+
embedding = embedding_adapter_from_type(
|
|
197
|
+
EmbeddingConfig(
|
|
198
|
+
name="test-embedding",
|
|
199
|
+
model_provider_name=provider_name,
|
|
200
|
+
model_name=model_name,
|
|
201
|
+
properties={},
|
|
202
|
+
)
|
|
203
|
+
)
|
|
204
|
+
embedding = await embedding.generate_embeddings(["Hello, world!"])
|
|
205
|
+
assert len(embedding.embeddings) == 1
|
|
206
|
+
assert len(embedding.embeddings[0].vector) == model_provider.n_dimensions
|
|
207
|
+
|
|
208
|
+
@pytest.mark.parametrize(
|
|
209
|
+
"model_name,provider_name", get_all_embedding_models_and_providers()
|
|
210
|
+
)
|
|
211
|
+
@pytest.mark.paid
|
|
212
|
+
async def test_generate_embedding_with_user_supplied_dimensions(
|
|
213
|
+
self, model_name, provider_name
|
|
214
|
+
):
|
|
215
|
+
"""Test generating an embedding with user supplied dimensions"""
|
|
216
|
+
model_provider = built_in_embedding_models_from_provider(
|
|
217
|
+
provider_name=provider_name,
|
|
218
|
+
model_name=model_name,
|
|
219
|
+
)
|
|
220
|
+
assert model_provider is not None
|
|
221
|
+
|
|
222
|
+
if not model_provider.supports_custom_dimensions:
|
|
223
|
+
pytest.skip("Model does not support custom dimensions")
|
|
224
|
+
|
|
225
|
+
# max dim
|
|
226
|
+
max_dimensions = model_provider.n_dimensions
|
|
227
|
+
dimensions_target = max_dimensions // 2
|
|
228
|
+
|
|
229
|
+
embedding = embedding_adapter_from_type(
|
|
230
|
+
EmbeddingConfig(
|
|
231
|
+
name="test-embedding",
|
|
232
|
+
model_provider_name=provider_name,
|
|
233
|
+
model_name=model_name,
|
|
234
|
+
properties={"dimensions": dimensions_target},
|
|
235
|
+
)
|
|
236
|
+
)
|
|
237
|
+
embedding = await embedding.generate_embeddings(["Hello, world!"])
|
|
238
|
+
assert len(embedding.embeddings) == 1
|
|
239
|
+
assert len(embedding.embeddings[0].vector) == dimensions_target
|
|
@@ -5,6 +5,7 @@ import pytest
|
|
|
5
5
|
from kiln_ai.adapters.ml_model_list import (
|
|
6
6
|
ModelName,
|
|
7
7
|
built_in_models,
|
|
8
|
+
built_in_models_from_provider,
|
|
8
9
|
default_structured_output_mode_for_model_provider,
|
|
9
10
|
get_model_by_name,
|
|
10
11
|
)
|
|
@@ -161,6 +162,186 @@ class TestDefaultStructuredOutputModeForModelProvider:
|
|
|
161
162
|
assert result == first_provider.structured_output_mode
|
|
162
163
|
|
|
163
164
|
|
|
165
|
+
class TestBuiltInModelsFromProvider:
|
|
166
|
+
"""Test cases for built_in_models_from_provider function"""
|
|
167
|
+
|
|
168
|
+
def test_valid_model_and_provider_returns_provider(self):
|
|
169
|
+
"""Test that valid model and provider returns the correct provider configuration"""
|
|
170
|
+
# GPT 4.1 has OpenAI provider
|
|
171
|
+
result = built_in_models_from_provider(
|
|
172
|
+
provider_name=ModelProviderName.openai,
|
|
173
|
+
model_name="gpt_4_1",
|
|
174
|
+
)
|
|
175
|
+
assert result is not None
|
|
176
|
+
assert result.name == ModelProviderName.openai
|
|
177
|
+
assert result.model_id == "gpt-4.1"
|
|
178
|
+
assert result.supports_logprobs is True
|
|
179
|
+
assert result.suggested_for_data_gen is True
|
|
180
|
+
|
|
181
|
+
def test_valid_model_different_provider_returns_correct_provider(self):
|
|
182
|
+
"""Test that different providers for the same model return different configurations"""
|
|
183
|
+
# GPT 4.1 has multiple providers with different configurations
|
|
184
|
+
openai_provider = built_in_models_from_provider(
|
|
185
|
+
provider_name=ModelProviderName.openai,
|
|
186
|
+
model_name="gpt_4_1",
|
|
187
|
+
)
|
|
188
|
+
openrouter_provider = built_in_models_from_provider(
|
|
189
|
+
provider_name=ModelProviderName.openrouter,
|
|
190
|
+
model_name="gpt_4_1",
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
assert openai_provider is not None
|
|
194
|
+
assert openrouter_provider is not None
|
|
195
|
+
assert openai_provider.name == ModelProviderName.openai
|
|
196
|
+
assert openrouter_provider.name == ModelProviderName.openrouter
|
|
197
|
+
assert openai_provider.model_id == "gpt-4.1"
|
|
198
|
+
assert openrouter_provider.model_id == "openai/gpt-4.1"
|
|
199
|
+
|
|
200
|
+
def test_invalid_model_name_returns_none(self):
|
|
201
|
+
"""Test that invalid model name returns None"""
|
|
202
|
+
result = built_in_models_from_provider(
|
|
203
|
+
provider_name=ModelProviderName.openai,
|
|
204
|
+
model_name="invalid_model_name",
|
|
205
|
+
)
|
|
206
|
+
assert result is None
|
|
207
|
+
|
|
208
|
+
def test_valid_model_invalid_provider_returns_none(self):
|
|
209
|
+
"""Test that valid model but invalid provider returns None"""
|
|
210
|
+
result = built_in_models_from_provider(
|
|
211
|
+
provider_name=ModelProviderName.gemini_api, # GPT 4.1 doesn't have gemini_api provider
|
|
212
|
+
model_name="gpt_4_1",
|
|
213
|
+
)
|
|
214
|
+
assert result is None
|
|
215
|
+
|
|
216
|
+
def test_model_with_single_provider(self):
|
|
217
|
+
"""Test model that only has one provider"""
|
|
218
|
+
# Find a model with only one provider for this test
|
|
219
|
+
model = get_model_by_name(ModelName.gpt_4_1_nano)
|
|
220
|
+
assert len(model.providers) >= 1 # Verify it has providers
|
|
221
|
+
|
|
222
|
+
first_provider = model.providers[0]
|
|
223
|
+
result = built_in_models_from_provider(
|
|
224
|
+
provider_name=first_provider.name,
|
|
225
|
+
model_name="gpt_4_1_nano",
|
|
226
|
+
)
|
|
227
|
+
assert result is not None
|
|
228
|
+
assert result.name == first_provider.name
|
|
229
|
+
assert result.model_id == first_provider.model_id
|
|
230
|
+
|
|
231
|
+
def test_model_with_multiple_providers(self):
|
|
232
|
+
"""Test model that has multiple providers"""
|
|
233
|
+
# GPT 4.1 has multiple providers
|
|
234
|
+
openai_result = built_in_models_from_provider(
|
|
235
|
+
provider_name=ModelProviderName.openai,
|
|
236
|
+
model_name="gpt_4_1",
|
|
237
|
+
)
|
|
238
|
+
openrouter_result = built_in_models_from_provider(
|
|
239
|
+
provider_name=ModelProviderName.openrouter,
|
|
240
|
+
model_name="gpt_4_1",
|
|
241
|
+
)
|
|
242
|
+
azure_result = built_in_models_from_provider(
|
|
243
|
+
provider_name=ModelProviderName.azure_openai,
|
|
244
|
+
model_name="gpt_4_1",
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
assert openai_result is not None
|
|
248
|
+
assert openrouter_result is not None
|
|
249
|
+
assert azure_result is not None
|
|
250
|
+
assert openai_result.name == ModelProviderName.openai
|
|
251
|
+
assert openrouter_result.name == ModelProviderName.openrouter
|
|
252
|
+
assert azure_result.name == ModelProviderName.azure_openai
|
|
253
|
+
|
|
254
|
+
def test_case_sensitive_model_name(self):
|
|
255
|
+
"""Test that model name matching is case sensitive"""
|
|
256
|
+
# Should return None for case-mismatched model name
|
|
257
|
+
result = built_in_models_from_provider(
|
|
258
|
+
provider_name=ModelProviderName.openai,
|
|
259
|
+
model_name="GPT_4_1", # Wrong case
|
|
260
|
+
)
|
|
261
|
+
assert result is None
|
|
262
|
+
|
|
263
|
+
def test_provider_specific_attributes(self):
|
|
264
|
+
"""Test that provider-specific attributes are correctly returned"""
|
|
265
|
+
# Test a provider with specific attributes like multimodal_capable
|
|
266
|
+
result = built_in_models_from_provider(
|
|
267
|
+
provider_name=ModelProviderName.openai,
|
|
268
|
+
model_name="gpt_4_1",
|
|
269
|
+
)
|
|
270
|
+
assert result is not None
|
|
271
|
+
assert result.multimodal_capable is True
|
|
272
|
+
assert result.supports_doc_extraction is True
|
|
273
|
+
assert result.multimodal_mime_types is not None
|
|
274
|
+
assert len(result.multimodal_mime_types) > 0
|
|
275
|
+
|
|
276
|
+
def test_provider_without_special_attributes(self):
|
|
277
|
+
"""Test provider that doesn't have special attributes"""
|
|
278
|
+
# Test a simpler provider configuration
|
|
279
|
+
result = built_in_models_from_provider(
|
|
280
|
+
provider_name=ModelProviderName.openai,
|
|
281
|
+
model_name="gpt_4_1_nano",
|
|
282
|
+
)
|
|
283
|
+
assert result is not None
|
|
284
|
+
assert result.multimodal_capable is False # Should be False for nano
|
|
285
|
+
assert result.supports_doc_extraction is False
|
|
286
|
+
|
|
287
|
+
@pytest.mark.parametrize(
|
|
288
|
+
"model_name,provider,expected_model_id",
|
|
289
|
+
[
|
|
290
|
+
("gpt_4o", ModelProviderName.openai, "gpt-4o"),
|
|
291
|
+
(
|
|
292
|
+
"claude_3_5_haiku",
|
|
293
|
+
ModelProviderName.anthropic,
|
|
294
|
+
"claude-3-5-haiku-20241022",
|
|
295
|
+
),
|
|
296
|
+
("gemini_2_5_pro", ModelProviderName.gemini_api, "gemini-2.5-pro"),
|
|
297
|
+
("llama_3_1_8b", ModelProviderName.groq, "llama-3.1-8b-instant"),
|
|
298
|
+
],
|
|
299
|
+
)
|
|
300
|
+
def test_parametrized_valid_combinations(
|
|
301
|
+
self, model_name, provider, expected_model_id
|
|
302
|
+
):
|
|
303
|
+
"""Test multiple valid model/provider combinations"""
|
|
304
|
+
result = built_in_models_from_provider(
|
|
305
|
+
provider_name=provider,
|
|
306
|
+
model_name=model_name,
|
|
307
|
+
)
|
|
308
|
+
assert result is not None
|
|
309
|
+
assert result.name == provider
|
|
310
|
+
assert result.model_id == expected_model_id
|
|
311
|
+
|
|
312
|
+
def test_empty_string_model_name(self):
|
|
313
|
+
"""Test that empty string model name returns None"""
|
|
314
|
+
result = built_in_models_from_provider(
|
|
315
|
+
provider_name=ModelProviderName.openai,
|
|
316
|
+
model_name="",
|
|
317
|
+
)
|
|
318
|
+
assert result is None
|
|
319
|
+
|
|
320
|
+
def test_none_model_name(self):
|
|
321
|
+
"""Test that None model name returns None"""
|
|
322
|
+
result = built_in_models_from_provider(
|
|
323
|
+
provider_name=ModelProviderName.openai,
|
|
324
|
+
model_name=None, # type: ignore
|
|
325
|
+
)
|
|
326
|
+
assert result is None
|
|
327
|
+
|
|
328
|
+
def test_all_built_in_models_have_valid_providers(self):
|
|
329
|
+
"""Test that all built-in models have at least one valid provider"""
|
|
330
|
+
for model in built_in_models:
|
|
331
|
+
assert len(model.providers) > 0, f"Model {model.name} has no providers"
|
|
332
|
+
for provider in model.providers:
|
|
333
|
+
# Test that we can retrieve each provider
|
|
334
|
+
result = built_in_models_from_provider(
|
|
335
|
+
provider_name=provider.name,
|
|
336
|
+
model_name=model.name,
|
|
337
|
+
)
|
|
338
|
+
assert result is not None, (
|
|
339
|
+
f"Could not retrieve provider {provider.name} for model {model.name}"
|
|
340
|
+
)
|
|
341
|
+
assert result.name == provider.name
|
|
342
|
+
assert result.model_id == provider.model_id
|
|
343
|
+
|
|
344
|
+
|
|
164
345
|
def test_uncensored():
|
|
165
346
|
"""Test that uncensored is set correctly"""
|
|
166
347
|
model = get_model_by_name(ModelName.grok_3_mini)
|
|
@@ -179,6 +360,27 @@ def test_uncensored():
|
|
|
179
360
|
assert provider.suggested_for_uncensored_data_gen
|
|
180
361
|
|
|
181
362
|
|
|
363
|
+
def test_no_empty_multimodal_mime_types():
|
|
364
|
+
"""Ensure that multimodal fields are self-consistent as they are interdependent"""
|
|
365
|
+
for model in built_in_models:
|
|
366
|
+
for provider in model.providers:
|
|
367
|
+
# a multimodal model should always have mime types it supports
|
|
368
|
+
if provider.multimodal_capable:
|
|
369
|
+
assert provider.multimodal_mime_types is not None
|
|
370
|
+
assert len(provider.multimodal_mime_types) > 0
|
|
371
|
+
|
|
372
|
+
# a model that specifies mime types is necessarily multimodal
|
|
373
|
+
if (
|
|
374
|
+
provider.multimodal_mime_types is not None
|
|
375
|
+
and len(provider.multimodal_mime_types) > 0
|
|
376
|
+
):
|
|
377
|
+
assert provider.multimodal_capable
|
|
378
|
+
|
|
379
|
+
# a model that supports doc extraction is necessarily multimodal
|
|
380
|
+
if provider.supports_doc_extraction:
|
|
381
|
+
assert provider.multimodal_capable
|
|
382
|
+
|
|
383
|
+
|
|
182
384
|
def test_no_reasoning_for_structured_output():
|
|
183
385
|
"""Test that no reasoning is returned for structured output"""
|
|
184
386
|
# get all models
|