kiln-ai 0.19.0__py3-none-any.whl → 0.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +8 -2
- kiln_ai/adapters/adapter_registry.py +43 -208
- kiln_ai/adapters/chat/chat_formatter.py +8 -12
- kiln_ai/adapters/chat/test_chat_formatter.py +6 -2
- kiln_ai/adapters/chunkers/__init__.py +13 -0
- kiln_ai/adapters/chunkers/base_chunker.py +42 -0
- kiln_ai/adapters/chunkers/chunker_registry.py +16 -0
- kiln_ai/adapters/chunkers/fixed_window_chunker.py +39 -0
- kiln_ai/adapters/chunkers/helpers.py +23 -0
- kiln_ai/adapters/chunkers/test_base_chunker.py +63 -0
- kiln_ai/adapters/chunkers/test_chunker_registry.py +28 -0
- kiln_ai/adapters/chunkers/test_fixed_window_chunker.py +346 -0
- kiln_ai/adapters/chunkers/test_helpers.py +75 -0
- kiln_ai/adapters/data_gen/test_data_gen_task.py +9 -3
- kiln_ai/adapters/docker_model_runner_tools.py +119 -0
- kiln_ai/adapters/embedding/__init__.py +0 -0
- kiln_ai/adapters/embedding/base_embedding_adapter.py +44 -0
- kiln_ai/adapters/embedding/embedding_registry.py +32 -0
- kiln_ai/adapters/embedding/litellm_embedding_adapter.py +199 -0
- kiln_ai/adapters/embedding/test_base_embedding_adapter.py +283 -0
- kiln_ai/adapters/embedding/test_embedding_registry.py +166 -0
- kiln_ai/adapters/embedding/test_litellm_embedding_adapter.py +1149 -0
- kiln_ai/adapters/eval/base_eval.py +2 -2
- kiln_ai/adapters/eval/eval_runner.py +9 -3
- kiln_ai/adapters/eval/g_eval.py +2 -2
- kiln_ai/adapters/eval/test_base_eval.py +2 -4
- kiln_ai/adapters/eval/test_g_eval.py +4 -5
- kiln_ai/adapters/extractors/__init__.py +18 -0
- kiln_ai/adapters/extractors/base_extractor.py +72 -0
- kiln_ai/adapters/extractors/encoding.py +20 -0
- kiln_ai/adapters/extractors/extractor_registry.py +44 -0
- kiln_ai/adapters/extractors/extractor_runner.py +112 -0
- kiln_ai/adapters/extractors/litellm_extractor.py +386 -0
- kiln_ai/adapters/extractors/test_base_extractor.py +244 -0
- kiln_ai/adapters/extractors/test_encoding.py +54 -0
- kiln_ai/adapters/extractors/test_extractor_registry.py +181 -0
- kiln_ai/adapters/extractors/test_extractor_runner.py +181 -0
- kiln_ai/adapters/extractors/test_litellm_extractor.py +1192 -0
- kiln_ai/adapters/fine_tune/__init__.py +1 -1
- kiln_ai/adapters/fine_tune/openai_finetune.py +14 -4
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +2 -2
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +2 -6
- kiln_ai/adapters/fine_tune/test_openai_finetune.py +108 -111
- kiln_ai/adapters/fine_tune/test_together_finetune.py +2 -6
- kiln_ai/adapters/ml_embedding_model_list.py +192 -0
- kiln_ai/adapters/ml_model_list.py +761 -37
- kiln_ai/adapters/model_adapters/base_adapter.py +51 -21
- kiln_ai/adapters/model_adapters/litellm_adapter.py +380 -138
- kiln_ai/adapters/model_adapters/test_base_adapter.py +193 -17
- kiln_ai/adapters/model_adapters/test_litellm_adapter.py +407 -2
- kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py +1103 -0
- kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +5 -5
- kiln_ai/adapters/model_adapters/test_structured_output.py +113 -5
- kiln_ai/adapters/ollama_tools.py +69 -12
- kiln_ai/adapters/parsers/__init__.py +1 -1
- kiln_ai/adapters/provider_tools.py +205 -47
- kiln_ai/adapters/rag/deduplication.py +49 -0
- kiln_ai/adapters/rag/progress.py +252 -0
- kiln_ai/adapters/rag/rag_runners.py +844 -0
- kiln_ai/adapters/rag/test_deduplication.py +195 -0
- kiln_ai/adapters/rag/test_progress.py +785 -0
- kiln_ai/adapters/rag/test_rag_runners.py +2376 -0
- kiln_ai/adapters/remote_config.py +80 -8
- kiln_ai/adapters/repair/test_repair_task.py +12 -9
- kiln_ai/adapters/run_output.py +3 -0
- kiln_ai/adapters/test_adapter_registry.py +657 -85
- kiln_ai/adapters/test_docker_model_runner_tools.py +305 -0
- kiln_ai/adapters/test_ml_embedding_model_list.py +429 -0
- kiln_ai/adapters/test_ml_model_list.py +251 -1
- kiln_ai/adapters/test_ollama_tools.py +340 -1
- kiln_ai/adapters/test_prompt_adaptors.py +13 -6
- kiln_ai/adapters/test_prompt_builders.py +1 -1
- kiln_ai/adapters/test_provider_tools.py +254 -8
- kiln_ai/adapters/test_remote_config.py +651 -58
- kiln_ai/adapters/vector_store/__init__.py +1 -0
- kiln_ai/adapters/vector_store/base_vector_store_adapter.py +83 -0
- kiln_ai/adapters/vector_store/lancedb_adapter.py +389 -0
- kiln_ai/adapters/vector_store/test_base_vector_store.py +160 -0
- kiln_ai/adapters/vector_store/test_lancedb_adapter.py +1841 -0
- kiln_ai/adapters/vector_store/test_vector_store_registry.py +199 -0
- kiln_ai/adapters/vector_store/vector_store_registry.py +33 -0
- kiln_ai/datamodel/__init__.py +39 -34
- kiln_ai/datamodel/basemodel.py +170 -1
- kiln_ai/datamodel/chunk.py +158 -0
- kiln_ai/datamodel/datamodel_enums.py +28 -0
- kiln_ai/datamodel/embedding.py +64 -0
- kiln_ai/datamodel/eval.py +1 -1
- kiln_ai/datamodel/external_tool_server.py +298 -0
- kiln_ai/datamodel/extraction.py +303 -0
- kiln_ai/datamodel/json_schema.py +25 -10
- kiln_ai/datamodel/project.py +40 -1
- kiln_ai/datamodel/rag.py +79 -0
- kiln_ai/datamodel/registry.py +0 -15
- kiln_ai/datamodel/run_config.py +62 -0
- kiln_ai/datamodel/task.py +2 -77
- kiln_ai/datamodel/task_output.py +6 -1
- kiln_ai/datamodel/task_run.py +41 -0
- kiln_ai/datamodel/test_attachment.py +649 -0
- kiln_ai/datamodel/test_basemodel.py +4 -4
- kiln_ai/datamodel/test_chunk_models.py +317 -0
- kiln_ai/datamodel/test_dataset_split.py +1 -1
- kiln_ai/datamodel/test_embedding_models.py +448 -0
- kiln_ai/datamodel/test_eval_model.py +6 -6
- kiln_ai/datamodel/test_example_models.py +175 -0
- kiln_ai/datamodel/test_external_tool_server.py +691 -0
- kiln_ai/datamodel/test_extraction_chunk.py +206 -0
- kiln_ai/datamodel/test_extraction_model.py +470 -0
- kiln_ai/datamodel/test_rag.py +641 -0
- kiln_ai/datamodel/test_registry.py +8 -3
- kiln_ai/datamodel/test_task.py +15 -47
- kiln_ai/datamodel/test_tool_id.py +320 -0
- kiln_ai/datamodel/test_vector_store.py +320 -0
- kiln_ai/datamodel/tool_id.py +105 -0
- kiln_ai/datamodel/vector_store.py +141 -0
- kiln_ai/tools/__init__.py +8 -0
- kiln_ai/tools/base_tool.py +82 -0
- kiln_ai/tools/built_in_tools/__init__.py +13 -0
- kiln_ai/tools/built_in_tools/math_tools.py +124 -0
- kiln_ai/tools/built_in_tools/test_math_tools.py +204 -0
- kiln_ai/tools/mcp_server_tool.py +95 -0
- kiln_ai/tools/mcp_session_manager.py +246 -0
- kiln_ai/tools/rag_tools.py +157 -0
- kiln_ai/tools/test_base_tools.py +199 -0
- kiln_ai/tools/test_mcp_server_tool.py +457 -0
- kiln_ai/tools/test_mcp_session_manager.py +1585 -0
- kiln_ai/tools/test_rag_tools.py +848 -0
- kiln_ai/tools/test_tool_registry.py +562 -0
- kiln_ai/tools/tool_registry.py +85 -0
- kiln_ai/utils/__init__.py +3 -0
- kiln_ai/utils/async_job_runner.py +62 -17
- kiln_ai/utils/config.py +24 -2
- kiln_ai/utils/env.py +15 -0
- kiln_ai/utils/filesystem.py +14 -0
- kiln_ai/utils/filesystem_cache.py +60 -0
- kiln_ai/utils/litellm.py +94 -0
- kiln_ai/utils/lock.py +100 -0
- kiln_ai/utils/mime_type.py +38 -0
- kiln_ai/utils/open_ai_types.py +94 -0
- kiln_ai/utils/pdf_utils.py +38 -0
- kiln_ai/utils/project_utils.py +17 -0
- kiln_ai/utils/test_async_job_runner.py +151 -35
- kiln_ai/utils/test_config.py +138 -1
- kiln_ai/utils/test_env.py +142 -0
- kiln_ai/utils/test_filesystem_cache.py +316 -0
- kiln_ai/utils/test_litellm.py +206 -0
- kiln_ai/utils/test_lock.py +185 -0
- kiln_ai/utils/test_mime_type.py +66 -0
- kiln_ai/utils/test_open_ai_types.py +131 -0
- kiln_ai/utils/test_pdf_utils.py +73 -0
- kiln_ai/utils/test_uuid.py +111 -0
- kiln_ai/utils/test_validation.py +524 -0
- kiln_ai/utils/uuid.py +9 -0
- kiln_ai/utils/validation.py +90 -0
- {kiln_ai-0.19.0.dist-info → kiln_ai-0.21.0.dist-info}/METADATA +12 -5
- kiln_ai-0.21.0.dist-info/RECORD +211 -0
- kiln_ai-0.19.0.dist-info/RECORD +0 -115
- {kiln_ai-0.19.0.dist-info → kiln_ai-0.21.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.19.0.dist-info → kiln_ai-0.21.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -0,0 +1,1192 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
from unittest.mock import AsyncMock, patch
|
|
3
|
+
|
|
4
|
+
import pytest
|
|
5
|
+
from litellm.types.utils import Choices, ModelResponse
|
|
6
|
+
|
|
7
|
+
from conftest import MockFileFactoryMimeType
|
|
8
|
+
from kiln_ai.adapters.extractors.base_extractor import ExtractionInput, OutputFormat
|
|
9
|
+
from kiln_ai.adapters.extractors.encoding import to_base64_url
|
|
10
|
+
from kiln_ai.adapters.extractors.litellm_extractor import (
|
|
11
|
+
ExtractorConfig,
|
|
12
|
+
Kind,
|
|
13
|
+
LitellmExtractor,
|
|
14
|
+
encode_file_litellm_format,
|
|
15
|
+
)
|
|
16
|
+
from kiln_ai.adapters.ml_model_list import built_in_models
|
|
17
|
+
from kiln_ai.adapters.provider_tools import LiteLlmCoreConfig
|
|
18
|
+
from kiln_ai.datamodel.extraction import ExtractorType
|
|
19
|
+
from kiln_ai.utils.filesystem_cache import FilesystemCache
|
|
20
|
+
|
|
21
|
+
PROMPTS_FOR_KIND: dict[str, str] = {
|
|
22
|
+
"document": "prompt for documents",
|
|
23
|
+
"image": "prompt for images",
|
|
24
|
+
"video": "prompt for videos",
|
|
25
|
+
"audio": "prompt for audio",
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@pytest.fixture
|
|
30
|
+
def mock_litellm_extractor():
|
|
31
|
+
return LitellmExtractor(
|
|
32
|
+
ExtractorConfig(
|
|
33
|
+
name="mock",
|
|
34
|
+
extractor_type=ExtractorType.LITELLM,
|
|
35
|
+
model_name="gpt_4o",
|
|
36
|
+
model_provider_name="openai",
|
|
37
|
+
properties={
|
|
38
|
+
"prompt_document": PROMPTS_FOR_KIND["document"],
|
|
39
|
+
"prompt_image": PROMPTS_FOR_KIND["image"],
|
|
40
|
+
"prompt_video": PROMPTS_FOR_KIND["video"],
|
|
41
|
+
"prompt_audio": PROMPTS_FOR_KIND["audio"],
|
|
42
|
+
},
|
|
43
|
+
),
|
|
44
|
+
litellm_core_config=LiteLlmCoreConfig(
|
|
45
|
+
base_url="https://test.com",
|
|
46
|
+
additional_body_options={"api_key": "test-key"},
|
|
47
|
+
default_headers={},
|
|
48
|
+
),
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@pytest.fixture
|
|
53
|
+
def mock_litellm_core_config():
|
|
54
|
+
return LiteLlmCoreConfig(
|
|
55
|
+
base_url="https://test.com",
|
|
56
|
+
additional_body_options={"api_key": "test-key"},
|
|
57
|
+
default_headers={},
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
@pytest.mark.parametrize(
|
|
62
|
+
"mime_type, kind",
|
|
63
|
+
[
|
|
64
|
+
# documents
|
|
65
|
+
("application/pdf", Kind.DOCUMENT),
|
|
66
|
+
("text/markdown", Kind.DOCUMENT),
|
|
67
|
+
("text/md", Kind.DOCUMENT),
|
|
68
|
+
("text/plain", Kind.DOCUMENT),
|
|
69
|
+
("text/html", Kind.DOCUMENT),
|
|
70
|
+
("text/csv", Kind.DOCUMENT),
|
|
71
|
+
# images
|
|
72
|
+
("image/png", Kind.IMAGE),
|
|
73
|
+
("image/jpeg", Kind.IMAGE),
|
|
74
|
+
("image/jpg", Kind.IMAGE),
|
|
75
|
+
# videos
|
|
76
|
+
("video/mp4", Kind.VIDEO),
|
|
77
|
+
("video/mov", Kind.VIDEO),
|
|
78
|
+
("video/quicktime", Kind.VIDEO),
|
|
79
|
+
# audio
|
|
80
|
+
("audio/mpeg", Kind.AUDIO),
|
|
81
|
+
("audio/ogg", Kind.AUDIO),
|
|
82
|
+
("audio/wav", Kind.AUDIO),
|
|
83
|
+
],
|
|
84
|
+
)
|
|
85
|
+
def test_get_kind_from_mime_type(mock_litellm_extractor, mime_type, kind):
|
|
86
|
+
"""Test that the kind is correctly inferred from the mime type."""
|
|
87
|
+
assert mock_litellm_extractor._get_kind_from_mime_type(mime_type) == kind
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def test_get_kind_from_mime_type_unsupported(mock_litellm_extractor):
|
|
91
|
+
assert (
|
|
92
|
+
mock_litellm_extractor._get_kind_from_mime_type("unsupported/mimetype") is None
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
@pytest.mark.parametrize(
|
|
97
|
+
"mime_type, expected_content",
|
|
98
|
+
[
|
|
99
|
+
(MockFileFactoryMimeType.TXT, "Content from text file"),
|
|
100
|
+
(MockFileFactoryMimeType.MD, "Content from markdown file"),
|
|
101
|
+
(MockFileFactoryMimeType.HTML, "Content from html file"),
|
|
102
|
+
(MockFileFactoryMimeType.CSV, "Content from csv file"),
|
|
103
|
+
(MockFileFactoryMimeType.PNG, "Content from image file"),
|
|
104
|
+
(MockFileFactoryMimeType.JPG, "Content from image file"),
|
|
105
|
+
(MockFileFactoryMimeType.MP4, "Content from video file"),
|
|
106
|
+
(MockFileFactoryMimeType.MP3, "Content from audio file"),
|
|
107
|
+
],
|
|
108
|
+
)
|
|
109
|
+
async def test_extract_success(
|
|
110
|
+
mock_file_factory, mock_litellm_extractor, mime_type, expected_content
|
|
111
|
+
):
|
|
112
|
+
"""Test successful extraction for non-PDF file types."""
|
|
113
|
+
# Create a mock file of the specified type
|
|
114
|
+
test_file = mock_file_factory(mime_type)
|
|
115
|
+
|
|
116
|
+
# Mock response for single file extraction
|
|
117
|
+
mock_response = AsyncMock(spec=ModelResponse)
|
|
118
|
+
mock_choice = AsyncMock(spec=Choices)
|
|
119
|
+
mock_message = AsyncMock()
|
|
120
|
+
mock_message.content = expected_content
|
|
121
|
+
mock_choice.message = mock_message
|
|
122
|
+
mock_response.choices = [mock_choice]
|
|
123
|
+
|
|
124
|
+
with patch("litellm.acompletion", return_value=mock_response) as mock_acompletion:
|
|
125
|
+
result = await mock_litellm_extractor.extract(
|
|
126
|
+
ExtractionInput(
|
|
127
|
+
path=str(test_file),
|
|
128
|
+
mime_type=mime_type.value,
|
|
129
|
+
)
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
# Verify that the completion was called once (single file)
|
|
133
|
+
assert mock_acompletion.call_count == 1
|
|
134
|
+
|
|
135
|
+
# Verify the output contains the expected content
|
|
136
|
+
assert expected_content in result.content
|
|
137
|
+
|
|
138
|
+
assert not result.is_passthrough
|
|
139
|
+
assert result.content_format == OutputFormat.MARKDOWN
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def test_build_completion_kwargs_with_all_options(mock_file_factory):
|
|
143
|
+
"""Test that _build_completion_kwargs properly includes all litellm_core_config options."""
|
|
144
|
+
litellm_core_config = LiteLlmCoreConfig(
|
|
145
|
+
base_url="https://custom-api.example.com",
|
|
146
|
+
additional_body_options={"custom_param": "value", "timeout": "30"},
|
|
147
|
+
default_headers={"Authorization": "Bearer custom-token"},
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
extractor = LitellmExtractor(
|
|
151
|
+
ExtractorConfig(
|
|
152
|
+
name="test",
|
|
153
|
+
extractor_type=ExtractorType.LITELLM,
|
|
154
|
+
model_name="gpt_4o",
|
|
155
|
+
model_provider_name="openai",
|
|
156
|
+
properties={
|
|
157
|
+
"prompt_document": "prompt for documents",
|
|
158
|
+
"prompt_image": "prompt for images",
|
|
159
|
+
"prompt_video": "prompt for videos",
|
|
160
|
+
"prompt_audio": "prompt for audio",
|
|
161
|
+
},
|
|
162
|
+
),
|
|
163
|
+
litellm_core_config=litellm_core_config,
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
extraction_input = ExtractionInput(
|
|
167
|
+
path=str(mock_file_factory(MockFileFactoryMimeType.PDF)),
|
|
168
|
+
mime_type="application/pdf",
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
completion_kwargs = extractor._build_completion_kwargs(
|
|
172
|
+
"test prompt", extraction_input
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
# Verify all completion kwargs are included
|
|
176
|
+
assert completion_kwargs["base_url"] == "https://custom-api.example.com"
|
|
177
|
+
assert completion_kwargs["custom_param"] == "value"
|
|
178
|
+
assert completion_kwargs["timeout"] == "30"
|
|
179
|
+
assert completion_kwargs["default_headers"] == {
|
|
180
|
+
"Authorization": "Bearer custom-token"
|
|
181
|
+
}
|
|
182
|
+
|
|
183
|
+
# Verify basic structure is maintained
|
|
184
|
+
assert "model" in completion_kwargs
|
|
185
|
+
assert "messages" in completion_kwargs
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def test_build_completion_kwargs_with_partial_options(mock_file_factory):
|
|
189
|
+
"""Test that _build_completion_kwargs works when only some options are set."""
|
|
190
|
+
litellm_core_config = LiteLlmCoreConfig(
|
|
191
|
+
base_url=None,
|
|
192
|
+
additional_body_options={"timeout": "30"},
|
|
193
|
+
default_headers=None,
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
extractor = LitellmExtractor(
|
|
197
|
+
ExtractorConfig(
|
|
198
|
+
name="test",
|
|
199
|
+
extractor_type=ExtractorType.LITELLM,
|
|
200
|
+
model_name="gpt_4o",
|
|
201
|
+
model_provider_name="openai",
|
|
202
|
+
properties={
|
|
203
|
+
"prompt_document": "prompt for documents",
|
|
204
|
+
"prompt_image": "prompt for images",
|
|
205
|
+
"prompt_video": "prompt for videos",
|
|
206
|
+
"prompt_audio": "prompt for audio",
|
|
207
|
+
},
|
|
208
|
+
),
|
|
209
|
+
litellm_core_config=litellm_core_config,
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
extraction_input = ExtractionInput(
|
|
213
|
+
path=str(mock_file_factory(MockFileFactoryMimeType.PDF)),
|
|
214
|
+
mime_type="application/pdf",
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
completion_kwargs = extractor._build_completion_kwargs(
|
|
218
|
+
"test prompt", extraction_input
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
# Verify only the set options are included
|
|
222
|
+
assert completion_kwargs["timeout"] == "30"
|
|
223
|
+
assert "base_url" not in completion_kwargs
|
|
224
|
+
assert "default_headers" not in completion_kwargs
|
|
225
|
+
|
|
226
|
+
# Verify basic structure is maintained
|
|
227
|
+
assert "model" in completion_kwargs
|
|
228
|
+
assert "messages" in completion_kwargs
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
def test_build_completion_kwargs_with_empty_options(mock_file_factory):
|
|
232
|
+
"""Test that _build_completion_kwargs works when all options are None/empty."""
|
|
233
|
+
litellm_core_config = LiteLlmCoreConfig(
|
|
234
|
+
base_url=None,
|
|
235
|
+
additional_body_options=None,
|
|
236
|
+
default_headers=None,
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
extractor = LitellmExtractor(
|
|
240
|
+
ExtractorConfig(
|
|
241
|
+
name="test",
|
|
242
|
+
extractor_type=ExtractorType.LITELLM,
|
|
243
|
+
model_name="gpt_4o",
|
|
244
|
+
model_provider_name="openai",
|
|
245
|
+
properties={
|
|
246
|
+
"prompt_document": "prompt for documents",
|
|
247
|
+
"prompt_image": "prompt for images",
|
|
248
|
+
"prompt_video": "prompt for videos",
|
|
249
|
+
"prompt_audio": "prompt for audio",
|
|
250
|
+
},
|
|
251
|
+
),
|
|
252
|
+
litellm_core_config=litellm_core_config,
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
extraction_input = ExtractionInput(
|
|
256
|
+
path=str(mock_file_factory(MockFileFactoryMimeType.PDF)),
|
|
257
|
+
mime_type="application/pdf",
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
completion_kwargs = extractor._build_completion_kwargs(
|
|
261
|
+
"test prompt", extraction_input
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
# Verify no completion kwargs are included
|
|
265
|
+
assert "base_url" not in completion_kwargs
|
|
266
|
+
assert "default_headers" not in completion_kwargs
|
|
267
|
+
|
|
268
|
+
# Verify basic structure is maintained
|
|
269
|
+
assert "model" in completion_kwargs
|
|
270
|
+
assert "messages" in completion_kwargs
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
def test_build_completion_kwargs_messages_structure(mock_file_factory):
|
|
274
|
+
"""Test that the messages structure in completion_kwargs is correct."""
|
|
275
|
+
litellm_core_config = LiteLlmCoreConfig()
|
|
276
|
+
|
|
277
|
+
extractor = LitellmExtractor(
|
|
278
|
+
ExtractorConfig(
|
|
279
|
+
name="test",
|
|
280
|
+
extractor_type=ExtractorType.LITELLM,
|
|
281
|
+
model_name="gpt_4o",
|
|
282
|
+
model_provider_name="openai",
|
|
283
|
+
properties={
|
|
284
|
+
"prompt_document": "prompt for documents",
|
|
285
|
+
"prompt_image": "prompt for images",
|
|
286
|
+
"prompt_video": "prompt for videos",
|
|
287
|
+
"prompt_audio": "prompt for audio",
|
|
288
|
+
},
|
|
289
|
+
),
|
|
290
|
+
litellm_core_config=litellm_core_config,
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
extraction_input = ExtractionInput(
|
|
294
|
+
path=str(mock_file_factory(MockFileFactoryMimeType.PDF)),
|
|
295
|
+
mime_type="application/pdf",
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
completion_kwargs = extractor._build_completion_kwargs(
|
|
299
|
+
"test prompt", extraction_input
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
# Verify messages structure
|
|
303
|
+
messages = completion_kwargs["messages"]
|
|
304
|
+
assert len(messages) == 1
|
|
305
|
+
assert messages[0]["role"] == "user"
|
|
306
|
+
|
|
307
|
+
content = messages[0]["content"]
|
|
308
|
+
assert len(content) == 2
|
|
309
|
+
|
|
310
|
+
# First content item should be text
|
|
311
|
+
assert content[0]["type"] == "text"
|
|
312
|
+
assert content[0]["text"] == "test prompt"
|
|
313
|
+
|
|
314
|
+
# Second content item should be file
|
|
315
|
+
assert content[1]["type"] == "file"
|
|
316
|
+
assert "file" in content[1]
|
|
317
|
+
assert "file_data" in content[1]["file"]
|
|
318
|
+
|
|
319
|
+
|
|
320
|
+
async def test_extract_failure_from_litellm(mock_file_factory, mock_litellm_extractor):
|
|
321
|
+
test_pdf_file = mock_file_factory(MockFileFactoryMimeType.PDF)
|
|
322
|
+
|
|
323
|
+
with (
|
|
324
|
+
patch("pathlib.Path.read_bytes", return_value=b"test content"),
|
|
325
|
+
patch("litellm.acompletion", side_effect=Exception("error from litellm")),
|
|
326
|
+
patch(
|
|
327
|
+
"kiln_ai.adapters.extractors.litellm_extractor.LitellmExtractor.litellm_model_slug",
|
|
328
|
+
return_value="provider-name/model-name",
|
|
329
|
+
),
|
|
330
|
+
):
|
|
331
|
+
# Mock litellm to raise an exception
|
|
332
|
+
with pytest.raises(Exception, match="error from litellm"):
|
|
333
|
+
await mock_litellm_extractor.extract(
|
|
334
|
+
extraction_input=ExtractionInput(
|
|
335
|
+
path=str(test_pdf_file),
|
|
336
|
+
mime_type="application/pdf",
|
|
337
|
+
)
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
|
|
341
|
+
async def test_extract_failure_from_bytes_read(mock_litellm_extractor):
|
|
342
|
+
with (
|
|
343
|
+
patch(
|
|
344
|
+
"mimetypes.guess_type",
|
|
345
|
+
return_value=("application/pdf", None),
|
|
346
|
+
),
|
|
347
|
+
patch(
|
|
348
|
+
"pathlib.Path.read_bytes",
|
|
349
|
+
side_effect=Exception("error from read_bytes"),
|
|
350
|
+
),
|
|
351
|
+
patch(
|
|
352
|
+
"kiln_ai.adapters.extractors.litellm_extractor.LitellmExtractor.litellm_model_slug",
|
|
353
|
+
return_value="provider-name/model-name",
|
|
354
|
+
),
|
|
355
|
+
patch(
|
|
356
|
+
"kiln_ai.adapters.extractors.litellm_extractor.split_pdf_into_pages",
|
|
357
|
+
side_effect=Exception("error from split_pdf_into_pages"),
|
|
358
|
+
),
|
|
359
|
+
):
|
|
360
|
+
# test the extract method
|
|
361
|
+
with pytest.raises(
|
|
362
|
+
ValueError,
|
|
363
|
+
match=r"Error extracting test.pdf: error from split_pdf_into_pages",
|
|
364
|
+
):
|
|
365
|
+
await mock_litellm_extractor.extract(
|
|
366
|
+
extraction_input=ExtractionInput(
|
|
367
|
+
path="test.pdf",
|
|
368
|
+
mime_type="application/pdf",
|
|
369
|
+
)
|
|
370
|
+
)
|
|
371
|
+
|
|
372
|
+
|
|
373
|
+
async def test_extract_failure_unsupported_mime_type(mock_litellm_extractor):
|
|
374
|
+
# spy on the get mime type
|
|
375
|
+
with patch(
|
|
376
|
+
"mimetypes.guess_type",
|
|
377
|
+
return_value=(None, None),
|
|
378
|
+
):
|
|
379
|
+
with pytest.raises(ValueError, match="Unsupported MIME type"):
|
|
380
|
+
await mock_litellm_extractor.extract(
|
|
381
|
+
extraction_input=ExtractionInput(
|
|
382
|
+
path="test.unsupported",
|
|
383
|
+
mime_type="unsupported/mimetype",
|
|
384
|
+
)
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
|
|
388
|
+
def test_litellm_model_slug_success(mock_litellm_extractor):
|
|
389
|
+
"""Test that litellm_model_slug returns the correct model slug."""
|
|
390
|
+
# Mock the built_in_models_from_provider function to return a valid model provider
|
|
391
|
+
mock_model_provider = AsyncMock()
|
|
392
|
+
mock_model_provider.name = "test-provider"
|
|
393
|
+
|
|
394
|
+
# Mock the get_litellm_provider_info function to return provider info with model ID
|
|
395
|
+
mock_provider_info = AsyncMock()
|
|
396
|
+
mock_provider_info.litellm_model_id = "test-provider/test-model"
|
|
397
|
+
|
|
398
|
+
with (
|
|
399
|
+
patch(
|
|
400
|
+
"kiln_ai.adapters.extractors.litellm_extractor.built_in_models_from_provider",
|
|
401
|
+
return_value=mock_model_provider,
|
|
402
|
+
) as mock_built_in_models,
|
|
403
|
+
patch(
|
|
404
|
+
"kiln_ai.adapters.extractors.litellm_extractor.get_litellm_provider_info",
|
|
405
|
+
return_value=mock_provider_info,
|
|
406
|
+
) as mock_get_provider_info,
|
|
407
|
+
):
|
|
408
|
+
result = mock_litellm_extractor.litellm_model_slug()
|
|
409
|
+
|
|
410
|
+
assert result == "test-provider/test-model"
|
|
411
|
+
|
|
412
|
+
# Verify the functions were called with correct arguments
|
|
413
|
+
mock_built_in_models.assert_called_once()
|
|
414
|
+
mock_get_provider_info.assert_called_once_with(mock_model_provider)
|
|
415
|
+
|
|
416
|
+
|
|
417
|
+
def test_litellm_model_slug_model_provider_not_found(mock_litellm_extractor):
|
|
418
|
+
"""Test that litellm_model_slug raises ValueError when model provider is not found."""
|
|
419
|
+
with patch(
|
|
420
|
+
"kiln_ai.adapters.extractors.litellm_extractor.built_in_models_from_provider",
|
|
421
|
+
return_value=None,
|
|
422
|
+
):
|
|
423
|
+
with pytest.raises(
|
|
424
|
+
ValueError,
|
|
425
|
+
match="Model provider openai not found in the list of built-in models",
|
|
426
|
+
):
|
|
427
|
+
mock_litellm_extractor.litellm_model_slug()
|
|
428
|
+
|
|
429
|
+
|
|
430
|
+
def test_litellm_model_slug_with_different_provider_names(mock_litellm_core_config):
|
|
431
|
+
"""Test litellm_model_slug with different provider and model combinations."""
|
|
432
|
+
test_cases = [
|
|
433
|
+
("anthropic", "claude-3-sonnet", "anthropic/claude-3-sonnet"),
|
|
434
|
+
("openai", "gpt-4", "openai/gpt-4"),
|
|
435
|
+
("gemini_api", "gemini-pro", "gemini_api/gemini-pro"),
|
|
436
|
+
]
|
|
437
|
+
|
|
438
|
+
for provider_name, model_name, expected_slug in test_cases:
|
|
439
|
+
extractor = LitellmExtractor(
|
|
440
|
+
ExtractorConfig(
|
|
441
|
+
name="test",
|
|
442
|
+
extractor_type=ExtractorType.LITELLM,
|
|
443
|
+
model_name=model_name,
|
|
444
|
+
model_provider_name=provider_name,
|
|
445
|
+
properties={
|
|
446
|
+
"prompt_document": "test prompt",
|
|
447
|
+
"prompt_image": "test prompt",
|
|
448
|
+
"prompt_video": "test prompt",
|
|
449
|
+
"prompt_audio": "test prompt",
|
|
450
|
+
},
|
|
451
|
+
),
|
|
452
|
+
litellm_core_config=mock_litellm_core_config,
|
|
453
|
+
)
|
|
454
|
+
|
|
455
|
+
mock_model_provider = AsyncMock()
|
|
456
|
+
mock_model_provider.name = provider_name
|
|
457
|
+
|
|
458
|
+
mock_provider_info = AsyncMock()
|
|
459
|
+
mock_provider_info.litellm_model_id = expected_slug
|
|
460
|
+
|
|
461
|
+
with (
|
|
462
|
+
patch(
|
|
463
|
+
"kiln_ai.adapters.extractors.litellm_extractor.built_in_models_from_provider",
|
|
464
|
+
return_value=mock_model_provider,
|
|
465
|
+
),
|
|
466
|
+
patch(
|
|
467
|
+
"kiln_ai.adapters.extractors.litellm_extractor.get_litellm_provider_info",
|
|
468
|
+
return_value=mock_provider_info,
|
|
469
|
+
),
|
|
470
|
+
):
|
|
471
|
+
result = extractor.litellm_model_slug()
|
|
472
|
+
assert result == expected_slug
|
|
473
|
+
|
|
474
|
+
|
|
475
|
+
def paid_litellm_extractor(model_name: str, provider_name: str):
|
|
476
|
+
return LitellmExtractor(
|
|
477
|
+
extractor_config=ExtractorConfig(
|
|
478
|
+
name="paid-litellm",
|
|
479
|
+
extractor_type=ExtractorType.LITELLM,
|
|
480
|
+
model_provider_name=provider_name,
|
|
481
|
+
model_name=model_name,
|
|
482
|
+
properties={
|
|
483
|
+
# in the paid tests, we can check which prompt is used by checking if the Kind shows up
|
|
484
|
+
# in the output - not ideal but usually works
|
|
485
|
+
"prompt_document": "Ignore the file and only respond with the word 'document'",
|
|
486
|
+
"prompt_image": "Ignore the file and only respond with the word 'image'",
|
|
487
|
+
"prompt_video": "Ignore the file and only respond with the word 'video'",
|
|
488
|
+
"prompt_audio": "Ignore the file and only respond with the word 'audio'",
|
|
489
|
+
},
|
|
490
|
+
passthrough_mimetypes=[
|
|
491
|
+
# we want all mimetypes to go to litellm to be sure we're testing the API call
|
|
492
|
+
],
|
|
493
|
+
),
|
|
494
|
+
litellm_core_config=LiteLlmCoreConfig(
|
|
495
|
+
base_url="https://test.com",
|
|
496
|
+
additional_body_options={"api_key": "test-key"},
|
|
497
|
+
default_headers={},
|
|
498
|
+
),
|
|
499
|
+
)
|
|
500
|
+
|
|
501
|
+
|
|
502
|
+
@pytest.mark.parametrize(
|
|
503
|
+
"mime_type, expected_encoding",
|
|
504
|
+
[
|
|
505
|
+
# documents
|
|
506
|
+
(MockFileFactoryMimeType.PDF, "generic_file_data"),
|
|
507
|
+
(MockFileFactoryMimeType.TXT, "generic_file_data"),
|
|
508
|
+
(MockFileFactoryMimeType.MD, "generic_file_data"),
|
|
509
|
+
(MockFileFactoryMimeType.HTML, "generic_file_data"),
|
|
510
|
+
(MockFileFactoryMimeType.CSV, "generic_file_data"),
|
|
511
|
+
# images
|
|
512
|
+
(MockFileFactoryMimeType.PNG, "image_data"),
|
|
513
|
+
(MockFileFactoryMimeType.JPEG, "image_data"),
|
|
514
|
+
(MockFileFactoryMimeType.JPG, "image_data"),
|
|
515
|
+
# videos
|
|
516
|
+
(MockFileFactoryMimeType.MP4, "generic_file_data"),
|
|
517
|
+
(MockFileFactoryMimeType.MOV, "generic_file_data"),
|
|
518
|
+
# audio
|
|
519
|
+
(MockFileFactoryMimeType.MP3, "generic_file_data"),
|
|
520
|
+
(MockFileFactoryMimeType.OGG, "generic_file_data"),
|
|
521
|
+
(MockFileFactoryMimeType.WAV, "generic_file_data"),
|
|
522
|
+
],
|
|
523
|
+
)
|
|
524
|
+
def test_encode_file_litellm_format(mock_file_factory, mime_type, expected_encoding):
|
|
525
|
+
test_file = mock_file_factory(mime_type)
|
|
526
|
+
encoded = encode_file_litellm_format(Path(test_file), mime_type)
|
|
527
|
+
|
|
528
|
+
# there are two types of ways of including files, image_url is a special case
|
|
529
|
+
# and it also works with the generic file_data encoding, but LiteLLM docs are
|
|
530
|
+
# not clear on this, so best to go with the more specific image_url encoding
|
|
531
|
+
if expected_encoding == "image_data":
|
|
532
|
+
assert encoded == {
|
|
533
|
+
"type": "image_url",
|
|
534
|
+
"image_url": {
|
|
535
|
+
"url": to_base64_url(mime_type, Path(test_file).read_bytes()),
|
|
536
|
+
},
|
|
537
|
+
}
|
|
538
|
+
elif expected_encoding == "generic_file_data":
|
|
539
|
+
assert encoded == {
|
|
540
|
+
"type": "file",
|
|
541
|
+
"file": {
|
|
542
|
+
"file_data": to_base64_url(mime_type, Path(test_file).read_bytes()),
|
|
543
|
+
},
|
|
544
|
+
}
|
|
545
|
+
else:
|
|
546
|
+
raise ValueError(f"Unsupported encoding: {expected_encoding}")
|
|
547
|
+
|
|
548
|
+
|
|
549
|
+
def get_all_models_support_doc_extraction(
|
|
550
|
+
must_support_mime_types: list[str] | None = None,
|
|
551
|
+
):
|
|
552
|
+
model_provider_pairs = []
|
|
553
|
+
for model in built_in_models:
|
|
554
|
+
for provider in model.providers:
|
|
555
|
+
if not provider.model_id:
|
|
556
|
+
# it's possible for models to not have an ID (fine-tune only model)
|
|
557
|
+
continue
|
|
558
|
+
if provider.supports_doc_extraction:
|
|
559
|
+
if (
|
|
560
|
+
provider.multimodal_mime_types is None
|
|
561
|
+
or must_support_mime_types is None
|
|
562
|
+
):
|
|
563
|
+
continue
|
|
564
|
+
# check that the model supports all the mime types
|
|
565
|
+
if all(
|
|
566
|
+
mime_type in provider.multimodal_mime_types
|
|
567
|
+
for mime_type in must_support_mime_types
|
|
568
|
+
):
|
|
569
|
+
model_provider_pairs.append((model.name, provider.name))
|
|
570
|
+
return model_provider_pairs
|
|
571
|
+
|
|
572
|
+
|
|
573
|
+
@pytest.mark.paid
|
|
574
|
+
@pytest.mark.parametrize(
|
|
575
|
+
"model_name,provider_name",
|
|
576
|
+
get_all_models_support_doc_extraction(
|
|
577
|
+
must_support_mime_types=[
|
|
578
|
+
MockFileFactoryMimeType.PDF,
|
|
579
|
+
MockFileFactoryMimeType.TXT,
|
|
580
|
+
MockFileFactoryMimeType.MD,
|
|
581
|
+
MockFileFactoryMimeType.HTML,
|
|
582
|
+
MockFileFactoryMimeType.CSV,
|
|
583
|
+
MockFileFactoryMimeType.PNG,
|
|
584
|
+
MockFileFactoryMimeType.JPEG,
|
|
585
|
+
MockFileFactoryMimeType.JPG,
|
|
586
|
+
MockFileFactoryMimeType.MP4,
|
|
587
|
+
MockFileFactoryMimeType.MOV,
|
|
588
|
+
MockFileFactoryMimeType.MP3,
|
|
589
|
+
MockFileFactoryMimeType.OGG,
|
|
590
|
+
MockFileFactoryMimeType.WAV,
|
|
591
|
+
]
|
|
592
|
+
),
|
|
593
|
+
)
|
|
594
|
+
@pytest.mark.parametrize(
|
|
595
|
+
"mime_type,expected_substring_in_output",
|
|
596
|
+
[
|
|
597
|
+
# documents
|
|
598
|
+
(MockFileFactoryMimeType.PDF, "document"),
|
|
599
|
+
(MockFileFactoryMimeType.TXT, "document"),
|
|
600
|
+
(MockFileFactoryMimeType.MD, "document"),
|
|
601
|
+
(MockFileFactoryMimeType.HTML, "document"),
|
|
602
|
+
(MockFileFactoryMimeType.CSV, "document"),
|
|
603
|
+
# images
|
|
604
|
+
(MockFileFactoryMimeType.PNG, "image"),
|
|
605
|
+
(MockFileFactoryMimeType.JPEG, "image"),
|
|
606
|
+
(MockFileFactoryMimeType.JPG, "image"),
|
|
607
|
+
# videos
|
|
608
|
+
(MockFileFactoryMimeType.MP4, "video"),
|
|
609
|
+
(MockFileFactoryMimeType.MOV, "video"),
|
|
610
|
+
# audio
|
|
611
|
+
(MockFileFactoryMimeType.MP3, "audio"),
|
|
612
|
+
(MockFileFactoryMimeType.OGG, "audio"),
|
|
613
|
+
(MockFileFactoryMimeType.WAV, "audio"),
|
|
614
|
+
],
|
|
615
|
+
)
|
|
616
|
+
async def test_extract_document_success(
|
|
617
|
+
model_name,
|
|
618
|
+
provider_name,
|
|
619
|
+
mime_type,
|
|
620
|
+
expected_substring_in_output,
|
|
621
|
+
mock_file_factory,
|
|
622
|
+
):
|
|
623
|
+
test_file = mock_file_factory(mime_type)
|
|
624
|
+
extractor = paid_litellm_extractor(
|
|
625
|
+
model_name=model_name, provider_name=provider_name
|
|
626
|
+
)
|
|
627
|
+
output = await extractor.extract(
|
|
628
|
+
extraction_input=ExtractionInput(
|
|
629
|
+
path=str(test_file),
|
|
630
|
+
mime_type=mime_type,
|
|
631
|
+
)
|
|
632
|
+
)
|
|
633
|
+
assert not output.is_passthrough
|
|
634
|
+
assert output.content_format == OutputFormat.MARKDOWN
|
|
635
|
+
assert expected_substring_in_output.lower() in output.content.lower()
|
|
636
|
+
|
|
637
|
+
|
|
638
|
+
@pytest.mark.paid
|
|
639
|
+
@pytest.mark.parametrize(
|
|
640
|
+
"model_name,provider_name",
|
|
641
|
+
get_all_models_support_doc_extraction(
|
|
642
|
+
must_support_mime_types=[MockFileFactoryMimeType.PDF]
|
|
643
|
+
),
|
|
644
|
+
)
|
|
645
|
+
@pytest.mark.parametrize(
|
|
646
|
+
"mime_type,expected_substring_in_output",
|
|
647
|
+
[
|
|
648
|
+
(MockFileFactoryMimeType.PDF, "document"),
|
|
649
|
+
],
|
|
650
|
+
)
|
|
651
|
+
async def test_extract_document_success_pdf(
|
|
652
|
+
model_name,
|
|
653
|
+
provider_name,
|
|
654
|
+
mime_type,
|
|
655
|
+
expected_substring_in_output,
|
|
656
|
+
mock_file_factory,
|
|
657
|
+
):
|
|
658
|
+
test_file = mock_file_factory(mime_type)
|
|
659
|
+
extractor = paid_litellm_extractor(
|
|
660
|
+
model_name=model_name, provider_name=provider_name
|
|
661
|
+
)
|
|
662
|
+
output = await extractor.extract(
|
|
663
|
+
extraction_input=ExtractionInput(
|
|
664
|
+
path=str(test_file),
|
|
665
|
+
mime_type=mime_type,
|
|
666
|
+
)
|
|
667
|
+
)
|
|
668
|
+
assert not output.is_passthrough
|
|
669
|
+
assert output.content_format == OutputFormat.MARKDOWN
|
|
670
|
+
assert expected_substring_in_output.lower() in output.content.lower()
|
|
671
|
+
|
|
672
|
+
|
|
673
|
+
async def test_extract_pdf_page_by_page(mock_file_factory, mock_litellm_extractor):
|
|
674
|
+
"""Test that PDFs are processed page by page with page numbers in output."""
|
|
675
|
+
test_file = mock_file_factory(MockFileFactoryMimeType.PDF)
|
|
676
|
+
|
|
677
|
+
# Mock responses for each page (PDF has 2 pages)
|
|
678
|
+
mock_responses = []
|
|
679
|
+
for i in range(2): # PDF has 2 pages
|
|
680
|
+
mock_response = AsyncMock(spec=ModelResponse)
|
|
681
|
+
mock_choice = AsyncMock(spec=Choices)
|
|
682
|
+
mock_message = AsyncMock()
|
|
683
|
+
mock_message.content = f"Content from page {i + 1}"
|
|
684
|
+
mock_choice.message = mock_message
|
|
685
|
+
mock_response.choices = [mock_choice]
|
|
686
|
+
mock_responses.append(mock_response)
|
|
687
|
+
|
|
688
|
+
with patch("litellm.acompletion", side_effect=mock_responses) as mock_acompletion:
|
|
689
|
+
result = await mock_litellm_extractor.extract(
|
|
690
|
+
ExtractionInput(
|
|
691
|
+
path=str(test_file),
|
|
692
|
+
mime_type="application/pdf",
|
|
693
|
+
)
|
|
694
|
+
)
|
|
695
|
+
|
|
696
|
+
# Verify that the completion was called multiple times (once per page)
|
|
697
|
+
assert mock_acompletion.call_count == 2
|
|
698
|
+
|
|
699
|
+
# Verify the output contains content from both pages
|
|
700
|
+
assert "Content from page 1" in result.content
|
|
701
|
+
assert "Content from page 2" in result.content
|
|
702
|
+
|
|
703
|
+
assert not result.is_passthrough
|
|
704
|
+
assert result.content_format == OutputFormat.MARKDOWN
|
|
705
|
+
|
|
706
|
+
|
|
707
|
+
async def test_extract_pdf_page_by_page_error_handling(
|
|
708
|
+
mock_file_factory, mock_litellm_extractor
|
|
709
|
+
):
|
|
710
|
+
"""Test that PDF page processing handles errors gracefully."""
|
|
711
|
+
test_file = mock_file_factory(MockFileFactoryMimeType.PDF)
|
|
712
|
+
|
|
713
|
+
# Mock the first page to succeed, second to fail
|
|
714
|
+
mock_response1 = AsyncMock(spec=ModelResponse)
|
|
715
|
+
mock_choice1 = AsyncMock(spec=Choices)
|
|
716
|
+
mock_message1 = AsyncMock()
|
|
717
|
+
mock_message1.content = "Content from page 1"
|
|
718
|
+
mock_choice1.message = mock_message1
|
|
719
|
+
mock_response1.choices = [mock_choice1]
|
|
720
|
+
|
|
721
|
+
with patch(
|
|
722
|
+
"litellm.acompletion", side_effect=[mock_response1, Exception("API Error")]
|
|
723
|
+
) as mock_acompletion:
|
|
724
|
+
with pytest.raises(Exception, match="API Error"):
|
|
725
|
+
await mock_litellm_extractor.extract(
|
|
726
|
+
ExtractionInput(
|
|
727
|
+
path=str(test_file),
|
|
728
|
+
mime_type="application/pdf",
|
|
729
|
+
)
|
|
730
|
+
)
|
|
731
|
+
|
|
732
|
+
# Verify that the completion was called at least once before failing
|
|
733
|
+
assert mock_acompletion.call_count >= 1
|
|
734
|
+
|
|
735
|
+
|
|
736
|
+
@pytest.mark.paid
|
|
737
|
+
@pytest.mark.parametrize(
|
|
738
|
+
"model_name,provider_name", get_all_models_support_doc_extraction()
|
|
739
|
+
)
|
|
740
|
+
async def test_provider_bad_request(tmp_path, model_name, provider_name):
|
|
741
|
+
# write corrupted PDF file to temp files
|
|
742
|
+
temp_file = tmp_path / "corrupted_file.pdf"
|
|
743
|
+
temp_file.write_bytes(b"invalid file")
|
|
744
|
+
|
|
745
|
+
extractor = paid_litellm_extractor(
|
|
746
|
+
model_name=model_name, provider_name=provider_name
|
|
747
|
+
)
|
|
748
|
+
|
|
749
|
+
with pytest.raises(ValueError, match=r"Error extracting .*corrupted_file.pdf: "):
|
|
750
|
+
await extractor.extract(
|
|
751
|
+
extraction_input=ExtractionInput(
|
|
752
|
+
path=temp_file.as_posix(),
|
|
753
|
+
mime_type="application/pdf",
|
|
754
|
+
)
|
|
755
|
+
)
|
|
756
|
+
|
|
757
|
+
|
|
758
|
+
# Cache-related tests for PDF processing
|
|
759
|
+
@pytest.fixture
|
|
760
|
+
def mock_litellm_extractor_with_cache(tmp_path):
|
|
761
|
+
"""Create a LitellmExtractor with a filesystem cache for testing."""
|
|
762
|
+
cache_dir = tmp_path / "cache"
|
|
763
|
+
cache_dir.mkdir() # Ensure cache directory exists
|
|
764
|
+
cache = FilesystemCache(cache_dir)
|
|
765
|
+
return LitellmExtractor(
|
|
766
|
+
ExtractorConfig(
|
|
767
|
+
id="test_extractor_123", # Required for cache key generation
|
|
768
|
+
name="mock_with_cache",
|
|
769
|
+
extractor_type=ExtractorType.LITELLM,
|
|
770
|
+
model_name="gpt_4o",
|
|
771
|
+
model_provider_name="openai",
|
|
772
|
+
properties={
|
|
773
|
+
"prompt_document": PROMPTS_FOR_KIND["document"],
|
|
774
|
+
"prompt_image": PROMPTS_FOR_KIND["image"],
|
|
775
|
+
"prompt_video": PROMPTS_FOR_KIND["video"],
|
|
776
|
+
"prompt_audio": PROMPTS_FOR_KIND["audio"],
|
|
777
|
+
},
|
|
778
|
+
),
|
|
779
|
+
litellm_core_config=LiteLlmCoreConfig(
|
|
780
|
+
base_url="https://test.com",
|
|
781
|
+
additional_body_options={"api_key": "test-key"},
|
|
782
|
+
default_headers={},
|
|
783
|
+
),
|
|
784
|
+
filesystem_cache=cache,
|
|
785
|
+
)
|
|
786
|
+
|
|
787
|
+
|
|
788
|
+
@pytest.fixture
|
|
789
|
+
def mock_litellm_extractor_without_cache():
|
|
790
|
+
"""Create a LitellmExtractor without a filesystem cache for testing."""
|
|
791
|
+
return LitellmExtractor(
|
|
792
|
+
ExtractorConfig(
|
|
793
|
+
id="test_extractor_456", # Required for cache key generation
|
|
794
|
+
name="mock_without_cache",
|
|
795
|
+
extractor_type=ExtractorType.LITELLM,
|
|
796
|
+
model_name="gpt_4o",
|
|
797
|
+
model_provider_name="openai",
|
|
798
|
+
properties={
|
|
799
|
+
"prompt_document": PROMPTS_FOR_KIND["document"],
|
|
800
|
+
"prompt_image": PROMPTS_FOR_KIND["image"],
|
|
801
|
+
"prompt_video": PROMPTS_FOR_KIND["video"],
|
|
802
|
+
"prompt_audio": PROMPTS_FOR_KIND["audio"],
|
|
803
|
+
},
|
|
804
|
+
),
|
|
805
|
+
litellm_core_config=LiteLlmCoreConfig(
|
|
806
|
+
base_url="https://test.com",
|
|
807
|
+
additional_body_options={"api_key": "test-key"},
|
|
808
|
+
default_headers={},
|
|
809
|
+
),
|
|
810
|
+
filesystem_cache=None, # Explicitly no cache
|
|
811
|
+
)
|
|
812
|
+
|
|
813
|
+
|
|
814
|
+
def test_pdf_page_cache_key_generation(mock_litellm_extractor_with_cache):
|
|
815
|
+
"""Test that PDF page cache keys are generated correctly."""
|
|
816
|
+
pdf_path = Path("test_document.pdf")
|
|
817
|
+
page_number = 0
|
|
818
|
+
|
|
819
|
+
cache_key = mock_litellm_extractor_with_cache.pdf_page_cache_key(
|
|
820
|
+
pdf_path, page_number
|
|
821
|
+
)
|
|
822
|
+
|
|
823
|
+
# Should include extractor ID and a hash of the PDF name and page number
|
|
824
|
+
assert cache_key.startswith("test_extractor_123_")
|
|
825
|
+
assert len(cache_key) > len("test_extractor_123_") # Should have hash suffix
|
|
826
|
+
|
|
827
|
+
# Same PDF and page should generate same key
|
|
828
|
+
cache_key2 = mock_litellm_extractor_with_cache.pdf_page_cache_key(
|
|
829
|
+
pdf_path, page_number
|
|
830
|
+
)
|
|
831
|
+
assert cache_key == cache_key2
|
|
832
|
+
|
|
833
|
+
# Different page should generate different key
|
|
834
|
+
cache_key3 = mock_litellm_extractor_with_cache.pdf_page_cache_key(pdf_path, 1)
|
|
835
|
+
assert cache_key != cache_key3
|
|
836
|
+
|
|
837
|
+
|
|
838
|
+
def test_pdf_page_cache_key_requires_extractor_id():
|
|
839
|
+
"""Test that PDF page cache key generation requires extractor ID."""
|
|
840
|
+
extractor_config = ExtractorConfig(
|
|
841
|
+
id=None, # No ID
|
|
842
|
+
name="mock",
|
|
843
|
+
extractor_type=ExtractorType.LITELLM,
|
|
844
|
+
model_name="gpt_4o",
|
|
845
|
+
model_provider_name="openai",
|
|
846
|
+
properties={
|
|
847
|
+
"prompt_document": PROMPTS_FOR_KIND["document"],
|
|
848
|
+
"prompt_image": PROMPTS_FOR_KIND["image"],
|
|
849
|
+
"prompt_video": PROMPTS_FOR_KIND["video"],
|
|
850
|
+
"prompt_audio": PROMPTS_FOR_KIND["audio"],
|
|
851
|
+
},
|
|
852
|
+
)
|
|
853
|
+
|
|
854
|
+
extractor = LitellmExtractor(
|
|
855
|
+
extractor_config,
|
|
856
|
+
LiteLlmCoreConfig(
|
|
857
|
+
base_url="https://test.com",
|
|
858
|
+
additional_body_options={"api_key": "test-key"},
|
|
859
|
+
default_headers={},
|
|
860
|
+
),
|
|
861
|
+
)
|
|
862
|
+
|
|
863
|
+
with pytest.raises(
|
|
864
|
+
ValueError, match="Extractor config ID is required for PDF page cache key"
|
|
865
|
+
):
|
|
866
|
+
extractor.pdf_page_cache_key(Path("test.pdf"), 0)
|
|
867
|
+
|
|
868
|
+
|
|
869
|
+
async def test_extract_pdf_with_cache_storage(
|
|
870
|
+
mock_file_factory, mock_litellm_extractor_with_cache
|
|
871
|
+
):
|
|
872
|
+
"""Test that PDF extraction stores content in cache when cache is available."""
|
|
873
|
+
test_file = mock_file_factory(MockFileFactoryMimeType.PDF)
|
|
874
|
+
|
|
875
|
+
# Mock responses for each page (PDF has 2 pages)
|
|
876
|
+
mock_responses = []
|
|
877
|
+
for i in range(2):
|
|
878
|
+
mock_response = AsyncMock(spec=ModelResponse)
|
|
879
|
+
mock_choice = AsyncMock(spec=Choices)
|
|
880
|
+
mock_message = AsyncMock()
|
|
881
|
+
mock_message.content = f"Content from page {i + 1}"
|
|
882
|
+
mock_choice.message = mock_message
|
|
883
|
+
mock_response.choices = [mock_choice]
|
|
884
|
+
mock_responses.append(mock_response)
|
|
885
|
+
|
|
886
|
+
with patch("litellm.acompletion", side_effect=mock_responses) as mock_acompletion:
|
|
887
|
+
result = await mock_litellm_extractor_with_cache.extract(
|
|
888
|
+
ExtractionInput(
|
|
889
|
+
path=str(test_file),
|
|
890
|
+
mime_type="application/pdf",
|
|
891
|
+
)
|
|
892
|
+
)
|
|
893
|
+
|
|
894
|
+
# Verify that the completion was called for each page
|
|
895
|
+
assert mock_acompletion.call_count == 2
|
|
896
|
+
|
|
897
|
+
# Verify content is stored in cache
|
|
898
|
+
pdf_path = Path(test_file)
|
|
899
|
+
for i in range(2):
|
|
900
|
+
cached_content = (
|
|
901
|
+
await mock_litellm_extractor_with_cache.get_page_content_from_cache(
|
|
902
|
+
pdf_path, i
|
|
903
|
+
)
|
|
904
|
+
)
|
|
905
|
+
assert cached_content == f"Content from page {i + 1}"
|
|
906
|
+
|
|
907
|
+
# Verify the output contains content from both pages
|
|
908
|
+
assert "Content from page 1" in result.content
|
|
909
|
+
assert "Content from page 2" in result.content
|
|
910
|
+
|
|
911
|
+
|
|
912
|
+
async def test_extract_pdf_with_cache_retrieval(
|
|
913
|
+
mock_file_factory, mock_litellm_extractor_with_cache
|
|
914
|
+
):
|
|
915
|
+
"""Test that PDF extraction retrieves content from cache when available."""
|
|
916
|
+
test_file = mock_file_factory(MockFileFactoryMimeType.PDF)
|
|
917
|
+
pdf_path = Path(test_file)
|
|
918
|
+
|
|
919
|
+
# Pre-populate cache with content
|
|
920
|
+
for i in range(2):
|
|
921
|
+
cache_key = mock_litellm_extractor_with_cache.pdf_page_cache_key(pdf_path, i)
|
|
922
|
+
await mock_litellm_extractor_with_cache.filesystem_cache.set(
|
|
923
|
+
cache_key, f"Cached content from page {i + 1}".encode("utf-8")
|
|
924
|
+
)
|
|
925
|
+
|
|
926
|
+
# Mock responses (should not be called due to cache hits)
|
|
927
|
+
mock_responses = []
|
|
928
|
+
for i in range(2):
|
|
929
|
+
mock_response = AsyncMock(spec=ModelResponse)
|
|
930
|
+
mock_choice = AsyncMock(spec=Choices)
|
|
931
|
+
mock_message = AsyncMock()
|
|
932
|
+
mock_message.content = f"Fresh content from page {i + 1}"
|
|
933
|
+
mock_choice.message = mock_message
|
|
934
|
+
mock_response.choices = [mock_choice]
|
|
935
|
+
mock_responses.append(mock_response)
|
|
936
|
+
|
|
937
|
+
with patch("litellm.acompletion", side_effect=mock_responses) as mock_acompletion:
|
|
938
|
+
result = await mock_litellm_extractor_with_cache.extract(
|
|
939
|
+
ExtractionInput(
|
|
940
|
+
path=str(test_file),
|
|
941
|
+
mime_type="application/pdf",
|
|
942
|
+
)
|
|
943
|
+
)
|
|
944
|
+
|
|
945
|
+
# Verify that litellm.acompletion was NOT called (cache hits)
|
|
946
|
+
assert mock_acompletion.call_count == 0
|
|
947
|
+
|
|
948
|
+
# Verify the output contains cached content, not fresh content
|
|
949
|
+
assert "Cached content from page 1" in result.content
|
|
950
|
+
assert "Cached content from page 2" in result.content
|
|
951
|
+
assert "Fresh content from page 1" not in result.content
|
|
952
|
+
assert "Fresh content from page 2" not in result.content
|
|
953
|
+
|
|
954
|
+
|
|
955
|
+
async def test_extract_pdf_without_cache(
|
|
956
|
+
mock_file_factory, mock_litellm_extractor_without_cache
|
|
957
|
+
):
|
|
958
|
+
"""Test that PDF extraction works normally when no cache is provided."""
|
|
959
|
+
test_file = mock_file_factory(MockFileFactoryMimeType.PDF)
|
|
960
|
+
|
|
961
|
+
# Mock responses for each page (PDF has 2 pages)
|
|
962
|
+
mock_responses = []
|
|
963
|
+
for i in range(2):
|
|
964
|
+
mock_response = AsyncMock(spec=ModelResponse)
|
|
965
|
+
mock_choice = AsyncMock(spec=Choices)
|
|
966
|
+
mock_message = AsyncMock()
|
|
967
|
+
mock_message.content = f"Content from page {i + 1}"
|
|
968
|
+
mock_choice.message = mock_message
|
|
969
|
+
mock_response.choices = [mock_choice]
|
|
970
|
+
mock_responses.append(mock_response)
|
|
971
|
+
|
|
972
|
+
with patch("litellm.acompletion", side_effect=mock_responses) as mock_acompletion:
|
|
973
|
+
result = await mock_litellm_extractor_without_cache.extract(
|
|
974
|
+
ExtractionInput(
|
|
975
|
+
path=str(test_file),
|
|
976
|
+
mime_type="application/pdf",
|
|
977
|
+
)
|
|
978
|
+
)
|
|
979
|
+
|
|
980
|
+
# Verify that the completion was called for each page
|
|
981
|
+
assert mock_acompletion.call_count == 2
|
|
982
|
+
|
|
983
|
+
# Verify the output contains content from both pages
|
|
984
|
+
assert "Content from page 1" in result.content
|
|
985
|
+
assert "Content from page 2" in result.content
|
|
986
|
+
|
|
987
|
+
|
|
988
|
+
async def test_extract_pdf_mixed_cache_hits_and_misses(
|
|
989
|
+
mock_file_factory, mock_litellm_extractor_with_cache
|
|
990
|
+
):
|
|
991
|
+
"""Test PDF extraction with some pages cached and others not."""
|
|
992
|
+
test_file = mock_file_factory(MockFileFactoryMimeType.PDF)
|
|
993
|
+
pdf_path = Path(test_file)
|
|
994
|
+
|
|
995
|
+
# Pre-populate cache with only page 0 content
|
|
996
|
+
cache_key = mock_litellm_extractor_with_cache.pdf_page_cache_key(pdf_path, 0)
|
|
997
|
+
await mock_litellm_extractor_with_cache.filesystem_cache.set(
|
|
998
|
+
cache_key, "Cached content from page 1".encode("utf-8")
|
|
999
|
+
)
|
|
1000
|
+
|
|
1001
|
+
# Mock responses for page 1 only (page 0 should hit cache)
|
|
1002
|
+
mock_response = AsyncMock(spec=ModelResponse)
|
|
1003
|
+
mock_choice = AsyncMock(spec=Choices)
|
|
1004
|
+
mock_message = AsyncMock()
|
|
1005
|
+
mock_message.content = "Fresh content from page 2"
|
|
1006
|
+
mock_choice.message = mock_message
|
|
1007
|
+
mock_response.choices = [mock_choice]
|
|
1008
|
+
|
|
1009
|
+
with patch("litellm.acompletion", return_value=mock_response) as mock_acompletion:
|
|
1010
|
+
result = await mock_litellm_extractor_with_cache.extract(
|
|
1011
|
+
ExtractionInput(
|
|
1012
|
+
path=str(test_file),
|
|
1013
|
+
mime_type="application/pdf",
|
|
1014
|
+
)
|
|
1015
|
+
)
|
|
1016
|
+
|
|
1017
|
+
# Verify that litellm.acompletion was called only once (for page 1)
|
|
1018
|
+
assert mock_acompletion.call_count == 1
|
|
1019
|
+
|
|
1020
|
+
# Verify the output contains both cached and fresh content
|
|
1021
|
+
assert "Cached content from page 1" in result.content
|
|
1022
|
+
assert "Fresh content from page 2" in result.content
|
|
1023
|
+
|
|
1024
|
+
|
|
1025
|
+
async def test_extract_pdf_cache_write_failure_does_not_throw(
|
|
1026
|
+
mock_file_factory, mock_litellm_extractor_with_cache
|
|
1027
|
+
):
|
|
1028
|
+
"""Test that PDF extraction continues successfully even when cache write fails."""
|
|
1029
|
+
test_file = mock_file_factory(MockFileFactoryMimeType.PDF)
|
|
1030
|
+
|
|
1031
|
+
# Mock responses for each page (PDF has 2 pages)
|
|
1032
|
+
mock_responses = []
|
|
1033
|
+
for i in range(2):
|
|
1034
|
+
mock_response = AsyncMock(spec=ModelResponse)
|
|
1035
|
+
mock_choice = AsyncMock(spec=Choices)
|
|
1036
|
+
mock_message = AsyncMock()
|
|
1037
|
+
mock_message.content = f"Content from page {i + 1}"
|
|
1038
|
+
mock_choice.message = mock_message
|
|
1039
|
+
mock_response.choices = [mock_choice]
|
|
1040
|
+
mock_responses.append(mock_response)
|
|
1041
|
+
|
|
1042
|
+
# Mock the cache set method to raise an exception
|
|
1043
|
+
with patch.object(
|
|
1044
|
+
mock_litellm_extractor_with_cache.filesystem_cache,
|
|
1045
|
+
"set",
|
|
1046
|
+
side_effect=Exception("Cache write failed"),
|
|
1047
|
+
) as mock_cache_set:
|
|
1048
|
+
with patch(
|
|
1049
|
+
"litellm.acompletion", side_effect=mock_responses
|
|
1050
|
+
) as mock_acompletion:
|
|
1051
|
+
# This should not raise an exception despite cache write failures
|
|
1052
|
+
result = await mock_litellm_extractor_with_cache.extract(
|
|
1053
|
+
ExtractionInput(
|
|
1054
|
+
path=str(test_file),
|
|
1055
|
+
mime_type="application/pdf",
|
|
1056
|
+
)
|
|
1057
|
+
)
|
|
1058
|
+
|
|
1059
|
+
# Verify that the completion was called for each page
|
|
1060
|
+
assert mock_acompletion.call_count == 2
|
|
1061
|
+
|
|
1062
|
+
# Verify that cache.set was called for each page (and failed)
|
|
1063
|
+
assert mock_cache_set.call_count == 2
|
|
1064
|
+
|
|
1065
|
+
# Verify the output contains content from both pages despite cache failures
|
|
1066
|
+
assert "Content from page 1" in result.content
|
|
1067
|
+
assert "Content from page 2" in result.content
|
|
1068
|
+
|
|
1069
|
+
# Verify the extraction completed successfully
|
|
1070
|
+
assert not result.is_passthrough
|
|
1071
|
+
assert result.content_format == OutputFormat.MARKDOWN
|
|
1072
|
+
|
|
1073
|
+
|
|
1074
|
+
async def test_extract_pdf_cache_decode_failure_does_not_throw(
|
|
1075
|
+
mock_file_factory, mock_litellm_extractor_with_cache
|
|
1076
|
+
):
|
|
1077
|
+
"""Test that PDF extraction continues successfully even when cache decode fails."""
|
|
1078
|
+
test_file = mock_file_factory(MockFileFactoryMimeType.PDF)
|
|
1079
|
+
pdf_path = Path(test_file)
|
|
1080
|
+
|
|
1081
|
+
# Pre-populate cache with invalid UTF-8 bytes that will cause decode failure
|
|
1082
|
+
for i in range(2):
|
|
1083
|
+
cache_key = mock_litellm_extractor_with_cache.pdf_page_cache_key(pdf_path, i)
|
|
1084
|
+
# Use bytes that are not valid UTF-8 (e.g., some binary data)
|
|
1085
|
+
invalid_utf8_bytes = b"\xff\xfe\x00\x00" # Invalid UTF-8 sequence
|
|
1086
|
+
await mock_litellm_extractor_with_cache.filesystem_cache.set(
|
|
1087
|
+
cache_key, invalid_utf8_bytes
|
|
1088
|
+
)
|
|
1089
|
+
|
|
1090
|
+
# Mock responses for each page (PDF has 2 pages) - should be called due to decode failures
|
|
1091
|
+
mock_responses = []
|
|
1092
|
+
for i in range(2):
|
|
1093
|
+
mock_response = AsyncMock(spec=ModelResponse)
|
|
1094
|
+
mock_choice = AsyncMock(spec=Choices)
|
|
1095
|
+
mock_message = AsyncMock()
|
|
1096
|
+
mock_message.content = f"Content from page {i + 1}"
|
|
1097
|
+
mock_choice.message = mock_message
|
|
1098
|
+
mock_response.choices = [mock_choice]
|
|
1099
|
+
mock_responses.append(mock_response)
|
|
1100
|
+
|
|
1101
|
+
with patch("litellm.acompletion", side_effect=mock_responses) as mock_acompletion:
|
|
1102
|
+
# This should not raise an exception despite cache decode failures
|
|
1103
|
+
result = await mock_litellm_extractor_with_cache.extract(
|
|
1104
|
+
ExtractionInput(
|
|
1105
|
+
path=str(test_file),
|
|
1106
|
+
mime_type="application/pdf",
|
|
1107
|
+
)
|
|
1108
|
+
)
|
|
1109
|
+
|
|
1110
|
+
# Verify that the completion was called for each page (due to decode failures)
|
|
1111
|
+
assert mock_acompletion.call_count == 2
|
|
1112
|
+
|
|
1113
|
+
# Verify the output contains content from both pages despite cache decode failures
|
|
1114
|
+
assert "Content from page 1" in result.content
|
|
1115
|
+
assert "Content from page 2" in result.content
|
|
1116
|
+
|
|
1117
|
+
# Verify the extraction completed successfully
|
|
1118
|
+
assert not result.is_passthrough
|
|
1119
|
+
assert result.content_format == OutputFormat.MARKDOWN
|
|
1120
|
+
|
|
1121
|
+
|
|
1122
|
+
async def test_extract_pdf_parallel_processing_error_handling(
|
|
1123
|
+
mock_file_factory, mock_litellm_extractor_with_cache
|
|
1124
|
+
):
|
|
1125
|
+
"""Test that parallel processing handles errors correctly."""
|
|
1126
|
+
test_file = mock_file_factory(MockFileFactoryMimeType.PDF)
|
|
1127
|
+
|
|
1128
|
+
# Mock first page to succeed, second to fail
|
|
1129
|
+
mock_response1 = AsyncMock(spec=ModelResponse)
|
|
1130
|
+
mock_choice1 = AsyncMock(spec=Choices)
|
|
1131
|
+
mock_message1 = AsyncMock()
|
|
1132
|
+
mock_message1.content = "Success from page 1"
|
|
1133
|
+
mock_choice1.message = mock_message1
|
|
1134
|
+
mock_response1.choices = [mock_choice1]
|
|
1135
|
+
|
|
1136
|
+
with patch(
|
|
1137
|
+
"litellm.acompletion",
|
|
1138
|
+
side_effect=[mock_response1, Exception("API Error on page 2")],
|
|
1139
|
+
) as mock_acompletion:
|
|
1140
|
+
with pytest.raises(ValueError, match=r".*Page 1:.*API Error on page 2"):
|
|
1141
|
+
await mock_litellm_extractor_with_cache.extract(
|
|
1142
|
+
ExtractionInput(
|
|
1143
|
+
path=str(test_file),
|
|
1144
|
+
mime_type="application/pdf",
|
|
1145
|
+
)
|
|
1146
|
+
)
|
|
1147
|
+
|
|
1148
|
+
# Verify that both pages were attempted
|
|
1149
|
+
assert mock_acompletion.call_count == 2
|
|
1150
|
+
|
|
1151
|
+
|
|
1152
|
+
async def test_extract_pdf_parallel_processing_all_cached(
|
|
1153
|
+
mock_file_factory, mock_litellm_extractor_with_cache
|
|
1154
|
+
):
|
|
1155
|
+
"""Test parallel processing when all pages are cached."""
|
|
1156
|
+
test_file = mock_file_factory(MockFileFactoryMimeType.PDF)
|
|
1157
|
+
pdf_path = Path(test_file)
|
|
1158
|
+
|
|
1159
|
+
# Pre-populate cache with both pages
|
|
1160
|
+
for i in range(2):
|
|
1161
|
+
cache_key = mock_litellm_extractor_with_cache.pdf_page_cache_key(pdf_path, i)
|
|
1162
|
+
await mock_litellm_extractor_with_cache.filesystem_cache.set(
|
|
1163
|
+
cache_key, f"Cached content from page {i + 1}".encode("utf-8")
|
|
1164
|
+
)
|
|
1165
|
+
|
|
1166
|
+
# Mock responses (should not be called due to cache hits)
|
|
1167
|
+
mock_responses = []
|
|
1168
|
+
for i in range(2):
|
|
1169
|
+
mock_response = AsyncMock(spec=ModelResponse)
|
|
1170
|
+
mock_choice = AsyncMock(spec=Choices)
|
|
1171
|
+
mock_message = AsyncMock()
|
|
1172
|
+
mock_message.content = f"Fresh content from page {i + 1}"
|
|
1173
|
+
mock_choice.message = mock_message
|
|
1174
|
+
mock_response.choices = [mock_choice]
|
|
1175
|
+
mock_responses.append(mock_response)
|
|
1176
|
+
|
|
1177
|
+
with patch("litellm.acompletion", side_effect=mock_responses) as mock_acompletion:
|
|
1178
|
+
result = await mock_litellm_extractor_with_cache.extract(
|
|
1179
|
+
ExtractionInput(
|
|
1180
|
+
path=str(test_file),
|
|
1181
|
+
mime_type="application/pdf",
|
|
1182
|
+
)
|
|
1183
|
+
)
|
|
1184
|
+
|
|
1185
|
+
# Verify that no API calls were made (all pages cached)
|
|
1186
|
+
assert mock_acompletion.call_count == 0
|
|
1187
|
+
|
|
1188
|
+
# Verify the output contains cached content
|
|
1189
|
+
assert "Cached content from page 1" in result.content
|
|
1190
|
+
assert "Cached content from page 2" in result.content
|
|
1191
|
+
assert "Fresh content from page 1" not in result.content
|
|
1192
|
+
assert "Fresh content from page 2" not in result.content
|