kiln-ai 0.19.0__py3-none-any.whl → 0.20.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (70) hide show
  1. kiln_ai/adapters/__init__.py +2 -2
  2. kiln_ai/adapters/adapter_registry.py +19 -1
  3. kiln_ai/adapters/chat/chat_formatter.py +8 -12
  4. kiln_ai/adapters/chat/test_chat_formatter.py +6 -2
  5. kiln_ai/adapters/docker_model_runner_tools.py +119 -0
  6. kiln_ai/adapters/eval/base_eval.py +2 -2
  7. kiln_ai/adapters/eval/eval_runner.py +3 -1
  8. kiln_ai/adapters/eval/g_eval.py +2 -2
  9. kiln_ai/adapters/eval/test_base_eval.py +1 -1
  10. kiln_ai/adapters/eval/test_g_eval.py +3 -4
  11. kiln_ai/adapters/fine_tune/__init__.py +1 -1
  12. kiln_ai/adapters/fine_tune/openai_finetune.py +14 -4
  13. kiln_ai/adapters/fine_tune/test_openai_finetune.py +108 -111
  14. kiln_ai/adapters/ml_model_list.py +380 -34
  15. kiln_ai/adapters/model_adapters/base_adapter.py +51 -21
  16. kiln_ai/adapters/model_adapters/litellm_adapter.py +383 -79
  17. kiln_ai/adapters/model_adapters/test_base_adapter.py +193 -17
  18. kiln_ai/adapters/model_adapters/test_litellm_adapter.py +406 -1
  19. kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py +1103 -0
  20. kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +5 -5
  21. kiln_ai/adapters/model_adapters/test_structured_output.py +110 -4
  22. kiln_ai/adapters/parsers/__init__.py +1 -1
  23. kiln_ai/adapters/provider_tools.py +15 -1
  24. kiln_ai/adapters/repair/test_repair_task.py +12 -9
  25. kiln_ai/adapters/run_output.py +3 -0
  26. kiln_ai/adapters/test_adapter_registry.py +80 -1
  27. kiln_ai/adapters/test_docker_model_runner_tools.py +305 -0
  28. kiln_ai/adapters/test_ml_model_list.py +39 -1
  29. kiln_ai/adapters/test_prompt_adaptors.py +13 -6
  30. kiln_ai/adapters/test_provider_tools.py +55 -0
  31. kiln_ai/adapters/test_remote_config.py +98 -0
  32. kiln_ai/datamodel/__init__.py +23 -21
  33. kiln_ai/datamodel/datamodel_enums.py +1 -0
  34. kiln_ai/datamodel/eval.py +1 -1
  35. kiln_ai/datamodel/external_tool_server.py +298 -0
  36. kiln_ai/datamodel/json_schema.py +25 -10
  37. kiln_ai/datamodel/project.py +8 -1
  38. kiln_ai/datamodel/registry.py +0 -15
  39. kiln_ai/datamodel/run_config.py +62 -0
  40. kiln_ai/datamodel/task.py +2 -77
  41. kiln_ai/datamodel/task_output.py +6 -1
  42. kiln_ai/datamodel/task_run.py +41 -0
  43. kiln_ai/datamodel/test_basemodel.py +3 -3
  44. kiln_ai/datamodel/test_example_models.py +175 -0
  45. kiln_ai/datamodel/test_external_tool_server.py +691 -0
  46. kiln_ai/datamodel/test_registry.py +8 -3
  47. kiln_ai/datamodel/test_task.py +15 -47
  48. kiln_ai/datamodel/test_tool_id.py +239 -0
  49. kiln_ai/datamodel/tool_id.py +83 -0
  50. kiln_ai/tools/__init__.py +8 -0
  51. kiln_ai/tools/base_tool.py +82 -0
  52. kiln_ai/tools/built_in_tools/__init__.py +13 -0
  53. kiln_ai/tools/built_in_tools/math_tools.py +124 -0
  54. kiln_ai/tools/built_in_tools/test_math_tools.py +204 -0
  55. kiln_ai/tools/mcp_server_tool.py +95 -0
  56. kiln_ai/tools/mcp_session_manager.py +243 -0
  57. kiln_ai/tools/test_base_tools.py +199 -0
  58. kiln_ai/tools/test_mcp_server_tool.py +457 -0
  59. kiln_ai/tools/test_mcp_session_manager.py +1585 -0
  60. kiln_ai/tools/test_tool_registry.py +473 -0
  61. kiln_ai/tools/tool_registry.py +64 -0
  62. kiln_ai/utils/config.py +22 -0
  63. kiln_ai/utils/open_ai_types.py +94 -0
  64. kiln_ai/utils/project_utils.py +17 -0
  65. kiln_ai/utils/test_config.py +138 -1
  66. kiln_ai/utils/test_open_ai_types.py +131 -0
  67. {kiln_ai-0.19.0.dist-info → kiln_ai-0.20.1.dist-info}/METADATA +6 -5
  68. {kiln_ai-0.19.0.dist-info → kiln_ai-0.20.1.dist-info}/RECORD +70 -47
  69. {kiln_ai-0.19.0.dist-info → kiln_ai-0.20.1.dist-info}/WHEEL +0 -0
  70. {kiln_ai-0.19.0.dist-info → kiln_ai-0.20.1.dist-info}/licenses/LICENSE.txt +0 -0
