kiln-ai 0.19.0__py3-none-any.whl → 0.20.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (70) hide show
  1. kiln_ai/adapters/__init__.py +2 -2
  2. kiln_ai/adapters/adapter_registry.py +19 -1
  3. kiln_ai/adapters/chat/chat_formatter.py +8 -12
  4. kiln_ai/adapters/chat/test_chat_formatter.py +6 -2
  5. kiln_ai/adapters/docker_model_runner_tools.py +119 -0
  6. kiln_ai/adapters/eval/base_eval.py +2 -2
  7. kiln_ai/adapters/eval/eval_runner.py +3 -1
  8. kiln_ai/adapters/eval/g_eval.py +2 -2
  9. kiln_ai/adapters/eval/test_base_eval.py +1 -1
  10. kiln_ai/adapters/eval/test_g_eval.py +3 -4
  11. kiln_ai/adapters/fine_tune/__init__.py +1 -1
  12. kiln_ai/adapters/fine_tune/openai_finetune.py +14 -4
  13. kiln_ai/adapters/fine_tune/test_openai_finetune.py +108 -111
  14. kiln_ai/adapters/ml_model_list.py +380 -34
  15. kiln_ai/adapters/model_adapters/base_adapter.py +51 -21
  16. kiln_ai/adapters/model_adapters/litellm_adapter.py +383 -79
  17. kiln_ai/adapters/model_adapters/test_base_adapter.py +193 -17
  18. kiln_ai/adapters/model_adapters/test_litellm_adapter.py +406 -1
  19. kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py +1103 -0
  20. kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +5 -5
  21. kiln_ai/adapters/model_adapters/test_structured_output.py +110 -4
  22. kiln_ai/adapters/parsers/__init__.py +1 -1
  23. kiln_ai/adapters/provider_tools.py +15 -1
  24. kiln_ai/adapters/repair/test_repair_task.py +12 -9
  25. kiln_ai/adapters/run_output.py +3 -0
  26. kiln_ai/adapters/test_adapter_registry.py +80 -1
  27. kiln_ai/adapters/test_docker_model_runner_tools.py +305 -0
  28. kiln_ai/adapters/test_ml_model_list.py +39 -1
  29. kiln_ai/adapters/test_prompt_adaptors.py +13 -6
  30. kiln_ai/adapters/test_provider_tools.py +55 -0
  31. kiln_ai/adapters/test_remote_config.py +98 -0
  32. kiln_ai/datamodel/__init__.py +23 -21
  33. kiln_ai/datamodel/datamodel_enums.py +1 -0
  34. kiln_ai/datamodel/eval.py +1 -1
  35. kiln_ai/datamodel/external_tool_server.py +298 -0
  36. kiln_ai/datamodel/json_schema.py +25 -10
  37. kiln_ai/datamodel/project.py +8 -1
  38. kiln_ai/datamodel/registry.py +0 -15
  39. kiln_ai/datamodel/run_config.py +62 -0
  40. kiln_ai/datamodel/task.py +2 -77
  41. kiln_ai/datamodel/task_output.py +6 -1
  42. kiln_ai/datamodel/task_run.py +41 -0
  43. kiln_ai/datamodel/test_basemodel.py +3 -3
  44. kiln_ai/datamodel/test_example_models.py +175 -0
  45. kiln_ai/datamodel/test_external_tool_server.py +691 -0
  46. kiln_ai/datamodel/test_registry.py +8 -3
  47. kiln_ai/datamodel/test_task.py +15 -47
  48. kiln_ai/datamodel/test_tool_id.py +239 -0
  49. kiln_ai/datamodel/tool_id.py +83 -0
  50. kiln_ai/tools/__init__.py +8 -0
  51. kiln_ai/tools/base_tool.py +82 -0
  52. kiln_ai/tools/built_in_tools/__init__.py +13 -0
  53. kiln_ai/tools/built_in_tools/math_tools.py +124 -0
  54. kiln_ai/tools/built_in_tools/test_math_tools.py +204 -0
  55. kiln_ai/tools/mcp_server_tool.py +95 -0
  56. kiln_ai/tools/mcp_session_manager.py +243 -0
  57. kiln_ai/tools/test_base_tools.py +199 -0
  58. kiln_ai/tools/test_mcp_server_tool.py +457 -0
  59. kiln_ai/tools/test_mcp_session_manager.py +1585 -0
  60. kiln_ai/tools/test_tool_registry.py +473 -0
  61. kiln_ai/tools/tool_registry.py +64 -0
  62. kiln_ai/utils/config.py +22 -0
  63. kiln_ai/utils/open_ai_types.py +94 -0
  64. kiln_ai/utils/project_utils.py +17 -0
  65. kiln_ai/utils/test_config.py +138 -1
  66. kiln_ai/utils/test_open_ai_types.py +131 -0
  67. {kiln_ai-0.19.0.dist-info → kiln_ai-0.20.1.dist-info}/METADATA +6 -5
  68. {kiln_ai-0.19.0.dist-info → kiln_ai-0.20.1.dist-info}/RECORD +70 -47
  69. {kiln_ai-0.19.0.dist-info → kiln_ai-0.20.1.dist-info}/WHEEL +0 -0
  70. {kiln_ai-0.19.0.dist-info → kiln_ai-0.20.1.dist-info}/licenses/LICENSE.txt +0 -0
