kiln-ai 0.19.0__py3-none-any.whl → 0.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +8 -2
- kiln_ai/adapters/adapter_registry.py +43 -208
- kiln_ai/adapters/chat/chat_formatter.py +8 -12
- kiln_ai/adapters/chat/test_chat_formatter.py +6 -2
- kiln_ai/adapters/chunkers/__init__.py +13 -0
- kiln_ai/adapters/chunkers/base_chunker.py +42 -0
- kiln_ai/adapters/chunkers/chunker_registry.py +16 -0
- kiln_ai/adapters/chunkers/fixed_window_chunker.py +39 -0
- kiln_ai/adapters/chunkers/helpers.py +23 -0
- kiln_ai/adapters/chunkers/test_base_chunker.py +63 -0
- kiln_ai/adapters/chunkers/test_chunker_registry.py +28 -0
- kiln_ai/adapters/chunkers/test_fixed_window_chunker.py +346 -0
- kiln_ai/adapters/chunkers/test_helpers.py +75 -0
- kiln_ai/adapters/data_gen/test_data_gen_task.py +9 -3
- kiln_ai/adapters/docker_model_runner_tools.py +119 -0
- kiln_ai/adapters/embedding/__init__.py +0 -0
- kiln_ai/adapters/embedding/base_embedding_adapter.py +44 -0
- kiln_ai/adapters/embedding/embedding_registry.py +32 -0
- kiln_ai/adapters/embedding/litellm_embedding_adapter.py +199 -0
- kiln_ai/adapters/embedding/test_base_embedding_adapter.py +283 -0
- kiln_ai/adapters/embedding/test_embedding_registry.py +166 -0
- kiln_ai/adapters/embedding/test_litellm_embedding_adapter.py +1149 -0
- kiln_ai/adapters/eval/base_eval.py +2 -2
- kiln_ai/adapters/eval/eval_runner.py +9 -3
- kiln_ai/adapters/eval/g_eval.py +2 -2
- kiln_ai/adapters/eval/test_base_eval.py +2 -4
- kiln_ai/adapters/eval/test_g_eval.py +4 -5
- kiln_ai/adapters/extractors/__init__.py +18 -0
- kiln_ai/adapters/extractors/base_extractor.py +72 -0
- kiln_ai/adapters/extractors/encoding.py +20 -0
- kiln_ai/adapters/extractors/extractor_registry.py +44 -0
- kiln_ai/adapters/extractors/extractor_runner.py +112 -0
- kiln_ai/adapters/extractors/litellm_extractor.py +386 -0
- kiln_ai/adapters/extractors/test_base_extractor.py +244 -0
- kiln_ai/adapters/extractors/test_encoding.py +54 -0
- kiln_ai/adapters/extractors/test_extractor_registry.py +181 -0
- kiln_ai/adapters/extractors/test_extractor_runner.py +181 -0
- kiln_ai/adapters/extractors/test_litellm_extractor.py +1192 -0
- kiln_ai/adapters/fine_tune/__init__.py +1 -1
- kiln_ai/adapters/fine_tune/openai_finetune.py +14 -4
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +2 -2
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +2 -6
- kiln_ai/adapters/fine_tune/test_openai_finetune.py +108 -111
- kiln_ai/adapters/fine_tune/test_together_finetune.py +2 -6
- kiln_ai/adapters/ml_embedding_model_list.py +192 -0
- kiln_ai/adapters/ml_model_list.py +761 -37
- kiln_ai/adapters/model_adapters/base_adapter.py +51 -21
- kiln_ai/adapters/model_adapters/litellm_adapter.py +380 -138
- kiln_ai/adapters/model_adapters/test_base_adapter.py +193 -17
- kiln_ai/adapters/model_adapters/test_litellm_adapter.py +407 -2
- kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py +1103 -0
- kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +5 -5
- kiln_ai/adapters/model_adapters/test_structured_output.py +113 -5
- kiln_ai/adapters/ollama_tools.py +69 -12
- kiln_ai/adapters/parsers/__init__.py +1 -1
- kiln_ai/adapters/provider_tools.py +205 -47
- kiln_ai/adapters/rag/deduplication.py +49 -0
- kiln_ai/adapters/rag/progress.py +252 -0
- kiln_ai/adapters/rag/rag_runners.py +844 -0
- kiln_ai/adapters/rag/test_deduplication.py +195 -0
- kiln_ai/adapters/rag/test_progress.py +785 -0
- kiln_ai/adapters/rag/test_rag_runners.py +2376 -0
- kiln_ai/adapters/remote_config.py +80 -8
- kiln_ai/adapters/repair/test_repair_task.py +12 -9
- kiln_ai/adapters/run_output.py +3 -0
- kiln_ai/adapters/test_adapter_registry.py +657 -85
- kiln_ai/adapters/test_docker_model_runner_tools.py +305 -0
- kiln_ai/adapters/test_ml_embedding_model_list.py +429 -0
- kiln_ai/adapters/test_ml_model_list.py +251 -1
- kiln_ai/adapters/test_ollama_tools.py +340 -1
- kiln_ai/adapters/test_prompt_adaptors.py +13 -6
- kiln_ai/adapters/test_prompt_builders.py +1 -1
- kiln_ai/adapters/test_provider_tools.py +254 -8
- kiln_ai/adapters/test_remote_config.py +651 -58
- kiln_ai/adapters/vector_store/__init__.py +1 -0
- kiln_ai/adapters/vector_store/base_vector_store_adapter.py +83 -0
- kiln_ai/adapters/vector_store/lancedb_adapter.py +389 -0
- kiln_ai/adapters/vector_store/test_base_vector_store.py +160 -0
- kiln_ai/adapters/vector_store/test_lancedb_adapter.py +1841 -0
- kiln_ai/adapters/vector_store/test_vector_store_registry.py +199 -0
- kiln_ai/adapters/vector_store/vector_store_registry.py +33 -0
- kiln_ai/datamodel/__init__.py +39 -34
- kiln_ai/datamodel/basemodel.py +170 -1
- kiln_ai/datamodel/chunk.py +158 -0
- kiln_ai/datamodel/datamodel_enums.py +28 -0
- kiln_ai/datamodel/embedding.py +64 -0
- kiln_ai/datamodel/eval.py +1 -1
- kiln_ai/datamodel/external_tool_server.py +298 -0
- kiln_ai/datamodel/extraction.py +303 -0
- kiln_ai/datamodel/json_schema.py +25 -10
- kiln_ai/datamodel/project.py +40 -1
- kiln_ai/datamodel/rag.py +79 -0
- kiln_ai/datamodel/registry.py +0 -15
- kiln_ai/datamodel/run_config.py +62 -0
- kiln_ai/datamodel/task.py +2 -77
- kiln_ai/datamodel/task_output.py +6 -1
- kiln_ai/datamodel/task_run.py +41 -0
- kiln_ai/datamodel/test_attachment.py +649 -0
- kiln_ai/datamodel/test_basemodel.py +4 -4
- kiln_ai/datamodel/test_chunk_models.py +317 -0
- kiln_ai/datamodel/test_dataset_split.py +1 -1
- kiln_ai/datamodel/test_embedding_models.py +448 -0
- kiln_ai/datamodel/test_eval_model.py +6 -6
- kiln_ai/datamodel/test_example_models.py +175 -0
- kiln_ai/datamodel/test_external_tool_server.py +691 -0
- kiln_ai/datamodel/test_extraction_chunk.py +206 -0
- kiln_ai/datamodel/test_extraction_model.py +470 -0
- kiln_ai/datamodel/test_rag.py +641 -0
- kiln_ai/datamodel/test_registry.py +8 -3
- kiln_ai/datamodel/test_task.py +15 -47
- kiln_ai/datamodel/test_tool_id.py +320 -0
- kiln_ai/datamodel/test_vector_store.py +320 -0
- kiln_ai/datamodel/tool_id.py +105 -0
- kiln_ai/datamodel/vector_store.py +141 -0
- kiln_ai/tools/__init__.py +8 -0
- kiln_ai/tools/base_tool.py +82 -0
- kiln_ai/tools/built_in_tools/__init__.py +13 -0
- kiln_ai/tools/built_in_tools/math_tools.py +124 -0
- kiln_ai/tools/built_in_tools/test_math_tools.py +204 -0
- kiln_ai/tools/mcp_server_tool.py +95 -0
- kiln_ai/tools/mcp_session_manager.py +246 -0
- kiln_ai/tools/rag_tools.py +157 -0
- kiln_ai/tools/test_base_tools.py +199 -0
- kiln_ai/tools/test_mcp_server_tool.py +457 -0
- kiln_ai/tools/test_mcp_session_manager.py +1585 -0
- kiln_ai/tools/test_rag_tools.py +848 -0
- kiln_ai/tools/test_tool_registry.py +562 -0
- kiln_ai/tools/tool_registry.py +85 -0
- kiln_ai/utils/__init__.py +3 -0
- kiln_ai/utils/async_job_runner.py +62 -17
- kiln_ai/utils/config.py +24 -2
- kiln_ai/utils/env.py +15 -0
- kiln_ai/utils/filesystem.py +14 -0
- kiln_ai/utils/filesystem_cache.py +60 -0
- kiln_ai/utils/litellm.py +94 -0
- kiln_ai/utils/lock.py +100 -0
- kiln_ai/utils/mime_type.py +38 -0
- kiln_ai/utils/open_ai_types.py +94 -0
- kiln_ai/utils/pdf_utils.py +38 -0
- kiln_ai/utils/project_utils.py +17 -0
- kiln_ai/utils/test_async_job_runner.py +151 -35
- kiln_ai/utils/test_config.py +138 -1
- kiln_ai/utils/test_env.py +142 -0
- kiln_ai/utils/test_filesystem_cache.py +316 -0
- kiln_ai/utils/test_litellm.py +206 -0
- kiln_ai/utils/test_lock.py +185 -0
- kiln_ai/utils/test_mime_type.py +66 -0
- kiln_ai/utils/test_open_ai_types.py +131 -0
- kiln_ai/utils/test_pdf_utils.py +73 -0
- kiln_ai/utils/test_uuid.py +111 -0
- kiln_ai/utils/test_validation.py +524 -0
- kiln_ai/utils/uuid.py +9 -0
- kiln_ai/utils/validation.py +90 -0
- {kiln_ai-0.19.0.dist-info → kiln_ai-0.21.0.dist-info}/METADATA +12 -5
- kiln_ai-0.21.0.dist-info/RECORD +211 -0
- kiln_ai-0.19.0.dist-info/RECORD +0 -115
- {kiln_ai-0.19.0.dist-info → kiln_ai-0.21.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.19.0.dist-info → kiln_ai-0.21.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -0,0 +1,386 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import hashlib
|
|
3
|
+
import logging
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any, List
|
|
6
|
+
|
|
7
|
+
import litellm
|
|
8
|
+
from litellm.types.utils import Choices, ModelResponse
|
|
9
|
+
|
|
10
|
+
from kiln_ai.adapters.extractors.base_extractor import (
|
|
11
|
+
BaseExtractor,
|
|
12
|
+
ExtractionInput,
|
|
13
|
+
ExtractionOutput,
|
|
14
|
+
)
|
|
15
|
+
from kiln_ai.adapters.extractors.encoding import to_base64_url
|
|
16
|
+
from kiln_ai.adapters.ml_model_list import built_in_models_from_provider
|
|
17
|
+
from kiln_ai.adapters.provider_tools import LiteLlmCoreConfig
|
|
18
|
+
from kiln_ai.datamodel.datamodel_enums import ModelProviderName
|
|
19
|
+
from kiln_ai.datamodel.extraction import ExtractorConfig, ExtractorType, Kind
|
|
20
|
+
from kiln_ai.utils.filesystem_cache import FilesystemCache
|
|
21
|
+
from kiln_ai.utils.litellm import get_litellm_provider_info
|
|
22
|
+
from kiln_ai.utils.pdf_utils import split_pdf_into_pages
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def max_pdf_page_concurrency_for_model(model_name: str) -> int:
|
|
26
|
+
# we assume each batch takes ~5s to complete (likely more in practice)
|
|
27
|
+
# lowest rate limit is 150 RPM for Tier 1 accounts for gemini-2.5-pro
|
|
28
|
+
if model_name == "gemini/gemini-2.5-pro":
|
|
29
|
+
return 2
|
|
30
|
+
# other models support at least 500 RPM for lowest tier accounts
|
|
31
|
+
return 5
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
logger = logging.getLogger(__name__)
|
|
35
|
+
|
|
36
|
+
MIME_TYPES_SUPPORTED = {
|
|
37
|
+
Kind.DOCUMENT: [
|
|
38
|
+
"application/pdf",
|
|
39
|
+
"text/plain",
|
|
40
|
+
"text/markdown", # not officially listed, but works
|
|
41
|
+
"text/html",
|
|
42
|
+
"text/md",
|
|
43
|
+
"text/csv",
|
|
44
|
+
],
|
|
45
|
+
Kind.IMAGE: [
|
|
46
|
+
"image/png",
|
|
47
|
+
"image/jpeg",
|
|
48
|
+
"image/jpg",
|
|
49
|
+
],
|
|
50
|
+
Kind.VIDEO: [
|
|
51
|
+
"video/mp4",
|
|
52
|
+
"video/mov", # the correct type is video/quicktime, but Google lists it as video/mov
|
|
53
|
+
"video/quicktime",
|
|
54
|
+
],
|
|
55
|
+
Kind.AUDIO: [
|
|
56
|
+
"audio/wav",
|
|
57
|
+
"audio/mpeg", # this is the official MP3 mimetype, audio/mp3 is often used but not correct
|
|
58
|
+
"audio/ogg",
|
|
59
|
+
],
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def encode_file_litellm_format(path: Path, mime_type: str) -> dict[str, Any]:
|
|
64
|
+
# There are different formats that LiteLLM supports, the docs are scattered
|
|
65
|
+
# and incomplete:
|
|
66
|
+
# - https://docs.litellm.ai/docs/completion/document_understanding#base64
|
|
67
|
+
# - https://docs.litellm.ai/docs/completion/vision#explicitly-specify-image-type
|
|
68
|
+
|
|
69
|
+
# this is the most generic format that seems to work for all / most mime types
|
|
70
|
+
if mime_type in [
|
|
71
|
+
"application/pdf",
|
|
72
|
+
"text/csv",
|
|
73
|
+
"text/html",
|
|
74
|
+
"text/markdown",
|
|
75
|
+
"text/plain",
|
|
76
|
+
] or any(mime_type.startswith(m) for m in ["video/", "audio/"]):
|
|
77
|
+
pdf_bytes = path.read_bytes()
|
|
78
|
+
return {
|
|
79
|
+
"type": "file",
|
|
80
|
+
"file": {
|
|
81
|
+
"file_data": to_base64_url(mime_type, pdf_bytes),
|
|
82
|
+
},
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
# image has its own format (but also appears to work with the file format)
|
|
86
|
+
if mime_type.startswith("image/"):
|
|
87
|
+
image_bytes = path.read_bytes()
|
|
88
|
+
return {
|
|
89
|
+
"type": "image_url",
|
|
90
|
+
"image_url": {
|
|
91
|
+
"url": to_base64_url(mime_type, image_bytes),
|
|
92
|
+
},
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
raise ValueError(f"Unsupported MIME type: {mime_type} for {path}")
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class LitellmExtractor(BaseExtractor):
|
|
99
|
+
def __init__(
|
|
100
|
+
self,
|
|
101
|
+
extractor_config: ExtractorConfig,
|
|
102
|
+
litellm_core_config: LiteLlmCoreConfig,
|
|
103
|
+
filesystem_cache: FilesystemCache | None = None,
|
|
104
|
+
):
|
|
105
|
+
if extractor_config.extractor_type != ExtractorType.LITELLM:
|
|
106
|
+
raise ValueError(
|
|
107
|
+
f"LitellmExtractor must be initialized with a litellm extractor_type config. Got {extractor_config.extractor_type}"
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
prompt_document = extractor_config.prompt_document()
|
|
111
|
+
if prompt_document is None or prompt_document == "":
|
|
112
|
+
raise ValueError(
|
|
113
|
+
"properties.prompt_document is required for LitellmExtractor"
|
|
114
|
+
)
|
|
115
|
+
prompt_video = extractor_config.prompt_video()
|
|
116
|
+
if prompt_video is None or prompt_video == "":
|
|
117
|
+
raise ValueError("properties.prompt_video is required for LitellmExtractor")
|
|
118
|
+
prompt_audio = extractor_config.prompt_audio()
|
|
119
|
+
if prompt_audio is None or prompt_audio == "":
|
|
120
|
+
raise ValueError("properties.prompt_audio is required for LitellmExtractor")
|
|
121
|
+
prompt_image = extractor_config.prompt_image()
|
|
122
|
+
if prompt_image is None or prompt_image == "":
|
|
123
|
+
raise ValueError("properties.prompt_image is required for LitellmExtractor")
|
|
124
|
+
|
|
125
|
+
self.filesystem_cache = filesystem_cache
|
|
126
|
+
|
|
127
|
+
super().__init__(extractor_config)
|
|
128
|
+
self.prompt_for_kind = {
|
|
129
|
+
Kind.DOCUMENT: prompt_document,
|
|
130
|
+
Kind.VIDEO: prompt_video,
|
|
131
|
+
Kind.AUDIO: prompt_audio,
|
|
132
|
+
Kind.IMAGE: prompt_image,
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
self.litellm_core_config = litellm_core_config
|
|
136
|
+
|
|
137
|
+
def pdf_page_cache_key(self, pdf_path: Path, page_number: int) -> str:
|
|
138
|
+
"""
|
|
139
|
+
Generate a cache key for a page of a PDF. The PDF path must be the full path to the PDF file,
|
|
140
|
+
not the path to the page - since page path is temporary and changes on each run.
|
|
141
|
+
"""
|
|
142
|
+
if self.extractor_config.id is None:
|
|
143
|
+
raise ValueError("Extractor config ID is required for PDF page cache key")
|
|
144
|
+
|
|
145
|
+
raw_key = f"{pdf_path.resolve()}::{page_number}"
|
|
146
|
+
digest = hashlib.md5(raw_key.encode("utf-8")).hexdigest()
|
|
147
|
+
return f"{self.extractor_config.id}_{digest}"
|
|
148
|
+
|
|
149
|
+
async def get_page_content_from_cache(
|
|
150
|
+
self, pdf_path: Path, page_number: int
|
|
151
|
+
) -> str | None:
|
|
152
|
+
if self.filesystem_cache is None:
|
|
153
|
+
return None
|
|
154
|
+
|
|
155
|
+
page_bytes = await self.filesystem_cache.get(
|
|
156
|
+
self.pdf_page_cache_key(pdf_path, page_number)
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
if page_bytes is not None:
|
|
160
|
+
logger.debug(f"Cache hit for page {page_number} of {pdf_path}")
|
|
161
|
+
try:
|
|
162
|
+
return page_bytes.decode("utf-8")
|
|
163
|
+
except UnicodeDecodeError:
|
|
164
|
+
logger.warning(
|
|
165
|
+
"Cached bytes for page %s of %s are not valid UTF-8; treating as miss.",
|
|
166
|
+
page_number,
|
|
167
|
+
pdf_path,
|
|
168
|
+
exc_info=True,
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
logger.debug(f"Cache miss for page {page_number} of {pdf_path}")
|
|
172
|
+
return None
|
|
173
|
+
|
|
174
|
+
async def _extract_single_pdf_page(
|
|
175
|
+
self, pdf_path: Path, page_path: Path, prompt: str, page_number: int
|
|
176
|
+
) -> str:
|
|
177
|
+
try:
|
|
178
|
+
page_input = ExtractionInput(
|
|
179
|
+
path=str(page_path), mime_type="application/pdf"
|
|
180
|
+
)
|
|
181
|
+
completion_kwargs = self._build_completion_kwargs(prompt, page_input)
|
|
182
|
+
response = await litellm.acompletion(**completion_kwargs)
|
|
183
|
+
except Exception as e:
|
|
184
|
+
raise RuntimeError(
|
|
185
|
+
f"Error extracting page {page_number} in file {page_path}: {e}"
|
|
186
|
+
) from e
|
|
187
|
+
|
|
188
|
+
if (
|
|
189
|
+
not isinstance(response, ModelResponse)
|
|
190
|
+
or not response.choices
|
|
191
|
+
or len(response.choices) == 0
|
|
192
|
+
or not isinstance(response.choices[0], Choices)
|
|
193
|
+
):
|
|
194
|
+
raise RuntimeError(
|
|
195
|
+
f"Expected ModelResponse with Choices for page {page_number}, got {type(response)}."
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
if response.choices[0].message.content is None:
|
|
199
|
+
raise ValueError(
|
|
200
|
+
f"No text returned from LiteLLM when extracting page {page_number}"
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
content = response.choices[0].message.content
|
|
204
|
+
if not content:
|
|
205
|
+
raise ValueError(
|
|
206
|
+
f"No text returned from extraction model when extracting page {page_number} for {page_path}"
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
if self.filesystem_cache is not None:
|
|
210
|
+
# we don't want to fail the whole extraction just because cache write fails
|
|
211
|
+
# as that would block the whole flow
|
|
212
|
+
try:
|
|
213
|
+
logger.debug(f"Caching page {page_number} of {page_path} in cache")
|
|
214
|
+
await self.filesystem_cache.set(
|
|
215
|
+
self.pdf_page_cache_key(pdf_path, page_number),
|
|
216
|
+
content.encode("utf-8"),
|
|
217
|
+
)
|
|
218
|
+
except Exception:
|
|
219
|
+
logger.warning(
|
|
220
|
+
"Failed to cache page %s of %s; continuing without cache.",
|
|
221
|
+
page_number,
|
|
222
|
+
page_path,
|
|
223
|
+
exc_info=True,
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
return content
|
|
227
|
+
|
|
228
|
+
async def _extract_pdf_page_by_page(self, pdf_path: Path, prompt: str) -> str:
|
|
229
|
+
async with split_pdf_into_pages(pdf_path) as page_paths:
|
|
230
|
+
page_outcomes: List[str | Exception | None] = [None] * len(page_paths)
|
|
231
|
+
|
|
232
|
+
extract_page_jobs: list = []
|
|
233
|
+
page_indices_for_jobs: list = [] # Track which page index each job corresponds to
|
|
234
|
+
|
|
235
|
+
# we extract from each page individually and then combine the results
|
|
236
|
+
# this ensures the model stays focused on the current page and does not
|
|
237
|
+
# start summarizing the later pages
|
|
238
|
+
for i, page_path in enumerate(page_paths):
|
|
239
|
+
page_content = await self.get_page_content_from_cache(pdf_path, i)
|
|
240
|
+
if page_content is not None:
|
|
241
|
+
page_outcomes[i] = page_content
|
|
242
|
+
continue
|
|
243
|
+
|
|
244
|
+
extract_page_jobs.append(
|
|
245
|
+
self._extract_single_pdf_page(pdf_path, page_path, prompt, i)
|
|
246
|
+
)
|
|
247
|
+
page_indices_for_jobs.append(i)
|
|
248
|
+
|
|
249
|
+
if (
|
|
250
|
+
len(extract_page_jobs)
|
|
251
|
+
>= max_pdf_page_concurrency_for_model(self.litellm_model_slug())
|
|
252
|
+
or i == len(page_paths) - 1
|
|
253
|
+
):
|
|
254
|
+
extraction_results = await asyncio.gather(
|
|
255
|
+
*extract_page_jobs, return_exceptions=True
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
for batch_i, extraction_result in enumerate(extraction_results):
|
|
259
|
+
page_index = page_indices_for_jobs[batch_i]
|
|
260
|
+
# we let it continue even if there is an error - the success results will be cached
|
|
261
|
+
# and can be reused on the next run
|
|
262
|
+
if isinstance(extraction_result, Exception):
|
|
263
|
+
page_outcomes[page_index] = extraction_result
|
|
264
|
+
elif isinstance(extraction_result, str):
|
|
265
|
+
page_outcomes[page_index] = extraction_result
|
|
266
|
+
else:
|
|
267
|
+
raise ValueError(
|
|
268
|
+
f"Unexpected type {type(extraction_result)} for page {page_index}"
|
|
269
|
+
)
|
|
270
|
+
extract_page_jobs.clear()
|
|
271
|
+
page_indices_for_jobs.clear()
|
|
272
|
+
|
|
273
|
+
exceptions: list[tuple[int, Exception]] = [
|
|
274
|
+
(page_index, result)
|
|
275
|
+
for page_index, result in enumerate(page_outcomes)
|
|
276
|
+
if isinstance(result, Exception)
|
|
277
|
+
]
|
|
278
|
+
if len(exceptions) > 0:
|
|
279
|
+
msg = f"Error extracting PDF {pdf_path}: "
|
|
280
|
+
for page_index, exception in exceptions:
|
|
281
|
+
msg += f"Page {page_index}: {exception}\n"
|
|
282
|
+
raise RuntimeError(msg)
|
|
283
|
+
|
|
284
|
+
return "\n\n".join(
|
|
285
|
+
[outcome for outcome in page_outcomes if isinstance(outcome, str)]
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
def _get_kind_from_mime_type(self, mime_type: str) -> Kind | None:
|
|
289
|
+
for kind, mime_types in MIME_TYPES_SUPPORTED.items():
|
|
290
|
+
if mime_type in mime_types:
|
|
291
|
+
return kind
|
|
292
|
+
return None
|
|
293
|
+
|
|
294
|
+
def _build_completion_kwargs(
|
|
295
|
+
self, prompt: str, extraction_input: ExtractionInput
|
|
296
|
+
) -> dict[str, Any]:
|
|
297
|
+
completion_kwargs = {
|
|
298
|
+
"model": self.litellm_model_slug(),
|
|
299
|
+
"messages": [
|
|
300
|
+
{
|
|
301
|
+
"role": "user",
|
|
302
|
+
"content": [
|
|
303
|
+
{"type": "text", "text": prompt},
|
|
304
|
+
encode_file_litellm_format(
|
|
305
|
+
Path(extraction_input.path), extraction_input.mime_type
|
|
306
|
+
),
|
|
307
|
+
],
|
|
308
|
+
}
|
|
309
|
+
],
|
|
310
|
+
}
|
|
311
|
+
|
|
312
|
+
if self.litellm_core_config.base_url:
|
|
313
|
+
completion_kwargs["base_url"] = self.litellm_core_config.base_url
|
|
314
|
+
|
|
315
|
+
if self.litellm_core_config.default_headers:
|
|
316
|
+
completion_kwargs["default_headers"] = (
|
|
317
|
+
self.litellm_core_config.default_headers
|
|
318
|
+
)
|
|
319
|
+
|
|
320
|
+
if self.litellm_core_config.additional_body_options:
|
|
321
|
+
completion_kwargs.update(self.litellm_core_config.additional_body_options)
|
|
322
|
+
|
|
323
|
+
return completion_kwargs
|
|
324
|
+
|
|
325
|
+
async def _extract(self, extraction_input: ExtractionInput) -> ExtractionOutput:
|
|
326
|
+
kind = self._get_kind_from_mime_type(extraction_input.mime_type)
|
|
327
|
+
if kind is None:
|
|
328
|
+
raise ValueError(
|
|
329
|
+
f"Unsupported MIME type: {extraction_input.mime_type} for {extraction_input.path}"
|
|
330
|
+
)
|
|
331
|
+
|
|
332
|
+
prompt = self.prompt_for_kind.get(kind)
|
|
333
|
+
if prompt is None:
|
|
334
|
+
raise ValueError(f"No prompt found for kind: {kind}")
|
|
335
|
+
|
|
336
|
+
# special handling for PDFs - process each page individually
|
|
337
|
+
if extraction_input.mime_type == "application/pdf":
|
|
338
|
+
content = await self._extract_pdf_page_by_page(
|
|
339
|
+
Path(extraction_input.path), prompt
|
|
340
|
+
)
|
|
341
|
+
return ExtractionOutput(
|
|
342
|
+
is_passthrough=False,
|
|
343
|
+
content=content,
|
|
344
|
+
content_format=self.extractor_config.output_format,
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
completion_kwargs = self._build_completion_kwargs(prompt, extraction_input)
|
|
348
|
+
|
|
349
|
+
response = await litellm.acompletion(**completion_kwargs)
|
|
350
|
+
|
|
351
|
+
if (
|
|
352
|
+
not isinstance(response, ModelResponse)
|
|
353
|
+
or not response.choices
|
|
354
|
+
or len(response.choices) == 0
|
|
355
|
+
or not isinstance(response.choices[0], Choices)
|
|
356
|
+
):
|
|
357
|
+
raise RuntimeError(
|
|
358
|
+
f"Expected ModelResponse with Choices, got {type(response)}."
|
|
359
|
+
)
|
|
360
|
+
|
|
361
|
+
if response.choices[0].message.content is None:
|
|
362
|
+
raise ValueError("No text returned from LiteLLM when extracting document")
|
|
363
|
+
|
|
364
|
+
return ExtractionOutput(
|
|
365
|
+
is_passthrough=False,
|
|
366
|
+
content=response.choices[0].message.content,
|
|
367
|
+
content_format=self.extractor_config.output_format,
|
|
368
|
+
)
|
|
369
|
+
|
|
370
|
+
def litellm_model_slug(self) -> str:
|
|
371
|
+
kiln_model_provider = built_in_models_from_provider(
|
|
372
|
+
ModelProviderName(self.extractor_config.model_provider_name),
|
|
373
|
+
self.extractor_config.model_name,
|
|
374
|
+
)
|
|
375
|
+
|
|
376
|
+
if kiln_model_provider is None:
|
|
377
|
+
raise ValueError(
|
|
378
|
+
f"Model provider {self.extractor_config.model_provider_name} not found in the list of built-in models"
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
# need to translate into LiteLLM model slug
|
|
382
|
+
litellm_provider_name = get_litellm_provider_info(
|
|
383
|
+
kiln_model_provider,
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
return litellm_provider_name.litellm_model_id
|
|
@@ -0,0 +1,244 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
from unittest.mock import patch
|
|
3
|
+
|
|
4
|
+
import pytest
|
|
5
|
+
|
|
6
|
+
from kiln_ai.adapters.extractors.base_extractor import (
|
|
7
|
+
BaseExtractor,
|
|
8
|
+
ExtractionInput,
|
|
9
|
+
ExtractionOutput,
|
|
10
|
+
)
|
|
11
|
+
from kiln_ai.datamodel.extraction import ExtractorConfig, ExtractorType, OutputFormat
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class MockBaseExtractor(BaseExtractor):
|
|
15
|
+
async def _extract(self, input: ExtractionInput) -> ExtractionOutput:
|
|
16
|
+
return ExtractionOutput(
|
|
17
|
+
is_passthrough=False,
|
|
18
|
+
content="mock concrete extractor output",
|
|
19
|
+
content_format=OutputFormat.MARKDOWN,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@pytest.fixture
|
|
24
|
+
def mock_litellm_properties():
|
|
25
|
+
return {
|
|
26
|
+
"prompt_document": "mock prompt for document",
|
|
27
|
+
"prompt_image": "mock prompt for image",
|
|
28
|
+
"prompt_video": "mock prompt for video",
|
|
29
|
+
"prompt_audio": "mock prompt for audio",
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@pytest.fixture
|
|
34
|
+
def mock_extractor(mock_litellm_properties):
|
|
35
|
+
return MockBaseExtractor(
|
|
36
|
+
ExtractorConfig(
|
|
37
|
+
name="mock",
|
|
38
|
+
model_provider_name="gemini_api",
|
|
39
|
+
model_name="gemini-2.0-flash",
|
|
40
|
+
extractor_type=ExtractorType.LITELLM,
|
|
41
|
+
output_format=OutputFormat.MARKDOWN,
|
|
42
|
+
properties=mock_litellm_properties,
|
|
43
|
+
)
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def mock_extractor_with_passthroughs(
|
|
48
|
+
properties: dict[str, Any],
|
|
49
|
+
mimetypes: list[OutputFormat],
|
|
50
|
+
output_format: OutputFormat,
|
|
51
|
+
):
|
|
52
|
+
return MockBaseExtractor(
|
|
53
|
+
ExtractorConfig(
|
|
54
|
+
name="mock",
|
|
55
|
+
model_provider_name="gemini_api",
|
|
56
|
+
model_name="gemini-2.0-flash",
|
|
57
|
+
extractor_type=ExtractorType.LITELLM,
|
|
58
|
+
passthrough_mimetypes=mimetypes,
|
|
59
|
+
output_format=output_format,
|
|
60
|
+
properties=properties,
|
|
61
|
+
)
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def test_should_passthrough(mock_litellm_properties):
|
|
66
|
+
extractor = mock_extractor_with_passthroughs(
|
|
67
|
+
mock_litellm_properties,
|
|
68
|
+
[OutputFormat.TEXT, OutputFormat.MARKDOWN],
|
|
69
|
+
OutputFormat.TEXT,
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
# should passthrough
|
|
73
|
+
assert extractor._should_passthrough("text/plain")
|
|
74
|
+
assert extractor._should_passthrough("text/markdown")
|
|
75
|
+
|
|
76
|
+
# should not passthrough
|
|
77
|
+
assert not extractor._should_passthrough("image/png")
|
|
78
|
+
assert not extractor._should_passthrough("application/pdf")
|
|
79
|
+
assert not extractor._should_passthrough("text/html")
|
|
80
|
+
assert not extractor._should_passthrough("image/jpeg")
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
async def test_extract_passthrough(mock_litellm_properties):
|
|
84
|
+
"""
|
|
85
|
+
Tests that when a file's MIME type is configured for passthrough, the extractor skips
|
|
86
|
+
the concrete extraction method and returns the file's contents directly with the
|
|
87
|
+
correct passthrough output format.
|
|
88
|
+
"""
|
|
89
|
+
extractor = mock_extractor_with_passthroughs(
|
|
90
|
+
mock_litellm_properties,
|
|
91
|
+
[OutputFormat.TEXT, OutputFormat.MARKDOWN],
|
|
92
|
+
OutputFormat.TEXT,
|
|
93
|
+
)
|
|
94
|
+
with (
|
|
95
|
+
patch.object(
|
|
96
|
+
extractor,
|
|
97
|
+
"_extract",
|
|
98
|
+
return_value=ExtractionOutput(
|
|
99
|
+
is_passthrough=False,
|
|
100
|
+
content="mock concrete extractor output",
|
|
101
|
+
content_format=OutputFormat.TEXT,
|
|
102
|
+
),
|
|
103
|
+
) as mock_extract,
|
|
104
|
+
patch(
|
|
105
|
+
"pathlib.Path.read_text",
|
|
106
|
+
return_value=b"test content",
|
|
107
|
+
),
|
|
108
|
+
):
|
|
109
|
+
result = await extractor.extract(
|
|
110
|
+
ExtractionInput(
|
|
111
|
+
path="test.txt",
|
|
112
|
+
mime_type="text/plain",
|
|
113
|
+
)
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
# Verify _extract was not called
|
|
117
|
+
mock_extract.assert_not_called()
|
|
118
|
+
|
|
119
|
+
# Verify correct passthrough result
|
|
120
|
+
assert result.is_passthrough
|
|
121
|
+
assert result.content == "test content"
|
|
122
|
+
assert result.content_format == OutputFormat.TEXT
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
@pytest.mark.parametrize(
|
|
126
|
+
"output_format",
|
|
127
|
+
[
|
|
128
|
+
"text/plain",
|
|
129
|
+
"text/markdown",
|
|
130
|
+
],
|
|
131
|
+
)
|
|
132
|
+
async def test_extract_passthrough_output_format(
|
|
133
|
+
mock_litellm_properties, output_format
|
|
134
|
+
):
|
|
135
|
+
extractor = mock_extractor_with_passthroughs(
|
|
136
|
+
mock_litellm_properties,
|
|
137
|
+
[OutputFormat.TEXT, OutputFormat.MARKDOWN],
|
|
138
|
+
output_format,
|
|
139
|
+
)
|
|
140
|
+
with (
|
|
141
|
+
patch.object(
|
|
142
|
+
extractor,
|
|
143
|
+
"_extract",
|
|
144
|
+
return_value=ExtractionOutput(
|
|
145
|
+
is_passthrough=False,
|
|
146
|
+
content="mock concrete extractor output",
|
|
147
|
+
content_format=output_format,
|
|
148
|
+
),
|
|
149
|
+
) as mock_extract,
|
|
150
|
+
patch(
|
|
151
|
+
"pathlib.Path.read_text",
|
|
152
|
+
return_value="test content",
|
|
153
|
+
),
|
|
154
|
+
):
|
|
155
|
+
result = await extractor.extract(
|
|
156
|
+
ExtractionInput(
|
|
157
|
+
path="test.txt",
|
|
158
|
+
mime_type="text/plain",
|
|
159
|
+
)
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
# Verify _extract was not called
|
|
163
|
+
mock_extract.assert_not_called()
|
|
164
|
+
|
|
165
|
+
# Verify correct passthrough result
|
|
166
|
+
assert result.is_passthrough
|
|
167
|
+
assert result.content == "test content"
|
|
168
|
+
assert result.content_format == output_format
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
@pytest.mark.parametrize(
|
|
172
|
+
"path, mime_type, output_format",
|
|
173
|
+
[
|
|
174
|
+
("test.mp3", "audio/mpeg", OutputFormat.TEXT),
|
|
175
|
+
("test.png", "image/png", OutputFormat.TEXT),
|
|
176
|
+
("test.pdf", "application/pdf", OutputFormat.TEXT),
|
|
177
|
+
("test.txt", "text/plain", OutputFormat.MARKDOWN),
|
|
178
|
+
("test.txt", "text/markdown", OutputFormat.MARKDOWN),
|
|
179
|
+
("test.html", "text/html", OutputFormat.MARKDOWN),
|
|
180
|
+
],
|
|
181
|
+
)
|
|
182
|
+
async def test_extract_non_passthrough(
|
|
183
|
+
mock_extractor, path: str, mime_type: str, output_format: OutputFormat
|
|
184
|
+
):
|
|
185
|
+
with (
|
|
186
|
+
patch.object(
|
|
187
|
+
mock_extractor,
|
|
188
|
+
"_extract",
|
|
189
|
+
return_value=ExtractionOutput(
|
|
190
|
+
is_passthrough=False,
|
|
191
|
+
content="mock concrete extractor output",
|
|
192
|
+
content_format=output_format,
|
|
193
|
+
),
|
|
194
|
+
) as mock_extract,
|
|
195
|
+
):
|
|
196
|
+
# first we call the base class extract method
|
|
197
|
+
result = await mock_extractor.extract(
|
|
198
|
+
ExtractionInput(
|
|
199
|
+
path=path,
|
|
200
|
+
mime_type=mime_type,
|
|
201
|
+
)
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
# then we call the subclass _extract method and add validated mime_type
|
|
205
|
+
mock_extract.assert_called_once_with(
|
|
206
|
+
ExtractionInput(
|
|
207
|
+
path=path,
|
|
208
|
+
mime_type=mime_type,
|
|
209
|
+
)
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
assert not result.is_passthrough
|
|
213
|
+
assert result.content == "mock concrete extractor output"
|
|
214
|
+
assert result.content_format == output_format
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
async def test_default_output_format(mock_litellm_properties):
|
|
218
|
+
config = ExtractorConfig(
|
|
219
|
+
name="mock",
|
|
220
|
+
model_provider_name="gemini_api",
|
|
221
|
+
model_name="gemini-2.0-flash",
|
|
222
|
+
extractor_type=ExtractorType.LITELLM,
|
|
223
|
+
properties=mock_litellm_properties,
|
|
224
|
+
)
|
|
225
|
+
assert config.output_format == OutputFormat.MARKDOWN
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
async def test_extract_failure_from_concrete_extractor(mock_extractor):
|
|
229
|
+
with patch.object(
|
|
230
|
+
mock_extractor,
|
|
231
|
+
"_extract",
|
|
232
|
+
side_effect=Exception("error from concrete extractor"),
|
|
233
|
+
):
|
|
234
|
+
with pytest.raises(ValueError, match="error from concrete extractor"):
|
|
235
|
+
await mock_extractor.extract(
|
|
236
|
+
ExtractionInput(
|
|
237
|
+
path="test.txt",
|
|
238
|
+
mime_type="text/plain",
|
|
239
|
+
)
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
async def test_output_format(mock_extractor):
|
|
244
|
+
assert mock_extractor.output_format() == OutputFormat.MARKDOWN
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
|
|
5
|
+
from conftest import MockFileFactoryMimeType
|
|
6
|
+
from kiln_ai.adapters.extractors.encoding import from_base64_url, to_base64_url
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
async def test_to_base64_url(mock_file_factory):
|
|
10
|
+
mock_file = mock_file_factory(MockFileFactoryMimeType.JPEG)
|
|
11
|
+
|
|
12
|
+
byte_data = Path(mock_file).read_bytes()
|
|
13
|
+
|
|
14
|
+
# encode the byte data
|
|
15
|
+
base64_url = to_base64_url("image/jpeg", byte_data)
|
|
16
|
+
assert base64_url.startswith("data:image/jpeg;base64,")
|
|
17
|
+
|
|
18
|
+
# decode the base64 url
|
|
19
|
+
assert from_base64_url(base64_url) == byte_data
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def test_from_base64_url_invalid_format_no_data_prefix():
|
|
23
|
+
"""Test that from_base64_url raises ValueError when input doesn't start with 'data:'"""
|
|
24
|
+
with pytest.raises(ValueError, match="Invalid base64 URL format"):
|
|
25
|
+
from_base64_url("not-a-data-url")
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def test_from_base64_url_invalid_format_no_comma():
|
|
29
|
+
"""Test that from_base64_url raises ValueError when input doesn't contain a comma"""
|
|
30
|
+
with pytest.raises(ValueError, match="Invalid base64 URL format"):
|
|
31
|
+
from_base64_url("data:image/jpeg;base64")
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def test_from_base64_url_invalid_parts():
|
|
35
|
+
"""Test that from_base64_url raises ValueError when splitting by comma doesn't result in exactly 2 parts"""
|
|
36
|
+
with pytest.raises(ValueError, match="Invalid base64 URL format"):
|
|
37
|
+
from_base64_url(",part2")
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def test_from_base64_url_base64_decode_failure():
|
|
41
|
+
"""Test that from_base64_url raises ValueError when base64 decoding fails"""
|
|
42
|
+
with pytest.raises(ValueError, match="Failed to decode base64 data"):
|
|
43
|
+
from_base64_url("-base64-data!")
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def test_from_base64_url_valid_format():
|
|
47
|
+
"""Test that from_base64_url works with valid base64 URL format"""
|
|
48
|
+
# Create a simple valid base64 URL
|
|
49
|
+
test_data = b"Hello, World!"
|
|
50
|
+
base64_encoded = "SGVsbG8sIFdvcmxkIQ=="
|
|
51
|
+
base64_url = f"data:text/plain;base64,{base64_encoded}"
|
|
52
|
+
|
|
53
|
+
result = from_base64_url(base64_url)
|
|
54
|
+
assert result == test_data
|