kiln-ai 0.20.1__py3-none-any.whl → 0.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (117) hide show
  1. kiln_ai/adapters/__init__.py +6 -0
  2. kiln_ai/adapters/adapter_registry.py +43 -226
  3. kiln_ai/adapters/chunkers/__init__.py +13 -0
  4. kiln_ai/adapters/chunkers/base_chunker.py +42 -0
  5. kiln_ai/adapters/chunkers/chunker_registry.py +16 -0
  6. kiln_ai/adapters/chunkers/fixed_window_chunker.py +39 -0
  7. kiln_ai/adapters/chunkers/helpers.py +23 -0
  8. kiln_ai/adapters/chunkers/test_base_chunker.py +63 -0
  9. kiln_ai/adapters/chunkers/test_chunker_registry.py +28 -0
  10. kiln_ai/adapters/chunkers/test_fixed_window_chunker.py +346 -0
  11. kiln_ai/adapters/chunkers/test_helpers.py +75 -0
  12. kiln_ai/adapters/data_gen/test_data_gen_task.py +9 -3
  13. kiln_ai/adapters/embedding/__init__.py +0 -0
  14. kiln_ai/adapters/embedding/base_embedding_adapter.py +44 -0
  15. kiln_ai/adapters/embedding/embedding_registry.py +32 -0
  16. kiln_ai/adapters/embedding/litellm_embedding_adapter.py +199 -0
  17. kiln_ai/adapters/embedding/test_base_embedding_adapter.py +283 -0
  18. kiln_ai/adapters/embedding/test_embedding_registry.py +166 -0
  19. kiln_ai/adapters/embedding/test_litellm_embedding_adapter.py +1149 -0
  20. kiln_ai/adapters/eval/eval_runner.py +6 -2
  21. kiln_ai/adapters/eval/test_base_eval.py +1 -3
  22. kiln_ai/adapters/eval/test_g_eval.py +1 -1
  23. kiln_ai/adapters/extractors/__init__.py +18 -0
  24. kiln_ai/adapters/extractors/base_extractor.py +72 -0
  25. kiln_ai/adapters/extractors/encoding.py +20 -0
  26. kiln_ai/adapters/extractors/extractor_registry.py +44 -0
  27. kiln_ai/adapters/extractors/extractor_runner.py +112 -0
  28. kiln_ai/adapters/extractors/litellm_extractor.py +386 -0
  29. kiln_ai/adapters/extractors/test_base_extractor.py +244 -0
  30. kiln_ai/adapters/extractors/test_encoding.py +54 -0
  31. kiln_ai/adapters/extractors/test_extractor_registry.py +181 -0
  32. kiln_ai/adapters/extractors/test_extractor_runner.py +181 -0
  33. kiln_ai/adapters/extractors/test_litellm_extractor.py +1192 -0
  34. kiln_ai/adapters/fine_tune/test_dataset_formatter.py +2 -2
  35. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +2 -6
  36. kiln_ai/adapters/fine_tune/test_together_finetune.py +2 -6
  37. kiln_ai/adapters/ml_embedding_model_list.py +192 -0
  38. kiln_ai/adapters/ml_model_list.py +382 -4
  39. kiln_ai/adapters/model_adapters/litellm_adapter.py +7 -69
  40. kiln_ai/adapters/model_adapters/test_litellm_adapter.py +1 -1
  41. kiln_ai/adapters/model_adapters/test_structured_output.py +3 -1
  42. kiln_ai/adapters/ollama_tools.py +69 -12
  43. kiln_ai/adapters/provider_tools.py +190 -46
  44. kiln_ai/adapters/rag/deduplication.py +49 -0
  45. kiln_ai/adapters/rag/progress.py +252 -0
  46. kiln_ai/adapters/rag/rag_runners.py +844 -0
  47. kiln_ai/adapters/rag/test_deduplication.py +195 -0
  48. kiln_ai/adapters/rag/test_progress.py +785 -0
  49. kiln_ai/adapters/rag/test_rag_runners.py +2376 -0
  50. kiln_ai/adapters/remote_config.py +80 -8
  51. kiln_ai/adapters/test_adapter_registry.py +579 -86
  52. kiln_ai/adapters/test_ml_embedding_model_list.py +429 -0
  53. kiln_ai/adapters/test_ml_model_list.py +212 -0
  54. kiln_ai/adapters/test_ollama_tools.py +340 -1
  55. kiln_ai/adapters/test_prompt_builders.py +1 -1
  56. kiln_ai/adapters/test_provider_tools.py +199 -8
  57. kiln_ai/adapters/test_remote_config.py +551 -56
  58. kiln_ai/adapters/vector_store/__init__.py +1 -0
  59. kiln_ai/adapters/vector_store/base_vector_store_adapter.py +83 -0
  60. kiln_ai/adapters/vector_store/lancedb_adapter.py +389 -0
  61. kiln_ai/adapters/vector_store/test_base_vector_store.py +160 -0
  62. kiln_ai/adapters/vector_store/test_lancedb_adapter.py +1841 -0
  63. kiln_ai/adapters/vector_store/test_vector_store_registry.py +199 -0
  64. kiln_ai/adapters/vector_store/vector_store_registry.py +33 -0
  65. kiln_ai/datamodel/__init__.py +16 -13
  66. kiln_ai/datamodel/basemodel.py +170 -1
  67. kiln_ai/datamodel/chunk.py +158 -0
  68. kiln_ai/datamodel/datamodel_enums.py +27 -0
  69. kiln_ai/datamodel/embedding.py +64 -0
  70. kiln_ai/datamodel/extraction.py +303 -0
  71. kiln_ai/datamodel/project.py +33 -1
  72. kiln_ai/datamodel/rag.py +79 -0
  73. kiln_ai/datamodel/test_attachment.py +649 -0
  74. kiln_ai/datamodel/test_basemodel.py +1 -1
  75. kiln_ai/datamodel/test_chunk_models.py +317 -0
  76. kiln_ai/datamodel/test_dataset_split.py +1 -1
  77. kiln_ai/datamodel/test_embedding_models.py +448 -0
  78. kiln_ai/datamodel/test_eval_model.py +6 -6
  79. kiln_ai/datamodel/test_extraction_chunk.py +206 -0
  80. kiln_ai/datamodel/test_extraction_model.py +470 -0
  81. kiln_ai/datamodel/test_rag.py +641 -0
  82. kiln_ai/datamodel/test_tool_id.py +81 -0
  83. kiln_ai/datamodel/test_vector_store.py +320 -0
  84. kiln_ai/datamodel/tool_id.py +22 -0
  85. kiln_ai/datamodel/vector_store.py +141 -0
  86. kiln_ai/tools/mcp_session_manager.py +4 -1
  87. kiln_ai/tools/rag_tools.py +157 -0
  88. kiln_ai/tools/test_mcp_session_manager.py +1 -1
  89. kiln_ai/tools/test_rag_tools.py +848 -0
  90. kiln_ai/tools/test_tool_registry.py +91 -2
  91. kiln_ai/tools/tool_registry.py +21 -0
  92. kiln_ai/utils/__init__.py +3 -0
  93. kiln_ai/utils/async_job_runner.py +62 -17
  94. kiln_ai/utils/config.py +2 -2
  95. kiln_ai/utils/env.py +15 -0
  96. kiln_ai/utils/filesystem.py +14 -0
  97. kiln_ai/utils/filesystem_cache.py +60 -0
  98. kiln_ai/utils/litellm.py +94 -0
  99. kiln_ai/utils/lock.py +100 -0
  100. kiln_ai/utils/mime_type.py +38 -0
  101. kiln_ai/utils/pdf_utils.py +38 -0
  102. kiln_ai/utils/test_async_job_runner.py +151 -35
  103. kiln_ai/utils/test_env.py +142 -0
  104. kiln_ai/utils/test_filesystem_cache.py +316 -0
  105. kiln_ai/utils/test_litellm.py +206 -0
  106. kiln_ai/utils/test_lock.py +185 -0
  107. kiln_ai/utils/test_mime_type.py +66 -0
  108. kiln_ai/utils/test_pdf_utils.py +73 -0
  109. kiln_ai/utils/test_uuid.py +111 -0
  110. kiln_ai/utils/test_validation.py +524 -0
  111. kiln_ai/utils/uuid.py +9 -0
  112. kiln_ai/utils/validation.py +90 -0
  113. {kiln_ai-0.20.1.dist-info → kiln_ai-0.21.0.dist-info}/METADATA +7 -1
  114. kiln_ai-0.21.0.dist-info/RECORD +211 -0
  115. kiln_ai-0.20.1.dist-info/RECORD +0 -138
  116. {kiln_ai-0.20.1.dist-info → kiln_ai-0.21.0.dist-info}/WHEEL +0 -0
  117. {kiln_ai-0.20.1.dist-info → kiln_ai-0.21.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -0,0 +1,199 @@
1
+ from typing import Callable
2
+ from unittest.mock import MagicMock, patch
3
+
4
+ import pytest
5
+
6
+ from kiln_ai.adapters.vector_store.vector_store_registry import (
7
+ vector_store_adapter_for_config,
8
+ )
9
+ from kiln_ai.datamodel.datamodel_enums import ModelProviderName
10
+ from kiln_ai.datamodel.embedding import EmbeddingConfig
11
+ from kiln_ai.datamodel.rag import RagConfig
12
+ from kiln_ai.datamodel.vector_store import VectorStoreConfig, VectorStoreType
13
+
14
+
15
+ @pytest.fixture(autouse=True)
16
+ def patch_settings_dir(tmp_path):
17
+ with patch("kiln_ai.utils.config.Config.settings_dir", return_value=tmp_path):
18
+ yield
19
+
20
+
21
+ @pytest.fixture
22
+ def create_rag_config_factory() -> Callable[
23
+ [VectorStoreConfig, EmbeddingConfig], RagConfig
24
+ ]:
25
+ def create_rag_config(
26
+ vector_store_config: VectorStoreConfig, embedding_config: EmbeddingConfig
27
+ ) -> RagConfig:
28
+ return RagConfig(
29
+ name="test_rag",
30
+ tool_name="test_rag_tool",
31
+ tool_description="A test RAG tool for registry testing",
32
+ extractor_config_id="test_extractor",
33
+ chunker_config_id="test_chunker",
34
+ embedding_config_id=embedding_config.id,
35
+ vector_store_config_id=vector_store_config.id,
36
+ )
37
+
38
+ return create_rag_config
39
+
40
+
41
+ @pytest.fixture
42
+ def embedding_config():
43
+ """Create an embedding config for testing."""
44
+ return EmbeddingConfig(
45
+ name="test_embedding",
46
+ model_provider_name=ModelProviderName.openai,
47
+ model_name="text-embedding-ada-002",
48
+ properties={},
49
+ )
50
+
51
+
52
+ @pytest.fixture
53
+ def lancedb_fts_vector_store_config():
54
+ """Create a vector store config for testing."""
55
+ config = VectorStoreConfig(
56
+ name="test_config",
57
+ store_type=VectorStoreType.LANCE_DB_FTS,
58
+ properties={
59
+ "similarity_top_k": 10,
60
+ "overfetch_factor": 20,
61
+ "vector_column_name": "vector",
62
+ "text_key": "text",
63
+ "doc_id_key": "doc_id",
64
+ },
65
+ )
66
+ # Set an ID for the config since build_lancedb_vector_store requires it
67
+ config.id = "test_config_id"
68
+ return config
69
+
70
+
71
+ @pytest.fixture
72
+ def lancedb_knn_vector_store_config():
73
+ """Create a vector store config for testing."""
74
+ config = VectorStoreConfig(
75
+ name="test_config",
76
+ store_type=VectorStoreType.LANCE_DB_VECTOR,
77
+ properties={
78
+ "similarity_top_k": 10,
79
+ "overfetch_factor": 20,
80
+ "vector_column_name": "vector",
81
+ "text_key": "text",
82
+ "doc_id_key": "doc_id",
83
+ "nprobes": 10,
84
+ },
85
+ )
86
+ # Set an ID for the config since build_lancedb_vector_store requires it
87
+ config.id = "test_config_id"
88
+ return config
89
+
90
+
91
+ @pytest.fixture
92
+ def lancedb_hybrid_vector_store_config():
93
+ """Create a vector store config for testing."""
94
+ config = VectorStoreConfig(
95
+ name="test_config",
96
+ store_type=VectorStoreType.LANCE_DB_HYBRID,
97
+ properties={
98
+ "similarity_top_k": 10,
99
+ "overfetch_factor": 20,
100
+ "vector_column_name": "vector",
101
+ "text_key": "text",
102
+ "doc_id_key": "doc_id",
103
+ "nprobes": 10,
104
+ },
105
+ )
106
+ # Set an ID for the config since build_lancedb_vector_store requires it
107
+ config.id = "test_config_id"
108
+ return config
109
+
110
+
111
+ class TestVectorStoreAdapterForConfig:
112
+ """Test the vector_store_adapter_for_config function."""
113
+
114
+ @pytest.mark.asyncio
115
+ async def test_vector_store_adapter_for_config_unsupported_type(
116
+ self,
117
+ create_rag_config_factory,
118
+ lancedb_fts_vector_store_config,
119
+ embedding_config,
120
+ ):
121
+ """Test error handling for unsupported vector store types."""
122
+ # Create a mock config with an invalid store type
123
+ unsupported_config = MagicMock()
124
+ unsupported_config.store_type = "INVALID_TYPE"
125
+ unsupported_config.name = "unsupported"
126
+ unsupported_config.id = "test_config_id"
127
+
128
+ rag_config = create_rag_config_factory(
129
+ MagicMock(spec=VectorStoreConfig, id="test_config_id"), embedding_config
130
+ )
131
+ with pytest.raises(ValueError, match="Unhandled enum value"):
132
+ await vector_store_adapter_for_config(rag_config, unsupported_config)
133
+
134
+ async def test_lancedb_fts_vector_store_adapter_for_config(
135
+ self,
136
+ lancedb_fts_vector_store_config,
137
+ create_rag_config_factory,
138
+ embedding_config,
139
+ ):
140
+ rag_config = create_rag_config_factory(
141
+ lancedb_fts_vector_store_config, embedding_config
142
+ )
143
+ adapter = await vector_store_adapter_for_config(
144
+ rag_config, lancedb_fts_vector_store_config
145
+ )
146
+
147
+ assert adapter.vector_store_config == lancedb_fts_vector_store_config
148
+ assert adapter.vector_store_config.name == "test_config"
149
+ assert adapter.vector_store_config.store_type == VectorStoreType.LANCE_DB_FTS
150
+
151
+ async def test_lancedb_hybrid_vector_store_adapter_for_config(
152
+ self,
153
+ lancedb_hybrid_vector_store_config,
154
+ create_rag_config_factory,
155
+ embedding_config,
156
+ ):
157
+ rag_config = create_rag_config_factory(
158
+ lancedb_hybrid_vector_store_config, embedding_config
159
+ )
160
+ adapter = await vector_store_adapter_for_config(
161
+ rag_config, lancedb_hybrid_vector_store_config
162
+ )
163
+
164
+ assert adapter.vector_store_config == lancedb_hybrid_vector_store_config
165
+ assert adapter.vector_store_config.name == "test_config"
166
+ assert adapter.vector_store_config.store_type == VectorStoreType.LANCE_DB_HYBRID
167
+
168
+ async def test_lancedb_vector_vector_store_adapter_for_config(
169
+ self,
170
+ lancedb_knn_vector_store_config,
171
+ create_rag_config_factory,
172
+ embedding_config,
173
+ ):
174
+ rag_config = create_rag_config_factory(
175
+ lancedb_knn_vector_store_config, embedding_config
176
+ )
177
+ adapter = await vector_store_adapter_for_config(
178
+ rag_config, lancedb_knn_vector_store_config
179
+ )
180
+ assert adapter.vector_store_config == lancedb_knn_vector_store_config
181
+ assert adapter.vector_store_config.name == "test_config"
182
+ assert adapter.vector_store_config.store_type == VectorStoreType.LANCE_DB_VECTOR
183
+
184
+ async def test_vector_store_adapter_for_config_missing_id(
185
+ self,
186
+ create_rag_config_factory,
187
+ lancedb_fts_vector_store_config,
188
+ embedding_config,
189
+ ):
190
+ rag_config = create_rag_config_factory(
191
+ lancedb_fts_vector_store_config, embedding_config
192
+ )
193
+
194
+ lancedb_fts_vector_store_config.id = None
195
+
196
+ with pytest.raises(ValueError, match="Vector store config ID is required"):
197
+ await vector_store_adapter_for_config(
198
+ rag_config, lancedb_fts_vector_store_config
199
+ )
@@ -0,0 +1,33 @@
1
+ import logging
2
+
3
+ from kiln_ai.adapters.vector_store.base_vector_store_adapter import (
4
+ BaseVectorStoreAdapter,
5
+ )
6
+ from kiln_ai.adapters.vector_store.lancedb_adapter import LanceDBAdapter
7
+ from kiln_ai.datamodel.rag import RagConfig
8
+ from kiln_ai.datamodel.vector_store import VectorStoreConfig, VectorStoreType
9
+ from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ async def vector_store_adapter_for_config(
15
+ rag_config: RagConfig,
16
+ vector_store_config: VectorStoreConfig,
17
+ ) -> BaseVectorStoreAdapter:
18
+ vector_store_config_id = vector_store_config.id
19
+ if vector_store_config_id is None:
20
+ raise ValueError("Vector store config ID is required")
21
+
22
+ match vector_store_config.store_type:
23
+ case (
24
+ VectorStoreType.LANCE_DB_FTS
25
+ | VectorStoreType.LANCE_DB_HYBRID
26
+ | VectorStoreType.LANCE_DB_VECTOR
27
+ ):
28
+ return LanceDBAdapter(
29
+ rag_config,
30
+ vector_store_config,
31
+ )
32
+ case _:
33
+ raise_exhaustive_enum_error(vector_store_config.store_type)
@@ -11,21 +11,24 @@ User docs: https://docs.kiln.tech/developers/kiln-datamodel
11
11
 
12
12
  from __future__ import annotations
13
13
 
14
- from kiln_ai.datamodel import dataset_split, eval, strict_mode
14
+ from kiln_ai.datamodel import (
15
+ chunk,
16
+ dataset_split,
17
+ embedding,
18
+ eval,
19
+ extraction,
20
+ rag,
21
+ strict_mode,
22
+ )
15
23
  from kiln_ai.datamodel.datamodel_enums import (
16
24
  FineTuneStatusType,
17
25
  Priority,
18
26
  StructuredOutputMode,
19
27
  TaskOutputRatingType,
20
28
  )
21
- from kiln_ai.datamodel.dataset_split import (
22
- DatasetSplit,
23
- DatasetSplitDefinition,
24
- )
29
+ from kiln_ai.datamodel.dataset_split import DatasetSplit, DatasetSplitDefinition
25
30
  from kiln_ai.datamodel.external_tool_server import ExternalToolServer
26
- from kiln_ai.datamodel.finetune import (
27
- Finetune,
28
- )
31
+ from kiln_ai.datamodel.finetune import Finetune
29
32
  from kiln_ai.datamodel.project import Project
30
33
  from kiln_ai.datamodel.prompt import BasePrompt, Prompt
31
34
  from kiln_ai.datamodel.prompt_id import (
@@ -42,10 +45,7 @@ from kiln_ai.datamodel.task_output import (
42
45
  TaskOutput,
43
46
  TaskOutputRating,
44
47
  )
45
- from kiln_ai.datamodel.task_run import (
46
- TaskRun,
47
- Usage,
48
- )
48
+ from kiln_ai.datamodel.task_run import TaskRun, Usage
49
49
 
50
50
  __all__ = [
51
51
  "BasePrompt",
@@ -69,11 +69,14 @@ __all__ = [
69
69
  "TaskOutputRating",
70
70
  "TaskOutputRatingType",
71
71
  "TaskRequirement",
72
- "TaskRequirement",
73
72
  "TaskRun",
74
73
  "Usage",
74
+ "chunk",
75
75
  "dataset_split",
76
+ "embedding",
76
77
  "eval",
78
+ "extraction",
77
79
  "prompt_generator_values",
80
+ "rag",
78
81
  "strict_mode",
79
82
  ]
@@ -2,6 +2,7 @@ import json
2
2
  import os
3
3
  import re
4
4
  import shutil
5
+ import tempfile
5
6
  import unicodedata
6
7
  import uuid
7
8
  from abc import ABCMeta
@@ -15,9 +16,11 @@ from pydantic import (
15
16
  BeforeValidator,
16
17
  ConfigDict,
17
18
  Field,
19
+ SerializationInfo,
18
20
  ValidationError,
19
21
  ValidationInfo,
20
22
  computed_field,
23
+ model_serializer,
21
24
  model_validator,
22
25
  )
23
26
  from pydantic_core import ErrorDetails
@@ -26,6 +29,7 @@ from typing_extensions import Annotated, Self
26
29
  from kiln_ai.datamodel.model_cache import ModelCache
27
30
  from kiln_ai.utils.config import Config
28
31
  from kiln_ai.utils.formatting import snake_case
32
+ from kiln_ai.utils.mime_type import guess_extension
29
33
 
30
34
  # ID is a 12 digit random integer string.
31
35
  # Should be unique per item, at least inside the context of a parent/child relationship.
@@ -85,6 +89,161 @@ def string_to_valid_name(name: str) -> str:
85
89
  return valid_name.strip("_").strip()
86
90
 
87
91
 
92
+ class KilnAttachmentModel(BaseModel):
93
+ path: Path | None = Field(
94
+ default=None,
95
+ description="The path to the attachment relative to the parent model's path",
96
+ )
97
+
98
+ input_path: Path | None = Field(
99
+ default=None,
100
+ description="The input absolute path to the attachment. The file will be copied to its permanent location when the model is saved.",
101
+ )
102
+
103
+ @model_validator(mode="after")
104
+ def check_file_exists(self, info: ValidationInfo) -> Self:
105
+ context = info.context or {}
106
+
107
+ if self.path is None and self.input_path is None:
108
+ raise ValueError("Path or input path is not set")
109
+ if self.path is not None and self.input_path is not None:
110
+ raise ValueError("Path and input path cannot both be set")
111
+
112
+ # when loading from file, we only know the relative path so we cannot check if it exists
113
+ # without knowing the parent path
114
+ if context.get("loading_from_file", False):
115
+ if isinstance(self.path, str):
116
+ self.path = Path(self.path)
117
+ self.input_path = None
118
+ return self
119
+
120
+ # when creating a new attachment, the path is not set yet (it is set when the model is saved)
121
+ # so we only expect the absolute path to be set
122
+ if self.input_path is not None:
123
+ if isinstance(self.input_path, str):
124
+ self.input_path = Path(self.input_path)
125
+ if not self.input_path.is_absolute():
126
+ raise ValueError(f"Input path is not absolute: {self.input_path}")
127
+ if not os.path.exists(self.input_path):
128
+ raise ValueError(f"Input path does not exist: {self.input_path}")
129
+ if not os.path.isfile(self.input_path):
130
+ raise ValueError(f"Input path is not a file: {self.input_path}")
131
+
132
+ # this normalizes the path and resolves symlinks
133
+ self.input_path = self.input_path.resolve()
134
+
135
+ return self
136
+
137
+ if self.path is not None:
138
+ if isinstance(self.path, str):
139
+ self.path = Path(self.path)
140
+ if self.path.is_absolute():
141
+ raise ValueError(
142
+ f"Path is absolute but should be relative: {self.path}"
143
+ )
144
+ if not os.path.exists(self.path):
145
+ raise ValueError(f"Path does not exist: {self.path}")
146
+ if not os.path.isfile(self.path):
147
+ raise ValueError(f"Path is not a file: {self.path}")
148
+
149
+ return self
150
+
151
+ @model_serializer
152
+ def serialize(self, info: SerializationInfo) -> dict[str, Path] | None:
153
+ # when the attachment is optional on the model, we get None here
154
+ if self is None:
155
+ return None
156
+
157
+ context = info.context or {}
158
+
159
+ # serialization may also be called by other parts of the system, the callers should
160
+ # explicitly set save_attachments to True if they want to save attachments
161
+ save_attachments: bool = context.get("save_attachments", False)
162
+ if not save_attachments:
163
+ path_val = self.path if self.path is not None else self.input_path
164
+ if path_val is None:
165
+ raise ValueError("Attachment has no path")
166
+ return {"path": path_val}
167
+
168
+ dest_path: Path | None = context.get("dest_path", None)
169
+ if not dest_path or not isinstance(dest_path, Path):
170
+ raise ValueError(
171
+ f"dest_path must be a valid Path object when saving attachments, got: {dest_path}"
172
+ )
173
+ if not dest_path.is_dir():
174
+ raise ValueError("dest_path must be a directory when saving attachments")
175
+
176
+ # the attachment is already in the parent folder, so we don't need to copy it
177
+ # if the path is already relative, we consider it has been copied already
178
+ if self.path is not None:
179
+ return {"path": self.path}
180
+
181
+ # copy file and update the path to be relative to the dest_path
182
+ new_path = self.copy_file_to(dest_path, context.get("filename_prefix", None))
183
+
184
+ self.path = new_path.relative_to(dest_path)
185
+ self.input_path = None
186
+
187
+ return {"path": self.path}
188
+
189
+ def copy_file_to(
190
+ self, dest_folder: Path, filename_prefix: str | None = None
191
+ ) -> Path:
192
+ if self.input_path is None:
193
+ raise ValueError("Attachment has no input path to copy")
194
+
195
+ filename = f"{str(uuid.uuid4().int)[:12]}{self.input_path.suffix}"
196
+ if filename_prefix:
197
+ filename = f"{filename_prefix}_{filename}"
198
+ target_path = dest_folder / filename
199
+ shutil.copy(self.input_path, target_path)
200
+ return target_path
201
+
202
+ @classmethod
203
+ def from_data(cls, data: str | bytes, mime_type: str) -> Self:
204
+ """Create an attachment from str or byte data, in a temp file. The attachment is persisted to
205
+ its permanent location when the model is saved.
206
+ """
207
+ extension = guess_extension(mime_type) or ".unknown"
208
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=extension)
209
+ if isinstance(data, str):
210
+ temp_file.write(data.encode("utf-8"))
211
+ else:
212
+ temp_file.write(data)
213
+ temp_file.close()
214
+ return cls(input_path=Path(temp_file.name))
215
+
216
+ @classmethod
217
+ def from_file(cls, path: Path | str) -> Self:
218
+ """Create an attachment from a file path. The attachment is persisted to
219
+ its permanent location when the model is saved.
220
+ """
221
+ if isinstance(path, str):
222
+ path = Path(path)
223
+ return cls(input_path=path)
224
+
225
+ def resolve_path(self, parent_path: Path | None = None) -> Path:
226
+ """
227
+ Resolve the path of the attachment relative to the parent path. The attachment does not know
228
+ its parent, so we need to call this to get the full path.
229
+ Args:
230
+ parent_path (Path): The absolute path to the parent folder. Must be provided if the model is saved to disk.
231
+ Returns:
232
+ Path: The resolved path of the attachment
233
+ """
234
+ if self.input_path is not None:
235
+ return self.input_path
236
+ if self.path is None:
237
+ raise ValueError("Attachment path is not set")
238
+ if parent_path is None:
239
+ raise ValueError("Parent path is not set")
240
+ if not parent_path.is_absolute():
241
+ raise ValueError(
242
+ f"Failed to resolve attachment path for {self.path} because parent path is not absolute: {parent_path}"
243
+ )
244
+ return (parent_path / self.path).resolve()
245
+
246
+
88
247
  # Usage:
89
248
  # class MyModel(KilnBaseModel):
90
249
  # name: FilenameString = Field(description="The name of the model.")
@@ -223,7 +382,17 @@ class KilnBaseModel(BaseModel):
223
382
  f"id: {getattr(self, 'id', None)}, path: {path}"
224
383
  )
225
384
  path.parent.mkdir(parents=True, exist_ok=True)
226
- json_data = self.model_dump_json(indent=2, exclude={"path"})
385
+
386
+ json_data = self.model_dump_json(
387
+ indent=2,
388
+ exclude={"path"},
389
+ # dest_path is used by the attachment serializer to save attachments to the correct location
390
+ # and update the paths to be relative to path.parent
391
+ context={
392
+ "save_attachments": True,
393
+ "dest_path": path.parent,
394
+ },
395
+ )
227
396
  with open(path, "w", encoding="utf-8") as file:
228
397
  file.write(json_data)
229
398
  # save the path so even if something like name changes, the file doesn't move
@@ -0,0 +1,158 @@
1
+ import logging
2
+ from enum import Enum
3
+ from typing import TYPE_CHECKING, List, Union
4
+
5
+ import anyio
6
+ from pydantic import (
7
+ BaseModel,
8
+ Field,
9
+ SerializationInfo,
10
+ ValidationInfo,
11
+ field_serializer,
12
+ field_validator,
13
+ )
14
+
15
+ from kiln_ai.datamodel.basemodel import (
16
+ ID_TYPE,
17
+ FilenameString,
18
+ KilnAttachmentModel,
19
+ KilnParentedModel,
20
+ KilnParentModel,
21
+ )
22
+ from kiln_ai.datamodel.embedding import ChunkEmbeddings
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+ if TYPE_CHECKING:
27
+ from kiln_ai.datamodel.extraction import Extraction
28
+ from kiln_ai.datamodel.project import Project
29
+
30
+
31
+ def validate_fixed_window_chunker_properties(
32
+ properties: dict[str, str | int | float | bool],
33
+ ) -> dict[str, str | int | float | bool]:
34
+ """Validate the properties for the fixed window chunker and set defaults if needed."""
35
+ chunk_overlap = properties.get("chunk_overlap")
36
+ if chunk_overlap is None:
37
+ raise ValueError("Chunk overlap is required.")
38
+
39
+ chunk_size = properties.get("chunk_size")
40
+ if chunk_size is None:
41
+ raise ValueError("Chunk size is required.")
42
+
43
+ if not isinstance(chunk_overlap, int):
44
+ raise ValueError("Chunk overlap must be an integer.")
45
+ if chunk_overlap < 0:
46
+ raise ValueError("Chunk overlap must be greater than or equal to 0.")
47
+
48
+ if not isinstance(chunk_size, int):
49
+ raise ValueError("Chunk size must be an integer.")
50
+ if chunk_size <= 0:
51
+ raise ValueError("Chunk size must be greater than 0.")
52
+
53
+ if chunk_overlap >= chunk_size:
54
+ raise ValueError("Chunk overlap must be less than chunk size.")
55
+
56
+ return properties
57
+
58
+
59
+ class ChunkerType(str, Enum):
60
+ FIXED_WINDOW = "fixed_window"
61
+
62
+
63
+ class ChunkerConfig(KilnParentedModel):
64
+ name: FilenameString = Field(
65
+ description="A name to identify the chunker config.",
66
+ )
67
+ description: str | None = Field(
68
+ default=None, description="The description of the chunker config"
69
+ )
70
+ chunker_type: ChunkerType = Field(
71
+ description="This is used to determine the type of chunker to use.",
72
+ )
73
+ properties: dict[str, str | int | float | bool] = Field(
74
+ description="Properties to be used to execute the chunker config. This is chunker_type specific and should serialize to a json dict.",
75
+ )
76
+
77
+ # Workaround to return typed parent without importing Project
78
+ def parent_project(self) -> Union["Project", None]:
79
+ if self.parent is None or self.parent.__class__.__name__ != "Project":
80
+ return None
81
+ return self.parent # type: ignore
82
+
83
+ @field_validator("properties")
84
+ @classmethod
85
+ def validate_properties(
86
+ cls, properties: dict[str, str | int | float | bool], info: ValidationInfo
87
+ ) -> dict[str, str | int | float | bool]:
88
+ if info.data.get("chunker_type") == ChunkerType.FIXED_WINDOW:
89
+ # do not trigger revalidation of properties
90
+ return validate_fixed_window_chunker_properties(properties)
91
+ return properties
92
+
93
+ def chunk_size(self) -> int | None:
94
+ if self.properties.get("chunk_size") is None:
95
+ return None
96
+ if not isinstance(self.properties["chunk_size"], int):
97
+ raise ValueError("Chunk size must be an integer.")
98
+ return self.properties["chunk_size"]
99
+
100
+ def chunk_overlap(self) -> int | None:
101
+ if self.properties.get("chunk_overlap") is None:
102
+ return None
103
+ if not isinstance(self.properties["chunk_overlap"], int):
104
+ raise ValueError("Chunk overlap must be an integer.")
105
+ return self.properties["chunk_overlap"]
106
+
107
+
108
+ class Chunk(BaseModel):
109
+ content: KilnAttachmentModel = Field(
110
+ description="The content of the chunk, stored as an attachment."
111
+ )
112
+
113
+ @field_serializer("content")
114
+ def serialize_content(
115
+ self, content: KilnAttachmentModel, info: SerializationInfo
116
+ ) -> dict:
117
+ context = info.context or {}
118
+ context["filename_prefix"] = "content"
119
+ return content.model_dump(mode="json", context=context)
120
+
121
+
122
+ class ChunkedDocument(
123
+ KilnParentedModel, KilnParentModel, parent_of={"chunk_embeddings": ChunkEmbeddings}
124
+ ):
125
+ chunker_config_id: ID_TYPE = Field(
126
+ description="The ID of the chunker config used to chunk the document.",
127
+ )
128
+ chunks: List[Chunk] = Field(description="The chunks of the document.")
129
+
130
+ def parent_extraction(self) -> Union["Extraction", None]:
131
+ if self.parent is None or self.parent.__class__.__name__ != "Extraction":
132
+ return None
133
+ return self.parent # type: ignore
134
+
135
+ def chunk_embeddings(self, readonly: bool = False) -> list[ChunkEmbeddings]:
136
+ return super().chunk_embeddings(readonly=readonly) # type: ignore
137
+
138
+ async def load_chunks_text(self) -> list[str]:
139
+ """Utility to return a list of text for each chunk, loaded from each chunk's content attachment."""
140
+ if not self.path:
141
+ raise ValueError(
142
+ "Failed to resolve the path of chunk content attachment because the chunk does not have a path."
143
+ )
144
+
145
+ chunks_text: list[str] = []
146
+ for chunk in self.chunks:
147
+ full_path = chunk.content.resolve_path(self.path.parent)
148
+
149
+ try:
150
+ chunks_text.append(
151
+ await anyio.Path(full_path).read_text(encoding="utf-8")
152
+ )
153
+ except Exception as e:
154
+ raise ValueError(
155
+ f"Failed to read chunk content for {full_path}: {e}"
156
+ ) from e
157
+
158
+ return chunks_text
@@ -100,3 +100,30 @@ class ModelProviderName(str, Enum):
100
100
  siliconflow_cn = "siliconflow_cn"
101
101
  cerebras = "cerebras"
102
102
  docker_model_runner = "docker_model_runner"
103
+
104
+
105
+ class KilnMimeType(str, Enum):
106
+ """
107
+ Enumeration of supported mime types.
108
+ """
109
+
110
+ # documents
111
+ PDF = "application/pdf"
112
+ CSV = "text/csv"
113
+ TXT = "text/plain"
114
+ HTML = "text/html"
115
+ MD = "text/markdown"
116
+
117
+ # images
118
+ PNG = "image/png"
119
+ JPG = "image/jpeg"
120
+ JPEG = "image/jpeg"
121
+
122
+ # audio
123
+ MP3 = "audio/mpeg"
124
+ WAV = "audio/wav"
125
+ OGG = "audio/ogg"
126
+
127
+ # video
128
+ MP4 = "video/mp4"
129
+ MOV = "video/quicktime"