@@ -0,0 +1,298 @@
1
+ from enum import Enum
2
+ from typing import Any, Dict
3
+
4
+ from pydantic import Field, PrivateAttr, model_validator
5
+
6
+ from kiln_ai.datamodel.basemodel import (
7
+ FilenameString,
8
+ KilnParentedModel,
9
+ )
10
+ from kiln_ai.utils.config import MCP_SECRETS_KEY, Config
11
+ from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
12
+
13
+
14
+ class ToolServerType(str, Enum):
15
+ """
16
+ Enumeration of supported external tool server types.
17
+ """
18
+
19
+ remote_mcp = "remote_mcp"
20
+ local_mcp = "local_mcp"
21
+
22
+
23
+ class ExternalToolServer(KilnParentedModel):
24
+ """
25
+ Configuration for communicating with a external MCP (Model Context Protocol) Server for LLM tool calls. External tool servers can be remote or local.
26
+
27
+ This model stores the necessary configuration to connect to and authenticate with
28
+ external MCP servers that provide tools for LLM interactions.
29
+ """
30
+
31
+ name: FilenameString = Field(description="The name of the external tool.")
32
+ type: ToolServerType = Field(
33
+ description="The type of external tool server. Remote tools are hosted on a remote server",
34
+ )
35
+ description: str | None = Field(
36
+ default=None,
37
+ description="A description of the external tool for you and your team. Will not be used in prompts/training/validation.",
38
+ )
39
+ properties: Dict[str, Any] = Field(
40
+ default={},
41
+ description="Configuration properties specific to the tool type.",
42
+ )
43
+
44
+ # Private variable to store unsaved secrets
45
+ _unsaved_secrets: dict[str, str] = PrivateAttr(default_factory=dict)
46
+
47
+ def model_post_init(self, __context: Any) -> None:
48
+ # Process secrets after initialization (pydantic v2 hook)
49
+ self._process_secrets_from_properties()
50
+
51
+ def _process_secrets_from_properties(self) -> None:
52
+ """
53
+ Extract secrets from properties and move them to _unsaved_secrets.
54
+ This removes secrets from the properties dict so they aren't saved to file.
55
+ Clears existing _unsaved_secrets first to handle property updates correctly.
56
+ """
57
+ # Clear existing unsaved secrets since we're reprocessing
58
+ self._unsaved_secrets.clear()
59
+
60
+ secret_keys = self.get_secret_keys()
61
+
62
+ if not secret_keys:
63
+ return
64
+
65
+ # Extract secret values from properties based on server type
66
+ match self.type:
67
+ case ToolServerType.remote_mcp:
68
+ headers = self.properties.get("headers", {})
69
+ for key_name in secret_keys:
70
+ if key_name in headers:
71
+ self._unsaved_secrets[key_name] = headers[key_name]
72
+ # Remove from headers immediately so they are not saved to file
73
+ del headers[key_name]
74
+
75
+ case ToolServerType.local_mcp:
76
+ env_vars = self.properties.get("env_vars", {})
77
+ for key_name in secret_keys:
78
+ if key_name in env_vars:
79
+ self._unsaved_secrets[key_name] = env_vars[key_name]
80
+ # Remove from env_vars immediately so they are not saved to file
81
+ del env_vars[key_name]
82
+
83
+ case _:
84
+ raise_exhaustive_enum_error(self.type)
85
+
86
+ def __setattr__(self, name: str, value: Any) -> None:
87
+ """
88
+ Override __setattr__ to process secrets whenever properties are updated.
89
+ """
90
+ super().__setattr__(name, value)
91
+
92
+ # Process secrets whenever properties are updated
93
+ if name == "properties":
94
+ self._process_secrets_from_properties()
95
+
96
+ @model_validator(mode="after")
97
+ def validate_required_fields(self) -> "ExternalToolServer":
98
+ """Validate that each tool type has the required configuration."""
99
+ match self.type:
100
+ case ToolServerType.remote_mcp:
101
+ server_url = self.properties.get("server_url", None)
102
+ if not isinstance(server_url, str):
103
+ raise ValueError(
104
+ "server_url must be a string for external tools of type 'remote_mcp'"
105
+ )
106
+ if not server_url:
107
+ raise ValueError(
108
+ "server_url is required to connect to a remote MCP server"
109
+ )
110
+
111
+ headers = self.properties.get("headers", None)
112
+ if headers is None:
113
+ raise ValueError("headers must be set when type is 'remote_mcp'")
114
+ if not isinstance(headers, dict):
115
+ raise ValueError(
116
+ "headers must be a dictionary for external tools of type 'remote_mcp'"
117
+ )
118
+
119
+ secret_header_keys = self.properties.get("secret_header_keys", None)
120
+ # Secret header keys are optional, but if they are set, they must be a list of strings
121
+ if secret_header_keys is not None:
122
+ if not isinstance(secret_header_keys, list):
123
+ raise ValueError(
124
+ "secret_header_keys must be a list for external tools of type 'remote_mcp'"
125
+ )
126
+ if not all(isinstance(k, str) for k in secret_header_keys):
127
+ raise ValueError("secret_header_keys must contain only strings")
128
+
129
+ case ToolServerType.local_mcp:
130
+ command = self.properties.get("command", None)
131
+ if not isinstance(command, str):
132
+ raise ValueError(
133
+ "command must be a string to start a local MCP server"
134
+ )
135
+ if not command.strip():
136
+ raise ValueError("command is required to start a local MCP server")
137
+
138
+ args = self.properties.get("args", None)
139
+ if not isinstance(args, list):
140
+ raise ValueError(
141
+ "arguments must be a list to start a local MCP server"
142
+ )
143
+
144
+ env_vars = self.properties.get("env_vars", {})
145
+ if not isinstance(env_vars, dict):
146
+ raise ValueError(
147
+ "environment variables must be a dictionary for external tools of type 'local_mcp'"
148
+ )
149
+
150
+ secret_env_var_keys = self.properties.get("secret_env_var_keys", None)
151
+ # Secret env var keys are optional, but if they are set, they must be a list of strings
152
+ if secret_env_var_keys is not None:
153
+ if not isinstance(secret_env_var_keys, list):
154
+ raise ValueError(
155
+ "secret_env_var_keys must be a list for external tools of type 'local_mcp'"
156
+ )
157
+ if not all(isinstance(k, str) for k in secret_env_var_keys):
158
+ raise ValueError(
159
+ "secret_env_var_keys must contain only strings"
160
+ )
161
+
162
+ case _:
163
+ # Type checking will catch missing cases
164
+ raise_exhaustive_enum_error(self.type)
165
+ return self
166
+
167
+ def get_secret_keys(self) -> list[str]:
168
+ """
169
+ Get the list of secret key names based on server type.
170
+
171
+ Returns:
172
+ List of secret key names (header names for remote, env var names for local)
173
+ """
174
+ match self.type:
175
+ case ToolServerType.remote_mcp:
176
+ return self.properties.get("secret_header_keys", [])
177
+ case ToolServerType.local_mcp:
178
+ return self.properties.get("secret_env_var_keys", [])
179
+ case _:
180
+ raise_exhaustive_enum_error(self.type)
181
+
182
+ def retrieve_secrets(self) -> tuple[dict[str, str], list[str]]:
183
+ """
184
+ Retrieve secrets from configuration system or in-memory storage.
185
+ Automatically determines which secret keys to retrieve based on the server type.
186
+ Config secrets take precedence over unsaved secrets.
187
+
188
+ Returns:
189
+ Tuple of (secrets_dict, missing_secrets_list) where:
190
+ - secrets_dict: Dictionary mapping key names to their secret values
191
+ - missing_secrets_list: List of secret key names that are missing values
192
+ """
193
+ secrets = {}
194
+ missing_secrets = []
195
+ secret_keys = self.get_secret_keys()
196
+
197
+ if secret_keys and len(secret_keys) > 0:
198
+ config = Config.shared()
199
+ mcp_secrets = config.get_value(MCP_SECRETS_KEY)
200
+
201
+ for key_name in secret_keys:
202
+ secret_value = None
203
+
204
+ # First check config secrets (persistent storage), key is mcp_server_id::key_name
205
+ secret_key = self._config_secret_key(key_name)
206
+ secret_value = mcp_secrets.get(secret_key) if mcp_secrets else None
207
+
208
+ # Fall back to unsaved secrets (in-memory storage)
209
+ if (
210
+ not secret_value
211
+ and hasattr(self, "_unsaved_secrets")
212
+ and key_name in self._unsaved_secrets
213
+ ):
214
+ secret_value = self._unsaved_secrets[key_name]
215
+
216
+ if secret_value:
217
+ secrets[key_name] = secret_value
218
+ else:
219
+ missing_secrets.append(key_name)
220
+
221
+ return secrets, missing_secrets
222
+
223
+ def _save_secrets(self) -> None:
224
+ """
225
+ Save unsaved secrets to the configuration system.
226
+ """
227
+ secret_keys = self.get_secret_keys()
228
+
229
+ # No secrets to save
230
+ if not secret_keys:
231
+ return
232
+
233
+ if self.id is None:
234
+ raise ValueError("Server ID cannot be None when saving secrets")
235
+
236
+ # Check if secrets are already saved
237
+ if not hasattr(self, "_unsaved_secrets") or not self._unsaved_secrets:
238
+ return
239
+
240
+ config = Config.shared()
241
+ mcp_secrets: dict[str, str] = config.get_value(MCP_SECRETS_KEY) or {}
242
+
243
+ # Store secrets with the pattern: mcp_server_id::key_name
244
+ for key_name, secret_value in self._unsaved_secrets.items():
245
+ secret_key = self._config_secret_key(key_name)
246
+ mcp_secrets[secret_key] = secret_value
247
+
248
+ config.update_settings({MCP_SECRETS_KEY: mcp_secrets})
249
+
250
+ # Clear unsaved secrets after saving
251
+ self._unsaved_secrets.clear()
252
+
253
+ def delete_secrets(self) -> None:
254
+ """
255
+ Delete all secrets for this tool server from the configuration system.
256
+ """
257
+ secret_keys = self.get_secret_keys()
258
+
259
+ config = Config.shared()
260
+ mcp_secrets = config.get_value(MCP_SECRETS_KEY) or dict[str, str]()
261
+
262
+ # Remove secrets with the pattern: mcp_server_id::key_name
263
+ for key_name in secret_keys:
264
+ secret_key = self._config_secret_key(key_name)
265
+ if secret_key in mcp_secrets:
266
+ del mcp_secrets[secret_key]
267
+
268
+ # Always call update_settings to maintain consistency with the old behavior
269
+ config.update_settings({MCP_SECRETS_KEY: mcp_secrets})
270
+
271
+ def save_to_file(self) -> None:
272
+ """
273
+ Override save_to_file to automatically save any unsaved secrets before saving to file.
274
+
275
+ This ensures that secrets are always saved when the object is saved,
276
+ preventing the issue where secrets could be lost if save_to_file is called
277
+ without explicitly saving secrets first.
278
+ """
279
+ # Save any unsaved secrets first
280
+ if hasattr(self, "_unsaved_secrets") and self._unsaved_secrets:
281
+ self._save_secrets()
282
+
283
+ # Call the parent save_to_file method
284
+ super().save_to_file()
285
+
286
+ # Internal helpers
287
+
288
+ def _config_secret_key(self, key_name: str) -> str:
289
+ """
290
+ Generate the secret key pattern for storing/retrieving secrets.
291
+
292
+ Args:
293
+ key_name: The name of the secret key
294
+
295
+ Returns:
296
+ The formatted secret key: "{server_id}::{key_name}"
297
+ """
298
+ return f"{self.id}::{key_name}"
@@ -84,25 +84,40 @@ def schema_from_json_str(v: str) -> Dict:
84
84
  """
85
85
  try:
86
86
  parsed = json.loads(v)
87
- jsonschema.Draft202012Validator.check_schema(parsed)
88
87
  if not isinstance(parsed, dict):
89
88
  raise ValueError(f"JSON schema must be a dict, not {type(parsed)}")
90
- # Top level arrays are valid JSON schemas, but we don't want to allow them here as they often cause issues
91
- if (
92
- "type" not in parsed
93
- or parsed["type"] != "object"
94
- or "properties" not in parsed
95
- ):
96
- raise ValueError(f"JSON schema must be an object with properties: {v}")
89
+
90
+ validate_schema_dict(parsed)
97
91
  return parsed
98
- except jsonschema.exceptions.SchemaError as e:
99
- raise ValueError(f"Invalid JSON schema: {v} \n{e}")
100
92
  except json.JSONDecodeError as e:
101
93
  raise ValueError(f"Invalid JSON: {v}\n {e}")
102
94
  except Exception as e:
103
95
  raise ValueError(f"Unexpected error parsing JSON schema: {v}\n {e}")
104
96
 
105
97
 
98
+ def validate_schema_dict(v: Dict):
99
+ """Parse and validate a JSON schema dictionary.
100
+
101
+ Args:
102
+ v: Dictionary containing a JSON schema definition
103
+
104
+ Returns:
105
+ Dict containing the parsed JSON schema
106
+
107
+ Raises:
108
+ ValueError: If the input is not a valid JSON schema object with required properties
109
+ """
110
+ try:
111
+ jsonschema.Draft202012Validator.check_schema(v)
112
+ # Top level arrays are valid JSON schemas, but we don't want to allow them here as they often cause issues
113
+ if "type" not in v or v["type"] != "object" or "properties" not in v:
114
+ raise ValueError(f"JSON schema must be an object with properties: {v}")
115
+ except jsonschema.exceptions.SchemaError as e:
116
+ raise ValueError(f"Invalid JSON schema: {v} \n{e}")
117
+ except Exception as e:
118
+ raise ValueError(f"Unexpected error validating dict JSON schema: {v}\n {e}")
119
+
120
+
106
121
  def string_to_json_key(s: str) -> str:
107
122
  """Convert a string to a valid JSON key."""
108
123
  return re.sub(r"[^a-z0-9_]", "", s.strip().lower().replace(" ", "_"))
@@ -1,10 +1,14 @@
1
1
  from pydantic import Field
2
2
 
3
3
  from kiln_ai.datamodel.basemodel import FilenameString, KilnParentModel
4
+ from kiln_ai.datamodel.external_tool_server import ExternalToolServer
4
5
  from kiln_ai.datamodel.task import Task
5
6
 
6
7
 
7
- class Project(KilnParentModel, parent_of={"tasks": Task}):
8
+ class Project(
9
+ KilnParentModel,
10
+ parent_of={"tasks": Task, "external_tool_servers": ExternalToolServer},
11
+ ):
8
12
  """
