kiln-ai 0.19.0__py3-none-any.whl → 0.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (158) hide show
  1. kiln_ai/adapters/__init__.py +8 -2
  2. kiln_ai/adapters/adapter_registry.py +43 -208
  3. kiln_ai/adapters/chat/chat_formatter.py +8 -12
  4. kiln_ai/adapters/chat/test_chat_formatter.py +6 -2
  5. kiln_ai/adapters/chunkers/__init__.py +13 -0
  6. kiln_ai/adapters/chunkers/base_chunker.py +42 -0
  7. kiln_ai/adapters/chunkers/chunker_registry.py +16 -0
  8. kiln_ai/adapters/chunkers/fixed_window_chunker.py +39 -0
  9. kiln_ai/adapters/chunkers/helpers.py +23 -0
  10. kiln_ai/adapters/chunkers/test_base_chunker.py +63 -0
  11. kiln_ai/adapters/chunkers/test_chunker_registry.py +28 -0
  12. kiln_ai/adapters/chunkers/test_fixed_window_chunker.py +346 -0
  13. kiln_ai/adapters/chunkers/test_helpers.py +75 -0
  14. kiln_ai/adapters/data_gen/test_data_gen_task.py +9 -3
  15. kiln_ai/adapters/docker_model_runner_tools.py +119 -0
  16. kiln_ai/adapters/embedding/__init__.py +0 -0
  17. kiln_ai/adapters/embedding/base_embedding_adapter.py +44 -0
  18. kiln_ai/adapters/embedding/embedding_registry.py +32 -0
  19. kiln_ai/adapters/embedding/litellm_embedding_adapter.py +199 -0
  20. kiln_ai/adapters/embedding/test_base_embedding_adapter.py +283 -0
  21. kiln_ai/adapters/embedding/test_embedding_registry.py +166 -0
  22. kiln_ai/adapters/embedding/test_litellm_embedding_adapter.py +1149 -0
  23. kiln_ai/adapters/eval/base_eval.py +2 -2
  24. kiln_ai/adapters/eval/eval_runner.py +9 -3
  25. kiln_ai/adapters/eval/g_eval.py +2 -2
  26. kiln_ai/adapters/eval/test_base_eval.py +2 -4
  27. kiln_ai/adapters/eval/test_g_eval.py +4 -5
  28. kiln_ai/adapters/extractors/__init__.py +18 -0
  29. kiln_ai/adapters/extractors/base_extractor.py +72 -0
  30. kiln_ai/adapters/extractors/encoding.py +20 -0
  31. kiln_ai/adapters/extractors/extractor_registry.py +44 -0
  32. kiln_ai/adapters/extractors/extractor_runner.py +112 -0
  33. kiln_ai/adapters/extractors/litellm_extractor.py +386 -0
  34. kiln_ai/adapters/extractors/test_base_extractor.py +244 -0
  35. kiln_ai/adapters/extractors/test_encoding.py +54 -0
  36. kiln_ai/adapters/extractors/test_extractor_registry.py +181 -0
  37. kiln_ai/adapters/extractors/test_extractor_runner.py +181 -0
  38. kiln_ai/adapters/extractors/test_litellm_extractor.py +1192 -0
  39. kiln_ai/adapters/fine_tune/__init__.py +1 -1
  40. kiln_ai/adapters/fine_tune/openai_finetune.py +14 -4
  41. kiln_ai/adapters/fine_tune/test_dataset_formatter.py +2 -2
  42. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +2 -6
  43. kiln_ai/adapters/fine_tune/test_openai_finetune.py +108 -111
  44. kiln_ai/adapters/fine_tune/test_together_finetune.py +2 -6
  45. kiln_ai/adapters/ml_embedding_model_list.py +192 -0
  46. kiln_ai/adapters/ml_model_list.py +761 -37
  47. kiln_ai/adapters/model_adapters/base_adapter.py +51 -21
  48. kiln_ai/adapters/model_adapters/litellm_adapter.py +380 -138
  49. kiln_ai/adapters/model_adapters/test_base_adapter.py +193 -17
  50. kiln_ai/adapters/model_adapters/test_litellm_adapter.py +407 -2
  51. kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py +1103 -0
  52. kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +5 -5
  53. kiln_ai/adapters/model_adapters/test_structured_output.py +113 -5
  54. kiln_ai/adapters/ollama_tools.py +69 -12
  55. kiln_ai/adapters/parsers/__init__.py +1 -1
  56. kiln_ai/adapters/provider_tools.py +205 -47
  57. kiln_ai/adapters/rag/deduplication.py +49 -0
  58. kiln_ai/adapters/rag/progress.py +252 -0
  59. kiln_ai/adapters/rag/rag_runners.py +844 -0
  60. kiln_ai/adapters/rag/test_deduplication.py +195 -0
  61. kiln_ai/adapters/rag/test_progress.py +785 -0
  62. kiln_ai/adapters/rag/test_rag_runners.py +2376 -0
  63. kiln_ai/adapters/remote_config.py +80 -8
  64. kiln_ai/adapters/repair/test_repair_task.py +12 -9
  65. kiln_ai/adapters/run_output.py +3 -0
  66. kiln_ai/adapters/test_adapter_registry.py +657 -85
  67. kiln_ai/adapters/test_docker_model_runner_tools.py +305 -0
  68. kiln_ai/adapters/test_ml_embedding_model_list.py +429 -0
  69. kiln_ai/adapters/test_ml_model_list.py +251 -1
  70. kiln_ai/adapters/test_ollama_tools.py +340 -1
  71. kiln_ai/adapters/test_prompt_adaptors.py +13 -6
  72. kiln_ai/adapters/test_prompt_builders.py +1 -1
  73. kiln_ai/adapters/test_provider_tools.py +254 -8
  74. kiln_ai/adapters/test_remote_config.py +651 -58
  75. kiln_ai/adapters/vector_store/__init__.py +1 -0
  76. kiln_ai/adapters/vector_store/base_vector_store_adapter.py +83 -0
  77. kiln_ai/adapters/vector_store/lancedb_adapter.py +389 -0
  78. kiln_ai/adapters/vector_store/test_base_vector_store.py +160 -0
  79. kiln_ai/adapters/vector_store/test_lancedb_adapter.py +1841 -0
  80. kiln_ai/adapters/vector_store/test_vector_store_registry.py +199 -0
  81. kiln_ai/adapters/vector_store/vector_store_registry.py +33 -0
  82. kiln_ai/datamodel/__init__.py +39 -34
  83. kiln_ai/datamodel/basemodel.py +170 -1
  84. kiln_ai/datamodel/chunk.py +158 -0
  85. kiln_ai/datamodel/datamodel_enums.py +28 -0
  86. kiln_ai/datamodel/embedding.py +64 -0
  87. kiln_ai/datamodel/eval.py +1 -1
  88. kiln_ai/datamodel/external_tool_server.py +298 -0
  89. kiln_ai/datamodel/extraction.py +303 -0
  90. kiln_ai/datamodel/json_schema.py +25 -10
  91. kiln_ai/datamodel/project.py +40 -1
  92. kiln_ai/datamodel/rag.py +79 -0
  93. kiln_ai/datamodel/registry.py +0 -15
  94. kiln_ai/datamodel/run_config.py +62 -0
  95. kiln_ai/datamodel/task.py +2 -77
  96. kiln_ai/datamodel/task_output.py +6 -1
  97. kiln_ai/datamodel/task_run.py +41 -0
  98. kiln_ai/datamodel/test_attachment.py +649 -0
  99. kiln_ai/datamodel/test_basemodel.py +4 -4
  100. kiln_ai/datamodel/test_chunk_models.py +317 -0
  101. kiln_ai/datamodel/test_dataset_split.py +1 -1
  102. kiln_ai/datamodel/test_embedding_models.py +448 -0
  103. kiln_ai/datamodel/test_eval_model.py +6 -6
  104. kiln_ai/datamodel/test_example_models.py +175 -0
  105. kiln_ai/datamodel/test_external_tool_server.py +691 -0
  106. kiln_ai/datamodel/test_extraction_chunk.py +206 -0
  107. kiln_ai/datamodel/test_extraction_model.py +470 -0
  108. kiln_ai/datamodel/test_rag.py +641 -0
  109. kiln_ai/datamodel/test_registry.py +8 -3
  110. kiln_ai/datamodel/test_task.py +15 -47
  111. kiln_ai/datamodel/test_tool_id.py +320 -0
  112. kiln_ai/datamodel/test_vector_store.py +320 -0
  113. kiln_ai/datamodel/tool_id.py +105 -0
  114. kiln_ai/datamodel/vector_store.py +141 -0
  115. kiln_ai/tools/__init__.py +8 -0
  116. kiln_ai/tools/base_tool.py +82 -0
  117. kiln_ai/tools/built_in_tools/__init__.py +13 -0
  118. kiln_ai/tools/built_in_tools/math_tools.py +124 -0
  119. kiln_ai/tools/built_in_tools/test_math_tools.py +204 -0
  120. kiln_ai/tools/mcp_server_tool.py +95 -0
  121. kiln_ai/tools/mcp_session_manager.py +246 -0
  122. kiln_ai/tools/rag_tools.py +157 -0
  123. kiln_ai/tools/test_base_tools.py +199 -0
  124. kiln_ai/tools/test_mcp_server_tool.py +457 -0
  125. kiln_ai/tools/test_mcp_session_manager.py +1585 -0
  126. kiln_ai/tools/test_rag_tools.py +848 -0
  127. kiln_ai/tools/test_tool_registry.py +562 -0
  128. kiln_ai/tools/tool_registry.py +85 -0
  129. kiln_ai/utils/__init__.py +3 -0
  130. kiln_ai/utils/async_job_runner.py +62 -17
  131. kiln_ai/utils/config.py +24 -2
  132. kiln_ai/utils/env.py +15 -0
  133. kiln_ai/utils/filesystem.py +14 -0
  134. kiln_ai/utils/filesystem_cache.py +60 -0
  135. kiln_ai/utils/litellm.py +94 -0
  136. kiln_ai/utils/lock.py +100 -0
  137. kiln_ai/utils/mime_type.py +38 -0
  138. kiln_ai/utils/open_ai_types.py +94 -0
  139. kiln_ai/utils/pdf_utils.py +38 -0
  140. kiln_ai/utils/project_utils.py +17 -0
  141. kiln_ai/utils/test_async_job_runner.py +151 -35
  142. kiln_ai/utils/test_config.py +138 -1
  143. kiln_ai/utils/test_env.py +142 -0
  144. kiln_ai/utils/test_filesystem_cache.py +316 -0
  145. kiln_ai/utils/test_litellm.py +206 -0
  146. kiln_ai/utils/test_lock.py +185 -0
  147. kiln_ai/utils/test_mime_type.py +66 -0
  148. kiln_ai/utils/test_open_ai_types.py +131 -0
  149. kiln_ai/utils/test_pdf_utils.py +73 -0
  150. kiln_ai/utils/test_uuid.py +111 -0
  151. kiln_ai/utils/test_validation.py +524 -0
  152. kiln_ai/utils/uuid.py +9 -0
  153. kiln_ai/utils/validation.py +90 -0
  154. {kiln_ai-0.19.0.dist-info → kiln_ai-0.21.0.dist-info}/METADATA +12 -5
  155. kiln_ai-0.21.0.dist-info/RECORD +211 -0
  156. kiln_ai-0.19.0.dist-info/RECORD +0 -115
  157. {kiln_ai-0.19.0.dist-info → kiln_ai-0.21.0.dist-info}/WHEEL +0 -0
  158. {kiln_ai-0.19.0.dist-info → kiln_ai-0.21.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -0,0 +1,386 @@
1
+ import asyncio
2
+ import hashlib
3
+ import logging
4
+ from pathlib import Path
5
+ from typing import Any, List
6
+
7
+ import litellm
8
+ from litellm.types.utils import Choices, ModelResponse
9
+
10
+ from kiln_ai.adapters.extractors.base_extractor import (
11
+ BaseExtractor,
12
+ ExtractionInput,
13
+ ExtractionOutput,
14
+ )
15
+ from kiln_ai.adapters.extractors.encoding import to_base64_url
16
+ from kiln_ai.adapters.ml_model_list import built_in_models_from_provider
17
+ from kiln_ai.adapters.provider_tools import LiteLlmCoreConfig
18
+ from kiln_ai.datamodel.datamodel_enums import ModelProviderName
19
+ from kiln_ai.datamodel.extraction import ExtractorConfig, ExtractorType, Kind
20
+ from kiln_ai.utils.filesystem_cache import FilesystemCache
21
+ from kiln_ai.utils.litellm import get_litellm_provider_info
22
+ from kiln_ai.utils.pdf_utils import split_pdf_into_pages
23
+
24
+
25
+ def max_pdf_page_concurrency_for_model(model_name: str) -> int:
26
+ # we assume each batch takes ~5s to complete (likely more in practice)
27
+ # lowest rate limit is 150 RPM for Tier 1 accounts for gemini-2.5-pro
28
+ if model_name == "gemini/gemini-2.5-pro":
29
+ return 2
30
+ # other models support at least 500 RPM for lowest tier accounts
31
+ return 5
32
+
33
+
34
+ logger = logging.getLogger(__name__)
35
+
36
+ MIME_TYPES_SUPPORTED = {
37
+ Kind.DOCUMENT: [
38
+ "application/pdf",
39
+ "text/plain",
40
+ "text/markdown", # not officially listed, but works
41
+ "text/html",
42
+ "text/md",
43
+ "text/csv",
44
+ ],
45
+ Kind.IMAGE: [
46
+ "image/png",
47
+ "image/jpeg",
48
+ "image/jpg",
49
+ ],
50
+ Kind.VIDEO: [
51
+ "video/mp4",
52
+ "video/mov", # the correct type is video/quicktime, but Google lists it as video/mov
53
+ "video/quicktime",
54
+ ],
55
+ Kind.AUDIO: [
56
+ "audio/wav",
57
+ "audio/mpeg", # this is the official MP3 mimetype, audio/mp3 is often used but not correct
58
+ "audio/ogg",
59
+ ],
60
+ }
61
+
62
+
63
+ def encode_file_litellm_format(path: Path, mime_type: str) -> dict[str, Any]:
64
+ # There are different formats that LiteLLM supports, the docs are scattered
65
+ # and incomplete:
66
+ # - https://docs.litellm.ai/docs/completion/document_understanding#base64
67
+ # - https://docs.litellm.ai/docs/completion/vision#explicitly-specify-image-type
68
+
69
+ # this is the most generic format that seems to work for all / most mime types
70
+ if mime_type in [
71
+ "application/pdf",
72
+ "text/csv",
73
+ "text/html",
74
+ "text/markdown",
75
+ "text/plain",
76
+ ] or any(mime_type.startswith(m) for m in ["video/", "audio/"]):
77
+ pdf_bytes = path.read_bytes()
78
+ return {
79
+ "type": "file",
80
+ "file": {
81
+ "file_data": to_base64_url(mime_type, pdf_bytes),
82
+ },
83
+ }
84
+
85
+ # image has its own format (but also appears to work with the file format)
86
+ if mime_type.startswith("image/"):
87
+ image_bytes = path.read_bytes()
88
+ return {
89
+ "type": "image_url",
90
+ "image_url": {
91
+ "url": to_base64_url(mime_type, image_bytes),
92
+ },
93
+ }
94
+
95
+ raise ValueError(f"Unsupported MIME type: {mime_type} for {path}")
96
+
97
+
98
+ class LitellmExtractor(BaseExtractor):
99
+ def __init__(
100
+ self,
101
+ extractor_config: ExtractorConfig,
102
+ litellm_core_config: LiteLlmCoreConfig,
103
+ filesystem_cache: FilesystemCache | None = None,
104
+ ):
105
+ if extractor_config.extractor_type != ExtractorType.LITELLM:
106
+ raise ValueError(
107
+ f"LitellmExtractor must be initialized with a litellm extractor_type config. Got {extractor_config.extractor_type}"
108
+ )
109
+
110
+ prompt_document = extractor_config.prompt_document()
111
+ if prompt_document is None or prompt_document == "":
112
+ raise ValueError(
113
+ "properties.prompt_document is required for LitellmExtractor"
114
+ )
115
+ prompt_video = extractor_config.prompt_video()
116
+ if prompt_video is None or prompt_video == "":
117
+ raise ValueError("properties.prompt_video is required for LitellmExtractor")
118
+ prompt_audio = extractor_config.prompt_audio()
119
+ if prompt_audio is None or prompt_audio == "":
120
+ raise ValueError("properties.prompt_audio is required for LitellmExtractor")
121
+ prompt_image = extractor_config.prompt_image()
122
+ if prompt_image is None or prompt_image == "":
123
+ raise ValueError("properties.prompt_image is required for LitellmExtractor")
124
+
125
+ self.filesystem_cache = filesystem_cache
126
+
127
+ super().__init__(extractor_config)
128
+ self.prompt_for_kind = {
129
+ Kind.DOCUMENT: prompt_document,
130
+ Kind.VIDEO: prompt_video,
131
+ Kind.AUDIO: prompt_audio,
132
+ Kind.IMAGE: prompt_image,
133
+ }
134
+
135
+ self.litellm_core_config = litellm_core_config
136
+
137
+ def pdf_page_cache_key(self, pdf_path: Path, page_number: int) -> str:
138
+ """
139
+ Generate a cache key for a page of a PDF. The PDF path must be the full path to the PDF file,
140
+ not the path to the page - since page path is temporary and changes on each run.
141
+ """
142
+ if self.extractor_config.id is None:
143
+ raise ValueError("Extractor config ID is required for PDF page cache key")
144
+
145
+ raw_key = f"{pdf_path.resolve()}::{page_number}"
146
+ digest = hashlib.md5(raw_key.encode("utf-8")).hexdigest()
147
+ return f"{self.extractor_config.id}_{digest}"
148
+
149
+ async def get_page_content_from_cache(
150
+ self, pdf_path: Path, page_number: int
151
+ ) -> str | None:
152
+ if self.filesystem_cache is None:
153
+ return None
154
+
155
+ page_bytes = await self.filesystem_cache.get(
156
+ self.pdf_page_cache_key(pdf_path, page_number)
157
+ )
158
+
159
+ if page_bytes is not None:
160
+ logger.debug(f"Cache hit for page {page_number} of {pdf_path}")
161
+ try:
162
+ return page_bytes.decode("utf-8")
163
+ except UnicodeDecodeError:
164
+ logger.warning(
165
+ "Cached bytes for page %s of %s are not valid UTF-8; treating as miss.",
166
+ page_number,
167
+ pdf_path,
168
+ exc_info=True,
169
+ )
170
+
171
+ logger.debug(f"Cache miss for page {page_number} of {pdf_path}")
172
+ return None
173
+
174
+ async def _extract_single_pdf_page(
175
+ self, pdf_path: Path, page_path: Path, prompt: str, page_number: int
176
+ ) -> str:
177
+ try:
178
+ page_input = ExtractionInput(
179
+ path=str(page_path), mime_type="application/pdf"
180
+ )
181
+ completion_kwargs = self._build_completion_kwargs(prompt, page_input)
182
+ response = await litellm.acompletion(**completion_kwargs)
183
+ except Exception as e:
184
+ raise RuntimeError(
185
+ f"Error extracting page {page_number} in file {page_path}: {e}"
186
+ ) from e
187
+
188
+ if (
189
+ not isinstance(response, ModelResponse)
190
+ or not response.choices
191
+ or len(response.choices) == 0
192
+ or not isinstance(response.choices[0], Choices)
193
+ ):
194
+ raise RuntimeError(
195
+ f"Expected ModelResponse with Choices for page {page_number}, got {type(response)}."
196
+ )
197
+
198
+ if response.choices[0].message.content is None:
199
+ raise ValueError(
200
+ f"No text returned from LiteLLM when extracting page {page_number}"
201
+ )
202
+
203
+ content = response.choices[0].message.content
204
+ if not content:
205
+ raise ValueError(
206
+ f"No text returned from extraction model when extracting page {page_number} for {page_path}"
207
+ )
208
+
209
+ if self.filesystem_cache is not None:
210
+ # we don't want to fail the whole extraction just because cache write fails
211
+ # as that would block the whole flow
212
+ try:
213
+ logger.debug(f"Caching page {page_number} of {page_path} in cache")
214
+ await self.filesystem_cache.set(
215
+ self.pdf_page_cache_key(pdf_path, page_number),
216
+ content.encode("utf-8"),
217
+ )
218
+ except Exception:
219
+ logger.warning(
220
+ "Failed to cache page %s of %s; continuing without cache.",
221
+ page_number,
222
+ page_path,
223
+ exc_info=True,
224
+ )
225
+
226
+ return content
227
+
228
+ async def _extract_pdf_page_by_page(self, pdf_path: Path, prompt: str) -> str:
229
+ async with split_pdf_into_pages(pdf_path) as page_paths:
230
+ page_outcomes: List[str | Exception | None] = [None] * len(page_paths)
231
+
232
+ extract_page_jobs: list = []
233
+ page_indices_for_jobs: list = [] # Track which page index each job corresponds to
234
+
235
+ # we extract from each page individually and then combine the results
236
+ # this ensures the model stays focused on the current page and does not
237
+ # start summarizing the later pages
238
+ for i, page_path in enumerate(page_paths):
239
+ page_content = await self.get_page_content_from_cache(pdf_path, i)
240
+ if page_content is not None:
241
+ page_outcomes[i] = page_content
242
+ continue
243
+
244
+ extract_page_jobs.append(
245
+ self._extract_single_pdf_page(pdf_path, page_path, prompt, i)
246
+ )
247
+ page_indices_for_jobs.append(i)
248
+
249
+ if (
250
+ len(extract_page_jobs)
251
+ >= max_pdf_page_concurrency_for_model(self.litellm_model_slug())
252
+ or i == len(page_paths) - 1
253
+ ):
254
+ extraction_results = await asyncio.gather(
255
+ *extract_page_jobs, return_exceptions=True
256
+ )
257
+
258
+ for batch_i, extraction_result in enumerate(extraction_results):
259
+ page_index = page_indices_for_jobs[batch_i]
260
+ # we let it continue even if there is an error - the success results will be cached
261
+ # and can be reused on the next run
262
+ if isinstance(extraction_result, Exception):
263
+ page_outcomes[page_index] = extraction_result
264
+ elif isinstance(extraction_result, str):
265
+ page_outcomes[page_index] = extraction_result
266
+ else:
267
+ raise ValueError(
268
+ f"Unexpected type {type(extraction_result)} for page {page_index}"
269
+ )
270
+ extract_page_jobs.clear()
271
+ page_indices_for_jobs.clear()
272
+
273
+ exceptions: list[tuple[int, Exception]] = [
274
+ (page_index, result)
275
+ for page_index, result in enumerate(page_outcomes)
276
+ if isinstance(result, Exception)
277
+ ]
278
+ if len(exceptions) > 0:
279
+ msg = f"Error extracting PDF {pdf_path}: "
280
+ for page_index, exception in exceptions:
281
+ msg += f"Page {page_index}: {exception}\n"
282
+ raise RuntimeError(msg)
283
+
284
+ return "\n\n".join(
285
+ [outcome for outcome in page_outcomes if isinstance(outcome, str)]
286
+ )
287
+
288
+ def _get_kind_from_mime_type(self, mime_type: str) -> Kind | None:
289
+ for kind, mime_types in MIME_TYPES_SUPPORTED.items():
290
+ if mime_type in mime_types:
291
+ return kind
292
+ return None
293
+
294
+ def _build_completion_kwargs(
295
+ self, prompt: str, extraction_input: ExtractionInput
296
+ ) -> dict[str, Any]:
297
+ completion_kwargs = {
298
+ "model": self.litellm_model_slug(),
299
+ "messages": [
300
+ {
301
+ "role": "user",
302
+ "content": [
303
+ {"type": "text", "text": prompt},
304
+ encode_file_litellm_format(
305
+ Path(extraction_input.path), extraction_input.mime_type
306
+ ),
307
+ ],
308
+ }
309
+ ],
310
+ }
311
+
312
+ if self.litellm_core_config.base_url:
313
+ completion_kwargs["base_url"] = self.litellm_core_config.base_url
314
+
315
+ if self.litellm_core_config.default_headers:
316
+ completion_kwargs["default_headers"] = (
317
+ self.litellm_core_config.default_headers
318
+ )
319
+
320
+ if self.litellm_core_config.additional_body_options:
321
+ completion_kwargs.update(self.litellm_core_config.additional_body_options)
322
+
323
+ return completion_kwargs
324
+
325
+ async def _extract(self, extraction_input: ExtractionInput) -> ExtractionOutput:
326
+ kind = self._get_kind_from_mime_type(extraction_input.mime_type)
327
+ if kind is None:
328
+ raise ValueError(
329
+ f"Unsupported MIME type: {extraction_input.mime_type} for {extraction_input.path}"
330
+ )
331
+
332
+ prompt = self.prompt_for_kind.get(kind)
333
+ if prompt is None:
334
+ raise ValueError(f"No prompt found for kind: {kind}")
335
+
336
+ # special handling for PDFs - process each page individually
337
+ if extraction_input.mime_type == "application/pdf":
338
+ content = await self._extract_pdf_page_by_page(
339
+ Path(extraction_input.path), prompt
340
+ )
341
+ return ExtractionOutput(
342
+ is_passthrough=False,
343
+ content=content,
344
+ content_format=self.extractor_config.output_format,
345
+ )
346
+
347
+ completion_kwargs = self._build_completion_kwargs(prompt, extraction_input)
348
+
349
+ response = await litellm.acompletion(**completion_kwargs)
350
+
351
+ if (
352
+ not isinstance(response, ModelResponse)
353
+ or not response.choices
354
+ or len(response.choices) == 0
355
+ or not isinstance(response.choices[0], Choices)
356
+ ):
357
+ raise RuntimeError(
358
+ f"Expected ModelResponse with Choices, got {type(response)}."
359
+ )
360
+
361
+ if response.choices[0].message.content is None:
362
+ raise ValueError("No text returned from LiteLLM when extracting document")
363
+
364
+ return ExtractionOutput(
365
+ is_passthrough=False,
366
+ content=response.choices[0].message.content,
367
+ content_format=self.extractor_config.output_format,
368
+ )
369
+
370
+ def litellm_model_slug(self) -> str:
371
+ kiln_model_provider = built_in_models_from_provider(
372
+ ModelProviderName(self.extractor_config.model_provider_name),
373
+ self.extractor_config.model_name,
374
+ )
375
+
376
+ if kiln_model_provider is None:
377
+ raise ValueError(
378
+ f"Model provider {self.extractor_config.model_provider_name} not found in the list of built-in models"
379
+ )
380
+
381
+ # need to translate into LiteLLM model slug
382
+ litellm_provider_name = get_litellm_provider_info(
383
+ kiln_model_provider,
384
+ )
385
+
386
+ return litellm_provider_name.litellm_model_id
@@ -0,0 +1,244 @@
1
+ from typing import Any
2
+ from unittest.mock import patch
3
+
4
+ import pytest
5
+
6
+ from kiln_ai.adapters.extractors.base_extractor import (
7
+ BaseExtractor,
8
+ ExtractionInput,
9
+ ExtractionOutput,
10
+ )
11
+ from kiln_ai.datamodel.extraction import ExtractorConfig, ExtractorType, OutputFormat
12
+
13
+
14
+ class MockBaseExtractor(BaseExtractor):
15
+ async def _extract(self, input: ExtractionInput) -> ExtractionOutput:
16
+ return ExtractionOutput(
17
+ is_passthrough=False,
18
+ content="mock concrete extractor output",
19
+ content_format=OutputFormat.MARKDOWN,
20
+ )
21
+
22
+
23
+ @pytest.fixture
24
+ def mock_litellm_properties():
25
+ return {
26
+ "prompt_document": "mock prompt for document",
27
+ "prompt_image": "mock prompt for image",
28
+ "prompt_video": "mock prompt for video",
29
+ "prompt_audio": "mock prompt for audio",
30
+ }
31
+
32
+
33
+ @pytest.fixture
34
+ def mock_extractor(mock_litellm_properties):
35
+ return MockBaseExtractor(
36
+ ExtractorConfig(
37
+ name="mock",
38
+ model_provider_name="gemini_api",
39
+ model_name="gemini-2.0-flash",
40
+ extractor_type=ExtractorType.LITELLM,
41
+ output_format=OutputFormat.MARKDOWN,
42
+ properties=mock_litellm_properties,
43
+ )
44
+ )
45
+
46
+
47
+ def mock_extractor_with_passthroughs(
48
+ properties: dict[str, Any],
49
+ mimetypes: list[OutputFormat],
50
+ output_format: OutputFormat,
51
+ ):
52
+ return MockBaseExtractor(
53
+ ExtractorConfig(
54
+ name="mock",
55
+ model_provider_name="gemini_api",
56
+ model_name="gemini-2.0-flash",
57
+ extractor_type=ExtractorType.LITELLM,
58
+ passthrough_mimetypes=mimetypes,
59
+ output_format=output_format,
60
+ properties=properties,
61
+ )
62
+ )
63
+
64
+
65
+ def test_should_passthrough(mock_litellm_properties):
66
+ extractor = mock_extractor_with_passthroughs(
67
+ mock_litellm_properties,
68
+ [OutputFormat.TEXT, OutputFormat.MARKDOWN],
69
+ OutputFormat.TEXT,
70
+ )
71
+
72
+ # should passthrough
73
+ assert extractor._should_passthrough("text/plain")
74
+ assert extractor._should_passthrough("text/markdown")
75
+
76
+ # should not passthrough
77
+ assert not extractor._should_passthrough("image/png")
78
+ assert not extractor._should_passthrough("application/pdf")
79
+ assert not extractor._should_passthrough("text/html")
80
+ assert not extractor._should_passthrough("image/jpeg")
81
+
82
+
83
+ async def test_extract_passthrough(mock_litellm_properties):
84
+ """
85
+ Tests that when a file's MIME type is configured for passthrough, the extractor skips
86
+ the concrete extraction method and returns the file's contents directly with the
87
+ correct passthrough output format.
88
+ """
89
+ extractor = mock_extractor_with_passthroughs(
90
+ mock_litellm_properties,
91
+ [OutputFormat.TEXT, OutputFormat.MARKDOWN],
92
+ OutputFormat.TEXT,
93
+ )
94
+ with (
95
+ patch.object(
96
+ extractor,
97
+ "_extract",
98
+ return_value=ExtractionOutput(
99
+ is_passthrough=False,
100
+ content="mock concrete extractor output",
101
+ content_format=OutputFormat.TEXT,
102
+ ),
103
+ ) as mock_extract,
104
+ patch(
105
+ "pathlib.Path.read_text",
106
+ return_value=b"test content",
107
+ ),
108
+ ):
109
+ result = await extractor.extract(
110
+ ExtractionInput(
111
+ path="test.txt",
112
+ mime_type="text/plain",
113
+ )
114
+ )
115
+
116
+ # Verify _extract was not called
117
+ mock_extract.assert_not_called()
118
+
119
+ # Verify correct passthrough result
120
+ assert result.is_passthrough
121
+ assert result.content == "test content"
122
+ assert result.content_format == OutputFormat.TEXT
123
+
124
+
125
+ @pytest.mark.parametrize(
126
+ "output_format",
127
+ [
128
+ "text/plain",
129
+ "text/markdown",
130
+ ],
131
+ )
132
+ async def test_extract_passthrough_output_format(
133
+ mock_litellm_properties, output_format
134
+ ):
135
+ extractor = mock_extractor_with_passthroughs(
136
+ mock_litellm_properties,
137
+ [OutputFormat.TEXT, OutputFormat.MARKDOWN],
138
+ output_format,
139
+ )
140
+ with (
141
+ patch.object(
142
+ extractor,
143
+ "_extract",
144
+ return_value=ExtractionOutput(
145
+ is_passthrough=False,
146
+ content="mock concrete extractor output",
147
+ content_format=output_format,
148
+ ),
149
+ ) as mock_extract,
150
+ patch(
151
+ "pathlib.Path.read_text",
152
+ return_value="test content",
153
+ ),
154
+ ):
155
+ result = await extractor.extract(
156
+ ExtractionInput(
157
+ path="test.txt",
158
+ mime_type="text/plain",
159
+ )
160
+ )
161
+
162
+ # Verify _extract was not called
163
+ mock_extract.assert_not_called()
164
+
165
+ # Verify correct passthrough result
166
+ assert result.is_passthrough
167
+ assert result.content == "test content"
168
+ assert result.content_format == output_format
169
+
170
+
171
+ @pytest.mark.parametrize(
172
+ "path, mime_type, output_format",
173
+ [
174
+ ("test.mp3", "audio/mpeg", OutputFormat.TEXT),
175
+ ("test.png", "image/png", OutputFormat.TEXT),
176
+ ("test.pdf", "application/pdf", OutputFormat.TEXT),
177
+ ("test.txt", "text/plain", OutputFormat.MARKDOWN),
178
+ ("test.txt", "text/markdown", OutputFormat.MARKDOWN),
179
+ ("test.html", "text/html", OutputFormat.MARKDOWN),
180
+ ],
181
+ )
182
+ async def test_extract_non_passthrough(
183
+ mock_extractor, path: str, mime_type: str, output_format: OutputFormat
184
+ ):
185
+ with (
186
+ patch.object(
187
+ mock_extractor,
188
+ "_extract",
189
+ return_value=ExtractionOutput(
190
+ is_passthrough=False,
191
+ content="mock concrete extractor output",
192
+ content_format=output_format,
193
+ ),
194
+ ) as mock_extract,
195
+ ):
196
+ # first we call the base class extract method
197
+ result = await mock_extractor.extract(
198
+ ExtractionInput(
199
+ path=path,
200
+ mime_type=mime_type,
201
+ )
202
+ )
203
+
204
+ # then we call the subclass _extract method and add validated mime_type
205
+ mock_extract.assert_called_once_with(
206
+ ExtractionInput(
207
+ path=path,
208
+ mime_type=mime_type,
209
+ )
210
+ )
211
+
212
+ assert not result.is_passthrough
213
+ assert result.content == "mock concrete extractor output"
214
+ assert result.content_format == output_format
215
+
216
+
217
+ async def test_default_output_format(mock_litellm_properties):
218
+ config = ExtractorConfig(
219
+ name="mock",
220
+ model_provider_name="gemini_api",
221
+ model_name="gemini-2.0-flash",
222
+ extractor_type=ExtractorType.LITELLM,
223
+ properties=mock_litellm_properties,
224
+ )
225
+ assert config.output_format == OutputFormat.MARKDOWN
226
+
227
+
228
+ async def test_extract_failure_from_concrete_extractor(mock_extractor):
229
+ with patch.object(
230
+ mock_extractor,
231
+ "_extract",
232
+ side_effect=Exception("error from concrete extractor"),
233
+ ):
234
+ with pytest.raises(ValueError, match="error from concrete extractor"):
235
+ await mock_extractor.extract(
236
+ ExtractionInput(
237
+ path="test.txt",
238
+ mime_type="text/plain",
239
+ )
240
+ )
241
+
242
+
243
+ async def test_output_format(mock_extractor):
244
+ assert mock_extractor.output_format() == OutputFormat.MARKDOWN
@@ -0,0 +1,54 @@
1
+ from pathlib import Path
2
+
3
+ import pytest
4
+
5
+ from conftest import MockFileFactoryMimeType
6
+ from kiln_ai.adapters.extractors.encoding import from_base64_url, to_base64_url
7
+
8
+
9
+ async def test_to_base64_url(mock_file_factory):
10
+ mock_file = mock_file_factory(MockFileFactoryMimeType.JPEG)
11
+
12
+ byte_data = Path(mock_file).read_bytes()
13
+
14
+ # encode the byte data
15
+ base64_url = to_base64_url("image/jpeg", byte_data)
16
+ assert base64_url.startswith("data:image/jpeg;base64,")
17
+
18
+ # decode the base64 url
19
+ assert from_base64_url(base64_url) == byte_data
20
+
21
+
22
+ def test_from_base64_url_invalid_format_no_data_prefix():
23
+ """Test that from_base64_url raises ValueError when input doesn't start with 'data:'"""
24
+ with pytest.raises(ValueError, match="Invalid base64 URL format"):
25
+ from_base64_url("not-a-data-url")
26
+
27
+
28
+ def test_from_base64_url_invalid_format_no_comma():
29
+ """Test that from_base64_url raises ValueError when input doesn't contain a comma"""
30
+ with pytest.raises(ValueError, match="Invalid base64 URL format"):
31
+ from_base64_url("data:image/jpeg;base64")
32
+
33
+
34
+ def test_from_base64_url_invalid_parts():
35
+ """Test that from_base64_url raises ValueError when splitting by comma doesn't result in exactly 2 parts"""
36
+ with pytest.raises(ValueError, match="Invalid base64 URL format"):
37
+ from_base64_url("data:image/jpeg;base64,part1,part2")
38
+
39
+
40
+ def test_from_base64_url_base64_decode_failure():
41
+ """Test that from_base64_url raises ValueError when base64 decoding fails"""
42
+ with pytest.raises(ValueError, match="Failed to decode base64 data"):
43
+ from_base64_url("data:image/jpeg;base64,invalid-base64-data!")
44
+
45
+
46
+ def test_from_base64_url_valid_format():
47
+ """Test that from_base64_url works with valid base64 URL format"""
48
+ # Create a simple valid base64 URL
49
+ test_data = b"Hello, World!"
50
+ base64_encoded = "SGVsbG8sIFdvcmxkIQ=="
51
+ base64_url = f"data:text/plain;base64,{base64_encoded}"
52
+
53
+ result = from_base64_url(base64_url)
54
+ assert result == test_data