kiln-ai 0.21.0__py3-none-any.whl → 0.22.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (53) hide show
  1. kiln_ai/adapters/extractors/litellm_extractor.py +52 -32
  2. kiln_ai/adapters/extractors/test_litellm_extractor.py +169 -71
  3. kiln_ai/adapters/ml_embedding_model_list.py +330 -28
  4. kiln_ai/adapters/ml_model_list.py +503 -23
  5. kiln_ai/adapters/model_adapters/litellm_adapter.py +39 -8
  6. kiln_ai/adapters/model_adapters/test_litellm_adapter.py +78 -0
  7. kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py +119 -5
  8. kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +9 -3
  9. kiln_ai/adapters/model_adapters/test_structured_output.py +6 -9
  10. kiln_ai/adapters/test_ml_embedding_model_list.py +89 -279
  11. kiln_ai/adapters/test_ml_model_list.py +0 -10
  12. kiln_ai/adapters/vector_store/lancedb_adapter.py +24 -70
  13. kiln_ai/adapters/vector_store/lancedb_helpers.py +101 -0
  14. kiln_ai/adapters/vector_store/test_lancedb_adapter.py +9 -16
  15. kiln_ai/adapters/vector_store/test_lancedb_helpers.py +142 -0
  16. kiln_ai/adapters/vector_store_loaders/__init__.py +0 -0
  17. kiln_ai/adapters/vector_store_loaders/test_lancedb_loader.py +282 -0
  18. kiln_ai/adapters/vector_store_loaders/test_vector_store_loader.py +544 -0
  19. kiln_ai/adapters/vector_store_loaders/vector_store_loader.py +91 -0
  20. kiln_ai/datamodel/basemodel.py +31 -3
  21. kiln_ai/datamodel/external_tool_server.py +206 -54
  22. kiln_ai/datamodel/extraction.py +14 -0
  23. kiln_ai/datamodel/task.py +5 -0
  24. kiln_ai/datamodel/task_output.py +41 -11
  25. kiln_ai/datamodel/test_attachment.py +3 -3
  26. kiln_ai/datamodel/test_basemodel.py +269 -13
  27. kiln_ai/datamodel/test_datasource.py +50 -0
  28. kiln_ai/datamodel/test_external_tool_server.py +534 -152
  29. kiln_ai/datamodel/test_extraction_model.py +31 -0
  30. kiln_ai/datamodel/test_task.py +35 -1
  31. kiln_ai/datamodel/test_tool_id.py +106 -1
  32. kiln_ai/datamodel/tool_id.py +49 -0
  33. kiln_ai/tools/base_tool.py +30 -6
  34. kiln_ai/tools/built_in_tools/math_tools.py +12 -4
  35. kiln_ai/tools/kiln_task_tool.py +162 -0
  36. kiln_ai/tools/mcp_server_tool.py +7 -5
  37. kiln_ai/tools/mcp_session_manager.py +50 -24
  38. kiln_ai/tools/rag_tools.py +17 -6
  39. kiln_ai/tools/test_kiln_task_tool.py +527 -0
  40. kiln_ai/tools/test_mcp_server_tool.py +4 -15
  41. kiln_ai/tools/test_mcp_session_manager.py +186 -226
  42. kiln_ai/tools/test_rag_tools.py +86 -5
  43. kiln_ai/tools/test_tool_registry.py +199 -5
  44. kiln_ai/tools/tool_registry.py +49 -17
  45. kiln_ai/utils/filesystem.py +4 -4
  46. kiln_ai/utils/open_ai_types.py +19 -2
  47. kiln_ai/utils/pdf_utils.py +21 -0
  48. kiln_ai/utils/test_open_ai_types.py +88 -12
  49. kiln_ai/utils/test_pdf_utils.py +14 -1
  50. {kiln_ai-0.21.0.dist-info → kiln_ai-0.22.1.dist-info}/METADATA +79 -1
  51. {kiln_ai-0.21.0.dist-info → kiln_ai-0.22.1.dist-info}/RECORD +53 -45
  52. {kiln_ai-0.21.0.dist-info → kiln_ai-0.22.1.dist-info}/WHEEL +0 -0
  53. {kiln_ai-0.21.0.dist-info → kiln_ai-0.22.1.dist-info}/licenses/LICENSE.txt +0 -0
