kiln-ai 0.21.0__py3-none-any.whl → 0.22.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/extractors/litellm_extractor.py +52 -32
- kiln_ai/adapters/extractors/test_litellm_extractor.py +169 -71
- kiln_ai/adapters/ml_embedding_model_list.py +330 -28
- kiln_ai/adapters/ml_model_list.py +503 -23
- kiln_ai/adapters/model_adapters/litellm_adapter.py +39 -8
- kiln_ai/adapters/model_adapters/test_litellm_adapter.py +78 -0
- kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py +119 -5
- kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +9 -3
- kiln_ai/adapters/model_adapters/test_structured_output.py +6 -9
- kiln_ai/adapters/test_ml_embedding_model_list.py +89 -279
- kiln_ai/adapters/test_ml_model_list.py +0 -10
- kiln_ai/adapters/vector_store/lancedb_adapter.py +24 -70
- kiln_ai/adapters/vector_store/lancedb_helpers.py +101 -0
- kiln_ai/adapters/vector_store/test_lancedb_adapter.py +9 -16
- kiln_ai/adapters/vector_store/test_lancedb_helpers.py +142 -0
- kiln_ai/adapters/vector_store_loaders/__init__.py +0 -0
- kiln_ai/adapters/vector_store_loaders/test_lancedb_loader.py +282 -0
- kiln_ai/adapters/vector_store_loaders/test_vector_store_loader.py +544 -0
- kiln_ai/adapters/vector_store_loaders/vector_store_loader.py +91 -0
- kiln_ai/datamodel/basemodel.py +31 -3
- kiln_ai/datamodel/external_tool_server.py +206 -54
- kiln_ai/datamodel/extraction.py +14 -0
- kiln_ai/datamodel/task.py +5 -0
- kiln_ai/datamodel/task_output.py +41 -11
- kiln_ai/datamodel/test_attachment.py +3 -3
- kiln_ai/datamodel/test_basemodel.py +269 -13
- kiln_ai/datamodel/test_datasource.py +50 -0
- kiln_ai/datamodel/test_external_tool_server.py +534 -152
- kiln_ai/datamodel/test_extraction_model.py +31 -0
- kiln_ai/datamodel/test_task.py +35 -1
- kiln_ai/datamodel/test_tool_id.py +106 -1
- kiln_ai/datamodel/tool_id.py +49 -0
- kiln_ai/tools/base_tool.py +30 -6
- kiln_ai/tools/built_in_tools/math_tools.py +12 -4
- kiln_ai/tools/kiln_task_tool.py +162 -0
- kiln_ai/tools/mcp_server_tool.py +7 -5
- kiln_ai/tools/mcp_session_manager.py +50 -24
- kiln_ai/tools/rag_tools.py +17 -6
- kiln_ai/tools/test_kiln_task_tool.py +527 -0
- kiln_ai/tools/test_mcp_server_tool.py +4 -15
- kiln_ai/tools/test_mcp_session_manager.py +186 -226
- kiln_ai/tools/test_rag_tools.py +86 -5
- kiln_ai/tools/test_tool_registry.py +199 -5
- kiln_ai/tools/tool_registry.py +49 -17
- kiln_ai/utils/filesystem.py +4 -4
- kiln_ai/utils/open_ai_types.py +19 -2
- kiln_ai/utils/pdf_utils.py +21 -0
- kiln_ai/utils/test_open_ai_types.py +88 -12
- kiln_ai/utils/test_pdf_utils.py +14 -1
- {kiln_ai-0.21.0.dist-info → kiln_ai-0.22.1.dist-info}/METADATA +79 -1
- {kiln_ai-0.21.0.dist-info → kiln_ai-0.22.1.dist-info}/RECORD +53 -45
- {kiln_ai-0.21.0.dist-info → kiln_ai-0.22.1.dist-info}/WHEEL +0 -0
- {kiln_ai-0.21.0.dist-info → kiln_ai-0.22.1.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
from typing import Any, Dict, List, Literal
|
|
2
|
+
|
|
3
|
+
from llama_index.core.schema import NodeRelationship, RelatedNodeInfo, TextNode
|
|
4
|
+
from llama_index.vector_stores.lancedb import LanceDBVectorStore
|
|
5
|
+
|
|
6
|
+
from kiln_ai.datamodel.vector_store import (
|
|
7
|
+
VectorStoreConfig,
|
|
8
|
+
VectorStoreType,
|
|
9
|
+
raise_exhaustive_enum_error,
|
|
10
|
+
)
|
|
11
|
+
from kiln_ai.utils.uuid import string_to_uuid
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def store_type_to_lancedb_query_type(
|
|
15
|
+
store_type: VectorStoreType,
|
|
16
|
+
) -> Literal["fts", "hybrid", "vector"]:
|
|
17
|
+
match store_type:
|
|
18
|
+
case VectorStoreType.LANCE_DB_FTS:
|
|
19
|
+
return "fts"
|
|
20
|
+
case VectorStoreType.LANCE_DB_HYBRID:
|
|
21
|
+
return "hybrid"
|
|
22
|
+
case VectorStoreType.LANCE_DB_VECTOR:
|
|
23
|
+
return "vector"
|
|
24
|
+
case _:
|
|
25
|
+
raise_exhaustive_enum_error(store_type)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def lancedb_construct_from_config(
|
|
29
|
+
vector_store_config: VectorStoreConfig,
|
|
30
|
+
uri: str,
|
|
31
|
+
**extra_params: Any,
|
|
32
|
+
) -> LanceDBVectorStore:
|
|
33
|
+
"""Construct a LanceDBVectorStore from a VectorStoreConfig."""
|
|
34
|
+
kwargs: Dict[str, Any] = {**extra_params}
|
|
35
|
+
if (
|
|
36
|
+
vector_store_config.lancedb_properties.nprobes is not None
|
|
37
|
+
and "nprobes" not in kwargs
|
|
38
|
+
):
|
|
39
|
+
kwargs["nprobes"] = vector_store_config.lancedb_properties.nprobes
|
|
40
|
+
|
|
41
|
+
return LanceDBVectorStore(
|
|
42
|
+
mode="create",
|
|
43
|
+
query_type=store_type_to_lancedb_query_type(vector_store_config.store_type),
|
|
44
|
+
overfetch_factor=vector_store_config.lancedb_properties.overfetch_factor,
|
|
45
|
+
vector_column_name=vector_store_config.lancedb_properties.vector_column_name,
|
|
46
|
+
text_key=vector_store_config.lancedb_properties.text_key,
|
|
47
|
+
doc_id_key=vector_store_config.lancedb_properties.doc_id_key,
|
|
48
|
+
uri=uri,
|
|
49
|
+
**kwargs,
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def convert_to_llama_index_node(
|
|
54
|
+
document_id: str,
|
|
55
|
+
chunk_idx: int,
|
|
56
|
+
node_id: str,
|
|
57
|
+
text: str,
|
|
58
|
+
vector: List[float],
|
|
59
|
+
) -> TextNode:
|
|
60
|
+
return TextNode(
|
|
61
|
+
id_=node_id,
|
|
62
|
+
text=text,
|
|
63
|
+
embedding=vector,
|
|
64
|
+
metadata={
|
|
65
|
+
# metadata is populated by some internal llama_index logic
|
|
66
|
+
# that uses for example the source_node relationship
|
|
67
|
+
"kiln_doc_id": document_id,
|
|
68
|
+
"kiln_chunk_idx": chunk_idx,
|
|
69
|
+
#
|
|
70
|
+
# llama_index lancedb vector store automatically sets these metadata:
|
|
71
|
+
# "doc_id": "UUID node_id of the Source Node relationship",
|
|
72
|
+
# "document_id": "UUID node_id of the Source Node relationship",
|
|
73
|
+
# "ref_doc_id": "UUID node_id of the Source Node relationship"
|
|
74
|
+
#
|
|
75
|
+
# llama_index file loaders set these metadata, which would be useful to also support:
|
|
76
|
+
# "creation_date": "2025-09-03",
|
|
77
|
+
# "file_name": "file.pdf",
|
|
78
|
+
# "file_path": "/absolute/path/to/the/file.pdf",
|
|
79
|
+
# "file_size": 395154,
|
|
80
|
+
# "file_type": "application\/pdf",
|
|
81
|
+
# "last_modified_date": "2025-09-03",
|
|
82
|
+
# "page_label": "1",
|
|
83
|
+
},
|
|
84
|
+
relationships={
|
|
85
|
+
# when using the llama_index loaders, llama_index groups Nodes under Documents
|
|
86
|
+
# and relationships point to the Document (which is also a Node), which confusingly
|
|
87
|
+
# enough does not map to an actual file (for a PDF, a Document is a page of the PDF)
|
|
88
|
+
# the Document structure is not something that is persisted, so it is fine here
|
|
89
|
+
# if we have a relationship to a node_id that does not exist in the db
|
|
90
|
+
NodeRelationship.SOURCE: RelatedNodeInfo(
|
|
91
|
+
node_id=document_id,
|
|
92
|
+
node_type="1",
|
|
93
|
+
metadata={},
|
|
94
|
+
),
|
|
95
|
+
},
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def deterministic_chunk_id(document_id: str, chunk_idx: int) -> str:
|
|
100
|
+
# the id_ of the Node must be a UUID string, otherwise llama_index / LanceDB fails downstream
|
|
101
|
+
return str(string_to_uuid(f"{document_id}::{chunk_idx}"))
|
|
@@ -17,6 +17,7 @@ from kiln_ai.adapters.vector_store.base_vector_store_adapter import (
|
|
|
17
17
|
VectorStoreQuery,
|
|
18
18
|
)
|
|
19
19
|
from kiln_ai.adapters.vector_store.lancedb_adapter import LanceDBAdapter
|
|
20
|
+
from kiln_ai.adapters.vector_store.lancedb_helpers import deterministic_chunk_id
|
|
20
21
|
from kiln_ai.adapters.vector_store.vector_store_registry import (
|
|
21
22
|
vector_store_adapter_for_config,
|
|
22
23
|
)
|
|
@@ -925,9 +926,7 @@ async def test_get_nodes_by_ids_functionality(
|
|
|
925
926
|
await adapter.add_chunks_with_embeddings([mock_chunked_documents[0]]) # doc_001
|
|
926
927
|
|
|
927
928
|
# Test getting nodes by IDs - compute expected IDs
|
|
928
|
-
expected_ids = [
|
|
929
|
-
adapter.compute_deterministic_chunk_id("doc_001", i) for i in range(4)
|
|
930
|
-
]
|
|
929
|
+
expected_ids = [deterministic_chunk_id("doc_001", i) for i in range(4)]
|
|
931
930
|
|
|
932
931
|
# Get nodes by IDs
|
|
933
932
|
retrieved_nodes = await adapter.get_nodes_by_ids(expected_ids)
|
|
@@ -943,7 +942,7 @@ async def test_get_nodes_by_ids_functionality(
|
|
|
943
942
|
assert len(node.get_content()) > 0
|
|
944
943
|
|
|
945
944
|
# Test with non-existent IDs
|
|
946
|
-
fake_ids = [
|
|
945
|
+
fake_ids = [deterministic_chunk_id("fake_doc", i) for i in range(2)]
|
|
947
946
|
retrieved_fake = await adapter.get_nodes_by_ids(fake_ids)
|
|
948
947
|
assert len(retrieved_fake) == 0
|
|
949
948
|
|
|
@@ -1019,7 +1018,7 @@ async def test_uuid_scheme_retrieval_and_node_properties(
|
|
|
1019
1018
|
# Test the UUID scheme: document_id::chunk_idx
|
|
1020
1019
|
for chunk_idx in range(4):
|
|
1021
1020
|
# Compute expected ID using the same scheme as the adapter
|
|
1022
|
-
expected_id =
|
|
1021
|
+
expected_id = deterministic_chunk_id("doc_001", chunk_idx)
|
|
1023
1022
|
|
|
1024
1023
|
# Retrieve the specific node by ID
|
|
1025
1024
|
retrieved_nodes = await adapter.get_nodes_by_ids([expected_id])
|
|
@@ -1053,7 +1052,7 @@ async def test_uuid_scheme_retrieval_and_node_properties(
|
|
|
1053
1052
|
|
|
1054
1053
|
# Test retrieval of doc_002 chunks
|
|
1055
1054
|
for chunk_idx in range(4):
|
|
1056
|
-
expected_id =
|
|
1055
|
+
expected_id = deterministic_chunk_id("doc_002", chunk_idx)
|
|
1057
1056
|
retrieved_nodes = await adapter.get_nodes_by_ids([expected_id])
|
|
1058
1057
|
assert len(retrieved_nodes) == 1
|
|
1059
1058
|
|
|
@@ -1080,25 +1079,19 @@ async def test_deterministic_chunk_id_consistency(
|
|
|
1080
1079
|
create_rag_config_factory,
|
|
1081
1080
|
):
|
|
1082
1081
|
"""Test that the deterministic chunk ID generation is consistent."""
|
|
1083
|
-
rag_config = create_rag_config_factory(fts_vector_store_config, embedding_config)
|
|
1084
|
-
|
|
1085
|
-
adapter = LanceDBAdapter(
|
|
1086
|
-
rag_config,
|
|
1087
|
-
fts_vector_store_config,
|
|
1088
|
-
)
|
|
1089
1082
|
|
|
1090
1083
|
# Test that the same document_id and chunk_idx always produce the same UUID
|
|
1091
1084
|
doc_id = "test_doc_123"
|
|
1092
1085
|
chunk_idx = 5
|
|
1093
1086
|
|
|
1094
|
-
id1 =
|
|
1095
|
-
id2 =
|
|
1087
|
+
id1 = deterministic_chunk_id(doc_id, chunk_idx)
|
|
1088
|
+
id2 = deterministic_chunk_id(doc_id, chunk_idx)
|
|
1096
1089
|
|
|
1097
1090
|
assert id1 == id2
|
|
1098
1091
|
|
|
1099
1092
|
# Test that different inputs produce different UUIDs
|
|
1100
|
-
id3 =
|
|
1101
|
-
id4 =
|
|
1093
|
+
id3 = deterministic_chunk_id(doc_id, chunk_idx + 1)
|
|
1094
|
+
id4 = deterministic_chunk_id(doc_id + "_different", chunk_idx)
|
|
1102
1095
|
|
|
1103
1096
|
assert id1 != id3
|
|
1104
1097
|
assert id1 != id4
|
|
@@ -0,0 +1,142 @@
|
|
|
1
|
+
from unittest.mock import patch
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
|
|
5
|
+
from kiln_ai.adapters.vector_store.lancedb_helpers import (
|
|
6
|
+
convert_to_llama_index_node,
|
|
7
|
+
deterministic_chunk_id,
|
|
8
|
+
lancedb_construct_from_config,
|
|
9
|
+
store_type_to_lancedb_query_type,
|
|
10
|
+
)
|
|
11
|
+
from kiln_ai.datamodel.vector_store import VectorStoreConfig, VectorStoreType
|
|
12
|
+
from kiln_ai.utils.uuid import string_to_uuid
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class _FakeLanceDBVectorStore:
|
|
16
|
+
def __init__(self, **kwargs):
|
|
17
|
+
self.kwargs = kwargs
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _base_properties(nprobes: int | None = None) -> dict[str, str | int | float | None]:
|
|
21
|
+
props: dict[str, str | int | float | None] = {
|
|
22
|
+
"similarity_top_k": 5,
|
|
23
|
+
"overfetch_factor": 2,
|
|
24
|
+
"vector_column_name": "vec",
|
|
25
|
+
"text_key": "text",
|
|
26
|
+
"doc_id_key": "doc_id",
|
|
27
|
+
}
|
|
28
|
+
if nprobes is not None:
|
|
29
|
+
props["nprobes"] = nprobes
|
|
30
|
+
return props
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _make_config(
|
|
34
|
+
store_type: VectorStoreType, nprobes: int | None = None
|
|
35
|
+
) -> VectorStoreConfig:
|
|
36
|
+
return VectorStoreConfig(
|
|
37
|
+
name="test_store",
|
|
38
|
+
description=None,
|
|
39
|
+
store_type=store_type,
|
|
40
|
+
properties=_base_properties(nprobes),
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def test_store_type_to_lancedb_query_type_mapping():
|
|
45
|
+
assert store_type_to_lancedb_query_type(VectorStoreType.LANCE_DB_FTS) == "fts"
|
|
46
|
+
assert store_type_to_lancedb_query_type(VectorStoreType.LANCE_DB_HYBRID) == "hybrid"
|
|
47
|
+
assert store_type_to_lancedb_query_type(VectorStoreType.LANCE_DB_VECTOR) == "vector"
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def test_store_type_to_lancedb_query_type_unsupported_raises():
|
|
51
|
+
with pytest.raises(Exception):
|
|
52
|
+
store_type_to_lancedb_query_type("unsupported") # type: ignore[arg-type]
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def test_lancedb_construct_from_config_includes_nprobes():
|
|
56
|
+
with patch(
|
|
57
|
+
"kiln_ai.adapters.vector_store.lancedb_helpers.LanceDBVectorStore",
|
|
58
|
+
new=_FakeLanceDBVectorStore,
|
|
59
|
+
):
|
|
60
|
+
cfg = _make_config(VectorStoreType.LANCE_DB_VECTOR, nprobes=7)
|
|
61
|
+
|
|
62
|
+
result = lancedb_construct_from_config(
|
|
63
|
+
vector_store_config=cfg,
|
|
64
|
+
uri="memory://",
|
|
65
|
+
api_key="k",
|
|
66
|
+
region="r",
|
|
67
|
+
table_name="t",
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
assert isinstance(result, _FakeLanceDBVectorStore)
|
|
71
|
+
kwargs = result.kwargs
|
|
72
|
+
|
|
73
|
+
assert kwargs["mode"] == "create"
|
|
74
|
+
assert kwargs["uri"] == "memory://"
|
|
75
|
+
assert kwargs["query_type"] == "vector"
|
|
76
|
+
assert kwargs["overfetch_factor"] == 2
|
|
77
|
+
assert kwargs["vector_column_name"] == "vec"
|
|
78
|
+
assert kwargs["text_key"] == "text"
|
|
79
|
+
assert kwargs["doc_id_key"] == "doc_id"
|
|
80
|
+
assert kwargs["api_key"] == "k"
|
|
81
|
+
assert kwargs["region"] == "r"
|
|
82
|
+
assert kwargs["table_name"] == "t"
|
|
83
|
+
# extra optional kwarg present when provided
|
|
84
|
+
assert kwargs["nprobes"] == 7
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def test_lancedb_construct_from_config_omits_nprobes_when_none():
|
|
88
|
+
with patch(
|
|
89
|
+
"kiln_ai.adapters.vector_store.lancedb_helpers.LanceDBVectorStore",
|
|
90
|
+
new=_FakeLanceDBVectorStore,
|
|
91
|
+
):
|
|
92
|
+
cfg = _make_config(VectorStoreType.LANCE_DB_FTS, nprobes=None)
|
|
93
|
+
|
|
94
|
+
result = lancedb_construct_from_config(
|
|
95
|
+
vector_store_config=cfg,
|
|
96
|
+
uri="memory://",
|
|
97
|
+
api_key=None,
|
|
98
|
+
region=None,
|
|
99
|
+
table_name=None,
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
assert isinstance(result, _FakeLanceDBVectorStore)
|
|
103
|
+
kwargs = result.kwargs
|
|
104
|
+
|
|
105
|
+
assert kwargs["query_type"] == "fts"
|
|
106
|
+
assert "nprobes" not in kwargs
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def test_convert_to_llama_index_node_builds_expected_structure():
|
|
110
|
+
node = convert_to_llama_index_node(
|
|
111
|
+
document_id="doc-123",
|
|
112
|
+
chunk_idx=0,
|
|
113
|
+
node_id="11111111-1111-5111-8111-111111111111",
|
|
114
|
+
text="hello",
|
|
115
|
+
vector=[0.1, 0.2],
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
assert node.id_ == "11111111-1111-5111-8111-111111111111"
|
|
119
|
+
assert node.text == "hello"
|
|
120
|
+
assert node.embedding == [0.1, 0.2]
|
|
121
|
+
assert node.metadata["kiln_doc_id"] == "doc-123"
|
|
122
|
+
assert node.metadata["kiln_chunk_idx"] == 0
|
|
123
|
+
|
|
124
|
+
# relationship exists and points to the source document id
|
|
125
|
+
from llama_index.core.schema import NodeRelationship, RelatedNodeInfo
|
|
126
|
+
|
|
127
|
+
assert NodeRelationship.SOURCE in node.relationships
|
|
128
|
+
related = node.relationships[NodeRelationship.SOURCE]
|
|
129
|
+
assert isinstance(related, RelatedNodeInfo)
|
|
130
|
+
assert related.node_id == "doc-123"
|
|
131
|
+
assert related.node_type == "1"
|
|
132
|
+
assert isinstance(related.metadata, dict)
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def test_deterministic_chunk_id_uses_uuid_v5_namespace():
|
|
136
|
+
doc_id = "doc-abc"
|
|
137
|
+
idx = 3
|
|
138
|
+
expected = str(string_to_uuid(f"{doc_id}::{idx}"))
|
|
139
|
+
assert deterministic_chunk_id(doc_id, idx) == expected
|
|
140
|
+
|
|
141
|
+
# call again to ensure the same value is returned
|
|
142
|
+
assert deterministic_chunk_id(doc_id, idx) == expected
|
|
File without changes
|
|
@@ -0,0 +1,282 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import random
|
|
3
|
+
import time
|
|
4
|
+
import uuid
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
|
|
7
|
+
import pytest
|
|
8
|
+
from pydantic import BaseModel, Field
|
|
9
|
+
|
|
10
|
+
from kiln_ai.adapters.vector_store.lancedb_adapter import lancedb_construct_from_config
|
|
11
|
+
from kiln_ai.adapters.vector_store_loaders.vector_store_loader import VectorStoreLoader
|
|
12
|
+
from kiln_ai.datamodel.chunk import Chunk, ChunkedDocument
|
|
13
|
+
from kiln_ai.datamodel.datamodel_enums import KilnMimeType
|
|
14
|
+
from kiln_ai.datamodel.embedding import ChunkEmbeddings, Embedding
|
|
15
|
+
from kiln_ai.datamodel.extraction import (
|
|
16
|
+
Document,
|
|
17
|
+
Extraction,
|
|
18
|
+
ExtractionSource,
|
|
19
|
+
FileInfo,
|
|
20
|
+
Kind,
|
|
21
|
+
)
|
|
22
|
+
from kiln_ai.datamodel.project import Project
|
|
23
|
+
from kiln_ai.datamodel.rag import RagConfig
|
|
24
|
+
from kiln_ai.datamodel.vector_store import VectorStoreConfig, VectorStoreType
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass
|
|
28
|
+
class DocWithChunks:
|
|
29
|
+
document: Document
|
|
30
|
+
extraction: Extraction
|
|
31
|
+
chunked_document: ChunkedDocument
|
|
32
|
+
chunked_embeddings: ChunkEmbeddings
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def lorem_ipsum(n: int) -> str:
|
|
36
|
+
return " ".join(
|
|
37
|
+
["Lorem ipsum dolor sit amet, consectetur adipiscing elit." for _ in range(n)]
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@pytest.fixture
|
|
42
|
+
def mock_chunks_factory(mock_attachment_factory):
|
|
43
|
+
def fn(
|
|
44
|
+
project: Project,
|
|
45
|
+
rag_config: RagConfig,
|
|
46
|
+
num_chunks: int = 1,
|
|
47
|
+
text: str | None = None,
|
|
48
|
+
) -> DocWithChunks:
|
|
49
|
+
doc = Document(
|
|
50
|
+
id=f"doc_{uuid.uuid4()}",
|
|
51
|
+
name="Test Document",
|
|
52
|
+
description="Test Document",
|
|
53
|
+
original_file=FileInfo(
|
|
54
|
+
filename="test.pdf",
|
|
55
|
+
size=100,
|
|
56
|
+
mime_type="application/pdf",
|
|
57
|
+
attachment=mock_attachment_factory(KilnMimeType.PDF),
|
|
58
|
+
),
|
|
59
|
+
kind=Kind.DOCUMENT,
|
|
60
|
+
parent=project,
|
|
61
|
+
)
|
|
62
|
+
doc.save_to_file()
|
|
63
|
+
|
|
64
|
+
extraction = Extraction(
|
|
65
|
+
source=ExtractionSource.PROCESSED,
|
|
66
|
+
extractor_config_id=rag_config.extractor_config_id,
|
|
67
|
+
output=mock_attachment_factory(KilnMimeType.PDF),
|
|
68
|
+
parent=doc,
|
|
69
|
+
)
|
|
70
|
+
extraction.save_to_file()
|
|
71
|
+
|
|
72
|
+
chunks = [
|
|
73
|
+
Chunk(
|
|
74
|
+
content=mock_attachment_factory(
|
|
75
|
+
KilnMimeType.TXT, text=f"text-{i}: {text or lorem_ipsum(10)}"
|
|
76
|
+
)
|
|
77
|
+
)
|
|
78
|
+
for i in range(num_chunks)
|
|
79
|
+
]
|
|
80
|
+
chunked_document = ChunkedDocument(
|
|
81
|
+
chunks=chunks,
|
|
82
|
+
chunker_config_id=rag_config.chunker_config_id,
|
|
83
|
+
parent=extraction,
|
|
84
|
+
)
|
|
85
|
+
chunked_document.save_to_file()
|
|
86
|
+
chunked_embeddings = ChunkEmbeddings(
|
|
87
|
+
embeddings=[
|
|
88
|
+
Embedding(vector=[i + 0.1, i + 0.2, i + 0.3, i + 0.4, i + 0.5])
|
|
89
|
+
for i in range(num_chunks)
|
|
90
|
+
],
|
|
91
|
+
embedding_config_id=rag_config.embedding_config_id,
|
|
92
|
+
parent=chunked_document,
|
|
93
|
+
)
|
|
94
|
+
chunked_embeddings.save_to_file()
|
|
95
|
+
return DocWithChunks(
|
|
96
|
+
document=doc,
|
|
97
|
+
extraction=extraction,
|
|
98
|
+
chunked_document=chunked_document,
|
|
99
|
+
chunked_embeddings=chunked_embeddings,
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
return fn
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
@pytest.fixture
|
|
106
|
+
def mock_project(tmp_path):
|
|
107
|
+
project = Project(
|
|
108
|
+
name="Test Project", path=tmp_path / "test_project" / "project.kiln"
|
|
109
|
+
)
|
|
110
|
+
project.save_to_file()
|
|
111
|
+
return project
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
@pytest.fixture
|
|
115
|
+
def rag_config_factory(mock_project):
|
|
116
|
+
def fn(vector_store_config_id: str) -> RagConfig:
|
|
117
|
+
rag_config = RagConfig(
|
|
118
|
+
name="Test Rag Config",
|
|
119
|
+
parent=mock_project,
|
|
120
|
+
vector_store_config_id=vector_store_config_id,
|
|
121
|
+
tool_name="test_tool",
|
|
122
|
+
tool_description="test_description",
|
|
123
|
+
extractor_config_id="test_extractor",
|
|
124
|
+
chunker_config_id="test_chunker",
|
|
125
|
+
embedding_config_id="test_embedding",
|
|
126
|
+
)
|
|
127
|
+
rag_config.save_to_file()
|
|
128
|
+
return rag_config
|
|
129
|
+
|
|
130
|
+
return fn
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
@pytest.fixture
|
|
134
|
+
def vector_store_config_factory(mock_project):
|
|
135
|
+
def fn(vector_store_type: VectorStoreType) -> VectorStoreConfig:
|
|
136
|
+
match vector_store_type:
|
|
137
|
+
case VectorStoreType.LANCE_DB_FTS:
|
|
138
|
+
vector_store_config = VectorStoreConfig(
|
|
139
|
+
name="Test Vector Store Config FTS",
|
|
140
|
+
parent=mock_project,
|
|
141
|
+
store_type=VectorStoreType.LANCE_DB_FTS,
|
|
142
|
+
properties={
|
|
143
|
+
"similarity_top_k": 10,
|
|
144
|
+
"overfetch_factor": 20,
|
|
145
|
+
"vector_column_name": "vector",
|
|
146
|
+
"text_key": "text",
|
|
147
|
+
"doc_id_key": "doc_id",
|
|
148
|
+
},
|
|
149
|
+
)
|
|
150
|
+
vector_store_config.save_to_file()
|
|
151
|
+
return vector_store_config
|
|
152
|
+
case VectorStoreType.LANCE_DB_VECTOR:
|
|
153
|
+
vector_store_config = VectorStoreConfig(
|
|
154
|
+
name="Test Vector Store Config KNN",
|
|
155
|
+
parent=mock_project,
|
|
156
|
+
store_type=VectorStoreType.LANCE_DB_VECTOR,
|
|
157
|
+
properties={
|
|
158
|
+
"similarity_top_k": 10,
|
|
159
|
+
"overfetch_factor": 20,
|
|
160
|
+
"vector_column_name": "vector",
|
|
161
|
+
"text_key": "text",
|
|
162
|
+
"doc_id_key": "doc_id",
|
|
163
|
+
"nprobes": 10,
|
|
164
|
+
},
|
|
165
|
+
)
|
|
166
|
+
vector_store_config.save_to_file()
|
|
167
|
+
return vector_store_config
|
|
168
|
+
case VectorStoreType.LANCE_DB_HYBRID:
|
|
169
|
+
vector_store_config = VectorStoreConfig(
|
|
170
|
+
name="Test Vector Store Config Hybrid",
|
|
171
|
+
parent=mock_project,
|
|
172
|
+
store_type=VectorStoreType.LANCE_DB_HYBRID,
|
|
173
|
+
properties={
|
|
174
|
+
"similarity_top_k": 10,
|
|
175
|
+
"nprobes": 10,
|
|
176
|
+
"overfetch_factor": 20,
|
|
177
|
+
"vector_column_name": "vector",
|
|
178
|
+
"text_key": "text",
|
|
179
|
+
"doc_id_key": "doc_id",
|
|
180
|
+
},
|
|
181
|
+
)
|
|
182
|
+
vector_store_config.save_to_file()
|
|
183
|
+
return vector_store_config
|
|
184
|
+
case _:
|
|
185
|
+
raise ValueError(f"Invalid vector store type: {vector_store_type}")
|
|
186
|
+
|
|
187
|
+
return fn
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
class LanceDBCloudEnvVars(BaseModel):
|
|
191
|
+
uri: str = Field("LANCE_DB_URI")
|
|
192
|
+
api_key: str = Field("LANCE_DB_API_KEY")
|
|
193
|
+
region: str = Field("LANCE_DB_REGION")
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def lancedb_cloud_env_vars() -> LanceDBCloudEnvVars:
|
|
197
|
+
lancedb_uri = os.getenv("LANCE_DB_URI")
|
|
198
|
+
assert lancedb_uri is not None, (
|
|
199
|
+
"LANCE_DB_URI is not set - test requires lancedb cloud"
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
lancedb_api_key = os.getenv("LANCE_DB_API_KEY")
|
|
203
|
+
assert lancedb_api_key is not None, (
|
|
204
|
+
"LANCE_DB_API_KEY is not set - test requires lancedb cloud"
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
lancedb_region = os.getenv("LANCE_DB_REGION")
|
|
208
|
+
assert lancedb_region is not None, (
|
|
209
|
+
"LANCE_DB_REGION is not set - test requires lancedb cloud"
|
|
210
|
+
)
|
|
211
|
+
return LanceDBCloudEnvVars(
|
|
212
|
+
uri=lancedb_uri,
|
|
213
|
+
api_key=lancedb_api_key,
|
|
214
|
+
region=lancedb_region,
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
@pytest.mark.parametrize(
|
|
219
|
+
"vector_store_type",
|
|
220
|
+
[
|
|
221
|
+
VectorStoreType.LANCE_DB_FTS,
|
|
222
|
+
VectorStoreType.LANCE_DB_VECTOR,
|
|
223
|
+
VectorStoreType.LANCE_DB_HYBRID,
|
|
224
|
+
],
|
|
225
|
+
)
|
|
226
|
+
@pytest.mark.paid
|
|
227
|
+
async def test_lancedb_loader_insert_nodes_lancedb_cloud(
|
|
228
|
+
mock_project,
|
|
229
|
+
mock_chunks_factory,
|
|
230
|
+
rag_config_factory,
|
|
231
|
+
vector_store_type,
|
|
232
|
+
vector_store_config_factory,
|
|
233
|
+
):
|
|
234
|
+
lancedb_cloud_config = lancedb_cloud_env_vars()
|
|
235
|
+
|
|
236
|
+
vector_store_config = vector_store_config_factory(vector_store_type)
|
|
237
|
+
rag_config = rag_config_factory(vector_store_config.id)
|
|
238
|
+
|
|
239
|
+
# init lancedb store
|
|
240
|
+
now = time.time()
|
|
241
|
+
table_name = f"test_lancedb_loader_insert_nodes_{vector_store_type.value}_{now}"
|
|
242
|
+
lancedb_store = lancedb_construct_from_config(
|
|
243
|
+
vector_store_config=vector_store_config,
|
|
244
|
+
uri=lancedb_cloud_config.uri,
|
|
245
|
+
api_key=lancedb_cloud_config.api_key,
|
|
246
|
+
region=lancedb_cloud_config.region,
|
|
247
|
+
table_name=table_name,
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
loader = VectorStoreLoader(
|
|
251
|
+
project=mock_project,
|
|
252
|
+
rag_config=rag_config,
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
# create nodes
|
|
256
|
+
doc_count = 10
|
|
257
|
+
node_count = 0
|
|
258
|
+
for i in range(doc_count):
|
|
259
|
+
nodes_to_add = random.randint(1, 20)
|
|
260
|
+
# create mock docs, extractions, chunked documents, and chunk embeddings and persist
|
|
261
|
+
mock_chunks_factory(
|
|
262
|
+
mock_project,
|
|
263
|
+
rag_config,
|
|
264
|
+
num_chunks=nodes_to_add,
|
|
265
|
+
text=f"Document {i}",
|
|
266
|
+
)
|
|
267
|
+
node_count += nodes_to_add
|
|
268
|
+
|
|
269
|
+
assert node_count > 0, "No mock nodes were created"
|
|
270
|
+
|
|
271
|
+
# insert docs
|
|
272
|
+
batch_size = 100
|
|
273
|
+
async for batch in loader.iter_llama_index_nodes(batch_size=batch_size):
|
|
274
|
+
await lancedb_store.async_add(batch)
|
|
275
|
+
|
|
276
|
+
# check if docs are inserted
|
|
277
|
+
table = lancedb_store.table
|
|
278
|
+
assert table is not None
|
|
279
|
+
row_count = table.count_rows()
|
|
280
|
+
assert row_count == node_count, (
|
|
281
|
+
f"Expected {node_count} rows (one for each node), got {row_count} instead"
|
|
282
|
+
)
|