kiln-ai 0.19.0__py3-none-any.whl → 0.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +8 -2
- kiln_ai/adapters/adapter_registry.py +43 -208
- kiln_ai/adapters/chat/chat_formatter.py +8 -12
- kiln_ai/adapters/chat/test_chat_formatter.py +6 -2
- kiln_ai/adapters/chunkers/__init__.py +13 -0
- kiln_ai/adapters/chunkers/base_chunker.py +42 -0
- kiln_ai/adapters/chunkers/chunker_registry.py +16 -0
- kiln_ai/adapters/chunkers/fixed_window_chunker.py +39 -0
- kiln_ai/adapters/chunkers/helpers.py +23 -0
- kiln_ai/adapters/chunkers/test_base_chunker.py +63 -0
- kiln_ai/adapters/chunkers/test_chunker_registry.py +28 -0
- kiln_ai/adapters/chunkers/test_fixed_window_chunker.py +346 -0
- kiln_ai/adapters/chunkers/test_helpers.py +75 -0
- kiln_ai/adapters/data_gen/test_data_gen_task.py +9 -3
- kiln_ai/adapters/docker_model_runner_tools.py +119 -0
- kiln_ai/adapters/embedding/__init__.py +0 -0
- kiln_ai/adapters/embedding/base_embedding_adapter.py +44 -0
- kiln_ai/adapters/embedding/embedding_registry.py +32 -0
- kiln_ai/adapters/embedding/litellm_embedding_adapter.py +199 -0
- kiln_ai/adapters/embedding/test_base_embedding_adapter.py +283 -0
- kiln_ai/adapters/embedding/test_embedding_registry.py +166 -0
- kiln_ai/adapters/embedding/test_litellm_embedding_adapter.py +1149 -0
- kiln_ai/adapters/eval/base_eval.py +2 -2
- kiln_ai/adapters/eval/eval_runner.py +9 -3
- kiln_ai/adapters/eval/g_eval.py +2 -2
- kiln_ai/adapters/eval/test_base_eval.py +2 -4
- kiln_ai/adapters/eval/test_g_eval.py +4 -5
- kiln_ai/adapters/extractors/__init__.py +18 -0
- kiln_ai/adapters/extractors/base_extractor.py +72 -0
- kiln_ai/adapters/extractors/encoding.py +20 -0
- kiln_ai/adapters/extractors/extractor_registry.py +44 -0
- kiln_ai/adapters/extractors/extractor_runner.py +112 -0
- kiln_ai/adapters/extractors/litellm_extractor.py +386 -0
- kiln_ai/adapters/extractors/test_base_extractor.py +244 -0
- kiln_ai/adapters/extractors/test_encoding.py +54 -0
- kiln_ai/adapters/extractors/test_extractor_registry.py +181 -0
- kiln_ai/adapters/extractors/test_extractor_runner.py +181 -0
- kiln_ai/adapters/extractors/test_litellm_extractor.py +1192 -0
- kiln_ai/adapters/fine_tune/__init__.py +1 -1
- kiln_ai/adapters/fine_tune/openai_finetune.py +14 -4
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +2 -2
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +2 -6
- kiln_ai/adapters/fine_tune/test_openai_finetune.py +108 -111
- kiln_ai/adapters/fine_tune/test_together_finetune.py +2 -6
- kiln_ai/adapters/ml_embedding_model_list.py +192 -0
- kiln_ai/adapters/ml_model_list.py +761 -37
- kiln_ai/adapters/model_adapters/base_adapter.py +51 -21
- kiln_ai/adapters/model_adapters/litellm_adapter.py +380 -138
- kiln_ai/adapters/model_adapters/test_base_adapter.py +193 -17
- kiln_ai/adapters/model_adapters/test_litellm_adapter.py +407 -2
- kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py +1103 -0
- kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +5 -5
- kiln_ai/adapters/model_adapters/test_structured_output.py +113 -5
- kiln_ai/adapters/ollama_tools.py +69 -12
- kiln_ai/adapters/parsers/__init__.py +1 -1
- kiln_ai/adapters/provider_tools.py +205 -47
- kiln_ai/adapters/rag/deduplication.py +49 -0
- kiln_ai/adapters/rag/progress.py +252 -0
- kiln_ai/adapters/rag/rag_runners.py +844 -0
- kiln_ai/adapters/rag/test_deduplication.py +195 -0
- kiln_ai/adapters/rag/test_progress.py +785 -0
- kiln_ai/adapters/rag/test_rag_runners.py +2376 -0
- kiln_ai/adapters/remote_config.py +80 -8
- kiln_ai/adapters/repair/test_repair_task.py +12 -9
- kiln_ai/adapters/run_output.py +3 -0
- kiln_ai/adapters/test_adapter_registry.py +657 -85
- kiln_ai/adapters/test_docker_model_runner_tools.py +305 -0
- kiln_ai/adapters/test_ml_embedding_model_list.py +429 -0
- kiln_ai/adapters/test_ml_model_list.py +251 -1
- kiln_ai/adapters/test_ollama_tools.py +340 -1
- kiln_ai/adapters/test_prompt_adaptors.py +13 -6
- kiln_ai/adapters/test_prompt_builders.py +1 -1
- kiln_ai/adapters/test_provider_tools.py +254 -8
- kiln_ai/adapters/test_remote_config.py +651 -58
- kiln_ai/adapters/vector_store/__init__.py +1 -0
- kiln_ai/adapters/vector_store/base_vector_store_adapter.py +83 -0
- kiln_ai/adapters/vector_store/lancedb_adapter.py +389 -0
- kiln_ai/adapters/vector_store/test_base_vector_store.py +160 -0
- kiln_ai/adapters/vector_store/test_lancedb_adapter.py +1841 -0
- kiln_ai/adapters/vector_store/test_vector_store_registry.py +199 -0
- kiln_ai/adapters/vector_store/vector_store_registry.py +33 -0
- kiln_ai/datamodel/__init__.py +39 -34
- kiln_ai/datamodel/basemodel.py +170 -1
- kiln_ai/datamodel/chunk.py +158 -0
- kiln_ai/datamodel/datamodel_enums.py +28 -0
- kiln_ai/datamodel/embedding.py +64 -0
- kiln_ai/datamodel/eval.py +1 -1
- kiln_ai/datamodel/external_tool_server.py +298 -0
- kiln_ai/datamodel/extraction.py +303 -0
- kiln_ai/datamodel/json_schema.py +25 -10
- kiln_ai/datamodel/project.py +40 -1
- kiln_ai/datamodel/rag.py +79 -0
- kiln_ai/datamodel/registry.py +0 -15
- kiln_ai/datamodel/run_config.py +62 -0
- kiln_ai/datamodel/task.py +2 -77
- kiln_ai/datamodel/task_output.py +6 -1
- kiln_ai/datamodel/task_run.py +41 -0
- kiln_ai/datamodel/test_attachment.py +649 -0
- kiln_ai/datamodel/test_basemodel.py +4 -4
- kiln_ai/datamodel/test_chunk_models.py +317 -0
- kiln_ai/datamodel/test_dataset_split.py +1 -1
- kiln_ai/datamodel/test_embedding_models.py +448 -0
- kiln_ai/datamodel/test_eval_model.py +6 -6
- kiln_ai/datamodel/test_example_models.py +175 -0
- kiln_ai/datamodel/test_external_tool_server.py +691 -0
- kiln_ai/datamodel/test_extraction_chunk.py +206 -0
- kiln_ai/datamodel/test_extraction_model.py +470 -0
- kiln_ai/datamodel/test_rag.py +641 -0
- kiln_ai/datamodel/test_registry.py +8 -3
- kiln_ai/datamodel/test_task.py +15 -47
- kiln_ai/datamodel/test_tool_id.py +320 -0
- kiln_ai/datamodel/test_vector_store.py +320 -0
- kiln_ai/datamodel/tool_id.py +105 -0
- kiln_ai/datamodel/vector_store.py +141 -0
- kiln_ai/tools/__init__.py +8 -0
- kiln_ai/tools/base_tool.py +82 -0
- kiln_ai/tools/built_in_tools/__init__.py +13 -0
- kiln_ai/tools/built_in_tools/math_tools.py +124 -0
- kiln_ai/tools/built_in_tools/test_math_tools.py +204 -0
- kiln_ai/tools/mcp_server_tool.py +95 -0
- kiln_ai/tools/mcp_session_manager.py +246 -0
- kiln_ai/tools/rag_tools.py +157 -0
- kiln_ai/tools/test_base_tools.py +199 -0
- kiln_ai/tools/test_mcp_server_tool.py +457 -0
- kiln_ai/tools/test_mcp_session_manager.py +1585 -0
- kiln_ai/tools/test_rag_tools.py +848 -0
- kiln_ai/tools/test_tool_registry.py +562 -0
- kiln_ai/tools/tool_registry.py +85 -0
- kiln_ai/utils/__init__.py +3 -0
- kiln_ai/utils/async_job_runner.py +62 -17
- kiln_ai/utils/config.py +24 -2
- kiln_ai/utils/env.py +15 -0
- kiln_ai/utils/filesystem.py +14 -0
- kiln_ai/utils/filesystem_cache.py +60 -0
- kiln_ai/utils/litellm.py +94 -0
- kiln_ai/utils/lock.py +100 -0
- kiln_ai/utils/mime_type.py +38 -0
- kiln_ai/utils/open_ai_types.py +94 -0
- kiln_ai/utils/pdf_utils.py +38 -0
- kiln_ai/utils/project_utils.py +17 -0
- kiln_ai/utils/test_async_job_runner.py +151 -35
- kiln_ai/utils/test_config.py +138 -1
- kiln_ai/utils/test_env.py +142 -0
- kiln_ai/utils/test_filesystem_cache.py +316 -0
- kiln_ai/utils/test_litellm.py +206 -0
- kiln_ai/utils/test_lock.py +185 -0
- kiln_ai/utils/test_mime_type.py +66 -0
- kiln_ai/utils/test_open_ai_types.py +131 -0
- kiln_ai/utils/test_pdf_utils.py +73 -0
- kiln_ai/utils/test_uuid.py +111 -0
- kiln_ai/utils/test_validation.py +524 -0
- kiln_ai/utils/uuid.py +9 -0
- kiln_ai/utils/validation.py +90 -0
- {kiln_ai-0.19.0.dist-info → kiln_ai-0.21.0.dist-info}/METADATA +12 -5
- kiln_ai-0.21.0.dist-info/RECORD +211 -0
- kiln_ai-0.19.0.dist-info/RECORD +0 -115
- {kiln_ai-0.19.0.dist-info → kiln_ai-0.21.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.19.0.dist-info → kiln_ai-0.21.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -0,0 +1,158 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from enum import Enum
|
|
3
|
+
from typing import TYPE_CHECKING, List, Union
|
|
4
|
+
|
|
5
|
+
import anyio
|
|
6
|
+
from pydantic import (
|
|
7
|
+
BaseModel,
|
|
8
|
+
Field,
|
|
9
|
+
SerializationInfo,
|
|
10
|
+
ValidationInfo,
|
|
11
|
+
field_serializer,
|
|
12
|
+
field_validator,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
from kiln_ai.datamodel.basemodel import (
|
|
16
|
+
ID_TYPE,
|
|
17
|
+
FilenameString,
|
|
18
|
+
KilnAttachmentModel,
|
|
19
|
+
KilnParentedModel,
|
|
20
|
+
KilnParentModel,
|
|
21
|
+
)
|
|
22
|
+
from kiln_ai.datamodel.embedding import ChunkEmbeddings
|
|
23
|
+
|
|
24
|
+
logger = logging.getLogger(__name__)
|
|
25
|
+
|
|
26
|
+
if TYPE_CHECKING:
|
|
27
|
+
from kiln_ai.datamodel.extraction import Extraction
|
|
28
|
+
from kiln_ai.datamodel.project import Project
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def validate_fixed_window_chunker_properties(
|
|
32
|
+
properties: dict[str, str | int | float | bool],
|
|
33
|
+
) -> dict[str, str | int | float | bool]:
|
|
34
|
+
"""Validate the properties for the fixed window chunker and set defaults if needed."""
|
|
35
|
+
chunk_overlap = properties.get("chunk_overlap")
|
|
36
|
+
if chunk_overlap is None:
|
|
37
|
+
raise ValueError("Chunk overlap is required.")
|
|
38
|
+
|
|
39
|
+
chunk_size = properties.get("chunk_size")
|
|
40
|
+
if chunk_size is None:
|
|
41
|
+
raise ValueError("Chunk size is required.")
|
|
42
|
+
|
|
43
|
+
if not isinstance(chunk_overlap, int):
|
|
44
|
+
raise ValueError("Chunk overlap must be an integer.")
|
|
45
|
+
if chunk_overlap < 0:
|
|
46
|
+
raise ValueError("Chunk overlap must be greater than or equal to 0.")
|
|
47
|
+
|
|
48
|
+
if not isinstance(chunk_size, int):
|
|
49
|
+
raise ValueError("Chunk size must be an integer.")
|
|
50
|
+
if chunk_size <= 0:
|
|
51
|
+
raise ValueError("Chunk size must be greater than 0.")
|
|
52
|
+
|
|
53
|
+
if chunk_overlap >= chunk_size:
|
|
54
|
+
raise ValueError("Chunk overlap must be less than chunk size.")
|
|
55
|
+
|
|
56
|
+
return properties
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class ChunkerType(str, Enum):
|
|
60
|
+
FIXED_WINDOW = "fixed_window"
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class ChunkerConfig(KilnParentedModel):
|
|
64
|
+
name: FilenameString = Field(
|
|
65
|
+
description="A name to identify the chunker config.",
|
|
66
|
+
)
|
|
67
|
+
description: str | None = Field(
|
|
68
|
+
default=None, description="The description of the chunker config"
|
|
69
|
+
)
|
|
70
|
+
chunker_type: ChunkerType = Field(
|
|
71
|
+
description="This is used to determine the type of chunker to use.",
|
|
72
|
+
)
|
|
73
|
+
properties: dict[str, str | int | float | bool] = Field(
|
|
74
|
+
description="Properties to be used to execute the chunker config. This is chunker_type specific and should serialize to a json dict.",
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
# Workaround to return typed parent without importing Project
|
|
78
|
+
def parent_project(self) -> Union["Project", None]:
|
|
79
|
+
if self.parent is None or self.parent.__class__.__name__ != "Project":
|
|
80
|
+
return None
|
|
81
|
+
return self.parent # type: ignore
|
|
82
|
+
|
|
83
|
+
@field_validator("properties")
|
|
84
|
+
@classmethod
|
|
85
|
+
def validate_properties(
|
|
86
|
+
cls, properties: dict[str, str | int | float | bool], info: ValidationInfo
|
|
87
|
+
) -> dict[str, str | int | float | bool]:
|
|
88
|
+
if info.data.get("chunker_type") == ChunkerType.FIXED_WINDOW:
|
|
89
|
+
# do not trigger revalidation of properties
|
|
90
|
+
return validate_fixed_window_chunker_properties(properties)
|
|
91
|
+
return properties
|
|
92
|
+
|
|
93
|
+
def chunk_size(self) -> int | None:
|
|
94
|
+
if self.properties.get("chunk_size") is None:
|
|
95
|
+
return None
|
|
96
|
+
if not isinstance(self.properties["chunk_size"], int):
|
|
97
|
+
raise ValueError("Chunk size must be an integer.")
|
|
98
|
+
return self.properties["chunk_size"]
|
|
99
|
+
|
|
100
|
+
def chunk_overlap(self) -> int | None:
|
|
101
|
+
if self.properties.get("chunk_overlap") is None:
|
|
102
|
+
return None
|
|
103
|
+
if not isinstance(self.properties["chunk_overlap"], int):
|
|
104
|
+
raise ValueError("Chunk overlap must be an integer.")
|
|
105
|
+
return self.properties["chunk_overlap"]
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
class Chunk(BaseModel):
|
|
109
|
+
content: KilnAttachmentModel = Field(
|
|
110
|
+
description="The content of the chunk, stored as an attachment."
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
@field_serializer("content")
|
|
114
|
+
def serialize_content(
|
|
115
|
+
self, content: KilnAttachmentModel, info: SerializationInfo
|
|
116
|
+
) -> dict:
|
|
117
|
+
context = info.context or {}
|
|
118
|
+
context["filename_prefix"] = "content"
|
|
119
|
+
return content.model_dump(mode="json", context=context)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
class ChunkedDocument(
|
|
123
|
+
KilnParentedModel, KilnParentModel, parent_of={"chunk_embeddings": ChunkEmbeddings}
|
|
124
|
+
):
|
|
125
|
+
chunker_config_id: ID_TYPE = Field(
|
|
126
|
+
description="The ID of the chunker config used to chunk the document.",
|
|
127
|
+
)
|
|
128
|
+
chunks: List[Chunk] = Field(description="The chunks of the document.")
|
|
129
|
+
|
|
130
|
+
def parent_extraction(self) -> Union["Extraction", None]:
|
|
131
|
+
if self.parent is None or self.parent.__class__.__name__ != "Extraction":
|
|
132
|
+
return None
|
|
133
|
+
return self.parent # type: ignore
|
|
134
|
+
|
|
135
|
+
def chunk_embeddings(self, readonly: bool = False) -> list[ChunkEmbeddings]:
|
|
136
|
+
return super().chunk_embeddings(readonly=readonly) # type: ignore
|
|
137
|
+
|
|
138
|
+
async def load_chunks_text(self) -> list[str]:
|
|
139
|
+
"""Utility to return a list of text for each chunk, loaded from each chunk's content attachment."""
|
|
140
|
+
if not self.path:
|
|
141
|
+
raise ValueError(
|
|
142
|
+
"Failed to resolve the path of chunk content attachment because the chunk does not have a path."
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
chunks_text: list[str] = []
|
|
146
|
+
for chunk in self.chunks:
|
|
147
|
+
full_path = chunk.content.resolve_path(self.path.parent)
|
|
148
|
+
|
|
149
|
+
try:
|
|
150
|
+
chunks_text.append(
|
|
151
|
+
await anyio.Path(full_path).read_text(encoding="utf-8")
|
|
152
|
+
)
|
|
153
|
+
except Exception as e:
|
|
154
|
+
raise ValueError(
|
|
155
|
+
f"Failed to read chunk content for {full_path}: {e}"
|
|
156
|
+
) from e
|
|
157
|
+
|
|
158
|
+
return chunks_text
|
|
@@ -99,3 +99,31 @@ class ModelProviderName(str, Enum):
|
|
|
99
99
|
together_ai = "together_ai"
|
|
100
100
|
siliconflow_cn = "siliconflow_cn"
|
|
101
101
|
cerebras = "cerebras"
|
|
102
|
+
docker_model_runner = "docker_model_runner"
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class KilnMimeType(str, Enum):
|
|
106
|
+
"""
|
|
107
|
+
Enumeration of supported mime types.
|
|
108
|
+
"""
|
|
109
|
+
|
|
110
|
+
# documents
|
|
111
|
+
PDF = "application/pdf"
|
|
112
|
+
CSV = "text/csv"
|
|
113
|
+
TXT = "text/plain"
|
|
114
|
+
HTML = "text/html"
|
|
115
|
+
MD = "text/markdown"
|
|
116
|
+
|
|
117
|
+
# images
|
|
118
|
+
PNG = "image/png"
|
|
119
|
+
JPG = "image/jpeg"
|
|
120
|
+
JPEG = "image/jpeg"
|
|
121
|
+
|
|
122
|
+
# audio
|
|
123
|
+
MP3 = "audio/mpeg"
|
|
124
|
+
WAV = "audio/wav"
|
|
125
|
+
OGG = "audio/ogg"
|
|
126
|
+
|
|
127
|
+
# video
|
|
128
|
+
MP4 = "video/mp4"
|
|
129
|
+
MOV = "video/quicktime"
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING, List, Union
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, Field, model_validator
|
|
4
|
+
|
|
5
|
+
from kiln_ai.datamodel.basemodel import ID_TYPE, FilenameString, KilnParentedModel
|
|
6
|
+
from kiln_ai.datamodel.datamodel_enums import ModelProviderName
|
|
7
|
+
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
from kiln_ai.datamodel.chunk import ChunkedDocument
|
|
10
|
+
from kiln_ai.datamodel.project import Project
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class EmbeddingConfig(KilnParentedModel):
|
|
14
|
+
name: FilenameString = Field(
|
|
15
|
+
description="A name to identify the embedding config.",
|
|
16
|
+
)
|
|
17
|
+
description: str | None = Field(
|
|
18
|
+
default=None,
|
|
19
|
+
description="A description for your reference, not shared with embedding models.",
|
|
20
|
+
)
|
|
21
|
+
model_provider_name: ModelProviderName = Field(
|
|
22
|
+
description="The provider to use to generate embeddings.",
|
|
23
|
+
)
|
|
24
|
+
model_name: str = Field(
|
|
25
|
+
description="The model to use to generate embeddings.",
|
|
26
|
+
)
|
|
27
|
+
properties: dict[str, str | int | float | bool] = Field(
|
|
28
|
+
description="Properties to be used to execute the embedding config.",
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
# Workaround to return typed parent without importing Project
|
|
32
|
+
def parent_project(self) -> Union["Project", None]:
|
|
33
|
+
if self.parent is None or self.parent.__class__.__name__ != "Project":
|
|
34
|
+
return None
|
|
35
|
+
return self.parent # type: ignore
|
|
36
|
+
|
|
37
|
+
@model_validator(mode="after")
|
|
38
|
+
def validate_properties(self):
|
|
39
|
+
if "dimensions" in self.properties:
|
|
40
|
+
if (
|
|
41
|
+
not isinstance(self.properties["dimensions"], int)
|
|
42
|
+
or self.properties["dimensions"] <= 0
|
|
43
|
+
):
|
|
44
|
+
raise ValueError("Dimensions must be a positive integer")
|
|
45
|
+
|
|
46
|
+
return self
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class Embedding(BaseModel):
|
|
50
|
+
vector: List[float] = Field(description="The vector of the embedding.")
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class ChunkEmbeddings(KilnParentedModel):
|
|
54
|
+
embedding_config_id: ID_TYPE = Field(
|
|
55
|
+
description="The ID of the embedding config used to generate the embeddings.",
|
|
56
|
+
)
|
|
57
|
+
embeddings: List[Embedding] = Field(
|
|
58
|
+
description="The embeddings of the chunks. The embedding at index i corresponds to the chunk at index i in the parent chunked document."
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
def parent_chunked_document(self) -> Union["ChunkedDocument", None]:
|
|
62
|
+
if self.parent is None or self.parent.__class__.__name__ != "ChunkedDocument":
|
|
63
|
+
return None
|
|
64
|
+
return self.parent # type: ignore
|
kiln_ai/datamodel/eval.py
CHANGED
|
@@ -252,7 +252,7 @@ class EvalConfig(KilnParentedModel, KilnParentModel, parent_of={"runs": EvalRun}
|
|
|
252
252
|
# This will raise a TypeError if the dict contains non-JSON-serializable objects
|
|
253
253
|
json.dumps(self.properties)
|
|
254
254
|
except TypeError as e:
|
|
255
|
-
raise ValueError(f"Properties must be JSON serializable: {
|
|
255
|
+
raise ValueError(f"Properties must be JSON serializable: {e!s}")
|
|
256
256
|
return self
|
|
257
257
|
|
|
258
258
|
|
|
@@ -0,0 +1,298 @@
|
|
|
1
|
+
from enum import Enum
|
|
2
|
+
from typing import Any, Dict
|
|
3
|
+
|
|
4
|
+
from pydantic import Field, PrivateAttr, model_validator
|
|
5
|
+
|
|
6
|
+
from kiln_ai.datamodel.basemodel import (
|
|
7
|
+
FilenameString,
|
|
8
|
+
KilnParentedModel,
|
|
9
|
+
)
|
|
10
|
+
from kiln_ai.utils.config import MCP_SECRETS_KEY, Config
|
|
11
|
+
from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ToolServerType(str, Enum):
|
|
15
|
+
"""
|
|
16
|
+
Enumeration of supported external tool server types.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
remote_mcp = "remote_mcp"
|
|
20
|
+
local_mcp = "local_mcp"
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class ExternalToolServer(KilnParentedModel):
|
|
24
|
+
"""
|
|
25
|
+
Configuration for communicating with a external MCP (Model Context Protocol) Server for LLM tool calls. External tool servers can be remote or local.
|
|
26
|
+
|
|
27
|
+
This model stores the necessary configuration to connect to and authenticate with
|
|
28
|
+
external MCP servers that provide tools for LLM interactions.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
name: FilenameString = Field(description="The name of the external tool.")
|
|
32
|
+
type: ToolServerType = Field(
|
|
33
|
+
description="The type of external tool server. Remote tools are hosted on a remote server",
|
|
34
|
+
)
|
|
35
|
+
description: str | None = Field(
|
|
36
|
+
default=None,
|
|
37
|
+
description="A description of the external tool for you and your team. Will not be used in prompts/training/validation.",
|
|
38
|
+
)
|
|
39
|
+
properties: Dict[str, Any] = Field(
|
|
40
|
+
default={},
|
|
41
|
+
description="Configuration properties specific to the tool type.",
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
# Private variable to store unsaved secrets
|
|
45
|
+
_unsaved_secrets: dict[str, str] = PrivateAttr(default_factory=dict)
|
|
46
|
+
|
|
47
|
+
def model_post_init(self, __context: Any) -> None:
|
|
48
|
+
# Process secrets after initialization (pydantic v2 hook)
|
|
49
|
+
self._process_secrets_from_properties()
|
|
50
|
+
|
|
51
|
+
def _process_secrets_from_properties(self) -> None:
|
|
52
|
+
"""
|
|
53
|
+
Extract secrets from properties and move them to _unsaved_secrets.
|
|
54
|
+
This removes secrets from the properties dict so they aren't saved to file.
|
|
55
|
+
Clears existing _unsaved_secrets first to handle property updates correctly.
|
|
56
|
+
"""
|
|
57
|
+
# Clear existing unsaved secrets since we're reprocessing
|
|
58
|
+
self._unsaved_secrets.clear()
|
|
59
|
+
|
|
60
|
+
secret_keys = self.get_secret_keys()
|
|
61
|
+
|
|
62
|
+
if not secret_keys:
|
|
63
|
+
return
|
|
64
|
+
|
|
65
|
+
# Extract secret values from properties based on server type
|
|
66
|
+
match self.type:
|
|
67
|
+
case ToolServerType.remote_mcp:
|
|
68
|
+
headers = self.properties.get("headers", {})
|
|
69
|
+
for key_name in secret_keys:
|
|
70
|
+
if key_name in headers:
|
|
71
|
+
self._unsaved_secrets[key_name] = headers[key_name]
|
|
72
|
+
# Remove from headers immediately so they are not saved to file
|
|
73
|
+
del headers[key_name]
|
|
74
|
+
|
|
75
|
+
case ToolServerType.local_mcp:
|
|
76
|
+
env_vars = self.properties.get("env_vars", {})
|
|
77
|
+
for key_name in secret_keys:
|
|
78
|
+
if key_name in env_vars:
|
|
79
|
+
self._unsaved_secrets[key_name] = env_vars[key_name]
|
|
80
|
+
# Remove from env_vars immediately so they are not saved to file
|
|
81
|
+
del env_vars[key_name]
|
|
82
|
+
|
|
83
|
+
case _:
|
|
84
|
+
raise_exhaustive_enum_error(self.type)
|
|
85
|
+
|
|
86
|
+
def __setattr__(self, name: str, value: Any) -> None:
|
|
87
|
+
"""
|
|
88
|
+
Override __setattr__ to process secrets whenever properties are updated.
|
|
89
|
+
"""
|
|
90
|
+
super().__setattr__(name, value)
|
|
91
|
+
|
|
92
|
+
# Process secrets whenever properties are updated
|
|
93
|
+
if name == "properties":
|
|
94
|
+
self._process_secrets_from_properties()
|
|
95
|
+
|
|
96
|
+
@model_validator(mode="after")
|
|
97
|
+
def validate_required_fields(self) -> "ExternalToolServer":
|
|
98
|
+
"""Validate that each tool type has the required configuration."""
|
|
99
|
+
match self.type:
|
|
100
|
+
case ToolServerType.remote_mcp:
|
|
101
|
+
server_url = self.properties.get("server_url", None)
|
|
102
|
+
if not isinstance(server_url, str):
|
|
103
|
+
raise ValueError(
|
|
104
|
+
"server_url must be a string for external tools of type 'remote_mcp'"
|
|
105
|
+
)
|
|
106
|
+
if not server_url:
|
|
107
|
+
raise ValueError(
|
|
108
|
+
"server_url is required to connect to a remote MCP server"
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
headers = self.properties.get("headers", None)
|
|
112
|
+
if headers is None:
|
|
113
|
+
raise ValueError("headers must be set when type is 'remote_mcp'")
|
|
114
|
+
if not isinstance(headers, dict):
|
|
115
|
+
raise ValueError(
|
|
116
|
+
"headers must be a dictionary for external tools of type 'remote_mcp'"
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
secret_header_keys = self.properties.get("secret_header_keys", None)
|
|
120
|
+
# Secret header keys are optional, but if they are set, they must be a list of strings
|
|
121
|
+
if secret_header_keys is not None:
|
|
122
|
+
if not isinstance(secret_header_keys, list):
|
|
123
|
+
raise ValueError(
|
|
124
|
+
"secret_header_keys must be a list for external tools of type 'remote_mcp'"
|
|
125
|
+
)
|
|
126
|
+
if not all(isinstance(k, str) for k in secret_header_keys):
|
|
127
|
+
raise ValueError("secret_header_keys must contain only strings")
|
|
128
|
+
|
|
129
|
+
case ToolServerType.local_mcp:
|
|
130
|
+
command = self.properties.get("command", None)
|
|
131
|
+
if not isinstance(command, str):
|
|
132
|
+
raise ValueError(
|
|
133
|
+
"command must be a string to start a local MCP server"
|
|
134
|
+
)
|
|
135
|
+
if not command.strip():
|
|
136
|
+
raise ValueError("command is required to start a local MCP server")
|
|
137
|
+
|
|
138
|
+
args = self.properties.get("args", None)
|
|
139
|
+
if not isinstance(args, list):
|
|
140
|
+
raise ValueError(
|
|
141
|
+
"arguments must be a list to start a local MCP server"
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
env_vars = self.properties.get("env_vars", {})
|
|
145
|
+
if not isinstance(env_vars, dict):
|
|
146
|
+
raise ValueError(
|
|
147
|
+
"environment variables must be a dictionary for external tools of type 'local_mcp'"
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
secret_env_var_keys = self.properties.get("secret_env_var_keys", None)
|
|
151
|
+
# Secret env var keys are optional, but if they are set, they must be a list of strings
|
|
152
|
+
if secret_env_var_keys is not None:
|
|
153
|
+
if not isinstance(secret_env_var_keys, list):
|
|
154
|
+
raise ValueError(
|
|
155
|
+
"secret_env_var_keys must be a list for external tools of type 'local_mcp'"
|
|
156
|
+
)
|
|
157
|
+
if not all(isinstance(k, str) for k in secret_env_var_keys):
|
|
158
|
+
raise ValueError(
|
|
159
|
+
"secret_env_var_keys must contain only strings"
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
case _:
|
|
163
|
+
# Type checking will catch missing cases
|
|
164
|
+
raise_exhaustive_enum_error(self.type)
|
|
165
|
+
return self
|
|
166
|
+
|
|
167
|
+
def get_secret_keys(self) -> list[str]:
|
|
168
|
+
"""
|
|
169
|
+
Get the list of secret key names based on server type.
|
|
170
|
+
|
|
171
|
+
Returns:
|
|
172
|
+
List of secret key names (header names for remote, env var names for local)
|
|
173
|
+
"""
|
|
174
|
+
match self.type:
|
|
175
|
+
case ToolServerType.remote_mcp:
|
|
176
|
+
return self.properties.get("secret_header_keys", [])
|
|
177
|
+
case ToolServerType.local_mcp:
|
|
178
|
+
return self.properties.get("secret_env_var_keys", [])
|
|
179
|
+
case _:
|
|
180
|
+
raise_exhaustive_enum_error(self.type)
|
|
181
|
+
|
|
182
|
+
def retrieve_secrets(self) -> tuple[dict[str, str], list[str]]:
|
|
183
|
+
"""
|
|
184
|
+
Retrieve secrets from configuration system or in-memory storage.
|
|
185
|
+
Automatically determines which secret keys to retrieve based on the server type.
|
|
186
|
+
Config secrets take precedence over unsaved secrets.
|
|
187
|
+
|
|
188
|
+
Returns:
|
|
189
|
+
Tuple of (secrets_dict, missing_secrets_list) where:
|
|
190
|
+
- secrets_dict: Dictionary mapping key names to their secret values
|
|
191
|
+
- missing_secrets_list: List of secret key names that are missing values
|
|
192
|
+
"""
|
|
193
|
+
secrets = {}
|
|
194
|
+
missing_secrets = []
|
|
195
|
+
secret_keys = self.get_secret_keys()
|
|
196
|
+
|
|
197
|
+
if secret_keys and len(secret_keys) > 0:
|
|
198
|
+
config = Config.shared()
|
|
199
|
+
mcp_secrets = config.get_value(MCP_SECRETS_KEY)
|
|
200
|
+
|
|
201
|
+
for key_name in secret_keys:
|
|
202
|
+
secret_value = None
|
|
203
|
+
|
|
204
|
+
# First check config secrets (persistent storage), key is mcp_server_id::key_name
|
|
205
|
+
secret_key = self._config_secret_key(key_name)
|
|
206
|
+
secret_value = mcp_secrets.get(secret_key) if mcp_secrets else None
|
|
207
|
+
|
|
208
|
+
# Fall back to unsaved secrets (in-memory storage)
|
|
209
|
+
if (
|
|
210
|
+
not secret_value
|
|
211
|
+
and hasattr(self, "_unsaved_secrets")
|
|
212
|
+
and key_name in self._unsaved_secrets
|
|
213
|
+
):
|
|
214
|
+
secret_value = self._unsaved_secrets[key_name]
|
|
215
|
+
|
|
216
|
+
if secret_value:
|
|
217
|
+
secrets[key_name] = secret_value
|
|
218
|
+
else:
|
|
219
|
+
missing_secrets.append(key_name)
|
|
220
|
+
|
|
221
|
+
return secrets, missing_secrets
|
|
222
|
+
|
|
223
|
+
def _save_secrets(self) -> None:
|
|
224
|
+
"""
|
|
225
|
+
Save unsaved secrets to the configuration system.
|
|
226
|
+
"""
|
|
227
|
+
secret_keys = self.get_secret_keys()
|
|
228
|
+
|
|
229
|
+
# No secrets to save
|
|
230
|
+
if not secret_keys:
|
|
231
|
+
return
|
|
232
|
+
|
|
233
|
+
if self.id is None:
|
|
234
|
+
raise ValueError("Server ID cannot be None when saving secrets")
|
|
235
|
+
|
|
236
|
+
# Check if secrets are already saved
|
|
237
|
+
if not hasattr(self, "_unsaved_secrets") or not self._unsaved_secrets:
|
|
238
|
+
return
|
|
239
|
+
|
|
240
|
+
config = Config.shared()
|
|
241
|
+
mcp_secrets: dict[str, str] = config.get_value(MCP_SECRETS_KEY) or {}
|
|
242
|
+
|
|
243
|
+
# Store secrets with the pattern: mcp_server_id::key_name
|
|
244
|
+
for key_name, secret_value in self._unsaved_secrets.items():
|
|
245
|
+
secret_key = self._config_secret_key(key_name)
|
|
246
|
+
mcp_secrets[secret_key] = secret_value
|
|
247
|
+
|
|
248
|
+
config.update_settings({MCP_SECRETS_KEY: mcp_secrets})
|
|
249
|
+
|
|
250
|
+
# Clear unsaved secrets after saving
|
|
251
|
+
self._unsaved_secrets.clear()
|
|
252
|
+
|
|
253
|
+
def delete_secrets(self) -> None:
|
|
254
|
+
"""
|
|
255
|
+
Delete all secrets for this tool server from the configuration system.
|
|
256
|
+
"""
|
|
257
|
+
secret_keys = self.get_secret_keys()
|
|
258
|
+
|
|
259
|
+
config = Config.shared()
|
|
260
|
+
mcp_secrets = config.get_value(MCP_SECRETS_KEY) or dict[str, str]()
|
|
261
|
+
|
|
262
|
+
# Remove secrets with the pattern: mcp_server_id::key_name
|
|
263
|
+
for key_name in secret_keys:
|
|
264
|
+
secret_key = self._config_secret_key(key_name)
|
|
265
|
+
if secret_key in mcp_secrets:
|
|
266
|
+
del mcp_secrets[secret_key]
|
|
267
|
+
|
|
268
|
+
# Always call update_settings to maintain consistency with the old behavior
|
|
269
|
+
config.update_settings({MCP_SECRETS_KEY: mcp_secrets})
|
|
270
|
+
|
|
271
|
+
def save_to_file(self) -> None:
|
|
272
|
+
"""
|
|
273
|
+
Override save_to_file to automatically save any unsaved secrets before saving to file.
|
|
274
|
+
|
|
275
|
+
This ensures that secrets are always saved when the object is saved,
|
|
276
|
+
preventing the issue where secrets could be lost if save_to_file is called
|
|
277
|
+
without explicitly saving secrets first.
|
|
278
|
+
"""
|
|
279
|
+
# Save any unsaved secrets first
|
|
280
|
+
if hasattr(self, "_unsaved_secrets") and self._unsaved_secrets:
|
|
281
|
+
self._save_secrets()
|
|
282
|
+
|
|
283
|
+
# Call the parent save_to_file method
|
|
284
|
+
super().save_to_file()
|
|
285
|
+
|
|
286
|
+
# Internal helpers
|
|
287
|
+
|
|
288
|
+
def _config_secret_key(self, key_name: str) -> str:
|
|
289
|
+
"""
|
|
290
|
+
Generate the secret key pattern for storing/retrieving secrets.
|
|
291
|
+
|
|
292
|
+
Args:
|
|
293
|
+
key_name: The name of the secret key
|
|
294
|
+
|
|
295
|
+
Returns:
|
|
296
|
+
The formatted secret key: "{server_id}::{key_name}"
|
|
297
|
+
"""
|
|
298
|
+
return f"{self.id}::{key_name}"
|