kiln-ai 0.18.0__py3-none-any.whl → 0.20.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (89) hide show
  1. kiln_ai/adapters/__init__.py +2 -2
  2. kiln_ai/adapters/adapter_registry.py +46 -0
  3. kiln_ai/adapters/chat/chat_formatter.py +8 -12
  4. kiln_ai/adapters/chat/test_chat_formatter.py +6 -2
  5. kiln_ai/adapters/data_gen/data_gen_task.py +2 -2
  6. kiln_ai/adapters/data_gen/test_data_gen_task.py +7 -3
  7. kiln_ai/adapters/docker_model_runner_tools.py +119 -0
  8. kiln_ai/adapters/eval/base_eval.py +2 -2
  9. kiln_ai/adapters/eval/eval_runner.py +3 -1
  10. kiln_ai/adapters/eval/g_eval.py +2 -2
  11. kiln_ai/adapters/eval/test_base_eval.py +1 -1
  12. kiln_ai/adapters/eval/test_eval_runner.py +6 -12
  13. kiln_ai/adapters/eval/test_g_eval.py +3 -4
  14. kiln_ai/adapters/eval/test_g_eval_data.py +1 -1
  15. kiln_ai/adapters/fine_tune/__init__.py +1 -1
  16. kiln_ai/adapters/fine_tune/base_finetune.py +1 -0
  17. kiln_ai/adapters/fine_tune/fireworks_finetune.py +32 -20
  18. kiln_ai/adapters/fine_tune/openai_finetune.py +14 -4
  19. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +30 -21
  20. kiln_ai/adapters/fine_tune/test_openai_finetune.py +108 -111
  21. kiln_ai/adapters/ml_model_list.py +1009 -111
  22. kiln_ai/adapters/model_adapters/base_adapter.py +62 -28
  23. kiln_ai/adapters/model_adapters/litellm_adapter.py +397 -80
  24. kiln_ai/adapters/model_adapters/test_base_adapter.py +194 -18
  25. kiln_ai/adapters/model_adapters/test_litellm_adapter.py +428 -4
  26. kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py +1103 -0
  27. kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +5 -5
  28. kiln_ai/adapters/model_adapters/test_structured_output.py +120 -14
  29. kiln_ai/adapters/parsers/__init__.py +1 -1
  30. kiln_ai/adapters/parsers/test_r1_parser.py +1 -1
  31. kiln_ai/adapters/provider_tools.py +35 -20
  32. kiln_ai/adapters/remote_config.py +57 -10
  33. kiln_ai/adapters/repair/repair_task.py +1 -1
  34. kiln_ai/adapters/repair/test_repair_task.py +12 -9
  35. kiln_ai/adapters/run_output.py +3 -0
  36. kiln_ai/adapters/test_adapter_registry.py +109 -2
  37. kiln_ai/adapters/test_docker_model_runner_tools.py +305 -0
  38. kiln_ai/adapters/test_ml_model_list.py +51 -1
  39. kiln_ai/adapters/test_prompt_adaptors.py +13 -6
  40. kiln_ai/adapters/test_provider_tools.py +73 -12
  41. kiln_ai/adapters/test_remote_config.py +470 -16
  42. kiln_ai/datamodel/__init__.py +23 -21
  43. kiln_ai/datamodel/basemodel.py +54 -28
  44. kiln_ai/datamodel/datamodel_enums.py +3 -0
  45. kiln_ai/datamodel/dataset_split.py +5 -3
  46. kiln_ai/datamodel/eval.py +4 -4
  47. kiln_ai/datamodel/external_tool_server.py +298 -0
  48. kiln_ai/datamodel/finetune.py +2 -2
  49. kiln_ai/datamodel/json_schema.py +25 -10
  50. kiln_ai/datamodel/project.py +11 -4
  51. kiln_ai/datamodel/prompt.py +2 -2
  52. kiln_ai/datamodel/prompt_id.py +4 -4
  53. kiln_ai/datamodel/registry.py +0 -15
  54. kiln_ai/datamodel/run_config.py +62 -0
  55. kiln_ai/datamodel/task.py +8 -83
  56. kiln_ai/datamodel/task_output.py +7 -2
  57. kiln_ai/datamodel/task_run.py +41 -0
  58. kiln_ai/datamodel/test_basemodel.py +213 -21
  59. kiln_ai/datamodel/test_eval_model.py +6 -6
  60. kiln_ai/datamodel/test_example_models.py +175 -0
  61. kiln_ai/datamodel/test_external_tool_server.py +691 -0
  62. kiln_ai/datamodel/test_model_perf.py +1 -1
  63. kiln_ai/datamodel/test_prompt_id.py +5 -1
  64. kiln_ai/datamodel/test_registry.py +8 -3
  65. kiln_ai/datamodel/test_task.py +20 -47
  66. kiln_ai/datamodel/test_tool_id.py +239 -0
  67. kiln_ai/datamodel/tool_id.py +83 -0
  68. kiln_ai/tools/__init__.py +8 -0
  69. kiln_ai/tools/base_tool.py +82 -0
  70. kiln_ai/tools/built_in_tools/__init__.py +13 -0
  71. kiln_ai/tools/built_in_tools/math_tools.py +124 -0
  72. kiln_ai/tools/built_in_tools/test_math_tools.py +204 -0
  73. kiln_ai/tools/mcp_server_tool.py +95 -0
  74. kiln_ai/tools/mcp_session_manager.py +243 -0
  75. kiln_ai/tools/test_base_tools.py +199 -0
  76. kiln_ai/tools/test_mcp_server_tool.py +457 -0
  77. kiln_ai/tools/test_mcp_session_manager.py +1585 -0
  78. kiln_ai/tools/test_tool_registry.py +473 -0
  79. kiln_ai/tools/tool_registry.py +64 -0
  80. kiln_ai/utils/config.py +32 -0
  81. kiln_ai/utils/open_ai_types.py +94 -0
  82. kiln_ai/utils/project_utils.py +17 -0
  83. kiln_ai/utils/test_config.py +138 -1
  84. kiln_ai/utils/test_open_ai_types.py +131 -0
  85. {kiln_ai-0.18.0.dist-info → kiln_ai-0.20.1.dist-info}/METADATA +37 -6
  86. kiln_ai-0.20.1.dist-info/RECORD +138 -0
  87. kiln_ai-0.18.0.dist-info/RECORD +0 -115
  88. {kiln_ai-0.18.0.dist-info → kiln_ai-0.20.1.dist-info}/WHEEL +0 -0
  89. {kiln_ai-0.18.0.dist-info → kiln_ai-0.20.1.dist-info}/licenses/LICENSE.txt +0 -0
