kiln-ai 0.20.1__py3-none-any.whl → 0.22.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +6 -0
- kiln_ai/adapters/adapter_registry.py +43 -226
- kiln_ai/adapters/chunkers/__init__.py +13 -0
- kiln_ai/adapters/chunkers/base_chunker.py +42 -0
- kiln_ai/adapters/chunkers/chunker_registry.py +16 -0
- kiln_ai/adapters/chunkers/fixed_window_chunker.py +39 -0
- kiln_ai/adapters/chunkers/helpers.py +23 -0
- kiln_ai/adapters/chunkers/test_base_chunker.py +63 -0
- kiln_ai/adapters/chunkers/test_chunker_registry.py +28 -0
- kiln_ai/adapters/chunkers/test_fixed_window_chunker.py +346 -0
- kiln_ai/adapters/chunkers/test_helpers.py +75 -0
- kiln_ai/adapters/data_gen/test_data_gen_task.py +9 -3
- kiln_ai/adapters/embedding/__init__.py +0 -0
- kiln_ai/adapters/embedding/base_embedding_adapter.py +44 -0
- kiln_ai/adapters/embedding/embedding_registry.py +32 -0
- kiln_ai/adapters/embedding/litellm_embedding_adapter.py +199 -0
- kiln_ai/adapters/embedding/test_base_embedding_adapter.py +283 -0
- kiln_ai/adapters/embedding/test_embedding_registry.py +166 -0
- kiln_ai/adapters/embedding/test_litellm_embedding_adapter.py +1149 -0
- kiln_ai/adapters/eval/eval_runner.py +6 -2
- kiln_ai/adapters/eval/test_base_eval.py +1 -3
- kiln_ai/adapters/eval/test_g_eval.py +1 -1
- kiln_ai/adapters/extractors/__init__.py +18 -0
- kiln_ai/adapters/extractors/base_extractor.py +72 -0
- kiln_ai/adapters/extractors/encoding.py +20 -0
- kiln_ai/adapters/extractors/extractor_registry.py +44 -0
- kiln_ai/adapters/extractors/extractor_runner.py +112 -0
- kiln_ai/adapters/extractors/litellm_extractor.py +406 -0
- kiln_ai/adapters/extractors/test_base_extractor.py +244 -0
- kiln_ai/adapters/extractors/test_encoding.py +54 -0
- kiln_ai/adapters/extractors/test_extractor_registry.py +181 -0
- kiln_ai/adapters/extractors/test_extractor_runner.py +181 -0
- kiln_ai/adapters/extractors/test_litellm_extractor.py +1290 -0
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +2 -2
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +2 -6
- kiln_ai/adapters/fine_tune/test_together_finetune.py +2 -6
- kiln_ai/adapters/ml_embedding_model_list.py +494 -0
- kiln_ai/adapters/ml_model_list.py +876 -18
- kiln_ai/adapters/model_adapters/litellm_adapter.py +40 -75
- kiln_ai/adapters/model_adapters/test_litellm_adapter.py +79 -1
- kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py +119 -5
- kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +9 -3
- kiln_ai/adapters/model_adapters/test_structured_output.py +9 -10
- kiln_ai/adapters/ollama_tools.py +69 -12
- kiln_ai/adapters/provider_tools.py +190 -46
- kiln_ai/adapters/rag/deduplication.py +49 -0
- kiln_ai/adapters/rag/progress.py +252 -0
- kiln_ai/adapters/rag/rag_runners.py +844 -0
- kiln_ai/adapters/rag/test_deduplication.py +195 -0
- kiln_ai/adapters/rag/test_progress.py +785 -0
- kiln_ai/adapters/rag/test_rag_runners.py +2376 -0
- kiln_ai/adapters/remote_config.py +80 -8
- kiln_ai/adapters/test_adapter_registry.py +579 -86
- kiln_ai/adapters/test_ml_embedding_model_list.py +239 -0
- kiln_ai/adapters/test_ml_model_list.py +202 -0
- kiln_ai/adapters/test_ollama_tools.py +340 -1
- kiln_ai/adapters/test_prompt_builders.py +1 -1
- kiln_ai/adapters/test_provider_tools.py +199 -8
- kiln_ai/adapters/test_remote_config.py +551 -56
- kiln_ai/adapters/vector_store/__init__.py +1 -0
- kiln_ai/adapters/vector_store/base_vector_store_adapter.py +83 -0
- kiln_ai/adapters/vector_store/lancedb_adapter.py +389 -0
- kiln_ai/adapters/vector_store/test_base_vector_store.py +160 -0
- kiln_ai/adapters/vector_store/test_lancedb_adapter.py +1841 -0
- kiln_ai/adapters/vector_store/test_vector_store_registry.py +199 -0
- kiln_ai/adapters/vector_store/vector_store_registry.py +33 -0
- kiln_ai/datamodel/__init__.py +16 -13
- kiln_ai/datamodel/basemodel.py +201 -4
- kiln_ai/datamodel/chunk.py +158 -0
- kiln_ai/datamodel/datamodel_enums.py +27 -0
- kiln_ai/datamodel/embedding.py +64 -0
- kiln_ai/datamodel/external_tool_server.py +206 -54
- kiln_ai/datamodel/extraction.py +317 -0
- kiln_ai/datamodel/project.py +33 -1
- kiln_ai/datamodel/rag.py +79 -0
- kiln_ai/datamodel/task.py +5 -0
- kiln_ai/datamodel/task_output.py +41 -11
- kiln_ai/datamodel/test_attachment.py +649 -0
- kiln_ai/datamodel/test_basemodel.py +270 -14
- kiln_ai/datamodel/test_chunk_models.py +317 -0
- kiln_ai/datamodel/test_dataset_split.py +1 -1
- kiln_ai/datamodel/test_datasource.py +50 -0
- kiln_ai/datamodel/test_embedding_models.py +448 -0
- kiln_ai/datamodel/test_eval_model.py +6 -6
- kiln_ai/datamodel/test_external_tool_server.py +534 -152
- kiln_ai/datamodel/test_extraction_chunk.py +206 -0
- kiln_ai/datamodel/test_extraction_model.py +501 -0
- kiln_ai/datamodel/test_rag.py +641 -0
- kiln_ai/datamodel/test_task.py +35 -1
- kiln_ai/datamodel/test_tool_id.py +187 -1
- kiln_ai/datamodel/test_vector_store.py +320 -0
- kiln_ai/datamodel/tool_id.py +58 -0
- kiln_ai/datamodel/vector_store.py +141 -0
- kiln_ai/tools/base_tool.py +12 -3
- kiln_ai/tools/built_in_tools/math_tools.py +12 -4
- kiln_ai/tools/kiln_task_tool.py +158 -0
- kiln_ai/tools/mcp_server_tool.py +2 -2
- kiln_ai/tools/mcp_session_manager.py +51 -22
- kiln_ai/tools/rag_tools.py +164 -0
- kiln_ai/tools/test_kiln_task_tool.py +527 -0
- kiln_ai/tools/test_mcp_server_tool.py +4 -15
- kiln_ai/tools/test_mcp_session_manager.py +187 -227
- kiln_ai/tools/test_rag_tools.py +929 -0
- kiln_ai/tools/test_tool_registry.py +290 -7
- kiln_ai/tools/tool_registry.py +69 -16
- kiln_ai/utils/__init__.py +3 -0
- kiln_ai/utils/async_job_runner.py +62 -17
- kiln_ai/utils/config.py +2 -2
- kiln_ai/utils/env.py +15 -0
- kiln_ai/utils/filesystem.py +14 -0
- kiln_ai/utils/filesystem_cache.py +60 -0
- kiln_ai/utils/litellm.py +94 -0
- kiln_ai/utils/lock.py +100 -0
- kiln_ai/utils/mime_type.py +38 -0
- kiln_ai/utils/open_ai_types.py +19 -2
- kiln_ai/utils/pdf_utils.py +59 -0
- kiln_ai/utils/test_async_job_runner.py +151 -35
- kiln_ai/utils/test_env.py +142 -0
- kiln_ai/utils/test_filesystem_cache.py +316 -0
- kiln_ai/utils/test_litellm.py +206 -0
- kiln_ai/utils/test_lock.py +185 -0
- kiln_ai/utils/test_mime_type.py +66 -0
- kiln_ai/utils/test_open_ai_types.py +88 -12
- kiln_ai/utils/test_pdf_utils.py +86 -0
- kiln_ai/utils/test_uuid.py +111 -0
- kiln_ai/utils/test_validation.py +524 -0
- kiln_ai/utils/uuid.py +9 -0
- kiln_ai/utils/validation.py +90 -0
- {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/METADATA +9 -1
- kiln_ai-0.22.0.dist-info/RECORD +213 -0
- kiln_ai-0.20.1.dist-info/RECORD +0 -138
- {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -0,0 +1,844 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from enum import Enum
|
|
5
|
+
from typing import AsyncGenerator, Generic, Set, Tuple, TypeVar
|
|
6
|
+
|
|
7
|
+
from kiln_ai.adapters.chunkers.base_chunker import BaseChunker
|
|
8
|
+
from kiln_ai.adapters.chunkers.chunker_registry import chunker_adapter_from_type
|
|
9
|
+
from kiln_ai.adapters.embedding.base_embedding_adapter import BaseEmbeddingAdapter
|
|
10
|
+
from kiln_ai.adapters.embedding.embedding_registry import embedding_adapter_from_type
|
|
11
|
+
from kiln_ai.adapters.extractors.base_extractor import BaseExtractor, ExtractionInput
|
|
12
|
+
from kiln_ai.adapters.extractors.extractor_registry import extractor_adapter_from_type
|
|
13
|
+
from kiln_ai.adapters.rag.deduplication import (
|
|
14
|
+
deduplicate_chunk_embeddings,
|
|
15
|
+
deduplicate_chunked_documents,
|
|
16
|
+
deduplicate_extractions,
|
|
17
|
+
filter_documents_by_tags,
|
|
18
|
+
)
|
|
19
|
+
from kiln_ai.adapters.rag.progress import LogMessage, RagProgress
|
|
20
|
+
from kiln_ai.adapters.vector_store.base_vector_store_adapter import (
|
|
21
|
+
DocumentWithChunksAndEmbeddings,
|
|
22
|
+
)
|
|
23
|
+
from kiln_ai.adapters.vector_store.vector_store_registry import (
|
|
24
|
+
vector_store_adapter_for_config,
|
|
25
|
+
)
|
|
26
|
+
from kiln_ai.datamodel import Project
|
|
27
|
+
from kiln_ai.datamodel.basemodel import ID_TYPE, KilnAttachmentModel
|
|
28
|
+
from kiln_ai.datamodel.chunk import Chunk, ChunkedDocument, ChunkerConfig
|
|
29
|
+
from kiln_ai.datamodel.embedding import ChunkEmbeddings, Embedding, EmbeddingConfig
|
|
30
|
+
from kiln_ai.datamodel.extraction import (
|
|
31
|
+
Document,
|
|
32
|
+
Extraction,
|
|
33
|
+
ExtractionSource,
|
|
34
|
+
ExtractorConfig,
|
|
35
|
+
)
|
|
36
|
+
from kiln_ai.datamodel.rag import RagConfig
|
|
37
|
+
from kiln_ai.datamodel.vector_store import VectorStoreConfig
|
|
38
|
+
from kiln_ai.utils.async_job_runner import AsyncJobRunner, AsyncJobRunnerObserver
|
|
39
|
+
from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
|
|
40
|
+
from kiln_ai.utils.filesystem_cache import FilesystemCache
|
|
41
|
+
from kiln_ai.utils.lock import shared_async_lock_manager
|
|
42
|
+
from pydantic import BaseModel, ConfigDict, Field
|
|
43
|
+
|
|
44
|
+
# We set the timeout high because current UX is likely to cause the user triggering
|
|
45
|
+
# multiple RAG Workflows whose subconfigs (e.g. same extractor) may be shared and take
|
|
46
|
+
# a long time to complete, causing whichever ones are waiting on the lock to time out
|
|
47
|
+
# before they are likely to start.
|
|
48
|
+
LOCK_TIMEOUT_SECONDS = 60 * 60 # 1 hour
|
|
49
|
+
|
|
50
|
+
logger = logging.getLogger(__name__)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@dataclass
|
|
54
|
+
class ExtractorJob:
|
|
55
|
+
doc: Document
|
|
56
|
+
extractor_config: ExtractorConfig
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@dataclass
|
|
60
|
+
class ChunkerJob:
|
|
61
|
+
extraction: Extraction
|
|
62
|
+
chunker_config: ChunkerConfig
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
@dataclass
|
|
66
|
+
class EmbeddingJob:
|
|
67
|
+
chunked_document: ChunkedDocument
|
|
68
|
+
embedding_config: EmbeddingConfig
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class RagStepRunnerProgress(BaseModel):
|
|
72
|
+
success_count: int | None = Field(
|
|
73
|
+
description="The number of items that have been processed",
|
|
74
|
+
default=None,
|
|
75
|
+
)
|
|
76
|
+
error_count: int | None = Field(
|
|
77
|
+
description="The number of items that have errored",
|
|
78
|
+
default=None,
|
|
79
|
+
)
|
|
80
|
+
logs: list[LogMessage] = Field(
|
|
81
|
+
description="A list of log messages to display to the user",
|
|
82
|
+
default_factory=list,
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
T = TypeVar("T")
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class GenericErrorCollector(AsyncJobRunnerObserver[T], Generic[T]):
|
|
90
|
+
def __init__(
|
|
91
|
+
self,
|
|
92
|
+
):
|
|
93
|
+
self.errors: list[Tuple[T, Exception]] = []
|
|
94
|
+
|
|
95
|
+
async def on_success(self, job: T):
|
|
96
|
+
pass
|
|
97
|
+
|
|
98
|
+
async def on_error(self, job: T, error: Exception):
|
|
99
|
+
self.errors.append((job, error))
|
|
100
|
+
|
|
101
|
+
def get_errors(
|
|
102
|
+
self,
|
|
103
|
+
start_idx: int = 0,
|
|
104
|
+
) -> tuple[list[Tuple[T, Exception]], int]:
|
|
105
|
+
"""Returns a tuple of: ((job, error), index of the last error)"""
|
|
106
|
+
if start_idx < 0:
|
|
107
|
+
raise ValueError("start_idx must be non-negative")
|
|
108
|
+
if start_idx >= len(self.errors):
|
|
109
|
+
return [], start_idx
|
|
110
|
+
if start_idx > 0:
|
|
111
|
+
return self.errors[start_idx : len(self.errors)], len(self.errors)
|
|
112
|
+
return self.errors, len(self.errors)
|
|
113
|
+
|
|
114
|
+
def get_error_count(self) -> int:
|
|
115
|
+
return len(self.errors)
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
class RagWorkflowStepNames(str, Enum):
|
|
119
|
+
EXTRACTING = "extracting"
|
|
120
|
+
CHUNKING = "chunking"
|
|
121
|
+
EMBEDDING = "embedding"
|
|
122
|
+
INDEXING = "indexing"
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
async def execute_extractor_job(job: ExtractorJob, extractor: BaseExtractor) -> bool:
|
|
126
|
+
if job.doc.path is None:
|
|
127
|
+
raise ValueError("Document path is not set")
|
|
128
|
+
|
|
129
|
+
output = await extractor.extract(
|
|
130
|
+
extraction_input=ExtractionInput(
|
|
131
|
+
path=job.doc.original_file.attachment.resolve_path(job.doc.path.parent),
|
|
132
|
+
mime_type=job.doc.original_file.mime_type,
|
|
133
|
+
)
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
extraction = Extraction(
|
|
137
|
+
parent=job.doc,
|
|
138
|
+
extractor_config_id=job.extractor_config.id,
|
|
139
|
+
output=KilnAttachmentModel.from_data(
|
|
140
|
+
data=output.content,
|
|
141
|
+
mime_type=output.content_format,
|
|
142
|
+
),
|
|
143
|
+
source=ExtractionSource.PASSTHROUGH
|
|
144
|
+
if output.is_passthrough
|
|
145
|
+
else ExtractionSource.PROCESSED,
|
|
146
|
+
)
|
|
147
|
+
extraction.save_to_file()
|
|
148
|
+
|
|
149
|
+
return True
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
async def execute_chunker_job(job: ChunkerJob, chunker: BaseChunker) -> bool:
|
|
153
|
+
extraction_output_content = await job.extraction.output_content()
|
|
154
|
+
if extraction_output_content is None:
|
|
155
|
+
raise ValueError("Extraction output content is not set")
|
|
156
|
+
|
|
157
|
+
chunking_result = await chunker.chunk(
|
|
158
|
+
extraction_output_content,
|
|
159
|
+
)
|
|
160
|
+
if chunking_result is None:
|
|
161
|
+
raise ValueError("Chunking result is not set")
|
|
162
|
+
|
|
163
|
+
chunked_document = ChunkedDocument(
|
|
164
|
+
parent=job.extraction,
|
|
165
|
+
chunker_config_id=job.chunker_config.id,
|
|
166
|
+
chunks=[
|
|
167
|
+
Chunk(
|
|
168
|
+
content=KilnAttachmentModel.from_data(
|
|
169
|
+
data=chunk.text,
|
|
170
|
+
mime_type="text/plain",
|
|
171
|
+
),
|
|
172
|
+
)
|
|
173
|
+
for chunk in chunking_result.chunks
|
|
174
|
+
],
|
|
175
|
+
)
|
|
176
|
+
chunked_document.save_to_file()
|
|
177
|
+
return True
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
async def execute_embedding_job(
|
|
181
|
+
job: EmbeddingJob, embedding_adapter: BaseEmbeddingAdapter
|
|
182
|
+
) -> bool:
|
|
183
|
+
chunks_text = await job.chunked_document.load_chunks_text()
|
|
184
|
+
if chunks_text is None or len(chunks_text) == 0:
|
|
185
|
+
raise ValueError(
|
|
186
|
+
f"Failed to load chunks for chunked document: {job.chunked_document.id}"
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
chunk_embedding_result = await embedding_adapter.generate_embeddings(
|
|
190
|
+
input_texts=chunks_text
|
|
191
|
+
)
|
|
192
|
+
if chunk_embedding_result is None:
|
|
193
|
+
raise ValueError(
|
|
194
|
+
f"Failed to generate embeddings for chunked document: {job.chunked_document.id}"
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
chunk_embeddings = ChunkEmbeddings(
|
|
198
|
+
parent=job.chunked_document,
|
|
199
|
+
embedding_config_id=job.embedding_config.id,
|
|
200
|
+
embeddings=[
|
|
201
|
+
Embedding(
|
|
202
|
+
vector=embedding.vector,
|
|
203
|
+
)
|
|
204
|
+
for embedding in chunk_embedding_result.embeddings
|
|
205
|
+
],
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
chunk_embeddings.save_to_file()
|
|
209
|
+
return True
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
class AbstractRagStepRunner(ABC):
|
|
213
|
+
@abstractmethod
|
|
214
|
+
def stage(self) -> RagWorkflowStepNames:
|
|
215
|
+
pass
|
|
216
|
+
|
|
217
|
+
# async keyword in the abstract prototype causes a type error in pyright
|
|
218
|
+
# so we need to remove it, but the concrete implementation should declare async
|
|
219
|
+
@abstractmethod
|
|
220
|
+
def run(
|
|
221
|
+
self, document_ids: list[ID_TYPE] | None = None
|
|
222
|
+
) -> AsyncGenerator[RagStepRunnerProgress, None]:
|
|
223
|
+
pass
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
class RagExtractionStepRunner(AbstractRagStepRunner):
|
|
227
|
+
def __init__(
|
|
228
|
+
self,
|
|
229
|
+
project: Project,
|
|
230
|
+
extractor_config: ExtractorConfig,
|
|
231
|
+
concurrency: int = 10,
|
|
232
|
+
rag_config: RagConfig | None = None,
|
|
233
|
+
filesystem_cache: FilesystemCache | None = None,
|
|
234
|
+
):
|
|
235
|
+
self.project = project
|
|
236
|
+
self.extractor_config = extractor_config
|
|
237
|
+
self.lock_key = f"docs:extract:{self.extractor_config.id}"
|
|
238
|
+
self.concurrency = concurrency
|
|
239
|
+
self.rag_config = rag_config
|
|
240
|
+
self.filesystem_cache = filesystem_cache
|
|
241
|
+
|
|
242
|
+
def stage(self) -> RagWorkflowStepNames:
|
|
243
|
+
return RagWorkflowStepNames.EXTRACTING
|
|
244
|
+
|
|
245
|
+
def has_extraction(self, document: Document, extractor_id: ID_TYPE) -> bool:
|
|
246
|
+
for ex in document.extractions(readonly=True):
|
|
247
|
+
if ex.extractor_config_id == extractor_id:
|
|
248
|
+
return True
|
|
249
|
+
return False
|
|
250
|
+
|
|
251
|
+
async def collect_jobs(
|
|
252
|
+
self, document_ids: list[ID_TYPE] | None = None
|
|
253
|
+
) -> list[ExtractorJob]:
|
|
254
|
+
jobs: list[ExtractorJob] = []
|
|
255
|
+
target_extractor_config_id = self.extractor_config.id
|
|
256
|
+
|
|
257
|
+
documents = self.project.documents(readonly=True)
|
|
258
|
+
if self.rag_config and self.rag_config.tags:
|
|
259
|
+
documents = filter_documents_by_tags(documents, self.rag_config.tags)
|
|
260
|
+
|
|
261
|
+
for document in documents:
|
|
262
|
+
if (
|
|
263
|
+
document_ids is not None
|
|
264
|
+
and len(document_ids) > 0
|
|
265
|
+
and document.id not in document_ids
|
|
266
|
+
):
|
|
267
|
+
continue
|
|
268
|
+
if not self.has_extraction(document, target_extractor_config_id):
|
|
269
|
+
jobs.append(
|
|
270
|
+
ExtractorJob(
|
|
271
|
+
doc=document,
|
|
272
|
+
extractor_config=self.extractor_config,
|
|
273
|
+
)
|
|
274
|
+
)
|
|
275
|
+
return jobs
|
|
276
|
+
|
|
277
|
+
async def run(
|
|
278
|
+
self, document_ids: list[ID_TYPE] | None = None
|
|
279
|
+
) -> AsyncGenerator[RagStepRunnerProgress, None]:
|
|
280
|
+
async with shared_async_lock_manager.acquire(
|
|
281
|
+
self.lock_key, timeout=LOCK_TIMEOUT_SECONDS
|
|
282
|
+
):
|
|
283
|
+
jobs = await self.collect_jobs(document_ids=document_ids)
|
|
284
|
+
extractor = extractor_adapter_from_type(
|
|
285
|
+
self.extractor_config.extractor_type,
|
|
286
|
+
self.extractor_config,
|
|
287
|
+
self.filesystem_cache,
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
observer = GenericErrorCollector()
|
|
291
|
+
runner = AsyncJobRunner(
|
|
292
|
+
jobs=jobs,
|
|
293
|
+
run_job_fn=lambda job: execute_extractor_job(job, extractor),
|
|
294
|
+
concurrency=self.concurrency,
|
|
295
|
+
observers=[observer],
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
error_idx = 0
|
|
299
|
+
async for progress in runner.run():
|
|
300
|
+
yield RagStepRunnerProgress(
|
|
301
|
+
success_count=progress.complete,
|
|
302
|
+
error_count=observer.get_error_count(),
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
# the errors are being accumulated in the observer so we need to flush them to the caller
|
|
306
|
+
if observer.get_error_count() > 0:
|
|
307
|
+
errors, error_idx = observer.get_errors(error_idx)
|
|
308
|
+
for job, error in errors:
|
|
309
|
+
yield RagStepRunnerProgress(
|
|
310
|
+
logs=[
|
|
311
|
+
LogMessage(
|
|
312
|
+
level="error",
|
|
313
|
+
message=f"Error extracting document: {job.doc.path}: {error}",
|
|
314
|
+
)
|
|
315
|
+
],
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
class RagChunkingStepRunner(AbstractRagStepRunner):
|
|
320
|
+
def __init__(
|
|
321
|
+
self,
|
|
322
|
+
project: Project,
|
|
323
|
+
extractor_config: ExtractorConfig,
|
|
324
|
+
chunker_config: ChunkerConfig,
|
|
325
|
+
concurrency: int = 10,
|
|
326
|
+
rag_config: RagConfig | None = None,
|
|
327
|
+
):
|
|
328
|
+
self.project = project
|
|
329
|
+
self.extractor_config = extractor_config
|
|
330
|
+
self.chunker_config = chunker_config
|
|
331
|
+
self.lock_key = f"docs:chunk:{self.chunker_config.id}"
|
|
332
|
+
self.concurrency = concurrency
|
|
333
|
+
self.rag_config = rag_config
|
|
334
|
+
|
|
335
|
+
def stage(self) -> RagWorkflowStepNames:
|
|
336
|
+
return RagWorkflowStepNames.CHUNKING
|
|
337
|
+
|
|
338
|
+
def has_chunks(self, extraction: Extraction, chunker_id: ID_TYPE) -> bool:
|
|
339
|
+
for cd in extraction.chunked_documents(readonly=True):
|
|
340
|
+
if cd.chunker_config_id == chunker_id:
|
|
341
|
+
return True
|
|
342
|
+
return False
|
|
343
|
+
|
|
344
|
+
async def collect_jobs(
|
|
345
|
+
self, document_ids: list[ID_TYPE] | None = None
|
|
346
|
+
) -> list[ChunkerJob]:
|
|
347
|
+
target_extractor_config_id = self.extractor_config.id
|
|
348
|
+
target_chunker_config_id = self.chunker_config.id
|
|
349
|
+
|
|
350
|
+
jobs: list[ChunkerJob] = []
|
|
351
|
+
documents = self.project.documents(readonly=True)
|
|
352
|
+
if self.rag_config and self.rag_config.tags:
|
|
353
|
+
documents = filter_documents_by_tags(documents, self.rag_config.tags)
|
|
354
|
+
|
|
355
|
+
for document in documents:
|
|
356
|
+
if (
|
|
357
|
+
document_ids is not None
|
|
358
|
+
and len(document_ids) > 0
|
|
359
|
+
and document.id not in document_ids
|
|
360
|
+
):
|
|
361
|
+
continue
|
|
362
|
+
for extraction in deduplicate_extractions(
|
|
363
|
+
document.extractions(readonly=True)
|
|
364
|
+
):
|
|
365
|
+
if extraction.extractor_config_id == target_extractor_config_id:
|
|
366
|
+
if not self.has_chunks(extraction, target_chunker_config_id):
|
|
367
|
+
jobs.append(
|
|
368
|
+
ChunkerJob(
|
|
369
|
+
extraction=extraction,
|
|
370
|
+
chunker_config=self.chunker_config,
|
|
371
|
+
)
|
|
372
|
+
)
|
|
373
|
+
return jobs
|
|
374
|
+
|
|
375
|
+
async def run(
|
|
376
|
+
self, document_ids: list[ID_TYPE] | None = None
|
|
377
|
+
) -> AsyncGenerator[RagStepRunnerProgress, None]:
|
|
378
|
+
async with shared_async_lock_manager.acquire(
|
|
379
|
+
self.lock_key, timeout=LOCK_TIMEOUT_SECONDS
|
|
380
|
+
):
|
|
381
|
+
jobs = await self.collect_jobs(document_ids=document_ids)
|
|
382
|
+
chunker = chunker_adapter_from_type(
|
|
383
|
+
self.chunker_config.chunker_type,
|
|
384
|
+
self.chunker_config,
|
|
385
|
+
)
|
|
386
|
+
observer = GenericErrorCollector()
|
|
387
|
+
runner = AsyncJobRunner(
|
|
388
|
+
jobs=jobs,
|
|
389
|
+
run_job_fn=lambda job: execute_chunker_job(job, chunker),
|
|
390
|
+
concurrency=self.concurrency,
|
|
391
|
+
observers=[observer],
|
|
392
|
+
)
|
|
393
|
+
|
|
394
|
+
error_idx = 0
|
|
395
|
+
async for progress in runner.run():
|
|
396
|
+
yield RagStepRunnerProgress(
|
|
397
|
+
success_count=progress.complete,
|
|
398
|
+
error_count=observer.get_error_count(),
|
|
399
|
+
)
|
|
400
|
+
|
|
401
|
+
# the errors are being accumulated in the observer so we need to flush them to the caller
|
|
402
|
+
if observer.get_error_count() > 0:
|
|
403
|
+
errors, error_idx = observer.get_errors(error_idx)
|
|
404
|
+
for job, error in errors:
|
|
405
|
+
yield RagStepRunnerProgress(
|
|
406
|
+
logs=[
|
|
407
|
+
LogMessage(
|
|
408
|
+
level="error",
|
|
409
|
+
message=f"Error chunking document: {job.extraction.path}: {error}",
|
|
410
|
+
)
|
|
411
|
+
],
|
|
412
|
+
)
|
|
413
|
+
|
|
414
|
+
|
|
415
|
+
class RagEmbeddingStepRunner(AbstractRagStepRunner):
|
|
416
|
+
def __init__(
|
|
417
|
+
self,
|
|
418
|
+
project: Project,
|
|
419
|
+
extractor_config: ExtractorConfig,
|
|
420
|
+
chunker_config: ChunkerConfig,
|
|
421
|
+
embedding_config: EmbeddingConfig,
|
|
422
|
+
concurrency: int = 10,
|
|
423
|
+
rag_config: RagConfig | None = None,
|
|
424
|
+
):
|
|
425
|
+
self.project = project
|
|
426
|
+
self.extractor_config = extractor_config
|
|
427
|
+
self.chunker_config = chunker_config
|
|
428
|
+
self.embedding_config = embedding_config
|
|
429
|
+
self.concurrency = concurrency
|
|
430
|
+
self.rag_config = rag_config
|
|
431
|
+
self.lock_key = f"docs:embedding:{self.embedding_config.id}"
|
|
432
|
+
|
|
433
|
+
def stage(self) -> RagWorkflowStepNames:
|
|
434
|
+
return RagWorkflowStepNames.EMBEDDING
|
|
435
|
+
|
|
436
|
+
def has_embeddings(self, chunked: ChunkedDocument, embedding_id: ID_TYPE) -> bool:
|
|
437
|
+
for emb in chunked.chunk_embeddings(readonly=True):
|
|
438
|
+
if emb.embedding_config_id == embedding_id:
|
|
439
|
+
return True
|
|
440
|
+
return False
|
|
441
|
+
|
|
442
|
+
async def collect_jobs(
|
|
443
|
+
self, document_ids: list[ID_TYPE] | None = None
|
|
444
|
+
) -> list[EmbeddingJob]:
|
|
445
|
+
target_extractor_config_id = self.extractor_config.id
|
|
446
|
+
target_chunker_config_id = self.chunker_config.id
|
|
447
|
+
target_embedding_config_id = self.embedding_config.id
|
|
448
|
+
|
|
449
|
+
jobs: list[EmbeddingJob] = []
|
|
450
|
+
documents = self.project.documents(readonly=True)
|
|
451
|
+
if self.rag_config and self.rag_config.tags:
|
|
452
|
+
documents = filter_documents_by_tags(documents, self.rag_config.tags)
|
|
453
|
+
|
|
454
|
+
for document in documents:
|
|
455
|
+
if (
|
|
456
|
+
document_ids is not None
|
|
457
|
+
and len(document_ids) > 0
|
|
458
|
+
and document.id not in document_ids
|
|
459
|
+
):
|
|
460
|
+
continue
|
|
461
|
+
for extraction in deduplicate_extractions(
|
|
462
|
+
document.extractions(readonly=True)
|
|
463
|
+
):
|
|
464
|
+
if extraction.extractor_config_id == target_extractor_config_id:
|
|
465
|
+
for chunked_document in deduplicate_chunked_documents(
|
|
466
|
+
extraction.chunked_documents(readonly=True)
|
|
467
|
+
):
|
|
468
|
+
if (
|
|
469
|
+
chunked_document.chunker_config_id
|
|
470
|
+
== target_chunker_config_id
|
|
471
|
+
):
|
|
472
|
+
if not self.has_embeddings(
|
|
473
|
+
chunked_document, target_embedding_config_id
|
|
474
|
+
):
|
|
475
|
+
jobs.append(
|
|
476
|
+
EmbeddingJob(
|
|
477
|
+
chunked_document=chunked_document,
|
|
478
|
+
embedding_config=self.embedding_config,
|
|
479
|
+
)
|
|
480
|
+
)
|
|
481
|
+
return jobs
|
|
482
|
+
|
|
483
|
+
async def run(
|
|
484
|
+
self, document_ids: list[ID_TYPE] | None = None
|
|
485
|
+
) -> AsyncGenerator[RagStepRunnerProgress, None]:
|
|
486
|
+
async with shared_async_lock_manager.acquire(
|
|
487
|
+
self.lock_key, timeout=LOCK_TIMEOUT_SECONDS
|
|
488
|
+
):
|
|
489
|
+
jobs = await self.collect_jobs(document_ids=document_ids)
|
|
490
|
+
embedding_adapter = embedding_adapter_from_type(
|
|
491
|
+
self.embedding_config,
|
|
492
|
+
)
|
|
493
|
+
|
|
494
|
+
observer = GenericErrorCollector()
|
|
495
|
+
runner = AsyncJobRunner(
|
|
496
|
+
jobs=jobs,
|
|
497
|
+
run_job_fn=lambda job: execute_embedding_job(job, embedding_adapter),
|
|
498
|
+
concurrency=self.concurrency,
|
|
499
|
+
observers=[observer],
|
|
500
|
+
)
|
|
501
|
+
|
|
502
|
+
error_idx = 0
|
|
503
|
+
async for progress in runner.run():
|
|
504
|
+
yield RagStepRunnerProgress(
|
|
505
|
+
success_count=progress.complete,
|
|
506
|
+
error_count=observer.get_error_count(),
|
|
507
|
+
)
|
|
508
|
+
|
|
509
|
+
# the errors are being accumulated in the observer so we need to flush them to the caller
|
|
510
|
+
if observer.get_error_count() > 0:
|
|
511
|
+
errors, error_idx = observer.get_errors(error_idx)
|
|
512
|
+
for job, error in errors:
|
|
513
|
+
yield RagStepRunnerProgress(
|
|
514
|
+
logs=[
|
|
515
|
+
LogMessage(
|
|
516
|
+
level="error",
|
|
517
|
+
message=f"Error embedding document: {job.chunked_document.path}: {error}",
|
|
518
|
+
)
|
|
519
|
+
],
|
|
520
|
+
)
|
|
521
|
+
|
|
522
|
+
|
|
523
|
+
class RagIndexingStepRunner(AbstractRagStepRunner):
|
|
524
|
+
def __init__(
|
|
525
|
+
self,
|
|
526
|
+
project: Project,
|
|
527
|
+
extractor_config: ExtractorConfig,
|
|
528
|
+
chunker_config: ChunkerConfig,
|
|
529
|
+
embedding_config: EmbeddingConfig,
|
|
530
|
+
vector_store_config: VectorStoreConfig,
|
|
531
|
+
rag_config: RagConfig,
|
|
532
|
+
concurrency: int = 10,
|
|
533
|
+
batch_size: int = 20,
|
|
534
|
+
):
|
|
535
|
+
self.project = project
|
|
536
|
+
self.extractor_config = extractor_config
|
|
537
|
+
self.chunker_config = chunker_config
|
|
538
|
+
self.embedding_config = embedding_config
|
|
539
|
+
self.vector_store_config = vector_store_config
|
|
540
|
+
self.rag_config = rag_config
|
|
541
|
+
self.concurrency = concurrency
|
|
542
|
+
self.batch_size = batch_size
|
|
543
|
+
|
|
544
|
+
@property
|
|
545
|
+
def lock_key(self) -> str:
|
|
546
|
+
return f"rag:index:{self.vector_store_config.id}"
|
|
547
|
+
|
|
548
|
+
def stage(self) -> RagWorkflowStepNames:
|
|
549
|
+
return RagWorkflowStepNames.INDEXING
|
|
550
|
+
|
|
551
|
+
async def collect_records(
|
|
552
|
+
self,
|
|
553
|
+
batch_size: int,
|
|
554
|
+
document_ids: list[ID_TYPE] | None = None,
|
|
555
|
+
) -> AsyncGenerator[list[DocumentWithChunksAndEmbeddings], None]:
|
|
556
|
+
target_extractor_config_id = self.extractor_config.id
|
|
557
|
+
target_chunker_config_id = self.chunker_config.id
|
|
558
|
+
target_embedding_config_id = self.embedding_config.id
|
|
559
|
+
|
|
560
|
+
# (document_id, chunked_document, embedding)
|
|
561
|
+
jobs: list[DocumentWithChunksAndEmbeddings] = []
|
|
562
|
+
documents = self.project.documents(readonly=True)
|
|
563
|
+
if self.rag_config and self.rag_config.tags:
|
|
564
|
+
documents = filter_documents_by_tags(documents, self.rag_config.tags)
|
|
565
|
+
|
|
566
|
+
for document in documents:
|
|
567
|
+
if (
|
|
568
|
+
document_ids is not None
|
|
569
|
+
and len(document_ids) > 0
|
|
570
|
+
and document.id not in document_ids
|
|
571
|
+
):
|
|
572
|
+
continue
|
|
573
|
+
for extraction in deduplicate_extractions(
|
|
574
|
+
document.extractions(readonly=True)
|
|
575
|
+
):
|
|
576
|
+
if extraction.extractor_config_id == target_extractor_config_id:
|
|
577
|
+
for chunked_document in deduplicate_chunked_documents(
|
|
578
|
+
extraction.chunked_documents(readonly=True)
|
|
579
|
+
):
|
|
580
|
+
if (
|
|
581
|
+
chunked_document.chunker_config_id
|
|
582
|
+
== target_chunker_config_id
|
|
583
|
+
):
|
|
584
|
+
for chunk_embedding in deduplicate_chunk_embeddings(
|
|
585
|
+
chunked_document.chunk_embeddings(readonly=True)
|
|
586
|
+
):
|
|
587
|
+
if (
|
|
588
|
+
chunk_embedding.embedding_config_id
|
|
589
|
+
== target_embedding_config_id
|
|
590
|
+
):
|
|
591
|
+
jobs.append(
|
|
592
|
+
DocumentWithChunksAndEmbeddings(
|
|
593
|
+
document_id=str(document.id),
|
|
594
|
+
chunked_document=chunked_document,
|
|
595
|
+
chunk_embeddings=chunk_embedding,
|
|
596
|
+
)
|
|
597
|
+
)
|
|
598
|
+
|
|
599
|
+
if len(jobs) >= batch_size:
|
|
600
|
+
yield jobs
|
|
601
|
+
jobs.clear()
|
|
602
|
+
|
|
603
|
+
if len(jobs) > 0:
|
|
604
|
+
yield jobs
|
|
605
|
+
jobs.clear()
|
|
606
|
+
|
|
607
|
+
async def count_total_chunks(self) -> int:
|
|
608
|
+
total_chunk_count = 0
|
|
609
|
+
async for documents in self.collect_records(batch_size=1):
|
|
610
|
+
total_chunk_count += len(documents[0].chunks)
|
|
611
|
+
return total_chunk_count
|
|
612
|
+
|
|
613
|
+
def get_all_target_document_ids(self) -> Set[str]:
|
|
614
|
+
documents = self.project.documents(readonly=True)
|
|
615
|
+
if self.rag_config and self.rag_config.tags:
|
|
616
|
+
documents = filter_documents_by_tags(documents, self.rag_config.tags)
|
|
617
|
+
return {str(document.id) for document in documents}
|
|
618
|
+
|
|
619
|
+
async def run(
|
|
620
|
+
self, document_ids: list[ID_TYPE] | None = None
|
|
621
|
+
) -> AsyncGenerator[RagStepRunnerProgress, None]:
|
|
622
|
+
async with shared_async_lock_manager.acquire(
|
|
623
|
+
self.lock_key, timeout=LOCK_TIMEOUT_SECONDS
|
|
624
|
+
):
|
|
625
|
+
vector_dimensions: int | None = None
|
|
626
|
+
|
|
627
|
+
# infer dimensionality - we peek into the first record to get the vector dimensions
|
|
628
|
+
# vector dimensions are not stored in the config because they are derived from the model
|
|
629
|
+
# and in some cases dynamic shortening of the vector (called Matryoshka Representation Learning)
|
|
630
|
+
async for doc_batch in self.collect_records(
|
|
631
|
+
batch_size=1,
|
|
632
|
+
):
|
|
633
|
+
if len(doc_batch) == 0:
|
|
634
|
+
# there are no records, because there may be nothing in the upstream steps at all yet
|
|
635
|
+
return
|
|
636
|
+
else:
|
|
637
|
+
doc = doc_batch[0]
|
|
638
|
+
embedding = doc.embeddings[0]
|
|
639
|
+
vector_dimensions = len(embedding.vector)
|
|
640
|
+
break
|
|
641
|
+
|
|
642
|
+
if vector_dimensions is None:
|
|
643
|
+
raise ValueError("Vector dimensions are not set")
|
|
644
|
+
|
|
645
|
+
vector_store = await vector_store_adapter_for_config(
|
|
646
|
+
self.rag_config,
|
|
647
|
+
self.vector_store_config,
|
|
648
|
+
)
|
|
649
|
+
|
|
650
|
+
yield RagStepRunnerProgress(
|
|
651
|
+
success_count=0,
|
|
652
|
+
error_count=0,
|
|
653
|
+
)
|
|
654
|
+
|
|
655
|
+
async for doc_batch in self.collect_records(
|
|
656
|
+
batch_size=self.batch_size, document_ids=document_ids
|
|
657
|
+
):
|
|
658
|
+
batch_chunk_count = 0
|
|
659
|
+
for doc in doc_batch:
|
|
660
|
+
batch_chunk_count += len(doc.chunks)
|
|
661
|
+
|
|
662
|
+
try:
|
|
663
|
+
await vector_store.add_chunks_with_embeddings(doc_batch)
|
|
664
|
+
yield RagStepRunnerProgress(
|
|
665
|
+
success_count=batch_chunk_count,
|
|
666
|
+
error_count=0,
|
|
667
|
+
)
|
|
668
|
+
except Exception as e:
|
|
669
|
+
error_msg = f"Error indexing document batch starting with {doc_batch[0].document_id}: {e}"
|
|
670
|
+
logger.error(error_msg, exc_info=True)
|
|
671
|
+
yield RagStepRunnerProgress(
|
|
672
|
+
success_count=0,
|
|
673
|
+
error_count=batch_chunk_count,
|
|
674
|
+
logs=[
|
|
675
|
+
LogMessage(
|
|
676
|
+
level="error",
|
|
677
|
+
message=error_msg,
|
|
678
|
+
),
|
|
679
|
+
],
|
|
680
|
+
)
|
|
681
|
+
|
|
682
|
+
# needed to reconcile and delete any chunks that are currently indexed but
|
|
683
|
+
# are no longer in our target set (because they were deleted or untagged)
|
|
684
|
+
await vector_store.delete_nodes_not_in_set(
|
|
685
|
+
self.get_all_target_document_ids()
|
|
686
|
+
)
|
|
687
|
+
|
|
688
|
+
|
|
689
|
+
class RagWorkflowRunnerConfiguration(BaseModel):
|
|
690
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
691
|
+
|
|
692
|
+
step_runners: list[AbstractRagStepRunner] = Field(
|
|
693
|
+
description="The step runners to run",
|
|
694
|
+
)
|
|
695
|
+
|
|
696
|
+
initial_progress: RagProgress = Field(
|
|
697
|
+
description="Initial progress state provided by the caller - progress will build on top of this",
|
|
698
|
+
)
|
|
699
|
+
|
|
700
|
+
rag_config: RagConfig = Field(
|
|
701
|
+
description="The rag config to use for the workflow",
|
|
702
|
+
)
|
|
703
|
+
|
|
704
|
+
extractor_config: ExtractorConfig = Field(
|
|
705
|
+
description="The extractor config to use for the workflow",
|
|
706
|
+
)
|
|
707
|
+
|
|
708
|
+
chunker_config: ChunkerConfig = Field(
|
|
709
|
+
description="The chunker config to use for the workflow",
|
|
710
|
+
)
|
|
711
|
+
|
|
712
|
+
embedding_config: EmbeddingConfig = Field(
|
|
713
|
+
description="The embedding config to use for the workflow",
|
|
714
|
+
)
|
|
715
|
+
|
|
716
|
+
|
|
717
|
+
class RagWorkflowRunner:
|
|
718
|
+
def __init__(
|
|
719
|
+
self,
|
|
720
|
+
project: Project,
|
|
721
|
+
configuration: RagWorkflowRunnerConfiguration,
|
|
722
|
+
):
|
|
723
|
+
self.project = project
|
|
724
|
+
self.configuration = configuration
|
|
725
|
+
self.step_runners: list[AbstractRagStepRunner] = configuration.step_runners
|
|
726
|
+
self.initial_progress = self.configuration.initial_progress
|
|
727
|
+
self.current_progress = self.initial_progress.model_copy()
|
|
728
|
+
|
|
729
|
+
@property
|
|
730
|
+
def lock_key(self) -> str:
|
|
731
|
+
return f"rag:run:{self.configuration.rag_config.id}"
|
|
732
|
+
|
|
733
|
+
def update_workflow_progress(
|
|
734
|
+
self, step_name: RagWorkflowStepNames, step_progress: RagStepRunnerProgress
|
|
735
|
+
) -> RagProgress:
|
|
736
|
+
# merge the simpler step-specific progress with the broader RAG progress
|
|
737
|
+
match step_name:
|
|
738
|
+
case RagWorkflowStepNames.EXTRACTING:
|
|
739
|
+
if step_progress.success_count is not None:
|
|
740
|
+
self.current_progress.total_document_extracted_count = max(
|
|
741
|
+
self.current_progress.total_document_extracted_count,
|
|
742
|
+
step_progress.success_count
|
|
743
|
+
+ self.initial_progress.total_document_extracted_count,
|
|
744
|
+
)
|
|
745
|
+
if step_progress.error_count is not None:
|
|
746
|
+
self.current_progress.total_document_extracted_error_count = max(
|
|
747
|
+
self.current_progress.total_document_extracted_error_count,
|
|
748
|
+
step_progress.error_count
|
|
749
|
+
+ self.initial_progress.total_document_extracted_error_count,
|
|
750
|
+
)
|
|
751
|
+
case RagWorkflowStepNames.CHUNKING:
|
|
752
|
+
if step_progress.success_count is not None:
|
|
753
|
+
self.current_progress.total_document_chunked_count = max(
|
|
754
|
+
self.current_progress.total_document_chunked_count,
|
|
755
|
+
step_progress.success_count
|
|
756
|
+
+ self.initial_progress.total_document_chunked_count,
|
|
757
|
+
)
|
|
758
|
+
if step_progress.error_count is not None:
|
|
759
|
+
self.current_progress.total_document_chunked_error_count = max(
|
|
760
|
+
self.current_progress.total_document_chunked_error_count,
|
|
761
|
+
step_progress.error_count
|
|
762
|
+
+ self.initial_progress.total_document_chunked_error_count,
|
|
763
|
+
)
|
|
764
|
+
case RagWorkflowStepNames.EMBEDDING:
|
|
765
|
+
if step_progress.success_count is not None:
|
|
766
|
+
self.current_progress.total_document_embedded_count = max(
|
|
767
|
+
self.current_progress.total_document_embedded_count,
|
|
768
|
+
step_progress.success_count
|
|
769
|
+
+ self.initial_progress.total_document_embedded_count,
|
|
770
|
+
)
|
|
771
|
+
if step_progress.error_count is not None:
|
|
772
|
+
self.current_progress.total_document_embedded_error_count = max(
|
|
773
|
+
self.current_progress.total_document_embedded_error_count,
|
|
774
|
+
step_progress.error_count
|
|
775
|
+
+ self.initial_progress.total_document_embedded_error_count,
|
|
776
|
+
)
|
|
777
|
+
case RagWorkflowStepNames.INDEXING:
|
|
778
|
+
if step_progress.success_count is not None:
|
|
779
|
+
self.current_progress.total_chunks_indexed_count += (
|
|
780
|
+
step_progress.success_count
|
|
781
|
+
)
|
|
782
|
+
if step_progress.error_count is not None:
|
|
783
|
+
self.current_progress.total_chunks_indexed_error_count += (
|
|
784
|
+
step_progress.error_count
|
|
785
|
+
)
|
|
786
|
+
case _:
|
|
787
|
+
raise_exhaustive_enum_error(step_name)
|
|
788
|
+
|
|
789
|
+
self.current_progress.total_document_completed_count = min(
|
|
790
|
+
self.current_progress.total_document_extracted_count,
|
|
791
|
+
self.current_progress.total_document_chunked_count,
|
|
792
|
+
self.current_progress.total_document_embedded_count,
|
|
793
|
+
)
|
|
794
|
+
|
|
795
|
+
self.current_progress.total_chunk_completed_count = (
|
|
796
|
+
self.current_progress.total_chunks_indexed_count
|
|
797
|
+
)
|
|
798
|
+
|
|
799
|
+
self.current_progress.logs = step_progress.logs
|
|
800
|
+
return self.current_progress
|
|
801
|
+
|
|
802
|
+
async def run(
|
|
803
|
+
self,
|
|
804
|
+
stages_to_run: list[RagWorkflowStepNames] | None = None,
|
|
805
|
+
document_ids: list[ID_TYPE] | None = None,
|
|
806
|
+
) -> AsyncGenerator[RagProgress, None]:
|
|
807
|
+
"""
|
|
808
|
+
Runs the RAG workflow for the given stages and document ids.
|
|
809
|
+
|
|
810
|
+
:param stages_to_run: The stages to run. If None, all stages will be run.
|
|
811
|
+
:param document_ids: The document ids to run the workflow for. If None, all documents will be run.
|
|
812
|
+
"""
|
|
813
|
+
yield self.initial_progress
|
|
814
|
+
|
|
815
|
+
async with shared_async_lock_manager.acquire(
|
|
816
|
+
self.lock_key, timeout=LOCK_TIMEOUT_SECONDS
|
|
817
|
+
):
|
|
818
|
+
for step in self.step_runners:
|
|
819
|
+
if stages_to_run is not None and step.stage() not in stages_to_run:
|
|
820
|
+
continue
|
|
821
|
+
|
|
822
|
+
# we need to know the total number of chunks to index to be able to
|
|
823
|
+
# calculate the progress on the client
|
|
824
|
+
if step.stage() == RagWorkflowStepNames.INDEXING and isinstance(
|
|
825
|
+
step, RagIndexingStepRunner
|
|
826
|
+
):
|
|
827
|
+
self.current_progress.total_chunk_count = (
|
|
828
|
+
await step.count_total_chunks()
|
|
829
|
+
)
|
|
830
|
+
# reset the indexing progress to 0 since we go through all the chunks again
|
|
831
|
+
if not document_ids:
|
|
832
|
+
self.initial_progress.total_chunks_indexed_count = 0
|
|
833
|
+
self.current_progress.total_chunks_indexed_count = 0
|
|
834
|
+
|
|
835
|
+
yield self.update_workflow_progress(
|
|
836
|
+
step.stage(),
|
|
837
|
+
RagStepRunnerProgress(
|
|
838
|
+
success_count=0,
|
|
839
|
+
error_count=0,
|
|
840
|
+
),
|
|
841
|
+
)
|
|
842
|
+
|
|
843
|
+
async for progress in step.run(document_ids=document_ids):
|
|
844
|
+
yield self.update_workflow_progress(step.stage(), progress)
|