kiln-ai 0.20.1__py3-none-any.whl → 0.22.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (133) hide show
  1. kiln_ai/adapters/__init__.py +6 -0
  2. kiln_ai/adapters/adapter_registry.py +43 -226
  3. kiln_ai/adapters/chunkers/__init__.py +13 -0
  4. kiln_ai/adapters/chunkers/base_chunker.py +42 -0
  5. kiln_ai/adapters/chunkers/chunker_registry.py +16 -0
  6. kiln_ai/adapters/chunkers/fixed_window_chunker.py +39 -0
  7. kiln_ai/adapters/chunkers/helpers.py +23 -0
  8. kiln_ai/adapters/chunkers/test_base_chunker.py +63 -0
  9. kiln_ai/adapters/chunkers/test_chunker_registry.py +28 -0
  10. kiln_ai/adapters/chunkers/test_fixed_window_chunker.py +346 -0
  11. kiln_ai/adapters/chunkers/test_helpers.py +75 -0
  12. kiln_ai/adapters/data_gen/test_data_gen_task.py +9 -3
  13. kiln_ai/adapters/embedding/__init__.py +0 -0
  14. kiln_ai/adapters/embedding/base_embedding_adapter.py +44 -0
  15. kiln_ai/adapters/embedding/embedding_registry.py +32 -0
  16. kiln_ai/adapters/embedding/litellm_embedding_adapter.py +199 -0
  17. kiln_ai/adapters/embedding/test_base_embedding_adapter.py +283 -0
  18. kiln_ai/adapters/embedding/test_embedding_registry.py +166 -0
  19. kiln_ai/adapters/embedding/test_litellm_embedding_adapter.py +1149 -0
  20. kiln_ai/adapters/eval/eval_runner.py +6 -2
  21. kiln_ai/adapters/eval/test_base_eval.py +1 -3
  22. kiln_ai/adapters/eval/test_g_eval.py +1 -1
  23. kiln_ai/adapters/extractors/__init__.py +18 -0
  24. kiln_ai/adapters/extractors/base_extractor.py +72 -0
  25. kiln_ai/adapters/extractors/encoding.py +20 -0
  26. kiln_ai/adapters/extractors/extractor_registry.py +44 -0
  27. kiln_ai/adapters/extractors/extractor_runner.py +112 -0
  28. kiln_ai/adapters/extractors/litellm_extractor.py +406 -0
  29. kiln_ai/adapters/extractors/test_base_extractor.py +244 -0
  30. kiln_ai/adapters/extractors/test_encoding.py +54 -0
  31. kiln_ai/adapters/extractors/test_extractor_registry.py +181 -0
  32. kiln_ai/adapters/extractors/test_extractor_runner.py +181 -0
  33. kiln_ai/adapters/extractors/test_litellm_extractor.py +1290 -0
  34. kiln_ai/adapters/fine_tune/test_dataset_formatter.py +2 -2
  35. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +2 -6
  36. kiln_ai/adapters/fine_tune/test_together_finetune.py +2 -6
  37. kiln_ai/adapters/ml_embedding_model_list.py +494 -0
  38. kiln_ai/adapters/ml_model_list.py +876 -18
  39. kiln_ai/adapters/model_adapters/litellm_adapter.py +40 -75
  40. kiln_ai/adapters/model_adapters/test_litellm_adapter.py +79 -1
  41. kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py +119 -5
  42. kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +9 -3
  43. kiln_ai/adapters/model_adapters/test_structured_output.py +9 -10
  44. kiln_ai/adapters/ollama_tools.py +69 -12
  45. kiln_ai/adapters/provider_tools.py +190 -46
  46. kiln_ai/adapters/rag/deduplication.py +49 -0
  47. kiln_ai/adapters/rag/progress.py +252 -0
  48. kiln_ai/adapters/rag/rag_runners.py +844 -0
  49. kiln_ai/adapters/rag/test_deduplication.py +195 -0
  50. kiln_ai/adapters/rag/test_progress.py +785 -0
  51. kiln_ai/adapters/rag/test_rag_runners.py +2376 -0
  52. kiln_ai/adapters/remote_config.py +80 -8
  53. kiln_ai/adapters/test_adapter_registry.py +579 -86
  54. kiln_ai/adapters/test_ml_embedding_model_list.py +239 -0
  55. kiln_ai/adapters/test_ml_model_list.py +202 -0
  56. kiln_ai/adapters/test_ollama_tools.py +340 -1
  57. kiln_ai/adapters/test_prompt_builders.py +1 -1
  58. kiln_ai/adapters/test_provider_tools.py +199 -8
  59. kiln_ai/adapters/test_remote_config.py +551 -56
  60. kiln_ai/adapters/vector_store/__init__.py +1 -0
  61. kiln_ai/adapters/vector_store/base_vector_store_adapter.py +83 -0
  62. kiln_ai/adapters/vector_store/lancedb_adapter.py +389 -0
  63. kiln_ai/adapters/vector_store/test_base_vector_store.py +160 -0
  64. kiln_ai/adapters/vector_store/test_lancedb_adapter.py +1841 -0
  65. kiln_ai/adapters/vector_store/test_vector_store_registry.py +199 -0
  66. kiln_ai/adapters/vector_store/vector_store_registry.py +33 -0
  67. kiln_ai/datamodel/__init__.py +16 -13
  68. kiln_ai/datamodel/basemodel.py +201 -4
  69. kiln_ai/datamodel/chunk.py +158 -0
  70. kiln_ai/datamodel/datamodel_enums.py +27 -0
  71. kiln_ai/datamodel/embedding.py +64 -0
  72. kiln_ai/datamodel/external_tool_server.py +206 -54
  73. kiln_ai/datamodel/extraction.py +317 -0
  74. kiln_ai/datamodel/project.py +33 -1
  75. kiln_ai/datamodel/rag.py +79 -0
  76. kiln_ai/datamodel/task.py +5 -0
  77. kiln_ai/datamodel/task_output.py +41 -11
  78. kiln_ai/datamodel/test_attachment.py +649 -0
  79. kiln_ai/datamodel/test_basemodel.py +270 -14
  80. kiln_ai/datamodel/test_chunk_models.py +317 -0
  81. kiln_ai/datamodel/test_dataset_split.py +1 -1
  82. kiln_ai/datamodel/test_datasource.py +50 -0
  83. kiln_ai/datamodel/test_embedding_models.py +448 -0
  84. kiln_ai/datamodel/test_eval_model.py +6 -6
  85. kiln_ai/datamodel/test_external_tool_server.py +534 -152
  86. kiln_ai/datamodel/test_extraction_chunk.py +206 -0
  87. kiln_ai/datamodel/test_extraction_model.py +501 -0
  88. kiln_ai/datamodel/test_rag.py +641 -0
  89. kiln_ai/datamodel/test_task.py +35 -1
  90. kiln_ai/datamodel/test_tool_id.py +187 -1
  91. kiln_ai/datamodel/test_vector_store.py +320 -0
  92. kiln_ai/datamodel/tool_id.py +58 -0
  93. kiln_ai/datamodel/vector_store.py +141 -0
  94. kiln_ai/tools/base_tool.py +12 -3
  95. kiln_ai/tools/built_in_tools/math_tools.py +12 -4
  96. kiln_ai/tools/kiln_task_tool.py +158 -0
  97. kiln_ai/tools/mcp_server_tool.py +2 -2
  98. kiln_ai/tools/mcp_session_manager.py +51 -22
  99. kiln_ai/tools/rag_tools.py +164 -0
  100. kiln_ai/tools/test_kiln_task_tool.py +527 -0
  101. kiln_ai/tools/test_mcp_server_tool.py +4 -15
  102. kiln_ai/tools/test_mcp_session_manager.py +187 -227
  103. kiln_ai/tools/test_rag_tools.py +929 -0
  104. kiln_ai/tools/test_tool_registry.py +290 -7
  105. kiln_ai/tools/tool_registry.py +69 -16
  106. kiln_ai/utils/__init__.py +3 -0
  107. kiln_ai/utils/async_job_runner.py +62 -17
  108. kiln_ai/utils/config.py +2 -2
  109. kiln_ai/utils/env.py +15 -0
  110. kiln_ai/utils/filesystem.py +14 -0
  111. kiln_ai/utils/filesystem_cache.py +60 -0
  112. kiln_ai/utils/litellm.py +94 -0
  113. kiln_ai/utils/lock.py +100 -0
  114. kiln_ai/utils/mime_type.py +38 -0
  115. kiln_ai/utils/open_ai_types.py +19 -2
  116. kiln_ai/utils/pdf_utils.py +59 -0
  117. kiln_ai/utils/test_async_job_runner.py +151 -35
  118. kiln_ai/utils/test_env.py +142 -0
  119. kiln_ai/utils/test_filesystem_cache.py +316 -0
  120. kiln_ai/utils/test_litellm.py +206 -0
  121. kiln_ai/utils/test_lock.py +185 -0
  122. kiln_ai/utils/test_mime_type.py +66 -0
  123. kiln_ai/utils/test_open_ai_types.py +88 -12
  124. kiln_ai/utils/test_pdf_utils.py +86 -0
  125. kiln_ai/utils/test_uuid.py +111 -0
  126. kiln_ai/utils/test_validation.py +524 -0
  127. kiln_ai/utils/uuid.py +9 -0
  128. kiln_ai/utils/validation.py +90 -0
  129. {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/METADATA +9 -1
  130. kiln_ai-0.22.0.dist-info/RECORD +213 -0
  131. kiln_ai-0.20.1.dist-info/RECORD +0 -138
  132. {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/WHEEL +0 -0
  133. {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -0,0 +1,239 @@
1
+ from typing import List
2
+
3
+ import pytest
4
+
5
+ from kiln_ai.adapters.embedding.embedding_registry import embedding_adapter_from_type
6
+ from kiln_ai.adapters.ml_embedding_model_list import (
7
+ EmbeddingModelName,
8
+ KilnEmbeddingModel,
9
+ KilnEmbeddingModelFamily,
10
+ KilnEmbeddingModelProvider,
11
+ built_in_embedding_models,
12
+ built_in_embedding_models_from_provider,
13
+ get_model_by_name,
14
+ )
15
+ from kiln_ai.datamodel.datamodel_enums import ModelProviderName
16
+ from kiln_ai.datamodel.embedding import EmbeddingConfig
17
+
18
+
19
+ @pytest.fixture
20
+ def litellm_adapter():
21
+ adapter = embedding_adapter_from_type(
22
+ EmbeddingConfig(
23
+ name="test-embedding",
24
+ model_provider_name=ModelProviderName.openai,
25
+ model_name=EmbeddingModelName.openai_text_embedding_3_small,
26
+ properties={},
27
+ )
28
+ )
29
+ return adapter
30
+
31
+
32
+ def get_all_embedding_models_and_providers() -> List[tuple[str, str]]:
33
+ return [
34
+ (model.name, provider.name)
35
+ for model in built_in_embedding_models
36
+ for provider in model.providers
37
+ ]
38
+
39
+
40
+ class TestKilnEmbeddingModelProvider:
41
+ """Test cases for KilnEmbeddingModelProvider model"""
42
+
43
+ def test_basic_provider_creation(self):
44
+ """Test creating a basic provider with required fields"""
45
+ provider = KilnEmbeddingModelProvider(
46
+ name=ModelProviderName.openai,
47
+ model_id="text-embedding-3-small",
48
+ max_input_tokens=8192,
49
+ n_dimensions=1536,
50
+ supports_custom_dimensions=True,
51
+ )
52
+
53
+ assert provider.name == ModelProviderName.openai
54
+ assert provider.model_id == "text-embedding-3-small"
55
+ assert provider.max_input_tokens == 8192
56
+ assert provider.n_dimensions == 1536
57
+ assert provider.supports_custom_dimensions is True
58
+
59
+ def test_provider_with_optional_fields_unspecified(self):
60
+ """Test creating a provider with optional fields not specified"""
61
+ provider = KilnEmbeddingModelProvider(
62
+ name=ModelProviderName.gemini_api,
63
+ model_id="text-embedding-004",
64
+ n_dimensions=768,
65
+ )
66
+
67
+ assert provider.name == ModelProviderName.gemini_api
68
+ assert provider.model_id == "text-embedding-004"
69
+ assert provider.max_input_tokens is None
70
+ assert provider.n_dimensions == 768
71
+ assert provider.supports_custom_dimensions is False
72
+
73
+
74
+ class TestKilnEmbeddingModel:
75
+ """Test cases for KilnEmbeddingModel model"""
76
+
77
+ def test_basic_model_creation(self):
78
+ """Test creating a basic model with required fields"""
79
+ providers = [
80
+ KilnEmbeddingModelProvider(
81
+ name=ModelProviderName.openai,
82
+ model_id="text-embedding-3-small",
83
+ n_dimensions=1536,
84
+ max_input_tokens=8192,
85
+ )
86
+ ]
87
+
88
+ model = KilnEmbeddingModel(
89
+ family=KilnEmbeddingModelFamily.openai,
90
+ name=EmbeddingModelName.openai_text_embedding_3_small,
91
+ friendly_name="Text Embedding 3 Small",
92
+ providers=providers,
93
+ )
94
+
95
+ assert model.family == KilnEmbeddingModelFamily.openai
96
+ assert model.name == EmbeddingModelName.openai_text_embedding_3_small
97
+ assert model.friendly_name == "Text Embedding 3 Small"
98
+ assert len(model.providers) == 1
99
+ assert model.providers[0].name == ModelProviderName.openai
100
+
101
+ def test_model_with_multiple_providers(self):
102
+ """Test creating a model with multiple providers"""
103
+ providers = [
104
+ KilnEmbeddingModelProvider(
105
+ name=ModelProviderName.openai,
106
+ model_id="model-1",
107
+ n_dimensions=1536,
108
+ max_input_tokens=8192,
109
+ ),
110
+ KilnEmbeddingModelProvider(
111
+ name=ModelProviderName.anthropic,
112
+ model_id="model-1",
113
+ n_dimensions=1536,
114
+ max_input_tokens=8192,
115
+ ),
116
+ ]
117
+
118
+ model = KilnEmbeddingModel(
119
+ family=KilnEmbeddingModelFamily.openai,
120
+ name=EmbeddingModelName.openai_text_embedding_3_small,
121
+ friendly_name="text-embedding-3-small",
122
+ providers=providers,
123
+ )
124
+
125
+ assert len(model.providers) == 2
126
+ assert model.providers[0].name == ModelProviderName.openai
127
+ assert model.providers[1].name == ModelProviderName.anthropic
128
+
129
+
130
+ class TestGetModelByName:
131
+ def test_get_nonexistent_model_raises_error(self):
132
+ """Test that getting a nonexistent model raises ValueError"""
133
+ with pytest.raises(
134
+ ValueError, match="Embedding model nonexistent_model not found"
135
+ ):
136
+ get_model_by_name("nonexistent_model") # type: ignore
137
+
138
+ @pytest.mark.parametrize(
139
+ "model_name",
140
+ [model.name for model in built_in_embedding_models],
141
+ )
142
+ def test_model_retrieval(self, model_name):
143
+ """Test retrieving models with parametrized test cases"""
144
+ model = get_model_by_name(model_name)
145
+ assert model.family == model.family
146
+ assert model.friendly_name == model.friendly_name
147
+
148
+
149
+ class TestBuiltInEmbeddingModelsFromProvider:
150
+ @pytest.mark.parametrize(
151
+ "model_name,provider_name", get_all_embedding_models_and_providers()
152
+ )
153
+ def test_get_all_existing_models_and_providers(self, model_name, provider_name):
154
+ provider = built_in_embedding_models_from_provider(provider_name, model_name)
155
+
156
+ assert provider is not None
157
+ assert provider.name == provider_name
158
+ assert provider.model_id == provider.model_id
159
+ assert provider.n_dimensions == provider.n_dimensions
160
+ assert provider.max_input_tokens == provider.max_input_tokens
161
+ assert (
162
+ provider.supports_custom_dimensions == provider.supports_custom_dimensions
163
+ )
164
+
165
+ def test_get_nonexistent_model_returns_none(self):
166
+ """Test that getting a nonexistent model returns None"""
167
+ provider = built_in_embedding_models_from_provider(
168
+ provider_name=ModelProviderName.openai,
169
+ model_name="nonexistent_model",
170
+ )
171
+ assert provider is None
172
+
173
+ def test_get_wrong_provider_for_model_returns_none(self):
174
+ """Test that getting wrong provider for a model returns None"""
175
+ provider = built_in_embedding_models_from_provider(
176
+ provider_name=ModelProviderName.gemini_api,
177
+ model_name=EmbeddingModelName.openai_text_embedding_3_small,
178
+ )
179
+ assert provider is None
180
+
181
+
182
+ class TestGenerateEmbedding:
183
+ """Test cases for generate_embedding function"""
184
+
185
+ @pytest.mark.parametrize(
186
+ "model_name,provider_name", get_all_embedding_models_and_providers()
187
+ )
188
+ @pytest.mark.paid
189
+ async def test_generate_embedding(self, model_name, provider_name):
190
+ """Test generating an embedding"""
191
+ model_provider = built_in_embedding_models_from_provider(
192
+ provider_name, model_name
193
+ )
194
+ assert model_provider is not None
195
+
196
+ embedding = embedding_adapter_from_type(
197
+ EmbeddingConfig(
198
+ name="test-embedding",
199
+ model_provider_name=provider_name,
200
+ model_name=model_name,
201
+ properties={},
202
+ )
203
+ )
204
+ embedding = await embedding.generate_embeddings(["Hello, world!"])
205
+ assert len(embedding.embeddings) == 1
206
+ assert len(embedding.embeddings[0].vector) == model_provider.n_dimensions
207
+
208
+ @pytest.mark.parametrize(
209
+ "model_name,provider_name", get_all_embedding_models_and_providers()
210
+ )
211
+ @pytest.mark.paid
212
+ async def test_generate_embedding_with_user_supplied_dimensions(
213
+ self, model_name, provider_name
214
+ ):
215
+ """Test generating an embedding with user supplied dimensions"""
216
+ model_provider = built_in_embedding_models_from_provider(
217
+ provider_name=provider_name,
218
+ model_name=model_name,
219
+ )
220
+ assert model_provider is not None
221
+
222
+ if not model_provider.supports_custom_dimensions:
223
+ pytest.skip("Model does not support custom dimensions")
224
+
225
+ # max dim
226
+ max_dimensions = model_provider.n_dimensions
227
+ dimensions_target = max_dimensions // 2
228
+
229
+ embedding = embedding_adapter_from_type(
230
+ EmbeddingConfig(
231
+ name="test-embedding",
232
+ model_provider_name=provider_name,
233
+ model_name=model_name,
234
+ properties={"dimensions": dimensions_target},
235
+ )
236
+ )
237
+ embedding = await embedding.generate_embeddings(["Hello, world!"])
238
+ assert len(embedding.embeddings) == 1
239
+ assert len(embedding.embeddings[0].vector) == dimensions_target
@@ -5,6 +5,7 @@ import pytest
5
5
  from kiln_ai.adapters.ml_model_list import (
6
6
  ModelName,
7
7
  built_in_models,
8
+ built_in_models_from_provider,
8
9
  default_structured_output_mode_for_model_provider,
9
10
  get_model_by_name,
10
11
  )
@@ -161,6 +162,186 @@ class TestDefaultStructuredOutputModeForModelProvider:
161
162
  assert result == first_provider.structured_output_mode
162
163
 
163
164
 
165
+ class TestBuiltInModelsFromProvider:
166
+ """Test cases for built_in_models_from_provider function"""
167
+
168
+ def test_valid_model_and_provider_returns_provider(self):
169
+ """Test that valid model and provider returns the correct provider configuration"""
170
+ # GPT 4.1 has OpenAI provider
171
+ result = built_in_models_from_provider(
172
+ provider_name=ModelProviderName.openai,
173
+ model_name="gpt_4_1",
174
+ )
175
+ assert result is not None
176
+ assert result.name == ModelProviderName.openai
177
+ assert result.model_id == "gpt-4.1"
178
+ assert result.supports_logprobs is True
179
+ assert result.suggested_for_data_gen is True
180
+
181
+ def test_valid_model_different_provider_returns_correct_provider(self):
182
+ """Test that different providers for the same model return different configurations"""
183
+ # GPT 4.1 has multiple providers with different configurations
184
+ openai_provider = built_in_models_from_provider(
185
+ provider_name=ModelProviderName.openai,
186
+ model_name="gpt_4_1",
187
+ )
188
+ openrouter_provider = built_in_models_from_provider(
189
+ provider_name=ModelProviderName.openrouter,
190
+ model_name="gpt_4_1",
191
+ )
192
+
193
+ assert openai_provider is not None
194
+ assert openrouter_provider is not None
195
+ assert openai_provider.name == ModelProviderName.openai
196
+ assert openrouter_provider.name == ModelProviderName.openrouter
197
+ assert openai_provider.model_id == "gpt-4.1"
198
+ assert openrouter_provider.model_id == "openai/gpt-4.1"
199
+
200
+ def test_invalid_model_name_returns_none(self):
201
+ """Test that invalid model name returns None"""
202
+ result = built_in_models_from_provider(
203
+ provider_name=ModelProviderName.openai,
204
+ model_name="invalid_model_name",
205
+ )
206
+ assert result is None
207
+
208
+ def test_valid_model_invalid_provider_returns_none(self):
209
+ """Test that valid model but invalid provider returns None"""
210
+ result = built_in_models_from_provider(
211
+ provider_name=ModelProviderName.gemini_api, # GPT 4.1 doesn't have gemini_api provider
212
+ model_name="gpt_4_1",
213
+ )
214
+ assert result is None
215
+
216
+ def test_model_with_single_provider(self):
217
+ """Test model that only has one provider"""
218
+ # Find a model with only one provider for this test
219
+ model = get_model_by_name(ModelName.gpt_4_1_nano)
220
+ assert len(model.providers) >= 1 # Verify it has providers
221
+
222
+ first_provider = model.providers[0]
223
+ result = built_in_models_from_provider(
224
+ provider_name=first_provider.name,
225
+ model_name="gpt_4_1_nano",
226
+ )
227
+ assert result is not None
228
+ assert result.name == first_provider.name
229
+ assert result.model_id == first_provider.model_id
230
+
231
+ def test_model_with_multiple_providers(self):
232
+ """Test model that has multiple providers"""
233
+ # GPT 4.1 has multiple providers
234
+ openai_result = built_in_models_from_provider(
235
+ provider_name=ModelProviderName.openai,
236
+ model_name="gpt_4_1",
237
+ )
238
+ openrouter_result = built_in_models_from_provider(
239
+ provider_name=ModelProviderName.openrouter,
240
+ model_name="gpt_4_1",
241
+ )
242
+ azure_result = built_in_models_from_provider(
243
+ provider_name=ModelProviderName.azure_openai,
244
+ model_name="gpt_4_1",
245
+ )
246
+
247
+ assert openai_result is not None
248
+ assert openrouter_result is not None
249
+ assert azure_result is not None
250
+ assert openai_result.name == ModelProviderName.openai
251
+ assert openrouter_result.name == ModelProviderName.openrouter
252
+ assert azure_result.name == ModelProviderName.azure_openai
253
+
254
+ def test_case_sensitive_model_name(self):
255
+ """Test that model name matching is case sensitive"""
256
+ # Should return None for case-mismatched model name
257
+ result = built_in_models_from_provider(
258
+ provider_name=ModelProviderName.openai,
259
+ model_name="GPT_4_1", # Wrong case
260
+ )
261
+ assert result is None
262
+
263
+ def test_provider_specific_attributes(self):
264
+ """Test that provider-specific attributes are correctly returned"""
265
+ # Test a provider with specific attributes like multimodal_capable
266
+ result = built_in_models_from_provider(
267
+ provider_name=ModelProviderName.openai,
268
+ model_name="gpt_4_1",
269
+ )
270
+ assert result is not None
271
+ assert result.multimodal_capable is True
272
+ assert result.supports_doc_extraction is True
273
+ assert result.multimodal_mime_types is not None
274
+ assert len(result.multimodal_mime_types) > 0
275
+
276
+ def test_provider_without_special_attributes(self):
277
+ """Test provider that doesn't have special attributes"""
278
+ # Test a simpler provider configuration
279
+ result = built_in_models_from_provider(
280
+ provider_name=ModelProviderName.openai,
281
+ model_name="gpt_4_1_nano",
282
+ )
283
+ assert result is not None
284
+ assert result.multimodal_capable is False # Should be False for nano
285
+ assert result.supports_doc_extraction is False
286
+
287
+ @pytest.mark.parametrize(
288
+ "model_name,provider,expected_model_id",
289
+ [
290
+ ("gpt_4o", ModelProviderName.openai, "gpt-4o"),
291
+ (
292
+ "claude_3_5_haiku",
293
+ ModelProviderName.anthropic,
294
+ "claude-3-5-haiku-20241022",
295
+ ),
296
+ ("gemini_2_5_pro", ModelProviderName.gemini_api, "gemini-2.5-pro"),
297
+ ("llama_3_1_8b", ModelProviderName.groq, "llama-3.1-8b-instant"),
298
+ ],
299
+ )
300
+ def test_parametrized_valid_combinations(
301
+ self, model_name, provider, expected_model_id
302
+ ):
303
+ """Test multiple valid model/provider combinations"""
304
+ result = built_in_models_from_provider(
305
+ provider_name=provider,
306
+ model_name=model_name,
307
+ )
308
+ assert result is not None
309
+ assert result.name == provider
310
+ assert result.model_id == expected_model_id
311
+
312
+ def test_empty_string_model_name(self):
313
+ """Test that empty string model name returns None"""
314
+ result = built_in_models_from_provider(
315
+ provider_name=ModelProviderName.openai,
316
+ model_name="",
317
+ )
318
+ assert result is None
319
+
320
+ def test_none_model_name(self):
321
+ """Test that None model name returns None"""
322
+ result = built_in_models_from_provider(
323
+ provider_name=ModelProviderName.openai,
324
+ model_name=None, # type: ignore
325
+ )
326
+ assert result is None
327
+
328
+ def test_all_built_in_models_have_valid_providers(self):
329
+ """Test that all built-in models have at least one valid provider"""
330
+ for model in built_in_models:
331
+ assert len(model.providers) > 0, f"Model {model.name} has no providers"
332
+ for provider in model.providers:
333
+ # Test that we can retrieve each provider
334
+ result = built_in_models_from_provider(
335
+ provider_name=provider.name,
336
+ model_name=model.name,
337
+ )
338
+ assert result is not None, (
339
+ f"Could not retrieve provider {provider.name} for model {model.name}"
340
+ )
341
+ assert result.name == provider.name
342
+ assert result.model_id == provider.model_id
343
+
344
+
164
345
  def test_uncensored():
165
346
  """Test that uncensored is set correctly"""
166
347
  model = get_model_by_name(ModelName.grok_3_mini)
@@ -179,6 +360,27 @@ def test_uncensored():
179
360
  assert provider.suggested_for_uncensored_data_gen
180
361
 
181
362
 
363
+ def test_no_empty_multimodal_mime_types():
364
+ """Ensure that multimodal fields are self-consistent as they are interdependent"""
365
+ for model in built_in_models:
366
+ for provider in model.providers:
367
+ # a multimodal model should always have mime types it supports
368
+ if provider.multimodal_capable:
369
+ assert provider.multimodal_mime_types is not None
370
+ assert len(provider.multimodal_mime_types) > 0
371
+
372
+ # a model that specifies mime types is necessarily multimodal
373
+ if (
374
+ provider.multimodal_mime_types is not None
375
+ and len(provider.multimodal_mime_types) > 0
376
+ ):
377
+ assert provider.multimodal_capable
378
+
379
+ # a model that supports doc extraction is necessarily multimodal
380
+ if provider.supports_doc_extraction:
381
+ assert provider.multimodal_capable
382
+
383
+
182
384
  def test_no_reasoning_for_structured_output():
183
385
  """Test that no reasoning is returned for structured output"""
184
386
  # get all models