kiln-ai 0.17.0__py3-none-any.whl → 0.19.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/adapter_registry.py +28 -0
- kiln_ai/adapters/chat/chat_formatter.py +0 -1
- kiln_ai/adapters/data_gen/data_gen_prompts.py +121 -36
- kiln_ai/adapters/data_gen/data_gen_task.py +51 -38
- kiln_ai/adapters/data_gen/test_data_gen_task.py +318 -37
- kiln_ai/adapters/eval/base_eval.py +6 -7
- kiln_ai/adapters/eval/eval_runner.py +5 -1
- kiln_ai/adapters/eval/g_eval.py +17 -12
- kiln_ai/adapters/eval/test_base_eval.py +8 -2
- kiln_ai/adapters/eval/test_eval_runner.py +6 -12
- kiln_ai/adapters/eval/test_g_eval.py +115 -5
- kiln_ai/adapters/eval/test_g_eval_data.py +1 -1
- kiln_ai/adapters/fine_tune/base_finetune.py +2 -6
- kiln_ai/adapters/fine_tune/dataset_formatter.py +1 -5
- kiln_ai/adapters/fine_tune/fireworks_finetune.py +32 -20
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +1 -1
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +30 -21
- kiln_ai/adapters/fine_tune/test_vertex_finetune.py +2 -7
- kiln_ai/adapters/fine_tune/together_finetune.py +1 -1
- kiln_ai/adapters/ml_model_list.py +926 -125
- kiln_ai/adapters/model_adapters/base_adapter.py +11 -7
- kiln_ai/adapters/model_adapters/litellm_adapter.py +23 -1
- kiln_ai/adapters/model_adapters/test_base_adapter.py +1 -2
- kiln_ai/adapters/model_adapters/test_litellm_adapter.py +70 -3
- kiln_ai/adapters/model_adapters/test_structured_output.py +13 -13
- kiln_ai/adapters/parsers/parser_registry.py +0 -2
- kiln_ai/adapters/parsers/r1_parser.py +0 -1
- kiln_ai/adapters/parsers/test_r1_parser.py +1 -1
- kiln_ai/adapters/provider_tools.py +20 -19
- kiln_ai/adapters/remote_config.py +113 -0
- kiln_ai/adapters/repair/repair_task.py +2 -7
- kiln_ai/adapters/test_adapter_registry.py +30 -2
- kiln_ai/adapters/test_ml_model_list.py +30 -0
- kiln_ai/adapters/test_prompt_adaptors.py +0 -4
- kiln_ai/adapters/test_provider_tools.py +18 -12
- kiln_ai/adapters/test_remote_config.py +456 -0
- kiln_ai/datamodel/basemodel.py +54 -28
- kiln_ai/datamodel/datamodel_enums.py +2 -0
- kiln_ai/datamodel/dataset_split.py +5 -3
- kiln_ai/datamodel/eval.py +35 -3
- kiln_ai/datamodel/finetune.py +2 -3
- kiln_ai/datamodel/project.py +3 -3
- kiln_ai/datamodel/prompt.py +2 -2
- kiln_ai/datamodel/prompt_id.py +4 -4
- kiln_ai/datamodel/task.py +6 -6
- kiln_ai/datamodel/task_output.py +1 -3
- kiln_ai/datamodel/task_run.py +0 -2
- kiln_ai/datamodel/test_basemodel.py +210 -18
- kiln_ai/datamodel/test_eval_model.py +152 -10
- kiln_ai/datamodel/test_model_perf.py +1 -1
- kiln_ai/datamodel/test_prompt_id.py +5 -1
- kiln_ai/datamodel/test_task.py +5 -0
- kiln_ai/utils/config.py +10 -0
- kiln_ai/utils/logging.py +4 -3
- {kiln_ai-0.17.0.dist-info → kiln_ai-0.19.0.dist-info}/METADATA +33 -3
- {kiln_ai-0.17.0.dist-info → kiln_ai-0.19.0.dist-info}/RECORD +58 -56
- {kiln_ai-0.17.0.dist-info → kiln_ai-0.19.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.17.0.dist-info → kiln_ai-0.19.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -3,10 +3,7 @@ from abc import ABCMeta, abstractmethod
|
|
|
3
3
|
from dataclasses import dataclass
|
|
4
4
|
from typing import Dict, Tuple
|
|
5
5
|
|
|
6
|
-
from kiln_ai.adapters.chat.chat_formatter import
|
|
7
|
-
ChatFormatter,
|
|
8
|
-
get_chat_formatter,
|
|
9
|
-
)
|
|
6
|
+
from kiln_ai.adapters.chat.chat_formatter import ChatFormatter, get_chat_formatter
|
|
10
7
|
from kiln_ai.adapters.ml_model_list import (
|
|
11
8
|
KilnModelProvider,
|
|
12
9
|
StructuredOutputMode,
|
|
@@ -156,9 +153,16 @@ class BaseAdapter(metaclass=ABCMeta):
|
|
|
156
153
|
)
|
|
157
154
|
|
|
158
155
|
# Validate reasoning content is present (if reasoning)
|
|
159
|
-
if
|
|
160
|
-
|
|
161
|
-
|
|
156
|
+
if (
|
|
157
|
+
provider.reasoning_capable
|
|
158
|
+
and (
|
|
159
|
+
not parsed_output.intermediate_outputs
|
|
160
|
+
or "reasoning" not in parsed_output.intermediate_outputs
|
|
161
|
+
)
|
|
162
|
+
and not (
|
|
163
|
+
provider.reasoning_optional_for_structured_output
|
|
164
|
+
and self.has_structured_output()
|
|
165
|
+
)
|
|
162
166
|
):
|
|
163
167
|
raise RuntimeError(
|
|
164
168
|
"Reasoning is required for this model, but no reasoning was returned."
|
|
@@ -235,7 +235,7 @@ class LiteLlmAdapter(BaseAdapter):
|
|
|
235
235
|
}
|
|
236
236
|
|
|
237
237
|
def build_extra_body(self, provider: KilnModelProvider) -> dict[str, Any]:
|
|
238
|
-
#
|
|
238
|
+
# Don't love having this logic here. But it's worth the usability improvement
|
|
239
239
|
# so better to keep it than exclude it. Should figure out how I want to isolate
|
|
240
240
|
# this sort of logic so it's config driven and can be overridden
|
|
241
241
|
|
|
@@ -251,6 +251,15 @@ class LiteLlmAdapter(BaseAdapter):
|
|
|
251
251
|
"exclude": False,
|
|
252
252
|
}
|
|
253
253
|
|
|
254
|
+
if provider.gemini_reasoning_enabled:
|
|
255
|
+
extra_body["reasoning"] = {
|
|
256
|
+
"enabled": True,
|
|
257
|
+
}
|
|
258
|
+
|
|
259
|
+
if provider.name == ModelProviderName.openrouter:
|
|
260
|
+
# Ask OpenRouter to include usage in the response (cost)
|
|
261
|
+
extra_body["usage"] = {"include": True}
|
|
262
|
+
|
|
254
263
|
if provider.anthropic_extended_thinking:
|
|
255
264
|
extra_body["thinking"] = {"type": "enabled", "budget_tokens": 4000}
|
|
256
265
|
|
|
@@ -276,6 +285,10 @@ class LiteLlmAdapter(BaseAdapter):
|
|
|
276
285
|
# Oddball case, R1 14/8/1.5B fail with this param, even though they support thinking params.
|
|
277
286
|
provider_options["require_parameters"] = False
|
|
278
287
|
|
|
288
|
+
# Siliconflow uses a bool flag for thinking, for some models
|
|
289
|
+
if provider.siliconflow_enable_thinking is not None:
|
|
290
|
+
extra_body["enable_thinking"] = provider.siliconflow_enable_thinking
|
|
291
|
+
|
|
279
292
|
if len(provider_options) > 0:
|
|
280
293
|
extra_body["provider"] = provider_options
|
|
281
294
|
|
|
@@ -321,6 +334,10 @@ class LiteLlmAdapter(BaseAdapter):
|
|
|
321
334
|
litellm_provider_name = "vertex_ai"
|
|
322
335
|
case ModelProviderName.together_ai:
|
|
323
336
|
litellm_provider_name = "together_ai"
|
|
337
|
+
case ModelProviderName.cerebras:
|
|
338
|
+
litellm_provider_name = "cerebras"
|
|
339
|
+
case ModelProviderName.siliconflow_cn:
|
|
340
|
+
is_custom = True
|
|
324
341
|
case ModelProviderName.openai_compatible:
|
|
325
342
|
is_custom = True
|
|
326
343
|
case ModelProviderName.kiln_custom_registry:
|
|
@@ -386,7 +403,12 @@ class LiteLlmAdapter(BaseAdapter):
|
|
|
386
403
|
|
|
387
404
|
def usage_from_response(self, response: ModelResponse) -> Usage | None:
|
|
388
405
|
litellm_usage = response.get("usage", None)
|
|
406
|
+
|
|
407
|
+
# LiteLLM isn't consistent in how it returns the cost.
|
|
389
408
|
cost = response._hidden_params.get("response_cost", None)
|
|
409
|
+
if cost is None and litellm_usage:
|
|
410
|
+
cost = litellm_usage.get("cost", None)
|
|
411
|
+
|
|
390
412
|
if not litellm_usage and not cost:
|
|
391
413
|
return None
|
|
392
414
|
|
|
@@ -4,7 +4,6 @@ import pytest
|
|
|
4
4
|
|
|
5
5
|
from kiln_ai.adapters.ml_model_list import KilnModelProvider, StructuredOutputMode
|
|
6
6
|
from kiln_ai.adapters.model_adapters.base_adapter import BaseAdapter, RunOutput
|
|
7
|
-
from kiln_ai.adapters.parsers.request_formatters import request_formatter_from_id
|
|
8
7
|
from kiln_ai.datamodel import Task
|
|
9
8
|
from kiln_ai.datamodel.datamodel_enums import ChatStrategy
|
|
10
9
|
from kiln_ai.datamodel.task import RunConfig, RunConfigProperties
|
|
@@ -103,7 +102,7 @@ async def test_model_provider_invalid_provider_model_name(base_task):
|
|
|
103
102
|
"""Test error when model or provider name is missing"""
|
|
104
103
|
# Test with missing model name
|
|
105
104
|
with pytest.raises(ValueError, match="Input should be"):
|
|
106
|
-
|
|
105
|
+
MockAdapter(
|
|
107
106
|
run_config=RunConfig(
|
|
108
107
|
task=base_task,
|
|
109
108
|
model_name="test_model",
|
|
@@ -7,9 +7,7 @@ import pytest
|
|
|
7
7
|
from kiln_ai.adapters.ml_model_list import ModelProviderName, StructuredOutputMode
|
|
8
8
|
from kiln_ai.adapters.model_adapters.base_adapter import AdapterConfig
|
|
9
9
|
from kiln_ai.adapters.model_adapters.litellm_adapter import LiteLlmAdapter
|
|
10
|
-
from kiln_ai.adapters.model_adapters.litellm_config import
|
|
11
|
-
LiteLlmConfig,
|
|
12
|
-
)
|
|
10
|
+
from kiln_ai.adapters.model_adapters.litellm_config import LiteLlmConfig
|
|
13
11
|
from kiln_ai.datamodel import Project, Task, Usage
|
|
14
12
|
from kiln_ai.datamodel.task import RunConfigProperties
|
|
15
13
|
|
|
@@ -242,6 +240,8 @@ def test_tool_call_params_strict(config, mock_task):
|
|
|
242
240
|
(ModelProviderName.huggingface, "huggingface"),
|
|
243
241
|
(ModelProviderName.vertex, "vertex_ai"),
|
|
244
242
|
(ModelProviderName.together_ai, "together_ai"),
|
|
243
|
+
# for openai-compatible providers, we expect openai as the provider name
|
|
244
|
+
(ModelProviderName.siliconflow_cn, "openai"),
|
|
245
245
|
],
|
|
246
246
|
)
|
|
247
247
|
def test_litellm_model_id_standard_providers(
|
|
@@ -352,6 +352,43 @@ def test_litellm_model_id_unknown_provider(config, mock_task):
|
|
|
352
352
|
adapter.litellm_model_id()
|
|
353
353
|
|
|
354
354
|
|
|
355
|
+
@pytest.mark.parametrize(
|
|
356
|
+
"provider_name,expected_usage_param",
|
|
357
|
+
[
|
|
358
|
+
(ModelProviderName.openrouter, {"usage": {"include": True}}),
|
|
359
|
+
(ModelProviderName.openai, {}),
|
|
360
|
+
(ModelProviderName.anthropic, {}),
|
|
361
|
+
(ModelProviderName.groq, {}),
|
|
362
|
+
],
|
|
363
|
+
)
|
|
364
|
+
def test_build_extra_body_openrouter_usage(
|
|
365
|
+
config, mock_task, provider_name, expected_usage_param
|
|
366
|
+
):
|
|
367
|
+
"""Test build_extra_body includes usage parameter for OpenRouter providers"""
|
|
368
|
+
adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
|
|
369
|
+
|
|
370
|
+
# Create a mock provider with the specified name and minimal required attributes
|
|
371
|
+
mock_provider = Mock()
|
|
372
|
+
mock_provider.name = provider_name
|
|
373
|
+
mock_provider.thinking_level = None
|
|
374
|
+
mock_provider.require_openrouter_reasoning = False
|
|
375
|
+
mock_provider.anthropic_extended_thinking = False
|
|
376
|
+
mock_provider.r1_openrouter_options = False
|
|
377
|
+
mock_provider.logprobs_openrouter_options = False
|
|
378
|
+
mock_provider.openrouter_skip_required_parameters = False
|
|
379
|
+
|
|
380
|
+
# Call build_extra_body
|
|
381
|
+
extra_body = adapter.build_extra_body(mock_provider)
|
|
382
|
+
|
|
383
|
+
# Verify the usage parameter is included only for OpenRouter
|
|
384
|
+
for key, value in expected_usage_param.items():
|
|
385
|
+
assert extra_body.get(key) == value
|
|
386
|
+
|
|
387
|
+
# Verify non-OpenRouter providers don't have the usage parameter
|
|
388
|
+
if provider_name != ModelProviderName.openrouter:
|
|
389
|
+
assert "usage" not in extra_body
|
|
390
|
+
|
|
391
|
+
|
|
355
392
|
@pytest.mark.asyncio
|
|
356
393
|
async def test_build_completion_kwargs_custom_temperature_top_p(config, mock_task):
|
|
357
394
|
"""Test build_completion_kwargs with custom temperature and top_p values"""
|
|
@@ -474,6 +511,17 @@ async def test_build_completion_kwargs(
|
|
|
474
511
|
({"prompt_tokens": 10}, None, None),
|
|
475
512
|
# Invalid cost type (should be ignored)
|
|
476
513
|
(None, "0.5", None),
|
|
514
|
+
# Cost in OpenRouter format
|
|
515
|
+
(
|
|
516
|
+
litellm.types.utils.Usage(
|
|
517
|
+
prompt_tokens=10,
|
|
518
|
+
completion_tokens=20,
|
|
519
|
+
total_tokens=30,
|
|
520
|
+
cost=0.5,
|
|
521
|
+
),
|
|
522
|
+
None,
|
|
523
|
+
Usage(input_tokens=10, output_tokens=20, total_tokens=30, cost=0.5),
|
|
524
|
+
),
|
|
477
525
|
],
|
|
478
526
|
)
|
|
479
527
|
def test_usage_from_response(config, mock_task, litellm_usage, cost, expected_usage):
|
|
@@ -504,3 +552,22 @@ def test_usage_from_response(config, mock_task, litellm_usage, cost, expected_us
|
|
|
504
552
|
|
|
505
553
|
# Verify the response was queried correctly
|
|
506
554
|
response.get.assert_called_once_with("usage", None)
|
|
555
|
+
|
|
556
|
+
|
|
557
|
+
@pytest.mark.parametrize(
|
|
558
|
+
"enable_thinking",
|
|
559
|
+
[
|
|
560
|
+
True,
|
|
561
|
+
False,
|
|
562
|
+
],
|
|
563
|
+
)
|
|
564
|
+
def test_build_extra_body_enable_thinking(config, mock_task, enable_thinking):
|
|
565
|
+
provider = Mock()
|
|
566
|
+
provider.name = ModelProviderName.siliconflow_cn
|
|
567
|
+
provider.siliconflow_enable_thinking = enable_thinking
|
|
568
|
+
|
|
569
|
+
adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
|
|
570
|
+
|
|
571
|
+
extra_body = adapter.build_extra_body(provider)
|
|
572
|
+
|
|
573
|
+
assert extra_body["enable_thinking"] == enable_thinking
|
|
@@ -6,14 +6,8 @@ import pytest
|
|
|
6
6
|
|
|
7
7
|
import kiln_ai.datamodel as datamodel
|
|
8
8
|
from kiln_ai.adapters.adapter_registry import adapter_for_task
|
|
9
|
-
from kiln_ai.adapters.ml_model_list import
|
|
10
|
-
|
|
11
|
-
)
|
|
12
|
-
from kiln_ai.adapters.model_adapters.base_adapter import (
|
|
13
|
-
BaseAdapter,
|
|
14
|
-
RunOutput,
|
|
15
|
-
Usage,
|
|
16
|
-
)
|
|
9
|
+
from kiln_ai.adapters.ml_model_list import built_in_models
|
|
10
|
+
from kiln_ai.adapters.model_adapters.base_adapter import BaseAdapter, RunOutput, Usage
|
|
17
11
|
from kiln_ai.adapters.ollama_tools import ollama_online
|
|
18
12
|
from kiln_ai.adapters.test_prompt_adaptors import get_all_models_and_providers
|
|
19
13
|
from kiln_ai.datamodel import PromptId
|
|
@@ -180,8 +174,14 @@ async def run_structured_output_test(tmp_path: Path, model_name: str, provider:
|
|
|
180
174
|
# Check reasoning models
|
|
181
175
|
assert a._model_provider is not None
|
|
182
176
|
if a._model_provider.reasoning_capable:
|
|
183
|
-
|
|
184
|
-
|
|
177
|
+
# some providers have reasoning_capable models that do not return the reasoning
|
|
178
|
+
# for structured output responses (they provide it only for non-structured output)
|
|
179
|
+
if a._model_provider.reasoning_optional_for_structured_output:
|
|
180
|
+
# models may be updated to include the reasoning in the future
|
|
181
|
+
assert "reasoning" not in run.intermediate_outputs
|
|
182
|
+
else:
|
|
183
|
+
assert "reasoning" in run.intermediate_outputs
|
|
184
|
+
assert isinstance(run.intermediate_outputs["reasoning"], str)
|
|
185
185
|
|
|
186
186
|
|
|
187
187
|
def build_structured_input_test_task(tmp_path: Path):
|
|
@@ -245,7 +245,7 @@ async def run_structured_input_task_no_validation(
|
|
|
245
245
|
try:
|
|
246
246
|
run = await a.invoke({"a": 2, "b": 2, "c": 2})
|
|
247
247
|
response = run.output.output
|
|
248
|
-
return response, a
|
|
248
|
+
return response, a, run
|
|
249
249
|
except ValueError as e:
|
|
250
250
|
if str(e) == "Failed to connect to Ollama. Ensure Ollama is running.":
|
|
251
251
|
pytest.skip(
|
|
@@ -260,7 +260,7 @@ async def run_structured_input_task(
|
|
|
260
260
|
provider: str,
|
|
261
261
|
prompt_id: PromptId,
|
|
262
262
|
):
|
|
263
|
-
response, a = await run_structured_input_task_no_validation(
|
|
263
|
+
response, a, run = await run_structured_input_task_no_validation(
|
|
264
264
|
task, model_name, provider, prompt_id
|
|
265
265
|
)
|
|
266
266
|
assert response is not None
|
|
@@ -350,7 +350,7 @@ When asked for a final result, this is the format (for an equilateral example):
|
|
|
350
350
|
"""
|
|
351
351
|
task.output_json_schema = json.dumps(triangle_schema)
|
|
352
352
|
task.save_to_file()
|
|
353
|
-
response, adapter = await run_structured_input_task_no_validation(
|
|
353
|
+
response, adapter, _ = await run_structured_input_task_no_validation(
|
|
354
354
|
task, model_name, provider_name, "simple_chain_of_thought_prompt_builder"
|
|
355
355
|
)
|
|
356
356
|
|
|
@@ -5,18 +5,13 @@ from typing import Dict, List
|
|
|
5
5
|
from kiln_ai.adapters.ml_model_list import (
|
|
6
6
|
KilnModel,
|
|
7
7
|
KilnModelProvider,
|
|
8
|
-
ModelName,
|
|
9
8
|
ModelParserID,
|
|
10
9
|
ModelProviderName,
|
|
11
10
|
StructuredOutputMode,
|
|
12
11
|
built_in_models,
|
|
13
12
|
)
|
|
14
|
-
from kiln_ai.adapters.model_adapters.litellm_config import
|
|
15
|
-
|
|
16
|
-
)
|
|
17
|
-
from kiln_ai.adapters.ollama_tools import (
|
|
18
|
-
get_ollama_connection,
|
|
19
|
-
)
|
|
13
|
+
from kiln_ai.adapters.model_adapters.litellm_config import LiteLlmConfig
|
|
14
|
+
from kiln_ai.adapters.ollama_tools import get_ollama_connection
|
|
20
15
|
from kiln_ai.datamodel import Finetune, Task
|
|
21
16
|
from kiln_ai.datamodel.datamodel_enums import ChatStrategy
|
|
22
17
|
from kiln_ai.datamodel.registry import project_from_id
|
|
@@ -75,30 +70,24 @@ def builtin_model_from(
|
|
|
75
70
|
name: str, provider_name: str | None = None
|
|
76
71
|
) -> KilnModelProvider | None:
|
|
77
72
|
"""
|
|
78
|
-
Gets a model
|
|
73
|
+
Gets a model provider from the built-in list of models.
|
|
79
74
|
|
|
80
75
|
Args:
|
|
81
76
|
name: The name of the model to get
|
|
82
77
|
provider_name: Optional specific provider to use (defaults to first available)
|
|
83
78
|
|
|
84
79
|
Returns:
|
|
85
|
-
A
|
|
86
|
-
|
|
87
|
-
Raises:
|
|
88
|
-
ValueError: If the model or provider is not found, or if the provider is misconfigured
|
|
80
|
+
A KilnModelProvider, or None if not found
|
|
89
81
|
"""
|
|
90
|
-
if name not in ModelName.__members__:
|
|
91
|
-
return None
|
|
92
|
-
|
|
93
82
|
# Select the model from built_in_models using the name
|
|
94
|
-
model = next(filter(lambda m: m.name == name, built_in_models))
|
|
83
|
+
model = next(filter(lambda m: m.name == name, built_in_models), None)
|
|
95
84
|
if model is None:
|
|
96
|
-
|
|
85
|
+
return None
|
|
97
86
|
|
|
98
|
-
# If a provider is provided, select the provider
|
|
87
|
+
# If a provider is provided, select the appropriate provider. Otherwise, use the first available.
|
|
99
88
|
provider: KilnModelProvider | None = None
|
|
100
89
|
if model.providers is None or len(model.providers) == 0:
|
|
101
|
-
|
|
90
|
+
return None
|
|
102
91
|
elif provider_name is None:
|
|
103
92
|
provider = model.providers[0]
|
|
104
93
|
else:
|
|
@@ -384,6 +373,10 @@ def provider_name_from_id(id: str) -> str:
|
|
|
384
373
|
return "Google Vertex AI"
|
|
385
374
|
case ModelProviderName.together_ai:
|
|
386
375
|
return "Together AI"
|
|
376
|
+
case ModelProviderName.siliconflow_cn:
|
|
377
|
+
return "SiliconFlow"
|
|
378
|
+
case ModelProviderName.cerebras:
|
|
379
|
+
return "Cerebras"
|
|
387
380
|
case _:
|
|
388
381
|
# triggers pyright warning if I miss a case
|
|
389
382
|
raise_exhaustive_enum_error(enum_id)
|
|
@@ -442,4 +435,12 @@ provider_warnings: Dict[ModelProviderName, ModelProviderWarning] = {
|
|
|
442
435
|
required_config_keys=["together_api_key"],
|
|
443
436
|
message="Attempted to use Together without an API key set. \nGet your API key from https://together.ai/settings/keys",
|
|
444
437
|
),
|
|
438
|
+
ModelProviderName.siliconflow_cn: ModelProviderWarning(
|
|
439
|
+
required_config_keys=["siliconflow_cn_api_key"],
|
|
440
|
+
message="Attempted to use SiliconFlow without an API key set. \nGet your API key from https://cloud.siliconflow.cn/account/ak",
|
|
441
|
+
),
|
|
442
|
+
ModelProviderName.cerebras: ModelProviderWarning(
|
|
443
|
+
required_config_keys=["cerebras_api_key"],
|
|
444
|
+
message="Attempted to use Cerebras without an API key set. \nGet your API key from https://cloud.cerebras.ai/platform",
|
|
445
|
+
),
|
|
445
446
|
}
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import json
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
import threading
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any, List
|
|
8
|
+
|
|
9
|
+
import requests
|
|
10
|
+
from pydantic import ValidationError
|
|
11
|
+
|
|
12
|
+
from .ml_model_list import KilnModel, KilnModelProvider, built_in_models
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def serialize_config(models: List[KilnModel], path: str | Path) -> None:
|
|
18
|
+
data = {"model_list": [m.model_dump(mode="json") for m in models]}
|
|
19
|
+
Path(path).write_text(json.dumps(data, indent=2, sort_keys=True))
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def deserialize_config_at_path(path: str | Path) -> List[KilnModel]:
|
|
23
|
+
raw = json.loads(Path(path).read_text())
|
|
24
|
+
return deserialize_config_data(raw)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def deserialize_config_data(config_data: Any) -> List[KilnModel]:
|
|
28
|
+
if not isinstance(config_data, dict):
|
|
29
|
+
raise ValueError(f"Remote config expected dict, got {type(config_data)}")
|
|
30
|
+
|
|
31
|
+
model_list = config_data.get("model_list", None)
|
|
32
|
+
if not isinstance(model_list, list):
|
|
33
|
+
raise ValueError(
|
|
34
|
+
f"Remote config expected list of models, got {type(model_list)}"
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
# We must be careful here, because some of the JSON data may be generated from a forward
|
|
38
|
+
# version of the code that has newer fields / versions of the fields, that may cause
|
|
39
|
+
# the current client this code is running on to fail to validate the item into a KilnModel.
|
|
40
|
+
models = []
|
|
41
|
+
for model_data in model_list:
|
|
42
|
+
# We skip any model that fails validation - the models that the client can support
|
|
43
|
+
# will be pulled from the remote config, but the user will need to update their
|
|
44
|
+
# client to the latest version to see the newer models that break backwards compatibility.
|
|
45
|
+
try:
|
|
46
|
+
providers_list = model_data.get("providers", [])
|
|
47
|
+
|
|
48
|
+
providers = []
|
|
49
|
+
for provider_data in providers_list:
|
|
50
|
+
try:
|
|
51
|
+
provider = KilnModelProvider.model_validate(provider_data)
|
|
52
|
+
providers.append(provider)
|
|
53
|
+
except ValidationError as e:
|
|
54
|
+
logger.warning(
|
|
55
|
+
"Failed to validate a model provider from remote config. Upgrade Kiln to use this model. Details %s: %s",
|
|
56
|
+
provider_data,
|
|
57
|
+
e,
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
# this ensures the model deserialization won't fail because of a bad provider
|
|
61
|
+
model_data["providers"] = []
|
|
62
|
+
|
|
63
|
+
# now we validate the model without its providers
|
|
64
|
+
model = KilnModel.model_validate(model_data)
|
|
65
|
+
|
|
66
|
+
# and we attach back the providers that passed our validation
|
|
67
|
+
model.providers = providers
|
|
68
|
+
models.append(model)
|
|
69
|
+
except ValidationError as e:
|
|
70
|
+
logger.warning(
|
|
71
|
+
"Failed to validate a model from remote config. Upgrade Kiln to use this model. Details %s: %s",
|
|
72
|
+
model_data,
|
|
73
|
+
e,
|
|
74
|
+
)
|
|
75
|
+
return models
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def load_from_url(url: str) -> List[KilnModel]:
|
|
79
|
+
response = requests.get(url, timeout=10)
|
|
80
|
+
response.raise_for_status()
|
|
81
|
+
data = response.json()
|
|
82
|
+
return deserialize_config_data(data)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def dump_builtin_config(path: str | Path) -> None:
|
|
86
|
+
serialize_config(built_in_models, path)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def load_remote_models(url: str) -> None:
|
|
90
|
+
if os.environ.get("KILN_SKIP_REMOTE_MODEL_LIST") == "true":
|
|
91
|
+
return
|
|
92
|
+
|
|
93
|
+
def fetch_and_replace() -> None:
|
|
94
|
+
try:
|
|
95
|
+
models = load_from_url(url)
|
|
96
|
+
built_in_models[:] = models
|
|
97
|
+
except Exception as exc:
|
|
98
|
+
# Do not crash startup, but surface the issue
|
|
99
|
+
logger.warning("Failed to fetch remote model list from %s: %s", url, exc)
|
|
100
|
+
|
|
101
|
+
thread = threading.Thread(target=fetch_and_replace, daemon=True)
|
|
102
|
+
thread.start()
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def main() -> None:
|
|
106
|
+
parser = argparse.ArgumentParser()
|
|
107
|
+
parser.add_argument("path", help="output path")
|
|
108
|
+
args = parser.parse_args()
|
|
109
|
+
dump_builtin_config(args.path)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
if __name__ == "__main__":
|
|
113
|
+
main()
|
|
@@ -1,17 +1,12 @@
|
|
|
1
1
|
import json
|
|
2
|
-
from typing import Type
|
|
3
2
|
|
|
4
3
|
from pydantic import BaseModel, Field
|
|
5
4
|
|
|
6
|
-
from kiln_ai.adapters.prompt_builders import
|
|
7
|
-
BasePromptBuilder,
|
|
8
|
-
SavedPromptBuilder,
|
|
9
|
-
prompt_builder_from_id,
|
|
10
|
-
)
|
|
5
|
+
from kiln_ai.adapters.prompt_builders import BasePromptBuilder, prompt_builder_from_id
|
|
11
6
|
from kiln_ai.datamodel import Priority, Project, Task, TaskRequirement, TaskRun
|
|
12
7
|
|
|
13
8
|
|
|
14
|
-
#
|
|
9
|
+
# We should add evaluator rating
|
|
15
10
|
class RepairTaskInput(BaseModel):
|
|
16
11
|
original_prompt: str
|
|
17
12
|
original_input: str
|
|
@@ -16,6 +16,7 @@ def mock_config():
|
|
|
16
16
|
with patch("kiln_ai.adapters.adapter_registry.Config") as mock:
|
|
17
17
|
mock.shared.return_value.open_ai_api_key = "test-openai-key"
|
|
18
18
|
mock.shared.return_value.open_router_api_key = "test-openrouter-key"
|
|
19
|
+
mock.shared.return_value.siliconflow_cn_api_key = "test-siliconflow-key"
|
|
19
20
|
yield mock
|
|
20
21
|
|
|
21
22
|
|
|
@@ -85,6 +86,33 @@ def test_openrouter_adapter_creation(mock_config, basic_task):
|
|
|
85
86
|
}
|
|
86
87
|
|
|
87
88
|
|
|
89
|
+
def test_siliconflow_adapter_creation(mock_config, basic_task):
|
|
90
|
+
adapter = adapter_for_task(
|
|
91
|
+
kiln_task=basic_task,
|
|
92
|
+
run_config_properties=RunConfigProperties(
|
|
93
|
+
model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-32B",
|
|
94
|
+
model_provider_name=ModelProviderName.siliconflow_cn,
|
|
95
|
+
prompt_id="simple_prompt_builder",
|
|
96
|
+
structured_output_mode="json_schema",
|
|
97
|
+
),
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
assert isinstance(adapter, LiteLlmAdapter)
|
|
101
|
+
assert (
|
|
102
|
+
adapter.config.run_config_properties.model_name
|
|
103
|
+
== "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"
|
|
104
|
+
)
|
|
105
|
+
assert adapter.config.additional_body_options == {"api_key": "test-siliconflow-key"}
|
|
106
|
+
assert (
|
|
107
|
+
adapter.config.run_config_properties.model_provider_name
|
|
108
|
+
== ModelProviderName.siliconflow_cn
|
|
109
|
+
)
|
|
110
|
+
assert adapter.config.default_headers == {
|
|
111
|
+
"HTTP-Referer": "https://getkiln.ai/siliconflow",
|
|
112
|
+
"X-Title": "KilnAI",
|
|
113
|
+
}
|
|
114
|
+
|
|
115
|
+
|
|
88
116
|
@pytest.mark.parametrize(
|
|
89
117
|
"provider",
|
|
90
118
|
[
|
|
@@ -109,7 +137,7 @@ def test_openai_compatible_adapter_creation(mock_config, basic_task, provider):
|
|
|
109
137
|
assert adapter.run_config.model_name == "test-model"
|
|
110
138
|
|
|
111
139
|
|
|
112
|
-
#
|
|
140
|
+
# We should run for all cases
|
|
113
141
|
def test_custom_prompt_builder(mock_config, basic_task):
|
|
114
142
|
adapter = adapter_for_task(
|
|
115
143
|
kiln_task=basic_task,
|
|
@@ -124,7 +152,7 @@ def test_custom_prompt_builder(mock_config, basic_task):
|
|
|
124
152
|
assert adapter.run_config.prompt_id == "simple_chain_of_thought_prompt_builder"
|
|
125
153
|
|
|
126
154
|
|
|
127
|
-
#
|
|
155
|
+
# We should run for all cases
|
|
128
156
|
def test_tags_passed_through(mock_config, basic_task):
|
|
129
157
|
tags = ["test-tag-1", "test-tag-2"]
|
|
130
158
|
adapter = adapter_for_task(
|
|
@@ -2,6 +2,7 @@ import pytest
|
|
|
2
2
|
|
|
3
3
|
from kiln_ai.adapters.ml_model_list import (
|
|
4
4
|
ModelName,
|
|
5
|
+
built_in_models,
|
|
5
6
|
default_structured_output_mode_for_model_provider,
|
|
6
7
|
get_model_by_name,
|
|
7
8
|
)
|
|
@@ -156,3 +157,32 @@ class TestDefaultStructuredOutputModeForModelProvider:
|
|
|
156
157
|
provider=first_provider.name,
|
|
157
158
|
)
|
|
158
159
|
assert result == first_provider.structured_output_mode
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def test_uncensored():
|
|
163
|
+
"""Test that uncensored is set correctly"""
|
|
164
|
+
model = get_model_by_name(ModelName.grok_3_mini)
|
|
165
|
+
for provider in model.providers:
|
|
166
|
+
assert provider.uncensored
|
|
167
|
+
assert not provider.suggested_for_uncensored_data_gen
|
|
168
|
+
|
|
169
|
+
model = get_model_by_name(ModelName.gpt_4_1_nano)
|
|
170
|
+
for provider in model.providers:
|
|
171
|
+
assert not provider.uncensored
|
|
172
|
+
assert not provider.suggested_for_uncensored_data_gen
|
|
173
|
+
|
|
174
|
+
model = get_model_by_name(ModelName.grok_4)
|
|
175
|
+
for provider in model.providers:
|
|
176
|
+
assert provider.uncensored
|
|
177
|
+
assert provider.suggested_for_uncensored_data_gen
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def test_no_reasoning_for_structured_output():
|
|
181
|
+
"""Test that no reasoning is returned for structured output"""
|
|
182
|
+
# get all models
|
|
183
|
+
for model in built_in_models:
|
|
184
|
+
for provider in model.providers:
|
|
185
|
+
if provider.reasoning_optional_for_structured_output is not None:
|
|
186
|
+
assert provider.reasoning_capable, (
|
|
187
|
+
f"{model.name} {provider.name} has reasoning_optional_for_structured_output but is not reasoning capable. This field should only be defined for models that are reasoning capable."
|
|
188
|
+
)
|
|
@@ -13,10 +13,6 @@ from kiln_ai.adapters.model_adapters.litellm_adapter import (
|
|
|
13
13
|
LiteLlmConfig,
|
|
14
14
|
)
|
|
15
15
|
from kiln_ai.adapters.ollama_tools import ollama_online
|
|
16
|
-
from kiln_ai.adapters.prompt_builders import (
|
|
17
|
-
BasePromptBuilder,
|
|
18
|
-
SimpleChainOfThoughtPromptBuilder,
|
|
19
|
-
)
|
|
20
16
|
from kiln_ai.datamodel import PromptId
|
|
21
17
|
from kiln_ai.datamodel.task import RunConfigProperties
|
|
22
18
|
|