9
13
  A collection of related tasks.
10
14
 
@@ -21,3 +25,6 @@ class Project(KilnParentModel, parent_of={"tasks": Task}):
21
25
  # Needed for typechecking. We should fix this in KilnParentModel
22
26
  def tasks(self) -> list[Task]:
23
27
  return super().tasks() # type: ignore
28
+
29
+ def external_tool_servers(self, readonly: bool = False) -> list[ExternalToolServer]:
30
+ return super().external_tool_servers(readonly=readonly) # type: ignore
@@ -14,18 +14,3 @@ def all_projects() -> list[Project]:
14
14
  # deleted files are possible continue with the rest
15
15
  continue
16
16
  return projects
17
-
18
-
19
- def project_from_id(project_id: str) -> Project | None:
20
- project_paths = Config.shared().projects
21
- if project_paths is not None:
22
- for project_path in project_paths:
23
- try:
24
- project = Project.load_from_file(project_path)
25
- if project.id == project_id:
26
- return project
27
- except Exception:
28
- # deleted files are possible continue with the rest
29
- continue
30
-
31
- return None
@@ -0,0 +1,62 @@
1
+ from typing import List
2
+
3
+ from pydantic import BaseModel, Field, model_validator
4
+ from typing_extensions import Self
5
+
6
+ from kiln_ai.datamodel.datamodel_enums import (
7
+ ModelProviderName,
8
+ StructuredOutputMode,
9
+ )
10
+ from kiln_ai.datamodel.prompt_id import PromptId
11
+ from kiln_ai.datamodel.tool_id import ToolId
12
+
13
+
14
+ class ToolsRunConfig(BaseModel):
15
+ """
16
+ A config describing which tools are available to a task.
17
+ """
18
+
19
+ tools: List[ToolId] = Field(
20
+ description="The IDs of the tools available to the task."
21
+ )
22
+
23
+
24
+ class RunConfigProperties(BaseModel):
25
+ """
26
+ A configuration for running a task.
27
+
28
+ This includes everything needed to run a task, except the input and task ID. Running the same RunConfig with the same input should make identical calls to the model (output may vary as models are non-deterministic).
29
+ """
30
+
31
+ model_name: str = Field(description="The model to use for this run config.")
32
+ model_provider_name: ModelProviderName = Field(
33
+ description="The provider to use for this run config."
34
+ )
35
+ prompt_id: PromptId = Field(
36
+ description="The prompt to use for this run config. Defaults to building a simple prompt from the task if not provided.",
37
+ )
38
+ top_p: float = Field(
39
+ default=1.0,
40
+ description="The top-p value to use for this run config. Defaults to 1.0.",
41
+ )
42
+ temperature: float = Field(
43
+ default=1.0,
44
+ description="The temperature to use for this run config. Defaults to 1.0.",
45
+ )
46
+ structured_output_mode: StructuredOutputMode = Field(
47
+ description="The structured output mode to use for this run config.",
48
+ )
49
+ tools_config: ToolsRunConfig | None = Field(
50
+ default=None,
51
+ description="The tools config to use for this run config, defining which tools are available to the model.",
52
+ )
53
+
54
+ @model_validator(mode="after")
55
+ def validate_required_fields(self) -> Self:
56
+ if not (0 <= self.top_p <= 1):
57
+ raise ValueError("top_p must be between 0 and 1")
58
+
59
+ elif self.temperature < 0 or self.temperature > 2:
60
+ raise ValueError("temperature must be between 0 and 2")
61
+
62
+ return self
kiln_ai/datamodel/task.py CHANGED
@@ -1,9 +1,7 @@
1
1
  from typing import TYPE_CHECKING, Dict, List, Union
