kiln-ai 0.18.0__py3-none-any.whl → 0.20.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +2 -2
- kiln_ai/adapters/adapter_registry.py +46 -0
- kiln_ai/adapters/chat/chat_formatter.py +8 -12
- kiln_ai/adapters/chat/test_chat_formatter.py +6 -2
- kiln_ai/adapters/data_gen/data_gen_task.py +2 -2
- kiln_ai/adapters/data_gen/test_data_gen_task.py +7 -3
- kiln_ai/adapters/docker_model_runner_tools.py +119 -0
- kiln_ai/adapters/eval/base_eval.py +2 -2
- kiln_ai/adapters/eval/eval_runner.py +3 -1
- kiln_ai/adapters/eval/g_eval.py +2 -2
- kiln_ai/adapters/eval/test_base_eval.py +1 -1
- kiln_ai/adapters/eval/test_eval_runner.py +6 -12
- kiln_ai/adapters/eval/test_g_eval.py +3 -4
- kiln_ai/adapters/eval/test_g_eval_data.py +1 -1
- kiln_ai/adapters/fine_tune/__init__.py +1 -1
- kiln_ai/adapters/fine_tune/base_finetune.py +1 -0
- kiln_ai/adapters/fine_tune/fireworks_finetune.py +32 -20
- kiln_ai/adapters/fine_tune/openai_finetune.py +14 -4
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +30 -21
- kiln_ai/adapters/fine_tune/test_openai_finetune.py +108 -111
- kiln_ai/adapters/ml_model_list.py +1009 -111
- kiln_ai/adapters/model_adapters/base_adapter.py +62 -28
- kiln_ai/adapters/model_adapters/litellm_adapter.py +397 -80
- kiln_ai/adapters/model_adapters/test_base_adapter.py +194 -18
- kiln_ai/adapters/model_adapters/test_litellm_adapter.py +428 -4
- kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py +1103 -0
- kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +5 -5
- kiln_ai/adapters/model_adapters/test_structured_output.py +120 -14
- kiln_ai/adapters/parsers/__init__.py +1 -1
- kiln_ai/adapters/parsers/test_r1_parser.py +1 -1
- kiln_ai/adapters/provider_tools.py +35 -20
- kiln_ai/adapters/remote_config.py +57 -10
- kiln_ai/adapters/repair/repair_task.py +1 -1
- kiln_ai/adapters/repair/test_repair_task.py +12 -9
- kiln_ai/adapters/run_output.py +3 -0
- kiln_ai/adapters/test_adapter_registry.py +109 -2
- kiln_ai/adapters/test_docker_model_runner_tools.py +305 -0
- kiln_ai/adapters/test_ml_model_list.py +51 -1
- kiln_ai/adapters/test_prompt_adaptors.py +13 -6
- kiln_ai/adapters/test_provider_tools.py +73 -12
- kiln_ai/adapters/test_remote_config.py +470 -16
- kiln_ai/datamodel/__init__.py +23 -21
- kiln_ai/datamodel/basemodel.py +54 -28
- kiln_ai/datamodel/datamodel_enums.py +3 -0
- kiln_ai/datamodel/dataset_split.py +5 -3
- kiln_ai/datamodel/eval.py +4 -4
- kiln_ai/datamodel/external_tool_server.py +298 -0
- kiln_ai/datamodel/finetune.py +2 -2
- kiln_ai/datamodel/json_schema.py +25 -10
- kiln_ai/datamodel/project.py +11 -4
- kiln_ai/datamodel/prompt.py +2 -2
- kiln_ai/datamodel/prompt_id.py +4 -4
- kiln_ai/datamodel/registry.py +0 -15
- kiln_ai/datamodel/run_config.py +62 -0
- kiln_ai/datamodel/task.py +8 -83
- kiln_ai/datamodel/task_output.py +7 -2
- kiln_ai/datamodel/task_run.py +41 -0
- kiln_ai/datamodel/test_basemodel.py +213 -21
- kiln_ai/datamodel/test_eval_model.py +6 -6
- kiln_ai/datamodel/test_example_models.py +175 -0
- kiln_ai/datamodel/test_external_tool_server.py +691 -0
- kiln_ai/datamodel/test_model_perf.py +1 -1
- kiln_ai/datamodel/test_prompt_id.py +5 -1
- kiln_ai/datamodel/test_registry.py +8 -3
- kiln_ai/datamodel/test_task.py +20 -47
- kiln_ai/datamodel/test_tool_id.py +239 -0
- kiln_ai/datamodel/tool_id.py +83 -0
- kiln_ai/tools/__init__.py +8 -0
- kiln_ai/tools/base_tool.py +82 -0
- kiln_ai/tools/built_in_tools/__init__.py +13 -0
- kiln_ai/tools/built_in_tools/math_tools.py +124 -0
- kiln_ai/tools/built_in_tools/test_math_tools.py +204 -0
- kiln_ai/tools/mcp_server_tool.py +95 -0
- kiln_ai/tools/mcp_session_manager.py +243 -0
- kiln_ai/tools/test_base_tools.py +199 -0
- kiln_ai/tools/test_mcp_server_tool.py +457 -0
- kiln_ai/tools/test_mcp_session_manager.py +1585 -0
- kiln_ai/tools/test_tool_registry.py +473 -0
- kiln_ai/tools/tool_registry.py +64 -0
- kiln_ai/utils/config.py +32 -0
- kiln_ai/utils/open_ai_types.py +94 -0
- kiln_ai/utils/project_utils.py +17 -0
- kiln_ai/utils/test_config.py +138 -1
- kiln_ai/utils/test_open_ai_types.py +131 -0
- {kiln_ai-0.18.0.dist-info → kiln_ai-0.20.1.dist-info}/METADATA +37 -6
- kiln_ai-0.20.1.dist-info/RECORD +138 -0
- kiln_ai-0.18.0.dist-info/RECORD +0 -115
- {kiln_ai-0.18.0.dist-info → kiln_ai-0.20.1.dist-info}/WHEEL +0 -0
- {kiln_ai-0.18.0.dist-info → kiln_ai-0.20.1.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -29,7 +29,7 @@ def test_valid_saved_prompt_id():
|
|
|
29
29
|
|
|
30
30
|
def test_valid_fine_tune_prompt_id():
|
|
31
31
|
"""Test that valid fine-tune prompt IDs are accepted"""
|
|
32
|
-
valid_id = "fine_tune_prompt::ft_123456"
|
|
32
|
+
valid_id = "fine_tune_prompt::project_123::task_456::ft_123456"
|
|
33
33
|
model = ModelTester(prompt_id=valid_id)
|
|
34
34
|
assert model.prompt_id == valid_id
|
|
35
35
|
|
|
@@ -53,6 +53,10 @@ def test_invalid_saved_prompt_id_format(invalid_id):
|
|
|
53
53
|
[
|
|
54
54
|
("fine_tune_prompt::", "Invalid fine-tune prompt ID: fine_tune_prompt::"),
|
|
55
55
|
("fine_tune_prompt", "Invalid prompt ID: fine_tune_prompt"),
|
|
56
|
+
(
|
|
57
|
+
"fine_tune_prompt::ft_123456",
|
|
58
|
+
"Invalid fine-tune prompt ID: fine_tune_prompt::ft_123456",
|
|
59
|
+
),
|
|
56
60
|
],
|
|
57
61
|
)
|
|
58
62
|
def test_invalid_fine_tune_prompt_id_format(invalid_id, expected_error):
|
|
@@ -3,14 +3,19 @@ from unittest.mock import Mock, patch
|
|
|
3
3
|
import pytest
|
|
4
4
|
|
|
5
5
|
from kiln_ai.datamodel import Project
|
|
6
|
-
from kiln_ai.datamodel.registry import all_projects
|
|
6
|
+
from kiln_ai.datamodel.registry import all_projects
|
|
7
|
+
from kiln_ai.utils.project_utils import project_from_id
|
|
7
8
|
|
|
8
9
|
|
|
9
10
|
@pytest.fixture
|
|
10
11
|
def mock_config():
|
|
11
|
-
with
|
|
12
|
+
with (
|
|
13
|
+
patch("kiln_ai.datamodel.registry.Config") as mock_registry,
|
|
14
|
+
patch("kiln_ai.utils.project_utils.Config") as mock_utils,
|
|
15
|
+
):
|
|
12
16
|
config_instance = Mock()
|
|
13
|
-
|
|
17
|
+
mock_registry.shared.return_value = config_instance
|
|
18
|
+
mock_utils.shared.return_value = config_instance
|
|
14
19
|
yield config_instance
|
|
15
20
|
|
|
16
21
|
|
kiln_ai/datamodel/test_task.py
CHANGED
|
@@ -3,22 +3,18 @@ from pydantic import ValidationError
|
|
|
3
3
|
|
|
4
4
|
from kiln_ai.datamodel.datamodel_enums import StructuredOutputMode, TaskOutputRatingType
|
|
5
5
|
from kiln_ai.datamodel.prompt_id import PromptGenerators
|
|
6
|
-
from kiln_ai.datamodel.task import
|
|
6
|
+
from kiln_ai.datamodel.task import RunConfigProperties, Task, TaskRunConfig
|
|
7
7
|
from kiln_ai.datamodel.task_output import normalize_rating
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
def test_runconfig_valid_creation():
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
config = RunConfig(
|
|
14
|
-
task=task,
|
|
11
|
+
config = RunConfigProperties(
|
|
15
12
|
model_name="gpt-4",
|
|
16
13
|
model_provider_name="openai",
|
|
17
14
|
prompt_id=PromptGenerators.SIMPLE,
|
|
18
15
|
structured_output_mode="json_schema",
|
|
19
16
|
)
|
|
20
17
|
|
|
21
|
-
assert config.task == task
|
|
22
18
|
assert config.model_name == "gpt-4"
|
|
23
19
|
assert config.model_provider_name == "openai"
|
|
24
20
|
assert config.prompt_id == PromptGenerators.SIMPLE # Check default value
|
|
@@ -26,13 +22,12 @@ def test_runconfig_valid_creation():
|
|
|
26
22
|
|
|
27
23
|
def test_runconfig_missing_required_fields():
|
|
28
24
|
with pytest.raises(ValidationError) as exc_info:
|
|
29
|
-
|
|
25
|
+
RunConfigProperties() # type: ignore
|
|
30
26
|
|
|
31
27
|
errors = exc_info.value.errors()
|
|
32
28
|
assert (
|
|
33
|
-
len(errors) ==
|
|
29
|
+
len(errors) == 4
|
|
34
30
|
) # task, model_name, model_provider_name, and prompt_id are required
|
|
35
|
-
assert any(error["loc"][0] == "task" for error in errors)
|
|
36
31
|
assert any(error["loc"][0] == "model_name" for error in errors)
|
|
37
32
|
assert any(error["loc"][0] == "model_provider_name" for error in errors)
|
|
38
33
|
assert any(error["loc"][0] == "prompt_id" for error in errors)
|
|
@@ -40,10 +35,7 @@ def test_runconfig_missing_required_fields():
|
|
|
40
35
|
|
|
41
36
|
|
|
42
37
|
def test_runconfig_custom_prompt_id():
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
config = RunConfig(
|
|
46
|
-
task=task,
|
|
38
|
+
config = RunConfigProperties(
|
|
47
39
|
model_name="gpt-4",
|
|
48
40
|
model_provider_name="openai",
|
|
49
41
|
prompt_id=PromptGenerators.SIMPLE_CHAIN_OF_THOUGHT,
|
|
@@ -100,30 +92,18 @@ def test_task_run_config_missing_required_fields(sample_task):
|
|
|
100
92
|
with pytest.raises(ValidationError) as exc_info:
|
|
101
93
|
TaskRunConfig(
|
|
102
94
|
run_config_properties=RunConfigProperties(
|
|
103
|
-
|
|
104
|
-
),
|
|
95
|
+
model_name="gpt-4", model_provider_name="openai"
|
|
96
|
+
), # type: ignore
|
|
105
97
|
parent=sample_task,
|
|
106
|
-
)
|
|
98
|
+
) # type: ignore
|
|
107
99
|
assert "Field required" in str(exc_info.value)
|
|
108
100
|
|
|
109
101
|
# Test missing run_config
|
|
110
102
|
with pytest.raises(ValidationError) as exc_info:
|
|
111
|
-
TaskRunConfig(name="Test Config", parent=sample_task)
|
|
103
|
+
TaskRunConfig(name="Test Config", parent=sample_task) # type: ignore
|
|
112
104
|
assert "Field required" in str(exc_info.value)
|
|
113
105
|
|
|
114
106
|
|
|
115
|
-
def test_task_run_config_missing_task_in_run_config(sample_task):
|
|
116
|
-
with pytest.raises(
|
|
117
|
-
ValidationError, match="Input should be a valid dictionary or instance of Task"
|
|
118
|
-
):
|
|
119
|
-
# Create a run config without a task
|
|
120
|
-
RunConfig(
|
|
121
|
-
model_name="gpt-4",
|
|
122
|
-
model_provider_name="openai",
|
|
123
|
-
task=None, # type: ignore
|
|
124
|
-
)
|
|
125
|
-
|
|
126
|
-
|
|
127
107
|
@pytest.mark.parametrize(
|
|
128
108
|
"rating_type,rating,expected",
|
|
129
109
|
[
|
|
@@ -165,10 +145,8 @@ def test_normalize_rating_errors(rating_type, rating):
|
|
|
165
145
|
|
|
166
146
|
def test_run_config_defaults():
|
|
167
147
|
"""RunConfig should require top_p, temperature, and structured_output_mode to be set."""
|
|
168
|
-
task = Task(id="task1", name="Test Task", instruction="Do something")
|
|
169
148
|
|
|
170
|
-
config =
|
|
171
|
-
task=task,
|
|
149
|
+
config = RunConfigProperties(
|
|
172
150
|
model_name="gpt-4",
|
|
173
151
|
model_provider_name="openai",
|
|
174
152
|
prompt_id=PromptGenerators.SIMPLE,
|
|
@@ -180,11 +158,9 @@ def test_run_config_defaults():
|
|
|
180
158
|
|
|
181
159
|
def test_run_config_valid_ranges():
|
|
182
160
|
"""RunConfig should accept valid ranges for top_p and temperature."""
|
|
183
|
-
task = Task(id="task1", name="Test Task", instruction="Do something")
|
|
184
161
|
|
|
185
162
|
# Test valid values
|
|
186
|
-
config =
|
|
187
|
-
task=task,
|
|
163
|
+
config = RunConfigProperties(
|
|
188
164
|
model_name="gpt-4",
|
|
189
165
|
model_provider_name="openai",
|
|
190
166
|
prompt_id=PromptGenerators.SIMPLE,
|
|
@@ -201,10 +177,8 @@ def test_run_config_valid_ranges():
|
|
|
201
177
|
@pytest.mark.parametrize("top_p", [0.0, 0.5, 1.0])
|
|
202
178
|
def test_run_config_valid_top_p(top_p):
|
|
203
179
|
"""Test that RunConfig accepts valid top_p values (0-1)."""
|
|
204
|
-
task = Task(id="task1", name="Test Task", instruction="Do something")
|
|
205
180
|
|
|
206
|
-
config =
|
|
207
|
-
task=task,
|
|
181
|
+
config = RunConfigProperties(
|
|
208
182
|
model_name="gpt-4",
|
|
209
183
|
model_provider_name="openai",
|
|
210
184
|
prompt_id=PromptGenerators.SIMPLE,
|
|
@@ -219,11 +193,9 @@ def test_run_config_valid_top_p(top_p):
|
|
|
219
193
|
@pytest.mark.parametrize("top_p", [-0.1, 1.1, 2.0])
|
|
220
194
|
def test_run_config_invalid_top_p(top_p):
|
|
221
195
|
"""Test that RunConfig rejects invalid top_p values."""
|
|
222
|
-
task = Task(id="task1", name="Test Task", instruction="Do something")
|
|
223
196
|
|
|
224
197
|
with pytest.raises(ValueError, match="top_p must be between 0 and 1"):
|
|
225
|
-
|
|
226
|
-
task=task,
|
|
198
|
+
RunConfigProperties(
|
|
227
199
|
model_name="gpt-4",
|
|
228
200
|
model_provider_name="openai",
|
|
229
201
|
prompt_id=PromptGenerators.SIMPLE,
|
|
@@ -236,10 +208,8 @@ def test_run_config_invalid_top_p(top_p):
|
|
|
236
208
|
@pytest.mark.parametrize("temperature", [0.0, 1.0, 2.0])
|
|
237
209
|
def test_run_config_valid_temperature(temperature):
|
|
238
210
|
"""Test that RunConfig accepts valid temperature values (0-2)."""
|
|
239
|
-
task = Task(id="task1", name="Test Task", instruction="Do something")
|
|
240
211
|
|
|
241
|
-
config =
|
|
242
|
-
task=task,
|
|
212
|
+
config = RunConfigProperties(
|
|
243
213
|
model_name="gpt-4",
|
|
244
214
|
model_provider_name="openai",
|
|
245
215
|
prompt_id=PromptGenerators.SIMPLE,
|
|
@@ -254,11 +224,9 @@ def test_run_config_valid_temperature(temperature):
|
|
|
254
224
|
@pytest.mark.parametrize("temperature", [-0.1, 2.1, 3.0])
|
|
255
225
|
def test_run_config_invalid_temperature(temperature):
|
|
256
226
|
"""Test that RunConfig rejects invalid temperature values."""
|
|
257
|
-
task = Task(id="task1", name="Test Task", instruction="Do something")
|
|
258
227
|
|
|
259
228
|
with pytest.raises(ValueError, match="temperature must be between 0 and 2"):
|
|
260
|
-
|
|
261
|
-
task=task,
|
|
229
|
+
RunConfigProperties(
|
|
262
230
|
model_name="gpt-4",
|
|
263
231
|
model_provider_name="openai",
|
|
264
232
|
prompt_id=PromptGenerators.SIMPLE,
|
|
@@ -323,3 +291,8 @@ def test_run_config_upgrade_old_entries():
|
|
|
323
291
|
assert parsed.name == "test name"
|
|
324
292
|
assert parsed.created_by == "scosman"
|
|
325
293
|
assert parsed.run_config_properties.structured_output_mode == "unknown"
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
def test_task_name_unicode_name():
|
|
297
|
+
task = Task(name="你好", instruction="Do something")
|
|
298
|
+
assert task.name == "你好"
|
|
@@ -0,0 +1,239 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
from pydantic import BaseModel, ValidationError
|
|
3
|
+
|
|
4
|
+
from kiln_ai.datamodel.tool_id import (
|
|
5
|
+
MCP_LOCAL_TOOL_ID_PREFIX,
|
|
6
|
+
MCP_REMOTE_TOOL_ID_PREFIX,
|
|
7
|
+
KilnBuiltInToolId,
|
|
8
|
+
ToolId,
|
|
9
|
+
_check_tool_id,
|
|
10
|
+
mcp_server_and_tool_name_from_id,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class TestKilnBuiltInToolId:
|
|
15
|
+
"""Test the KilnBuiltInToolId enum."""
|
|
16
|
+
|
|
17
|
+
def test_enum_values(self):
|
|
18
|
+
"""Test that enum has expected values."""
|
|
19
|
+
assert KilnBuiltInToolId.ADD_NUMBERS == "kiln_tool::add_numbers"
|
|
20
|
+
assert KilnBuiltInToolId.SUBTRACT_NUMBERS == "kiln_tool::subtract_numbers"
|
|
21
|
+
assert KilnBuiltInToolId.MULTIPLY_NUMBERS == "kiln_tool::multiply_numbers"
|
|
22
|
+
assert KilnBuiltInToolId.DIVIDE_NUMBERS == "kiln_tool::divide_numbers"
|
|
23
|
+
for enum_value in KilnBuiltInToolId.__members__.values():
|
|
24
|
+
assert _check_tool_id(enum_value) == enum_value
|
|
25
|
+
|
|
26
|
+
def test_enum_membership(self):
|
|
27
|
+
"""Test enum membership checks."""
|
|
28
|
+
assert "kiln_tool::add_numbers" in KilnBuiltInToolId.__members__.values()
|
|
29
|
+
assert "invalid_tool" not in KilnBuiltInToolId.__members__.values()
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class TestCheckToolId:
|
|
33
|
+
"""Test the _check_tool_id validation function."""
|
|
34
|
+
|
|
35
|
+
def test_valid_builtin_tools(self):
|
|
36
|
+
"""Test validation of valid built-in tools."""
|
|
37
|
+
for tool_id in KilnBuiltInToolId:
|
|
38
|
+
result = _check_tool_id(tool_id.value)
|
|
39
|
+
assert result == tool_id.value
|
|
40
|
+
|
|
41
|
+
def test_valid_mcp_remote_tools(self):
|
|
42
|
+
"""Test validation of valid MCP remote tools."""
|
|
43
|
+
valid_ids = [
|
|
44
|
+
"mcp::remote::server1::tool1",
|
|
45
|
+
"mcp::remote::my_server::my_tool",
|
|
46
|
+
"mcp::remote::test::function_name",
|
|
47
|
+
]
|
|
48
|
+
for tool_id in valid_ids:
|
|
49
|
+
result = _check_tool_id(tool_id)
|
|
50
|
+
assert result == tool_id
|
|
51
|
+
|
|
52
|
+
def test_valid_mcp_local_tools(self):
|
|
53
|
+
"""Test validation of valid MCP local tools."""
|
|
54
|
+
valid_ids = [
|
|
55
|
+
"mcp::local::server1::tool1",
|
|
56
|
+
"mcp::local::my_server::my_tool",
|
|
57
|
+
"mcp::local::test::function_name",
|
|
58
|
+
]
|
|
59
|
+
for tool_id in valid_ids:
|
|
60
|
+
result = _check_tool_id(tool_id)
|
|
61
|
+
assert result == tool_id
|
|
62
|
+
|
|
63
|
+
def test_invalid_empty_or_none(self):
|
|
64
|
+
"""Test validation fails for empty or None values."""
|
|
65
|
+
with pytest.raises(ValueError, match="Invalid tool ID"):
|
|
66
|
+
_check_tool_id("")
|
|
67
|
+
|
|
68
|
+
with pytest.raises(ValueError, match="Invalid tool ID"):
|
|
69
|
+
_check_tool_id(None) # type: ignore
|
|
70
|
+
|
|
71
|
+
def test_invalid_non_string(self):
|
|
72
|
+
"""Test validation fails for non-string values."""
|
|
73
|
+
with pytest.raises(ValueError, match="Invalid tool ID"):
|
|
74
|
+
_check_tool_id(123) # type: ignore
|
|
75
|
+
|
|
76
|
+
with pytest.raises(ValueError, match="Invalid tool ID"):
|
|
77
|
+
_check_tool_id(["tool"]) # type: ignore
|
|
78
|
+
|
|
79
|
+
def test_invalid_unknown_tool(self):
|
|
80
|
+
"""Test validation fails for unknown tool IDs."""
|
|
81
|
+
with pytest.raises(ValueError, match="Invalid tool ID: unknown_tool"):
|
|
82
|
+
_check_tool_id("unknown_tool")
|
|
83
|
+
|
|
84
|
+
def test_invalid_mcp_format(self):
|
|
85
|
+
"""Test validation fails for invalid MCP tool formats."""
|
|
86
|
+
# These IDs start with the MCP remote prefix but have invalid formats
|
|
87
|
+
mcp_remote_invalid_ids = [
|
|
88
|
+
"mcp::remote::", # Missing server and tool
|
|
89
|
+
"mcp::remote::server", # Missing tool
|
|
90
|
+
"mcp::remote::server::", # Empty tool name
|
|
91
|
+
"mcp::remote::::tool", # Empty server name
|
|
92
|
+
"mcp::remote::server::tool::extra", # Too many parts
|
|
93
|
+
]
|
|
94
|
+
|
|
95
|
+
for invalid_id in mcp_remote_invalid_ids:
|
|
96
|
+
with pytest.raises(ValueError, match="Invalid remote MCP tool ID"):
|
|
97
|
+
_check_tool_id(invalid_id)
|
|
98
|
+
|
|
99
|
+
# These IDs start with the MCP local prefix but have invalid formats
|
|
100
|
+
mcp_local_invalid_ids = [
|
|
101
|
+
"mcp::local::", # Missing server and tool
|
|
102
|
+
"mcp::local::server", # Missing tool
|
|
103
|
+
"mcp::local::server::", # Empty tool name
|
|
104
|
+
"mcp::local::::tool", # Empty server name
|
|
105
|
+
"mcp::local::server::tool::extra", # Too many parts
|
|
106
|
+
]
|
|
107
|
+
|
|
108
|
+
for invalid_id in mcp_local_invalid_ids:
|
|
109
|
+
with pytest.raises(ValueError, match="Invalid local MCP tool ID"):
|
|
110
|
+
_check_tool_id(invalid_id)
|
|
111
|
+
|
|
112
|
+
# This ID doesn't start with MCP prefix so gets generic error
|
|
113
|
+
with pytest.raises(ValueError, match="Invalid tool ID"):
|
|
114
|
+
_check_tool_id("mcp::wrong::server::tool")
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
class TestMcpServerAndToolNameFromId:
|
|
118
|
+
"""Test the mcp_server_and_tool_name_from_id function."""
|
|
119
|
+
|
|
120
|
+
def test_valid_mcp_ids(self):
|
|
121
|
+
"""Test parsing valid MCP tool IDs."""
|
|
122
|
+
test_cases = [
|
|
123
|
+
# Remote MCP tools
|
|
124
|
+
("mcp::remote::server1::tool1", ("server1", "tool1")),
|
|
125
|
+
("mcp::remote::my_server::my_tool", ("my_server", "my_tool")),
|
|
126
|
+
("mcp::remote::test::function_name", ("test", "function_name")),
|
|
127
|
+
# Local MCP tools
|
|
128
|
+
("mcp::local::server1::tool1", ("server1", "tool1")),
|
|
129
|
+
("mcp::local::my_server::my_tool", ("my_server", "my_tool")),
|
|
130
|
+
("mcp::local::test::function_name", ("test", "function_name")),
|
|
131
|
+
]
|
|
132
|
+
|
|
133
|
+
for tool_id, expected in test_cases:
|
|
134
|
+
result = mcp_server_and_tool_name_from_id(tool_id)
|
|
135
|
+
assert result == expected
|
|
136
|
+
|
|
137
|
+
def test_invalid_mcp_ids(self):
|
|
138
|
+
"""Test parsing fails for invalid MCP tool IDs."""
|
|
139
|
+
# Test remote MCP tool ID errors
|
|
140
|
+
remote_invalid_ids = [
|
|
141
|
+
"mcp::remote::", # Only 3 parts
|
|
142
|
+
"mcp::remote::server", # Only 3 parts
|
|
143
|
+
"mcp::remote::server::tool::extra", # 5 parts
|
|
144
|
+
]
|
|
145
|
+
|
|
146
|
+
for invalid_id in remote_invalid_ids:
|
|
147
|
+
with pytest.raises(ValueError, match="Invalid remote MCP tool ID"):
|
|
148
|
+
mcp_server_and_tool_name_from_id(invalid_id)
|
|
149
|
+
|
|
150
|
+
# Test local MCP tool ID errors
|
|
151
|
+
local_invalid_ids = [
|
|
152
|
+
"mcp::local::", # Only 3 parts
|
|
153
|
+
"mcp::local::server", # Only 3 parts
|
|
154
|
+
"mcp::local::server::tool::extra", # 5 parts
|
|
155
|
+
]
|
|
156
|
+
|
|
157
|
+
for invalid_id in local_invalid_ids:
|
|
158
|
+
with pytest.raises(ValueError, match="Invalid local MCP tool ID"):
|
|
159
|
+
mcp_server_and_tool_name_from_id(invalid_id)
|
|
160
|
+
|
|
161
|
+
# Test generic MCP tool ID errors (not remote or local)
|
|
162
|
+
generic_invalid_ids = [
|
|
163
|
+
"not_mcp_format", # Only 1 part
|
|
164
|
+
"single_part", # Only 1 part
|
|
165
|
+
"", # Empty string
|
|
166
|
+
]
|
|
167
|
+
|
|
168
|
+
for invalid_id in generic_invalid_ids:
|
|
169
|
+
with pytest.raises(ValueError, match="Invalid MCP tool ID"):
|
|
170
|
+
mcp_server_and_tool_name_from_id(invalid_id)
|
|
171
|
+
|
|
172
|
+
def test_mcp_ids_with_wrong_prefix_still_parse(self):
|
|
173
|
+
"""Test that IDs with wrong prefix but correct structure still parse (validation happens elsewhere)."""
|
|
174
|
+
# This function only checks structure (4 parts), not content
|
|
175
|
+
result = mcp_server_and_tool_name_from_id("mcp::wrong::server::tool")
|
|
176
|
+
assert result == ("server", "tool")
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
class TestToolIdPydanticType:
|
|
180
|
+
"""Test the ToolId pydantic type annotation."""
|
|
181
|
+
|
|
182
|
+
class _ModelWithToolId(BaseModel):
|
|
183
|
+
tool_id: ToolId
|
|
184
|
+
|
|
185
|
+
def test_valid_builtin_tools(self):
|
|
186
|
+
"""Test ToolId validates built-in tools."""
|
|
187
|
+
for tool_id in KilnBuiltInToolId:
|
|
188
|
+
model = self._ModelWithToolId(tool_id=tool_id.value)
|
|
189
|
+
assert model.tool_id == tool_id.value
|
|
190
|
+
|
|
191
|
+
def test_valid_mcp_tools(self):
|
|
192
|
+
"""Test ToolId validates MCP remote and local tools."""
|
|
193
|
+
valid_ids = [
|
|
194
|
+
# Remote MCP tools
|
|
195
|
+
"mcp::remote::server1::tool1",
|
|
196
|
+
"mcp::remote::my_server::my_tool",
|
|
197
|
+
# Local MCP tools
|
|
198
|
+
"mcp::local::server1::tool1",
|
|
199
|
+
"mcp::local::my_server::my_tool",
|
|
200
|
+
]
|
|
201
|
+
|
|
202
|
+
for tool_id in valid_ids:
|
|
203
|
+
model = self._ModelWithToolId(tool_id=tool_id)
|
|
204
|
+
assert model.tool_id == tool_id
|
|
205
|
+
|
|
206
|
+
def test_invalid_tools_raise_validation_error(self):
|
|
207
|
+
"""Test ToolId raises ValidationError for invalid tools."""
|
|
208
|
+
invalid_ids = [
|
|
209
|
+
"",
|
|
210
|
+
"unknown_tool",
|
|
211
|
+
"mcp::remote::",
|
|
212
|
+
"mcp::remote::server",
|
|
213
|
+
"mcp::local::",
|
|
214
|
+
"mcp::local::server",
|
|
215
|
+
]
|
|
216
|
+
|
|
217
|
+
for invalid_id in invalid_ids:
|
|
218
|
+
with pytest.raises(ValidationError):
|
|
219
|
+
self._ModelWithToolId(tool_id=invalid_id)
|
|
220
|
+
|
|
221
|
+
def test_non_string_raises_validation_error(self):
|
|
222
|
+
"""Test ToolId raises ValidationError for non-string values."""
|
|
223
|
+
with pytest.raises(ValidationError):
|
|
224
|
+
self._ModelWithToolId(tool_id=123) # type: ignore
|
|
225
|
+
|
|
226
|
+
with pytest.raises(ValidationError):
|
|
227
|
+
self._ModelWithToolId(tool_id=None) # type: ignore
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
class TestConstants:
|
|
231
|
+
"""Test module constants."""
|
|
232
|
+
|
|
233
|
+
def test_mcp_remote_tool_id_prefix(self):
|
|
234
|
+
"""Test the MCP remote tool ID prefix constant."""
|
|
235
|
+
assert MCP_REMOTE_TOOL_ID_PREFIX == "mcp::remote::"
|
|
236
|
+
|
|
237
|
+
def test_mcp_local_tool_id_prefix(self):
|
|
238
|
+
"""Test the MCP local tool ID prefix constant."""
|
|
239
|
+
assert MCP_LOCAL_TOOL_ID_PREFIX == "mcp::local::"
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
from enum import Enum
|
|
2
|
+
from typing import Annotated
|
|
3
|
+
|
|
4
|
+
from pydantic import AfterValidator
|
|
5
|
+
|
|
6
|
+
ToolId = Annotated[
|
|
7
|
+
str,
|
|
8
|
+
AfterValidator(lambda v: _check_tool_id(v)),
|
|
9
|
+
]
|
|
10
|
+
"""
|
|
11
|
+
A pydantic type that validates strings containing a valid tool ID.
|
|
12
|
+
|
|
13
|
+
Tool IDs can be one of:
|
|
14
|
+
- A kiln built-in tool name: kiln_tool::add_numbers
|
|
15
|
+
- A remote MCP tool: mcp::remote::<server_id>::<tool_name>
|
|
16
|
+
- A local MCP tool: mcp::local::<server_id>::<tool_name>
|
|
17
|
+
- More coming soon like kiln_project_tool::rag::RAG_CONFIG_ID
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class KilnBuiltInToolId(str, Enum):
|
|
22
|
+
ADD_NUMBERS = "kiln_tool::add_numbers"
|
|
23
|
+
SUBTRACT_NUMBERS = "kiln_tool::subtract_numbers"
|
|
24
|
+
MULTIPLY_NUMBERS = "kiln_tool::multiply_numbers"
|
|
25
|
+
DIVIDE_NUMBERS = "kiln_tool::divide_numbers"
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
MCP_REMOTE_TOOL_ID_PREFIX = "mcp::remote::"
|
|
29
|
+
MCP_LOCAL_TOOL_ID_PREFIX = "mcp::local::"
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _check_tool_id(id: str) -> str:
|
|
33
|
+
"""
|
|
34
|
+
Check that the tool ID is valid.
|
|
35
|
+
"""
|
|
36
|
+
if not id or not isinstance(id, str):
|
|
37
|
+
raise ValueError(f"Invalid tool ID: {id}")
|
|
38
|
+
|
|
39
|
+
# Build in tools
|
|
40
|
+
if id in KilnBuiltInToolId.__members__.values():
|
|
41
|
+
return id
|
|
42
|
+
|
|
43
|
+
# MCP remote tools must have format: mcp::remote::<server_id>::<tool_name>
|
|
44
|
+
if id.startswith(MCP_REMOTE_TOOL_ID_PREFIX):
|
|
45
|
+
server_id, tool_name = mcp_server_and_tool_name_from_id(id)
|
|
46
|
+
if not server_id or not tool_name:
|
|
47
|
+
raise ValueError(
|
|
48
|
+
f"Invalid remote MCP tool ID: {id}. Expected format: 'mcp::remote::<server_id>::<tool_name>'."
|
|
49
|
+
)
|
|
50
|
+
return id
|
|
51
|
+
|
|
52
|
+
# MCP local tools must have format: mcp::local::<server_id>::<tool_name>
|
|
53
|
+
if id.startswith(MCP_LOCAL_TOOL_ID_PREFIX):
|
|
54
|
+
server_id, tool_name = mcp_server_and_tool_name_from_id(id)
|
|
55
|
+
if not server_id or not tool_name:
|
|
56
|
+
raise ValueError(
|
|
57
|
+
f"Invalid local MCP tool ID: {id}. Expected format: 'mcp::local::<server_id>::<tool_name>'."
|
|
58
|
+
)
|
|
59
|
+
return id
|
|
60
|
+
|
|
61
|
+
raise ValueError(f"Invalid tool ID: {id}")
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def mcp_server_and_tool_name_from_id(id: str) -> tuple[str, str]:
|
|
65
|
+
"""
|
|
66
|
+
Get the tool server ID and tool name from the ID.
|
|
67
|
+
"""
|
|
68
|
+
parts = id.split("::")
|
|
69
|
+
if len(parts) != 4:
|
|
70
|
+
# Determine if it's remote or local for the error message
|
|
71
|
+
if id.startswith(MCP_REMOTE_TOOL_ID_PREFIX):
|
|
72
|
+
raise ValueError(
|
|
73
|
+
f"Invalid remote MCP tool ID: {id}. Expected format: 'mcp::remote::<server_id>::<tool_name>'."
|
|
74
|
+
)
|
|
75
|
+
elif id.startswith(MCP_LOCAL_TOOL_ID_PREFIX):
|
|
76
|
+
raise ValueError(
|
|
77
|
+
f"Invalid local MCP tool ID: {id}. Expected format: 'mcp::local::<server_id>::<tool_name>'."
|
|
78
|
+
)
|
|
79
|
+
else:
|
|
80
|
+
raise ValueError(
|
|
81
|
+
f"Invalid MCP tool ID: {id}. Expected format: 'mcp::(remote|local)::<server_id>::<tool_name>'."
|
|
82
|
+
)
|
|
83
|
+
return parts[2], parts[3] # server_id, tool_name
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import Any, Dict
|
|
3
|
+
|
|
4
|
+
from kiln_ai.datamodel.json_schema import validate_schema_dict
|
|
5
|
+
from kiln_ai.datamodel.tool_id import KilnBuiltInToolId, ToolId
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class KilnToolInterface(ABC):
|
|
9
|
+
"""
|
|
10
|
+
Abstract interface defining the core API that all Kiln tools must implement.
|
|
11
|
+
This ensures consistency across all tool implementations.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
@abstractmethod
|
|
15
|
+
async def run(self, **kwargs) -> Any:
|
|
16
|
+
"""Execute the tool with the given parameters."""
|
|
17
|
+
pass
|
|
18
|
+
|
|
19
|
+
@abstractmethod
|
|
20
|
+
async def toolcall_definition(self) -> Dict[str, Any]:
|
|
21
|
+
"""Return the OpenAI-compatible tool definition for this tool."""
|
|
22
|
+
pass
|
|
23
|
+
|
|
24
|
+
@abstractmethod
|
|
25
|
+
async def id(self) -> ToolId:
|
|
26
|
+
"""Return a unique identifier for this tool."""
|
|
27
|
+
pass
|
|
28
|
+
|
|
29
|
+
@abstractmethod
|
|
30
|
+
async def name(self) -> str:
|
|
31
|
+
"""Return the tool name (function name) of this tool."""
|
|
32
|
+
pass
|
|
33
|
+
|
|
34
|
+
@abstractmethod
|
|
35
|
+
async def description(self) -> str:
|
|
36
|
+
"""Return a description of what this tool does."""
|
|
37
|
+
pass
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class KilnTool(KilnToolInterface):
|
|
41
|
+
"""
|
|
42
|
+
Base helper class that provides common functionality for tool implementations.
|
|
43
|
+
Subclasses only need to implement run() and provide tool configuration.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
def __init__(
|
|
47
|
+
self,
|
|
48
|
+
tool_id: KilnBuiltInToolId,
|
|
49
|
+
name: str,
|
|
50
|
+
description: str,
|
|
51
|
+
parameters_schema: Dict[str, Any],
|
|
52
|
+
):
|
|
53
|
+
self._id = tool_id
|
|
54
|
+
self._name = name
|
|
55
|
+
self._description = description
|
|
56
|
+
validate_schema_dict(parameters_schema)
|
|
57
|
+
self._parameters_schema = parameters_schema
|
|
58
|
+
|
|
59
|
+
async def id(self) -> KilnBuiltInToolId:
|
|
60
|
+
return self._id
|
|
61
|
+
|
|
62
|
+
async def name(self) -> str:
|
|
63
|
+
return self._name
|
|
64
|
+
|
|
65
|
+
async def description(self) -> str:
|
|
66
|
+
return self._description
|
|
67
|
+
|
|
68
|
+
async def toolcall_definition(self) -> Dict[str, Any]:
|
|
69
|
+
"""Generate OpenAI-compatible tool definition."""
|
|
70
|
+
return {
|
|
71
|
+
"type": "function",
|
|
72
|
+
"function": {
|
|
73
|
+
"name": await self.name(),
|
|
74
|
+
"description": await self.description(),
|
|
75
|
+
"parameters": self._parameters_schema,
|
|
76
|
+
},
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
@abstractmethod
|
|
80
|
+
async def run(self, **kwargs) -> str:
|
|
81
|
+
"""Subclasses must implement the actual tool logic."""
|
|
82
|
+
pass
|