kiln-ai 0.19.0__py3-none-any.whl → 0.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +8 -2
- kiln_ai/adapters/adapter_registry.py +43 -208
- kiln_ai/adapters/chat/chat_formatter.py +8 -12
- kiln_ai/adapters/chat/test_chat_formatter.py +6 -2
- kiln_ai/adapters/chunkers/__init__.py +13 -0
- kiln_ai/adapters/chunkers/base_chunker.py +42 -0
- kiln_ai/adapters/chunkers/chunker_registry.py +16 -0
- kiln_ai/adapters/chunkers/fixed_window_chunker.py +39 -0
- kiln_ai/adapters/chunkers/helpers.py +23 -0
- kiln_ai/adapters/chunkers/test_base_chunker.py +63 -0
- kiln_ai/adapters/chunkers/test_chunker_registry.py +28 -0
- kiln_ai/adapters/chunkers/test_fixed_window_chunker.py +346 -0
- kiln_ai/adapters/chunkers/test_helpers.py +75 -0
- kiln_ai/adapters/data_gen/test_data_gen_task.py +9 -3
- kiln_ai/adapters/docker_model_runner_tools.py +119 -0
- kiln_ai/adapters/embedding/__init__.py +0 -0
- kiln_ai/adapters/embedding/base_embedding_adapter.py +44 -0
- kiln_ai/adapters/embedding/embedding_registry.py +32 -0
- kiln_ai/adapters/embedding/litellm_embedding_adapter.py +199 -0
- kiln_ai/adapters/embedding/test_base_embedding_adapter.py +283 -0
- kiln_ai/adapters/embedding/test_embedding_registry.py +166 -0
- kiln_ai/adapters/embedding/test_litellm_embedding_adapter.py +1149 -0
- kiln_ai/adapters/eval/base_eval.py +2 -2
- kiln_ai/adapters/eval/eval_runner.py +9 -3
- kiln_ai/adapters/eval/g_eval.py +2 -2
- kiln_ai/adapters/eval/test_base_eval.py +2 -4
- kiln_ai/adapters/eval/test_g_eval.py +4 -5
- kiln_ai/adapters/extractors/__init__.py +18 -0
- kiln_ai/adapters/extractors/base_extractor.py +72 -0
- kiln_ai/adapters/extractors/encoding.py +20 -0
- kiln_ai/adapters/extractors/extractor_registry.py +44 -0
- kiln_ai/adapters/extractors/extractor_runner.py +112 -0
- kiln_ai/adapters/extractors/litellm_extractor.py +386 -0
- kiln_ai/adapters/extractors/test_base_extractor.py +244 -0
- kiln_ai/adapters/extractors/test_encoding.py +54 -0
- kiln_ai/adapters/extractors/test_extractor_registry.py +181 -0
- kiln_ai/adapters/extractors/test_extractor_runner.py +181 -0
- kiln_ai/adapters/extractors/test_litellm_extractor.py +1192 -0
- kiln_ai/adapters/fine_tune/__init__.py +1 -1
- kiln_ai/adapters/fine_tune/openai_finetune.py +14 -4
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +2 -2
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +2 -6
- kiln_ai/adapters/fine_tune/test_openai_finetune.py +108 -111
- kiln_ai/adapters/fine_tune/test_together_finetune.py +2 -6
- kiln_ai/adapters/ml_embedding_model_list.py +192 -0
- kiln_ai/adapters/ml_model_list.py +761 -37
- kiln_ai/adapters/model_adapters/base_adapter.py +51 -21
- kiln_ai/adapters/model_adapters/litellm_adapter.py +380 -138
- kiln_ai/adapters/model_adapters/test_base_adapter.py +193 -17
- kiln_ai/adapters/model_adapters/test_litellm_adapter.py +407 -2
- kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py +1103 -0
- kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +5 -5
- kiln_ai/adapters/model_adapters/test_structured_output.py +113 -5
- kiln_ai/adapters/ollama_tools.py +69 -12
- kiln_ai/adapters/parsers/__init__.py +1 -1
- kiln_ai/adapters/provider_tools.py +205 -47
- kiln_ai/adapters/rag/deduplication.py +49 -0
- kiln_ai/adapters/rag/progress.py +252 -0
- kiln_ai/adapters/rag/rag_runners.py +844 -0
- kiln_ai/adapters/rag/test_deduplication.py +195 -0
- kiln_ai/adapters/rag/test_progress.py +785 -0
- kiln_ai/adapters/rag/test_rag_runners.py +2376 -0
- kiln_ai/adapters/remote_config.py +80 -8
- kiln_ai/adapters/repair/test_repair_task.py +12 -9
- kiln_ai/adapters/run_output.py +3 -0
- kiln_ai/adapters/test_adapter_registry.py +657 -85
- kiln_ai/adapters/test_docker_model_runner_tools.py +305 -0
- kiln_ai/adapters/test_ml_embedding_model_list.py +429 -0
- kiln_ai/adapters/test_ml_model_list.py +251 -1
- kiln_ai/adapters/test_ollama_tools.py +340 -1
- kiln_ai/adapters/test_prompt_adaptors.py +13 -6
- kiln_ai/adapters/test_prompt_builders.py +1 -1
- kiln_ai/adapters/test_provider_tools.py +254 -8
- kiln_ai/adapters/test_remote_config.py +651 -58
- kiln_ai/adapters/vector_store/__init__.py +1 -0
- kiln_ai/adapters/vector_store/base_vector_store_adapter.py +83 -0
- kiln_ai/adapters/vector_store/lancedb_adapter.py +389 -0
- kiln_ai/adapters/vector_store/test_base_vector_store.py +160 -0
- kiln_ai/adapters/vector_store/test_lancedb_adapter.py +1841 -0
- kiln_ai/adapters/vector_store/test_vector_store_registry.py +199 -0
- kiln_ai/adapters/vector_store/vector_store_registry.py +33 -0
- kiln_ai/datamodel/__init__.py +39 -34
- kiln_ai/datamodel/basemodel.py +170 -1
- kiln_ai/datamodel/chunk.py +158 -0
- kiln_ai/datamodel/datamodel_enums.py +28 -0
- kiln_ai/datamodel/embedding.py +64 -0
- kiln_ai/datamodel/eval.py +1 -1
- kiln_ai/datamodel/external_tool_server.py +298 -0
- kiln_ai/datamodel/extraction.py +303 -0
- kiln_ai/datamodel/json_schema.py +25 -10
- kiln_ai/datamodel/project.py +40 -1
- kiln_ai/datamodel/rag.py +79 -0
- kiln_ai/datamodel/registry.py +0 -15
- kiln_ai/datamodel/run_config.py +62 -0
- kiln_ai/datamodel/task.py +2 -77
- kiln_ai/datamodel/task_output.py +6 -1
- kiln_ai/datamodel/task_run.py +41 -0
- kiln_ai/datamodel/test_attachment.py +649 -0
- kiln_ai/datamodel/test_basemodel.py +4 -4
- kiln_ai/datamodel/test_chunk_models.py +317 -0
- kiln_ai/datamodel/test_dataset_split.py +1 -1
- kiln_ai/datamodel/test_embedding_models.py +448 -0
- kiln_ai/datamodel/test_eval_model.py +6 -6
- kiln_ai/datamodel/test_example_models.py +175 -0
- kiln_ai/datamodel/test_external_tool_server.py +691 -0
- kiln_ai/datamodel/test_extraction_chunk.py +206 -0
- kiln_ai/datamodel/test_extraction_model.py +470 -0
- kiln_ai/datamodel/test_rag.py +641 -0
- kiln_ai/datamodel/test_registry.py +8 -3
- kiln_ai/datamodel/test_task.py +15 -47
- kiln_ai/datamodel/test_tool_id.py +320 -0
- kiln_ai/datamodel/test_vector_store.py +320 -0
- kiln_ai/datamodel/tool_id.py +105 -0
- kiln_ai/datamodel/vector_store.py +141 -0
- kiln_ai/tools/__init__.py +8 -0
- kiln_ai/tools/base_tool.py +82 -0
- kiln_ai/tools/built_in_tools/__init__.py +13 -0
- kiln_ai/tools/built_in_tools/math_tools.py +124 -0
- kiln_ai/tools/built_in_tools/test_math_tools.py +204 -0
- kiln_ai/tools/mcp_server_tool.py +95 -0
- kiln_ai/tools/mcp_session_manager.py +246 -0
- kiln_ai/tools/rag_tools.py +157 -0
- kiln_ai/tools/test_base_tools.py +199 -0
- kiln_ai/tools/test_mcp_server_tool.py +457 -0
- kiln_ai/tools/test_mcp_session_manager.py +1585 -0
- kiln_ai/tools/test_rag_tools.py +848 -0
- kiln_ai/tools/test_tool_registry.py +562 -0
- kiln_ai/tools/tool_registry.py +85 -0
- kiln_ai/utils/__init__.py +3 -0
- kiln_ai/utils/async_job_runner.py +62 -17
- kiln_ai/utils/config.py +24 -2
- kiln_ai/utils/env.py +15 -0
- kiln_ai/utils/filesystem.py +14 -0
- kiln_ai/utils/filesystem_cache.py +60 -0
- kiln_ai/utils/litellm.py +94 -0
- kiln_ai/utils/lock.py +100 -0
- kiln_ai/utils/mime_type.py +38 -0
- kiln_ai/utils/open_ai_types.py +94 -0
- kiln_ai/utils/pdf_utils.py +38 -0
- kiln_ai/utils/project_utils.py +17 -0
- kiln_ai/utils/test_async_job_runner.py +151 -35
- kiln_ai/utils/test_config.py +138 -1
- kiln_ai/utils/test_env.py +142 -0
- kiln_ai/utils/test_filesystem_cache.py +316 -0
- kiln_ai/utils/test_litellm.py +206 -0
- kiln_ai/utils/test_lock.py +185 -0
- kiln_ai/utils/test_mime_type.py +66 -0
- kiln_ai/utils/test_open_ai_types.py +131 -0
- kiln_ai/utils/test_pdf_utils.py +73 -0
- kiln_ai/utils/test_uuid.py +111 -0
- kiln_ai/utils/test_validation.py +524 -0
- kiln_ai/utils/uuid.py +9 -0
- kiln_ai/utils/validation.py +90 -0
- {kiln_ai-0.19.0.dist-info → kiln_ai-0.21.0.dist-info}/METADATA +12 -5
- kiln_ai-0.21.0.dist-info/RECORD +211 -0
- kiln_ai-0.19.0.dist-info/RECORD +0 -115
- {kiln_ai-0.19.0.dist-info → kiln_ai-0.21.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.19.0.dist-info → kiln_ai-0.21.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -0,0 +1,429 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
|
|
3
|
+
from kiln_ai.adapters.ml_embedding_model_list import (
|
|
4
|
+
EmbeddingModelName,
|
|
5
|
+
KilnEmbeddingModel,
|
|
6
|
+
KilnEmbeddingModelFamily,
|
|
7
|
+
KilnEmbeddingModelProvider,
|
|
8
|
+
built_in_embedding_models,
|
|
9
|
+
built_in_embedding_models_from_provider,
|
|
10
|
+
get_model_by_name,
|
|
11
|
+
)
|
|
12
|
+
from kiln_ai.datamodel.datamodel_enums import ModelProviderName
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class TestEmbeddingModelName:
|
|
16
|
+
"""Test cases for EmbeddingModelName enum"""
|
|
17
|
+
|
|
18
|
+
def test_enum_values(self):
|
|
19
|
+
"""Test that enum values are correctly defined"""
|
|
20
|
+
assert (
|
|
21
|
+
EmbeddingModelName.openai_text_embedding_3_small
|
|
22
|
+
== "openai_text_embedding_3_small"
|
|
23
|
+
)
|
|
24
|
+
assert (
|
|
25
|
+
EmbeddingModelName.openai_text_embedding_3_large
|
|
26
|
+
== "openai_text_embedding_3_large"
|
|
27
|
+
)
|
|
28
|
+
assert (
|
|
29
|
+
EmbeddingModelName.gemini_text_embedding_004 == "gemini_text_embedding_004"
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class TestKilnEmbeddingModelProvider:
|
|
34
|
+
"""Test cases for KilnEmbeddingModelProvider model"""
|
|
35
|
+
|
|
36
|
+
def test_basic_provider_creation(self):
|
|
37
|
+
"""Test creating a basic provider with required fields"""
|
|
38
|
+
provider = KilnEmbeddingModelProvider(
|
|
39
|
+
name=ModelProviderName.openai,
|
|
40
|
+
model_id="text-embedding-3-small",
|
|
41
|
+
max_input_tokens=8192,
|
|
42
|
+
n_dimensions=1536,
|
|
43
|
+
supports_custom_dimensions=True,
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
assert provider.name == ModelProviderName.openai
|
|
47
|
+
assert provider.model_id == "text-embedding-3-small"
|
|
48
|
+
assert provider.max_input_tokens == 8192
|
|
49
|
+
assert provider.n_dimensions == 1536
|
|
50
|
+
assert provider.supports_custom_dimensions is True
|
|
51
|
+
|
|
52
|
+
def test_provider_with_optional_fields_unspecified(self):
|
|
53
|
+
"""Test creating a provider with optional fields not specified"""
|
|
54
|
+
provider = KilnEmbeddingModelProvider(
|
|
55
|
+
name=ModelProviderName.gemini_api,
|
|
56
|
+
model_id="text-embedding-004",
|
|
57
|
+
n_dimensions=768,
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
assert provider.name == ModelProviderName.gemini_api
|
|
61
|
+
assert provider.model_id == "text-embedding-004"
|
|
62
|
+
assert provider.max_input_tokens is None
|
|
63
|
+
assert provider.n_dimensions == 768
|
|
64
|
+
assert provider.supports_custom_dimensions is False
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class TestKilnEmbeddingModel:
|
|
68
|
+
"""Test cases for KilnEmbeddingModel model"""
|
|
69
|
+
|
|
70
|
+
def test_basic_model_creation(self):
|
|
71
|
+
"""Test creating a basic model with required fields"""
|
|
72
|
+
providers = [
|
|
73
|
+
KilnEmbeddingModelProvider(
|
|
74
|
+
name=ModelProviderName.openai,
|
|
75
|
+
model_id="text-embedding-3-small",
|
|
76
|
+
n_dimensions=1536,
|
|
77
|
+
max_input_tokens=8192,
|
|
78
|
+
)
|
|
79
|
+
]
|
|
80
|
+
|
|
81
|
+
model = KilnEmbeddingModel(
|
|
82
|
+
family=KilnEmbeddingModelFamily.openai,
|
|
83
|
+
name=EmbeddingModelName.openai_text_embedding_3_small,
|
|
84
|
+
friendly_name="Text Embedding 3 Small",
|
|
85
|
+
providers=providers,
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
assert model.family == KilnEmbeddingModelFamily.openai
|
|
89
|
+
assert model.name == EmbeddingModelName.openai_text_embedding_3_small
|
|
90
|
+
assert model.friendly_name == "Text Embedding 3 Small"
|
|
91
|
+
assert len(model.providers) == 1
|
|
92
|
+
assert model.providers[0].name == ModelProviderName.openai
|
|
93
|
+
|
|
94
|
+
def test_model_with_multiple_providers(self):
|
|
95
|
+
"""Test creating a model with multiple providers"""
|
|
96
|
+
providers = [
|
|
97
|
+
KilnEmbeddingModelProvider(
|
|
98
|
+
name=ModelProviderName.openai,
|
|
99
|
+
model_id="model-1",
|
|
100
|
+
n_dimensions=1536,
|
|
101
|
+
max_input_tokens=8192,
|
|
102
|
+
),
|
|
103
|
+
KilnEmbeddingModelProvider(
|
|
104
|
+
name=ModelProviderName.anthropic,
|
|
105
|
+
model_id="model-1",
|
|
106
|
+
n_dimensions=1536,
|
|
107
|
+
max_input_tokens=8192,
|
|
108
|
+
),
|
|
109
|
+
]
|
|
110
|
+
|
|
111
|
+
model = KilnEmbeddingModel(
|
|
112
|
+
family=KilnEmbeddingModelFamily.openai,
|
|
113
|
+
name=EmbeddingModelName.openai_text_embedding_3_small,
|
|
114
|
+
friendly_name="text-embedding-3-small",
|
|
115
|
+
providers=providers,
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
assert len(model.providers) == 2
|
|
119
|
+
assert model.providers[0].name == ModelProviderName.openai
|
|
120
|
+
assert model.providers[1].name == ModelProviderName.anthropic
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
class TestEmbeddingModelsList:
|
|
124
|
+
"""Test cases for the embedding_models list"""
|
|
125
|
+
|
|
126
|
+
def test_embedding_models_not_empty(self):
|
|
127
|
+
"""Test that the embedding_models list is not empty"""
|
|
128
|
+
assert len(built_in_embedding_models) > 0
|
|
129
|
+
|
|
130
|
+
def test_all_models_have_required_fields(self):
|
|
131
|
+
"""Test that all models in the list have required fields"""
|
|
132
|
+
for model in built_in_embedding_models:
|
|
133
|
+
assert hasattr(model, "family")
|
|
134
|
+
assert hasattr(model, "name")
|
|
135
|
+
assert hasattr(model, "friendly_name")
|
|
136
|
+
assert hasattr(model, "providers")
|
|
137
|
+
assert isinstance(model.name, str)
|
|
138
|
+
assert isinstance(model.friendly_name, str)
|
|
139
|
+
assert isinstance(model.providers, list)
|
|
140
|
+
assert len(model.providers) > 0
|
|
141
|
+
|
|
142
|
+
def test_all_providers_have_required_fields(self):
|
|
143
|
+
"""Test that all providers in all models have required fields"""
|
|
144
|
+
for model in built_in_embedding_models:
|
|
145
|
+
for provider in model.providers:
|
|
146
|
+
assert hasattr(provider, "name")
|
|
147
|
+
assert isinstance(provider.name, ModelProviderName)
|
|
148
|
+
|
|
149
|
+
def test_model_names_are_unique(self):
|
|
150
|
+
"""Test that all model names in the list are unique"""
|
|
151
|
+
model_names = [model.name for model in built_in_embedding_models]
|
|
152
|
+
assert len(model_names) == len(set(model_names))
|
|
153
|
+
|
|
154
|
+
def test_specific_models_exist(self):
|
|
155
|
+
"""Test that specific expected models exist in the list"""
|
|
156
|
+
model_names = [model.name for model in built_in_embedding_models]
|
|
157
|
+
|
|
158
|
+
assert EmbeddingModelName.openai_text_embedding_3_small in model_names
|
|
159
|
+
assert EmbeddingModelName.openai_text_embedding_3_large in model_names
|
|
160
|
+
assert EmbeddingModelName.gemini_text_embedding_004 in model_names
|
|
161
|
+
|
|
162
|
+
def test_openai_embedding_models(self):
|
|
163
|
+
"""Test specific OpenAI embedding models"""
|
|
164
|
+
openai_models = [
|
|
165
|
+
model
|
|
166
|
+
for model in built_in_embedding_models
|
|
167
|
+
if model.family == KilnEmbeddingModelFamily.openai
|
|
168
|
+
]
|
|
169
|
+
|
|
170
|
+
assert len(openai_models) >= 2 # Should have at least 2 OpenAI models
|
|
171
|
+
|
|
172
|
+
# Check for specific OpenAI models
|
|
173
|
+
openai_model_names = [model.name for model in openai_models]
|
|
174
|
+
assert EmbeddingModelName.openai_text_embedding_3_small in openai_model_names
|
|
175
|
+
assert EmbeddingModelName.openai_text_embedding_3_large in openai_model_names
|
|
176
|
+
|
|
177
|
+
def test_gemini_embedding_models(self):
|
|
178
|
+
"""Test specific Gemini embedding models"""
|
|
179
|
+
gemini_models = [
|
|
180
|
+
model
|
|
181
|
+
for model in built_in_embedding_models
|
|
182
|
+
if model.family == KilnEmbeddingModelFamily.gemini
|
|
183
|
+
]
|
|
184
|
+
|
|
185
|
+
assert len(gemini_models) >= 1 # Should have at least 1 Gemini model
|
|
186
|
+
|
|
187
|
+
# Check for specific Gemini model
|
|
188
|
+
gemini_model_names = [model.name for model in gemini_models]
|
|
189
|
+
assert EmbeddingModelName.gemini_text_embedding_004 in gemini_model_names
|
|
190
|
+
|
|
191
|
+
def test_openai_text_embedding_3_small_details(self):
|
|
192
|
+
"""Test specific details of OpenAI text-embedding-3-small model"""
|
|
193
|
+
model = get_model_by_name(EmbeddingModelName.openai_text_embedding_3_small)
|
|
194
|
+
|
|
195
|
+
assert model.family == KilnEmbeddingModelFamily.openai
|
|
196
|
+
assert model.friendly_name == "Text Embedding 3 Small"
|
|
197
|
+
assert len(model.providers) == 1
|
|
198
|
+
|
|
199
|
+
provider = model.providers[0]
|
|
200
|
+
assert provider.name == ModelProviderName.openai
|
|
201
|
+
assert provider.model_id == "text-embedding-3-small"
|
|
202
|
+
assert provider.n_dimensions == 1536
|
|
203
|
+
assert provider.max_input_tokens == 8192
|
|
204
|
+
assert provider.supports_custom_dimensions is True
|
|
205
|
+
|
|
206
|
+
def test_openai_text_embedding_3_large_details(self):
|
|
207
|
+
"""Test specific details of OpenAI text-embedding-3-large model"""
|
|
208
|
+
model = get_model_by_name(EmbeddingModelName.openai_text_embedding_3_large)
|
|
209
|
+
|
|
210
|
+
assert model.family == KilnEmbeddingModelFamily.openai
|
|
211
|
+
assert model.friendly_name == "Text Embedding 3 Large"
|
|
212
|
+
assert len(model.providers) == 1
|
|
213
|
+
|
|
214
|
+
provider = model.providers[0]
|
|
215
|
+
assert provider.name == ModelProviderName.openai
|
|
216
|
+
assert provider.model_id == "text-embedding-3-large"
|
|
217
|
+
assert provider.n_dimensions == 3072
|
|
218
|
+
assert provider.max_input_tokens == 8192
|
|
219
|
+
assert provider.supports_custom_dimensions is True
|
|
220
|
+
|
|
221
|
+
def test_gemini_text_embedding_004_details(self):
|
|
222
|
+
"""Test specific details of Gemini text-embedding-004 model"""
|
|
223
|
+
model = get_model_by_name(EmbeddingModelName.gemini_text_embedding_004)
|
|
224
|
+
|
|
225
|
+
assert model.family == KilnEmbeddingModelFamily.gemini
|
|
226
|
+
assert model.friendly_name == "Text Embedding 004"
|
|
227
|
+
assert len(model.providers) == 1
|
|
228
|
+
|
|
229
|
+
provider = model.providers[0]
|
|
230
|
+
assert provider.name == ModelProviderName.gemini_api
|
|
231
|
+
assert provider.model_id == "text-embedding-004"
|
|
232
|
+
assert provider.n_dimensions == 768
|
|
233
|
+
assert provider.max_input_tokens == 2048
|
|
234
|
+
assert provider.supports_custom_dimensions is False
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
class TestGetModelByName:
|
|
238
|
+
"""Test cases for get_model_by_name function"""
|
|
239
|
+
|
|
240
|
+
def test_get_existing_model(self):
|
|
241
|
+
"""Test getting an existing model by name"""
|
|
242
|
+
model = get_model_by_name(EmbeddingModelName.openai_text_embedding_3_small)
|
|
243
|
+
assert model.name == EmbeddingModelName.openai_text_embedding_3_small
|
|
244
|
+
assert model.family == KilnEmbeddingModelFamily.openai
|
|
245
|
+
|
|
246
|
+
def test_get_all_existing_models(self):
|
|
247
|
+
"""Test getting all existing models by name"""
|
|
248
|
+
for model_name in EmbeddingModelName:
|
|
249
|
+
model = get_model_by_name(model_name)
|
|
250
|
+
assert model.name == model_name
|
|
251
|
+
|
|
252
|
+
def test_get_nonexistent_model_raises_error(self):
|
|
253
|
+
"""Test that getting a nonexistent model raises ValueError"""
|
|
254
|
+
with pytest.raises(
|
|
255
|
+
ValueError, match="Embedding model nonexistent_model not found"
|
|
256
|
+
):
|
|
257
|
+
get_model_by_name("nonexistent_model")
|
|
258
|
+
|
|
259
|
+
def test_get_model_with_invalid_enum_value(self):
|
|
260
|
+
"""Test that getting a model with invalid enum value raises ValueError"""
|
|
261
|
+
with pytest.raises(ValueError, match="Embedding model invalid_enum not found"):
|
|
262
|
+
get_model_by_name("invalid_enum")
|
|
263
|
+
|
|
264
|
+
@pytest.mark.parametrize(
|
|
265
|
+
"model_name,expected_family,expected_friendly_name",
|
|
266
|
+
[
|
|
267
|
+
(
|
|
268
|
+
EmbeddingModelName.openai_text_embedding_3_small,
|
|
269
|
+
KilnEmbeddingModelFamily.openai,
|
|
270
|
+
"Text Embedding 3 Small",
|
|
271
|
+
),
|
|
272
|
+
(
|
|
273
|
+
EmbeddingModelName.openai_text_embedding_3_large,
|
|
274
|
+
KilnEmbeddingModelFamily.openai,
|
|
275
|
+
"Text Embedding 3 Large",
|
|
276
|
+
),
|
|
277
|
+
(
|
|
278
|
+
EmbeddingModelName.gemini_text_embedding_004,
|
|
279
|
+
KilnEmbeddingModelFamily.gemini,
|
|
280
|
+
"Text Embedding 004",
|
|
281
|
+
),
|
|
282
|
+
],
|
|
283
|
+
)
|
|
284
|
+
def test_parametrized_model_retrieval(
|
|
285
|
+
self, model_name, expected_family, expected_friendly_name
|
|
286
|
+
):
|
|
287
|
+
"""Test retrieving models with parametrized test cases"""
|
|
288
|
+
model = get_model_by_name(model_name)
|
|
289
|
+
assert model.family == expected_family
|
|
290
|
+
assert model.friendly_name == expected_friendly_name
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
class TestBuiltInEmbeddingModelsFromProvider:
|
|
294
|
+
"""Test cases for built_in_embedding_models_from_provider function"""
|
|
295
|
+
|
|
296
|
+
def test_get_existing_provider_for_model(self):
|
|
297
|
+
"""Test getting an existing provider for a model"""
|
|
298
|
+
provider = built_in_embedding_models_from_provider(
|
|
299
|
+
provider_name=ModelProviderName.openai,
|
|
300
|
+
model_name=EmbeddingModelName.openai_text_embedding_3_small,
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
assert provider is not None
|
|
304
|
+
assert provider.name == ModelProviderName.openai
|
|
305
|
+
assert provider.model_id == "text-embedding-3-small"
|
|
306
|
+
assert provider.n_dimensions == 1536
|
|
307
|
+
|
|
308
|
+
def test_get_all_existing_provider_model_combinations(self):
|
|
309
|
+
"""Test getting all existing provider-model combinations"""
|
|
310
|
+
combinations = [
|
|
311
|
+
(
|
|
312
|
+
ModelProviderName.openai,
|
|
313
|
+
EmbeddingModelName.openai_text_embedding_3_small,
|
|
314
|
+
),
|
|
315
|
+
(
|
|
316
|
+
ModelProviderName.openai,
|
|
317
|
+
EmbeddingModelName.openai_text_embedding_3_large,
|
|
318
|
+
),
|
|
319
|
+
(
|
|
320
|
+
ModelProviderName.gemini_api,
|
|
321
|
+
EmbeddingModelName.gemini_text_embedding_004,
|
|
322
|
+
),
|
|
323
|
+
]
|
|
324
|
+
|
|
325
|
+
for provider_name, model_name in combinations:
|
|
326
|
+
provider = built_in_embedding_models_from_provider(
|
|
327
|
+
provider_name, model_name
|
|
328
|
+
)
|
|
329
|
+
assert provider is not None
|
|
330
|
+
assert provider.name == provider_name
|
|
331
|
+
|
|
332
|
+
def test_get_nonexistent_provider_returns_none(self):
|
|
333
|
+
"""Test that getting a nonexistent provider returns None"""
|
|
334
|
+
provider = built_in_embedding_models_from_provider(
|
|
335
|
+
provider_name=ModelProviderName.anthropic, # Not used for embeddings
|
|
336
|
+
model_name=EmbeddingModelName.openai_text_embedding_3_small,
|
|
337
|
+
)
|
|
338
|
+
assert provider is None
|
|
339
|
+
|
|
340
|
+
def test_get_nonexistent_model_returns_none(self):
|
|
341
|
+
"""Test that getting a nonexistent model returns None"""
|
|
342
|
+
provider = built_in_embedding_models_from_provider(
|
|
343
|
+
provider_name=ModelProviderName.openai,
|
|
344
|
+
model_name="nonexistent_model",
|
|
345
|
+
)
|
|
346
|
+
assert provider is None
|
|
347
|
+
|
|
348
|
+
def test_get_wrong_provider_for_model_returns_none(self):
|
|
349
|
+
"""Test that getting wrong provider for a model returns None"""
|
|
350
|
+
provider = built_in_embedding_models_from_provider(
|
|
351
|
+
provider_name=ModelProviderName.gemini_api,
|
|
352
|
+
model_name=EmbeddingModelName.openai_text_embedding_3_small,
|
|
353
|
+
)
|
|
354
|
+
assert provider is None
|
|
355
|
+
|
|
356
|
+
def test_get_openai_text_embedding_3_small_provider_details(self):
|
|
357
|
+
"""Test specific details of OpenAI text-embedding-3-small provider"""
|
|
358
|
+
provider = built_in_embedding_models_from_provider(
|
|
359
|
+
provider_name=ModelProviderName.openai,
|
|
360
|
+
model_name=EmbeddingModelName.openai_text_embedding_3_small,
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
assert provider is not None
|
|
364
|
+
assert provider.name == ModelProviderName.openai
|
|
365
|
+
assert provider.model_id == "text-embedding-3-small"
|
|
366
|
+
assert provider.n_dimensions == 1536
|
|
367
|
+
assert provider.max_input_tokens == 8192
|
|
368
|
+
assert provider.supports_custom_dimensions is True
|
|
369
|
+
|
|
370
|
+
def test_get_openai_text_embedding_3_large_provider_details(self):
|
|
371
|
+
"""Test specific details of OpenAI text-embedding-3-large provider"""
|
|
372
|
+
provider = built_in_embedding_models_from_provider(
|
|
373
|
+
provider_name=ModelProviderName.openai,
|
|
374
|
+
model_name=EmbeddingModelName.openai_text_embedding_3_large,
|
|
375
|
+
)
|
|
376
|
+
|
|
377
|
+
assert provider is not None
|
|
378
|
+
assert provider.name == ModelProviderName.openai
|
|
379
|
+
assert provider.model_id == "text-embedding-3-large"
|
|
380
|
+
assert provider.n_dimensions == 3072
|
|
381
|
+
assert provider.max_input_tokens == 8192
|
|
382
|
+
assert provider.supports_custom_dimensions is True
|
|
383
|
+
|
|
384
|
+
def test_get_gemini_text_embedding_004_provider_details(self):
|
|
385
|
+
"""Test specific details of Gemini text-embedding-004 provider"""
|
|
386
|
+
provider = built_in_embedding_models_from_provider(
|
|
387
|
+
provider_name=ModelProviderName.gemini_api,
|
|
388
|
+
model_name=EmbeddingModelName.gemini_text_embedding_004,
|
|
389
|
+
)
|
|
390
|
+
|
|
391
|
+
assert provider is not None
|
|
392
|
+
assert provider.name == ModelProviderName.gemini_api
|
|
393
|
+
assert provider.model_id == "text-embedding-004"
|
|
394
|
+
assert provider.n_dimensions == 768
|
|
395
|
+
assert provider.max_input_tokens == 2048
|
|
396
|
+
assert provider.supports_custom_dimensions is False
|
|
397
|
+
|
|
398
|
+
@pytest.mark.parametrize(
|
|
399
|
+
"provider_name,model_name,expected_model_id,expected_dimensions",
|
|
400
|
+
[
|
|
401
|
+
(
|
|
402
|
+
ModelProviderName.openai,
|
|
403
|
+
EmbeddingModelName.openai_text_embedding_3_small,
|
|
404
|
+
"text-embedding-3-small",
|
|
405
|
+
1536,
|
|
406
|
+
),
|
|
407
|
+
(
|
|
408
|
+
ModelProviderName.openai,
|
|
409
|
+
EmbeddingModelName.openai_text_embedding_3_large,
|
|
410
|
+
"text-embedding-3-large",
|
|
411
|
+
3072,
|
|
412
|
+
),
|
|
413
|
+
(
|
|
414
|
+
ModelProviderName.gemini_api,
|
|
415
|
+
EmbeddingModelName.gemini_text_embedding_004,
|
|
416
|
+
"text-embedding-004",
|
|
417
|
+
768,
|
|
418
|
+
),
|
|
419
|
+
],
|
|
420
|
+
)
|
|
421
|
+
def test_parametrized_provider_retrieval(
|
|
422
|
+
self, provider_name, model_name, expected_model_id, expected_dimensions
|
|
423
|
+
):
|
|
424
|
+
"""Test retrieving providers with parametrized test cases"""
|
|
425
|
+
provider = built_in_embedding_models_from_provider(provider_name, model_name)
|
|
426
|
+
|
|
427
|
+
assert provider is not None
|
|
428
|
+
assert provider.model_id == expected_model_id
|
|
429
|
+
assert provider.n_dimensions == expected_dimensions
|