kiln-ai 0.20.1__py3-none-any.whl → 0.22.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (133) hide show
  1. kiln_ai/adapters/__init__.py +6 -0
  2. kiln_ai/adapters/adapter_registry.py +43 -226
  3. kiln_ai/adapters/chunkers/__init__.py +13 -0
  4. kiln_ai/adapters/chunkers/base_chunker.py +42 -0
  5. kiln_ai/adapters/chunkers/chunker_registry.py +16 -0
  6. kiln_ai/adapters/chunkers/fixed_window_chunker.py +39 -0
  7. kiln_ai/adapters/chunkers/helpers.py +23 -0
  8. kiln_ai/adapters/chunkers/test_base_chunker.py +63 -0
  9. kiln_ai/adapters/chunkers/test_chunker_registry.py +28 -0
  10. kiln_ai/adapters/chunkers/test_fixed_window_chunker.py +346 -0
  11. kiln_ai/adapters/chunkers/test_helpers.py +75 -0
  12. kiln_ai/adapters/data_gen/test_data_gen_task.py +9 -3
  13. kiln_ai/adapters/embedding/__init__.py +0 -0
  14. kiln_ai/adapters/embedding/base_embedding_adapter.py +44 -0
  15. kiln_ai/adapters/embedding/embedding_registry.py +32 -0
  16. kiln_ai/adapters/embedding/litellm_embedding_adapter.py +199 -0
  17. kiln_ai/adapters/embedding/test_base_embedding_adapter.py +283 -0
  18. kiln_ai/adapters/embedding/test_embedding_registry.py +166 -0
  19. kiln_ai/adapters/embedding/test_litellm_embedding_adapter.py +1149 -0
  20. kiln_ai/adapters/eval/eval_runner.py +6 -2
  21. kiln_ai/adapters/eval/test_base_eval.py +1 -3
  22. kiln_ai/adapters/eval/test_g_eval.py +1 -1
  23. kiln_ai/adapters/extractors/__init__.py +18 -0
  24. kiln_ai/adapters/extractors/base_extractor.py +72 -0
  25. kiln_ai/adapters/extractors/encoding.py +20 -0
  26. kiln_ai/adapters/extractors/extractor_registry.py +44 -0
  27. kiln_ai/adapters/extractors/extractor_runner.py +112 -0
  28. kiln_ai/adapters/extractors/litellm_extractor.py +406 -0
  29. kiln_ai/adapters/extractors/test_base_extractor.py +244 -0
  30. kiln_ai/adapters/extractors/test_encoding.py +54 -0
  31. kiln_ai/adapters/extractors/test_extractor_registry.py +181 -0
  32. kiln_ai/adapters/extractors/test_extractor_runner.py +181 -0
  33. kiln_ai/adapters/extractors/test_litellm_extractor.py +1290 -0
  34. kiln_ai/adapters/fine_tune/test_dataset_formatter.py +2 -2
  35. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +2 -6
  36. kiln_ai/adapters/fine_tune/test_together_finetune.py +2 -6
  37. kiln_ai/adapters/ml_embedding_model_list.py +494 -0
  38. kiln_ai/adapters/ml_model_list.py +876 -18
  39. kiln_ai/adapters/model_adapters/litellm_adapter.py +40 -75
  40. kiln_ai/adapters/model_adapters/test_litellm_adapter.py +79 -1
  41. kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py +119 -5
  42. kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +9 -3
  43. kiln_ai/adapters/model_adapters/test_structured_output.py +9 -10
  44. kiln_ai/adapters/ollama_tools.py +69 -12
  45. kiln_ai/adapters/provider_tools.py +190 -46
  46. kiln_ai/adapters/rag/deduplication.py +49 -0
  47. kiln_ai/adapters/rag/progress.py +252 -0
  48. kiln_ai/adapters/rag/rag_runners.py +844 -0
  49. kiln_ai/adapters/rag/test_deduplication.py +195 -0
  50. kiln_ai/adapters/rag/test_progress.py +785 -0
  51. kiln_ai/adapters/rag/test_rag_runners.py +2376 -0
  52. kiln_ai/adapters/remote_config.py +80 -8
  53. kiln_ai/adapters/test_adapter_registry.py +579 -86
  54. kiln_ai/adapters/test_ml_embedding_model_list.py +239 -0
  55. kiln_ai/adapters/test_ml_model_list.py +202 -0
  56. kiln_ai/adapters/test_ollama_tools.py +340 -1
  57. kiln_ai/adapters/test_prompt_builders.py +1 -1
  58. kiln_ai/adapters/test_provider_tools.py +199 -8
  59. kiln_ai/adapters/test_remote_config.py +551 -56
  60. kiln_ai/adapters/vector_store/__init__.py +1 -0
  61. kiln_ai/adapters/vector_store/base_vector_store_adapter.py +83 -0
  62. kiln_ai/adapters/vector_store/lancedb_adapter.py +389 -0
  63. kiln_ai/adapters/vector_store/test_base_vector_store.py +160 -0
  64. kiln_ai/adapters/vector_store/test_lancedb_adapter.py +1841 -0
  65. kiln_ai/adapters/vector_store/test_vector_store_registry.py +199 -0
  66. kiln_ai/adapters/vector_store/vector_store_registry.py +33 -0
  67. kiln_ai/datamodel/__init__.py +16 -13
  68. kiln_ai/datamodel/basemodel.py +201 -4
  69. kiln_ai/datamodel/chunk.py +158 -0
  70. kiln_ai/datamodel/datamodel_enums.py +27 -0
  71. kiln_ai/datamodel/embedding.py +64 -0
  72. kiln_ai/datamodel/external_tool_server.py +206 -54
  73. kiln_ai/datamodel/extraction.py +317 -0
  74. kiln_ai/datamodel/project.py +33 -1
  75. kiln_ai/datamodel/rag.py +79 -0
  76. kiln_ai/datamodel/task.py +5 -0
  77. kiln_ai/datamodel/task_output.py +41 -11
  78. kiln_ai/datamodel/test_attachment.py +649 -0
  79. kiln_ai/datamodel/test_basemodel.py +270 -14
  80. kiln_ai/datamodel/test_chunk_models.py +317 -0
  81. kiln_ai/datamodel/test_dataset_split.py +1 -1
  82. kiln_ai/datamodel/test_datasource.py +50 -0
  83. kiln_ai/datamodel/test_embedding_models.py +448 -0
  84. kiln_ai/datamodel/test_eval_model.py +6 -6
  85. kiln_ai/datamodel/test_external_tool_server.py +534 -152
  86. kiln_ai/datamodel/test_extraction_chunk.py +206 -0
  87. kiln_ai/datamodel/test_extraction_model.py +501 -0
  88. kiln_ai/datamodel/test_rag.py +641 -0
  89. kiln_ai/datamodel/test_task.py +35 -1
  90. kiln_ai/datamodel/test_tool_id.py +187 -1
  91. kiln_ai/datamodel/test_vector_store.py +320 -0
  92. kiln_ai/datamodel/tool_id.py +58 -0
  93. kiln_ai/datamodel/vector_store.py +141 -0
  94. kiln_ai/tools/base_tool.py +12 -3
  95. kiln_ai/tools/built_in_tools/math_tools.py +12 -4
  96. kiln_ai/tools/kiln_task_tool.py +158 -0
  97. kiln_ai/tools/mcp_server_tool.py +2 -2
  98. kiln_ai/tools/mcp_session_manager.py +51 -22
  99. kiln_ai/tools/rag_tools.py +164 -0
  100. kiln_ai/tools/test_kiln_task_tool.py +527 -0
  101. kiln_ai/tools/test_mcp_server_tool.py +4 -15
  102. kiln_ai/tools/test_mcp_session_manager.py +187 -227
  103. kiln_ai/tools/test_rag_tools.py +929 -0
  104. kiln_ai/tools/test_tool_registry.py +290 -7
  105. kiln_ai/tools/tool_registry.py +69 -16
  106. kiln_ai/utils/__init__.py +3 -0
  107. kiln_ai/utils/async_job_runner.py +62 -17
  108. kiln_ai/utils/config.py +2 -2
  109. kiln_ai/utils/env.py +15 -0
  110. kiln_ai/utils/filesystem.py +14 -0
  111. kiln_ai/utils/filesystem_cache.py +60 -0
  112. kiln_ai/utils/litellm.py +94 -0
  113. kiln_ai/utils/lock.py +100 -0
  114. kiln_ai/utils/mime_type.py +38 -0
  115. kiln_ai/utils/open_ai_types.py +19 -2
  116. kiln_ai/utils/pdf_utils.py +59 -0
  117. kiln_ai/utils/test_async_job_runner.py +151 -35
  118. kiln_ai/utils/test_env.py +142 -0
  119. kiln_ai/utils/test_filesystem_cache.py +316 -0
  120. kiln_ai/utils/test_litellm.py +206 -0
  121. kiln_ai/utils/test_lock.py +185 -0
  122. kiln_ai/utils/test_mime_type.py +66 -0
  123. kiln_ai/utils/test_open_ai_types.py +88 -12
  124. kiln_ai/utils/test_pdf_utils.py +86 -0
  125. kiln_ai/utils/test_uuid.py +111 -0
  126. kiln_ai/utils/test_validation.py +524 -0
  127. kiln_ai/utils/uuid.py +9 -0
  128. kiln_ai/utils/validation.py +90 -0
  129. {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/METADATA +9 -1
  130. kiln_ai-0.22.0.dist-info/RECORD +213 -0
  131. kiln_ai-0.20.1.dist-info/RECORD +0 -138
  132. {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/WHEEL +0 -0
  133. {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -0,0 +1,1290 @@
1
+ from pathlib import Path
2
+ from unittest.mock import AsyncMock, MagicMock, patch
3
+
4
+ import pytest
5
+ from litellm.types.utils import Choices, ModelResponse
6
+
7
+ from conftest import MockFileFactoryMimeType
8
+ from kiln_ai.adapters.extractors.base_extractor import ExtractionInput, OutputFormat
9
+ from kiln_ai.adapters.extractors.encoding import to_base64_url
10
+ from kiln_ai.adapters.extractors.extractor_registry import extractor_adapter_from_type
11
+ from kiln_ai.adapters.extractors.litellm_extractor import (
12
+ ExtractorConfig,
13
+ Kind,
14
+ LitellmExtractor,
15
+ encode_file_litellm_format,
16
+ )
17
+ from kiln_ai.adapters.ml_model_list import (
18
+ built_in_models,
19
+ built_in_models_from_provider,
20
+ )
21
+ from kiln_ai.adapters.provider_tools import LiteLlmCoreConfig
22
+ from kiln_ai.datamodel.extraction import ExtractorType
23
+ from kiln_ai.utils.filesystem_cache import FilesystemCache
24
+
25
+ PROMPTS_FOR_KIND: dict[str, str] = {
26
+ "document": "prompt for documents",
27
+ "image": "prompt for images",
28
+ "video": "prompt for videos",
29
+ "audio": "prompt for audio",
30
+ }
31
+
32
+
33
+ @pytest.fixture
34
+ def mock_litellm_extractor():
35
+ return LitellmExtractor(
36
+ ExtractorConfig(
37
+ name="mock",
38
+ extractor_type=ExtractorType.LITELLM,
39
+ model_name="gpt_4o",
40
+ model_provider_name="openai",
41
+ properties={
42
+ "prompt_document": PROMPTS_FOR_KIND["document"],
43
+ "prompt_image": PROMPTS_FOR_KIND["image"],
44
+ "prompt_video": PROMPTS_FOR_KIND["video"],
45
+ "prompt_audio": PROMPTS_FOR_KIND["audio"],
46
+ },
47
+ ),
48
+ litellm_core_config=LiteLlmCoreConfig(
49
+ base_url="https://test.com",
50
+ additional_body_options={"api_key": "test-key"},
51
+ default_headers={},
52
+ ),
53
+ )
54
+
55
+
56
+ @pytest.fixture
57
+ def mock_litellm_core_config():
58
+ return LiteLlmCoreConfig(
59
+ base_url="https://test.com",
60
+ additional_body_options={"api_key": "test-key"},
61
+ default_headers={},
62
+ )
63
+
64
+
65
+ @pytest.mark.parametrize(
66
+ "mime_type, kind",
67
+ [
68
+ # documents
69
+ ("application/pdf", Kind.DOCUMENT),
70
+ ("text/markdown", Kind.DOCUMENT),
71
+ ("text/md", Kind.DOCUMENT),
72
+ ("text/plain", Kind.DOCUMENT),
73
+ ("text/html", Kind.DOCUMENT),
74
+ ("text/csv", Kind.DOCUMENT),
75
+ # images
76
+ ("image/png", Kind.IMAGE),
77
+ ("image/jpeg", Kind.IMAGE),
78
+ ("image/jpg", Kind.IMAGE),
79
+ # videos
80
+ ("video/mp4", Kind.VIDEO),
81
+ ("video/mov", Kind.VIDEO),
82
+ ("video/quicktime", Kind.VIDEO),
83
+ # audio
84
+ ("audio/mpeg", Kind.AUDIO),
85
+ ("audio/ogg", Kind.AUDIO),
86
+ ("audio/wav", Kind.AUDIO),
87
+ ],
88
+ )
89
+ def test_get_kind_from_mime_type(mock_litellm_extractor, mime_type, kind):
90
+ """Test that the kind is correctly inferred from the mime type."""
91
+ assert mock_litellm_extractor._get_kind_from_mime_type(mime_type) == kind
92
+
93
+
94
+ def test_get_kind_from_mime_type_unsupported(mock_litellm_extractor):
95
+ assert (
96
+ mock_litellm_extractor._get_kind_from_mime_type("unsupported/mimetype") is None
97
+ )
98
+
99
+
100
+ @pytest.mark.parametrize(
101
+ "mime_type, expected_content",
102
+ [
103
+ (MockFileFactoryMimeType.TXT, "Content from text file"),
104
+ (MockFileFactoryMimeType.MD, "Content from markdown file"),
105
+ (MockFileFactoryMimeType.HTML, "Content from html file"),
106
+ (MockFileFactoryMimeType.CSV, "Content from csv file"),
107
+ (MockFileFactoryMimeType.PNG, "Content from image file"),
108
+ (MockFileFactoryMimeType.JPG, "Content from image file"),
109
+ (MockFileFactoryMimeType.MP4, "Content from video file"),
110
+ (MockFileFactoryMimeType.MP3, "Content from audio file"),
111
+ ],
112
+ )
113
+ async def test_extract_success(
114
+ mock_file_factory, mock_litellm_extractor, mime_type, expected_content
115
+ ):
116
+ """Test successful extraction for non-PDF file types."""
117
+ # Create a mock file of the specified type
118
+ test_file = mock_file_factory(mime_type)
119
+
120
+ # Mock response for single file extraction
121
+ mock_response = AsyncMock(spec=ModelResponse)
122
+ mock_choice = AsyncMock(spec=Choices)
123
+ mock_message = AsyncMock()
124
+ mock_message.content = expected_content
125
+ mock_choice.message = mock_message
126
+ mock_response.choices = [mock_choice]
127
+
128
+ with patch("litellm.acompletion", return_value=mock_response) as mock_acompletion:
129
+ result = await mock_litellm_extractor.extract(
130
+ ExtractionInput(
131
+ path=str(test_file),
132
+ mime_type=mime_type.value,
133
+ )
134
+ )
135
+
136
+ # Verify that the completion was called once (single file)
137
+ assert mock_acompletion.call_count == 1
138
+
139
+ # Verify the output contains the expected content
140
+ assert expected_content in result.content
141
+
142
+ assert not result.is_passthrough
143
+ assert result.content_format == OutputFormat.MARKDOWN
144
+
145
+
146
+ def test_build_completion_kwargs_with_all_options(mock_file_factory):
147
+ """Test that _build_completion_kwargs properly includes all litellm_core_config options."""
148
+ litellm_core_config = LiteLlmCoreConfig(
149
+ base_url="https://custom-api.example.com",
150
+ additional_body_options={"custom_param": "value", "timeout": "30"},
151
+ default_headers={"Authorization": "Bearer custom-token"},
152
+ )
153
+
154
+ extractor = LitellmExtractor(
155
+ ExtractorConfig(
156
+ name="test",
157
+ extractor_type=ExtractorType.LITELLM,
158
+ model_name="gpt_4o",
159
+ model_provider_name="openai",
160
+ properties={
161
+ "prompt_document": "prompt for documents",
162
+ "prompt_image": "prompt for images",
163
+ "prompt_video": "prompt for videos",
164
+ "prompt_audio": "prompt for audio",
165
+ },
166
+ ),
167
+ litellm_core_config=litellm_core_config,
168
+ )
169
+
170
+ extraction_input = ExtractionInput(
171
+ path=str(mock_file_factory(MockFileFactoryMimeType.PDF)),
172
+ mime_type="application/pdf",
173
+ )
174
+
175
+ completion_kwargs = extractor._build_completion_kwargs(
176
+ "test prompt", extraction_input
177
+ )
178
+
179
+ # Verify all completion kwargs are included
180
+ assert completion_kwargs["base_url"] == "https://custom-api.example.com"
181
+ assert completion_kwargs["custom_param"] == "value"
182
+ assert completion_kwargs["timeout"] == "30"
183
+ assert completion_kwargs["default_headers"] == {
184
+ "Authorization": "Bearer custom-token"
185
+ }
186
+
187
+ # Verify basic structure is maintained
188
+ assert "model" in completion_kwargs
189
+ assert "messages" in completion_kwargs
190
+
191
+
192
+ def test_build_completion_kwargs_with_partial_options(mock_file_factory):
193
+ """Test that _build_completion_kwargs works when only some options are set."""
194
+ litellm_core_config = LiteLlmCoreConfig(
195
+ base_url=None,
196
+ additional_body_options={"timeout": "30"},
197
+ default_headers=None,
198
+ )
199
+
200
+ extractor = LitellmExtractor(
201
+ ExtractorConfig(
202
+ name="test",
203
+ extractor_type=ExtractorType.LITELLM,
204
+ model_name="gpt_4o",
205
+ model_provider_name="openai",
206
+ properties={
207
+ "prompt_document": "prompt for documents",
208
+ "prompt_image": "prompt for images",
209
+ "prompt_video": "prompt for videos",
210
+ "prompt_audio": "prompt for audio",
211
+ },
212
+ ),
213
+ litellm_core_config=litellm_core_config,
214
+ )
215
+
216
+ extraction_input = ExtractionInput(
217
+ path=str(mock_file_factory(MockFileFactoryMimeType.PDF)),
218
+ mime_type="application/pdf",
219
+ )
220
+
221
+ completion_kwargs = extractor._build_completion_kwargs(
222
+ "test prompt", extraction_input
223
+ )
224
+
225
+ # Verify only the set options are included
226
+ assert completion_kwargs["timeout"] == "30"
227
+ assert "base_url" not in completion_kwargs
228
+ assert "default_headers" not in completion_kwargs
229
+
230
+ # Verify basic structure is maintained
231
+ assert "model" in completion_kwargs
232
+ assert "messages" in completion_kwargs
233
+
234
+
235
+ def test_build_completion_kwargs_with_empty_options(mock_file_factory):
236
+ """Test that _build_completion_kwargs works when all options are None/empty."""
237
+ litellm_core_config = LiteLlmCoreConfig(
238
+ base_url=None,
239
+ additional_body_options=None,
240
+ default_headers=None,
241
+ )
242
+
243
+ extractor = LitellmExtractor(
244
+ ExtractorConfig(
245
+ name="test",
246
+ extractor_type=ExtractorType.LITELLM,
247
+ model_name="gpt_4o",
248
+ model_provider_name="openai",
249
+ properties={
250
+ "prompt_document": "prompt for documents",
251
+ "prompt_image": "prompt for images",
252
+ "prompt_video": "prompt for videos",
253
+ "prompt_audio": "prompt for audio",
254
+ },
255
+ ),
256
+ litellm_core_config=litellm_core_config,
257
+ )
258
+
259
+ extraction_input = ExtractionInput(
260
+ path=str(mock_file_factory(MockFileFactoryMimeType.PDF)),
261
+ mime_type="application/pdf",
262
+ )
263
+
264
+ completion_kwargs = extractor._build_completion_kwargs(
265
+ "test prompt", extraction_input
266
+ )
267
+
268
+ # Verify no completion kwargs are included
269
+ assert "base_url" not in completion_kwargs
270
+ assert "default_headers" not in completion_kwargs
271
+
272
+ # Verify basic structure is maintained
273
+ assert "model" in completion_kwargs
274
+ assert "messages" in completion_kwargs
275
+
276
+
277
+ def test_build_completion_kwargs_messages_structure(mock_file_factory):
278
+ """Test that the messages structure in completion_kwargs is correct."""
279
+ litellm_core_config = LiteLlmCoreConfig()
280
+
281
+ extractor = LitellmExtractor(
282
+ ExtractorConfig(
283
+ name="test",
284
+ extractor_type=ExtractorType.LITELLM,
285
+ model_name="gpt_4o",
286
+ model_provider_name="openai",
287
+ properties={
288
+ "prompt_document": "prompt for documents",
289
+ "prompt_image": "prompt for images",
290
+ "prompt_video": "prompt for videos",
291
+ "prompt_audio": "prompt for audio",
292
+ },
293
+ ),
294
+ litellm_core_config=litellm_core_config,
295
+ )
296
+
297
+ extraction_input = ExtractionInput(
298
+ path=str(mock_file_factory(MockFileFactoryMimeType.PDF)),
299
+ mime_type="application/pdf",
300
+ )
301
+
302
+ completion_kwargs = extractor._build_completion_kwargs(
303
+ "test prompt", extraction_input
304
+ )
305
+
306
+ # Verify messages structure
307
+ messages = completion_kwargs["messages"]
308
+ assert len(messages) == 1
309
+ assert messages[0]["role"] == "user"
310
+
311
+ content = messages[0]["content"]
312
+ assert len(content) == 2
313
+
314
+ # First content item should be text
315
+ assert content[0]["type"] == "text"
316
+ assert content[0]["text"] == "test prompt"
317
+
318
+ # Second content item should be file
319
+ assert content[1]["type"] == "file"
320
+ assert "file" in content[1]
321
+ assert "file_data" in content[1]["file"]
322
+
323
+
324
+ async def test_extract_failure_from_litellm(mock_file_factory, mock_litellm_extractor):
325
+ test_pdf_file = mock_file_factory(MockFileFactoryMimeType.PDF)
326
+
327
+ with (
328
+ patch("pathlib.Path.read_bytes", return_value=b"test content"),
329
+ patch("litellm.acompletion", side_effect=Exception("error from litellm")),
330
+ patch(
331
+ "kiln_ai.adapters.extractors.litellm_extractor.LitellmExtractor.litellm_model_slug",
332
+ return_value="provider-name/model-name",
333
+ ),
334
+ ):
335
+ # Mock litellm to raise an exception
336
+ with pytest.raises(Exception, match="error from litellm"):
337
+ await mock_litellm_extractor.extract(
338
+ extraction_input=ExtractionInput(
339
+ path=str(test_pdf_file),
340
+ mime_type="application/pdf",
341
+ )
342
+ )
343
+
344
+
345
+ async def test_extract_failure_from_bytes_read(mock_litellm_extractor):
346
+ with (
347
+ patch(
348
+ "mimetypes.guess_type",
349
+ return_value=("application/pdf", None),
350
+ ),
351
+ patch(
352
+ "pathlib.Path.read_bytes",
353
+ side_effect=Exception("error from read_bytes"),
354
+ ),
355
+ patch(
356
+ "kiln_ai.adapters.extractors.litellm_extractor.LitellmExtractor.litellm_model_slug",
357
+ return_value="provider-name/model-name",
358
+ ),
359
+ patch(
360
+ "kiln_ai.adapters.extractors.litellm_extractor.split_pdf_into_pages",
361
+ side_effect=Exception("error from split_pdf_into_pages"),
362
+ ),
363
+ ):
364
+ # test the extract method
365
+ with pytest.raises(
366
+ ValueError,
367
+ match=r"Error extracting test.pdf: error from split_pdf_into_pages",
368
+ ):
369
+ await mock_litellm_extractor.extract(
370
+ extraction_input=ExtractionInput(
371
+ path="test.pdf",
372
+ mime_type="application/pdf",
373
+ )
374
+ )
375
+
376
+
377
+ async def test_extract_failure_unsupported_mime_type(mock_litellm_extractor):
378
+ # spy on the get mime type
379
+ with patch(
380
+ "mimetypes.guess_type",
381
+ return_value=(None, None),
382
+ ):
383
+ with pytest.raises(ValueError, match="Unsupported MIME type"):
384
+ await mock_litellm_extractor.extract(
385
+ extraction_input=ExtractionInput(
386
+ path="test.unsupported",
387
+ mime_type="unsupported/mimetype",
388
+ )
389
+ )
390
+
391
+
392
+ def test_litellm_model_slug_success(mock_litellm_extractor):
393
+ """Test that litellm_model_slug returns the correct model slug."""
394
+ # Mock the built_in_models_from_provider function to return a valid model provider
395
+ mock_model_provider = AsyncMock()
396
+ mock_model_provider.name = "test-provider"
397
+
398
+ # Mock the get_litellm_provider_info function to return provider info with model ID
399
+ mock_provider_info = AsyncMock()
400
+ mock_provider_info.litellm_model_id = "test-provider/test-model"
401
+
402
+ with (
403
+ patch(
404
+ "kiln_ai.adapters.extractors.litellm_extractor.built_in_models_from_provider",
405
+ return_value=mock_model_provider,
406
+ ) as mock_built_in_models,
407
+ patch(
408
+ "kiln_ai.adapters.extractors.litellm_extractor.get_litellm_provider_info",
409
+ return_value=mock_provider_info,
410
+ ) as mock_get_provider_info,
411
+ ):
412
+ result = mock_litellm_extractor.litellm_model_slug
413
+
414
+ assert result == "test-provider/test-model"
415
+
416
+ # Verify the functions were called with correct arguments
417
+ mock_built_in_models.assert_called_once()
418
+ mock_get_provider_info.assert_called_once_with(mock_model_provider)
419
+
420
+
421
+ @pytest.mark.parametrize(
422
+ "max_parallel_requests, expected_result",
423
+ [
424
+ (10, 10),
425
+ (0, 0),
426
+ # 5 is the current default, it may change in the future if we have
427
+ # a better modeling of rate limit constraints
428
+ (None, 5),
429
+ ],
430
+ )
431
+ def test_litellm_model_max_parallel_requests(
432
+ mock_litellm_extractor, max_parallel_requests, expected_result
433
+ ):
434
+ """Test that max_parallel_requests_for_model returns the provider's limit."""
435
+ # Mock the built_in_models_from_provider function to return a valid model provider
436
+ mock_model_provider = MagicMock()
437
+ mock_model_provider.name = "test-provider"
438
+ mock_model_provider.max_parallel_requests = max_parallel_requests
439
+
440
+ with (
441
+ patch(
442
+ "kiln_ai.adapters.extractors.litellm_extractor.built_in_models_from_provider",
443
+ return_value=mock_model_provider,
444
+ ) as mock_built_in_models,
445
+ ):
446
+ result = mock_litellm_extractor.max_parallel_requests_for_model
447
+
448
+ assert result == expected_result
449
+
450
+ mock_built_in_models.assert_called_once()
451
+
452
+
453
+ def test_litellm_model_slug_model_provider_not_found(mock_litellm_extractor):
454
+ """Test that litellm_model_slug raises ValueError when model provider is not found."""
455
+ with patch(
456
+ "kiln_ai.adapters.extractors.litellm_extractor.built_in_models_from_provider",
457
+ return_value=None,
458
+ ):
459
+ with pytest.raises(
460
+ ValueError,
461
+ match="Model provider openai not found in the list of built-in models",
462
+ ):
463
+ mock_litellm_extractor.litellm_model_slug
464
+
465
+
466
+ def test_litellm_model_slug_with_different_provider_names(mock_litellm_core_config):
467
+ """Test litellm_model_slug with different provider and model combinations."""
468
+ test_cases = [
469
+ ("anthropic", "claude-3-sonnet", "anthropic/claude-3-sonnet"),
470
+ ("openai", "gpt-4", "openai/gpt-4"),
471
+ ("gemini_api", "gemini-pro", "gemini_api/gemini-pro"),
472
+ ]
473
+
474
+ for provider_name, model_name, expected_slug in test_cases:
475
+ extractor = LitellmExtractor(
476
+ ExtractorConfig(
477
+ name="test",
478
+ extractor_type=ExtractorType.LITELLM,
479
+ model_name=model_name,
480
+ model_provider_name=provider_name,
481
+ properties={
482
+ "prompt_document": "test prompt",
483
+ "prompt_image": "test prompt",
484
+ "prompt_video": "test prompt",
485
+ "prompt_audio": "test prompt",
486
+ },
487
+ ),
488
+ litellm_core_config=mock_litellm_core_config,
489
+ )
490
+
491
+ mock_model_provider = AsyncMock()
492
+ mock_model_provider.name = provider_name
493
+
494
+ mock_provider_info = AsyncMock()
495
+ mock_provider_info.litellm_model_id = expected_slug
496
+
497
+ with (
498
+ patch(
499
+ "kiln_ai.adapters.extractors.litellm_extractor.built_in_models_from_provider",
500
+ return_value=mock_model_provider,
501
+ ),
502
+ patch(
503
+ "kiln_ai.adapters.extractors.litellm_extractor.get_litellm_provider_info",
504
+ return_value=mock_provider_info,
505
+ ),
506
+ ):
507
+ result = extractor.litellm_model_slug
508
+ assert result == expected_slug
509
+
510
+
511
+ def paid_litellm_extractor(model_name: str, provider_name: str):
512
+ extractor = extractor_adapter_from_type(
513
+ ExtractorType.LITELLM,
514
+ ExtractorConfig(
515
+ name="paid-litellm",
516
+ extractor_type=ExtractorType.LITELLM,
517
+ model_provider_name=provider_name,
518
+ model_name=model_name,
519
+ properties={
520
+ "prompt_document": "Ignore the file and only respond with the word 'document'",
521
+ "prompt_image": "Ignore the file and only respond with the word 'image'",
522
+ "prompt_video": "Ignore the file and only respond with the word 'video'",
523
+ "prompt_audio": "Ignore the file and only respond with the word 'audio'",
524
+ },
525
+ passthrough_mimetypes=[OutputFormat.MARKDOWN, OutputFormat.TEXT],
526
+ ),
527
+ )
528
+ return extractor
529
+
530
+
531
+ @pytest.mark.parametrize(
532
+ "mime_type, expected_encoding",
533
+ [
534
+ # documents
535
+ (MockFileFactoryMimeType.PDF, "generic_file_data"),
536
+ (MockFileFactoryMimeType.TXT, "generic_file_data"),
537
+ (MockFileFactoryMimeType.MD, "generic_file_data"),
538
+ (MockFileFactoryMimeType.HTML, "generic_file_data"),
539
+ (MockFileFactoryMimeType.CSV, "generic_file_data"),
540
+ # images
541
+ (MockFileFactoryMimeType.PNG, "image_data"),
542
+ (MockFileFactoryMimeType.JPEG, "image_data"),
543
+ (MockFileFactoryMimeType.JPG, "image_data"),
544
+ # videos
545
+ (MockFileFactoryMimeType.MP4, "generic_file_data"),
546
+ (MockFileFactoryMimeType.MOV, "generic_file_data"),
547
+ # audio
548
+ (MockFileFactoryMimeType.MP3, "generic_file_data"),
549
+ (MockFileFactoryMimeType.OGG, "generic_file_data"),
550
+ (MockFileFactoryMimeType.WAV, "generic_file_data"),
551
+ ],
552
+ )
553
+ def test_encode_file_litellm_format(mock_file_factory, mime_type, expected_encoding):
554
+ test_file = mock_file_factory(mime_type)
555
+ encoded = encode_file_litellm_format(Path(test_file), mime_type)
556
+
557
+ # there are two types of ways of including files, image_url is a special case
558
+ # and it also works with the generic file_data encoding, but LiteLLM docs are
559
+ # not clear on this, so best to go with the more specific image_url encoding
560
+ if expected_encoding == "image_data":
561
+ assert encoded == {
562
+ "type": "image_url",
563
+ "image_url": {
564
+ "url": to_base64_url(mime_type, Path(test_file).read_bytes()),
565
+ },
566
+ }
567
+ elif expected_encoding == "generic_file_data":
568
+ assert encoded == {
569
+ "type": "file",
570
+ "file": {
571
+ "file_data": to_base64_url(mime_type, Path(test_file).read_bytes()),
572
+ },
573
+ }
574
+ else:
575
+ raise ValueError(f"Unsupported encoding: {expected_encoding}")
576
+
577
+
578
+ def get_all_models_support_doc_extraction(
579
+ must_support_mime_types: list[str] | None = None,
580
+ ):
581
+ model_provider_pairs = []
582
+ for model in built_in_models:
583
+ for provider in model.providers:
584
+ if not provider.model_id:
585
+ # it's possible for models to not have an ID (fine-tune only model)
586
+ continue
587
+ if provider.supports_doc_extraction:
588
+ if (
589
+ provider.multimodal_mime_types is None
590
+ or must_support_mime_types is None
591
+ ):
592
+ model_provider_pairs.append((model.name, provider.name))
593
+ continue
594
+ # check that the model supports all the mime types
595
+ if all(
596
+ mime_type in provider.multimodal_mime_types
597
+ for mime_type in must_support_mime_types
598
+ ):
599
+ model_provider_pairs.append((model.name, provider.name))
600
+ return model_provider_pairs
601
+
602
+
603
+ @pytest.mark.paid
604
+ @pytest.mark.parametrize(
605
+ "model_name,provider_name",
606
+ get_all_models_support_doc_extraction(must_support_mime_types=None),
607
+ )
608
+ @pytest.mark.parametrize(
609
+ "mime_type,expected_substring_in_output",
610
+ [
611
+ # documents
612
+ (MockFileFactoryMimeType.PDF, "document"),
613
+ (MockFileFactoryMimeType.TXT, "document"),
614
+ (MockFileFactoryMimeType.MD, "document"),
615
+ (MockFileFactoryMimeType.HTML, "document"),
616
+ (MockFileFactoryMimeType.CSV, "document"),
617
+ # images
618
+ (MockFileFactoryMimeType.PNG, "image"),
619
+ (MockFileFactoryMimeType.JPEG, "image"),
620
+ (MockFileFactoryMimeType.JPG, "image"),
621
+ # videos
622
+ (MockFileFactoryMimeType.MP4, "video"),
623
+ (MockFileFactoryMimeType.MOV, "video"),
624
+ # audio
625
+ (MockFileFactoryMimeType.MP3, "audio"),
626
+ (MockFileFactoryMimeType.OGG, "audio"),
627
+ (MockFileFactoryMimeType.WAV, "audio"),
628
+ ],
629
+ )
630
+ async def test_extract_document_success(
631
+ model_name,
632
+ provider_name,
633
+ mime_type,
634
+ expected_substring_in_output,
635
+ mock_file_factory,
636
+ ):
637
+ # get model
638
+ model = built_in_models_from_provider(provider_name, model_name)
639
+ assert model is not None
640
+ if mime_type not in model.multimodal_mime_types:
641
+ pytest.skip(f"Model {model_name} configured to not support {mime_type}")
642
+ if (
643
+ mime_type == MockFileFactoryMimeType.MD
644
+ or mime_type == MockFileFactoryMimeType.TXT
645
+ ):
646
+ pytest.skip(f"Model {model_name} configured to passthrough {mime_type}")
647
+
648
+ test_file = mock_file_factory(mime_type)
649
+ extractor = paid_litellm_extractor(
650
+ model_name=model_name, provider_name=provider_name
651
+ )
652
+ output = await extractor.extract(
653
+ extraction_input=ExtractionInput(
654
+ path=str(test_file),
655
+ mime_type=mime_type,
656
+ )
657
+ )
658
+ assert not output.is_passthrough
659
+ assert output.content_format == OutputFormat.MARKDOWN
660
+ assert expected_substring_in_output.lower() in output.content.lower()
661
+
662
+
663
+ async def test_extract_pdf_page_by_page(mock_file_factory, mock_litellm_extractor):
664
+ """Test that PDFs are processed page by page with page numbers in output."""
665
+ test_file = mock_file_factory(MockFileFactoryMimeType.PDF)
666
+
667
+ # Mock responses for each page (PDF has 2 pages)
668
+ mock_responses = []
669
+ for i in range(2): # PDF has 2 pages
670
+ mock_response = AsyncMock(spec=ModelResponse)
671
+ mock_choice = AsyncMock(spec=Choices)
672
+ mock_message = AsyncMock()
673
+ mock_message.content = f"Content from page {i + 1}"
674
+ mock_choice.message = mock_message
675
+ mock_response.choices = [mock_choice]
676
+ mock_responses.append(mock_response)
677
+
678
+ with patch("litellm.acompletion", side_effect=mock_responses) as mock_acompletion:
679
+ result = await mock_litellm_extractor.extract(
680
+ ExtractionInput(
681
+ path=str(test_file),
682
+ mime_type="application/pdf",
683
+ )
684
+ )
685
+
686
+ # Verify that the completion was called multiple times (once per page)
687
+ assert mock_acompletion.call_count == 2
688
+
689
+ # Verify the output contains content from both pages
690
+ assert "Content from page 1" in result.content
691
+ assert "Content from page 2" in result.content
692
+
693
+ assert not result.is_passthrough
694
+ assert result.content_format == OutputFormat.MARKDOWN
695
+
696
+
697
+ async def test_extract_pdf_page_by_page_pdf_as_image(
698
+ mock_file_factory, mock_litellm_extractor, tmp_path
699
+ ):
700
+ """Test that PDFs are processed page by page as images if the model requires it."""
701
+
702
+ test_file = mock_file_factory(MockFileFactoryMimeType.PDF)
703
+
704
+ # Mock responses for each page (PDF has 2 pages)
705
+ mock_responses = []
706
+ for i in range(2): # PDF has 2 pages
707
+ mock_response = AsyncMock(spec=ModelResponse)
708
+ mock_choice = AsyncMock(spec=Choices)
709
+ mock_message = AsyncMock()
710
+ mock_message.content = f"Content from page {i + 1}"
711
+ mock_choice.message = mock_message
712
+ mock_response.choices = [mock_choice]
713
+ mock_responses.append(mock_response)
714
+
715
+ mock_image_path = tmp_path / "img-test_document-mock.png"
716
+ mock_image_path.write_bytes(b"test image")
717
+
718
+ with patch("litellm.acompletion", side_effect=mock_responses) as mock_acompletion:
719
+ # this model requires PDFs to be processed as images
720
+ mock_litellm_extractor.model_provider.multimodal_requires_pdf_as_image = True
721
+
722
+ with patch(
723
+ "kiln_ai.adapters.extractors.litellm_extractor.convert_pdf_to_images",
724
+ return_value=[mock_image_path],
725
+ ) as mock_convert:
726
+ result = await mock_litellm_extractor.extract(
727
+ ExtractionInput(
728
+ path=str(test_file),
729
+ mime_type="application/pdf",
730
+ )
731
+ )
732
+
733
+ # Verify image conversion called once per page
734
+ assert mock_convert.call_count == 2
735
+
736
+ # Verify LiteLLM was called with image inputs (not PDF) for each page
737
+ for call in mock_acompletion.call_args_list:
738
+ kwargs = call.kwargs
739
+ content = kwargs["messages"][0]["content"]
740
+ assert content[1]["type"] == "image_url"
741
+
742
+ # Verify that the completion was called multiple times (once per page)
743
+ assert mock_acompletion.call_count == 2
744
+
745
+ # Verify the output contains content from both pages
746
+ assert "Content from page 1" in result.content
747
+ assert "Content from page 2" in result.content
748
+
749
+ assert not result.is_passthrough
750
+ assert result.content_format == OutputFormat.MARKDOWN
751
+
752
+
753
+ async def test_convert_pdf_page_to_image_input_success(
754
+ mock_litellm_extractor, tmp_path
755
+ ):
756
+ page_dir = tmp_path / "pages"
757
+ page_dir.mkdir()
758
+ page_path = page_dir / "page_1.pdf"
759
+ page_path.write_bytes(b"%PDF-1.4 test")
760
+
761
+ mock_image_path = page_dir / "img-page_1.pdf-0.png"
762
+ mock_image_path.write_bytes(b"image-bytes")
763
+
764
+ with patch(
765
+ "kiln_ai.adapters.extractors.litellm_extractor.convert_pdf_to_images",
766
+ return_value=[mock_image_path],
767
+ ):
768
+ extraction_input = await mock_litellm_extractor.convert_pdf_page_to_image_input(
769
+ page_path, 0
770
+ )
771
+
772
+ assert extraction_input.mime_type == "image/png"
773
+ assert Path(extraction_input.path) == mock_image_path
774
+
775
+
776
+ @pytest.mark.parametrize("returned_count", [0, 2])
777
+ async def test_convert_pdf_page_to_image_input_error_on_invalid_count(
778
+ mock_litellm_extractor, tmp_path, returned_count
779
+ ):
780
+ page_dir = tmp_path / "pages"
781
+ page_dir.mkdir()
782
+ page_path = page_dir / "page_1.pdf"
783
+ page_path.write_bytes(b"%PDF-1.4 test")
784
+
785
+ image_paths = []
786
+ if returned_count == 2:
787
+ img1 = page_dir / "img-page_1.pdf-0.png"
788
+ img2 = page_dir / "img-page_1.pdf-1.png"
789
+ img1.write_bytes(b"i1")
790
+ img2.write_bytes(b"i2")
791
+ image_paths = [img1, img2]
792
+
793
+ with patch(
794
+ "kiln_ai.adapters.extractors.litellm_extractor.convert_pdf_to_images",
795
+ return_value=image_paths,
796
+ ):
797
+ with pytest.raises(ValueError, match=r"Expected 1 image, got "):
798
+ await mock_litellm_extractor.convert_pdf_page_to_image_input(page_path, 0)
799
+
800
+
801
+ async def test_extract_pdf_page_by_page_error_handling(
802
+ mock_file_factory, mock_litellm_extractor
803
+ ):
804
+ """Test that PDF page processing handles errors gracefully."""
805
+ test_file = mock_file_factory(MockFileFactoryMimeType.PDF)
806
+
807
+ # Mock the first page to succeed, second to fail
808
+ mock_response1 = AsyncMock(spec=ModelResponse)
809
+ mock_choice1 = AsyncMock(spec=Choices)
810
+ mock_message1 = AsyncMock()
811
+ mock_message1.content = "Content from page 1"
812
+ mock_choice1.message = mock_message1
813
+ mock_response1.choices = [mock_choice1]
814
+
815
+ with patch(
816
+ "litellm.acompletion", side_effect=[mock_response1, Exception("API Error")]
817
+ ) as mock_acompletion:
818
+ with pytest.raises(Exception, match="API Error"):
819
+ await mock_litellm_extractor.extract(
820
+ ExtractionInput(
821
+ path=str(test_file),
822
+ mime_type="application/pdf",
823
+ )
824
+ )
825
+
826
+ # Verify that the completion was called at least once before failing
827
+ assert mock_acompletion.call_count >= 1
828
+
829
+
830
+ @pytest.mark.paid
831
+ @pytest.mark.parametrize(
832
+ "model_name,provider_name", get_all_models_support_doc_extraction()
833
+ )
834
+ async def test_provider_bad_request(tmp_path, model_name, provider_name):
835
+ # write corrupted PDF file to temp files
836
+ temp_file = tmp_path / "corrupted_file.pdf"
837
+ temp_file.write_bytes(b"invalid file")
838
+
839
+ extractor = paid_litellm_extractor(
840
+ model_name=model_name, provider_name=provider_name
841
+ )
842
+
843
+ with pytest.raises(ValueError, match=r"Error extracting .*corrupted_file.pdf: "):
844
+ await extractor.extract(
845
+ extraction_input=ExtractionInput(
846
+ path=temp_file.as_posix(),
847
+ mime_type="application/pdf",
848
+ )
849
+ )
850
+
851
+
852
+ # Cache-related tests for PDF processing
853
+ @pytest.fixture
854
+ def mock_litellm_extractor_with_cache(tmp_path):
855
+ """Create a LitellmExtractor with a filesystem cache for testing."""
856
+ cache_dir = tmp_path / "cache"
857
+ cache_dir.mkdir() # Ensure cache directory exists
858
+ cache = FilesystemCache(cache_dir)
859
+ return LitellmExtractor(
860
+ ExtractorConfig(
861
+ id="test_extractor_123", # Required for cache key generation
862
+ name="mock_with_cache",
863
+ extractor_type=ExtractorType.LITELLM,
864
+ model_name="gpt_4o",
865
+ model_provider_name="openai",
866
+ properties={
867
+ "prompt_document": PROMPTS_FOR_KIND["document"],
868
+ "prompt_image": PROMPTS_FOR_KIND["image"],
869
+ "prompt_video": PROMPTS_FOR_KIND["video"],
870
+ "prompt_audio": PROMPTS_FOR_KIND["audio"],
871
+ },
872
+ ),
873
+ litellm_core_config=LiteLlmCoreConfig(
874
+ base_url="https://test.com",
875
+ additional_body_options={"api_key": "test-key"},
876
+ default_headers={},
877
+ ),
878
+ filesystem_cache=cache,
879
+ )
880
+
881
+
882
+ @pytest.fixture
883
+ def mock_litellm_extractor_without_cache():
884
+ """Create a LitellmExtractor without a filesystem cache for testing."""
885
+ return LitellmExtractor(
886
+ ExtractorConfig(
887
+ id="test_extractor_456", # Required for cache key generation
888
+ name="mock_without_cache",
889
+ extractor_type=ExtractorType.LITELLM,
890
+ model_name="gpt_4o",
891
+ model_provider_name="openai",
892
+ properties={
893
+ "prompt_document": PROMPTS_FOR_KIND["document"],
894
+ "prompt_image": PROMPTS_FOR_KIND["image"],
895
+ "prompt_video": PROMPTS_FOR_KIND["video"],
896
+ "prompt_audio": PROMPTS_FOR_KIND["audio"],
897
+ },
898
+ ),
899
+ litellm_core_config=LiteLlmCoreConfig(
900
+ base_url="https://test.com",
901
+ additional_body_options={"api_key": "test-key"},
902
+ default_headers={},
903
+ ),
904
+ filesystem_cache=None, # Explicitly no cache
905
+ )
906
+
907
+
908
+ def test_pdf_page_cache_key_generation(mock_litellm_extractor_with_cache):
909
+ """Test that PDF page cache keys are generated correctly."""
910
+ pdf_path = Path("test_document.pdf")
911
+ page_number = 0
912
+
913
+ cache_key = mock_litellm_extractor_with_cache.pdf_page_cache_key(
914
+ pdf_path, page_number
915
+ )
916
+
917
+ # Should include extractor ID and a hash of the PDF name and page number
918
+ assert cache_key.startswith("test_extractor_123_")
919
+ assert len(cache_key) > len("test_extractor_123_") # Should have hash suffix
920
+
921
+ # Same PDF and page should generate same key
922
+ cache_key2 = mock_litellm_extractor_with_cache.pdf_page_cache_key(
923
+ pdf_path, page_number
924
+ )
925
+ assert cache_key == cache_key2
926
+
927
+ # Different page should generate different key
928
+ cache_key3 = mock_litellm_extractor_with_cache.pdf_page_cache_key(pdf_path, 1)
929
+ assert cache_key != cache_key3
930
+
931
+
932
+ def test_pdf_page_cache_key_requires_extractor_id():
933
+ """Test that PDF page cache key generation requires extractor ID."""
934
+ extractor_config = ExtractorConfig(
935
+ id=None, # No ID
936
+ name="mock",
937
+ extractor_type=ExtractorType.LITELLM,
938
+ model_name="gpt_4o",
939
+ model_provider_name="openai",
940
+ properties={
941
+ "prompt_document": PROMPTS_FOR_KIND["document"],
942
+ "prompt_image": PROMPTS_FOR_KIND["image"],
943
+ "prompt_video": PROMPTS_FOR_KIND["video"],
944
+ "prompt_audio": PROMPTS_FOR_KIND["audio"],
945
+ },
946
+ )
947
+
948
+ extractor = LitellmExtractor(
949
+ extractor_config,
950
+ LiteLlmCoreConfig(
951
+ base_url="https://test.com",
952
+ additional_body_options={"api_key": "test-key"},
953
+ default_headers={},
954
+ ),
955
+ )
956
+
957
+ with pytest.raises(
958
+ ValueError, match="Extractor config ID is required for PDF page cache key"
959
+ ):
960
+ extractor.pdf_page_cache_key(Path("test.pdf"), 0)
961
+
962
+
963
+ async def test_extract_pdf_with_cache_storage(
964
+ mock_file_factory, mock_litellm_extractor_with_cache
965
+ ):
966
+ """Test that PDF extraction stores content in cache when cache is available."""
967
+ test_file = mock_file_factory(MockFileFactoryMimeType.PDF)
968
+
969
+ # Mock responses for each page (PDF has 2 pages)
970
+ mock_responses = []
971
+ for i in range(2):
972
+ mock_response = AsyncMock(spec=ModelResponse)
973
+ mock_choice = AsyncMock(spec=Choices)
974
+ mock_message = AsyncMock()
975
+ mock_message.content = f"Content from page {i + 1}"
976
+ mock_choice.message = mock_message
977
+ mock_response.choices = [mock_choice]
978
+ mock_responses.append(mock_response)
979
+
980
+ with patch("litellm.acompletion", side_effect=mock_responses) as mock_acompletion:
981
+ result = await mock_litellm_extractor_with_cache.extract(
982
+ ExtractionInput(
983
+ path=str(test_file),
984
+ mime_type="application/pdf",
985
+ )
986
+ )
987
+
988
+ # Verify that the completion was called for each page
989
+ assert mock_acompletion.call_count == 2
990
+
991
+ # Verify content is stored in cache - note that order is not guaranteed since
992
+ # we batch the page extraction requests in parallel
993
+ pdf_path = Path(test_file)
994
+ cached_contents = []
995
+ for i in range(2):
996
+ cached_content = (
997
+ await mock_litellm_extractor_with_cache.get_page_content_from_cache(
998
+ pdf_path, i
999
+ )
1000
+ )
1001
+ assert cached_content is not None
1002
+ cached_contents.append(cached_content)
1003
+ assert set(cached_contents) == {"Content from page 1", "Content from page 2"}
1004
+
1005
+ # Verify the output contains content from both pages
1006
+ assert "Content from page 1" in result.content
1007
+ assert "Content from page 2" in result.content
1008
+
1009
+
1010
+ async def test_extract_pdf_with_cache_retrieval(
1011
+ mock_file_factory, mock_litellm_extractor_with_cache
1012
+ ):
1013
+ """Test that PDF extraction retrieves content from cache when available."""
1014
+ test_file = mock_file_factory(MockFileFactoryMimeType.PDF)
1015
+ pdf_path = Path(test_file)
1016
+
1017
+ # Pre-populate cache with content
1018
+ for i in range(2):
1019
+ cache_key = mock_litellm_extractor_with_cache.pdf_page_cache_key(pdf_path, i)
1020
+ await mock_litellm_extractor_with_cache.filesystem_cache.set(
1021
+ cache_key, f"Cached content from page {i + 1}".encode("utf-8")
1022
+ )
1023
+
1024
+ # Mock responses (should not be called due to cache hits)
1025
+ mock_responses = []
1026
+ for i in range(2):
1027
+ mock_response = AsyncMock(spec=ModelResponse)
1028
+ mock_choice = AsyncMock(spec=Choices)
1029
+ mock_message = AsyncMock()
1030
+ mock_message.content = f"Fresh content from page {i + 1}"
1031
+ mock_choice.message = mock_message
1032
+ mock_response.choices = [mock_choice]
1033
+ mock_responses.append(mock_response)
1034
+
1035
+ with patch("litellm.acompletion", side_effect=mock_responses) as mock_acompletion:
1036
+ result = await mock_litellm_extractor_with_cache.extract(
1037
+ ExtractionInput(
1038
+ path=str(test_file),
1039
+ mime_type="application/pdf",
1040
+ )
1041
+ )
1042
+
1043
+ # Verify that litellm.acompletion was NOT called (cache hits)
1044
+ assert mock_acompletion.call_count == 0
1045
+
1046
+ # Verify the output contains cached content, not fresh content
1047
+ assert "Cached content from page 1" in result.content
1048
+ assert "Cached content from page 2" in result.content
1049
+ assert "Fresh content from page 1" not in result.content
1050
+ assert "Fresh content from page 2" not in result.content
1051
+
1052
+
1053
+ async def test_extract_pdf_without_cache(
1054
+ mock_file_factory, mock_litellm_extractor_without_cache
1055
+ ):
1056
+ """Test that PDF extraction works normally when no cache is provided."""
1057
+ test_file = mock_file_factory(MockFileFactoryMimeType.PDF)
1058
+
1059
+ # Mock responses for each page (PDF has 2 pages)
1060
+ mock_responses = []
1061
+ for i in range(2):
1062
+ mock_response = AsyncMock(spec=ModelResponse)
1063
+ mock_choice = AsyncMock(spec=Choices)
1064
+ mock_message = AsyncMock()
1065
+ mock_message.content = f"Content from page {i + 1}"
1066
+ mock_choice.message = mock_message
1067
+ mock_response.choices = [mock_choice]
1068
+ mock_responses.append(mock_response)
1069
+
1070
+ with patch("litellm.acompletion", side_effect=mock_responses) as mock_acompletion:
1071
+ result = await mock_litellm_extractor_without_cache.extract(
1072
+ ExtractionInput(
1073
+ path=str(test_file),
1074
+ mime_type="application/pdf",
1075
+ )
1076
+ )
1077
+
1078
+ # Verify that the completion was called for each page
1079
+ assert mock_acompletion.call_count == 2
1080
+
1081
+ # Verify the output contains content from both pages
1082
+ assert "Content from page 1" in result.content
1083
+ assert "Content from page 2" in result.content
1084
+
1085
+
1086
+ async def test_extract_pdf_mixed_cache_hits_and_misses(
1087
+ mock_file_factory, mock_litellm_extractor_with_cache
1088
+ ):
1089
+ """Test PDF extraction with some pages cached and others not."""
1090
+ test_file = mock_file_factory(MockFileFactoryMimeType.PDF)
1091
+ pdf_path = Path(test_file)
1092
+
1093
+ # Pre-populate cache with only page 0 content
1094
+ cache_key = mock_litellm_extractor_with_cache.pdf_page_cache_key(pdf_path, 0)
1095
+ await mock_litellm_extractor_with_cache.filesystem_cache.set(
1096
+ cache_key, "Cached content from page 1".encode("utf-8")
1097
+ )
1098
+
1099
+ # Mock responses for page 1 only (page 0 should hit cache)
1100
+ mock_response = AsyncMock(spec=ModelResponse)
1101
+ mock_choice = AsyncMock(spec=Choices)
1102
+ mock_message = AsyncMock()
1103
+ mock_message.content = "Fresh content from page 2"
1104
+ mock_choice.message = mock_message
1105
+ mock_response.choices = [mock_choice]
1106
+
1107
+ with patch("litellm.acompletion", return_value=mock_response) as mock_acompletion:
1108
+ result = await mock_litellm_extractor_with_cache.extract(
1109
+ ExtractionInput(
1110
+ path=str(test_file),
1111
+ mime_type="application/pdf",
1112
+ )
1113
+ )
1114
+
1115
+ # Verify that litellm.acompletion was called only once (for page 1)
1116
+ assert mock_acompletion.call_count == 1
1117
+
1118
+ # Verify the output contains both cached and fresh content
1119
+ assert "Cached content from page 1" in result.content
1120
+ assert "Fresh content from page 2" in result.content
1121
+
1122
+
1123
+ async def test_extract_pdf_cache_write_failure_does_not_throw(
1124
+ mock_file_factory, mock_litellm_extractor_with_cache
1125
+ ):
1126
+ """Test that PDF extraction continues successfully even when cache write fails."""
1127
+ test_file = mock_file_factory(MockFileFactoryMimeType.PDF)
1128
+
1129
+ # Mock responses for each page (PDF has 2 pages)
1130
+ mock_responses = []
1131
+ for i in range(2):
1132
+ mock_response = AsyncMock(spec=ModelResponse)
1133
+ mock_choice = AsyncMock(spec=Choices)
1134
+ mock_message = AsyncMock()
1135
+ mock_message.content = f"Content from page {i + 1}"
1136
+ mock_choice.message = mock_message
1137
+ mock_response.choices = [mock_choice]
1138
+ mock_responses.append(mock_response)
1139
+
1140
+ # Mock the cache set method to raise an exception
1141
+ with patch.object(
1142
+ mock_litellm_extractor_with_cache.filesystem_cache,
1143
+ "set",
1144
+ side_effect=Exception("Cache write failed"),
1145
+ ) as mock_cache_set:
1146
+ with patch(
1147
+ "litellm.acompletion", side_effect=mock_responses
1148
+ ) as mock_acompletion:
1149
+ # This should not raise an exception despite cache write failures
1150
+ result = await mock_litellm_extractor_with_cache.extract(
1151
+ ExtractionInput(
1152
+ path=str(test_file),
1153
+ mime_type="application/pdf",
1154
+ )
1155
+ )
1156
+
1157
+ # Verify that the completion was called for each page
1158
+ assert mock_acompletion.call_count == 2
1159
+
1160
+ # Verify that cache.set was called for each page (and failed)
1161
+ assert mock_cache_set.call_count == 2
1162
+
1163
+ # Verify the output contains content from both pages despite cache failures
1164
+ assert "Content from page 1" in result.content
1165
+ assert "Content from page 2" in result.content
1166
+
1167
+ # Verify the extraction completed successfully
1168
+ assert not result.is_passthrough
1169
+ assert result.content_format == OutputFormat.MARKDOWN
1170
+
1171
+
1172
+ async def test_extract_pdf_cache_decode_failure_does_not_throw(
1173
+ mock_file_factory, mock_litellm_extractor_with_cache
1174
+ ):
1175
+ """Test that PDF extraction continues successfully even when cache decode fails."""
1176
+ test_file = mock_file_factory(MockFileFactoryMimeType.PDF)
1177
+ pdf_path = Path(test_file)
1178
+
1179
+ # Pre-populate cache with invalid UTF-8 bytes that will cause decode failure
1180
+ for i in range(2):
1181
+ cache_key = mock_litellm_extractor_with_cache.pdf_page_cache_key(pdf_path, i)
1182
+ # Use bytes that are not valid UTF-8 (e.g., some binary data)
1183
+ invalid_utf8_bytes = b"\xff\xfe\x00\x00" # Invalid UTF-8 sequence
1184
+ await mock_litellm_extractor_with_cache.filesystem_cache.set(
1185
+ cache_key, invalid_utf8_bytes
1186
+ )
1187
+
1188
+ # Mock responses for each page (PDF has 2 pages) - should be called due to decode failures
1189
+ mock_responses = []
1190
+ for i in range(2):
1191
+ mock_response = AsyncMock(spec=ModelResponse)
1192
+ mock_choice = AsyncMock(spec=Choices)
1193
+ mock_message = AsyncMock()
1194
+ mock_message.content = f"Content from page {i + 1}"
1195
+ mock_choice.message = mock_message
1196
+ mock_response.choices = [mock_choice]
1197
+ mock_responses.append(mock_response)
1198
+
1199
+ with patch("litellm.acompletion", side_effect=mock_responses) as mock_acompletion:
1200
+ # This should not raise an exception despite cache decode failures
1201
+ result = await mock_litellm_extractor_with_cache.extract(
1202
+ ExtractionInput(
1203
+ path=str(test_file),
1204
+ mime_type="application/pdf",
1205
+ )
1206
+ )
1207
+
1208
+ # Verify that the completion was called for each page (due to decode failures)
1209
+ assert mock_acompletion.call_count == 2
1210
+
1211
+ # Verify the output contains content from both pages despite cache decode failures
1212
+ assert "Content from page 1" in result.content
1213
+ assert "Content from page 2" in result.content
1214
+
1215
+ # Verify the extraction completed successfully
1216
+ assert not result.is_passthrough
1217
+ assert result.content_format == OutputFormat.MARKDOWN
1218
+
1219
+
1220
+ async def test_extract_pdf_parallel_processing_error_handling(
1221
+ mock_file_factory, mock_litellm_extractor_with_cache
1222
+ ):
1223
+ """Test that parallel processing handles errors correctly."""
1224
+ test_file = mock_file_factory(MockFileFactoryMimeType.PDF)
1225
+
1226
+ # Mock first page to succeed, second to fail
1227
+ mock_response1 = AsyncMock(spec=ModelResponse)
1228
+ mock_choice1 = AsyncMock(spec=Choices)
1229
+ mock_message1 = AsyncMock()
1230
+ mock_message1.content = "Success from page 1"
1231
+ mock_choice1.message = mock_message1
1232
+ mock_response1.choices = [mock_choice1]
1233
+
1234
+ with patch(
1235
+ "litellm.acompletion",
1236
+ side_effect=[mock_response1, Exception("API Error on page 2")],
1237
+ ) as mock_acompletion:
1238
+ with pytest.raises(ValueError, match=r".*API Error on page 2"):
1239
+ await mock_litellm_extractor_with_cache.extract(
1240
+ ExtractionInput(
1241
+ path=str(test_file),
1242
+ mime_type="application/pdf",
1243
+ )
1244
+ )
1245
+
1246
+ # Verify that both pages were attempted
1247
+ assert mock_acompletion.call_count == 2
1248
+
1249
+
1250
+ async def test_extract_pdf_parallel_processing_all_cached(
1251
+ mock_file_factory, mock_litellm_extractor_with_cache
1252
+ ):
1253
+ """Test parallel processing when all pages are cached."""
1254
+ test_file = mock_file_factory(MockFileFactoryMimeType.PDF)
1255
+ pdf_path = Path(test_file)
1256
+
1257
+ # Pre-populate cache with both pages
1258
+ for i in range(2):
1259
+ cache_key = mock_litellm_extractor_with_cache.pdf_page_cache_key(pdf_path, i)
1260
+ await mock_litellm_extractor_with_cache.filesystem_cache.set(
1261
+ cache_key, f"Cached content from page {i + 1}".encode("utf-8")
1262
+ )
1263
+
1264
+ # Mock responses (should not be called due to cache hits)
1265
+ mock_responses = []
1266
+ for i in range(2):
1267
+ mock_response = AsyncMock(spec=ModelResponse)
1268
+ mock_choice = AsyncMock(spec=Choices)
1269
+ mock_message = AsyncMock()
1270
+ mock_message.content = f"Fresh content from page {i + 1}"
1271
+ mock_choice.message = mock_message
1272
+ mock_response.choices = [mock_choice]
1273
+ mock_responses.append(mock_response)
1274
+
1275
+ with patch("litellm.acompletion", side_effect=mock_responses) as mock_acompletion:
1276
+ result = await mock_litellm_extractor_with_cache.extract(
1277
+ ExtractionInput(
1278
+ path=str(test_file),
1279
+ mime_type="application/pdf",
1280
+ )
1281
+ )
1282
+
1283
+ # Verify that no API calls were made (all pages cached)
1284
+ assert mock_acompletion.call_count == 0
1285
+
1286
+ # Verify the output contains cached content
1287
+ assert "Cached content from page 1" in result.content
1288
+ assert "Cached content from page 2" in result.content
1289
+ assert "Fresh content from page 1" not in result.content
1290
+ assert "Fresh content from page 2" not in result.content