@@ -13,7 +13,7 @@ from kiln_ai.datamodel import (
13
13
  Task,
14
14
  Usage,
15
15
  )
16
- from kiln_ai.datamodel.task import RunConfig
16
+ from kiln_ai.datamodel.task import RunConfigProperties
17
17
  from kiln_ai.utils.config import Config
18
18
 
19
19
 
@@ -41,8 +41,8 @@ def test_task(tmp_path):
41
41
  @pytest.fixture
42
42
  def adapter(test_task):
43
43
  return MockAdapter(
44
- run_config=RunConfig(
45
- task=test_task,
44
+ task=test_task,
45
+ run_config=RunConfigProperties(
46
46
  model_name="phi_3_5",
47
47
  model_provider_name="ollama",
48
48
  prompt_id="simple_chain_of_thought_prompt_builder",
@@ -240,8 +240,8 @@ async def test_autosave_true(test_task, adapter):
240
240
  def test_properties_for_task_output_custom_values(test_task):
241
241
  """Test that _properties_for_task_output includes custom temperature, top_p, and structured_output_mode"""
242
242
  adapter = MockAdapter(
243
- run_config=RunConfig(
244
- task=test_task,
243
+ task=test_task,
244
+ run_config=RunConfigProperties(
245
245
  model_name="gpt-4",
246
246
  model_provider_name="openai",
247
247
  prompt_id="simple_prompt_builder",
@@ -1,23 +1,19 @@
1
1
  import json
2
2
  from pathlib import Path
3
3
  from typing import Dict
4
+ from unittest.mock import Mock, patch
4
5
 
5
6
  import pytest
7
+ from litellm.types.utils import ModelResponse
6
8
 
7
9
  import kiln_ai.datamodel as datamodel
8
10
  from kiln_ai.adapters.adapter_registry import adapter_for_task
9
- from kiln_ai.adapters.ml_model_list import (
10
- built_in_models,
11
- )
12
- from kiln_ai.adapters.model_adapters.base_adapter import (
13
- BaseAdapter,
14
- RunOutput,
15
- Usage,
16
- )
11
+ from kiln_ai.adapters.ml_model_list import built_in_models
12
+ from kiln_ai.adapters.model_adapters.base_adapter import BaseAdapter, RunOutput, Usage
17
13
  from kiln_ai.adapters.ollama_tools import ollama_online
18
14
  from kiln_ai.adapters.test_prompt_adaptors import get_all_models_and_providers
19
15
  from kiln_ai.datamodel import PromptId
20
- from kiln_ai.datamodel.task import RunConfig, RunConfigProperties
16
+ from kiln_ai.datamodel.task import RunConfigProperties
21
17
  from kiln_ai.datamodel.test_json_schema import json_joke_schema, json_triangle_schema
22
18
 
23
19
 
@@ -46,8 +42,8 @@ async def test_structured_output_ollama(tmp_path, model_name):
46
42
  class MockAdapter(BaseAdapter):
47
43
  def __init__(self, kiln_task: datamodel.Task, response: Dict | str | None):
48
44
  super().__init__(
49
- run_config=RunConfig(
50
- task=kiln_task,
45
+ task=kiln_task,
46
+ run_config=RunConfigProperties(
51
47
  model_name="phi_3_5",
52
48
  model_provider_name="ollama",
53
49
  prompt_id="simple_chain_of_thought_prompt_builder",
@@ -180,8 +176,14 @@ async def run_structured_output_test(tmp_path: Path, model_name: str, provider:
180
176
  # Check reasoning models
181
177
  assert a._model_provider is not None
182
178
  if a._model_provider.reasoning_capable:
183
- assert "reasoning" in run.intermediate_outputs
184
- assert isinstance(run.intermediate_outputs["reasoning"], str)
179
+ # some providers have reasoning_capable models that do not return the reasoning
180
+ # for structured output responses (they provide it only for non-structured output)
181
+ if a._model_provider.reasoning_optional_for_structured_output:
182
+ # models may be updated to include the reasoning in the future
183
+ assert "reasoning" not in run.intermediate_outputs
184
+ else:
185
+ assert "reasoning" in run.intermediate_outputs
186
+ assert isinstance(run.intermediate_outputs["reasoning"], str)
185
187
 
186
188
 
187
189
  def build_structured_input_test_task(tmp_path: Path):
@@ -259,6 +261,7 @@ async def run_structured_input_task(
259
261
  model_name: str,
260
262
  provider: str,
261
263
  prompt_id: PromptId,
264
+ verify_trace_cot: bool = False,
262
265
  ):
263
266
  response, a, run = await run_structured_input_task_no_validation(
264
267
  task, model_name, provider, prompt_id
@@ -282,6 +285,32 @@ async def run_structured_input_task(
282
285
  assert "reasoning" in run.intermediate_outputs
283
286
  assert isinstance(run.intermediate_outputs["reasoning"], str)
284
287
 
288
+ # Check the trace
289
+ trace = run.trace
290
+ assert trace is not None
291
+ if verify_trace_cot:
292
+ assert len(trace) == 5
293
+ assert trace[0]["role"] == "system"
294
+ assert "You are an assistant which classifies a triangle" in trace[0]["content"]
295
+ assert trace[1]["role"] == "user"
296
+ assert trace[2]["role"] == "assistant"
297
+ assert trace[2].get("tool_calls") is None
298
+ assert trace[3]["role"] == "user"
299
+ assert trace[4]["role"] == "assistant"
300
+ assert trace[4].get("tool_calls") is None
301
+ else:
302
+ assert len(trace) == 3
303
+ assert trace[0]["role"] == "system"
304
+ assert "You are an assistant which classifies a triangle" in trace[0]["content"]
305
+ assert trace[1]["role"] == "user"
306
+ json_content = json.loads(trace[1]["content"])
307
+ assert json_content["a"] == 2
308
+ assert json_content["b"] == 2
309
+ assert json_content["c"] == 2
310
+ assert trace[2]["role"] == "assistant"
311
+ assert trace[2].get("tool_calls") is None
312
+ assert "[[equilateral]]" in trace[2]["content"]
313
+
285
314
 
286
315
  @pytest.mark.paid
287
316
  async def test_structured_input_gpt_4o_mini(tmp_path):
@@ -299,14 +328,91 @@ async def test_all_built_in_models_structured_input(
299
328
  )
300
329
 
301
330
 
331
+ async def test_all_built_in_models_structured_input_mocked(tmp_path):
332
+ mock_response = ModelResponse(
333
+ model="gpt-4o-mini",
334
+ choices=[
335
+ {
336
+ "message": {
337
+ "content": "The answer is [[equilateral]]",
338
+ }
339
+ }
340
+ ],
341
+ )
342
+
343
+ # Mock the Config.shared() method to return a mock config with required attributes
344
+ mock_config = Mock()
345
+ mock_config.open_ai_api_key = "mock_api_key"
346
+ mock_config.user_id = "test_user"
347
+
348
+ with (
349
+ patch(
350
+ "litellm.acompletion",
351
+ side_effect=[mock_response],
352
+ ),
353
+ patch("kiln_ai.utils.config.Config.shared", return_value=mock_config),
354
+ ):
355
+ await run_structured_input_test(
356
+ tmp_path, "llama_3_1_8b", "groq", "simple_prompt_builder"
357
+ )
358
+
359
+
302
360
  @pytest.mark.paid
303
361
  @pytest.mark.ollama
304
362
  @pytest.mark.parametrize("model_name,provider_name", get_all_models_and_providers())
305
363
  async def test_structured_input_cot_prompt_builder(tmp_path, model_name, provider_name):
306
364
  task = build_structured_input_test_task(tmp_path)
307
365
  await run_structured_input_task(
308
- task, model_name, provider_name, "simple_chain_of_thought_prompt_builder"
366
+ task,
367
+ model_name,
368
+ provider_name,
369
+ "simple_chain_of_thought_prompt_builder",
370
+ verify_trace_cot=True,
371
+ )
372
+
373
+
374
+ async def test_structured_input_cot_prompt_builder_mocked(tmp_path):
375
+ task = build_structured_input_test_task(tmp_path)
376
+ mock_response_1 = ModelResponse(
377
+ model="gpt-4o-mini",
378
+ choices=[
379
+ {
380
+ "message": {
381
+ "content": "I'm thinking real hard... oh!",
382
+ }
383
+ }
384
+ ],
309
385
  )
386
+ mock_response_2 = ModelResponse(
387
+ model="gpt-4o-mini",
388
+ choices=[
389
+ {
390
+ "message": {
391
+ "content": "After thinking, I've decided the answer is [[equilateral]]",
392
+ }
393
+ }
394
+ ],
395
+ )
396
+
397
+ # Mock the Config.shared() method to return a mock config with required attributes
398
+ mock_config = Mock()
399
+ mock_config.open_ai_api_key = "mock_api_key"
400
+ mock_config.user_id = "test_user"
401
+
402
+ with (
403
+ patch(
404
+ "litellm.acompletion",
405
+ side_effect=[mock_response_1, mock_response_2],
406
+ ),
407
+ patch("kiln_ai.utils.config.Config.shared", return_value=mock_config),
408
+ ):
409
+ await run_structured_input_task(
410
+ task,
411
+ "llama_3_1_8b",
412
+ "groq",
413
+ "simple_chain_of_thought_prompt_builder",
414
+ verify_trace_cot=True,
415
+ )
310
416
 
311
417
 
312
418
  @pytest.mark.paid
@@ -7,4 +7,4 @@ Parsing utilities for JSON and models with custom output formats (R1, etc.)
7
7
 
8
8
  from . import base_parser, json_parser, r1_parser
9
9
 
10
- __all__ = ["r1_parser", "base_parser", "json_parser"]
10
+ __all__ = ["base_parser", "json_parser", "r1_parser"]
@@ -46,7 +46,7 @@ def test_response_with_whitespace(parser):
46
46
  assert parsed.output.strip() == "This is the result"
47
47
 
48
48
 
49
- def test_empty_thinking_content(parser):
49
+ def test_empty_thinking_content_multiline(parser):
50
50
  response = RunOutput(
51
51
  output="""
52
52
  <think>
@@ -2,27 +2,25 @@ import logging
2
2
  from dataclasses import dataclass
3
3
  from typing import Dict, List
4
4
 
5
+ from kiln_ai.adapters.docker_model_runner_tools import (
6
+ get_docker_model_runner_connection,
7
+ )
5
8
  from kiln_ai.adapters.ml_model_list import (
6
9
  KilnModel,
7
10
  KilnModelProvider,
8
- ModelName,
9
11
  ModelParserID,
10
12
  ModelProviderName,
11
13
  StructuredOutputMode,
12
14
  built_in_models,
13
15
  )
14
- from kiln_ai.adapters.model_adapters.litellm_config import (
15
- LiteLlmConfig,
16
- )
17
- from kiln_ai.adapters.ollama_tools import (
18
- get_ollama_connection,
19
- )
16
+ from kiln_ai.adapters.model_adapters.litellm_config import LiteLlmConfig
17
+ from kiln_ai.adapters.ollama_tools import get_ollama_connection
20
18
  from kiln_ai.datamodel import Finetune, Task
21
19
  from kiln_ai.datamodel.datamodel_enums import ChatStrategy
22
- from kiln_ai.datamodel.registry import project_from_id
23
20
  from kiln_ai.datamodel.task import RunConfigProperties
24
21
  from kiln_ai.utils.config import Config
25
22
  from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
23
+ from kiln_ai.utils.project_utils import project_from_id
26
24
 
27
25
  logger = logging.getLogger(__name__)
28
26
 
@@ -37,6 +35,15 @@ async def provider_enabled(provider_name: ModelProviderName) -> bool:
37
35
  except Exception:
38
36
  return False
39
37
 
38
+ if provider_name == ModelProviderName.docker_model_runner:
39
+ try:
40
+ conn = await get_docker_model_runner_connection()
41
+ return conn is not None and (
42
+ len(conn.supported_models) > 0 or len(conn.untested_models) > 0
43
+ )
44
+ except Exception:
45
+ return False
46
+
40
47
  provider_warning = provider_warnings.get(provider_name)
41
48
  if provider_warning is None:
42
49
  return False
@@ -75,30 +82,24 @@ def builtin_model_from(
75
82
  name: str, provider_name: str | None = None
76
83
  ) -> KilnModelProvider | None:
77
84
  """
78
- Gets a model and provider from the built-in list of models.
85
+ Gets a model provider from the built-in list of models.
79
86
 
80
87
  Args:
81
88
  name: The name of the model to get
82
89
  provider_name: Optional specific provider to use (defaults to first available)
83
90
 
84
91
  Returns:
85
- A tuple of (provider, model)
86
-
87
- Raises:
88
- ValueError: If the model or provider is not found, or if the provider is misconfigured
92
+ A KilnModelProvider, or None if not found
89
93
  """
90
- if name not in ModelName.__members__:
91
- return None
92
-
93
94
  # Select the model from built_in_models using the name
94
- model = next(filter(lambda m: m.name == name, built_in_models))
95
+ model = next(filter(lambda m: m.name == name, built_in_models), None)
95
96
  if model is None:
96
- raise ValueError(f"Model {name} not found")
97
+ return None
97
98
 
98
- # If a provider is provided, select the provider from the model's provider_config
99
+ # If a provider is provided, select the appropriate provider. Otherwise, use the first available.
99
100
  provider: KilnModelProvider | None = None
100
101
  if model.providers is None or len(model.providers) == 0:
101
- raise ValueError(f"Model {name} has no providers")
102
+ return None
102
103
  elif provider_name is None:
103
104
  provider = model.providers[0]
104
105
  else:
@@ -384,6 +385,12 @@ def provider_name_from_id(id: str) -> str:
384
385
  return "Google Vertex AI"
385
386
  case ModelProviderName.together_ai:
386
387
  return "Together AI"
388
+ case ModelProviderName.siliconflow_cn:
389
+ return "SiliconFlow"
390
+ case ModelProviderName.cerebras:
391
+ return "Cerebras"
392
+ case ModelProviderName.docker_model_runner:
393
+ return "Docker Model Runner"
387
394
  case _:
388
395
  # triggers pyright warning if I miss a case
389
396
  raise_exhaustive_enum_error(enum_id)
@@ -442,4 +449,12 @@ provider_warnings: Dict[ModelProviderName, ModelProviderWarning] = {
442
449
  required_config_keys=["together_api_key"],
443
450
  message="Attempted to use Together without an API key set. \nGet your API key from https://together.ai/settings/keys",
444
451
  ),
452
+ ModelProviderName.siliconflow_cn: ModelProviderWarning(
453
+ required_config_keys=["siliconflow_cn_api_key"],
454
+ message="Attempted to use SiliconFlow without an API key set. \nGet your API key from https://cloud.siliconflow.cn/account/ak",
455
+ ),
456
+ ModelProviderName.cerebras: ModelProviderWarning(
457
+ required_config_keys=["cerebras_api_key"],
458
+ message="Attempted to use Cerebras without an API key set. \nGet your API key from https://cloud.cerebras.ai/platform",
459
+ ),
445
460
  }
@@ -4,11 +4,12 @@ import logging
4
4
  import os
5
5
  import threading
6
6
  from pathlib import Path
7
- from typing import List
7
+ from typing import Any, List
8
8
 
9
9
  import requests
10
+ from pydantic import ValidationError
10
11
 
11
- from .ml_model_list import KilnModel, built_in_models
12
+ from .ml_model_list import KilnModel, KilnModelProvider, built_in_models
12
13
 
13
14
  logger = logging.getLogger(__name__)
14
15
 
@@ -18,21 +19,67 @@ def serialize_config(models: List[KilnModel], path: str | Path) -> None:
18
19
  Path(path).write_text(json.dumps(data, indent=2, sort_keys=True))
19
20
 
20
21
 
21
- def deserialize_config(path: str | Path) -> List[KilnModel]:
22
+ def deserialize_config_at_path(path: str | Path) -> List[KilnModel]:
22
23
  raw = json.loads(Path(path).read_text())
23
- model_data = raw.get("model_list", raw if isinstance(raw, list) else [])
24
- return [KilnModel.model_validate(item) for item in model_data]
24
+ return deserialize_config_data(raw)
25
+
26
+
27
+ def deserialize_config_data(config_data: Any) -> List[KilnModel]:
28
+ if not isinstance(config_data, dict):
29
+ raise ValueError(f"Remote config expected dict, got {type(config_data)}")
30
+
31
+ model_list = config_data.get("model_list", None)
32
+ if not isinstance(model_list, list):
33
+ raise ValueError(
34
+ f"Remote config expected list of models, got {type(model_list)}"
35
+ )
36
+
37
+ # We must be careful here, because some of the JSON data may be generated from a forward
38
+ # version of the code that has newer fields / versions of the fields, that may cause
39
+ # the current client this code is running on to fail to validate the item into a KilnModel.
40
+ models = []
41
+ for model_data in model_list:
42
+ # We skip any model that fails validation - the models that the client can support
43
+ # will be pulled from the remote config, but the user will need to update their
44
+ # client to the latest version to see the newer models that break backwards compatibility.
45
+ try:
46
+ providers_list = model_data.get("providers", [])
47
+
48
+ providers = []
49
+ for provider_data in providers_list:
50
+ try:
51
+ provider = KilnModelProvider.model_validate(provider_data)
52
+ providers.append(provider)
53
+ except ValidationError as e:
54
+ logger.warning(
55
+ "Failed to validate a model provider from remote config. Upgrade Kiln to use this model. Details %s: %s",
56
+ provider_data,
57
+ e,
58
+ )
59
+
60
+ # this ensures the model deserialization won't fail because of a bad provider
61
+ model_data["providers"] = []
62
+
63
+ # now we validate the model without its providers
64
+ model = KilnModel.model_validate(model_data)
65
+
66
+ # and we attach back the providers that passed our validation
67
+ model.providers = providers
68
+ models.append(model)
69
+ except ValidationError as e:
70
+ logger.warning(
71
+ "Failed to validate a model from remote config. Upgrade Kiln to use this model. Details %s: %s",
72
+ model_data,
73
+ e,
74
+ )
75
+ return models
25
76
 
26
77
 
27
78
  def load_from_url(url: str) -> List[KilnModel]:
28
79
  response = requests.get(url, timeout=10)
29
80
  response.raise_for_status()
30
81
  data = response.json()
31
- if isinstance(data, list):
32
- model_data = data
33
- else:
34
- model_data = data.get("model_list", [])
35
- return [KilnModel.model_validate(item) for item in model_data]
82
+ return deserialize_config_data(data)
36
83
 
37
84
 
38
85
  def dump_builtin_config(path: str | Path) -> None:
@@ -6,7 +6,7 @@ from kiln_ai.adapters.prompt_builders import BasePromptBuilder, prompt_builder_f
6
6
  from kiln_ai.datamodel import Priority, Project, Task, TaskRequirement, TaskRun
7
7
 
8
8
 
9
- # TODO add evaluator rating
9
+ # We should add evaluator rating
10
10
  class RepairTaskInput(BaseModel):
11
11
  original_prompt: str
12
12
  original_input: str
@@ -229,21 +229,20 @@ async def test_mocked_repair_task_run(sample_task, sample_task_run, sample_repai
229
229
  "rating": 8,
230
230
  }
231
231
 
232
+ run_config = RunConfigProperties(
233
+ model_name="llama_3_1_8b",
234
+ model_provider_name="ollama",
235
+ prompt_id="simple_prompt_builder",
236
+ structured_output_mode="json_schema",
237
+ )
238
+
232
239
  with patch.object(LiteLlmAdapter, "_run", new_callable=AsyncMock) as mock_run:
233
240
  mock_run.return_value = (
234
241
  RunOutput(output=mocked_output, intermediate_outputs=None),
235
242
  None,
236
243
  )
237
244
 
238
- adapter = adapter_for_task(
239
- repair_task,
240
- RunConfigProperties(
241
- model_name="llama_3_1_8b",
242
- model_provider_name="ollama",
243
- prompt_id="simple_prompt_builder",
244
- structured_output_mode="json_schema",
245
- ),
246
- )
245
+ adapter = adapter_for_task(repair_task, run_config)
247
246
 
248
247
  run = await adapter.invoke(repair_task_input.model_dump())
249
248
 
@@ -264,6 +263,10 @@ async def test_mocked_repair_task_run(sample_task, sample_task_run, sample_repai
264
263
  }
265
264
  assert run.input_source.type == DataSourceType.human
266
265
  assert "created_by" in run.input_source.properties
266
+ assert run.output.source is not None
267
+ assert run.output.source.run_config is not None
268
+ saved_run_config = run.output.source.run_config.model_dump()
269
+ assert saved_run_config == run_config.model_dump()
267
270
 
268
271
  # Verify that the mock was called
269
272
  mock_run.assert_called_once()
@@ -3,9 +3,12 @@ from typing import Dict
3
3
 
4
4
  from litellm.types.utils import ChoiceLogprobs
5
5
 
6
+ from kiln_ai.utils.open_ai_types import ChatCompletionMessageParam
7
+
6
8
 
7
9
  @dataclass
8
10
  class RunOutput:
9
11
  output: Dict | str
10
12
  intermediate_outputs: Dict[str, str] | None
11
13
  output_logprobs: ChoiceLogprobs | None = None
14
+ trace: list[ChatCompletionMessageParam] | None = None
@@ -8,6 +8,7 @@ from kiln_ai.adapters.ml_model_list import ModelProviderName
8
8
  from kiln_ai.adapters.model_adapters.base_adapter import AdapterConfig
9
9
  from kiln_ai.adapters.model_adapters.litellm_adapter import LiteLlmAdapter
10
10
  from kiln_ai.adapters.provider_tools import kiln_model_provider_from
11
+ from kiln_ai.datamodel.datamodel_enums import StructuredOutputMode
11
12
  from kiln_ai.datamodel.task import RunConfigProperties
12
13
 
13
14
 
@@ -16,6 +17,10 @@ def mock_config():
16
17
  with patch("kiln_ai.adapters.adapter_registry.Config") as mock:
17
18
  mock.shared.return_value.open_ai_api_key = "test-openai-key"
18
19
  mock.shared.return_value.open_router_api_key = "test-openrouter-key"
20
+ mock.shared.return_value.siliconflow_cn_api_key = "test-siliconflow-key"
21
+ mock.shared.return_value.docker_model_runner_base_url = (
22
+ "http://localhost:12434/engines/llama.cpp"
23
+ )
19
24
  yield mock
20
25
 
21
26
 
@@ -85,6 +90,33 @@ def test_openrouter_adapter_creation(mock_config, basic_task):
85
90
  }
86
91
 
87
92
 
93
+ def test_siliconflow_adapter_creation(mock_config, basic_task):
94
+ adapter = adapter_for_task(
95
+ kiln_task=basic_task,
96
+ run_config_properties=RunConfigProperties(
97
+ model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-32B",
98
+ model_provider_name=ModelProviderName.siliconflow_cn,
99
+ prompt_id="simple_prompt_builder",
100
+ structured_output_mode="json_schema",
101
+ ),
102
+ )
103
+
104
+ assert isinstance(adapter, LiteLlmAdapter)
105
+ assert (
106
+ adapter.config.run_config_properties.model_name
107
+ == "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"
108
+ )
109
+ assert adapter.config.additional_body_options == {"api_key": "test-siliconflow-key"}
110
+ assert (
111
+ adapter.config.run_config_properties.model_provider_name
112
+ == ModelProviderName.siliconflow_cn
113
+ )
114
+ assert adapter.config.default_headers == {
115
+ "HTTP-Referer": "https://kiln.tech/siliconflow",
116
+ "X-Title": "KilnAI",
117
+ }
118
+
119
+
88
120
  @pytest.mark.parametrize(
89
121
  "provider",
90
122
  [
@@ -109,7 +141,7 @@ def test_openai_compatible_adapter_creation(mock_config, basic_task, provider):
109
141
  assert adapter.run_config.model_name == "test-model"
110
142
 
111
143
 
112
- # TODO should run for all cases
144
+ # We should run for all cases
113
145
  def test_custom_prompt_builder(mock_config, basic_task):
114
146
  adapter = adapter_for_task(
115
147
  kiln_task=basic_task,
@@ -124,7 +156,7 @@ def test_custom_prompt_builder(mock_config, basic_task):
124
156
  assert adapter.run_config.prompt_id == "simple_chain_of_thought_prompt_builder"
125
157
 
126
158
 
127
- # TODO should run for all cases
159
+ # We should run for all cases
128
160
  def test_tags_passed_through(mock_config, basic_task):
129
161
  tags = ["test-tag-1", "test-tag-2"]
130
162
  adapter = adapter_for_task(
@@ -232,3 +264,78 @@ async def test_fine_tune_provider(mock_config, basic_task, mock_finetune_from_id
232
264
  )
233
265
  # The actual model name from the fine tune object
234
266
  assert provider.model_id == "test-model"
267
+
268
+
269
+ def test_docker_model_runner_adapter_creation(mock_config, basic_task):
270
+ """Test Docker Model Runner adapter creation with default and custom base URL."""
271
+ adapter = adapter_for_task(
272
+ kiln_task=basic_task,
273
+ run_config_properties=RunConfigProperties(
274
+ model_name="llama_3_2_3b",
275
+ model_provider_name=ModelProviderName.docker_model_runner,
276
+ prompt_id="simple_prompt_builder",
277
+ structured_output_mode=StructuredOutputMode.json_schema,
278
+ ),
279
+ )
280
+
281
+ assert isinstance(adapter, LiteLlmAdapter)
282
+ assert adapter.config.run_config_properties.model_name == "llama_3_2_3b"
283
+ assert adapter.config.additional_body_options == {"api_key": "DMR"}
284
+ assert (
285
+ adapter.config.run_config_properties.model_provider_name
286
+ == ModelProviderName.docker_model_runner
287
+ )
288
+ assert adapter.config.base_url == "http://localhost:12434/engines/llama.cpp/v1"
289
+ assert adapter.config.default_headers is None
290
+
291
+
292
+ def test_docker_model_runner_adapter_creation_with_custom_url(mock_config, basic_task):
293
+ """Test Docker Model Runner adapter creation with custom base URL."""
294
+ mock_config.shared.return_value.docker_model_runner_base_url = (
295
+ "http://custom:8080/engines/llama.cpp"
296
+ )
297
+
298
+ adapter = adapter_for_task(
299
+ kiln_task=basic_task,
300
+ run_config_properties=RunConfigProperties(
301
+ model_name="llama_3_2_3b",
302
+ model_provider_name=ModelProviderName.docker_model_runner,
303
+ prompt_id="simple_prompt_builder",
304
+ structured_output_mode=StructuredOutputMode.json_schema,
305
+ ),
306
+ )
307
+
308
+ assert isinstance(adapter, LiteLlmAdapter)
309
+ assert adapter.config.run_config_properties.model_name == "llama_3_2_3b"
310
+ assert adapter.config.additional_body_options == {"api_key": "DMR"}
311
+ assert (
312
+ adapter.config.run_config_properties.model_provider_name
313
+ == ModelProviderName.docker_model_runner
314
+ )
315
+ assert adapter.config.base_url == "http://custom:8080/engines/llama.cpp/v1"
316
+ assert adapter.config.default_headers is None
317
+
318
+
319
+ def test_docker_model_runner_adapter_creation_with_none_url(mock_config, basic_task):
320
+ """Test Docker Model Runner adapter creation when config URL is None."""
321
+ mock_config.shared.return_value.docker_model_runner_base_url = None
322
+
323
+ adapter = adapter_for_task(
324
+ kiln_task=basic_task,
325
+ run_config_properties=RunConfigProperties(
326
+ model_name="llama_3_2_3b",
327
+ model_provider_name=ModelProviderName.docker_model_runner,
328
+ prompt_id="simple_prompt_builder",
329
+ structured_output_mode=StructuredOutputMode.json_schema,
330
+ ),
331
+ )
332
+
333
+ assert isinstance(adapter, LiteLlmAdapter)
334
+ assert adapter.config.run_config_properties.model_name == "llama_3_2_3b"
335
+ assert adapter.config.additional_body_options == {"api_key": "DMR"}
336
+ assert (
337
+ adapter.config.run_config_properties.model_provider_name
338
+ == ModelProviderName.docker_model_runner
339
+ )
340
+ assert adapter.config.base_url == "http://localhost:12434/engines/llama.cpp/v1"
341
+ assert adapter.config.default_headers is None