kiln-ai 0.20.1__py3-none-any.whl → 0.22.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +6 -0
- kiln_ai/adapters/adapter_registry.py +43 -226
- kiln_ai/adapters/chunkers/__init__.py +13 -0
- kiln_ai/adapters/chunkers/base_chunker.py +42 -0
- kiln_ai/adapters/chunkers/chunker_registry.py +16 -0
- kiln_ai/adapters/chunkers/fixed_window_chunker.py +39 -0
- kiln_ai/adapters/chunkers/helpers.py +23 -0
- kiln_ai/adapters/chunkers/test_base_chunker.py +63 -0
- kiln_ai/adapters/chunkers/test_chunker_registry.py +28 -0
- kiln_ai/adapters/chunkers/test_fixed_window_chunker.py +346 -0
- kiln_ai/adapters/chunkers/test_helpers.py +75 -0
- kiln_ai/adapters/data_gen/test_data_gen_task.py +9 -3
- kiln_ai/adapters/embedding/__init__.py +0 -0
- kiln_ai/adapters/embedding/base_embedding_adapter.py +44 -0
- kiln_ai/adapters/embedding/embedding_registry.py +32 -0
- kiln_ai/adapters/embedding/litellm_embedding_adapter.py +199 -0
- kiln_ai/adapters/embedding/test_base_embedding_adapter.py +283 -0
- kiln_ai/adapters/embedding/test_embedding_registry.py +166 -0
- kiln_ai/adapters/embedding/test_litellm_embedding_adapter.py +1149 -0
- kiln_ai/adapters/eval/eval_runner.py +6 -2
- kiln_ai/adapters/eval/test_base_eval.py +1 -3
- kiln_ai/adapters/eval/test_g_eval.py +1 -1
- kiln_ai/adapters/extractors/__init__.py +18 -0
- kiln_ai/adapters/extractors/base_extractor.py +72 -0
- kiln_ai/adapters/extractors/encoding.py +20 -0
- kiln_ai/adapters/extractors/extractor_registry.py +44 -0
- kiln_ai/adapters/extractors/extractor_runner.py +112 -0
- kiln_ai/adapters/extractors/litellm_extractor.py +406 -0
- kiln_ai/adapters/extractors/test_base_extractor.py +244 -0
- kiln_ai/adapters/extractors/test_encoding.py +54 -0
- kiln_ai/adapters/extractors/test_extractor_registry.py +181 -0
- kiln_ai/adapters/extractors/test_extractor_runner.py +181 -0
- kiln_ai/adapters/extractors/test_litellm_extractor.py +1290 -0
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +2 -2
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +2 -6
- kiln_ai/adapters/fine_tune/test_together_finetune.py +2 -6
- kiln_ai/adapters/ml_embedding_model_list.py +494 -0
- kiln_ai/adapters/ml_model_list.py +876 -18
- kiln_ai/adapters/model_adapters/litellm_adapter.py +40 -75
- kiln_ai/adapters/model_adapters/test_litellm_adapter.py +79 -1
- kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py +119 -5
- kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +9 -3
- kiln_ai/adapters/model_adapters/test_structured_output.py +9 -10
- kiln_ai/adapters/ollama_tools.py +69 -12
- kiln_ai/adapters/provider_tools.py +190 -46
- kiln_ai/adapters/rag/deduplication.py +49 -0
- kiln_ai/adapters/rag/progress.py +252 -0
- kiln_ai/adapters/rag/rag_runners.py +844 -0
- kiln_ai/adapters/rag/test_deduplication.py +195 -0
- kiln_ai/adapters/rag/test_progress.py +785 -0
- kiln_ai/adapters/rag/test_rag_runners.py +2376 -0
- kiln_ai/adapters/remote_config.py +80 -8
- kiln_ai/adapters/test_adapter_registry.py +579 -86
- kiln_ai/adapters/test_ml_embedding_model_list.py +239 -0
- kiln_ai/adapters/test_ml_model_list.py +202 -0
- kiln_ai/adapters/test_ollama_tools.py +340 -1
- kiln_ai/adapters/test_prompt_builders.py +1 -1
- kiln_ai/adapters/test_provider_tools.py +199 -8
- kiln_ai/adapters/test_remote_config.py +551 -56
- kiln_ai/adapters/vector_store/__init__.py +1 -0
- kiln_ai/adapters/vector_store/base_vector_store_adapter.py +83 -0
- kiln_ai/adapters/vector_store/lancedb_adapter.py +389 -0
- kiln_ai/adapters/vector_store/test_base_vector_store.py +160 -0
- kiln_ai/adapters/vector_store/test_lancedb_adapter.py +1841 -0
- kiln_ai/adapters/vector_store/test_vector_store_registry.py +199 -0
- kiln_ai/adapters/vector_store/vector_store_registry.py +33 -0
- kiln_ai/datamodel/__init__.py +16 -13
- kiln_ai/datamodel/basemodel.py +201 -4
- kiln_ai/datamodel/chunk.py +158 -0
- kiln_ai/datamodel/datamodel_enums.py +27 -0
- kiln_ai/datamodel/embedding.py +64 -0
- kiln_ai/datamodel/external_tool_server.py +206 -54
- kiln_ai/datamodel/extraction.py +317 -0
- kiln_ai/datamodel/project.py +33 -1
- kiln_ai/datamodel/rag.py +79 -0
- kiln_ai/datamodel/task.py +5 -0
- kiln_ai/datamodel/task_output.py +41 -11
- kiln_ai/datamodel/test_attachment.py +649 -0
- kiln_ai/datamodel/test_basemodel.py +270 -14
- kiln_ai/datamodel/test_chunk_models.py +317 -0
- kiln_ai/datamodel/test_dataset_split.py +1 -1
- kiln_ai/datamodel/test_datasource.py +50 -0
- kiln_ai/datamodel/test_embedding_models.py +448 -0
- kiln_ai/datamodel/test_eval_model.py +6 -6
- kiln_ai/datamodel/test_external_tool_server.py +534 -152
- kiln_ai/datamodel/test_extraction_chunk.py +206 -0
- kiln_ai/datamodel/test_extraction_model.py +501 -0
- kiln_ai/datamodel/test_rag.py +641 -0
- kiln_ai/datamodel/test_task.py +35 -1
- kiln_ai/datamodel/test_tool_id.py +187 -1
- kiln_ai/datamodel/test_vector_store.py +320 -0
- kiln_ai/datamodel/tool_id.py +58 -0
- kiln_ai/datamodel/vector_store.py +141 -0
- kiln_ai/tools/base_tool.py +12 -3
- kiln_ai/tools/built_in_tools/math_tools.py +12 -4
- kiln_ai/tools/kiln_task_tool.py +158 -0
- kiln_ai/tools/mcp_server_tool.py +2 -2
- kiln_ai/tools/mcp_session_manager.py +51 -22
- kiln_ai/tools/rag_tools.py +164 -0
- kiln_ai/tools/test_kiln_task_tool.py +527 -0
- kiln_ai/tools/test_mcp_server_tool.py +4 -15
- kiln_ai/tools/test_mcp_session_manager.py +187 -227
- kiln_ai/tools/test_rag_tools.py +929 -0
- kiln_ai/tools/test_tool_registry.py +290 -7
- kiln_ai/tools/tool_registry.py +69 -16
- kiln_ai/utils/__init__.py +3 -0
- kiln_ai/utils/async_job_runner.py +62 -17
- kiln_ai/utils/config.py +2 -2
- kiln_ai/utils/env.py +15 -0
- kiln_ai/utils/filesystem.py +14 -0
- kiln_ai/utils/filesystem_cache.py +60 -0
- kiln_ai/utils/litellm.py +94 -0
- kiln_ai/utils/lock.py +100 -0
- kiln_ai/utils/mime_type.py +38 -0
- kiln_ai/utils/open_ai_types.py +19 -2
- kiln_ai/utils/pdf_utils.py +59 -0
- kiln_ai/utils/test_async_job_runner.py +151 -35
- kiln_ai/utils/test_env.py +142 -0
- kiln_ai/utils/test_filesystem_cache.py +316 -0
- kiln_ai/utils/test_litellm.py +206 -0
- kiln_ai/utils/test_lock.py +185 -0
- kiln_ai/utils/test_mime_type.py +66 -0
- kiln_ai/utils/test_open_ai_types.py +88 -12
- kiln_ai/utils/test_pdf_utils.py +86 -0
- kiln_ai/utils/test_uuid.py +111 -0
- kiln_ai/utils/test_validation.py +524 -0
- kiln_ai/utils/uuid.py +9 -0
- kiln_ai/utils/validation.py +90 -0
- {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/METADATA +9 -1
- kiln_ai-0.22.0.dist-info/RECORD +213 -0
- kiln_ai-0.20.1.dist-info/RECORD +0 -138
- {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -0,0 +1,2376 @@
|
|
|
1
|
+
from datetime import datetime
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from unittest.mock import AsyncMock, MagicMock, patch
|
|
4
|
+
|
|
5
|
+
import pytest
|
|
6
|
+
from kiln_ai.adapters.chunkers.base_chunker import BaseChunker, ChunkingResult
|
|
7
|
+
from kiln_ai.adapters.embedding.base_embedding_adapter import (
|
|
8
|
+
BaseEmbeddingAdapter,
|
|
9
|
+
EmbeddingResult,
|
|
10
|
+
)
|
|
11
|
+
from kiln_ai.adapters.extractors.base_extractor import BaseExtractor, ExtractionOutput
|
|
12
|
+
from kiln_ai.adapters.rag.progress import LogMessage, RagProgress
|
|
13
|
+
from kiln_ai.adapters.rag.rag_runners import (
|
|
14
|
+
ChunkerJob,
|
|
15
|
+
EmbeddingJob,
|
|
16
|
+
ExtractorJob,
|
|
17
|
+
GenericErrorCollector,
|
|
18
|
+
RagChunkingStepRunner,
|
|
19
|
+
RagEmbeddingStepRunner,
|
|
20
|
+
RagExtractionStepRunner,
|
|
21
|
+
RagIndexingStepRunner,
|
|
22
|
+
RagStepRunnerProgress,
|
|
23
|
+
RagWorkflowRunner,
|
|
24
|
+
RagWorkflowRunnerConfiguration,
|
|
25
|
+
RagWorkflowStepNames,
|
|
26
|
+
execute_chunker_job,
|
|
27
|
+
execute_embedding_job,
|
|
28
|
+
execute_extractor_job,
|
|
29
|
+
)
|
|
30
|
+
from kiln_ai.datamodel.chunk import ChunkedDocument, ChunkerConfig, ChunkerType
|
|
31
|
+
from kiln_ai.datamodel.datamodel_enums import ModelProviderName
|
|
32
|
+
from kiln_ai.datamodel.embedding import EmbeddingConfig
|
|
33
|
+
from kiln_ai.datamodel.extraction import (
|
|
34
|
+
Document,
|
|
35
|
+
Extraction,
|
|
36
|
+
ExtractorConfig,
|
|
37
|
+
ExtractorType,
|
|
38
|
+
OutputFormat,
|
|
39
|
+
)
|
|
40
|
+
from kiln_ai.datamodel.project import Project
|
|
41
|
+
from kiln_ai.datamodel.rag import RagConfig
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
# Test fixtures
|
|
45
|
+
@pytest.fixture
|
|
46
|
+
def mock_project():
|
|
47
|
+
"""Create a mock project for testing"""
|
|
48
|
+
project = MagicMock(spec=Project)
|
|
49
|
+
return project
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@pytest.fixture
|
|
53
|
+
def mock_document():
|
|
54
|
+
"""Create a mock document for testing"""
|
|
55
|
+
doc = MagicMock(spec=Document)
|
|
56
|
+
doc.path = Path("test_doc.txt")
|
|
57
|
+
doc.original_file = MagicMock()
|
|
58
|
+
doc.original_file.attachment = MagicMock()
|
|
59
|
+
doc.original_file.attachment.resolve_path.return_value = "test_file_path"
|
|
60
|
+
doc.original_file.mime_type = "text/plain"
|
|
61
|
+
return doc
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
@pytest.fixture
|
|
65
|
+
def mock_extractor_config():
|
|
66
|
+
"""Create a mock extractor config for testing"""
|
|
67
|
+
config = MagicMock()
|
|
68
|
+
config.id = "extractor-123"
|
|
69
|
+
config.extractor_type = "test_extractor"
|
|
70
|
+
return config
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
@pytest.fixture
|
|
74
|
+
def mock_chunker_config():
|
|
75
|
+
"""Create a mock chunker config for testing"""
|
|
76
|
+
config = MagicMock(spec=ChunkerConfig)
|
|
77
|
+
config.id = "chunker-123"
|
|
78
|
+
config.chunker_type = "test_chunker"
|
|
79
|
+
return config
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
@pytest.fixture
|
|
83
|
+
def mock_embedding_config():
|
|
84
|
+
"""Create a mock embedding config for testing"""
|
|
85
|
+
config = MagicMock(spec=EmbeddingConfig)
|
|
86
|
+
config.id = "embedding-123"
|
|
87
|
+
return config
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
@pytest.fixture
|
|
91
|
+
def real_extractor_config(mock_project):
|
|
92
|
+
"""Create a real extractor config for workflow testing"""
|
|
93
|
+
return ExtractorConfig(
|
|
94
|
+
name="test-extractor",
|
|
95
|
+
model_provider_name="test",
|
|
96
|
+
model_name="test-model",
|
|
97
|
+
extractor_type=ExtractorType.LITELLM,
|
|
98
|
+
output_format=OutputFormat.MARKDOWN,
|
|
99
|
+
properties={
|
|
100
|
+
"prompt_document": "Transcribe the document.",
|
|
101
|
+
"prompt_audio": "Transcribe the audio.",
|
|
102
|
+
"prompt_video": "Transcribe the video.",
|
|
103
|
+
"prompt_image": "Describe the image.",
|
|
104
|
+
},
|
|
105
|
+
parent=mock_project,
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
@pytest.fixture
|
|
110
|
+
def real_chunker_config(mock_project):
|
|
111
|
+
"""Create a real chunker config for workflow testing"""
|
|
112
|
+
return ChunkerConfig(
|
|
113
|
+
name="test-chunker",
|
|
114
|
+
chunker_type=ChunkerType.FIXED_WINDOW,
|
|
115
|
+
properties={"chunk_size": 500, "chunk_overlap": 50},
|
|
116
|
+
parent=mock_project,
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
@pytest.fixture
|
|
121
|
+
def real_embedding_config(mock_project):
|
|
122
|
+
"""Create a real embedding config for workflow testing"""
|
|
123
|
+
return EmbeddingConfig(
|
|
124
|
+
name="test-embedding",
|
|
125
|
+
model_provider_name=ModelProviderName.openai,
|
|
126
|
+
model_name="text-embedding-3-small",
|
|
127
|
+
properties={"dimensions": 1536},
|
|
128
|
+
parent=mock_project,
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
@pytest.fixture
|
|
133
|
+
def real_rag_config(mock_project):
|
|
134
|
+
"""Create a real RAG config for workflow testing"""
|
|
135
|
+
return RagConfig(
|
|
136
|
+
name="test-rag",
|
|
137
|
+
tool_name="test_rag_tool",
|
|
138
|
+
tool_description="A test RAG tool for searching documents",
|
|
139
|
+
extractor_config_id="extractor-123",
|
|
140
|
+
chunker_config_id="chunker-123",
|
|
141
|
+
embedding_config_id="embedding-123",
|
|
142
|
+
vector_store_config_id="vector-store-123",
|
|
143
|
+
parent=mock_project,
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
@pytest.fixture
|
|
148
|
+
def mock_extraction():
|
|
149
|
+
"""Create a mock extraction for testing"""
|
|
150
|
+
extraction = MagicMock(spec=Extraction)
|
|
151
|
+
extraction.extractor_config_id = "extractor-123"
|
|
152
|
+
extraction.path = Path("test_extraction.txt")
|
|
153
|
+
extraction.output_content = AsyncMock(return_value="test content")
|
|
154
|
+
return extraction
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
@pytest.fixture
|
|
158
|
+
def mock_chunked_document():
|
|
159
|
+
"""Create a mock chunked document for testing"""
|
|
160
|
+
chunked_doc = MagicMock(spec=ChunkedDocument)
|
|
161
|
+
chunked_doc.chunker_config_id = "chunker-123"
|
|
162
|
+
chunked_doc.path = Path("test_chunked.txt")
|
|
163
|
+
chunked_doc.load_chunks_text = AsyncMock(return_value=["chunk 1", "chunk 2"])
|
|
164
|
+
return chunked_doc
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
@pytest.fixture
|
|
168
|
+
def mock_rag_config():
|
|
169
|
+
"""Create a mock RAG config for testing"""
|
|
170
|
+
config = MagicMock(spec=RagConfig)
|
|
171
|
+
config.id = "rag-123"
|
|
172
|
+
config.tags = None
|
|
173
|
+
return config
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
# Tests for dataclasses
|
|
177
|
+
class TestExtractorJob:
|
|
178
|
+
def test_extractor_job_creation(self, mock_document, mock_extractor_config):
|
|
179
|
+
job = ExtractorJob(doc=mock_document, extractor_config=mock_extractor_config)
|
|
180
|
+
assert job.doc == mock_document
|
|
181
|
+
assert job.extractor_config == mock_extractor_config
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
class TestChunkerJob:
|
|
185
|
+
def test_chunker_job_creation(self, mock_extraction, mock_chunker_config):
|
|
186
|
+
job = ChunkerJob(extraction=mock_extraction, chunker_config=mock_chunker_config)
|
|
187
|
+
assert job.extraction == mock_extraction
|
|
188
|
+
assert job.chunker_config == mock_chunker_config
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
class TestEmbeddingJob:
|
|
192
|
+
def test_embedding_job_creation(self, mock_chunked_document, mock_embedding_config):
|
|
193
|
+
job = EmbeddingJob(
|
|
194
|
+
chunked_document=mock_chunked_document,
|
|
195
|
+
embedding_config=mock_embedding_config,
|
|
196
|
+
)
|
|
197
|
+
assert job.chunked_document == mock_chunked_document
|
|
198
|
+
assert job.embedding_config == mock_embedding_config
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
class TestRagStepRunnerProgress:
|
|
202
|
+
def test_progress_creation_with_defaults(self):
|
|
203
|
+
progress = RagStepRunnerProgress()
|
|
204
|
+
assert progress.success_count is None
|
|
205
|
+
assert progress.error_count is None
|
|
206
|
+
assert progress.logs == []
|
|
207
|
+
|
|
208
|
+
def test_progress_creation_with_values(self):
|
|
209
|
+
logs = [LogMessage(level="info", message="test")]
|
|
210
|
+
progress = RagStepRunnerProgress(success_count=5, error_count=2, logs=logs)
|
|
211
|
+
assert progress.success_count == 5
|
|
212
|
+
assert progress.error_count == 2
|
|
213
|
+
assert progress.logs == logs
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
# Tests for GenericErrorCollector
|
|
217
|
+
class TestGenericErrorCollector:
|
|
218
|
+
@pytest.fixture
|
|
219
|
+
def error_collector(self):
|
|
220
|
+
return GenericErrorCollector()
|
|
221
|
+
|
|
222
|
+
@pytest.mark.asyncio
|
|
223
|
+
async def test_on_success_does_nothing(self, error_collector):
|
|
224
|
+
job = "test_job"
|
|
225
|
+
await error_collector.on_success(job)
|
|
226
|
+
assert len(error_collector.errors) == 0
|
|
227
|
+
|
|
228
|
+
@pytest.mark.asyncio
|
|
229
|
+
async def test_on_error_collects_error(self, error_collector):
|
|
230
|
+
job = "test_job"
|
|
231
|
+
error = Exception("test error")
|
|
232
|
+
await error_collector.on_error(job, error)
|
|
233
|
+
|
|
234
|
+
assert len(error_collector.errors) == 1
|
|
235
|
+
assert error_collector.errors[0] == (job, error)
|
|
236
|
+
|
|
237
|
+
def test_get_errors_returns_all_errors(self, error_collector):
|
|
238
|
+
# Add some errors manually
|
|
239
|
+
error1 = Exception("error 1")
|
|
240
|
+
error2 = Exception("error 2")
|
|
241
|
+
error_collector.errors = [("job1", error1), ("job2", error2)]
|
|
242
|
+
|
|
243
|
+
errors, last_idx = error_collector.get_errors()
|
|
244
|
+
assert len(errors) == 2
|
|
245
|
+
assert errors[0] == ("job1", error1)
|
|
246
|
+
assert errors[1] == ("job2", error2)
|
|
247
|
+
assert last_idx == 2
|
|
248
|
+
|
|
249
|
+
def test_get_errors_with_start_idx(self, error_collector):
|
|
250
|
+
# Add some errors manually
|
|
251
|
+
error1 = Exception("error 1")
|
|
252
|
+
error2 = Exception("error 2")
|
|
253
|
+
error3 = Exception("error 3")
|
|
254
|
+
error_collector.errors = [("job1", error1), ("job2", error2), ("job3", error3)]
|
|
255
|
+
|
|
256
|
+
errors, last_idx = error_collector.get_errors(start_idx=1)
|
|
257
|
+
assert len(errors) == 2
|
|
258
|
+
assert errors[0] == ("job2", error2)
|
|
259
|
+
assert errors[1] == ("job3", error3)
|
|
260
|
+
assert last_idx == 3
|
|
261
|
+
|
|
262
|
+
def test_get_errors_negative_start_idx_raises_error(self, error_collector):
|
|
263
|
+
# Add some errors
|
|
264
|
+
error_collector.errors = [("job1", Exception()), ("job2", Exception())]
|
|
265
|
+
|
|
266
|
+
with pytest.raises(ValueError, match="start_idx must be non-negative"):
|
|
267
|
+
error_collector.get_errors(start_idx=-1)
|
|
268
|
+
|
|
269
|
+
def test_get_errors_with_start_idx_zero(self, error_collector):
|
|
270
|
+
# Add some errors
|
|
271
|
+
error1 = Exception("error 1")
|
|
272
|
+
error2 = Exception("error 2")
|
|
273
|
+
error_collector.errors = [("job1", error1), ("job2", error2)]
|
|
274
|
+
|
|
275
|
+
errors, last_idx = error_collector.get_errors(start_idx=0)
|
|
276
|
+
assert len(errors) == 2
|
|
277
|
+
assert errors[0] == ("job1", error1)
|
|
278
|
+
assert errors[1] == ("job2", error2)
|
|
279
|
+
assert last_idx == 2
|
|
280
|
+
|
|
281
|
+
def test_get_errors_start_idx_equal_to_length(self, error_collector):
|
|
282
|
+
# Add some errors
|
|
283
|
+
error_collector.errors = [("job1", Exception()), ("job2", Exception())]
|
|
284
|
+
|
|
285
|
+
# start_idx equal to length should return empty list
|
|
286
|
+
errors, last_idx = error_collector.get_errors(start_idx=2)
|
|
287
|
+
assert len(errors) == 0
|
|
288
|
+
assert last_idx == 2
|
|
289
|
+
|
|
290
|
+
def test_get_errors_start_idx_greater_than_length(self, error_collector):
|
|
291
|
+
# Add some errors
|
|
292
|
+
error_collector.errors = [("job1", Exception()), ("job2", Exception())]
|
|
293
|
+
|
|
294
|
+
# start_idx greater than length should return empty list
|
|
295
|
+
errors, last_idx = error_collector.get_errors(start_idx=5)
|
|
296
|
+
assert len(errors) == 0
|
|
297
|
+
assert last_idx == 5
|
|
298
|
+
|
|
299
|
+
def test_get_errors_with_empty_error_list(self, error_collector):
|
|
300
|
+
# Test with no errors and different start_idx values
|
|
301
|
+
errors, last_idx = error_collector.get_errors(start_idx=0)
|
|
302
|
+
assert len(errors) == 0
|
|
303
|
+
assert last_idx == 0
|
|
304
|
+
|
|
305
|
+
errors, last_idx = error_collector.get_errors(start_idx=1)
|
|
306
|
+
assert len(errors) == 0
|
|
307
|
+
assert last_idx == 1
|
|
308
|
+
|
|
309
|
+
def test_get_errors_boundary_conditions(self, error_collector):
|
|
310
|
+
# Test with single error
|
|
311
|
+
single_error = Exception("single error")
|
|
312
|
+
error_collector.errors = [("job1", single_error)]
|
|
313
|
+
|
|
314
|
+
# Get from start
|
|
315
|
+
errors, last_idx = error_collector.get_errors(start_idx=0)
|
|
316
|
+
assert len(errors) == 1
|
|
317
|
+
assert errors[0] == ("job1", single_error)
|
|
318
|
+
assert last_idx == 1
|
|
319
|
+
|
|
320
|
+
# Get from index 1 (equal to length)
|
|
321
|
+
errors, last_idx = error_collector.get_errors(start_idx=1)
|
|
322
|
+
assert len(errors) == 0
|
|
323
|
+
assert last_idx == 1
|
|
324
|
+
|
|
325
|
+
def test_get_error_count(self, error_collector):
|
|
326
|
+
assert error_collector.get_error_count() == 0
|
|
327
|
+
|
|
328
|
+
error_collector.errors = [("job1", Exception()), ("job2", Exception())]
|
|
329
|
+
assert error_collector.get_error_count() == 2
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
# Tests for job execution functions
|
|
333
|
+
class TestExecuteExtractorJob:
|
|
334
|
+
@pytest.mark.asyncio
|
|
335
|
+
async def test_execute_extractor_job_success(
|
|
336
|
+
self, mock_document, mock_extractor_config
|
|
337
|
+
):
|
|
338
|
+
# Setup mocks
|
|
339
|
+
job = ExtractorJob(doc=mock_document, extractor_config=mock_extractor_config)
|
|
340
|
+
|
|
341
|
+
mock_extractor = MagicMock(spec=BaseExtractor)
|
|
342
|
+
mock_output = ExtractionOutput(
|
|
343
|
+
content="extracted content", content_format=OutputFormat.TEXT
|
|
344
|
+
)
|
|
345
|
+
mock_extractor.extract = AsyncMock(return_value=mock_output)
|
|
346
|
+
|
|
347
|
+
with patch(
|
|
348
|
+
"kiln_ai.adapters.rag.rag_runners.Extraction"
|
|
349
|
+
) as mock_extraction_class:
|
|
350
|
+
mock_extraction = MagicMock()
|
|
351
|
+
mock_extraction_class.return_value = mock_extraction
|
|
352
|
+
|
|
353
|
+
result = await execute_extractor_job(job, mock_extractor)
|
|
354
|
+
|
|
355
|
+
assert result is True
|
|
356
|
+
mock_extractor.extract.assert_called_once()
|
|
357
|
+
mock_extraction.save_to_file.assert_called_once()
|
|
358
|
+
|
|
359
|
+
@pytest.mark.asyncio
|
|
360
|
+
async def test_execute_extractor_job_no_path_raises_error(
|
|
361
|
+
self, mock_extractor_config
|
|
362
|
+
):
|
|
363
|
+
# Setup document without path
|
|
364
|
+
mock_document = MagicMock(spec=Document)
|
|
365
|
+
mock_document.path = None
|
|
366
|
+
|
|
367
|
+
job = ExtractorJob(doc=mock_document, extractor_config=mock_extractor_config)
|
|
368
|
+
mock_extractor = MagicMock(spec=BaseExtractor)
|
|
369
|
+
|
|
370
|
+
with pytest.raises(ValueError, match="Document path is not set"):
|
|
371
|
+
await execute_extractor_job(job, mock_extractor)
|
|
372
|
+
|
|
373
|
+
|
|
374
|
+
class TestExecuteChunkerJob:
|
|
375
|
+
@pytest.mark.asyncio
|
|
376
|
+
async def test_execute_chunker_job_success(
|
|
377
|
+
self, mock_extraction, mock_chunker_config
|
|
378
|
+
):
|
|
379
|
+
# Setup mocks
|
|
380
|
+
job = ChunkerJob(extraction=mock_extraction, chunker_config=mock_chunker_config)
|
|
381
|
+
|
|
382
|
+
mock_chunker = MagicMock(spec=BaseChunker)
|
|
383
|
+
mock_chunking_result = MagicMock(spec=ChunkingResult)
|
|
384
|
+
mock_chunk = MagicMock()
|
|
385
|
+
mock_chunk.text = "chunk text"
|
|
386
|
+
mock_chunking_result.chunks = [mock_chunk]
|
|
387
|
+
mock_chunker.chunk = AsyncMock(return_value=mock_chunking_result)
|
|
388
|
+
|
|
389
|
+
with patch(
|
|
390
|
+
"kiln_ai.adapters.rag.rag_runners.ChunkedDocument"
|
|
391
|
+
) as mock_chunked_doc_class:
|
|
392
|
+
mock_chunked_doc = MagicMock()
|
|
393
|
+
mock_chunked_doc_class.return_value = mock_chunked_doc
|
|
394
|
+
|
|
395
|
+
result = await execute_chunker_job(job, mock_chunker)
|
|
396
|
+
|
|
397
|
+
assert result is True
|
|
398
|
+
mock_chunker.chunk.assert_called_once_with("test content")
|
|
399
|
+
mock_chunked_doc.save_to_file.assert_called_once()
|
|
400
|
+
|
|
401
|
+
@pytest.mark.asyncio
|
|
402
|
+
async def test_execute_chunker_job_no_content_raises_error(
|
|
403
|
+
self, mock_chunker_config
|
|
404
|
+
):
|
|
405
|
+
# Setup extraction without content
|
|
406
|
+
mock_extraction = MagicMock(spec=Extraction)
|
|
407
|
+
mock_extraction.output_content = AsyncMock(return_value=None)
|
|
408
|
+
|
|
409
|
+
job = ChunkerJob(extraction=mock_extraction, chunker_config=mock_chunker_config)
|
|
410
|
+
mock_chunker = MagicMock(spec=BaseChunker)
|
|
411
|
+
|
|
412
|
+
with pytest.raises(ValueError, match="Extraction output content is not set"):
|
|
413
|
+
await execute_chunker_job(job, mock_chunker)
|
|
414
|
+
|
|
415
|
+
@pytest.mark.asyncio
|
|
416
|
+
async def test_execute_chunker_job_no_chunking_result_raises_error(
|
|
417
|
+
self, mock_extraction, mock_chunker_config
|
|
418
|
+
):
|
|
419
|
+
job = ChunkerJob(extraction=mock_extraction, chunker_config=mock_chunker_config)
|
|
420
|
+
|
|
421
|
+
mock_chunker = MagicMock(spec=BaseChunker)
|
|
422
|
+
mock_chunker.chunk = AsyncMock(return_value=None)
|
|
423
|
+
|
|
424
|
+
with pytest.raises(ValueError, match="Chunking result is not set"):
|
|
425
|
+
await execute_chunker_job(job, mock_chunker)
|
|
426
|
+
|
|
427
|
+
|
|
428
|
+
class TestExecuteEmbeddingJob:
|
|
429
|
+
@pytest.mark.asyncio
|
|
430
|
+
async def test_execute_embedding_job_success(
|
|
431
|
+
self, mock_chunked_document, mock_embedding_config
|
|
432
|
+
):
|
|
433
|
+
# Setup mocks
|
|
434
|
+
job = EmbeddingJob(
|
|
435
|
+
chunked_document=mock_chunked_document,
|
|
436
|
+
embedding_config=mock_embedding_config,
|
|
437
|
+
)
|
|
438
|
+
|
|
439
|
+
mock_embedding_adapter = MagicMock(spec=BaseEmbeddingAdapter)
|
|
440
|
+
mock_embedding_result = MagicMock(spec=EmbeddingResult)
|
|
441
|
+
mock_embedding = MagicMock()
|
|
442
|
+
mock_embedding.vector = [0.1, 0.2, 0.3]
|
|
443
|
+
mock_embedding_result.embeddings = [mock_embedding]
|
|
444
|
+
mock_embedding_adapter.generate_embeddings = AsyncMock(
|
|
445
|
+
return_value=mock_embedding_result
|
|
446
|
+
)
|
|
447
|
+
|
|
448
|
+
with patch(
|
|
449
|
+
"kiln_ai.adapters.rag.rag_runners.ChunkEmbeddings"
|
|
450
|
+
) as mock_chunk_embeddings_class:
|
|
451
|
+
mock_chunk_embeddings = MagicMock()
|
|
452
|
+
mock_chunk_embeddings_class.return_value = mock_chunk_embeddings
|
|
453
|
+
|
|
454
|
+
result = await execute_embedding_job(job, mock_embedding_adapter)
|
|
455
|
+
|
|
456
|
+
assert result is True
|
|
457
|
+
mock_embedding_adapter.generate_embeddings.assert_called_once_with(
|
|
458
|
+
input_texts=["chunk 1", "chunk 2"]
|
|
459
|
+
)
|
|
460
|
+
mock_chunk_embeddings.save_to_file.assert_called_once()
|
|
461
|
+
|
|
462
|
+
@pytest.mark.asyncio
|
|
463
|
+
@pytest.mark.parametrize("return_value", [None, []])
|
|
464
|
+
async def test_execute_embedding_job_no_chunks_raises_error(
|
|
465
|
+
self, mock_embedding_config, return_value
|
|
466
|
+
):
|
|
467
|
+
# Setup chunked document without chunks
|
|
468
|
+
mock_chunked_document = MagicMock(spec=ChunkedDocument, id="123")
|
|
469
|
+
mock_chunked_document.load_chunks_text = AsyncMock(return_value=return_value)
|
|
470
|
+
|
|
471
|
+
job = EmbeddingJob(
|
|
472
|
+
chunked_document=mock_chunked_document,
|
|
473
|
+
embedding_config=mock_embedding_config,
|
|
474
|
+
)
|
|
475
|
+
mock_embedding_adapter = MagicMock(spec=BaseEmbeddingAdapter)
|
|
476
|
+
|
|
477
|
+
with pytest.raises(
|
|
478
|
+
ValueError, match="Failed to load chunks for chunked document: 123"
|
|
479
|
+
):
|
|
480
|
+
await execute_embedding_job(job, mock_embedding_adapter)
|
|
481
|
+
|
|
482
|
+
@pytest.mark.asyncio
|
|
483
|
+
async def test_execute_embedding_job_no_embedding_result_raises_error(
|
|
484
|
+
self, mock_chunked_document, mock_embedding_config
|
|
485
|
+
):
|
|
486
|
+
mock_chunked_document.id = "123"
|
|
487
|
+
job = EmbeddingJob(
|
|
488
|
+
chunked_document=mock_chunked_document,
|
|
489
|
+
embedding_config=mock_embedding_config,
|
|
490
|
+
)
|
|
491
|
+
|
|
492
|
+
mock_embedding_adapter = MagicMock(spec=BaseEmbeddingAdapter)
|
|
493
|
+
mock_embedding_adapter.generate_embeddings = AsyncMock(return_value=None)
|
|
494
|
+
|
|
495
|
+
with pytest.raises(
|
|
496
|
+
ValueError, match="Failed to generate embeddings for chunked document: 123"
|
|
497
|
+
):
|
|
498
|
+
await execute_embedding_job(job, mock_embedding_adapter)
|
|
499
|
+
|
|
500
|
+
|
|
501
|
+
# Tests for step runners
|
|
502
|
+
class TestRagExtractionStepRunner:
|
|
503
|
+
@pytest.fixture
|
|
504
|
+
def extraction_runner(self, mock_project, mock_extractor_config):
|
|
505
|
+
return RagExtractionStepRunner(
|
|
506
|
+
project=mock_project, extractor_config=mock_extractor_config, concurrency=2
|
|
507
|
+
)
|
|
508
|
+
|
|
509
|
+
def test_stage_returns_extracting(self, extraction_runner):
|
|
510
|
+
assert extraction_runner.stage() == RagWorkflowStepNames.EXTRACTING
|
|
511
|
+
|
|
512
|
+
def test_has_extraction_returns_true_when_found(
|
|
513
|
+
self, extraction_runner, mock_document
|
|
514
|
+
):
|
|
515
|
+
# Setup mock extraction with matching config ID
|
|
516
|
+
mock_extraction = MagicMock()
|
|
517
|
+
mock_extraction.extractor_config_id = "extractor-123"
|
|
518
|
+
mock_document.extractions.return_value = [mock_extraction]
|
|
519
|
+
|
|
520
|
+
result = extraction_runner.has_extraction(mock_document, "extractor-123")
|
|
521
|
+
assert result is True
|
|
522
|
+
|
|
523
|
+
def test_has_extraction_returns_false_when_not_found(
|
|
524
|
+
self, extraction_runner, mock_document
|
|
525
|
+
):
|
|
526
|
+
# Setup mock extraction with different config ID
|
|
527
|
+
mock_extraction = MagicMock()
|
|
528
|
+
mock_extraction.extractor_config_id = "different-extractor"
|
|
529
|
+
mock_document.extractions.return_value = [mock_extraction]
|
|
530
|
+
|
|
531
|
+
result = extraction_runner.has_extraction(mock_document, "extractor-123")
|
|
532
|
+
assert result is False
|
|
533
|
+
|
|
534
|
+
@pytest.mark.asyncio
|
|
535
|
+
async def test_collect_jobs_returns_jobs_for_documents_without_extractions(
|
|
536
|
+
self, extraction_runner
|
|
537
|
+
):
|
|
538
|
+
# Setup mock documents - one with extraction, one without
|
|
539
|
+
mock_doc1 = MagicMock(spec=Document)
|
|
540
|
+
mock_doc1.extractions.return_value = [] # No extractions
|
|
541
|
+
|
|
542
|
+
mock_doc2 = MagicMock(spec=Document)
|
|
543
|
+
mock_extraction = MagicMock()
|
|
544
|
+
mock_extraction.extractor_config_id = "extractor-123"
|
|
545
|
+
mock_doc2.extractions.return_value = [
|
|
546
|
+
mock_extraction
|
|
547
|
+
] # Has matching extraction
|
|
548
|
+
|
|
549
|
+
extraction_runner.project.documents.return_value = [mock_doc1, mock_doc2]
|
|
550
|
+
|
|
551
|
+
jobs = await extraction_runner.collect_jobs()
|
|
552
|
+
|
|
553
|
+
# Should only create job for doc1 (no extraction)
|
|
554
|
+
assert len(jobs) == 1
|
|
555
|
+
assert jobs[0].doc == mock_doc1
|
|
556
|
+
assert jobs[0].extractor_config == extraction_runner.extractor_config
|
|
557
|
+
|
|
558
|
+
@pytest.mark.asyncio
|
|
559
|
+
async def test_collect_jobs_with_document_ids_filters_documents(
|
|
560
|
+
self, extraction_runner
|
|
561
|
+
):
|
|
562
|
+
# Setup mock documents with specific IDs
|
|
563
|
+
mock_doc1 = MagicMock(spec=Document)
|
|
564
|
+
mock_doc1.id = "doc-1"
|
|
565
|
+
mock_doc1.extractions.return_value = [] # No extractions
|
|
566
|
+
|
|
567
|
+
mock_doc2 = MagicMock(spec=Document)
|
|
568
|
+
mock_doc2.id = "doc-2"
|
|
569
|
+
mock_doc2.extractions.return_value = [] # No extractions
|
|
570
|
+
|
|
571
|
+
mock_doc3 = MagicMock(spec=Document)
|
|
572
|
+
mock_doc3.id = "doc-3"
|
|
573
|
+
mock_doc3.extractions.return_value = [] # No extractions
|
|
574
|
+
|
|
575
|
+
extraction_runner.project.documents.return_value = [
|
|
576
|
+
mock_doc1,
|
|
577
|
+
mock_doc2,
|
|
578
|
+
mock_doc3,
|
|
579
|
+
]
|
|
580
|
+
|
|
581
|
+
# Only process doc-1 and doc-3
|
|
582
|
+
jobs = await extraction_runner.collect_jobs(document_ids=["doc-1", "doc-3"])
|
|
583
|
+
|
|
584
|
+
# Should only create jobs for doc-1 and doc-3
|
|
585
|
+
assert len(jobs) == 2
|
|
586
|
+
job_doc_ids = {job.doc.id for job in jobs}
|
|
587
|
+
assert job_doc_ids == {"doc-1", "doc-3"}
|
|
588
|
+
|
|
589
|
+
@pytest.mark.asyncio
|
|
590
|
+
async def test_collect_jobs_with_empty_document_ids_processes_all_documents(
|
|
591
|
+
self, extraction_runner
|
|
592
|
+
):
|
|
593
|
+
# Setup mock documents
|
|
594
|
+
mock_doc1 = MagicMock(spec=Document)
|
|
595
|
+
mock_doc1.id = "doc-1"
|
|
596
|
+
mock_doc1.extractions.return_value = []
|
|
597
|
+
|
|
598
|
+
mock_doc2 = MagicMock(spec=Document)
|
|
599
|
+
mock_doc2.id = "doc-2"
|
|
600
|
+
mock_doc2.extractions.return_value = []
|
|
601
|
+
|
|
602
|
+
extraction_runner.project.documents.return_value = [mock_doc1, mock_doc2]
|
|
603
|
+
|
|
604
|
+
# Empty list should behave like None
|
|
605
|
+
jobs_empty = await extraction_runner.collect_jobs(document_ids=[])
|
|
606
|
+
jobs_none = await extraction_runner.collect_jobs(document_ids=None)
|
|
607
|
+
|
|
608
|
+
# Both should process all documents
|
|
609
|
+
assert len(jobs_empty) == 2
|
|
610
|
+
assert len(jobs_none) == 2
|
|
611
|
+
|
|
612
|
+
# Should have same document IDs
|
|
613
|
+
empty_doc_ids = {job.doc.id for job in jobs_empty}
|
|
614
|
+
none_doc_ids = {job.doc.id for job in jobs_none}
|
|
615
|
+
assert empty_doc_ids == none_doc_ids == {"doc-1", "doc-2"}
|
|
616
|
+
|
|
617
|
+
@pytest.mark.asyncio
|
|
618
|
+
async def test_run_with_document_ids_filters_documents(self, extraction_runner):
|
|
619
|
+
# Setup mock documents with specific IDs
|
|
620
|
+
mock_doc1 = MagicMock(spec=Document)
|
|
621
|
+
mock_doc1.id = "doc-1"
|
|
622
|
+
mock_doc1.extractions.return_value = []
|
|
623
|
+
mock_doc1.path = Path("doc1.txt")
|
|
624
|
+
mock_doc1.original_file = MagicMock()
|
|
625
|
+
mock_doc1.original_file.attachment = MagicMock()
|
|
626
|
+
mock_doc1.original_file.attachment.resolve_path.return_value = "doc1_path"
|
|
627
|
+
mock_doc1.original_file.mime_type = "text/plain"
|
|
628
|
+
|
|
629
|
+
mock_doc2 = MagicMock(spec=Document)
|
|
630
|
+
mock_doc2.id = "doc-2"
|
|
631
|
+
mock_doc2.extractions.return_value = []
|
|
632
|
+
mock_doc2.path = Path("doc2.txt")
|
|
633
|
+
mock_doc2.original_file = MagicMock()
|
|
634
|
+
mock_doc2.original_file.attachment = MagicMock()
|
|
635
|
+
mock_doc2.original_file.attachment.resolve_path.return_value = "doc2_path"
|
|
636
|
+
mock_doc2.original_file.mime_type = "text/plain"
|
|
637
|
+
|
|
638
|
+
extraction_runner.project.documents.return_value = [mock_doc1, mock_doc2]
|
|
639
|
+
|
|
640
|
+
with (
|
|
641
|
+
patch(
|
|
642
|
+
"kiln_ai.adapters.rag.rag_runners.extractor_adapter_from_type"
|
|
643
|
+
) as mock_adapter_factory,
|
|
644
|
+
patch(
|
|
645
|
+
"kiln_ai.adapters.rag.rag_runners.AsyncJobRunner"
|
|
646
|
+
) as mock_job_runner_class,
|
|
647
|
+
patch("kiln_ai.utils.lock.shared_async_lock_manager"),
|
|
648
|
+
):
|
|
649
|
+
mock_extractor = MagicMock(spec=BaseExtractor)
|
|
650
|
+
mock_adapter_factory.return_value = mock_extractor
|
|
651
|
+
|
|
652
|
+
mock_job_runner = MagicMock()
|
|
653
|
+
mock_job_runner_class.return_value = mock_job_runner
|
|
654
|
+
|
|
655
|
+
async def mock_runner_progress():
|
|
656
|
+
yield MagicMock(complete=1)
|
|
657
|
+
|
|
658
|
+
mock_job_runner.run.return_value = mock_runner_progress()
|
|
659
|
+
|
|
660
|
+
# Run with specific document IDs
|
|
661
|
+
progress_values = []
|
|
662
|
+
async for progress in extraction_runner.run(document_ids=["doc-1"]):
|
|
663
|
+
progress_values.append(progress)
|
|
664
|
+
|
|
665
|
+
# Verify job runner was created with only one job (for doc-1)
|
|
666
|
+
mock_job_runner_class.assert_called_once()
|
|
667
|
+
call_args = mock_job_runner_class.call_args
|
|
668
|
+
jobs = call_args.kwargs["jobs"]
|
|
669
|
+
assert len(jobs) == 1
|
|
670
|
+
assert jobs[0].doc.id == "doc-1"
|
|
671
|
+
|
|
672
|
+
@pytest.mark.asyncio
|
|
673
|
+
async def test_run_with_empty_document_ids_behaves_like_none(
|
|
674
|
+
self, extraction_runner
|
|
675
|
+
):
|
|
676
|
+
# Setup mock documents
|
|
677
|
+
mock_doc1 = MagicMock(spec=Document)
|
|
678
|
+
mock_doc1.id = "doc-1"
|
|
679
|
+
mock_doc1.extractions.return_value = []
|
|
680
|
+
mock_doc1.path = Path("doc1.txt")
|
|
681
|
+
mock_doc1.original_file = MagicMock()
|
|
682
|
+
mock_doc1.original_file.attachment = MagicMock()
|
|
683
|
+
mock_doc1.original_file.attachment.resolve_path.return_value = "doc1_path"
|
|
684
|
+
mock_doc1.original_file.mime_type = "text/plain"
|
|
685
|
+
|
|
686
|
+
extraction_runner.project.documents.return_value = [mock_doc1]
|
|
687
|
+
|
|
688
|
+
with (
|
|
689
|
+
patch(
|
|
690
|
+
"kiln_ai.adapters.rag.rag_runners.extractor_adapter_from_type"
|
|
691
|
+
) as mock_adapter_factory,
|
|
692
|
+
patch(
|
|
693
|
+
"kiln_ai.adapters.rag.rag_runners.AsyncJobRunner"
|
|
694
|
+
) as mock_job_runner_class,
|
|
695
|
+
patch("kiln_ai.utils.lock.shared_async_lock_manager"),
|
|
696
|
+
):
|
|
697
|
+
mock_extractor = MagicMock(spec=BaseExtractor)
|
|
698
|
+
mock_adapter_factory.return_value = mock_extractor
|
|
699
|
+
|
|
700
|
+
mock_job_runner = MagicMock()
|
|
701
|
+
mock_job_runner_class.return_value = mock_job_runner
|
|
702
|
+
|
|
703
|
+
async def mock_runner_progress():
|
|
704
|
+
yield MagicMock(complete=1)
|
|
705
|
+
|
|
706
|
+
mock_job_runner.run.return_value = mock_runner_progress()
|
|
707
|
+
|
|
708
|
+
# Test with empty list
|
|
709
|
+
jobs_with_empty = None
|
|
710
|
+
async for _ in extraction_runner.run(document_ids=[]):
|
|
711
|
+
pass
|
|
712
|
+
call_args_empty = mock_job_runner_class.call_args
|
|
713
|
+
jobs_with_empty = call_args_empty.kwargs["jobs"]
|
|
714
|
+
|
|
715
|
+
# Reset mock
|
|
716
|
+
mock_job_runner_class.reset_mock()
|
|
717
|
+
|
|
718
|
+
# Test with None
|
|
719
|
+
jobs_with_none = None
|
|
720
|
+
async for _ in extraction_runner.run(document_ids=None):
|
|
721
|
+
pass
|
|
722
|
+
call_args_none = mock_job_runner_class.call_args
|
|
723
|
+
jobs_with_none = call_args_none.kwargs["jobs"]
|
|
724
|
+
|
|
725
|
+
# Both should have same number of jobs
|
|
726
|
+
assert len(jobs_with_empty) == len(jobs_with_none) == 1
|
|
727
|
+
|
|
728
|
+
|
|
729
|
+
class TestRagChunkingStepRunner:
|
|
730
|
+
@pytest.fixture
|
|
731
|
+
def chunking_runner(self, mock_project, mock_extractor_config, mock_chunker_config):
|
|
732
|
+
return RagChunkingStepRunner(
|
|
733
|
+
project=mock_project,
|
|
734
|
+
extractor_config=mock_extractor_config,
|
|
735
|
+
chunker_config=mock_chunker_config,
|
|
736
|
+
concurrency=2,
|
|
737
|
+
)
|
|
738
|
+
|
|
739
|
+
def test_stage_returns_chunking(self, chunking_runner):
|
|
740
|
+
assert chunking_runner.stage() == RagWorkflowStepNames.CHUNKING
|
|
741
|
+
|
|
742
|
+
def test_has_chunks_returns_true_when_found(self, chunking_runner, mock_extraction):
|
|
743
|
+
# Setup mock chunked document with matching config ID
|
|
744
|
+
mock_chunked_doc = MagicMock()
|
|
745
|
+
mock_chunked_doc.chunker_config_id = "chunker-123"
|
|
746
|
+
mock_extraction.chunked_documents.return_value = [mock_chunked_doc]
|
|
747
|
+
|
|
748
|
+
result = chunking_runner.has_chunks(mock_extraction, "chunker-123")
|
|
749
|
+
assert result is True
|
|
750
|
+
|
|
751
|
+
def test_has_chunks_returns_false_when_not_found(
|
|
752
|
+
self, chunking_runner, mock_extraction
|
|
753
|
+
):
|
|
754
|
+
# Setup mock chunked document with different config ID
|
|
755
|
+
mock_chunked_doc = MagicMock()
|
|
756
|
+
mock_chunked_doc.chunker_config_id = "different-chunker"
|
|
757
|
+
mock_extraction.chunked_documents.return_value = [mock_chunked_doc]
|
|
758
|
+
|
|
759
|
+
result = chunking_runner.has_chunks(mock_extraction, "chunker-123")
|
|
760
|
+
assert result is False
|
|
761
|
+
|
|
762
|
+
@pytest.mark.asyncio
|
|
763
|
+
async def test_collect_jobs_returns_jobs_for_extractions_without_chunks(
|
|
764
|
+
self, chunking_runner
|
|
765
|
+
):
|
|
766
|
+
# Setup mock document with extractions
|
|
767
|
+
mock_doc = MagicMock(spec=Document)
|
|
768
|
+
|
|
769
|
+
# Extraction with matching extractor config but no chunks
|
|
770
|
+
mock_extraction1 = MagicMock(spec=Extraction)
|
|
771
|
+
mock_extraction1.extractor_config_id = "extractor-123"
|
|
772
|
+
mock_extraction1.created_at = datetime(2023, 1, 1)
|
|
773
|
+
mock_extraction1.chunked_documents.return_value = []
|
|
774
|
+
|
|
775
|
+
# Extraction with matching extractor config and existing chunks
|
|
776
|
+
mock_extraction2 = MagicMock(spec=Extraction)
|
|
777
|
+
mock_extraction2.extractor_config_id = "extractor-123"
|
|
778
|
+
mock_extraction2.created_at = datetime(2023, 1, 2)
|
|
779
|
+
mock_chunked_doc = MagicMock()
|
|
780
|
+
mock_chunked_doc.chunker_config_id = "chunker-123"
|
|
781
|
+
mock_extraction2.chunked_documents.return_value = [mock_chunked_doc]
|
|
782
|
+
|
|
783
|
+
# Extraction with different extractor config
|
|
784
|
+
mock_extraction3 = MagicMock(spec=Extraction)
|
|
785
|
+
mock_extraction3.extractor_config_id = "different-extractor"
|
|
786
|
+
mock_extraction3.created_at = datetime(2023, 1, 3)
|
|
787
|
+
mock_extraction3.chunked_documents.return_value = []
|
|
788
|
+
|
|
789
|
+
mock_doc.extractions.return_value = [
|
|
790
|
+
mock_extraction1,
|
|
791
|
+
mock_extraction2,
|
|
792
|
+
mock_extraction3,
|
|
793
|
+
]
|
|
794
|
+
chunking_runner.project.documents.return_value = [mock_doc]
|
|
795
|
+
|
|
796
|
+
jobs = await chunking_runner.collect_jobs()
|
|
797
|
+
|
|
798
|
+
# Should only create job for extraction1 (matching extractor, no chunks)
|
|
799
|
+
assert len(jobs) == 1
|
|
800
|
+
assert jobs[0].extraction == mock_extraction1
|
|
801
|
+
assert jobs[0].chunker_config == chunking_runner.chunker_config
|
|
802
|
+
|
|
803
|
+
@pytest.mark.asyncio
|
|
804
|
+
async def test_collect_jobs_with_document_ids_filters_documents(
|
|
805
|
+
self, chunking_runner
|
|
806
|
+
):
|
|
807
|
+
# Setup mock documents with specific IDs
|
|
808
|
+
mock_doc1 = MagicMock(spec=Document)
|
|
809
|
+
mock_doc1.id = "doc-1"
|
|
810
|
+
mock_extraction1 = MagicMock(spec=Extraction)
|
|
811
|
+
mock_extraction1.extractor_config_id = "extractor-123"
|
|
812
|
+
mock_extraction1.created_at = datetime(2023, 1, 1)
|
|
813
|
+
mock_extraction1.chunked_documents.return_value = []
|
|
814
|
+
mock_doc1.extractions.return_value = [mock_extraction1]
|
|
815
|
+
|
|
816
|
+
mock_doc2 = MagicMock(spec=Document)
|
|
817
|
+
mock_doc2.id = "doc-2"
|
|
818
|
+
mock_extraction2 = MagicMock(spec=Extraction)
|
|
819
|
+
mock_extraction2.extractor_config_id = "extractor-123"
|
|
820
|
+
mock_extraction2.created_at = datetime(2023, 1, 2)
|
|
821
|
+
mock_extraction2.chunked_documents.return_value = []
|
|
822
|
+
mock_doc2.extractions.return_value = [mock_extraction2]
|
|
823
|
+
|
|
824
|
+
mock_doc3 = MagicMock(spec=Document)
|
|
825
|
+
mock_doc3.id = "doc-3"
|
|
826
|
+
mock_extraction3 = MagicMock(spec=Extraction)
|
|
827
|
+
mock_extraction3.extractor_config_id = "extractor-123"
|
|
828
|
+
mock_extraction3.created_at = datetime(2023, 1, 3)
|
|
829
|
+
mock_extraction3.chunked_documents.return_value = []
|
|
830
|
+
mock_doc3.extractions.return_value = [mock_extraction3]
|
|
831
|
+
|
|
832
|
+
chunking_runner.project.documents.return_value = [
|
|
833
|
+
mock_doc1,
|
|
834
|
+
mock_doc2,
|
|
835
|
+
mock_doc3,
|
|
836
|
+
]
|
|
837
|
+
|
|
838
|
+
# Only process doc-1 and doc-3
|
|
839
|
+
jobs = await chunking_runner.collect_jobs(document_ids=["doc-1", "doc-3"])
|
|
840
|
+
|
|
841
|
+
# Should only create jobs for doc-1 and doc-3
|
|
842
|
+
assert len(jobs) == 2
|
|
843
|
+
job_doc_ids = {job.extraction.extractor_config_id for job in jobs}
|
|
844
|
+
assert job_doc_ids == {
|
|
845
|
+
"extractor-123"
|
|
846
|
+
} # Both should have matching extractor config
|
|
847
|
+
|
|
848
|
+
# Verify the extractions come from the right documents
|
|
849
|
+
extraction_times = {job.extraction.created_at for job in jobs}
|
|
850
|
+
assert extraction_times == {datetime(2023, 1, 1), datetime(2023, 1, 3)}
|
|
851
|
+
|
|
852
|
+
@pytest.mark.asyncio
|
|
853
|
+
async def test_collect_jobs_with_empty_document_ids_processes_all_documents(
|
|
854
|
+
self, chunking_runner
|
|
855
|
+
):
|
|
856
|
+
# Setup mock documents
|
|
857
|
+
mock_doc1 = MagicMock(spec=Document)
|
|
858
|
+
mock_doc1.id = "doc-1"
|
|
859
|
+
mock_extraction1 = MagicMock(spec=Extraction)
|
|
860
|
+
mock_extraction1.extractor_config_id = "extractor-123"
|
|
861
|
+
mock_extraction1.created_at = datetime(2023, 1, 1)
|
|
862
|
+
mock_extraction1.chunked_documents.return_value = []
|
|
863
|
+
mock_doc1.extractions.return_value = [mock_extraction1]
|
|
864
|
+
|
|
865
|
+
mock_doc2 = MagicMock(spec=Document)
|
|
866
|
+
mock_doc2.id = "doc-2"
|
|
867
|
+
mock_extraction2 = MagicMock(spec=Extraction)
|
|
868
|
+
mock_extraction2.extractor_config_id = "extractor-123"
|
|
869
|
+
mock_extraction2.created_at = datetime(2023, 1, 2)
|
|
870
|
+
mock_extraction2.chunked_documents.return_value = []
|
|
871
|
+
mock_doc2.extractions.return_value = [mock_extraction2]
|
|
872
|
+
|
|
873
|
+
chunking_runner.project.documents.return_value = [mock_doc1, mock_doc2]
|
|
874
|
+
|
|
875
|
+
# Empty list should behave like None
|
|
876
|
+
jobs_empty = await chunking_runner.collect_jobs(document_ids=[])
|
|
877
|
+
jobs_none = await chunking_runner.collect_jobs(document_ids=None)
|
|
878
|
+
|
|
879
|
+
# Both should process all documents
|
|
880
|
+
assert len(jobs_empty) == 2
|
|
881
|
+
assert len(jobs_none) == 2
|
|
882
|
+
|
|
883
|
+
# Should have same extraction times
|
|
884
|
+
empty_times = {job.extraction.created_at for job in jobs_empty}
|
|
885
|
+
none_times = {job.extraction.created_at for job in jobs_none}
|
|
886
|
+
assert empty_times == none_times == {datetime(2023, 1, 1), datetime(2023, 1, 2)}
|
|
887
|
+
|
|
888
|
+
|
|
889
|
+
class TestRagEmbeddingStepRunner:
|
|
890
|
+
@pytest.fixture
|
|
891
|
+
def embedding_runner(
|
|
892
|
+
self,
|
|
893
|
+
mock_project,
|
|
894
|
+
mock_extractor_config,
|
|
895
|
+
mock_chunker_config,
|
|
896
|
+
mock_embedding_config,
|
|
897
|
+
):
|
|
898
|
+
return RagEmbeddingStepRunner(
|
|
899
|
+
project=mock_project,
|
|
900
|
+
extractor_config=mock_extractor_config,
|
|
901
|
+
chunker_config=mock_chunker_config,
|
|
902
|
+
embedding_config=mock_embedding_config,
|
|
903
|
+
concurrency=2,
|
|
904
|
+
)
|
|
905
|
+
|
|
906
|
+
def test_stage_returns_embedding(self, embedding_runner):
|
|
907
|
+
assert embedding_runner.stage() == RagWorkflowStepNames.EMBEDDING
|
|
908
|
+
|
|
909
|
+
def test_has_embeddings_returns_true_when_found(
|
|
910
|
+
self, embedding_runner, mock_chunked_document
|
|
911
|
+
):
|
|
912
|
+
# Setup mock chunk embeddings with matching config ID
|
|
913
|
+
mock_chunk_embeddings = MagicMock()
|
|
914
|
+
mock_chunk_embeddings.embedding_config_id = "embedding-123"
|
|
915
|
+
mock_chunked_document.chunk_embeddings.return_value = [mock_chunk_embeddings]
|
|
916
|
+
|
|
917
|
+
result = embedding_runner.has_embeddings(mock_chunked_document, "embedding-123")
|
|
918
|
+
assert result is True
|
|
919
|
+
|
|
920
|
+
def test_has_embeddings_returns_false_when_not_found(
|
|
921
|
+
self, embedding_runner, mock_chunked_document
|
|
922
|
+
):
|
|
923
|
+
# Setup mock chunk embeddings with different config ID
|
|
924
|
+
mock_chunk_embeddings = MagicMock()
|
|
925
|
+
mock_chunk_embeddings.embedding_config_id = "different-embedding"
|
|
926
|
+
mock_chunked_document.chunk_embeddings.return_value = [mock_chunk_embeddings]
|
|
927
|
+
|
|
928
|
+
result = embedding_runner.has_embeddings(mock_chunked_document, "embedding-123")
|
|
929
|
+
assert result is False
|
|
930
|
+
|
|
931
|
+
@pytest.mark.asyncio
|
|
932
|
+
async def test_collect_jobs_returns_jobs_for_chunked_documents_without_embeddings(
|
|
933
|
+
self, embedding_runner
|
|
934
|
+
):
|
|
935
|
+
# Setup mock document with extraction and chunked documents
|
|
936
|
+
mock_doc = MagicMock(spec=Document)
|
|
937
|
+
|
|
938
|
+
mock_extraction = MagicMock(spec=Extraction)
|
|
939
|
+
mock_extraction.extractor_config_id = "extractor-123"
|
|
940
|
+
mock_extraction.created_at = datetime(2023, 1, 1)
|
|
941
|
+
|
|
942
|
+
# Chunked document with matching chunker config but no embeddings
|
|
943
|
+
mock_chunked_doc1 = MagicMock(spec=ChunkedDocument)
|
|
944
|
+
mock_chunked_doc1.chunker_config_id = "chunker-123"
|
|
945
|
+
mock_chunked_doc1.created_at = datetime(2023, 1, 1)
|
|
946
|
+
mock_chunked_doc1.chunk_embeddings.return_value = []
|
|
947
|
+
|
|
948
|
+
# Chunked document with matching chunker config and existing embeddings
|
|
949
|
+
mock_chunked_doc2 = MagicMock(spec=ChunkedDocument)
|
|
950
|
+
mock_chunked_doc2.chunker_config_id = "chunker-123"
|
|
951
|
+
mock_chunked_doc2.created_at = datetime(2023, 1, 2)
|
|
952
|
+
mock_chunk_embeddings = MagicMock()
|
|
953
|
+
mock_chunk_embeddings.embedding_config_id = "embedding-123"
|
|
954
|
+
mock_chunked_doc2.chunk_embeddings.return_value = [mock_chunk_embeddings]
|
|
955
|
+
|
|
956
|
+
mock_extraction.chunked_documents.return_value = [
|
|
957
|
+
mock_chunked_doc1,
|
|
958
|
+
mock_chunked_doc2,
|
|
959
|
+
]
|
|
960
|
+
mock_doc.extractions.return_value = [mock_extraction]
|
|
961
|
+
embedding_runner.project.documents.return_value = [mock_doc]
|
|
962
|
+
|
|
963
|
+
jobs = await embedding_runner.collect_jobs()
|
|
964
|
+
|
|
965
|
+
# Should only create job for chunked_doc1 (matching configs, no embeddings)
|
|
966
|
+
assert len(jobs) == 1
|
|
967
|
+
assert jobs[0].chunked_document == mock_chunked_doc1
|
|
968
|
+
assert jobs[0].embedding_config == embedding_runner.embedding_config
|
|
969
|
+
|
|
970
|
+
@pytest.mark.asyncio
|
|
971
|
+
async def test_collect_jobs_with_document_ids_filters_documents(
|
|
972
|
+
self, embedding_runner
|
|
973
|
+
):
|
|
974
|
+
# Setup mock documents with specific IDs
|
|
975
|
+
mock_doc1 = MagicMock(spec=Document)
|
|
976
|
+
mock_doc1.id = "doc-1"
|
|
977
|
+
mock_extraction1 = MagicMock(spec=Extraction)
|
|
978
|
+
mock_extraction1.extractor_config_id = "extractor-123"
|
|
979
|
+
mock_extraction1.created_at = datetime(2023, 1, 1)
|
|
980
|
+
mock_chunked_doc1 = MagicMock(spec=ChunkedDocument)
|
|
981
|
+
mock_chunked_doc1.chunker_config_id = "chunker-123"
|
|
982
|
+
mock_chunked_doc1.created_at = datetime(2023, 1, 1)
|
|
983
|
+
mock_chunked_doc1.chunk_embeddings.return_value = []
|
|
984
|
+
mock_extraction1.chunked_documents.return_value = [mock_chunked_doc1]
|
|
985
|
+
mock_doc1.extractions.return_value = [mock_extraction1]
|
|
986
|
+
|
|
987
|
+
mock_doc2 = MagicMock(spec=Document)
|
|
988
|
+
mock_doc2.id = "doc-2"
|
|
989
|
+
mock_extraction2 = MagicMock(spec=Extraction)
|
|
990
|
+
mock_extraction2.extractor_config_id = "extractor-123"
|
|
991
|
+
mock_extraction2.created_at = datetime(2023, 1, 2)
|
|
992
|
+
mock_chunked_doc2 = MagicMock(spec=ChunkedDocument)
|
|
993
|
+
mock_chunked_doc2.chunker_config_id = "chunker-123"
|
|
994
|
+
mock_chunked_doc2.created_at = datetime(2023, 1, 2)
|
|
995
|
+
mock_chunked_doc2.chunk_embeddings.return_value = []
|
|
996
|
+
mock_extraction2.chunked_documents.return_value = [mock_chunked_doc2]
|
|
997
|
+
mock_doc2.extractions.return_value = [mock_extraction2]
|
|
998
|
+
|
|
999
|
+
mock_doc3 = MagicMock(spec=Document)
|
|
1000
|
+
mock_doc3.id = "doc-3"
|
|
1001
|
+
mock_extraction3 = MagicMock(spec=Extraction)
|
|
1002
|
+
mock_extraction3.extractor_config_id = "extractor-123"
|
|
1003
|
+
mock_extraction3.created_at = datetime(2023, 1, 3)
|
|
1004
|
+
mock_chunked_doc3 = MagicMock(spec=ChunkedDocument)
|
|
1005
|
+
mock_chunked_doc3.chunker_config_id = "chunker-123"
|
|
1006
|
+
mock_chunked_doc3.created_at = datetime(2023, 1, 3)
|
|
1007
|
+
mock_chunked_doc3.chunk_embeddings.return_value = []
|
|
1008
|
+
mock_extraction3.chunked_documents.return_value = [mock_chunked_doc3]
|
|
1009
|
+
mock_doc3.extractions.return_value = [mock_extraction3]
|
|
1010
|
+
|
|
1011
|
+
embedding_runner.project.documents.return_value = [
|
|
1012
|
+
mock_doc1,
|
|
1013
|
+
mock_doc2,
|
|
1014
|
+
mock_doc3,
|
|
1015
|
+
]
|
|
1016
|
+
|
|
1017
|
+
# Only process doc-1 and doc-3
|
|
1018
|
+
jobs = await embedding_runner.collect_jobs(document_ids=["doc-1", "doc-3"])
|
|
1019
|
+
|
|
1020
|
+
# Should only create jobs for doc-1 and doc-3
|
|
1021
|
+
assert len(jobs) == 2
|
|
1022
|
+
job_doc_times = {job.chunked_document.created_at for job in jobs}
|
|
1023
|
+
assert job_doc_times == {datetime(2023, 1, 1), datetime(2023, 1, 3)}
|
|
1024
|
+
|
|
1025
|
+
@pytest.mark.asyncio
|
|
1026
|
+
async def test_collect_jobs_with_empty_document_ids_processes_all_documents(
|
|
1027
|
+
self, embedding_runner
|
|
1028
|
+
):
|
|
1029
|
+
# Setup mock documents
|
|
1030
|
+
mock_doc1 = MagicMock(spec=Document)
|
|
1031
|
+
mock_doc1.id = "doc-1"
|
|
1032
|
+
mock_extraction1 = MagicMock(spec=Extraction)
|
|
1033
|
+
mock_extraction1.extractor_config_id = "extractor-123"
|
|
1034
|
+
mock_extraction1.created_at = datetime(2023, 1, 1)
|
|
1035
|
+
mock_chunked_doc1 = MagicMock(spec=ChunkedDocument)
|
|
1036
|
+
mock_chunked_doc1.chunker_config_id = "chunker-123"
|
|
1037
|
+
mock_chunked_doc1.created_at = datetime(2023, 1, 1)
|
|
1038
|
+
mock_chunked_doc1.chunk_embeddings.return_value = []
|
|
1039
|
+
mock_extraction1.chunked_documents.return_value = [mock_chunked_doc1]
|
|
1040
|
+
mock_doc1.extractions.return_value = [mock_extraction1]
|
|
1041
|
+
|
|
1042
|
+
mock_doc2 = MagicMock(spec=Document)
|
|
1043
|
+
mock_doc2.id = "doc-2"
|
|
1044
|
+
mock_extraction2 = MagicMock(spec=Extraction)
|
|
1045
|
+
mock_extraction2.extractor_config_id = "extractor-123"
|
|
1046
|
+
mock_extraction2.created_at = datetime(2023, 1, 2)
|
|
1047
|
+
mock_chunked_doc2 = MagicMock(spec=ChunkedDocument)
|
|
1048
|
+
mock_chunked_doc2.chunker_config_id = "chunker-123"
|
|
1049
|
+
mock_chunked_doc2.created_at = datetime(2023, 1, 2)
|
|
1050
|
+
mock_chunked_doc2.chunk_embeddings.return_value = []
|
|
1051
|
+
mock_extraction2.chunked_documents.return_value = [mock_chunked_doc2]
|
|
1052
|
+
mock_doc2.extractions.return_value = [mock_extraction2]
|
|
1053
|
+
|
|
1054
|
+
embedding_runner.project.documents.return_value = [mock_doc1, mock_doc2]
|
|
1055
|
+
|
|
1056
|
+
# Empty list should behave like None
|
|
1057
|
+
jobs_empty = await embedding_runner.collect_jobs(document_ids=[])
|
|
1058
|
+
jobs_none = await embedding_runner.collect_jobs(document_ids=None)
|
|
1059
|
+
|
|
1060
|
+
# Both should process all documents
|
|
1061
|
+
assert len(jobs_empty) == 2
|
|
1062
|
+
assert len(jobs_none) == 2
|
|
1063
|
+
|
|
1064
|
+
# Should have same chunked document times
|
|
1065
|
+
empty_times = {job.chunked_document.created_at for job in jobs_empty}
|
|
1066
|
+
none_times = {job.chunked_document.created_at for job in jobs_none}
|
|
1067
|
+
assert empty_times == none_times == {datetime(2023, 1, 1), datetime(2023, 1, 2)}
|
|
1068
|
+
|
|
1069
|
+
|
|
1070
|
+
class TestRagIndexingStepRunner:
|
|
1071
|
+
@pytest.fixture
|
|
1072
|
+
def indexing_runner(
|
|
1073
|
+
self,
|
|
1074
|
+
mock_project,
|
|
1075
|
+
mock_extractor_config,
|
|
1076
|
+
mock_chunker_config,
|
|
1077
|
+
mock_embedding_config,
|
|
1078
|
+
mock_rag_config,
|
|
1079
|
+
):
|
|
1080
|
+
from kiln_ai.adapters.rag.rag_runners import RagIndexingStepRunner
|
|
1081
|
+
from kiln_ai.datamodel.vector_store import VectorStoreConfig
|
|
1082
|
+
|
|
1083
|
+
# Create a mock vector store config
|
|
1084
|
+
mock_vector_store_config = MagicMock(spec=VectorStoreConfig)
|
|
1085
|
+
mock_vector_store_config.id = "vector-store-123"
|
|
1086
|
+
|
|
1087
|
+
return RagIndexingStepRunner(
|
|
1088
|
+
project=mock_project,
|
|
1089
|
+
extractor_config=mock_extractor_config,
|
|
1090
|
+
chunker_config=mock_chunker_config,
|
|
1091
|
+
embedding_config=mock_embedding_config,
|
|
1092
|
+
vector_store_config=mock_vector_store_config,
|
|
1093
|
+
rag_config=mock_rag_config,
|
|
1094
|
+
concurrency=2,
|
|
1095
|
+
batch_size=5,
|
|
1096
|
+
)
|
|
1097
|
+
|
|
1098
|
+
def test_stage_returns_indexing(self, indexing_runner):
|
|
1099
|
+
assert indexing_runner.stage() == RagWorkflowStepNames.INDEXING
|
|
1100
|
+
|
|
1101
|
+
def test_lock_key_property(self, indexing_runner):
|
|
1102
|
+
expected_key = f"rag:index:{indexing_runner.vector_store_config.id}"
|
|
1103
|
+
assert indexing_runner.lock_key == expected_key
|
|
1104
|
+
|
|
1105
|
+
@pytest.mark.asyncio
|
|
1106
|
+
async def test_collect_records_with_document_ids_filters_documents(
|
|
1107
|
+
self, indexing_runner
|
|
1108
|
+
):
|
|
1109
|
+
# Setup mock documents with specific IDs and complete pipeline data
|
|
1110
|
+
mock_doc1 = MagicMock(spec=Document)
|
|
1111
|
+
mock_doc1.id = "doc-1"
|
|
1112
|
+
mock_extraction1 = MagicMock(spec=Extraction)
|
|
1113
|
+
mock_extraction1.extractor_config_id = "extractor-123"
|
|
1114
|
+
mock_extraction1.created_at = datetime(2023, 1, 1)
|
|
1115
|
+
|
|
1116
|
+
mock_chunked_doc1 = MagicMock(spec=ChunkedDocument)
|
|
1117
|
+
mock_chunked_doc1.chunker_config_id = "chunker-123"
|
|
1118
|
+
mock_chunked_doc1.created_at = datetime(2023, 1, 1)
|
|
1119
|
+
|
|
1120
|
+
mock_chunk_embeddings1 = MagicMock()
|
|
1121
|
+
mock_chunk_embeddings1.embedding_config_id = "embedding-123"
|
|
1122
|
+
mock_chunk_embeddings1.created_at = datetime(2023, 1, 1)
|
|
1123
|
+
mock_chunked_doc1.chunk_embeddings.return_value = [mock_chunk_embeddings1]
|
|
1124
|
+
|
|
1125
|
+
mock_extraction1.chunked_documents.return_value = [mock_chunked_doc1]
|
|
1126
|
+
mock_doc1.extractions.return_value = [mock_extraction1]
|
|
1127
|
+
|
|
1128
|
+
mock_doc2 = MagicMock(spec=Document)
|
|
1129
|
+
mock_doc2.id = "doc-2"
|
|
1130
|
+
mock_extraction2 = MagicMock(spec=Extraction)
|
|
1131
|
+
mock_extraction2.extractor_config_id = "extractor-123"
|
|
1132
|
+
mock_extraction2.created_at = datetime(2023, 1, 2)
|
|
1133
|
+
|
|
1134
|
+
mock_chunked_doc2 = MagicMock(spec=ChunkedDocument)
|
|
1135
|
+
mock_chunked_doc2.chunker_config_id = "chunker-123"
|
|
1136
|
+
mock_chunked_doc2.created_at = datetime(2023, 1, 2)
|
|
1137
|
+
|
|
1138
|
+
mock_chunk_embeddings2 = MagicMock()
|
|
1139
|
+
mock_chunk_embeddings2.embedding_config_id = "embedding-123"
|
|
1140
|
+
mock_chunk_embeddings2.created_at = datetime(2023, 1, 2)
|
|
1141
|
+
mock_chunked_doc2.chunk_embeddings.return_value = [mock_chunk_embeddings2]
|
|
1142
|
+
|
|
1143
|
+
mock_extraction2.chunked_documents.return_value = [mock_chunked_doc2]
|
|
1144
|
+
mock_doc2.extractions.return_value = [mock_extraction2]
|
|
1145
|
+
|
|
1146
|
+
mock_doc3 = MagicMock(spec=Document)
|
|
1147
|
+
mock_doc3.id = "doc-3"
|
|
1148
|
+
mock_extraction3 = MagicMock(spec=Extraction)
|
|
1149
|
+
mock_extraction3.extractor_config_id = "extractor-123"
|
|
1150
|
+
mock_extraction3.created_at = datetime(2023, 1, 3)
|
|
1151
|
+
|
|
1152
|
+
mock_chunked_doc3 = MagicMock(spec=ChunkedDocument)
|
|
1153
|
+
mock_chunked_doc3.chunker_config_id = "chunker-123"
|
|
1154
|
+
mock_chunked_doc3.created_at = datetime(2023, 1, 3)
|
|
1155
|
+
|
|
1156
|
+
mock_chunk_embeddings3 = MagicMock()
|
|
1157
|
+
mock_chunk_embeddings3.embedding_config_id = "embedding-123"
|
|
1158
|
+
mock_chunk_embeddings3.created_at = datetime(2023, 1, 3)
|
|
1159
|
+
mock_chunked_doc3.chunk_embeddings.return_value = [mock_chunk_embeddings3]
|
|
1160
|
+
|
|
1161
|
+
mock_extraction3.chunked_documents.return_value = [mock_chunked_doc3]
|
|
1162
|
+
mock_doc3.extractions.return_value = [mock_extraction3]
|
|
1163
|
+
|
|
1164
|
+
indexing_runner.project.documents.return_value = [
|
|
1165
|
+
mock_doc1,
|
|
1166
|
+
mock_doc2,
|
|
1167
|
+
mock_doc3,
|
|
1168
|
+
]
|
|
1169
|
+
|
|
1170
|
+
# Collect records for doc-1 and doc-3 only
|
|
1171
|
+
collected_records = []
|
|
1172
|
+
async for records in indexing_runner.collect_records(
|
|
1173
|
+
batch_size=10, document_ids=["doc-1", "doc-3"]
|
|
1174
|
+
):
|
|
1175
|
+
collected_records.extend(records)
|
|
1176
|
+
|
|
1177
|
+
# Should only have records for doc-1 and doc-3
|
|
1178
|
+
assert len(collected_records) == 2
|
|
1179
|
+
record_doc_ids = {record.document_id for record in collected_records}
|
|
1180
|
+
assert record_doc_ids == {"doc-1", "doc-3"}
|
|
1181
|
+
|
|
1182
|
+
@pytest.mark.asyncio
|
|
1183
|
+
async def test_collect_records_with_empty_document_ids_processes_all_documents(
|
|
1184
|
+
self, indexing_runner
|
|
1185
|
+
):
|
|
1186
|
+
# Setup mock documents
|
|
1187
|
+
mock_doc1 = MagicMock(spec=Document)
|
|
1188
|
+
mock_doc1.id = "doc-1"
|
|
1189
|
+
mock_extraction1 = MagicMock(spec=Extraction)
|
|
1190
|
+
mock_extraction1.extractor_config_id = "extractor-123"
|
|
1191
|
+
mock_extraction1.created_at = datetime(2023, 1, 1)
|
|
1192
|
+
|
|
1193
|
+
mock_chunked_doc1 = MagicMock(spec=ChunkedDocument)
|
|
1194
|
+
mock_chunked_doc1.chunker_config_id = "chunker-123"
|
|
1195
|
+
mock_chunked_doc1.created_at = datetime(2023, 1, 1)
|
|
1196
|
+
|
|
1197
|
+
mock_chunk_embeddings1 = MagicMock()
|
|
1198
|
+
mock_chunk_embeddings1.embedding_config_id = "embedding-123"
|
|
1199
|
+
mock_chunk_embeddings1.created_at = datetime(2023, 1, 1)
|
|
1200
|
+
mock_chunked_doc1.chunk_embeddings.return_value = [mock_chunk_embeddings1]
|
|
1201
|
+
|
|
1202
|
+
mock_extraction1.chunked_documents.return_value = [mock_chunked_doc1]
|
|
1203
|
+
mock_doc1.extractions.return_value = [mock_extraction1]
|
|
1204
|
+
|
|
1205
|
+
mock_doc2 = MagicMock(spec=Document)
|
|
1206
|
+
mock_doc2.id = "doc-2"
|
|
1207
|
+
mock_extraction2 = MagicMock(spec=Extraction)
|
|
1208
|
+
mock_extraction2.extractor_config_id = "extractor-123"
|
|
1209
|
+
mock_extraction2.created_at = datetime(2023, 1, 2)
|
|
1210
|
+
|
|
1211
|
+
mock_chunked_doc2 = MagicMock(spec=ChunkedDocument)
|
|
1212
|
+
mock_chunked_doc2.chunker_config_id = "chunker-123"
|
|
1213
|
+
mock_chunked_doc2.created_at = datetime(2023, 1, 2)
|
|
1214
|
+
|
|
1215
|
+
mock_chunk_embeddings2 = MagicMock()
|
|
1216
|
+
mock_chunk_embeddings2.embedding_config_id = "embedding-123"
|
|
1217
|
+
mock_chunk_embeddings2.created_at = datetime(2023, 1, 2)
|
|
1218
|
+
mock_chunked_doc2.chunk_embeddings.return_value = [mock_chunk_embeddings2]
|
|
1219
|
+
|
|
1220
|
+
mock_extraction2.chunked_documents.return_value = [mock_chunked_doc2]
|
|
1221
|
+
mock_doc2.extractions.return_value = [mock_extraction2]
|
|
1222
|
+
|
|
1223
|
+
indexing_runner.project.documents.return_value = [mock_doc1, mock_doc2]
|
|
1224
|
+
|
|
1225
|
+
# Empty list should behave like None
|
|
1226
|
+
records_empty = []
|
|
1227
|
+
async for records in indexing_runner.collect_records(
|
|
1228
|
+
batch_size=10, document_ids=[]
|
|
1229
|
+
):
|
|
1230
|
+
records_empty.extend(records)
|
|
1231
|
+
|
|
1232
|
+
records_none = []
|
|
1233
|
+
async for records in indexing_runner.collect_records(
|
|
1234
|
+
batch_size=10, document_ids=None
|
|
1235
|
+
):
|
|
1236
|
+
records_none.extend(records)
|
|
1237
|
+
|
|
1238
|
+
# Both should process all documents
|
|
1239
|
+
assert len(records_empty) == 2
|
|
1240
|
+
assert len(records_none) == 2
|
|
1241
|
+
|
|
1242
|
+
# Should have same document IDs
|
|
1243
|
+
empty_doc_ids = {record.document_id for record in records_empty}
|
|
1244
|
+
none_doc_ids = {record.document_id for record in records_none}
|
|
1245
|
+
assert empty_doc_ids == none_doc_ids == {"doc-1", "doc-2"}
|
|
1246
|
+
|
|
1247
|
+
@pytest.mark.asyncio
|
|
1248
|
+
async def test_count_total_chunks(self, indexing_runner):
|
|
1249
|
+
# Setup mock documents with chunked documents
|
|
1250
|
+
mock_doc = MagicMock(spec=Document)
|
|
1251
|
+
mock_doc.id = "doc-1"
|
|
1252
|
+
mock_extraction = MagicMock(spec=Extraction)
|
|
1253
|
+
mock_extraction.extractor_config_id = "extractor-123"
|
|
1254
|
+
mock_extraction.created_at = datetime(2023, 1, 1)
|
|
1255
|
+
|
|
1256
|
+
mock_chunked_doc = MagicMock(spec=ChunkedDocument)
|
|
1257
|
+
mock_chunked_doc.chunker_config_id = "chunker-123"
|
|
1258
|
+
mock_chunked_doc.created_at = datetime(2023, 1, 1)
|
|
1259
|
+
mock_chunked_doc.chunks = [MagicMock(), MagicMock(), MagicMock()] # 3 chunks
|
|
1260
|
+
|
|
1261
|
+
mock_chunk_embeddings = MagicMock()
|
|
1262
|
+
mock_chunk_embeddings.embedding_config_id = "embedding-123"
|
|
1263
|
+
mock_chunk_embeddings.created_at = datetime(2023, 1, 1)
|
|
1264
|
+
mock_chunked_doc.chunk_embeddings.return_value = [mock_chunk_embeddings]
|
|
1265
|
+
|
|
1266
|
+
mock_extraction.chunked_documents.return_value = [mock_chunked_doc]
|
|
1267
|
+
mock_doc.extractions.return_value = [mock_extraction]
|
|
1268
|
+
|
|
1269
|
+
indexing_runner.project.documents.return_value = [mock_doc]
|
|
1270
|
+
|
|
1271
|
+
total_chunks = await indexing_runner.count_total_chunks()
|
|
1272
|
+
assert total_chunks == 3
|
|
1273
|
+
|
|
1274
|
+
@pytest.mark.asyncio
|
|
1275
|
+
async def test_run_vector_dimensions_inference(self, indexing_runner):
|
|
1276
|
+
# Setup mock documents with embeddings
|
|
1277
|
+
mock_doc = MagicMock(spec=Document)
|
|
1278
|
+
mock_doc.id = "doc-1"
|
|
1279
|
+
mock_extraction = MagicMock(spec=Extraction)
|
|
1280
|
+
mock_extraction.extractor_config_id = "extractor-123"
|
|
1281
|
+
mock_extraction.created_at = datetime(2023, 1, 1)
|
|
1282
|
+
|
|
1283
|
+
mock_chunked_doc = MagicMock(spec=ChunkedDocument)
|
|
1284
|
+
mock_chunked_doc.chunker_config_id = "chunker-123"
|
|
1285
|
+
mock_chunked_doc.created_at = datetime(2023, 1, 1)
|
|
1286
|
+
mock_chunked_doc.chunks = [MagicMock()]
|
|
1287
|
+
|
|
1288
|
+
# Mock embeddings with specific vector dimensions
|
|
1289
|
+
mock_embedding = MagicMock()
|
|
1290
|
+
mock_embedding.vector = [0.1, 0.2, 0.3, 0.4, 0.5] # 5 dimensions
|
|
1291
|
+
mock_chunk_embeddings = MagicMock()
|
|
1292
|
+
mock_chunk_embeddings.embedding_config_id = "embedding-123"
|
|
1293
|
+
mock_chunk_embeddings.embeddings = [mock_embedding]
|
|
1294
|
+
mock_chunked_doc.chunk_embeddings.return_value = [mock_chunk_embeddings]
|
|
1295
|
+
|
|
1296
|
+
mock_extraction.chunked_documents.return_value = [mock_chunked_doc]
|
|
1297
|
+
mock_doc.extractions.return_value = [mock_extraction]
|
|
1298
|
+
|
|
1299
|
+
indexing_runner.project.documents.return_value = [mock_doc]
|
|
1300
|
+
|
|
1301
|
+
with (
|
|
1302
|
+
patch("kiln_ai.utils.lock.shared_async_lock_manager"),
|
|
1303
|
+
patch(
|
|
1304
|
+
"kiln_ai.adapters.rag.rag_runners.vector_store_adapter_for_config",
|
|
1305
|
+
new_callable=AsyncMock,
|
|
1306
|
+
) as mock_vector_store_factory,
|
|
1307
|
+
):
|
|
1308
|
+
mock_vector_store = MagicMock()
|
|
1309
|
+
mock_vector_store.add_chunks_with_embeddings = AsyncMock()
|
|
1310
|
+
mock_vector_store.delete_nodes_not_in_set = AsyncMock()
|
|
1311
|
+
mock_vector_store_factory.return_value = mock_vector_store
|
|
1312
|
+
|
|
1313
|
+
progress_values = []
|
|
1314
|
+
async for progress in indexing_runner.run():
|
|
1315
|
+
progress_values.append(progress)
|
|
1316
|
+
|
|
1317
|
+
# Should create vector store and process records
|
|
1318
|
+
mock_vector_store_factory.assert_called_once_with(
|
|
1319
|
+
indexing_runner.rag_config, indexing_runner.vector_store_config
|
|
1320
|
+
)
|
|
1321
|
+
assert len(progress_values) >= 2 # Initial progress + at least one batch
|
|
1322
|
+
|
|
1323
|
+
@pytest.mark.asyncio
|
|
1324
|
+
async def test_run_successful_indexing_flow(self, indexing_runner):
|
|
1325
|
+
# Setup mock documents with embeddings
|
|
1326
|
+
mock_doc = MagicMock(spec=Document)
|
|
1327
|
+
mock_doc.id = "doc-1"
|
|
1328
|
+
mock_extraction = MagicMock(spec=Extraction)
|
|
1329
|
+
mock_extraction.extractor_config_id = "extractor-123"
|
|
1330
|
+
mock_extraction.created_at = datetime(2023, 1, 1)
|
|
1331
|
+
|
|
1332
|
+
mock_chunked_doc = MagicMock(spec=ChunkedDocument)
|
|
1333
|
+
mock_chunked_doc.chunker_config_id = "chunker-123"
|
|
1334
|
+
mock_chunked_doc.created_at = datetime(2023, 1, 1)
|
|
1335
|
+
mock_chunked_doc.chunks = [MagicMock(), MagicMock()] # 2 chunks
|
|
1336
|
+
|
|
1337
|
+
mock_embedding = MagicMock()
|
|
1338
|
+
mock_embedding.vector = [0.1, 0.2, 0.3] # 3 dimensions
|
|
1339
|
+
mock_chunk_embeddings = MagicMock()
|
|
1340
|
+
mock_chunk_embeddings.embedding_config_id = "embedding-123"
|
|
1341
|
+
mock_chunk_embeddings.embeddings = [mock_embedding]
|
|
1342
|
+
mock_chunked_doc.chunk_embeddings.return_value = [mock_chunk_embeddings]
|
|
1343
|
+
|
|
1344
|
+
mock_extraction.chunked_documents.return_value = [mock_chunked_doc]
|
|
1345
|
+
mock_doc.extractions.return_value = [mock_extraction]
|
|
1346
|
+
|
|
1347
|
+
indexing_runner.project.documents.return_value = [mock_doc]
|
|
1348
|
+
|
|
1349
|
+
with (
|
|
1350
|
+
patch("kiln_ai.utils.lock.shared_async_lock_manager"),
|
|
1351
|
+
patch(
|
|
1352
|
+
"kiln_ai.adapters.rag.rag_runners.vector_store_adapter_for_config",
|
|
1353
|
+
new_callable=AsyncMock,
|
|
1354
|
+
) as mock_vector_store_factory,
|
|
1355
|
+
):
|
|
1356
|
+
mock_vector_store = MagicMock()
|
|
1357
|
+
mock_vector_store.add_chunks_with_embeddings = AsyncMock()
|
|
1358
|
+
mock_vector_store.delete_nodes_not_in_set = AsyncMock()
|
|
1359
|
+
mock_vector_store_factory.return_value = mock_vector_store
|
|
1360
|
+
|
|
1361
|
+
progress_values = []
|
|
1362
|
+
async for progress in indexing_runner.run():
|
|
1363
|
+
progress_values.append(progress)
|
|
1364
|
+
|
|
1365
|
+
# Should yield initial progress and success progress
|
|
1366
|
+
assert len(progress_values) >= 2
|
|
1367
|
+
# Initial progress should have 0 counts
|
|
1368
|
+
assert progress_values[0].success_count == 0
|
|
1369
|
+
assert progress_values[0].error_count == 0
|
|
1370
|
+
# Should have at least one success progress
|
|
1371
|
+
success_progress = [
|
|
1372
|
+
p for p in progress_values if p.success_count and p.success_count > 0
|
|
1373
|
+
]
|
|
1374
|
+
assert len(success_progress) >= 1
|
|
1375
|
+
assert success_progress[0].success_count == 2 # 2 chunks
|
|
1376
|
+
|
|
1377
|
+
@pytest.mark.asyncio
|
|
1378
|
+
async def test_run_error_handling_during_indexing(self, indexing_runner):
|
|
1379
|
+
# Setup mock documents with embeddings
|
|
1380
|
+
mock_doc = MagicMock(spec=Document)
|
|
1381
|
+
mock_doc.id = "doc-1"
|
|
1382
|
+
mock_extraction = MagicMock(spec=Extraction)
|
|
1383
|
+
mock_extraction.extractor_config_id = "extractor-123"
|
|
1384
|
+
mock_extraction.created_at = datetime(2023, 1, 1)
|
|
1385
|
+
|
|
1386
|
+
mock_chunked_doc = MagicMock(spec=ChunkedDocument)
|
|
1387
|
+
mock_chunked_doc.chunker_config_id = "chunker-123"
|
|
1388
|
+
mock_chunked_doc.created_at = datetime(2023, 1, 1)
|
|
1389
|
+
mock_chunked_doc.chunks = [MagicMock(), MagicMock()] # 2 chunks
|
|
1390
|
+
|
|
1391
|
+
mock_embedding = MagicMock()
|
|
1392
|
+
mock_embedding.vector = [0.1, 0.2, 0.3] # 3 dimensions
|
|
1393
|
+
mock_chunk_embeddings = MagicMock()
|
|
1394
|
+
mock_chunk_embeddings.embedding_config_id = "embedding-123"
|
|
1395
|
+
mock_chunk_embeddings.embeddings = [mock_embedding]
|
|
1396
|
+
mock_chunked_doc.chunk_embeddings.return_value = [mock_chunk_embeddings]
|
|
1397
|
+
|
|
1398
|
+
mock_extraction.chunked_documents.return_value = [mock_chunked_doc]
|
|
1399
|
+
mock_doc.extractions.return_value = [mock_extraction]
|
|
1400
|
+
|
|
1401
|
+
indexing_runner.project.documents.return_value = [mock_doc]
|
|
1402
|
+
|
|
1403
|
+
with (
|
|
1404
|
+
patch("kiln_ai.utils.lock.shared_async_lock_manager"),
|
|
1405
|
+
patch(
|
|
1406
|
+
"kiln_ai.adapters.rag.rag_runners.vector_store_adapter_for_config",
|
|
1407
|
+
new_callable=AsyncMock,
|
|
1408
|
+
) as mock_vector_store_factory,
|
|
1409
|
+
):
|
|
1410
|
+
mock_vector_store = MagicMock()
|
|
1411
|
+
# Make the vector store raise an exception
|
|
1412
|
+
mock_vector_store.add_chunks_with_embeddings = AsyncMock(
|
|
1413
|
+
side_effect=Exception("Vector store error")
|
|
1414
|
+
)
|
|
1415
|
+
mock_vector_store.delete_nodes_not_in_set = AsyncMock()
|
|
1416
|
+
mock_vector_store_factory.return_value = mock_vector_store
|
|
1417
|
+
|
|
1418
|
+
progress_values = []
|
|
1419
|
+
async for progress in indexing_runner.run():
|
|
1420
|
+
progress_values.append(progress)
|
|
1421
|
+
|
|
1422
|
+
# Should yield initial progress and error progress
|
|
1423
|
+
assert len(progress_values) >= 2
|
|
1424
|
+
# Should have error progress with logs
|
|
1425
|
+
error_progress = [
|
|
1426
|
+
p for p in progress_values if p.error_count and p.error_count > 0
|
|
1427
|
+
]
|
|
1428
|
+
assert len(error_progress) >= 1
|
|
1429
|
+
assert error_progress[0].error_count == 2 # 2 chunks failed
|
|
1430
|
+
assert len(error_progress[0].logs) > 0
|
|
1431
|
+
assert "error" in error_progress[0].logs[0].level.lower()
|
|
1432
|
+
assert "Vector store error" in error_progress[0].logs[0].message
|
|
1433
|
+
|
|
1434
|
+
@pytest.mark.asyncio
|
|
1435
|
+
async def test_run_calls_delete_nodes_not_in_set_with_all_documents_no_tags(
|
|
1436
|
+
self, indexing_runner
|
|
1437
|
+
):
|
|
1438
|
+
"""Test that delete_nodes_not_in_set is called with all document IDs when no tags are configured"""
|
|
1439
|
+
# Setup mock documents
|
|
1440
|
+
mock_doc1 = MagicMock(spec=Document)
|
|
1441
|
+
mock_doc1.id = "doc-1"
|
|
1442
|
+
mock_doc1.tags = ["tag1", "tag2"]
|
|
1443
|
+
|
|
1444
|
+
mock_doc2 = MagicMock(spec=Document)
|
|
1445
|
+
mock_doc2.id = "doc-2"
|
|
1446
|
+
mock_doc2.tags = ["tag3"]
|
|
1447
|
+
|
|
1448
|
+
mock_doc3 = MagicMock(spec=Document)
|
|
1449
|
+
mock_doc3.id = "doc-3"
|
|
1450
|
+
mock_doc3.tags = None
|
|
1451
|
+
|
|
1452
|
+
all_docs = [mock_doc1, mock_doc2, mock_doc3]
|
|
1453
|
+
|
|
1454
|
+
# Setup complete pipeline data for one document to satisfy vector dimension inference
|
|
1455
|
+
mock_extraction = MagicMock(spec=Extraction)
|
|
1456
|
+
mock_extraction.extractor_config_id = "extractor-123"
|
|
1457
|
+
mock_extraction.created_at = datetime(2023, 1, 1)
|
|
1458
|
+
|
|
1459
|
+
mock_chunked_doc = MagicMock(spec=ChunkedDocument)
|
|
1460
|
+
mock_chunked_doc.chunker_config_id = "chunker-123"
|
|
1461
|
+
mock_chunked_doc.created_at = datetime(2023, 1, 1)
|
|
1462
|
+
mock_chunked_doc.chunks = [MagicMock()]
|
|
1463
|
+
|
|
1464
|
+
mock_embedding = MagicMock()
|
|
1465
|
+
mock_embedding.vector = [0.1, 0.2, 0.3]
|
|
1466
|
+
mock_chunk_embeddings = MagicMock()
|
|
1467
|
+
mock_chunk_embeddings.embedding_config_id = "embedding-123"
|
|
1468
|
+
mock_chunk_embeddings.embeddings = [mock_embedding]
|
|
1469
|
+
mock_chunked_doc.chunk_embeddings.return_value = [mock_chunk_embeddings]
|
|
1470
|
+
|
|
1471
|
+
mock_extraction.chunked_documents.return_value = [mock_chunked_doc]
|
|
1472
|
+
mock_doc1.extractions.return_value = [mock_extraction]
|
|
1473
|
+
mock_doc2.extractions.return_value = []
|
|
1474
|
+
mock_doc3.extractions.return_value = []
|
|
1475
|
+
|
|
1476
|
+
indexing_runner.project.documents.return_value = all_docs
|
|
1477
|
+
|
|
1478
|
+
# Configure no tags in rag_config
|
|
1479
|
+
indexing_runner.rag_config.tags = None
|
|
1480
|
+
|
|
1481
|
+
with (
|
|
1482
|
+
patch("kiln_ai.utils.lock.shared_async_lock_manager"),
|
|
1483
|
+
patch(
|
|
1484
|
+
"kiln_ai.adapters.rag.rag_runners.vector_store_adapter_for_config",
|
|
1485
|
+
new_callable=AsyncMock,
|
|
1486
|
+
) as mock_vector_store_factory,
|
|
1487
|
+
):
|
|
1488
|
+
mock_vector_store = MagicMock()
|
|
1489
|
+
mock_vector_store.add_chunks_with_embeddings = AsyncMock()
|
|
1490
|
+
mock_vector_store.delete_nodes_not_in_set = AsyncMock()
|
|
1491
|
+
mock_vector_store_factory.return_value = mock_vector_store
|
|
1492
|
+
|
|
1493
|
+
# Run the indexing
|
|
1494
|
+
async for _ in indexing_runner.run():
|
|
1495
|
+
pass
|
|
1496
|
+
|
|
1497
|
+
# Verify delete_nodes_not_in_set was called with all document IDs
|
|
1498
|
+
mock_vector_store.delete_nodes_not_in_set.assert_called_once_with(
|
|
1499
|
+
{"doc-1", "doc-2", "doc-3"}
|
|
1500
|
+
)
|
|
1501
|
+
|
|
1502
|
+
@pytest.mark.asyncio
|
|
1503
|
+
async def test_run_calls_delete_nodes_not_in_set_with_tagged_documents_only(
|
|
1504
|
+
self, indexing_runner
|
|
1505
|
+
):
|
|
1506
|
+
"""Test that delete_nodes_not_in_set is called with only tagged document IDs when tags are configured"""
|
|
1507
|
+
# Setup mock documents with different tags
|
|
1508
|
+
mock_doc1 = MagicMock(spec=Document)
|
|
1509
|
+
mock_doc1.id = "doc-1"
|
|
1510
|
+
mock_doc1.tags = ["important", "data"]
|
|
1511
|
+
|
|
1512
|
+
mock_doc2 = MagicMock(spec=Document)
|
|
1513
|
+
mock_doc2.id = "doc-2"
|
|
1514
|
+
mock_doc2.tags = ["important", "test"]
|
|
1515
|
+
|
|
1516
|
+
mock_doc3 = MagicMock(spec=Document)
|
|
1517
|
+
mock_doc3.id = "doc-3"
|
|
1518
|
+
mock_doc3.tags = ["unrelated"]
|
|
1519
|
+
|
|
1520
|
+
mock_doc4 = MagicMock(spec=Document)
|
|
1521
|
+
mock_doc4.id = "doc-4"
|
|
1522
|
+
mock_doc4.tags = None
|
|
1523
|
+
|
|
1524
|
+
all_docs = [mock_doc1, mock_doc2, mock_doc3, mock_doc4]
|
|
1525
|
+
|
|
1526
|
+
# Setup complete pipeline data for one document to satisfy vector dimension inference
|
|
1527
|
+
mock_extraction = MagicMock(spec=Extraction)
|
|
1528
|
+
mock_extraction.extractor_config_id = "extractor-123"
|
|
1529
|
+
mock_extraction.created_at = datetime(2023, 1, 1)
|
|
1530
|
+
|
|
1531
|
+
mock_chunked_doc = MagicMock(spec=ChunkedDocument)
|
|
1532
|
+
mock_chunked_doc.chunker_config_id = "chunker-123"
|
|
1533
|
+
mock_chunked_doc.created_at = datetime(2023, 1, 1)
|
|
1534
|
+
mock_chunked_doc.chunks = [MagicMock()]
|
|
1535
|
+
|
|
1536
|
+
mock_embedding = MagicMock()
|
|
1537
|
+
mock_embedding.vector = [0.1, 0.2, 0.3]
|
|
1538
|
+
mock_chunk_embeddings = MagicMock()
|
|
1539
|
+
mock_chunk_embeddings.embedding_config_id = "embedding-123"
|
|
1540
|
+
mock_chunk_embeddings.embeddings = [mock_embedding]
|
|
1541
|
+
mock_chunked_doc.chunk_embeddings.return_value = [mock_chunk_embeddings]
|
|
1542
|
+
|
|
1543
|
+
mock_extraction.chunked_documents.return_value = [mock_chunked_doc]
|
|
1544
|
+
mock_doc1.extractions.return_value = [mock_extraction]
|
|
1545
|
+
mock_doc2.extractions.return_value = []
|
|
1546
|
+
mock_doc3.extractions.return_value = []
|
|
1547
|
+
mock_doc4.extractions.return_value = []
|
|
1548
|
+
|
|
1549
|
+
indexing_runner.project.documents.return_value = all_docs
|
|
1550
|
+
|
|
1551
|
+
# Configure tags to filter only documents with "important" tag
|
|
1552
|
+
indexing_runner.rag_config.tags = ["important"]
|
|
1553
|
+
|
|
1554
|
+
with (
|
|
1555
|
+
patch("kiln_ai.utils.lock.shared_async_lock_manager"),
|
|
1556
|
+
patch(
|
|
1557
|
+
"kiln_ai.adapters.rag.rag_runners.vector_store_adapter_for_config",
|
|
1558
|
+
new_callable=AsyncMock,
|
|
1559
|
+
) as mock_vector_store_factory,
|
|
1560
|
+
):
|
|
1561
|
+
mock_vector_store = MagicMock()
|
|
1562
|
+
mock_vector_store.add_chunks_with_embeddings = AsyncMock()
|
|
1563
|
+
mock_vector_store.delete_nodes_not_in_set = AsyncMock()
|
|
1564
|
+
mock_vector_store_factory.return_value = mock_vector_store
|
|
1565
|
+
|
|
1566
|
+
# Run the indexing
|
|
1567
|
+
async for _ in indexing_runner.run():
|
|
1568
|
+
pass
|
|
1569
|
+
|
|
1570
|
+
# Verify delete_nodes_not_in_set was called with only "important" tagged document IDs
|
|
1571
|
+
mock_vector_store.delete_nodes_not_in_set.assert_called_once_with(
|
|
1572
|
+
{"doc-1", "doc-2"}
|
|
1573
|
+
)
|
|
1574
|
+
|
|
1575
|
+
@pytest.mark.asyncio
|
|
1576
|
+
async def test_run_raises_error_when_no_documents_match_tags(self, indexing_runner):
|
|
1577
|
+
"""Test that run raises ValueError when no documents match the tag filter"""
|
|
1578
|
+
# Setup mock documents that don't match the configured tags
|
|
1579
|
+
mock_doc1 = MagicMock(spec=Document)
|
|
1580
|
+
mock_doc1.id = "doc-1"
|
|
1581
|
+
mock_doc1.tags = ["tag1"]
|
|
1582
|
+
|
|
1583
|
+
mock_doc2 = MagicMock(spec=Document)
|
|
1584
|
+
mock_doc2.id = "doc-2"
|
|
1585
|
+
mock_doc2.tags = ["tag2"]
|
|
1586
|
+
|
|
1587
|
+
all_docs = [mock_doc1, mock_doc2]
|
|
1588
|
+
|
|
1589
|
+
# Setup complete pipeline data for the documents but they won't match the tag filter
|
|
1590
|
+
mock_extraction = MagicMock(spec=Extraction)
|
|
1591
|
+
mock_extraction.extractor_config_id = "extractor-123"
|
|
1592
|
+
mock_extraction.created_at = datetime(2023, 1, 1)
|
|
1593
|
+
|
|
1594
|
+
mock_chunked_doc = MagicMock(spec=ChunkedDocument)
|
|
1595
|
+
mock_chunked_doc.chunker_config_id = "chunker-123"
|
|
1596
|
+
mock_chunked_doc.created_at = datetime(2023, 1, 1)
|
|
1597
|
+
mock_chunked_doc.chunks = [MagicMock()]
|
|
1598
|
+
|
|
1599
|
+
mock_embedding = MagicMock()
|
|
1600
|
+
mock_embedding.vector = [0.1, 0.2, 0.3]
|
|
1601
|
+
mock_chunk_embeddings = MagicMock()
|
|
1602
|
+
mock_chunk_embeddings.embedding_config_id = "embedding-123"
|
|
1603
|
+
mock_chunk_embeddings.embeddings = [mock_embedding]
|
|
1604
|
+
mock_chunked_doc.chunk_embeddings.return_value = [mock_chunk_embeddings]
|
|
1605
|
+
|
|
1606
|
+
mock_extraction.chunked_documents.return_value = [mock_chunked_doc]
|
|
1607
|
+
mock_doc1.extractions.return_value = [mock_extraction]
|
|
1608
|
+
mock_doc2.extractions.return_value = [mock_extraction]
|
|
1609
|
+
|
|
1610
|
+
indexing_runner.project.documents.return_value = all_docs
|
|
1611
|
+
|
|
1612
|
+
# Configure tags that don't match any documents
|
|
1613
|
+
indexing_runner.rag_config.tags = ["nonexistent_tag"]
|
|
1614
|
+
|
|
1615
|
+
with (
|
|
1616
|
+
patch("kiln_ai.utils.lock.shared_async_lock_manager"),
|
|
1617
|
+
patch(
|
|
1618
|
+
"kiln_ai.adapters.rag.rag_runners.vector_store_adapter_for_config",
|
|
1619
|
+
new_callable=AsyncMock,
|
|
1620
|
+
) as mock_vector_store_factory,
|
|
1621
|
+
):
|
|
1622
|
+
mock_vector_store = MagicMock()
|
|
1623
|
+
mock_vector_store.add_chunks_with_embeddings = AsyncMock()
|
|
1624
|
+
mock_vector_store.delete_nodes_not_in_set = AsyncMock()
|
|
1625
|
+
mock_vector_store_factory.return_value = mock_vector_store
|
|
1626
|
+
|
|
1627
|
+
# Should raise ValueError when no documents match the tag filter
|
|
1628
|
+
with pytest.raises(ValueError, match="Vector dimensions are not set"):
|
|
1629
|
+
async for _ in indexing_runner.run():
|
|
1630
|
+
pass
|
|
1631
|
+
|
|
1632
|
+
# Should not call vector store methods since it fails before that
|
|
1633
|
+
mock_vector_store_factory.assert_not_called()
|
|
1634
|
+
mock_vector_store.delete_nodes_not_in_set.assert_not_called()
|
|
1635
|
+
|
|
1636
|
+
@pytest.mark.asyncio
|
|
1637
|
+
async def test_run_calls_delete_nodes_not_in_set_with_multiple_tag_filters(
|
|
1638
|
+
self, indexing_runner
|
|
1639
|
+
):
|
|
1640
|
+
"""Test that delete_nodes_not_in_set is called with documents matching any of multiple tags"""
|
|
1641
|
+
# Setup mock documents with various tag combinations
|
|
1642
|
+
mock_doc1 = MagicMock(spec=Document)
|
|
1643
|
+
mock_doc1.id = "doc-1"
|
|
1644
|
+
mock_doc1.tags = ["important", "data"]
|
|
1645
|
+
|
|
1646
|
+
mock_doc2 = MagicMock(spec=Document)
|
|
1647
|
+
mock_doc2.id = "doc-2"
|
|
1648
|
+
mock_doc2.tags = ["urgent", "test"]
|
|
1649
|
+
|
|
1650
|
+
mock_doc3 = MagicMock(spec=Document)
|
|
1651
|
+
mock_doc3.id = "doc-3"
|
|
1652
|
+
mock_doc3.tags = ["archive"]
|
|
1653
|
+
|
|
1654
|
+
mock_doc4 = MagicMock(spec=Document)
|
|
1655
|
+
mock_doc4.id = "doc-4"
|
|
1656
|
+
mock_doc4.tags = ["important", "urgent", "critical"]
|
|
1657
|
+
|
|
1658
|
+
all_docs = [mock_doc1, mock_doc2, mock_doc3, mock_doc4]
|
|
1659
|
+
|
|
1660
|
+
# Setup complete pipeline data for one document to satisfy vector dimension inference
|
|
1661
|
+
mock_extraction = MagicMock(spec=Extraction)
|
|
1662
|
+
mock_extraction.extractor_config_id = "extractor-123"
|
|
1663
|
+
mock_extraction.created_at = datetime(2023, 1, 1)
|
|
1664
|
+
|
|
1665
|
+
mock_chunked_doc = MagicMock(spec=ChunkedDocument)
|
|
1666
|
+
mock_chunked_doc.chunker_config_id = "chunker-123"
|
|
1667
|
+
mock_chunked_doc.created_at = datetime(2023, 1, 1)
|
|
1668
|
+
mock_chunked_doc.chunks = [MagicMock()]
|
|
1669
|
+
|
|
1670
|
+
mock_embedding = MagicMock()
|
|
1671
|
+
mock_embedding.vector = [0.1, 0.2, 0.3]
|
|
1672
|
+
mock_chunk_embeddings = MagicMock()
|
|
1673
|
+
mock_chunk_embeddings.embedding_config_id = "embedding-123"
|
|
1674
|
+
mock_chunk_embeddings.embeddings = [mock_embedding]
|
|
1675
|
+
mock_chunked_doc.chunk_embeddings.return_value = [mock_chunk_embeddings]
|
|
1676
|
+
|
|
1677
|
+
mock_extraction.chunked_documents.return_value = [mock_chunked_doc]
|
|
1678
|
+
mock_doc1.extractions.return_value = [mock_extraction]
|
|
1679
|
+
mock_doc2.extractions.return_value = []
|
|
1680
|
+
mock_doc3.extractions.return_value = []
|
|
1681
|
+
mock_doc4.extractions.return_value = []
|
|
1682
|
+
|
|
1683
|
+
indexing_runner.project.documents.return_value = all_docs
|
|
1684
|
+
|
|
1685
|
+
# Configure multiple tags - should match doc-1 (important), doc-2 (urgent), and doc-4 (both)
|
|
1686
|
+
indexing_runner.rag_config.tags = ["important", "urgent"]
|
|
1687
|
+
|
|
1688
|
+
with (
|
|
1689
|
+
patch("kiln_ai.utils.lock.shared_async_lock_manager"),
|
|
1690
|
+
patch(
|
|
1691
|
+
"kiln_ai.adapters.rag.rag_runners.vector_store_adapter_for_config",
|
|
1692
|
+
new_callable=AsyncMock,
|
|
1693
|
+
) as mock_vector_store_factory,
|
|
1694
|
+
):
|
|
1695
|
+
mock_vector_store = MagicMock()
|
|
1696
|
+
mock_vector_store.add_chunks_with_embeddings = AsyncMock()
|
|
1697
|
+
mock_vector_store.delete_nodes_not_in_set = AsyncMock()
|
|
1698
|
+
mock_vector_store_factory.return_value = mock_vector_store
|
|
1699
|
+
|
|
1700
|
+
# Run the indexing
|
|
1701
|
+
async for _ in indexing_runner.run():
|
|
1702
|
+
pass
|
|
1703
|
+
|
|
1704
|
+
# Verify delete_nodes_not_in_set was called with documents having "important" OR "urgent" tags
|
|
1705
|
+
mock_vector_store.delete_nodes_not_in_set.assert_called_once_with(
|
|
1706
|
+
{"doc-1", "doc-2", "doc-4"}
|
|
1707
|
+
)
|
|
1708
|
+
|
|
1709
|
+
|
|
1710
|
+
# Tests for workflow runner
|
|
1711
|
+
class TestRagWorkflowRunner:
|
|
1712
|
+
@pytest.fixture
|
|
1713
|
+
def mock_step_runner(self):
|
|
1714
|
+
runner = MagicMock(spec=RagExtractionStepRunner)
|
|
1715
|
+
runner.stage.return_value = RagWorkflowStepNames.EXTRACTING
|
|
1716
|
+
|
|
1717
|
+
async def mock_run():
|
|
1718
|
+
yield RagStepRunnerProgress(success_count=1, error_count=0)
|
|
1719
|
+
yield RagStepRunnerProgress(success_count=2, error_count=0)
|
|
1720
|
+
|
|
1721
|
+
runner.run.return_value = mock_run()
|
|
1722
|
+
return runner
|
|
1723
|
+
|
|
1724
|
+
@pytest.fixture
|
|
1725
|
+
def workflow_config(
|
|
1726
|
+
self,
|
|
1727
|
+
mock_step_runner,
|
|
1728
|
+
real_rag_config,
|
|
1729
|
+
real_extractor_config,
|
|
1730
|
+
real_chunker_config,
|
|
1731
|
+
real_embedding_config,
|
|
1732
|
+
):
|
|
1733
|
+
return RagWorkflowRunnerConfiguration(
|
|
1734
|
+
step_runners=[mock_step_runner],
|
|
1735
|
+
initial_progress=RagProgress(
|
|
1736
|
+
total_document_count=10,
|
|
1737
|
+
total_document_extracted_count=0,
|
|
1738
|
+
total_document_chunked_count=0,
|
|
1739
|
+
total_document_embedded_count=0,
|
|
1740
|
+
total_document_completed_count=0,
|
|
1741
|
+
total_document_extracted_error_count=0,
|
|
1742
|
+
total_document_chunked_error_count=0,
|
|
1743
|
+
total_document_embedded_error_count=0,
|
|
1744
|
+
logs=[],
|
|
1745
|
+
),
|
|
1746
|
+
rag_config=real_rag_config,
|
|
1747
|
+
extractor_config=real_extractor_config,
|
|
1748
|
+
chunker_config=real_chunker_config,
|
|
1749
|
+
embedding_config=real_embedding_config,
|
|
1750
|
+
)
|
|
1751
|
+
|
|
1752
|
+
@pytest.fixture
|
|
1753
|
+
def workflow_runner(self, mock_project, workflow_config):
|
|
1754
|
+
return RagWorkflowRunner(project=mock_project, configuration=workflow_config)
|
|
1755
|
+
|
|
1756
|
+
def test_lock_key_generation(self, workflow_runner):
|
|
1757
|
+
expected_key = f"rag:run:{workflow_runner.configuration.rag_config.id}"
|
|
1758
|
+
assert workflow_runner.lock_key == expected_key
|
|
1759
|
+
|
|
1760
|
+
def test_update_workflow_progress_extracting(self, workflow_runner):
|
|
1761
|
+
step_progress = RagStepRunnerProgress(success_count=5, error_count=2)
|
|
1762
|
+
|
|
1763
|
+
result = workflow_runner.update_workflow_progress(
|
|
1764
|
+
RagWorkflowStepNames.EXTRACTING, step_progress
|
|
1765
|
+
)
|
|
1766
|
+
|
|
1767
|
+
assert result.total_document_extracted_count == 5
|
|
1768
|
+
assert result.total_document_extracted_error_count == 2
|
|
1769
|
+
|
|
1770
|
+
def test_update_workflow_progress_chunking(self, workflow_runner):
|
|
1771
|
+
step_progress = RagStepRunnerProgress(success_count=3, error_count=1)
|
|
1772
|
+
|
|
1773
|
+
result = workflow_runner.update_workflow_progress(
|
|
1774
|
+
RagWorkflowStepNames.CHUNKING, step_progress
|
|
1775
|
+
)
|
|
1776
|
+
|
|
1777
|
+
assert result.total_document_chunked_count == 3
|
|
1778
|
+
assert result.total_document_chunked_error_count == 1
|
|
1779
|
+
|
|
1780
|
+
def test_update_workflow_progress_embedding(self, workflow_runner):
|
|
1781
|
+
step_progress = RagStepRunnerProgress(success_count=2, error_count=0)
|
|
1782
|
+
|
|
1783
|
+
result = workflow_runner.update_workflow_progress(
|
|
1784
|
+
RagWorkflowStepNames.EMBEDDING, step_progress
|
|
1785
|
+
)
|
|
1786
|
+
|
|
1787
|
+
assert result.total_document_embedded_count == 2
|
|
1788
|
+
assert result.total_document_embedded_error_count == 0
|
|
1789
|
+
|
|
1790
|
+
def test_update_workflow_progress_indexing(self, workflow_runner):
|
|
1791
|
+
step_progress = RagStepRunnerProgress(success_count=10, error_count=2)
|
|
1792
|
+
|
|
1793
|
+
result = workflow_runner.update_workflow_progress(
|
|
1794
|
+
RagWorkflowStepNames.INDEXING, step_progress
|
|
1795
|
+
)
|
|
1796
|
+
|
|
1797
|
+
# For indexing, success_count is added (not max) because it's chunks, not documents
|
|
1798
|
+
assert result.total_chunks_indexed_count == 10
|
|
1799
|
+
assert result.total_chunks_indexed_error_count == 2
|
|
1800
|
+
|
|
1801
|
+
def test_update_workflow_progress_indexing_accumulates_chunks(
|
|
1802
|
+
self, workflow_runner
|
|
1803
|
+
):
|
|
1804
|
+
# First batch of chunks
|
|
1805
|
+
step_progress1 = RagStepRunnerProgress(success_count=5, error_count=0)
|
|
1806
|
+
result1 = workflow_runner.update_workflow_progress(
|
|
1807
|
+
RagWorkflowStepNames.INDEXING, step_progress1
|
|
1808
|
+
)
|
|
1809
|
+
assert result1.total_chunks_indexed_count == 5
|
|
1810
|
+
|
|
1811
|
+
# Second batch of chunks - should accumulate
|
|
1812
|
+
step_progress2 = RagStepRunnerProgress(success_count=3, error_count=1)
|
|
1813
|
+
result2 = workflow_runner.update_workflow_progress(
|
|
1814
|
+
RagWorkflowStepNames.INDEXING, step_progress2
|
|
1815
|
+
)
|
|
1816
|
+
assert result2.total_chunks_indexed_count == 8 # 5 + 3
|
|
1817
|
+
assert result2.total_chunks_indexed_error_count == 1 # max(0, 1)
|
|
1818
|
+
|
|
1819
|
+
def test_update_workflow_progress_unknown_step_raises_error(self, workflow_runner):
|
|
1820
|
+
step_progress = RagStepRunnerProgress(success_count=1, error_count=0)
|
|
1821
|
+
|
|
1822
|
+
with pytest.raises(ValueError, match="Unhandled enum value"):
|
|
1823
|
+
workflow_runner.update_workflow_progress("unknown_step", step_progress)
|
|
1824
|
+
|
|
1825
|
+
def test_update_workflow_progress_calculates_completed_count(self, workflow_runner):
|
|
1826
|
+
# Set different counts for each step
|
|
1827
|
+
workflow_runner.current_progress.total_document_extracted_count = 10
|
|
1828
|
+
workflow_runner.current_progress.total_document_chunked_count = 8
|
|
1829
|
+
workflow_runner.current_progress.total_document_embedded_count = 5
|
|
1830
|
+
workflow_runner.current_progress.total_chunks_indexed_count = 3
|
|
1831
|
+
|
|
1832
|
+
step_progress = RagStepRunnerProgress(success_count=1, error_count=0)
|
|
1833
|
+
result = workflow_runner.update_workflow_progress(
|
|
1834
|
+
RagWorkflowStepNames.EXTRACTING, step_progress
|
|
1835
|
+
)
|
|
1836
|
+
|
|
1837
|
+
# Completed count should be the minimum of all document-related step counts
|
|
1838
|
+
assert result.total_document_completed_count == 5
|
|
1839
|
+
|
|
1840
|
+
# chunks are tracked separately (so we can compare them against the total chunk count
|
|
1841
|
+
# to determine completion)
|
|
1842
|
+
assert result.total_chunk_completed_count == 3
|
|
1843
|
+
|
|
1844
|
+
@pytest.mark.asyncio
|
|
1845
|
+
async def test_run_yields_initial_progress_and_step_progress(self, workflow_runner):
|
|
1846
|
+
with patch("kiln_ai.utils.lock.shared_async_lock_manager"):
|
|
1847
|
+
progress_values = []
|
|
1848
|
+
async for progress in workflow_runner.run():
|
|
1849
|
+
progress_values.append(progress)
|
|
1850
|
+
|
|
1851
|
+
# Should yield initial progress plus progress from step runner
|
|
1852
|
+
assert len(progress_values) >= 1
|
|
1853
|
+
# First progress should be initial progress
|
|
1854
|
+
assert progress_values[0] == workflow_runner.initial_progress
|
|
1855
|
+
|
|
1856
|
+
@pytest.mark.asyncio
|
|
1857
|
+
async def test_run_with_stages_filter(self, workflow_runner):
|
|
1858
|
+
# Add another step runner for chunking
|
|
1859
|
+
chunking_runner = MagicMock(spec=RagChunkingStepRunner)
|
|
1860
|
+
chunking_runner.stage.return_value = RagWorkflowStepNames.CHUNKING
|
|
1861
|
+
|
|
1862
|
+
async def mock_chunking_run():
|
|
1863
|
+
yield RagStepRunnerProgress(success_count=1, error_count=0)
|
|
1864
|
+
|
|
1865
|
+
chunking_runner.run.return_value = mock_chunking_run()
|
|
1866
|
+
workflow_runner.step_runners.append(chunking_runner)
|
|
1867
|
+
|
|
1868
|
+
with patch("kiln_ai.utils.lock.shared_async_lock_manager"):
|
|
1869
|
+
progress_values = []
|
|
1870
|
+
# Only run extracting stage
|
|
1871
|
+
async for progress in workflow_runner.run(
|
|
1872
|
+
stages_to_run=[RagWorkflowStepNames.EXTRACTING]
|
|
1873
|
+
):
|
|
1874
|
+
progress_values.append(progress)
|
|
1875
|
+
|
|
1876
|
+
# Should only execute the extracting runner, not the chunking runner
|
|
1877
|
+
chunking_runner.run.assert_not_called()
|
|
1878
|
+
|
|
1879
|
+
@pytest.mark.asyncio
|
|
1880
|
+
async def test_run_with_document_ids_passes_to_step_runners(self, workflow_runner):
|
|
1881
|
+
# Mock the step runner to capture the document_ids parameter
|
|
1882
|
+
mock_step_runner = workflow_runner.step_runners[0]
|
|
1883
|
+
|
|
1884
|
+
async def mock_run_with_doc_ids(document_ids=None):
|
|
1885
|
+
# Store the document_ids for verification
|
|
1886
|
+
mock_run_with_doc_ids.called_with_document_ids = document_ids
|
|
1887
|
+
yield RagStepRunnerProgress(success_count=1, error_count=0)
|
|
1888
|
+
|
|
1889
|
+
mock_step_runner.run = mock_run_with_doc_ids
|
|
1890
|
+
|
|
1891
|
+
with patch("kiln_ai.utils.lock.shared_async_lock_manager"):
|
|
1892
|
+
# Run with specific document IDs
|
|
1893
|
+
async for _ in workflow_runner.run(document_ids=["doc-1", "doc-2"]):
|
|
1894
|
+
pass
|
|
1895
|
+
|
|
1896
|
+
# Verify the document_ids were passed to the step runner
|
|
1897
|
+
assert mock_run_with_doc_ids.called_with_document_ids == ["doc-1", "doc-2"]
|
|
1898
|
+
|
|
1899
|
+
@pytest.mark.asyncio
|
|
1900
|
+
async def test_run_with_empty_document_ids_passes_empty_list_to_step_runners(
|
|
1901
|
+
self, workflow_runner
|
|
1902
|
+
):
|
|
1903
|
+
# Mock the step runner to capture the document_ids parameter
|
|
1904
|
+
mock_step_runner = workflow_runner.step_runners[0]
|
|
1905
|
+
|
|
1906
|
+
async def mock_run_with_doc_ids(document_ids=None):
|
|
1907
|
+
# Store the document_ids for verification
|
|
1908
|
+
mock_run_with_doc_ids.called_with_document_ids = document_ids
|
|
1909
|
+
yield RagStepRunnerProgress(success_count=1, error_count=0)
|
|
1910
|
+
|
|
1911
|
+
mock_step_runner.run = mock_run_with_doc_ids
|
|
1912
|
+
|
|
1913
|
+
with patch("kiln_ai.utils.lock.shared_async_lock_manager"):
|
|
1914
|
+
# Run with empty document IDs list
|
|
1915
|
+
async for _ in workflow_runner.run(document_ids=[]):
|
|
1916
|
+
pass
|
|
1917
|
+
|
|
1918
|
+
# Verify the empty list was passed to the step runner
|
|
1919
|
+
assert mock_run_with_doc_ids.called_with_document_ids == []
|
|
1920
|
+
|
|
1921
|
+
@pytest.mark.asyncio
|
|
1922
|
+
async def test_run_calls_count_total_chunks_for_indexing_step(
|
|
1923
|
+
self, workflow_runner
|
|
1924
|
+
):
|
|
1925
|
+
# Add an indexing step runner
|
|
1926
|
+
from kiln_ai.adapters.rag.rag_runners import RagIndexingStepRunner
|
|
1927
|
+
|
|
1928
|
+
indexing_runner = MagicMock(spec=RagIndexingStepRunner)
|
|
1929
|
+
indexing_runner.stage.return_value = RagWorkflowStepNames.INDEXING
|
|
1930
|
+
indexing_runner.count_total_chunks = AsyncMock(return_value=42)
|
|
1931
|
+
|
|
1932
|
+
async def mock_indexing_run():
|
|
1933
|
+
yield RagStepRunnerProgress(success_count=1, error_count=0)
|
|
1934
|
+
|
|
1935
|
+
indexing_runner.run.return_value = mock_indexing_run()
|
|
1936
|
+
workflow_runner.step_runners.append(indexing_runner)
|
|
1937
|
+
|
|
1938
|
+
with patch("kiln_ai.utils.lock.shared_async_lock_manager"):
|
|
1939
|
+
progress_values = []
|
|
1940
|
+
async for progress in workflow_runner.run():
|
|
1941
|
+
progress_values.append(progress)
|
|
1942
|
+
|
|
1943
|
+
# Should call count_total_chunks for indexing step
|
|
1944
|
+
indexing_runner.count_total_chunks.assert_called_once()
|
|
1945
|
+
# Should set total_chunk_count in progress
|
|
1946
|
+
assert workflow_runner.current_progress.total_chunk_count == 42
|
|
1947
|
+
|
|
1948
|
+
|
|
1949
|
+
class TestRagWorkflowRunnerConfiguration:
|
|
1950
|
+
def test_configuration_creation(
|
|
1951
|
+
self,
|
|
1952
|
+
real_rag_config,
|
|
1953
|
+
real_extractor_config,
|
|
1954
|
+
real_chunker_config,
|
|
1955
|
+
real_embedding_config,
|
|
1956
|
+
):
|
|
1957
|
+
mock_step_runner = MagicMock(spec=RagExtractionStepRunner)
|
|
1958
|
+
|
|
1959
|
+
config = RagWorkflowRunnerConfiguration(
|
|
1960
|
+
step_runners=[mock_step_runner],
|
|
1961
|
+
initial_progress=RagProgress(),
|
|
1962
|
+
rag_config=real_rag_config,
|
|
1963
|
+
extractor_config=real_extractor_config,
|
|
1964
|
+
chunker_config=real_chunker_config,
|
|
1965
|
+
embedding_config=real_embedding_config,
|
|
1966
|
+
)
|
|
1967
|
+
|
|
1968
|
+
assert config.step_runners == [mock_step_runner]
|
|
1969
|
+
assert config.rag_config == real_rag_config
|
|
1970
|
+
assert config.extractor_config == real_extractor_config
|
|
1971
|
+
assert config.chunker_config == real_chunker_config
|
|
1972
|
+
assert config.embedding_config == real_embedding_config
|
|
1973
|
+
assert isinstance(config.initial_progress, RagProgress)
|
|
1974
|
+
|
|
1975
|
+
def test_configuration_with_initial_progress(
|
|
1976
|
+
self,
|
|
1977
|
+
real_rag_config,
|
|
1978
|
+
real_extractor_config,
|
|
1979
|
+
real_chunker_config,
|
|
1980
|
+
real_embedding_config,
|
|
1981
|
+
):
|
|
1982
|
+
mock_step_runner = MagicMock(spec=RagExtractionStepRunner)
|
|
1983
|
+
initial_progress = RagProgress(
|
|
1984
|
+
total_document_count=5,
|
|
1985
|
+
total_document_extracted_count=1,
|
|
1986
|
+
total_document_chunked_count=0,
|
|
1987
|
+
total_document_embedded_count=0,
|
|
1988
|
+
total_document_completed_count=0,
|
|
1989
|
+
total_document_extracted_error_count=0,
|
|
1990
|
+
total_document_chunked_error_count=0,
|
|
1991
|
+
total_document_embedded_error_count=0,
|
|
1992
|
+
logs=[],
|
|
1993
|
+
)
|
|
1994
|
+
|
|
1995
|
+
config = RagWorkflowRunnerConfiguration(
|
|
1996
|
+
step_runners=[mock_step_runner],
|
|
1997
|
+
initial_progress=initial_progress,
|
|
1998
|
+
rag_config=real_rag_config,
|
|
1999
|
+
extractor_config=real_extractor_config,
|
|
2000
|
+
chunker_config=real_chunker_config,
|
|
2001
|
+
embedding_config=real_embedding_config,
|
|
2002
|
+
)
|
|
2003
|
+
|
|
2004
|
+
assert config.initial_progress == initial_progress
|
|
2005
|
+
|
|
2006
|
+
|
|
2007
|
+
# Integration tests
|
|
2008
|
+
class TestRagWorkflowIntegration:
|
|
2009
|
+
"""Integration tests that test multiple components working together"""
|
|
2010
|
+
|
|
2011
|
+
@pytest.mark.asyncio
|
|
2012
|
+
async def test_end_to_end_extraction_workflow(
|
|
2013
|
+
self, mock_project, mock_extractor_config
|
|
2014
|
+
):
|
|
2015
|
+
# Setup mock documents and project
|
|
2016
|
+
mock_doc1 = MagicMock(spec=Document)
|
|
2017
|
+
mock_doc1.path = Path("doc1.txt")
|
|
2018
|
+
mock_doc1.original_file = MagicMock()
|
|
2019
|
+
mock_doc1.original_file.attachment = MagicMock()
|
|
2020
|
+
mock_doc1.original_file.attachment.resolve_path.return_value = "doc1_path"
|
|
2021
|
+
mock_doc1.original_file.mime_type = "text/plain"
|
|
2022
|
+
mock_doc1.extractions.return_value = []
|
|
2023
|
+
|
|
2024
|
+
mock_project.documents.return_value = [mock_doc1]
|
|
2025
|
+
|
|
2026
|
+
# Create extraction runner
|
|
2027
|
+
runner = RagExtractionStepRunner(
|
|
2028
|
+
project=mock_project, extractor_config=mock_extractor_config, concurrency=1
|
|
2029
|
+
)
|
|
2030
|
+
|
|
2031
|
+
# Mock the necessary adapters and dependencies
|
|
2032
|
+
with (
|
|
2033
|
+
patch(
|
|
2034
|
+
"kiln_ai.adapters.rag.rag_runners.extractor_adapter_from_type"
|
|
2035
|
+
) as mock_adapter_factory,
|
|
2036
|
+
patch(
|
|
2037
|
+
"kiln_ai.adapters.rag.rag_runners.AsyncJobRunner"
|
|
2038
|
+
) as mock_job_runner_class,
|
|
2039
|
+
patch("kiln_ai.utils.lock.shared_async_lock_manager"),
|
|
2040
|
+
):
|
|
2041
|
+
# Setup mock extractor
|
|
2042
|
+
mock_extractor = MagicMock(spec=BaseExtractor)
|
|
2043
|
+
mock_adapter_factory.return_value = mock_extractor
|
|
2044
|
+
|
|
2045
|
+
# Setup mock job runner
|
|
2046
|
+
mock_job_runner = MagicMock()
|
|
2047
|
+
mock_job_runner_class.return_value = mock_job_runner
|
|
2048
|
+
|
|
2049
|
+
async def mock_runner_progress():
|
|
2050
|
+
yield MagicMock(complete=1)
|
|
2051
|
+
|
|
2052
|
+
mock_job_runner.run.return_value = mock_runner_progress()
|
|
2053
|
+
|
|
2054
|
+
# Run the extraction step
|
|
2055
|
+
progress_values = []
|
|
2056
|
+
async for progress in runner.run():
|
|
2057
|
+
progress_values.append(progress)
|
|
2058
|
+
|
|
2059
|
+
# Verify that jobs were collected and runner was created
|
|
2060
|
+
mock_adapter_factory.assert_called_once_with(
|
|
2061
|
+
mock_extractor_config.extractor_type,
|
|
2062
|
+
mock_extractor_config,
|
|
2063
|
+
None,
|
|
2064
|
+
)
|
|
2065
|
+
mock_job_runner_class.assert_called_once()
|
|
2066
|
+
assert len(progress_values) > 0
|
|
2067
|
+
|
|
2068
|
+
|
|
2069
|
+
class TestRagStepRunnersWithTagFiltering:
|
|
2070
|
+
"""Test RAG step runners with document tag filtering"""
|
|
2071
|
+
|
|
2072
|
+
@pytest.mark.asyncio
|
|
2073
|
+
async def test_extraction_runner_with_tag_filter(
|
|
2074
|
+
self, mock_project, mock_extractor_config
|
|
2075
|
+
):
|
|
2076
|
+
"""Test RagExtractionStepRunner filters documents by tags"""
|
|
2077
|
+
# Create documents with different tags
|
|
2078
|
+
doc1 = MagicMock(spec=Document)
|
|
2079
|
+
doc1.id = "doc1"
|
|
2080
|
+
doc1.tags = ["python", "ml"]
|
|
2081
|
+
|
|
2082
|
+
doc2 = MagicMock(spec=Document)
|
|
2083
|
+
doc2.id = "doc2"
|
|
2084
|
+
doc2.tags = ["javascript", "web"]
|
|
2085
|
+
|
|
2086
|
+
doc3 = MagicMock(spec=Document)
|
|
2087
|
+
doc3.id = "doc3"
|
|
2088
|
+
doc3.tags = ["python", "backend"]
|
|
2089
|
+
|
|
2090
|
+
doc4 = MagicMock(spec=Document)
|
|
2091
|
+
doc4.id = "doc4"
|
|
2092
|
+
doc4.tags = None # No tags
|
|
2093
|
+
|
|
2094
|
+
mock_project.documents.return_value = [doc1, doc2, doc3, doc4]
|
|
2095
|
+
|
|
2096
|
+
# Mock that none of the documents have extractions yet
|
|
2097
|
+
for doc in [doc1, doc2, doc3, doc4]:
|
|
2098
|
+
doc.extractions.return_value = []
|
|
2099
|
+
|
|
2100
|
+
# Create RAG config that filters for "python" tags
|
|
2101
|
+
rag_config = MagicMock(spec=RagConfig)
|
|
2102
|
+
rag_config.tags = ["python"]
|
|
2103
|
+
|
|
2104
|
+
runner = RagExtractionStepRunner(
|
|
2105
|
+
mock_project, mock_extractor_config, concurrency=1, rag_config=rag_config
|
|
2106
|
+
)
|
|
2107
|
+
|
|
2108
|
+
jobs = await runner.collect_jobs()
|
|
2109
|
+
|
|
2110
|
+
# Should only create jobs for doc1 and doc3 (have "python" tag)
|
|
2111
|
+
assert len(jobs) == 2
|
|
2112
|
+
job_doc_ids = {job.doc.id for job in jobs}
|
|
2113
|
+
assert "doc1" in job_doc_ids
|
|
2114
|
+
assert "doc3" in job_doc_ids
|
|
2115
|
+
assert "doc2" not in job_doc_ids # javascript tag
|
|
2116
|
+
assert "doc4" not in job_doc_ids # no tags
|
|
2117
|
+
|
|
2118
|
+
@pytest.mark.asyncio
|
|
2119
|
+
async def test_chunking_runner_with_tag_filter(
|
|
2120
|
+
self, mock_project, mock_extractor_config, mock_chunker_config
|
|
2121
|
+
):
|
|
2122
|
+
"""Test RagChunkingStepRunner filters documents by tags"""
|
|
2123
|
+
# Create documents with extractions and different tags
|
|
2124
|
+
doc1 = MagicMock(spec=Document)
|
|
2125
|
+
doc1.id = "doc1"
|
|
2126
|
+
doc1.tags = ["rust", "systems"]
|
|
2127
|
+
extraction1 = MagicMock(spec=Extraction)
|
|
2128
|
+
extraction1.extractor_config_id = mock_extractor_config.id
|
|
2129
|
+
extraction1.created_at = "2024-01-01"
|
|
2130
|
+
extraction1.chunked_documents.return_value = [] # No chunks yet
|
|
2131
|
+
doc1.extractions.return_value = [extraction1]
|
|
2132
|
+
|
|
2133
|
+
doc2 = MagicMock(spec=Document)
|
|
2134
|
+
doc2.id = "doc2"
|
|
2135
|
+
doc2.tags = ["python", "ml"]
|
|
2136
|
+
extraction2 = MagicMock(spec=Extraction)
|
|
2137
|
+
extraction2.extractor_config_id = mock_extractor_config.id
|
|
2138
|
+
extraction2.created_at = "2024-01-02"
|
|
2139
|
+
extraction2.chunked_documents.return_value = [] # No chunks yet
|
|
2140
|
+
doc2.extractions.return_value = [extraction2]
|
|
2141
|
+
|
|
2142
|
+
doc3 = MagicMock(spec=Document)
|
|
2143
|
+
doc3.id = "doc3"
|
|
2144
|
+
doc3.tags = ["rust", "performance"]
|
|
2145
|
+
extraction3 = MagicMock(spec=Extraction)
|
|
2146
|
+
extraction3.extractor_config_id = mock_extractor_config.id
|
|
2147
|
+
extraction3.created_at = "2024-01-03"
|
|
2148
|
+
extraction3.chunked_documents.return_value = [] # No chunks yet
|
|
2149
|
+
doc3.extractions.return_value = [extraction3]
|
|
2150
|
+
|
|
2151
|
+
mock_project.documents.return_value = [doc1, doc2, doc3]
|
|
2152
|
+
|
|
2153
|
+
# Create RAG config that filters for "rust" tags
|
|
2154
|
+
rag_config = MagicMock(spec=RagConfig)
|
|
2155
|
+
rag_config.tags = ["rust"]
|
|
2156
|
+
|
|
2157
|
+
runner = RagChunkingStepRunner(
|
|
2158
|
+
mock_project,
|
|
2159
|
+
mock_extractor_config,
|
|
2160
|
+
mock_chunker_config,
|
|
2161
|
+
concurrency=1,
|
|
2162
|
+
rag_config=rag_config,
|
|
2163
|
+
)
|
|
2164
|
+
|
|
2165
|
+
jobs = await runner.collect_jobs()
|
|
2166
|
+
|
|
2167
|
+
# Should only create jobs for doc1 and doc3 (have "rust" tag)
|
|
2168
|
+
assert len(jobs) == 2
|
|
2169
|
+
job_extraction_docs = {job.extraction.extractor_config_id for job in jobs}
|
|
2170
|
+
assert all(doc_id == mock_extractor_config.id for doc_id in job_extraction_docs)
|
|
2171
|
+
|
|
2172
|
+
@pytest.mark.asyncio
|
|
2173
|
+
async def test_embedding_runner_with_tag_filter(
|
|
2174
|
+
self,
|
|
2175
|
+
mock_project,
|
|
2176
|
+
mock_extractor_config,
|
|
2177
|
+
mock_chunker_config,
|
|
2178
|
+
mock_embedding_config,
|
|
2179
|
+
):
|
|
2180
|
+
"""Test RagEmbeddingStepRunner filters documents by tags"""
|
|
2181
|
+
# Create document with chunked documents and specific tags
|
|
2182
|
+
doc1 = MagicMock(spec=Document)
|
|
2183
|
+
doc1.id = "doc1"
|
|
2184
|
+
doc1.tags = ["go", "backend"]
|
|
2185
|
+
|
|
2186
|
+
chunked_doc1 = MagicMock(spec=ChunkedDocument)
|
|
2187
|
+
chunked_doc1.chunker_config_id = mock_chunker_config.id
|
|
2188
|
+
chunked_doc1.created_at = "2024-01-01"
|
|
2189
|
+
chunked_doc1.chunk_embeddings.return_value = [] # No embeddings yet
|
|
2190
|
+
|
|
2191
|
+
extraction1 = MagicMock(spec=Extraction)
|
|
2192
|
+
extraction1.extractor_config_id = mock_extractor_config.id
|
|
2193
|
+
extraction1.created_at = "2024-01-01"
|
|
2194
|
+
extraction1.chunked_documents.return_value = [chunked_doc1]
|
|
2195
|
+
doc1.extractions.return_value = [extraction1]
|
|
2196
|
+
|
|
2197
|
+
# Document with different tags
|
|
2198
|
+
doc2 = MagicMock(spec=Document)
|
|
2199
|
+
doc2.id = "doc2"
|
|
2200
|
+
doc2.tags = ["python", "web"]
|
|
2201
|
+
|
|
2202
|
+
chunked_doc2 = MagicMock(spec=ChunkedDocument)
|
|
2203
|
+
chunked_doc2.chunker_config_id = mock_chunker_config.id
|
|
2204
|
+
chunked_doc2.created_at = "2024-01-02"
|
|
2205
|
+
chunked_doc2.chunk_embeddings.return_value = [] # No embeddings yet
|
|
2206
|
+
|
|
2207
|
+
extraction2 = MagicMock(spec=Extraction)
|
|
2208
|
+
extraction2.extractor_config_id = mock_extractor_config.id
|
|
2209
|
+
extraction2.created_at = "2024-01-02"
|
|
2210
|
+
extraction2.chunked_documents.return_value = [chunked_doc2]
|
|
2211
|
+
doc2.extractions.return_value = [extraction2]
|
|
2212
|
+
|
|
2213
|
+
mock_project.documents.return_value = [doc1, doc2]
|
|
2214
|
+
|
|
2215
|
+
# Create RAG config that filters for "go" tags
|
|
2216
|
+
rag_config = MagicMock(spec=RagConfig)
|
|
2217
|
+
rag_config.tags = ["go"]
|
|
2218
|
+
|
|
2219
|
+
runner = RagEmbeddingStepRunner(
|
|
2220
|
+
mock_project,
|
|
2221
|
+
mock_extractor_config,
|
|
2222
|
+
mock_chunker_config,
|
|
2223
|
+
mock_embedding_config,
|
|
2224
|
+
concurrency=1,
|
|
2225
|
+
rag_config=rag_config,
|
|
2226
|
+
)
|
|
2227
|
+
|
|
2228
|
+
jobs = await runner.collect_jobs()
|
|
2229
|
+
|
|
2230
|
+
# Should only create job for doc1 (has "go" tag)
|
|
2231
|
+
assert len(jobs) == 1
|
|
2232
|
+
assert jobs[0].chunked_document == chunked_doc1
|
|
2233
|
+
|
|
2234
|
+
@pytest.mark.asyncio
|
|
2235
|
+
async def test_indexing_runner_collect_records_with_tag_filter(
|
|
2236
|
+
self,
|
|
2237
|
+
mock_project,
|
|
2238
|
+
mock_extractor_config,
|
|
2239
|
+
mock_chunker_config,
|
|
2240
|
+
mock_embedding_config,
|
|
2241
|
+
):
|
|
2242
|
+
"""Test RagIndexingStepRunner filters documents by tags"""
|
|
2243
|
+
# Create document with full pipeline and specific tags
|
|
2244
|
+
doc1 = MagicMock(spec=Document)
|
|
2245
|
+
doc1.id = "doc1"
|
|
2246
|
+
doc1.tags = ["typescript", "frontend"]
|
|
2247
|
+
|
|
2248
|
+
chunk_embedding1 = MagicMock()
|
|
2249
|
+
chunk_embedding1.embedding_config_id = mock_embedding_config.id
|
|
2250
|
+
chunk_embedding1.created_at = "2024-01-01"
|
|
2251
|
+
|
|
2252
|
+
chunked_doc1 = MagicMock(spec=ChunkedDocument)
|
|
2253
|
+
chunked_doc1.chunker_config_id = mock_chunker_config.id
|
|
2254
|
+
chunked_doc1.created_at = "2024-01-01"
|
|
2255
|
+
chunked_doc1.chunk_embeddings.return_value = [chunk_embedding1]
|
|
2256
|
+
|
|
2257
|
+
extraction1 = MagicMock(spec=Extraction)
|
|
2258
|
+
extraction1.extractor_config_id = mock_extractor_config.id
|
|
2259
|
+
extraction1.created_at = "2024-01-01"
|
|
2260
|
+
extraction1.chunked_documents.return_value = [chunked_doc1]
|
|
2261
|
+
doc1.extractions.return_value = [extraction1]
|
|
2262
|
+
|
|
2263
|
+
# Document with different tags
|
|
2264
|
+
doc2 = MagicMock(spec=Document)
|
|
2265
|
+
doc2.id = "doc2"
|
|
2266
|
+
doc2.tags = ["java", "enterprise"]
|
|
2267
|
+
|
|
2268
|
+
chunk_embedding2 = MagicMock()
|
|
2269
|
+
chunk_embedding2.embedding_config_id = mock_embedding_config.id
|
|
2270
|
+
chunk_embedding2.created_at = "2024-01-02"
|
|
2271
|
+
|
|
2272
|
+
chunked_doc2 = MagicMock(spec=ChunkedDocument)
|
|
2273
|
+
chunked_doc2.chunker_config_id = mock_chunker_config.id
|
|
2274
|
+
chunked_doc2.created_at = "2024-01-02"
|
|
2275
|
+
chunked_doc2.chunk_embeddings.return_value = [chunk_embedding2]
|
|
2276
|
+
|
|
2277
|
+
extraction2 = MagicMock(spec=Extraction)
|
|
2278
|
+
extraction2.extractor_config_id = mock_extractor_config.id
|
|
2279
|
+
extraction2.created_at = "2024-01-02"
|
|
2280
|
+
extraction2.chunked_documents.return_value = [chunked_doc2]
|
|
2281
|
+
doc2.extractions.return_value = [extraction2]
|
|
2282
|
+
|
|
2283
|
+
mock_project.documents.return_value = [doc1, doc2]
|
|
2284
|
+
|
|
2285
|
+
# Create RAG config that filters for "typescript" tags
|
|
2286
|
+
rag_config = MagicMock(spec=RagConfig)
|
|
2287
|
+
rag_config.tags = ["typescript"]
|
|
2288
|
+
|
|
2289
|
+
# Create mock vector store config
|
|
2290
|
+
mock_vector_store_config = MagicMock()
|
|
2291
|
+
mock_vector_store_config.id = "vector-store-123"
|
|
2292
|
+
|
|
2293
|
+
runner = RagIndexingStepRunner(
|
|
2294
|
+
mock_project,
|
|
2295
|
+
mock_extractor_config,
|
|
2296
|
+
mock_chunker_config,
|
|
2297
|
+
mock_embedding_config,
|
|
2298
|
+
mock_vector_store_config,
|
|
2299
|
+
rag_config,
|
|
2300
|
+
)
|
|
2301
|
+
|
|
2302
|
+
records = []
|
|
2303
|
+
async for record_batch in runner.collect_records(batch_size=10):
|
|
2304
|
+
records.extend(record_batch)
|
|
2305
|
+
|
|
2306
|
+
# Should only collect records for doc1 (has "typescript" tag)
|
|
2307
|
+
assert len(records) == 1
|
|
2308
|
+
assert records[0].document_id == "doc1"
|
|
2309
|
+
|
|
2310
|
+
@pytest.mark.asyncio
|
|
2311
|
+
async def test_step_runners_with_no_tag_filter(
|
|
2312
|
+
self, mock_project, mock_extractor_config
|
|
2313
|
+
):
|
|
2314
|
+
"""Test that step runners work normally when rag_config has no tags"""
|
|
2315
|
+
# Create documents with various tags
|
|
2316
|
+
doc1 = MagicMock(spec=Document)
|
|
2317
|
+
doc1.id = "doc1"
|
|
2318
|
+
doc1.tags = ["python", "ml"]
|
|
2319
|
+
doc1.extractions.return_value = []
|
|
2320
|
+
|
|
2321
|
+
doc2 = MagicMock(spec=Document)
|
|
2322
|
+
doc2.id = "doc2"
|
|
2323
|
+
doc2.tags = ["javascript", "web"]
|
|
2324
|
+
doc2.extractions.return_value = []
|
|
2325
|
+
|
|
2326
|
+
mock_project.documents.return_value = [doc1, doc2]
|
|
2327
|
+
|
|
2328
|
+
# Create RAG config with no tag filter
|
|
2329
|
+
rag_config = MagicMock(spec=RagConfig)
|
|
2330
|
+
rag_config.tags = None
|
|
2331
|
+
|
|
2332
|
+
runner = RagExtractionStepRunner(
|
|
2333
|
+
mock_project, mock_extractor_config, concurrency=1, rag_config=rag_config
|
|
2334
|
+
)
|
|
2335
|
+
|
|
2336
|
+
jobs = await runner.collect_jobs()
|
|
2337
|
+
|
|
2338
|
+
# Should create jobs for all documents
|
|
2339
|
+
assert len(jobs) == 2
|
|
2340
|
+
job_doc_ids = {job.doc.id for job in jobs}
|
|
2341
|
+
assert "doc1" in job_doc_ids
|
|
2342
|
+
assert "doc2" in job_doc_ids
|
|
2343
|
+
|
|
2344
|
+
@pytest.mark.asyncio
|
|
2345
|
+
async def test_step_runners_with_empty_tag_filter(
|
|
2346
|
+
self, mock_project, mock_extractor_config
|
|
2347
|
+
):
|
|
2348
|
+
"""Test that step runners work normally when rag_config has empty tags list"""
|
|
2349
|
+
# Create documents with various tags
|
|
2350
|
+
doc1 = MagicMock(spec=Document)
|
|
2351
|
+
doc1.id = "doc1"
|
|
2352
|
+
doc1.tags = ["python", "ml"]
|
|
2353
|
+
doc1.extractions.return_value = []
|
|
2354
|
+
|
|
2355
|
+
doc2 = MagicMock(spec=Document)
|
|
2356
|
+
doc2.id = "doc2"
|
|
2357
|
+
doc2.tags = ["javascript", "web"]
|
|
2358
|
+
doc2.extractions.return_value = []
|
|
2359
|
+
|
|
2360
|
+
mock_project.documents.return_value = [doc1, doc2]
|
|
2361
|
+
|
|
2362
|
+
# Create RAG config with empty tag filter
|
|
2363
|
+
rag_config = MagicMock(spec=RagConfig)
|
|
2364
|
+
rag_config.tags = []
|
|
2365
|
+
|
|
2366
|
+
runner = RagExtractionStepRunner(
|
|
2367
|
+
mock_project, mock_extractor_config, concurrency=1, rag_config=rag_config
|
|
2368
|
+
)
|
|
2369
|
+
|
|
2370
|
+
jobs = await runner.collect_jobs()
|
|
2371
|
+
|
|
2372
|
+
# Should create jobs for all documents
|
|
2373
|
+
assert len(jobs) == 2
|
|
2374
|
+
job_doc_ids = {job.doc.id for job in jobs}
|
|
2375
|
+
assert "doc1" in job_doc_ids
|
|
2376
|
+
assert "doc2" in job_doc_ids
|