kiln-ai 0.18.0__py3-none-any.whl → 0.20.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +2 -2
- kiln_ai/adapters/adapter_registry.py +46 -0
- kiln_ai/adapters/chat/chat_formatter.py +8 -12
- kiln_ai/adapters/chat/test_chat_formatter.py +6 -2
- kiln_ai/adapters/data_gen/data_gen_task.py +2 -2
- kiln_ai/adapters/data_gen/test_data_gen_task.py +7 -3
- kiln_ai/adapters/docker_model_runner_tools.py +119 -0
- kiln_ai/adapters/eval/base_eval.py +2 -2
- kiln_ai/adapters/eval/eval_runner.py +3 -1
- kiln_ai/adapters/eval/g_eval.py +2 -2
- kiln_ai/adapters/eval/test_base_eval.py +1 -1
- kiln_ai/adapters/eval/test_eval_runner.py +6 -12
- kiln_ai/adapters/eval/test_g_eval.py +3 -4
- kiln_ai/adapters/eval/test_g_eval_data.py +1 -1
- kiln_ai/adapters/fine_tune/__init__.py +1 -1
- kiln_ai/adapters/fine_tune/base_finetune.py +1 -0
- kiln_ai/adapters/fine_tune/fireworks_finetune.py +32 -20
- kiln_ai/adapters/fine_tune/openai_finetune.py +14 -4
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +30 -21
- kiln_ai/adapters/fine_tune/test_openai_finetune.py +108 -111
- kiln_ai/adapters/ml_model_list.py +1009 -111
- kiln_ai/adapters/model_adapters/base_adapter.py +62 -28
- kiln_ai/adapters/model_adapters/litellm_adapter.py +397 -80
- kiln_ai/adapters/model_adapters/test_base_adapter.py +194 -18
- kiln_ai/adapters/model_adapters/test_litellm_adapter.py +428 -4
- kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py +1103 -0
- kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +5 -5
- kiln_ai/adapters/model_adapters/test_structured_output.py +120 -14
- kiln_ai/adapters/parsers/__init__.py +1 -1
- kiln_ai/adapters/parsers/test_r1_parser.py +1 -1
- kiln_ai/adapters/provider_tools.py +35 -20
- kiln_ai/adapters/remote_config.py +57 -10
- kiln_ai/adapters/repair/repair_task.py +1 -1
- kiln_ai/adapters/repair/test_repair_task.py +12 -9
- kiln_ai/adapters/run_output.py +3 -0
- kiln_ai/adapters/test_adapter_registry.py +109 -2
- kiln_ai/adapters/test_docker_model_runner_tools.py +305 -0
- kiln_ai/adapters/test_ml_model_list.py +51 -1
- kiln_ai/adapters/test_prompt_adaptors.py +13 -6
- kiln_ai/adapters/test_provider_tools.py +73 -12
- kiln_ai/adapters/test_remote_config.py +470 -16
- kiln_ai/datamodel/__init__.py +23 -21
- kiln_ai/datamodel/basemodel.py +54 -28
- kiln_ai/datamodel/datamodel_enums.py +3 -0
- kiln_ai/datamodel/dataset_split.py +5 -3
- kiln_ai/datamodel/eval.py +4 -4
- kiln_ai/datamodel/external_tool_server.py +298 -0
- kiln_ai/datamodel/finetune.py +2 -2
- kiln_ai/datamodel/json_schema.py +25 -10
- kiln_ai/datamodel/project.py +11 -4
- kiln_ai/datamodel/prompt.py +2 -2
- kiln_ai/datamodel/prompt_id.py +4 -4
- kiln_ai/datamodel/registry.py +0 -15
- kiln_ai/datamodel/run_config.py +62 -0
- kiln_ai/datamodel/task.py +8 -83
- kiln_ai/datamodel/task_output.py +7 -2
- kiln_ai/datamodel/task_run.py +41 -0
- kiln_ai/datamodel/test_basemodel.py +213 -21
- kiln_ai/datamodel/test_eval_model.py +6 -6
- kiln_ai/datamodel/test_example_models.py +175 -0
- kiln_ai/datamodel/test_external_tool_server.py +691 -0
- kiln_ai/datamodel/test_model_perf.py +1 -1
- kiln_ai/datamodel/test_prompt_id.py +5 -1
- kiln_ai/datamodel/test_registry.py +8 -3
- kiln_ai/datamodel/test_task.py +20 -47
- kiln_ai/datamodel/test_tool_id.py +239 -0
- kiln_ai/datamodel/tool_id.py +83 -0
- kiln_ai/tools/__init__.py +8 -0
- kiln_ai/tools/base_tool.py +82 -0
- kiln_ai/tools/built_in_tools/__init__.py +13 -0
- kiln_ai/tools/built_in_tools/math_tools.py +124 -0
- kiln_ai/tools/built_in_tools/test_math_tools.py +204 -0
- kiln_ai/tools/mcp_server_tool.py +95 -0
- kiln_ai/tools/mcp_session_manager.py +243 -0
- kiln_ai/tools/test_base_tools.py +199 -0
- kiln_ai/tools/test_mcp_server_tool.py +457 -0
- kiln_ai/tools/test_mcp_session_manager.py +1585 -0
- kiln_ai/tools/test_tool_registry.py +473 -0
- kiln_ai/tools/tool_registry.py +64 -0
- kiln_ai/utils/config.py +32 -0
- kiln_ai/utils/open_ai_types.py +94 -0
- kiln_ai/utils/project_utils.py +17 -0
- kiln_ai/utils/test_config.py +138 -1
- kiln_ai/utils/test_open_ai_types.py +131 -0
- {kiln_ai-0.18.0.dist-info → kiln_ai-0.20.1.dist-info}/METADATA +37 -6
- kiln_ai-0.20.1.dist-info/RECORD +138 -0
- kiln_ai-0.18.0.dist-info/RECORD +0 -115
- {kiln_ai-0.18.0.dist-info → kiln_ai-0.20.1.dist-info}/WHEEL +0 -0
- {kiln_ai-0.18.0.dist-info → kiln_ai-0.20.1.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -13,7 +13,7 @@ from kiln_ai.datamodel import (
|
|
|
13
13
|
Task,
|
|
14
14
|
Usage,
|
|
15
15
|
)
|
|
16
|
-
from kiln_ai.datamodel.task import
|
|
16
|
+
from kiln_ai.datamodel.task import RunConfigProperties
|
|
17
17
|
from kiln_ai.utils.config import Config
|
|
18
18
|
|
|
19
19
|
|
|
@@ -41,8 +41,8 @@ def test_task(tmp_path):
|
|
|
41
41
|
@pytest.fixture
|
|
42
42
|
def adapter(test_task):
|
|
43
43
|
return MockAdapter(
|
|
44
|
-
|
|
45
|
-
|
|
44
|
+
task=test_task,
|
|
45
|
+
run_config=RunConfigProperties(
|
|
46
46
|
model_name="phi_3_5",
|
|
47
47
|
model_provider_name="ollama",
|
|
48
48
|
prompt_id="simple_chain_of_thought_prompt_builder",
|
|
@@ -240,8 +240,8 @@ async def test_autosave_true(test_task, adapter):
|
|
|
240
240
|
def test_properties_for_task_output_custom_values(test_task):
|
|
241
241
|
"""Test that _properties_for_task_output includes custom temperature, top_p, and structured_output_mode"""
|
|
242
242
|
adapter = MockAdapter(
|
|
243
|
-
|
|
244
|
-
|
|
243
|
+
task=test_task,
|
|
244
|
+
run_config=RunConfigProperties(
|
|
245
245
|
model_name="gpt-4",
|
|
246
246
|
model_provider_name="openai",
|
|
247
247
|
prompt_id="simple_prompt_builder",
|
|
@@ -1,23 +1,19 @@
|
|
|
1
1
|
import json
|
|
2
2
|
from pathlib import Path
|
|
3
3
|
from typing import Dict
|
|
4
|
+
from unittest.mock import Mock, patch
|
|
4
5
|
|
|
5
6
|
import pytest
|
|
7
|
+
from litellm.types.utils import ModelResponse
|
|
6
8
|
|
|
7
9
|
import kiln_ai.datamodel as datamodel
|
|
8
10
|
from kiln_ai.adapters.adapter_registry import adapter_for_task
|
|
9
|
-
from kiln_ai.adapters.ml_model_list import
|
|
10
|
-
|
|
11
|
-
)
|
|
12
|
-
from kiln_ai.adapters.model_adapters.base_adapter import (
|
|
13
|
-
BaseAdapter,
|
|
14
|
-
RunOutput,
|
|
15
|
-
Usage,
|
|
16
|
-
)
|
|
11
|
+
from kiln_ai.adapters.ml_model_list import built_in_models
|
|
12
|
+
from kiln_ai.adapters.model_adapters.base_adapter import BaseAdapter, RunOutput, Usage
|
|
17
13
|
from kiln_ai.adapters.ollama_tools import ollama_online
|
|
18
14
|
from kiln_ai.adapters.test_prompt_adaptors import get_all_models_and_providers
|
|
19
15
|
from kiln_ai.datamodel import PromptId
|
|
20
|
-
from kiln_ai.datamodel.task import
|
|
16
|
+
from kiln_ai.datamodel.task import RunConfigProperties
|
|
21
17
|
from kiln_ai.datamodel.test_json_schema import json_joke_schema, json_triangle_schema
|
|
22
18
|
|
|
23
19
|
|
|
@@ -46,8 +42,8 @@ async def test_structured_output_ollama(tmp_path, model_name):
|
|
|
46
42
|
class MockAdapter(BaseAdapter):
|
|
47
43
|
def __init__(self, kiln_task: datamodel.Task, response: Dict | str | None):
|
|
48
44
|
super().__init__(
|
|
49
|
-
|
|
50
|
-
|
|
45
|
+
task=kiln_task,
|
|
46
|
+
run_config=RunConfigProperties(
|
|
51
47
|
model_name="phi_3_5",
|
|
52
48
|
model_provider_name="ollama",
|
|
53
49
|
prompt_id="simple_chain_of_thought_prompt_builder",
|
|
@@ -180,8 +176,14 @@ async def run_structured_output_test(tmp_path: Path, model_name: str, provider:
|
|
|
180
176
|
# Check reasoning models
|
|
181
177
|
assert a._model_provider is not None
|
|
182
178
|
if a._model_provider.reasoning_capable:
|
|
183
|
-
|
|
184
|
-
|
|
179
|
+
# some providers have reasoning_capable models that do not return the reasoning
|
|
180
|
+
# for structured output responses (they provide it only for non-structured output)
|
|
181
|
+
if a._model_provider.reasoning_optional_for_structured_output:
|
|
182
|
+
# models may be updated to include the reasoning in the future
|
|
183
|
+
assert "reasoning" not in run.intermediate_outputs
|
|
184
|
+
else:
|
|
185
|
+
assert "reasoning" in run.intermediate_outputs
|
|
186
|
+
assert isinstance(run.intermediate_outputs["reasoning"], str)
|
|
185
187
|
|
|
186
188
|
|
|
187
189
|
def build_structured_input_test_task(tmp_path: Path):
|
|
@@ -259,6 +261,7 @@ async def run_structured_input_task(
|
|
|
259
261
|
model_name: str,
|
|
260
262
|
provider: str,
|
|
261
263
|
prompt_id: PromptId,
|
|
264
|
+
verify_trace_cot: bool = False,
|
|
262
265
|
):
|
|
263
266
|
response, a, run = await run_structured_input_task_no_validation(
|
|
264
267
|
task, model_name, provider, prompt_id
|
|
@@ -282,6 +285,32 @@ async def run_structured_input_task(
|
|
|
282
285
|
assert "reasoning" in run.intermediate_outputs
|
|
283
286
|
assert isinstance(run.intermediate_outputs["reasoning"], str)
|
|
284
287
|
|
|
288
|
+
# Check the trace
|
|
289
|
+
trace = run.trace
|
|
290
|
+
assert trace is not None
|
|
291
|
+
if verify_trace_cot:
|
|
292
|
+
assert len(trace) == 5
|
|
293
|
+
assert trace[0]["role"] == "system"
|
|
294
|
+
assert "You are an assistant which classifies a triangle" in trace[0]["content"]
|
|
295
|
+
assert trace[1]["role"] == "user"
|
|
296
|
+
assert trace[2]["role"] == "assistant"
|
|
297
|
+
assert trace[2].get("tool_calls") is None
|
|
298
|
+
assert trace[3]["role"] == "user"
|
|
299
|
+
assert trace[4]["role"] == "assistant"
|
|
300
|
+
assert trace[4].get("tool_calls") is None
|
|
301
|
+
else:
|
|
302
|
+
assert len(trace) == 3
|
|
303
|
+
assert trace[0]["role"] == "system"
|
|
304
|
+
assert "You are an assistant which classifies a triangle" in trace[0]["content"]
|
|
305
|
+
assert trace[1]["role"] == "user"
|
|
306
|
+
json_content = json.loads(trace[1]["content"])
|
|
307
|
+
assert json_content["a"] == 2
|
|
308
|
+
assert json_content["b"] == 2
|
|
309
|
+
assert json_content["c"] == 2
|
|
310
|
+
assert trace[2]["role"] == "assistant"
|
|
311
|
+
assert trace[2].get("tool_calls") is None
|
|
312
|
+
assert "[[equilateral]]" in trace[2]["content"]
|
|
313
|
+
|
|
285
314
|
|
|
286
315
|
@pytest.mark.paid
|
|
287
316
|
async def test_structured_input_gpt_4o_mini(tmp_path):
|
|
@@ -299,14 +328,91 @@ async def test_all_built_in_models_structured_input(
|
|
|
299
328
|
)
|
|
300
329
|
|
|
301
330
|
|
|
331
|
+
async def test_all_built_in_models_structured_input_mocked(tmp_path):
|
|
332
|
+
mock_response = ModelResponse(
|
|
333
|
+
model="gpt-4o-mini",
|
|
334
|
+
choices=[
|
|
335
|
+
{
|
|
336
|
+
"message": {
|
|
337
|
+
"content": "The answer is [[equilateral]]",
|
|
338
|
+
}
|
|
339
|
+
}
|
|
340
|
+
],
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
# Mock the Config.shared() method to return a mock config with required attributes
|
|
344
|
+
mock_config = Mock()
|
|
345
|
+
mock_config.open_ai_api_key = "mock_api_key"
|
|
346
|
+
mock_config.user_id = "test_user"
|
|
347
|
+
|
|
348
|
+
with (
|
|
349
|
+
patch(
|
|
350
|
+
"litellm.acompletion",
|
|
351
|
+
side_effect=[mock_response],
|
|
352
|
+
),
|
|
353
|
+
patch("kiln_ai.utils.config.Config.shared", return_value=mock_config),
|
|
354
|
+
):
|
|
355
|
+
await run_structured_input_test(
|
|
356
|
+
tmp_path, "llama_3_1_8b", "groq", "simple_prompt_builder"
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
|
|
302
360
|
@pytest.mark.paid
|
|
303
361
|
@pytest.mark.ollama
|
|
304
362
|
@pytest.mark.parametrize("model_name,provider_name", get_all_models_and_providers())
|
|
305
363
|
async def test_structured_input_cot_prompt_builder(tmp_path, model_name, provider_name):
|
|
306
364
|
task = build_structured_input_test_task(tmp_path)
|
|
307
365
|
await run_structured_input_task(
|
|
308
|
-
task,
|
|
366
|
+
task,
|
|
367
|
+
model_name,
|
|
368
|
+
provider_name,
|
|
369
|
+
"simple_chain_of_thought_prompt_builder",
|
|
370
|
+
verify_trace_cot=True,
|
|
371
|
+
)
|
|
372
|
+
|
|
373
|
+
|
|
374
|
+
async def test_structured_input_cot_prompt_builder_mocked(tmp_path):
|
|
375
|
+
task = build_structured_input_test_task(tmp_path)
|
|
376
|
+
mock_response_1 = ModelResponse(
|
|
377
|
+
model="gpt-4o-mini",
|
|
378
|
+
choices=[
|
|
379
|
+
{
|
|
380
|
+
"message": {
|
|
381
|
+
"content": "I'm thinking real hard... oh!",
|
|
382
|
+
}
|
|
383
|
+
}
|
|
384
|
+
],
|
|
309
385
|
)
|
|
386
|
+
mock_response_2 = ModelResponse(
|
|
387
|
+
model="gpt-4o-mini",
|
|
388
|
+
choices=[
|
|
389
|
+
{
|
|
390
|
+
"message": {
|
|
391
|
+
"content": "After thinking, I've decided the answer is [[equilateral]]",
|
|
392
|
+
}
|
|
393
|
+
}
|
|
394
|
+
],
|
|
395
|
+
)
|
|
396
|
+
|
|
397
|
+
# Mock the Config.shared() method to return a mock config with required attributes
|
|
398
|
+
mock_config = Mock()
|
|
399
|
+
mock_config.open_ai_api_key = "mock_api_key"
|
|
400
|
+
mock_config.user_id = "test_user"
|
|
401
|
+
|
|
402
|
+
with (
|
|
403
|
+
patch(
|
|
404
|
+
"litellm.acompletion",
|
|
405
|
+
side_effect=[mock_response_1, mock_response_2],
|
|
406
|
+
),
|
|
407
|
+
patch("kiln_ai.utils.config.Config.shared", return_value=mock_config),
|
|
408
|
+
):
|
|
409
|
+
await run_structured_input_task(
|
|
410
|
+
task,
|
|
411
|
+
"llama_3_1_8b",
|
|
412
|
+
"groq",
|
|
413
|
+
"simple_chain_of_thought_prompt_builder",
|
|
414
|
+
verify_trace_cot=True,
|
|
415
|
+
)
|
|
310
416
|
|
|
311
417
|
|
|
312
418
|
@pytest.mark.paid
|
|
@@ -2,27 +2,25 @@ import logging
|
|
|
2
2
|
from dataclasses import dataclass
|
|
3
3
|
from typing import Dict, List
|
|
4
4
|
|
|
5
|
+
from kiln_ai.adapters.docker_model_runner_tools import (
|
|
6
|
+
get_docker_model_runner_connection,
|
|
7
|
+
)
|
|
5
8
|
from kiln_ai.adapters.ml_model_list import (
|
|
6
9
|
KilnModel,
|
|
7
10
|
KilnModelProvider,
|
|
8
|
-
ModelName,
|
|
9
11
|
ModelParserID,
|
|
10
12
|
ModelProviderName,
|
|
11
13
|
StructuredOutputMode,
|
|
12
14
|
built_in_models,
|
|
13
15
|
)
|
|
14
|
-
from kiln_ai.adapters.model_adapters.litellm_config import
|
|
15
|
-
|
|
16
|
-
)
|
|
17
|
-
from kiln_ai.adapters.ollama_tools import (
|
|
18
|
-
get_ollama_connection,
|
|
19
|
-
)
|
|
16
|
+
from kiln_ai.adapters.model_adapters.litellm_config import LiteLlmConfig
|
|
17
|
+
from kiln_ai.adapters.ollama_tools import get_ollama_connection
|
|
20
18
|
from kiln_ai.datamodel import Finetune, Task
|
|
21
19
|
from kiln_ai.datamodel.datamodel_enums import ChatStrategy
|
|
22
|
-
from kiln_ai.datamodel.registry import project_from_id
|
|
23
20
|
from kiln_ai.datamodel.task import RunConfigProperties
|
|
24
21
|
from kiln_ai.utils.config import Config
|
|
25
22
|
from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
|
|
23
|
+
from kiln_ai.utils.project_utils import project_from_id
|
|
26
24
|
|
|
27
25
|
logger = logging.getLogger(__name__)
|
|
28
26
|
|
|
@@ -37,6 +35,15 @@ async def provider_enabled(provider_name: ModelProviderName) -> bool:
|
|
|
37
35
|
except Exception:
|
|
38
36
|
return False
|
|
39
37
|
|
|
38
|
+
if provider_name == ModelProviderName.docker_model_runner:
|
|
39
|
+
try:
|
|
40
|
+
conn = await get_docker_model_runner_connection()
|
|
41
|
+
return conn is not None and (
|
|
42
|
+
len(conn.supported_models) > 0 or len(conn.untested_models) > 0
|
|
43
|
+
)
|
|
44
|
+
except Exception:
|
|
45
|
+
return False
|
|
46
|
+
|
|
40
47
|
provider_warning = provider_warnings.get(provider_name)
|
|
41
48
|
if provider_warning is None:
|
|
42
49
|
return False
|
|
@@ -75,30 +82,24 @@ def builtin_model_from(
|
|
|
75
82
|
name: str, provider_name: str | None = None
|
|
76
83
|
) -> KilnModelProvider | None:
|
|
77
84
|
"""
|
|
78
|
-
Gets a model
|
|
85
|
+
Gets a model provider from the built-in list of models.
|
|
79
86
|
|
|
80
87
|
Args:
|
|
81
88
|
name: The name of the model to get
|
|
82
89
|
provider_name: Optional specific provider to use (defaults to first available)
|
|
83
90
|
|
|
84
91
|
Returns:
|
|
85
|
-
A
|
|
86
|
-
|
|
87
|
-
Raises:
|
|
88
|
-
ValueError: If the model or provider is not found, or if the provider is misconfigured
|
|
92
|
+
A KilnModelProvider, or None if not found
|
|
89
93
|
"""
|
|
90
|
-
if name not in ModelName.__members__:
|
|
91
|
-
return None
|
|
92
|
-
|
|
93
94
|
# Select the model from built_in_models using the name
|
|
94
|
-
model = next(filter(lambda m: m.name == name, built_in_models))
|
|
95
|
+
model = next(filter(lambda m: m.name == name, built_in_models), None)
|
|
95
96
|
if model is None:
|
|
96
|
-
|
|
97
|
+
return None
|
|
97
98
|
|
|
98
|
-
# If a provider is provided, select the provider
|
|
99
|
+
# If a provider is provided, select the appropriate provider. Otherwise, use the first available.
|
|
99
100
|
provider: KilnModelProvider | None = None
|
|
100
101
|
if model.providers is None or len(model.providers) == 0:
|
|
101
|
-
|
|
102
|
+
return None
|
|
102
103
|
elif provider_name is None:
|
|
103
104
|
provider = model.providers[0]
|
|
104
105
|
else:
|
|
@@ -384,6 +385,12 @@ def provider_name_from_id(id: str) -> str:
|
|
|
384
385
|
return "Google Vertex AI"
|
|
385
386
|
case ModelProviderName.together_ai:
|
|
386
387
|
return "Together AI"
|
|
388
|
+
case ModelProviderName.siliconflow_cn:
|
|
389
|
+
return "SiliconFlow"
|
|
390
|
+
case ModelProviderName.cerebras:
|
|
391
|
+
return "Cerebras"
|
|
392
|
+
case ModelProviderName.docker_model_runner:
|
|
393
|
+
return "Docker Model Runner"
|
|
387
394
|
case _:
|
|
388
395
|
# triggers pyright warning if I miss a case
|
|
389
396
|
raise_exhaustive_enum_error(enum_id)
|
|
@@ -442,4 +449,12 @@ provider_warnings: Dict[ModelProviderName, ModelProviderWarning] = {
|
|
|
442
449
|
required_config_keys=["together_api_key"],
|
|
443
450
|
message="Attempted to use Together without an API key set. \nGet your API key from https://together.ai/settings/keys",
|
|
444
451
|
),
|
|
452
|
+
ModelProviderName.siliconflow_cn: ModelProviderWarning(
|
|
453
|
+
required_config_keys=["siliconflow_cn_api_key"],
|
|
454
|
+
message="Attempted to use SiliconFlow without an API key set. \nGet your API key from https://cloud.siliconflow.cn/account/ak",
|
|
455
|
+
),
|
|
456
|
+
ModelProviderName.cerebras: ModelProviderWarning(
|
|
457
|
+
required_config_keys=["cerebras_api_key"],
|
|
458
|
+
message="Attempted to use Cerebras without an API key set. \nGet your API key from https://cloud.cerebras.ai/platform",
|
|
459
|
+
),
|
|
445
460
|
}
|
|
@@ -4,11 +4,12 @@ import logging
|
|
|
4
4
|
import os
|
|
5
5
|
import threading
|
|
6
6
|
from pathlib import Path
|
|
7
|
-
from typing import List
|
|
7
|
+
from typing import Any, List
|
|
8
8
|
|
|
9
9
|
import requests
|
|
10
|
+
from pydantic import ValidationError
|
|
10
11
|
|
|
11
|
-
from .ml_model_list import KilnModel, built_in_models
|
|
12
|
+
from .ml_model_list import KilnModel, KilnModelProvider, built_in_models
|
|
12
13
|
|
|
13
14
|
logger = logging.getLogger(__name__)
|
|
14
15
|
|
|
@@ -18,21 +19,67 @@ def serialize_config(models: List[KilnModel], path: str | Path) -> None:
|
|
|
18
19
|
Path(path).write_text(json.dumps(data, indent=2, sort_keys=True))
|
|
19
20
|
|
|
20
21
|
|
|
21
|
-
def
|
|
22
|
+
def deserialize_config_at_path(path: str | Path) -> List[KilnModel]:
|
|
22
23
|
raw = json.loads(Path(path).read_text())
|
|
23
|
-
|
|
24
|
-
|
|
24
|
+
return deserialize_config_data(raw)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def deserialize_config_data(config_data: Any) -> List[KilnModel]:
|
|
28
|
+
if not isinstance(config_data, dict):
|
|
29
|
+
raise ValueError(f"Remote config expected dict, got {type(config_data)}")
|
|
30
|
+
|
|
31
|
+
model_list = config_data.get("model_list", None)
|
|
32
|
+
if not isinstance(model_list, list):
|
|
33
|
+
raise ValueError(
|
|
34
|
+
f"Remote config expected list of models, got {type(model_list)}"
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
# We must be careful here, because some of the JSON data may be generated from a forward
|
|
38
|
+
# version of the code that has newer fields / versions of the fields, that may cause
|
|
39
|
+
# the current client this code is running on to fail to validate the item into a KilnModel.
|
|
40
|
+
models = []
|
|
41
|
+
for model_data in model_list:
|
|
42
|
+
# We skip any model that fails validation - the models that the client can support
|
|
43
|
+
# will be pulled from the remote config, but the user will need to update their
|
|
44
|
+
# client to the latest version to see the newer models that break backwards compatibility.
|
|
45
|
+
try:
|
|
46
|
+
providers_list = model_data.get("providers", [])
|
|
47
|
+
|
|
48
|
+
providers = []
|
|
49
|
+
for provider_data in providers_list:
|
|
50
|
+
try:
|
|
51
|
+
provider = KilnModelProvider.model_validate(provider_data)
|
|
52
|
+
providers.append(provider)
|
|
53
|
+
except ValidationError as e:
|
|
54
|
+
logger.warning(
|
|
55
|
+
"Failed to validate a model provider from remote config. Upgrade Kiln to use this model. Details %s: %s",
|
|
56
|
+
provider_data,
|
|
57
|
+
e,
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
# this ensures the model deserialization won't fail because of a bad provider
|
|
61
|
+
model_data["providers"] = []
|
|
62
|
+
|
|
63
|
+
# now we validate the model without its providers
|
|
64
|
+
model = KilnModel.model_validate(model_data)
|
|
65
|
+
|
|
66
|
+
# and we attach back the providers that passed our validation
|
|
67
|
+
model.providers = providers
|
|
68
|
+
models.append(model)
|
|
69
|
+
except ValidationError as e:
|
|
70
|
+
logger.warning(
|
|
71
|
+
"Failed to validate a model from remote config. Upgrade Kiln to use this model. Details %s: %s",
|
|
72
|
+
model_data,
|
|
73
|
+
e,
|
|
74
|
+
)
|
|
75
|
+
return models
|
|
25
76
|
|
|
26
77
|
|
|
27
78
|
def load_from_url(url: str) -> List[KilnModel]:
|
|
28
79
|
response = requests.get(url, timeout=10)
|
|
29
80
|
response.raise_for_status()
|
|
30
81
|
data = response.json()
|
|
31
|
-
|
|
32
|
-
model_data = data
|
|
33
|
-
else:
|
|
34
|
-
model_data = data.get("model_list", [])
|
|
35
|
-
return [KilnModel.model_validate(item) for item in model_data]
|
|
82
|
+
return deserialize_config_data(data)
|
|
36
83
|
|
|
37
84
|
|
|
38
85
|
def dump_builtin_config(path: str | Path) -> None:
|
|
@@ -6,7 +6,7 @@ from kiln_ai.adapters.prompt_builders import BasePromptBuilder, prompt_builder_f
|
|
|
6
6
|
from kiln_ai.datamodel import Priority, Project, Task, TaskRequirement, TaskRun
|
|
7
7
|
|
|
8
8
|
|
|
9
|
-
#
|
|
9
|
+
# We should add evaluator rating
|
|
10
10
|
class RepairTaskInput(BaseModel):
|
|
11
11
|
original_prompt: str
|
|
12
12
|
original_input: str
|
|
@@ -229,21 +229,20 @@ async def test_mocked_repair_task_run(sample_task, sample_task_run, sample_repai
|
|
|
229
229
|
"rating": 8,
|
|
230
230
|
}
|
|
231
231
|
|
|
232
|
+
run_config = RunConfigProperties(
|
|
233
|
+
model_name="llama_3_1_8b",
|
|
234
|
+
model_provider_name="ollama",
|
|
235
|
+
prompt_id="simple_prompt_builder",
|
|
236
|
+
structured_output_mode="json_schema",
|
|
237
|
+
)
|
|
238
|
+
|
|
232
239
|
with patch.object(LiteLlmAdapter, "_run", new_callable=AsyncMock) as mock_run:
|
|
233
240
|
mock_run.return_value = (
|
|
234
241
|
RunOutput(output=mocked_output, intermediate_outputs=None),
|
|
235
242
|
None,
|
|
236
243
|
)
|
|
237
244
|
|
|
238
|
-
adapter = adapter_for_task(
|
|
239
|
-
repair_task,
|
|
240
|
-
RunConfigProperties(
|
|
241
|
-
model_name="llama_3_1_8b",
|
|
242
|
-
model_provider_name="ollama",
|
|
243
|
-
prompt_id="simple_prompt_builder",
|
|
244
|
-
structured_output_mode="json_schema",
|
|
245
|
-
),
|
|
246
|
-
)
|
|
245
|
+
adapter = adapter_for_task(repair_task, run_config)
|
|
247
246
|
|
|
248
247
|
run = await adapter.invoke(repair_task_input.model_dump())
|
|
249
248
|
|
|
@@ -264,6 +263,10 @@ async def test_mocked_repair_task_run(sample_task, sample_task_run, sample_repai
|
|
|
264
263
|
}
|
|
265
264
|
assert run.input_source.type == DataSourceType.human
|
|
266
265
|
assert "created_by" in run.input_source.properties
|
|
266
|
+
assert run.output.source is not None
|
|
267
|
+
assert run.output.source.run_config is not None
|
|
268
|
+
saved_run_config = run.output.source.run_config.model_dump()
|
|
269
|
+
assert saved_run_config == run_config.model_dump()
|
|
267
270
|
|
|
268
271
|
# Verify that the mock was called
|
|
269
272
|
mock_run.assert_called_once()
|
kiln_ai/adapters/run_output.py
CHANGED
|
@@ -3,9 +3,12 @@ from typing import Dict
|
|
|
3
3
|
|
|
4
4
|
from litellm.types.utils import ChoiceLogprobs
|
|
5
5
|
|
|
6
|
+
from kiln_ai.utils.open_ai_types import ChatCompletionMessageParam
|
|
7
|
+
|
|
6
8
|
|
|
7
9
|
@dataclass
|
|
8
10
|
class RunOutput:
|
|
9
11
|
output: Dict | str
|
|
10
12
|
intermediate_outputs: Dict[str, str] | None
|
|
11
13
|
output_logprobs: ChoiceLogprobs | None = None
|
|
14
|
+
trace: list[ChatCompletionMessageParam] | None = None
|
|
@@ -8,6 +8,7 @@ from kiln_ai.adapters.ml_model_list import ModelProviderName
|
|
|
8
8
|
from kiln_ai.adapters.model_adapters.base_adapter import AdapterConfig
|
|
9
9
|
from kiln_ai.adapters.model_adapters.litellm_adapter import LiteLlmAdapter
|
|
10
10
|
from kiln_ai.adapters.provider_tools import kiln_model_provider_from
|
|
11
|
+
from kiln_ai.datamodel.datamodel_enums import StructuredOutputMode
|
|
11
12
|
from kiln_ai.datamodel.task import RunConfigProperties
|
|
12
13
|
|
|
13
14
|
|
|
@@ -16,6 +17,10 @@ def mock_config():
|
|
|
16
17
|
with patch("kiln_ai.adapters.adapter_registry.Config") as mock:
|
|
17
18
|
mock.shared.return_value.open_ai_api_key = "test-openai-key"
|
|
18
19
|
mock.shared.return_value.open_router_api_key = "test-openrouter-key"
|
|
20
|
+
mock.shared.return_value.siliconflow_cn_api_key = "test-siliconflow-key"
|
|
21
|
+
mock.shared.return_value.docker_model_runner_base_url = (
|
|
22
|
+
"http://localhost:12434/engines/llama.cpp"
|
|
23
|
+
)
|
|
19
24
|
yield mock
|
|
20
25
|
|
|
21
26
|
|
|
@@ -85,6 +90,33 @@ def test_openrouter_adapter_creation(mock_config, basic_task):
|
|
|
85
90
|
}
|
|
86
91
|
|
|
87
92
|
|
|
93
|
+
def test_siliconflow_adapter_creation(mock_config, basic_task):
|
|
94
|
+
adapter = adapter_for_task(
|
|
95
|
+
kiln_task=basic_task,
|
|
96
|
+
run_config_properties=RunConfigProperties(
|
|
97
|
+
model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-32B",
|
|
98
|
+
model_provider_name=ModelProviderName.siliconflow_cn,
|
|
99
|
+
prompt_id="simple_prompt_builder",
|
|
100
|
+
structured_output_mode="json_schema",
|
|
101
|
+
),
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
assert isinstance(adapter, LiteLlmAdapter)
|
|
105
|
+
assert (
|
|
106
|
+
adapter.config.run_config_properties.model_name
|
|
107
|
+
== "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"
|
|
108
|
+
)
|
|
109
|
+
assert adapter.config.additional_body_options == {"api_key": "test-siliconflow-key"}
|
|
110
|
+
assert (
|
|
111
|
+
adapter.config.run_config_properties.model_provider_name
|
|
112
|
+
== ModelProviderName.siliconflow_cn
|
|
113
|
+
)
|
|
114
|
+
assert adapter.config.default_headers == {
|
|
115
|
+
"HTTP-Referer": "https://kiln.tech/siliconflow",
|
|
116
|
+
"X-Title": "KilnAI",
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
|
|
88
120
|
@pytest.mark.parametrize(
|
|
89
121
|
"provider",
|
|
90
122
|
[
|
|
@@ -109,7 +141,7 @@ def test_openai_compatible_adapter_creation(mock_config, basic_task, provider):
|
|
|
109
141
|
assert adapter.run_config.model_name == "test-model"
|
|
110
142
|
|
|
111
143
|
|
|
112
|
-
#
|
|
144
|
+
# We should run for all cases
|
|
113
145
|
def test_custom_prompt_builder(mock_config, basic_task):
|
|
114
146
|
adapter = adapter_for_task(
|
|
115
147
|
kiln_task=basic_task,
|
|
@@ -124,7 +156,7 @@ def test_custom_prompt_builder(mock_config, basic_task):
|
|
|
124
156
|
assert adapter.run_config.prompt_id == "simple_chain_of_thought_prompt_builder"
|
|
125
157
|
|
|
126
158
|
|
|
127
|
-
#
|
|
159
|
+
# We should run for all cases
|
|
128
160
|
def test_tags_passed_through(mock_config, basic_task):
|
|
129
161
|
tags = ["test-tag-1", "test-tag-2"]
|
|
130
162
|
adapter = adapter_for_task(
|
|
@@ -232,3 +264,78 @@ async def test_fine_tune_provider(mock_config, basic_task, mock_finetune_from_id
|
|
|
232
264
|
)
|
|
233
265
|
# The actual model name from the fine tune object
|
|
234
266
|
assert provider.model_id == "test-model"
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
def test_docker_model_runner_adapter_creation(mock_config, basic_task):
|
|
270
|
+
"""Test Docker Model Runner adapter creation with default and custom base URL."""
|
|
271
|
+
adapter = adapter_for_task(
|
|
272
|
+
kiln_task=basic_task,
|
|
273
|
+
run_config_properties=RunConfigProperties(
|
|
274
|
+
model_name="llama_3_2_3b",
|
|
275
|
+
model_provider_name=ModelProviderName.docker_model_runner,
|
|
276
|
+
prompt_id="simple_prompt_builder",
|
|
277
|
+
structured_output_mode=StructuredOutputMode.json_schema,
|
|
278
|
+
),
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
assert isinstance(adapter, LiteLlmAdapter)
|
|
282
|
+
assert adapter.config.run_config_properties.model_name == "llama_3_2_3b"
|
|
283
|
+
assert adapter.config.additional_body_options == {"api_key": "DMR"}
|
|
284
|
+
assert (
|
|
285
|
+
adapter.config.run_config_properties.model_provider_name
|
|
286
|
+
== ModelProviderName.docker_model_runner
|
|
287
|
+
)
|
|
288
|
+
assert adapter.config.base_url == "http://localhost:12434/engines/llama.cpp/v1"
|
|
289
|
+
assert adapter.config.default_headers is None
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
def test_docker_model_runner_adapter_creation_with_custom_url(mock_config, basic_task):
|
|
293
|
+
"""Test Docker Model Runner adapter creation with custom base URL."""
|
|
294
|
+
mock_config.shared.return_value.docker_model_runner_base_url = (
|
|
295
|
+
"http://custom:8080/engines/llama.cpp"
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
adapter = adapter_for_task(
|
|
299
|
+
kiln_task=basic_task,
|
|
300
|
+
run_config_properties=RunConfigProperties(
|
|
301
|
+
model_name="llama_3_2_3b",
|
|
302
|
+
model_provider_name=ModelProviderName.docker_model_runner,
|
|
303
|
+
prompt_id="simple_prompt_builder",
|
|
304
|
+
structured_output_mode=StructuredOutputMode.json_schema,
|
|
305
|
+
),
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
assert isinstance(adapter, LiteLlmAdapter)
|
|
309
|
+
assert adapter.config.run_config_properties.model_name == "llama_3_2_3b"
|
|
310
|
+
assert adapter.config.additional_body_options == {"api_key": "DMR"}
|
|
311
|
+
assert (
|
|
312
|
+
adapter.config.run_config_properties.model_provider_name
|
|
313
|
+
== ModelProviderName.docker_model_runner
|
|
314
|
+
)
|
|
315
|
+
assert adapter.config.base_url == "http://custom:8080/engines/llama.cpp/v1"
|
|
316
|
+
assert adapter.config.default_headers is None
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
def test_docker_model_runner_adapter_creation_with_none_url(mock_config, basic_task):
|
|
320
|
+
"""Test Docker Model Runner adapter creation when config URL is None."""
|
|
321
|
+
mock_config.shared.return_value.docker_model_runner_base_url = None
|
|
322
|
+
|
|
323
|
+
adapter = adapter_for_task(
|
|
324
|
+
kiln_task=basic_task,
|
|
325
|
+
run_config_properties=RunConfigProperties(
|
|
326
|
+
model_name="llama_3_2_3b",
|
|
327
|
+
model_provider_name=ModelProviderName.docker_model_runner,
|
|
328
|
+
prompt_id="simple_prompt_builder",
|
|
329
|
+
structured_output_mode=StructuredOutputMode.json_schema,
|
|
330
|
+
),
|
|
331
|
+
)
|
|
332
|
+
|
|
333
|
+
assert isinstance(adapter, LiteLlmAdapter)
|
|
334
|
+
assert adapter.config.run_config_properties.model_name == "llama_3_2_3b"
|
|
335
|
+
assert adapter.config.additional_body_options == {"api_key": "DMR"}
|
|
336
|
+
assert (
|
|
337
|
+
adapter.config.run_config_properties.model_provider_name
|
|
338
|
+
== ModelProviderName.docker_model_runner
|
|
339
|
+
)
|
|
340
|
+
assert adapter.config.base_url == "http://localhost:12434/engines/llama.cpp/v1"
|
|
341
|
+
assert adapter.config.default_headers is None
|