kiln-ai 0.20.1__py3-none-any.whl → 0.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (117) hide show
  1. kiln_ai/adapters/__init__.py +6 -0
  2. kiln_ai/adapters/adapter_registry.py +43 -226
  3. kiln_ai/adapters/chunkers/__init__.py +13 -0
  4. kiln_ai/adapters/chunkers/base_chunker.py +42 -0
  5. kiln_ai/adapters/chunkers/chunker_registry.py +16 -0
  6. kiln_ai/adapters/chunkers/fixed_window_chunker.py +39 -0
  7. kiln_ai/adapters/chunkers/helpers.py +23 -0
  8. kiln_ai/adapters/chunkers/test_base_chunker.py +63 -0
  9. kiln_ai/adapters/chunkers/test_chunker_registry.py +28 -0
  10. kiln_ai/adapters/chunkers/test_fixed_window_chunker.py +346 -0
  11. kiln_ai/adapters/chunkers/test_helpers.py +75 -0
  12. kiln_ai/adapters/data_gen/test_data_gen_task.py +9 -3
  13. kiln_ai/adapters/embedding/__init__.py +0 -0
  14. kiln_ai/adapters/embedding/base_embedding_adapter.py +44 -0
  15. kiln_ai/adapters/embedding/embedding_registry.py +32 -0
  16. kiln_ai/adapters/embedding/litellm_embedding_adapter.py +199 -0
  17. kiln_ai/adapters/embedding/test_base_embedding_adapter.py +283 -0
  18. kiln_ai/adapters/embedding/test_embedding_registry.py +166 -0
  19. kiln_ai/adapters/embedding/test_litellm_embedding_adapter.py +1149 -0
  20. kiln_ai/adapters/eval/eval_runner.py +6 -2
  21. kiln_ai/adapters/eval/test_base_eval.py +1 -3
  22. kiln_ai/adapters/eval/test_g_eval.py +1 -1
  23. kiln_ai/adapters/extractors/__init__.py +18 -0
  24. kiln_ai/adapters/extractors/base_extractor.py +72 -0
  25. kiln_ai/adapters/extractors/encoding.py +20 -0
  26. kiln_ai/adapters/extractors/extractor_registry.py +44 -0
  27. kiln_ai/adapters/extractors/extractor_runner.py +112 -0
  28. kiln_ai/adapters/extractors/litellm_extractor.py +386 -0
  29. kiln_ai/adapters/extractors/test_base_extractor.py +244 -0
  30. kiln_ai/adapters/extractors/test_encoding.py +54 -0
  31. kiln_ai/adapters/extractors/test_extractor_registry.py +181 -0
  32. kiln_ai/adapters/extractors/test_extractor_runner.py +181 -0
  33. kiln_ai/adapters/extractors/test_litellm_extractor.py +1192 -0
  34. kiln_ai/adapters/fine_tune/test_dataset_formatter.py +2 -2
  35. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +2 -6
  36. kiln_ai/adapters/fine_tune/test_together_finetune.py +2 -6
  37. kiln_ai/adapters/ml_embedding_model_list.py +192 -0
  38. kiln_ai/adapters/ml_model_list.py +382 -4
  39. kiln_ai/adapters/model_adapters/litellm_adapter.py +7 -69
  40. kiln_ai/adapters/model_adapters/test_litellm_adapter.py +1 -1
  41. kiln_ai/adapters/model_adapters/test_structured_output.py +3 -1
  42. kiln_ai/adapters/ollama_tools.py +69 -12
  43. kiln_ai/adapters/provider_tools.py +190 -46
  44. kiln_ai/adapters/rag/deduplication.py +49 -0
  45. kiln_ai/adapters/rag/progress.py +252 -0
  46. kiln_ai/adapters/rag/rag_runners.py +844 -0
  47. kiln_ai/adapters/rag/test_deduplication.py +195 -0
  48. kiln_ai/adapters/rag/test_progress.py +785 -0
  49. kiln_ai/adapters/rag/test_rag_runners.py +2376 -0
  50. kiln_ai/adapters/remote_config.py +80 -8
  51. kiln_ai/adapters/test_adapter_registry.py +579 -86
  52. kiln_ai/adapters/test_ml_embedding_model_list.py +429 -0
  53. kiln_ai/adapters/test_ml_model_list.py +212 -0
  54. kiln_ai/adapters/test_ollama_tools.py +340 -1
  55. kiln_ai/adapters/test_prompt_builders.py +1 -1
  56. kiln_ai/adapters/test_provider_tools.py +199 -8
  57. kiln_ai/adapters/test_remote_config.py +551 -56
  58. kiln_ai/adapters/vector_store/__init__.py +1 -0
  59. kiln_ai/adapters/vector_store/base_vector_store_adapter.py +83 -0
  60. kiln_ai/adapters/vector_store/lancedb_adapter.py +389 -0
  61. kiln_ai/adapters/vector_store/test_base_vector_store.py +160 -0
  62. kiln_ai/adapters/vector_store/test_lancedb_adapter.py +1841 -0
  63. kiln_ai/adapters/vector_store/test_vector_store_registry.py +199 -0
  64. kiln_ai/adapters/vector_store/vector_store_registry.py +33 -0
  65. kiln_ai/datamodel/__init__.py +16 -13
  66. kiln_ai/datamodel/basemodel.py +170 -1
  67. kiln_ai/datamodel/chunk.py +158 -0
  68. kiln_ai/datamodel/datamodel_enums.py +27 -0
  69. kiln_ai/datamodel/embedding.py +64 -0
  70. kiln_ai/datamodel/extraction.py +303 -0
  71. kiln_ai/datamodel/project.py +33 -1
  72. kiln_ai/datamodel/rag.py +79 -0
  73. kiln_ai/datamodel/test_attachment.py +649 -0
  74. kiln_ai/datamodel/test_basemodel.py +1 -1
  75. kiln_ai/datamodel/test_chunk_models.py +317 -0
  76. kiln_ai/datamodel/test_dataset_split.py +1 -1
  77. kiln_ai/datamodel/test_embedding_models.py +448 -0
  78. kiln_ai/datamodel/test_eval_model.py +6 -6
  79. kiln_ai/datamodel/test_extraction_chunk.py +206 -0
  80. kiln_ai/datamodel/test_extraction_model.py +470 -0
  81. kiln_ai/datamodel/test_rag.py +641 -0
  82. kiln_ai/datamodel/test_tool_id.py +81 -0
  83. kiln_ai/datamodel/test_vector_store.py +320 -0
  84. kiln_ai/datamodel/tool_id.py +22 -0
  85. kiln_ai/datamodel/vector_store.py +141 -0
  86. kiln_ai/tools/mcp_session_manager.py +4 -1
  87. kiln_ai/tools/rag_tools.py +157 -0
  88. kiln_ai/tools/test_mcp_session_manager.py +1 -1
  89. kiln_ai/tools/test_rag_tools.py +848 -0
  90. kiln_ai/tools/test_tool_registry.py +91 -2
  91. kiln_ai/tools/tool_registry.py +21 -0
  92. kiln_ai/utils/__init__.py +3 -0
  93. kiln_ai/utils/async_job_runner.py +62 -17
  94. kiln_ai/utils/config.py +2 -2
  95. kiln_ai/utils/env.py +15 -0
  96. kiln_ai/utils/filesystem.py +14 -0
  97. kiln_ai/utils/filesystem_cache.py +60 -0
  98. kiln_ai/utils/litellm.py +94 -0
  99. kiln_ai/utils/lock.py +100 -0
  100. kiln_ai/utils/mime_type.py +38 -0
  101. kiln_ai/utils/pdf_utils.py +38 -0
  102. kiln_ai/utils/test_async_job_runner.py +151 -35
  103. kiln_ai/utils/test_env.py +142 -0
  104. kiln_ai/utils/test_filesystem_cache.py +316 -0
  105. kiln_ai/utils/test_litellm.py +206 -0
  106. kiln_ai/utils/test_lock.py +185 -0
  107. kiln_ai/utils/test_mime_type.py +66 -0
  108. kiln_ai/utils/test_pdf_utils.py +73 -0
  109. kiln_ai/utils/test_uuid.py +111 -0
  110. kiln_ai/utils/test_validation.py +524 -0
  111. kiln_ai/utils/uuid.py +9 -0
  112. kiln_ai/utils/validation.py +90 -0
  113. {kiln_ai-0.20.1.dist-info → kiln_ai-0.21.0.dist-info}/METADATA +7 -1
  114. kiln_ai-0.21.0.dist-info/RECORD +211 -0
  115. kiln_ai-0.20.1.dist-info/RECORD +0 -138
  116. {kiln_ai-0.20.1.dist-info → kiln_ai-0.21.0.dist-info}/WHEEL +0 -0
  117. {kiln_ai-0.20.1.dist-info → kiln_ai-0.21.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -0,0 +1,244 @@
1
+ from typing import Any
2
+ from unittest.mock import patch
3
+
4
+ import pytest
5
+
6
+ from kiln_ai.adapters.extractors.base_extractor import (
7
+ BaseExtractor,
8
+ ExtractionInput,
9
+ ExtractionOutput,
10
+ )
11
+ from kiln_ai.datamodel.extraction import ExtractorConfig, ExtractorType, OutputFormat
12
+
13
+
14
+ class MockBaseExtractor(BaseExtractor):
15
+ async def _extract(self, input: ExtractionInput) -> ExtractionOutput:
16
+ return ExtractionOutput(
17
+ is_passthrough=False,
18
+ content="mock concrete extractor output",
19
+ content_format=OutputFormat.MARKDOWN,
20
+ )
21
+
22
+
23
+ @pytest.fixture
24
+ def mock_litellm_properties():
25
+ return {
26
+ "prompt_document": "mock prompt for document",
27
+ "prompt_image": "mock prompt for image",
28
+ "prompt_video": "mock prompt for video",
29
+ "prompt_audio": "mock prompt for audio",
30
+ }
31
+
32
+
33
+ @pytest.fixture
34
+ def mock_extractor(mock_litellm_properties):
35
+ return MockBaseExtractor(
36
+ ExtractorConfig(
37
+ name="mock",
38
+ model_provider_name="gemini_api",
39
+ model_name="gemini-2.0-flash",
40
+ extractor_type=ExtractorType.LITELLM,
41
+ output_format=OutputFormat.MARKDOWN,
42
+ properties=mock_litellm_properties,
43
+ )
44
+ )
45
+
46
+
47
+ def mock_extractor_with_passthroughs(
48
+ properties: dict[str, Any],
49
+ mimetypes: list[OutputFormat],
50
+ output_format: OutputFormat,
51
+ ):
52
+ return MockBaseExtractor(
53
+ ExtractorConfig(
54
+ name="mock",
55
+ model_provider_name="gemini_api",
56
+ model_name="gemini-2.0-flash",
57
+ extractor_type=ExtractorType.LITELLM,
58
+ passthrough_mimetypes=mimetypes,
59
+ output_format=output_format,
60
+ properties=properties,
61
+ )
62
+ )
63
+
64
+
65
+ def test_should_passthrough(mock_litellm_properties):
66
+ extractor = mock_extractor_with_passthroughs(
67
+ mock_litellm_properties,
68
+ [OutputFormat.TEXT, OutputFormat.MARKDOWN],
69
+ OutputFormat.TEXT,
70
+ )
71
+
72
+ # should passthrough
73
+ assert extractor._should_passthrough("text/plain")
74
+ assert extractor._should_passthrough("text/markdown")
75
+
76
+ # should not passthrough
77
+ assert not extractor._should_passthrough("image/png")
78
+ assert not extractor._should_passthrough("application/pdf")
79
+ assert not extractor._should_passthrough("text/html")
80
+ assert not extractor._should_passthrough("image/jpeg")
81
+
82
+
83
+ async def test_extract_passthrough(mock_litellm_properties):
84
+ """
85
+ Tests that when a file's MIME type is configured for passthrough, the extractor skips
86
+ the concrete extraction method and returns the file's contents directly with the
87
+ correct passthrough output format.
88
+ """
89
+ extractor = mock_extractor_with_passthroughs(
90
+ mock_litellm_properties,
91
+ [OutputFormat.TEXT, OutputFormat.MARKDOWN],
92
+ OutputFormat.TEXT,
93
+ )
94
+ with (
95
+ patch.object(
96
+ extractor,
97
+ "_extract",
98
+ return_value=ExtractionOutput(
99
+ is_passthrough=False,
100
+ content="mock concrete extractor output",
101
+ content_format=OutputFormat.TEXT,
102
+ ),
103
+ ) as mock_extract,
104
+ patch(
105
+ "pathlib.Path.read_text",
106
+ return_value=b"test content",
107
+ ),
108
+ ):
109
+ result = await extractor.extract(
110
+ ExtractionInput(
111
+ path="test.txt",
112
+ mime_type="text/plain",
113
+ )
114
+ )
115
+
116
+ # Verify _extract was not called
117
+ mock_extract.assert_not_called()
118
+
119
+ # Verify correct passthrough result
120
+ assert result.is_passthrough
121
+ assert result.content == "test content"
122
+ assert result.content_format == OutputFormat.TEXT
123
+
124
+
125
+ @pytest.mark.parametrize(
126
+ "output_format",
127
+ [
128
+ "text/plain",
129
+ "text/markdown",
130
+ ],
131
+ )
132
+ async def test_extract_passthrough_output_format(
133
+ mock_litellm_properties, output_format
134
+ ):
135
+ extractor = mock_extractor_with_passthroughs(
136
+ mock_litellm_properties,
137
+ [OutputFormat.TEXT, OutputFormat.MARKDOWN],
138
+ output_format,
139
+ )
140
+ with (
141
+ patch.object(
142
+ extractor,
143
+ "_extract",
144
+ return_value=ExtractionOutput(
145
+ is_passthrough=False,
146
+ content="mock concrete extractor output",
147
+ content_format=output_format,
148
+ ),
149
+ ) as mock_extract,
150
+ patch(
151
+ "pathlib.Path.read_text",
152
+ return_value="test content",
153
+ ),
154
+ ):
155
+ result = await extractor.extract(
156
+ ExtractionInput(
157
+ path="test.txt",
158
+ mime_type="text/plain",
159
+ )
160
+ )
161
+
162
+ # Verify _extract was not called
163
+ mock_extract.assert_not_called()
164
+
165
+ # Verify correct passthrough result
166
+ assert result.is_passthrough
167
+ assert result.content == "test content"
168
+ assert result.content_format == output_format
169
+
170
+
171
+ @pytest.mark.parametrize(
172
+ "path, mime_type, output_format",
173
+ [
174
+ ("test.mp3", "audio/mpeg", OutputFormat.TEXT),
175
+ ("test.png", "image/png", OutputFormat.TEXT),
176
+ ("test.pdf", "application/pdf", OutputFormat.TEXT),
177
+ ("test.txt", "text/plain", OutputFormat.MARKDOWN),
178
+ ("test.txt", "text/markdown", OutputFormat.MARKDOWN),
179
+ ("test.html", "text/html", OutputFormat.MARKDOWN),
180
+ ],
181
+ )
182
+ async def test_extract_non_passthrough(
183
+ mock_extractor, path: str, mime_type: str, output_format: OutputFormat
184
+ ):
185
+ with (
186
+ patch.object(
187
+ mock_extractor,
188
+ "_extract",
189
+ return_value=ExtractionOutput(
190
+ is_passthrough=False,
191
+ content="mock concrete extractor output",
192
+ content_format=output_format,
193
+ ),
194
+ ) as mock_extract,
195
+ ):
196
+ # first we call the base class extract method
197
+ result = await mock_extractor.extract(
198
+ ExtractionInput(
199
+ path=path,
200
+ mime_type=mime_type,
201
+ )
202
+ )
203
+
204
+ # then we call the subclass _extract method and add validated mime_type
205
+ mock_extract.assert_called_once_with(
206
+ ExtractionInput(
207
+ path=path,
208
+ mime_type=mime_type,
209
+ )
210
+ )
211
+
212
+ assert not result.is_passthrough
213
+ assert result.content == "mock concrete extractor output"
214
+ assert result.content_format == output_format
215
+
216
+
217
+ async def test_default_output_format(mock_litellm_properties):
218
+ config = ExtractorConfig(
219
+ name="mock",
220
+ model_provider_name="gemini_api",
221
+ model_name="gemini-2.0-flash",
222
+ extractor_type=ExtractorType.LITELLM,
223
+ properties=mock_litellm_properties,
224
+ )
225
+ assert config.output_format == OutputFormat.MARKDOWN
226
+
227
+
228
+ async def test_extract_failure_from_concrete_extractor(mock_extractor):
229
+ with patch.object(
230
+ mock_extractor,
231
+ "_extract",
232
+ side_effect=Exception("error from concrete extractor"),
233
+ ):
234
+ with pytest.raises(ValueError, match="error from concrete extractor"):
235
+ await mock_extractor.extract(
236
+ ExtractionInput(
237
+ path="test.txt",
238
+ mime_type="text/plain",
239
+ )
240
+ )
241
+
242
+
243
+ async def test_output_format(mock_extractor):
244
+ assert mock_extractor.output_format() == OutputFormat.MARKDOWN
@@ -0,0 +1,54 @@
1
+ from pathlib import Path
2
+
3
+ import pytest
4
+
5
+ from conftest import MockFileFactoryMimeType
6
+ from kiln_ai.adapters.extractors.encoding import from_base64_url, to_base64_url
7
+
8
+
9
+ async def test_to_base64_url(mock_file_factory):
10
+ mock_file = mock_file_factory(MockFileFactoryMimeType.JPEG)
11
+
12
+ byte_data = Path(mock_file).read_bytes()
13
+
14
+ # encode the byte data
15
+ base64_url = to_base64_url("image/jpeg", byte_data)
16
+ assert base64_url.startswith("data:image/jpeg;base64,")
17
+
18
+ # decode the base64 url
19
+ assert from_base64_url(base64_url) == byte_data
20
+
21
+
22
+ def test_from_base64_url_invalid_format_no_data_prefix():
23
+ """Test that from_base64_url raises ValueError when input doesn't start with 'data:'"""
24
+ with pytest.raises(ValueError, match="Invalid base64 URL format"):
25
+ from_base64_url("not-a-data-url")
26
+
27
+
28
+ def test_from_base64_url_invalid_format_no_comma():
29
+ """Test that from_base64_url raises ValueError when input doesn't contain a comma"""
30
+ with pytest.raises(ValueError, match="Invalid base64 URL format"):
31
+ from_base64_url("data:image/jpeg;base64")
32
+
33
+
34
+ def test_from_base64_url_invalid_parts():
35
+ """Test that from_base64_url raises ValueError when splitting by comma doesn't result in exactly 2 parts"""
36
+ with pytest.raises(ValueError, match="Invalid base64 URL format"):
37
+ from_base64_url(",part2")
38
+
39
+
40
+ def test_from_base64_url_base64_decode_failure():
41
+ """Test that from_base64_url raises ValueError when base64 decoding fails"""
42
+ with pytest.raises(ValueError, match="Failed to decode base64 data"):
43
+ from_base64_url("-base64-data!")
44
+
45
+
46
+ def test_from_base64_url_valid_format():
47
+ """Test that from_base64_url works with valid base64 URL format"""
48
+ # Create a simple valid base64 URL
49
+ test_data = b"Hello, World!"
50
+ base64_encoded = "SGVsbG8sIFdvcmxkIQ=="
51
+ base64_url = f"data:text/plain;base64,{base64_encoded}"
52
+
53
+ result = from_base64_url(base64_url)
54
+ assert result == test_data
@@ -0,0 +1,181 @@
1
+ from unittest.mock import patch
2
+
3
+ import pytest
4
+
5
+ from kiln_ai.adapters.extractors.extractor_registry import extractor_adapter_from_type
6
+ from kiln_ai.adapters.extractors.litellm_extractor import LitellmExtractor
7
+ from kiln_ai.adapters.ml_model_list import ModelProviderName
8
+ from kiln_ai.adapters.provider_tools import LiteLlmCoreConfig
9
+ from kiln_ai.datamodel.extraction import ExtractorConfig, ExtractorType
10
+
11
+
12
+ @pytest.fixture
13
+ def mock_provider_configs():
14
+ with patch("kiln_ai.utils.config.Config.shared") as mock_config:
15
+ mock_config.return_value.open_ai_api_key = "test-openai-key"
16
+ mock_config.return_value.gemini_api_key = "test-gemini-key"
17
+ mock_config.return_value.anthropic_api_key = "test-anthropic-key"
18
+ mock_config.return_value.bedrock_access_key = "test-amazon-bedrock-key"
19
+ mock_config.return_value.bedrock_secret_key = "test-amazon-bedrock-secret-key"
20
+ mock_config.return_value.fireworks_api_key = "test-fireworks-key"
21
+ mock_config.return_value.groq_api_key = "test-groq-key"
22
+ mock_config.return_value.huggingface_api_key = "test-huggingface-key"
23
+ yield mock_config
24
+
25
+
26
+ def test_extractor_adapter_from_type(mock_provider_configs):
27
+ extractor = extractor_adapter_from_type(
28
+ ExtractorType.LITELLM,
29
+ ExtractorConfig(
30
+ name="test-extractor",
31
+ extractor_type=ExtractorType.LITELLM,
32
+ model_provider_name="gemini_api",
33
+ model_name="gemini-2.0-flash",
34
+ properties={
35
+ "prompt_document": "Extract the text from the document",
36
+ "prompt_image": "Extract the text from the image",
37
+ "prompt_video": "Extract the text from the video",
38
+ "prompt_audio": "Extract the text from the audio",
39
+ },
40
+ ),
41
+ )
42
+ assert isinstance(extractor, LitellmExtractor)
43
+ assert extractor.extractor_config.model_name == "gemini-2.0-flash"
44
+ assert extractor.extractor_config.model_provider_name == "gemini_api"
45
+
46
+
47
+ @patch(
48
+ "kiln_ai.adapters.extractors.extractor_registry.lite_llm_core_config_for_provider"
49
+ )
50
+ def test_extractor_adapter_from_type_uses_litellm_core_config(
51
+ mock_get_litellm_core_config,
52
+ ):
53
+ """Test that extractor receives auth details from provider_tools."""
54
+ mock_litellm_core_config = LiteLlmCoreConfig(
55
+ base_url="https://test.com",
56
+ additional_body_options={"api_key": "test-key"},
57
+ default_headers={},
58
+ )
59
+ mock_get_litellm_core_config.return_value = mock_litellm_core_config
60
+
61
+ extractor = extractor_adapter_from_type(
62
+ ExtractorType.LITELLM,
63
+ ExtractorConfig(
64
+ name="test-extractor",
65
+ extractor_type=ExtractorType.LITELLM,
66
+ model_provider_name="openai",
67
+ model_name="gpt-4",
68
+ properties={
69
+ "prompt_document": "Extract the text from the document",
70
+ "prompt_image": "Extract the text from the image",
71
+ "prompt_video": "Extract the text from the video",
72
+ "prompt_audio": "Extract the text from the audio",
73
+ },
74
+ ),
75
+ )
76
+
77
+ assert isinstance(extractor, LitellmExtractor)
78
+ assert extractor.litellm_core_config == mock_litellm_core_config
79
+ mock_get_litellm_core_config.assert_called_once_with(ModelProviderName.openai)
80
+
81
+
82
+ def test_extractor_adapter_from_type_invalid_provider():
83
+ """Test that invalid model provider names raise a clear error."""
84
+ with pytest.raises(
85
+ ValueError, match="Unsupported model provider name: invalid_provider"
86
+ ):
87
+ extractor_adapter_from_type(
88
+ ExtractorType.LITELLM,
89
+ ExtractorConfig(
90
+ name="test-extractor",
91
+ extractor_type=ExtractorType.LITELLM,
92
+ model_provider_name="invalid_provider",
93
+ model_name="some-model",
94
+ properties={
95
+ "prompt_document": "Extract the text from the document",
96
+ "prompt_image": "Extract the text from the image",
97
+ "prompt_video": "Extract the text from the video",
98
+ "prompt_audio": "Extract the text from the audio",
99
+ },
100
+ ),
101
+ )
102
+
103
+
104
+ def test_extractor_adapter_from_type_invalid():
105
+ with pytest.raises(ValueError, match="Unhandled enum value: fake_type"):
106
+ extractor_adapter_from_type(
107
+ "fake_type",
108
+ ExtractorConfig(
109
+ name="test-extractor",
110
+ extractor_type=ExtractorType.LITELLM,
111
+ model_provider_name="invalid_provider",
112
+ model_name="some-model",
113
+ properties={
114
+ "prompt_document": "Extract the text from the document",
115
+ "prompt_image": "Extract the text from the image",
116
+ "prompt_video": "Extract the text from the video",
117
+ "prompt_audio": "Extract the text from the audio",
118
+ },
119
+ ),
120
+ )
121
+
122
+
123
+ @pytest.mark.parametrize(
124
+ "provider_name",
125
+ [
126
+ "openai",
127
+ "anthropic",
128
+ "gemini_api",
129
+ "amazon_bedrock",
130
+ "fireworks_ai",
131
+ "groq",
132
+ "huggingface",
133
+ ],
134
+ )
135
+ def test_extractor_adapter_from_type_different_providers(
136
+ provider_name, mock_provider_configs
137
+ ):
138
+ """Test that different providers work correctly."""
139
+ extractor = extractor_adapter_from_type(
140
+ ExtractorType.LITELLM,
141
+ ExtractorConfig(
142
+ name="test-extractor",
143
+ extractor_type=ExtractorType.LITELLM,
144
+ model_provider_name=provider_name,
145
+ model_name="test-model",
146
+ properties={
147
+ "prompt_document": "Extract the text from the document",
148
+ "prompt_image": "Extract the text from the image",
149
+ "prompt_video": "Extract the text from the video",
150
+ "prompt_audio": "Extract the text from the audio",
151
+ },
152
+ ),
153
+ )
154
+
155
+ assert isinstance(extractor, LitellmExtractor)
156
+ assert extractor.extractor_config.model_provider_name == provider_name
157
+
158
+
159
+ def test_extractor_adapter_from_type_no_config_found(mock_provider_configs):
160
+ with patch(
161
+ "kiln_ai.adapters.extractors.extractor_registry.lite_llm_core_config_for_provider"
162
+ ) as mock_lite_llm_core_config_for_provider:
163
+ mock_lite_llm_core_config_for_provider.return_value = None
164
+ with pytest.raises(
165
+ ValueError, match="No configuration found for core provider: openai"
166
+ ):
167
+ extractor_adapter_from_type(
168
+ ExtractorType.LITELLM,
169
+ ExtractorConfig(
170
+ name="test-extractor",
171
+ extractor_type=ExtractorType.LITELLM,
172
+ model_provider_name="openai",
173
+ model_name="gpt-4",
174
+ properties={
175
+ "prompt_document": "Extract the text from the document",
176
+ "prompt_image": "Extract the text from the image",
177
+ "prompt_video": "Extract the text from the video",
178
+ "prompt_audio": "Extract the text from the audio",
179
+ },
180
+ ),
181
+ )
@@ -0,0 +1,181 @@
1
+ from unittest.mock import AsyncMock
2
+
3
+ import pytest
4
+
5
+ from conftest import MockFileFactoryMimeType
6
+ from kiln_ai.adapters.extractors.extractor_runner import ExtractorRunner
7
+ from kiln_ai.datamodel.basemodel import KilnAttachmentModel
8
+ from kiln_ai.datamodel.extraction import (
9
+ Document,
10
+ Extraction,
11
+ ExtractionSource,
12
+ ExtractorConfig,
13
+ ExtractorType,
14
+ FileInfo,
15
+ Kind,
16
+ OutputFormat,
17
+ )
18
+ from kiln_ai.datamodel.project import Project
19
+
20
+
21
+ @pytest.fixture
22
+ def mock_project(tmp_path):
23
+ project = Project(
24
+ name="test",
25
+ description="test",
26
+ path=tmp_path / "project.kiln",
27
+ )
28
+ project.save_to_file()
29
+ return project
30
+
31
+
32
+ @pytest.fixture
33
+ def mock_extractor_config(mock_project):
34
+ extractor_config = ExtractorConfig(
35
+ name="test",
36
+ description="test",
37
+ output_format=OutputFormat.TEXT,
38
+ passthrough_mimetypes=[],
39
+ extractor_type=ExtractorType.LITELLM,
40
+ model_provider_name="gemini_api",
41
+ model_name="gemini-2.0-flash",
42
+ parent=mock_project,
43
+ properties={
44
+ "prompt_document": "Extract the text from the document",
45
+ "prompt_image": "Extract the text from the image",
46
+ "prompt_video": "Extract the text from the video",
47
+ "prompt_audio": "Extract the text from the audio",
48
+ },
49
+ )
50
+ extractor_config.save_to_file()
51
+ return extractor_config
52
+
53
+
54
+ @pytest.fixture
55
+ def mock_document(mock_project, mock_file_factory) -> Document:
56
+ test_pdf_file = mock_file_factory(MockFileFactoryMimeType.PDF)
57
+ document = Document(
58
+ name="test",
59
+ description="test",
60
+ kind=Kind.DOCUMENT,
61
+ original_file=FileInfo(
62
+ filename="test.pdf",
63
+ size=100,
64
+ mime_type="application/pdf",
65
+ attachment=KilnAttachmentModel.from_file(test_pdf_file),
66
+ ),
67
+ parent=mock_project,
68
+ )
69
+ document.save_to_file()
70
+ return document
71
+
72
+
73
+ @pytest.fixture
74
+ def mock_extractor_runner(mock_extractor_config, mock_document):
75
+ return ExtractorRunner(
76
+ extractor_configs=[mock_extractor_config],
77
+ documents=[mock_document],
78
+ )
79
+
80
+
81
+ # Test with and without concurrency
82
+ @pytest.mark.parametrize("concurrency", [1, 25])
83
+ @pytest.mark.asyncio
84
+ async def test_async_extractor_runner_status_updates(
85
+ mock_extractor_runner, concurrency
86
+ ):
87
+ # Real async testing!
88
+
89
+ job_count = 50
90
+ # Job objects are not the right type, but since we're mocking run_job, it doesn't matter
91
+ jobs = [{} for _ in range(job_count)]
92
+
93
+ # Mock collect_tasks to return our fake jobs
94
+ mock_extractor_runner.collect_jobs = lambda: jobs
95
+
96
+ # Mock run_job to return True immediately
97
+ mock_extractor_runner.run_job = AsyncMock(return_value=True)
98
+
99
+ # Expect the status updates in order, and 1 for each job
100
+ expected_completed_count = 0
101
+ async for progress in mock_extractor_runner.run(concurrency=concurrency):
102
+ assert progress.complete == expected_completed_count
103
+ expected_completed_count += 1
104
+ assert progress.errors == 0
105
+ assert progress.total == job_count
106
+
107
+ # Verify last status update was complete
108
+ assert expected_completed_count == job_count + 1
109
+
110
+ # Verify run_job was called for each job
111
+ assert mock_extractor_runner.run_job.call_count == job_count
112
+
113
+
114
+ def test_collect_jobs_excludes_already_run_extraction(
115
+ mock_extractor_runner, mock_document, mock_extractor_config
116
+ ):
117
+ """Test that already run documents are excluded"""
118
+ Extraction(
119
+ parent=mock_document,
120
+ source=ExtractionSource.PROCESSED,
121
+ extractor_config_id="other-extractor-config-id",
122
+ output=KilnAttachmentModel.from_data("test extraction output", "text/plain"),
123
+ ).save_to_file()
124
+
125
+ # should get the one job, since the document was not already extracted with this extractor config
126
+ jobs = mock_extractor_runner.collect_jobs()
127
+ assert len(jobs) == 1
128
+ assert jobs[0].doc.id == mock_document.id
129
+ assert jobs[0].extractor_config.id == mock_extractor_config.id
130
+
131
+ # Create an extraction for this document
132
+ Extraction(
133
+ parent=mock_document,
134
+ source=ExtractionSource.PROCESSED,
135
+ extractor_config_id=mock_extractor_config.id,
136
+ output=KilnAttachmentModel.from_data("test extraction output", "text/plain"),
137
+ ).save_to_file()
138
+
139
+ jobs = mock_extractor_runner.collect_jobs()
140
+
141
+ # should now get no jobs since the document was already extracted with this extractor config
142
+ assert len(jobs) == 0
143
+
144
+
145
+ def test_collect_jobs_multiple_extractor_configs(
146
+ mock_extractor_runner,
147
+ mock_document,
148
+ mock_extractor_config,
149
+ mock_project,
150
+ ):
151
+ """Test handling multiple extractor configs"""
152
+ second_config = ExtractorConfig(
153
+ name="test2",
154
+ description="test2",
155
+ output_format=OutputFormat.TEXT,
156
+ passthrough_mimetypes=[],
157
+ extractor_type=ExtractorType.LITELLM,
158
+ parent=mock_project,
159
+ model_provider_name="gemini_api",
160
+ model_name="gemini-2.0-flash",
161
+ properties={
162
+ "prompt_document": "Extract the text from the document",
163
+ "prompt_image": "Extract the text from the image",
164
+ "prompt_video": "Extract the text from the video",
165
+ "prompt_audio": "Extract the text from the audio",
166
+ },
167
+ )
168
+ second_config.save_to_file()
169
+
170
+ runner = ExtractorRunner(
171
+ extractor_configs=[mock_extractor_config, second_config],
172
+ documents=[mock_document],
173
+ )
174
+ jobs = runner.collect_jobs()
175
+
176
+ # Should get 2 jobs, one for each config
177
+ assert len(jobs) == 2
178
+ assert {job.extractor_config.id for job in jobs} == {
179
+ second_config.id,
180
+ mock_extractor_config.id,
181
+ }