kiln-ai 0.18.0__py3-none-any.whl → 0.19.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/adapter_registry.py +28 -0
- kiln_ai/adapters/data_gen/data_gen_task.py +2 -2
- kiln_ai/adapters/data_gen/test_data_gen_task.py +7 -3
- kiln_ai/adapters/eval/test_eval_runner.py +6 -12
- kiln_ai/adapters/eval/test_g_eval_data.py +1 -1
- kiln_ai/adapters/fine_tune/base_finetune.py +1 -0
- kiln_ai/adapters/fine_tune/fireworks_finetune.py +32 -20
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +30 -21
- kiln_ai/adapters/ml_model_list.py +635 -83
- kiln_ai/adapters/model_adapters/base_adapter.py +11 -7
- kiln_ai/adapters/model_adapters/litellm_adapter.py +14 -1
- kiln_ai/adapters/model_adapters/test_base_adapter.py +1 -1
- kiln_ai/adapters/model_adapters/test_litellm_adapter.py +22 -3
- kiln_ai/adapters/model_adapters/test_structured_output.py +10 -10
- kiln_ai/adapters/parsers/test_r1_parser.py +1 -1
- kiln_ai/adapters/provider_tools.py +20 -19
- kiln_ai/adapters/remote_config.py +57 -10
- kiln_ai/adapters/repair/repair_task.py +1 -1
- kiln_ai/adapters/test_adapter_registry.py +30 -2
- kiln_ai/adapters/test_ml_model_list.py +12 -0
- kiln_ai/adapters/test_provider_tools.py +18 -12
- kiln_ai/adapters/test_remote_config.py +372 -16
- kiln_ai/datamodel/basemodel.py +54 -28
- kiln_ai/datamodel/datamodel_enums.py +2 -0
- kiln_ai/datamodel/dataset_split.py +5 -3
- kiln_ai/datamodel/eval.py +3 -3
- kiln_ai/datamodel/finetune.py +2 -2
- kiln_ai/datamodel/project.py +3 -3
- kiln_ai/datamodel/prompt.py +2 -2
- kiln_ai/datamodel/prompt_id.py +4 -4
- kiln_ai/datamodel/task.py +6 -6
- kiln_ai/datamodel/task_output.py +1 -1
- kiln_ai/datamodel/test_basemodel.py +210 -18
- kiln_ai/datamodel/test_eval_model.py +6 -6
- kiln_ai/datamodel/test_model_perf.py +1 -1
- kiln_ai/datamodel/test_prompt_id.py +5 -1
- kiln_ai/datamodel/test_task.py +5 -0
- kiln_ai/utils/config.py +10 -0
- {kiln_ai-0.18.0.dist-info → kiln_ai-0.19.0.dist-info}/METADATA +32 -2
- {kiln_ai-0.18.0.dist-info → kiln_ai-0.19.0.dist-info}/RECORD +42 -42
- {kiln_ai-0.18.0.dist-info → kiln_ai-0.19.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.18.0.dist-info → kiln_ai-0.19.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -3,10 +3,7 @@ from abc import ABCMeta, abstractmethod
|
|
|
3
3
|
from dataclasses import dataclass
|
|
4
4
|
from typing import Dict, Tuple
|
|
5
5
|
|
|
6
|
-
from kiln_ai.adapters.chat.chat_formatter import
|
|
7
|
-
ChatFormatter,
|
|
8
|
-
get_chat_formatter,
|
|
9
|
-
)
|
|
6
|
+
from kiln_ai.adapters.chat.chat_formatter import ChatFormatter, get_chat_formatter
|
|
10
7
|
from kiln_ai.adapters.ml_model_list import (
|
|
11
8
|
KilnModelProvider,
|
|
12
9
|
StructuredOutputMode,
|
|
@@ -156,9 +153,16 @@ class BaseAdapter(metaclass=ABCMeta):
|
|
|
156
153
|
)
|
|
157
154
|
|
|
158
155
|
# Validate reasoning content is present (if reasoning)
|
|
159
|
-
if
|
|
160
|
-
|
|
161
|
-
|
|
156
|
+
if (
|
|
157
|
+
provider.reasoning_capable
|
|
158
|
+
and (
|
|
159
|
+
not parsed_output.intermediate_outputs
|
|
160
|
+
or "reasoning" not in parsed_output.intermediate_outputs
|
|
161
|
+
)
|
|
162
|
+
and not (
|
|
163
|
+
provider.reasoning_optional_for_structured_output
|
|
164
|
+
and self.has_structured_output()
|
|
165
|
+
)
|
|
162
166
|
):
|
|
163
167
|
raise RuntimeError(
|
|
164
168
|
"Reasoning is required for this model, but no reasoning was returned."
|
|
@@ -235,7 +235,7 @@ class LiteLlmAdapter(BaseAdapter):
|
|
|
235
235
|
}
|
|
236
236
|
|
|
237
237
|
def build_extra_body(self, provider: KilnModelProvider) -> dict[str, Any]:
|
|
238
|
-
#
|
|
238
|
+
# Don't love having this logic here. But it's worth the usability improvement
|
|
239
239
|
# so better to keep it than exclude it. Should figure out how I want to isolate
|
|
240
240
|
# this sort of logic so it's config driven and can be overridden
|
|
241
241
|
|
|
@@ -251,6 +251,11 @@ class LiteLlmAdapter(BaseAdapter):
|
|
|
251
251
|
"exclude": False,
|
|
252
252
|
}
|
|
253
253
|
|
|
254
|
+
if provider.gemini_reasoning_enabled:
|
|
255
|
+
extra_body["reasoning"] = {
|
|
256
|
+
"enabled": True,
|
|
257
|
+
}
|
|
258
|
+
|
|
254
259
|
if provider.name == ModelProviderName.openrouter:
|
|
255
260
|
# Ask OpenRouter to include usage in the response (cost)
|
|
256
261
|
extra_body["usage"] = {"include": True}
|
|
@@ -280,6 +285,10 @@ class LiteLlmAdapter(BaseAdapter):
|
|
|
280
285
|
# Oddball case, R1 14/8/1.5B fail with this param, even though they support thinking params.
|
|
281
286
|
provider_options["require_parameters"] = False
|
|
282
287
|
|
|
288
|
+
# Siliconflow uses a bool flag for thinking, for some models
|
|
289
|
+
if provider.siliconflow_enable_thinking is not None:
|
|
290
|
+
extra_body["enable_thinking"] = provider.siliconflow_enable_thinking
|
|
291
|
+
|
|
283
292
|
if len(provider_options) > 0:
|
|
284
293
|
extra_body["provider"] = provider_options
|
|
285
294
|
|
|
@@ -325,6 +334,10 @@ class LiteLlmAdapter(BaseAdapter):
|
|
|
325
334
|
litellm_provider_name = "vertex_ai"
|
|
326
335
|
case ModelProviderName.together_ai:
|
|
327
336
|
litellm_provider_name = "together_ai"
|
|
337
|
+
case ModelProviderName.cerebras:
|
|
338
|
+
litellm_provider_name = "cerebras"
|
|
339
|
+
case ModelProviderName.siliconflow_cn:
|
|
340
|
+
is_custom = True
|
|
328
341
|
case ModelProviderName.openai_compatible:
|
|
329
342
|
is_custom = True
|
|
330
343
|
case ModelProviderName.kiln_custom_registry:
|
|
@@ -102,7 +102,7 @@ async def test_model_provider_invalid_provider_model_name(base_task):
|
|
|
102
102
|
"""Test error when model or provider name is missing"""
|
|
103
103
|
# Test with missing model name
|
|
104
104
|
with pytest.raises(ValueError, match="Input should be"):
|
|
105
|
-
|
|
105
|
+
MockAdapter(
|
|
106
106
|
run_config=RunConfig(
|
|
107
107
|
task=base_task,
|
|
108
108
|
model_name="test_model",
|
|
@@ -7,9 +7,7 @@ import pytest
|
|
|
7
7
|
from kiln_ai.adapters.ml_model_list import ModelProviderName, StructuredOutputMode
|
|
8
8
|
from kiln_ai.adapters.model_adapters.base_adapter import AdapterConfig
|
|
9
9
|
from kiln_ai.adapters.model_adapters.litellm_adapter import LiteLlmAdapter
|
|
10
|
-
from kiln_ai.adapters.model_adapters.litellm_config import
|
|
11
|
-
LiteLlmConfig,
|
|
12
|
-
)
|
|
10
|
+
from kiln_ai.adapters.model_adapters.litellm_config import LiteLlmConfig
|
|
13
11
|
from kiln_ai.datamodel import Project, Task, Usage
|
|
14
12
|
from kiln_ai.datamodel.task import RunConfigProperties
|
|
15
13
|
|
|
@@ -242,6 +240,8 @@ def test_tool_call_params_strict(config, mock_task):
|
|
|
242
240
|
(ModelProviderName.huggingface, "huggingface"),
|
|
243
241
|
(ModelProviderName.vertex, "vertex_ai"),
|
|
244
242
|
(ModelProviderName.together_ai, "together_ai"),
|
|
243
|
+
# for openai-compatible providers, we expect openai as the provider name
|
|
244
|
+
(ModelProviderName.siliconflow_cn, "openai"),
|
|
245
245
|
],
|
|
246
246
|
)
|
|
247
247
|
def test_litellm_model_id_standard_providers(
|
|
@@ -552,3 +552,22 @@ def test_usage_from_response(config, mock_task, litellm_usage, cost, expected_us
|
|
|
552
552
|
|
|
553
553
|
# Verify the response was queried correctly
|
|
554
554
|
response.get.assert_called_once_with("usage", None)
|
|
555
|
+
|
|
556
|
+
|
|
557
|
+
@pytest.mark.parametrize(
|
|
558
|
+
"enable_thinking",
|
|
559
|
+
[
|
|
560
|
+
True,
|
|
561
|
+
False,
|
|
562
|
+
],
|
|
563
|
+
)
|
|
564
|
+
def test_build_extra_body_enable_thinking(config, mock_task, enable_thinking):
|
|
565
|
+
provider = Mock()
|
|
566
|
+
provider.name = ModelProviderName.siliconflow_cn
|
|
567
|
+
provider.siliconflow_enable_thinking = enable_thinking
|
|
568
|
+
|
|
569
|
+
adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
|
|
570
|
+
|
|
571
|
+
extra_body = adapter.build_extra_body(provider)
|
|
572
|
+
|
|
573
|
+
assert extra_body["enable_thinking"] == enable_thinking
|
|
@@ -6,14 +6,8 @@ import pytest
|
|
|
6
6
|
|
|
7
7
|
import kiln_ai.datamodel as datamodel
|
|
8
8
|
from kiln_ai.adapters.adapter_registry import adapter_for_task
|
|
9
|
-
from kiln_ai.adapters.ml_model_list import
|
|
10
|
-
|
|
11
|
-
)
|
|
12
|
-
from kiln_ai.adapters.model_adapters.base_adapter import (
|
|
13
|
-
BaseAdapter,
|
|
14
|
-
RunOutput,
|
|
15
|
-
Usage,
|
|
16
|
-
)
|
|
9
|
+
from kiln_ai.adapters.ml_model_list import built_in_models
|
|
10
|
+
from kiln_ai.adapters.model_adapters.base_adapter import BaseAdapter, RunOutput, Usage
|
|
17
11
|
from kiln_ai.adapters.ollama_tools import ollama_online
|
|
18
12
|
from kiln_ai.adapters.test_prompt_adaptors import get_all_models_and_providers
|
|
19
13
|
from kiln_ai.datamodel import PromptId
|
|
@@ -180,8 +174,14 @@ async def run_structured_output_test(tmp_path: Path, model_name: str, provider:
|
|
|
180
174
|
# Check reasoning models
|
|
181
175
|
assert a._model_provider is not None
|
|
182
176
|
if a._model_provider.reasoning_capable:
|
|
183
|
-
|
|
184
|
-
|
|
177
|
+
# some providers have reasoning_capable models that do not return the reasoning
|
|
178
|
+
# for structured output responses (they provide it only for non-structured output)
|
|
179
|
+
if a._model_provider.reasoning_optional_for_structured_output:
|
|
180
|
+
# models may be updated to include the reasoning in the future
|
|
181
|
+
assert "reasoning" not in run.intermediate_outputs
|
|
182
|
+
else:
|
|
183
|
+
assert "reasoning" in run.intermediate_outputs
|
|
184
|
+
assert isinstance(run.intermediate_outputs["reasoning"], str)
|
|
185
185
|
|
|
186
186
|
|
|
187
187
|
def build_structured_input_test_task(tmp_path: Path):
|
|
@@ -5,18 +5,13 @@ from typing import Dict, List
|
|
|
5
5
|
from kiln_ai.adapters.ml_model_list import (
|
|
6
6
|
KilnModel,
|
|
7
7
|
KilnModelProvider,
|
|
8
|
-
ModelName,
|
|
9
8
|
ModelParserID,
|
|
10
9
|
ModelProviderName,
|
|
11
10
|
StructuredOutputMode,
|
|
12
11
|
built_in_models,
|
|
13
12
|
)
|
|
14
|
-
from kiln_ai.adapters.model_adapters.litellm_config import
|
|
15
|
-
|
|
16
|
-
)
|
|
17
|
-
from kiln_ai.adapters.ollama_tools import (
|
|
18
|
-
get_ollama_connection,
|
|
19
|
-
)
|
|
13
|
+
from kiln_ai.adapters.model_adapters.litellm_config import LiteLlmConfig
|
|
14
|
+
from kiln_ai.adapters.ollama_tools import get_ollama_connection
|
|
20
15
|
from kiln_ai.datamodel import Finetune, Task
|
|
21
16
|
from kiln_ai.datamodel.datamodel_enums import ChatStrategy
|
|
22
17
|
from kiln_ai.datamodel.registry import project_from_id
|
|
@@ -75,30 +70,24 @@ def builtin_model_from(
|
|
|
75
70
|
name: str, provider_name: str | None = None
|
|
76
71
|
) -> KilnModelProvider | None:
|
|
77
72
|
"""
|
|
78
|
-
Gets a model
|
|
73
|
+
Gets a model provider from the built-in list of models.
|
|
79
74
|
|
|
80
75
|
Args:
|
|
81
76
|
name: The name of the model to get
|
|
82
77
|
provider_name: Optional specific provider to use (defaults to first available)
|
|
83
78
|
|
|
84
79
|
Returns:
|
|
85
|
-
A
|
|
86
|
-
|
|
87
|
-
Raises:
|
|
88
|
-
ValueError: If the model or provider is not found, or if the provider is misconfigured
|
|
80
|
+
A KilnModelProvider, or None if not found
|
|
89
81
|
"""
|
|
90
|
-
if name not in ModelName.__members__:
|
|
91
|
-
return None
|
|
92
|
-
|
|
93
82
|
# Select the model from built_in_models using the name
|
|
94
|
-
model = next(filter(lambda m: m.name == name, built_in_models))
|
|
83
|
+
model = next(filter(lambda m: m.name == name, built_in_models), None)
|
|
95
84
|
if model is None:
|
|
96
|
-
|
|
85
|
+
return None
|
|
97
86
|
|
|
98
|
-
# If a provider is provided, select the provider
|
|
87
|
+
# If a provider is provided, select the appropriate provider. Otherwise, use the first available.
|
|
99
88
|
provider: KilnModelProvider | None = None
|
|
100
89
|
if model.providers is None or len(model.providers) == 0:
|
|
101
|
-
|
|
90
|
+
return None
|
|
102
91
|
elif provider_name is None:
|
|
103
92
|
provider = model.providers[0]
|
|
104
93
|
else:
|
|
@@ -384,6 +373,10 @@ def provider_name_from_id(id: str) -> str:
|
|
|
384
373
|
return "Google Vertex AI"
|
|
385
374
|
case ModelProviderName.together_ai:
|
|
386
375
|
return "Together AI"
|
|
376
|
+
case ModelProviderName.siliconflow_cn:
|
|
377
|
+
return "SiliconFlow"
|
|
378
|
+
case ModelProviderName.cerebras:
|
|
379
|
+
return "Cerebras"
|
|
387
380
|
case _:
|
|
388
381
|
# triggers pyright warning if I miss a case
|
|
389
382
|
raise_exhaustive_enum_error(enum_id)
|
|
@@ -442,4 +435,12 @@ provider_warnings: Dict[ModelProviderName, ModelProviderWarning] = {
|
|
|
442
435
|
required_config_keys=["together_api_key"],
|
|
443
436
|
message="Attempted to use Together without an API key set. \nGet your API key from https://together.ai/settings/keys",
|
|
444
437
|
),
|
|
438
|
+
ModelProviderName.siliconflow_cn: ModelProviderWarning(
|
|
439
|
+
required_config_keys=["siliconflow_cn_api_key"],
|
|
440
|
+
message="Attempted to use SiliconFlow without an API key set. \nGet your API key from https://cloud.siliconflow.cn/account/ak",
|
|
441
|
+
),
|
|
442
|
+
ModelProviderName.cerebras: ModelProviderWarning(
|
|
443
|
+
required_config_keys=["cerebras_api_key"],
|
|
444
|
+
message="Attempted to use Cerebras without an API key set. \nGet your API key from https://cloud.cerebras.ai/platform",
|
|
445
|
+
),
|
|
445
446
|
}
|
|
@@ -4,11 +4,12 @@ import logging
|
|
|
4
4
|
import os
|
|
5
5
|
import threading
|
|
6
6
|
from pathlib import Path
|
|
7
|
-
from typing import List
|
|
7
|
+
from typing import Any, List
|
|
8
8
|
|
|
9
9
|
import requests
|
|
10
|
+
from pydantic import ValidationError
|
|
10
11
|
|
|
11
|
-
from .ml_model_list import KilnModel, built_in_models
|
|
12
|
+
from .ml_model_list import KilnModel, KilnModelProvider, built_in_models
|
|
12
13
|
|
|
13
14
|
logger = logging.getLogger(__name__)
|
|
14
15
|
|
|
@@ -18,21 +19,67 @@ def serialize_config(models: List[KilnModel], path: str | Path) -> None:
|
|
|
18
19
|
Path(path).write_text(json.dumps(data, indent=2, sort_keys=True))
|
|
19
20
|
|
|
20
21
|
|
|
21
|
-
def
|
|
22
|
+
def deserialize_config_at_path(path: str | Path) -> List[KilnModel]:
|
|
22
23
|
raw = json.loads(Path(path).read_text())
|
|
23
|
-
|
|
24
|
-
|
|
24
|
+
return deserialize_config_data(raw)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def deserialize_config_data(config_data: Any) -> List[KilnModel]:
|
|
28
|
+
if not isinstance(config_data, dict):
|
|
29
|
+
raise ValueError(f"Remote config expected dict, got {type(config_data)}")
|
|
30
|
+
|
|
31
|
+
model_list = config_data.get("model_list", None)
|
|
32
|
+
if not isinstance(model_list, list):
|
|
33
|
+
raise ValueError(
|
|
34
|
+
f"Remote config expected list of models, got {type(model_list)}"
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
# We must be careful here, because some of the JSON data may be generated from a forward
|
|
38
|
+
# version of the code that has newer fields / versions of the fields, that may cause
|
|
39
|
+
# the current client this code is running on to fail to validate the item into a KilnModel.
|
|
40
|
+
models = []
|
|
41
|
+
for model_data in model_list:
|
|
42
|
+
# We skip any model that fails validation - the models that the client can support
|
|
43
|
+
# will be pulled from the remote config, but the user will need to update their
|
|
44
|
+
# client to the latest version to see the newer models that break backwards compatibility.
|
|
45
|
+
try:
|
|
46
|
+
providers_list = model_data.get("providers", [])
|
|
47
|
+
|
|
48
|
+
providers = []
|
|
49
|
+
for provider_data in providers_list:
|
|
50
|
+
try:
|
|
51
|
+
provider = KilnModelProvider.model_validate(provider_data)
|
|
52
|
+
providers.append(provider)
|
|
53
|
+
except ValidationError as e:
|
|
54
|
+
logger.warning(
|
|
55
|
+
"Failed to validate a model provider from remote config. Upgrade Kiln to use this model. Details %s: %s",
|
|
56
|
+
provider_data,
|
|
57
|
+
e,
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
# this ensures the model deserialization won't fail because of a bad provider
|
|
61
|
+
model_data["providers"] = []
|
|
62
|
+
|
|
63
|
+
# now we validate the model without its providers
|
|
64
|
+
model = KilnModel.model_validate(model_data)
|
|
65
|
+
|
|
66
|
+
# and we attach back the providers that passed our validation
|
|
67
|
+
model.providers = providers
|
|
68
|
+
models.append(model)
|
|
69
|
+
except ValidationError as e:
|
|
70
|
+
logger.warning(
|
|
71
|
+
"Failed to validate a model from remote config. Upgrade Kiln to use this model. Details %s: %s",
|
|
72
|
+
model_data,
|
|
73
|
+
e,
|
|
74
|
+
)
|
|
75
|
+
return models
|
|
25
76
|
|
|
26
77
|
|
|
27
78
|
def load_from_url(url: str) -> List[KilnModel]:
|
|
28
79
|
response = requests.get(url, timeout=10)
|
|
29
80
|
response.raise_for_status()
|
|
30
81
|
data = response.json()
|
|
31
|
-
|
|
32
|
-
model_data = data
|
|
33
|
-
else:
|
|
34
|
-
model_data = data.get("model_list", [])
|
|
35
|
-
return [KilnModel.model_validate(item) for item in model_data]
|
|
82
|
+
return deserialize_config_data(data)
|
|
36
83
|
|
|
37
84
|
|
|
38
85
|
def dump_builtin_config(path: str | Path) -> None:
|
|
@@ -6,7 +6,7 @@ from kiln_ai.adapters.prompt_builders import BasePromptBuilder, prompt_builder_f
|
|
|
6
6
|
from kiln_ai.datamodel import Priority, Project, Task, TaskRequirement, TaskRun
|
|
7
7
|
|
|
8
8
|
|
|
9
|
-
#
|
|
9
|
+
# We should add evaluator rating
|
|
10
10
|
class RepairTaskInput(BaseModel):
|
|
11
11
|
original_prompt: str
|
|
12
12
|
original_input: str
|
|
@@ -16,6 +16,7 @@ def mock_config():
|
|
|
16
16
|
with patch("kiln_ai.adapters.adapter_registry.Config") as mock:
|
|
17
17
|
mock.shared.return_value.open_ai_api_key = "test-openai-key"
|
|
18
18
|
mock.shared.return_value.open_router_api_key = "test-openrouter-key"
|
|
19
|
+
mock.shared.return_value.siliconflow_cn_api_key = "test-siliconflow-key"
|
|
19
20
|
yield mock
|
|
20
21
|
|
|
21
22
|
|
|
@@ -85,6 +86,33 @@ def test_openrouter_adapter_creation(mock_config, basic_task):
|
|
|
85
86
|
}
|
|
86
87
|
|
|
87
88
|
|
|
89
|
+
def test_siliconflow_adapter_creation(mock_config, basic_task):
|
|
90
|
+
adapter = adapter_for_task(
|
|
91
|
+
kiln_task=basic_task,
|
|
92
|
+
run_config_properties=RunConfigProperties(
|
|
93
|
+
model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-32B",
|
|
94
|
+
model_provider_name=ModelProviderName.siliconflow_cn,
|
|
95
|
+
prompt_id="simple_prompt_builder",
|
|
96
|
+
structured_output_mode="json_schema",
|
|
97
|
+
),
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
assert isinstance(adapter, LiteLlmAdapter)
|
|
101
|
+
assert (
|
|
102
|
+
adapter.config.run_config_properties.model_name
|
|
103
|
+
== "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"
|
|
104
|
+
)
|
|
105
|
+
assert adapter.config.additional_body_options == {"api_key": "test-siliconflow-key"}
|
|
106
|
+
assert (
|
|
107
|
+
adapter.config.run_config_properties.model_provider_name
|
|
108
|
+
== ModelProviderName.siliconflow_cn
|
|
109
|
+
)
|
|
110
|
+
assert adapter.config.default_headers == {
|
|
111
|
+
"HTTP-Referer": "https://getkiln.ai/siliconflow",
|
|
112
|
+
"X-Title": "KilnAI",
|
|
113
|
+
}
|
|
114
|
+
|
|
115
|
+
|
|
88
116
|
@pytest.mark.parametrize(
|
|
89
117
|
"provider",
|
|
90
118
|
[
|
|
@@ -109,7 +137,7 @@ def test_openai_compatible_adapter_creation(mock_config, basic_task, provider):
|
|
|
109
137
|
assert adapter.run_config.model_name == "test-model"
|
|
110
138
|
|
|
111
139
|
|
|
112
|
-
#
|
|
140
|
+
# We should run for all cases
|
|
113
141
|
def test_custom_prompt_builder(mock_config, basic_task):
|
|
114
142
|
adapter = adapter_for_task(
|
|
115
143
|
kiln_task=basic_task,
|
|
@@ -124,7 +152,7 @@ def test_custom_prompt_builder(mock_config, basic_task):
|
|
|
124
152
|
assert adapter.run_config.prompt_id == "simple_chain_of_thought_prompt_builder"
|
|
125
153
|
|
|
126
154
|
|
|
127
|
-
#
|
|
155
|
+
# We should run for all cases
|
|
128
156
|
def test_tags_passed_through(mock_config, basic_task):
|
|
129
157
|
tags = ["test-tag-1", "test-tag-2"]
|
|
130
158
|
adapter = adapter_for_task(
|
|
@@ -2,6 +2,7 @@ import pytest
|
|
|
2
2
|
|
|
3
3
|
from kiln_ai.adapters.ml_model_list import (
|
|
4
4
|
ModelName,
|
|
5
|
+
built_in_models,
|
|
5
6
|
default_structured_output_mode_for_model_provider,
|
|
6
7
|
get_model_by_name,
|
|
7
8
|
)
|
|
@@ -174,3 +175,14 @@ def test_uncensored():
|
|
|
174
175
|
for provider in model.providers:
|
|
175
176
|
assert provider.uncensored
|
|
176
177
|
assert provider.suggested_for_uncensored_data_gen
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def test_no_reasoning_for_structured_output():
|
|
181
|
+
"""Test that no reasoning is returned for structured output"""
|
|
182
|
+
# get all models
|
|
183
|
+
for model in built_in_models:
|
|
184
|
+
for provider in model.providers:
|
|
185
|
+
if provider.reasoning_optional_for_structured_output is not None:
|
|
186
|
+
assert provider.reasoning_capable, (
|
|
187
|
+
f"{model.name} {provider.name} has reasoning_optional_for_structured_output but is not reasoning capable. This field should only be defined for models that are reasoning capable."
|
|
188
|
+
)
|
|
@@ -25,11 +25,7 @@ from kiln_ai.adapters.provider_tools import (
|
|
|
25
25
|
provider_name_from_id,
|
|
26
26
|
provider_warnings,
|
|
27
27
|
)
|
|
28
|
-
from kiln_ai.datamodel import
|
|
29
|
-
Finetune,
|
|
30
|
-
StructuredOutputMode,
|
|
31
|
-
Task,
|
|
32
|
-
)
|
|
28
|
+
from kiln_ai.datamodel import Finetune, StructuredOutputMode, Task
|
|
33
29
|
from kiln_ai.datamodel.datamodel_enums import ChatStrategy
|
|
34
30
|
from kiln_ai.datamodel.task import RunConfigProperties
|
|
35
31
|
|
|
@@ -199,6 +195,7 @@ def test_provider_name_from_id_case_sensitivity():
|
|
|
199
195
|
(ModelProviderName.ollama, "Ollama"),
|
|
200
196
|
(ModelProviderName.openai, "OpenAI"),
|
|
201
197
|
(ModelProviderName.fireworks_ai, "Fireworks AI"),
|
|
198
|
+
(ModelProviderName.siliconflow_cn, "SiliconFlow"),
|
|
202
199
|
(ModelProviderName.kiln_fine_tune, "Fine Tuned Models"),
|
|
203
200
|
(ModelProviderName.kiln_custom_registry, "Custom Models"),
|
|
204
201
|
],
|
|
@@ -420,6 +417,17 @@ async def test_builtin_model_from_invalid_provider(mock_config):
|
|
|
420
417
|
assert provider is None
|
|
421
418
|
|
|
422
419
|
|
|
420
|
+
@pytest.mark.asyncio
|
|
421
|
+
async def test_builtin_model_future_proof():
|
|
422
|
+
"""Test handling of a model that doesn't exist yet but could be added over the air"""
|
|
423
|
+
with patch("kiln_ai.adapters.provider_tools.built_in_models") as mock_models:
|
|
424
|
+
mock_models.__iter__.return_value = []
|
|
425
|
+
|
|
426
|
+
# should not find it, but should not raise an error
|
|
427
|
+
result = builtin_model_from("gpt_99")
|
|
428
|
+
assert result is None
|
|
429
|
+
|
|
430
|
+
|
|
423
431
|
@pytest.mark.asyncio
|
|
424
432
|
async def test_builtin_model_from_model_no_providers():
|
|
425
433
|
"""Test handling of a model with no providers"""
|
|
@@ -433,10 +441,8 @@ async def test_builtin_model_from_model_no_providers():
|
|
|
433
441
|
)
|
|
434
442
|
mock_models.__iter__.return_value = [mock_model]
|
|
435
443
|
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
assert str(exc_info.value) == f"Model {ModelName.phi_3_5} has no providers"
|
|
444
|
+
result = builtin_model_from(ModelName.phi_3_5)
|
|
445
|
+
assert result is None
|
|
440
446
|
|
|
441
447
|
|
|
442
448
|
@pytest.mark.asyncio
|
|
@@ -461,7 +467,7 @@ def test_finetune_provider_model_success(mock_project, mock_task, mock_finetune)
|
|
|
461
467
|
assert provider.model_id == "ft:gpt-3.5-turbo:custom:model-123"
|
|
462
468
|
assert provider.structured_output_mode == StructuredOutputMode.json_schema
|
|
463
469
|
assert provider.reasoning_capable is False
|
|
464
|
-
assert provider.parser
|
|
470
|
+
assert provider.parser is None
|
|
465
471
|
|
|
466
472
|
|
|
467
473
|
def test_finetune_provider_model_success_final_and_intermediate(
|
|
@@ -476,7 +482,7 @@ def test_finetune_provider_model_success_final_and_intermediate(
|
|
|
476
482
|
assert provider.model_id == "ft:gpt-3.5-turbo:custom:model-123"
|
|
477
483
|
assert provider.structured_output_mode == StructuredOutputMode.json_schema
|
|
478
484
|
assert provider.reasoning_capable is False
|
|
479
|
-
assert provider.parser
|
|
485
|
+
assert provider.parser is None
|
|
480
486
|
|
|
481
487
|
|
|
482
488
|
def test_finetune_provider_model_success_r1_compatible(
|
|
@@ -590,7 +596,7 @@ def test_finetune_provider_model_structured_mode(
|
|
|
590
596
|
assert provider.model_id == "fireworks-model-123"
|
|
591
597
|
assert provider.structured_output_mode == expected_mode
|
|
592
598
|
assert provider.reasoning_capable is False
|
|
593
|
-
assert provider.parser
|
|
599
|
+
assert provider.parser is None
|
|
594
600
|
|
|
595
601
|
|
|
596
602
|
def test_openai_compatible_provider_config(mock_shared_config):
|