kiln-ai 0.20.1__py3-none-any.whl → 0.22.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +6 -0
- kiln_ai/adapters/adapter_registry.py +43 -226
- kiln_ai/adapters/chunkers/__init__.py +13 -0
- kiln_ai/adapters/chunkers/base_chunker.py +42 -0
- kiln_ai/adapters/chunkers/chunker_registry.py +16 -0
- kiln_ai/adapters/chunkers/fixed_window_chunker.py +39 -0
- kiln_ai/adapters/chunkers/helpers.py +23 -0
- kiln_ai/adapters/chunkers/test_base_chunker.py +63 -0
- kiln_ai/adapters/chunkers/test_chunker_registry.py +28 -0
- kiln_ai/adapters/chunkers/test_fixed_window_chunker.py +346 -0
- kiln_ai/adapters/chunkers/test_helpers.py +75 -0
- kiln_ai/adapters/data_gen/test_data_gen_task.py +9 -3
- kiln_ai/adapters/embedding/__init__.py +0 -0
- kiln_ai/adapters/embedding/base_embedding_adapter.py +44 -0
- kiln_ai/adapters/embedding/embedding_registry.py +32 -0
- kiln_ai/adapters/embedding/litellm_embedding_adapter.py +199 -0
- kiln_ai/adapters/embedding/test_base_embedding_adapter.py +283 -0
- kiln_ai/adapters/embedding/test_embedding_registry.py +166 -0
- kiln_ai/adapters/embedding/test_litellm_embedding_adapter.py +1149 -0
- kiln_ai/adapters/eval/eval_runner.py +6 -2
- kiln_ai/adapters/eval/test_base_eval.py +1 -3
- kiln_ai/adapters/eval/test_g_eval.py +1 -1
- kiln_ai/adapters/extractors/__init__.py +18 -0
- kiln_ai/adapters/extractors/base_extractor.py +72 -0
- kiln_ai/adapters/extractors/encoding.py +20 -0
- kiln_ai/adapters/extractors/extractor_registry.py +44 -0
- kiln_ai/adapters/extractors/extractor_runner.py +112 -0
- kiln_ai/adapters/extractors/litellm_extractor.py +406 -0
- kiln_ai/adapters/extractors/test_base_extractor.py +244 -0
- kiln_ai/adapters/extractors/test_encoding.py +54 -0
- kiln_ai/adapters/extractors/test_extractor_registry.py +181 -0
- kiln_ai/adapters/extractors/test_extractor_runner.py +181 -0
- kiln_ai/adapters/extractors/test_litellm_extractor.py +1290 -0
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +2 -2
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +2 -6
- kiln_ai/adapters/fine_tune/test_together_finetune.py +2 -6
- kiln_ai/adapters/ml_embedding_model_list.py +494 -0
- kiln_ai/adapters/ml_model_list.py +876 -18
- kiln_ai/adapters/model_adapters/litellm_adapter.py +40 -75
- kiln_ai/adapters/model_adapters/test_litellm_adapter.py +79 -1
- kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py +119 -5
- kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +9 -3
- kiln_ai/adapters/model_adapters/test_structured_output.py +9 -10
- kiln_ai/adapters/ollama_tools.py +69 -12
- kiln_ai/adapters/provider_tools.py +190 -46
- kiln_ai/adapters/rag/deduplication.py +49 -0
- kiln_ai/adapters/rag/progress.py +252 -0
- kiln_ai/adapters/rag/rag_runners.py +844 -0
- kiln_ai/adapters/rag/test_deduplication.py +195 -0
- kiln_ai/adapters/rag/test_progress.py +785 -0
- kiln_ai/adapters/rag/test_rag_runners.py +2376 -0
- kiln_ai/adapters/remote_config.py +80 -8
- kiln_ai/adapters/test_adapter_registry.py +579 -86
- kiln_ai/adapters/test_ml_embedding_model_list.py +239 -0
- kiln_ai/adapters/test_ml_model_list.py +202 -0
- kiln_ai/adapters/test_ollama_tools.py +340 -1
- kiln_ai/adapters/test_prompt_builders.py +1 -1
- kiln_ai/adapters/test_provider_tools.py +199 -8
- kiln_ai/adapters/test_remote_config.py +551 -56
- kiln_ai/adapters/vector_store/__init__.py +1 -0
- kiln_ai/adapters/vector_store/base_vector_store_adapter.py +83 -0
- kiln_ai/adapters/vector_store/lancedb_adapter.py +389 -0
- kiln_ai/adapters/vector_store/test_base_vector_store.py +160 -0
- kiln_ai/adapters/vector_store/test_lancedb_adapter.py +1841 -0
- kiln_ai/adapters/vector_store/test_vector_store_registry.py +199 -0
- kiln_ai/adapters/vector_store/vector_store_registry.py +33 -0
- kiln_ai/datamodel/__init__.py +16 -13
- kiln_ai/datamodel/basemodel.py +201 -4
- kiln_ai/datamodel/chunk.py +158 -0
- kiln_ai/datamodel/datamodel_enums.py +27 -0
- kiln_ai/datamodel/embedding.py +64 -0
- kiln_ai/datamodel/external_tool_server.py +206 -54
- kiln_ai/datamodel/extraction.py +317 -0
- kiln_ai/datamodel/project.py +33 -1
- kiln_ai/datamodel/rag.py +79 -0
- kiln_ai/datamodel/task.py +5 -0
- kiln_ai/datamodel/task_output.py +41 -11
- kiln_ai/datamodel/test_attachment.py +649 -0
- kiln_ai/datamodel/test_basemodel.py +270 -14
- kiln_ai/datamodel/test_chunk_models.py +317 -0
- kiln_ai/datamodel/test_dataset_split.py +1 -1
- kiln_ai/datamodel/test_datasource.py +50 -0
- kiln_ai/datamodel/test_embedding_models.py +448 -0
- kiln_ai/datamodel/test_eval_model.py +6 -6
- kiln_ai/datamodel/test_external_tool_server.py +534 -152
- kiln_ai/datamodel/test_extraction_chunk.py +206 -0
- kiln_ai/datamodel/test_extraction_model.py +501 -0
- kiln_ai/datamodel/test_rag.py +641 -0
- kiln_ai/datamodel/test_task.py +35 -1
- kiln_ai/datamodel/test_tool_id.py +187 -1
- kiln_ai/datamodel/test_vector_store.py +320 -0
- kiln_ai/datamodel/tool_id.py +58 -0
- kiln_ai/datamodel/vector_store.py +141 -0
- kiln_ai/tools/base_tool.py +12 -3
- kiln_ai/tools/built_in_tools/math_tools.py +12 -4
- kiln_ai/tools/kiln_task_tool.py +158 -0
- kiln_ai/tools/mcp_server_tool.py +2 -2
- kiln_ai/tools/mcp_session_manager.py +51 -22
- kiln_ai/tools/rag_tools.py +164 -0
- kiln_ai/tools/test_kiln_task_tool.py +527 -0
- kiln_ai/tools/test_mcp_server_tool.py +4 -15
- kiln_ai/tools/test_mcp_session_manager.py +187 -227
- kiln_ai/tools/test_rag_tools.py +929 -0
- kiln_ai/tools/test_tool_registry.py +290 -7
- kiln_ai/tools/tool_registry.py +69 -16
- kiln_ai/utils/__init__.py +3 -0
- kiln_ai/utils/async_job_runner.py +62 -17
- kiln_ai/utils/config.py +2 -2
- kiln_ai/utils/env.py +15 -0
- kiln_ai/utils/filesystem.py +14 -0
- kiln_ai/utils/filesystem_cache.py +60 -0
- kiln_ai/utils/litellm.py +94 -0
- kiln_ai/utils/lock.py +100 -0
- kiln_ai/utils/mime_type.py +38 -0
- kiln_ai/utils/open_ai_types.py +19 -2
- kiln_ai/utils/pdf_utils.py +59 -0
- kiln_ai/utils/test_async_job_runner.py +151 -35
- kiln_ai/utils/test_env.py +142 -0
- kiln_ai/utils/test_filesystem_cache.py +316 -0
- kiln_ai/utils/test_litellm.py +206 -0
- kiln_ai/utils/test_lock.py +185 -0
- kiln_ai/utils/test_mime_type.py +66 -0
- kiln_ai/utils/test_open_ai_types.py +88 -12
- kiln_ai/utils/test_pdf_utils.py +86 -0
- kiln_ai/utils/test_uuid.py +111 -0
- kiln_ai/utils/test_validation.py +524 -0
- kiln_ai/utils/uuid.py +9 -0
- kiln_ai/utils/validation.py +90 -0
- {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/METADATA +9 -1
- kiln_ai-0.22.0.dist-info/RECORD +213 -0
- kiln_ai-0.20.1.dist-info/RECORD +0 -138
- {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -0,0 +1,158 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from enum import Enum
|
|
3
|
+
from typing import TYPE_CHECKING, List, Union
|
|
4
|
+
|
|
5
|
+
import anyio
|
|
6
|
+
from pydantic import (
|
|
7
|
+
BaseModel,
|
|
8
|
+
Field,
|
|
9
|
+
SerializationInfo,
|
|
10
|
+
ValidationInfo,
|
|
11
|
+
field_serializer,
|
|
12
|
+
field_validator,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
from kiln_ai.datamodel.basemodel import (
|
|
16
|
+
ID_TYPE,
|
|
17
|
+
FilenameString,
|
|
18
|
+
KilnAttachmentModel,
|
|
19
|
+
KilnParentedModel,
|
|
20
|
+
KilnParentModel,
|
|
21
|
+
)
|
|
22
|
+
from kiln_ai.datamodel.embedding import ChunkEmbeddings
|
|
23
|
+
|
|
24
|
+
logger = logging.getLogger(__name__)
|
|
25
|
+
|
|
26
|
+
if TYPE_CHECKING:
|
|
27
|
+
from kiln_ai.datamodel.extraction import Extraction
|
|
28
|
+
from kiln_ai.datamodel.project import Project
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def validate_fixed_window_chunker_properties(
|
|
32
|
+
properties: dict[str, str | int | float | bool],
|
|
33
|
+
) -> dict[str, str | int | float | bool]:
|
|
34
|
+
"""Validate the properties for the fixed window chunker and set defaults if needed."""
|
|
35
|
+
chunk_overlap = properties.get("chunk_overlap")
|
|
36
|
+
if chunk_overlap is None:
|
|
37
|
+
raise ValueError("Chunk overlap is required.")
|
|
38
|
+
|
|
39
|
+
chunk_size = properties.get("chunk_size")
|
|
40
|
+
if chunk_size is None:
|
|
41
|
+
raise ValueError("Chunk size is required.")
|
|
42
|
+
|
|
43
|
+
if not isinstance(chunk_overlap, int):
|
|
44
|
+
raise ValueError("Chunk overlap must be an integer.")
|
|
45
|
+
if chunk_overlap < 0:
|
|
46
|
+
raise ValueError("Chunk overlap must be greater than or equal to 0.")
|
|
47
|
+
|
|
48
|
+
if not isinstance(chunk_size, int):
|
|
49
|
+
raise ValueError("Chunk size must be an integer.")
|
|
50
|
+
if chunk_size <= 0:
|
|
51
|
+
raise ValueError("Chunk size must be greater than 0.")
|
|
52
|
+
|
|
53
|
+
if chunk_overlap >= chunk_size:
|
|
54
|
+
raise ValueError("Chunk overlap must be less than chunk size.")
|
|
55
|
+
|
|
56
|
+
return properties
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class ChunkerType(str, Enum):
|
|
60
|
+
FIXED_WINDOW = "fixed_window"
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class ChunkerConfig(KilnParentedModel):
|
|
64
|
+
name: FilenameString = Field(
|
|
65
|
+
description="A name to identify the chunker config.",
|
|
66
|
+
)
|
|
67
|
+
description: str | None = Field(
|
|
68
|
+
default=None, description="The description of the chunker config"
|
|
69
|
+
)
|
|
70
|
+
chunker_type: ChunkerType = Field(
|
|
71
|
+
description="This is used to determine the type of chunker to use.",
|
|
72
|
+
)
|
|
73
|
+
properties: dict[str, str | int | float | bool] = Field(
|
|
74
|
+
description="Properties to be used to execute the chunker config. This is chunker_type specific and should serialize to a json dict.",
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
# Workaround to return typed parent without importing Project
|
|
78
|
+
def parent_project(self) -> Union["Project", None]:
|
|
79
|
+
if self.parent is None or self.parent.__class__.__name__ != "Project":
|
|
80
|
+
return None
|
|
81
|
+
return self.parent # type: ignore
|
|
82
|
+
|
|
83
|
+
@field_validator("properties")
|
|
84
|
+
@classmethod
|
|
85
|
+
def validate_properties(
|
|
86
|
+
cls, properties: dict[str, str | int | float | bool], info: ValidationInfo
|
|
87
|
+
) -> dict[str, str | int | float | bool]:
|
|
88
|
+
if info.data.get("chunker_type") == ChunkerType.FIXED_WINDOW:
|
|
89
|
+
# do not trigger revalidation of properties
|
|
90
|
+
return validate_fixed_window_chunker_properties(properties)
|
|
91
|
+
return properties
|
|
92
|
+
|
|
93
|
+
def chunk_size(self) -> int | None:
|
|
94
|
+
if self.properties.get("chunk_size") is None:
|
|
95
|
+
return None
|
|
96
|
+
if not isinstance(self.properties["chunk_size"], int):
|
|
97
|
+
raise ValueError("Chunk size must be an integer.")
|
|
98
|
+
return self.properties["chunk_size"]
|
|
99
|
+
|
|
100
|
+
def chunk_overlap(self) -> int | None:
|
|
101
|
+
if self.properties.get("chunk_overlap") is None:
|
|
102
|
+
return None
|
|
103
|
+
if not isinstance(self.properties["chunk_overlap"], int):
|
|
104
|
+
raise ValueError("Chunk overlap must be an integer.")
|
|
105
|
+
return self.properties["chunk_overlap"]
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
class Chunk(BaseModel):
|
|
109
|
+
content: KilnAttachmentModel = Field(
|
|
110
|
+
description="The content of the chunk, stored as an attachment."
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
@field_serializer("content")
|
|
114
|
+
def serialize_content(
|
|
115
|
+
self, content: KilnAttachmentModel, info: SerializationInfo
|
|
116
|
+
) -> dict:
|
|
117
|
+
context = info.context or {}
|
|
118
|
+
context["filename_prefix"] = "content"
|
|
119
|
+
return content.model_dump(mode="json", context=context)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
class ChunkedDocument(
|
|
123
|
+
KilnParentedModel, KilnParentModel, parent_of={"chunk_embeddings": ChunkEmbeddings}
|
|
124
|
+
):
|
|
125
|
+
chunker_config_id: ID_TYPE = Field(
|
|
126
|
+
description="The ID of the chunker config used to chunk the document.",
|
|
127
|
+
)
|
|
128
|
+
chunks: List[Chunk] = Field(description="The chunks of the document.")
|
|
129
|
+
|
|
130
|
+
def parent_extraction(self) -> Union["Extraction", None]:
|
|
131
|
+
if self.parent is None or self.parent.__class__.__name__ != "Extraction":
|
|
132
|
+
return None
|
|
133
|
+
return self.parent # type: ignore
|
|
134
|
+
|
|
135
|
+
def chunk_embeddings(self, readonly: bool = False) -> list[ChunkEmbeddings]:
|
|
136
|
+
return super().chunk_embeddings(readonly=readonly) # type: ignore
|
|
137
|
+
|
|
138
|
+
async def load_chunks_text(self) -> list[str]:
|
|
139
|
+
"""Utility to return a list of text for each chunk, loaded from each chunk's content attachment."""
|
|
140
|
+
if not self.path:
|
|
141
|
+
raise ValueError(
|
|
142
|
+
"Failed to resolve the path of chunk content attachment because the chunk does not have a path."
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
chunks_text: list[str] = []
|
|
146
|
+
for chunk in self.chunks:
|
|
147
|
+
full_path = chunk.content.resolve_path(self.path.parent)
|
|
148
|
+
|
|
149
|
+
try:
|
|
150
|
+
chunks_text.append(
|
|
151
|
+
await anyio.Path(full_path).read_text(encoding="utf-8")
|
|
152
|
+
)
|
|
153
|
+
except Exception as e:
|
|
154
|
+
raise ValueError(
|
|
155
|
+
f"Failed to read chunk content for {full_path}: {e}"
|
|
156
|
+
) from e
|
|
157
|
+
|
|
158
|
+
return chunks_text
|
|
@@ -100,3 +100,30 @@ class ModelProviderName(str, Enum):
|
|
|
100
100
|
siliconflow_cn = "siliconflow_cn"
|
|
101
101
|
cerebras = "cerebras"
|
|
102
102
|
docker_model_runner = "docker_model_runner"
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class KilnMimeType(str, Enum):
|
|
106
|
+
"""
|
|
107
|
+
Enumeration of supported mime types.
|
|
108
|
+
"""
|
|
109
|
+
|
|
110
|
+
# documents
|
|
111
|
+
PDF = "application/pdf"
|
|
112
|
+
CSV = "text/csv"
|
|
113
|
+
TXT = "text/plain"
|
|
114
|
+
HTML = "text/html"
|
|
115
|
+
MD = "text/markdown"
|
|
116
|
+
|
|
117
|
+
# images
|
|
118
|
+
PNG = "image/png"
|
|
119
|
+
JPG = "image/jpeg"
|
|
120
|
+
JPEG = "image/jpeg"
|
|
121
|
+
|
|
122
|
+
# audio
|
|
123
|
+
MP3 = "audio/mpeg"
|
|
124
|
+
WAV = "audio/wav"
|
|
125
|
+
OGG = "audio/ogg"
|
|
126
|
+
|
|
127
|
+
# video
|
|
128
|
+
MP4 = "video/mp4"
|
|
129
|
+
MOV = "video/quicktime"
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING, List, Union
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, Field, model_validator
|
|
4
|
+
|
|
5
|
+
from kiln_ai.datamodel.basemodel import ID_TYPE, FilenameString, KilnParentedModel
|
|
6
|
+
from kiln_ai.datamodel.datamodel_enums import ModelProviderName
|
|
7
|
+
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
from kiln_ai.datamodel.chunk import ChunkedDocument
|
|
10
|
+
from kiln_ai.datamodel.project import Project
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class EmbeddingConfig(KilnParentedModel):
|
|
14
|
+
name: FilenameString = Field(
|
|
15
|
+
description="A name to identify the embedding config.",
|
|
16
|
+
)
|
|
17
|
+
description: str | None = Field(
|
|
18
|
+
default=None,
|
|
19
|
+
description="A description for your reference, not shared with embedding models.",
|
|
20
|
+
)
|
|
21
|
+
model_provider_name: ModelProviderName = Field(
|
|
22
|
+
description="The provider to use to generate embeddings.",
|
|
23
|
+
)
|
|
24
|
+
model_name: str = Field(
|
|
25
|
+
description="The model to use to generate embeddings.",
|
|
26
|
+
)
|
|
27
|
+
properties: dict[str, str | int | float | bool] = Field(
|
|
28
|
+
description="Properties to be used to execute the embedding config.",
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
# Workaround to return typed parent without importing Project
|
|
32
|
+
def parent_project(self) -> Union["Project", None]:
|
|
33
|
+
if self.parent is None or self.parent.__class__.__name__ != "Project":
|
|
34
|
+
return None
|
|
35
|
+
return self.parent # type: ignore
|
|
36
|
+
|
|
37
|
+
@model_validator(mode="after")
|
|
38
|
+
def validate_properties(self):
|
|
39
|
+
if "dimensions" in self.properties:
|
|
40
|
+
if (
|
|
41
|
+
not isinstance(self.properties["dimensions"], int)
|
|
42
|
+
or self.properties["dimensions"] <= 0
|
|
43
|
+
):
|
|
44
|
+
raise ValueError("Dimensions must be a positive integer")
|
|
45
|
+
|
|
46
|
+
return self
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class Embedding(BaseModel):
|
|
50
|
+
vector: List[float] = Field(description="The vector of the embedding.")
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class ChunkEmbeddings(KilnParentedModel):
|
|
54
|
+
embedding_config_id: ID_TYPE = Field(
|
|
55
|
+
description="The ID of the embedding config used to generate the embeddings.",
|
|
56
|
+
)
|
|
57
|
+
embeddings: List[Embedding] = Field(
|
|
58
|
+
description="The embeddings of the chunks. The embedding at index i corresponds to the chunk at index i in the parent chunked document."
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
def parent_chunked_document(self) -> Union["ChunkedDocument", None]:
|
|
62
|
+
if self.parent is None or self.parent.__class__.__name__ != "ChunkedDocument":
|
|
63
|
+
return None
|
|
64
|
+
return self.parent # type: ignore
|
|
@@ -1,7 +1,10 @@
|
|
|
1
|
+
import re
|
|
1
2
|
from enum import Enum
|
|
2
|
-
from typing import Any
|
|
3
|
+
from typing import Any
|
|
4
|
+
from urllib.parse import urlparse
|
|
3
5
|
|
|
4
6
|
from pydantic import Field, PrivateAttr, model_validator
|
|
7
|
+
from typing_extensions import NotRequired, TypedDict
|
|
5
8
|
|
|
6
9
|
from kiln_ai.datamodel.basemodel import (
|
|
7
10
|
FilenameString,
|
|
@@ -9,6 +12,7 @@ from kiln_ai.datamodel.basemodel import (
|
|
|
9
12
|
)
|
|
10
13
|
from kiln_ai.utils.config import MCP_SECRETS_KEY, Config
|
|
11
14
|
from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
|
|
15
|
+
from kiln_ai.utils.validation import tool_name_validator, validate_return_dict_prop
|
|
12
16
|
|
|
13
17
|
|
|
14
18
|
class ToolServerType(str, Enum):
|
|
@@ -18,6 +22,28 @@ class ToolServerType(str, Enum):
|
|
|
18
22
|
|
|
19
23
|
remote_mcp = "remote_mcp"
|
|
20
24
|
local_mcp = "local_mcp"
|
|
25
|
+
kiln_task = "kiln_task"
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class LocalServerProperties(TypedDict, total=True):
|
|
29
|
+
command: str
|
|
30
|
+
args: NotRequired[list[str]]
|
|
31
|
+
env_vars: NotRequired[dict[str, str]]
|
|
32
|
+
secret_env_var_keys: NotRequired[list[str]]
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class RemoteServerProperties(TypedDict, total=True):
|
|
36
|
+
server_url: str
|
|
37
|
+
headers: NotRequired[dict[str, str]]
|
|
38
|
+
secret_header_keys: NotRequired[list[str]]
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class KilnTaskServerProperties(TypedDict, total=True):
|
|
42
|
+
task_id: str
|
|
43
|
+
run_config_id: str
|
|
44
|
+
name: str
|
|
45
|
+
description: str
|
|
46
|
+
is_archived: bool
|
|
21
47
|
|
|
22
48
|
|
|
23
49
|
class ExternalToolServer(KilnParentedModel):
|
|
@@ -36,8 +62,10 @@ class ExternalToolServer(KilnParentedModel):
|
|
|
36
62
|
default=None,
|
|
37
63
|
description="A description of the external tool for you and your team. Will not be used in prompts/training/validation.",
|
|
38
64
|
)
|
|
39
|
-
|
|
40
|
-
|
|
65
|
+
|
|
66
|
+
properties: (
|
|
67
|
+
LocalServerProperties | RemoteServerProperties | KilnTaskServerProperties
|
|
68
|
+
) = Field(
|
|
41
69
|
description="Configuration properties specific to the tool type.",
|
|
42
70
|
)
|
|
43
71
|
|
|
@@ -80,6 +108,9 @@ class ExternalToolServer(KilnParentedModel):
|
|
|
80
108
|
# Remove from env_vars immediately so they are not saved to file
|
|
81
109
|
del env_vars[key_name]
|
|
82
110
|
|
|
111
|
+
case ToolServerType.kiln_task:
|
|
112
|
+
pass
|
|
113
|
+
|
|
83
114
|
case _:
|
|
84
115
|
raise_exhaustive_enum_error(self.type)
|
|
85
116
|
|
|
@@ -93,76 +124,195 @@ class ExternalToolServer(KilnParentedModel):
|
|
|
93
124
|
if name == "properties":
|
|
94
125
|
self._process_secrets_from_properties()
|
|
95
126
|
|
|
96
|
-
|
|
97
|
-
|
|
127
|
+
# Validation Helpers
|
|
128
|
+
|
|
129
|
+
@classmethod
|
|
130
|
+
def check_server_url(cls, server_url: str) -> None:
|
|
131
|
+
"""Validate Server URL"""
|
|
132
|
+
if not isinstance(server_url, str):
|
|
133
|
+
raise ValueError("Server URL must be a string")
|
|
134
|
+
|
|
135
|
+
# Check for leading whitespace in URL
|
|
136
|
+
if server_url != server_url.lstrip():
|
|
137
|
+
raise ValueError("Server URL must not have leading whitespace")
|
|
138
|
+
|
|
139
|
+
parsed_url = urlparse(server_url)
|
|
140
|
+
if not parsed_url.netloc:
|
|
141
|
+
raise ValueError("Server URL is not a valid URL")
|
|
142
|
+
if parsed_url.scheme not in ["http", "https"]:
|
|
143
|
+
raise ValueError("Server URL must start with http:// or https://")
|
|
144
|
+
|
|
145
|
+
@classmethod
|
|
146
|
+
def check_headers(cls, headers: dict) -> None:
|
|
147
|
+
"""Validate Headers"""
|
|
148
|
+
if not isinstance(headers, dict):
|
|
149
|
+
raise ValueError("headers must be a dictionary")
|
|
150
|
+
|
|
151
|
+
for key, value in headers.items():
|
|
152
|
+
if not key:
|
|
153
|
+
raise ValueError("Header name is required")
|
|
154
|
+
if not value:
|
|
155
|
+
raise ValueError("Header value is required")
|
|
156
|
+
|
|
157
|
+
# Reject invalid header names and CR/LF in names/values
|
|
158
|
+
token_re = re.compile(r"^[!#$%&'*+.^_`|~0-9A-Za-z-]+$")
|
|
159
|
+
if not token_re.match(key):
|
|
160
|
+
raise ValueError(f'Invalid header name: "{key}"')
|
|
161
|
+
if re.search(r"\r|\n", key) or re.search(r"\r|\n", value):
|
|
162
|
+
raise ValueError(
|
|
163
|
+
"Header names/values must not contain invalid characters"
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
@classmethod
|
|
167
|
+
def check_secret_keys(
|
|
168
|
+
cls, secret_keys: list, key_type: str, tool_type: str
|
|
169
|
+
) -> None:
|
|
170
|
+
"""Validate Secret Keys (generic method for both header and env var keys)"""
|
|
171
|
+
if not isinstance(secret_keys, list):
|
|
172
|
+
raise ValueError(
|
|
173
|
+
f"{key_type} must be a list for external tools of type '{tool_type}'"
|
|
174
|
+
)
|
|
175
|
+
if not all(isinstance(k, str) for k in secret_keys):
|
|
176
|
+
raise ValueError(f"{key_type} must contain only strings")
|
|
177
|
+
if not all(key for key in secret_keys):
|
|
178
|
+
raise ValueError("Secret key is required")
|
|
179
|
+
|
|
180
|
+
@classmethod
|
|
181
|
+
def check_env_vars(cls, env_vars: dict) -> None:
|
|
182
|
+
"""Validate Environment Variables"""
|
|
183
|
+
if not isinstance(env_vars, dict):
|
|
184
|
+
raise ValueError("environment variables must be a dictionary")
|
|
185
|
+
|
|
186
|
+
# Validate env_vars keys are in the correct format for Environment Variables
|
|
187
|
+
# According to POSIX specification, environment variable names must:
|
|
188
|
+
# - Start with a letter (a-z, A-Z) or underscore (_)
|
|
189
|
+
# - Contain only ASCII letters, digits, and underscores
|
|
190
|
+
for key, _ in env_vars.items():
|
|
191
|
+
if not key or not (
|
|
192
|
+
key[0].isascii() and (key[0].isalpha() or key[0] == "_")
|
|
193
|
+
):
|
|
194
|
+
raise ValueError(
|
|
195
|
+
f"Invalid environment variable key: {key}. Must start with a letter or underscore."
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
if not all(c.isascii() and (c.isalnum() or c == "_") for c in key):
|
|
199
|
+
raise ValueError(
|
|
200
|
+
f"Invalid environment variable key: {key}. Can only contain letters, digits, and underscores."
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
@classmethod
|
|
204
|
+
def type_from_data(cls, data: dict) -> ToolServerType:
|
|
205
|
+
"""Get the tool server type from the data for the the validators"""
|
|
206
|
+
raw_type = data.get("type")
|
|
207
|
+
if raw_type is None:
|
|
208
|
+
raise ValueError("type is required")
|
|
209
|
+
try:
|
|
210
|
+
return ToolServerType(raw_type)
|
|
211
|
+
except ValueError:
|
|
212
|
+
valid_types = ", ".join(type.value for type in ToolServerType)
|
|
213
|
+
raise ValueError(f"type must be one of: {valid_types}")
|
|
214
|
+
|
|
215
|
+
@model_validator(mode="before")
|
|
216
|
+
def validate_required_fields(cls, data: dict) -> dict:
|
|
98
217
|
"""Validate that each tool type has the required configuration."""
|
|
99
|
-
|
|
218
|
+
server_type = ExternalToolServer.type_from_data(data)
|
|
219
|
+
properties = data.get("properties", {})
|
|
220
|
+
|
|
221
|
+
match server_type:
|
|
100
222
|
case ToolServerType.remote_mcp:
|
|
101
|
-
server_url =
|
|
102
|
-
if
|
|
103
|
-
raise ValueError(
|
|
104
|
-
"server_url must be a string for external tools of type 'remote_mcp'"
|
|
105
|
-
)
|
|
106
|
-
if not server_url:
|
|
223
|
+
server_url = properties.get("server_url", None)
|
|
224
|
+
if server_url is None:
|
|
107
225
|
raise ValueError(
|
|
108
|
-
"
|
|
226
|
+
"Server URL is required to connect to a remote MCP server"
|
|
109
227
|
)
|
|
228
|
+
ExternalToolServer.check_server_url(server_url)
|
|
110
229
|
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
230
|
+
case ToolServerType.local_mcp:
|
|
231
|
+
command = properties.get("command", None)
|
|
232
|
+
if command is None:
|
|
233
|
+
raise ValueError("command is required to start a local MCP server")
|
|
234
|
+
if not isinstance(command, str):
|
|
115
235
|
raise ValueError(
|
|
116
|
-
"
|
|
236
|
+
"command must be a string to start a local MCP server"
|
|
117
237
|
)
|
|
238
|
+
# Reject empty/whitespace-only command strings
|
|
239
|
+
if command.strip() == "":
|
|
240
|
+
raise ValueError("command must be a non-empty string")
|
|
118
241
|
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
if not isinstance(secret_header_keys, list):
|
|
242
|
+
args = properties.get("args", None)
|
|
243
|
+
if args is not None:
|
|
244
|
+
if not isinstance(args, list):
|
|
123
245
|
raise ValueError(
|
|
124
|
-
"
|
|
246
|
+
"arguments must be a list to start a local MCP server"
|
|
125
247
|
)
|
|
126
|
-
if not all(isinstance(k, str) for k in secret_header_keys):
|
|
127
|
-
raise ValueError("secret_header_keys must contain only strings")
|
|
128
248
|
|
|
129
|
-
case ToolServerType.
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
249
|
+
case ToolServerType.kiln_task:
|
|
250
|
+
tool_name_validator(properties.get("name", ""))
|
|
251
|
+
err_msg_prefix = "Kiln task server properties:"
|
|
252
|
+
validate_return_dict_prop(
|
|
253
|
+
properties, "description", str, err_msg_prefix
|
|
254
|
+
)
|
|
255
|
+
description = properties.get("description", "")
|
|
256
|
+
if len(description) > 128:
|
|
257
|
+
raise ValueError("description must be 128 characters or less")
|
|
258
|
+
validate_return_dict_prop(
|
|
259
|
+
properties, "is_archived", bool, err_msg_prefix
|
|
260
|
+
)
|
|
261
|
+
validate_return_dict_prop(properties, "task_id", str, err_msg_prefix)
|
|
262
|
+
validate_return_dict_prop(
|
|
263
|
+
properties, "run_config_id", str, err_msg_prefix
|
|
264
|
+
)
|
|
137
265
|
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
)
|
|
266
|
+
case _:
|
|
267
|
+
# Type checking will catch missing cases
|
|
268
|
+
raise_exhaustive_enum_error(server_type)
|
|
269
|
+
return data
|
|
143
270
|
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
271
|
+
@model_validator(mode="before")
|
|
272
|
+
def validate_headers_and_env_vars(cls, data: dict) -> dict:
|
|
273
|
+
"""
|
|
274
|
+
Validate secrets, these needs to be validated before model initlization because secrets will be processed and stripped
|
|
275
|
+
"""
|
|
276
|
+
type = ExternalToolServer.type_from_data(data)
|
|
277
|
+
|
|
278
|
+
properties = data.get("properties", {})
|
|
279
|
+
if properties is None:
|
|
280
|
+
raise ValueError("properties is required")
|
|
281
|
+
|
|
282
|
+
match type:
|
|
283
|
+
case ToolServerType.remote_mcp:
|
|
284
|
+
# Validate headers
|
|
285
|
+
headers = properties.get("headers", None)
|
|
286
|
+
if headers is not None:
|
|
287
|
+
ExternalToolServer.check_headers(headers)
|
|
288
|
+
|
|
289
|
+
# Secret header keys are optional, validate if they are set
|
|
290
|
+
secret_header_keys = properties.get("secret_header_keys", None)
|
|
291
|
+
if secret_header_keys is not None:
|
|
292
|
+
ExternalToolServer.check_secret_keys(
|
|
293
|
+
secret_header_keys, "secret_header_keys", "remote_mcp"
|
|
148
294
|
)
|
|
149
295
|
|
|
150
|
-
|
|
296
|
+
case ToolServerType.local_mcp:
|
|
297
|
+
# Validate secret environment variable keys
|
|
298
|
+
env_vars = properties.get("env_vars", {})
|
|
299
|
+
if env_vars is not None:
|
|
300
|
+
ExternalToolServer.check_env_vars(env_vars)
|
|
301
|
+
|
|
151
302
|
# Secret env var keys are optional, but if they are set, they must be a list of strings
|
|
303
|
+
secret_env_var_keys = properties.get("secret_env_var_keys", None)
|
|
152
304
|
if secret_env_var_keys is not None:
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
"secret_env_var_keys must contain only strings"
|
|
160
|
-
)
|
|
305
|
+
ExternalToolServer.check_secret_keys(
|
|
306
|
+
secret_env_var_keys, "secret_env_var_keys", "local_mcp"
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
case ToolServerType.kiln_task:
|
|
310
|
+
pass
|
|
161
311
|
|
|
162
312
|
case _:
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
return
|
|
313
|
+
raise_exhaustive_enum_error(type)
|
|
314
|
+
|
|
315
|
+
return data
|
|
166
316
|
|
|
167
317
|
def get_secret_keys(self) -> list[str]:
|
|
168
318
|
"""
|
|
@@ -176,6 +326,8 @@ class ExternalToolServer(KilnParentedModel):
|
|
|
176
326
|
return self.properties.get("secret_header_keys", [])
|
|
177
327
|
case ToolServerType.local_mcp:
|
|
178
328
|
return self.properties.get("secret_env_var_keys", [])
|
|
329
|
+
case ToolServerType.kiln_task:
|
|
330
|
+
return []
|
|
179
331
|
case _:
|
|
180
332
|
raise_exhaustive_enum_error(self.type)
|
|
181
333
|
|