kiln-ai 0.20.1__py3-none-any.whl → 0.22.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (133) hide show
  1. kiln_ai/adapters/__init__.py +6 -0
  2. kiln_ai/adapters/adapter_registry.py +43 -226
  3. kiln_ai/adapters/chunkers/__init__.py +13 -0
  4. kiln_ai/adapters/chunkers/base_chunker.py +42 -0
  5. kiln_ai/adapters/chunkers/chunker_registry.py +16 -0
  6. kiln_ai/adapters/chunkers/fixed_window_chunker.py +39 -0
  7. kiln_ai/adapters/chunkers/helpers.py +23 -0
  8. kiln_ai/adapters/chunkers/test_base_chunker.py +63 -0
  9. kiln_ai/adapters/chunkers/test_chunker_registry.py +28 -0
  10. kiln_ai/adapters/chunkers/test_fixed_window_chunker.py +346 -0
  11. kiln_ai/adapters/chunkers/test_helpers.py +75 -0
  12. kiln_ai/adapters/data_gen/test_data_gen_task.py +9 -3
  13. kiln_ai/adapters/embedding/__init__.py +0 -0
  14. kiln_ai/adapters/embedding/base_embedding_adapter.py +44 -0
  15. kiln_ai/adapters/embedding/embedding_registry.py +32 -0
  16. kiln_ai/adapters/embedding/litellm_embedding_adapter.py +199 -0
  17. kiln_ai/adapters/embedding/test_base_embedding_adapter.py +283 -0
  18. kiln_ai/adapters/embedding/test_embedding_registry.py +166 -0
  19. kiln_ai/adapters/embedding/test_litellm_embedding_adapter.py +1149 -0
  20. kiln_ai/adapters/eval/eval_runner.py +6 -2
  21. kiln_ai/adapters/eval/test_base_eval.py +1 -3
  22. kiln_ai/adapters/eval/test_g_eval.py +1 -1
  23. kiln_ai/adapters/extractors/__init__.py +18 -0
  24. kiln_ai/adapters/extractors/base_extractor.py +72 -0
  25. kiln_ai/adapters/extractors/encoding.py +20 -0
  26. kiln_ai/adapters/extractors/extractor_registry.py +44 -0
  27. kiln_ai/adapters/extractors/extractor_runner.py +112 -0
  28. kiln_ai/adapters/extractors/litellm_extractor.py +406 -0
  29. kiln_ai/adapters/extractors/test_base_extractor.py +244 -0
  30. kiln_ai/adapters/extractors/test_encoding.py +54 -0
  31. kiln_ai/adapters/extractors/test_extractor_registry.py +181 -0
  32. kiln_ai/adapters/extractors/test_extractor_runner.py +181 -0
  33. kiln_ai/adapters/extractors/test_litellm_extractor.py +1290 -0
  34. kiln_ai/adapters/fine_tune/test_dataset_formatter.py +2 -2
  35. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +2 -6
  36. kiln_ai/adapters/fine_tune/test_together_finetune.py +2 -6
  37. kiln_ai/adapters/ml_embedding_model_list.py +494 -0
  38. kiln_ai/adapters/ml_model_list.py +876 -18
  39. kiln_ai/adapters/model_adapters/litellm_adapter.py +40 -75
  40. kiln_ai/adapters/model_adapters/test_litellm_adapter.py +79 -1
  41. kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py +119 -5
  42. kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +9 -3
  43. kiln_ai/adapters/model_adapters/test_structured_output.py +9 -10
  44. kiln_ai/adapters/ollama_tools.py +69 -12
  45. kiln_ai/adapters/provider_tools.py +190 -46
  46. kiln_ai/adapters/rag/deduplication.py +49 -0
  47. kiln_ai/adapters/rag/progress.py +252 -0
  48. kiln_ai/adapters/rag/rag_runners.py +844 -0
  49. kiln_ai/adapters/rag/test_deduplication.py +195 -0
  50. kiln_ai/adapters/rag/test_progress.py +785 -0
  51. kiln_ai/adapters/rag/test_rag_runners.py +2376 -0
  52. kiln_ai/adapters/remote_config.py +80 -8
  53. kiln_ai/adapters/test_adapter_registry.py +579 -86
  54. kiln_ai/adapters/test_ml_embedding_model_list.py +239 -0
  55. kiln_ai/adapters/test_ml_model_list.py +202 -0
  56. kiln_ai/adapters/test_ollama_tools.py +340 -1
  57. kiln_ai/adapters/test_prompt_builders.py +1 -1
  58. kiln_ai/adapters/test_provider_tools.py +199 -8
  59. kiln_ai/adapters/test_remote_config.py +551 -56
  60. kiln_ai/adapters/vector_store/__init__.py +1 -0
  61. kiln_ai/adapters/vector_store/base_vector_store_adapter.py +83 -0
  62. kiln_ai/adapters/vector_store/lancedb_adapter.py +389 -0
  63. kiln_ai/adapters/vector_store/test_base_vector_store.py +160 -0
  64. kiln_ai/adapters/vector_store/test_lancedb_adapter.py +1841 -0
  65. kiln_ai/adapters/vector_store/test_vector_store_registry.py +199 -0
  66. kiln_ai/adapters/vector_store/vector_store_registry.py +33 -0
  67. kiln_ai/datamodel/__init__.py +16 -13
  68. kiln_ai/datamodel/basemodel.py +201 -4
  69. kiln_ai/datamodel/chunk.py +158 -0
  70. kiln_ai/datamodel/datamodel_enums.py +27 -0
  71. kiln_ai/datamodel/embedding.py +64 -0
  72. kiln_ai/datamodel/external_tool_server.py +206 -54
  73. kiln_ai/datamodel/extraction.py +317 -0
  74. kiln_ai/datamodel/project.py +33 -1
  75. kiln_ai/datamodel/rag.py +79 -0
  76. kiln_ai/datamodel/task.py +5 -0
  77. kiln_ai/datamodel/task_output.py +41 -11
  78. kiln_ai/datamodel/test_attachment.py +649 -0
  79. kiln_ai/datamodel/test_basemodel.py +270 -14
  80. kiln_ai/datamodel/test_chunk_models.py +317 -0
  81. kiln_ai/datamodel/test_dataset_split.py +1 -1
  82. kiln_ai/datamodel/test_datasource.py +50 -0
  83. kiln_ai/datamodel/test_embedding_models.py +448 -0
  84. kiln_ai/datamodel/test_eval_model.py +6 -6
  85. kiln_ai/datamodel/test_external_tool_server.py +534 -152
  86. kiln_ai/datamodel/test_extraction_chunk.py +206 -0
  87. kiln_ai/datamodel/test_extraction_model.py +501 -0
  88. kiln_ai/datamodel/test_rag.py +641 -0
  89. kiln_ai/datamodel/test_task.py +35 -1
  90. kiln_ai/datamodel/test_tool_id.py +187 -1
  91. kiln_ai/datamodel/test_vector_store.py +320 -0
  92. kiln_ai/datamodel/tool_id.py +58 -0
  93. kiln_ai/datamodel/vector_store.py +141 -0
  94. kiln_ai/tools/base_tool.py +12 -3
  95. kiln_ai/tools/built_in_tools/math_tools.py +12 -4
  96. kiln_ai/tools/kiln_task_tool.py +158 -0
  97. kiln_ai/tools/mcp_server_tool.py +2 -2
  98. kiln_ai/tools/mcp_session_manager.py +51 -22
  99. kiln_ai/tools/rag_tools.py +164 -0
  100. kiln_ai/tools/test_kiln_task_tool.py +527 -0
  101. kiln_ai/tools/test_mcp_server_tool.py +4 -15
  102. kiln_ai/tools/test_mcp_session_manager.py +187 -227
  103. kiln_ai/tools/test_rag_tools.py +929 -0
  104. kiln_ai/tools/test_tool_registry.py +290 -7
  105. kiln_ai/tools/tool_registry.py +69 -16
  106. kiln_ai/utils/__init__.py +3 -0
  107. kiln_ai/utils/async_job_runner.py +62 -17
  108. kiln_ai/utils/config.py +2 -2
  109. kiln_ai/utils/env.py +15 -0
  110. kiln_ai/utils/filesystem.py +14 -0
  111. kiln_ai/utils/filesystem_cache.py +60 -0
  112. kiln_ai/utils/litellm.py +94 -0
  113. kiln_ai/utils/lock.py +100 -0
  114. kiln_ai/utils/mime_type.py +38 -0
  115. kiln_ai/utils/open_ai_types.py +19 -2
  116. kiln_ai/utils/pdf_utils.py +59 -0
  117. kiln_ai/utils/test_async_job_runner.py +151 -35
  118. kiln_ai/utils/test_env.py +142 -0
  119. kiln_ai/utils/test_filesystem_cache.py +316 -0
  120. kiln_ai/utils/test_litellm.py +206 -0
  121. kiln_ai/utils/test_lock.py +185 -0
  122. kiln_ai/utils/test_mime_type.py +66 -0
  123. kiln_ai/utils/test_open_ai_types.py +88 -12
  124. kiln_ai/utils/test_pdf_utils.py +86 -0
  125. kiln_ai/utils/test_uuid.py +111 -0
  126. kiln_ai/utils/test_validation.py +524 -0
  127. kiln_ai/utils/uuid.py +9 -0
  128. kiln_ai/utils/validation.py +90 -0
  129. {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/METADATA +9 -1
  130. kiln_ai-0.22.0.dist-info/RECORD +213 -0
  131. kiln_ai-0.20.1.dist-info/RECORD +0 -138
  132. {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/WHEEL +0 -0
  133. {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -0,0 +1,406 @@
1
+ import asyncio
2
+ import hashlib
3
+ import logging
4
+ from functools import cached_property
5
+ from pathlib import Path
6
+ from typing import Any, List
7
+
8
+ import litellm
9
+ from litellm.types.utils import Choices, ModelResponse
10
+
11
+ from kiln_ai.adapters.extractors.base_extractor import (
12
+ BaseExtractor,
13
+ ExtractionInput,
14
+ ExtractionOutput,
15
+ )
16
+ from kiln_ai.adapters.extractors.encoding import to_base64_url
17
+ from kiln_ai.adapters.ml_model_list import (
18
+ KilnModelProvider,
19
+ built_in_models_from_provider,
20
+ )
21
+ from kiln_ai.adapters.provider_tools import LiteLlmCoreConfig
22
+ from kiln_ai.datamodel.datamodel_enums import ModelProviderName
23
+ from kiln_ai.datamodel.extraction import ExtractorConfig, ExtractorType, Kind
24
+ from kiln_ai.utils.filesystem_cache import FilesystemCache
25
+ from kiln_ai.utils.litellm import get_litellm_provider_info
26
+ from kiln_ai.utils.pdf_utils import convert_pdf_to_images, split_pdf_into_pages
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+ MIME_TYPES_SUPPORTED = {
31
+ Kind.DOCUMENT: [
32
+ "application/pdf",
33
+ "text/plain",
34
+ "text/markdown", # not officially listed, but works
35
+ "text/html",
36
+ "text/md",
37
+ "text/csv",
38
+ ],
39
+ Kind.IMAGE: [
40
+ "image/png",
41
+ "image/jpeg",
42
+ "image/jpg",
43
+ ],
44
+ Kind.VIDEO: [
45
+ "video/mp4",
46
+ "video/mov", # the correct type is video/quicktime, but Google lists it as video/mov
47
+ "video/quicktime",
48
+ ],
49
+ Kind.AUDIO: [
50
+ "audio/wav",
51
+ "audio/mpeg", # this is the official MP3 mimetype, audio/mp3 is often used but not correct
52
+ "audio/ogg",
53
+ ],
54
+ }
55
+
56
+
57
+ def encode_file_litellm_format(path: Path, mime_type: str) -> dict[str, Any]:
58
+ # There are different formats that LiteLLM supports, the docs are scattered
59
+ # and incomplete:
60
+ # - https://docs.litellm.ai/docs/completion/document_understanding#base64
61
+ # - https://docs.litellm.ai/docs/completion/vision#explicitly-specify-image-type
62
+
63
+ # this is the most generic format that seems to work for all / most mime types
64
+ if mime_type in [
65
+ "application/pdf",
66
+ "text/csv",
67
+ "text/html",
68
+ "text/markdown",
69
+ "text/plain",
70
+ ] or any(mime_type.startswith(m) for m in ["video/", "audio/"]):
71
+ file_bytes = path.read_bytes()
72
+ return {
73
+ "type": "file",
74
+ "file": {
75
+ "file_data": to_base64_url(mime_type, file_bytes),
76
+ },
77
+ }
78
+
79
+ # image has its own format (but also appears to work with the file format)
80
+ if mime_type.startswith("image/"):
81
+ image_bytes = path.read_bytes()
82
+ return {
83
+ "type": "image_url",
84
+ "image_url": {
85
+ "url": to_base64_url(mime_type, image_bytes),
86
+ },
87
+ }
88
+
89
+ raise ValueError(f"Unsupported MIME type: {mime_type} for {path}")
90
+
91
+
92
+ class LitellmExtractor(BaseExtractor):
93
+ def __init__(
94
+ self,
95
+ extractor_config: ExtractorConfig,
96
+ litellm_core_config: LiteLlmCoreConfig,
97
+ filesystem_cache: FilesystemCache | None = None,
98
+ default_max_parallel_requests: int = 5,
99
+ ):
100
+ if extractor_config.extractor_type != ExtractorType.LITELLM:
101
+ raise ValueError(
102
+ f"LitellmExtractor must be initialized with a litellm extractor_type config. Got {extractor_config.extractor_type}"
103
+ )
104
+
105
+ prompt_document = extractor_config.prompt_document()
106
+ if prompt_document is None or prompt_document == "":
107
+ raise ValueError(
108
+ "properties.prompt_document is required for LitellmExtractor"
109
+ )
110
+ prompt_video = extractor_config.prompt_video()
111
+ if prompt_video is None or prompt_video == "":
112
+ raise ValueError("properties.prompt_video is required for LitellmExtractor")
113
+ prompt_audio = extractor_config.prompt_audio()
114
+ if prompt_audio is None or prompt_audio == "":
115
+ raise ValueError("properties.prompt_audio is required for LitellmExtractor")
116
+ prompt_image = extractor_config.prompt_image()
117
+ if prompt_image is None or prompt_image == "":
118
+ raise ValueError("properties.prompt_image is required for LitellmExtractor")
119
+
120
+ self.filesystem_cache = filesystem_cache
121
+
122
+ super().__init__(extractor_config)
123
+ self.prompt_for_kind = {
124
+ Kind.DOCUMENT: prompt_document,
125
+ Kind.VIDEO: prompt_video,
126
+ Kind.AUDIO: prompt_audio,
127
+ Kind.IMAGE: prompt_image,
128
+ }
129
+
130
+ self.litellm_core_config = litellm_core_config
131
+ self.default_max_parallel_requests = default_max_parallel_requests
132
+
133
+ def pdf_page_cache_key(self, pdf_path: Path, page_number: int) -> str:
134
+ """
135
+ Generate a cache key for a page of a PDF. The PDF path must be the full path to the PDF file,
136
+ not the path to the page - since page path is temporary and changes on each run.
137
+ """
138
+ if self.extractor_config.id is None:
139
+ raise ValueError("Extractor config ID is required for PDF page cache key")
140
+
141
+ raw_key = f"{pdf_path.resolve()}::{page_number}"
142
+ digest = hashlib.md5(raw_key.encode("utf-8")).hexdigest()
143
+ return f"{self.extractor_config.id}_{digest}"
144
+
145
+ async def get_page_content_from_cache(
146
+ self, pdf_path: Path, page_number: int
147
+ ) -> str | None:
148
+ if self.filesystem_cache is None:
149
+ return None
150
+
151
+ page_bytes = await self.filesystem_cache.get(
152
+ self.pdf_page_cache_key(pdf_path, page_number)
153
+ )
154
+
155
+ if page_bytes is not None:
156
+ logger.debug(f"Cache hit for page {page_number} of {pdf_path}")
157
+ try:
158
+ return page_bytes.decode("utf-8")
159
+ except UnicodeDecodeError:
160
+ logger.warning(
161
+ "Cached bytes for page %s of %s are not valid UTF-8; treating as miss.",
162
+ page_number,
163
+ pdf_path,
164
+ exc_info=True,
165
+ )
166
+
167
+ logger.debug(f"Cache miss for page {page_number} of {pdf_path}")
168
+ return None
169
+
170
+ async def convert_pdf_page_to_image_input(
171
+ self, page_path: Path, page_number: int
172
+ ) -> ExtractionInput:
173
+ image_paths = await convert_pdf_to_images(page_path, page_path.parent)
174
+ if len(image_paths) != 1:
175
+ raise ValueError(
176
+ f"Expected 1 image, got {len(image_paths)} for page {page_number} in {page_path}"
177
+ )
178
+ image_path = image_paths[0]
179
+ page_input = ExtractionInput(path=str(image_path), mime_type="image/png")
180
+ return page_input
181
+
182
+ async def _extract_single_pdf_page(
183
+ self,
184
+ pdf_path: Path,
185
+ page_path: Path,
186
+ prompt: str,
187
+ page_number: int,
188
+ ) -> str:
189
+ try:
190
+ if self.model_provider.multimodal_requires_pdf_as_image:
191
+ page_input = await self.convert_pdf_page_to_image_input(
192
+ page_path, page_number
193
+ )
194
+ else:
195
+ page_input = ExtractionInput(
196
+ path=str(page_path), mime_type="application/pdf"
197
+ )
198
+
199
+ completion_kwargs = self._build_completion_kwargs(prompt, page_input)
200
+ response = await litellm.acompletion(**completion_kwargs)
201
+ except Exception as e:
202
+ raise RuntimeError(
203
+ f"Error extracting page {page_number} in file {page_path}: {e}"
204
+ ) from e
205
+
206
+ if (
207
+ not isinstance(response, ModelResponse)
208
+ or not response.choices
209
+ or len(response.choices) == 0
210
+ or not isinstance(response.choices[0], Choices)
211
+ ):
212
+ raise RuntimeError(
213
+ f"Expected ModelResponse with Choices for page {page_number}, got {type(response)}."
214
+ )
215
+
216
+ if response.choices[0].message.content is None:
217
+ raise ValueError(
218
+ f"No text returned from LiteLLM when extracting page {page_number}"
219
+ )
220
+
221
+ content = response.choices[0].message.content
222
+ if self.filesystem_cache is not None:
223
+ # we don't want to fail the whole extraction just because cache write fails
224
+ # as that would block the whole flow
225
+ try:
226
+ logger.debug(f"Caching page {page_number} of {page_path} in cache")
227
+ await self.filesystem_cache.set(
228
+ self.pdf_page_cache_key(pdf_path, page_number),
229
+ content.encode("utf-8"),
230
+ )
231
+ except Exception:
232
+ logger.warning(
233
+ "Failed to cache page %s of %s; continuing without cache.",
234
+ page_number,
235
+ page_path,
236
+ exc_info=True,
237
+ )
238
+
239
+ return content
240
+
241
+ async def _extract_pdf_page_by_page(self, pdf_path: Path, prompt: str) -> str:
242
+ async with split_pdf_into_pages(pdf_path) as page_paths:
243
+ page_outcomes: List[str | Exception | None] = [None] * len(page_paths)
244
+
245
+ extract_page_jobs: list = []
246
+ page_indices_for_jobs: list = [] # Track which page index each job corresponds to
247
+
248
+ # we extract from each page individually and then combine the results
249
+ # this ensures the model stays focused on the current page and does not
250
+ # start summarizing the later pages
251
+ for i, page_path in enumerate(page_paths):
252
+ page_content = await self.get_page_content_from_cache(pdf_path, i)
253
+ if page_content is not None:
254
+ page_outcomes[i] = page_content
255
+ continue
256
+
257
+ extract_page_jobs.append(
258
+ self._extract_single_pdf_page(
259
+ pdf_path, page_path, prompt, page_number=i
260
+ )
261
+ )
262
+ page_indices_for_jobs.append(i)
263
+
264
+ if (
265
+ len(extract_page_jobs) >= self.max_parallel_requests_for_model
266
+ or i == len(page_paths) - 1
267
+ ):
268
+ extraction_results = await asyncio.gather(
269
+ *extract_page_jobs, return_exceptions=True
270
+ )
271
+
272
+ for batch_i, extraction_result in enumerate(extraction_results):
273
+ page_index = page_indices_for_jobs[batch_i]
274
+ # we let it continue even if there is an error - the success results will be cached
275
+ # and can be reused on the next run
276
+ if isinstance(extraction_result, Exception):
277
+ page_outcomes[page_index] = extraction_result
278
+ elif isinstance(extraction_result, str):
279
+ page_outcomes[page_index] = extraction_result
280
+ else:
281
+ raise ValueError(
282
+ f"Unexpected type {type(extraction_result)} for page {page_index}"
283
+ )
284
+ extract_page_jobs.clear()
285
+ page_indices_for_jobs.clear()
286
+
287
+ exceptions: list[tuple[int, Exception]] = [
288
+ (page_index, result)
289
+ for page_index, result in enumerate(page_outcomes)
290
+ if isinstance(result, Exception)
291
+ ]
292
+ if len(exceptions) > 0:
293
+ msg = f"Error extracting PDF {pdf_path}: "
294
+ for page_index, exception in exceptions:
295
+ msg += f"Page {page_index}: {exception}\n"
296
+ raise RuntimeError(msg)
297
+
298
+ return "\n\n".join(
299
+ [outcome for outcome in page_outcomes if isinstance(outcome, str)]
300
+ )
301
+
302
+ def _get_kind_from_mime_type(self, mime_type: str) -> Kind | None:
303
+ for kind, mime_types in MIME_TYPES_SUPPORTED.items():
304
+ if mime_type in mime_types:
305
+ return kind
306
+ return None
307
+
308
+ def _build_completion_kwargs(
309
+ self, prompt: str, extraction_input: ExtractionInput
310
+ ) -> dict[str, Any]:
311
+ completion_kwargs = {
312
+ "model": self.litellm_model_slug,
313
+ "messages": [
314
+ {
315
+ "role": "user",
316
+ "content": [
317
+ {"type": "text", "text": prompt},
318
+ encode_file_litellm_format(
319
+ Path(extraction_input.path), extraction_input.mime_type
320
+ ),
321
+ ],
322
+ }
323
+ ],
324
+ }
325
+
326
+ if self.litellm_core_config.base_url:
327
+ completion_kwargs["base_url"] = self.litellm_core_config.base_url
328
+
329
+ if self.litellm_core_config.default_headers:
330
+ completion_kwargs["default_headers"] = (
331
+ self.litellm_core_config.default_headers
332
+ )
333
+
334
+ if self.litellm_core_config.additional_body_options:
335
+ completion_kwargs.update(self.litellm_core_config.additional_body_options)
336
+
337
+ return completion_kwargs
338
+
339
+ async def _extract(self, extraction_input: ExtractionInput) -> ExtractionOutput:
340
+ kind = self._get_kind_from_mime_type(extraction_input.mime_type)
341
+ if kind is None:
342
+ raise ValueError(
343
+ f"Unsupported MIME type: {extraction_input.mime_type} for {extraction_input.path}"
344
+ )
345
+
346
+ prompt = self.prompt_for_kind.get(kind)
347
+ if prompt is None:
348
+ raise ValueError(f"No prompt found for kind: {kind}")
349
+
350
+ # special handling for PDFs - process each page individually
351
+ if extraction_input.mime_type == "application/pdf":
352
+ content = await self._extract_pdf_page_by_page(
353
+ Path(extraction_input.path), prompt
354
+ )
355
+ return ExtractionOutput(
356
+ is_passthrough=False,
357
+ content=content,
358
+ content_format=self.extractor_config.output_format,
359
+ )
360
+
361
+ completion_kwargs = self._build_completion_kwargs(prompt, extraction_input)
362
+
363
+ response = await litellm.acompletion(**completion_kwargs)
364
+
365
+ if (
366
+ not isinstance(response, ModelResponse)
367
+ or not response.choices
368
+ or len(response.choices) == 0
369
+ or not isinstance(response.choices[0], Choices)
370
+ ):
371
+ raise RuntimeError(
372
+ f"Expected ModelResponse with Choices, got {type(response)}."
373
+ )
374
+
375
+ if response.choices[0].message.content is None:
376
+ raise ValueError("No text returned from LiteLLM when extracting document")
377
+
378
+ return ExtractionOutput(
379
+ is_passthrough=False,
380
+ content=response.choices[0].message.content,
381
+ content_format=self.extractor_config.output_format,
382
+ )
383
+
384
+ @cached_property
385
+ def model_provider(self) -> KilnModelProvider:
386
+ kiln_model_provider = built_in_models_from_provider(
387
+ ModelProviderName(self.extractor_config.model_provider_name),
388
+ self.extractor_config.model_name,
389
+ )
390
+ if kiln_model_provider is None:
391
+ raise ValueError(
392
+ f"Model provider {self.extractor_config.model_provider_name} not found in the list of built-in models"
393
+ )
394
+ return kiln_model_provider
395
+
396
+ @cached_property
397
+ def max_parallel_requests_for_model(self) -> int:
398
+ value = self.model_provider.max_parallel_requests
399
+ return value if value is not None else self.default_max_parallel_requests
400
+
401
+ @cached_property
402
+ def litellm_model_slug(self) -> str:
403
+ litellm_provider_name = get_litellm_provider_info(
404
+ self.model_provider,
405
+ )
406
+ return litellm_provider_name.litellm_model_id
@@ -0,0 +1,244 @@
1
+ from typing import Any
2
+ from unittest.mock import patch
3
+
4
+ import pytest
5
+
6
+ from kiln_ai.adapters.extractors.base_extractor import (
7
+ BaseExtractor,
8
+ ExtractionInput,
9
+ ExtractionOutput,
10
+ )
11
+ from kiln_ai.datamodel.extraction import ExtractorConfig, ExtractorType, OutputFormat
12
+
13
+
14
+ class MockBaseExtractor(BaseExtractor):
15
+ async def _extract(self, input: ExtractionInput) -> ExtractionOutput:
16
+ return ExtractionOutput(
17
+ is_passthrough=False,
18
+ content="mock concrete extractor output",
19
+ content_format=OutputFormat.MARKDOWN,
20
+ )
21
+
22
+
23
+ @pytest.fixture
24
+ def mock_litellm_properties():
25
+ return {
26
+ "prompt_document": "mock prompt for document",
27
+ "prompt_image": "mock prompt for image",
28
+ "prompt_video": "mock prompt for video",
29
+ "prompt_audio": "mock prompt for audio",
30
+ }
31
+
32
+
33
+ @pytest.fixture
34
+ def mock_extractor(mock_litellm_properties):
35
+ return MockBaseExtractor(
36
+ ExtractorConfig(
37
+ name="mock",
38
+ model_provider_name="gemini_api",
39
+ model_name="gemini-2.0-flash",
40
+ extractor_type=ExtractorType.LITELLM,
41
+ output_format=OutputFormat.MARKDOWN,
42
+ properties=mock_litellm_properties,
43
+ )
44
+ )
45
+
46
+
47
+ def mock_extractor_with_passthroughs(
48
+ properties: dict[str, Any],
49
+ mimetypes: list[OutputFormat],
50
+ output_format: OutputFormat,
51
+ ):
52
+ return MockBaseExtractor(
53
+ ExtractorConfig(
54
+ name="mock",
55
+ model_provider_name="gemini_api",
56
+ model_name="gemini-2.0-flash",
57
+ extractor_type=ExtractorType.LITELLM,
58
+ passthrough_mimetypes=mimetypes,
59
+ output_format=output_format,
60
+ properties=properties,
61
+ )
62
+ )
63
+
64
+
65
+ def test_should_passthrough(mock_litellm_properties):
66
+ extractor = mock_extractor_with_passthroughs(
67
+ mock_litellm_properties,
68
+ [OutputFormat.TEXT, OutputFormat.MARKDOWN],
69
+ OutputFormat.TEXT,
70
+ )
71
+
72
+ # should passthrough
73
+ assert extractor._should_passthrough("text/plain")
74
+ assert extractor._should_passthrough("text/markdown")
75
+
76
+ # should not passthrough
77
+ assert not extractor._should_passthrough("image/png")
78
+ assert not extractor._should_passthrough("application/pdf")
79
+ assert not extractor._should_passthrough("text/html")
80
+ assert not extractor._should_passthrough("image/jpeg")
81
+
82
+
83
+ async def test_extract_passthrough(mock_litellm_properties):
84
+ """
85
+ Tests that when a file's MIME type is configured for passthrough, the extractor skips
86
+ the concrete extraction method and returns the file's contents directly with the
87
+ correct passthrough output format.
88
+ """
89
+ extractor = mock_extractor_with_passthroughs(
90
+ mock_litellm_properties,
91
+ [OutputFormat.TEXT, OutputFormat.MARKDOWN],
92
+ OutputFormat.TEXT,
93
+ )
94
+ with (
95
+ patch.object(
96
+ extractor,
97
+ "_extract",
98
+ return_value=ExtractionOutput(
99
+ is_passthrough=False,
100
+ content="mock concrete extractor output",
101
+ content_format=OutputFormat.TEXT,
102
+ ),
103
+ ) as mock_extract,
104
+ patch(
105
+ "pathlib.Path.read_text",
106
+ return_value=b"test content",
107
+ ),
108
+ ):
109
+ result = await extractor.extract(
110
+ ExtractionInput(
111
+ path="test.txt",
112
+ mime_type="text/plain",
113
+ )
114
+ )
115
+
116
+ # Verify _extract was not called
117
+ mock_extract.assert_not_called()
118
+
119
+ # Verify correct passthrough result
120
+ assert result.is_passthrough
121
+ assert result.content == "test content"
122
+ assert result.content_format == OutputFormat.TEXT
123
+
124
+
125
+ @pytest.mark.parametrize(
126
+ "output_format",
127
+ [
128
+ "text/plain",
129
+ "text/markdown",
130
+ ],
131
+ )
132
+ async def test_extract_passthrough_output_format(
133
+ mock_litellm_properties, output_format
134
+ ):
135
+ extractor = mock_extractor_with_passthroughs(
136
+ mock_litellm_properties,
137
+ [OutputFormat.TEXT, OutputFormat.MARKDOWN],
138
+ output_format,
139
+ )
140
+ with (
141
+ patch.object(
142
+ extractor,
143
+ "_extract",
144
+ return_value=ExtractionOutput(
145
+ is_passthrough=False,
146
+ content="mock concrete extractor output",
147
+ content_format=output_format,
148
+ ),
149
+ ) as mock_extract,
150
+ patch(
151
+ "pathlib.Path.read_text",
152
+ return_value="test content",
153
+ ),
154
+ ):
155
+ result = await extractor.extract(
156
+ ExtractionInput(
157
+ path="test.txt",
158
+ mime_type="text/plain",
159
+ )
160
+ )
161
+
162
+ # Verify _extract was not called
163
+ mock_extract.assert_not_called()
164
+
165
+ # Verify correct passthrough result
166
+ assert result.is_passthrough
167
+ assert result.content == "test content"
168
+ assert result.content_format == output_format
169
+
170
+
171
+ @pytest.mark.parametrize(
172
+ "path, mime_type, output_format",
173
+ [
174
+ ("test.mp3", "audio/mpeg", OutputFormat.TEXT),
175
+ ("test.png", "image/png", OutputFormat.TEXT),
176
+ ("test.pdf", "application/pdf", OutputFormat.TEXT),
177
+ ("test.txt", "text/plain", OutputFormat.MARKDOWN),
178
+ ("test.txt", "text/markdown", OutputFormat.MARKDOWN),
179
+ ("test.html", "text/html", OutputFormat.MARKDOWN),
180
+ ],
181
+ )
182
+ async def test_extract_non_passthrough(
183
+ mock_extractor, path: str, mime_type: str, output_format: OutputFormat
184
+ ):
185
+ with (
186
+ patch.object(
187
+ mock_extractor,
188
+ "_extract",
189
+ return_value=ExtractionOutput(
190
+ is_passthrough=False,
191
+ content="mock concrete extractor output",
192
+ content_format=output_format,
193
+ ),
194
+ ) as mock_extract,
195
+ ):
196
+ # first we call the base class extract method
197
+ result = await mock_extractor.extract(
198
+ ExtractionInput(
199
+ path=path,
200
+ mime_type=mime_type,
201
+ )
202
+ )
203
+
204
+ # then we call the subclass _extract method and add validated mime_type
205
+ mock_extract.assert_called_once_with(
206
+ ExtractionInput(
207
+ path=path,
208
+ mime_type=mime_type,
209
+ )
210
+ )
211
+
212
+ assert not result.is_passthrough
213
+ assert result.content == "mock concrete extractor output"
214
+ assert result.content_format == output_format
215
+
216
+
217
+ async def test_default_output_format(mock_litellm_properties):
218
+ config = ExtractorConfig(
219
+ name="mock",
220
+ model_provider_name="gemini_api",
221
+ model_name="gemini-2.0-flash",
222
+ extractor_type=ExtractorType.LITELLM,
223
+ properties=mock_litellm_properties,
224
+ )
225
+ assert config.output_format == OutputFormat.MARKDOWN
226
+
227
+
228
+ async def test_extract_failure_from_concrete_extractor(mock_extractor):
229
+ with patch.object(
230
+ mock_extractor,
231
+ "_extract",
232
+ side_effect=Exception("error from concrete extractor"),
233
+ ):
234
+ with pytest.raises(ValueError, match="error from concrete extractor"):
235
+ await mock_extractor.extract(
236
+ ExtractionInput(
237
+ path="test.txt",
238
+ mime_type="text/plain",
239
+ )
240
+ )
241
+
242
+
243
+ async def test_output_format(mock_extractor):
244
+ assert mock_extractor.output_format() == OutputFormat.MARKDOWN
@@ -0,0 +1,54 @@
1
+ from pathlib import Path
2
+
3
+ import pytest
4
+
5
+ from conftest import MockFileFactoryMimeType
6
+ from kiln_ai.adapters.extractors.encoding import from_base64_url, to_base64_url
7
+
8
+
9
+ async def test_to_base64_url(mock_file_factory):
10
+ mock_file = mock_file_factory(MockFileFactoryMimeType.JPEG)
11
+
12
+ byte_data = Path(mock_file).read_bytes()
13
+
14
+ # encode the byte data
15
+ base64_url = to_base64_url("image/jpeg", byte_data)
16
+ assert base64_url.startswith("data:image/jpeg;base64,")
17
+
18
+ # decode the base64 url
19
+ assert from_base64_url(base64_url) == byte_data
20
+
21
+
22
+ def test_from_base64_url_invalid_format_no_data_prefix():
23
+ """Test that from_base64_url raises ValueError when input doesn't start with 'data:'"""
24
+ with pytest.raises(ValueError, match="Invalid base64 URL format"):
25
+ from_base64_url("not-a-data-url")
26
+
27
+
28
+ def test_from_base64_url_invalid_format_no_comma():
29
+ """Test that from_base64_url raises ValueError when input doesn't contain a comma"""
30
+ with pytest.raises(ValueError, match="Invalid base64 URL format"):
31
+ from_base64_url("data:image/jpeg;base64")
32
+
33
+
34
+ def test_from_base64_url_invalid_parts():
35
+ """Test that from_base64_url raises ValueError when splitting by comma doesn't result in exactly 2 parts"""
36
+ with pytest.raises(ValueError, match="Invalid base64 URL format"):
37
+ from_base64_url("data:image/jpeg;base64,part1,part2")
38
+
39
+
40
+ def test_from_base64_url_base64_decode_failure():
41
+ """Test that from_base64_url raises ValueError when base64 decoding fails"""
42
+ with pytest.raises(ValueError, match="Failed to decode base64 data"):
43
+ from_base64_url("data:image/jpeg;base64,invalid-base64-data!")
44
+
45
+
46
+ def test_from_base64_url_valid_format():
47
+ """Test that from_base64_url works with valid base64 URL format"""
48
+ # Create a simple valid base64 URL
49
+ test_data = b"Hello, World!"
50
+ base64_encoded = "SGVsbG8sIFdvcmxkIQ=="
51
+ base64_url = f"data:text/plain;base64,{base64_encoded}"
52
+
53
+ result = from_base64_url(base64_url)
54
+ assert result == test_data