kiln-ai 0.20.1__py3-none-any.whl → 0.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (117) hide show
  1. kiln_ai/adapters/__init__.py +6 -0
  2. kiln_ai/adapters/adapter_registry.py +43 -226
  3. kiln_ai/adapters/chunkers/__init__.py +13 -0
  4. kiln_ai/adapters/chunkers/base_chunker.py +42 -0
  5. kiln_ai/adapters/chunkers/chunker_registry.py +16 -0
  6. kiln_ai/adapters/chunkers/fixed_window_chunker.py +39 -0
  7. kiln_ai/adapters/chunkers/helpers.py +23 -0
  8. kiln_ai/adapters/chunkers/test_base_chunker.py +63 -0
  9. kiln_ai/adapters/chunkers/test_chunker_registry.py +28 -0
  10. kiln_ai/adapters/chunkers/test_fixed_window_chunker.py +346 -0
  11. kiln_ai/adapters/chunkers/test_helpers.py +75 -0
  12. kiln_ai/adapters/data_gen/test_data_gen_task.py +9 -3
  13. kiln_ai/adapters/embedding/__init__.py +0 -0
  14. kiln_ai/adapters/embedding/base_embedding_adapter.py +44 -0
  15. kiln_ai/adapters/embedding/embedding_registry.py +32 -0
  16. kiln_ai/adapters/embedding/litellm_embedding_adapter.py +199 -0
  17. kiln_ai/adapters/embedding/test_base_embedding_adapter.py +283 -0
  18. kiln_ai/adapters/embedding/test_embedding_registry.py +166 -0
  19. kiln_ai/adapters/embedding/test_litellm_embedding_adapter.py +1149 -0
  20. kiln_ai/adapters/eval/eval_runner.py +6 -2
  21. kiln_ai/adapters/eval/test_base_eval.py +1 -3
  22. kiln_ai/adapters/eval/test_g_eval.py +1 -1
  23. kiln_ai/adapters/extractors/__init__.py +18 -0
  24. kiln_ai/adapters/extractors/base_extractor.py +72 -0
  25. kiln_ai/adapters/extractors/encoding.py +20 -0
  26. kiln_ai/adapters/extractors/extractor_registry.py +44 -0
  27. kiln_ai/adapters/extractors/extractor_runner.py +112 -0
  28. kiln_ai/adapters/extractors/litellm_extractor.py +386 -0
  29. kiln_ai/adapters/extractors/test_base_extractor.py +244 -0
  30. kiln_ai/adapters/extractors/test_encoding.py +54 -0
  31. kiln_ai/adapters/extractors/test_extractor_registry.py +181 -0
  32. kiln_ai/adapters/extractors/test_extractor_runner.py +181 -0
  33. kiln_ai/adapters/extractors/test_litellm_extractor.py +1192 -0
  34. kiln_ai/adapters/fine_tune/test_dataset_formatter.py +2 -2
  35. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +2 -6
  36. kiln_ai/adapters/fine_tune/test_together_finetune.py +2 -6
  37. kiln_ai/adapters/ml_embedding_model_list.py +192 -0
  38. kiln_ai/adapters/ml_model_list.py +382 -4
  39. kiln_ai/adapters/model_adapters/litellm_adapter.py +7 -69
  40. kiln_ai/adapters/model_adapters/test_litellm_adapter.py +1 -1
  41. kiln_ai/adapters/model_adapters/test_structured_output.py +3 -1
  42. kiln_ai/adapters/ollama_tools.py +69 -12
  43. kiln_ai/adapters/provider_tools.py +190 -46
  44. kiln_ai/adapters/rag/deduplication.py +49 -0
  45. kiln_ai/adapters/rag/progress.py +252 -0
  46. kiln_ai/adapters/rag/rag_runners.py +844 -0
  47. kiln_ai/adapters/rag/test_deduplication.py +195 -0
  48. kiln_ai/adapters/rag/test_progress.py +785 -0
  49. kiln_ai/adapters/rag/test_rag_runners.py +2376 -0
  50. kiln_ai/adapters/remote_config.py +80 -8
  51. kiln_ai/adapters/test_adapter_registry.py +579 -86
  52. kiln_ai/adapters/test_ml_embedding_model_list.py +429 -0
  53. kiln_ai/adapters/test_ml_model_list.py +212 -0
  54. kiln_ai/adapters/test_ollama_tools.py +340 -1
  55. kiln_ai/adapters/test_prompt_builders.py +1 -1
  56. kiln_ai/adapters/test_provider_tools.py +199 -8
  57. kiln_ai/adapters/test_remote_config.py +551 -56
  58. kiln_ai/adapters/vector_store/__init__.py +1 -0
  59. kiln_ai/adapters/vector_store/base_vector_store_adapter.py +83 -0
  60. kiln_ai/adapters/vector_store/lancedb_adapter.py +389 -0
  61. kiln_ai/adapters/vector_store/test_base_vector_store.py +160 -0
  62. kiln_ai/adapters/vector_store/test_lancedb_adapter.py +1841 -0
  63. kiln_ai/adapters/vector_store/test_vector_store_registry.py +199 -0
  64. kiln_ai/adapters/vector_store/vector_store_registry.py +33 -0
  65. kiln_ai/datamodel/__init__.py +16 -13
  66. kiln_ai/datamodel/basemodel.py +170 -1
  67. kiln_ai/datamodel/chunk.py +158 -0
  68. kiln_ai/datamodel/datamodel_enums.py +27 -0
  69. kiln_ai/datamodel/embedding.py +64 -0
  70. kiln_ai/datamodel/extraction.py +303 -0
  71. kiln_ai/datamodel/project.py +33 -1
  72. kiln_ai/datamodel/rag.py +79 -0
  73. kiln_ai/datamodel/test_attachment.py +649 -0
  74. kiln_ai/datamodel/test_basemodel.py +1 -1
  75. kiln_ai/datamodel/test_chunk_models.py +317 -0
  76. kiln_ai/datamodel/test_dataset_split.py +1 -1
  77. kiln_ai/datamodel/test_embedding_models.py +448 -0
  78. kiln_ai/datamodel/test_eval_model.py +6 -6
  79. kiln_ai/datamodel/test_extraction_chunk.py +206 -0
  80. kiln_ai/datamodel/test_extraction_model.py +470 -0
  81. kiln_ai/datamodel/test_rag.py +641 -0
  82. kiln_ai/datamodel/test_tool_id.py +81 -0
  83. kiln_ai/datamodel/test_vector_store.py +320 -0
  84. kiln_ai/datamodel/tool_id.py +22 -0
  85. kiln_ai/datamodel/vector_store.py +141 -0
  86. kiln_ai/tools/mcp_session_manager.py +4 -1
  87. kiln_ai/tools/rag_tools.py +157 -0
  88. kiln_ai/tools/test_mcp_session_manager.py +1 -1
  89. kiln_ai/tools/test_rag_tools.py +848 -0
  90. kiln_ai/tools/test_tool_registry.py +91 -2
  91. kiln_ai/tools/tool_registry.py +21 -0
  92. kiln_ai/utils/__init__.py +3 -0
  93. kiln_ai/utils/async_job_runner.py +62 -17
  94. kiln_ai/utils/config.py +2 -2
  95. kiln_ai/utils/env.py +15 -0
  96. kiln_ai/utils/filesystem.py +14 -0
  97. kiln_ai/utils/filesystem_cache.py +60 -0
  98. kiln_ai/utils/litellm.py +94 -0
  99. kiln_ai/utils/lock.py +100 -0
  100. kiln_ai/utils/mime_type.py +38 -0
  101. kiln_ai/utils/pdf_utils.py +38 -0
  102. kiln_ai/utils/test_async_job_runner.py +151 -35
  103. kiln_ai/utils/test_env.py +142 -0
  104. kiln_ai/utils/test_filesystem_cache.py +316 -0
  105. kiln_ai/utils/test_litellm.py +206 -0
  106. kiln_ai/utils/test_lock.py +185 -0
  107. kiln_ai/utils/test_mime_type.py +66 -0
  108. kiln_ai/utils/test_pdf_utils.py +73 -0
  109. kiln_ai/utils/test_uuid.py +111 -0
  110. kiln_ai/utils/test_validation.py +524 -0
  111. kiln_ai/utils/uuid.py +9 -0
  112. kiln_ai/utils/validation.py +90 -0
  113. {kiln_ai-0.20.1.dist-info → kiln_ai-0.21.0.dist-info}/METADATA +7 -1
  114. kiln_ai-0.21.0.dist-info/RECORD +211 -0
  115. kiln_ai-0.20.1.dist-info/RECORD +0 -138
  116. {kiln_ai-0.20.1.dist-info → kiln_ai-0.21.0.dist-info}/WHEEL +0 -0
  117. {kiln_ai-0.20.1.dist-info → kiln_ai-0.21.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -0,0 +1,429 @@
1
+ import pytest
2
+
3
+ from kiln_ai.adapters.ml_embedding_model_list import (
4
+ EmbeddingModelName,
5
+ KilnEmbeddingModel,
6
+ KilnEmbeddingModelFamily,
7
+ KilnEmbeddingModelProvider,
8
+ built_in_embedding_models,
9
+ built_in_embedding_models_from_provider,
10
+ get_model_by_name,
11
+ )
12
+ from kiln_ai.datamodel.datamodel_enums import ModelProviderName
13
+
14
+
15
+ class TestEmbeddingModelName:
16
+ """Test cases for EmbeddingModelName enum"""
17
+
18
+ def test_enum_values(self):
19
+ """Test that enum values are correctly defined"""
20
+ assert (
21
+ EmbeddingModelName.openai_text_embedding_3_small
22
+ == "openai_text_embedding_3_small"
23
+ )
24
+ assert (
25
+ EmbeddingModelName.openai_text_embedding_3_large
26
+ == "openai_text_embedding_3_large"
27
+ )
28
+ assert (
29
+ EmbeddingModelName.gemini_text_embedding_004 == "gemini_text_embedding_004"
30
+ )
31
+
32
+
33
+ class TestKilnEmbeddingModelProvider:
34
+ """Test cases for KilnEmbeddingModelProvider model"""
35
+
36
+ def test_basic_provider_creation(self):
37
+ """Test creating a basic provider with required fields"""
38
+ provider = KilnEmbeddingModelProvider(
39
+ name=ModelProviderName.openai,
40
+ model_id="text-embedding-3-small",
41
+ max_input_tokens=8192,
42
+ n_dimensions=1536,
43
+ supports_custom_dimensions=True,
44
+ )
45
+
46
+ assert provider.name == ModelProviderName.openai
47
+ assert provider.model_id == "text-embedding-3-small"
48
+ assert provider.max_input_tokens == 8192
49
+ assert provider.n_dimensions == 1536
50
+ assert provider.supports_custom_dimensions is True
51
+
52
+ def test_provider_with_optional_fields_unspecified(self):
53
+ """Test creating a provider with optional fields not specified"""
54
+ provider = KilnEmbeddingModelProvider(
55
+ name=ModelProviderName.gemini_api,
56
+ model_id="text-embedding-004",
57
+ n_dimensions=768,
58
+ )
59
+
60
+ assert provider.name == ModelProviderName.gemini_api
61
+ assert provider.model_id == "text-embedding-004"
62
+ assert provider.max_input_tokens is None
63
+ assert provider.n_dimensions == 768
64
+ assert provider.supports_custom_dimensions is False
65
+
66
+
67
+ class TestKilnEmbeddingModel:
68
+ """Test cases for KilnEmbeddingModel model"""
69
+
70
+ def test_basic_model_creation(self):
71
+ """Test creating a basic model with required fields"""
72
+ providers = [
73
+ KilnEmbeddingModelProvider(
74
+ name=ModelProviderName.openai,
75
+ model_id="text-embedding-3-small",
76
+ n_dimensions=1536,
77
+ max_input_tokens=8192,
78
+ )
79
+ ]
80
+
81
+ model = KilnEmbeddingModel(
82
+ family=KilnEmbeddingModelFamily.openai,
83
+ name=EmbeddingModelName.openai_text_embedding_3_small,
84
+ friendly_name="Text Embedding 3 Small",
85
+ providers=providers,
86
+ )
87
+
88
+ assert model.family == KilnEmbeddingModelFamily.openai
89
+ assert model.name == EmbeddingModelName.openai_text_embedding_3_small
90
+ assert model.friendly_name == "Text Embedding 3 Small"
91
+ assert len(model.providers) == 1
92
+ assert model.providers[0].name == ModelProviderName.openai
93
+
94
+ def test_model_with_multiple_providers(self):
95
+ """Test creating a model with multiple providers"""
96
+ providers = [
97
+ KilnEmbeddingModelProvider(
98
+ name=ModelProviderName.openai,
99
+ model_id="model-1",
100
+ n_dimensions=1536,
101
+ max_input_tokens=8192,
102
+ ),
103
+ KilnEmbeddingModelProvider(
104
+ name=ModelProviderName.anthropic,
105
+ model_id="model-1",
106
+ n_dimensions=1536,
107
+ max_input_tokens=8192,
108
+ ),
109
+ ]
110
+
111
+ model = KilnEmbeddingModel(
112
+ family=KilnEmbeddingModelFamily.openai,
113
+ name=EmbeddingModelName.openai_text_embedding_3_small,
114
+ friendly_name="text-embedding-3-small",
115
+ providers=providers,
116
+ )
117
+
118
+ assert len(model.providers) == 2
119
+ assert model.providers[0].name == ModelProviderName.openai
120
+ assert model.providers[1].name == ModelProviderName.anthropic
121
+
122
+
123
+ class TestEmbeddingModelsList:
124
+ """Test cases for the embedding_models list"""
125
+
126
+ def test_embedding_models_not_empty(self):
127
+ """Test that the embedding_models list is not empty"""
128
+ assert len(built_in_embedding_models) > 0
129
+
130
+ def test_all_models_have_required_fields(self):
131
+ """Test that all models in the list have required fields"""
132
+ for model in built_in_embedding_models:
133
+ assert hasattr(model, "family")
134
+ assert hasattr(model, "name")
135
+ assert hasattr(model, "friendly_name")
136
+ assert hasattr(model, "providers")
137
+ assert isinstance(model.name, str)
138
+ assert isinstance(model.friendly_name, str)
139
+ assert isinstance(model.providers, list)
140
+ assert len(model.providers) > 0
141
+
142
+ def test_all_providers_have_required_fields(self):
143
+ """Test that all providers in all models have required fields"""
144
+ for model in built_in_embedding_models:
145
+ for provider in model.providers:
146
+ assert hasattr(provider, "name")
147
+ assert isinstance(provider.name, ModelProviderName)
148
+
149
+ def test_model_names_are_unique(self):
150
+ """Test that all model names in the list are unique"""
151
+ model_names = [model.name for model in built_in_embedding_models]
152
+ assert len(model_names) == len(set(model_names))
153
+
154
+ def test_specific_models_exist(self):
155
+ """Test that specific expected models exist in the list"""
156
+ model_names = [model.name for model in built_in_embedding_models]
157
+
158
+ assert EmbeddingModelName.openai_text_embedding_3_small in model_names
159
+ assert EmbeddingModelName.openai_text_embedding_3_large in model_names
160
+ assert EmbeddingModelName.gemini_text_embedding_004 in model_names
161
+
162
+ def test_openai_embedding_models(self):
163
+ """Test specific OpenAI embedding models"""
164
+ openai_models = [
165
+ model
166
+ for model in built_in_embedding_models
167
+ if model.family == KilnEmbeddingModelFamily.openai
168
+ ]
169
+
170
+ assert len(openai_models) >= 2 # Should have at least 2 OpenAI models
171
+
172
+ # Check for specific OpenAI models
173
+ openai_model_names = [model.name for model in openai_models]
174
+ assert EmbeddingModelName.openai_text_embedding_3_small in openai_model_names
175
+ assert EmbeddingModelName.openai_text_embedding_3_large in openai_model_names
176
+
177
+ def test_gemini_embedding_models(self):
178
+ """Test specific Gemini embedding models"""
179
+ gemini_models = [
180
+ model
181
+ for model in built_in_embedding_models
182
+ if model.family == KilnEmbeddingModelFamily.gemini
183
+ ]
184
+
185
+ assert len(gemini_models) >= 1 # Should have at least 1 Gemini model
186
+
187
+ # Check for specific Gemini model
188
+ gemini_model_names = [model.name for model in gemini_models]
189
+ assert EmbeddingModelName.gemini_text_embedding_004 in gemini_model_names
190
+
191
+ def test_openai_text_embedding_3_small_details(self):
192
+ """Test specific details of OpenAI text-embedding-3-small model"""
193
+ model = get_model_by_name(EmbeddingModelName.openai_text_embedding_3_small)
194
+
195
+ assert model.family == KilnEmbeddingModelFamily.openai
196
+ assert model.friendly_name == "Text Embedding 3 Small"
197
+ assert len(model.providers) == 1
198
+
199
+ provider = model.providers[0]
200
+ assert provider.name == ModelProviderName.openai
201
+ assert provider.model_id == "text-embedding-3-small"
202
+ assert provider.n_dimensions == 1536
203
+ assert provider.max_input_tokens == 8192
204
+ assert provider.supports_custom_dimensions is True
205
+
206
+ def test_openai_text_embedding_3_large_details(self):
207
+ """Test specific details of OpenAI text-embedding-3-large model"""
208
+ model = get_model_by_name(EmbeddingModelName.openai_text_embedding_3_large)
209
+
210
+ assert model.family == KilnEmbeddingModelFamily.openai
211
+ assert model.friendly_name == "Text Embedding 3 Large"
212
+ assert len(model.providers) == 1
213
+
214
+ provider = model.providers[0]
215
+ assert provider.name == ModelProviderName.openai
216
+ assert provider.model_id == "text-embedding-3-large"
217
+ assert provider.n_dimensions == 3072
218
+ assert provider.max_input_tokens == 8192
219
+ assert provider.supports_custom_dimensions is True
220
+
221
+ def test_gemini_text_embedding_004_details(self):
222
+ """Test specific details of Gemini text-embedding-004 model"""
223
+ model = get_model_by_name(EmbeddingModelName.gemini_text_embedding_004)
224
+
225
+ assert model.family == KilnEmbeddingModelFamily.gemini
226
+ assert model.friendly_name == "Text Embedding 004"
227
+ assert len(model.providers) == 1
228
+
229
+ provider = model.providers[0]
230
+ assert provider.name == ModelProviderName.gemini_api
231
+ assert provider.model_id == "text-embedding-004"
232
+ assert provider.n_dimensions == 768
233
+ assert provider.max_input_tokens == 2048
234
+ assert provider.supports_custom_dimensions is False
235
+
236
+
237
+ class TestGetModelByName:
238
+ """Test cases for get_model_by_name function"""
239
+
240
+ def test_get_existing_model(self):
241
+ """Test getting an existing model by name"""
242
+ model = get_model_by_name(EmbeddingModelName.openai_text_embedding_3_small)
243
+ assert model.name == EmbeddingModelName.openai_text_embedding_3_small
244
+ assert model.family == KilnEmbeddingModelFamily.openai
245
+
246
+ def test_get_all_existing_models(self):
247
+ """Test getting all existing models by name"""
248
+ for model_name in EmbeddingModelName:
249
+ model = get_model_by_name(model_name)
250
+ assert model.name == model_name
251
+
252
+ def test_get_nonexistent_model_raises_error(self):
253
+ """Test that getting a nonexistent model raises ValueError"""
254
+ with pytest.raises(
255
+ ValueError, match="Embedding model nonexistent_model not found"
256
+ ):
257
+ get_model_by_name("nonexistent_model")
258
+
259
+ def test_get_model_with_invalid_enum_value(self):
260
+ """Test that getting a model with invalid enum value raises ValueError"""
261
+ with pytest.raises(ValueError, match="Embedding model invalid_enum not found"):
262
+ get_model_by_name("invalid_enum")
263
+
264
+ @pytest.mark.parametrize(
265
+ "model_name,expected_family,expected_friendly_name",
266
+ [
267
+ (
268
+ EmbeddingModelName.openai_text_embedding_3_small,
269
+ KilnEmbeddingModelFamily.openai,
270
+ "Text Embedding 3 Small",
271
+ ),
272
+ (
273
+ EmbeddingModelName.openai_text_embedding_3_large,
274
+ KilnEmbeddingModelFamily.openai,
275
+ "Text Embedding 3 Large",
276
+ ),
277
+ (
278
+ EmbeddingModelName.gemini_text_embedding_004,
279
+ KilnEmbeddingModelFamily.gemini,
280
+ "Text Embedding 004",
281
+ ),
282
+ ],
283
+ )
284
+ def test_parametrized_model_retrieval(
285
+ self, model_name, expected_family, expected_friendly_name
286
+ ):
287
+ """Test retrieving models with parametrized test cases"""
288
+ model = get_model_by_name(model_name)
289
+ assert model.family == expected_family
290
+ assert model.friendly_name == expected_friendly_name
291
+
292
+
293
+ class TestBuiltInEmbeddingModelsFromProvider:
294
+ """Test cases for built_in_embedding_models_from_provider function"""
295
+
296
+ def test_get_existing_provider_for_model(self):
297
+ """Test getting an existing provider for a model"""
298
+ provider = built_in_embedding_models_from_provider(
299
+ provider_name=ModelProviderName.openai,
300
+ model_name=EmbeddingModelName.openai_text_embedding_3_small,
301
+ )
302
+
303
+ assert provider is not None
304
+ assert provider.name == ModelProviderName.openai
305
+ assert provider.model_id == "text-embedding-3-small"
306
+ assert provider.n_dimensions == 1536
307
+
308
+ def test_get_all_existing_provider_model_combinations(self):
309
+ """Test getting all existing provider-model combinations"""
310
+ combinations = [
311
+ (
312
+ ModelProviderName.openai,
313
+ EmbeddingModelName.openai_text_embedding_3_small,
314
+ ),
315
+ (
316
+ ModelProviderName.openai,
317
+ EmbeddingModelName.openai_text_embedding_3_large,
318
+ ),
319
+ (
320
+ ModelProviderName.gemini_api,
321
+ EmbeddingModelName.gemini_text_embedding_004,
322
+ ),
323
+ ]
324
+
325
+ for provider_name, model_name in combinations:
326
+ provider = built_in_embedding_models_from_provider(
327
+ provider_name, model_name
328
+ )
329
+ assert provider is not None
330
+ assert provider.name == provider_name
331
+
332
+ def test_get_nonexistent_provider_returns_none(self):
333
+ """Test that getting a nonexistent provider returns None"""
334
+ provider = built_in_embedding_models_from_provider(
335
+ provider_name=ModelProviderName.anthropic, # Not used for embeddings
336
+ model_name=EmbeddingModelName.openai_text_embedding_3_small,
337
+ )
338
+ assert provider is None
339
+
340
+ def test_get_nonexistent_model_returns_none(self):
341
+ """Test that getting a nonexistent model returns None"""
342
+ provider = built_in_embedding_models_from_provider(
343
+ provider_name=ModelProviderName.openai,
344
+ model_name="nonexistent_model",
345
+ )
346
+ assert provider is None
347
+
348
+ def test_get_wrong_provider_for_model_returns_none(self):
349
+ """Test that getting wrong provider for a model returns None"""
350
+ provider = built_in_embedding_models_from_provider(
351
+ provider_name=ModelProviderName.gemini_api,
352
+ model_name=EmbeddingModelName.openai_text_embedding_3_small,
353
+ )
354
+ assert provider is None
355
+
356
+ def test_get_openai_text_embedding_3_small_provider_details(self):
357
+ """Test specific details of OpenAI text-embedding-3-small provider"""
358
+ provider = built_in_embedding_models_from_provider(
359
+ provider_name=ModelProviderName.openai,
360
+ model_name=EmbeddingModelName.openai_text_embedding_3_small,
361
+ )
362
+
363
+ assert provider is not None
364
+ assert provider.name == ModelProviderName.openai
365
+ assert provider.model_id == "text-embedding-3-small"
366
+ assert provider.n_dimensions == 1536
367
+ assert provider.max_input_tokens == 8192
368
+ assert provider.supports_custom_dimensions is True
369
+
370
+ def test_get_openai_text_embedding_3_large_provider_details(self):
371
+ """Test specific details of OpenAI text-embedding-3-large provider"""
372
+ provider = built_in_embedding_models_from_provider(
373
+ provider_name=ModelProviderName.openai,
374
+ model_name=EmbeddingModelName.openai_text_embedding_3_large,
375
+ )
376
+
377
+ assert provider is not None
378
+ assert provider.name == ModelProviderName.openai
379
+ assert provider.model_id == "text-embedding-3-large"
380
+ assert provider.n_dimensions == 3072
381
+ assert provider.max_input_tokens == 8192
382
+ assert provider.supports_custom_dimensions is True
383
+
384
+ def test_get_gemini_text_embedding_004_provider_details(self):
385
+ """Test specific details of Gemini text-embedding-004 provider"""
386
+ provider = built_in_embedding_models_from_provider(
387
+ provider_name=ModelProviderName.gemini_api,
388
+ model_name=EmbeddingModelName.gemini_text_embedding_004,
389
+ )
390
+
391
+ assert provider is not None
392
+ assert provider.name == ModelProviderName.gemini_api
393
+ assert provider.model_id == "text-embedding-004"
394
+ assert provider.n_dimensions == 768
395
+ assert provider.max_input_tokens == 2048
396
+ assert provider.supports_custom_dimensions is False
397
+
398
+ @pytest.mark.parametrize(
399
+ "provider_name,model_name,expected_model_id,expected_dimensions",
400
+ [
401
+ (
402
+ ModelProviderName.openai,
403
+ EmbeddingModelName.openai_text_embedding_3_small,
404
+ "text-embedding-3-small",
405
+ 1536,
406
+ ),
407
+ (
408
+ ModelProviderName.openai,
409
+ EmbeddingModelName.openai_text_embedding_3_large,
410
+ "text-embedding-3-large",
411
+ 3072,
412
+ ),
413
+ (
414
+ ModelProviderName.gemini_api,
415
+ EmbeddingModelName.gemini_text_embedding_004,
416
+ "text-embedding-004",
417
+ 768,
418
+ ),
419
+ ],
420
+ )
421
+ def test_parametrized_provider_retrieval(
422
+ self, provider_name, model_name, expected_model_id, expected_dimensions
423
+ ):
424
+ """Test retrieving providers with parametrized test cases"""
425
+ provider = built_in_embedding_models_from_provider(provider_name, model_name)
426
+
427
+ assert provider is not None
428
+ assert provider.model_id == expected_model_id
429
+ assert provider.n_dimensions == expected_dimensions
@@ -5,6 +5,7 @@ import pytest
5
5
  from kiln_ai.adapters.ml_model_list import (
6
6
  ModelName,
7
7
  built_in_models,
8
+ built_in_models_from_provider,
8
9
  default_structured_output_mode_for_model_provider,
9
10
  get_model_by_name,
10
11
  )
@@ -161,6 +162,186 @@ class TestDefaultStructuredOutputModeForModelProvider:
161
162
  assert result == first_provider.structured_output_mode
162
163
 
163
164
 
165
+ class TestBuiltInModelsFromProvider:
166
+ """Test cases for built_in_models_from_provider function"""
167
+
168
+ def test_valid_model_and_provider_returns_provider(self):
169
+ """Test that valid model and provider returns the correct provider configuration"""
170
+ # GPT 4.1 has OpenAI provider
171
+ result = built_in_models_from_provider(
172
+ provider_name=ModelProviderName.openai,
173
+ model_name="gpt_4_1",
174
+ )
175
+ assert result is not None
176
+ assert result.name == ModelProviderName.openai
177
+ assert result.model_id == "gpt-4.1"
178
+ assert result.supports_logprobs is True
179
+ assert result.suggested_for_data_gen is True
180
+
181
+ def test_valid_model_different_provider_returns_correct_provider(self):
182
+ """Test that different providers for the same model return different configurations"""
183
+ # GPT 4.1 has multiple providers with different configurations
184
+ openai_provider = built_in_models_from_provider(
185
+ provider_name=ModelProviderName.openai,
186
+ model_name="gpt_4_1",
187
+ )
188
+ openrouter_provider = built_in_models_from_provider(
189
+ provider_name=ModelProviderName.openrouter,
190
+ model_name="gpt_4_1",
191
+ )
192
+
193
+ assert openai_provider is not None
194
+ assert openrouter_provider is not None
195
+ assert openai_provider.name == ModelProviderName.openai
196
+ assert openrouter_provider.name == ModelProviderName.openrouter
197
+ assert openai_provider.model_id == "gpt-4.1"
198
+ assert openrouter_provider.model_id == "openai/gpt-4.1"
199
+
200
+ def test_invalid_model_name_returns_none(self):
201
+ """Test that invalid model name returns None"""
202
+ result = built_in_models_from_provider(
203
+ provider_name=ModelProviderName.openai,
204
+ model_name="invalid_model_name",
205
+ )
206
+ assert result is None
207
+
208
+ def test_valid_model_invalid_provider_returns_none(self):
209
+ """Test that valid model but invalid provider returns None"""
210
+ result = built_in_models_from_provider(
211
+ provider_name=ModelProviderName.gemini_api, # GPT 4.1 doesn't have gemini_api provider
212
+ model_name="gpt_4_1",
213
+ )
214
+ assert result is None
215
+
216
+ def test_model_with_single_provider(self):
217
+ """Test model that only has one provider"""
218
+ # Find a model with only one provider for this test
219
+ model = get_model_by_name(ModelName.gpt_4_1_nano)
220
+ assert len(model.providers) >= 1 # Verify it has providers
221
+
222
+ first_provider = model.providers[0]
223
+ result = built_in_models_from_provider(
224
+ provider_name=first_provider.name,
225
+ model_name="gpt_4_1_nano",
226
+ )
227
+ assert result is not None
228
+ assert result.name == first_provider.name
229
+ assert result.model_id == first_provider.model_id
230
+
231
+ def test_model_with_multiple_providers(self):
232
+ """Test model that has multiple providers"""
233
+ # GPT 4.1 has multiple providers
234
+ openai_result = built_in_models_from_provider(
235
+ provider_name=ModelProviderName.openai,
236
+ model_name="gpt_4_1",
237
+ )
238
+ openrouter_result = built_in_models_from_provider(
239
+ provider_name=ModelProviderName.openrouter,
240
+ model_name="gpt_4_1",
241
+ )
242
+ azure_result = built_in_models_from_provider(
243
+ provider_name=ModelProviderName.azure_openai,
244
+ model_name="gpt_4_1",
245
+ )
246
+
247
+ assert openai_result is not None
248
+ assert openrouter_result is not None
249
+ assert azure_result is not None
250
+ assert openai_result.name == ModelProviderName.openai
251
+ assert openrouter_result.name == ModelProviderName.openrouter
252
+ assert azure_result.name == ModelProviderName.azure_openai
253
+
254
+ def test_case_sensitive_model_name(self):
255
+ """Test that model name matching is case sensitive"""
256
+ # Should return None for case-mismatched model name
257
+ result = built_in_models_from_provider(
258
+ provider_name=ModelProviderName.openai,
259
+ model_name="GPT_4_1", # Wrong case
260
+ )
261
+ assert result is None
262
+
263
+ def test_provider_specific_attributes(self):
264
+ """Test that provider-specific attributes are correctly returned"""
265
+ # Test a provider with specific attributes like multimodal_capable
266
+ result = built_in_models_from_provider(
267
+ provider_name=ModelProviderName.openai,
268
+ model_name="gpt_4_1",
269
+ )
270
+ assert result is not None
271
+ assert result.multimodal_capable is True
272
+ assert result.supports_doc_extraction is True
273
+ assert result.multimodal_mime_types is not None
274
+ assert len(result.multimodal_mime_types) > 0
275
+
276
+ def test_provider_without_special_attributes(self):
277
+ """Test provider that doesn't have special attributes"""
278
+ # Test a simpler provider configuration
279
+ result = built_in_models_from_provider(
280
+ provider_name=ModelProviderName.openai,
281
+ model_name="gpt_4_1_nano",
282
+ )
283
+ assert result is not None
284
+ assert result.multimodal_capable is False # Should be False for nano
285
+ assert result.supports_doc_extraction is False
286
+
287
+ @pytest.mark.parametrize(
288
+ "model_name,provider,expected_model_id",
289
+ [
290
+ ("gpt_4o", ModelProviderName.openai, "gpt-4o"),
291
+ (
292
+ "claude_3_5_haiku",
293
+ ModelProviderName.anthropic,
294
+ "claude-3-5-haiku-20241022",
295
+ ),
296
+ ("gemini_2_5_pro", ModelProviderName.gemini_api, "gemini-2.5-pro"),
297
+ ("llama_3_1_8b", ModelProviderName.groq, "llama-3.1-8b-instant"),
298
+ ],
299
+ )
300
+ def test_parametrized_valid_combinations(
301
+ self, model_name, provider, expected_model_id
302
+ ):
303
+ """Test multiple valid model/provider combinations"""
304
+ result = built_in_models_from_provider(
305
+ provider_name=provider,
306
+ model_name=model_name,
307
+ )
308
+ assert result is not None
309
+ assert result.name == provider
310
+ assert result.model_id == expected_model_id
311
+
312
+ def test_empty_string_model_name(self):
313
+ """Test that empty string model name returns None"""
314
+ result = built_in_models_from_provider(
315
+ provider_name=ModelProviderName.openai,
316
+ model_name="",
317
+ )
318
+ assert result is None
319
+
320
+ def test_none_model_name(self):
321
+ """Test that None model name returns None"""
322
+ result = built_in_models_from_provider(
323
+ provider_name=ModelProviderName.openai,
324
+ model_name=None, # type: ignore
325
+ )
326
+ assert result is None
327
+
328
+ def test_all_built_in_models_have_valid_providers(self):
329
+ """Test that all built-in models have at least one valid provider"""
330
+ for model in built_in_models:
331
+ assert len(model.providers) > 0, f"Model {model.name} has no providers"
332
+ for provider in model.providers:
333
+ # Test that we can retrieve each provider
334
+ result = built_in_models_from_provider(
335
+ provider_name=provider.name,
336
+ model_name=model.name,
337
+ )
338
+ assert result is not None, (
339
+ f"Could not retrieve provider {provider.name} for model {model.name}"
340
+ )
341
+ assert result.name == provider.name
342
+ assert result.model_id == provider.model_id
343
+
344
+
164
345
  def test_uncensored():
165
346
  """Test that uncensored is set correctly"""
166
347
  model = get_model_by_name(ModelName.grok_3_mini)
@@ -179,6 +360,37 @@ def test_uncensored():
179
360
  assert provider.suggested_for_uncensored_data_gen
180
361
 
181
362
 
363
+ def test_multimodal_capable():
364
+ """Test that multimodal_capable is set correctly"""
365
+ model = get_model_by_name(ModelName.gpt_4_1)
366
+ for provider in model.providers:
367
+ assert provider.multimodal_capable
368
+ assert provider.supports_doc_extraction
369
+ assert provider.multimodal_mime_types is not None
370
+ assert len(provider.multimodal_mime_types) > 0
371
+
372
+
373
+ def test_no_empty_multimodal_mime_types():
374
+ """Ensure that multimodal fields are self-consistent as they are interdependent"""
375
+ for model in built_in_models:
376
+ for provider in model.providers:
377
+ # a multimodal model should always have mime types it supports
378
+ if provider.multimodal_capable:
379
+ assert provider.multimodal_mime_types is not None
380
+ assert len(provider.multimodal_mime_types) > 0
381
+
382
+ # a model that specifies mime types is necessarily multimodal
383
+ if (
384
+ provider.multimodal_mime_types is not None
385
+ and len(provider.multimodal_mime_types) > 0
386
+ ):
387
+ assert provider.multimodal_capable
388
+
389
+ # a model that supports doc extraction is necessarily multimodal
390
+ if provider.supports_doc_extraction:
391
+ assert provider.multimodal_capable
392
+
393
+
182
394
  def test_no_reasoning_for_structured_output():
183
395
  """Test that no reasoning is returned for structured output"""
184
396
  # get all models