kiln-ai 0.20.1__py3-none-any.whl → 0.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (117) hide show
  1. kiln_ai/adapters/__init__.py +6 -0
  2. kiln_ai/adapters/adapter_registry.py +43 -226
  3. kiln_ai/adapters/chunkers/__init__.py +13 -0
  4. kiln_ai/adapters/chunkers/base_chunker.py +42 -0
  5. kiln_ai/adapters/chunkers/chunker_registry.py +16 -0
  6. kiln_ai/adapters/chunkers/fixed_window_chunker.py +39 -0
  7. kiln_ai/adapters/chunkers/helpers.py +23 -0
  8. kiln_ai/adapters/chunkers/test_base_chunker.py +63 -0
  9. kiln_ai/adapters/chunkers/test_chunker_registry.py +28 -0
  10. kiln_ai/adapters/chunkers/test_fixed_window_chunker.py +346 -0
  11. kiln_ai/adapters/chunkers/test_helpers.py +75 -0
  12. kiln_ai/adapters/data_gen/test_data_gen_task.py +9 -3
  13. kiln_ai/adapters/embedding/__init__.py +0 -0
  14. kiln_ai/adapters/embedding/base_embedding_adapter.py +44 -0
  15. kiln_ai/adapters/embedding/embedding_registry.py +32 -0
  16. kiln_ai/adapters/embedding/litellm_embedding_adapter.py +199 -0
  17. kiln_ai/adapters/embedding/test_base_embedding_adapter.py +283 -0
  18. kiln_ai/adapters/embedding/test_embedding_registry.py +166 -0
  19. kiln_ai/adapters/embedding/test_litellm_embedding_adapter.py +1149 -0
  20. kiln_ai/adapters/eval/eval_runner.py +6 -2
  21. kiln_ai/adapters/eval/test_base_eval.py +1 -3
  22. kiln_ai/adapters/eval/test_g_eval.py +1 -1
  23. kiln_ai/adapters/extractors/__init__.py +18 -0
  24. kiln_ai/adapters/extractors/base_extractor.py +72 -0
  25. kiln_ai/adapters/extractors/encoding.py +20 -0
  26. kiln_ai/adapters/extractors/extractor_registry.py +44 -0
  27. kiln_ai/adapters/extractors/extractor_runner.py +112 -0
  28. kiln_ai/adapters/extractors/litellm_extractor.py +386 -0
  29. kiln_ai/adapters/extractors/test_base_extractor.py +244 -0
  30. kiln_ai/adapters/extractors/test_encoding.py +54 -0
  31. kiln_ai/adapters/extractors/test_extractor_registry.py +181 -0
  32. kiln_ai/adapters/extractors/test_extractor_runner.py +181 -0
  33. kiln_ai/adapters/extractors/test_litellm_extractor.py +1192 -0
  34. kiln_ai/adapters/fine_tune/test_dataset_formatter.py +2 -2
  35. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +2 -6
  36. kiln_ai/adapters/fine_tune/test_together_finetune.py +2 -6
  37. kiln_ai/adapters/ml_embedding_model_list.py +192 -0
  38. kiln_ai/adapters/ml_model_list.py +382 -4
  39. kiln_ai/adapters/model_adapters/litellm_adapter.py +7 -69
  40. kiln_ai/adapters/model_adapters/test_litellm_adapter.py +1 -1
  41. kiln_ai/adapters/model_adapters/test_structured_output.py +3 -1
  42. kiln_ai/adapters/ollama_tools.py +69 -12
  43. kiln_ai/adapters/provider_tools.py +190 -46
  44. kiln_ai/adapters/rag/deduplication.py +49 -0
  45. kiln_ai/adapters/rag/progress.py +252 -0
  46. kiln_ai/adapters/rag/rag_runners.py +844 -0
  47. kiln_ai/adapters/rag/test_deduplication.py +195 -0
  48. kiln_ai/adapters/rag/test_progress.py +785 -0
  49. kiln_ai/adapters/rag/test_rag_runners.py +2376 -0
  50. kiln_ai/adapters/remote_config.py +80 -8
  51. kiln_ai/adapters/test_adapter_registry.py +579 -86
  52. kiln_ai/adapters/test_ml_embedding_model_list.py +429 -0
  53. kiln_ai/adapters/test_ml_model_list.py +212 -0
  54. kiln_ai/adapters/test_ollama_tools.py +340 -1
  55. kiln_ai/adapters/test_prompt_builders.py +1 -1
  56. kiln_ai/adapters/test_provider_tools.py +199 -8
  57. kiln_ai/adapters/test_remote_config.py +551 -56
  58. kiln_ai/adapters/vector_store/__init__.py +1 -0
  59. kiln_ai/adapters/vector_store/base_vector_store_adapter.py +83 -0
  60. kiln_ai/adapters/vector_store/lancedb_adapter.py +389 -0
  61. kiln_ai/adapters/vector_store/test_base_vector_store.py +160 -0
  62. kiln_ai/adapters/vector_store/test_lancedb_adapter.py +1841 -0
  63. kiln_ai/adapters/vector_store/test_vector_store_registry.py +199 -0
  64. kiln_ai/adapters/vector_store/vector_store_registry.py +33 -0
  65. kiln_ai/datamodel/__init__.py +16 -13
  66. kiln_ai/datamodel/basemodel.py +170 -1
  67. kiln_ai/datamodel/chunk.py +158 -0
  68. kiln_ai/datamodel/datamodel_enums.py +27 -0
  69. kiln_ai/datamodel/embedding.py +64 -0
  70. kiln_ai/datamodel/extraction.py +303 -0
  71. kiln_ai/datamodel/project.py +33 -1
  72. kiln_ai/datamodel/rag.py +79 -0
  73. kiln_ai/datamodel/test_attachment.py +649 -0
  74. kiln_ai/datamodel/test_basemodel.py +1 -1
  75. kiln_ai/datamodel/test_chunk_models.py +317 -0
  76. kiln_ai/datamodel/test_dataset_split.py +1 -1
  77. kiln_ai/datamodel/test_embedding_models.py +448 -0
  78. kiln_ai/datamodel/test_eval_model.py +6 -6
  79. kiln_ai/datamodel/test_extraction_chunk.py +206 -0
  80. kiln_ai/datamodel/test_extraction_model.py +470 -0
  81. kiln_ai/datamodel/test_rag.py +641 -0
  82. kiln_ai/datamodel/test_tool_id.py +81 -0
  83. kiln_ai/datamodel/test_vector_store.py +320 -0
  84. kiln_ai/datamodel/tool_id.py +22 -0
  85. kiln_ai/datamodel/vector_store.py +141 -0
  86. kiln_ai/tools/mcp_session_manager.py +4 -1
  87. kiln_ai/tools/rag_tools.py +157 -0
  88. kiln_ai/tools/test_mcp_session_manager.py +1 -1
  89. kiln_ai/tools/test_rag_tools.py +848 -0
  90. kiln_ai/tools/test_tool_registry.py +91 -2
  91. kiln_ai/tools/tool_registry.py +21 -0
  92. kiln_ai/utils/__init__.py +3 -0
  93. kiln_ai/utils/async_job_runner.py +62 -17
  94. kiln_ai/utils/config.py +2 -2
  95. kiln_ai/utils/env.py +15 -0
  96. kiln_ai/utils/filesystem.py +14 -0
  97. kiln_ai/utils/filesystem_cache.py +60 -0
  98. kiln_ai/utils/litellm.py +94 -0
  99. kiln_ai/utils/lock.py +100 -0
  100. kiln_ai/utils/mime_type.py +38 -0
  101. kiln_ai/utils/pdf_utils.py +38 -0
  102. kiln_ai/utils/test_async_job_runner.py +151 -35
  103. kiln_ai/utils/test_env.py +142 -0
  104. kiln_ai/utils/test_filesystem_cache.py +316 -0
  105. kiln_ai/utils/test_litellm.py +206 -0
  106. kiln_ai/utils/test_lock.py +185 -0
  107. kiln_ai/utils/test_mime_type.py +66 -0
  108. kiln_ai/utils/test_pdf_utils.py +73 -0
  109. kiln_ai/utils/test_uuid.py +111 -0
  110. kiln_ai/utils/test_validation.py +524 -0
  111. kiln_ai/utils/uuid.py +9 -0
  112. kiln_ai/utils/validation.py +90 -0
  113. {kiln_ai-0.20.1.dist-info → kiln_ai-0.21.0.dist-info}/METADATA +7 -1
  114. kiln_ai-0.21.0.dist-info/RECORD +211 -0
  115. kiln_ai-0.20.1.dist-info/RECORD +0 -138
  116. {kiln_ai-0.20.1.dist-info → kiln_ai-0.21.0.dist-info}/WHEEL +0 -0
  117. {kiln_ai-0.20.1.dist-info → kiln_ai-0.21.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -160,8 +160,12 @@ class EvalRunner:
160
160
  """
161
161
  jobs = self.collect_tasks()
162
162
 
163
- runner = AsyncJobRunner(concurrency=concurrency)
164
- async for progress in runner.run(jobs, self.run_job):
163
+ runner = AsyncJobRunner(
164
+ concurrency=concurrency,
165
+ jobs=jobs,
166
+ run_job_fn=self.run_job,
167
+ )
168
+ async for progress in runner.run():
165
169
  yield progress
166
170
 
167
171
  async def run_job(self, job: EvalJob) -> bool:
@@ -307,9 +307,7 @@ async def test_run_method():
307
307
  evaluator = EvalTester(eval_config, run_config.run_config())
308
308
 
309
309
  # Run the evaluation
310
- task_run, eval_scores, intermediate_outputs = await evaluator.run_task_and_eval(
311
- "test input"
312
- )
310
+ task_run, eval_scores, _ = await evaluator.run_task_and_eval("test input")
313
311
 
314
312
  # Verify task run was created
315
313
  assert task_run.input == "test input"
@@ -188,7 +188,7 @@ async def test_run_g_eval_e2e(
188
188
  g_eval = GEval(test_eval_config, test_run_config)
189
189
 
190
190
  # Run the evaluation
191
- task_run, scores, intermediate_outputs = await g_eval.run_task_and_eval("chickens")
191
+ _, scores, intermediate_outputs = await g_eval.run_task_and_eval("chickens")
192
192
 
193
193
  # Verify the evaluation results
194
194
  assert isinstance(scores, dict)
@@ -0,0 +1,18 @@
1
+ """
2
+ File extractors for processing different document types.
3
+
4
+ This package provides a framework for extracting content from files
5
+ using different extraction methods.
6
+ """
7
+
8
+ from . import base_extractor, extractor_registry, extractor_runner, litellm_extractor
9
+ from .base_extractor import ExtractionInput, ExtractionOutput
10
+
11
+ __all__ = [
12
+ "ExtractionInput",
13
+ "ExtractionOutput",
14
+ "base_extractor",
15
+ "extractor_registry",
16
+ "extractor_runner",
17
+ "litellm_extractor",
18
+ ]
@@ -0,0 +1,72 @@
1
+ import logging
2
+ from abc import ABC, abstractmethod
3
+ from pathlib import Path
4
+
5
+ from pydantic import BaseModel, Field
6
+
7
+ from kiln_ai.datamodel.extraction import ExtractorConfig, OutputFormat
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ class ExtractionInput(BaseModel):
13
+ path: Path | str = Field(description="The absolute path to the file to extract.")
14
+ mime_type: str = Field(description="The mime type of the file.")
15
+
16
+
17
+ class ExtractionOutput(BaseModel):
18
+ """
19
+ The output of an extraction. This is the data that will be saved to the data store.
20
+ """
21
+
22
+ is_passthrough: bool = Field(
23
+ default=False, description="Whether the extractor returned the file as is."
24
+ )
25
+ content_format: OutputFormat = Field(
26
+ description="The format of the extracted data."
27
+ )
28
+ content: str = Field(description="The extracted data.")
29
+
30
+
31
+ class BaseExtractor(ABC):
32
+ """
33
+ Base class for all extractors.
34
+
35
+ Should be subclassed by each extractor.
36
+ """
37
+
38
+ def __init__(self, extractor_config: ExtractorConfig):
39
+ self.extractor_config = extractor_config
40
+
41
+ @abstractmethod
42
+ async def _extract(self, extraction_input: ExtractionInput) -> ExtractionOutput:
43
+ pass
44
+
45
+ async def extract(
46
+ self,
47
+ extraction_input: ExtractionInput,
48
+ ) -> ExtractionOutput:
49
+ """
50
+ Extracts content from a file by delegating to the concrete extractor implementation.
51
+ """
52
+ try:
53
+ if self._should_passthrough(extraction_input.mime_type):
54
+ return ExtractionOutput(
55
+ is_passthrough=True,
56
+ content=Path(extraction_input.path).read_text(encoding="utf-8"),
57
+ content_format=self.extractor_config.output_format,
58
+ )
59
+
60
+ return await self._extract(
61
+ extraction_input,
62
+ )
63
+ except Exception as e:
64
+ raise ValueError(f"Error extracting {extraction_input.path}: {e}") from e
65
+
66
+ def _should_passthrough(self, mime_type: str) -> bool:
67
+ return mime_type.lower() in {
68
+ mt.lower() for mt in self.extractor_config.passthrough_mimetypes
69
+ }
70
+
71
+ def output_format(self) -> OutputFormat:
72
+ return self.extractor_config.output_format
@@ -0,0 +1,20 @@
1
+ import base64
2
+
3
+
4
+ def to_base64_url(mime_type: str, bytes: bytes) -> str:
5
+ base64_url = f"data:{mime_type};base64,{base64.b64encode(bytes).decode('utf-8')}"
6
+ return base64_url
7
+
8
+
9
+ def from_base64_url(base64_url: str) -> bytes:
10
+ if not base64_url.startswith("data:") or "," not in base64_url:
11
+ raise ValueError("Invalid base64 URL format")
12
+
13
+ parts = base64_url.split(",")
14
+ if len(parts) != 2:
15
+ raise ValueError("Invalid base64 URL format")
16
+
17
+ try:
18
+ return base64.b64decode(parts[1])
19
+ except Exception as e:
20
+ raise ValueError(f"Failed to decode base64 data: {e}")
@@ -0,0 +1,44 @@
1
+ from kiln_ai.adapters.extractors.base_extractor import BaseExtractor
2
+ from kiln_ai.adapters.extractors.litellm_extractor import LitellmExtractor
3
+ from kiln_ai.adapters.ml_model_list import ModelProviderName
4
+ from kiln_ai.adapters.provider_tools import (
5
+ core_provider,
6
+ lite_llm_core_config_for_provider,
7
+ )
8
+ from kiln_ai.datamodel.extraction import ExtractorConfig, ExtractorType
9
+ from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
10
+ from kiln_ai.utils.filesystem_cache import FilesystemCache
11
+
12
+
13
+ def extractor_adapter_from_type(
14
+ extractor_type: ExtractorType,
15
+ extractor_config: ExtractorConfig,
16
+ filesystem_cache: FilesystemCache | None = None,
17
+ ) -> BaseExtractor:
18
+ match extractor_type:
19
+ case ExtractorType.LITELLM:
20
+ try:
21
+ provider_enum = ModelProviderName(extractor_config.model_provider_name)
22
+ except ValueError:
23
+ raise ValueError(
24
+ f"Unsupported model provider name: {extractor_config.model_provider_name}. "
25
+ )
26
+
27
+ core_provider_name = core_provider(
28
+ extractor_config.model_name, provider_enum
29
+ )
30
+
31
+ provider_config = lite_llm_core_config_for_provider(core_provider_name)
32
+ if provider_config is None:
33
+ raise ValueError(
34
+ f"No configuration found for core provider: {core_provider_name.value}. "
35
+ )
36
+
37
+ return LitellmExtractor(
38
+ extractor_config,
39
+ provider_config,
40
+ filesystem_cache,
41
+ )
42
+ case _:
43
+ # type checking will catch missing cases
44
+ raise_exhaustive_enum_error(extractor_type)
@@ -0,0 +1,112 @@
1
+ import logging
2
+ from collections import defaultdict
3
+ from dataclasses import dataclass
4
+ from pathlib import Path
5
+ from typing import AsyncGenerator, Dict, List, Set
6
+
7
+ from kiln_ai.adapters.extractors.base_extractor import BaseExtractor, ExtractionInput
8
+ from kiln_ai.adapters.extractors.extractor_registry import extractor_adapter_from_type
9
+ from kiln_ai.datamodel.basemodel import ID_TYPE, KilnAttachmentModel
10
+ from kiln_ai.datamodel.extraction import (
11
+ Document,
12
+ Extraction,
13
+ ExtractionSource,
14
+ ExtractorConfig,
15
+ )
16
+ from kiln_ai.utils.async_job_runner import AsyncJobRunner, Progress
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ @dataclass
22
+ class ExtractorJob:
23
+ doc: Document
24
+ extractor_config: ExtractorConfig
25
+
26
+
27
+ class ExtractorRunner:
28
+ def __init__(
29
+ self,
30
+ documents: List[Document],
31
+ extractor_configs: List[ExtractorConfig],
32
+ ):
33
+ if len(extractor_configs) == 0:
34
+ raise ValueError("Extractor runner requires at least one extractor config")
35
+
36
+ self.documents = documents
37
+ self.extractor_configs = extractor_configs
38
+
39
+ def collect_jobs(self) -> List[ExtractorJob]:
40
+ jobs = []
41
+
42
+ # we want to avoid re-running the same document for the same extractor config
43
+ already_extracted: Dict[ID_TYPE, Set[ID_TYPE]] = defaultdict(set)
44
+ for document in self.documents:
45
+ for extraction in document.extractions():
46
+ already_extracted[extraction.extractor_config_id].add(document.id)
47
+
48
+ for extractor_config in self.extractor_configs:
49
+ for document in self.documents:
50
+ if document.id not in already_extracted[extractor_config.id]:
51
+ jobs.append(
52
+ ExtractorJob(
53
+ doc=document,
54
+ extractor_config=extractor_config,
55
+ )
56
+ )
57
+
58
+ return jobs
59
+
60
+ async def run(self, concurrency: int = 25) -> AsyncGenerator[Progress, None]:
61
+ jobs = self.collect_jobs()
62
+
63
+ runner = AsyncJobRunner(
64
+ concurrency=concurrency,
65
+ jobs=jobs,
66
+ run_job_fn=self.run_job,
67
+ )
68
+ async for progress in runner.run():
69
+ yield progress
70
+
71
+ async def run_job(self, job: ExtractorJob) -> bool:
72
+ try:
73
+ extractor = extractor_adapter_from_type(
74
+ job.extractor_config.extractor_type,
75
+ job.extractor_config,
76
+ )
77
+ if not isinstance(extractor, BaseExtractor):
78
+ raise ValueError("Not able to create extractor from extractor config")
79
+
80
+ if job.doc.path is None:
81
+ raise ValueError("Document path is not set")
82
+
83
+ output = await extractor.extract(
84
+ extraction_input=ExtractionInput(
85
+ path=Path(
86
+ job.doc.original_file.attachment.resolve_path(
87
+ job.doc.path.parent
88
+ )
89
+ ),
90
+ mime_type=job.doc.original_file.mime_type,
91
+ )
92
+ )
93
+
94
+ extraction = Extraction(
95
+ parent=job.doc,
96
+ extractor_config_id=job.extractor_config.id,
97
+ output=KilnAttachmentModel.from_data(
98
+ data=output.content,
99
+ mime_type=output.content_format,
100
+ ),
101
+ source=ExtractionSource.PASSTHROUGH
102
+ if output.is_passthrough
103
+ else ExtractionSource.PROCESSED,
104
+ )
105
+ extraction.save_to_file()
106
+
107
+ return True
108
+ except Exception as e:
109
+ logger.error(
110
+ f"Error running extraction job for dataset item {job.doc.id}: {e}"
111
+ )
112
+ return False
@@ -0,0 +1,386 @@
1
+ import asyncio
2
+ import hashlib
3
+ import logging
4
+ from pathlib import Path
5
+ from typing import Any, List
6
+
7
+ import litellm
8
+ from litellm.types.utils import Choices, ModelResponse
9
+
10
+ from kiln_ai.adapters.extractors.base_extractor import (
11
+ BaseExtractor,
12
+ ExtractionInput,
13
+ ExtractionOutput,
14
+ )
15
+ from kiln_ai.adapters.extractors.encoding import to_base64_url
16
+ from kiln_ai.adapters.ml_model_list import built_in_models_from_provider
17
+ from kiln_ai.adapters.provider_tools import LiteLlmCoreConfig
18
+ from kiln_ai.datamodel.datamodel_enums import ModelProviderName
19
+ from kiln_ai.datamodel.extraction import ExtractorConfig, ExtractorType, Kind
20
+ from kiln_ai.utils.filesystem_cache import FilesystemCache
21
+ from kiln_ai.utils.litellm import get_litellm_provider_info
22
+ from kiln_ai.utils.pdf_utils import split_pdf_into_pages
23
+
24
+
25
+ def max_pdf_page_concurrency_for_model(model_name: str) -> int:
26
+ # we assume each batch takes ~5s to complete (likely more in practice)
27
+ # lowest rate limit is 150 RPM for Tier 1 accounts for gemini-2.5-pro
28
+ if model_name == "gemini/gemini-2.5-pro":
29
+ return 2
30
+ # other models support at least 500 RPM for lowest tier accounts
31
+ return 5
32
+
33
+
34
+ logger = logging.getLogger(__name__)
35
+
36
+ MIME_TYPES_SUPPORTED = {
37
+ Kind.DOCUMENT: [
38
+ "application/pdf",
39
+ "text/plain",
40
+ "text/markdown", # not officially listed, but works
41
+ "text/html",
42
+ "text/md",
43
+ "text/csv",
44
+ ],
45
+ Kind.IMAGE: [
46
+ "image/png",
47
+ "image/jpeg",
48
+ "image/jpg",
49
+ ],
50
+ Kind.VIDEO: [
51
+ "video/mp4",
52
+ "video/mov", # the correct type is video/quicktime, but Google lists it as video/mov
53
+ "video/quicktime",
54
+ ],
55
+ Kind.AUDIO: [
56
+ "audio/wav",
57
+ "audio/mpeg", # this is the official MP3 mimetype, audio/mp3 is often used but not correct
58
+ "audio/ogg",
59
+ ],
60
+ }
61
+
62
+
63
+ def encode_file_litellm_format(path: Path, mime_type: str) -> dict[str, Any]:
64
+ # There are different formats that LiteLLM supports, the docs are scattered
65
+ # and incomplete:
66
+ # - https://docs.litellm.ai/docs/completion/document_understanding#base64
67
+ # - https://docs.litellm.ai/docs/completion/vision#explicitly-specify-image-type
68
+
69
+ # this is the most generic format that seems to work for all / most mime types
70
+ if mime_type in [
71
+ "application/pdf",
72
+ "text/csv",
73
+ "text/html",
74
+ "text/markdown",
75
+ "text/plain",
76
+ ] or any(mime_type.startswith(m) for m in ["video/", "audio/"]):
77
+ pdf_bytes = path.read_bytes()
78
+ return {
79
+ "type": "file",
80
+ "file": {
81
+ "file_data": to_base64_url(mime_type, pdf_bytes),
82
+ },
83
+ }
84
+
85
+ # image has its own format (but also appears to work with the file format)
86
+ if mime_type.startswith("image/"):
87
+ image_bytes = path.read_bytes()
88
+ return {
89
+ "type": "image_url",
90
+ "image_url": {
91
+ "url": to_base64_url(mime_type, image_bytes),
92
+ },
93
+ }
94
+
95
+ raise ValueError(f"Unsupported MIME type: {mime_type} for {path}")
96
+
97
+
98
+ class LitellmExtractor(BaseExtractor):
99
+ def __init__(
100
+ self,
101
+ extractor_config: ExtractorConfig,
102
+ litellm_core_config: LiteLlmCoreConfig,
103
+ filesystem_cache: FilesystemCache | None = None,
104
+ ):
105
+ if extractor_config.extractor_type != ExtractorType.LITELLM:
106
+ raise ValueError(
107
+ f"LitellmExtractor must be initialized with a litellm extractor_type config. Got {extractor_config.extractor_type}"
108
+ )
109
+
110
+ prompt_document = extractor_config.prompt_document()
111
+ if prompt_document is None or prompt_document == "":
112
+ raise ValueError(
113
+ "properties.prompt_document is required for LitellmExtractor"
114
+ )
115
+ prompt_video = extractor_config.prompt_video()
116
+ if prompt_video is None or prompt_video == "":
117
+ raise ValueError("properties.prompt_video is required for LitellmExtractor")
118
+ prompt_audio = extractor_config.prompt_audio()
119
+ if prompt_audio is None or prompt_audio == "":
120
+ raise ValueError("properties.prompt_audio is required for LitellmExtractor")
121
+ prompt_image = extractor_config.prompt_image()
122
+ if prompt_image is None or prompt_image == "":
123
+ raise ValueError("properties.prompt_image is required for LitellmExtractor")
124
+
125
+ self.filesystem_cache = filesystem_cache
126
+
127
+ super().__init__(extractor_config)
128
+ self.prompt_for_kind = {
129
+ Kind.DOCUMENT: prompt_document,
130
+ Kind.VIDEO: prompt_video,
131
+ Kind.AUDIO: prompt_audio,
132
+ Kind.IMAGE: prompt_image,
133
+ }
134
+
135
+ self.litellm_core_config = litellm_core_config
136
+
137
+ def pdf_page_cache_key(self, pdf_path: Path, page_number: int) -> str:
138
+ """
139
+ Generate a cache key for a page of a PDF. The PDF path must be the full path to the PDF file,
140
+ not the path to the page - since page path is temporary and changes on each run.
141
+ """
142
+ if self.extractor_config.id is None:
143
+ raise ValueError("Extractor config ID is required for PDF page cache key")
144
+
145
+ raw_key = f"{pdf_path.resolve()}::{page_number}"
146
+ digest = hashlib.md5(raw_key.encode("utf-8")).hexdigest()
147
+ return f"{self.extractor_config.id}_{digest}"
148
+
149
+ async def get_page_content_from_cache(
150
+ self, pdf_path: Path, page_number: int
151
+ ) -> str | None:
152
+ if self.filesystem_cache is None:
153
+ return None
154
+
155
+ page_bytes = await self.filesystem_cache.get(
156
+ self.pdf_page_cache_key(pdf_path, page_number)
157
+ )
158
+
159
+ if page_bytes is not None:
160
+ logger.debug(f"Cache hit for page {page_number} of {pdf_path}")
161
+ try:
162
+ return page_bytes.decode("utf-8")
163
+ except UnicodeDecodeError:
164
+ logger.warning(
165
+ "Cached bytes for page %s of %s are not valid UTF-8; treating as miss.",
166
+ page_number,
167
+ pdf_path,
168
+ exc_info=True,
169
+ )
170
+
171
+ logger.debug(f"Cache miss for page {page_number} of {pdf_path}")
172
+ return None
173
+
174
+ async def _extract_single_pdf_page(
175
+ self, pdf_path: Path, page_path: Path, prompt: str, page_number: int
176
+ ) -> str:
177
+ try:
178
+ page_input = ExtractionInput(
179
+ path=str(page_path), mime_type="application/pdf"
180
+ )
181
+ completion_kwargs = self._build_completion_kwargs(prompt, page_input)
182
+ response = await litellm.acompletion(**completion_kwargs)
183
+ except Exception as e:
184
+ raise RuntimeError(
185
+ f"Error extracting page {page_number} in file {page_path}: {e}"
186
+ ) from e
187
+
188
+ if (
189
+ not isinstance(response, ModelResponse)
190
+ or not response.choices
191
+ or len(response.choices) == 0
192
+ or not isinstance(response.choices[0], Choices)
193
+ ):
194
+ raise RuntimeError(
195
+ f"Expected ModelResponse with Choices for page {page_number}, got {type(response)}."
196
+ )
197
+
198
+ if response.choices[0].message.content is None:
199
+ raise ValueError(
200
+ f"No text returned from LiteLLM when extracting page {page_number}"
201
+ )
202
+
203
+ content = response.choices[0].message.content
204
+ if not content:
205
+ raise ValueError(
206
+ f"No text returned from extraction model when extracting page {page_number} for {page_path}"
207
+ )
208
+
209
+ if self.filesystem_cache is not None:
210
+ # we don't want to fail the whole extraction just because cache write fails
211
+ # as that would block the whole flow
212
+ try:
213
+ logger.debug(f"Caching page {page_number} of {page_path} in cache")
214
+ await self.filesystem_cache.set(
215
+ self.pdf_page_cache_key(pdf_path, page_number),
216
+ content.encode("utf-8"),
217
+ )
218
+ except Exception:
219
+ logger.warning(
220
+ "Failed to cache page %s of %s; continuing without cache.",
221
+ page_number,
222
+ page_path,
223
+ exc_info=True,
224
+ )
225
+
226
+ return content
227
+
228
+ async def _extract_pdf_page_by_page(self, pdf_path: Path, prompt: str) -> str:
229
+ async with split_pdf_into_pages(pdf_path) as page_paths:
230
+ page_outcomes: List[str | Exception | None] = [None] * len(page_paths)
231
+
232
+ extract_page_jobs: list = []
233
+ page_indices_for_jobs: list = [] # Track which page index each job corresponds to
234
+
235
+ # we extract from each page individually and then combine the results
236
+ # this ensures the model stays focused on the current page and does not
237
+ # start summarizing the later pages
238
+ for i, page_path in enumerate(page_paths):
239
+ page_content = await self.get_page_content_from_cache(pdf_path, i)
240
+ if page_content is not None:
241
+ page_outcomes[i] = page_content
242
+ continue
243
+
244
+ extract_page_jobs.append(
245
+ self._extract_single_pdf_page(pdf_path, page_path, prompt, i)
246
+ )
247
+ page_indices_for_jobs.append(i)
248
+
249
+ if (
250
+ len(extract_page_jobs)
251
+ >= max_pdf_page_concurrency_for_model(self.litellm_model_slug())
252
+ or i == len(page_paths) - 1
253
+ ):
254
+ extraction_results = await asyncio.gather(
255
+ *extract_page_jobs, return_exceptions=True
256
+ )
257
+
258
+ for batch_i, extraction_result in enumerate(extraction_results):
259
+ page_index = page_indices_for_jobs[batch_i]
260
+ # we let it continue even if there is an error - the success results will be cached
261
+ # and can be reused on the next run
262
+ if isinstance(extraction_result, Exception):
263
+ page_outcomes[page_index] = extraction_result
264
+ elif isinstance(extraction_result, str):
265
+ page_outcomes[page_index] = extraction_result
266
+ else:
267
+ raise ValueError(
268
+ f"Unexpected type {type(extraction_result)} for page {page_index}"
269
+ )
270
+ extract_page_jobs.clear()
271
+ page_indices_for_jobs.clear()
272
+
273
+ exceptions: list[tuple[int, Exception]] = [
274
+ (page_index, result)
275
+ for page_index, result in enumerate(page_outcomes)
276
+ if isinstance(result, Exception)
277
+ ]
278
+ if len(exceptions) > 0:
279
+ msg = f"Error extracting PDF {pdf_path}: "
280
+ for page_index, exception in exceptions:
281
+ msg += f"Page {page_index}: {exception}\n"
282
+ raise RuntimeError(msg)
283
+
284
+ return "\n\n".join(
285
+ [outcome for outcome in page_outcomes if isinstance(outcome, str)]
286
+ )
287
+
288
+ def _get_kind_from_mime_type(self, mime_type: str) -> Kind | None:
289
+ for kind, mime_types in MIME_TYPES_SUPPORTED.items():
290
+ if mime_type in mime_types:
291
+ return kind
292
+ return None
293
+
294
+ def _build_completion_kwargs(
295
+ self, prompt: str, extraction_input: ExtractionInput
296
+ ) -> dict[str, Any]:
297
+ completion_kwargs = {
298
+ "model": self.litellm_model_slug(),
299
+ "messages": [
300
+ {
301
+ "role": "user",
302
+ "content": [
303
+ {"type": "text", "text": prompt},
304
+ encode_file_litellm_format(
305
+ Path(extraction_input.path), extraction_input.mime_type
306
+ ),
307
+ ],
308
+ }
309
+ ],
310
+ }
311
+
312
+ if self.litellm_core_config.base_url:
313
+ completion_kwargs["base_url"] = self.litellm_core_config.base_url
314
+
315
+ if self.litellm_core_config.default_headers:
316
+ completion_kwargs["default_headers"] = (
317
+ self.litellm_core_config.default_headers
318
+ )
319
+
320
+ if self.litellm_core_config.additional_body_options:
321
+ completion_kwargs.update(self.litellm_core_config.additional_body_options)
322
+
323
+ return completion_kwargs
324
+
325
+ async def _extract(self, extraction_input: ExtractionInput) -> ExtractionOutput:
326
+ kind = self._get_kind_from_mime_type(extraction_input.mime_type)
327
+ if kind is None:
328
+ raise ValueError(
329
+ f"Unsupported MIME type: {extraction_input.mime_type} for {extraction_input.path}"
330
+ )
331
+
332
+ prompt = self.prompt_for_kind.get(kind)
333
+ if prompt is None:
334
+ raise ValueError(f"No prompt found for kind: {kind}")
335
+
336
+ # special handling for PDFs - process each page individually
337
+ if extraction_input.mime_type == "application/pdf":
338
+ content = await self._extract_pdf_page_by_page(
339
+ Path(extraction_input.path), prompt
340
+ )
341
+ return ExtractionOutput(
342
+ is_passthrough=False,
343
+ content=content,
344
+ content_format=self.extractor_config.output_format,
345
+ )
346
+
347
+ completion_kwargs = self._build_completion_kwargs(prompt, extraction_input)
348
+
349
+ response = await litellm.acompletion(**completion_kwargs)
350
+
351
+ if (
352
+ not isinstance(response, ModelResponse)
353
+ or not response.choices
354
+ or len(response.choices) == 0
355
+ or not isinstance(response.choices[0], Choices)
356
+ ):
357
+ raise RuntimeError(
358
+ f"Expected ModelResponse with Choices, got {type(response)}."
359
+ )
360
+
361
+ if response.choices[0].message.content is None:
362
+ raise ValueError("No text returned from LiteLLM when extracting document")
363
+
364
+ return ExtractionOutput(
365
+ is_passthrough=False,
366
+ content=response.choices[0].message.content,
367
+ content_format=self.extractor_config.output_format,
368
+ )
369
+
370
+ def litellm_model_slug(self) -> str:
371
+ kiln_model_provider = built_in_models_from_provider(
372
+ ModelProviderName(self.extractor_config.model_provider_name),
373
+ self.extractor_config.model_name,
374
+ )
375
+
376
+ if kiln_model_provider is None:
377
+ raise ValueError(
378
+ f"Model provider {self.extractor_config.model_provider_name} not found in the list of built-in models"
379
+ )
380
+
381
+ # need to translate into LiteLLM model slug
382
+ litellm_provider_name = get_litellm_provider_info(
383
+ kiln_model_provider,
384
+ )
385
+
386
+ return litellm_provider_name.litellm_model_id