kiln-ai 0.22.0__py3-none-any.whl → 0.22.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

@@ -0,0 +1,282 @@
1
+ import os
2
+ import random
3
+ import time
4
+ import uuid
5
+ from dataclasses import dataclass
6
+
7
+ import pytest
8
+ from pydantic import BaseModel, Field
9
+
10
+ from kiln_ai.adapters.vector_store.lancedb_adapter import lancedb_construct_from_config
11
+ from kiln_ai.adapters.vector_store_loaders.vector_store_loader import VectorStoreLoader
12
+ from kiln_ai.datamodel.chunk import Chunk, ChunkedDocument
13
+ from kiln_ai.datamodel.datamodel_enums import KilnMimeType
14
+ from kiln_ai.datamodel.embedding import ChunkEmbeddings, Embedding
15
+ from kiln_ai.datamodel.extraction import (
16
+ Document,
17
+ Extraction,
18
+ ExtractionSource,
19
+ FileInfo,
20
+ Kind,
21
+ )
22
+ from kiln_ai.datamodel.project import Project
23
+ from kiln_ai.datamodel.rag import RagConfig
24
+ from kiln_ai.datamodel.vector_store import VectorStoreConfig, VectorStoreType
25
+
26
+
27
+ @dataclass
28
+ class DocWithChunks:
29
+ document: Document
30
+ extraction: Extraction
31
+ chunked_document: ChunkedDocument
32
+ chunked_embeddings: ChunkEmbeddings
33
+
34
+
35
+ def lorem_ipsum(n: int) -> str:
36
+ return " ".join(
37
+ ["Lorem ipsum dolor sit amet, consectetur adipiscing elit." for _ in range(n)]
38
+ )
39
+
40
+
41
+ @pytest.fixture
42
+ def mock_chunks_factory(mock_attachment_factory):
43
+ def fn(
44
+ project: Project,
45
+ rag_config: RagConfig,
46
+ num_chunks: int = 1,
47
+ text: str | None = None,
48
+ ) -> DocWithChunks:
49
+ doc = Document(
50
+ id=f"doc_{uuid.uuid4()}",
51
+ name="Test Document",
52
+ description="Test Document",
53
+ original_file=FileInfo(
54
+ filename="test.pdf",
55
+ size=100,
56
+ mime_type="application/pdf",
57
+ attachment=mock_attachment_factory(KilnMimeType.PDF),
58
+ ),
59
+ kind=Kind.DOCUMENT,
60
+ parent=project,
61
+ )
62
+ doc.save_to_file()
63
+
64
+ extraction = Extraction(
65
+ source=ExtractionSource.PROCESSED,
66
+ extractor_config_id=rag_config.extractor_config_id,
67
+ output=mock_attachment_factory(KilnMimeType.PDF),
68
+ parent=doc,
69
+ )
70
+ extraction.save_to_file()
71
+
72
+ chunks = [
73
+ Chunk(
74
+ content=mock_attachment_factory(
75
+ KilnMimeType.TXT, text=f"text-{i}: {text or lorem_ipsum(10)}"
76
+ )
77
+ )
78
+ for i in range(num_chunks)
79
+ ]
80
+ chunked_document = ChunkedDocument(
81
+ chunks=chunks,
82
+ chunker_config_id=rag_config.chunker_config_id,
83
+ parent=extraction,
84
+ )
85
+ chunked_document.save_to_file()
86
+ chunked_embeddings = ChunkEmbeddings(
87
+ embeddings=[
88
+ Embedding(vector=[i + 0.1, i + 0.2, i + 0.3, i + 0.4, i + 0.5])
89
+ for i in range(num_chunks)
90
+ ],
91
+ embedding_config_id=rag_config.embedding_config_id,
92
+ parent=chunked_document,
93
+ )
94
+ chunked_embeddings.save_to_file()
95
+ return DocWithChunks(
96
+ document=doc,
97
+ extraction=extraction,
98
+ chunked_document=chunked_document,
99
+ chunked_embeddings=chunked_embeddings,
100
+ )
101
+
102
+ return fn
103
+
104
+
105
+ @pytest.fixture
106
+ def mock_project(tmp_path):
107
+ project = Project(
108
+ name="Test Project", path=tmp_path / "test_project" / "project.kiln"
109
+ )
110
+ project.save_to_file()
111
+ return project
112
+
113
+
114
+ @pytest.fixture
115
+ def rag_config_factory(mock_project):
116
+ def fn(vector_store_config_id: str) -> RagConfig:
117
+ rag_config = RagConfig(
118
+ name="Test Rag Config",
119
+ parent=mock_project,
120
+ vector_store_config_id=vector_store_config_id,
121
+ tool_name="test_tool",
122
+ tool_description="test_description",
123
+ extractor_config_id="test_extractor",
124
+ chunker_config_id="test_chunker",
125
+ embedding_config_id="test_embedding",
126
+ )
127
+ rag_config.save_to_file()
128
+ return rag_config
129
+
130
+ return fn
131
+
132
+
133
+ @pytest.fixture
134
+ def vector_store_config_factory(mock_project):
135
+ def fn(vector_store_type: VectorStoreType) -> VectorStoreConfig:
136
+ match vector_store_type:
137
+ case VectorStoreType.LANCE_DB_FTS:
138
+ vector_store_config = VectorStoreConfig(
139
+ name="Test Vector Store Config FTS",
140
+ parent=mock_project,
141
+ store_type=VectorStoreType.LANCE_DB_FTS,
142
+ properties={
143
+ "similarity_top_k": 10,
144
+ "overfetch_factor": 20,
145
+ "vector_column_name": "vector",
146
+ "text_key": "text",
147
+ "doc_id_key": "doc_id",
148
+ },
149
+ )
150
+ vector_store_config.save_to_file()
151
+ return vector_store_config
152
+ case VectorStoreType.LANCE_DB_VECTOR:
153
+ vector_store_config = VectorStoreConfig(
154
+ name="Test Vector Store Config KNN",
155
+ parent=mock_project,
156
+ store_type=VectorStoreType.LANCE_DB_VECTOR,
157
+ properties={
158
+ "similarity_top_k": 10,
159
+ "overfetch_factor": 20,
160
+ "vector_column_name": "vector",
161
+ "text_key": "text",
162
+ "doc_id_key": "doc_id",
163
+ "nprobes": 10,
164
+ },
165
+ )
166
+ vector_store_config.save_to_file()
167
+ return vector_store_config
168
+ case VectorStoreType.LANCE_DB_HYBRID:
169
+ vector_store_config = VectorStoreConfig(
170
+ name="Test Vector Store Config Hybrid",
171
+ parent=mock_project,
172
+ store_type=VectorStoreType.LANCE_DB_HYBRID,
173
+ properties={
174
+ "similarity_top_k": 10,
175
+ "nprobes": 10,
176
+ "overfetch_factor": 20,
177
+ "vector_column_name": "vector",
178
+ "text_key": "text",
179
+ "doc_id_key": "doc_id",
180
+ },
181
+ )
182
+ vector_store_config.save_to_file()
183
+ return vector_store_config
184
+ case _:
185
+ raise ValueError(f"Invalid vector store type: {vector_store_type}")
186
+
187
+ return fn
188
+
189
+
190
+ class LanceDBCloudEnvVars(BaseModel):
191
+ uri: str = Field("LANCE_DB_URI")
192
+ api_key: str = Field("LANCE_DB_API_KEY")
193
+ region: str = Field("LANCE_DB_REGION")
194
+
195
+
196
+ def lancedb_cloud_env_vars() -> LanceDBCloudEnvVars:
197
+ lancedb_uri = os.getenv("LANCE_DB_URI")
198
+ assert lancedb_uri is not None, (
199
+ "LANCE_DB_URI is not set - test requires lancedb cloud"
200
+ )
201
+
202
+ lancedb_api_key = os.getenv("LANCE_DB_API_KEY")
203
+ assert lancedb_api_key is not None, (
204
+ "LANCE_DB_API_KEY is not set - test requires lancedb cloud"
205
+ )
206
+
207
+ lancedb_region = os.getenv("LANCE_DB_REGION")
208
+ assert lancedb_region is not None, (
209
+ "LANCE_DB_REGION is not set - test requires lancedb cloud"
210
+ )
211
+ return LanceDBCloudEnvVars(
212
+ uri=lancedb_uri,
213
+ api_key=lancedb_api_key,
214
+ region=lancedb_region,
215
+ )
216
+
217
+
218
+ @pytest.mark.parametrize(
219
+ "vector_store_type",
220
+ [
221
+ VectorStoreType.LANCE_DB_FTS,
222
+ VectorStoreType.LANCE_DB_VECTOR,
223
+ VectorStoreType.LANCE_DB_HYBRID,
224
+ ],
225
+ )
226
+ @pytest.mark.paid
227
+ async def test_lancedb_loader_insert_nodes_lancedb_cloud(
228
+ mock_project,
229
+ mock_chunks_factory,
230
+ rag_config_factory,
231
+ vector_store_type,
232
+ vector_store_config_factory,
233
+ ):
234
+ lancedb_cloud_config = lancedb_cloud_env_vars()
235
+
236
+ vector_store_config = vector_store_config_factory(vector_store_type)
237
+ rag_config = rag_config_factory(vector_store_config.id)
238
+
239
+ # init lancedb store
240
+ now = time.time()
241
+ table_name = f"test_lancedb_loader_insert_nodes_{vector_store_type.value}_{now}"
242
+ lancedb_store = lancedb_construct_from_config(
243
+ vector_store_config=vector_store_config,
244
+ uri=lancedb_cloud_config.uri,
245
+ api_key=lancedb_cloud_config.api_key,
246
+ region=lancedb_cloud_config.region,
247
+ table_name=table_name,
248
+ )
249
+
250
+ loader = VectorStoreLoader(
251
+ project=mock_project,
252
+ rag_config=rag_config,
253
+ )
254
+
255
+ # create nodes
256
+ doc_count = 10
257
+ node_count = 0
258
+ for i in range(doc_count):
259
+ nodes_to_add = random.randint(1, 20)
260
+ # create mock docs, extractions, chunked documents, and chunk embeddings and persist
261
+ mock_chunks_factory(
262
+ mock_project,
263
+ rag_config,
264
+ num_chunks=nodes_to_add,
265
+ text=f"Document {i}",
266
+ )
267
+ node_count += nodes_to_add
268
+
269
+ assert node_count > 0, "No mock nodes were created"
270
+
271
+ # insert docs
272
+ batch_size = 100
273
+ async for batch in loader.iter_llama_index_nodes(batch_size=batch_size):
274
+ await lancedb_store.async_add(batch)
275
+
276
+ # check if docs are inserted
277
+ table = lancedb_store.table
278
+ assert table is not None
279
+ row_count = table.count_rows()
280
+ assert row_count == node_count, (
281
+ f"Expected {node_count} rows (one for each node), got {row_count} instead"
282
+ )