2
2
 
3
3
  from pydantic import BaseModel, Field, ValidationInfo, model_validator
4
- from typing_extensions import Self
5
4
 
6
- from kiln_ai.datamodel import Finetune
7
5
  from kiln_ai.datamodel.basemodel import (
8
6
  ID_FIELD,
9
7
  ID_TYPE,
@@ -13,16 +11,16 @@ from kiln_ai.datamodel.basemodel import (
13
11
  KilnParentModel,
14
12
  )
15
13
  from kiln_ai.datamodel.datamodel_enums import (
16
- ModelProviderName,
17
14
  Priority,
18
15
  StructuredOutputMode,
19
16
  TaskOutputRatingType,
20
17
  )
21
18
  from kiln_ai.datamodel.dataset_split import DatasetSplit
22
19
  from kiln_ai.datamodel.eval import Eval
20
+ from kiln_ai.datamodel.finetune import Finetune
23
21
  from kiln_ai.datamodel.json_schema import JsonObjectSchema, schema_from_json_str
24
22
  from kiln_ai.datamodel.prompt import BasePrompt, Prompt
25
- from kiln_ai.datamodel.prompt_id import PromptId
23
+ from kiln_ai.datamodel.run_config import RunConfigProperties
26
24
  from kiln_ai.datamodel.task_run import TaskRun
27
25
 
28
26
  if TYPE_CHECKING:
@@ -45,55 +43,6 @@ class TaskRequirement(BaseModel):
45
43
  type: TaskOutputRatingType = Field(default=TaskOutputRatingType.five_star)
46
44
 
47
45
 
48
- class RunConfigProperties(BaseModel):
49
- """
50
- A configuration for running a task.
51
-
52
- This includes everything needed to run a task, except the input and task ID. Running the same RunConfig with the same input should make identical calls to the model (output may vary as models are non-deterministic).
53
- """
54
-
55
- model_name: str = Field(description="The model to use for this run config.")
56
- model_provider_name: ModelProviderName = Field(
57
- description="The provider to use for this run config."
58
- )
59
- prompt_id: PromptId = Field(
60
- description="The prompt to use for this run config. Defaults to building a simple prompt from the task if not provided.",
61
- )
62
- top_p: float = Field(
63
- default=1.0,
64
- description="The top-p value to use for this run config. Defaults to 1.0.",
65
- )
66
- temperature: float = Field(
67
- default=1.0,
68
- description="The temperature to use for this run config. Defaults to 1.0.",
69
- )
70
- structured_output_mode: StructuredOutputMode = Field(
71
- description="The structured output mode to use for this run config.",
72
- )
73
-
74
- @model_validator(mode="after")
75
- def validate_required_fields(self) -> Self:
76
- if not (0 <= self.top_p <= 1):
77
- raise ValueError("top_p must be between 0 and 1")
78
-
79
- elif self.temperature < 0 or self.temperature > 2:
80
- raise ValueError("temperature must be between 0 and 2")
81
-
82
- return self
83
-
84
-
85
- class RunConfig(RunConfigProperties):
86
- """
87
- A configuration for running a task.
88
-
89
- This includes everything needed to run a task, except the input. Running the same RunConfig with the same input should make identical calls to the model (output may vary as models are non-deterministic).
90
-
91
- For example: task, model, provider, prompt, etc.
92
- """
93
-
94
- task: "Task" = Field(description="The task to run.")
95
-
96
-
97
46
  class TaskRunConfig(KilnParentedModel):
98
47
  """
99
48
  A Kiln model for persisting a run config in a Kiln Project, nested under a task.
@@ -124,15 +73,6 @@ class TaskRunConfig(KilnParentedModel):
124
73
  return None
125
74
  return self.parent # type: ignore
126
75
 
127
- def run_config(self) -> RunConfig:
128
- parent_task = self.parent_task()
129
- if parent_task is None:
130
- raise ValueError("Run config must be parented to a task")
131
- return run_config_from_run_config_properties(
132
- task=parent_task,
133
- run_config_properties=self.run_config_properties,
134
- )
135
-
136
76
  # Previously we didn't store structured_output_mode in the run_config_properties. Updgrade old models when loading from file.
137
77
  @model_validator(mode="before")
138
78
  def upgrade_old_entries(cls, data: dict, info: ValidationInfo) -> dict:
@@ -155,21 +95,6 @@ class TaskRunConfig(KilnParentedModel):
155
95
  return data
156
96
 
157
97
 
158
- def run_config_from_run_config_properties(
159
- task: "Task",
160
- run_config_properties: RunConfigProperties,
161
- ) -> RunConfig:
162
- return RunConfig(
163
- task=task,
164
- model_name=run_config_properties.model_name,
165
- model_provider_name=run_config_properties.model_provider_name,
166
- prompt_id=run_config_properties.prompt_id,
167
- top_p=run_config_properties.top_p,
168
- temperature=run_config_properties.temperature,
169
- structured_output_mode=run_config_properties.structured_output_mode,
170
- )
171
-
172
-
173
98
  class Task(
174
99
  KilnParentedModel,
175
100
  KilnParentModel,
@@ -1,6 +1,6 @@
1
1
  import json
2
2
  from enum import Enum
3
- from typing import TYPE_CHECKING, Dict, List, Type, Union
3
+ from typing import TYPE_CHECKING, Dict, List, Optional, Type, Union
4
4
 
5
5
  from pydantic import BaseModel, Field, ValidationInfo, model_validator
6
6
  from typing_extensions import Self
@@ -8,6 +8,7 @@ from typing_extensions import Self
8
8
  from kiln_ai.datamodel.basemodel import ID_TYPE, KilnBaseModel
9
9
  from kiln_ai.datamodel.datamodel_enums import TaskOutputRatingType
10
10
  from kiln_ai.datamodel.json_schema import validate_schema_with_value_error
11
+ from kiln_ai.datamodel.run_config import RunConfigProperties
11
12
  from kiln_ai.datamodel.strict_mode import strict_mode
12
13
  from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
13
14
 
@@ -199,6 +200,10 @@ class DataSource(BaseModel):
199
200
  default={},
200
201
  description="Properties describing the data source. For synthetic things like model. For human, the human's name.",
201
202
  )
203
+ run_config: Optional[RunConfigProperties] = Field(
204
+ default=None,
205
+ description="The run config used to generate the data, if generated by a running a model in Kiln (only true for type=synthetic).",
206
+ )
202
207
 
203
208
  _data_source_properties = [
204
209
  DataSourceProperty(
@@ -8,6 +8,7 @@ from kiln_ai.datamodel.basemodel import KilnParentedModel
8
8
  from kiln_ai.datamodel.json_schema import validate_schema_with_value_error
9
9
  from kiln_ai.datamodel.strict_mode import strict_mode
10
10
  from kiln_ai.datamodel.task_output import DataSource, TaskOutput
11
+ from kiln_ai.utils.open_ai_types import ChatCompletionMessageParam
11
12
 
12
13
  if TYPE_CHECKING:
13
14
  from kiln_ai.datamodel.task import Task
@@ -35,6 +36,42 @@ class Usage(BaseModel):
35
36
  ge=0,
36
37
  )
37
38
 
39
+ def __add__(self, other: "Usage") -> "Usage":
40
+ """Add two Usage objects together, handling None values gracefully.
41
+
42
+ None + None = None
43
+ None + value = value
44
+ value + None = value
45
+ value1 + value2 = value1 + value2
46
+ """
47
+ if not isinstance(other, Usage):
48
+ raise TypeError(f"Cannot add Usage with {type(other).__name__}")
49
+
50
+ def _add_optional_int(a: int | None, b: int | None) -> int | None:
51
+ if a is None and b is None:
52
+ return None
53
+ if a is None:
54
+ return b
55
+ if b is None:
56
+ return a
57
+ return a + b
58
+
59
+ def _add_optional_float(a: float | None, b: float | None) -> float | None:
60
+ if a is None and b is None:
61
+ return None
62
+ if a is None:
63
+ return b
64
+ if b is None:
65
+ return a
66
+ return a + b
67
+
68
+ return Usage(
69
+ input_tokens=_add_optional_int(self.input_tokens, other.input_tokens),
70
+ output_tokens=_add_optional_int(self.output_tokens, other.output_tokens),
71
+ total_tokens=_add_optional_int(self.total_tokens, other.total_tokens),
72
+ cost=_add_optional_float(self.cost, other.cost),
73
+ )
74
+
38
75
 
39
76
  class TaskRun(KilnParentedModel):
40
77
  """
@@ -72,6 +109,10 @@ class TaskRun(KilnParentedModel):
72
109
  default=None,
73
110
  description="Usage information for the task run. This includes the number of input tokens, output tokens, and total tokens used.",
74
111
  )
112
+ trace: list[ChatCompletionMessageParam] | None = Field(
113
+ default=None,
114
+ description="The trace of the task run in OpenAI format. This is the list of messages that were sent to/from the model.",
115
+ )
75
116
 
76
117
  def thinking_training_data(self) -> str | None:
77
118
  """
@@ -17,7 +17,7 @@ from kiln_ai.datamodel.basemodel import (
17
17
  string_to_valid_name,
18
18
  )
19
19
  from kiln_ai.datamodel.model_cache import ModelCache
20
- from kiln_ai.datamodel.task import RunConfig
20
+ from kiln_ai.datamodel.task import RunConfigProperties
21
21
 
22
22
 
23
23
  @pytest.fixture
@@ -552,8 +552,8 @@ def base_task():
552
552
  @pytest.fixture
553
553
  def adapter(base_task):
554
554
  return MockAdapter(
555
- run_config=RunConfig(
556
- task=base_task,
555
+ task=base_task,
556
+ run_config=RunConfigProperties(
557
557
  model_name="test_model",
558
558
  model_provider_name="openai",
559
559
  prompt_id="simple_prompt_builder",