@@ -1,5 +1,8 @@
1
+ from typing import List
2
+
1
3
  import pytest
2
4
 
5
+ from kiln_ai.adapters.embedding.embedding_registry import embedding_adapter_from_type
3
6
  from kiln_ai.adapters.ml_embedding_model_list import (
4
7
  EmbeddingModelName,
5
8
  KilnEmbeddingModel,
@@ -10,24 +13,28 @@ from kiln_ai.adapters.ml_embedding_model_list import (
10
13
  get_model_by_name,
11
14
  )
12
15
  from kiln_ai.datamodel.datamodel_enums import ModelProviderName
16
+ from kiln_ai.datamodel.embedding import EmbeddingConfig
13
17
 
14
18
 
15
- class TestEmbeddingModelName:
16
- """Test cases for EmbeddingModelName enum"""
17
-
18
- def test_enum_values(self):
19
- """Test that enum values are correctly defined"""
20
- assert (
21
- EmbeddingModelName.openai_text_embedding_3_small
22
- == "openai_text_embedding_3_small"
23
- )
24
- assert (
25
- EmbeddingModelName.openai_text_embedding_3_large
26
- == "openai_text_embedding_3_large"
27
- )
28
- assert (
29
- EmbeddingModelName.gemini_text_embedding_004 == "gemini_text_embedding_004"
19
+ @pytest.fixture
20
+ def litellm_adapter():
21
+ adapter = embedding_adapter_from_type(
22
+ EmbeddingConfig(
23
+ name="test-embedding",
24
+ model_provider_name=ModelProviderName.openai,
25
+ model_name=EmbeddingModelName.openai_text_embedding_3_small,
26
+ properties={},
30
27
  )
28
+ )
29
+ return adapter
30
+
31
+
32
+ def get_all_embedding_models_and_providers() -> List[tuple[str, str]]:
33
+ return [
34
+ (model.name, provider.name)
35
+ for model in built_in_embedding_models
36
+ for provider in model.providers
37
+ ]
31
38
 
32
39
 
33
40
  class TestKilnEmbeddingModelProvider:
@@ -120,222 +127,40 @@ class TestKilnEmbeddingModel:
120
127
  assert model.providers[1].name == ModelProviderName.anthropic
121
128
 
122
129
 
123
- class TestEmbeddingModelsList:
124
- """Test cases for the embedding_models list"""
125
-
126
- def test_embedding_models_not_empty(self):
127
- """Test that the embedding_models list is not empty"""
128
- assert len(built_in_embedding_models) > 0
129
-
130
- def test_all_models_have_required_fields(self):
131
- """Test that all models in the list have required fields"""
132
- for model in built_in_embedding_models:
133
- assert hasattr(model, "family")
134
- assert hasattr(model, "name")
135
- assert hasattr(model, "friendly_name")
136
- assert hasattr(model, "providers")
137
- assert isinstance(model.name, str)
138
- assert isinstance(model.friendly_name, str)
139
- assert isinstance(model.providers, list)
140
- assert len(model.providers) > 0
141
-
142
- def test_all_providers_have_required_fields(self):
143
- """Test that all providers in all models have required fields"""
144
- for model in built_in_embedding_models:
145
- for provider in model.providers:
146
- assert hasattr(provider, "name")
147
- assert isinstance(provider.name, ModelProviderName)
148
-
149
- def test_model_names_are_unique(self):
150
- """Test that all model names in the list are unique"""
151
- model_names = [model.name for model in built_in_embedding_models]
152
- assert len(model_names) == len(set(model_names))
153
-
154
- def test_specific_models_exist(self):
155
- """Test that specific expected models exist in the list"""
156
- model_names = [model.name for model in built_in_embedding_models]
157
-
158
- assert EmbeddingModelName.openai_text_embedding_3_small in model_names
159
- assert EmbeddingModelName.openai_text_embedding_3_large in model_names
160
- assert EmbeddingModelName.gemini_text_embedding_004 in model_names
161
-
162
- def test_openai_embedding_models(self):
163
- """Test specific OpenAI embedding models"""
164
- openai_models = [
165
- model
166
- for model in built_in_embedding_models
167
- if model.family == KilnEmbeddingModelFamily.openai
168
- ]
169
-
170
- assert len(openai_models) >= 2 # Should have at least 2 OpenAI models
171
-
172
- # Check for specific OpenAI models
173
- openai_model_names = [model.name for model in openai_models]
174
- assert EmbeddingModelName.openai_text_embedding_3_small in openai_model_names
175
- assert EmbeddingModelName.openai_text_embedding_3_large in openai_model_names
176
-
177
- def test_gemini_embedding_models(self):
178
- """Test specific Gemini embedding models"""
179
- gemini_models = [
180
- model
181
- for model in built_in_embedding_models
182
- if model.family == KilnEmbeddingModelFamily.gemini
183
- ]
184
-
185
- assert len(gemini_models) >= 1 # Should have at least 1 Gemini model
186
-
187
- # Check for specific Gemini model
188
- gemini_model_names = [model.name for model in gemini_models]
189
- assert EmbeddingModelName.gemini_text_embedding_004 in gemini_model_names
190
-
191
- def test_openai_text_embedding_3_small_details(self):
192
- """Test specific details of OpenAI text-embedding-3-small model"""
193
- model = get_model_by_name(EmbeddingModelName.openai_text_embedding_3_small)
194
-
195
- assert model.family == KilnEmbeddingModelFamily.openai
196
- assert model.friendly_name == "Text Embedding 3 Small"
197
- assert len(model.providers) == 1
198
-
199
- provider = model.providers[0]
200
- assert provider.name == ModelProviderName.openai
201
- assert provider.model_id == "text-embedding-3-small"
202
- assert provider.n_dimensions == 1536
203
- assert provider.max_input_tokens == 8192
204
- assert provider.supports_custom_dimensions is True
205
-
206
- def test_openai_text_embedding_3_large_details(self):
207
- """Test specific details of OpenAI text-embedding-3-large model"""
208
- model = get_model_by_name(EmbeddingModelName.openai_text_embedding_3_large)
209
-
210
- assert model.family == KilnEmbeddingModelFamily.openai
211
- assert model.friendly_name == "Text Embedding 3 Large"
212
- assert len(model.providers) == 1
213
-
214
- provider = model.providers[0]
215
- assert provider.name == ModelProviderName.openai
216
- assert provider.model_id == "text-embedding-3-large"
217
- assert provider.n_dimensions == 3072
218
- assert provider.max_input_tokens == 8192
219
- assert provider.supports_custom_dimensions is True
220
-
221
- def test_gemini_text_embedding_004_details(self):
222
- """Test specific details of Gemini text-embedding-004 model"""
223
- model = get_model_by_name(EmbeddingModelName.gemini_text_embedding_004)
224
-
225
- assert model.family == KilnEmbeddingModelFamily.gemini
226
- assert model.friendly_name == "Text Embedding 004"
227
- assert len(model.providers) == 1
228
-
229
- provider = model.providers[0]
230
- assert provider.name == ModelProviderName.gemini_api
231
- assert provider.model_id == "text-embedding-004"
232
- assert provider.n_dimensions == 768
233
- assert provider.max_input_tokens == 2048
234
- assert provider.supports_custom_dimensions is False
235
-
236
-
237
130
  class TestGetModelByName:
238
- """Test cases for get_model_by_name function"""
239
-
240
- def test_get_existing_model(self):
241
- """Test getting an existing model by name"""
242
- model = get_model_by_name(EmbeddingModelName.openai_text_embedding_3_small)
243
- assert model.name == EmbeddingModelName.openai_text_embedding_3_small
244
- assert model.family == KilnEmbeddingModelFamily.openai
245
-
246
- def test_get_all_existing_models(self):
247
- """Test getting all existing models by name"""
248
- for model_name in EmbeddingModelName:
249
- model = get_model_by_name(model_name)
250
- assert model.name == model_name
251
-
252
131
  def test_get_nonexistent_model_raises_error(self):
253
132
  """Test that getting a nonexistent model raises ValueError"""
254
133
  with pytest.raises(
255
134
  ValueError, match="Embedding model nonexistent_model not found"
256
135
  ):
257
- get_model_by_name("nonexistent_model")
258
-
259
- def test_get_model_with_invalid_enum_value(self):
260
- """Test that getting a model with invalid enum value raises ValueError"""
261
- with pytest.raises(ValueError, match="Embedding model invalid_enum not found"):
262
- get_model_by_name("invalid_enum")
136
+ get_model_by_name("nonexistent_model") # type: ignore
263
137
 
264
138
  @pytest.mark.parametrize(
265
- "model_name,expected_family,expected_friendly_name",
266
- [
267
- (
268
- EmbeddingModelName.openai_text_embedding_3_small,
269
- KilnEmbeddingModelFamily.openai,
270
- "Text Embedding 3 Small",
271
- ),
272
- (
273
- EmbeddingModelName.openai_text_embedding_3_large,
274
- KilnEmbeddingModelFamily.openai,
275
- "Text Embedding 3 Large",
276
- ),
277
- (
278
- EmbeddingModelName.gemini_text_embedding_004,
279
- KilnEmbeddingModelFamily.gemini,
280
- "Text Embedding 004",
281
- ),
282
- ],
139
+ "model_name",
140
+ [model.name for model in built_in_embedding_models],
283
141
  )
284
- def test_parametrized_model_retrieval(
285
- self, model_name, expected_family, expected_friendly_name
286
- ):
142
+ def test_model_retrieval(self, model_name):
287
143
  """Test retrieving models with parametrized test cases"""
288
144
  model = get_model_by_name(model_name)
289
- assert model.family == expected_family
290
- assert model.friendly_name == expected_friendly_name
145
+ assert model.family == model.family
146
+ assert model.friendly_name == model.friendly_name
291
147
 
292
148
 
293
149
  class TestBuiltInEmbeddingModelsFromProvider:
294
- """Test cases for built_in_embedding_models_from_provider function"""
295
-
296
- def test_get_existing_provider_for_model(self):
297
- """Test getting an existing provider for a model"""
298
- provider = built_in_embedding_models_from_provider(
299
- provider_name=ModelProviderName.openai,
300
- model_name=EmbeddingModelName.openai_text_embedding_3_small,
301
- )
150
+ @pytest.mark.parametrize(
151
+ "model_name,provider_name", get_all_embedding_models_and_providers()
152
+ )
153
+ def test_get_all_existing_models_and_providers(self, model_name, provider_name):
154
+ provider = built_in_embedding_models_from_provider(provider_name, model_name)
302
155
 
303
156
  assert provider is not None
304
- assert provider.name == ModelProviderName.openai
305
- assert provider.model_id == "text-embedding-3-small"
306
- assert provider.n_dimensions == 1536
307
-
308
- def test_get_all_existing_provider_model_combinations(self):
309
- """Test getting all existing provider-model combinations"""
310
- combinations = [
311
- (
312
- ModelProviderName.openai,
313
- EmbeddingModelName.openai_text_embedding_3_small,
314
- ),
315
- (
316
- ModelProviderName.openai,
317
- EmbeddingModelName.openai_text_embedding_3_large,
318
- ),
319
- (
320
- ModelProviderName.gemini_api,
321
- EmbeddingModelName.gemini_text_embedding_004,
322
- ),
323
- ]
324
-
325
- for provider_name, model_name in combinations:
326
- provider = built_in_embedding_models_from_provider(
327
- provider_name, model_name
328
- )
329
- assert provider is not None
330
- assert provider.name == provider_name
331
-
332
- def test_get_nonexistent_provider_returns_none(self):
333
- """Test that getting a nonexistent provider returns None"""
334
- provider = built_in_embedding_models_from_provider(
335
- provider_name=ModelProviderName.anthropic, # Not used for embeddings
336
- model_name=EmbeddingModelName.openai_text_embedding_3_small,
157
+ assert provider.name == provider_name
158
+ assert provider.model_id == provider.model_id
159
+ assert provider.n_dimensions == provider.n_dimensions
160
+ assert provider.max_input_tokens == provider.max_input_tokens
161
+ assert (
162
+ provider.supports_custom_dimensions == provider.supports_custom_dimensions
337
163
  )
338
- assert provider is None
339
164
 
340
165
  def test_get_nonexistent_model_returns_none(self):
341
166
  """Test that getting a nonexistent model returns None"""
@@ -353,77 +178,62 @@ class TestBuiltInEmbeddingModelsFromProvider:
353
178
  )
354
179
  assert provider is None
355
180
 
356
- def test_get_openai_text_embedding_3_small_provider_details(self):
357
- """Test specific details of OpenAI text-embedding-3-small provider"""
358
- provider = built_in_embedding_models_from_provider(
359
- provider_name=ModelProviderName.openai,
360
- model_name=EmbeddingModelName.openai_text_embedding_3_small,
361
- )
362
181
 
363
- assert provider is not None
364
- assert provider.name == ModelProviderName.openai
365
- assert provider.model_id == "text-embedding-3-small"
366
- assert provider.n_dimensions == 1536
367
- assert provider.max_input_tokens == 8192
368
- assert provider.supports_custom_dimensions is True
182
+ class TestGenerateEmbedding:
183
+ """Test cases for generate_embedding function"""
369
184
 
370
- def test_get_openai_text_embedding_3_large_provider_details(self):
371
- """Test specific details of OpenAI text-embedding-3-large provider"""
372
- provider = built_in_embedding_models_from_provider(
373
- provider_name=ModelProviderName.openai,
374
- model_name=EmbeddingModelName.openai_text_embedding_3_large,
185
+ @pytest.mark.parametrize(
186
+ "model_name,provider_name", get_all_embedding_models_and_providers()
187
+ )
188
+ @pytest.mark.paid
189
+ async def test_generate_embedding(self, model_name, provider_name):
190
+ """Test generating an embedding"""
191
+ model_provider = built_in_embedding_models_from_provider(
192
+ provider_name, model_name
375
193
  )
376
-
377
- assert provider is not None
378
- assert provider.name == ModelProviderName.openai
379
- assert provider.model_id == "text-embedding-3-large"
380
- assert provider.n_dimensions == 3072
381
- assert provider.max_input_tokens == 8192
382
- assert provider.supports_custom_dimensions is True
383
-
384
- def test_get_gemini_text_embedding_004_provider_details(self):
385
- """Test specific details of Gemini text-embedding-004 provider"""
386
- provider = built_in_embedding_models_from_provider(
387
- provider_name=ModelProviderName.gemini_api,
388
- model_name=EmbeddingModelName.gemini_text_embedding_004,
194
+ assert model_provider is not None
195
+
196
+ embedding = embedding_adapter_from_type(
197
+ EmbeddingConfig(
198
+ name="test-embedding",
199
+ model_provider_name=provider_name,
200
+ model_name=model_name,
201
+ properties={},
202
+ )
389
203
  )
390
-
391
- assert provider is not None
392
- assert provider.name == ModelProviderName.gemini_api
393
- assert provider.model_id == "text-embedding-004"
394
- assert provider.n_dimensions == 768
395
- assert provider.max_input_tokens == 2048
396
- assert provider.supports_custom_dimensions is False
204
+ embedding = await embedding.generate_embeddings(["Hello, world!"])
205
+ assert len(embedding.embeddings) == 1
206
+ assert len(embedding.embeddings[0].vector) == model_provider.n_dimensions
397
207
 
398
208
  @pytest.mark.parametrize(
399
- "provider_name,model_name,expected_model_id,expected_dimensions",
400
- [
401
- (
402
- ModelProviderName.openai,
403
- EmbeddingModelName.openai_text_embedding_3_small,
404
- "text-embedding-3-small",
405
- 1536,
406
- ),
407
- (
408
- ModelProviderName.openai,
409
- EmbeddingModelName.openai_text_embedding_3_large,
410
- "text-embedding-3-large",
411
- 3072,
412
- ),
413
- (
414
- ModelProviderName.gemini_api,
415
- EmbeddingModelName.gemini_text_embedding_004,
416
- "text-embedding-004",
417
- 768,
418
- ),
419
- ],
209
+ "model_name,provider_name", get_all_embedding_models_and_providers()
420
210
  )
421
- def test_parametrized_provider_retrieval(
422
- self, provider_name, model_name, expected_model_id, expected_dimensions
211
+ @pytest.mark.paid
212
+ async def test_generate_embedding_with_user_supplied_dimensions(
213
+ self, model_name, provider_name
423
214
  ):
424
- """Test retrieving providers with parametrized test cases"""
425
- provider = built_in_embedding_models_from_provider(provider_name, model_name)
215
+ """Test generating an embedding with user supplied dimensions"""
216
+ model_provider = built_in_embedding_models_from_provider(
217
+ provider_name=provider_name,
218
+ model_name=model_name,
219
+ )
220
+ assert model_provider is not None
426
221
 
427
- assert provider is not None
428
- assert provider.model_id == expected_model_id
429
- assert provider.n_dimensions == expected_dimensions
222
+ if not model_provider.supports_custom_dimensions:
223
+ pytest.skip("Model does not support custom dimensions")
224
+
225
+ # max dim
226
+ max_dimensions = model_provider.n_dimensions
227
+ dimensions_target = max_dimensions // 2
228
+
229
+ embedding = embedding_adapter_from_type(
230
+ EmbeddingConfig(
231
+ name="test-embedding",
232
+ model_provider_name=provider_name,
233
+ model_name=model_name,
234
+ properties={"dimensions": dimensions_target},
235
+ )
236
+ )
237
+ embedding = await embedding.generate_embeddings(["Hello, world!"])
238
+ assert len(embedding.embeddings) == 1
239
+ assert len(embedding.embeddings[0].vector) == dimensions_target
@@ -360,16 +360,6 @@ def test_uncensored():
360
360
  assert provider.suggested_for_uncensored_data_gen
361
361
 
362
362
 
363
- def test_multimodal_capable():
364
- """Test that multimodal_capable is set correctly"""
365
- model = get_model_by_name(ModelName.gpt_4_1)
366
- for provider in model.providers:
367
- assert provider.multimodal_capable
368
- assert provider.supports_doc_extraction
369
- assert provider.multimodal_mime_types is not None
370
- assert len(provider.multimodal_mime_types) > 0
371
-
372
-
373
363
  def test_no_empty_multimodal_mime_types():
374
364
  """Ensure that multimodal fields are self-consistent as they are interdependent"""
375
365
  for model in built_in_models:
@@ -5,12 +5,7 @@ from pathlib import Path
5
5
  from typing import Any, Dict, List, Literal, Optional, Set, TypedDict
6
6
 
7
7
  from llama_index.core import StorageContext, VectorStoreIndex
8
- from llama_index.core.schema import (
9
- BaseNode,
10
- NodeRelationship,
11
- RelatedNodeInfo,
12
- TextNode,
13
- )
8
+ from llama_index.core.schema import BaseNode, TextNode
14
9
  from llama_index.core.vector_stores.types import (
15
10
  VectorStoreQuery as LlamaIndexVectorStoreQuery,
16
11
  )
@@ -24,15 +19,19 @@ from kiln_ai.adapters.vector_store.base_vector_store_adapter import (
24
19
  SearchResult,
25
20
  VectorStoreQuery,
26
21
  )
22
+ from kiln_ai.adapters.vector_store.lancedb_helpers import (
23
+ convert_to_llama_index_node,
24
+ deterministic_chunk_id,
25
+ lancedb_construct_from_config,
26
+ store_type_to_lancedb_query_type,
27
+ )
27
28
  from kiln_ai.datamodel.rag import RagConfig
28
29
  from kiln_ai.datamodel.vector_store import (
29
30
  VectorStoreConfig,
30
- VectorStoreType,
31
31
  raise_exhaustive_enum_error,
32
32
  )
33
33
  from kiln_ai.utils.config import Config
34
34
  from kiln_ai.utils.env import temporary_env
35
- from kiln_ai.utils.uuid import string_to_uuid
36
35
 
37
36
  logger = logging.getLogger(__name__)
38
37
 
@@ -48,6 +47,7 @@ class LanceDBAdapter(BaseVectorStoreAdapter):
48
47
  self,
49
48
  rag_config: RagConfig,
50
49
  vector_store_config: VectorStoreConfig,
50
+ lancedb_vector_store: LanceDBVectorStore | None = None,
51
51
  ):
52
52
  super().__init__(rag_config, vector_store_config)
53
53
  self.config_properties = self.vector_store_config.lancedb_properties
@@ -56,17 +56,15 @@ class LanceDBAdapter(BaseVectorStoreAdapter):
56
56
  if vector_store_config.lancedb_properties.nprobes is not None:
57
57
  kwargs["nprobes"] = vector_store_config.lancedb_properties.nprobes
58
58
 
59
- self.lancedb_vector_store = LanceDBVectorStore(
60
- mode="create",
61
- uri=LanceDBAdapter.lancedb_path_for_config(rag_config),
62
- query_type=self.query_type,
63
- overfetch_factor=vector_store_config.lancedb_properties.overfetch_factor,
64
- vector_column_name=vector_store_config.lancedb_properties.vector_column_name,
65
- text_key=vector_store_config.lancedb_properties.text_key,
66
- doc_id_key=vector_store_config.lancedb_properties.doc_id_key,
67
- **kwargs,
59
+ # allow overriding the vector store with a custom one, useful for user loading into an arbitrary
60
+ # deployment
61
+ self.lancedb_vector_store = (
62
+ lancedb_vector_store
63
+ or lancedb_construct_from_config(
64
+ vector_store_config,
65
+ uri=LanceDBAdapter.lancedb_path_for_config(rag_config),
66
+ )
68
67
  )
69
-
70
68
  self._index = None
71
69
 
72
70
  @property
@@ -149,7 +147,7 @@ class LanceDBAdapter(BaseVectorStoreAdapter):
149
147
 
150
148
  chunk_count_for_document = len(chunks)
151
149
  deterministic_chunk_ids = [
152
- self.compute_deterministic_chunk_id(document_id, chunk_idx)
150
+ deterministic_chunk_id(document_id, chunk_idx)
153
151
  for chunk_idx in range(chunk_count_for_document)
154
152
  ]
155
153
 
@@ -176,42 +174,12 @@ class LanceDBAdapter(BaseVectorStoreAdapter):
176
174
  zip(chunks_text, embeddings)
177
175
  ):
178
176
  node_batch.append(
179
- TextNode(
180
- id_=deterministic_chunk_ids[chunk_idx],
177
+ convert_to_llama_index_node(
178
+ document_id=document_id,
179
+ chunk_idx=chunk_idx,
180
+ node_id=deterministic_chunk_id(document_id, chunk_idx),
181
181
  text=chunk_text,
182
- embedding=embedding.vector,
183
- metadata={
184
- # metadata is populated by some internal llama_index logic
185
- # that uses for example the source_node relationship
186
- "kiln_doc_id": document_id,
187
- "kiln_chunk_idx": chunk_idx,
188
- #
189
- # llama_index lancedb vector store automatically sets these metadata:
190
- # "doc_id": "UUID node_id of the Source Node relationship",
191
- # "document_id": "UUID node_id of the Source Node relationship",
192
- # "ref_doc_id": "UUID node_id of the Source Node relationship"
193
- #
194
- # llama_index file loaders set these metadata, which would be useful to also support:
195
- # "creation_date": "2025-09-03",
196
- # "file_name": "file.pdf",
197
- # "file_path": "/absolute/path/to/the/file.pdf",
198
- # "file_size": 395154,
199
- # "file_type": "application\/pdf",
200
- # "last_modified_date": "2025-09-03",
201
- # "page_label": "1",
202
- },
203
- relationships={
204
- # when using the llama_index loaders, llama_index groups Nodes under Documents
205
- # and relationships point to the Document (which is also a Node), which confusingly
206
- # enough does not map to an actual file (for a PDF, a Document is a page of the PDF)
207
- # the Document structure is not something that is persisted, so it is fine here
208
- # if we have a relationship to a node_id that does not exist in the db
209
- NodeRelationship.SOURCE: RelatedNodeInfo(
210
- node_id=document_id,
211
- node_type="1",
212
- metadata={},
213
- ),
214
- },
182
+ vector=embedding.vector,
215
183
  )
216
184
  )
217
185
 
@@ -330,10 +298,6 @@ class LanceDBAdapter(BaseVectorStoreAdapter):
330
298
  return []
331
299
  raise
332
300
 
333
- def compute_deterministic_chunk_id(self, document_id: str, chunk_idx: int) -> str:
334
- # the id_ of the Node must be a UUID string, otherwise llama_index / LanceDB fails downstream
335
- return str(string_to_uuid(f"{document_id}::{chunk_idx}"))
336
-
337
301
  async def count_records(self) -> int:
338
302
  try:
339
303
  table = self.lancedb_vector_store.table
@@ -346,15 +310,7 @@ class LanceDBAdapter(BaseVectorStoreAdapter):
346
310
 
347
311
  @property
348
312
  def query_type(self) -> Literal["fts", "hybrid", "vector"]:
349
- match self.vector_store_config.store_type:
350
- case VectorStoreType.LANCE_DB_FTS:
351
- return "fts"
352
- case VectorStoreType.LANCE_DB_HYBRID:
353
- return "hybrid"
354
- case VectorStoreType.LANCE_DB_VECTOR:
355
- return "vector"
356
- case _:
357
- raise_exhaustive_enum_error(self.vector_store_config.store_type)
313
+ return store_type_to_lancedb_query_type(self.vector_store_config.store_type)
358
314
 
359
315
  @staticmethod
360
316
  def lancedb_path_for_config(rag_config: RagConfig) -> str:
@@ -380,9 +336,7 @@ class LanceDBAdapter(BaseVectorStoreAdapter):
380
336
  kiln_doc_id = row["metadata"]["kiln_doc_id"]
381
337
  if kiln_doc_id not in document_ids:
382
338
  kiln_chunk_idx = row["metadata"]["kiln_chunk_idx"]
383
- record_id = self.compute_deterministic_chunk_id(
384
- kiln_doc_id, kiln_chunk_idx
385
- )
339
+ record_id = deterministic_chunk_id(kiln_doc_id, kiln_chunk_idx)
386
340
  rows_to_delete.append(record_id)
387
341
 
388
342
  if rows_to_delete: