kiln-ai 0.20.1__py3-none-any.whl → 0.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +6 -0
- kiln_ai/adapters/adapter_registry.py +43 -226
- kiln_ai/adapters/chunkers/__init__.py +13 -0
- kiln_ai/adapters/chunkers/base_chunker.py +42 -0
- kiln_ai/adapters/chunkers/chunker_registry.py +16 -0
- kiln_ai/adapters/chunkers/fixed_window_chunker.py +39 -0
- kiln_ai/adapters/chunkers/helpers.py +23 -0
- kiln_ai/adapters/chunkers/test_base_chunker.py +63 -0
- kiln_ai/adapters/chunkers/test_chunker_registry.py +28 -0
- kiln_ai/adapters/chunkers/test_fixed_window_chunker.py +346 -0
- kiln_ai/adapters/chunkers/test_helpers.py +75 -0
- kiln_ai/adapters/data_gen/test_data_gen_task.py +9 -3
- kiln_ai/adapters/embedding/__init__.py +0 -0
- kiln_ai/adapters/embedding/base_embedding_adapter.py +44 -0
- kiln_ai/adapters/embedding/embedding_registry.py +32 -0
- kiln_ai/adapters/embedding/litellm_embedding_adapter.py +199 -0
- kiln_ai/adapters/embedding/test_base_embedding_adapter.py +283 -0
- kiln_ai/adapters/embedding/test_embedding_registry.py +166 -0
- kiln_ai/adapters/embedding/test_litellm_embedding_adapter.py +1149 -0
- kiln_ai/adapters/eval/eval_runner.py +6 -2
- kiln_ai/adapters/eval/test_base_eval.py +1 -3
- kiln_ai/adapters/eval/test_g_eval.py +1 -1
- kiln_ai/adapters/extractors/__init__.py +18 -0
- kiln_ai/adapters/extractors/base_extractor.py +72 -0
- kiln_ai/adapters/extractors/encoding.py +20 -0
- kiln_ai/adapters/extractors/extractor_registry.py +44 -0
- kiln_ai/adapters/extractors/extractor_runner.py +112 -0
- kiln_ai/adapters/extractors/litellm_extractor.py +386 -0
- kiln_ai/adapters/extractors/test_base_extractor.py +244 -0
- kiln_ai/adapters/extractors/test_encoding.py +54 -0
- kiln_ai/adapters/extractors/test_extractor_registry.py +181 -0
- kiln_ai/adapters/extractors/test_extractor_runner.py +181 -0
- kiln_ai/adapters/extractors/test_litellm_extractor.py +1192 -0
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +2 -2
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +2 -6
- kiln_ai/adapters/fine_tune/test_together_finetune.py +2 -6
- kiln_ai/adapters/ml_embedding_model_list.py +192 -0
- kiln_ai/adapters/ml_model_list.py +382 -4
- kiln_ai/adapters/model_adapters/litellm_adapter.py +7 -69
- kiln_ai/adapters/model_adapters/test_litellm_adapter.py +1 -1
- kiln_ai/adapters/model_adapters/test_structured_output.py +3 -1
- kiln_ai/adapters/ollama_tools.py +69 -12
- kiln_ai/adapters/provider_tools.py +190 -46
- kiln_ai/adapters/rag/deduplication.py +49 -0
- kiln_ai/adapters/rag/progress.py +252 -0
- kiln_ai/adapters/rag/rag_runners.py +844 -0
- kiln_ai/adapters/rag/test_deduplication.py +195 -0
- kiln_ai/adapters/rag/test_progress.py +785 -0
- kiln_ai/adapters/rag/test_rag_runners.py +2376 -0
- kiln_ai/adapters/remote_config.py +80 -8
- kiln_ai/adapters/test_adapter_registry.py +579 -86
- kiln_ai/adapters/test_ml_embedding_model_list.py +429 -0
- kiln_ai/adapters/test_ml_model_list.py +212 -0
- kiln_ai/adapters/test_ollama_tools.py +340 -1
- kiln_ai/adapters/test_prompt_builders.py +1 -1
- kiln_ai/adapters/test_provider_tools.py +199 -8
- kiln_ai/adapters/test_remote_config.py +551 -56
- kiln_ai/adapters/vector_store/__init__.py +1 -0
- kiln_ai/adapters/vector_store/base_vector_store_adapter.py +83 -0
- kiln_ai/adapters/vector_store/lancedb_adapter.py +389 -0
- kiln_ai/adapters/vector_store/test_base_vector_store.py +160 -0
- kiln_ai/adapters/vector_store/test_lancedb_adapter.py +1841 -0
- kiln_ai/adapters/vector_store/test_vector_store_registry.py +199 -0
- kiln_ai/adapters/vector_store/vector_store_registry.py +33 -0
- kiln_ai/datamodel/__init__.py +16 -13
- kiln_ai/datamodel/basemodel.py +170 -1
- kiln_ai/datamodel/chunk.py +158 -0
- kiln_ai/datamodel/datamodel_enums.py +27 -0
- kiln_ai/datamodel/embedding.py +64 -0
- kiln_ai/datamodel/extraction.py +303 -0
- kiln_ai/datamodel/project.py +33 -1
- kiln_ai/datamodel/rag.py +79 -0
- kiln_ai/datamodel/test_attachment.py +649 -0
- kiln_ai/datamodel/test_basemodel.py +1 -1
- kiln_ai/datamodel/test_chunk_models.py +317 -0
- kiln_ai/datamodel/test_dataset_split.py +1 -1
- kiln_ai/datamodel/test_embedding_models.py +448 -0
- kiln_ai/datamodel/test_eval_model.py +6 -6
- kiln_ai/datamodel/test_extraction_chunk.py +206 -0
- kiln_ai/datamodel/test_extraction_model.py +470 -0
- kiln_ai/datamodel/test_rag.py +641 -0
- kiln_ai/datamodel/test_tool_id.py +81 -0
- kiln_ai/datamodel/test_vector_store.py +320 -0
- kiln_ai/datamodel/tool_id.py +22 -0
- kiln_ai/datamodel/vector_store.py +141 -0
- kiln_ai/tools/mcp_session_manager.py +4 -1
- kiln_ai/tools/rag_tools.py +157 -0
- kiln_ai/tools/test_mcp_session_manager.py +1 -1
- kiln_ai/tools/test_rag_tools.py +848 -0
- kiln_ai/tools/test_tool_registry.py +91 -2
- kiln_ai/tools/tool_registry.py +21 -0
- kiln_ai/utils/__init__.py +3 -0
- kiln_ai/utils/async_job_runner.py +62 -17
- kiln_ai/utils/config.py +2 -2
- kiln_ai/utils/env.py +15 -0
- kiln_ai/utils/filesystem.py +14 -0
- kiln_ai/utils/filesystem_cache.py +60 -0
- kiln_ai/utils/litellm.py +94 -0
- kiln_ai/utils/lock.py +100 -0
- kiln_ai/utils/mime_type.py +38 -0
- kiln_ai/utils/pdf_utils.py +38 -0
- kiln_ai/utils/test_async_job_runner.py +151 -35
- kiln_ai/utils/test_env.py +142 -0
- kiln_ai/utils/test_filesystem_cache.py +316 -0
- kiln_ai/utils/test_litellm.py +206 -0
- kiln_ai/utils/test_lock.py +185 -0
- kiln_ai/utils/test_mime_type.py +66 -0
- kiln_ai/utils/test_pdf_utils.py +73 -0
- kiln_ai/utils/test_uuid.py +111 -0
- kiln_ai/utils/test_validation.py +524 -0
- kiln_ai/utils/uuid.py +9 -0
- kiln_ai/utils/validation.py +90 -0
- {kiln_ai-0.20.1.dist-info → kiln_ai-0.21.0.dist-info}/METADATA +7 -1
- kiln_ai-0.21.0.dist-info/RECORD +211 -0
- kiln_ai-0.20.1.dist-info/RECORD +0 -138
- {kiln_ai-0.20.1.dist-info → kiln_ai-0.21.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.20.1.dist-info → kiln_ai-0.21.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -0,0 +1,244 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
from unittest.mock import patch
|
|
3
|
+
|
|
4
|
+
import pytest
|
|
5
|
+
|
|
6
|
+
from kiln_ai.adapters.extractors.base_extractor import (
|
|
7
|
+
BaseExtractor,
|
|
8
|
+
ExtractionInput,
|
|
9
|
+
ExtractionOutput,
|
|
10
|
+
)
|
|
11
|
+
from kiln_ai.datamodel.extraction import ExtractorConfig, ExtractorType, OutputFormat
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class MockBaseExtractor(BaseExtractor):
|
|
15
|
+
async def _extract(self, input: ExtractionInput) -> ExtractionOutput:
|
|
16
|
+
return ExtractionOutput(
|
|
17
|
+
is_passthrough=False,
|
|
18
|
+
content="mock concrete extractor output",
|
|
19
|
+
content_format=OutputFormat.MARKDOWN,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@pytest.fixture
|
|
24
|
+
def mock_litellm_properties():
|
|
25
|
+
return {
|
|
26
|
+
"prompt_document": "mock prompt for document",
|
|
27
|
+
"prompt_image": "mock prompt for image",
|
|
28
|
+
"prompt_video": "mock prompt for video",
|
|
29
|
+
"prompt_audio": "mock prompt for audio",
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@pytest.fixture
|
|
34
|
+
def mock_extractor(mock_litellm_properties):
|
|
35
|
+
return MockBaseExtractor(
|
|
36
|
+
ExtractorConfig(
|
|
37
|
+
name="mock",
|
|
38
|
+
model_provider_name="gemini_api",
|
|
39
|
+
model_name="gemini-2.0-flash",
|
|
40
|
+
extractor_type=ExtractorType.LITELLM,
|
|
41
|
+
output_format=OutputFormat.MARKDOWN,
|
|
42
|
+
properties=mock_litellm_properties,
|
|
43
|
+
)
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def mock_extractor_with_passthroughs(
|
|
48
|
+
properties: dict[str, Any],
|
|
49
|
+
mimetypes: list[OutputFormat],
|
|
50
|
+
output_format: OutputFormat,
|
|
51
|
+
):
|
|
52
|
+
return MockBaseExtractor(
|
|
53
|
+
ExtractorConfig(
|
|
54
|
+
name="mock",
|
|
55
|
+
model_provider_name="gemini_api",
|
|
56
|
+
model_name="gemini-2.0-flash",
|
|
57
|
+
extractor_type=ExtractorType.LITELLM,
|
|
58
|
+
passthrough_mimetypes=mimetypes,
|
|
59
|
+
output_format=output_format,
|
|
60
|
+
properties=properties,
|
|
61
|
+
)
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def test_should_passthrough(mock_litellm_properties):
|
|
66
|
+
extractor = mock_extractor_with_passthroughs(
|
|
67
|
+
mock_litellm_properties,
|
|
68
|
+
[OutputFormat.TEXT, OutputFormat.MARKDOWN],
|
|
69
|
+
OutputFormat.TEXT,
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
# should passthrough
|
|
73
|
+
assert extractor._should_passthrough("text/plain")
|
|
74
|
+
assert extractor._should_passthrough("text/markdown")
|
|
75
|
+
|
|
76
|
+
# should not passthrough
|
|
77
|
+
assert not extractor._should_passthrough("image/png")
|
|
78
|
+
assert not extractor._should_passthrough("application/pdf")
|
|
79
|
+
assert not extractor._should_passthrough("text/html")
|
|
80
|
+
assert not extractor._should_passthrough("image/jpeg")
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
async def test_extract_passthrough(mock_litellm_properties):
|
|
84
|
+
"""
|
|
85
|
+
Tests that when a file's MIME type is configured for passthrough, the extractor skips
|
|
86
|
+
the concrete extraction method and returns the file's contents directly with the
|
|
87
|
+
correct passthrough output format.
|
|
88
|
+
"""
|
|
89
|
+
extractor = mock_extractor_with_passthroughs(
|
|
90
|
+
mock_litellm_properties,
|
|
91
|
+
[OutputFormat.TEXT, OutputFormat.MARKDOWN],
|
|
92
|
+
OutputFormat.TEXT,
|
|
93
|
+
)
|
|
94
|
+
with (
|
|
95
|
+
patch.object(
|
|
96
|
+
extractor,
|
|
97
|
+
"_extract",
|
|
98
|
+
return_value=ExtractionOutput(
|
|
99
|
+
is_passthrough=False,
|
|
100
|
+
content="mock concrete extractor output",
|
|
101
|
+
content_format=OutputFormat.TEXT,
|
|
102
|
+
),
|
|
103
|
+
) as mock_extract,
|
|
104
|
+
patch(
|
|
105
|
+
"pathlib.Path.read_text",
|
|
106
|
+
return_value=b"test content",
|
|
107
|
+
),
|
|
108
|
+
):
|
|
109
|
+
result = await extractor.extract(
|
|
110
|
+
ExtractionInput(
|
|
111
|
+
path="test.txt",
|
|
112
|
+
mime_type="text/plain",
|
|
113
|
+
)
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
# Verify _extract was not called
|
|
117
|
+
mock_extract.assert_not_called()
|
|
118
|
+
|
|
119
|
+
# Verify correct passthrough result
|
|
120
|
+
assert result.is_passthrough
|
|
121
|
+
assert result.content == "test content"
|
|
122
|
+
assert result.content_format == OutputFormat.TEXT
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
@pytest.mark.parametrize(
|
|
126
|
+
"output_format",
|
|
127
|
+
[
|
|
128
|
+
"text/plain",
|
|
129
|
+
"text/markdown",
|
|
130
|
+
],
|
|
131
|
+
)
|
|
132
|
+
async def test_extract_passthrough_output_format(
|
|
133
|
+
mock_litellm_properties, output_format
|
|
134
|
+
):
|
|
135
|
+
extractor = mock_extractor_with_passthroughs(
|
|
136
|
+
mock_litellm_properties,
|
|
137
|
+
[OutputFormat.TEXT, OutputFormat.MARKDOWN],
|
|
138
|
+
output_format,
|
|
139
|
+
)
|
|
140
|
+
with (
|
|
141
|
+
patch.object(
|
|
142
|
+
extractor,
|
|
143
|
+
"_extract",
|
|
144
|
+
return_value=ExtractionOutput(
|
|
145
|
+
is_passthrough=False,
|
|
146
|
+
content="mock concrete extractor output",
|
|
147
|
+
content_format=output_format,
|
|
148
|
+
),
|
|
149
|
+
) as mock_extract,
|
|
150
|
+
patch(
|
|
151
|
+
"pathlib.Path.read_text",
|
|
152
|
+
return_value="test content",
|
|
153
|
+
),
|
|
154
|
+
):
|
|
155
|
+
result = await extractor.extract(
|
|
156
|
+
ExtractionInput(
|
|
157
|
+
path="test.txt",
|
|
158
|
+
mime_type="text/plain",
|
|
159
|
+
)
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
# Verify _extract was not called
|
|
163
|
+
mock_extract.assert_not_called()
|
|
164
|
+
|
|
165
|
+
# Verify correct passthrough result
|
|
166
|
+
assert result.is_passthrough
|
|
167
|
+
assert result.content == "test content"
|
|
168
|
+
assert result.content_format == output_format
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
@pytest.mark.parametrize(
|
|
172
|
+
"path, mime_type, output_format",
|
|
173
|
+
[
|
|
174
|
+
("test.mp3", "audio/mpeg", OutputFormat.TEXT),
|
|
175
|
+
("test.png", "image/png", OutputFormat.TEXT),
|
|
176
|
+
("test.pdf", "application/pdf", OutputFormat.TEXT),
|
|
177
|
+
("test.txt", "text/plain", OutputFormat.MARKDOWN),
|
|
178
|
+
("test.txt", "text/markdown", OutputFormat.MARKDOWN),
|
|
179
|
+
("test.html", "text/html", OutputFormat.MARKDOWN),
|
|
180
|
+
],
|
|
181
|
+
)
|
|
182
|
+
async def test_extract_non_passthrough(
|
|
183
|
+
mock_extractor, path: str, mime_type: str, output_format: OutputFormat
|
|
184
|
+
):
|
|
185
|
+
with (
|
|
186
|
+
patch.object(
|
|
187
|
+
mock_extractor,
|
|
188
|
+
"_extract",
|
|
189
|
+
return_value=ExtractionOutput(
|
|
190
|
+
is_passthrough=False,
|
|
191
|
+
content="mock concrete extractor output",
|
|
192
|
+
content_format=output_format,
|
|
193
|
+
),
|
|
194
|
+
) as mock_extract,
|
|
195
|
+
):
|
|
196
|
+
# first we call the base class extract method
|
|
197
|
+
result = await mock_extractor.extract(
|
|
198
|
+
ExtractionInput(
|
|
199
|
+
path=path,
|
|
200
|
+
mime_type=mime_type,
|
|
201
|
+
)
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
# then we call the subclass _extract method and add validated mime_type
|
|
205
|
+
mock_extract.assert_called_once_with(
|
|
206
|
+
ExtractionInput(
|
|
207
|
+
path=path,
|
|
208
|
+
mime_type=mime_type,
|
|
209
|
+
)
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
assert not result.is_passthrough
|
|
213
|
+
assert result.content == "mock concrete extractor output"
|
|
214
|
+
assert result.content_format == output_format
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
async def test_default_output_format(mock_litellm_properties):
|
|
218
|
+
config = ExtractorConfig(
|
|
219
|
+
name="mock",
|
|
220
|
+
model_provider_name="gemini_api",
|
|
221
|
+
model_name="gemini-2.0-flash",
|
|
222
|
+
extractor_type=ExtractorType.LITELLM,
|
|
223
|
+
properties=mock_litellm_properties,
|
|
224
|
+
)
|
|
225
|
+
assert config.output_format == OutputFormat.MARKDOWN
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
async def test_extract_failure_from_concrete_extractor(mock_extractor):
|
|
229
|
+
with patch.object(
|
|
230
|
+
mock_extractor,
|
|
231
|
+
"_extract",
|
|
232
|
+
side_effect=Exception("error from concrete extractor"),
|
|
233
|
+
):
|
|
234
|
+
with pytest.raises(ValueError, match="error from concrete extractor"):
|
|
235
|
+
await mock_extractor.extract(
|
|
236
|
+
ExtractionInput(
|
|
237
|
+
path="test.txt",
|
|
238
|
+
mime_type="text/plain",
|
|
239
|
+
)
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
async def test_output_format(mock_extractor):
|
|
244
|
+
assert mock_extractor.output_format() == OutputFormat.MARKDOWN
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
|
|
5
|
+
from conftest import MockFileFactoryMimeType
|
|
6
|
+
from kiln_ai.adapters.extractors.encoding import from_base64_url, to_base64_url
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
async def test_to_base64_url(mock_file_factory):
|
|
10
|
+
mock_file = mock_file_factory(MockFileFactoryMimeType.JPEG)
|
|
11
|
+
|
|
12
|
+
byte_data = Path(mock_file).read_bytes()
|
|
13
|
+
|
|
14
|
+
# encode the byte data
|
|
15
|
+
base64_url = to_base64_url("image/jpeg", byte_data)
|
|
16
|
+
assert base64_url.startswith("data:image/jpeg;base64,")
|
|
17
|
+
|
|
18
|
+
# decode the base64 url
|
|
19
|
+
assert from_base64_url(base64_url) == byte_data
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def test_from_base64_url_invalid_format_no_data_prefix():
|
|
23
|
+
"""Test that from_base64_url raises ValueError when input doesn't start with 'data:'"""
|
|
24
|
+
with pytest.raises(ValueError, match="Invalid base64 URL format"):
|
|
25
|
+
from_base64_url("not-a-data-url")
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def test_from_base64_url_invalid_format_no_comma():
|
|
29
|
+
"""Test that from_base64_url raises ValueError when input doesn't contain a comma"""
|
|
30
|
+
with pytest.raises(ValueError, match="Invalid base64 URL format"):
|
|
31
|
+
from_base64_url("data:image/jpeg;base64")
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def test_from_base64_url_invalid_parts():
|
|
35
|
+
"""Test that from_base64_url raises ValueError when splitting by comma doesn't result in exactly 2 parts"""
|
|
36
|
+
with pytest.raises(ValueError, match="Invalid base64 URL format"):
|
|
37
|
+
from_base64_url(",part2")
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def test_from_base64_url_base64_decode_failure():
|
|
41
|
+
"""Test that from_base64_url raises ValueError when base64 decoding fails"""
|
|
42
|
+
with pytest.raises(ValueError, match="Failed to decode base64 data"):
|
|
43
|
+
from_base64_url("-base64-data!")
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def test_from_base64_url_valid_format():
|
|
47
|
+
"""Test that from_base64_url works with valid base64 URL format"""
|
|
48
|
+
# Create a simple valid base64 URL
|
|
49
|
+
test_data = b"Hello, World!"
|
|
50
|
+
base64_encoded = "SGVsbG8sIFdvcmxkIQ=="
|
|
51
|
+
base64_url = f"data:text/plain;base64,{base64_encoded}"
|
|
52
|
+
|
|
53
|
+
result = from_base64_url(base64_url)
|
|
54
|
+
assert result == test_data
|
|
@@ -0,0 +1,181 @@
|
|
|
1
|
+
from unittest.mock import patch
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
|
|
5
|
+
from kiln_ai.adapters.extractors.extractor_registry import extractor_adapter_from_type
|
|
6
|
+
from kiln_ai.adapters.extractors.litellm_extractor import LitellmExtractor
|
|
7
|
+
from kiln_ai.adapters.ml_model_list import ModelProviderName
|
|
8
|
+
from kiln_ai.adapters.provider_tools import LiteLlmCoreConfig
|
|
9
|
+
from kiln_ai.datamodel.extraction import ExtractorConfig, ExtractorType
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@pytest.fixture
|
|
13
|
+
def mock_provider_configs():
|
|
14
|
+
with patch("kiln_ai.utils.config.Config.shared") as mock_config:
|
|
15
|
+
mock_config.return_value.open_ai_api_key = "test-openai-key"
|
|
16
|
+
mock_config.return_value.gemini_api_key = "test-gemini-key"
|
|
17
|
+
mock_config.return_value.anthropic_api_key = "test-anthropic-key"
|
|
18
|
+
mock_config.return_value.bedrock_access_key = "test-amazon-bedrock-key"
|
|
19
|
+
mock_config.return_value.bedrock_secret_key = "test-amazon-bedrock-secret-key"
|
|
20
|
+
mock_config.return_value.fireworks_api_key = "test-fireworks-key"
|
|
21
|
+
mock_config.return_value.groq_api_key = "test-groq-key"
|
|
22
|
+
mock_config.return_value.huggingface_api_key = "test-huggingface-key"
|
|
23
|
+
yield mock_config
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def test_extractor_adapter_from_type(mock_provider_configs):
|
|
27
|
+
extractor = extractor_adapter_from_type(
|
|
28
|
+
ExtractorType.LITELLM,
|
|
29
|
+
ExtractorConfig(
|
|
30
|
+
name="test-extractor",
|
|
31
|
+
extractor_type=ExtractorType.LITELLM,
|
|
32
|
+
model_provider_name="gemini_api",
|
|
33
|
+
model_name="gemini-2.0-flash",
|
|
34
|
+
properties={
|
|
35
|
+
"prompt_document": "Extract the text from the document",
|
|
36
|
+
"prompt_image": "Extract the text from the image",
|
|
37
|
+
"prompt_video": "Extract the text from the video",
|
|
38
|
+
"prompt_audio": "Extract the text from the audio",
|
|
39
|
+
},
|
|
40
|
+
),
|
|
41
|
+
)
|
|
42
|
+
assert isinstance(extractor, LitellmExtractor)
|
|
43
|
+
assert extractor.extractor_config.model_name == "gemini-2.0-flash"
|
|
44
|
+
assert extractor.extractor_config.model_provider_name == "gemini_api"
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@patch(
|
|
48
|
+
"kiln_ai.adapters.extractors.extractor_registry.lite_llm_core_config_for_provider"
|
|
49
|
+
)
|
|
50
|
+
def test_extractor_adapter_from_type_uses_litellm_core_config(
|
|
51
|
+
mock_get_litellm_core_config,
|
|
52
|
+
):
|
|
53
|
+
"""Test that extractor receives auth details from provider_tools."""
|
|
54
|
+
mock_litellm_core_config = LiteLlmCoreConfig(
|
|
55
|
+
base_url="https://test.com",
|
|
56
|
+
additional_body_options={"api_key": "test-key"},
|
|
57
|
+
default_headers={},
|
|
58
|
+
)
|
|
59
|
+
mock_get_litellm_core_config.return_value = mock_litellm_core_config
|
|
60
|
+
|
|
61
|
+
extractor = extractor_adapter_from_type(
|
|
62
|
+
ExtractorType.LITELLM,
|
|
63
|
+
ExtractorConfig(
|
|
64
|
+
name="test-extractor",
|
|
65
|
+
extractor_type=ExtractorType.LITELLM,
|
|
66
|
+
model_provider_name="openai",
|
|
67
|
+
model_name="gpt-4",
|
|
68
|
+
properties={
|
|
69
|
+
"prompt_document": "Extract the text from the document",
|
|
70
|
+
"prompt_image": "Extract the text from the image",
|
|
71
|
+
"prompt_video": "Extract the text from the video",
|
|
72
|
+
"prompt_audio": "Extract the text from the audio",
|
|
73
|
+
},
|
|
74
|
+
),
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
assert isinstance(extractor, LitellmExtractor)
|
|
78
|
+
assert extractor.litellm_core_config == mock_litellm_core_config
|
|
79
|
+
mock_get_litellm_core_config.assert_called_once_with(ModelProviderName.openai)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def test_extractor_adapter_from_type_invalid_provider():
|
|
83
|
+
"""Test that invalid model provider names raise a clear error."""
|
|
84
|
+
with pytest.raises(
|
|
85
|
+
ValueError, match="Unsupported model provider name: invalid_provider"
|
|
86
|
+
):
|
|
87
|
+
extractor_adapter_from_type(
|
|
88
|
+
ExtractorType.LITELLM,
|
|
89
|
+
ExtractorConfig(
|
|
90
|
+
name="test-extractor",
|
|
91
|
+
extractor_type=ExtractorType.LITELLM,
|
|
92
|
+
model_provider_name="invalid_provider",
|
|
93
|
+
model_name="some-model",
|
|
94
|
+
properties={
|
|
95
|
+
"prompt_document": "Extract the text from the document",
|
|
96
|
+
"prompt_image": "Extract the text from the image",
|
|
97
|
+
"prompt_video": "Extract the text from the video",
|
|
98
|
+
"prompt_audio": "Extract the text from the audio",
|
|
99
|
+
},
|
|
100
|
+
),
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def test_extractor_adapter_from_type_invalid():
|
|
105
|
+
with pytest.raises(ValueError, match="Unhandled enum value: fake_type"):
|
|
106
|
+
extractor_adapter_from_type(
|
|
107
|
+
"fake_type",
|
|
108
|
+
ExtractorConfig(
|
|
109
|
+
name="test-extractor",
|
|
110
|
+
extractor_type=ExtractorType.LITELLM,
|
|
111
|
+
model_provider_name="invalid_provider",
|
|
112
|
+
model_name="some-model",
|
|
113
|
+
properties={
|
|
114
|
+
"prompt_document": "Extract the text from the document",
|
|
115
|
+
"prompt_image": "Extract the text from the image",
|
|
116
|
+
"prompt_video": "Extract the text from the video",
|
|
117
|
+
"prompt_audio": "Extract the text from the audio",
|
|
118
|
+
},
|
|
119
|
+
),
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
@pytest.mark.parametrize(
|
|
124
|
+
"provider_name",
|
|
125
|
+
[
|
|
126
|
+
"openai",
|
|
127
|
+
"anthropic",
|
|
128
|
+
"gemini_api",
|
|
129
|
+
"amazon_bedrock",
|
|
130
|
+
"fireworks_ai",
|
|
131
|
+
"groq",
|
|
132
|
+
"huggingface",
|
|
133
|
+
],
|
|
134
|
+
)
|
|
135
|
+
def test_extractor_adapter_from_type_different_providers(
|
|
136
|
+
provider_name, mock_provider_configs
|
|
137
|
+
):
|
|
138
|
+
"""Test that different providers work correctly."""
|
|
139
|
+
extractor = extractor_adapter_from_type(
|
|
140
|
+
ExtractorType.LITELLM,
|
|
141
|
+
ExtractorConfig(
|
|
142
|
+
name="test-extractor",
|
|
143
|
+
extractor_type=ExtractorType.LITELLM,
|
|
144
|
+
model_provider_name=provider_name,
|
|
145
|
+
model_name="test-model",
|
|
146
|
+
properties={
|
|
147
|
+
"prompt_document": "Extract the text from the document",
|
|
148
|
+
"prompt_image": "Extract the text from the image",
|
|
149
|
+
"prompt_video": "Extract the text from the video",
|
|
150
|
+
"prompt_audio": "Extract the text from the audio",
|
|
151
|
+
},
|
|
152
|
+
),
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
assert isinstance(extractor, LitellmExtractor)
|
|
156
|
+
assert extractor.extractor_config.model_provider_name == provider_name
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def test_extractor_adapter_from_type_no_config_found(mock_provider_configs):
|
|
160
|
+
with patch(
|
|
161
|
+
"kiln_ai.adapters.extractors.extractor_registry.lite_llm_core_config_for_provider"
|
|
162
|
+
) as mock_lite_llm_core_config_for_provider:
|
|
163
|
+
mock_lite_llm_core_config_for_provider.return_value = None
|
|
164
|
+
with pytest.raises(
|
|
165
|
+
ValueError, match="No configuration found for core provider: openai"
|
|
166
|
+
):
|
|
167
|
+
extractor_adapter_from_type(
|
|
168
|
+
ExtractorType.LITELLM,
|
|
169
|
+
ExtractorConfig(
|
|
170
|
+
name="test-extractor",
|
|
171
|
+
extractor_type=ExtractorType.LITELLM,
|
|
172
|
+
model_provider_name="openai",
|
|
173
|
+
model_name="gpt-4",
|
|
174
|
+
properties={
|
|
175
|
+
"prompt_document": "Extract the text from the document",
|
|
176
|
+
"prompt_image": "Extract the text from the image",
|
|
177
|
+
"prompt_video": "Extract the text from the video",
|
|
178
|
+
"prompt_audio": "Extract the text from the audio",
|
|
179
|
+
},
|
|
180
|
+
),
|
|
181
|
+
)
|
|
@@ -0,0 +1,181 @@
|
|
|
1
|
+
from unittest.mock import AsyncMock
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
|
|
5
|
+
from conftest import MockFileFactoryMimeType
|
|
6
|
+
from kiln_ai.adapters.extractors.extractor_runner import ExtractorRunner
|
|
7
|
+
from kiln_ai.datamodel.basemodel import KilnAttachmentModel
|
|
8
|
+
from kiln_ai.datamodel.extraction import (
|
|
9
|
+
Document,
|
|
10
|
+
Extraction,
|
|
11
|
+
ExtractionSource,
|
|
12
|
+
ExtractorConfig,
|
|
13
|
+
ExtractorType,
|
|
14
|
+
FileInfo,
|
|
15
|
+
Kind,
|
|
16
|
+
OutputFormat,
|
|
17
|
+
)
|
|
18
|
+
from kiln_ai.datamodel.project import Project
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@pytest.fixture
|
|
22
|
+
def mock_project(tmp_path):
|
|
23
|
+
project = Project(
|
|
24
|
+
name="test",
|
|
25
|
+
description="test",
|
|
26
|
+
path=tmp_path / "project.kiln",
|
|
27
|
+
)
|
|
28
|
+
project.save_to_file()
|
|
29
|
+
return project
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@pytest.fixture
|
|
33
|
+
def mock_extractor_config(mock_project):
|
|
34
|
+
extractor_config = ExtractorConfig(
|
|
35
|
+
name="test",
|
|
36
|
+
description="test",
|
|
37
|
+
output_format=OutputFormat.TEXT,
|
|
38
|
+
passthrough_mimetypes=[],
|
|
39
|
+
extractor_type=ExtractorType.LITELLM,
|
|
40
|
+
model_provider_name="gemini_api",
|
|
41
|
+
model_name="gemini-2.0-flash",
|
|
42
|
+
parent=mock_project,
|
|
43
|
+
properties={
|
|
44
|
+
"prompt_document": "Extract the text from the document",
|
|
45
|
+
"prompt_image": "Extract the text from the image",
|
|
46
|
+
"prompt_video": "Extract the text from the video",
|
|
47
|
+
"prompt_audio": "Extract the text from the audio",
|
|
48
|
+
},
|
|
49
|
+
)
|
|
50
|
+
extractor_config.save_to_file()
|
|
51
|
+
return extractor_config
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@pytest.fixture
|
|
55
|
+
def mock_document(mock_project, mock_file_factory) -> Document:
|
|
56
|
+
test_pdf_file = mock_file_factory(MockFileFactoryMimeType.PDF)
|
|
57
|
+
document = Document(
|
|
58
|
+
name="test",
|
|
59
|
+
description="test",
|
|
60
|
+
kind=Kind.DOCUMENT,
|
|
61
|
+
original_file=FileInfo(
|
|
62
|
+
filename="test.pdf",
|
|
63
|
+
size=100,
|
|
64
|
+
mime_type="application/pdf",
|
|
65
|
+
attachment=KilnAttachmentModel.from_file(test_pdf_file),
|
|
66
|
+
),
|
|
67
|
+
parent=mock_project,
|
|
68
|
+
)
|
|
69
|
+
document.save_to_file()
|
|
70
|
+
return document
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
@pytest.fixture
|
|
74
|
+
def mock_extractor_runner(mock_extractor_config, mock_document):
|
|
75
|
+
return ExtractorRunner(
|
|
76
|
+
extractor_configs=[mock_extractor_config],
|
|
77
|
+
documents=[mock_document],
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
# Test with and without concurrency
|
|
82
|
+
@pytest.mark.parametrize("concurrency", [1, 25])
|
|
83
|
+
@pytest.mark.asyncio
|
|
84
|
+
async def test_async_extractor_runner_status_updates(
|
|
85
|
+
mock_extractor_runner, concurrency
|
|
86
|
+
):
|
|
87
|
+
# Real async testing!
|
|
88
|
+
|
|
89
|
+
job_count = 50
|
|
90
|
+
# Job objects are not the right type, but since we're mocking run_job, it doesn't matter
|
|
91
|
+
jobs = [{} for _ in range(job_count)]
|
|
92
|
+
|
|
93
|
+
# Mock collect_tasks to return our fake jobs
|
|
94
|
+
mock_extractor_runner.collect_jobs = lambda: jobs
|
|
95
|
+
|
|
96
|
+
# Mock run_job to return True immediately
|
|
97
|
+
mock_extractor_runner.run_job = AsyncMock(return_value=True)
|
|
98
|
+
|
|
99
|
+
# Expect the status updates in order, and 1 for each job
|
|
100
|
+
expected_completed_count = 0
|
|
101
|
+
async for progress in mock_extractor_runner.run(concurrency=concurrency):
|
|
102
|
+
assert progress.complete == expected_completed_count
|
|
103
|
+
expected_completed_count += 1
|
|
104
|
+
assert progress.errors == 0
|
|
105
|
+
assert progress.total == job_count
|
|
106
|
+
|
|
107
|
+
# Verify last status update was complete
|
|
108
|
+
assert expected_completed_count == job_count + 1
|
|
109
|
+
|
|
110
|
+
# Verify run_job was called for each job
|
|
111
|
+
assert mock_extractor_runner.run_job.call_count == job_count
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def test_collect_jobs_excludes_already_run_extraction(
|
|
115
|
+
mock_extractor_runner, mock_document, mock_extractor_config
|
|
116
|
+
):
|
|
117
|
+
"""Test that already run documents are excluded"""
|
|
118
|
+
Extraction(
|
|
119
|
+
parent=mock_document,
|
|
120
|
+
source=ExtractionSource.PROCESSED,
|
|
121
|
+
extractor_config_id="other-extractor-config-id",
|
|
122
|
+
output=KilnAttachmentModel.from_data("test extraction output", "text/plain"),
|
|
123
|
+
).save_to_file()
|
|
124
|
+
|
|
125
|
+
# should get the one job, since the document was not already extracted with this extractor config
|
|
126
|
+
jobs = mock_extractor_runner.collect_jobs()
|
|
127
|
+
assert len(jobs) == 1
|
|
128
|
+
assert jobs[0].doc.id == mock_document.id
|
|
129
|
+
assert jobs[0].extractor_config.id == mock_extractor_config.id
|
|
130
|
+
|
|
131
|
+
# Create an extraction for this document
|
|
132
|
+
Extraction(
|
|
133
|
+
parent=mock_document,
|
|
134
|
+
source=ExtractionSource.PROCESSED,
|
|
135
|
+
extractor_config_id=mock_extractor_config.id,
|
|
136
|
+
output=KilnAttachmentModel.from_data("test extraction output", "text/plain"),
|
|
137
|
+
).save_to_file()
|
|
138
|
+
|
|
139
|
+
jobs = mock_extractor_runner.collect_jobs()
|
|
140
|
+
|
|
141
|
+
# should now get no jobs since the document was already extracted with this extractor config
|
|
142
|
+
assert len(jobs) == 0
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def test_collect_jobs_multiple_extractor_configs(
|
|
146
|
+
mock_extractor_runner,
|
|
147
|
+
mock_document,
|
|
148
|
+
mock_extractor_config,
|
|
149
|
+
mock_project,
|
|
150
|
+
):
|
|
151
|
+
"""Test handling multiple extractor configs"""
|
|
152
|
+
second_config = ExtractorConfig(
|
|
153
|
+
name="test2",
|
|
154
|
+
description="test2",
|
|
155
|
+
output_format=OutputFormat.TEXT,
|
|
156
|
+
passthrough_mimetypes=[],
|
|
157
|
+
extractor_type=ExtractorType.LITELLM,
|
|
158
|
+
parent=mock_project,
|
|
159
|
+
model_provider_name="gemini_api",
|
|
160
|
+
model_name="gemini-2.0-flash",
|
|
161
|
+
properties={
|
|
162
|
+
"prompt_document": "Extract the text from the document",
|
|
163
|
+
"prompt_image": "Extract the text from the image",
|
|
164
|
+
"prompt_video": "Extract the text from the video",
|
|
165
|
+
"prompt_audio": "Extract the text from the audio",
|
|
166
|
+
},
|
|
167
|
+
)
|
|
168
|
+
second_config.save_to_file()
|
|
169
|
+
|
|
170
|
+
runner = ExtractorRunner(
|
|
171
|
+
extractor_configs=[mock_extractor_config, second_config],
|
|
172
|
+
documents=[mock_document],
|
|
173
|
+
)
|
|
174
|
+
jobs = runner.collect_jobs()
|
|
175
|
+
|
|
176
|
+
# Should get 2 jobs, one for each config
|
|
177
|
+
assert len(jobs) == 2
|
|
178
|
+
assert {job.extractor_config.id for job in jobs} == {
|
|
179
|
+
second_config.id,
|
|
180
|
+
mock_extractor_config.id,
|
|
181
|
+
}
|