kiln-ai 0.20.1__py3-none-any.whl → 0.22.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (133) hide show
  1. kiln_ai/adapters/__init__.py +6 -0
  2. kiln_ai/adapters/adapter_registry.py +43 -226
  3. kiln_ai/adapters/chunkers/__init__.py +13 -0
  4. kiln_ai/adapters/chunkers/base_chunker.py +42 -0
  5. kiln_ai/adapters/chunkers/chunker_registry.py +16 -0
  6. kiln_ai/adapters/chunkers/fixed_window_chunker.py +39 -0
  7. kiln_ai/adapters/chunkers/helpers.py +23 -0
  8. kiln_ai/adapters/chunkers/test_base_chunker.py +63 -0
  9. kiln_ai/adapters/chunkers/test_chunker_registry.py +28 -0
  10. kiln_ai/adapters/chunkers/test_fixed_window_chunker.py +346 -0
  11. kiln_ai/adapters/chunkers/test_helpers.py +75 -0
  12. kiln_ai/adapters/data_gen/test_data_gen_task.py +9 -3
  13. kiln_ai/adapters/embedding/__init__.py +0 -0
  14. kiln_ai/adapters/embedding/base_embedding_adapter.py +44 -0
  15. kiln_ai/adapters/embedding/embedding_registry.py +32 -0
  16. kiln_ai/adapters/embedding/litellm_embedding_adapter.py +199 -0
  17. kiln_ai/adapters/embedding/test_base_embedding_adapter.py +283 -0
  18. kiln_ai/adapters/embedding/test_embedding_registry.py +166 -0
  19. kiln_ai/adapters/embedding/test_litellm_embedding_adapter.py +1149 -0
  20. kiln_ai/adapters/eval/eval_runner.py +6 -2
  21. kiln_ai/adapters/eval/test_base_eval.py +1 -3
  22. kiln_ai/adapters/eval/test_g_eval.py +1 -1
  23. kiln_ai/adapters/extractors/__init__.py +18 -0
  24. kiln_ai/adapters/extractors/base_extractor.py +72 -0
  25. kiln_ai/adapters/extractors/encoding.py +20 -0
  26. kiln_ai/adapters/extractors/extractor_registry.py +44 -0
  27. kiln_ai/adapters/extractors/extractor_runner.py +112 -0
  28. kiln_ai/adapters/extractors/litellm_extractor.py +406 -0
  29. kiln_ai/adapters/extractors/test_base_extractor.py +244 -0
  30. kiln_ai/adapters/extractors/test_encoding.py +54 -0
  31. kiln_ai/adapters/extractors/test_extractor_registry.py +181 -0
  32. kiln_ai/adapters/extractors/test_extractor_runner.py +181 -0
  33. kiln_ai/adapters/extractors/test_litellm_extractor.py +1290 -0
  34. kiln_ai/adapters/fine_tune/test_dataset_formatter.py +2 -2
  35. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +2 -6
  36. kiln_ai/adapters/fine_tune/test_together_finetune.py +2 -6
  37. kiln_ai/adapters/ml_embedding_model_list.py +494 -0
  38. kiln_ai/adapters/ml_model_list.py +876 -18
  39. kiln_ai/adapters/model_adapters/litellm_adapter.py +40 -75
  40. kiln_ai/adapters/model_adapters/test_litellm_adapter.py +79 -1
  41. kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py +119 -5
  42. kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +9 -3
  43. kiln_ai/adapters/model_adapters/test_structured_output.py +9 -10
  44. kiln_ai/adapters/ollama_tools.py +69 -12
  45. kiln_ai/adapters/provider_tools.py +190 -46
  46. kiln_ai/adapters/rag/deduplication.py +49 -0
  47. kiln_ai/adapters/rag/progress.py +252 -0
  48. kiln_ai/adapters/rag/rag_runners.py +844 -0
  49. kiln_ai/adapters/rag/test_deduplication.py +195 -0
  50. kiln_ai/adapters/rag/test_progress.py +785 -0
  51. kiln_ai/adapters/rag/test_rag_runners.py +2376 -0
  52. kiln_ai/adapters/remote_config.py +80 -8
  53. kiln_ai/adapters/test_adapter_registry.py +579 -86
  54. kiln_ai/adapters/test_ml_embedding_model_list.py +239 -0
  55. kiln_ai/adapters/test_ml_model_list.py +202 -0
  56. kiln_ai/adapters/test_ollama_tools.py +340 -1
  57. kiln_ai/adapters/test_prompt_builders.py +1 -1
  58. kiln_ai/adapters/test_provider_tools.py +199 -8
  59. kiln_ai/adapters/test_remote_config.py +551 -56
  60. kiln_ai/adapters/vector_store/__init__.py +1 -0
  61. kiln_ai/adapters/vector_store/base_vector_store_adapter.py +83 -0
  62. kiln_ai/adapters/vector_store/lancedb_adapter.py +389 -0
  63. kiln_ai/adapters/vector_store/test_base_vector_store.py +160 -0
  64. kiln_ai/adapters/vector_store/test_lancedb_adapter.py +1841 -0
  65. kiln_ai/adapters/vector_store/test_vector_store_registry.py +199 -0
  66. kiln_ai/adapters/vector_store/vector_store_registry.py +33 -0
  67. kiln_ai/datamodel/__init__.py +16 -13
  68. kiln_ai/datamodel/basemodel.py +201 -4
  69. kiln_ai/datamodel/chunk.py +158 -0
  70. kiln_ai/datamodel/datamodel_enums.py +27 -0
  71. kiln_ai/datamodel/embedding.py +64 -0
  72. kiln_ai/datamodel/external_tool_server.py +206 -54
  73. kiln_ai/datamodel/extraction.py +317 -0
  74. kiln_ai/datamodel/project.py +33 -1
  75. kiln_ai/datamodel/rag.py +79 -0
  76. kiln_ai/datamodel/task.py +5 -0
  77. kiln_ai/datamodel/task_output.py +41 -11
  78. kiln_ai/datamodel/test_attachment.py +649 -0
  79. kiln_ai/datamodel/test_basemodel.py +270 -14
  80. kiln_ai/datamodel/test_chunk_models.py +317 -0
  81. kiln_ai/datamodel/test_dataset_split.py +1 -1
  82. kiln_ai/datamodel/test_datasource.py +50 -0
  83. kiln_ai/datamodel/test_embedding_models.py +448 -0
  84. kiln_ai/datamodel/test_eval_model.py +6 -6
  85. kiln_ai/datamodel/test_external_tool_server.py +534 -152
  86. kiln_ai/datamodel/test_extraction_chunk.py +206 -0
  87. kiln_ai/datamodel/test_extraction_model.py +501 -0
  88. kiln_ai/datamodel/test_rag.py +641 -0
  89. kiln_ai/datamodel/test_task.py +35 -1
  90. kiln_ai/datamodel/test_tool_id.py +187 -1
  91. kiln_ai/datamodel/test_vector_store.py +320 -0
  92. kiln_ai/datamodel/tool_id.py +58 -0
  93. kiln_ai/datamodel/vector_store.py +141 -0
  94. kiln_ai/tools/base_tool.py +12 -3
  95. kiln_ai/tools/built_in_tools/math_tools.py +12 -4
  96. kiln_ai/tools/kiln_task_tool.py +158 -0
  97. kiln_ai/tools/mcp_server_tool.py +2 -2
  98. kiln_ai/tools/mcp_session_manager.py +51 -22
  99. kiln_ai/tools/rag_tools.py +164 -0
  100. kiln_ai/tools/test_kiln_task_tool.py +527 -0
  101. kiln_ai/tools/test_mcp_server_tool.py +4 -15
  102. kiln_ai/tools/test_mcp_session_manager.py +187 -227
  103. kiln_ai/tools/test_rag_tools.py +929 -0
  104. kiln_ai/tools/test_tool_registry.py +290 -7
  105. kiln_ai/tools/tool_registry.py +69 -16
  106. kiln_ai/utils/__init__.py +3 -0
  107. kiln_ai/utils/async_job_runner.py +62 -17
  108. kiln_ai/utils/config.py +2 -2
  109. kiln_ai/utils/env.py +15 -0
  110. kiln_ai/utils/filesystem.py +14 -0
  111. kiln_ai/utils/filesystem_cache.py +60 -0
  112. kiln_ai/utils/litellm.py +94 -0
  113. kiln_ai/utils/lock.py +100 -0
  114. kiln_ai/utils/mime_type.py +38 -0
  115. kiln_ai/utils/open_ai_types.py +19 -2
  116. kiln_ai/utils/pdf_utils.py +59 -0
  117. kiln_ai/utils/test_async_job_runner.py +151 -35
  118. kiln_ai/utils/test_env.py +142 -0
  119. kiln_ai/utils/test_filesystem_cache.py +316 -0
  120. kiln_ai/utils/test_litellm.py +206 -0
  121. kiln_ai/utils/test_lock.py +185 -0
  122. kiln_ai/utils/test_mime_type.py +66 -0
  123. kiln_ai/utils/test_open_ai_types.py +88 -12
  124. kiln_ai/utils/test_pdf_utils.py +86 -0
  125. kiln_ai/utils/test_uuid.py +111 -0
  126. kiln_ai/utils/test_validation.py +524 -0
  127. kiln_ai/utils/uuid.py +9 -0
  128. kiln_ai/utils/validation.py +90 -0
  129. {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/METADATA +9 -1
  130. kiln_ai-0.22.0.dist-info/RECORD +213 -0
  131. kiln_ai-0.20.1.dist-info/RECORD +0 -138
  132. {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/WHEEL +0 -0
  133. {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -0,0 +1,1149 @@
1
+ import os
2
+ from unittest.mock import AsyncMock, patch
3
+
4
+ import pytest
5
+ from litellm import Usage
6
+ from litellm.types.utils import EmbeddingResponse
7
+
8
+ from kiln_ai.adapters.embedding.base_embedding_adapter import Embedding
9
+ from kiln_ai.adapters.embedding.litellm_embedding_adapter import (
10
+ MAX_BATCH_SIZE,
11
+ EmbeddingOptions,
12
+ LitellmEmbeddingAdapter,
13
+ validate_map_to_embeddings,
14
+ )
15
+ from kiln_ai.adapters.ml_embedding_model_list import KilnEmbeddingModelProvider
16
+ from kiln_ai.adapters.provider_tools import LiteLlmCoreConfig
17
+ from kiln_ai.datamodel.datamodel_enums import ModelProviderName
18
+ from kiln_ai.datamodel.embedding import EmbeddingConfig
19
+ from kiln_ai.utils.config import Config
20
+
21
+
22
+ @pytest.fixture
23
+ def mock_embedding_config():
24
+ return EmbeddingConfig(
25
+ name="test-embedding",
26
+ model_provider_name=ModelProviderName.openai,
27
+ model_name="openai_text_embedding_3_small",
28
+ properties={},
29
+ )
30
+
31
+
32
+ @pytest.fixture
33
+ def mock_litellm_core_config():
34
+ return LiteLlmCoreConfig()
35
+
36
+
37
+ @pytest.fixture
38
+ def mock_litellm_adapter(mock_embedding_config, mock_litellm_core_config):
39
+ return LitellmEmbeddingAdapter(
40
+ mock_embedding_config, litellm_core_config=mock_litellm_core_config
41
+ )
42
+
43
+
44
+ class TestEmbeddingOptions:
45
+ """Test the EmbeddingOptions class."""
46
+
47
+ def test_default_values(self):
48
+ """Test that EmbeddingOptions has correct default values."""
49
+ options = EmbeddingOptions()
50
+ assert options.dimensions is None
51
+
52
+ def test_with_dimensions(self):
53
+ """Test EmbeddingOptions with dimensions set."""
54
+ options = EmbeddingOptions(dimensions=1536)
55
+ assert options.dimensions == 1536
56
+
57
+ def test_model_dump_excludes_none(self):
58
+ """Test that model_dump excludes None values."""
59
+ options = EmbeddingOptions()
60
+ dumped = options.model_dump(exclude_none=True)
61
+ assert "dimensions" not in dumped
62
+
63
+ options_with_dim = EmbeddingOptions(dimensions=1536)
64
+ dumped_with_dim = options_with_dim.model_dump(exclude_none=True)
65
+ assert "dimensions" in dumped_with_dim
66
+ assert dumped_with_dim["dimensions"] == 1536
67
+
68
+
69
+ class TestLitellmEmbeddingAdapter:
70
+ """Test the LitellmEmbeddingAdapter class."""
71
+
72
+ def test_init_success(self, mock_embedding_config, mock_litellm_core_config):
73
+ """Test successful initialization of the adapter."""
74
+ adapter = LitellmEmbeddingAdapter(
75
+ mock_embedding_config, mock_litellm_core_config
76
+ )
77
+ assert adapter.embedding_config == mock_embedding_config
78
+
79
+ def test_build_options_no_dimensions(self, mock_litellm_adapter):
80
+ """Test build_options when no dimensions are specified."""
81
+ options = mock_litellm_adapter.build_options()
82
+ assert options.dimensions is None
83
+
84
+ def test_build_options_with_dimensions(
85
+ self, mock_embedding_config, mock_litellm_core_config
86
+ ):
87
+ """Test build_options when dimensions are specified."""
88
+ mock_embedding_config.properties = {"dimensions": 1536}
89
+ adapter = LitellmEmbeddingAdapter(
90
+ mock_embedding_config, litellm_core_config=mock_litellm_core_config
91
+ )
92
+ options = adapter.build_options()
93
+ assert options.dimensions == 1536
94
+
95
+ async def test_generate_embeddings_with_completion_kwargs(
96
+ self, mock_embedding_config, mock_litellm_core_config
97
+ ):
98
+ """Test that completion_kwargs are properly passed to litellm.aembedding."""
99
+ # Set up litellm_core_config with additional options
100
+ mock_litellm_core_config.additional_body_options = {"custom_param": "value"}
101
+ mock_litellm_core_config.base_url = "https://custom-api.example.com"
102
+ mock_litellm_core_config.default_headers = {
103
+ "Authorization": "Bearer custom-token"
104
+ }
105
+
106
+ adapter = LitellmEmbeddingAdapter(
107
+ mock_embedding_config, litellm_core_config=mock_litellm_core_config
108
+ )
109
+
110
+ mock_response = AsyncMock(spec=EmbeddingResponse)
111
+ mock_response.data = [
112
+ {"object": "embedding", "index": 0, "embedding": [0.1, 0.2, 0.3]}
113
+ ]
114
+ mock_response.usage = Usage(prompt_tokens=5, total_tokens=5)
115
+
116
+ with patch(
117
+ "litellm.aembedding", new_callable=AsyncMock, return_value=mock_response
118
+ ) as mock_aembedding:
119
+ await adapter._generate_embeddings(["test text"])
120
+
121
+ # Verify litellm.aembedding was called with completion_kwargs
122
+ call_args = mock_aembedding.call_args
123
+ assert call_args[1]["custom_param"] == "value"
124
+ assert call_args[1]["api_base"] == "https://custom-api.example.com"
125
+ assert call_args[1]["default_headers"] == {
126
+ "Authorization": "Bearer custom-token"
127
+ }
128
+
129
+ async def test_generate_embeddings_with_partial_completion_kwargs(
130
+ self, mock_embedding_config, mock_litellm_core_config
131
+ ):
132
+ """Test that completion_kwargs work when only some options are set."""
133
+ # Set only additional_body_options
134
+ mock_litellm_core_config.additional_body_options = {"timeout": 30}
135
+ mock_litellm_core_config.base_url = None
136
+ mock_litellm_core_config.default_headers = None
137
+
138
+ adapter = LitellmEmbeddingAdapter(
139
+ mock_embedding_config, litellm_core_config=mock_litellm_core_config
140
+ )
141
+
142
+ mock_response = AsyncMock(spec=EmbeddingResponse)
143
+ mock_response.data = [
144
+ {"object": "embedding", "index": 0, "embedding": [0.1, 0.2, 0.3]}
145
+ ]
146
+ mock_response.usage = Usage(prompt_tokens=5, total_tokens=5)
147
+
148
+ with patch(
149
+ "litellm.aembedding", new_callable=AsyncMock, return_value=mock_response
150
+ ) as mock_aembedding:
151
+ await adapter._generate_embeddings(["test text"])
152
+
153
+ # Verify only the set options are passed
154
+ call_args = mock_aembedding.call_args
155
+ assert call_args[1]["timeout"] == 30
156
+ assert "api_base" not in call_args[1]
157
+ assert "default_headers" not in call_args[1]
158
+
159
+ async def test_generate_embeddings_with_empty_completion_kwargs(
160
+ self, mock_embedding_config, mock_litellm_core_config
161
+ ):
162
+ """Test that completion_kwargs work when all options are None/empty."""
163
+ # Ensure all options are None/empty
164
+ mock_litellm_core_config.additional_body_options = None
165
+ mock_litellm_core_config.base_url = None
166
+ mock_litellm_core_config.default_headers = None
167
+
168
+ adapter = LitellmEmbeddingAdapter(
169
+ mock_embedding_config, litellm_core_config=mock_litellm_core_config
170
+ )
171
+
172
+ mock_response = AsyncMock(spec=EmbeddingResponse)
173
+ mock_response.data = [
174
+ {"object": "embedding", "index": 0, "embedding": [0.1, 0.2, 0.3]}
175
+ ]
176
+ mock_response.usage = Usage(prompt_tokens=5, total_tokens=5)
177
+
178
+ with patch(
179
+ "litellm.aembedding", new_callable=AsyncMock, return_value=mock_response
180
+ ) as mock_aembedding:
181
+ await adapter._generate_embeddings(["test text"])
182
+
183
+ # Verify no completion_kwargs are passed
184
+ call_args = mock_aembedding.call_args
185
+ assert "api_base" not in call_args[1]
186
+ assert "default_headers" not in call_args[1]
187
+ # Should only have the basic parameters
188
+ assert "model" in call_args[1]
189
+ assert "input" in call_args[1]
190
+
191
+ async def test_generate_embeddings_empty_list(self, mock_litellm_adapter):
192
+ """Test embed method with empty text list."""
193
+ result = await mock_litellm_adapter.generate_embeddings([])
194
+ assert result.embeddings == []
195
+ assert result.usage is None
196
+
197
+ async def test_generate_embeddings_success(self, mock_litellm_adapter):
198
+ """Test successful embedding generation."""
199
+ # mock the response type
200
+ mock_response = AsyncMock(spec=EmbeddingResponse)
201
+ mock_response.data = [
202
+ {"object": "embedding", "index": 0, "embedding": [0.1, 0.2, 0.3]},
203
+ {"object": "embedding", "index": 1, "embedding": [0.4, 0.5, 0.6]},
204
+ ]
205
+ mock_response.usage = Usage(prompt_tokens=10, total_tokens=10)
206
+
207
+ with patch(
208
+ "litellm.aembedding", new_callable=AsyncMock, return_value=mock_response
209
+ ):
210
+ result = await mock_litellm_adapter._generate_embeddings(["text1", "text2"])
211
+
212
+ assert len(result.embeddings) == 2
213
+ assert result.embeddings[0].vector == [0.1, 0.2, 0.3]
214
+ assert result.embeddings[1].vector == [0.4, 0.5, 0.6]
215
+ assert result.usage == mock_response.usage
216
+
217
+ async def test_generate_embeddings_for_batch_success(self, mock_litellm_adapter):
218
+ """Test successful embedding generation for a single batch."""
219
+ mock_response = AsyncMock(spec=EmbeddingResponse)
220
+ mock_response.data = [
221
+ {"object": "embedding", "index": 0, "embedding": [0.1, 0.2, 0.3]},
222
+ {"object": "embedding", "index": 1, "embedding": [0.4, 0.5, 0.6]},
223
+ ]
224
+ mock_response.usage = Usage(prompt_tokens=10, total_tokens=10)
225
+
226
+ with patch(
227
+ "litellm.aembedding", new_callable=AsyncMock, return_value=mock_response
228
+ ):
229
+ result = await mock_litellm_adapter._generate_embeddings_for_batch(
230
+ ["text1", "text2"]
231
+ )
232
+
233
+ assert len(result.embeddings) == 2
234
+ assert result.embeddings[0].vector == [0.1, 0.2, 0.3]
235
+ assert result.embeddings[1].vector == [0.4, 0.5, 0.6]
236
+ assert result.usage == mock_response.usage
237
+
238
+ async def test_generate_embeddings_for_batch_with_completion_kwargs(
239
+ self, mock_embedding_config, mock_litellm_core_config
240
+ ):
241
+ """Test that completion_kwargs are properly passed to litellm.aembedding in batch method."""
242
+ # Set up litellm_core_config with additional options
243
+ mock_litellm_core_config.additional_body_options = {"custom_param": "value"}
244
+ mock_litellm_core_config.base_url = "https://custom-api.example.com"
245
+ mock_litellm_core_config.default_headers = {
246
+ "Authorization": "Bearer custom-token"
247
+ }
248
+
249
+ adapter = LitellmEmbeddingAdapter(
250
+ mock_embedding_config, litellm_core_config=mock_litellm_core_config
251
+ )
252
+
253
+ mock_response = AsyncMock(spec=EmbeddingResponse)
254
+ mock_response.data = [
255
+ {"object": "embedding", "index": 0, "embedding": [0.1, 0.2, 0.3]}
256
+ ]
257
+ mock_response.usage = Usage(prompt_tokens=5, total_tokens=5)
258
+
259
+ with patch(
260
+ "litellm.aembedding", new_callable=AsyncMock, return_value=mock_response
261
+ ) as mock_aembedding:
262
+ await adapter._generate_embeddings_for_batch(["test text"])
263
+
264
+ # Verify litellm.aembedding was called with completion_kwargs
265
+ call_args = mock_aembedding.call_args
266
+ assert call_args[1]["custom_param"] == "value"
267
+ assert call_args[1]["api_base"] == "https://custom-api.example.com"
268
+ assert call_args[1]["default_headers"] == {
269
+ "Authorization": "Bearer custom-token"
270
+ }
271
+
272
+ async def test_generate_embeddings_with_dimensions(
273
+ self, mock_embedding_config, mock_litellm_core_config
274
+ ):
275
+ """Test embedding with dimensions specified."""
276
+ mock_embedding_config.properties = {"dimensions": 1536}
277
+ adapter = LitellmEmbeddingAdapter(
278
+ mock_embedding_config, litellm_core_config=mock_litellm_core_config
279
+ )
280
+
281
+ mock_response = AsyncMock(spec=EmbeddingResponse)
282
+ mock_response.data = [
283
+ {"object": "embedding", "index": 0, "embedding": [0.1] * 1536}
284
+ ]
285
+ mock_response.usage = Usage(prompt_tokens=5, total_tokens=5)
286
+
287
+ with patch(
288
+ "litellm.aembedding", new_callable=AsyncMock, return_value=mock_response
289
+ ) as mock_aembedding:
290
+ result = await adapter._generate_embeddings(["test text"])
291
+
292
+ # Verify litellm.aembedding was called with correct parameters
293
+ mock_aembedding.assert_called_once_with(
294
+ model="openai/text-embedding-3-small",
295
+ input=["test text"],
296
+ dimensions=1536,
297
+ )
298
+
299
+ assert len(result.embeddings) == 1
300
+ assert len(result.embeddings[0].vector) == 1536
301
+ assert result.usage == mock_response.usage
302
+
303
+ async def test_generate_embeddings_batch_size_exceeded(self, mock_litellm_adapter):
304
+ """Test that embedding fails when batch size is exceeded in individual batch."""
305
+ # This test now tests the _generate_embeddings_for_batch method directly
306
+ # since the main _generate_embeddings method now handles batching automatically
307
+ large_text_list = ["text"] * (MAX_BATCH_SIZE + 1)
308
+
309
+ with pytest.raises(
310
+ ValueError,
311
+ match=f"Too many input texts, max batch size is {MAX_BATCH_SIZE}, got {MAX_BATCH_SIZE + 1}",
312
+ ):
313
+ await mock_litellm_adapter._generate_embeddings_for_batch(large_text_list)
314
+
315
+ async def test_generate_embeddings_response_length_mismatch(
316
+ self, mock_litellm_adapter
317
+ ):
318
+ """Test that embedding fails when response data length doesn't match input."""
319
+ mock_response = AsyncMock(spec=EmbeddingResponse)
320
+ mock_response.data = [
321
+ {"object": "embedding", "index": 0, "embedding": [0.1, 0.2, 0.3]}
322
+ ] # Only one embedding
323
+
324
+ with patch(
325
+ "litellm.aembedding", new_callable=AsyncMock, return_value=mock_response
326
+ ):
327
+ with pytest.raises(
328
+ RuntimeError,
329
+ match=r"Expected the number of embeddings in the response to be 2, got 1.",
330
+ ):
331
+ await mock_litellm_adapter._generate_embeddings(["text1", "text2"])
332
+
333
+ async def test_generate_embeddings_litellm_exception(self, mock_litellm_adapter):
334
+ """Test that litellm exceptions are properly raised."""
335
+ with patch(
336
+ "litellm.aembedding",
337
+ new_callable=AsyncMock,
338
+ side_effect=Exception("litellm error"),
339
+ ):
340
+ with pytest.raises(Exception, match="litellm error"):
341
+ await mock_litellm_adapter._generate_embeddings(["test text"])
342
+
343
+ async def test_generate_embeddings_sorts_by_index(self, mock_litellm_adapter):
344
+ """Test that embeddings are sorted by index."""
345
+ mock_response = AsyncMock(spec=EmbeddingResponse)
346
+ mock_response.data = [
347
+ {"object": "embedding", "index": 2, "embedding": [0.3, 0.4, 0.5]},
348
+ {"object": "embedding", "index": 0, "embedding": [0.1, 0.2, 0.3]},
349
+ {"object": "embedding", "index": 1, "embedding": [0.2, 0.3, 0.4]},
350
+ ]
351
+ mock_response.usage = Usage(prompt_tokens=15, total_tokens=15)
352
+
353
+ with patch(
354
+ "litellm.aembedding", new_callable=AsyncMock, return_value=mock_response
355
+ ):
356
+ result = await mock_litellm_adapter._generate_embeddings(
357
+ ["text1", "text2", "text3"]
358
+ )
359
+
360
+ # Verify embeddings are sorted by index
361
+ assert len(result.embeddings) == 3
362
+ assert result.embeddings[0].vector == [0.1, 0.2, 0.3] # index 0
363
+ assert result.embeddings[1].vector == [0.2, 0.3, 0.4] # index 1
364
+ assert result.embeddings[2].vector == [0.3, 0.4, 0.5] # index 2
365
+
366
+ async def test_generate_embeddings_single_text(self, mock_litellm_adapter):
367
+ """Test embedding a single text."""
368
+ mock_response = AsyncMock(spec=EmbeddingResponse)
369
+ mock_response.data = [
370
+ {"object": "embedding", "index": 0, "embedding": [0.1, 0.2, 0.3]}
371
+ ]
372
+ mock_response.usage = Usage(prompt_tokens=5, total_tokens=5)
373
+
374
+ with patch(
375
+ "litellm.aembedding", new_callable=AsyncMock, return_value=mock_response
376
+ ) as mock_aembedding:
377
+ result = await mock_litellm_adapter._generate_embeddings(["single text"])
378
+
379
+ # The call should not include dimensions since the fixture has empty properties
380
+ mock_aembedding.assert_called_once_with(
381
+ model="openai/text-embedding-3-small",
382
+ input=["single text"],
383
+ )
384
+
385
+ assert len(result.embeddings) == 1
386
+ assert result.embeddings[0].vector == [0.1, 0.2, 0.3]
387
+ assert result.usage == mock_response.usage
388
+
389
+ async def test_generate_embeddings_max_batch_size(self, mock_litellm_adapter):
390
+ """Test embedding with exactly the maximum batch size."""
391
+ mock_response = AsyncMock(spec=EmbeddingResponse)
392
+ mock_response.data = [
393
+ {"object": "embedding", "index": i, "embedding": [0.1, 0.2, 0.3]}
394
+ for i in range(MAX_BATCH_SIZE)
395
+ ]
396
+ mock_response.usage = Usage(
397
+ prompt_tokens=MAX_BATCH_SIZE * 5, total_tokens=MAX_BATCH_SIZE * 5
398
+ )
399
+
400
+ large_text_list = ["text"] * MAX_BATCH_SIZE
401
+
402
+ with patch(
403
+ "litellm.aembedding", new_callable=AsyncMock, return_value=mock_response
404
+ ):
405
+ result = await mock_litellm_adapter._generate_embeddings(large_text_list)
406
+
407
+ assert len(result.embeddings) == MAX_BATCH_SIZE
408
+ assert result.usage == mock_response.usage
409
+
410
+ async def test_generate_embeddings_multiple_batches(self, mock_litellm_adapter):
411
+ """Test that embedding properly handles multiple batches."""
412
+ # Create a list that will require multiple batches
413
+ total_texts = MAX_BATCH_SIZE * 2 + 50 # 2 full batches + 50 more
414
+ text_list = [f"text_{i}" for i in range(total_texts)]
415
+
416
+ # Mock responses for each batch
417
+ batch1_response = AsyncMock(spec=EmbeddingResponse)
418
+ batch1_response.data = [
419
+ {"object": "embedding", "index": i, "embedding": [0.1, 0.2, 0.3]}
420
+ for i in range(MAX_BATCH_SIZE)
421
+ ]
422
+ batch1_response.usage = Usage(prompt_tokens=100, total_tokens=100)
423
+
424
+ batch2_response = AsyncMock(spec=EmbeddingResponse)
425
+ batch2_response.data = [
426
+ {"object": "embedding", "index": i, "embedding": [0.4, 0.5, 0.6]}
427
+ for i in range(MAX_BATCH_SIZE)
428
+ ]
429
+ batch2_response.usage = Usage(prompt_tokens=100, total_tokens=100)
430
+
431
+ batch3_response = AsyncMock(spec=EmbeddingResponse)
432
+ batch3_response.data = [
433
+ {"object": "embedding", "index": i, "embedding": [0.7, 0.8, 0.9]}
434
+ for i in range(50)
435
+ ]
436
+ batch3_response.usage = Usage(prompt_tokens=50, total_tokens=50)
437
+
438
+ # Mock litellm.aembedding to return different responses based on input size
439
+ async def mock_aembedding(*args, **kwargs):
440
+ input_size = len(kwargs.get("input", []))
441
+ if input_size == MAX_BATCH_SIZE:
442
+ if len(mock_aembedding.call_count) == 0:
443
+ mock_aembedding.call_count.append(1)
444
+ return batch1_response
445
+ else:
446
+ mock_aembedding.call_count.append(1)
447
+ return batch2_response
448
+ else:
449
+ return batch3_response
450
+
451
+ mock_aembedding.call_count = []
452
+
453
+ with patch(
454
+ "litellm.aembedding", new_callable=AsyncMock, side_effect=mock_aembedding
455
+ ):
456
+ result = await mock_litellm_adapter._generate_embeddings(text_list)
457
+
458
+ # Should have all embeddings combined
459
+ assert len(result.embeddings) == total_texts
460
+
461
+ # Should have combined usage from all batches
462
+ assert result.usage is not None
463
+ assert result.usage.prompt_tokens == 250 # 100 + 100 + 50
464
+ assert result.usage.total_tokens == 250 # 100 + 100 + 50
465
+
466
+ # Verify embeddings are in the right order
467
+ assert result.embeddings[0].vector == [0.1, 0.2, 0.3] # First batch
468
+ assert result.embeddings[MAX_BATCH_SIZE].vector == [
469
+ 0.4,
470
+ 0.5,
471
+ 0.6,
472
+ ] # Second batch
473
+ assert result.embeddings[MAX_BATCH_SIZE * 2].vector == [
474
+ 0.7,
475
+ 0.8,
476
+ 0.9,
477
+ ] # Third batch
478
+
479
+ async def test_generate_embeddings_batching_edge_cases(self, mock_litellm_adapter):
480
+ """Test batching edge cases like empty lists and single items."""
481
+ # Test empty list
482
+ result = await mock_litellm_adapter._generate_embeddings([])
483
+ assert result.embeddings == []
484
+ assert result.usage is not None
485
+ assert result.usage.prompt_tokens == 0
486
+ assert result.usage.total_tokens == 0
487
+
488
+ # Test single item (should still go through batching logic)
489
+ mock_response = AsyncMock(spec=EmbeddingResponse)
490
+ mock_response.data = [
491
+ {"object": "embedding", "index": 0, "embedding": [0.1, 0.2, 0.3]}
492
+ ]
493
+ mock_response.usage = Usage(prompt_tokens=5, total_tokens=5)
494
+
495
+ with patch("litellm.aembedding", return_value=mock_response):
496
+ result = await mock_litellm_adapter._generate_embeddings(["single text"])
497
+
498
+ assert len(result.embeddings) == 1
499
+ assert result.embeddings[0].vector == [0.1, 0.2, 0.3]
500
+ assert result.usage == mock_response.usage
501
+
502
+ async def test_generate_embeddings_batching_with_mixed_usage(
503
+ self, mock_litellm_adapter
504
+ ):
505
+ """Test batching when some responses have usage and others don't."""
506
+ # Create a list that will require multiple batches
507
+ text_list = ["text"] * (MAX_BATCH_SIZE + 10)
508
+
509
+ # First batch with usage
510
+ batch1_response = AsyncMock(spec=EmbeddingResponse)
511
+ batch1_response.data = [
512
+ {"object": "embedding", "index": i, "embedding": [0.1, 0.2, 0.3]}
513
+ for i in range(MAX_BATCH_SIZE)
514
+ ]
515
+ batch1_response.usage = Usage(prompt_tokens=100, total_tokens=100)
516
+
517
+ # Second batch without usage
518
+ batch2_response = AsyncMock(spec=EmbeddingResponse)
519
+ batch2_response.data = [
520
+ {"object": "embedding", "index": i, "embedding": [0.4, 0.5, 0.6]}
521
+ for i in range(10)
522
+ ]
523
+ batch2_response.usage = None
524
+
525
+ # Mock litellm.aembedding to return different responses based on input size
526
+ async def mock_aembedding(*args, **kwargs):
527
+ input_size = len(kwargs.get("input", []))
528
+ if input_size == MAX_BATCH_SIZE:
529
+ return batch1_response
530
+ else:
531
+ return batch2_response
532
+
533
+ with patch(
534
+ "litellm.aembedding", new_callable=AsyncMock, side_effect=mock_aembedding
535
+ ):
536
+ result = await mock_litellm_adapter._generate_embeddings(text_list)
537
+
538
+ # Should have all embeddings combined
539
+ assert len(result.embeddings) == MAX_BATCH_SIZE + 10
540
+
541
+ # Should have None usage since one batch has None usage
542
+ assert result.usage is None
543
+
544
+ async def test_generate_embeddings_batching_with_all_usage(
545
+ self, mock_litellm_adapter
546
+ ):
547
+ """Test batching when all responses have usage information."""
548
+ # Create a list that will require multiple batches
549
+ text_list = ["text"] * (MAX_BATCH_SIZE + 10)
550
+
551
+ # First batch with usage
552
+ batch1_response = AsyncMock(spec=EmbeddingResponse)
553
+ batch1_response.data = [
554
+ {"object": "embedding", "index": i, "embedding": [0.1, 0.2, 0.3]}
555
+ for i in range(MAX_BATCH_SIZE)
556
+ ]
557
+ batch1_response.usage = Usage(prompt_tokens=100, total_tokens=100)
558
+
559
+ # Second batch with usage
560
+ batch2_response = AsyncMock(spec=EmbeddingResponse)
561
+ batch2_response.data = [
562
+ {"object": "embedding", "index": i, "embedding": [0.4, 0.5, 0.6]}
563
+ for i in range(10)
564
+ ]
565
+ batch2_response.usage = Usage(prompt_tokens=50, total_tokens=50)
566
+
567
+ # Mock litellm.aembedding to return different responses based on input size
568
+ async def mock_aembedding(*args, **kwargs):
569
+ input_size = len(kwargs.get("input", []))
570
+ if input_size == MAX_BATCH_SIZE:
571
+ return batch1_response
572
+ else:
573
+ return batch2_response
574
+
575
+ with patch(
576
+ "litellm.aembedding", new_callable=AsyncMock, side_effect=mock_aembedding
577
+ ):
578
+ result = await mock_litellm_adapter._generate_embeddings(text_list)
579
+
580
+ # Should have all embeddings combined
581
+ assert len(result.embeddings) == MAX_BATCH_SIZE + 10
582
+
583
+ # Should have combined usage since all batches have usage
584
+ assert result.usage is not None
585
+ assert result.usage.prompt_tokens == 150 # 100 + 50
586
+ assert result.usage.total_tokens == 150 # 100 + 50
587
+
588
+ def test_embedding_config_inheritance(
589
+ self, mock_embedding_config, mock_litellm_core_config
590
+ ):
591
+ """Test that the adapter properly inherits from BaseEmbeddingAdapter."""
592
+ adapter = LitellmEmbeddingAdapter(
593
+ mock_embedding_config, litellm_core_config=mock_litellm_core_config
594
+ )
595
+ assert adapter.embedding_config == mock_embedding_config
596
+
597
+ async def test_generate_embeddings_method_integration(self, mock_litellm_adapter):
598
+ """Test the public embed method integration."""
599
+ mock_response = AsyncMock(spec=EmbeddingResponse)
600
+ mock_response.data = [
601
+ {"object": "embedding", "index": 0, "embedding": [0.1, 0.2, 0.3]}
602
+ ]
603
+ mock_response.usage = Usage(prompt_tokens=5, total_tokens=5)
604
+
605
+ with patch(
606
+ "litellm.aembedding", new_callable=AsyncMock, return_value=mock_response
607
+ ):
608
+ result = await mock_litellm_adapter.generate_embeddings(["test text"])
609
+
610
+ assert len(result.embeddings) == 1
611
+ assert result.embeddings[0].vector == [0.1, 0.2, 0.3]
612
+ assert result.usage == mock_response.usage
613
+
614
+
615
+ class TestLitellmEmbeddingAdapterEdgeCases:
616
+ """Test edge cases and error conditions."""
617
+
618
+ async def test_generate_embeddings_with_none_usage(
619
+ self, mock_embedding_config, mock_litellm_core_config
620
+ ):
621
+ """Test embedding when litellm returns None usage."""
622
+ adapter = LitellmEmbeddingAdapter(
623
+ mock_embedding_config, litellm_core_config=mock_litellm_core_config
624
+ )
625
+ mock_response = AsyncMock(spec=EmbeddingResponse)
626
+ mock_response.data = [
627
+ {"object": "embedding", "index": 0, "embedding": [0.1, 0.2, 0.3]}
628
+ ]
629
+ mock_response.usage = None
630
+
631
+ with patch(
632
+ "litellm.aembedding", new_callable=AsyncMock, return_value=mock_response
633
+ ):
634
+ result = await adapter._generate_embeddings(["test text"])
635
+
636
+ assert len(result.embeddings) == 1
637
+ # With the new logic, if any response has None usage, the result has None usage
638
+ assert result.usage is None
639
+
640
+ async def test_generate_embeddings_with_empty_embedding_vector(
641
+ self, mock_embedding_config, mock_litellm_core_config
642
+ ):
643
+ """Test embedding with empty vector."""
644
+ adapter = LitellmEmbeddingAdapter(
645
+ mock_embedding_config, litellm_core_config=mock_litellm_core_config
646
+ )
647
+ mock_response = AsyncMock(spec=EmbeddingResponse)
648
+ mock_response.data = [{"object": "embedding", "index": 0, "embedding": []}]
649
+ mock_response.usage = Usage(prompt_tokens=5, total_tokens=5)
650
+
651
+ with patch(
652
+ "litellm.aembedding", new_callable=AsyncMock, return_value=mock_response
653
+ ):
654
+ result = await adapter._generate_embeddings(["test text"])
655
+
656
+ assert len(result.embeddings) == 1
657
+ assert result.embeddings[0].vector == []
658
+
659
+ async def test_generate_embeddings_with_duplicate_indices(
660
+ self, mock_embedding_config, mock_litellm_core_config
661
+ ):
662
+ """Test embedding with duplicate indices (should still work due to sorting)."""
663
+ adapter = LitellmEmbeddingAdapter(
664
+ mock_embedding_config, litellm_core_config=mock_litellm_core_config
665
+ )
666
+ mock_response = AsyncMock(spec=EmbeddingResponse)
667
+ mock_response.data = [
668
+ {"object": "embedding", "index": 0, "embedding": [0.1, 0.2, 0.3]},
669
+ {
670
+ "object": "embedding",
671
+ "index": 0,
672
+ "embedding": [0.4, 0.5, 0.6],
673
+ }, # Duplicate index
674
+ ]
675
+ mock_response.usage = Usage(prompt_tokens=10, total_tokens=10)
676
+
677
+ with patch(
678
+ "litellm.aembedding", new_callable=AsyncMock, return_value=mock_response
679
+ ):
680
+ result = await adapter._generate_embeddings(["text1", "text2"])
681
+
682
+ # Both embeddings should be present and match the order in response.data
683
+ assert len(result.embeddings) == 2
684
+ assert result.embeddings[0].vector == [0.1, 0.2, 0.3]
685
+ assert result.embeddings[1].vector == [0.4, 0.5, 0.6]
686
+
687
+ async def test_generate_embeddings_with_complex_properties(
688
+ self, mock_embedding_config, mock_litellm_core_config
689
+ ):
690
+ """Test embedding with complex properties (only dimensions should be used)."""
691
+ mock_embedding_config.properties = {
692
+ "dimensions": 1536,
693
+ "custom_property": "value",
694
+ "numeric_property": 42,
695
+ "boolean_property": True,
696
+ }
697
+ adapter = LitellmEmbeddingAdapter(
698
+ mock_embedding_config, litellm_core_config=mock_litellm_core_config
699
+ )
700
+
701
+ mock_response = AsyncMock(spec=EmbeddingResponse)
702
+ mock_response.data = [
703
+ {"object": "embedding", "index": 0, "embedding": [0.1] * 1536}
704
+ ]
705
+ mock_response.usage = Usage(prompt_tokens=5, total_tokens=5)
706
+
707
+ with patch(
708
+ "litellm.aembedding", new_callable=AsyncMock, return_value=mock_response
709
+ ) as mock_aembedding:
710
+ await adapter._generate_embeddings(["test text"])
711
+
712
+ # Only dimensions should be passed to litellm
713
+ call_args = mock_aembedding.call_args
714
+ assert call_args[1]["dimensions"] == 1536
715
+ # Other properties should not be passed
716
+ assert "custom_property" not in call_args[1]
717
+ assert "numeric_property" not in call_args[1]
718
+ assert "boolean_property" not in call_args[1]
719
+
720
+
721
+ @pytest.mark.paid
722
+ @pytest.mark.parametrize(
723
+ "provider,model_name,expected_dim",
724
+ [
725
+ (ModelProviderName.openai, "openai_text_embedding_3_small", 1536),
726
+ (ModelProviderName.openai, "openai_text_embedding_3_large", 3072),
727
+ (ModelProviderName.gemini_api, "gemini_text_embedding_004", 768),
728
+ ],
729
+ )
730
+ @pytest.mark.asyncio
731
+ async def test_paid_generate_embeddings_basic(
732
+ provider, model_name, expected_dim, mock_litellm_core_config
733
+ ):
734
+ openai_key = Config.shared().open_ai_api_key
735
+ if not openai_key:
736
+ pytest.skip("OPENAI_API_KEY not set")
737
+ # Set the API key for litellm
738
+ os.environ["OPENAI_API_KEY"] = openai_key
739
+
740
+ # gemini key
741
+ gemini_key = Config.shared().gemini_api_key
742
+ if not gemini_key:
743
+ pytest.skip("GEMINI_API_KEY not set")
744
+ os.environ["GEMINI_API_KEY"] = gemini_key
745
+
746
+ config = EmbeddingConfig(
747
+ name="paid-embedding",
748
+ model_provider_name=provider,
749
+ model_name=model_name,
750
+ properties={},
751
+ )
752
+ adapter = LitellmEmbeddingAdapter(
753
+ config, litellm_core_config=mock_litellm_core_config
754
+ )
755
+ text = ["Kiln is an open-source evaluation platform for LLMs."]
756
+ result = await adapter.generate_embeddings(text)
757
+ assert len(result.embeddings) == 1
758
+ assert isinstance(result.embeddings[0].vector, list)
759
+ assert len(result.embeddings[0].vector) == expected_dim
760
+ assert all(isinstance(x, float) for x in result.embeddings[0].vector)
761
+
762
+
763
+ # test model_provider
764
+ def test_model_provider(mock_litellm_core_config):
765
+ mock_embedding_config = EmbeddingConfig(
766
+ name="test",
767
+ model_provider_name=ModelProviderName.openai,
768
+ model_name="openai_text_embedding_3_small",
769
+ properties={},
770
+ )
771
+ adapter = LitellmEmbeddingAdapter(
772
+ mock_embedding_config, litellm_core_config=mock_litellm_core_config
773
+ )
774
+ assert adapter.model_provider.name == ModelProviderName.openai
775
+ assert adapter.model_provider.model_id == "text-embedding-3-small"
776
+
777
+
778
+ def test_model_provider_gemini(mock_litellm_core_config):
779
+ config = EmbeddingConfig(
780
+ name="test",
781
+ model_provider_name=ModelProviderName.gemini_api,
782
+ model_name="gemini_text_embedding_004",
783
+ properties={},
784
+ )
785
+ adapter = LitellmEmbeddingAdapter(
786
+ config, litellm_core_config=mock_litellm_core_config
787
+ )
788
+ assert adapter.model_provider.name == ModelProviderName.gemini_api
789
+ assert adapter.model_provider.model_id == "text-embedding-004"
790
+
791
+
792
+ @pytest.mark.parametrize(
793
+ "provider,model_name,expected_model_id",
794
+ [
795
+ (
796
+ ModelProviderName.gemini_api,
797
+ "gemini_text_embedding_004",
798
+ "gemini/text-embedding-004",
799
+ ),
800
+ (
801
+ ModelProviderName.openai,
802
+ "openai_text_embedding_3_small",
803
+ "openai/text-embedding-3-small",
804
+ ),
805
+ ],
806
+ )
807
+ def test_litellm_model_id(
808
+ provider, model_name, expected_model_id, mock_litellm_core_config
809
+ ):
810
+ config = EmbeddingConfig(
811
+ name="test",
812
+ model_provider_name=provider,
813
+ model_name=model_name,
814
+ properties={},
815
+ )
816
+ adapter = LitellmEmbeddingAdapter(
817
+ config, litellm_core_config=mock_litellm_core_config
818
+ )
819
+ assert adapter.litellm_model_id == expected_model_id
820
+
821
+
822
+ def test_litellm_model_id_custom_provider_without_base_url(mock_litellm_core_config):
823
+ """Test that custom providers without base_url raise an error."""
824
+ config = EmbeddingConfig(
825
+ name="test",
826
+ model_provider_name=ModelProviderName.openai_compatible,
827
+ model_name="some-model",
828
+ properties={},
829
+ )
830
+ adapter = LitellmEmbeddingAdapter(
831
+ config, litellm_core_config=mock_litellm_core_config
832
+ )
833
+
834
+ with pytest.raises(
835
+ ValueError,
836
+ match="Embedding model some-model not found in the list of built-in models",
837
+ ):
838
+ adapter.model_provider
839
+
840
+
841
+ def test_litellm_model_id_custom_provider_with_base_url(mock_litellm_core_config):
842
+ """Test that custom providers with base_url work correctly."""
843
+ # Set up a custom provider with base_url
844
+ mock_litellm_core_config.base_url = "https://custom-api.example.com"
845
+
846
+ config = EmbeddingConfig(
847
+ name="test",
848
+ model_provider_name=ModelProviderName.openai_compatible,
849
+ model_name="some-model",
850
+ properties={},
851
+ )
852
+ adapter = LitellmEmbeddingAdapter(
853
+ config, litellm_core_config=mock_litellm_core_config
854
+ )
855
+
856
+ with pytest.raises(
857
+ ValueError,
858
+ match="Embedding model some-model not found in the list of built-in models",
859
+ ):
860
+ adapter.model_provider
861
+
862
+
863
+ def test_litellm_model_id_custom_provider_ollama_with_base_url():
864
+ """Test that ollama provider with base_url works correctly."""
865
+
866
+ # Create a mock provider that would be found in the built-in models
867
+ # We need to mock the built_in_embedding_models_from_provider function
868
+ with patch(
869
+ "kiln_ai.adapters.embedding.litellm_embedding_adapter.built_in_embedding_models_from_provider"
870
+ ) as mock_built_in:
871
+ mock_built_in.return_value = KilnEmbeddingModelProvider(
872
+ name=ModelProviderName.ollama,
873
+ model_id="test-model",
874
+ n_dimensions=768,
875
+ )
876
+
877
+ config = EmbeddingConfig(
878
+ name="test",
879
+ model_provider_name=ModelProviderName.ollama,
880
+ model_name="test-model",
881
+ properties={},
882
+ )
883
+
884
+ # With base_url - should work
885
+ litellm_core_config_with_url = LiteLlmCoreConfig(
886
+ base_url="http://localhost:11434"
887
+ )
888
+ adapter = LitellmEmbeddingAdapter(
889
+ config, litellm_core_config=litellm_core_config_with_url
890
+ )
891
+
892
+ # Should not raise an error
893
+ model_id = adapter.litellm_model_id
894
+ assert model_id == "openai/test-model"
895
+
896
+
897
+ def test_litellm_model_id_custom_provider_ollama_without_base_url():
898
+ """Test that ollama provider without base_url raises an error."""
899
+ from kiln_ai.adapters.ml_embedding_model_list import KilnEmbeddingModelProvider
900
+
901
+ # Create a mock provider that would be found in the built-in models
902
+ with patch(
903
+ "kiln_ai.adapters.embedding.litellm_embedding_adapter.built_in_embedding_models_from_provider"
904
+ ) as mock_built_in:
905
+ mock_built_in.return_value = KilnEmbeddingModelProvider(
906
+ name=ModelProviderName.ollama,
907
+ model_id="test-model",
908
+ n_dimensions=768,
909
+ )
910
+
911
+ config = EmbeddingConfig(
912
+ name="test",
913
+ model_provider_name=ModelProviderName.ollama,
914
+ model_name="test-model",
915
+ properties={},
916
+ )
917
+
918
+ # Without base_url - should raise an error
919
+ litellm_core_config_without_url = LiteLlmCoreConfig(base_url=None)
920
+ adapter = LitellmEmbeddingAdapter(
921
+ config, litellm_core_config=litellm_core_config_without_url
922
+ )
923
+
924
+ with pytest.raises(
925
+ ValueError,
926
+ match="Provider ollama must have an explicit base URL",
927
+ ):
928
+ adapter.litellm_model_id
929
+
930
+
931
+ def test_litellm_model_id_custom_provider_openai_compatible_with_base_url():
932
+ """Test that openai_compatible provider with base_url works correctly."""
933
+ from kiln_ai.adapters.ml_embedding_model_list import KilnEmbeddingModelProvider
934
+
935
+ # Create a mock provider that would be found in the built-in models
936
+ with patch(
937
+ "kiln_ai.adapters.embedding.litellm_embedding_adapter.built_in_embedding_models_from_provider"
938
+ ) as mock_built_in:
939
+ mock_built_in.return_value = KilnEmbeddingModelProvider(
940
+ name=ModelProviderName.openai_compatible,
941
+ model_id="test-model",
942
+ n_dimensions=768,
943
+ )
944
+
945
+ config = EmbeddingConfig(
946
+ name="test",
947
+ model_provider_name=ModelProviderName.openai_compatible,
948
+ model_name="test-model",
949
+ properties={},
950
+ )
951
+
952
+ # With base_url - should work
953
+ litellm_core_config_with_url = LiteLlmCoreConfig(
954
+ base_url="https://custom-api.example.com"
955
+ )
956
+ adapter = LitellmEmbeddingAdapter(
957
+ config, litellm_core_config=litellm_core_config_with_url
958
+ )
959
+
960
+ # Should not raise an error
961
+ model_id = adapter.litellm_model_id
962
+ assert model_id == "openai/test-model"
963
+
964
+
965
+ def test_litellm_model_id_custom_provider_openai_compatible_without_base_url():
966
+ """Test that openai_compatible provider without base_url raises an error."""
967
+
968
+ # Create a mock provider that would be found in the built-in models
969
+ with patch(
970
+ "kiln_ai.adapters.embedding.litellm_embedding_adapter.built_in_embedding_models_from_provider"
971
+ ) as mock_built_in:
972
+ mock_built_in.return_value = KilnEmbeddingModelProvider(
973
+ name=ModelProviderName.openai_compatible,
974
+ model_id="test-model",
975
+ n_dimensions=768,
976
+ )
977
+
978
+ config = EmbeddingConfig(
979
+ name="test",
980
+ model_provider_name=ModelProviderName.openai_compatible,
981
+ model_name="test-model",
982
+ properties={},
983
+ )
984
+
985
+ # Without base_url - should raise an error
986
+ litellm_core_config_without_url = LiteLlmCoreConfig(base_url=None)
987
+ adapter = LitellmEmbeddingAdapter(
988
+ config, litellm_core_config=litellm_core_config_without_url
989
+ )
990
+
991
+ with pytest.raises(
992
+ ValueError,
993
+ match="Provider openai_compatible must have an explicit base URL",
994
+ ):
995
+ adapter.litellm_model_id
996
+
997
+
998
+ @pytest.mark.paid
999
+ @pytest.mark.parametrize(
1000
+ "provider,model_name,expected_dim",
1001
+ [
1002
+ (ModelProviderName.openai, "openai_text_embedding_3_small", 256),
1003
+ (ModelProviderName.openai, "openai_text_embedding_3_small", 512),
1004
+ (ModelProviderName.openai, "openai_text_embedding_3_large", 256),
1005
+ (ModelProviderName.openai, "openai_text_embedding_3_large", 512),
1006
+ (ModelProviderName.openai, "openai_text_embedding_3_large", 1024),
1007
+ (ModelProviderName.openai, "openai_text_embedding_3_large", 2048),
1008
+ ],
1009
+ )
1010
+ @pytest.mark.asyncio
1011
+ async def test_paid_generate_embeddings_with_custom_dimensions_supported(
1012
+ provider, model_name, expected_dim, mock_litellm_core_config
1013
+ ):
1014
+ """
1015
+ Some models support custom dimensions - where the provider shortens the dimensions to match
1016
+ the desired custom number of dimensions. Ref: https://openai.com/index/new-embedding-models-and-api-updates/
1017
+ """
1018
+ api_key = Config.shared().open_ai_api_key or os.environ.get("OPENAI_API_KEY")
1019
+ if api_key:
1020
+ os.environ["OPENAI_API_KEY"] = api_key
1021
+
1022
+ config = EmbeddingConfig(
1023
+ name="paid-embedding",
1024
+ model_provider_name=provider,
1025
+ model_name=model_name,
1026
+ properties={"dimensions": expected_dim},
1027
+ )
1028
+ adapter = LitellmEmbeddingAdapter(
1029
+ config, litellm_core_config=mock_litellm_core_config
1030
+ )
1031
+ text = ["Kiln is an open-source evaluation platform for LLMs."]
1032
+ result = await adapter.generate_embeddings(text)
1033
+ assert len(result.embeddings) == 1
1034
+ assert isinstance(result.embeddings[0].vector, list)
1035
+ assert len(result.embeddings[0].vector) == expected_dim
1036
+ assert all(isinstance(x, float) for x in result.embeddings[0].vector)
1037
+
1038
+
1039
+ def test_validate_map_to_embeddings():
1040
+ mock_response = AsyncMock(spec=EmbeddingResponse)
1041
+ mock_response.data = [
1042
+ {"object": "embedding", "index": 0, "embedding": [0.1, 0.2, 0.3]},
1043
+ {"object": "embedding", "index": 1, "embedding": [0.4, 0.5, 0.6]},
1044
+ ]
1045
+ expected_embeddings = [
1046
+ Embedding(vector=[0.1, 0.2, 0.3]),
1047
+ Embedding(vector=[0.4, 0.5, 0.6]),
1048
+ ]
1049
+ result = validate_map_to_embeddings(mock_response, 2)
1050
+ assert result == expected_embeddings
1051
+
1052
+
1053
+ def test_validate_map_to_embeddings_invalid_length():
1054
+ mock_response = AsyncMock(spec=EmbeddingResponse)
1055
+ mock_response.data = [
1056
+ {"object": "embedding", "index": 0, "embedding": [0.1, 0.2, 0.3]},
1057
+ ]
1058
+ with pytest.raises(
1059
+ RuntimeError,
1060
+ match=r"Expected the number of embeddings in the response to be 2, got 1.",
1061
+ ):
1062
+ validate_map_to_embeddings(mock_response, 2)
1063
+
1064
+
1065
+ def test_validate_map_to_embeddings_invalid_object_type():
1066
+ mock_response = AsyncMock(spec=EmbeddingResponse)
1067
+ mock_response.data = [
1068
+ {"object": "not_embedding", "index": 0, "embedding": [0.1, 0.2, 0.3]},
1069
+ ]
1070
+ with pytest.raises(
1071
+ RuntimeError,
1072
+ match=r"Embedding response data has an unexpected shape. Property 'object' is not 'embedding'. Got not_embedding.",
1073
+ ):
1074
+ validate_map_to_embeddings(mock_response, 1)
1075
+
1076
+
1077
+ def test_validate_map_to_embeddings_invalid_embedding_type():
1078
+ mock_response = AsyncMock(spec=EmbeddingResponse)
1079
+ mock_response.data = [
1080
+ {"object": "embedding", "index": 0, "embedding": "not_a_list"},
1081
+ ]
1082
+ mock_response.usage = Usage(prompt_tokens=5, total_tokens=5)
1083
+ with pytest.raises(
1084
+ RuntimeError,
1085
+ match=r"Embedding response data has an unexpected shape. Property 'embedding' is not a list. Got <class 'str'>.",
1086
+ ):
1087
+ validate_map_to_embeddings(mock_response, 1)
1088
+
1089
+ # missing embedding
1090
+ mock_response = AsyncMock(spec=EmbeddingResponse)
1091
+ mock_response.data = [
1092
+ {"object": "embedding", "index": 0},
1093
+ ]
1094
+ with pytest.raises(
1095
+ RuntimeError,
1096
+ match=r"Embedding response data has an unexpected shape. Property 'embedding' is None in response data item.",
1097
+ ):
1098
+ validate_map_to_embeddings(mock_response, 1)
1099
+
1100
+
1101
+ def test_validate_map_to_embeddings_invalid_index_type():
1102
+ mock_response = AsyncMock(spec=EmbeddingResponse)
1103
+ mock_response.data = [
1104
+ {"object": "embedding", "index": "not_an_int", "embedding": [0.1, 0.2, 0.3]},
1105
+ ]
1106
+ mock_response.usage = Usage(prompt_tokens=5, total_tokens=5)
1107
+ with pytest.raises(
1108
+ RuntimeError,
1109
+ match=r"Embedding response data has an unexpected shape. Property 'index' is not an integer. Got <class 'str'>.",
1110
+ ):
1111
+ validate_map_to_embeddings(mock_response, 1)
1112
+
1113
+ # missing index
1114
+ mock_response = AsyncMock(spec=EmbeddingResponse)
1115
+ mock_response.data = [
1116
+ {"object": "embedding", "embedding": [0.1, 0.2, 0.3]},
1117
+ ]
1118
+ with pytest.raises(
1119
+ RuntimeError,
1120
+ match=r"Embedding response data has an unexpected shape. Property 'index' is None in response data item.",
1121
+ ):
1122
+ validate_map_to_embeddings(mock_response, 1)
1123
+
1124
+
1125
+ def test_validate_map_to_embeddings_sorting():
1126
+ mock_response = AsyncMock(spec=EmbeddingResponse)
1127
+ mock_response.data = [
1128
+ {"object": "embedding", "index": 2, "embedding": [0.3, 0.4, 0.5]},
1129
+ {"object": "embedding", "index": 0, "embedding": [0.1, 0.2, 0.3]},
1130
+ {"object": "embedding", "index": 1, "embedding": [0.2, 0.3, 0.4]},
1131
+ ]
1132
+ expected_embeddings = [
1133
+ Embedding(vector=[0.1, 0.2, 0.3]),
1134
+ Embedding(vector=[0.2, 0.3, 0.4]),
1135
+ Embedding(vector=[0.3, 0.4, 0.5]),
1136
+ ]
1137
+ result = validate_map_to_embeddings(mock_response, 3)
1138
+ assert result == expected_embeddings
1139
+
1140
+
1141
+ def test_generate_embeddings_response_not_embedding_response():
1142
+ response = AsyncMock()
1143
+ response.data = [{"object": "embedding", "index": 0, "embedding": [0.1, 0.2, 0.3]}]
1144
+ response.usage = Usage(prompt_tokens=5, total_tokens=5)
1145
+ with pytest.raises(
1146
+ RuntimeError,
1147
+ match=r"Expected EmbeddingResponse, got <class 'unittest.mock.AsyncMock'>.",
1148
+ ):
1149
+ validate_map_to_embeddings(response, 1)