kiln-ai 0.21.0__py3-none-any.whl → 0.22.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/extractors/litellm_extractor.py +52 -32
- kiln_ai/adapters/extractors/test_litellm_extractor.py +169 -71
- kiln_ai/adapters/ml_embedding_model_list.py +330 -28
- kiln_ai/adapters/ml_model_list.py +503 -23
- kiln_ai/adapters/model_adapters/litellm_adapter.py +39 -8
- kiln_ai/adapters/model_adapters/test_litellm_adapter.py +78 -0
- kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py +119 -5
- kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +9 -3
- kiln_ai/adapters/model_adapters/test_structured_output.py +6 -9
- kiln_ai/adapters/test_ml_embedding_model_list.py +89 -279
- kiln_ai/adapters/test_ml_model_list.py +0 -10
- kiln_ai/adapters/vector_store/lancedb_adapter.py +24 -70
- kiln_ai/adapters/vector_store/lancedb_helpers.py +101 -0
- kiln_ai/adapters/vector_store/test_lancedb_adapter.py +9 -16
- kiln_ai/adapters/vector_store/test_lancedb_helpers.py +142 -0
- kiln_ai/adapters/vector_store_loaders/__init__.py +0 -0
- kiln_ai/adapters/vector_store_loaders/test_lancedb_loader.py +282 -0
- kiln_ai/adapters/vector_store_loaders/test_vector_store_loader.py +544 -0
- kiln_ai/adapters/vector_store_loaders/vector_store_loader.py +91 -0
- kiln_ai/datamodel/basemodel.py +31 -3
- kiln_ai/datamodel/external_tool_server.py +206 -54
- kiln_ai/datamodel/extraction.py +14 -0
- kiln_ai/datamodel/task.py +5 -0
- kiln_ai/datamodel/task_output.py +41 -11
- kiln_ai/datamodel/test_attachment.py +3 -3
- kiln_ai/datamodel/test_basemodel.py +269 -13
- kiln_ai/datamodel/test_datasource.py +50 -0
- kiln_ai/datamodel/test_external_tool_server.py +534 -152
- kiln_ai/datamodel/test_extraction_model.py +31 -0
- kiln_ai/datamodel/test_task.py +35 -1
- kiln_ai/datamodel/test_tool_id.py +106 -1
- kiln_ai/datamodel/tool_id.py +49 -0
- kiln_ai/tools/base_tool.py +30 -6
- kiln_ai/tools/built_in_tools/math_tools.py +12 -4
- kiln_ai/tools/kiln_task_tool.py +162 -0
- kiln_ai/tools/mcp_server_tool.py +7 -5
- kiln_ai/tools/mcp_session_manager.py +50 -24
- kiln_ai/tools/rag_tools.py +17 -6
- kiln_ai/tools/test_kiln_task_tool.py +527 -0
- kiln_ai/tools/test_mcp_server_tool.py +4 -15
- kiln_ai/tools/test_mcp_session_manager.py +186 -226
- kiln_ai/tools/test_rag_tools.py +86 -5
- kiln_ai/tools/test_tool_registry.py +199 -5
- kiln_ai/tools/tool_registry.py +49 -17
- kiln_ai/utils/filesystem.py +4 -4
- kiln_ai/utils/open_ai_types.py +19 -2
- kiln_ai/utils/pdf_utils.py +21 -0
- kiln_ai/utils/test_open_ai_types.py +88 -12
- kiln_ai/utils/test_pdf_utils.py +14 -1
- {kiln_ai-0.21.0.dist-info → kiln_ai-0.22.1.dist-info}/METADATA +79 -1
- {kiln_ai-0.21.0.dist-info → kiln_ai-0.22.1.dist-info}/RECORD +53 -45
- {kiln_ai-0.21.0.dist-info → kiln_ai-0.22.1.dist-info}/WHEEL +0 -0
- {kiln_ai-0.21.0.dist-info → kiln_ai-0.22.1.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
import hashlib
|
|
3
3
|
import logging
|
|
4
|
+
from functools import cached_property
|
|
4
5
|
from pathlib import Path
|
|
5
6
|
from typing import Any, List
|
|
6
7
|
|
|
@@ -13,23 +14,16 @@ from kiln_ai.adapters.extractors.base_extractor import (
|
|
|
13
14
|
ExtractionOutput,
|
|
14
15
|
)
|
|
15
16
|
from kiln_ai.adapters.extractors.encoding import to_base64_url
|
|
16
|
-
from kiln_ai.adapters.ml_model_list import
|
|
17
|
+
from kiln_ai.adapters.ml_model_list import (
|
|
18
|
+
KilnModelProvider,
|
|
19
|
+
built_in_models_from_provider,
|
|
20
|
+
)
|
|
17
21
|
from kiln_ai.adapters.provider_tools import LiteLlmCoreConfig
|
|
18
22
|
from kiln_ai.datamodel.datamodel_enums import ModelProviderName
|
|
19
23
|
from kiln_ai.datamodel.extraction import ExtractorConfig, ExtractorType, Kind
|
|
20
24
|
from kiln_ai.utils.filesystem_cache import FilesystemCache
|
|
21
25
|
from kiln_ai.utils.litellm import get_litellm_provider_info
|
|
22
|
-
from kiln_ai.utils.pdf_utils import split_pdf_into_pages
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
def max_pdf_page_concurrency_for_model(model_name: str) -> int:
|
|
26
|
-
# we assume each batch takes ~5s to complete (likely more in practice)
|
|
27
|
-
# lowest rate limit is 150 RPM for Tier 1 accounts for gemini-2.5-pro
|
|
28
|
-
if model_name == "gemini/gemini-2.5-pro":
|
|
29
|
-
return 2
|
|
30
|
-
# other models support at least 500 RPM for lowest tier accounts
|
|
31
|
-
return 5
|
|
32
|
-
|
|
26
|
+
from kiln_ai.utils.pdf_utils import convert_pdf_to_images, split_pdf_into_pages
|
|
33
27
|
|
|
34
28
|
logger = logging.getLogger(__name__)
|
|
35
29
|
|
|
@@ -74,11 +68,11 @@ def encode_file_litellm_format(path: Path, mime_type: str) -> dict[str, Any]:
|
|
|
74
68
|
"text/markdown",
|
|
75
69
|
"text/plain",
|
|
76
70
|
] or any(mime_type.startswith(m) for m in ["video/", "audio/"]):
|
|
77
|
-
|
|
71
|
+
file_bytes = path.read_bytes()
|
|
78
72
|
return {
|
|
79
73
|
"type": "file",
|
|
80
74
|
"file": {
|
|
81
|
-
"file_data": to_base64_url(mime_type,
|
|
75
|
+
"file_data": to_base64_url(mime_type, file_bytes),
|
|
82
76
|
},
|
|
83
77
|
}
|
|
84
78
|
|
|
@@ -101,6 +95,7 @@ class LitellmExtractor(BaseExtractor):
|
|
|
101
95
|
extractor_config: ExtractorConfig,
|
|
102
96
|
litellm_core_config: LiteLlmCoreConfig,
|
|
103
97
|
filesystem_cache: FilesystemCache | None = None,
|
|
98
|
+
default_max_parallel_requests: int = 5,
|
|
104
99
|
):
|
|
105
100
|
if extractor_config.extractor_type != ExtractorType.LITELLM:
|
|
106
101
|
raise ValueError(
|
|
@@ -133,6 +128,7 @@ class LitellmExtractor(BaseExtractor):
|
|
|
133
128
|
}
|
|
134
129
|
|
|
135
130
|
self.litellm_core_config = litellm_core_config
|
|
131
|
+
self.default_max_parallel_requests = default_max_parallel_requests
|
|
136
132
|
|
|
137
133
|
def pdf_page_cache_key(self, pdf_path: Path, page_number: int) -> str:
|
|
138
134
|
"""
|
|
@@ -171,13 +167,35 @@ class LitellmExtractor(BaseExtractor):
|
|
|
171
167
|
logger.debug(f"Cache miss for page {page_number} of {pdf_path}")
|
|
172
168
|
return None
|
|
173
169
|
|
|
170
|
+
async def convert_pdf_page_to_image_input(
|
|
171
|
+
self, page_path: Path, page_number: int
|
|
172
|
+
) -> ExtractionInput:
|
|
173
|
+
image_paths = await convert_pdf_to_images(page_path, page_path.parent)
|
|
174
|
+
if len(image_paths) != 1:
|
|
175
|
+
raise ValueError(
|
|
176
|
+
f"Expected 1 image, got {len(image_paths)} for page {page_number} in {page_path}"
|
|
177
|
+
)
|
|
178
|
+
image_path = image_paths[0]
|
|
179
|
+
page_input = ExtractionInput(path=str(image_path), mime_type="image/png")
|
|
180
|
+
return page_input
|
|
181
|
+
|
|
174
182
|
async def _extract_single_pdf_page(
|
|
175
|
-
self,
|
|
183
|
+
self,
|
|
184
|
+
pdf_path: Path,
|
|
185
|
+
page_path: Path,
|
|
186
|
+
prompt: str,
|
|
187
|
+
page_number: int,
|
|
176
188
|
) -> str:
|
|
177
189
|
try:
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
190
|
+
if self.model_provider.multimodal_requires_pdf_as_image:
|
|
191
|
+
page_input = await self.convert_pdf_page_to_image_input(
|
|
192
|
+
page_path, page_number
|
|
193
|
+
)
|
|
194
|
+
else:
|
|
195
|
+
page_input = ExtractionInput(
|
|
196
|
+
path=str(page_path), mime_type="application/pdf"
|
|
197
|
+
)
|
|
198
|
+
|
|
181
199
|
completion_kwargs = self._build_completion_kwargs(prompt, page_input)
|
|
182
200
|
response = await litellm.acompletion(**completion_kwargs)
|
|
183
201
|
except Exception as e:
|
|
@@ -201,11 +219,6 @@ class LitellmExtractor(BaseExtractor):
|
|
|
201
219
|
)
|
|
202
220
|
|
|
203
221
|
content = response.choices[0].message.content
|
|
204
|
-
if not content:
|
|
205
|
-
raise ValueError(
|
|
206
|
-
f"No text returned from extraction model when extracting page {page_number} for {page_path}"
|
|
207
|
-
)
|
|
208
|
-
|
|
209
222
|
if self.filesystem_cache is not None:
|
|
210
223
|
# we don't want to fail the whole extraction just because cache write fails
|
|
211
224
|
# as that would block the whole flow
|
|
@@ -242,13 +255,14 @@ class LitellmExtractor(BaseExtractor):
|
|
|
242
255
|
continue
|
|
243
256
|
|
|
244
257
|
extract_page_jobs.append(
|
|
245
|
-
self._extract_single_pdf_page(
|
|
258
|
+
self._extract_single_pdf_page(
|
|
259
|
+
pdf_path, page_path, prompt, page_number=i
|
|
260
|
+
)
|
|
246
261
|
)
|
|
247
262
|
page_indices_for_jobs.append(i)
|
|
248
263
|
|
|
249
264
|
if (
|
|
250
|
-
len(extract_page_jobs)
|
|
251
|
-
>= max_pdf_page_concurrency_for_model(self.litellm_model_slug())
|
|
265
|
+
len(extract_page_jobs) >= self.max_parallel_requests_for_model
|
|
252
266
|
or i == len(page_paths) - 1
|
|
253
267
|
):
|
|
254
268
|
extraction_results = await asyncio.gather(
|
|
@@ -295,7 +309,7 @@ class LitellmExtractor(BaseExtractor):
|
|
|
295
309
|
self, prompt: str, extraction_input: ExtractionInput
|
|
296
310
|
) -> dict[str, Any]:
|
|
297
311
|
completion_kwargs = {
|
|
298
|
-
"model": self.litellm_model_slug
|
|
312
|
+
"model": self.litellm_model_slug,
|
|
299
313
|
"messages": [
|
|
300
314
|
{
|
|
301
315
|
"role": "user",
|
|
@@ -367,20 +381,26 @@ class LitellmExtractor(BaseExtractor):
|
|
|
367
381
|
content_format=self.extractor_config.output_format,
|
|
368
382
|
)
|
|
369
383
|
|
|
370
|
-
|
|
384
|
+
@cached_property
|
|
385
|
+
def model_provider(self) -> KilnModelProvider:
|
|
371
386
|
kiln_model_provider = built_in_models_from_provider(
|
|
372
387
|
ModelProviderName(self.extractor_config.model_provider_name),
|
|
373
388
|
self.extractor_config.model_name,
|
|
374
389
|
)
|
|
375
|
-
|
|
376
390
|
if kiln_model_provider is None:
|
|
377
391
|
raise ValueError(
|
|
378
392
|
f"Model provider {self.extractor_config.model_provider_name} not found in the list of built-in models"
|
|
379
393
|
)
|
|
394
|
+
return kiln_model_provider
|
|
395
|
+
|
|
396
|
+
@cached_property
|
|
397
|
+
def max_parallel_requests_for_model(self) -> int:
|
|
398
|
+
value = self.model_provider.max_parallel_requests
|
|
399
|
+
return value if value is not None else self.default_max_parallel_requests
|
|
380
400
|
|
|
381
|
-
|
|
401
|
+
@cached_property
|
|
402
|
+
def litellm_model_slug(self) -> str:
|
|
382
403
|
litellm_provider_name = get_litellm_provider_info(
|
|
383
|
-
|
|
404
|
+
self.model_provider,
|
|
384
405
|
)
|
|
385
|
-
|
|
386
406
|
return litellm_provider_name.litellm_model_id
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from pathlib import Path
|
|
2
|
-
from unittest.mock import AsyncMock, patch
|
|
2
|
+
from unittest.mock import AsyncMock, MagicMock, patch
|
|
3
3
|
|
|
4
4
|
import pytest
|
|
5
5
|
from litellm.types.utils import Choices, ModelResponse
|
|
@@ -7,13 +7,17 @@ from litellm.types.utils import Choices, ModelResponse
|
|
|
7
7
|
from conftest import MockFileFactoryMimeType
|
|
8
8
|
from kiln_ai.adapters.extractors.base_extractor import ExtractionInput, OutputFormat
|
|
9
9
|
from kiln_ai.adapters.extractors.encoding import to_base64_url
|
|
10
|
+
from kiln_ai.adapters.extractors.extractor_registry import extractor_adapter_from_type
|
|
10
11
|
from kiln_ai.adapters.extractors.litellm_extractor import (
|
|
11
12
|
ExtractorConfig,
|
|
12
13
|
Kind,
|
|
13
14
|
LitellmExtractor,
|
|
14
15
|
encode_file_litellm_format,
|
|
15
16
|
)
|
|
16
|
-
from kiln_ai.adapters.ml_model_list import
|
|
17
|
+
from kiln_ai.adapters.ml_model_list import (
|
|
18
|
+
built_in_models,
|
|
19
|
+
built_in_models_from_provider,
|
|
20
|
+
)
|
|
17
21
|
from kiln_ai.adapters.provider_tools import LiteLlmCoreConfig
|
|
18
22
|
from kiln_ai.datamodel.extraction import ExtractorType
|
|
19
23
|
from kiln_ai.utils.filesystem_cache import FilesystemCache
|
|
@@ -405,7 +409,7 @@ def test_litellm_model_slug_success(mock_litellm_extractor):
|
|
|
405
409
|
return_value=mock_provider_info,
|
|
406
410
|
) as mock_get_provider_info,
|
|
407
411
|
):
|
|
408
|
-
result = mock_litellm_extractor.litellm_model_slug
|
|
412
|
+
result = mock_litellm_extractor.litellm_model_slug
|
|
409
413
|
|
|
410
414
|
assert result == "test-provider/test-model"
|
|
411
415
|
|
|
@@ -414,6 +418,38 @@ def test_litellm_model_slug_success(mock_litellm_extractor):
|
|
|
414
418
|
mock_get_provider_info.assert_called_once_with(mock_model_provider)
|
|
415
419
|
|
|
416
420
|
|
|
421
|
+
@pytest.mark.parametrize(
|
|
422
|
+
"max_parallel_requests, expected_result",
|
|
423
|
+
[
|
|
424
|
+
(10, 10),
|
|
425
|
+
(0, 0),
|
|
426
|
+
# 5 is the current default, it may change in the future if we have
|
|
427
|
+
# a better modeling of rate limit constraints
|
|
428
|
+
(None, 5),
|
|
429
|
+
],
|
|
430
|
+
)
|
|
431
|
+
def test_litellm_model_max_parallel_requests(
|
|
432
|
+
mock_litellm_extractor, max_parallel_requests, expected_result
|
|
433
|
+
):
|
|
434
|
+
"""Test that max_parallel_requests_for_model returns the provider's limit."""
|
|
435
|
+
# Mock the built_in_models_from_provider function to return a valid model provider
|
|
436
|
+
mock_model_provider = MagicMock()
|
|
437
|
+
mock_model_provider.name = "test-provider"
|
|
438
|
+
mock_model_provider.max_parallel_requests = max_parallel_requests
|
|
439
|
+
|
|
440
|
+
with (
|
|
441
|
+
patch(
|
|
442
|
+
"kiln_ai.adapters.extractors.litellm_extractor.built_in_models_from_provider",
|
|
443
|
+
return_value=mock_model_provider,
|
|
444
|
+
) as mock_built_in_models,
|
|
445
|
+
):
|
|
446
|
+
result = mock_litellm_extractor.max_parallel_requests_for_model
|
|
447
|
+
|
|
448
|
+
assert result == expected_result
|
|
449
|
+
|
|
450
|
+
mock_built_in_models.assert_called_once()
|
|
451
|
+
|
|
452
|
+
|
|
417
453
|
def test_litellm_model_slug_model_provider_not_found(mock_litellm_extractor):
|
|
418
454
|
"""Test that litellm_model_slug raises ValueError when model provider is not found."""
|
|
419
455
|
with patch(
|
|
@@ -424,7 +460,7 @@ def test_litellm_model_slug_model_provider_not_found(mock_litellm_extractor):
|
|
|
424
460
|
ValueError,
|
|
425
461
|
match="Model provider openai not found in the list of built-in models",
|
|
426
462
|
):
|
|
427
|
-
mock_litellm_extractor.litellm_model_slug
|
|
463
|
+
mock_litellm_extractor.litellm_model_slug
|
|
428
464
|
|
|
429
465
|
|
|
430
466
|
def test_litellm_model_slug_with_different_provider_names(mock_litellm_core_config):
|
|
@@ -468,35 +504,28 @@ def test_litellm_model_slug_with_different_provider_names(mock_litellm_core_conf
|
|
|
468
504
|
return_value=mock_provider_info,
|
|
469
505
|
),
|
|
470
506
|
):
|
|
471
|
-
result = extractor.litellm_model_slug
|
|
507
|
+
result = extractor.litellm_model_slug
|
|
472
508
|
assert result == expected_slug
|
|
473
509
|
|
|
474
510
|
|
|
475
511
|
def paid_litellm_extractor(model_name: str, provider_name: str):
|
|
476
|
-
|
|
477
|
-
|
|
512
|
+
extractor = extractor_adapter_from_type(
|
|
513
|
+
ExtractorType.LITELLM,
|
|
514
|
+
ExtractorConfig(
|
|
478
515
|
name="paid-litellm",
|
|
479
516
|
extractor_type=ExtractorType.LITELLM,
|
|
480
517
|
model_provider_name=provider_name,
|
|
481
518
|
model_name=model_name,
|
|
482
519
|
properties={
|
|
483
|
-
# in the paid tests, we can check which prompt is used by checking if the Kind shows up
|
|
484
|
-
# in the output - not ideal but usually works
|
|
485
520
|
"prompt_document": "Ignore the file and only respond with the word 'document'",
|
|
486
521
|
"prompt_image": "Ignore the file and only respond with the word 'image'",
|
|
487
522
|
"prompt_video": "Ignore the file and only respond with the word 'video'",
|
|
488
523
|
"prompt_audio": "Ignore the file and only respond with the word 'audio'",
|
|
489
524
|
},
|
|
490
|
-
passthrough_mimetypes=[
|
|
491
|
-
# we want all mimetypes to go to litellm to be sure we're testing the API call
|
|
492
|
-
],
|
|
493
|
-
),
|
|
494
|
-
litellm_core_config=LiteLlmCoreConfig(
|
|
495
|
-
base_url="https://test.com",
|
|
496
|
-
additional_body_options={"api_key": "test-key"},
|
|
497
|
-
default_headers={},
|
|
525
|
+
passthrough_mimetypes=[OutputFormat.MARKDOWN, OutputFormat.TEXT],
|
|
498
526
|
),
|
|
499
527
|
)
|
|
528
|
+
return extractor
|
|
500
529
|
|
|
501
530
|
|
|
502
531
|
@pytest.mark.parametrize(
|
|
@@ -560,6 +589,7 @@ def get_all_models_support_doc_extraction(
|
|
|
560
589
|
provider.multimodal_mime_types is None
|
|
561
590
|
or must_support_mime_types is None
|
|
562
591
|
):
|
|
592
|
+
model_provider_pairs.append((model.name, provider.name))
|
|
563
593
|
continue
|
|
564
594
|
# check that the model supports all the mime types
|
|
565
595
|
if all(
|
|
@@ -573,23 +603,7 @@ def get_all_models_support_doc_extraction(
|
|
|
573
603
|
@pytest.mark.paid
|
|
574
604
|
@pytest.mark.parametrize(
|
|
575
605
|
"model_name,provider_name",
|
|
576
|
-
get_all_models_support_doc_extraction(
|
|
577
|
-
must_support_mime_types=[
|
|
578
|
-
MockFileFactoryMimeType.PDF,
|
|
579
|
-
MockFileFactoryMimeType.TXT,
|
|
580
|
-
MockFileFactoryMimeType.MD,
|
|
581
|
-
MockFileFactoryMimeType.HTML,
|
|
582
|
-
MockFileFactoryMimeType.CSV,
|
|
583
|
-
MockFileFactoryMimeType.PNG,
|
|
584
|
-
MockFileFactoryMimeType.JPEG,
|
|
585
|
-
MockFileFactoryMimeType.JPG,
|
|
586
|
-
MockFileFactoryMimeType.MP4,
|
|
587
|
-
MockFileFactoryMimeType.MOV,
|
|
588
|
-
MockFileFactoryMimeType.MP3,
|
|
589
|
-
MockFileFactoryMimeType.OGG,
|
|
590
|
-
MockFileFactoryMimeType.WAV,
|
|
591
|
-
]
|
|
592
|
-
),
|
|
606
|
+
get_all_models_support_doc_extraction(must_support_mime_types=None),
|
|
593
607
|
)
|
|
594
608
|
@pytest.mark.parametrize(
|
|
595
609
|
"mime_type,expected_substring_in_output",
|
|
@@ -620,41 +634,17 @@ async def test_extract_document_success(
|
|
|
620
634
|
expected_substring_in_output,
|
|
621
635
|
mock_file_factory,
|
|
622
636
|
):
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
assert not output.is_passthrough
|
|
634
|
-
assert output.content_format == OutputFormat.MARKDOWN
|
|
635
|
-
assert expected_substring_in_output.lower() in output.content.lower()
|
|
636
|
-
|
|
637
|
+
# get model
|
|
638
|
+
model = built_in_models_from_provider(provider_name, model_name)
|
|
639
|
+
assert model is not None
|
|
640
|
+
if mime_type not in model.multimodal_mime_types:
|
|
641
|
+
pytest.skip(f"Model {model_name} configured to not support {mime_type}")
|
|
642
|
+
if (
|
|
643
|
+
mime_type == MockFileFactoryMimeType.MD
|
|
644
|
+
or mime_type == MockFileFactoryMimeType.TXT
|
|
645
|
+
):
|
|
646
|
+
pytest.skip(f"Model {model_name} configured to passthrough {mime_type}")
|
|
637
647
|
|
|
638
|
-
@pytest.mark.paid
|
|
639
|
-
@pytest.mark.parametrize(
|
|
640
|
-
"model_name,provider_name",
|
|
641
|
-
get_all_models_support_doc_extraction(
|
|
642
|
-
must_support_mime_types=[MockFileFactoryMimeType.PDF]
|
|
643
|
-
),
|
|
644
|
-
)
|
|
645
|
-
@pytest.mark.parametrize(
|
|
646
|
-
"mime_type,expected_substring_in_output",
|
|
647
|
-
[
|
|
648
|
-
(MockFileFactoryMimeType.PDF, "document"),
|
|
649
|
-
],
|
|
650
|
-
)
|
|
651
|
-
async def test_extract_document_success_pdf(
|
|
652
|
-
model_name,
|
|
653
|
-
provider_name,
|
|
654
|
-
mime_type,
|
|
655
|
-
expected_substring_in_output,
|
|
656
|
-
mock_file_factory,
|
|
657
|
-
):
|
|
658
648
|
test_file = mock_file_factory(mime_type)
|
|
659
649
|
extractor = paid_litellm_extractor(
|
|
660
650
|
model_name=model_name, provider_name=provider_name
|
|
@@ -704,6 +694,110 @@ async def test_extract_pdf_page_by_page(mock_file_factory, mock_litellm_extracto
|
|
|
704
694
|
assert result.content_format == OutputFormat.MARKDOWN
|
|
705
695
|
|
|
706
696
|
|
|
697
|
+
async def test_extract_pdf_page_by_page_pdf_as_image(
|
|
698
|
+
mock_file_factory, mock_litellm_extractor, tmp_path
|
|
699
|
+
):
|
|
700
|
+
"""Test that PDFs are processed page by page as images if the model requires it."""
|
|
701
|
+
|
|
702
|
+
test_file = mock_file_factory(MockFileFactoryMimeType.PDF)
|
|
703
|
+
|
|
704
|
+
# Mock responses for each page (PDF has 2 pages)
|
|
705
|
+
mock_responses = []
|
|
706
|
+
for i in range(2): # PDF has 2 pages
|
|
707
|
+
mock_response = AsyncMock(spec=ModelResponse)
|
|
708
|
+
mock_choice = AsyncMock(spec=Choices)
|
|
709
|
+
mock_message = AsyncMock()
|
|
710
|
+
mock_message.content = f"Content from page {i + 1}"
|
|
711
|
+
mock_choice.message = mock_message
|
|
712
|
+
mock_response.choices = [mock_choice]
|
|
713
|
+
mock_responses.append(mock_response)
|
|
714
|
+
|
|
715
|
+
mock_image_path = tmp_path / "img-test_document-mock.png"
|
|
716
|
+
mock_image_path.write_bytes(b"test image")
|
|
717
|
+
|
|
718
|
+
with patch("litellm.acompletion", side_effect=mock_responses) as mock_acompletion:
|
|
719
|
+
# this model requires PDFs to be processed as images
|
|
720
|
+
mock_litellm_extractor.model_provider.multimodal_requires_pdf_as_image = True
|
|
721
|
+
|
|
722
|
+
with patch(
|
|
723
|
+
"kiln_ai.adapters.extractors.litellm_extractor.convert_pdf_to_images",
|
|
724
|
+
return_value=[mock_image_path],
|
|
725
|
+
) as mock_convert:
|
|
726
|
+
result = await mock_litellm_extractor.extract(
|
|
727
|
+
ExtractionInput(
|
|
728
|
+
path=str(test_file),
|
|
729
|
+
mime_type="application/pdf",
|
|
730
|
+
)
|
|
731
|
+
)
|
|
732
|
+
|
|
733
|
+
# Verify image conversion called once per page
|
|
734
|
+
assert mock_convert.call_count == 2
|
|
735
|
+
|
|
736
|
+
# Verify LiteLLM was called with image inputs (not PDF) for each page
|
|
737
|
+
for call in mock_acompletion.call_args_list:
|
|
738
|
+
kwargs = call.kwargs
|
|
739
|
+
content = kwargs["messages"][0]["content"]
|
|
740
|
+
assert content[1]["type"] == "image_url"
|
|
741
|
+
|
|
742
|
+
# Verify that the completion was called multiple times (once per page)
|
|
743
|
+
assert mock_acompletion.call_count == 2
|
|
744
|
+
|
|
745
|
+
# Verify the output contains content from both pages
|
|
746
|
+
assert "Content from page 1" in result.content
|
|
747
|
+
assert "Content from page 2" in result.content
|
|
748
|
+
|
|
749
|
+
assert not result.is_passthrough
|
|
750
|
+
assert result.content_format == OutputFormat.MARKDOWN
|
|
751
|
+
|
|
752
|
+
|
|
753
|
+
async def test_convert_pdf_page_to_image_input_success(
|
|
754
|
+
mock_litellm_extractor, tmp_path
|
|
755
|
+
):
|
|
756
|
+
page_dir = tmp_path / "pages"
|
|
757
|
+
page_dir.mkdir()
|
|
758
|
+
page_path = page_dir / "page_1.pdf"
|
|
759
|
+
page_path.write_bytes(b"%PDF-1.4 test")
|
|
760
|
+
|
|
761
|
+
mock_image_path = page_dir / "img-page_1.pdf-0.png"
|
|
762
|
+
mock_image_path.write_bytes(b"image-bytes")
|
|
763
|
+
|
|
764
|
+
with patch(
|
|
765
|
+
"kiln_ai.adapters.extractors.litellm_extractor.convert_pdf_to_images",
|
|
766
|
+
return_value=[mock_image_path],
|
|
767
|
+
):
|
|
768
|
+
extraction_input = await mock_litellm_extractor.convert_pdf_page_to_image_input(
|
|
769
|
+
page_path, 0
|
|
770
|
+
)
|
|
771
|
+
|
|
772
|
+
assert extraction_input.mime_type == "image/png"
|
|
773
|
+
assert Path(extraction_input.path) == mock_image_path
|
|
774
|
+
|
|
775
|
+
|
|
776
|
+
@pytest.mark.parametrize("returned_count", [0, 2])
|
|
777
|
+
async def test_convert_pdf_page_to_image_input_error_on_invalid_count(
|
|
778
|
+
mock_litellm_extractor, tmp_path, returned_count
|
|
779
|
+
):
|
|
780
|
+
page_dir = tmp_path / "pages"
|
|
781
|
+
page_dir.mkdir()
|
|
782
|
+
page_path = page_dir / "page_1.pdf"
|
|
783
|
+
page_path.write_bytes(b"%PDF-1.4 test")
|
|
784
|
+
|
|
785
|
+
image_paths = []
|
|
786
|
+
if returned_count == 2:
|
|
787
|
+
img1 = page_dir / "img-page_1.pdf-0.png"
|
|
788
|
+
img2 = page_dir / "img-page_1.pdf-1.png"
|
|
789
|
+
img1.write_bytes(b"i1")
|
|
790
|
+
img2.write_bytes(b"i2")
|
|
791
|
+
image_paths = [img1, img2]
|
|
792
|
+
|
|
793
|
+
with patch(
|
|
794
|
+
"kiln_ai.adapters.extractors.litellm_extractor.convert_pdf_to_images",
|
|
795
|
+
return_value=image_paths,
|
|
796
|
+
):
|
|
797
|
+
with pytest.raises(ValueError, match=r"Expected 1 image, got "):
|
|
798
|
+
await mock_litellm_extractor.convert_pdf_page_to_image_input(page_path, 0)
|
|
799
|
+
|
|
800
|
+
|
|
707
801
|
async def test_extract_pdf_page_by_page_error_handling(
|
|
708
802
|
mock_file_factory, mock_litellm_extractor
|
|
709
803
|
):
|
|
@@ -894,15 +988,19 @@ async def test_extract_pdf_with_cache_storage(
|
|
|
894
988
|
# Verify that the completion was called for each page
|
|
895
989
|
assert mock_acompletion.call_count == 2
|
|
896
990
|
|
|
897
|
-
# Verify content is stored in cache
|
|
991
|
+
# Verify content is stored in cache - note that order is not guaranteed since
|
|
992
|
+
# we batch the page extraction requests in parallel
|
|
898
993
|
pdf_path = Path(test_file)
|
|
994
|
+
cached_contents = []
|
|
899
995
|
for i in range(2):
|
|
900
996
|
cached_content = (
|
|
901
997
|
await mock_litellm_extractor_with_cache.get_page_content_from_cache(
|
|
902
998
|
pdf_path, i
|
|
903
999
|
)
|
|
904
1000
|
)
|
|
905
|
-
assert cached_content
|
|
1001
|
+
assert cached_content is not None
|
|
1002
|
+
cached_contents.append(cached_content)
|
|
1003
|
+
assert set(cached_contents) == {"Content from page 1", "Content from page 2"}
|
|
906
1004
|
|
|
907
1005
|
# Verify the output contains content from both pages
|
|
908
1006
|
assert "Content from page 1" in result.content
|
|
@@ -1137,7 +1235,7 @@ async def test_extract_pdf_parallel_processing_error_handling(
|
|
|
1137
1235
|
"litellm.acompletion",
|
|
1138
1236
|
side_effect=[mock_response1, Exception("API Error on page 2")],
|
|
1139
1237
|
) as mock_acompletion:
|
|
1140
|
-
with pytest.raises(ValueError, match=r".*
|
|
1238
|
+
with pytest.raises(ValueError, match=r".*API Error on page 2"):
|
|
1141
1239
|
await mock_litellm_extractor_with_cache.extract(
|
|
1142
1240
|
ExtractionInput(
|
|
1143
1241
|
path=str(test_file),
|