kiln-ai 0.19.0__py3-none-any.whl → 0.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +8 -2
- kiln_ai/adapters/adapter_registry.py +43 -208
- kiln_ai/adapters/chat/chat_formatter.py +8 -12
- kiln_ai/adapters/chat/test_chat_formatter.py +6 -2
- kiln_ai/adapters/chunkers/__init__.py +13 -0
- kiln_ai/adapters/chunkers/base_chunker.py +42 -0
- kiln_ai/adapters/chunkers/chunker_registry.py +16 -0
- kiln_ai/adapters/chunkers/fixed_window_chunker.py +39 -0
- kiln_ai/adapters/chunkers/helpers.py +23 -0
- kiln_ai/adapters/chunkers/test_base_chunker.py +63 -0
- kiln_ai/adapters/chunkers/test_chunker_registry.py +28 -0
- kiln_ai/adapters/chunkers/test_fixed_window_chunker.py +346 -0
- kiln_ai/adapters/chunkers/test_helpers.py +75 -0
- kiln_ai/adapters/data_gen/test_data_gen_task.py +9 -3
- kiln_ai/adapters/docker_model_runner_tools.py +119 -0
- kiln_ai/adapters/embedding/__init__.py +0 -0
- kiln_ai/adapters/embedding/base_embedding_adapter.py +44 -0
- kiln_ai/adapters/embedding/embedding_registry.py +32 -0
- kiln_ai/adapters/embedding/litellm_embedding_adapter.py +199 -0
- kiln_ai/adapters/embedding/test_base_embedding_adapter.py +283 -0
- kiln_ai/adapters/embedding/test_embedding_registry.py +166 -0
- kiln_ai/adapters/embedding/test_litellm_embedding_adapter.py +1149 -0
- kiln_ai/adapters/eval/base_eval.py +2 -2
- kiln_ai/adapters/eval/eval_runner.py +9 -3
- kiln_ai/adapters/eval/g_eval.py +2 -2
- kiln_ai/adapters/eval/test_base_eval.py +2 -4
- kiln_ai/adapters/eval/test_g_eval.py +4 -5
- kiln_ai/adapters/extractors/__init__.py +18 -0
- kiln_ai/adapters/extractors/base_extractor.py +72 -0
- kiln_ai/adapters/extractors/encoding.py +20 -0
- kiln_ai/adapters/extractors/extractor_registry.py +44 -0
- kiln_ai/adapters/extractors/extractor_runner.py +112 -0
- kiln_ai/adapters/extractors/litellm_extractor.py +386 -0
- kiln_ai/adapters/extractors/test_base_extractor.py +244 -0
- kiln_ai/adapters/extractors/test_encoding.py +54 -0
- kiln_ai/adapters/extractors/test_extractor_registry.py +181 -0
- kiln_ai/adapters/extractors/test_extractor_runner.py +181 -0
- kiln_ai/adapters/extractors/test_litellm_extractor.py +1192 -0
- kiln_ai/adapters/fine_tune/__init__.py +1 -1
- kiln_ai/adapters/fine_tune/openai_finetune.py +14 -4
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +2 -2
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +2 -6
- kiln_ai/adapters/fine_tune/test_openai_finetune.py +108 -111
- kiln_ai/adapters/fine_tune/test_together_finetune.py +2 -6
- kiln_ai/adapters/ml_embedding_model_list.py +192 -0
- kiln_ai/adapters/ml_model_list.py +761 -37
- kiln_ai/adapters/model_adapters/base_adapter.py +51 -21
- kiln_ai/adapters/model_adapters/litellm_adapter.py +380 -138
- kiln_ai/adapters/model_adapters/test_base_adapter.py +193 -17
- kiln_ai/adapters/model_adapters/test_litellm_adapter.py +407 -2
- kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py +1103 -0
- kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +5 -5
- kiln_ai/adapters/model_adapters/test_structured_output.py +113 -5
- kiln_ai/adapters/ollama_tools.py +69 -12
- kiln_ai/adapters/parsers/__init__.py +1 -1
- kiln_ai/adapters/provider_tools.py +205 -47
- kiln_ai/adapters/rag/deduplication.py +49 -0
- kiln_ai/adapters/rag/progress.py +252 -0
- kiln_ai/adapters/rag/rag_runners.py +844 -0
- kiln_ai/adapters/rag/test_deduplication.py +195 -0
- kiln_ai/adapters/rag/test_progress.py +785 -0
- kiln_ai/adapters/rag/test_rag_runners.py +2376 -0
- kiln_ai/adapters/remote_config.py +80 -8
- kiln_ai/adapters/repair/test_repair_task.py +12 -9
- kiln_ai/adapters/run_output.py +3 -0
- kiln_ai/adapters/test_adapter_registry.py +657 -85
- kiln_ai/adapters/test_docker_model_runner_tools.py +305 -0
- kiln_ai/adapters/test_ml_embedding_model_list.py +429 -0
- kiln_ai/adapters/test_ml_model_list.py +251 -1
- kiln_ai/adapters/test_ollama_tools.py +340 -1
- kiln_ai/adapters/test_prompt_adaptors.py +13 -6
- kiln_ai/adapters/test_prompt_builders.py +1 -1
- kiln_ai/adapters/test_provider_tools.py +254 -8
- kiln_ai/adapters/test_remote_config.py +651 -58
- kiln_ai/adapters/vector_store/__init__.py +1 -0
- kiln_ai/adapters/vector_store/base_vector_store_adapter.py +83 -0
- kiln_ai/adapters/vector_store/lancedb_adapter.py +389 -0
- kiln_ai/adapters/vector_store/test_base_vector_store.py +160 -0
- kiln_ai/adapters/vector_store/test_lancedb_adapter.py +1841 -0
- kiln_ai/adapters/vector_store/test_vector_store_registry.py +199 -0
- kiln_ai/adapters/vector_store/vector_store_registry.py +33 -0
- kiln_ai/datamodel/__init__.py +39 -34
- kiln_ai/datamodel/basemodel.py +170 -1
- kiln_ai/datamodel/chunk.py +158 -0
- kiln_ai/datamodel/datamodel_enums.py +28 -0
- kiln_ai/datamodel/embedding.py +64 -0
- kiln_ai/datamodel/eval.py +1 -1
- kiln_ai/datamodel/external_tool_server.py +298 -0
- kiln_ai/datamodel/extraction.py +303 -0
- kiln_ai/datamodel/json_schema.py +25 -10
- kiln_ai/datamodel/project.py +40 -1
- kiln_ai/datamodel/rag.py +79 -0
- kiln_ai/datamodel/registry.py +0 -15
- kiln_ai/datamodel/run_config.py +62 -0
- kiln_ai/datamodel/task.py +2 -77
- kiln_ai/datamodel/task_output.py +6 -1
- kiln_ai/datamodel/task_run.py +41 -0
- kiln_ai/datamodel/test_attachment.py +649 -0
- kiln_ai/datamodel/test_basemodel.py +4 -4
- kiln_ai/datamodel/test_chunk_models.py +317 -0
- kiln_ai/datamodel/test_dataset_split.py +1 -1
- kiln_ai/datamodel/test_embedding_models.py +448 -0
- kiln_ai/datamodel/test_eval_model.py +6 -6
- kiln_ai/datamodel/test_example_models.py +175 -0
- kiln_ai/datamodel/test_external_tool_server.py +691 -0
- kiln_ai/datamodel/test_extraction_chunk.py +206 -0
- kiln_ai/datamodel/test_extraction_model.py +470 -0
- kiln_ai/datamodel/test_rag.py +641 -0
- kiln_ai/datamodel/test_registry.py +8 -3
- kiln_ai/datamodel/test_task.py +15 -47
- kiln_ai/datamodel/test_tool_id.py +320 -0
- kiln_ai/datamodel/test_vector_store.py +320 -0
- kiln_ai/datamodel/tool_id.py +105 -0
- kiln_ai/datamodel/vector_store.py +141 -0
- kiln_ai/tools/__init__.py +8 -0
- kiln_ai/tools/base_tool.py +82 -0
- kiln_ai/tools/built_in_tools/__init__.py +13 -0
- kiln_ai/tools/built_in_tools/math_tools.py +124 -0
- kiln_ai/tools/built_in_tools/test_math_tools.py +204 -0
- kiln_ai/tools/mcp_server_tool.py +95 -0
- kiln_ai/tools/mcp_session_manager.py +246 -0
- kiln_ai/tools/rag_tools.py +157 -0
- kiln_ai/tools/test_base_tools.py +199 -0
- kiln_ai/tools/test_mcp_server_tool.py +457 -0
- kiln_ai/tools/test_mcp_session_manager.py +1585 -0
- kiln_ai/tools/test_rag_tools.py +848 -0
- kiln_ai/tools/test_tool_registry.py +562 -0
- kiln_ai/tools/tool_registry.py +85 -0
- kiln_ai/utils/__init__.py +3 -0
- kiln_ai/utils/async_job_runner.py +62 -17
- kiln_ai/utils/config.py +24 -2
- kiln_ai/utils/env.py +15 -0
- kiln_ai/utils/filesystem.py +14 -0
- kiln_ai/utils/filesystem_cache.py +60 -0
- kiln_ai/utils/litellm.py +94 -0
- kiln_ai/utils/lock.py +100 -0
- kiln_ai/utils/mime_type.py +38 -0
- kiln_ai/utils/open_ai_types.py +94 -0
- kiln_ai/utils/pdf_utils.py +38 -0
- kiln_ai/utils/project_utils.py +17 -0
- kiln_ai/utils/test_async_job_runner.py +151 -35
- kiln_ai/utils/test_config.py +138 -1
- kiln_ai/utils/test_env.py +142 -0
- kiln_ai/utils/test_filesystem_cache.py +316 -0
- kiln_ai/utils/test_litellm.py +206 -0
- kiln_ai/utils/test_lock.py +185 -0
- kiln_ai/utils/test_mime_type.py +66 -0
- kiln_ai/utils/test_open_ai_types.py +131 -0
- kiln_ai/utils/test_pdf_utils.py +73 -0
- kiln_ai/utils/test_uuid.py +111 -0
- kiln_ai/utils/test_validation.py +524 -0
- kiln_ai/utils/uuid.py +9 -0
- kiln_ai/utils/validation.py +90 -0
- {kiln_ai-0.19.0.dist-info → kiln_ai-0.21.0.dist-info}/METADATA +12 -5
- kiln_ai-0.21.0.dist-info/RECORD +211 -0
- kiln_ai-0.19.0.dist-info/RECORD +0 -115
- {kiln_ai-0.19.0.dist-info → kiln_ai-0.21.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.19.0.dist-info → kiln_ai-0.21.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -0,0 +1,206 @@
|
|
|
1
|
+
import tempfile
|
|
2
|
+
import uuid
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
import pytest
|
|
6
|
+
|
|
7
|
+
from kiln_ai.datamodel.basemodel import KilnAttachmentModel
|
|
8
|
+
from kiln_ai.datamodel.chunk import Chunk, ChunkedDocument, ChunkerConfig, ChunkerType
|
|
9
|
+
from kiln_ai.datamodel.embedding import ChunkEmbeddings, Embedding, EmbeddingConfig
|
|
10
|
+
from kiln_ai.datamodel.extraction import (
|
|
11
|
+
Document,
|
|
12
|
+
Extraction,
|
|
13
|
+
ExtractionSource,
|
|
14
|
+
FileInfo,
|
|
15
|
+
Kind,
|
|
16
|
+
)
|
|
17
|
+
from kiln_ai.datamodel.project import Project
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@pytest.fixture
|
|
21
|
+
def mock_project(tmp_path):
|
|
22
|
+
project_root = tmp_path / str(uuid.uuid4())
|
|
23
|
+
project_root.mkdir()
|
|
24
|
+
project = Project(
|
|
25
|
+
name="Test Project",
|
|
26
|
+
description="Test description",
|
|
27
|
+
path=project_root / "project.kiln",
|
|
28
|
+
)
|
|
29
|
+
project.save_to_file()
|
|
30
|
+
return project
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class TestIntegration:
|
|
34
|
+
"""Integration tests for the chunk module."""
|
|
35
|
+
|
|
36
|
+
def test_full_workflow(self):
|
|
37
|
+
"""Test a complete workflow with all classes."""
|
|
38
|
+
# Create chunker properties
|
|
39
|
+
properties = {"chunk_size": 256, "chunk_overlap": 10}
|
|
40
|
+
|
|
41
|
+
# Create chunker config
|
|
42
|
+
config = ChunkerConfig(
|
|
43
|
+
name="test-chunker",
|
|
44
|
+
description="A test chunker configuration",
|
|
45
|
+
chunker_type=ChunkerType.FIXED_WINDOW,
|
|
46
|
+
properties=properties,
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
# Create a temporary file for the attachment
|
|
50
|
+
with tempfile.NamedTemporaryFile(delete=True) as tmp_file:
|
|
51
|
+
tmp_file.write(b"test content")
|
|
52
|
+
tmp_path = Path(tmp_file.name)
|
|
53
|
+
|
|
54
|
+
# Create attachment
|
|
55
|
+
attachment = KilnAttachmentModel.from_file(tmp_path)
|
|
56
|
+
|
|
57
|
+
# Create chunks
|
|
58
|
+
chunk1 = Chunk(content=attachment)
|
|
59
|
+
chunk2 = Chunk(content=attachment)
|
|
60
|
+
|
|
61
|
+
# Create chunk document
|
|
62
|
+
doc = ChunkedDocument(
|
|
63
|
+
chunks=[chunk1, chunk2],
|
|
64
|
+
chunker_config_id=config.id,
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
# Verify the complete structure
|
|
68
|
+
assert config.name == "test-chunker"
|
|
69
|
+
assert config.chunker_type == ChunkerType.FIXED_WINDOW
|
|
70
|
+
assert config.chunk_size() == 256
|
|
71
|
+
assert config.chunk_overlap() == 10
|
|
72
|
+
assert len(doc.chunks) == 2
|
|
73
|
+
assert doc.chunks[0].content == attachment
|
|
74
|
+
assert doc.chunks[1].content == attachment
|
|
75
|
+
|
|
76
|
+
def test_serialization(self, mock_project):
|
|
77
|
+
"""Test that models can be serialized and deserialized."""
|
|
78
|
+
properties = {"chunk_size": 512, "chunk_overlap": 20}
|
|
79
|
+
config = ChunkerConfig(
|
|
80
|
+
name="serialization-test",
|
|
81
|
+
chunker_type=ChunkerType.FIXED_WINDOW,
|
|
82
|
+
properties=properties,
|
|
83
|
+
parent=mock_project,
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
# Save to file
|
|
87
|
+
config.save_to_file()
|
|
88
|
+
|
|
89
|
+
# Load from file
|
|
90
|
+
config_restored = ChunkerConfig.load_from_file(config.path)
|
|
91
|
+
|
|
92
|
+
assert config_restored.name == config.name
|
|
93
|
+
assert config_restored.chunker_type == config.chunker_type
|
|
94
|
+
assert config_restored.chunk_size() == config.chunk_size()
|
|
95
|
+
assert config_restored.chunk_overlap() == config.chunk_overlap()
|
|
96
|
+
assert config_restored.parent_project().id == mock_project.id
|
|
97
|
+
|
|
98
|
+
def test_enum_serialization(self):
|
|
99
|
+
"""Test that ChunkerType enum serializes correctly."""
|
|
100
|
+
config = ChunkerConfig(
|
|
101
|
+
name="enum-test",
|
|
102
|
+
chunker_type=ChunkerType.FIXED_WINDOW,
|
|
103
|
+
properties={"chunk_size": 512, "chunk_overlap": 20},
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
config_dict = config.model_dump()
|
|
107
|
+
assert config_dict["chunker_type"] == "fixed_window"
|
|
108
|
+
|
|
109
|
+
config_restored = ChunkerConfig.model_validate(config_dict)
|
|
110
|
+
assert config_restored.chunker_type == ChunkerType.FIXED_WINDOW
|
|
111
|
+
|
|
112
|
+
def test_relationships(self, mock_project):
|
|
113
|
+
"""Test that relationships are properly validated."""
|
|
114
|
+
|
|
115
|
+
# Create a config
|
|
116
|
+
config = ChunkerConfig(
|
|
117
|
+
name="test-chunker",
|
|
118
|
+
chunker_type=ChunkerType.FIXED_WINDOW,
|
|
119
|
+
properties={"chunk_size": 512, "chunk_overlap": 20},
|
|
120
|
+
parent=mock_project,
|
|
121
|
+
)
|
|
122
|
+
config.save_to_file()
|
|
123
|
+
|
|
124
|
+
# Dummy file we will use as attachment
|
|
125
|
+
with tempfile.NamedTemporaryFile(delete=True) as tmp_file:
|
|
126
|
+
tmp_file.write(b"test content")
|
|
127
|
+
tmp_path = Path(tmp_file.name)
|
|
128
|
+
|
|
129
|
+
# Create a document
|
|
130
|
+
document = Document(
|
|
131
|
+
name="test-document",
|
|
132
|
+
description="Test document",
|
|
133
|
+
parent=mock_project,
|
|
134
|
+
original_file=FileInfo(
|
|
135
|
+
filename="test.txt",
|
|
136
|
+
size=100,
|
|
137
|
+
mime_type="text/plain",
|
|
138
|
+
attachment=KilnAttachmentModel.from_file(tmp_path),
|
|
139
|
+
),
|
|
140
|
+
kind=Kind.DOCUMENT,
|
|
141
|
+
)
|
|
142
|
+
document.save_to_file()
|
|
143
|
+
|
|
144
|
+
# Create an extraction
|
|
145
|
+
extraction = Extraction(
|
|
146
|
+
source=ExtractionSource.PROCESSED,
|
|
147
|
+
extractor_config_id=config.id,
|
|
148
|
+
output=KilnAttachmentModel.from_file(tmp_path),
|
|
149
|
+
parent=document,
|
|
150
|
+
)
|
|
151
|
+
extraction.save_to_file()
|
|
152
|
+
|
|
153
|
+
# Create some chunks
|
|
154
|
+
chunks = [Chunk(content=KilnAttachmentModel.from_file(tmp_path))] * 3
|
|
155
|
+
|
|
156
|
+
chunked_document = ChunkedDocument(
|
|
157
|
+
parent=extraction,
|
|
158
|
+
chunks=chunks,
|
|
159
|
+
chunker_config_id=config.id,
|
|
160
|
+
)
|
|
161
|
+
chunked_document.save_to_file()
|
|
162
|
+
|
|
163
|
+
assert len(chunked_document.chunks) == 3
|
|
164
|
+
|
|
165
|
+
# Check that the document chunked is associated with the correct extraction
|
|
166
|
+
assert chunked_document.parent_extraction().id == extraction.id
|
|
167
|
+
|
|
168
|
+
for chunked_document_found in extraction.chunked_documents():
|
|
169
|
+
assert chunked_document.id == chunked_document_found.id
|
|
170
|
+
|
|
171
|
+
assert len(extraction.chunked_documents()) == 1
|
|
172
|
+
|
|
173
|
+
# the chunks should have a filename prefixed with content_
|
|
174
|
+
for chunk in chunked_document.chunks:
|
|
175
|
+
filename = chunk.content.path.name
|
|
176
|
+
assert filename.startswith("content_")
|
|
177
|
+
|
|
178
|
+
# create an embedding config
|
|
179
|
+
embedding_config = EmbeddingConfig(
|
|
180
|
+
name="test-embedding-config",
|
|
181
|
+
description="Test embedding config",
|
|
182
|
+
parent=mock_project,
|
|
183
|
+
model_name="openai_text_embedding_3_small",
|
|
184
|
+
model_provider_name="openai",
|
|
185
|
+
properties={},
|
|
186
|
+
)
|
|
187
|
+
embedding_config.save_to_file()
|
|
188
|
+
|
|
189
|
+
# create chunk embeddings
|
|
190
|
+
chunk_embeddings = ChunkEmbeddings(
|
|
191
|
+
parent=chunked_document,
|
|
192
|
+
embedding_config_id=embedding_config.id,
|
|
193
|
+
embeddings=[Embedding(vector=[1.0] * 1536) for _ in range(3)],
|
|
194
|
+
)
|
|
195
|
+
chunk_embeddings.save_to_file()
|
|
196
|
+
|
|
197
|
+
retrieved_chunk_embeddings = chunked_document.chunk_embeddings()
|
|
198
|
+
assert isinstance(retrieved_chunk_embeddings, list)
|
|
199
|
+
assert len(retrieved_chunk_embeddings) == 1
|
|
200
|
+
|
|
201
|
+
# check project has the embedding config and the chunker config
|
|
202
|
+
assert (
|
|
203
|
+
mock_project.embedding_configs(readonly=True)[0].id
|
|
204
|
+
== embedding_config.id
|
|
205
|
+
)
|
|
206
|
+
assert mock_project.chunker_configs(readonly=True)[0].id == config.id
|
|
@@ -0,0 +1,470 @@
|
|
|
1
|
+
import re
|
|
2
|
+
import uuid
|
|
3
|
+
|
|
4
|
+
import pytest
|
|
5
|
+
|
|
6
|
+
from kiln_ai.datamodel.basemodel import KilnAttachmentModel
|
|
7
|
+
from kiln_ai.datamodel.extraction import (
|
|
8
|
+
Document,
|
|
9
|
+
Extraction,
|
|
10
|
+
ExtractionSource,
|
|
11
|
+
ExtractorConfig,
|
|
12
|
+
ExtractorType,
|
|
13
|
+
FileInfo,
|
|
14
|
+
Kind,
|
|
15
|
+
OutputFormat,
|
|
16
|
+
get_kind_from_mime_type,
|
|
17
|
+
validate_prompt,
|
|
18
|
+
)
|
|
19
|
+
from kiln_ai.datamodel.project import Project
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@pytest.fixture
|
|
23
|
+
def valid_extractor_config_data():
|
|
24
|
+
return {
|
|
25
|
+
"name": "Test Extractor Config",
|
|
26
|
+
"description": "Test description",
|
|
27
|
+
"extractor_type": ExtractorType.LITELLM,
|
|
28
|
+
"model_provider_name": "gemini_api",
|
|
29
|
+
"model_name": "gemini-2.0-flash",
|
|
30
|
+
"properties": {
|
|
31
|
+
"prompt_document": "Transcribe the document.",
|
|
32
|
+
"prompt_audio": "Transcribe the audio.",
|
|
33
|
+
"prompt_video": "Transcribe the video.",
|
|
34
|
+
"prompt_image": "Describe the image.",
|
|
35
|
+
},
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@pytest.fixture
|
|
40
|
+
def valid_extractor_config(valid_extractor_config_data):
|
|
41
|
+
return ExtractorConfig(**valid_extractor_config_data)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def test_extractor_config_kind_coercion(valid_extractor_config):
|
|
45
|
+
assert valid_extractor_config.prompt_document() == "Transcribe the document."
|
|
46
|
+
assert valid_extractor_config.prompt_audio() == "Transcribe the audio."
|
|
47
|
+
assert valid_extractor_config.prompt_video() == "Transcribe the video."
|
|
48
|
+
assert valid_extractor_config.prompt_image() == "Describe the image."
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def test_extractor_config_description_empty(valid_extractor_config_data):
|
|
52
|
+
# should not raise an error when description is None
|
|
53
|
+
valid_extractor_config_data["description"] = None
|
|
54
|
+
valid_extractor_config = ExtractorConfig(**valid_extractor_config_data)
|
|
55
|
+
assert valid_extractor_config.description is None
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def test_extractor_config_valid(valid_extractor_config):
|
|
59
|
+
assert valid_extractor_config.name == "Test Extractor Config"
|
|
60
|
+
assert valid_extractor_config.description == "Test description"
|
|
61
|
+
assert valid_extractor_config.extractor_type == ExtractorType.LITELLM
|
|
62
|
+
assert valid_extractor_config.output_format == OutputFormat.MARKDOWN
|
|
63
|
+
assert valid_extractor_config.model_provider_name == "gemini_api"
|
|
64
|
+
assert valid_extractor_config.model_name == "gemini-2.0-flash"
|
|
65
|
+
assert (
|
|
66
|
+
valid_extractor_config.properties["prompt_document"]
|
|
67
|
+
== "Transcribe the document."
|
|
68
|
+
)
|
|
69
|
+
assert valid_extractor_config.properties["prompt_audio"] == "Transcribe the audio."
|
|
70
|
+
assert valid_extractor_config.properties["prompt_video"] == "Transcribe the video."
|
|
71
|
+
assert valid_extractor_config.properties["prompt_image"] == "Describe the image."
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def test_extractor_config_missing_model_name(valid_extractor_config):
|
|
75
|
+
with pytest.raises(ValueError):
|
|
76
|
+
valid_extractor_config.model_name = None
|
|
77
|
+
|
|
78
|
+
with pytest.raises(ValueError):
|
|
79
|
+
valid_extractor_config.model_provider_name = None
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def test_extractor_config_missing_prompts(valid_extractor_config):
|
|
83
|
+
# should not raise an error - prompts will be set to defaults
|
|
84
|
+
with pytest.raises(ValueError):
|
|
85
|
+
valid_extractor_config.properties = {}
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def test_extractor_config_invalid_json(
|
|
89
|
+
valid_extractor_config, valid_extractor_config_data
|
|
90
|
+
):
|
|
91
|
+
class InvalidClass:
|
|
92
|
+
pass
|
|
93
|
+
|
|
94
|
+
with pytest.raises(ValueError, match="validation errors for ExtractorConfig"):
|
|
95
|
+
valid_extractor_config.properties = {
|
|
96
|
+
"prompt_document": valid_extractor_config_data["properties"][
|
|
97
|
+
"prompt_document"
|
|
98
|
+
],
|
|
99
|
+
"prompt_audio": valid_extractor_config_data["properties"]["prompt_audio"],
|
|
100
|
+
"prompt_video": valid_extractor_config_data["properties"]["prompt_video"],
|
|
101
|
+
"prompt_image": valid_extractor_config_data["properties"]["prompt_image"],
|
|
102
|
+
"invalid_key": InvalidClass(),
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def test_extractor_config_invalid_prompt(valid_extractor_config):
|
|
107
|
+
with pytest.raises(ValueError, match="prompt_document must be a string"):
|
|
108
|
+
valid_extractor_config.properties = {
|
|
109
|
+
"prompt_document": 123,
|
|
110
|
+
"prompt_audio": "Transcribe the audio.",
|
|
111
|
+
"prompt_video": "Transcribe the video.",
|
|
112
|
+
"prompt_image": "Describe the image.",
|
|
113
|
+
}
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def test_extractor_config_missing_single_prompt(valid_extractor_config):
|
|
117
|
+
with pytest.raises(ValueError):
|
|
118
|
+
valid_extractor_config.properties = {
|
|
119
|
+
"prompt_document": "Transcribe the document.",
|
|
120
|
+
"prompt_audio": "Transcribe the audio.",
|
|
121
|
+
"prompt_video": "Transcribe the video.",
|
|
122
|
+
# missing image
|
|
123
|
+
}
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def test_extractor_config_invalid_config_type(valid_extractor_config):
|
|
127
|
+
# Create an invalid config type using string
|
|
128
|
+
with pytest.raises(ValueError):
|
|
129
|
+
valid_extractor_config.extractor_type = "invalid_type"
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
@pytest.mark.parametrize(
|
|
133
|
+
"passthrough_mimetypes",
|
|
134
|
+
[
|
|
135
|
+
[OutputFormat.TEXT],
|
|
136
|
+
[OutputFormat.MARKDOWN],
|
|
137
|
+
[OutputFormat.TEXT, OutputFormat.MARKDOWN],
|
|
138
|
+
],
|
|
139
|
+
)
|
|
140
|
+
def test_valid_passthrough_mimetypes(
|
|
141
|
+
valid_extractor_config_data, passthrough_mimetypes
|
|
142
|
+
):
|
|
143
|
+
config_data = valid_extractor_config_data.copy()
|
|
144
|
+
config_data["passthrough_mimetypes"] = passthrough_mimetypes
|
|
145
|
+
config = ExtractorConfig(**config_data)
|
|
146
|
+
assert config.passthrough_mimetypes == passthrough_mimetypes
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
@pytest.mark.parametrize(
|
|
150
|
+
"passthrough_mimetypes",
|
|
151
|
+
[
|
|
152
|
+
["invalid_format"],
|
|
153
|
+
["another_invalid"],
|
|
154
|
+
[OutputFormat.TEXT, "invalid_format"],
|
|
155
|
+
],
|
|
156
|
+
)
|
|
157
|
+
def test_invalid_passthrough_mimetypes(
|
|
158
|
+
valid_extractor_config_data, passthrough_mimetypes
|
|
159
|
+
):
|
|
160
|
+
config_data = valid_extractor_config_data.copy()
|
|
161
|
+
config_data["passthrough_mimetypes"] = passthrough_mimetypes
|
|
162
|
+
with pytest.raises(ValueError):
|
|
163
|
+
ExtractorConfig(**config_data)
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def test_validate_prompt_valid():
|
|
167
|
+
# should not raise an error
|
|
168
|
+
validate_prompt("Transcribe the document.", "prompt_document")
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
@pytest.mark.parametrize(
|
|
172
|
+
"value, expected_error",
|
|
173
|
+
[
|
|
174
|
+
(123, "prompt_document must be a string"),
|
|
175
|
+
("", "prompt_document cannot be empty"),
|
|
176
|
+
],
|
|
177
|
+
)
|
|
178
|
+
def test_validate_prompt_errors(value, expected_error):
|
|
179
|
+
with pytest.raises(ValueError, match=expected_error):
|
|
180
|
+
validate_prompt(value, "prompt_document")
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
@pytest.fixture
|
|
184
|
+
def mock_project(tmp_path):
|
|
185
|
+
project_root = tmp_path / str(uuid.uuid4())
|
|
186
|
+
project_root.mkdir()
|
|
187
|
+
project = Project(
|
|
188
|
+
name="Test Project",
|
|
189
|
+
description="Test description",
|
|
190
|
+
path=project_root / "project.kiln",
|
|
191
|
+
)
|
|
192
|
+
project.save_to_file()
|
|
193
|
+
return project
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
@pytest.fixture
|
|
197
|
+
def mock_extractor_config_factory(mock_project):
|
|
198
|
+
def _create_mock_extractor_config():
|
|
199
|
+
name = f"Test Extractor Config {uuid.uuid4()!s}"
|
|
200
|
+
extractor_config = ExtractorConfig(
|
|
201
|
+
name=name,
|
|
202
|
+
description="Test description",
|
|
203
|
+
model_provider_name="gemini_api",
|
|
204
|
+
model_name="gemini-2.0-flash",
|
|
205
|
+
extractor_type=ExtractorType.LITELLM,
|
|
206
|
+
properties={
|
|
207
|
+
"prompt_document": "Transcribe the document.",
|
|
208
|
+
"prompt_audio": "Transcribe the audio.",
|
|
209
|
+
"prompt_video": "Transcribe the video.",
|
|
210
|
+
"prompt_image": "Describe the image.",
|
|
211
|
+
},
|
|
212
|
+
parent=mock_project,
|
|
213
|
+
)
|
|
214
|
+
extractor_config.save_to_file()
|
|
215
|
+
return extractor_config
|
|
216
|
+
|
|
217
|
+
return _create_mock_extractor_config
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
@pytest.fixture
|
|
221
|
+
def mock_attachment_factory(tmp_path):
|
|
222
|
+
def _create_mock_attachment():
|
|
223
|
+
filename = f"test_{uuid.uuid4()!s}.txt"
|
|
224
|
+
with open(tmp_path / filename, "w") as f:
|
|
225
|
+
f.write("test")
|
|
226
|
+
return KilnAttachmentModel.from_file(tmp_path / filename)
|
|
227
|
+
|
|
228
|
+
return _create_mock_attachment
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
@pytest.fixture
|
|
232
|
+
def mock_document_factory(mock_project, mock_attachment_factory):
|
|
233
|
+
def _create_mock_document():
|
|
234
|
+
name = f"Test Document {uuid.uuid4()!s}"
|
|
235
|
+
document = Document(
|
|
236
|
+
name=name,
|
|
237
|
+
description=f"Test description {uuid.uuid4()!s}",
|
|
238
|
+
kind=Kind.DOCUMENT,
|
|
239
|
+
original_file=FileInfo(
|
|
240
|
+
filename=f"test_{name}.txt",
|
|
241
|
+
size=100,
|
|
242
|
+
mime_type="text/plain",
|
|
243
|
+
attachment=mock_attachment_factory(),
|
|
244
|
+
),
|
|
245
|
+
parent=mock_project,
|
|
246
|
+
)
|
|
247
|
+
document.save_to_file()
|
|
248
|
+
return document
|
|
249
|
+
|
|
250
|
+
return _create_mock_document
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
def test_relationships(
|
|
254
|
+
mock_project,
|
|
255
|
+
mock_extractor_config_factory,
|
|
256
|
+
mock_document_factory,
|
|
257
|
+
mock_attachment_factory,
|
|
258
|
+
):
|
|
259
|
+
# create extractor configs
|
|
260
|
+
initial_extractor_configs = mock_project.extractor_configs()
|
|
261
|
+
assert len(initial_extractor_configs) == 0
|
|
262
|
+
|
|
263
|
+
extractor_configs = []
|
|
264
|
+
for i in range(3):
|
|
265
|
+
extractor_configs.append(mock_extractor_config_factory())
|
|
266
|
+
|
|
267
|
+
# check can get extractor configs from project
|
|
268
|
+
extractor_configs = mock_project.extractor_configs()
|
|
269
|
+
assert len(extractor_configs) == 3
|
|
270
|
+
|
|
271
|
+
# check extractor configs are associated with the correct project
|
|
272
|
+
for extractor_config in extractor_configs:
|
|
273
|
+
assert extractor_config.parent_project().id == mock_project.id
|
|
274
|
+
|
|
275
|
+
# check can get documents from project
|
|
276
|
+
documents = mock_project.documents()
|
|
277
|
+
assert len(documents) == 0
|
|
278
|
+
|
|
279
|
+
documents = []
|
|
280
|
+
for i in range(5):
|
|
281
|
+
document = mock_document_factory()
|
|
282
|
+
documents.append(document)
|
|
283
|
+
|
|
284
|
+
# check can get documents from project
|
|
285
|
+
documents = mock_project.documents()
|
|
286
|
+
assert len(documents) == 5
|
|
287
|
+
|
|
288
|
+
# check documents are associated with the correct project
|
|
289
|
+
for document in documents:
|
|
290
|
+
assert document.parent_project().id == mock_project.id
|
|
291
|
+
|
|
292
|
+
# create extractions for the first 3 documents
|
|
293
|
+
for document in [documents[0], documents[1], documents[2]]:
|
|
294
|
+
for extractor_config in extractor_configs:
|
|
295
|
+
extraction = Extraction(
|
|
296
|
+
source=ExtractionSource.PROCESSED,
|
|
297
|
+
extractor_config_id=extractor_config.id,
|
|
298
|
+
output=mock_attachment_factory(),
|
|
299
|
+
parent=document,
|
|
300
|
+
)
|
|
301
|
+
extraction.save_to_file()
|
|
302
|
+
|
|
303
|
+
# check extractions are associated with the correct document
|
|
304
|
+
for document in [documents[0], documents[1], documents[2]]:
|
|
305
|
+
assert len(document.extractions()) == 3
|
|
306
|
+
for extraction in document.extractions():
|
|
307
|
+
assert extraction.parent_document().id == document.id
|
|
308
|
+
|
|
309
|
+
# check no extractions for the last 2 documents
|
|
310
|
+
for document in [documents[3], documents[4]]:
|
|
311
|
+
assert len(document.extractions()) == 0
|
|
312
|
+
|
|
313
|
+
# check can retrieve a document by id
|
|
314
|
+
document_0 = Document.from_id_and_parent_path(documents[0].id, mock_project.path)
|
|
315
|
+
assert document_0 is not None
|
|
316
|
+
assert document_0.parent_project().id == mock_project.id
|
|
317
|
+
assert document_0.id == documents[0].id
|
|
318
|
+
|
|
319
|
+
# check can retrieve extractions for a document
|
|
320
|
+
document_0_extractions = document_0.extractions()
|
|
321
|
+
assert document_0_extractions is not None
|
|
322
|
+
assert len(document_0_extractions) == 3
|
|
323
|
+
for extraction in document_0_extractions:
|
|
324
|
+
assert extraction.parent_document().id == document_0.id
|
|
325
|
+
|
|
326
|
+
# check can retrieve all documents
|
|
327
|
+
all_documents = Document.all_children_of_parent_path(mock_project.path)
|
|
328
|
+
|
|
329
|
+
# check can retrieve all documents
|
|
330
|
+
assert (
|
|
331
|
+
[d.id for d in all_documents]
|
|
332
|
+
== [d.id for d in mock_project.documents()]
|
|
333
|
+
== [d.id for d in documents]
|
|
334
|
+
)
|
|
335
|
+
|
|
336
|
+
# check all documents are retrieved
|
|
337
|
+
for document_retrieved, document_original in zip(all_documents, documents):
|
|
338
|
+
assert document_retrieved.parent_project().id == mock_project.id
|
|
339
|
+
assert document_retrieved.id == document_original.id
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
@pytest.fixture
|
|
343
|
+
def valid_document(mock_document_factory):
|
|
344
|
+
return mock_document_factory()
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
@pytest.mark.parametrize(
|
|
348
|
+
"tags, expected_tags",
|
|
349
|
+
[
|
|
350
|
+
(["test", "document"], ["test", "document"]),
|
|
351
|
+
(["test", "document", "new"], ["test", "document", "new"]),
|
|
352
|
+
([], []),
|
|
353
|
+
],
|
|
354
|
+
)
|
|
355
|
+
def test_document_tags(valid_document, tags, expected_tags):
|
|
356
|
+
valid_document.tags = tags
|
|
357
|
+
assert valid_document.tags == expected_tags
|
|
358
|
+
|
|
359
|
+
|
|
360
|
+
def test_invalid_tags(valid_document):
|
|
361
|
+
with pytest.raises(ValueError, match="Tags cannot be empty strings"):
|
|
362
|
+
valid_document.tags = ["test", ""]
|
|
363
|
+
with pytest.raises(
|
|
364
|
+
ValueError, match=r"Tags cannot contain spaces. Try underscores."
|
|
365
|
+
):
|
|
366
|
+
valid_document.tags = ["test", "document new"]
|
|
367
|
+
|
|
368
|
+
|
|
369
|
+
@pytest.mark.parametrize(
|
|
370
|
+
"filename, mime_type",
|
|
371
|
+
[
|
|
372
|
+
("file.pdf", "application/pdf"),
|
|
373
|
+
("file.txt", "text/plain"),
|
|
374
|
+
("file.md", "text/markdown"),
|
|
375
|
+
("file.html", "text/html"),
|
|
376
|
+
("file.png", "image/png"),
|
|
377
|
+
("file.jpeg", "image/jpeg"),
|
|
378
|
+
("file.mp4", "video/mp4"),
|
|
379
|
+
("file.mov", "video/quicktime"),
|
|
380
|
+
("file.wav", "audio/wav"),
|
|
381
|
+
("file.mp3", "audio/mpeg"),
|
|
382
|
+
("file.ogg", "audio/ogg"),
|
|
383
|
+
],
|
|
384
|
+
)
|
|
385
|
+
def test_document_valid_mime_type(
|
|
386
|
+
mock_project, mock_attachment_factory, filename, mime_type
|
|
387
|
+
):
|
|
388
|
+
document = Document(
|
|
389
|
+
name="Test Document",
|
|
390
|
+
description="Test description",
|
|
391
|
+
kind=Kind.DOCUMENT,
|
|
392
|
+
original_file=FileInfo(
|
|
393
|
+
filename=filename,
|
|
394
|
+
size=100,
|
|
395
|
+
mime_type=mime_type,
|
|
396
|
+
attachment=mock_attachment_factory(),
|
|
397
|
+
),
|
|
398
|
+
parent=mock_project,
|
|
399
|
+
)
|
|
400
|
+
assert document.original_file.mime_type == mime_type
|
|
401
|
+
|
|
402
|
+
|
|
403
|
+
@pytest.mark.parametrize(
|
|
404
|
+
"filename, mime_type",
|
|
405
|
+
[
|
|
406
|
+
# these are a handful of mime types not currently supported by the extractors
|
|
407
|
+
(
|
|
408
|
+
"file.pptx",
|
|
409
|
+
"application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
|
410
|
+
),
|
|
411
|
+
(
|
|
412
|
+
"file.docx",
|
|
413
|
+
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
|
414
|
+
),
|
|
415
|
+
(
|
|
416
|
+
"file.xlsx",
|
|
417
|
+
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
|
418
|
+
),
|
|
419
|
+
(
|
|
420
|
+
"file.svg",
|
|
421
|
+
"image/svg+xml",
|
|
422
|
+
),
|
|
423
|
+
(
|
|
424
|
+
"file.avi",
|
|
425
|
+
"video/x-msvideo",
|
|
426
|
+
),
|
|
427
|
+
(
|
|
428
|
+
"file.csv",
|
|
429
|
+
"text/csv",
|
|
430
|
+
),
|
|
431
|
+
],
|
|
432
|
+
)
|
|
433
|
+
def test_document_invalid_mime_type(
|
|
434
|
+
mock_project, mock_attachment_factory, filename, mime_type
|
|
435
|
+
):
|
|
436
|
+
with pytest.raises(
|
|
437
|
+
ValueError, match=f"MIME type is not supported: {re.escape(mime_type)}"
|
|
438
|
+
):
|
|
439
|
+
Document(
|
|
440
|
+
name="Test Document",
|
|
441
|
+
description="Test description",
|
|
442
|
+
kind=Kind.DOCUMENT,
|
|
443
|
+
original_file=FileInfo(
|
|
444
|
+
filename=filename,
|
|
445
|
+
size=100,
|
|
446
|
+
mime_type=mime_type,
|
|
447
|
+
attachment=mock_attachment_factory(),
|
|
448
|
+
),
|
|
449
|
+
parent=mock_project,
|
|
450
|
+
)
|
|
451
|
+
|
|
452
|
+
|
|
453
|
+
@pytest.mark.parametrize(
|
|
454
|
+
"mime_type, expected_kind",
|
|
455
|
+
[
|
|
456
|
+
("application/pdf", Kind.DOCUMENT),
|
|
457
|
+
("text/plain", Kind.DOCUMENT),
|
|
458
|
+
("text/markdown", Kind.DOCUMENT),
|
|
459
|
+
("text/html", Kind.DOCUMENT),
|
|
460
|
+
("image/png", Kind.IMAGE),
|
|
461
|
+
("image/jpeg", Kind.IMAGE),
|
|
462
|
+
("video/mp4", Kind.VIDEO),
|
|
463
|
+
("video/quicktime", Kind.VIDEO),
|
|
464
|
+
("audio/mpeg", Kind.AUDIO),
|
|
465
|
+
("audio/wav", Kind.AUDIO),
|
|
466
|
+
("audio/ogg", Kind.AUDIO),
|
|
467
|
+
],
|
|
468
|
+
)
|
|
469
|
+
def test_get_kind_from_mime_type(mime_type, expected_kind):
|
|
470
|
+
assert get_kind_from_mime_type(mime_type) == expected_kind
|