kiln-ai 0.21.0__py3-none-any.whl → 0.22.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/extractors/litellm_extractor.py +52 -32
- kiln_ai/adapters/extractors/test_litellm_extractor.py +169 -71
- kiln_ai/adapters/ml_embedding_model_list.py +330 -28
- kiln_ai/adapters/ml_model_list.py +503 -23
- kiln_ai/adapters/model_adapters/litellm_adapter.py +39 -8
- kiln_ai/adapters/model_adapters/test_litellm_adapter.py +78 -0
- kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py +119 -5
- kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +9 -3
- kiln_ai/adapters/model_adapters/test_structured_output.py +6 -9
- kiln_ai/adapters/test_ml_embedding_model_list.py +89 -279
- kiln_ai/adapters/test_ml_model_list.py +0 -10
- kiln_ai/adapters/vector_store/lancedb_adapter.py +24 -70
- kiln_ai/adapters/vector_store/lancedb_helpers.py +101 -0
- kiln_ai/adapters/vector_store/test_lancedb_adapter.py +9 -16
- kiln_ai/adapters/vector_store/test_lancedb_helpers.py +142 -0
- kiln_ai/adapters/vector_store_loaders/__init__.py +0 -0
- kiln_ai/adapters/vector_store_loaders/test_lancedb_loader.py +282 -0
- kiln_ai/adapters/vector_store_loaders/test_vector_store_loader.py +544 -0
- kiln_ai/adapters/vector_store_loaders/vector_store_loader.py +91 -0
- kiln_ai/datamodel/basemodel.py +31 -3
- kiln_ai/datamodel/external_tool_server.py +206 -54
- kiln_ai/datamodel/extraction.py +14 -0
- kiln_ai/datamodel/task.py +5 -0
- kiln_ai/datamodel/task_output.py +41 -11
- kiln_ai/datamodel/test_attachment.py +3 -3
- kiln_ai/datamodel/test_basemodel.py +269 -13
- kiln_ai/datamodel/test_datasource.py +50 -0
- kiln_ai/datamodel/test_external_tool_server.py +534 -152
- kiln_ai/datamodel/test_extraction_model.py +31 -0
- kiln_ai/datamodel/test_task.py +35 -1
- kiln_ai/datamodel/test_tool_id.py +106 -1
- kiln_ai/datamodel/tool_id.py +49 -0
- kiln_ai/tools/base_tool.py +30 -6
- kiln_ai/tools/built_in_tools/math_tools.py +12 -4
- kiln_ai/tools/kiln_task_tool.py +162 -0
- kiln_ai/tools/mcp_server_tool.py +7 -5
- kiln_ai/tools/mcp_session_manager.py +50 -24
- kiln_ai/tools/rag_tools.py +17 -6
- kiln_ai/tools/test_kiln_task_tool.py +527 -0
- kiln_ai/tools/test_mcp_server_tool.py +4 -15
- kiln_ai/tools/test_mcp_session_manager.py +186 -226
- kiln_ai/tools/test_rag_tools.py +86 -5
- kiln_ai/tools/test_tool_registry.py +199 -5
- kiln_ai/tools/tool_registry.py +49 -17
- kiln_ai/utils/filesystem.py +4 -4
- kiln_ai/utils/open_ai_types.py +19 -2
- kiln_ai/utils/pdf_utils.py +21 -0
- kiln_ai/utils/test_open_ai_types.py +88 -12
- kiln_ai/utils/test_pdf_utils.py +14 -1
- {kiln_ai-0.21.0.dist-info → kiln_ai-0.22.1.dist-info}/METADATA +79 -1
- {kiln_ai-0.21.0.dist-info → kiln_ai-0.22.1.dist-info}/RECORD +53 -45
- {kiln_ai-0.21.0.dist-info → kiln_ai-0.22.1.dist-info}/WHEEL +0 -0
- {kiln_ai-0.21.0.dist-info → kiln_ai-0.22.1.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -1,5 +1,8 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
|
|
1
3
|
import pytest
|
|
2
4
|
|
|
5
|
+
from kiln_ai.adapters.embedding.embedding_registry import embedding_adapter_from_type
|
|
3
6
|
from kiln_ai.adapters.ml_embedding_model_list import (
|
|
4
7
|
EmbeddingModelName,
|
|
5
8
|
KilnEmbeddingModel,
|
|
@@ -10,24 +13,28 @@ from kiln_ai.adapters.ml_embedding_model_list import (
|
|
|
10
13
|
get_model_by_name,
|
|
11
14
|
)
|
|
12
15
|
from kiln_ai.datamodel.datamodel_enums import ModelProviderName
|
|
16
|
+
from kiln_ai.datamodel.embedding import EmbeddingConfig
|
|
13
17
|
|
|
14
18
|
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
EmbeddingModelName.openai_text_embedding_3_small
|
|
22
|
-
|
|
23
|
-
)
|
|
24
|
-
assert (
|
|
25
|
-
EmbeddingModelName.openai_text_embedding_3_large
|
|
26
|
-
== "openai_text_embedding_3_large"
|
|
27
|
-
)
|
|
28
|
-
assert (
|
|
29
|
-
EmbeddingModelName.gemini_text_embedding_004 == "gemini_text_embedding_004"
|
|
19
|
+
@pytest.fixture
|
|
20
|
+
def litellm_adapter():
|
|
21
|
+
adapter = embedding_adapter_from_type(
|
|
22
|
+
EmbeddingConfig(
|
|
23
|
+
name="test-embedding",
|
|
24
|
+
model_provider_name=ModelProviderName.openai,
|
|
25
|
+
model_name=EmbeddingModelName.openai_text_embedding_3_small,
|
|
26
|
+
properties={},
|
|
30
27
|
)
|
|
28
|
+
)
|
|
29
|
+
return adapter
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def get_all_embedding_models_and_providers() -> List[tuple[str, str]]:
|
|
33
|
+
return [
|
|
34
|
+
(model.name, provider.name)
|
|
35
|
+
for model in built_in_embedding_models
|
|
36
|
+
for provider in model.providers
|
|
37
|
+
]
|
|
31
38
|
|
|
32
39
|
|
|
33
40
|
class TestKilnEmbeddingModelProvider:
|
|
@@ -120,222 +127,40 @@ class TestKilnEmbeddingModel:
|
|
|
120
127
|
assert model.providers[1].name == ModelProviderName.anthropic
|
|
121
128
|
|
|
122
129
|
|
|
123
|
-
class TestEmbeddingModelsList:
|
|
124
|
-
"""Test cases for the embedding_models list"""
|
|
125
|
-
|
|
126
|
-
def test_embedding_models_not_empty(self):
|
|
127
|
-
"""Test that the embedding_models list is not empty"""
|
|
128
|
-
assert len(built_in_embedding_models) > 0
|
|
129
|
-
|
|
130
|
-
def test_all_models_have_required_fields(self):
|
|
131
|
-
"""Test that all models in the list have required fields"""
|
|
132
|
-
for model in built_in_embedding_models:
|
|
133
|
-
assert hasattr(model, "family")
|
|
134
|
-
assert hasattr(model, "name")
|
|
135
|
-
assert hasattr(model, "friendly_name")
|
|
136
|
-
assert hasattr(model, "providers")
|
|
137
|
-
assert isinstance(model.name, str)
|
|
138
|
-
assert isinstance(model.friendly_name, str)
|
|
139
|
-
assert isinstance(model.providers, list)
|
|
140
|
-
assert len(model.providers) > 0
|
|
141
|
-
|
|
142
|
-
def test_all_providers_have_required_fields(self):
|
|
143
|
-
"""Test that all providers in all models have required fields"""
|
|
144
|
-
for model in built_in_embedding_models:
|
|
145
|
-
for provider in model.providers:
|
|
146
|
-
assert hasattr(provider, "name")
|
|
147
|
-
assert isinstance(provider.name, ModelProviderName)
|
|
148
|
-
|
|
149
|
-
def test_model_names_are_unique(self):
|
|
150
|
-
"""Test that all model names in the list are unique"""
|
|
151
|
-
model_names = [model.name for model in built_in_embedding_models]
|
|
152
|
-
assert len(model_names) == len(set(model_names))
|
|
153
|
-
|
|
154
|
-
def test_specific_models_exist(self):
|
|
155
|
-
"""Test that specific expected models exist in the list"""
|
|
156
|
-
model_names = [model.name for model in built_in_embedding_models]
|
|
157
|
-
|
|
158
|
-
assert EmbeddingModelName.openai_text_embedding_3_small in model_names
|
|
159
|
-
assert EmbeddingModelName.openai_text_embedding_3_large in model_names
|
|
160
|
-
assert EmbeddingModelName.gemini_text_embedding_004 in model_names
|
|
161
|
-
|
|
162
|
-
def test_openai_embedding_models(self):
|
|
163
|
-
"""Test specific OpenAI embedding models"""
|
|
164
|
-
openai_models = [
|
|
165
|
-
model
|
|
166
|
-
for model in built_in_embedding_models
|
|
167
|
-
if model.family == KilnEmbeddingModelFamily.openai
|
|
168
|
-
]
|
|
169
|
-
|
|
170
|
-
assert len(openai_models) >= 2 # Should have at least 2 OpenAI models
|
|
171
|
-
|
|
172
|
-
# Check for specific OpenAI models
|
|
173
|
-
openai_model_names = [model.name for model in openai_models]
|
|
174
|
-
assert EmbeddingModelName.openai_text_embedding_3_small in openai_model_names
|
|
175
|
-
assert EmbeddingModelName.openai_text_embedding_3_large in openai_model_names
|
|
176
|
-
|
|
177
|
-
def test_gemini_embedding_models(self):
|
|
178
|
-
"""Test specific Gemini embedding models"""
|
|
179
|
-
gemini_models = [
|
|
180
|
-
model
|
|
181
|
-
for model in built_in_embedding_models
|
|
182
|
-
if model.family == KilnEmbeddingModelFamily.gemini
|
|
183
|
-
]
|
|
184
|
-
|
|
185
|
-
assert len(gemini_models) >= 1 # Should have at least 1 Gemini model
|
|
186
|
-
|
|
187
|
-
# Check for specific Gemini model
|
|
188
|
-
gemini_model_names = [model.name for model in gemini_models]
|
|
189
|
-
assert EmbeddingModelName.gemini_text_embedding_004 in gemini_model_names
|
|
190
|
-
|
|
191
|
-
def test_openai_text_embedding_3_small_details(self):
|
|
192
|
-
"""Test specific details of OpenAI text-embedding-3-small model"""
|
|
193
|
-
model = get_model_by_name(EmbeddingModelName.openai_text_embedding_3_small)
|
|
194
|
-
|
|
195
|
-
assert model.family == KilnEmbeddingModelFamily.openai
|
|
196
|
-
assert model.friendly_name == "Text Embedding 3 Small"
|
|
197
|
-
assert len(model.providers) == 1
|
|
198
|
-
|
|
199
|
-
provider = model.providers[0]
|
|
200
|
-
assert provider.name == ModelProviderName.openai
|
|
201
|
-
assert provider.model_id == "text-embedding-3-small"
|
|
202
|
-
assert provider.n_dimensions == 1536
|
|
203
|
-
assert provider.max_input_tokens == 8192
|
|
204
|
-
assert provider.supports_custom_dimensions is True
|
|
205
|
-
|
|
206
|
-
def test_openai_text_embedding_3_large_details(self):
|
|
207
|
-
"""Test specific details of OpenAI text-embedding-3-large model"""
|
|
208
|
-
model = get_model_by_name(EmbeddingModelName.openai_text_embedding_3_large)
|
|
209
|
-
|
|
210
|
-
assert model.family == KilnEmbeddingModelFamily.openai
|
|
211
|
-
assert model.friendly_name == "Text Embedding 3 Large"
|
|
212
|
-
assert len(model.providers) == 1
|
|
213
|
-
|
|
214
|
-
provider = model.providers[0]
|
|
215
|
-
assert provider.name == ModelProviderName.openai
|
|
216
|
-
assert provider.model_id == "text-embedding-3-large"
|
|
217
|
-
assert provider.n_dimensions == 3072
|
|
218
|
-
assert provider.max_input_tokens == 8192
|
|
219
|
-
assert provider.supports_custom_dimensions is True
|
|
220
|
-
|
|
221
|
-
def test_gemini_text_embedding_004_details(self):
|
|
222
|
-
"""Test specific details of Gemini text-embedding-004 model"""
|
|
223
|
-
model = get_model_by_name(EmbeddingModelName.gemini_text_embedding_004)
|
|
224
|
-
|
|
225
|
-
assert model.family == KilnEmbeddingModelFamily.gemini
|
|
226
|
-
assert model.friendly_name == "Text Embedding 004"
|
|
227
|
-
assert len(model.providers) == 1
|
|
228
|
-
|
|
229
|
-
provider = model.providers[0]
|
|
230
|
-
assert provider.name == ModelProviderName.gemini_api
|
|
231
|
-
assert provider.model_id == "text-embedding-004"
|
|
232
|
-
assert provider.n_dimensions == 768
|
|
233
|
-
assert provider.max_input_tokens == 2048
|
|
234
|
-
assert provider.supports_custom_dimensions is False
|
|
235
|
-
|
|
236
|
-
|
|
237
130
|
class TestGetModelByName:
|
|
238
|
-
"""Test cases for get_model_by_name function"""
|
|
239
|
-
|
|
240
|
-
def test_get_existing_model(self):
|
|
241
|
-
"""Test getting an existing model by name"""
|
|
242
|
-
model = get_model_by_name(EmbeddingModelName.openai_text_embedding_3_small)
|
|
243
|
-
assert model.name == EmbeddingModelName.openai_text_embedding_3_small
|
|
244
|
-
assert model.family == KilnEmbeddingModelFamily.openai
|
|
245
|
-
|
|
246
|
-
def test_get_all_existing_models(self):
|
|
247
|
-
"""Test getting all existing models by name"""
|
|
248
|
-
for model_name in EmbeddingModelName:
|
|
249
|
-
model = get_model_by_name(model_name)
|
|
250
|
-
assert model.name == model_name
|
|
251
|
-
|
|
252
131
|
def test_get_nonexistent_model_raises_error(self):
|
|
253
132
|
"""Test that getting a nonexistent model raises ValueError"""
|
|
254
133
|
with pytest.raises(
|
|
255
134
|
ValueError, match="Embedding model nonexistent_model not found"
|
|
256
135
|
):
|
|
257
|
-
get_model_by_name("nonexistent_model")
|
|
258
|
-
|
|
259
|
-
def test_get_model_with_invalid_enum_value(self):
|
|
260
|
-
"""Test that getting a model with invalid enum value raises ValueError"""
|
|
261
|
-
with pytest.raises(ValueError, match="Embedding model invalid_enum not found"):
|
|
262
|
-
get_model_by_name("invalid_enum")
|
|
136
|
+
get_model_by_name("nonexistent_model") # type: ignore
|
|
263
137
|
|
|
264
138
|
@pytest.mark.parametrize(
|
|
265
|
-
"model_name
|
|
266
|
-
[
|
|
267
|
-
(
|
|
268
|
-
EmbeddingModelName.openai_text_embedding_3_small,
|
|
269
|
-
KilnEmbeddingModelFamily.openai,
|
|
270
|
-
"Text Embedding 3 Small",
|
|
271
|
-
),
|
|
272
|
-
(
|
|
273
|
-
EmbeddingModelName.openai_text_embedding_3_large,
|
|
274
|
-
KilnEmbeddingModelFamily.openai,
|
|
275
|
-
"Text Embedding 3 Large",
|
|
276
|
-
),
|
|
277
|
-
(
|
|
278
|
-
EmbeddingModelName.gemini_text_embedding_004,
|
|
279
|
-
KilnEmbeddingModelFamily.gemini,
|
|
280
|
-
"Text Embedding 004",
|
|
281
|
-
),
|
|
282
|
-
],
|
|
139
|
+
"model_name",
|
|
140
|
+
[model.name for model in built_in_embedding_models],
|
|
283
141
|
)
|
|
284
|
-
def
|
|
285
|
-
self, model_name, expected_family, expected_friendly_name
|
|
286
|
-
):
|
|
142
|
+
def test_model_retrieval(self, model_name):
|
|
287
143
|
"""Test retrieving models with parametrized test cases"""
|
|
288
144
|
model = get_model_by_name(model_name)
|
|
289
|
-
assert model.family ==
|
|
290
|
-
assert model.friendly_name ==
|
|
145
|
+
assert model.family == model.family
|
|
146
|
+
assert model.friendly_name == model.friendly_name
|
|
291
147
|
|
|
292
148
|
|
|
293
149
|
class TestBuiltInEmbeddingModelsFromProvider:
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
provider = built_in_embedding_models_from_provider(
|
|
299
|
-
provider_name=ModelProviderName.openai,
|
|
300
|
-
model_name=EmbeddingModelName.openai_text_embedding_3_small,
|
|
301
|
-
)
|
|
150
|
+
@pytest.mark.parametrize(
|
|
151
|
+
"model_name,provider_name", get_all_embedding_models_and_providers()
|
|
152
|
+
)
|
|
153
|
+
def test_get_all_existing_models_and_providers(self, model_name, provider_name):
|
|
154
|
+
provider = built_in_embedding_models_from_provider(provider_name, model_name)
|
|
302
155
|
|
|
303
156
|
assert provider is not None
|
|
304
|
-
assert provider.name ==
|
|
305
|
-
assert provider.model_id ==
|
|
306
|
-
assert provider.n_dimensions ==
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
combinations = [
|
|
311
|
-
(
|
|
312
|
-
ModelProviderName.openai,
|
|
313
|
-
EmbeddingModelName.openai_text_embedding_3_small,
|
|
314
|
-
),
|
|
315
|
-
(
|
|
316
|
-
ModelProviderName.openai,
|
|
317
|
-
EmbeddingModelName.openai_text_embedding_3_large,
|
|
318
|
-
),
|
|
319
|
-
(
|
|
320
|
-
ModelProviderName.gemini_api,
|
|
321
|
-
EmbeddingModelName.gemini_text_embedding_004,
|
|
322
|
-
),
|
|
323
|
-
]
|
|
324
|
-
|
|
325
|
-
for provider_name, model_name in combinations:
|
|
326
|
-
provider = built_in_embedding_models_from_provider(
|
|
327
|
-
provider_name, model_name
|
|
328
|
-
)
|
|
329
|
-
assert provider is not None
|
|
330
|
-
assert provider.name == provider_name
|
|
331
|
-
|
|
332
|
-
def test_get_nonexistent_provider_returns_none(self):
|
|
333
|
-
"""Test that getting a nonexistent provider returns None"""
|
|
334
|
-
provider = built_in_embedding_models_from_provider(
|
|
335
|
-
provider_name=ModelProviderName.anthropic, # Not used for embeddings
|
|
336
|
-
model_name=EmbeddingModelName.openai_text_embedding_3_small,
|
|
157
|
+
assert provider.name == provider_name
|
|
158
|
+
assert provider.model_id == provider.model_id
|
|
159
|
+
assert provider.n_dimensions == provider.n_dimensions
|
|
160
|
+
assert provider.max_input_tokens == provider.max_input_tokens
|
|
161
|
+
assert (
|
|
162
|
+
provider.supports_custom_dimensions == provider.supports_custom_dimensions
|
|
337
163
|
)
|
|
338
|
-
assert provider is None
|
|
339
164
|
|
|
340
165
|
def test_get_nonexistent_model_returns_none(self):
|
|
341
166
|
"""Test that getting a nonexistent model returns None"""
|
|
@@ -353,77 +178,62 @@ class TestBuiltInEmbeddingModelsFromProvider:
|
|
|
353
178
|
)
|
|
354
179
|
assert provider is None
|
|
355
180
|
|
|
356
|
-
def test_get_openai_text_embedding_3_small_provider_details(self):
|
|
357
|
-
"""Test specific details of OpenAI text-embedding-3-small provider"""
|
|
358
|
-
provider = built_in_embedding_models_from_provider(
|
|
359
|
-
provider_name=ModelProviderName.openai,
|
|
360
|
-
model_name=EmbeddingModelName.openai_text_embedding_3_small,
|
|
361
|
-
)
|
|
362
181
|
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
assert provider.model_id == "text-embedding-3-small"
|
|
366
|
-
assert provider.n_dimensions == 1536
|
|
367
|
-
assert provider.max_input_tokens == 8192
|
|
368
|
-
assert provider.supports_custom_dimensions is True
|
|
182
|
+
class TestGenerateEmbedding:
|
|
183
|
+
"""Test cases for generate_embedding function"""
|
|
369
184
|
|
|
370
|
-
|
|
371
|
-
""
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
185
|
+
@pytest.mark.parametrize(
|
|
186
|
+
"model_name,provider_name", get_all_embedding_models_and_providers()
|
|
187
|
+
)
|
|
188
|
+
@pytest.mark.paid
|
|
189
|
+
async def test_generate_embedding(self, model_name, provider_name):
|
|
190
|
+
"""Test generating an embedding"""
|
|
191
|
+
model_provider = built_in_embedding_models_from_provider(
|
|
192
|
+
provider_name, model_name
|
|
375
193
|
)
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
"""Test specific details of Gemini text-embedding-004 provider"""
|
|
386
|
-
provider = built_in_embedding_models_from_provider(
|
|
387
|
-
provider_name=ModelProviderName.gemini_api,
|
|
388
|
-
model_name=EmbeddingModelName.gemini_text_embedding_004,
|
|
194
|
+
assert model_provider is not None
|
|
195
|
+
|
|
196
|
+
embedding = embedding_adapter_from_type(
|
|
197
|
+
EmbeddingConfig(
|
|
198
|
+
name="test-embedding",
|
|
199
|
+
model_provider_name=provider_name,
|
|
200
|
+
model_name=model_name,
|
|
201
|
+
properties={},
|
|
202
|
+
)
|
|
389
203
|
)
|
|
390
|
-
|
|
391
|
-
assert
|
|
392
|
-
assert
|
|
393
|
-
assert provider.model_id == "text-embedding-004"
|
|
394
|
-
assert provider.n_dimensions == 768
|
|
395
|
-
assert provider.max_input_tokens == 2048
|
|
396
|
-
assert provider.supports_custom_dimensions is False
|
|
204
|
+
embedding = await embedding.generate_embeddings(["Hello, world!"])
|
|
205
|
+
assert len(embedding.embeddings) == 1
|
|
206
|
+
assert len(embedding.embeddings[0].vector) == model_provider.n_dimensions
|
|
397
207
|
|
|
398
208
|
@pytest.mark.parametrize(
|
|
399
|
-
"
|
|
400
|
-
[
|
|
401
|
-
(
|
|
402
|
-
ModelProviderName.openai,
|
|
403
|
-
EmbeddingModelName.openai_text_embedding_3_small,
|
|
404
|
-
"text-embedding-3-small",
|
|
405
|
-
1536,
|
|
406
|
-
),
|
|
407
|
-
(
|
|
408
|
-
ModelProviderName.openai,
|
|
409
|
-
EmbeddingModelName.openai_text_embedding_3_large,
|
|
410
|
-
"text-embedding-3-large",
|
|
411
|
-
3072,
|
|
412
|
-
),
|
|
413
|
-
(
|
|
414
|
-
ModelProviderName.gemini_api,
|
|
415
|
-
EmbeddingModelName.gemini_text_embedding_004,
|
|
416
|
-
"text-embedding-004",
|
|
417
|
-
768,
|
|
418
|
-
),
|
|
419
|
-
],
|
|
209
|
+
"model_name,provider_name", get_all_embedding_models_and_providers()
|
|
420
210
|
)
|
|
421
|
-
|
|
422
|
-
|
|
211
|
+
@pytest.mark.paid
|
|
212
|
+
async def test_generate_embedding_with_user_supplied_dimensions(
|
|
213
|
+
self, model_name, provider_name
|
|
423
214
|
):
|
|
424
|
-
"""Test
|
|
425
|
-
|
|
215
|
+
"""Test generating an embedding with user supplied dimensions"""
|
|
216
|
+
model_provider = built_in_embedding_models_from_provider(
|
|
217
|
+
provider_name=provider_name,
|
|
218
|
+
model_name=model_name,
|
|
219
|
+
)
|
|
220
|
+
assert model_provider is not None
|
|
426
221
|
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
222
|
+
if not model_provider.supports_custom_dimensions:
|
|
223
|
+
pytest.skip("Model does not support custom dimensions")
|
|
224
|
+
|
|
225
|
+
# max dim
|
|
226
|
+
max_dimensions = model_provider.n_dimensions
|
|
227
|
+
dimensions_target = max_dimensions // 2
|
|
228
|
+
|
|
229
|
+
embedding = embedding_adapter_from_type(
|
|
230
|
+
EmbeddingConfig(
|
|
231
|
+
name="test-embedding",
|
|
232
|
+
model_provider_name=provider_name,
|
|
233
|
+
model_name=model_name,
|
|
234
|
+
properties={"dimensions": dimensions_target},
|
|
235
|
+
)
|
|
236
|
+
)
|
|
237
|
+
embedding = await embedding.generate_embeddings(["Hello, world!"])
|
|
238
|
+
assert len(embedding.embeddings) == 1
|
|
239
|
+
assert len(embedding.embeddings[0].vector) == dimensions_target
|
|
@@ -360,16 +360,6 @@ def test_uncensored():
|
|
|
360
360
|
assert provider.suggested_for_uncensored_data_gen
|
|
361
361
|
|
|
362
362
|
|
|
363
|
-
def test_multimodal_capable():
|
|
364
|
-
"""Test that multimodal_capable is set correctly"""
|
|
365
|
-
model = get_model_by_name(ModelName.gpt_4_1)
|
|
366
|
-
for provider in model.providers:
|
|
367
|
-
assert provider.multimodal_capable
|
|
368
|
-
assert provider.supports_doc_extraction
|
|
369
|
-
assert provider.multimodal_mime_types is not None
|
|
370
|
-
assert len(provider.multimodal_mime_types) > 0
|
|
371
|
-
|
|
372
|
-
|
|
373
363
|
def test_no_empty_multimodal_mime_types():
|
|
374
364
|
"""Ensure that multimodal fields are self-consistent as they are interdependent"""
|
|
375
365
|
for model in built_in_models:
|
|
@@ -5,12 +5,7 @@ from pathlib import Path
|
|
|
5
5
|
from typing import Any, Dict, List, Literal, Optional, Set, TypedDict
|
|
6
6
|
|
|
7
7
|
from llama_index.core import StorageContext, VectorStoreIndex
|
|
8
|
-
from llama_index.core.schema import
|
|
9
|
-
BaseNode,
|
|
10
|
-
NodeRelationship,
|
|
11
|
-
RelatedNodeInfo,
|
|
12
|
-
TextNode,
|
|
13
|
-
)
|
|
8
|
+
from llama_index.core.schema import BaseNode, TextNode
|
|
14
9
|
from llama_index.core.vector_stores.types import (
|
|
15
10
|
VectorStoreQuery as LlamaIndexVectorStoreQuery,
|
|
16
11
|
)
|
|
@@ -24,15 +19,19 @@ from kiln_ai.adapters.vector_store.base_vector_store_adapter import (
|
|
|
24
19
|
SearchResult,
|
|
25
20
|
VectorStoreQuery,
|
|
26
21
|
)
|
|
22
|
+
from kiln_ai.adapters.vector_store.lancedb_helpers import (
|
|
23
|
+
convert_to_llama_index_node,
|
|
24
|
+
deterministic_chunk_id,
|
|
25
|
+
lancedb_construct_from_config,
|
|
26
|
+
store_type_to_lancedb_query_type,
|
|
27
|
+
)
|
|
27
28
|
from kiln_ai.datamodel.rag import RagConfig
|
|
28
29
|
from kiln_ai.datamodel.vector_store import (
|
|
29
30
|
VectorStoreConfig,
|
|
30
|
-
VectorStoreType,
|
|
31
31
|
raise_exhaustive_enum_error,
|
|
32
32
|
)
|
|
33
33
|
from kiln_ai.utils.config import Config
|
|
34
34
|
from kiln_ai.utils.env import temporary_env
|
|
35
|
-
from kiln_ai.utils.uuid import string_to_uuid
|
|
36
35
|
|
|
37
36
|
logger = logging.getLogger(__name__)
|
|
38
37
|
|
|
@@ -48,6 +47,7 @@ class LanceDBAdapter(BaseVectorStoreAdapter):
|
|
|
48
47
|
self,
|
|
49
48
|
rag_config: RagConfig,
|
|
50
49
|
vector_store_config: VectorStoreConfig,
|
|
50
|
+
lancedb_vector_store: LanceDBVectorStore | None = None,
|
|
51
51
|
):
|
|
52
52
|
super().__init__(rag_config, vector_store_config)
|
|
53
53
|
self.config_properties = self.vector_store_config.lancedb_properties
|
|
@@ -56,17 +56,15 @@ class LanceDBAdapter(BaseVectorStoreAdapter):
|
|
|
56
56
|
if vector_store_config.lancedb_properties.nprobes is not None:
|
|
57
57
|
kwargs["nprobes"] = vector_store_config.lancedb_properties.nprobes
|
|
58
58
|
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
**kwargs,
|
|
59
|
+
# allow overriding the vector store with a custom one, useful for user loading into an arbitrary
|
|
60
|
+
# deployment
|
|
61
|
+
self.lancedb_vector_store = (
|
|
62
|
+
lancedb_vector_store
|
|
63
|
+
or lancedb_construct_from_config(
|
|
64
|
+
vector_store_config,
|
|
65
|
+
uri=LanceDBAdapter.lancedb_path_for_config(rag_config),
|
|
66
|
+
)
|
|
68
67
|
)
|
|
69
|
-
|
|
70
68
|
self._index = None
|
|
71
69
|
|
|
72
70
|
@property
|
|
@@ -149,7 +147,7 @@ class LanceDBAdapter(BaseVectorStoreAdapter):
|
|
|
149
147
|
|
|
150
148
|
chunk_count_for_document = len(chunks)
|
|
151
149
|
deterministic_chunk_ids = [
|
|
152
|
-
|
|
150
|
+
deterministic_chunk_id(document_id, chunk_idx)
|
|
153
151
|
for chunk_idx in range(chunk_count_for_document)
|
|
154
152
|
]
|
|
155
153
|
|
|
@@ -176,42 +174,12 @@ class LanceDBAdapter(BaseVectorStoreAdapter):
|
|
|
176
174
|
zip(chunks_text, embeddings)
|
|
177
175
|
):
|
|
178
176
|
node_batch.append(
|
|
179
|
-
|
|
180
|
-
|
|
177
|
+
convert_to_llama_index_node(
|
|
178
|
+
document_id=document_id,
|
|
179
|
+
chunk_idx=chunk_idx,
|
|
180
|
+
node_id=deterministic_chunk_id(document_id, chunk_idx),
|
|
181
181
|
text=chunk_text,
|
|
182
|
-
|
|
183
|
-
metadata={
|
|
184
|
-
# metadata is populated by some internal llama_index logic
|
|
185
|
-
# that uses for example the source_node relationship
|
|
186
|
-
"kiln_doc_id": document_id,
|
|
187
|
-
"kiln_chunk_idx": chunk_idx,
|
|
188
|
-
#
|
|
189
|
-
# llama_index lancedb vector store automatically sets these metadata:
|
|
190
|
-
# "doc_id": "UUID node_id of the Source Node relationship",
|
|
191
|
-
# "document_id": "UUID node_id of the Source Node relationship",
|
|
192
|
-
# "ref_doc_id": "UUID node_id of the Source Node relationship"
|
|
193
|
-
#
|
|
194
|
-
# llama_index file loaders set these metadata, which would be useful to also support:
|
|
195
|
-
# "creation_date": "2025-09-03",
|
|
196
|
-
# "file_name": "file.pdf",
|
|
197
|
-
# "file_path": "/absolute/path/to/the/file.pdf",
|
|
198
|
-
# "file_size": 395154,
|
|
199
|
-
# "file_type": "application\/pdf",
|
|
200
|
-
# "last_modified_date": "2025-09-03",
|
|
201
|
-
# "page_label": "1",
|
|
202
|
-
},
|
|
203
|
-
relationships={
|
|
204
|
-
# when using the llama_index loaders, llama_index groups Nodes under Documents
|
|
205
|
-
# and relationships point to the Document (which is also a Node), which confusingly
|
|
206
|
-
# enough does not map to an actual file (for a PDF, a Document is a page of the PDF)
|
|
207
|
-
# the Document structure is not something that is persisted, so it is fine here
|
|
208
|
-
# if we have a relationship to a node_id that does not exist in the db
|
|
209
|
-
NodeRelationship.SOURCE: RelatedNodeInfo(
|
|
210
|
-
node_id=document_id,
|
|
211
|
-
node_type="1",
|
|
212
|
-
metadata={},
|
|
213
|
-
),
|
|
214
|
-
},
|
|
182
|
+
vector=embedding.vector,
|
|
215
183
|
)
|
|
216
184
|
)
|
|
217
185
|
|
|
@@ -330,10 +298,6 @@ class LanceDBAdapter(BaseVectorStoreAdapter):
|
|
|
330
298
|
return []
|
|
331
299
|
raise
|
|
332
300
|
|
|
333
|
-
def compute_deterministic_chunk_id(self, document_id: str, chunk_idx: int) -> str:
|
|
334
|
-
# the id_ of the Node must be a UUID string, otherwise llama_index / LanceDB fails downstream
|
|
335
|
-
return str(string_to_uuid(f"{document_id}::{chunk_idx}"))
|
|
336
|
-
|
|
337
301
|
async def count_records(self) -> int:
|
|
338
302
|
try:
|
|
339
303
|
table = self.lancedb_vector_store.table
|
|
@@ -346,15 +310,7 @@ class LanceDBAdapter(BaseVectorStoreAdapter):
|
|
|
346
310
|
|
|
347
311
|
@property
|
|
348
312
|
def query_type(self) -> Literal["fts", "hybrid", "vector"]:
|
|
349
|
-
|
|
350
|
-
case VectorStoreType.LANCE_DB_FTS:
|
|
351
|
-
return "fts"
|
|
352
|
-
case VectorStoreType.LANCE_DB_HYBRID:
|
|
353
|
-
return "hybrid"
|
|
354
|
-
case VectorStoreType.LANCE_DB_VECTOR:
|
|
355
|
-
return "vector"
|
|
356
|
-
case _:
|
|
357
|
-
raise_exhaustive_enum_error(self.vector_store_config.store_type)
|
|
313
|
+
return store_type_to_lancedb_query_type(self.vector_store_config.store_type)
|
|
358
314
|
|
|
359
315
|
@staticmethod
|
|
360
316
|
def lancedb_path_for_config(rag_config: RagConfig) -> str:
|
|
@@ -380,9 +336,7 @@ class LanceDBAdapter(BaseVectorStoreAdapter):
|
|
|
380
336
|
kiln_doc_id = row["metadata"]["kiln_doc_id"]
|
|
381
337
|
if kiln_doc_id not in document_ids:
|
|
382
338
|
kiln_chunk_idx = row["metadata"]["kiln_chunk_idx"]
|
|
383
|
-
record_id =
|
|
384
|
-
kiln_doc_id, kiln_chunk_idx
|
|
385
|
-
)
|
|
339
|
+
record_id = deterministic_chunk_id(kiln_doc_id, kiln_chunk_idx)
|
|
386
340
|
rows_to_delete.append(record_id)
|
|
387
341
|
|
|
388
342
|
if rows_to_delete:
|