kiln-ai 0.20.1__py3-none-any.whl → 0.22.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (133) hide show
  1. kiln_ai/adapters/__init__.py +6 -0
  2. kiln_ai/adapters/adapter_registry.py +43 -226
  3. kiln_ai/adapters/chunkers/__init__.py +13 -0
  4. kiln_ai/adapters/chunkers/base_chunker.py +42 -0
  5. kiln_ai/adapters/chunkers/chunker_registry.py +16 -0
  6. kiln_ai/adapters/chunkers/fixed_window_chunker.py +39 -0
  7. kiln_ai/adapters/chunkers/helpers.py +23 -0
  8. kiln_ai/adapters/chunkers/test_base_chunker.py +63 -0
  9. kiln_ai/adapters/chunkers/test_chunker_registry.py +28 -0
  10. kiln_ai/adapters/chunkers/test_fixed_window_chunker.py +346 -0
  11. kiln_ai/adapters/chunkers/test_helpers.py +75 -0
  12. kiln_ai/adapters/data_gen/test_data_gen_task.py +9 -3
  13. kiln_ai/adapters/embedding/__init__.py +0 -0
  14. kiln_ai/adapters/embedding/base_embedding_adapter.py +44 -0
  15. kiln_ai/adapters/embedding/embedding_registry.py +32 -0
  16. kiln_ai/adapters/embedding/litellm_embedding_adapter.py +199 -0
  17. kiln_ai/adapters/embedding/test_base_embedding_adapter.py +283 -0
  18. kiln_ai/adapters/embedding/test_embedding_registry.py +166 -0
  19. kiln_ai/adapters/embedding/test_litellm_embedding_adapter.py +1149 -0
  20. kiln_ai/adapters/eval/eval_runner.py +6 -2
  21. kiln_ai/adapters/eval/test_base_eval.py +1 -3
  22. kiln_ai/adapters/eval/test_g_eval.py +1 -1
  23. kiln_ai/adapters/extractors/__init__.py +18 -0
  24. kiln_ai/adapters/extractors/base_extractor.py +72 -0
  25. kiln_ai/adapters/extractors/encoding.py +20 -0
  26. kiln_ai/adapters/extractors/extractor_registry.py +44 -0
  27. kiln_ai/adapters/extractors/extractor_runner.py +112 -0
  28. kiln_ai/adapters/extractors/litellm_extractor.py +406 -0
  29. kiln_ai/adapters/extractors/test_base_extractor.py +244 -0
  30. kiln_ai/adapters/extractors/test_encoding.py +54 -0
  31. kiln_ai/adapters/extractors/test_extractor_registry.py +181 -0
  32. kiln_ai/adapters/extractors/test_extractor_runner.py +181 -0
  33. kiln_ai/adapters/extractors/test_litellm_extractor.py +1290 -0
  34. kiln_ai/adapters/fine_tune/test_dataset_formatter.py +2 -2
  35. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +2 -6
  36. kiln_ai/adapters/fine_tune/test_together_finetune.py +2 -6
  37. kiln_ai/adapters/ml_embedding_model_list.py +494 -0
  38. kiln_ai/adapters/ml_model_list.py +876 -18
  39. kiln_ai/adapters/model_adapters/litellm_adapter.py +40 -75
  40. kiln_ai/adapters/model_adapters/test_litellm_adapter.py +79 -1
  41. kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py +119 -5
  42. kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +9 -3
  43. kiln_ai/adapters/model_adapters/test_structured_output.py +9 -10
  44. kiln_ai/adapters/ollama_tools.py +69 -12
  45. kiln_ai/adapters/provider_tools.py +190 -46
  46. kiln_ai/adapters/rag/deduplication.py +49 -0
  47. kiln_ai/adapters/rag/progress.py +252 -0
  48. kiln_ai/adapters/rag/rag_runners.py +844 -0
  49. kiln_ai/adapters/rag/test_deduplication.py +195 -0
  50. kiln_ai/adapters/rag/test_progress.py +785 -0
  51. kiln_ai/adapters/rag/test_rag_runners.py +2376 -0
  52. kiln_ai/adapters/remote_config.py +80 -8
  53. kiln_ai/adapters/test_adapter_registry.py +579 -86
  54. kiln_ai/adapters/test_ml_embedding_model_list.py +239 -0
  55. kiln_ai/adapters/test_ml_model_list.py +202 -0
  56. kiln_ai/adapters/test_ollama_tools.py +340 -1
  57. kiln_ai/adapters/test_prompt_builders.py +1 -1
  58. kiln_ai/adapters/test_provider_tools.py +199 -8
  59. kiln_ai/adapters/test_remote_config.py +551 -56
  60. kiln_ai/adapters/vector_store/__init__.py +1 -0
  61. kiln_ai/adapters/vector_store/base_vector_store_adapter.py +83 -0
  62. kiln_ai/adapters/vector_store/lancedb_adapter.py +389 -0
  63. kiln_ai/adapters/vector_store/test_base_vector_store.py +160 -0
  64. kiln_ai/adapters/vector_store/test_lancedb_adapter.py +1841 -0
  65. kiln_ai/adapters/vector_store/test_vector_store_registry.py +199 -0
  66. kiln_ai/adapters/vector_store/vector_store_registry.py +33 -0
  67. kiln_ai/datamodel/__init__.py +16 -13
  68. kiln_ai/datamodel/basemodel.py +201 -4
  69. kiln_ai/datamodel/chunk.py +158 -0
  70. kiln_ai/datamodel/datamodel_enums.py +27 -0
  71. kiln_ai/datamodel/embedding.py +64 -0
  72. kiln_ai/datamodel/external_tool_server.py +206 -54
  73. kiln_ai/datamodel/extraction.py +317 -0
  74. kiln_ai/datamodel/project.py +33 -1
  75. kiln_ai/datamodel/rag.py +79 -0
  76. kiln_ai/datamodel/task.py +5 -0
  77. kiln_ai/datamodel/task_output.py +41 -11
  78. kiln_ai/datamodel/test_attachment.py +649 -0
  79. kiln_ai/datamodel/test_basemodel.py +270 -14
  80. kiln_ai/datamodel/test_chunk_models.py +317 -0
  81. kiln_ai/datamodel/test_dataset_split.py +1 -1
  82. kiln_ai/datamodel/test_datasource.py +50 -0
  83. kiln_ai/datamodel/test_embedding_models.py +448 -0
  84. kiln_ai/datamodel/test_eval_model.py +6 -6
  85. kiln_ai/datamodel/test_external_tool_server.py +534 -152
  86. kiln_ai/datamodel/test_extraction_chunk.py +206 -0
  87. kiln_ai/datamodel/test_extraction_model.py +501 -0
  88. kiln_ai/datamodel/test_rag.py +641 -0
  89. kiln_ai/datamodel/test_task.py +35 -1
  90. kiln_ai/datamodel/test_tool_id.py +187 -1
  91. kiln_ai/datamodel/test_vector_store.py +320 -0
  92. kiln_ai/datamodel/tool_id.py +58 -0
  93. kiln_ai/datamodel/vector_store.py +141 -0
  94. kiln_ai/tools/base_tool.py +12 -3
  95. kiln_ai/tools/built_in_tools/math_tools.py +12 -4
  96. kiln_ai/tools/kiln_task_tool.py +158 -0
  97. kiln_ai/tools/mcp_server_tool.py +2 -2
  98. kiln_ai/tools/mcp_session_manager.py +51 -22
  99. kiln_ai/tools/rag_tools.py +164 -0
  100. kiln_ai/tools/test_kiln_task_tool.py +527 -0
  101. kiln_ai/tools/test_mcp_server_tool.py +4 -15
  102. kiln_ai/tools/test_mcp_session_manager.py +187 -227
  103. kiln_ai/tools/test_rag_tools.py +929 -0
  104. kiln_ai/tools/test_tool_registry.py +290 -7
  105. kiln_ai/tools/tool_registry.py +69 -16
  106. kiln_ai/utils/__init__.py +3 -0
  107. kiln_ai/utils/async_job_runner.py +62 -17
  108. kiln_ai/utils/config.py +2 -2
  109. kiln_ai/utils/env.py +15 -0
  110. kiln_ai/utils/filesystem.py +14 -0
  111. kiln_ai/utils/filesystem_cache.py +60 -0
  112. kiln_ai/utils/litellm.py +94 -0
  113. kiln_ai/utils/lock.py +100 -0
  114. kiln_ai/utils/mime_type.py +38 -0
  115. kiln_ai/utils/open_ai_types.py +19 -2
  116. kiln_ai/utils/pdf_utils.py +59 -0
  117. kiln_ai/utils/test_async_job_runner.py +151 -35
  118. kiln_ai/utils/test_env.py +142 -0
  119. kiln_ai/utils/test_filesystem_cache.py +316 -0
  120. kiln_ai/utils/test_litellm.py +206 -0
  121. kiln_ai/utils/test_lock.py +185 -0
  122. kiln_ai/utils/test_mime_type.py +66 -0
  123. kiln_ai/utils/test_open_ai_types.py +88 -12
  124. kiln_ai/utils/test_pdf_utils.py +86 -0
  125. kiln_ai/utils/test_uuid.py +111 -0
  126. kiln_ai/utils/test_validation.py +524 -0
  127. kiln_ai/utils/uuid.py +9 -0
  128. kiln_ai/utils/validation.py +90 -0
  129. {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/METADATA +9 -1
  130. kiln_ai-0.22.0.dist-info/RECORD +213 -0
  131. kiln_ai-0.20.1.dist-info/RECORD +0 -138
  132. {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/WHEEL +0 -0
  133. {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -0,0 +1,844 @@
1
+ import logging
2
+ from abc import ABC, abstractmethod
3
+ from dataclasses import dataclass
4
+ from enum import Enum
5
+ from typing import AsyncGenerator, Generic, Set, Tuple, TypeVar
6
+
7
+ from kiln_ai.adapters.chunkers.base_chunker import BaseChunker
8
+ from kiln_ai.adapters.chunkers.chunker_registry import chunker_adapter_from_type
9
+ from kiln_ai.adapters.embedding.base_embedding_adapter import BaseEmbeddingAdapter
10
+ from kiln_ai.adapters.embedding.embedding_registry import embedding_adapter_from_type
11
+ from kiln_ai.adapters.extractors.base_extractor import BaseExtractor, ExtractionInput
12
+ from kiln_ai.adapters.extractors.extractor_registry import extractor_adapter_from_type
13
+ from kiln_ai.adapters.rag.deduplication import (
14
+ deduplicate_chunk_embeddings,
15
+ deduplicate_chunked_documents,
16
+ deduplicate_extractions,
17
+ filter_documents_by_tags,
18
+ )
19
+ from kiln_ai.adapters.rag.progress import LogMessage, RagProgress
20
+ from kiln_ai.adapters.vector_store.base_vector_store_adapter import (
21
+ DocumentWithChunksAndEmbeddings,
22
+ )
23
+ from kiln_ai.adapters.vector_store.vector_store_registry import (
24
+ vector_store_adapter_for_config,
25
+ )
26
+ from kiln_ai.datamodel import Project
27
+ from kiln_ai.datamodel.basemodel import ID_TYPE, KilnAttachmentModel
28
+ from kiln_ai.datamodel.chunk import Chunk, ChunkedDocument, ChunkerConfig
29
+ from kiln_ai.datamodel.embedding import ChunkEmbeddings, Embedding, EmbeddingConfig
30
+ from kiln_ai.datamodel.extraction import (
31
+ Document,
32
+ Extraction,
33
+ ExtractionSource,
34
+ ExtractorConfig,
35
+ )
36
+ from kiln_ai.datamodel.rag import RagConfig
37
+ from kiln_ai.datamodel.vector_store import VectorStoreConfig
38
+ from kiln_ai.utils.async_job_runner import AsyncJobRunner, AsyncJobRunnerObserver
39
+ from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
40
+ from kiln_ai.utils.filesystem_cache import FilesystemCache
41
+ from kiln_ai.utils.lock import shared_async_lock_manager
42
+ from pydantic import BaseModel, ConfigDict, Field
43
+
44
+ # We set the timeout high because current UX is likely to cause the user triggering
45
+ # multiple RAG Workflows whose subconfigs (e.g. same extractor) may be shared and take
46
+ # a long time to complete, causing whichever ones are waiting on the lock to time out
47
+ # before they are likely to start.
48
+ LOCK_TIMEOUT_SECONDS = 60 * 60 # 1 hour
49
+
50
+ logger = logging.getLogger(__name__)
51
+
52
+
53
+ @dataclass
54
+ class ExtractorJob:
55
+ doc: Document
56
+ extractor_config: ExtractorConfig
57
+
58
+
59
+ @dataclass
60
+ class ChunkerJob:
61
+ extraction: Extraction
62
+ chunker_config: ChunkerConfig
63
+
64
+
65
+ @dataclass
66
+ class EmbeddingJob:
67
+ chunked_document: ChunkedDocument
68
+ embedding_config: EmbeddingConfig
69
+
70
+
71
+ class RagStepRunnerProgress(BaseModel):
72
+ success_count: int | None = Field(
73
+ description="The number of items that have been processed",
74
+ default=None,
75
+ )
76
+ error_count: int | None = Field(
77
+ description="The number of items that have errored",
78
+ default=None,
79
+ )
80
+ logs: list[LogMessage] = Field(
81
+ description="A list of log messages to display to the user",
82
+ default_factory=list,
83
+ )
84
+
85
+
86
+ T = TypeVar("T")
87
+
88
+
89
+ class GenericErrorCollector(AsyncJobRunnerObserver[T], Generic[T]):
90
+ def __init__(
91
+ self,
92
+ ):
93
+ self.errors: list[Tuple[T, Exception]] = []
94
+
95
+ async def on_success(self, job: T):
96
+ pass
97
+
98
+ async def on_error(self, job: T, error: Exception):
99
+ self.errors.append((job, error))
100
+
101
+ def get_errors(
102
+ self,
103
+ start_idx: int = 0,
104
+ ) -> tuple[list[Tuple[T, Exception]], int]:
105
+ """Returns a tuple of: ((job, error), index of the last error)"""
106
+ if start_idx < 0:
107
+ raise ValueError("start_idx must be non-negative")
108
+ if start_idx >= len(self.errors):
109
+ return [], start_idx
110
+ if start_idx > 0:
111
+ return self.errors[start_idx : len(self.errors)], len(self.errors)
112
+ return self.errors, len(self.errors)
113
+
114
+ def get_error_count(self) -> int:
115
+ return len(self.errors)
116
+
117
+
118
+ class RagWorkflowStepNames(str, Enum):
119
+ EXTRACTING = "extracting"
120
+ CHUNKING = "chunking"
121
+ EMBEDDING = "embedding"
122
+ INDEXING = "indexing"
123
+
124
+
125
+ async def execute_extractor_job(job: ExtractorJob, extractor: BaseExtractor) -> bool:
126
+ if job.doc.path is None:
127
+ raise ValueError("Document path is not set")
128
+
129
+ output = await extractor.extract(
130
+ extraction_input=ExtractionInput(
131
+ path=job.doc.original_file.attachment.resolve_path(job.doc.path.parent),
132
+ mime_type=job.doc.original_file.mime_type,
133
+ )
134
+ )
135
+
136
+ extraction = Extraction(
137
+ parent=job.doc,
138
+ extractor_config_id=job.extractor_config.id,
139
+ output=KilnAttachmentModel.from_data(
140
+ data=output.content,
141
+ mime_type=output.content_format,
142
+ ),
143
+ source=ExtractionSource.PASSTHROUGH
144
+ if output.is_passthrough
145
+ else ExtractionSource.PROCESSED,
146
+ )
147
+ extraction.save_to_file()
148
+
149
+ return True
150
+
151
+
152
+ async def execute_chunker_job(job: ChunkerJob, chunker: BaseChunker) -> bool:
153
+ extraction_output_content = await job.extraction.output_content()
154
+ if extraction_output_content is None:
155
+ raise ValueError("Extraction output content is not set")
156
+
157
+ chunking_result = await chunker.chunk(
158
+ extraction_output_content,
159
+ )
160
+ if chunking_result is None:
161
+ raise ValueError("Chunking result is not set")
162
+
163
+ chunked_document = ChunkedDocument(
164
+ parent=job.extraction,
165
+ chunker_config_id=job.chunker_config.id,
166
+ chunks=[
167
+ Chunk(
168
+ content=KilnAttachmentModel.from_data(
169
+ data=chunk.text,
170
+ mime_type="text/plain",
171
+ ),
172
+ )
173
+ for chunk in chunking_result.chunks
174
+ ],
175
+ )
176
+ chunked_document.save_to_file()
177
+ return True
178
+
179
+
180
+ async def execute_embedding_job(
181
+ job: EmbeddingJob, embedding_adapter: BaseEmbeddingAdapter
182
+ ) -> bool:
183
+ chunks_text = await job.chunked_document.load_chunks_text()
184
+ if chunks_text is None or len(chunks_text) == 0:
185
+ raise ValueError(
186
+ f"Failed to load chunks for chunked document: {job.chunked_document.id}"
187
+ )
188
+
189
+ chunk_embedding_result = await embedding_adapter.generate_embeddings(
190
+ input_texts=chunks_text
191
+ )
192
+ if chunk_embedding_result is None:
193
+ raise ValueError(
194
+ f"Failed to generate embeddings for chunked document: {job.chunked_document.id}"
195
+ )
196
+
197
+ chunk_embeddings = ChunkEmbeddings(
198
+ parent=job.chunked_document,
199
+ embedding_config_id=job.embedding_config.id,
200
+ embeddings=[
201
+ Embedding(
202
+ vector=embedding.vector,
203
+ )
204
+ for embedding in chunk_embedding_result.embeddings
205
+ ],
206
+ )
207
+
208
+ chunk_embeddings.save_to_file()
209
+ return True
210
+
211
+
212
+ class AbstractRagStepRunner(ABC):
213
+ @abstractmethod
214
+ def stage(self) -> RagWorkflowStepNames:
215
+ pass
216
+
217
+ # async keyword in the abstract prototype causes a type error in pyright
218
+ # so we need to remove it, but the concrete implementation should declare async
219
+ @abstractmethod
220
+ def run(
221
+ self, document_ids: list[ID_TYPE] | None = None
222
+ ) -> AsyncGenerator[RagStepRunnerProgress, None]:
223
+ pass
224
+
225
+
226
+ class RagExtractionStepRunner(AbstractRagStepRunner):
227
+ def __init__(
228
+ self,
229
+ project: Project,
230
+ extractor_config: ExtractorConfig,
231
+ concurrency: int = 10,
232
+ rag_config: RagConfig | None = None,
233
+ filesystem_cache: FilesystemCache | None = None,
234
+ ):
235
+ self.project = project
236
+ self.extractor_config = extractor_config
237
+ self.lock_key = f"docs:extract:{self.extractor_config.id}"
238
+ self.concurrency = concurrency
239
+ self.rag_config = rag_config
240
+ self.filesystem_cache = filesystem_cache
241
+
242
+ def stage(self) -> RagWorkflowStepNames:
243
+ return RagWorkflowStepNames.EXTRACTING
244
+
245
+ def has_extraction(self, document: Document, extractor_id: ID_TYPE) -> bool:
246
+ for ex in document.extractions(readonly=True):
247
+ if ex.extractor_config_id == extractor_id:
248
+ return True
249
+ return False
250
+
251
+ async def collect_jobs(
252
+ self, document_ids: list[ID_TYPE] | None = None
253
+ ) -> list[ExtractorJob]:
254
+ jobs: list[ExtractorJob] = []
255
+ target_extractor_config_id = self.extractor_config.id
256
+
257
+ documents = self.project.documents(readonly=True)
258
+ if self.rag_config and self.rag_config.tags:
259
+ documents = filter_documents_by_tags(documents, self.rag_config.tags)
260
+
261
+ for document in documents:
262
+ if (
263
+ document_ids is not None
264
+ and len(document_ids) > 0
265
+ and document.id not in document_ids
266
+ ):
267
+ continue
268
+ if not self.has_extraction(document, target_extractor_config_id):
269
+ jobs.append(
270
+ ExtractorJob(
271
+ doc=document,
272
+ extractor_config=self.extractor_config,
273
+ )
274
+ )
275
+ return jobs
276
+
277
+ async def run(
278
+ self, document_ids: list[ID_TYPE] | None = None
279
+ ) -> AsyncGenerator[RagStepRunnerProgress, None]:
280
+ async with shared_async_lock_manager.acquire(
281
+ self.lock_key, timeout=LOCK_TIMEOUT_SECONDS
282
+ ):
283
+ jobs = await self.collect_jobs(document_ids=document_ids)
284
+ extractor = extractor_adapter_from_type(
285
+ self.extractor_config.extractor_type,
286
+ self.extractor_config,
287
+ self.filesystem_cache,
288
+ )
289
+
290
+ observer = GenericErrorCollector()
291
+ runner = AsyncJobRunner(
292
+ jobs=jobs,
293
+ run_job_fn=lambda job: execute_extractor_job(job, extractor),
294
+ concurrency=self.concurrency,
295
+ observers=[observer],
296
+ )
297
+
298
+ error_idx = 0
299
+ async for progress in runner.run():
300
+ yield RagStepRunnerProgress(
301
+ success_count=progress.complete,
302
+ error_count=observer.get_error_count(),
303
+ )
304
+
305
+ # the errors are being accumulated in the observer so we need to flush them to the caller
306
+ if observer.get_error_count() > 0:
307
+ errors, error_idx = observer.get_errors(error_idx)
308
+ for job, error in errors:
309
+ yield RagStepRunnerProgress(
310
+ logs=[
311
+ LogMessage(
312
+ level="error",
313
+ message=f"Error extracting document: {job.doc.path}: {error}",
314
+ )
315
+ ],
316
+ )
317
+
318
+
319
+ class RagChunkingStepRunner(AbstractRagStepRunner):
320
+ def __init__(
321
+ self,
322
+ project: Project,
323
+ extractor_config: ExtractorConfig,
324
+ chunker_config: ChunkerConfig,
325
+ concurrency: int = 10,
326
+ rag_config: RagConfig | None = None,
327
+ ):
328
+ self.project = project
329
+ self.extractor_config = extractor_config
330
+ self.chunker_config = chunker_config
331
+ self.lock_key = f"docs:chunk:{self.chunker_config.id}"
332
+ self.concurrency = concurrency
333
+ self.rag_config = rag_config
334
+
335
+ def stage(self) -> RagWorkflowStepNames:
336
+ return RagWorkflowStepNames.CHUNKING
337
+
338
+ def has_chunks(self, extraction: Extraction, chunker_id: ID_TYPE) -> bool:
339
+ for cd in extraction.chunked_documents(readonly=True):
340
+ if cd.chunker_config_id == chunker_id:
341
+ return True
342
+ return False
343
+
344
+ async def collect_jobs(
345
+ self, document_ids: list[ID_TYPE] | None = None
346
+ ) -> list[ChunkerJob]:
347
+ target_extractor_config_id = self.extractor_config.id
348
+ target_chunker_config_id = self.chunker_config.id
349
+
350
+ jobs: list[ChunkerJob] = []
351
+ documents = self.project.documents(readonly=True)
352
+ if self.rag_config and self.rag_config.tags:
353
+ documents = filter_documents_by_tags(documents, self.rag_config.tags)
354
+
355
+ for document in documents:
356
+ if (
357
+ document_ids is not None
358
+ and len(document_ids) > 0
359
+ and document.id not in document_ids
360
+ ):
361
+ continue
362
+ for extraction in deduplicate_extractions(
363
+ document.extractions(readonly=True)
364
+ ):
365
+ if extraction.extractor_config_id == target_extractor_config_id:
366
+ if not self.has_chunks(extraction, target_chunker_config_id):
367
+ jobs.append(
368
+ ChunkerJob(
369
+ extraction=extraction,
370
+ chunker_config=self.chunker_config,
371
+ )
372
+ )
373
+ return jobs
374
+
375
+ async def run(
376
+ self, document_ids: list[ID_TYPE] | None = None
377
+ ) -> AsyncGenerator[RagStepRunnerProgress, None]:
378
+ async with shared_async_lock_manager.acquire(
379
+ self.lock_key, timeout=LOCK_TIMEOUT_SECONDS
380
+ ):
381
+ jobs = await self.collect_jobs(document_ids=document_ids)
382
+ chunker = chunker_adapter_from_type(
383
+ self.chunker_config.chunker_type,
384
+ self.chunker_config,
385
+ )
386
+ observer = GenericErrorCollector()
387
+ runner = AsyncJobRunner(
388
+ jobs=jobs,
389
+ run_job_fn=lambda job: execute_chunker_job(job, chunker),
390
+ concurrency=self.concurrency,
391
+ observers=[observer],
392
+ )
393
+
394
+ error_idx = 0
395
+ async for progress in runner.run():
396
+ yield RagStepRunnerProgress(
397
+ success_count=progress.complete,
398
+ error_count=observer.get_error_count(),
399
+ )
400
+
401
+ # the errors are being accumulated in the observer so we need to flush them to the caller
402
+ if observer.get_error_count() > 0:
403
+ errors, error_idx = observer.get_errors(error_idx)
404
+ for job, error in errors:
405
+ yield RagStepRunnerProgress(
406
+ logs=[
407
+ LogMessage(
408
+ level="error",
409
+ message=f"Error chunking document: {job.extraction.path}: {error}",
410
+ )
411
+ ],
412
+ )
413
+
414
+
415
+ class RagEmbeddingStepRunner(AbstractRagStepRunner):
416
+ def __init__(
417
+ self,
418
+ project: Project,
419
+ extractor_config: ExtractorConfig,
420
+ chunker_config: ChunkerConfig,
421
+ embedding_config: EmbeddingConfig,
422
+ concurrency: int = 10,
423
+ rag_config: RagConfig | None = None,
424
+ ):
425
+ self.project = project
426
+ self.extractor_config = extractor_config
427
+ self.chunker_config = chunker_config
428
+ self.embedding_config = embedding_config
429
+ self.concurrency = concurrency
430
+ self.rag_config = rag_config
431
+ self.lock_key = f"docs:embedding:{self.embedding_config.id}"
432
+
433
+ def stage(self) -> RagWorkflowStepNames:
434
+ return RagWorkflowStepNames.EMBEDDING
435
+
436
+ def has_embeddings(self, chunked: ChunkedDocument, embedding_id: ID_TYPE) -> bool:
437
+ for emb in chunked.chunk_embeddings(readonly=True):
438
+ if emb.embedding_config_id == embedding_id:
439
+ return True
440
+ return False
441
+
442
+ async def collect_jobs(
443
+ self, document_ids: list[ID_TYPE] | None = None
444
+ ) -> list[EmbeddingJob]:
445
+ target_extractor_config_id = self.extractor_config.id
446
+ target_chunker_config_id = self.chunker_config.id
447
+ target_embedding_config_id = self.embedding_config.id
448
+
449
+ jobs: list[EmbeddingJob] = []
450
+ documents = self.project.documents(readonly=True)
451
+ if self.rag_config and self.rag_config.tags:
452
+ documents = filter_documents_by_tags(documents, self.rag_config.tags)
453
+
454
+ for document in documents:
455
+ if (
456
+ document_ids is not None
457
+ and len(document_ids) > 0
458
+ and document.id not in document_ids
459
+ ):
460
+ continue
461
+ for extraction in deduplicate_extractions(
462
+ document.extractions(readonly=True)
463
+ ):
464
+ if extraction.extractor_config_id == target_extractor_config_id:
465
+ for chunked_document in deduplicate_chunked_documents(
466
+ extraction.chunked_documents(readonly=True)
467
+ ):
468
+ if (
469
+ chunked_document.chunker_config_id
470
+ == target_chunker_config_id
471
+ ):
472
+ if not self.has_embeddings(
473
+ chunked_document, target_embedding_config_id
474
+ ):
475
+ jobs.append(
476
+ EmbeddingJob(
477
+ chunked_document=chunked_document,
478
+ embedding_config=self.embedding_config,
479
+ )
480
+ )
481
+ return jobs
482
+
483
+ async def run(
484
+ self, document_ids: list[ID_TYPE] | None = None
485
+ ) -> AsyncGenerator[RagStepRunnerProgress, None]:
486
+ async with shared_async_lock_manager.acquire(
487
+ self.lock_key, timeout=LOCK_TIMEOUT_SECONDS
488
+ ):
489
+ jobs = await self.collect_jobs(document_ids=document_ids)
490
+ embedding_adapter = embedding_adapter_from_type(
491
+ self.embedding_config,
492
+ )
493
+
494
+ observer = GenericErrorCollector()
495
+ runner = AsyncJobRunner(
496
+ jobs=jobs,
497
+ run_job_fn=lambda job: execute_embedding_job(job, embedding_adapter),
498
+ concurrency=self.concurrency,
499
+ observers=[observer],
500
+ )
501
+
502
+ error_idx = 0
503
+ async for progress in runner.run():
504
+ yield RagStepRunnerProgress(
505
+ success_count=progress.complete,
506
+ error_count=observer.get_error_count(),
507
+ )
508
+
509
+ # the errors are being accumulated in the observer so we need to flush them to the caller
510
+ if observer.get_error_count() > 0:
511
+ errors, error_idx = observer.get_errors(error_idx)
512
+ for job, error in errors:
513
+ yield RagStepRunnerProgress(
514
+ logs=[
515
+ LogMessage(
516
+ level="error",
517
+ message=f"Error embedding document: {job.chunked_document.path}: {error}",
518
+ )
519
+ ],
520
+ )
521
+
522
+
523
+ class RagIndexingStepRunner(AbstractRagStepRunner):
524
+ def __init__(
525
+ self,
526
+ project: Project,
527
+ extractor_config: ExtractorConfig,
528
+ chunker_config: ChunkerConfig,
529
+ embedding_config: EmbeddingConfig,
530
+ vector_store_config: VectorStoreConfig,
531
+ rag_config: RagConfig,
532
+ concurrency: int = 10,
533
+ batch_size: int = 20,
534
+ ):
535
+ self.project = project
536
+ self.extractor_config = extractor_config
537
+ self.chunker_config = chunker_config
538
+ self.embedding_config = embedding_config
539
+ self.vector_store_config = vector_store_config
540
+ self.rag_config = rag_config
541
+ self.concurrency = concurrency
542
+ self.batch_size = batch_size
543
+
544
+ @property
545
+ def lock_key(self) -> str:
546
+ return f"rag:index:{self.vector_store_config.id}"
547
+
548
+ def stage(self) -> RagWorkflowStepNames:
549
+ return RagWorkflowStepNames.INDEXING
550
+
551
+ async def collect_records(
552
+ self,
553
+ batch_size: int,
554
+ document_ids: list[ID_TYPE] | None = None,
555
+ ) -> AsyncGenerator[list[DocumentWithChunksAndEmbeddings], None]:
556
+ target_extractor_config_id = self.extractor_config.id
557
+ target_chunker_config_id = self.chunker_config.id
558
+ target_embedding_config_id = self.embedding_config.id
559
+
560
+ # (document_id, chunked_document, embedding)
561
+ jobs: list[DocumentWithChunksAndEmbeddings] = []
562
+ documents = self.project.documents(readonly=True)
563
+ if self.rag_config and self.rag_config.tags:
564
+ documents = filter_documents_by_tags(documents, self.rag_config.tags)
565
+
566
+ for document in documents:
567
+ if (
568
+ document_ids is not None
569
+ and len(document_ids) > 0
570
+ and document.id not in document_ids
571
+ ):
572
+ continue
573
+ for extraction in deduplicate_extractions(
574
+ document.extractions(readonly=True)
575
+ ):
576
+ if extraction.extractor_config_id == target_extractor_config_id:
577
+ for chunked_document in deduplicate_chunked_documents(
578
+ extraction.chunked_documents(readonly=True)
579
+ ):
580
+ if (
581
+ chunked_document.chunker_config_id
582
+ == target_chunker_config_id
583
+ ):
584
+ for chunk_embedding in deduplicate_chunk_embeddings(
585
+ chunked_document.chunk_embeddings(readonly=True)
586
+ ):
587
+ if (
588
+ chunk_embedding.embedding_config_id
589
+ == target_embedding_config_id
590
+ ):
591
+ jobs.append(
592
+ DocumentWithChunksAndEmbeddings(
593
+ document_id=str(document.id),
594
+ chunked_document=chunked_document,
595
+ chunk_embeddings=chunk_embedding,
596
+ )
597
+ )
598
+
599
+ if len(jobs) >= batch_size:
600
+ yield jobs
601
+ jobs.clear()
602
+
603
+ if len(jobs) > 0:
604
+ yield jobs
605
+ jobs.clear()
606
+
607
+ async def count_total_chunks(self) -> int:
608
+ total_chunk_count = 0
609
+ async for documents in self.collect_records(batch_size=1):
610
+ total_chunk_count += len(documents[0].chunks)
611
+ return total_chunk_count
612
+
613
+ def get_all_target_document_ids(self) -> Set[str]:
614
+ documents = self.project.documents(readonly=True)
615
+ if self.rag_config and self.rag_config.tags:
616
+ documents = filter_documents_by_tags(documents, self.rag_config.tags)
617
+ return {str(document.id) for document in documents}
618
+
619
+ async def run(
620
+ self, document_ids: list[ID_TYPE] | None = None
621
+ ) -> AsyncGenerator[RagStepRunnerProgress, None]:
622
+ async with shared_async_lock_manager.acquire(
623
+ self.lock_key, timeout=LOCK_TIMEOUT_SECONDS
624
+ ):
625
+ vector_dimensions: int | None = None
626
+
627
+ # infer dimensionality - we peek into the first record to get the vector dimensions
628
+ # vector dimensions are not stored in the config because they are derived from the model
629
+ # and in some cases dynamic shortening of the vector (called Matryoshka Representation Learning)
630
+ async for doc_batch in self.collect_records(
631
+ batch_size=1,
632
+ ):
633
+ if len(doc_batch) == 0:
634
+ # there are no records, because there may be nothing in the upstream steps at all yet
635
+ return
636
+ else:
637
+ doc = doc_batch[0]
638
+ embedding = doc.embeddings[0]
639
+ vector_dimensions = len(embedding.vector)
640
+ break
641
+
642
+ if vector_dimensions is None:
643
+ raise ValueError("Vector dimensions are not set")
644
+
645
+ vector_store = await vector_store_adapter_for_config(
646
+ self.rag_config,
647
+ self.vector_store_config,
648
+ )
649
+
650
+ yield RagStepRunnerProgress(
651
+ success_count=0,
652
+ error_count=0,
653
+ )
654
+
655
+ async for doc_batch in self.collect_records(
656
+ batch_size=self.batch_size, document_ids=document_ids
657
+ ):
658
+ batch_chunk_count = 0
659
+ for doc in doc_batch:
660
+ batch_chunk_count += len(doc.chunks)
661
+
662
+ try:
663
+ await vector_store.add_chunks_with_embeddings(doc_batch)
664
+ yield RagStepRunnerProgress(
665
+ success_count=batch_chunk_count,
666
+ error_count=0,
667
+ )
668
+ except Exception as e:
669
+ error_msg = f"Error indexing document batch starting with {doc_batch[0].document_id}: {e}"
670
+ logger.error(error_msg, exc_info=True)
671
+ yield RagStepRunnerProgress(
672
+ success_count=0,
673
+ error_count=batch_chunk_count,
674
+ logs=[
675
+ LogMessage(
676
+ level="error",
677
+ message=error_msg,
678
+ ),
679
+ ],
680
+ )
681
+
682
+ # needed to reconcile and delete any chunks that are currently indexed but
683
+ # are no longer in our target set (because they were deleted or untagged)
684
+ await vector_store.delete_nodes_not_in_set(
685
+ self.get_all_target_document_ids()
686
+ )
687
+
688
+
689
+ class RagWorkflowRunnerConfiguration(BaseModel):
690
+ model_config = ConfigDict(arbitrary_types_allowed=True)
691
+
692
+ step_runners: list[AbstractRagStepRunner] = Field(
693
+ description="The step runners to run",
694
+ )
695
+
696
+ initial_progress: RagProgress = Field(
697
+ description="Initial progress state provided by the caller - progress will build on top of this",
698
+ )
699
+
700
+ rag_config: RagConfig = Field(
701
+ description="The rag config to use for the workflow",
702
+ )
703
+
704
+ extractor_config: ExtractorConfig = Field(
705
+ description="The extractor config to use for the workflow",
706
+ )
707
+
708
+ chunker_config: ChunkerConfig = Field(
709
+ description="The chunker config to use for the workflow",
710
+ )
711
+
712
+ embedding_config: EmbeddingConfig = Field(
713
+ description="The embedding config to use for the workflow",
714
+ )
715
+
716
+
717
+ class RagWorkflowRunner:
718
+ def __init__(
719
+ self,
720
+ project: Project,
721
+ configuration: RagWorkflowRunnerConfiguration,
722
+ ):
723
+ self.project = project
724
+ self.configuration = configuration
725
+ self.step_runners: list[AbstractRagStepRunner] = configuration.step_runners
726
+ self.initial_progress = self.configuration.initial_progress
727
+ self.current_progress = self.initial_progress.model_copy()
728
+
729
+ @property
730
+ def lock_key(self) -> str:
731
+ return f"rag:run:{self.configuration.rag_config.id}"
732
+
733
+ def update_workflow_progress(
734
+ self, step_name: RagWorkflowStepNames, step_progress: RagStepRunnerProgress
735
+ ) -> RagProgress:
736
+ # merge the simpler step-specific progress with the broader RAG progress
737
+ match step_name:
738
+ case RagWorkflowStepNames.EXTRACTING:
739
+ if step_progress.success_count is not None:
740
+ self.current_progress.total_document_extracted_count = max(
741
+ self.current_progress.total_document_extracted_count,
742
+ step_progress.success_count
743
+ + self.initial_progress.total_document_extracted_count,
744
+ )
745
+ if step_progress.error_count is not None:
746
+ self.current_progress.total_document_extracted_error_count = max(
747
+ self.current_progress.total_document_extracted_error_count,
748
+ step_progress.error_count
749
+ + self.initial_progress.total_document_extracted_error_count,
750
+ )
751
+ case RagWorkflowStepNames.CHUNKING:
752
+ if step_progress.success_count is not None:
753
+ self.current_progress.total_document_chunked_count = max(
754
+ self.current_progress.total_document_chunked_count,
755
+ step_progress.success_count
756
+ + self.initial_progress.total_document_chunked_count,
757
+ )
758
+ if step_progress.error_count is not None:
759
+ self.current_progress.total_document_chunked_error_count = max(
760
+ self.current_progress.total_document_chunked_error_count,
761
+ step_progress.error_count
762
+ + self.initial_progress.total_document_chunked_error_count,
763
+ )
764
+ case RagWorkflowStepNames.EMBEDDING:
765
+ if step_progress.success_count is not None:
766
+ self.current_progress.total_document_embedded_count = max(
767
+ self.current_progress.total_document_embedded_count,
768
+ step_progress.success_count
769
+ + self.initial_progress.total_document_embedded_count,
770
+ )
771
+ if step_progress.error_count is not None:
772
+ self.current_progress.total_document_embedded_error_count = max(
773
+ self.current_progress.total_document_embedded_error_count,
774
+ step_progress.error_count
775
+ + self.initial_progress.total_document_embedded_error_count,
776
+ )
777
+ case RagWorkflowStepNames.INDEXING:
778
+ if step_progress.success_count is not None:
779
+ self.current_progress.total_chunks_indexed_count += (
780
+ step_progress.success_count
781
+ )
782
+ if step_progress.error_count is not None:
783
+ self.current_progress.total_chunks_indexed_error_count += (
784
+ step_progress.error_count
785
+ )
786
+ case _:
787
+ raise_exhaustive_enum_error(step_name)
788
+
789
+ self.current_progress.total_document_completed_count = min(
790
+ self.current_progress.total_document_extracted_count,
791
+ self.current_progress.total_document_chunked_count,
792
+ self.current_progress.total_document_embedded_count,
793
+ )
794
+
795
+ self.current_progress.total_chunk_completed_count = (
796
+ self.current_progress.total_chunks_indexed_count
797
+ )
798
+
799
+ self.current_progress.logs = step_progress.logs
800
+ return self.current_progress
801
+
802
+ async def run(
803
+ self,
804
+ stages_to_run: list[RagWorkflowStepNames] | None = None,
805
+ document_ids: list[ID_TYPE] | None = None,
806
+ ) -> AsyncGenerator[RagProgress, None]:
807
+ """
808
+ Runs the RAG workflow for the given stages and document ids.
809
+
810
+ :param stages_to_run: The stages to run. If None, all stages will be run.
811
+ :param document_ids: The document ids to run the workflow for. If None, all documents will be run.
812
+ """
813
+ yield self.initial_progress
814
+
815
+ async with shared_async_lock_manager.acquire(
816
+ self.lock_key, timeout=LOCK_TIMEOUT_SECONDS
817
+ ):
818
+ for step in self.step_runners:
819
+ if stages_to_run is not None and step.stage() not in stages_to_run:
820
+ continue
821
+
822
+ # we need to know the total number of chunks to index to be able to
823
+ # calculate the progress on the client
824
+ if step.stage() == RagWorkflowStepNames.INDEXING and isinstance(
825
+ step, RagIndexingStepRunner
826
+ ):
827
+ self.current_progress.total_chunk_count = (
828
+ await step.count_total_chunks()
829
+ )
830
+ # reset the indexing progress to 0 since we go through all the chunks again
831
+ if not document_ids:
832
+ self.initial_progress.total_chunks_indexed_count = 0
833
+ self.current_progress.total_chunks_indexed_count = 0
834
+
835
+ yield self.update_workflow_progress(
836
+ step.stage(),
837
+ RagStepRunnerProgress(
838
+ success_count=0,
839
+ error_count=0,
840
+ ),
841
+ )
842
+
843
+ async for progress in step.run(document_ids=document_ids):
844
+ yield self.update_workflow_progress(step.stage(), progress)