@@ -3,14 +3,19 @@ from unittest.mock import Mock, patch
3
3
  import pytest
4
4
 
5
5
  from kiln_ai.datamodel import Project
6
- from kiln_ai.datamodel.registry import all_projects, project_from_id
6
+ from kiln_ai.datamodel.registry import all_projects
7
+ from kiln_ai.utils.project_utils import project_from_id
7
8
 
8
9
 
9
10
  @pytest.fixture
10
11
  def mock_config():
11
- with patch("kiln_ai.datamodel.registry.Config") as mock:
12
+ with (
13
+ patch("kiln_ai.datamodel.registry.Config") as mock_registry,
14
+ patch("kiln_ai.utils.project_utils.Config") as mock_utils,
15
+ ):
12
16
  config_instance = Mock()
13
- mock.shared.return_value = config_instance
17
+ mock_registry.shared.return_value = config_instance
18
+ mock_utils.shared.return_value = config_instance
14
19
  yield config_instance
15
20
 
16
21
 
@@ -3,22 +3,18 @@ from pydantic import ValidationError
3
3
 
4
4
  from kiln_ai.datamodel.datamodel_enums import StructuredOutputMode, TaskOutputRatingType
5
5
  from kiln_ai.datamodel.prompt_id import PromptGenerators
6
- from kiln_ai.datamodel.task import RunConfig, RunConfigProperties, Task, TaskRunConfig
6
+ from kiln_ai.datamodel.task import RunConfigProperties, Task, TaskRunConfig
7
7
  from kiln_ai.datamodel.task_output import normalize_rating
8
8
 
9
9
 
10
10
  def test_runconfig_valid_creation():
