kiln-ai 0.20.1__py3-none-any.whl → 0.22.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (133) hide show
  1. kiln_ai/adapters/__init__.py +6 -0
  2. kiln_ai/adapters/adapter_registry.py +43 -226
  3. kiln_ai/adapters/chunkers/__init__.py +13 -0
  4. kiln_ai/adapters/chunkers/base_chunker.py +42 -0
  5. kiln_ai/adapters/chunkers/chunker_registry.py +16 -0
  6. kiln_ai/adapters/chunkers/fixed_window_chunker.py +39 -0
  7. kiln_ai/adapters/chunkers/helpers.py +23 -0
  8. kiln_ai/adapters/chunkers/test_base_chunker.py +63 -0
  9. kiln_ai/adapters/chunkers/test_chunker_registry.py +28 -0
  10. kiln_ai/adapters/chunkers/test_fixed_window_chunker.py +346 -0
  11. kiln_ai/adapters/chunkers/test_helpers.py +75 -0
  12. kiln_ai/adapters/data_gen/test_data_gen_task.py +9 -3
  13. kiln_ai/adapters/embedding/__init__.py +0 -0
  14. kiln_ai/adapters/embedding/base_embedding_adapter.py +44 -0
  15. kiln_ai/adapters/embedding/embedding_registry.py +32 -0
  16. kiln_ai/adapters/embedding/litellm_embedding_adapter.py +199 -0
  17. kiln_ai/adapters/embedding/test_base_embedding_adapter.py +283 -0
  18. kiln_ai/adapters/embedding/test_embedding_registry.py +166 -0
  19. kiln_ai/adapters/embedding/test_litellm_embedding_adapter.py +1149 -0
  20. kiln_ai/adapters/eval/eval_runner.py +6 -2
  21. kiln_ai/adapters/eval/test_base_eval.py +1 -3
  22. kiln_ai/adapters/eval/test_g_eval.py +1 -1
  23. kiln_ai/adapters/extractors/__init__.py +18 -0
  24. kiln_ai/adapters/extractors/base_extractor.py +72 -0
  25. kiln_ai/adapters/extractors/encoding.py +20 -0
  26. kiln_ai/adapters/extractors/extractor_registry.py +44 -0
  27. kiln_ai/adapters/extractors/extractor_runner.py +112 -0
  28. kiln_ai/adapters/extractors/litellm_extractor.py +406 -0
  29. kiln_ai/adapters/extractors/test_base_extractor.py +244 -0
  30. kiln_ai/adapters/extractors/test_encoding.py +54 -0
  31. kiln_ai/adapters/extractors/test_extractor_registry.py +181 -0
  32. kiln_ai/adapters/extractors/test_extractor_runner.py +181 -0
  33. kiln_ai/adapters/extractors/test_litellm_extractor.py +1290 -0
  34. kiln_ai/adapters/fine_tune/test_dataset_formatter.py +2 -2
  35. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +2 -6
  36. kiln_ai/adapters/fine_tune/test_together_finetune.py +2 -6
  37. kiln_ai/adapters/ml_embedding_model_list.py +494 -0
  38. kiln_ai/adapters/ml_model_list.py +876 -18
  39. kiln_ai/adapters/model_adapters/litellm_adapter.py +40 -75
  40. kiln_ai/adapters/model_adapters/test_litellm_adapter.py +79 -1
  41. kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py +119 -5
  42. kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +9 -3
  43. kiln_ai/adapters/model_adapters/test_structured_output.py +9 -10
  44. kiln_ai/adapters/ollama_tools.py +69 -12
  45. kiln_ai/adapters/provider_tools.py +190 -46
  46. kiln_ai/adapters/rag/deduplication.py +49 -0
  47. kiln_ai/adapters/rag/progress.py +252 -0
  48. kiln_ai/adapters/rag/rag_runners.py +844 -0
  49. kiln_ai/adapters/rag/test_deduplication.py +195 -0
  50. kiln_ai/adapters/rag/test_progress.py +785 -0
  51. kiln_ai/adapters/rag/test_rag_runners.py +2376 -0
  52. kiln_ai/adapters/remote_config.py +80 -8
  53. kiln_ai/adapters/test_adapter_registry.py +579 -86
  54. kiln_ai/adapters/test_ml_embedding_model_list.py +239 -0
  55. kiln_ai/adapters/test_ml_model_list.py +202 -0
  56. kiln_ai/adapters/test_ollama_tools.py +340 -1
  57. kiln_ai/adapters/test_prompt_builders.py +1 -1
  58. kiln_ai/adapters/test_provider_tools.py +199 -8
  59. kiln_ai/adapters/test_remote_config.py +551 -56
  60. kiln_ai/adapters/vector_store/__init__.py +1 -0
  61. kiln_ai/adapters/vector_store/base_vector_store_adapter.py +83 -0
  62. kiln_ai/adapters/vector_store/lancedb_adapter.py +389 -0
  63. kiln_ai/adapters/vector_store/test_base_vector_store.py +160 -0
  64. kiln_ai/adapters/vector_store/test_lancedb_adapter.py +1841 -0
  65. kiln_ai/adapters/vector_store/test_vector_store_registry.py +199 -0
  66. kiln_ai/adapters/vector_store/vector_store_registry.py +33 -0
  67. kiln_ai/datamodel/__init__.py +16 -13
  68. kiln_ai/datamodel/basemodel.py +201 -4
  69. kiln_ai/datamodel/chunk.py +158 -0
  70. kiln_ai/datamodel/datamodel_enums.py +27 -0
  71. kiln_ai/datamodel/embedding.py +64 -0
  72. kiln_ai/datamodel/external_tool_server.py +206 -54
  73. kiln_ai/datamodel/extraction.py +317 -0
  74. kiln_ai/datamodel/project.py +33 -1
  75. kiln_ai/datamodel/rag.py +79 -0
  76. kiln_ai/datamodel/task.py +5 -0
  77. kiln_ai/datamodel/task_output.py +41 -11
  78. kiln_ai/datamodel/test_attachment.py +649 -0
  79. kiln_ai/datamodel/test_basemodel.py +270 -14
  80. kiln_ai/datamodel/test_chunk_models.py +317 -0
  81. kiln_ai/datamodel/test_dataset_split.py +1 -1
  82. kiln_ai/datamodel/test_datasource.py +50 -0
  83. kiln_ai/datamodel/test_embedding_models.py +448 -0
  84. kiln_ai/datamodel/test_eval_model.py +6 -6
  85. kiln_ai/datamodel/test_external_tool_server.py +534 -152
  86. kiln_ai/datamodel/test_extraction_chunk.py +206 -0
  87. kiln_ai/datamodel/test_extraction_model.py +501 -0
  88. kiln_ai/datamodel/test_rag.py +641 -0
  89. kiln_ai/datamodel/test_task.py +35 -1
  90. kiln_ai/datamodel/test_tool_id.py +187 -1
  91. kiln_ai/datamodel/test_vector_store.py +320 -0
  92. kiln_ai/datamodel/tool_id.py +58 -0
  93. kiln_ai/datamodel/vector_store.py +141 -0
  94. kiln_ai/tools/base_tool.py +12 -3
  95. kiln_ai/tools/built_in_tools/math_tools.py +12 -4
  96. kiln_ai/tools/kiln_task_tool.py +158 -0
  97. kiln_ai/tools/mcp_server_tool.py +2 -2
  98. kiln_ai/tools/mcp_session_manager.py +51 -22
  99. kiln_ai/tools/rag_tools.py +164 -0
  100. kiln_ai/tools/test_kiln_task_tool.py +527 -0
  101. kiln_ai/tools/test_mcp_server_tool.py +4 -15
  102. kiln_ai/tools/test_mcp_session_manager.py +187 -227
  103. kiln_ai/tools/test_rag_tools.py +929 -0
  104. kiln_ai/tools/test_tool_registry.py +290 -7
  105. kiln_ai/tools/tool_registry.py +69 -16
  106. kiln_ai/utils/__init__.py +3 -0
  107. kiln_ai/utils/async_job_runner.py +62 -17
  108. kiln_ai/utils/config.py +2 -2
  109. kiln_ai/utils/env.py +15 -0
  110. kiln_ai/utils/filesystem.py +14 -0
  111. kiln_ai/utils/filesystem_cache.py +60 -0
  112. kiln_ai/utils/litellm.py +94 -0
  113. kiln_ai/utils/lock.py +100 -0
  114. kiln_ai/utils/mime_type.py +38 -0
  115. kiln_ai/utils/open_ai_types.py +19 -2
  116. kiln_ai/utils/pdf_utils.py +59 -0
  117. kiln_ai/utils/test_async_job_runner.py +151 -35
  118. kiln_ai/utils/test_env.py +142 -0
  119. kiln_ai/utils/test_filesystem_cache.py +316 -0
  120. kiln_ai/utils/test_litellm.py +206 -0
  121. kiln_ai/utils/test_lock.py +185 -0
  122. kiln_ai/utils/test_mime_type.py +66 -0
  123. kiln_ai/utils/test_open_ai_types.py +88 -12
  124. kiln_ai/utils/test_pdf_utils.py +86 -0
  125. kiln_ai/utils/test_uuid.py +111 -0
  126. kiln_ai/utils/test_validation.py +524 -0
  127. kiln_ai/utils/uuid.py +9 -0
  128. kiln_ai/utils/validation.py +90 -0
  129. {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/METADATA +9 -1
  130. kiln_ai-0.22.0.dist-info/RECORD +213 -0
  131. kiln_ai-0.20.1.dist-info/RECORD +0 -138
  132. {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/WHEEL +0 -0
  133. {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -0,0 +1,206 @@
1
+ import tempfile
2
+ import uuid
3
+ from pathlib import Path
4
+
5
+ import pytest
6
+
7
+ from kiln_ai.datamodel.basemodel import KilnAttachmentModel
8
+ from kiln_ai.datamodel.chunk import Chunk, ChunkedDocument, ChunkerConfig, ChunkerType
9
+ from kiln_ai.datamodel.embedding import ChunkEmbeddings, Embedding, EmbeddingConfig
10
+ from kiln_ai.datamodel.extraction import (
11
+ Document,
12
+ Extraction,
13
+ ExtractionSource,
14
+ FileInfo,
15
+ Kind,
16
+ )
17
+ from kiln_ai.datamodel.project import Project
18
+
19
+
20
+ @pytest.fixture
21
+ def mock_project(tmp_path):
22
+ project_root = tmp_path / str(uuid.uuid4())
23
+ project_root.mkdir()
24
+ project = Project(
25
+ name="Test Project",
26
+ description="Test description",
27
+ path=project_root / "project.kiln",
28
+ )
29
+ project.save_to_file()
30
+ return project
31
+
32
+
33
+ class TestIntegration:
34
+ """Integration tests for the chunk module."""
35
+
36
+ def test_full_workflow(self):
37
+ """Test a complete workflow with all classes."""
38
+ # Create chunker properties
39
+ properties = {"chunk_size": 256, "chunk_overlap": 10}
40
+
41
+ # Create chunker config
42
+ config = ChunkerConfig(
43
+ name="test-chunker",
44
+ description="A test chunker configuration",
45
+ chunker_type=ChunkerType.FIXED_WINDOW,
46
+ properties=properties,
47
+ )
48
+
49
+ # Create a temporary file for the attachment
50
+ with tempfile.NamedTemporaryFile(delete=True) as tmp_file:
51
+ tmp_file.write(b"test content")
52
+ tmp_path = Path(tmp_file.name)
53
+
54
+ # Create attachment
55
+ attachment = KilnAttachmentModel.from_file(tmp_path)
56
+
57
+ # Create chunks
58
+ chunk1 = Chunk(content=attachment)
59
+ chunk2 = Chunk(content=attachment)
60
+
61
+ # Create chunk document
62
+ doc = ChunkedDocument(
63
+ chunks=[chunk1, chunk2],
64
+ chunker_config_id=config.id,
65
+ )
66
+
67
+ # Verify the complete structure
68
+ assert config.name == "test-chunker"
69
+ assert config.chunker_type == ChunkerType.FIXED_WINDOW
70
+ assert config.chunk_size() == 256
71
+ assert config.chunk_overlap() == 10
72
+ assert len(doc.chunks) == 2
73
+ assert doc.chunks[0].content == attachment
74
+ assert doc.chunks[1].content == attachment
75
+
76
+ def test_serialization(self, mock_project):
77
+ """Test that models can be serialized and deserialized."""
78
+ properties = {"chunk_size": 512, "chunk_overlap": 20}
79
+ config = ChunkerConfig(
80
+ name="serialization-test",
81
+ chunker_type=ChunkerType.FIXED_WINDOW,
82
+ properties=properties,
83
+ parent=mock_project,
84
+ )
85
+
86
+ # Save to file
87
+ config.save_to_file()
88
+
89
+ # Load from file
90
+ config_restored = ChunkerConfig.load_from_file(config.path)
91
+
92
+ assert config_restored.name == config.name
93
+ assert config_restored.chunker_type == config.chunker_type
94
+ assert config_restored.chunk_size() == config.chunk_size()
95
+ assert config_restored.chunk_overlap() == config.chunk_overlap()
96
+ assert config_restored.parent_project().id == mock_project.id
97
+
98
+ def test_enum_serialization(self):
99
+ """Test that ChunkerType enum serializes correctly."""
100
+ config = ChunkerConfig(
101
+ name="enum-test",
102
+ chunker_type=ChunkerType.FIXED_WINDOW,
103
+ properties={"chunk_size": 512, "chunk_overlap": 20},
104
+ )
105
+
106
+ config_dict = config.model_dump()
107
+ assert config_dict["chunker_type"] == "fixed_window"
108
+
109
+ config_restored = ChunkerConfig.model_validate(config_dict)
110
+ assert config_restored.chunker_type == ChunkerType.FIXED_WINDOW
111
+
112
+ def test_relationships(self, mock_project):
113
+ """Test that relationships are properly validated."""
114
+
115
+ # Create a config
116
+ config = ChunkerConfig(
117
+ name="test-chunker",
118
+ chunker_type=ChunkerType.FIXED_WINDOW,
119
+ properties={"chunk_size": 512, "chunk_overlap": 20},
120
+ parent=mock_project,
121
+ )
122
+ config.save_to_file()
123
+
124
+ # Dummy file we will use as attachment
125
+ with tempfile.NamedTemporaryFile(delete=True) as tmp_file:
126
+ tmp_file.write(b"test content")
127
+ tmp_path = Path(tmp_file.name)
128
+
129
+ # Create a document
130
+ document = Document(
131
+ name="test-document",
132
+ description="Test document",
133
+ parent=mock_project,
134
+ original_file=FileInfo(
135
+ filename="test.txt",
136
+ size=100,
137
+ mime_type="text/plain",
138
+ attachment=KilnAttachmentModel.from_file(tmp_path),
139
+ ),
140
+ kind=Kind.DOCUMENT,
141
+ )
142
+ document.save_to_file()
143
+
144
+ # Create an extraction
145
+ extraction = Extraction(
146
+ source=ExtractionSource.PROCESSED,
147
+ extractor_config_id=config.id,
148
+ output=KilnAttachmentModel.from_file(tmp_path),
149
+ parent=document,
150
+ )
151
+ extraction.save_to_file()
152
+
153
+ # Create some chunks
154
+ chunks = [Chunk(content=KilnAttachmentModel.from_file(tmp_path))] * 3
155
+
156
+ chunked_document = ChunkedDocument(
157
+ parent=extraction,
158
+ chunks=chunks,
159
+ chunker_config_id=config.id,
160
+ )
161
+ chunked_document.save_to_file()
162
+
163
+ assert len(chunked_document.chunks) == 3
164
+
165
+ # Check that the document chunked is associated with the correct extraction
166
+ assert chunked_document.parent_extraction().id == extraction.id
167
+
168
+ for chunked_document_found in extraction.chunked_documents():
169
+ assert chunked_document.id == chunked_document_found.id
170
+
171
+ assert len(extraction.chunked_documents()) == 1
172
+
173
+ # the chunks should have a filename prefixed with content_
174
+ for chunk in chunked_document.chunks:
175
+ filename = chunk.content.path.name
176
+ assert filename.startswith("content_")
177
+
178
+ # create an embedding config
179
+ embedding_config = EmbeddingConfig(
180
+ name="test-embedding-config",
181
+ description="Test embedding config",
182
+ parent=mock_project,
183
+ model_name="openai_text_embedding_3_small",
184
+ model_provider_name="openai",
185
+ properties={},
186
+ )
187
+ embedding_config.save_to_file()
188
+
189
+ # create chunk embeddings
190
+ chunk_embeddings = ChunkEmbeddings(
191
+ parent=chunked_document,
192
+ embedding_config_id=embedding_config.id,
193
+ embeddings=[Embedding(vector=[1.0] * 1536) for _ in range(3)],
194
+ )
195
+ chunk_embeddings.save_to_file()
196
+
197
+ retrieved_chunk_embeddings = chunked_document.chunk_embeddings()
198
+ assert isinstance(retrieved_chunk_embeddings, list)
199
+ assert len(retrieved_chunk_embeddings) == 1
200
+
201
+ # check project has the embedding config and the chunker config
202
+ assert (
203
+ mock_project.embedding_configs(readonly=True)[0].id
204
+ == embedding_config.id
205
+ )
206
+ assert mock_project.chunker_configs(readonly=True)[0].id == config.id
@@ -0,0 +1,501 @@
1
+ import re
2
+ import uuid
3
+
4
+ import pytest
5
+
6
+ from kiln_ai.datamodel.basemodel import KilnAttachmentModel
7
+ from kiln_ai.datamodel.extraction import (
8
+ Document,
9
+ Extraction,
10
+ ExtractionSource,
11
+ ExtractorConfig,
12
+ ExtractorType,
13
+ FileInfo,
14
+ Kind,
15
+ OutputFormat,
16
+ get_kind_from_mime_type,
17
+ validate_prompt,
18
+ )
19
+ from kiln_ai.datamodel.project import Project
20
+
21
+
22
+ @pytest.fixture
23
+ def valid_extractor_config_data():
24
+ return {
25
+ "name": "Test Extractor Config",
26
+ "description": "Test description",
27
+ "extractor_type": ExtractorType.LITELLM,
28
+ "model_provider_name": "gemini_api",
29
+ "model_name": "gemini-2.0-flash",
30
+ "properties": {
31
+ "prompt_document": "Transcribe the document.",
32
+ "prompt_audio": "Transcribe the audio.",
33
+ "prompt_video": "Transcribe the video.",
34
+ "prompt_image": "Describe the image.",
35
+ },
36
+ }
37
+
38
+
39
+ @pytest.fixture
40
+ def valid_extractor_config(valid_extractor_config_data):
41
+ return ExtractorConfig(**valid_extractor_config_data)
42
+
43
+
44
+ def test_extractor_config_kind_coercion(valid_extractor_config):
45
+ assert valid_extractor_config.prompt_document() == "Transcribe the document."
46
+ assert valid_extractor_config.prompt_audio() == "Transcribe the audio."
47
+ assert valid_extractor_config.prompt_video() == "Transcribe the video."
48
+ assert valid_extractor_config.prompt_image() == "Describe the image."
49
+
50
+
51
+ def test_extractor_config_description_empty(valid_extractor_config_data):
52
+ # should not raise an error when description is None
53
+ valid_extractor_config_data["description"] = None
54
+ valid_extractor_config = ExtractorConfig(**valid_extractor_config_data)
55
+ assert valid_extractor_config.description is None
56
+
57
+
58
+ def test_extractor_config_valid(valid_extractor_config):
59
+ assert valid_extractor_config.name == "Test Extractor Config"
60
+ assert valid_extractor_config.description == "Test description"
61
+ assert valid_extractor_config.extractor_type == ExtractorType.LITELLM
62
+ assert valid_extractor_config.output_format == OutputFormat.MARKDOWN
63
+ assert valid_extractor_config.model_provider_name == "gemini_api"
64
+ assert valid_extractor_config.model_name == "gemini-2.0-flash"
65
+ assert (
66
+ valid_extractor_config.properties["prompt_document"]
67
+ == "Transcribe the document."
68
+ )
69
+ assert valid_extractor_config.properties["prompt_audio"] == "Transcribe the audio."
70
+ assert valid_extractor_config.properties["prompt_video"] == "Transcribe the video."
71
+ assert valid_extractor_config.properties["prompt_image"] == "Describe the image."
72
+
73
+
74
+ def test_extractor_config_missing_model_name(valid_extractor_config):
75
+ with pytest.raises(ValueError):
76
+ valid_extractor_config.model_name = None
77
+
78
+ with pytest.raises(ValueError):
79
+ valid_extractor_config.model_provider_name = None
80
+
81
+
82
+ def test_extractor_config_missing_prompts(valid_extractor_config):
83
+ # should not raise an error - prompts will be set to defaults
84
+ with pytest.raises(ValueError):
85
+ valid_extractor_config.properties = {}
86
+
87
+
88
+ def test_extractor_config_invalid_json(
89
+ valid_extractor_config, valid_extractor_config_data
90
+ ):
91
+ class InvalidClass:
92
+ pass
93
+
94
+ with pytest.raises(ValueError, match="validation errors for ExtractorConfig"):
95
+ valid_extractor_config.properties = {
96
+ "prompt_document": valid_extractor_config_data["properties"][
97
+ "prompt_document"
98
+ ],
99
+ "prompt_audio": valid_extractor_config_data["properties"]["prompt_audio"],
100
+ "prompt_video": valid_extractor_config_data["properties"]["prompt_video"],
101
+ "prompt_image": valid_extractor_config_data["properties"]["prompt_image"],
102
+ "invalid_key": InvalidClass(),
103
+ }
104
+
105
+
106
+ def test_extractor_config_invalid_prompt(valid_extractor_config):
107
+ with pytest.raises(ValueError, match="prompt_document must be a string"):
108
+ valid_extractor_config.properties = {
109
+ "prompt_document": 123,
110
+ "prompt_audio": "Transcribe the audio.",
111
+ "prompt_video": "Transcribe the video.",
112
+ "prompt_image": "Describe the image.",
113
+ }
114
+
115
+
116
+ def test_extractor_config_missing_single_prompt(valid_extractor_config):
117
+ with pytest.raises(ValueError):
118
+ valid_extractor_config.properties = {
119
+ "prompt_document": "Transcribe the document.",
120
+ "prompt_audio": "Transcribe the audio.",
121
+ "prompt_video": "Transcribe the video.",
122
+ # missing image
123
+ }
124
+
125
+
126
+ def test_extractor_config_invalid_config_type(valid_extractor_config):
127
+ # Create an invalid config type using string
128
+ with pytest.raises(ValueError):
129
+ valid_extractor_config.extractor_type = "invalid_type"
130
+
131
+
132
+ @pytest.mark.parametrize(
133
+ "passthrough_mimetypes",
134
+ [
135
+ [OutputFormat.TEXT],
136
+ [OutputFormat.MARKDOWN],
137
+ [OutputFormat.TEXT, OutputFormat.MARKDOWN],
138
+ ],
139
+ )
140
+ def test_valid_passthrough_mimetypes(
141
+ valid_extractor_config_data, passthrough_mimetypes
142
+ ):
143
+ config_data = valid_extractor_config_data.copy()
144
+ config_data["passthrough_mimetypes"] = passthrough_mimetypes
145
+ config = ExtractorConfig(**config_data)
146
+ assert config.passthrough_mimetypes == passthrough_mimetypes
147
+
148
+
149
+ @pytest.mark.parametrize(
150
+ "passthrough_mimetypes",
151
+ [
152
+ ["invalid_format"],
153
+ ["another_invalid"],
154
+ [OutputFormat.TEXT, "invalid_format"],
155
+ ],
156
+ )
157
+ def test_invalid_passthrough_mimetypes(
158
+ valid_extractor_config_data, passthrough_mimetypes
159
+ ):
160
+ config_data = valid_extractor_config_data.copy()
161
+ config_data["passthrough_mimetypes"] = passthrough_mimetypes
162
+ with pytest.raises(ValueError):
163
+ ExtractorConfig(**config_data)
164
+
165
+
166
+ def test_validate_prompt_valid():
167
+ # should not raise an error
168
+ validate_prompt("Transcribe the document.", "prompt_document")
169
+
170
+
171
+ @pytest.mark.parametrize(
172
+ "value, expected_error",
173
+ [
174
+ (123, "prompt_document must be a string"),
175
+ ("", "prompt_document cannot be empty"),
176
+ ],
177
+ )
178
+ def test_validate_prompt_errors(value, expected_error):
179
+ with pytest.raises(ValueError, match=expected_error):
180
+ validate_prompt(value, "prompt_document")
181
+
182
+
183
+ @pytest.fixture
184
+ def mock_project(tmp_path):
185
+ project_root = tmp_path / str(uuid.uuid4())
186
+ project_root.mkdir()
187
+ project = Project(
188
+ name="Test Project",
189
+ description="Test description",
190
+ path=project_root / "project.kiln",
191
+ )
192
+ project.save_to_file()
193
+ return project
194
+
195
+
196
+ @pytest.fixture
197
+ def mock_extractor_config_factory(mock_project):
198
+ def _create_mock_extractor_config():
199
+ name = f"Test Extractor Config {uuid.uuid4()!s}"
200
+ extractor_config = ExtractorConfig(
201
+ name=name,
202
+ description="Test description",
203
+ model_provider_name="gemini_api",
204
+ model_name="gemini-2.0-flash",
205
+ extractor_type=ExtractorType.LITELLM,
206
+ properties={
207
+ "prompt_document": "Transcribe the document.",
208
+ "prompt_audio": "Transcribe the audio.",
209
+ "prompt_video": "Transcribe the video.",
210
+ "prompt_image": "Describe the image.",
211
+ },
212
+ parent=mock_project,
213
+ )
214
+ extractor_config.save_to_file()
215
+ return extractor_config
216
+
217
+ return _create_mock_extractor_config
218
+
219
+
220
+ @pytest.fixture
221
+ def mock_attachment_factory(tmp_path):
222
+ def _create_mock_attachment():
223
+ filename = f"test_{uuid.uuid4()!s}.txt"
224
+ with open(tmp_path / filename, "w") as f:
225
+ f.write("test")
226
+ return KilnAttachmentModel.from_file(tmp_path / filename)
227
+
228
+ return _create_mock_attachment
229
+
230
+
231
+ @pytest.fixture
232
+ def mock_document_factory(mock_project, mock_attachment_factory):
233
+ def _create_mock_document():
234
+ name = f"Test Document {uuid.uuid4()!s}"
235
+ document = Document(
236
+ name=name,
237
+ description=f"Test description {uuid.uuid4()!s}",
238
+ kind=Kind.DOCUMENT,
239
+ original_file=FileInfo(
240
+ filename=f"test_{name}.txt",
241
+ size=100,
242
+ mime_type="text/plain",
243
+ attachment=mock_attachment_factory(),
244
+ ),
245
+ parent=mock_project,
246
+ )
247
+ document.save_to_file()
248
+ return document
249
+
250
+ return _create_mock_document
251
+
252
+
253
+ def test_relationships(
254
+ mock_project,
255
+ mock_extractor_config_factory,
256
+ mock_document_factory,
257
+ mock_attachment_factory,
258
+ ):
259
+ # create extractor configs
260
+ initial_extractor_configs = mock_project.extractor_configs()
261
+ assert len(initial_extractor_configs) == 0
262
+
263
+ extractor_configs = []
264
+ for i in range(3):
265
+ extractor_configs.append(mock_extractor_config_factory())
266
+
267
+ # check can get extractor configs from project
268
+ extractor_configs = mock_project.extractor_configs()
269
+ assert len(extractor_configs) == 3
270
+
271
+ # check extractor configs are associated with the correct project
272
+ for extractor_config in extractor_configs:
273
+ assert extractor_config.parent_project().id == mock_project.id
274
+
275
+ # check can get documents from project
276
+ documents = mock_project.documents()
277
+ assert len(documents) == 0
278
+
279
+ documents = []
280
+ for i in range(5):
281
+ document = mock_document_factory()
282
+ documents.append(document)
283
+
284
+ # check can get documents from project
285
+ documents = mock_project.documents()
286
+ assert len(documents) == 5
287
+
288
+ # check documents are associated with the correct project
289
+ for document in documents:
290
+ assert document.parent_project().id == mock_project.id
291
+
292
+ # create extractions for the first 3 documents
293
+ for document in [documents[0], documents[1], documents[2]]:
294
+ for extractor_config in extractor_configs:
295
+ extraction = Extraction(
296
+ source=ExtractionSource.PROCESSED,
297
+ extractor_config_id=extractor_config.id,
298
+ output=mock_attachment_factory(),
299
+ parent=document,
300
+ )
301
+ extraction.save_to_file()
302
+
303
+ # check extractions are associated with the correct document
304
+ for document in [documents[0], documents[1], documents[2]]:
305
+ assert len(document.extractions()) == 3
306
+ for extraction in document.extractions():
307
+ assert extraction.parent_document().id == document.id
308
+
309
+ # check no extractions for the last 2 documents
310
+ for document in [documents[3], documents[4]]:
311
+ assert len(document.extractions()) == 0
312
+
313
+ # check can retrieve a document by id
314
+ document_0 = Document.from_id_and_parent_path(documents[0].id, mock_project.path)
315
+ assert document_0 is not None
316
+ assert document_0.parent_project().id == mock_project.id
317
+ assert document_0.id == documents[0].id
318
+
319
+ # check can retrieve extractions for a document
320
+ document_0_extractions = document_0.extractions()
321
+ assert document_0_extractions is not None
322
+ assert len(document_0_extractions) == 3
323
+ for extraction in document_0_extractions:
324
+ assert extraction.parent_document().id == document_0.id
325
+
326
+ # check can retrieve all documents
327
+ all_documents = Document.all_children_of_parent_path(mock_project.path)
328
+
329
+ # check can retrieve all documents
330
+ assert (
331
+ [d.id for d in all_documents]
332
+ == [d.id for d in mock_project.documents()]
333
+ == [d.id for d in documents]
334
+ )
335
+
336
+ # check all documents are retrieved
337
+ for document_retrieved, document_original in zip(all_documents, documents):
338
+ assert document_retrieved.parent_project().id == mock_project.id
339
+ assert document_retrieved.id == document_original.id
340
+
341
+
342
+ @pytest.fixture
343
+ def valid_document(mock_document_factory):
344
+ return mock_document_factory()
345
+
346
+
347
+ @pytest.mark.parametrize(
348
+ "tags, expected_tags",
349
+ [
350
+ (["test", "document"], ["test", "document"]),
351
+ (["test", "document", "new"], ["test", "document", "new"]),
352
+ ([], []),
353
+ ],
354
+ )
355
+ def test_document_tags(valid_document, tags, expected_tags):
356
+ valid_document.tags = tags
357
+ assert valid_document.tags == expected_tags
358
+
359
+
360
+ def test_invalid_tags(valid_document):
361
+ with pytest.raises(ValueError, match="Tags cannot be empty strings"):
362
+ valid_document.tags = ["test", ""]
363
+ with pytest.raises(
364
+ ValueError, match=r"Tags cannot contain spaces. Try underscores."
365
+ ):
366
+ valid_document.tags = ["test", "document new"]
367
+
368
+
369
+ @pytest.mark.parametrize(
370
+ "filename, mime_type",
371
+ [
372
+ ("file.pdf", "application/pdf"),
373
+ ("file.txt", "text/plain"),
374
+ ("file.md", "text/markdown"),
375
+ ("file.html", "text/html"),
376
+ ("file.png", "image/png"),
377
+ ("file.jpeg", "image/jpeg"),
378
+ ("file.mp4", "video/mp4"),
379
+ ("file.mov", "video/quicktime"),
380
+ ("file.wav", "audio/wav"),
381
+ ("file.mp3", "audio/mpeg"),
382
+ ("file.ogg", "audio/ogg"),
383
+ ],
384
+ )
385
+ def test_document_valid_mime_type(
386
+ mock_project, mock_attachment_factory, filename, mime_type
387
+ ):
388
+ document = Document(
389
+ name="Test Document",
390
+ description="Test description",
391
+ kind=Kind.DOCUMENT,
392
+ original_file=FileInfo(
393
+ filename=filename,
394
+ size=100,
395
+ mime_type=mime_type,
396
+ attachment=mock_attachment_factory(),
397
+ ),
398
+ parent=mock_project,
399
+ )
400
+ assert document.original_file.mime_type == mime_type
401
+
402
+
403
+ @pytest.mark.parametrize(
404
+ "filename, mime_type",
405
+ [
406
+ # these are a handful of mime types not currently supported by the extractors
407
+ (
408
+ "file.pptx",
409
+ "application/vnd.openxmlformats-officedocument.presentationml.presentation",
410
+ ),
411
+ (
412
+ "file.docx",
413
+ "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
414
+ ),
415
+ (
416
+ "file.xlsx",
417
+ "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
418
+ ),
419
+ (
420
+ "file.svg",
421
+ "image/svg+xml",
422
+ ),
423
+ (
424
+ "file.avi",
425
+ "video/x-msvideo",
426
+ ),
427
+ (
428
+ "file.csv",
429
+ "text/csv",
430
+ ),
431
+ ],
432
+ )
433
+ def test_document_invalid_mime_type(
434
+ mock_project, mock_attachment_factory, filename, mime_type
435
+ ):
436
+ with pytest.raises(
437
+ ValueError, match=f"MIME type is not supported: {re.escape(mime_type)}"
438
+ ):
439
+ Document(
440
+ name="Test Document",
441
+ description="Test description",
442
+ kind=Kind.DOCUMENT,
443
+ original_file=FileInfo(
444
+ filename=filename,
445
+ size=100,
446
+ mime_type=mime_type,
447
+ attachment=mock_attachment_factory(),
448
+ ),
449
+ parent=mock_project,
450
+ )
451
+
452
+
453
+ @pytest.mark.parametrize(
454
+ "mime_type, expected_kind",
455
+ [
456
+ ("application/pdf", Kind.DOCUMENT),
457
+ ("text/plain", Kind.DOCUMENT),
458
+ ("text/markdown", Kind.DOCUMENT),
459
+ ("text/html", Kind.DOCUMENT),
460
+ ("image/png", Kind.IMAGE),
461
+ ("image/jpeg", Kind.IMAGE),
462
+ ("video/mp4", Kind.VIDEO),
463
+ ("video/quicktime", Kind.VIDEO),
464
+ ("audio/mpeg", Kind.AUDIO),
465
+ ("audio/wav", Kind.AUDIO),
466
+ ("audio/ogg", Kind.AUDIO),
467
+ ],
468
+ )
469
+ def test_get_kind_from_mime_type(mime_type, expected_kind):
470
+ assert get_kind_from_mime_type(mime_type) == expected_kind
471
+
472
+
473
+ def test_document_friendly_name(mock_project, mock_attachment_factory):
474
+ name = f"Test Document {uuid.uuid4()!s}"
475
+ document = Document(
476
+ name=name,
477
+ description=f"Test description {uuid.uuid4()!s}",
478
+ kind=Kind.DOCUMENT,
479
+ original_file=FileInfo(
480
+ filename=f"test_{name}.txt",
481
+ size=100,
482
+ mime_type="text/plain",
483
+ attachment=mock_attachment_factory(),
484
+ ),
485
+ parent=mock_project,
486
+ )
487
+ document.save_to_file()
488
+
489
+ # backward compatibility: old documents did not have name_override
490
+ assert document.name_override is None
491
+ assert document.friendly_name == name
492
+
493
+ # new documents have name_override
494
+ document.name_override = "Test Document Override"
495
+ assert document.friendly_name == "Test Document Override"
496
+
497
+ document.save_to_file()
498
+
499
+ document = Document.from_id_and_parent_path(str(document.id), mock_project.path)
500
+ assert document is not None
501
+ assert document.friendly_name == "Test Document Override"