kiln-ai 0.20.1__py3-none-any.whl → 0.22.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +6 -0
- kiln_ai/adapters/adapter_registry.py +43 -226
- kiln_ai/adapters/chunkers/__init__.py +13 -0
- kiln_ai/adapters/chunkers/base_chunker.py +42 -0
- kiln_ai/adapters/chunkers/chunker_registry.py +16 -0
- kiln_ai/adapters/chunkers/fixed_window_chunker.py +39 -0
- kiln_ai/adapters/chunkers/helpers.py +23 -0
- kiln_ai/adapters/chunkers/test_base_chunker.py +63 -0
- kiln_ai/adapters/chunkers/test_chunker_registry.py +28 -0
- kiln_ai/adapters/chunkers/test_fixed_window_chunker.py +346 -0
- kiln_ai/adapters/chunkers/test_helpers.py +75 -0
- kiln_ai/adapters/data_gen/test_data_gen_task.py +9 -3
- kiln_ai/adapters/embedding/__init__.py +0 -0
- kiln_ai/adapters/embedding/base_embedding_adapter.py +44 -0
- kiln_ai/adapters/embedding/embedding_registry.py +32 -0
- kiln_ai/adapters/embedding/litellm_embedding_adapter.py +199 -0
- kiln_ai/adapters/embedding/test_base_embedding_adapter.py +283 -0
- kiln_ai/adapters/embedding/test_embedding_registry.py +166 -0
- kiln_ai/adapters/embedding/test_litellm_embedding_adapter.py +1149 -0
- kiln_ai/adapters/eval/eval_runner.py +6 -2
- kiln_ai/adapters/eval/test_base_eval.py +1 -3
- kiln_ai/adapters/eval/test_g_eval.py +1 -1
- kiln_ai/adapters/extractors/__init__.py +18 -0
- kiln_ai/adapters/extractors/base_extractor.py +72 -0
- kiln_ai/adapters/extractors/encoding.py +20 -0
- kiln_ai/adapters/extractors/extractor_registry.py +44 -0
- kiln_ai/adapters/extractors/extractor_runner.py +112 -0
- kiln_ai/adapters/extractors/litellm_extractor.py +406 -0
- kiln_ai/adapters/extractors/test_base_extractor.py +244 -0
- kiln_ai/adapters/extractors/test_encoding.py +54 -0
- kiln_ai/adapters/extractors/test_extractor_registry.py +181 -0
- kiln_ai/adapters/extractors/test_extractor_runner.py +181 -0
- kiln_ai/adapters/extractors/test_litellm_extractor.py +1290 -0
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +2 -2
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +2 -6
- kiln_ai/adapters/fine_tune/test_together_finetune.py +2 -6
- kiln_ai/adapters/ml_embedding_model_list.py +494 -0
- kiln_ai/adapters/ml_model_list.py +876 -18
- kiln_ai/adapters/model_adapters/litellm_adapter.py +40 -75
- kiln_ai/adapters/model_adapters/test_litellm_adapter.py +79 -1
- kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py +119 -5
- kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +9 -3
- kiln_ai/adapters/model_adapters/test_structured_output.py +9 -10
- kiln_ai/adapters/ollama_tools.py +69 -12
- kiln_ai/adapters/provider_tools.py +190 -46
- kiln_ai/adapters/rag/deduplication.py +49 -0
- kiln_ai/adapters/rag/progress.py +252 -0
- kiln_ai/adapters/rag/rag_runners.py +844 -0
- kiln_ai/adapters/rag/test_deduplication.py +195 -0
- kiln_ai/adapters/rag/test_progress.py +785 -0
- kiln_ai/adapters/rag/test_rag_runners.py +2376 -0
- kiln_ai/adapters/remote_config.py +80 -8
- kiln_ai/adapters/test_adapter_registry.py +579 -86
- kiln_ai/adapters/test_ml_embedding_model_list.py +239 -0
- kiln_ai/adapters/test_ml_model_list.py +202 -0
- kiln_ai/adapters/test_ollama_tools.py +340 -1
- kiln_ai/adapters/test_prompt_builders.py +1 -1
- kiln_ai/adapters/test_provider_tools.py +199 -8
- kiln_ai/adapters/test_remote_config.py +551 -56
- kiln_ai/adapters/vector_store/__init__.py +1 -0
- kiln_ai/adapters/vector_store/base_vector_store_adapter.py +83 -0
- kiln_ai/adapters/vector_store/lancedb_adapter.py +389 -0
- kiln_ai/adapters/vector_store/test_base_vector_store.py +160 -0
- kiln_ai/adapters/vector_store/test_lancedb_adapter.py +1841 -0
- kiln_ai/adapters/vector_store/test_vector_store_registry.py +199 -0
- kiln_ai/adapters/vector_store/vector_store_registry.py +33 -0
- kiln_ai/datamodel/__init__.py +16 -13
- kiln_ai/datamodel/basemodel.py +201 -4
- kiln_ai/datamodel/chunk.py +158 -0
- kiln_ai/datamodel/datamodel_enums.py +27 -0
- kiln_ai/datamodel/embedding.py +64 -0
- kiln_ai/datamodel/external_tool_server.py +206 -54
- kiln_ai/datamodel/extraction.py +317 -0
- kiln_ai/datamodel/project.py +33 -1
- kiln_ai/datamodel/rag.py +79 -0
- kiln_ai/datamodel/task.py +5 -0
- kiln_ai/datamodel/task_output.py +41 -11
- kiln_ai/datamodel/test_attachment.py +649 -0
- kiln_ai/datamodel/test_basemodel.py +270 -14
- kiln_ai/datamodel/test_chunk_models.py +317 -0
- kiln_ai/datamodel/test_dataset_split.py +1 -1
- kiln_ai/datamodel/test_datasource.py +50 -0
- kiln_ai/datamodel/test_embedding_models.py +448 -0
- kiln_ai/datamodel/test_eval_model.py +6 -6
- kiln_ai/datamodel/test_external_tool_server.py +534 -152
- kiln_ai/datamodel/test_extraction_chunk.py +206 -0
- kiln_ai/datamodel/test_extraction_model.py +501 -0
- kiln_ai/datamodel/test_rag.py +641 -0
- kiln_ai/datamodel/test_task.py +35 -1
- kiln_ai/datamodel/test_tool_id.py +187 -1
- kiln_ai/datamodel/test_vector_store.py +320 -0
- kiln_ai/datamodel/tool_id.py +58 -0
- kiln_ai/datamodel/vector_store.py +141 -0
- kiln_ai/tools/base_tool.py +12 -3
- kiln_ai/tools/built_in_tools/math_tools.py +12 -4
- kiln_ai/tools/kiln_task_tool.py +158 -0
- kiln_ai/tools/mcp_server_tool.py +2 -2
- kiln_ai/tools/mcp_session_manager.py +51 -22
- kiln_ai/tools/rag_tools.py +164 -0
- kiln_ai/tools/test_kiln_task_tool.py +527 -0
- kiln_ai/tools/test_mcp_server_tool.py +4 -15
- kiln_ai/tools/test_mcp_session_manager.py +187 -227
- kiln_ai/tools/test_rag_tools.py +929 -0
- kiln_ai/tools/test_tool_registry.py +290 -7
- kiln_ai/tools/tool_registry.py +69 -16
- kiln_ai/utils/__init__.py +3 -0
- kiln_ai/utils/async_job_runner.py +62 -17
- kiln_ai/utils/config.py +2 -2
- kiln_ai/utils/env.py +15 -0
- kiln_ai/utils/filesystem.py +14 -0
- kiln_ai/utils/filesystem_cache.py +60 -0
- kiln_ai/utils/litellm.py +94 -0
- kiln_ai/utils/lock.py +100 -0
- kiln_ai/utils/mime_type.py +38 -0
- kiln_ai/utils/open_ai_types.py +19 -2
- kiln_ai/utils/pdf_utils.py +59 -0
- kiln_ai/utils/test_async_job_runner.py +151 -35
- kiln_ai/utils/test_env.py +142 -0
- kiln_ai/utils/test_filesystem_cache.py +316 -0
- kiln_ai/utils/test_litellm.py +206 -0
- kiln_ai/utils/test_lock.py +185 -0
- kiln_ai/utils/test_mime_type.py +66 -0
- kiln_ai/utils/test_open_ai_types.py +88 -12
- kiln_ai/utils/test_pdf_utils.py +86 -0
- kiln_ai/utils/test_uuid.py +111 -0
- kiln_ai/utils/test_validation.py +524 -0
- kiln_ai/utils/uuid.py +9 -0
- kiln_ai/utils/validation.py +90 -0
- {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/METADATA +9 -1
- kiln_ai-0.22.0.dist-info/RECORD +213 -0
- kiln_ai-0.20.1.dist-info/RECORD +0 -138
- {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -0,0 +1,1149 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from unittest.mock import AsyncMock, patch
|
|
3
|
+
|
|
4
|
+
import pytest
|
|
5
|
+
from litellm import Usage
|
|
6
|
+
from litellm.types.utils import EmbeddingResponse
|
|
7
|
+
|
|
8
|
+
from kiln_ai.adapters.embedding.base_embedding_adapter import Embedding
|
|
9
|
+
from kiln_ai.adapters.embedding.litellm_embedding_adapter import (
|
|
10
|
+
MAX_BATCH_SIZE,
|
|
11
|
+
EmbeddingOptions,
|
|
12
|
+
LitellmEmbeddingAdapter,
|
|
13
|
+
validate_map_to_embeddings,
|
|
14
|
+
)
|
|
15
|
+
from kiln_ai.adapters.ml_embedding_model_list import KilnEmbeddingModelProvider
|
|
16
|
+
from kiln_ai.adapters.provider_tools import LiteLlmCoreConfig
|
|
17
|
+
from kiln_ai.datamodel.datamodel_enums import ModelProviderName
|
|
18
|
+
from kiln_ai.datamodel.embedding import EmbeddingConfig
|
|
19
|
+
from kiln_ai.utils.config import Config
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@pytest.fixture
|
|
23
|
+
def mock_embedding_config():
|
|
24
|
+
return EmbeddingConfig(
|
|
25
|
+
name="test-embedding",
|
|
26
|
+
model_provider_name=ModelProviderName.openai,
|
|
27
|
+
model_name="openai_text_embedding_3_small",
|
|
28
|
+
properties={},
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@pytest.fixture
|
|
33
|
+
def mock_litellm_core_config():
|
|
34
|
+
return LiteLlmCoreConfig()
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@pytest.fixture
|
|
38
|
+
def mock_litellm_adapter(mock_embedding_config, mock_litellm_core_config):
|
|
39
|
+
return LitellmEmbeddingAdapter(
|
|
40
|
+
mock_embedding_config, litellm_core_config=mock_litellm_core_config
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class TestEmbeddingOptions:
|
|
45
|
+
"""Test the EmbeddingOptions class."""
|
|
46
|
+
|
|
47
|
+
def test_default_values(self):
|
|
48
|
+
"""Test that EmbeddingOptions has correct default values."""
|
|
49
|
+
options = EmbeddingOptions()
|
|
50
|
+
assert options.dimensions is None
|
|
51
|
+
|
|
52
|
+
def test_with_dimensions(self):
|
|
53
|
+
"""Test EmbeddingOptions with dimensions set."""
|
|
54
|
+
options = EmbeddingOptions(dimensions=1536)
|
|
55
|
+
assert options.dimensions == 1536
|
|
56
|
+
|
|
57
|
+
def test_model_dump_excludes_none(self):
|
|
58
|
+
"""Test that model_dump excludes None values."""
|
|
59
|
+
options = EmbeddingOptions()
|
|
60
|
+
dumped = options.model_dump(exclude_none=True)
|
|
61
|
+
assert "dimensions" not in dumped
|
|
62
|
+
|
|
63
|
+
options_with_dim = EmbeddingOptions(dimensions=1536)
|
|
64
|
+
dumped_with_dim = options_with_dim.model_dump(exclude_none=True)
|
|
65
|
+
assert "dimensions" in dumped_with_dim
|
|
66
|
+
assert dumped_with_dim["dimensions"] == 1536
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class TestLitellmEmbeddingAdapter:
|
|
70
|
+
"""Test the LitellmEmbeddingAdapter class."""
|
|
71
|
+
|
|
72
|
+
def test_init_success(self, mock_embedding_config, mock_litellm_core_config):
|
|
73
|
+
"""Test successful initialization of the adapter."""
|
|
74
|
+
adapter = LitellmEmbeddingAdapter(
|
|
75
|
+
mock_embedding_config, mock_litellm_core_config
|
|
76
|
+
)
|
|
77
|
+
assert adapter.embedding_config == mock_embedding_config
|
|
78
|
+
|
|
79
|
+
def test_build_options_no_dimensions(self, mock_litellm_adapter):
|
|
80
|
+
"""Test build_options when no dimensions are specified."""
|
|
81
|
+
options = mock_litellm_adapter.build_options()
|
|
82
|
+
assert options.dimensions is None
|
|
83
|
+
|
|
84
|
+
def test_build_options_with_dimensions(
|
|
85
|
+
self, mock_embedding_config, mock_litellm_core_config
|
|
86
|
+
):
|
|
87
|
+
"""Test build_options when dimensions are specified."""
|
|
88
|
+
mock_embedding_config.properties = {"dimensions": 1536}
|
|
89
|
+
adapter = LitellmEmbeddingAdapter(
|
|
90
|
+
mock_embedding_config, litellm_core_config=mock_litellm_core_config
|
|
91
|
+
)
|
|
92
|
+
options = adapter.build_options()
|
|
93
|
+
assert options.dimensions == 1536
|
|
94
|
+
|
|
95
|
+
async def test_generate_embeddings_with_completion_kwargs(
|
|
96
|
+
self, mock_embedding_config, mock_litellm_core_config
|
|
97
|
+
):
|
|
98
|
+
"""Test that completion_kwargs are properly passed to litellm.aembedding."""
|
|
99
|
+
# Set up litellm_core_config with additional options
|
|
100
|
+
mock_litellm_core_config.additional_body_options = {"custom_param": "value"}
|
|
101
|
+
mock_litellm_core_config.base_url = "https://custom-api.example.com"
|
|
102
|
+
mock_litellm_core_config.default_headers = {
|
|
103
|
+
"Authorization": "Bearer custom-token"
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
adapter = LitellmEmbeddingAdapter(
|
|
107
|
+
mock_embedding_config, litellm_core_config=mock_litellm_core_config
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
mock_response = AsyncMock(spec=EmbeddingResponse)
|
|
111
|
+
mock_response.data = [
|
|
112
|
+
{"object": "embedding", "index": 0, "embedding": [0.1, 0.2, 0.3]}
|
|
113
|
+
]
|
|
114
|
+
mock_response.usage = Usage(prompt_tokens=5, total_tokens=5)
|
|
115
|
+
|
|
116
|
+
with patch(
|
|
117
|
+
"litellm.aembedding", new_callable=AsyncMock, return_value=mock_response
|
|
118
|
+
) as mock_aembedding:
|
|
119
|
+
await adapter._generate_embeddings(["test text"])
|
|
120
|
+
|
|
121
|
+
# Verify litellm.aembedding was called with completion_kwargs
|
|
122
|
+
call_args = mock_aembedding.call_args
|
|
123
|
+
assert call_args[1]["custom_param"] == "value"
|
|
124
|
+
assert call_args[1]["api_base"] == "https://custom-api.example.com"
|
|
125
|
+
assert call_args[1]["default_headers"] == {
|
|
126
|
+
"Authorization": "Bearer custom-token"
|
|
127
|
+
}
|
|
128
|
+
|
|
129
|
+
async def test_generate_embeddings_with_partial_completion_kwargs(
|
|
130
|
+
self, mock_embedding_config, mock_litellm_core_config
|
|
131
|
+
):
|
|
132
|
+
"""Test that completion_kwargs work when only some options are set."""
|
|
133
|
+
# Set only additional_body_options
|
|
134
|
+
mock_litellm_core_config.additional_body_options = {"timeout": 30}
|
|
135
|
+
mock_litellm_core_config.base_url = None
|
|
136
|
+
mock_litellm_core_config.default_headers = None
|
|
137
|
+
|
|
138
|
+
adapter = LitellmEmbeddingAdapter(
|
|
139
|
+
mock_embedding_config, litellm_core_config=mock_litellm_core_config
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
mock_response = AsyncMock(spec=EmbeddingResponse)
|
|
143
|
+
mock_response.data = [
|
|
144
|
+
{"object": "embedding", "index": 0, "embedding": [0.1, 0.2, 0.3]}
|
|
145
|
+
]
|
|
146
|
+
mock_response.usage = Usage(prompt_tokens=5, total_tokens=5)
|
|
147
|
+
|
|
148
|
+
with patch(
|
|
149
|
+
"litellm.aembedding", new_callable=AsyncMock, return_value=mock_response
|
|
150
|
+
) as mock_aembedding:
|
|
151
|
+
await adapter._generate_embeddings(["test text"])
|
|
152
|
+
|
|
153
|
+
# Verify only the set options are passed
|
|
154
|
+
call_args = mock_aembedding.call_args
|
|
155
|
+
assert call_args[1]["timeout"] == 30
|
|
156
|
+
assert "api_base" not in call_args[1]
|
|
157
|
+
assert "default_headers" not in call_args[1]
|
|
158
|
+
|
|
159
|
+
async def test_generate_embeddings_with_empty_completion_kwargs(
|
|
160
|
+
self, mock_embedding_config, mock_litellm_core_config
|
|
161
|
+
):
|
|
162
|
+
"""Test that completion_kwargs work when all options are None/empty."""
|
|
163
|
+
# Ensure all options are None/empty
|
|
164
|
+
mock_litellm_core_config.additional_body_options = None
|
|
165
|
+
mock_litellm_core_config.base_url = None
|
|
166
|
+
mock_litellm_core_config.default_headers = None
|
|
167
|
+
|
|
168
|
+
adapter = LitellmEmbeddingAdapter(
|
|
169
|
+
mock_embedding_config, litellm_core_config=mock_litellm_core_config
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
mock_response = AsyncMock(spec=EmbeddingResponse)
|
|
173
|
+
mock_response.data = [
|
|
174
|
+
{"object": "embedding", "index": 0, "embedding": [0.1, 0.2, 0.3]}
|
|
175
|
+
]
|
|
176
|
+
mock_response.usage = Usage(prompt_tokens=5, total_tokens=5)
|
|
177
|
+
|
|
178
|
+
with patch(
|
|
179
|
+
"litellm.aembedding", new_callable=AsyncMock, return_value=mock_response
|
|
180
|
+
) as mock_aembedding:
|
|
181
|
+
await adapter._generate_embeddings(["test text"])
|
|
182
|
+
|
|
183
|
+
# Verify no completion_kwargs are passed
|
|
184
|
+
call_args = mock_aembedding.call_args
|
|
185
|
+
assert "api_base" not in call_args[1]
|
|
186
|
+
assert "default_headers" not in call_args[1]
|
|
187
|
+
# Should only have the basic parameters
|
|
188
|
+
assert "model" in call_args[1]
|
|
189
|
+
assert "input" in call_args[1]
|
|
190
|
+
|
|
191
|
+
async def test_generate_embeddings_empty_list(self, mock_litellm_adapter):
|
|
192
|
+
"""Test embed method with empty text list."""
|
|
193
|
+
result = await mock_litellm_adapter.generate_embeddings([])
|
|
194
|
+
assert result.embeddings == []
|
|
195
|
+
assert result.usage is None
|
|
196
|
+
|
|
197
|
+
async def test_generate_embeddings_success(self, mock_litellm_adapter):
|
|
198
|
+
"""Test successful embedding generation."""
|
|
199
|
+
# mock the response type
|
|
200
|
+
mock_response = AsyncMock(spec=EmbeddingResponse)
|
|
201
|
+
mock_response.data = [
|
|
202
|
+
{"object": "embedding", "index": 0, "embedding": [0.1, 0.2, 0.3]},
|
|
203
|
+
{"object": "embedding", "index": 1, "embedding": [0.4, 0.5, 0.6]},
|
|
204
|
+
]
|
|
205
|
+
mock_response.usage = Usage(prompt_tokens=10, total_tokens=10)
|
|
206
|
+
|
|
207
|
+
with patch(
|
|
208
|
+
"litellm.aembedding", new_callable=AsyncMock, return_value=mock_response
|
|
209
|
+
):
|
|
210
|
+
result = await mock_litellm_adapter._generate_embeddings(["text1", "text2"])
|
|
211
|
+
|
|
212
|
+
assert len(result.embeddings) == 2
|
|
213
|
+
assert result.embeddings[0].vector == [0.1, 0.2, 0.3]
|
|
214
|
+
assert result.embeddings[1].vector == [0.4, 0.5, 0.6]
|
|
215
|
+
assert result.usage == mock_response.usage
|
|
216
|
+
|
|
217
|
+
async def test_generate_embeddings_for_batch_success(self, mock_litellm_adapter):
|
|
218
|
+
"""Test successful embedding generation for a single batch."""
|
|
219
|
+
mock_response = AsyncMock(spec=EmbeddingResponse)
|
|
220
|
+
mock_response.data = [
|
|
221
|
+
{"object": "embedding", "index": 0, "embedding": [0.1, 0.2, 0.3]},
|
|
222
|
+
{"object": "embedding", "index": 1, "embedding": [0.4, 0.5, 0.6]},
|
|
223
|
+
]
|
|
224
|
+
mock_response.usage = Usage(prompt_tokens=10, total_tokens=10)
|
|
225
|
+
|
|
226
|
+
with patch(
|
|
227
|
+
"litellm.aembedding", new_callable=AsyncMock, return_value=mock_response
|
|
228
|
+
):
|
|
229
|
+
result = await mock_litellm_adapter._generate_embeddings_for_batch(
|
|
230
|
+
["text1", "text2"]
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
assert len(result.embeddings) == 2
|
|
234
|
+
assert result.embeddings[0].vector == [0.1, 0.2, 0.3]
|
|
235
|
+
assert result.embeddings[1].vector == [0.4, 0.5, 0.6]
|
|
236
|
+
assert result.usage == mock_response.usage
|
|
237
|
+
|
|
238
|
+
async def test_generate_embeddings_for_batch_with_completion_kwargs(
|
|
239
|
+
self, mock_embedding_config, mock_litellm_core_config
|
|
240
|
+
):
|
|
241
|
+
"""Test that completion_kwargs are properly passed to litellm.aembedding in batch method."""
|
|
242
|
+
# Set up litellm_core_config with additional options
|
|
243
|
+
mock_litellm_core_config.additional_body_options = {"custom_param": "value"}
|
|
244
|
+
mock_litellm_core_config.base_url = "https://custom-api.example.com"
|
|
245
|
+
mock_litellm_core_config.default_headers = {
|
|
246
|
+
"Authorization": "Bearer custom-token"
|
|
247
|
+
}
|
|
248
|
+
|
|
249
|
+
adapter = LitellmEmbeddingAdapter(
|
|
250
|
+
mock_embedding_config, litellm_core_config=mock_litellm_core_config
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
mock_response = AsyncMock(spec=EmbeddingResponse)
|
|
254
|
+
mock_response.data = [
|
|
255
|
+
{"object": "embedding", "index": 0, "embedding": [0.1, 0.2, 0.3]}
|
|
256
|
+
]
|
|
257
|
+
mock_response.usage = Usage(prompt_tokens=5, total_tokens=5)
|
|
258
|
+
|
|
259
|
+
with patch(
|
|
260
|
+
"litellm.aembedding", new_callable=AsyncMock, return_value=mock_response
|
|
261
|
+
) as mock_aembedding:
|
|
262
|
+
await adapter._generate_embeddings_for_batch(["test text"])
|
|
263
|
+
|
|
264
|
+
# Verify litellm.aembedding was called with completion_kwargs
|
|
265
|
+
call_args = mock_aembedding.call_args
|
|
266
|
+
assert call_args[1]["custom_param"] == "value"
|
|
267
|
+
assert call_args[1]["api_base"] == "https://custom-api.example.com"
|
|
268
|
+
assert call_args[1]["default_headers"] == {
|
|
269
|
+
"Authorization": "Bearer custom-token"
|
|
270
|
+
}
|
|
271
|
+
|
|
272
|
+
async def test_generate_embeddings_with_dimensions(
|
|
273
|
+
self, mock_embedding_config, mock_litellm_core_config
|
|
274
|
+
):
|
|
275
|
+
"""Test embedding with dimensions specified."""
|
|
276
|
+
mock_embedding_config.properties = {"dimensions": 1536}
|
|
277
|
+
adapter = LitellmEmbeddingAdapter(
|
|
278
|
+
mock_embedding_config, litellm_core_config=mock_litellm_core_config
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
mock_response = AsyncMock(spec=EmbeddingResponse)
|
|
282
|
+
mock_response.data = [
|
|
283
|
+
{"object": "embedding", "index": 0, "embedding": [0.1] * 1536}
|
|
284
|
+
]
|
|
285
|
+
mock_response.usage = Usage(prompt_tokens=5, total_tokens=5)
|
|
286
|
+
|
|
287
|
+
with patch(
|
|
288
|
+
"litellm.aembedding", new_callable=AsyncMock, return_value=mock_response
|
|
289
|
+
) as mock_aembedding:
|
|
290
|
+
result = await adapter._generate_embeddings(["test text"])
|
|
291
|
+
|
|
292
|
+
# Verify litellm.aembedding was called with correct parameters
|
|
293
|
+
mock_aembedding.assert_called_once_with(
|
|
294
|
+
model="openai/text-embedding-3-small",
|
|
295
|
+
input=["test text"],
|
|
296
|
+
dimensions=1536,
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
assert len(result.embeddings) == 1
|
|
300
|
+
assert len(result.embeddings[0].vector) == 1536
|
|
301
|
+
assert result.usage == mock_response.usage
|
|
302
|
+
|
|
303
|
+
async def test_generate_embeddings_batch_size_exceeded(self, mock_litellm_adapter):
|
|
304
|
+
"""Test that embedding fails when batch size is exceeded in individual batch."""
|
|
305
|
+
# This test now tests the _generate_embeddings_for_batch method directly
|
|
306
|
+
# since the main _generate_embeddings method now handles batching automatically
|
|
307
|
+
large_text_list = ["text"] * (MAX_BATCH_SIZE + 1)
|
|
308
|
+
|
|
309
|
+
with pytest.raises(
|
|
310
|
+
ValueError,
|
|
311
|
+
match=f"Too many input texts, max batch size is {MAX_BATCH_SIZE}, got {MAX_BATCH_SIZE + 1}",
|
|
312
|
+
):
|
|
313
|
+
await mock_litellm_adapter._generate_embeddings_for_batch(large_text_list)
|
|
314
|
+
|
|
315
|
+
async def test_generate_embeddings_response_length_mismatch(
|
|
316
|
+
self, mock_litellm_adapter
|
|
317
|
+
):
|
|
318
|
+
"""Test that embedding fails when response data length doesn't match input."""
|
|
319
|
+
mock_response = AsyncMock(spec=EmbeddingResponse)
|
|
320
|
+
mock_response.data = [
|
|
321
|
+
{"object": "embedding", "index": 0, "embedding": [0.1, 0.2, 0.3]}
|
|
322
|
+
] # Only one embedding
|
|
323
|
+
|
|
324
|
+
with patch(
|
|
325
|
+
"litellm.aembedding", new_callable=AsyncMock, return_value=mock_response
|
|
326
|
+
):
|
|
327
|
+
with pytest.raises(
|
|
328
|
+
RuntimeError,
|
|
329
|
+
match=r"Expected the number of embeddings in the response to be 2, got 1.",
|
|
330
|
+
):
|
|
331
|
+
await mock_litellm_adapter._generate_embeddings(["text1", "text2"])
|
|
332
|
+
|
|
333
|
+
async def test_generate_embeddings_litellm_exception(self, mock_litellm_adapter):
|
|
334
|
+
"""Test that litellm exceptions are properly raised."""
|
|
335
|
+
with patch(
|
|
336
|
+
"litellm.aembedding",
|
|
337
|
+
new_callable=AsyncMock,
|
|
338
|
+
side_effect=Exception("litellm error"),
|
|
339
|
+
):
|
|
340
|
+
with pytest.raises(Exception, match="litellm error"):
|
|
341
|
+
await mock_litellm_adapter._generate_embeddings(["test text"])
|
|
342
|
+
|
|
343
|
+
async def test_generate_embeddings_sorts_by_index(self, mock_litellm_adapter):
|
|
344
|
+
"""Test that embeddings are sorted by index."""
|
|
345
|
+
mock_response = AsyncMock(spec=EmbeddingResponse)
|
|
346
|
+
mock_response.data = [
|
|
347
|
+
{"object": "embedding", "index": 2, "embedding": [0.3, 0.4, 0.5]},
|
|
348
|
+
{"object": "embedding", "index": 0, "embedding": [0.1, 0.2, 0.3]},
|
|
349
|
+
{"object": "embedding", "index": 1, "embedding": [0.2, 0.3, 0.4]},
|
|
350
|
+
]
|
|
351
|
+
mock_response.usage = Usage(prompt_tokens=15, total_tokens=15)
|
|
352
|
+
|
|
353
|
+
with patch(
|
|
354
|
+
"litellm.aembedding", new_callable=AsyncMock, return_value=mock_response
|
|
355
|
+
):
|
|
356
|
+
result = await mock_litellm_adapter._generate_embeddings(
|
|
357
|
+
["text1", "text2", "text3"]
|
|
358
|
+
)
|
|
359
|
+
|
|
360
|
+
# Verify embeddings are sorted by index
|
|
361
|
+
assert len(result.embeddings) == 3
|
|
362
|
+
assert result.embeddings[0].vector == [0.1, 0.2, 0.3] # index 0
|
|
363
|
+
assert result.embeddings[1].vector == [0.2, 0.3, 0.4] # index 1
|
|
364
|
+
assert result.embeddings[2].vector == [0.3, 0.4, 0.5] # index 2
|
|
365
|
+
|
|
366
|
+
async def test_generate_embeddings_single_text(self, mock_litellm_adapter):
|
|
367
|
+
"""Test embedding a single text."""
|
|
368
|
+
mock_response = AsyncMock(spec=EmbeddingResponse)
|
|
369
|
+
mock_response.data = [
|
|
370
|
+
{"object": "embedding", "index": 0, "embedding": [0.1, 0.2, 0.3]}
|
|
371
|
+
]
|
|
372
|
+
mock_response.usage = Usage(prompt_tokens=5, total_tokens=5)
|
|
373
|
+
|
|
374
|
+
with patch(
|
|
375
|
+
"litellm.aembedding", new_callable=AsyncMock, return_value=mock_response
|
|
376
|
+
) as mock_aembedding:
|
|
377
|
+
result = await mock_litellm_adapter._generate_embeddings(["single text"])
|
|
378
|
+
|
|
379
|
+
# The call should not include dimensions since the fixture has empty properties
|
|
380
|
+
mock_aembedding.assert_called_once_with(
|
|
381
|
+
model="openai/text-embedding-3-small",
|
|
382
|
+
input=["single text"],
|
|
383
|
+
)
|
|
384
|
+
|
|
385
|
+
assert len(result.embeddings) == 1
|
|
386
|
+
assert result.embeddings[0].vector == [0.1, 0.2, 0.3]
|
|
387
|
+
assert result.usage == mock_response.usage
|
|
388
|
+
|
|
389
|
+
async def test_generate_embeddings_max_batch_size(self, mock_litellm_adapter):
|
|
390
|
+
"""Test embedding with exactly the maximum batch size."""
|
|
391
|
+
mock_response = AsyncMock(spec=EmbeddingResponse)
|
|
392
|
+
mock_response.data = [
|
|
393
|
+
{"object": "embedding", "index": i, "embedding": [0.1, 0.2, 0.3]}
|
|
394
|
+
for i in range(MAX_BATCH_SIZE)
|
|
395
|
+
]
|
|
396
|
+
mock_response.usage = Usage(
|
|
397
|
+
prompt_tokens=MAX_BATCH_SIZE * 5, total_tokens=MAX_BATCH_SIZE * 5
|
|
398
|
+
)
|
|
399
|
+
|
|
400
|
+
large_text_list = ["text"] * MAX_BATCH_SIZE
|
|
401
|
+
|
|
402
|
+
with patch(
|
|
403
|
+
"litellm.aembedding", new_callable=AsyncMock, return_value=mock_response
|
|
404
|
+
):
|
|
405
|
+
result = await mock_litellm_adapter._generate_embeddings(large_text_list)
|
|
406
|
+
|
|
407
|
+
assert len(result.embeddings) == MAX_BATCH_SIZE
|
|
408
|
+
assert result.usage == mock_response.usage
|
|
409
|
+
|
|
410
|
+
async def test_generate_embeddings_multiple_batches(self, mock_litellm_adapter):
|
|
411
|
+
"""Test that embedding properly handles multiple batches."""
|
|
412
|
+
# Create a list that will require multiple batches
|
|
413
|
+
total_texts = MAX_BATCH_SIZE * 2 + 50 # 2 full batches + 50 more
|
|
414
|
+
text_list = [f"text_{i}" for i in range(total_texts)]
|
|
415
|
+
|
|
416
|
+
# Mock responses for each batch
|
|
417
|
+
batch1_response = AsyncMock(spec=EmbeddingResponse)
|
|
418
|
+
batch1_response.data = [
|
|
419
|
+
{"object": "embedding", "index": i, "embedding": [0.1, 0.2, 0.3]}
|
|
420
|
+
for i in range(MAX_BATCH_SIZE)
|
|
421
|
+
]
|
|
422
|
+
batch1_response.usage = Usage(prompt_tokens=100, total_tokens=100)
|
|
423
|
+
|
|
424
|
+
batch2_response = AsyncMock(spec=EmbeddingResponse)
|
|
425
|
+
batch2_response.data = [
|
|
426
|
+
{"object": "embedding", "index": i, "embedding": [0.4, 0.5, 0.6]}
|
|
427
|
+
for i in range(MAX_BATCH_SIZE)
|
|
428
|
+
]
|
|
429
|
+
batch2_response.usage = Usage(prompt_tokens=100, total_tokens=100)
|
|
430
|
+
|
|
431
|
+
batch3_response = AsyncMock(spec=EmbeddingResponse)
|
|
432
|
+
batch3_response.data = [
|
|
433
|
+
{"object": "embedding", "index": i, "embedding": [0.7, 0.8, 0.9]}
|
|
434
|
+
for i in range(50)
|
|
435
|
+
]
|
|
436
|
+
batch3_response.usage = Usage(prompt_tokens=50, total_tokens=50)
|
|
437
|
+
|
|
438
|
+
# Mock litellm.aembedding to return different responses based on input size
|
|
439
|
+
async def mock_aembedding(*args, **kwargs):
|
|
440
|
+
input_size = len(kwargs.get("input", []))
|
|
441
|
+
if input_size == MAX_BATCH_SIZE:
|
|
442
|
+
if len(mock_aembedding.call_count) == 0:
|
|
443
|
+
mock_aembedding.call_count.append(1)
|
|
444
|
+
return batch1_response
|
|
445
|
+
else:
|
|
446
|
+
mock_aembedding.call_count.append(1)
|
|
447
|
+
return batch2_response
|
|
448
|
+
else:
|
|
449
|
+
return batch3_response
|
|
450
|
+
|
|
451
|
+
mock_aembedding.call_count = []
|
|
452
|
+
|
|
453
|
+
with patch(
|
|
454
|
+
"litellm.aembedding", new_callable=AsyncMock, side_effect=mock_aembedding
|
|
455
|
+
):
|
|
456
|
+
result = await mock_litellm_adapter._generate_embeddings(text_list)
|
|
457
|
+
|
|
458
|
+
# Should have all embeddings combined
|
|
459
|
+
assert len(result.embeddings) == total_texts
|
|
460
|
+
|
|
461
|
+
# Should have combined usage from all batches
|
|
462
|
+
assert result.usage is not None
|
|
463
|
+
assert result.usage.prompt_tokens == 250 # 100 + 100 + 50
|
|
464
|
+
assert result.usage.total_tokens == 250 # 100 + 100 + 50
|
|
465
|
+
|
|
466
|
+
# Verify embeddings are in the right order
|
|
467
|
+
assert result.embeddings[0].vector == [0.1, 0.2, 0.3] # First batch
|
|
468
|
+
assert result.embeddings[MAX_BATCH_SIZE].vector == [
|
|
469
|
+
0.4,
|
|
470
|
+
0.5,
|
|
471
|
+
0.6,
|
|
472
|
+
] # Second batch
|
|
473
|
+
assert result.embeddings[MAX_BATCH_SIZE * 2].vector == [
|
|
474
|
+
0.7,
|
|
475
|
+
0.8,
|
|
476
|
+
0.9,
|
|
477
|
+
] # Third batch
|
|
478
|
+
|
|
479
|
+
async def test_generate_embeddings_batching_edge_cases(self, mock_litellm_adapter):
|
|
480
|
+
"""Test batching edge cases like empty lists and single items."""
|
|
481
|
+
# Test empty list
|
|
482
|
+
result = await mock_litellm_adapter._generate_embeddings([])
|
|
483
|
+
assert result.embeddings == []
|
|
484
|
+
assert result.usage is not None
|
|
485
|
+
assert result.usage.prompt_tokens == 0
|
|
486
|
+
assert result.usage.total_tokens == 0
|
|
487
|
+
|
|
488
|
+
# Test single item (should still go through batching logic)
|
|
489
|
+
mock_response = AsyncMock(spec=EmbeddingResponse)
|
|
490
|
+
mock_response.data = [
|
|
491
|
+
{"object": "embedding", "index": 0, "embedding": [0.1, 0.2, 0.3]}
|
|
492
|
+
]
|
|
493
|
+
mock_response.usage = Usage(prompt_tokens=5, total_tokens=5)
|
|
494
|
+
|
|
495
|
+
with patch("litellm.aembedding", return_value=mock_response):
|
|
496
|
+
result = await mock_litellm_adapter._generate_embeddings(["single text"])
|
|
497
|
+
|
|
498
|
+
assert len(result.embeddings) == 1
|
|
499
|
+
assert result.embeddings[0].vector == [0.1, 0.2, 0.3]
|
|
500
|
+
assert result.usage == mock_response.usage
|
|
501
|
+
|
|
502
|
+
async def test_generate_embeddings_batching_with_mixed_usage(
|
|
503
|
+
self, mock_litellm_adapter
|
|
504
|
+
):
|
|
505
|
+
"""Test batching when some responses have usage and others don't."""
|
|
506
|
+
# Create a list that will require multiple batches
|
|
507
|
+
text_list = ["text"] * (MAX_BATCH_SIZE + 10)
|
|
508
|
+
|
|
509
|
+
# First batch with usage
|
|
510
|
+
batch1_response = AsyncMock(spec=EmbeddingResponse)
|
|
511
|
+
batch1_response.data = [
|
|
512
|
+
{"object": "embedding", "index": i, "embedding": [0.1, 0.2, 0.3]}
|
|
513
|
+
for i in range(MAX_BATCH_SIZE)
|
|
514
|
+
]
|
|
515
|
+
batch1_response.usage = Usage(prompt_tokens=100, total_tokens=100)
|
|
516
|
+
|
|
517
|
+
# Second batch without usage
|
|
518
|
+
batch2_response = AsyncMock(spec=EmbeddingResponse)
|
|
519
|
+
batch2_response.data = [
|
|
520
|
+
{"object": "embedding", "index": i, "embedding": [0.4, 0.5, 0.6]}
|
|
521
|
+
for i in range(10)
|
|
522
|
+
]
|
|
523
|
+
batch2_response.usage = None
|
|
524
|
+
|
|
525
|
+
# Mock litellm.aembedding to return different responses based on input size
|
|
526
|
+
async def mock_aembedding(*args, **kwargs):
|
|
527
|
+
input_size = len(kwargs.get("input", []))
|
|
528
|
+
if input_size == MAX_BATCH_SIZE:
|
|
529
|
+
return batch1_response
|
|
530
|
+
else:
|
|
531
|
+
return batch2_response
|
|
532
|
+
|
|
533
|
+
with patch(
|
|
534
|
+
"litellm.aembedding", new_callable=AsyncMock, side_effect=mock_aembedding
|
|
535
|
+
):
|
|
536
|
+
result = await mock_litellm_adapter._generate_embeddings(text_list)
|
|
537
|
+
|
|
538
|
+
# Should have all embeddings combined
|
|
539
|
+
assert len(result.embeddings) == MAX_BATCH_SIZE + 10
|
|
540
|
+
|
|
541
|
+
# Should have None usage since one batch has None usage
|
|
542
|
+
assert result.usage is None
|
|
543
|
+
|
|
544
|
+
async def test_generate_embeddings_batching_with_all_usage(
|
|
545
|
+
self, mock_litellm_adapter
|
|
546
|
+
):
|
|
547
|
+
"""Test batching when all responses have usage information."""
|
|
548
|
+
# Create a list that will require multiple batches
|
|
549
|
+
text_list = ["text"] * (MAX_BATCH_SIZE + 10)
|
|
550
|
+
|
|
551
|
+
# First batch with usage
|
|
552
|
+
batch1_response = AsyncMock(spec=EmbeddingResponse)
|
|
553
|
+
batch1_response.data = [
|
|
554
|
+
{"object": "embedding", "index": i, "embedding": [0.1, 0.2, 0.3]}
|
|
555
|
+
for i in range(MAX_BATCH_SIZE)
|
|
556
|
+
]
|
|
557
|
+
batch1_response.usage = Usage(prompt_tokens=100, total_tokens=100)
|
|
558
|
+
|
|
559
|
+
# Second batch with usage
|
|
560
|
+
batch2_response = AsyncMock(spec=EmbeddingResponse)
|
|
561
|
+
batch2_response.data = [
|
|
562
|
+
{"object": "embedding", "index": i, "embedding": [0.4, 0.5, 0.6]}
|
|
563
|
+
for i in range(10)
|
|
564
|
+
]
|
|
565
|
+
batch2_response.usage = Usage(prompt_tokens=50, total_tokens=50)
|
|
566
|
+
|
|
567
|
+
# Mock litellm.aembedding to return different responses based on input size
|
|
568
|
+
async def mock_aembedding(*args, **kwargs):
|
|
569
|
+
input_size = len(kwargs.get("input", []))
|
|
570
|
+
if input_size == MAX_BATCH_SIZE:
|
|
571
|
+
return batch1_response
|
|
572
|
+
else:
|
|
573
|
+
return batch2_response
|
|
574
|
+
|
|
575
|
+
with patch(
|
|
576
|
+
"litellm.aembedding", new_callable=AsyncMock, side_effect=mock_aembedding
|
|
577
|
+
):
|
|
578
|
+
result = await mock_litellm_adapter._generate_embeddings(text_list)
|
|
579
|
+
|
|
580
|
+
# Should have all embeddings combined
|
|
581
|
+
assert len(result.embeddings) == MAX_BATCH_SIZE + 10
|
|
582
|
+
|
|
583
|
+
# Should have combined usage since all batches have usage
|
|
584
|
+
assert result.usage is not None
|
|
585
|
+
assert result.usage.prompt_tokens == 150 # 100 + 50
|
|
586
|
+
assert result.usage.total_tokens == 150 # 100 + 50
|
|
587
|
+
|
|
588
|
+
def test_embedding_config_inheritance(
|
|
589
|
+
self, mock_embedding_config, mock_litellm_core_config
|
|
590
|
+
):
|
|
591
|
+
"""Test that the adapter properly inherits from BaseEmbeddingAdapter."""
|
|
592
|
+
adapter = LitellmEmbeddingAdapter(
|
|
593
|
+
mock_embedding_config, litellm_core_config=mock_litellm_core_config
|
|
594
|
+
)
|
|
595
|
+
assert adapter.embedding_config == mock_embedding_config
|
|
596
|
+
|
|
597
|
+
async def test_generate_embeddings_method_integration(self, mock_litellm_adapter):
|
|
598
|
+
"""Test the public embed method integration."""
|
|
599
|
+
mock_response = AsyncMock(spec=EmbeddingResponse)
|
|
600
|
+
mock_response.data = [
|
|
601
|
+
{"object": "embedding", "index": 0, "embedding": [0.1, 0.2, 0.3]}
|
|
602
|
+
]
|
|
603
|
+
mock_response.usage = Usage(prompt_tokens=5, total_tokens=5)
|
|
604
|
+
|
|
605
|
+
with patch(
|
|
606
|
+
"litellm.aembedding", new_callable=AsyncMock, return_value=mock_response
|
|
607
|
+
):
|
|
608
|
+
result = await mock_litellm_adapter.generate_embeddings(["test text"])
|
|
609
|
+
|
|
610
|
+
assert len(result.embeddings) == 1
|
|
611
|
+
assert result.embeddings[0].vector == [0.1, 0.2, 0.3]
|
|
612
|
+
assert result.usage == mock_response.usage
|
|
613
|
+
|
|
614
|
+
|
|
615
|
+
class TestLitellmEmbeddingAdapterEdgeCases:
|
|
616
|
+
"""Test edge cases and error conditions."""
|
|
617
|
+
|
|
618
|
+
async def test_generate_embeddings_with_none_usage(
|
|
619
|
+
self, mock_embedding_config, mock_litellm_core_config
|
|
620
|
+
):
|
|
621
|
+
"""Test embedding when litellm returns None usage."""
|
|
622
|
+
adapter = LitellmEmbeddingAdapter(
|
|
623
|
+
mock_embedding_config, litellm_core_config=mock_litellm_core_config
|
|
624
|
+
)
|
|
625
|
+
mock_response = AsyncMock(spec=EmbeddingResponse)
|
|
626
|
+
mock_response.data = [
|
|
627
|
+
{"object": "embedding", "index": 0, "embedding": [0.1, 0.2, 0.3]}
|
|
628
|
+
]
|
|
629
|
+
mock_response.usage = None
|
|
630
|
+
|
|
631
|
+
with patch(
|
|
632
|
+
"litellm.aembedding", new_callable=AsyncMock, return_value=mock_response
|
|
633
|
+
):
|
|
634
|
+
result = await adapter._generate_embeddings(["test text"])
|
|
635
|
+
|
|
636
|
+
assert len(result.embeddings) == 1
|
|
637
|
+
# With the new logic, if any response has None usage, the result has None usage
|
|
638
|
+
assert result.usage is None
|
|
639
|
+
|
|
640
|
+
async def test_generate_embeddings_with_empty_embedding_vector(
|
|
641
|
+
self, mock_embedding_config, mock_litellm_core_config
|
|
642
|
+
):
|
|
643
|
+
"""Test embedding with empty vector."""
|
|
644
|
+
adapter = LitellmEmbeddingAdapter(
|
|
645
|
+
mock_embedding_config, litellm_core_config=mock_litellm_core_config
|
|
646
|
+
)
|
|
647
|
+
mock_response = AsyncMock(spec=EmbeddingResponse)
|
|
648
|
+
mock_response.data = [{"object": "embedding", "index": 0, "embedding": []}]
|
|
649
|
+
mock_response.usage = Usage(prompt_tokens=5, total_tokens=5)
|
|
650
|
+
|
|
651
|
+
with patch(
|
|
652
|
+
"litellm.aembedding", new_callable=AsyncMock, return_value=mock_response
|
|
653
|
+
):
|
|
654
|
+
result = await adapter._generate_embeddings(["test text"])
|
|
655
|
+
|
|
656
|
+
assert len(result.embeddings) == 1
|
|
657
|
+
assert result.embeddings[0].vector == []
|
|
658
|
+
|
|
659
|
+
async def test_generate_embeddings_with_duplicate_indices(
|
|
660
|
+
self, mock_embedding_config, mock_litellm_core_config
|
|
661
|
+
):
|
|
662
|
+
"""Test embedding with duplicate indices (should still work due to sorting)."""
|
|
663
|
+
adapter = LitellmEmbeddingAdapter(
|
|
664
|
+
mock_embedding_config, litellm_core_config=mock_litellm_core_config
|
|
665
|
+
)
|
|
666
|
+
mock_response = AsyncMock(spec=EmbeddingResponse)
|
|
667
|
+
mock_response.data = [
|
|
668
|
+
{"object": "embedding", "index": 0, "embedding": [0.1, 0.2, 0.3]},
|
|
669
|
+
{
|
|
670
|
+
"object": "embedding",
|
|
671
|
+
"index": 0,
|
|
672
|
+
"embedding": [0.4, 0.5, 0.6],
|
|
673
|
+
}, # Duplicate index
|
|
674
|
+
]
|
|
675
|
+
mock_response.usage = Usage(prompt_tokens=10, total_tokens=10)
|
|
676
|
+
|
|
677
|
+
with patch(
|
|
678
|
+
"litellm.aembedding", new_callable=AsyncMock, return_value=mock_response
|
|
679
|
+
):
|
|
680
|
+
result = await adapter._generate_embeddings(["text1", "text2"])
|
|
681
|
+
|
|
682
|
+
# Both embeddings should be present and match the order in response.data
|
|
683
|
+
assert len(result.embeddings) == 2
|
|
684
|
+
assert result.embeddings[0].vector == [0.1, 0.2, 0.3]
|
|
685
|
+
assert result.embeddings[1].vector == [0.4, 0.5, 0.6]
|
|
686
|
+
|
|
687
|
+
async def test_generate_embeddings_with_complex_properties(
|
|
688
|
+
self, mock_embedding_config, mock_litellm_core_config
|
|
689
|
+
):
|
|
690
|
+
"""Test embedding with complex properties (only dimensions should be used)."""
|
|
691
|
+
mock_embedding_config.properties = {
|
|
692
|
+
"dimensions": 1536,
|
|
693
|
+
"custom_property": "value",
|
|
694
|
+
"numeric_property": 42,
|
|
695
|
+
"boolean_property": True,
|
|
696
|
+
}
|
|
697
|
+
adapter = LitellmEmbeddingAdapter(
|
|
698
|
+
mock_embedding_config, litellm_core_config=mock_litellm_core_config
|
|
699
|
+
)
|
|
700
|
+
|
|
701
|
+
mock_response = AsyncMock(spec=EmbeddingResponse)
|
|
702
|
+
mock_response.data = [
|
|
703
|
+
{"object": "embedding", "index": 0, "embedding": [0.1] * 1536}
|
|
704
|
+
]
|
|
705
|
+
mock_response.usage = Usage(prompt_tokens=5, total_tokens=5)
|
|
706
|
+
|
|
707
|
+
with patch(
|
|
708
|
+
"litellm.aembedding", new_callable=AsyncMock, return_value=mock_response
|
|
709
|
+
) as mock_aembedding:
|
|
710
|
+
await adapter._generate_embeddings(["test text"])
|
|
711
|
+
|
|
712
|
+
# Only dimensions should be passed to litellm
|
|
713
|
+
call_args = mock_aembedding.call_args
|
|
714
|
+
assert call_args[1]["dimensions"] == 1536
|
|
715
|
+
# Other properties should not be passed
|
|
716
|
+
assert "custom_property" not in call_args[1]
|
|
717
|
+
assert "numeric_property" not in call_args[1]
|
|
718
|
+
assert "boolean_property" not in call_args[1]
|
|
719
|
+
|
|
720
|
+
|
|
721
|
+
@pytest.mark.paid
|
|
722
|
+
@pytest.mark.parametrize(
|
|
723
|
+
"provider,model_name,expected_dim",
|
|
724
|
+
[
|
|
725
|
+
(ModelProviderName.openai, "openai_text_embedding_3_small", 1536),
|
|
726
|
+
(ModelProviderName.openai, "openai_text_embedding_3_large", 3072),
|
|
727
|
+
(ModelProviderName.gemini_api, "gemini_text_embedding_004", 768),
|
|
728
|
+
],
|
|
729
|
+
)
|
|
730
|
+
@pytest.mark.asyncio
|
|
731
|
+
async def test_paid_generate_embeddings_basic(
|
|
732
|
+
provider, model_name, expected_dim, mock_litellm_core_config
|
|
733
|
+
):
|
|
734
|
+
openai_key = Config.shared().open_ai_api_key
|
|
735
|
+
if not openai_key:
|
|
736
|
+
pytest.skip("OPENAI_API_KEY not set")
|
|
737
|
+
# Set the API key for litellm
|
|
738
|
+
os.environ["OPENAI_API_KEY"] = openai_key
|
|
739
|
+
|
|
740
|
+
# gemini key
|
|
741
|
+
gemini_key = Config.shared().gemini_api_key
|
|
742
|
+
if not gemini_key:
|
|
743
|
+
pytest.skip("GEMINI_API_KEY not set")
|
|
744
|
+
os.environ["GEMINI_API_KEY"] = gemini_key
|
|
745
|
+
|
|
746
|
+
config = EmbeddingConfig(
|
|
747
|
+
name="paid-embedding",
|
|
748
|
+
model_provider_name=provider,
|
|
749
|
+
model_name=model_name,
|
|
750
|
+
properties={},
|
|
751
|
+
)
|
|
752
|
+
adapter = LitellmEmbeddingAdapter(
|
|
753
|
+
config, litellm_core_config=mock_litellm_core_config
|
|
754
|
+
)
|
|
755
|
+
text = ["Kiln is an open-source evaluation platform for LLMs."]
|
|
756
|
+
result = await adapter.generate_embeddings(text)
|
|
757
|
+
assert len(result.embeddings) == 1
|
|
758
|
+
assert isinstance(result.embeddings[0].vector, list)
|
|
759
|
+
assert len(result.embeddings[0].vector) == expected_dim
|
|
760
|
+
assert all(isinstance(x, float) for x in result.embeddings[0].vector)
|
|
761
|
+
|
|
762
|
+
|
|
763
|
+
# test model_provider
|
|
764
|
+
def test_model_provider(mock_litellm_core_config):
|
|
765
|
+
mock_embedding_config = EmbeddingConfig(
|
|
766
|
+
name="test",
|
|
767
|
+
model_provider_name=ModelProviderName.openai,
|
|
768
|
+
model_name="openai_text_embedding_3_small",
|
|
769
|
+
properties={},
|
|
770
|
+
)
|
|
771
|
+
adapter = LitellmEmbeddingAdapter(
|
|
772
|
+
mock_embedding_config, litellm_core_config=mock_litellm_core_config
|
|
773
|
+
)
|
|
774
|
+
assert adapter.model_provider.name == ModelProviderName.openai
|
|
775
|
+
assert adapter.model_provider.model_id == "text-embedding-3-small"
|
|
776
|
+
|
|
777
|
+
|
|
778
|
+
def test_model_provider_gemini(mock_litellm_core_config):
|
|
779
|
+
config = EmbeddingConfig(
|
|
780
|
+
name="test",
|
|
781
|
+
model_provider_name=ModelProviderName.gemini_api,
|
|
782
|
+
model_name="gemini_text_embedding_004",
|
|
783
|
+
properties={},
|
|
784
|
+
)
|
|
785
|
+
adapter = LitellmEmbeddingAdapter(
|
|
786
|
+
config, litellm_core_config=mock_litellm_core_config
|
|
787
|
+
)
|
|
788
|
+
assert adapter.model_provider.name == ModelProviderName.gemini_api
|
|
789
|
+
assert adapter.model_provider.model_id == "text-embedding-004"
|
|
790
|
+
|
|
791
|
+
|
|
792
|
+
@pytest.mark.parametrize(
|
|
793
|
+
"provider,model_name,expected_model_id",
|
|
794
|
+
[
|
|
795
|
+
(
|
|
796
|
+
ModelProviderName.gemini_api,
|
|
797
|
+
"gemini_text_embedding_004",
|
|
798
|
+
"gemini/text-embedding-004",
|
|
799
|
+
),
|
|
800
|
+
(
|
|
801
|
+
ModelProviderName.openai,
|
|
802
|
+
"openai_text_embedding_3_small",
|
|
803
|
+
"openai/text-embedding-3-small",
|
|
804
|
+
),
|
|
805
|
+
],
|
|
806
|
+
)
|
|
807
|
+
def test_litellm_model_id(
|
|
808
|
+
provider, model_name, expected_model_id, mock_litellm_core_config
|
|
809
|
+
):
|
|
810
|
+
config = EmbeddingConfig(
|
|
811
|
+
name="test",
|
|
812
|
+
model_provider_name=provider,
|
|
813
|
+
model_name=model_name,
|
|
814
|
+
properties={},
|
|
815
|
+
)
|
|
816
|
+
adapter = LitellmEmbeddingAdapter(
|
|
817
|
+
config, litellm_core_config=mock_litellm_core_config
|
|
818
|
+
)
|
|
819
|
+
assert adapter.litellm_model_id == expected_model_id
|
|
820
|
+
|
|
821
|
+
|
|
822
|
+
def test_litellm_model_id_custom_provider_without_base_url(mock_litellm_core_config):
|
|
823
|
+
"""Test that custom providers without base_url raise an error."""
|
|
824
|
+
config = EmbeddingConfig(
|
|
825
|
+
name="test",
|
|
826
|
+
model_provider_name=ModelProviderName.openai_compatible,
|
|
827
|
+
model_name="some-model",
|
|
828
|
+
properties={},
|
|
829
|
+
)
|
|
830
|
+
adapter = LitellmEmbeddingAdapter(
|
|
831
|
+
config, litellm_core_config=mock_litellm_core_config
|
|
832
|
+
)
|
|
833
|
+
|
|
834
|
+
with pytest.raises(
|
|
835
|
+
ValueError,
|
|
836
|
+
match="Embedding model some-model not found in the list of built-in models",
|
|
837
|
+
):
|
|
838
|
+
adapter.model_provider
|
|
839
|
+
|
|
840
|
+
|
|
841
|
+
def test_litellm_model_id_custom_provider_with_base_url(mock_litellm_core_config):
|
|
842
|
+
"""Test that custom providers with base_url work correctly."""
|
|
843
|
+
# Set up a custom provider with base_url
|
|
844
|
+
mock_litellm_core_config.base_url = "https://custom-api.example.com"
|
|
845
|
+
|
|
846
|
+
config = EmbeddingConfig(
|
|
847
|
+
name="test",
|
|
848
|
+
model_provider_name=ModelProviderName.openai_compatible,
|
|
849
|
+
model_name="some-model",
|
|
850
|
+
properties={},
|
|
851
|
+
)
|
|
852
|
+
adapter = LitellmEmbeddingAdapter(
|
|
853
|
+
config, litellm_core_config=mock_litellm_core_config
|
|
854
|
+
)
|
|
855
|
+
|
|
856
|
+
with pytest.raises(
|
|
857
|
+
ValueError,
|
|
858
|
+
match="Embedding model some-model not found in the list of built-in models",
|
|
859
|
+
):
|
|
860
|
+
adapter.model_provider
|
|
861
|
+
|
|
862
|
+
|
|
863
|
+
def test_litellm_model_id_custom_provider_ollama_with_base_url():
|
|
864
|
+
"""Test that ollama provider with base_url works correctly."""
|
|
865
|
+
|
|
866
|
+
# Create a mock provider that would be found in the built-in models
|
|
867
|
+
# We need to mock the built_in_embedding_models_from_provider function
|
|
868
|
+
with patch(
|
|
869
|
+
"kiln_ai.adapters.embedding.litellm_embedding_adapter.built_in_embedding_models_from_provider"
|
|
870
|
+
) as mock_built_in:
|
|
871
|
+
mock_built_in.return_value = KilnEmbeddingModelProvider(
|
|
872
|
+
name=ModelProviderName.ollama,
|
|
873
|
+
model_id="test-model",
|
|
874
|
+
n_dimensions=768,
|
|
875
|
+
)
|
|
876
|
+
|
|
877
|
+
config = EmbeddingConfig(
|
|
878
|
+
name="test",
|
|
879
|
+
model_provider_name=ModelProviderName.ollama,
|
|
880
|
+
model_name="test-model",
|
|
881
|
+
properties={},
|
|
882
|
+
)
|
|
883
|
+
|
|
884
|
+
# With base_url - should work
|
|
885
|
+
litellm_core_config_with_url = LiteLlmCoreConfig(
|
|
886
|
+
base_url="http://localhost:11434"
|
|
887
|
+
)
|
|
888
|
+
adapter = LitellmEmbeddingAdapter(
|
|
889
|
+
config, litellm_core_config=litellm_core_config_with_url
|
|
890
|
+
)
|
|
891
|
+
|
|
892
|
+
# Should not raise an error
|
|
893
|
+
model_id = adapter.litellm_model_id
|
|
894
|
+
assert model_id == "openai/test-model"
|
|
895
|
+
|
|
896
|
+
|
|
897
|
+
def test_litellm_model_id_custom_provider_ollama_without_base_url():
|
|
898
|
+
"""Test that ollama provider without base_url raises an error."""
|
|
899
|
+
from kiln_ai.adapters.ml_embedding_model_list import KilnEmbeddingModelProvider
|
|
900
|
+
|
|
901
|
+
# Create a mock provider that would be found in the built-in models
|
|
902
|
+
with patch(
|
|
903
|
+
"kiln_ai.adapters.embedding.litellm_embedding_adapter.built_in_embedding_models_from_provider"
|
|
904
|
+
) as mock_built_in:
|
|
905
|
+
mock_built_in.return_value = KilnEmbeddingModelProvider(
|
|
906
|
+
name=ModelProviderName.ollama,
|
|
907
|
+
model_id="test-model",
|
|
908
|
+
n_dimensions=768,
|
|
909
|
+
)
|
|
910
|
+
|
|
911
|
+
config = EmbeddingConfig(
|
|
912
|
+
name="test",
|
|
913
|
+
model_provider_name=ModelProviderName.ollama,
|
|
914
|
+
model_name="test-model",
|
|
915
|
+
properties={},
|
|
916
|
+
)
|
|
917
|
+
|
|
918
|
+
# Without base_url - should raise an error
|
|
919
|
+
litellm_core_config_without_url = LiteLlmCoreConfig(base_url=None)
|
|
920
|
+
adapter = LitellmEmbeddingAdapter(
|
|
921
|
+
config, litellm_core_config=litellm_core_config_without_url
|
|
922
|
+
)
|
|
923
|
+
|
|
924
|
+
with pytest.raises(
|
|
925
|
+
ValueError,
|
|
926
|
+
match="Provider ollama must have an explicit base URL",
|
|
927
|
+
):
|
|
928
|
+
adapter.litellm_model_id
|
|
929
|
+
|
|
930
|
+
|
|
931
|
+
def test_litellm_model_id_custom_provider_openai_compatible_with_base_url():
|
|
932
|
+
"""Test that openai_compatible provider with base_url works correctly."""
|
|
933
|
+
from kiln_ai.adapters.ml_embedding_model_list import KilnEmbeddingModelProvider
|
|
934
|
+
|
|
935
|
+
# Create a mock provider that would be found in the built-in models
|
|
936
|
+
with patch(
|
|
937
|
+
"kiln_ai.adapters.embedding.litellm_embedding_adapter.built_in_embedding_models_from_provider"
|
|
938
|
+
) as mock_built_in:
|
|
939
|
+
mock_built_in.return_value = KilnEmbeddingModelProvider(
|
|
940
|
+
name=ModelProviderName.openai_compatible,
|
|
941
|
+
model_id="test-model",
|
|
942
|
+
n_dimensions=768,
|
|
943
|
+
)
|
|
944
|
+
|
|
945
|
+
config = EmbeddingConfig(
|
|
946
|
+
name="test",
|
|
947
|
+
model_provider_name=ModelProviderName.openai_compatible,
|
|
948
|
+
model_name="test-model",
|
|
949
|
+
properties={},
|
|
950
|
+
)
|
|
951
|
+
|
|
952
|
+
# With base_url - should work
|
|
953
|
+
litellm_core_config_with_url = LiteLlmCoreConfig(
|
|
954
|
+
base_url="https://custom-api.example.com"
|
|
955
|
+
)
|
|
956
|
+
adapter = LitellmEmbeddingAdapter(
|
|
957
|
+
config, litellm_core_config=litellm_core_config_with_url
|
|
958
|
+
)
|
|
959
|
+
|
|
960
|
+
# Should not raise an error
|
|
961
|
+
model_id = adapter.litellm_model_id
|
|
962
|
+
assert model_id == "openai/test-model"
|
|
963
|
+
|
|
964
|
+
|
|
965
|
+
def test_litellm_model_id_custom_provider_openai_compatible_without_base_url():
|
|
966
|
+
"""Test that openai_compatible provider without base_url raises an error."""
|
|
967
|
+
|
|
968
|
+
# Create a mock provider that would be found in the built-in models
|
|
969
|
+
with patch(
|
|
970
|
+
"kiln_ai.adapters.embedding.litellm_embedding_adapter.built_in_embedding_models_from_provider"
|
|
971
|
+
) as mock_built_in:
|
|
972
|
+
mock_built_in.return_value = KilnEmbeddingModelProvider(
|
|
973
|
+
name=ModelProviderName.openai_compatible,
|
|
974
|
+
model_id="test-model",
|
|
975
|
+
n_dimensions=768,
|
|
976
|
+
)
|
|
977
|
+
|
|
978
|
+
config = EmbeddingConfig(
|
|
979
|
+
name="test",
|
|
980
|
+
model_provider_name=ModelProviderName.openai_compatible,
|
|
981
|
+
model_name="test-model",
|
|
982
|
+
properties={},
|
|
983
|
+
)
|
|
984
|
+
|
|
985
|
+
# Without base_url - should raise an error
|
|
986
|
+
litellm_core_config_without_url = LiteLlmCoreConfig(base_url=None)
|
|
987
|
+
adapter = LitellmEmbeddingAdapter(
|
|
988
|
+
config, litellm_core_config=litellm_core_config_without_url
|
|
989
|
+
)
|
|
990
|
+
|
|
991
|
+
with pytest.raises(
|
|
992
|
+
ValueError,
|
|
993
|
+
match="Provider openai_compatible must have an explicit base URL",
|
|
994
|
+
):
|
|
995
|
+
adapter.litellm_model_id
|
|
996
|
+
|
|
997
|
+
|
|
998
|
+
@pytest.mark.paid
|
|
999
|
+
@pytest.mark.parametrize(
|
|
1000
|
+
"provider,model_name,expected_dim",
|
|
1001
|
+
[
|
|
1002
|
+
(ModelProviderName.openai, "openai_text_embedding_3_small", 256),
|
|
1003
|
+
(ModelProviderName.openai, "openai_text_embedding_3_small", 512),
|
|
1004
|
+
(ModelProviderName.openai, "openai_text_embedding_3_large", 256),
|
|
1005
|
+
(ModelProviderName.openai, "openai_text_embedding_3_large", 512),
|
|
1006
|
+
(ModelProviderName.openai, "openai_text_embedding_3_large", 1024),
|
|
1007
|
+
(ModelProviderName.openai, "openai_text_embedding_3_large", 2048),
|
|
1008
|
+
],
|
|
1009
|
+
)
|
|
1010
|
+
@pytest.mark.asyncio
|
|
1011
|
+
async def test_paid_generate_embeddings_with_custom_dimensions_supported(
|
|
1012
|
+
provider, model_name, expected_dim, mock_litellm_core_config
|
|
1013
|
+
):
|
|
1014
|
+
"""
|
|
1015
|
+
Some models support custom dimensions - where the provider shortens the dimensions to match
|
|
1016
|
+
the desired custom number of dimensions. Ref: https://openai.com/index/new-embedding-models-and-api-updates/
|
|
1017
|
+
"""
|
|
1018
|
+
api_key = Config.shared().open_ai_api_key or os.environ.get("OPENAI_API_KEY")
|
|
1019
|
+
if api_key:
|
|
1020
|
+
os.environ["OPENAI_API_KEY"] = api_key
|
|
1021
|
+
|
|
1022
|
+
config = EmbeddingConfig(
|
|
1023
|
+
name="paid-embedding",
|
|
1024
|
+
model_provider_name=provider,
|
|
1025
|
+
model_name=model_name,
|
|
1026
|
+
properties={"dimensions": expected_dim},
|
|
1027
|
+
)
|
|
1028
|
+
adapter = LitellmEmbeddingAdapter(
|
|
1029
|
+
config, litellm_core_config=mock_litellm_core_config
|
|
1030
|
+
)
|
|
1031
|
+
text = ["Kiln is an open-source evaluation platform for LLMs."]
|
|
1032
|
+
result = await adapter.generate_embeddings(text)
|
|
1033
|
+
assert len(result.embeddings) == 1
|
|
1034
|
+
assert isinstance(result.embeddings[0].vector, list)
|
|
1035
|
+
assert len(result.embeddings[0].vector) == expected_dim
|
|
1036
|
+
assert all(isinstance(x, float) for x in result.embeddings[0].vector)
|
|
1037
|
+
|
|
1038
|
+
|
|
1039
|
+
def test_validate_map_to_embeddings():
|
|
1040
|
+
mock_response = AsyncMock(spec=EmbeddingResponse)
|
|
1041
|
+
mock_response.data = [
|
|
1042
|
+
{"object": "embedding", "index": 0, "embedding": [0.1, 0.2, 0.3]},
|
|
1043
|
+
{"object": "embedding", "index": 1, "embedding": [0.4, 0.5, 0.6]},
|
|
1044
|
+
]
|
|
1045
|
+
expected_embeddings = [
|
|
1046
|
+
Embedding(vector=[0.1, 0.2, 0.3]),
|
|
1047
|
+
Embedding(vector=[0.4, 0.5, 0.6]),
|
|
1048
|
+
]
|
|
1049
|
+
result = validate_map_to_embeddings(mock_response, 2)
|
|
1050
|
+
assert result == expected_embeddings
|
|
1051
|
+
|
|
1052
|
+
|
|
1053
|
+
def test_validate_map_to_embeddings_invalid_length():
|
|
1054
|
+
mock_response = AsyncMock(spec=EmbeddingResponse)
|
|
1055
|
+
mock_response.data = [
|
|
1056
|
+
{"object": "embedding", "index": 0, "embedding": [0.1, 0.2, 0.3]},
|
|
1057
|
+
]
|
|
1058
|
+
with pytest.raises(
|
|
1059
|
+
RuntimeError,
|
|
1060
|
+
match=r"Expected the number of embeddings in the response to be 2, got 1.",
|
|
1061
|
+
):
|
|
1062
|
+
validate_map_to_embeddings(mock_response, 2)
|
|
1063
|
+
|
|
1064
|
+
|
|
1065
|
+
def test_validate_map_to_embeddings_invalid_object_type():
|
|
1066
|
+
mock_response = AsyncMock(spec=EmbeddingResponse)
|
|
1067
|
+
mock_response.data = [
|
|
1068
|
+
{"object": "not_embedding", "index": 0, "embedding": [0.1, 0.2, 0.3]},
|
|
1069
|
+
]
|
|
1070
|
+
with pytest.raises(
|
|
1071
|
+
RuntimeError,
|
|
1072
|
+
match=r"Embedding response data has an unexpected shape. Property 'object' is not 'embedding'. Got not_embedding.",
|
|
1073
|
+
):
|
|
1074
|
+
validate_map_to_embeddings(mock_response, 1)
|
|
1075
|
+
|
|
1076
|
+
|
|
1077
|
+
def test_validate_map_to_embeddings_invalid_embedding_type():
|
|
1078
|
+
mock_response = AsyncMock(spec=EmbeddingResponse)
|
|
1079
|
+
mock_response.data = [
|
|
1080
|
+
{"object": "embedding", "index": 0, "embedding": "not_a_list"},
|
|
1081
|
+
]
|
|
1082
|
+
mock_response.usage = Usage(prompt_tokens=5, total_tokens=5)
|
|
1083
|
+
with pytest.raises(
|
|
1084
|
+
RuntimeError,
|
|
1085
|
+
match=r"Embedding response data has an unexpected shape. Property 'embedding' is not a list. Got <class 'str'>.",
|
|
1086
|
+
):
|
|
1087
|
+
validate_map_to_embeddings(mock_response, 1)
|
|
1088
|
+
|
|
1089
|
+
# missing embedding
|
|
1090
|
+
mock_response = AsyncMock(spec=EmbeddingResponse)
|
|
1091
|
+
mock_response.data = [
|
|
1092
|
+
{"object": "embedding", "index": 0},
|
|
1093
|
+
]
|
|
1094
|
+
with pytest.raises(
|
|
1095
|
+
RuntimeError,
|
|
1096
|
+
match=r"Embedding response data has an unexpected shape. Property 'embedding' is None in response data item.",
|
|
1097
|
+
):
|
|
1098
|
+
validate_map_to_embeddings(mock_response, 1)
|
|
1099
|
+
|
|
1100
|
+
|
|
1101
|
+
def test_validate_map_to_embeddings_invalid_index_type():
|
|
1102
|
+
mock_response = AsyncMock(spec=EmbeddingResponse)
|
|
1103
|
+
mock_response.data = [
|
|
1104
|
+
{"object": "embedding", "index": "not_an_int", "embedding": [0.1, 0.2, 0.3]},
|
|
1105
|
+
]
|
|
1106
|
+
mock_response.usage = Usage(prompt_tokens=5, total_tokens=5)
|
|
1107
|
+
with pytest.raises(
|
|
1108
|
+
RuntimeError,
|
|
1109
|
+
match=r"Embedding response data has an unexpected shape. Property 'index' is not an integer. Got <class 'str'>.",
|
|
1110
|
+
):
|
|
1111
|
+
validate_map_to_embeddings(mock_response, 1)
|
|
1112
|
+
|
|
1113
|
+
# missing index
|
|
1114
|
+
mock_response = AsyncMock(spec=EmbeddingResponse)
|
|
1115
|
+
mock_response.data = [
|
|
1116
|
+
{"object": "embedding", "embedding": [0.1, 0.2, 0.3]},
|
|
1117
|
+
]
|
|
1118
|
+
with pytest.raises(
|
|
1119
|
+
RuntimeError,
|
|
1120
|
+
match=r"Embedding response data has an unexpected shape. Property 'index' is None in response data item.",
|
|
1121
|
+
):
|
|
1122
|
+
validate_map_to_embeddings(mock_response, 1)
|
|
1123
|
+
|
|
1124
|
+
|
|
1125
|
+
def test_validate_map_to_embeddings_sorting():
|
|
1126
|
+
mock_response = AsyncMock(spec=EmbeddingResponse)
|
|
1127
|
+
mock_response.data = [
|
|
1128
|
+
{"object": "embedding", "index": 2, "embedding": [0.3, 0.4, 0.5]},
|
|
1129
|
+
{"object": "embedding", "index": 0, "embedding": [0.1, 0.2, 0.3]},
|
|
1130
|
+
{"object": "embedding", "index": 1, "embedding": [0.2, 0.3, 0.4]},
|
|
1131
|
+
]
|
|
1132
|
+
expected_embeddings = [
|
|
1133
|
+
Embedding(vector=[0.1, 0.2, 0.3]),
|
|
1134
|
+
Embedding(vector=[0.2, 0.3, 0.4]),
|
|
1135
|
+
Embedding(vector=[0.3, 0.4, 0.5]),
|
|
1136
|
+
]
|
|
1137
|
+
result = validate_map_to_embeddings(mock_response, 3)
|
|
1138
|
+
assert result == expected_embeddings
|
|
1139
|
+
|
|
1140
|
+
|
|
1141
|
+
def test_generate_embeddings_response_not_embedding_response():
|
|
1142
|
+
response = AsyncMock()
|
|
1143
|
+
response.data = [{"object": "embedding", "index": 0, "embedding": [0.1, 0.2, 0.3]}]
|
|
1144
|
+
response.usage = Usage(prompt_tokens=5, total_tokens=5)
|
|
1145
|
+
with pytest.raises(
|
|
1146
|
+
RuntimeError,
|
|
1147
|
+
match=r"Expected EmbeddingResponse, got <class 'unittest.mock.AsyncMock'>.",
|
|
1148
|
+
):
|
|
1149
|
+
validate_map_to_embeddings(response, 1)
|