11
- task = Task(id="task1", name="Test Task", instruction="Do something")
12
-
13
- config = RunConfig(
14
- task=task,
11
+ config = RunConfigProperties(
15
12
  model_name="gpt-4",
16
13
  model_provider_name="openai",
17
14
  prompt_id=PromptGenerators.SIMPLE,
18
15
  structured_output_mode="json_schema",
19
16
  )
20
17
 
21
- assert config.task == task
22
18
  assert config.model_name == "gpt-4"
23
19
  assert config.model_provider_name == "openai"
24
20
  assert config.prompt_id == PromptGenerators.SIMPLE # Check default value
@@ -26,13 +22,12 @@ def test_runconfig_valid_creation():
26
22
 
27
23
  def test_runconfig_missing_required_fields():
28
24
  with pytest.raises(ValidationError) as exc_info:
29
- RunConfig()
25
+ RunConfigProperties() # type: ignore
30
26
 
31
27
  errors = exc_info.value.errors()
32
28
  assert (
33
- len(errors) == 5
29
+ len(errors) == 4
34
30
  ) # task, model_name, model_provider_name, and prompt_id are required
35
- assert any(error["loc"][0] == "task" for error in errors)
36
31
  assert any(error["loc"][0] == "model_name" for error in errors)
37
32
  assert any(error["loc"][0] == "model_provider_name" for error in errors)
38
33
  assert any(error["loc"][0] == "prompt_id" for error in errors)
@@ -40,10 +35,7 @@ def test_runconfig_missing_required_fields():
40
35
 
41
36
 
42
37
  def test_runconfig_custom_prompt_id():
43
- task = Task(id="task1", name="Test Task", instruction="Do something")
44
-
45
- config = RunConfig(
46
- task=task,
38
+ config = RunConfigProperties(
47
39
  model_name="gpt-4",
48
40
  model_provider_name="openai",
49
41
  prompt_id=PromptGenerators.SIMPLE_CHAIN_OF_THOUGHT,
@@ -100,30 +92,18 @@ def test_task_run_config_missing_required_fields(sample_task):
100
92
  with pytest.raises(ValidationError) as exc_info:
101
93
  TaskRunConfig(
102
94
  run_config_properties=RunConfigProperties(
103
- task=sample_task, model_name="gpt-4", model_provider_name="openai"
104
- ),
95
+ model_name="gpt-4", model_provider_name="openai"
96
+ ), # type: ignore
105
97
  parent=sample_task,
106
- )
98
+ ) # type: ignore
107
99
  assert "Field required" in str(exc_info.value)
108
100
 
109
101
  # Test missing run_config
110
102
  with pytest.raises(ValidationError) as exc_info:
111
- TaskRunConfig(name="Test Config", parent=sample_task)
103
+ TaskRunConfig(name="Test Config", parent=sample_task) # type: ignore
112
104
  assert "Field required" in str(exc_info.value)
113
105
 
114
106
 
115
- def test_task_run_config_missing_task_in_run_config(sample_task):
116
- with pytest.raises(
117
- ValidationError, match="Input should be a valid dictionary or instance of Task"
118
- ):
119
- # Create a run config without a task
120
- RunConfig(
121
- model_name="gpt-4",
122
- model_provider_name="openai",
123
- task=None, # type: ignore
124
- )
125
-
126
-
127
107
  @pytest.mark.parametrize(
128
108
  "rating_type,rating,expected",
129
109
  [
@@ -165,10 +145,8 @@ def test_normalize_rating_errors(rating_type, rating):
165
145
 
166
146
  def test_run_config_defaults():
167
147
  """RunConfig should require top_p, temperature, and structured_output_mode to be set."""
168
- task = Task(id="task1", name="Test Task", instruction="Do something")
169
148
 
170
- config = RunConfig(
171
- task=task,
149
+ config = RunConfigProperties(
172
150
  model_name="gpt-4",
173
151
  model_provider_name="openai",
174
152
  prompt_id=PromptGenerators.SIMPLE,
@@ -180,11 +158,9 @@ def test_run_config_defaults():
180
158
 
181
159
  def test_run_config_valid_ranges():
182
160
  """RunConfig should accept valid ranges for top_p and temperature."""
183
- task = Task(id="task1", name="Test Task", instruction="Do something")
184
161
 
185
162
  # Test valid values
186
- config = RunConfig(
187
- task=task,
163
+ config = RunConfigProperties(
188
164
  model_name="gpt-4",
189
165
  model_provider_name="openai",
190
166
  prompt_id=PromptGenerators.SIMPLE,
@@ -201,10 +177,8 @@ def test_run_config_valid_ranges():
201
177
  @pytest.mark.parametrize("top_p", [0.0, 0.5, 1.0])
202
178
  def test_run_config_valid_top_p(top_p):
203
179
  """Test that RunConfig accepts valid top_p values (0-1)."""
204
- task = Task(id="task1", name="Test Task", instruction="Do something")
205
180
 
206
- config = RunConfig(
207
- task=task,
181
+ config = RunConfigProperties(
208
182
  model_name="gpt-4",
209
183
  model_provider_name="openai",
210
184
  prompt_id=PromptGenerators.SIMPLE,
@@ -219,11 +193,9 @@ def test_run_config_valid_top_p(top_p):
219
193
  @pytest.mark.parametrize("top_p", [-0.1, 1.1, 2.0])
220
194
  def test_run_config_invalid_top_p(top_p):
221
195
  """Test that RunConfig rejects invalid top_p values."""
222
- task = Task(id="task1", name="Test Task", instruction="Do something")
223
196
 
224
197
  with pytest.raises(ValueError, match="top_p must be between 0 and 1"):
225
- RunConfig(
226
- task=task,
198
+ RunConfigProperties(
227
199
  model_name="gpt-4",
228
200
  model_provider_name="openai",
229
201
  prompt_id=PromptGenerators.SIMPLE,
@@ -236,10 +208,8 @@ def test_run_config_invalid_top_p(top_p):
236
208
  @pytest.mark.parametrize("temperature", [0.0, 1.0, 2.0])
237
209
  def test_run_config_valid_temperature(temperature):
238
210
  """Test that RunConfig accepts valid temperature values (0-2)."""
239
- task = Task(id="task1", name="Test Task", instruction="Do something")
240
211
 
241
- config = RunConfig(
242
- task=task,
212
+ config = RunConfigProperties(
243
213
  model_name="gpt-4",
244
214
  model_provider_name="openai",
245
215
  prompt_id=PromptGenerators.SIMPLE,
@@ -254,11 +224,9 @@ def test_run_config_valid_temperature(temperature):
254
224
  @pytest.mark.parametrize("temperature", [-0.1, 2.1, 3.0])
255
225
  def test_run_config_invalid_temperature(temperature):
256
226
  """Test that RunConfig rejects invalid temperature values."""
257
- task = Task(id="task1", name="Test Task", instruction="Do something")
258
227
 
259
228
  with pytest.raises(ValueError, match="temperature must be between 0 and 2"):
260
- RunConfig(
261
- task=task,
229
+ RunConfigProperties(
262
230
  model_name="gpt-4",
263
231
  model_provider_name="openai",
264
232
  prompt_id=PromptGenerators.SIMPLE,
@@ -0,0 +1,239 @@
1
+ import pytest
2
+ from pydantic import BaseModel, ValidationError
3
+
4
+ from kiln_ai.datamodel.tool_id import (
5
+ MCP_LOCAL_TOOL_ID_PREFIX,
6
+ MCP_REMOTE_TOOL_ID_PREFIX,
7
+ KilnBuiltInToolId,
8
+ ToolId,
9
+ _check_tool_id,
10
+ mcp_server_and_tool_name_from_id,
11
+ )
12
+
13
+
14
+ class TestKilnBuiltInToolId:
15
+ """Test the KilnBuiltInToolId enum."""
16
+
17
+ def test_enum_values(self):
18
+ """Test that enum has expected values."""
19
+ assert KilnBuiltInToolId.ADD_NUMBERS == "kiln_tool::add_numbers"
20
+ assert KilnBuiltInToolId.SUBTRACT_NUMBERS == "kiln_tool::subtract_numbers"
21
+ assert KilnBuiltInToolId.MULTIPLY_NUMBERS == "kiln_tool::multiply_numbers"
22
+ assert KilnBuiltInToolId.DIVIDE_NUMBERS == "kiln_tool::divide_numbers"
23
+ for enum_value in KilnBuiltInToolId.__members__.values():
24
+ assert _check_tool_id(enum_value) == enum_value
25
+
26
+ def test_enum_membership(self):
27
+ """Test enum membership checks."""
28
+ assert "kiln_tool::add_numbers" in KilnBuiltInToolId.__members__.values()
29
+ assert "invalid_tool" not in KilnBuiltInToolId.__members__.values()
30
+
31
+
32
+ class TestCheckToolId:
33
+ """Test the _check_tool_id validation function."""
34
+
35
+ def test_valid_builtin_tools(self):
36
+ """Test validation of valid built-in tools."""
37
+ for tool_id in KilnBuiltInToolId:
38
+ result = _check_tool_id(tool_id.value)
39
+ assert result == tool_id.value
40
+
41
+ def test_valid_mcp_remote_tools(self):
42
+ """Test validation of valid MCP remote tools."""
43
+ valid_ids = [
44
+ "mcp::remote::server1::tool1",
45
+ "mcp::remote::my_server::my_tool",
46
+ "mcp::remote::test::function_name",
47
+ ]
48
+ for tool_id in valid_ids:
49
+ result = _check_tool_id(tool_id)
50
+ assert result == tool_id
51
+
52
+ def test_valid_mcp_local_tools(self):
53
+ """Test validation of valid MCP local tools."""
54
+ valid_ids = [
55
+ "mcp::local::server1::tool1",
56
+ "mcp::local::my_server::my_tool",
57
+ "mcp::local::test::function_name",
58
+ ]
59
+ for tool_id in valid_ids:
60
+ result = _check_tool_id(tool_id)
61
+ assert result == tool_id
62
+
63
+ def test_invalid_empty_or_none(self):
64
+ """Test validation fails for empty or None values."""
65
+ with pytest.raises(ValueError, match="Invalid tool ID"):
66
+ _check_tool_id("")
67
+
68
+ with pytest.raises(ValueError, match="Invalid tool ID"):
69
+ _check_tool_id(None) # type: ignore
70
+
71
+ def test_invalid_non_string(self):
72
+ """Test validation fails for non-string values."""
73
+ with pytest.raises(ValueError, match="Invalid tool ID"):
74
+ _check_tool_id(123) # type: ignore
75
+
76
+ with pytest.raises(ValueError, match="Invalid tool ID"):
77
+ _check_tool_id(["tool"]) # type: ignore
78
+
79
+ def test_invalid_unknown_tool(self):
80
+ """Test validation fails for unknown tool IDs."""
81
+ with pytest.raises(ValueError, match="Invalid tool ID: unknown_tool"):
82
+ _check_tool_id("unknown_tool")
83
+
84
+ def test_invalid_mcp_format(self):
85
+ """Test validation fails for invalid MCP tool formats."""
86
+ # These IDs start with the MCP remote prefix but have invalid formats
87
+ mcp_remote_invalid_ids = [
88
+ "mcp::remote::", # Missing server and tool
89
+ "mcp::remote::server", # Missing tool
90
+ "mcp::remote::server::", # Empty tool name
91
+ "mcp::remote::::tool", # Empty server name
92
+ "mcp::remote::server::tool::extra", # Too many parts
93
+ ]
94
+
95
+ for invalid_id in mcp_remote_invalid_ids:
96
+ with pytest.raises(ValueError, match="Invalid remote MCP tool ID"):
97
+ _check_tool_id(invalid_id)
98
+
99
+ # These IDs start with the MCP local prefix but have invalid formats
100
+ mcp_local_invalid_ids = [
101
+ "mcp::local::", # Missing server and tool
102
+ "mcp::local::server", # Missing tool
103
+ "mcp::local::server::", # Empty tool name
104
+ "mcp::local::::tool", # Empty server name
105
+ "mcp::local::server::tool::extra", # Too many parts
106
+ ]
107
+
108
+ for invalid_id in mcp_local_invalid_ids:
109
+ with pytest.raises(ValueError, match="Invalid local MCP tool ID"):
110
+ _check_tool_id(invalid_id)
111
+
112
+ # This ID doesn't start with MCP prefix so gets generic error
113
+ with pytest.raises(ValueError, match="Invalid tool ID"):
114
+ _check_tool_id("mcp::wrong::server::tool")
115
+
116
+
117
+ class TestMcpServerAndToolNameFromId:
118
+ """Test the mcp_server_and_tool_name_from_id function."""
119
+
120
+ def test_valid_mcp_ids(self):
121
+ """Test parsing valid MCP tool IDs."""
122
+ test_cases = [
123
+ # Remote MCP tools
124
+ ("mcp::remote::server1::tool1", ("server1", "tool1")),
125
+ ("mcp::remote::my_server::my_tool", ("my_server", "my_tool")),
126
+ ("mcp::remote::test::function_name", ("test", "function_name")),
127
+ # Local MCP tools
128
+ ("mcp::local::server1::tool1", ("server1", "tool1")),
129
+ ("mcp::local::my_server::my_tool", ("my_server", "my_tool")),
130
+ ("mcp::local::test::function_name", ("test", "function_name")),
131
+ ]
132
+
133
+ for tool_id, expected in test_cases:
134
+ result = mcp_server_and_tool_name_from_id(tool_id)
135
+ assert result == expected
136
+
137
+ def test_invalid_mcp_ids(self):
138
+ """Test parsing fails for invalid MCP tool IDs."""
139
+ # Test remote MCP tool ID errors
140
+ remote_invalid_ids = [
141
+ "mcp::remote::", # Only 3 parts
142
+ "mcp::remote::server", # Only 3 parts
143
+ "mcp::remote::server::tool::extra", # 5 parts
144
+ ]
145
+
146
+ for invalid_id in remote_invalid_ids:
147
+ with pytest.raises(ValueError, match="Invalid remote MCP tool ID"):
148
+ mcp_server_and_tool_name_from_id(invalid_id)
149
+
150
+ # Test local MCP tool ID errors
151
+ local_invalid_ids = [
152
+ "mcp::local::", # Only 3 parts
153
+ "mcp::local::server", # Only 3 parts
154
+ "mcp::local::server::tool::extra", # 5 parts
155
+ ]
156
+
157
+ for invalid_id in local_invalid_ids:
158
+ with pytest.raises(ValueError, match="Invalid local MCP tool ID"):
159
+ mcp_server_and_tool_name_from_id(invalid_id)
160
+
161
+ # Test generic MCP tool ID errors (not remote or local)
162
+ generic_invalid_ids = [
163
+ "not_mcp_format", # Only 1 part
164
+ "single_part", # Only 1 part
165
+ "", # Empty string
166
+ ]
167
+
168
+ for invalid_id in generic_invalid_ids:
169
+ with pytest.raises(ValueError, match="Invalid MCP tool ID"):
170
+ mcp_server_and_tool_name_from_id(invalid_id)
171
+
172
+ def test_mcp_ids_with_wrong_prefix_still_parse(self):
173
+ """Test that IDs with wrong prefix but correct structure still parse (validation happens elsewhere)."""
174
+ # This function only checks structure (4 parts), not content
175
+ result = mcp_server_and_tool_name_from_id("mcp::wrong::server::tool")
176
+ assert result == ("server", "tool")
177
+
178
+
179
+ class TestToolIdPydanticType:
180
+ """Test the ToolId pydantic type annotation."""
181
+
182
+ class _ModelWithToolId(BaseModel):
183
+ tool_id: ToolId
184
+
185
+ def test_valid_builtin_tools(self):
186
+ """Test ToolId validates built-in tools."""
187
+ for tool_id in KilnBuiltInToolId:
188
+ model = self._ModelWithToolId(tool_id=tool_id.value)
189
+ assert model.tool_id == tool_id.value
190
+
191
+ def test_valid_mcp_tools(self):
192
+ """Test ToolId validates MCP remote and local tools."""
193
+ valid_ids = [
194
+ # Remote MCP tools
195
+ "mcp::remote::server1::tool1",
196
+ "mcp::remote::my_server::my_tool",
197
+ # Local MCP tools
198
+ "mcp::local::server1::tool1",
199
+ "mcp::local::my_server::my_tool",
200
+ ]
201
+
202
+ for tool_id in valid_ids:
203
+ model = self._ModelWithToolId(tool_id=tool_id)
204
+ assert model.tool_id == tool_id
205
+
206
+ def test_invalid_tools_raise_validation_error(self):
207
+ """Test ToolId raises ValidationError for invalid tools."""
208
+ invalid_ids = [
209
+ "",
210
+ "unknown_tool",
211
+ "mcp::remote::",
212
+ "mcp::remote::server",
213
+ "mcp::local::",
214
+ "mcp::local::server",
215
+ ]
216
+
217
+ for invalid_id in invalid_ids:
218
+ with pytest.raises(ValidationError):
219
+ self._ModelWithToolId(tool_id=invalid_id)
220
+
221
+ def test_non_string_raises_validation_error(self):
222
+ """Test ToolId raises ValidationError for non-string values."""
223
+ with pytest.raises(ValidationError):
224
+ self._ModelWithToolId(tool_id=123) # type: ignore
225
+
226
+ with pytest.raises(ValidationError):
227
+ self._ModelWithToolId(tool_id=None) # type: ignore
228
+
229
+
230
+ class TestConstants:
231
+ """Test module constants."""
232
+
233
+ def test_mcp_remote_tool_id_prefix(self):
234
+ """Test the MCP remote tool ID prefix constant."""
235
+ assert MCP_REMOTE_TOOL_ID_PREFIX == "mcp::remote::"
236
+
237
+ def test_mcp_local_tool_id_prefix(self):
238
+ """Test the MCP local tool ID prefix constant."""
239
+ assert MCP_LOCAL_TOOL_ID_PREFIX == "mcp::local::"
@@ -0,0 +1,83 @@
1
+ from enum import Enum
2
+ from typing import Annotated
3
+
4
+ from pydantic import AfterValidator
5
+
6
+ ToolId = Annotated[
7
+ str,
8
+ AfterValidator(lambda v: _check_tool_id(v)),
9
+ ]
10
+ """
11
+ A pydantic type that validates strings containing a valid tool ID.
12
+
13
+ Tool IDs can be one of:
14
+ - A kiln built-in tool name: kiln_tool::add_numbers
15
+ - A remote MCP tool: mcp::remote::<server_id>::<tool_name>
16
+ - A local MCP tool: mcp::local::<server_id>::<tool_name>
17
+ - More coming soon like kiln_project_tool::rag::RAG_CONFIG_ID
18
+ """
19
+
20
+
21
+ class KilnBuiltInToolId(str, Enum):
22
+ ADD_NUMBERS = "kiln_tool::add_numbers"
23
+ SUBTRACT_NUMBERS = "kiln_tool::subtract_numbers"
24
+ MULTIPLY_NUMBERS = "kiln_tool::multiply_numbers"
25
+ DIVIDE_NUMBERS = "kiln_tool::divide_numbers"
26
+
27
+
28
+ MCP_REMOTE_TOOL_ID_PREFIX = "mcp::remote::"
29
+ MCP_LOCAL_TOOL_ID_PREFIX = "mcp::local::"
30
+
31
+
32
+ def _check_tool_id(id: str) -> str:
33
+ """
34
+ Check that the tool ID is valid.
35
+ """
36
+ if not id or not isinstance(id, str):
37
+ raise ValueError(f"Invalid tool ID: {id}")
38
+
39
+ # Build in tools
40
+ if id in KilnBuiltInToolId.__members__.values():
41
+ return id
42
+
43
+ # MCP remote tools must have format: mcp::remote::<server_id>::<tool_name>
44
+ if id.startswith(MCP_REMOTE_TOOL_ID_PREFIX):
45
+ server_id, tool_name = mcp_server_and_tool_name_from_id(id)
46
+ if not server_id or not tool_name:
47
+ raise ValueError(
48
+ f"Invalid remote MCP tool ID: {id}. Expected format: 'mcp::remote::<server_id>::<tool_name>'."
49
+ )
50
+ return id
51
+
52
+ # MCP local tools must have format: mcp::local::<server_id>::<tool_name>
53
+ if id.startswith(MCP_LOCAL_TOOL_ID_PREFIX):
54
+ server_id, tool_name = mcp_server_and_tool_name_from_id(id)
55
+ if not server_id or not tool_name:
56
+ raise ValueError(
57
+ f"Invalid local MCP tool ID: {id}. Expected format: 'mcp::local::<server_id>::<tool_name>'."
58
+ )
59
+ return id
60
+
61
+ raise ValueError(f"Invalid tool ID: {id}")
62
+
63
+
64
+ def mcp_server_and_tool_name_from_id(id: str) -> tuple[str, str]:
65
+ """
66
+ Get the tool server ID and tool name from the ID.
67
+ """
68
+ parts = id.split("::")
69
+ if len(parts) != 4:
70
+ # Determine if it's remote or local for the error message
71
+ if id.startswith(MCP_REMOTE_TOOL_ID_PREFIX):
72
+ raise ValueError(
73
+ f"Invalid remote MCP tool ID: {id}. Expected format: 'mcp::remote::<server_id>::<tool_name>'."
74
+ )
75
+ elif id.startswith(MCP_LOCAL_TOOL_ID_PREFIX):
76
+ raise ValueError(
77
+ f"Invalid local MCP tool ID: {id}. Expected format: 'mcp::local::<server_id>::<tool_name>'."
78
+ )
79
+ else:
80
+ raise ValueError(
81
+ f"Invalid MCP tool ID: {id}. Expected format: 'mcp::(remote|local)::<server_id>::<tool_name>'."
82
+ )
83
+ return parts[2], parts[3] # server_id, tool_name
@@ -0,0 +1,8 @@
1
+ from kiln_ai.tools.base_tool import KilnTool, KilnToolInterface
2
+ from kiln_ai.tools.tool_registry import tool_from_id
3
+
4
+ __all__ = [
5
+ "KilnTool",
6
+ "KilnToolInterface",
7
+ "tool_from_id",
8
+ ]
@@ -0,0 +1,82 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Any, Dict
3
+
4
+ from kiln_ai.datamodel.json_schema import validate_schema_dict
5
+ from kiln_ai.datamodel.tool_id import KilnBuiltInToolId, ToolId
6
+
7
+
8
+ class KilnToolInterface(ABC):
9
+ """
10
+ Abstract interface defining the core API that all Kiln tools must implement.
11
+ This ensures consistency across all tool implementations.
12
+ """
13
+
14
+ @abstractmethod
15
+ async def run(self, **kwargs) -> Any:
16
+ """Execute the tool with the given parameters."""
17
+ pass
18
+
19
+ @abstractmethod
20
+ async def toolcall_definition(self) -> Dict[str, Any]:
21
+ """Return the OpenAI-compatible tool definition for this tool."""
22
+ pass
23
+
24
+ @abstractmethod
25
+ async def id(self) -> ToolId:
26
+ """Return a unique identifier for this tool."""
27
+ pass
28
+
29
+ @abstractmethod
30
+ async def name(self) -> str:
31
+ """Return the tool name (function name) of this tool."""
32
+ pass
33
+
34
+ @abstractmethod
35
+ async def description(self) -> str:
36
+ """Return a description of what this tool does."""
37
+ pass
38
+
39
+
40
+ class KilnTool(KilnToolInterface):
41
+ """
42
+ Base helper class that provides common functionality for tool implementations.
43
+ Subclasses only need to implement run() and provide tool configuration.
44
+ """
45
+
46
+ def __init__(
47
+ self,
48
+ tool_id: KilnBuiltInToolId,
49
+ name: str,
50
+ description: str,
51
+ parameters_schema: Dict[str, Any],
52
+ ):
53
+ self._id = tool_id
54
+ self._name = name
55
+ self._description = description
56
+ validate_schema_dict(parameters_schema)
57
+ self._parameters_schema = parameters_schema
58
+
59
+ async def id(self) -> KilnBuiltInToolId:
60
+ return self._id
61
+
62
+ async def name(self) -> str:
63
+ return self._name
64
+
65
+ async def description(self) -> str:
66
+ return self._description
67
+
68
+ async def toolcall_definition(self) -> Dict[str, Any]:
69
+ """Generate OpenAI-compatible tool definition."""
70
+ return {
71
+ "type": "function",
72
+ "function": {
73
+ "name": await self.name(),
74
+ "description": await self.description(),
75
+ "parameters": self._parameters_schema,
76
+ },
77
+ }
78
+
79
+ @abstractmethod
80
+ async def run(self, **kwargs) -> str:
81
+ """Subclasses must implement the actual tool logic."""
82
+ pass
@@ -0,0 +1,13 @@
1
+ from kiln_ai.tools.built_in_tools.math_tools import (
2
+ AddTool,
3
+ DivideTool,
4
+ MultiplyTool,
5
+ SubtractTool,
6
+ )
7
+
8
+ __all__ = [
9
+ "AddTool",
10
+ "DivideTool",
11
+ "MultiplyTool",
12
+ "SubtractTool",
13
+ ]