kiln-ai 0.21.0__py3-none-any.whl → 0.22.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/extractors/litellm_extractor.py +52 -32
- kiln_ai/adapters/extractors/test_litellm_extractor.py +169 -71
- kiln_ai/adapters/ml_embedding_model_list.py +330 -28
- kiln_ai/adapters/ml_model_list.py +503 -23
- kiln_ai/adapters/model_adapters/litellm_adapter.py +34 -7
- kiln_ai/adapters/model_adapters/test_litellm_adapter.py +78 -0
- kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py +119 -5
- kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +9 -3
- kiln_ai/adapters/model_adapters/test_structured_output.py +6 -9
- kiln_ai/adapters/test_ml_embedding_model_list.py +89 -279
- kiln_ai/adapters/test_ml_model_list.py +0 -10
- kiln_ai/datamodel/basemodel.py +31 -3
- kiln_ai/datamodel/external_tool_server.py +206 -54
- kiln_ai/datamodel/extraction.py +14 -0
- kiln_ai/datamodel/task.py +5 -0
- kiln_ai/datamodel/task_output.py +41 -11
- kiln_ai/datamodel/test_attachment.py +3 -3
- kiln_ai/datamodel/test_basemodel.py +269 -13
- kiln_ai/datamodel/test_datasource.py +50 -0
- kiln_ai/datamodel/test_external_tool_server.py +534 -152
- kiln_ai/datamodel/test_extraction_model.py +31 -0
- kiln_ai/datamodel/test_task.py +35 -1
- kiln_ai/datamodel/test_tool_id.py +106 -1
- kiln_ai/datamodel/tool_id.py +36 -0
- kiln_ai/tools/base_tool.py +12 -3
- kiln_ai/tools/built_in_tools/math_tools.py +12 -4
- kiln_ai/tools/kiln_task_tool.py +158 -0
- kiln_ai/tools/mcp_server_tool.py +2 -2
- kiln_ai/tools/mcp_session_manager.py +50 -24
- kiln_ai/tools/rag_tools.py +12 -5
- kiln_ai/tools/test_kiln_task_tool.py +527 -0
- kiln_ai/tools/test_mcp_server_tool.py +4 -15
- kiln_ai/tools/test_mcp_session_manager.py +186 -226
- kiln_ai/tools/test_rag_tools.py +86 -5
- kiln_ai/tools/test_tool_registry.py +199 -5
- kiln_ai/tools/tool_registry.py +49 -17
- kiln_ai/utils/filesystem.py +4 -4
- kiln_ai/utils/open_ai_types.py +19 -2
- kiln_ai/utils/pdf_utils.py +21 -0
- kiln_ai/utils/test_open_ai_types.py +88 -12
- kiln_ai/utils/test_pdf_utils.py +14 -1
- {kiln_ai-0.21.0.dist-info → kiln_ai-0.22.0.dist-info}/METADATA +3 -1
- {kiln_ai-0.21.0.dist-info → kiln_ai-0.22.0.dist-info}/RECORD +45 -43
- {kiln_ai-0.21.0.dist-info → kiln_ai-0.22.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.21.0.dist-info → kiln_ai-0.22.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -1,7 +1,10 @@
|
|
|
1
|
+
import re
|
|
1
2
|
from enum import Enum
|
|
2
|
-
from typing import Any
|
|
3
|
+
from typing import Any
|
|
4
|
+
from urllib.parse import urlparse
|
|
3
5
|
|
|
4
6
|
from pydantic import Field, PrivateAttr, model_validator
|
|
7
|
+
from typing_extensions import NotRequired, TypedDict
|
|
5
8
|
|
|
6
9
|
from kiln_ai.datamodel.basemodel import (
|
|
7
10
|
FilenameString,
|
|
@@ -9,6 +12,7 @@ from kiln_ai.datamodel.basemodel import (
|
|
|
9
12
|
)
|
|
10
13
|
from kiln_ai.utils.config import MCP_SECRETS_KEY, Config
|
|
11
14
|
from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
|
|
15
|
+
from kiln_ai.utils.validation import tool_name_validator, validate_return_dict_prop
|
|
12
16
|
|
|
13
17
|
|
|
14
18
|
class ToolServerType(str, Enum):
|
|
@@ -18,6 +22,28 @@ class ToolServerType(str, Enum):
|
|
|
18
22
|
|
|
19
23
|
remote_mcp = "remote_mcp"
|
|
20
24
|
local_mcp = "local_mcp"
|
|
25
|
+
kiln_task = "kiln_task"
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class LocalServerProperties(TypedDict, total=True):
|
|
29
|
+
command: str
|
|
30
|
+
args: NotRequired[list[str]]
|
|
31
|
+
env_vars: NotRequired[dict[str, str]]
|
|
32
|
+
secret_env_var_keys: NotRequired[list[str]]
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class RemoteServerProperties(TypedDict, total=True):
|
|
36
|
+
server_url: str
|
|
37
|
+
headers: NotRequired[dict[str, str]]
|
|
38
|
+
secret_header_keys: NotRequired[list[str]]
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class KilnTaskServerProperties(TypedDict, total=True):
|
|
42
|
+
task_id: str
|
|
43
|
+
run_config_id: str
|
|
44
|
+
name: str
|
|
45
|
+
description: str
|
|
46
|
+
is_archived: bool
|
|
21
47
|
|
|
22
48
|
|
|
23
49
|
class ExternalToolServer(KilnParentedModel):
|
|
@@ -36,8 +62,10 @@ class ExternalToolServer(KilnParentedModel):
|
|
|
36
62
|
default=None,
|
|
37
63
|
description="A description of the external tool for you and your team. Will not be used in prompts/training/validation.",
|
|
38
64
|
)
|
|
39
|
-
|
|
40
|
-
|
|
65
|
+
|
|
66
|
+
properties: (
|
|
67
|
+
LocalServerProperties | RemoteServerProperties | KilnTaskServerProperties
|
|
68
|
+
) = Field(
|
|
41
69
|
description="Configuration properties specific to the tool type.",
|
|
42
70
|
)
|
|
43
71
|
|
|
@@ -80,6 +108,9 @@ class ExternalToolServer(KilnParentedModel):
|
|
|
80
108
|
# Remove from env_vars immediately so they are not saved to file
|
|
81
109
|
del env_vars[key_name]
|
|
82
110
|
|
|
111
|
+
case ToolServerType.kiln_task:
|
|
112
|
+
pass
|
|
113
|
+
|
|
83
114
|
case _:
|
|
84
115
|
raise_exhaustive_enum_error(self.type)
|
|
85
116
|
|
|
@@ -93,76 +124,195 @@ class ExternalToolServer(KilnParentedModel):
|
|
|
93
124
|
if name == "properties":
|
|
94
125
|
self._process_secrets_from_properties()
|
|
95
126
|
|
|
96
|
-
|
|
97
|
-
|
|
127
|
+
# Validation Helpers
|
|
128
|
+
|
|
129
|
+
@classmethod
|
|
130
|
+
def check_server_url(cls, server_url: str) -> None:
|
|
131
|
+
"""Validate Server URL"""
|
|
132
|
+
if not isinstance(server_url, str):
|
|
133
|
+
raise ValueError("Server URL must be a string")
|
|
134
|
+
|
|
135
|
+
# Check for leading whitespace in URL
|
|
136
|
+
if server_url != server_url.lstrip():
|
|
137
|
+
raise ValueError("Server URL must not have leading whitespace")
|
|
138
|
+
|
|
139
|
+
parsed_url = urlparse(server_url)
|
|
140
|
+
if not parsed_url.netloc:
|
|
141
|
+
raise ValueError("Server URL is not a valid URL")
|
|
142
|
+
if parsed_url.scheme not in ["http", "https"]:
|
|
143
|
+
raise ValueError("Server URL must start with http:// or https://")
|
|
144
|
+
|
|
145
|
+
@classmethod
|
|
146
|
+
def check_headers(cls, headers: dict) -> None:
|
|
147
|
+
"""Validate Headers"""
|
|
148
|
+
if not isinstance(headers, dict):
|
|
149
|
+
raise ValueError("headers must be a dictionary")
|
|
150
|
+
|
|
151
|
+
for key, value in headers.items():
|
|
152
|
+
if not key:
|
|
153
|
+
raise ValueError("Header name is required")
|
|
154
|
+
if not value:
|
|
155
|
+
raise ValueError("Header value is required")
|
|
156
|
+
|
|
157
|
+
# Reject invalid header names and CR/LF in names/values
|
|
158
|
+
token_re = re.compile(r"^[!#$%&'*+.^_`|~0-9A-Za-z-]+$")
|
|
159
|
+
if not token_re.match(key):
|
|
160
|
+
raise ValueError(f'Invalid header name: "{key}"')
|
|
161
|
+
if re.search(r"\r|\n", key) or re.search(r"\r|\n", value):
|
|
162
|
+
raise ValueError(
|
|
163
|
+
"Header names/values must not contain invalid characters"
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
@classmethod
|
|
167
|
+
def check_secret_keys(
|
|
168
|
+
cls, secret_keys: list, key_type: str, tool_type: str
|
|
169
|
+
) -> None:
|
|
170
|
+
"""Validate Secret Keys (generic method for both header and env var keys)"""
|
|
171
|
+
if not isinstance(secret_keys, list):
|
|
172
|
+
raise ValueError(
|
|
173
|
+
f"{key_type} must be a list for external tools of type '{tool_type}'"
|
|
174
|
+
)
|
|
175
|
+
if not all(isinstance(k, str) for k in secret_keys):
|
|
176
|
+
raise ValueError(f"{key_type} must contain only strings")
|
|
177
|
+
if not all(key for key in secret_keys):
|
|
178
|
+
raise ValueError("Secret key is required")
|
|
179
|
+
|
|
180
|
+
@classmethod
|
|
181
|
+
def check_env_vars(cls, env_vars: dict) -> None:
|
|
182
|
+
"""Validate Environment Variables"""
|
|
183
|
+
if not isinstance(env_vars, dict):
|
|
184
|
+
raise ValueError("environment variables must be a dictionary")
|
|
185
|
+
|
|
186
|
+
# Validate env_vars keys are in the correct format for Environment Variables
|
|
187
|
+
# According to POSIX specification, environment variable names must:
|
|
188
|
+
# - Start with a letter (a-z, A-Z) or underscore (_)
|
|
189
|
+
# - Contain only ASCII letters, digits, and underscores
|
|
190
|
+
for key, _ in env_vars.items():
|
|
191
|
+
if not key or not (
|
|
192
|
+
key[0].isascii() and (key[0].isalpha() or key[0] == "_")
|
|
193
|
+
):
|
|
194
|
+
raise ValueError(
|
|
195
|
+
f"Invalid environment variable key: {key}. Must start with a letter or underscore."
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
if not all(c.isascii() and (c.isalnum() or c == "_") for c in key):
|
|
199
|
+
raise ValueError(
|
|
200
|
+
f"Invalid environment variable key: {key}. Can only contain letters, digits, and underscores."
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
@classmethod
|
|
204
|
+
def type_from_data(cls, data: dict) -> ToolServerType:
|
|
205
|
+
"""Get the tool server type from the data for the the validators"""
|
|
206
|
+
raw_type = data.get("type")
|
|
207
|
+
if raw_type is None:
|
|
208
|
+
raise ValueError("type is required")
|
|
209
|
+
try:
|
|
210
|
+
return ToolServerType(raw_type)
|
|
211
|
+
except ValueError:
|
|
212
|
+
valid_types = ", ".join(type.value for type in ToolServerType)
|
|
213
|
+
raise ValueError(f"type must be one of: {valid_types}")
|
|
214
|
+
|
|
215
|
+
@model_validator(mode="before")
|
|
216
|
+
def validate_required_fields(cls, data: dict) -> dict:
|
|
98
217
|
"""Validate that each tool type has the required configuration."""
|
|
99
|
-
|
|
218
|
+
server_type = ExternalToolServer.type_from_data(data)
|
|
219
|
+
properties = data.get("properties", {})
|
|
220
|
+
|
|
221
|
+
match server_type:
|
|
100
222
|
case ToolServerType.remote_mcp:
|
|
101
|
-
server_url =
|
|
102
|
-
if
|
|
103
|
-
raise ValueError(
|
|
104
|
-
"server_url must be a string for external tools of type 'remote_mcp'"
|
|
105
|
-
)
|
|
106
|
-
if not server_url:
|
|
223
|
+
server_url = properties.get("server_url", None)
|
|
224
|
+
if server_url is None:
|
|
107
225
|
raise ValueError(
|
|
108
|
-
"
|
|
226
|
+
"Server URL is required to connect to a remote MCP server"
|
|
109
227
|
)
|
|
228
|
+
ExternalToolServer.check_server_url(server_url)
|
|
110
229
|
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
230
|
+
case ToolServerType.local_mcp:
|
|
231
|
+
command = properties.get("command", None)
|
|
232
|
+
if command is None:
|
|
233
|
+
raise ValueError("command is required to start a local MCP server")
|
|
234
|
+
if not isinstance(command, str):
|
|
115
235
|
raise ValueError(
|
|
116
|
-
"
|
|
236
|
+
"command must be a string to start a local MCP server"
|
|
117
237
|
)
|
|
238
|
+
# Reject empty/whitespace-only command strings
|
|
239
|
+
if command.strip() == "":
|
|
240
|
+
raise ValueError("command must be a non-empty string")
|
|
118
241
|
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
if not isinstance(secret_header_keys, list):
|
|
242
|
+
args = properties.get("args", None)
|
|
243
|
+
if args is not None:
|
|
244
|
+
if not isinstance(args, list):
|
|
123
245
|
raise ValueError(
|
|
124
|
-
"
|
|
246
|
+
"arguments must be a list to start a local MCP server"
|
|
125
247
|
)
|
|
126
|
-
if not all(isinstance(k, str) for k in secret_header_keys):
|
|
127
|
-
raise ValueError("secret_header_keys must contain only strings")
|
|
128
248
|
|
|
129
|
-
case ToolServerType.
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
249
|
+
case ToolServerType.kiln_task:
|
|
250
|
+
tool_name_validator(properties.get("name", ""))
|
|
251
|
+
err_msg_prefix = "Kiln task server properties:"
|
|
252
|
+
validate_return_dict_prop(
|
|
253
|
+
properties, "description", str, err_msg_prefix
|
|
254
|
+
)
|
|
255
|
+
description = properties.get("description", "")
|
|
256
|
+
if len(description) > 128:
|
|
257
|
+
raise ValueError("description must be 128 characters or less")
|
|
258
|
+
validate_return_dict_prop(
|
|
259
|
+
properties, "is_archived", bool, err_msg_prefix
|
|
260
|
+
)
|
|
261
|
+
validate_return_dict_prop(properties, "task_id", str, err_msg_prefix)
|
|
262
|
+
validate_return_dict_prop(
|
|
263
|
+
properties, "run_config_id", str, err_msg_prefix
|
|
264
|
+
)
|
|
137
265
|
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
)
|
|
266
|
+
case _:
|
|
267
|
+
# Type checking will catch missing cases
|
|
268
|
+
raise_exhaustive_enum_error(server_type)
|
|
269
|
+
return data
|
|
143
270
|
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
271
|
+
@model_validator(mode="before")
|
|
272
|
+
def validate_headers_and_env_vars(cls, data: dict) -> dict:
|
|
273
|
+
"""
|
|
274
|
+
Validate secrets, these needs to be validated before model initlization because secrets will be processed and stripped
|
|
275
|
+
"""
|
|
276
|
+
type = ExternalToolServer.type_from_data(data)
|
|
277
|
+
|
|
278
|
+
properties = data.get("properties", {})
|
|
279
|
+
if properties is None:
|
|
280
|
+
raise ValueError("properties is required")
|
|
281
|
+
|
|
282
|
+
match type:
|
|
283
|
+
case ToolServerType.remote_mcp:
|
|
284
|
+
# Validate headers
|
|
285
|
+
headers = properties.get("headers", None)
|
|
286
|
+
if headers is not None:
|
|
287
|
+
ExternalToolServer.check_headers(headers)
|
|
288
|
+
|
|
289
|
+
# Secret header keys are optional, validate if they are set
|
|
290
|
+
secret_header_keys = properties.get("secret_header_keys", None)
|
|
291
|
+
if secret_header_keys is not None:
|
|
292
|
+
ExternalToolServer.check_secret_keys(
|
|
293
|
+
secret_header_keys, "secret_header_keys", "remote_mcp"
|
|
148
294
|
)
|
|
149
295
|
|
|
150
|
-
|
|
296
|
+
case ToolServerType.local_mcp:
|
|
297
|
+
# Validate secret environment variable keys
|
|
298
|
+
env_vars = properties.get("env_vars", {})
|
|
299
|
+
if env_vars is not None:
|
|
300
|
+
ExternalToolServer.check_env_vars(env_vars)
|
|
301
|
+
|
|
151
302
|
# Secret env var keys are optional, but if they are set, they must be a list of strings
|
|
303
|
+
secret_env_var_keys = properties.get("secret_env_var_keys", None)
|
|
152
304
|
if secret_env_var_keys is not None:
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
"secret_env_var_keys must contain only strings"
|
|
160
|
-
)
|
|
305
|
+
ExternalToolServer.check_secret_keys(
|
|
306
|
+
secret_env_var_keys, "secret_env_var_keys", "local_mcp"
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
case ToolServerType.kiln_task:
|
|
310
|
+
pass
|
|
161
311
|
|
|
162
312
|
case _:
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
return
|
|
313
|
+
raise_exhaustive_enum_error(type)
|
|
314
|
+
|
|
315
|
+
return data
|
|
166
316
|
|
|
167
317
|
def get_secret_keys(self) -> list[str]:
|
|
168
318
|
"""
|
|
@@ -176,6 +326,8 @@ class ExternalToolServer(KilnParentedModel):
|
|
|
176
326
|
return self.properties.get("secret_header_keys", [])
|
|
177
327
|
case ToolServerType.local_mcp:
|
|
178
328
|
return self.properties.get("secret_env_var_keys", [])
|
|
329
|
+
case ToolServerType.kiln_task:
|
|
330
|
+
return []
|
|
179
331
|
case _:
|
|
180
332
|
raise_exhaustive_enum_error(self.type)
|
|
181
333
|
|
kiln_ai/datamodel/extraction.py
CHANGED
|
@@ -8,6 +8,7 @@ from pydantic import (
|
|
|
8
8
|
Field,
|
|
9
9
|
SerializationInfo,
|
|
10
10
|
ValidationInfo,
|
|
11
|
+
computed_field,
|
|
11
12
|
field_serializer,
|
|
12
13
|
field_validator,
|
|
13
14
|
model_validator,
|
|
@@ -259,10 +260,17 @@ class FileInfo(BaseModel):
|
|
|
259
260
|
class Document(
|
|
260
261
|
KilnParentedModel, KilnParentModel, parent_of={"extractions": Extraction}
|
|
261
262
|
):
|
|
263
|
+
# this field should not be changed after creation
|
|
262
264
|
name: FilenameString = Field(
|
|
263
265
|
description="A name to identify the document.",
|
|
264
266
|
)
|
|
265
267
|
|
|
268
|
+
# this field can be changed after creation
|
|
269
|
+
name_override: str | None = Field(
|
|
270
|
+
description="A friendly name to identify the document. This is used for display purposes and can be different from the name.",
|
|
271
|
+
default=None,
|
|
272
|
+
)
|
|
273
|
+
|
|
266
274
|
description: str = Field(description="A description for the file")
|
|
267
275
|
|
|
268
276
|
original_file: FileInfo = Field(description="The original file")
|
|
@@ -295,6 +303,12 @@ class Document(
|
|
|
295
303
|
def extractions(self, readonly: bool = False) -> list[Extraction]:
|
|
296
304
|
return super().extractions(readonly=readonly) # type: ignore
|
|
297
305
|
|
|
306
|
+
@computed_field
|
|
307
|
+
@property
|
|
308
|
+
def friendly_name(self) -> str:
|
|
309
|
+
# backward compatibility: old documents did not have name_override
|
|
310
|
+
return self.name_override or self.name
|
|
311
|
+
|
|
298
312
|
|
|
299
313
|
def get_kind_from_mime_type(mime_type: str) -> Kind | None:
|
|
300
314
|
for kind, mime_types in SUPPORTED_MIME_TYPES.items():
|
kiln_ai/datamodel/task.py
CHANGED
|
@@ -131,6 +131,11 @@ class Task(
|
|
|
131
131
|
description="Instructions for the model 'thinking' about the requirement prior to answering. Used for chain of thought style prompting.",
|
|
132
132
|
)
|
|
133
133
|
|
|
134
|
+
default_run_config_id: ID_TYPE | None = Field(
|
|
135
|
+
default=None,
|
|
136
|
+
description="ID of the run config to use for this task by default. Must exist in saved run configs for this task.",
|
|
137
|
+
)
|
|
138
|
+
|
|
134
139
|
def output_schema(self) -> Dict | None:
|
|
135
140
|
if self.output_json_schema is None:
|
|
136
141
|
return None
|
kiln_ai/datamodel/task_output.py
CHANGED
|
@@ -171,6 +171,7 @@ class DataSourceType(str, Enum):
|
|
|
171
171
|
human = "human"
|
|
172
172
|
synthetic = "synthetic"
|
|
173
173
|
file_import = "file_import"
|
|
174
|
+
tool_call = "tool_call"
|
|
174
175
|
|
|
175
176
|
|
|
176
177
|
class DataSourceProperty(BaseModel):
|
|
@@ -189,16 +190,17 @@ class DataSourceProperty(BaseModel):
|
|
|
189
190
|
|
|
190
191
|
class DataSource(BaseModel):
|
|
191
192
|
"""
|
|
192
|
-
Represents the origin of data, either human or
|
|
193
|
+
Represents the origin of data, either human, synthetic, file import, or tool call, with associated properties.
|
|
193
194
|
|
|
194
|
-
Properties vary based on the source type - for synthetic sources this includes
|
|
195
|
-
model information, for human sources this includes creator information
|
|
195
|
+
Properties vary based on the source type - for synthetic/tool_call sources this includes
|
|
196
|
+
model information, for human sources this includes creator information, for file imports
|
|
197
|
+
this includes file information.
|
|
196
198
|
"""
|
|
197
199
|
|
|
198
200
|
type: DataSourceType
|
|
199
201
|
properties: Dict[str, str | int | float] = Field(
|
|
200
202
|
default={},
|
|
201
|
-
description="Properties describing the data source. For synthetic things like model. For human
|
|
203
|
+
description="Properties describing the data source. For synthetic things like model. For human: the human's name. For file_import: file information.",
|
|
202
204
|
)
|
|
203
205
|
run_config: Optional[RunConfigProperties] = Field(
|
|
204
206
|
default=None,
|
|
@@ -210,43 +212,71 @@ class DataSource(BaseModel):
|
|
|
210
212
|
name="created_by",
|
|
211
213
|
type=str,
|
|
212
214
|
required_for=[DataSourceType.human],
|
|
213
|
-
not_allowed_for=[
|
|
215
|
+
not_allowed_for=[
|
|
216
|
+
DataSourceType.synthetic,
|
|
217
|
+
DataSourceType.file_import,
|
|
218
|
+
DataSourceType.tool_call,
|
|
219
|
+
],
|
|
214
220
|
),
|
|
215
221
|
DataSourceProperty(
|
|
216
222
|
name="model_name",
|
|
217
223
|
type=str,
|
|
218
224
|
required_for=[DataSourceType.synthetic],
|
|
219
|
-
not_allowed_for=[
|
|
225
|
+
not_allowed_for=[
|
|
226
|
+
DataSourceType.human,
|
|
227
|
+
DataSourceType.file_import,
|
|
228
|
+
DataSourceType.tool_call,
|
|
229
|
+
],
|
|
220
230
|
),
|
|
221
231
|
DataSourceProperty(
|
|
222
232
|
name="model_provider",
|
|
223
233
|
type=str,
|
|
224
234
|
required_for=[DataSourceType.synthetic],
|
|
225
|
-
not_allowed_for=[
|
|
235
|
+
not_allowed_for=[
|
|
236
|
+
DataSourceType.human,
|
|
237
|
+
DataSourceType.file_import,
|
|
238
|
+
DataSourceType.tool_call,
|
|
239
|
+
],
|
|
226
240
|
),
|
|
227
241
|
DataSourceProperty(
|
|
228
242
|
name="adapter_name",
|
|
229
243
|
type=str,
|
|
230
244
|
required_for=[DataSourceType.synthetic],
|
|
231
|
-
not_allowed_for=[
|
|
245
|
+
not_allowed_for=[
|
|
246
|
+
DataSourceType.human,
|
|
247
|
+
DataSourceType.file_import,
|
|
248
|
+
DataSourceType.tool_call,
|
|
249
|
+
],
|
|
232
250
|
),
|
|
233
251
|
DataSourceProperty(
|
|
234
252
|
# Legacy field -- allow loading from old runs, but we shouldn't be setting it.
|
|
235
253
|
name="prompt_builder_name",
|
|
236
254
|
type=str,
|
|
237
|
-
not_allowed_for=[
|
|
255
|
+
not_allowed_for=[
|
|
256
|
+
DataSourceType.human,
|
|
257
|
+
DataSourceType.file_import,
|
|
258
|
+
DataSourceType.tool_call,
|
|
259
|
+
],
|
|
238
260
|
),
|
|
239
261
|
DataSourceProperty(
|
|
240
262
|
# The PromptId of the prompt. Can be a saved prompt, fine-tune, generator name, etc. See PromptId type for more details.
|
|
241
263
|
name="prompt_id",
|
|
242
264
|
type=str,
|
|
243
|
-
not_allowed_for=[
|
|
265
|
+
not_allowed_for=[
|
|
266
|
+
DataSourceType.human,
|
|
267
|
+
DataSourceType.file_import,
|
|
268
|
+
DataSourceType.tool_call,
|
|
269
|
+
],
|
|
244
270
|
),
|
|
245
271
|
DataSourceProperty(
|
|
246
272
|
name="file_name",
|
|
247
273
|
type=str,
|
|
248
274
|
required_for=[DataSourceType.file_import],
|
|
249
|
-
not_allowed_for=[
|
|
275
|
+
not_allowed_for=[
|
|
276
|
+
DataSourceType.human,
|
|
277
|
+
DataSourceType.synthetic,
|
|
278
|
+
DataSourceType.tool_call,
|
|
279
|
+
],
|
|
250
280
|
),
|
|
251
281
|
]
|
|
252
282
|
|
|
@@ -14,7 +14,7 @@ from kiln_ai.datamodel.basemodel import KilnAttachmentModel, KilnBaseModel
|
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
class ModelWithAttachment(KilnBaseModel):
|
|
17
|
-
attachment: KilnAttachmentModel = Field(default=None)
|
|
17
|
+
attachment: KilnAttachmentModel | None = Field(default=None)
|
|
18
18
|
attachment_list: Optional[List[KilnAttachmentModel]] = Field(default=None)
|
|
19
19
|
attachment_dict: Optional[Dict[str, KilnAttachmentModel]] = Field(default=None)
|
|
20
20
|
|
|
@@ -516,7 +516,7 @@ class ModelWithAttachmentNameOverrideList(KilnBaseModel):
|
|
|
516
516
|
@field_serializer("attachment_list")
|
|
517
517
|
def serialize_attachment_list(
|
|
518
518
|
self, attachment_list: List[KilnAttachmentModel], info: SerializationInfo
|
|
519
|
-
) -> dict:
|
|
519
|
+
) -> List[dict]:
|
|
520
520
|
context = info.context or {}
|
|
521
521
|
context["filename_prefix"] = "attachment_override"
|
|
522
522
|
return [
|
|
@@ -555,7 +555,7 @@ def test_attachment_filename_override_list(test_base_kiln_file, mock_file_factor
|
|
|
555
555
|
|
|
556
556
|
|
|
557
557
|
class ModelWithAttachmentNoNameOverride(KilnBaseModel):
|
|
558
|
-
attachment: KilnAttachmentModel = Field(default=None)
|
|
558
|
+
attachment: KilnAttachmentModel | None = Field(default=None)
|
|
559
559
|
|
|
560
560
|
|
|
561
561
|
def test_attachment_filename_no_override(test_base_kiln_file, mock_file_factory):
|