kiln-ai 0.22.0__py3-none-any.whl → 0.22.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/model_adapters/litellm_adapter.py +6 -2
- kiln_ai/adapters/vector_store/lancedb_adapter.py +24 -70
- kiln_ai/adapters/vector_store/lancedb_helpers.py +101 -0
- kiln_ai/adapters/vector_store/test_lancedb_adapter.py +9 -16
- kiln_ai/adapters/vector_store/test_lancedb_helpers.py +142 -0
- kiln_ai/adapters/vector_store_loaders/__init__.py +0 -0
- kiln_ai/adapters/vector_store_loaders/test_lancedb_loader.py +282 -0
- kiln_ai/adapters/vector_store_loaders/test_vector_store_loader.py +544 -0
- kiln_ai/adapters/vector_store_loaders/vector_store_loader.py +91 -0
- kiln_ai/datamodel/tool_id.py +13 -0
- kiln_ai/tools/base_tool.py +18 -3
- kiln_ai/tools/kiln_task_tool.py +6 -2
- kiln_ai/tools/mcp_server_tool.py +6 -4
- kiln_ai/tools/rag_tools.py +7 -3
- {kiln_ai-0.22.0.dist-info → kiln_ai-0.22.1.dist-info}/METADATA +77 -1
- {kiln_ai-0.22.0.dist-info → kiln_ai-0.22.1.dist-info}/RECORD +18 -12
- {kiln_ai-0.22.0.dist-info → kiln_ai-0.22.1.dist-info}/WHEEL +0 -0
- {kiln_ai-0.22.0.dist-info → kiln_ai-0.22.1.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -0,0 +1,282 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import random
|
|
3
|
+
import time
|
|
4
|
+
import uuid
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
|
|
7
|
+
import pytest
|
|
8
|
+
from pydantic import BaseModel, Field
|
|
9
|
+
|
|
10
|
+
from kiln_ai.adapters.vector_store.lancedb_adapter import lancedb_construct_from_config
|
|
11
|
+
from kiln_ai.adapters.vector_store_loaders.vector_store_loader import VectorStoreLoader
|
|
12
|
+
from kiln_ai.datamodel.chunk import Chunk, ChunkedDocument
|
|
13
|
+
from kiln_ai.datamodel.datamodel_enums import KilnMimeType
|
|
14
|
+
from kiln_ai.datamodel.embedding import ChunkEmbeddings, Embedding
|
|
15
|
+
from kiln_ai.datamodel.extraction import (
|
|
16
|
+
Document,
|
|
17
|
+
Extraction,
|
|
18
|
+
ExtractionSource,
|
|
19
|
+
FileInfo,
|
|
20
|
+
Kind,
|
|
21
|
+
)
|
|
22
|
+
from kiln_ai.datamodel.project import Project
|
|
23
|
+
from kiln_ai.datamodel.rag import RagConfig
|
|
24
|
+
from kiln_ai.datamodel.vector_store import VectorStoreConfig, VectorStoreType
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass
|
|
28
|
+
class DocWithChunks:
|
|
29
|
+
document: Document
|
|
30
|
+
extraction: Extraction
|
|
31
|
+
chunked_document: ChunkedDocument
|
|
32
|
+
chunked_embeddings: ChunkEmbeddings
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def lorem_ipsum(n: int) -> str:
|
|
36
|
+
return " ".join(
|
|
37
|
+
["Lorem ipsum dolor sit amet, consectetur adipiscing elit." for _ in range(n)]
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@pytest.fixture
|
|
42
|
+
def mock_chunks_factory(mock_attachment_factory):
|
|
43
|
+
def fn(
|
|
44
|
+
project: Project,
|
|
45
|
+
rag_config: RagConfig,
|
|
46
|
+
num_chunks: int = 1,
|
|
47
|
+
text: str | None = None,
|
|
48
|
+
) -> DocWithChunks:
|
|
49
|
+
doc = Document(
|
|
50
|
+
id=f"doc_{uuid.uuid4()}",
|
|
51
|
+
name="Test Document",
|
|
52
|
+
description="Test Document",
|
|
53
|
+
original_file=FileInfo(
|
|
54
|
+
filename="test.pdf",
|
|
55
|
+
size=100,
|
|
56
|
+
mime_type="application/pdf",
|
|
57
|
+
attachment=mock_attachment_factory(KilnMimeType.PDF),
|
|
58
|
+
),
|
|
59
|
+
kind=Kind.DOCUMENT,
|
|
60
|
+
parent=project,
|
|
61
|
+
)
|
|
62
|
+
doc.save_to_file()
|
|
63
|
+
|
|
64
|
+
extraction = Extraction(
|
|
65
|
+
source=ExtractionSource.PROCESSED,
|
|
66
|
+
extractor_config_id=rag_config.extractor_config_id,
|
|
67
|
+
output=mock_attachment_factory(KilnMimeType.PDF),
|
|
68
|
+
parent=doc,
|
|
69
|
+
)
|
|
70
|
+
extraction.save_to_file()
|
|
71
|
+
|
|
72
|
+
chunks = [
|
|
73
|
+
Chunk(
|
|
74
|
+
content=mock_attachment_factory(
|
|
75
|
+
KilnMimeType.TXT, text=f"text-{i}: {text or lorem_ipsum(10)}"
|
|
76
|
+
)
|
|
77
|
+
)
|
|
78
|
+
for i in range(num_chunks)
|
|
79
|
+
]
|
|
80
|
+
chunked_document = ChunkedDocument(
|
|
81
|
+
chunks=chunks,
|
|
82
|
+
chunker_config_id=rag_config.chunker_config_id,
|
|
83
|
+
parent=extraction,
|
|
84
|
+
)
|
|
85
|
+
chunked_document.save_to_file()
|
|
86
|
+
chunked_embeddings = ChunkEmbeddings(
|
|
87
|
+
embeddings=[
|
|
88
|
+
Embedding(vector=[i + 0.1, i + 0.2, i + 0.3, i + 0.4, i + 0.5])
|
|
89
|
+
for i in range(num_chunks)
|
|
90
|
+
],
|
|
91
|
+
embedding_config_id=rag_config.embedding_config_id,
|
|
92
|
+
parent=chunked_document,
|
|
93
|
+
)
|
|
94
|
+
chunked_embeddings.save_to_file()
|
|
95
|
+
return DocWithChunks(
|
|
96
|
+
document=doc,
|
|
97
|
+
extraction=extraction,
|
|
98
|
+
chunked_document=chunked_document,
|
|
99
|
+
chunked_embeddings=chunked_embeddings,
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
return fn
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
@pytest.fixture
|
|
106
|
+
def mock_project(tmp_path):
|
|
107
|
+
project = Project(
|
|
108
|
+
name="Test Project", path=tmp_path / "test_project" / "project.kiln"
|
|
109
|
+
)
|
|
110
|
+
project.save_to_file()
|
|
111
|
+
return project
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
@pytest.fixture
|
|
115
|
+
def rag_config_factory(mock_project):
|
|
116
|
+
def fn(vector_store_config_id: str) -> RagConfig:
|
|
117
|
+
rag_config = RagConfig(
|
|
118
|
+
name="Test Rag Config",
|
|
119
|
+
parent=mock_project,
|
|
120
|
+
vector_store_config_id=vector_store_config_id,
|
|
121
|
+
tool_name="test_tool",
|
|
122
|
+
tool_description="test_description",
|
|
123
|
+
extractor_config_id="test_extractor",
|
|
124
|
+
chunker_config_id="test_chunker",
|
|
125
|
+
embedding_config_id="test_embedding",
|
|
126
|
+
)
|
|
127
|
+
rag_config.save_to_file()
|
|
128
|
+
return rag_config
|
|
129
|
+
|
|
130
|
+
return fn
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
@pytest.fixture
|
|
134
|
+
def vector_store_config_factory(mock_project):
|
|
135
|
+
def fn(vector_store_type: VectorStoreType) -> VectorStoreConfig:
|
|
136
|
+
match vector_store_type:
|
|
137
|
+
case VectorStoreType.LANCE_DB_FTS:
|
|
138
|
+
vector_store_config = VectorStoreConfig(
|
|
139
|
+
name="Test Vector Store Config FTS",
|
|
140
|
+
parent=mock_project,
|
|
141
|
+
store_type=VectorStoreType.LANCE_DB_FTS,
|
|
142
|
+
properties={
|
|
143
|
+
"similarity_top_k": 10,
|
|
144
|
+
"overfetch_factor": 20,
|
|
145
|
+
"vector_column_name": "vector",
|
|
146
|
+
"text_key": "text",
|
|
147
|
+
"doc_id_key": "doc_id",
|
|
148
|
+
},
|
|
149
|
+
)
|
|
150
|
+
vector_store_config.save_to_file()
|
|
151
|
+
return vector_store_config
|
|
152
|
+
case VectorStoreType.LANCE_DB_VECTOR:
|
|
153
|
+
vector_store_config = VectorStoreConfig(
|
|
154
|
+
name="Test Vector Store Config KNN",
|
|
155
|
+
parent=mock_project,
|
|
156
|
+
store_type=VectorStoreType.LANCE_DB_VECTOR,
|
|
157
|
+
properties={
|
|
158
|
+
"similarity_top_k": 10,
|
|
159
|
+
"overfetch_factor": 20,
|
|
160
|
+
"vector_column_name": "vector",
|
|
161
|
+
"text_key": "text",
|
|
162
|
+
"doc_id_key": "doc_id",
|
|
163
|
+
"nprobes": 10,
|
|
164
|
+
},
|
|
165
|
+
)
|
|
166
|
+
vector_store_config.save_to_file()
|
|
167
|
+
return vector_store_config
|
|
168
|
+
case VectorStoreType.LANCE_DB_HYBRID:
|
|
169
|
+
vector_store_config = VectorStoreConfig(
|
|
170
|
+
name="Test Vector Store Config Hybrid",
|
|
171
|
+
parent=mock_project,
|
|
172
|
+
store_type=VectorStoreType.LANCE_DB_HYBRID,
|
|
173
|
+
properties={
|
|
174
|
+
"similarity_top_k": 10,
|
|
175
|
+
"nprobes": 10,
|
|
176
|
+
"overfetch_factor": 20,
|
|
177
|
+
"vector_column_name": "vector",
|
|
178
|
+
"text_key": "text",
|
|
179
|
+
"doc_id_key": "doc_id",
|
|
180
|
+
},
|
|
181
|
+
)
|
|
182
|
+
vector_store_config.save_to_file()
|
|
183
|
+
return vector_store_config
|
|
184
|
+
case _:
|
|
185
|
+
raise ValueError(f"Invalid vector store type: {vector_store_type}")
|
|
186
|
+
|
|
187
|
+
return fn
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
class LanceDBCloudEnvVars(BaseModel):
|
|
191
|
+
uri: str = Field("LANCE_DB_URI")
|
|
192
|
+
api_key: str = Field("LANCE_DB_API_KEY")
|
|
193
|
+
region: str = Field("LANCE_DB_REGION")
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def lancedb_cloud_env_vars() -> LanceDBCloudEnvVars:
|
|
197
|
+
lancedb_uri = os.getenv("LANCE_DB_URI")
|
|
198
|
+
assert lancedb_uri is not None, (
|
|
199
|
+
"LANCE_DB_URI is not set - test requires lancedb cloud"
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
lancedb_api_key = os.getenv("LANCE_DB_API_KEY")
|
|
203
|
+
assert lancedb_api_key is not None, (
|
|
204
|
+
"LANCE_DB_API_KEY is not set - test requires lancedb cloud"
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
lancedb_region = os.getenv("LANCE_DB_REGION")
|
|
208
|
+
assert lancedb_region is not None, (
|
|
209
|
+
"LANCE_DB_REGION is not set - test requires lancedb cloud"
|
|
210
|
+
)
|
|
211
|
+
return LanceDBCloudEnvVars(
|
|
212
|
+
uri=lancedb_uri,
|
|
213
|
+
api_key=lancedb_api_key,
|
|
214
|
+
region=lancedb_region,
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
@pytest.mark.parametrize(
|
|
219
|
+
"vector_store_type",
|
|
220
|
+
[
|
|
221
|
+
VectorStoreType.LANCE_DB_FTS,
|
|
222
|
+
VectorStoreType.LANCE_DB_VECTOR,
|
|
223
|
+
VectorStoreType.LANCE_DB_HYBRID,
|
|
224
|
+
],
|
|
225
|
+
)
|
|
226
|
+
@pytest.mark.paid
|
|
227
|
+
async def test_lancedb_loader_insert_nodes_lancedb_cloud(
|
|
228
|
+
mock_project,
|
|
229
|
+
mock_chunks_factory,
|
|
230
|
+
rag_config_factory,
|
|
231
|
+
vector_store_type,
|
|
232
|
+
vector_store_config_factory,
|
|
233
|
+
):
|
|
234
|
+
lancedb_cloud_config = lancedb_cloud_env_vars()
|
|
235
|
+
|
|
236
|
+
vector_store_config = vector_store_config_factory(vector_store_type)
|
|
237
|
+
rag_config = rag_config_factory(vector_store_config.id)
|
|
238
|
+
|
|
239
|
+
# init lancedb store
|
|
240
|
+
now = time.time()
|
|
241
|
+
table_name = f"test_lancedb_loader_insert_nodes_{vector_store_type.value}_{now}"
|
|
242
|
+
lancedb_store = lancedb_construct_from_config(
|
|
243
|
+
vector_store_config=vector_store_config,
|
|
244
|
+
uri=lancedb_cloud_config.uri,
|
|
245
|
+
api_key=lancedb_cloud_config.api_key,
|
|
246
|
+
region=lancedb_cloud_config.region,
|
|
247
|
+
table_name=table_name,
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
loader = VectorStoreLoader(
|
|
251
|
+
project=mock_project,
|
|
252
|
+
rag_config=rag_config,
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
# create nodes
|
|
256
|
+
doc_count = 10
|
|
257
|
+
node_count = 0
|
|
258
|
+
for i in range(doc_count):
|
|
259
|
+
nodes_to_add = random.randint(1, 20)
|
|
260
|
+
# create mock docs, extractions, chunked documents, and chunk embeddings and persist
|
|
261
|
+
mock_chunks_factory(
|
|
262
|
+
mock_project,
|
|
263
|
+
rag_config,
|
|
264
|
+
num_chunks=nodes_to_add,
|
|
265
|
+
text=f"Document {i}",
|
|
266
|
+
)
|
|
267
|
+
node_count += nodes_to_add
|
|
268
|
+
|
|
269
|
+
assert node_count > 0, "No mock nodes were created"
|
|
270
|
+
|
|
271
|
+
# insert docs
|
|
272
|
+
batch_size = 100
|
|
273
|
+
async for batch in loader.iter_llama_index_nodes(batch_size=batch_size):
|
|
274
|
+
await lancedb_store.async_add(batch)
|
|
275
|
+
|
|
276
|
+
# check if docs are inserted
|
|
277
|
+
table = lancedb_store.table
|
|
278
|
+
assert table is not None
|
|
279
|
+
row_count = table.count_rows()
|
|
280
|
+
assert row_count == node_count, (
|
|
281
|
+
f"Expected {node_count} rows (one for each node), got {row_count} instead"
|
|
282
|
+
)
|