kiln-ai 0.16.0__py3-none-any.whl → 0.18.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. kiln_ai/adapters/__init__.py +2 -0
  2. kiln_ai/adapters/adapter_registry.py +22 -44
  3. kiln_ai/adapters/chat/__init__.py +8 -0
  4. kiln_ai/adapters/chat/chat_formatter.py +233 -0
  5. kiln_ai/adapters/chat/test_chat_formatter.py +131 -0
  6. kiln_ai/adapters/data_gen/data_gen_prompts.py +121 -36
  7. kiln_ai/adapters/data_gen/data_gen_task.py +49 -36
  8. kiln_ai/adapters/data_gen/test_data_gen_task.py +330 -40
  9. kiln_ai/adapters/eval/base_eval.py +7 -6
  10. kiln_ai/adapters/eval/eval_runner.py +9 -2
  11. kiln_ai/adapters/eval/g_eval.py +40 -17
  12. kiln_ai/adapters/eval/test_base_eval.py +174 -17
  13. kiln_ai/adapters/eval/test_eval_runner.py +3 -0
  14. kiln_ai/adapters/eval/test_g_eval.py +116 -5
  15. kiln_ai/adapters/fine_tune/base_finetune.py +3 -8
  16. kiln_ai/adapters/fine_tune/dataset_formatter.py +135 -273
  17. kiln_ai/adapters/fine_tune/test_base_finetune.py +10 -10
  18. kiln_ai/adapters/fine_tune/test_dataset_formatter.py +287 -353
  19. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +3 -3
  20. kiln_ai/adapters/fine_tune/test_openai_finetune.py +6 -6
  21. kiln_ai/adapters/fine_tune/test_together_finetune.py +1 -0
  22. kiln_ai/adapters/fine_tune/test_vertex_finetune.py +6 -11
  23. kiln_ai/adapters/fine_tune/together_finetune.py +13 -2
  24. kiln_ai/adapters/ml_model_list.py +370 -84
  25. kiln_ai/adapters/model_adapters/base_adapter.py +73 -26
  26. kiln_ai/adapters/model_adapters/litellm_adapter.py +88 -97
  27. kiln_ai/adapters/model_adapters/litellm_config.py +3 -2
  28. kiln_ai/adapters/model_adapters/test_base_adapter.py +235 -61
  29. kiln_ai/adapters/model_adapters/test_litellm_adapter.py +104 -21
  30. kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +41 -0
  31. kiln_ai/adapters/model_adapters/test_structured_output.py +44 -12
  32. kiln_ai/adapters/parsers/parser_registry.py +0 -2
  33. kiln_ai/adapters/parsers/r1_parser.py +0 -1
  34. kiln_ai/adapters/prompt_builders.py +0 -16
  35. kiln_ai/adapters/provider_tools.py +27 -9
  36. kiln_ai/adapters/remote_config.py +66 -0
  37. kiln_ai/adapters/repair/repair_task.py +1 -6
  38. kiln_ai/adapters/repair/test_repair_task.py +24 -3
  39. kiln_ai/adapters/test_adapter_registry.py +88 -28
  40. kiln_ai/adapters/test_ml_model_list.py +176 -0
  41. kiln_ai/adapters/test_prompt_adaptors.py +17 -7
  42. kiln_ai/adapters/test_prompt_builders.py +3 -16
  43. kiln_ai/adapters/test_provider_tools.py +69 -20
  44. kiln_ai/adapters/test_remote_config.py +100 -0
  45. kiln_ai/datamodel/__init__.py +0 -2
  46. kiln_ai/datamodel/datamodel_enums.py +38 -13
  47. kiln_ai/datamodel/eval.py +32 -0
  48. kiln_ai/datamodel/finetune.py +12 -8
  49. kiln_ai/datamodel/task.py +68 -7
  50. kiln_ai/datamodel/task_output.py +0 -2
  51. kiln_ai/datamodel/task_run.py +0 -2
  52. kiln_ai/datamodel/test_basemodel.py +2 -1
  53. kiln_ai/datamodel/test_dataset_split.py +0 -8
  54. kiln_ai/datamodel/test_eval_model.py +146 -4
  55. kiln_ai/datamodel/test_models.py +33 -10
  56. kiln_ai/datamodel/test_task.py +168 -2
  57. kiln_ai/utils/config.py +3 -2
  58. kiln_ai/utils/dataset_import.py +1 -1
  59. kiln_ai/utils/logging.py +166 -0
  60. kiln_ai/utils/test_config.py +23 -0
  61. kiln_ai/utils/test_dataset_import.py +30 -0
  62. {kiln_ai-0.16.0.dist-info → kiln_ai-0.18.0.dist-info}/METADATA +2 -2
  63. kiln_ai-0.18.0.dist-info/RECORD +115 -0
  64. kiln_ai-0.16.0.dist-info/RECORD +0 -108
  65. {kiln_ai-0.16.0.dist-info → kiln_ai-0.18.0.dist-info}/WHEEL +0 -0
  66. {kiln_ai-0.16.0.dist-info → kiln_ai-0.18.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -17,7 +17,7 @@ from kiln_ai.adapters.model_adapters.base_adapter import (
17
17
  from kiln_ai.adapters.ollama_tools import ollama_online
18
18
  from kiln_ai.adapters.test_prompt_adaptors import get_all_models_and_providers
19
19
  from kiln_ai.datamodel import PromptId
20
- from kiln_ai.datamodel.task import RunConfig
20
+ from kiln_ai.datamodel.task import RunConfig, RunConfigProperties
21
21
  from kiln_ai.datamodel.test_json_schema import json_joke_schema, json_triangle_schema
22
22
 
23
23
 
@@ -51,6 +51,7 @@ class MockAdapter(BaseAdapter):
51
51
  model_name="phi_3_5",
52
52
  model_provider_name="ollama",
53
53
  prompt_id="simple_chain_of_thought_prompt_builder",
54
+ structured_output_mode="json_schema",
54
55
  ),
55
56
  )
56
57
  self.response = response
@@ -146,7 +147,15 @@ def build_structured_output_test_task(tmp_path: Path):
146
147
 
147
148
  async def run_structured_output_test(tmp_path: Path, model_name: str, provider: str):
148
149
  task = build_structured_output_test_task(tmp_path)
149
- a = adapter_for_task(task, model_name=model_name, provider=provider)
150
+ a = adapter_for_task(
151
+ task,
152
+ run_config_properties=RunConfigProperties(
153
+ model_name=model_name,
154
+ model_provider_name=provider,
155
+ prompt_id="simple_prompt_builder",
156
+ structured_output_mode="unknown",
157
+ ),
158
+ )
150
159
  try:
151
160
  run = await a.invoke("Cows") # a joke about cows
152
161
  parsed = json.loads(run.output.output)
@@ -197,10 +206,12 @@ def build_structured_input_test_task(tmp_path: Path):
197
206
  return task
198
207
 
199
208
 
200
- async def run_structured_input_test(tmp_path: Path, model_name: str, provider: str):
209
+ async def run_structured_input_test(
210
+ tmp_path: Path, model_name: str, provider: str, prompt_id: PromptId
211
+ ):
201
212
  task = build_structured_input_test_task(tmp_path)
202
213
  try:
203
- await run_structured_input_task(task, model_name, provider)
214
+ await run_structured_input_task(task, model_name, provider, prompt_id)
204
215
  except ValueError as e:
205
216
  if str(e) == "Failed to connect to Ollama. Ensure Ollama is running.":
206
217
  pytest.skip(
@@ -209,17 +220,20 @@ async def run_structured_input_test(tmp_path: Path, model_name: str, provider: s
209
220
  raise e
210
221
 
211
222
 
212
- async def run_structured_input_task(
223
+ async def run_structured_input_task_no_validation(
213
224
  task: datamodel.Task,
214
225
  model_name: str,
215
226
  provider: str,
216
- prompt_id: PromptId | None = None,
227
+ prompt_id: PromptId,
217
228
  ):
218
229
  a = adapter_for_task(
219
230
  task,
220
- model_name=model_name,
221
- provider=provider,
222
- prompt_id=prompt_id,
231
+ run_config_properties=RunConfigProperties(
232
+ model_name=model_name,
233
+ model_provider_name=provider,
234
+ prompt_id=prompt_id,
235
+ structured_output_mode="unknown",
236
+ ),
223
237
  )
224
238
  with pytest.raises(ValueError):
225
239
  # not structured input in dictionary
@@ -231,18 +245,29 @@ async def run_structured_input_task(
231
245
  try:
232
246
  run = await a.invoke({"a": 2, "b": 2, "c": 2})
233
247
  response = run.output.output
248
+ return response, a, run
234
249
  except ValueError as e:
235
250
  if str(e) == "Failed to connect to Ollama. Ensure Ollama is running.":
236
251
  pytest.skip(
237
252
  f"Skipping {model_name} {provider} because Ollama is not running"
238
253
  )
239
254
  raise e
255
+
256
+
257
+ async def run_structured_input_task(
258
+ task: datamodel.Task,
259
+ model_name: str,
260
+ provider: str,
261
+ prompt_id: PromptId,
262
+ ):
263
+ response, a, run = await run_structured_input_task_no_validation(
264
+ task, model_name, provider, prompt_id
265
+ )
240
266
  assert response is not None
241
267
  if isinstance(response, str):
242
268
  assert "[[equilateral]]" in response
243
269
  else:
244
270
  assert response["is_equilateral"] is True
245
-
246
271
  expected_pb_name = "simple_prompt_builder"
247
272
  if prompt_id is not None:
248
273
  expected_pb_name = prompt_id
@@ -269,7 +294,9 @@ async def test_structured_input_gpt_4o_mini(tmp_path):
269
294
  async def test_all_built_in_models_structured_input(
270
295
  tmp_path, model_name, provider_name
271
296
  ):
272
- await run_structured_input_test(tmp_path, model_name, provider_name)
297
+ await run_structured_input_test(
298
+ tmp_path, model_name, provider_name, "simple_prompt_builder"
299
+ )
273
300
 
274
301
 
275
302
  @pytest.mark.paid
@@ -323,6 +350,11 @@ When asked for a final result, this is the format (for an equilateral example):
323
350
  """
324
351
  task.output_json_schema = json.dumps(triangle_schema)
325
352
  task.save_to_file()
326
- await run_structured_input_task(
353
+ response, adapter, _ = await run_structured_input_task_no_validation(
327
354
  task, model_name, provider_name, "simple_chain_of_thought_prompt_builder"
328
355
  )
356
+
357
+ formatted_response = json.loads(response)
358
+ assert formatted_response["is_equilateral"] is True
359
+ assert formatted_response["is_scalene"] is False
360
+ assert formatted_response["is_obtuse"] is False
@@ -1,5 +1,3 @@
1
- from typing import Type
2
-
3
1
  from kiln_ai.adapters.ml_model_list import ModelParserID
4
2
  from kiln_ai.adapters.parsers.base_parser import BaseParser
5
3
  from kiln_ai.adapters.parsers.r1_parser import R1ThinkingParser
@@ -1,5 +1,4 @@
1
1
  from kiln_ai.adapters.parsers.base_parser import BaseParser
2
- from kiln_ai.adapters.parsers.json_parser import parse_json_string
3
2
  from kiln_ai.adapters.run_output import RunOutput
4
3
 
5
4
 
@@ -1,6 +1,4 @@
1
- import json
2
1
  from abc import ABCMeta, abstractmethod
3
- from typing import Dict
4
2
 
5
3
  from kiln_ai.datamodel import PromptGenerators, PromptId, Task, TaskRun
6
4
  from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
@@ -53,20 +51,6 @@ class BasePromptBuilder(metaclass=ABCMeta):
53
51
  """
54
52
  pass
55
53
 
56
- def build_user_message(self, input: Dict | str) -> str:
57
- """Build a user message from the input.
58
-
59
- Args:
60
- input (Union[Dict, str]): The input to format into a message.
61
-
62
- Returns:
63
- str: The formatted user message.
64
- """
65
- if isinstance(input, Dict):
66
- return f"The input is:\n{json.dumps(input, indent=2, ensure_ascii=False)}"
67
-
68
- return f"The input is:\n{input}"
69
-
70
54
  def chain_of_thought_prompt(self) -> str | None:
71
55
  """Build and return the chain of thought prompt string.
72
56
 
@@ -1,3 +1,4 @@
1
+ import logging
1
2
  from dataclasses import dataclass
2
3
  from typing import Dict, List
3
4
 
@@ -16,11 +17,15 @@ from kiln_ai.adapters.model_adapters.litellm_config import (
16
17
  from kiln_ai.adapters.ollama_tools import (
17
18
  get_ollama_connection,
18
19
  )
19
- from kiln_ai.datamodel import Finetune, FinetuneDataStrategy, Task
20
+ from kiln_ai.datamodel import Finetune, Task
21
+ from kiln_ai.datamodel.datamodel_enums import ChatStrategy
20
22
  from kiln_ai.datamodel.registry import project_from_id
23
+ from kiln_ai.datamodel.task import RunConfigProperties
21
24
  from kiln_ai.utils.config import Config
22
25
  from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
23
26
 
27
+ logger = logging.getLogger(__name__)
28
+
24
29
 
25
30
  async def provider_enabled(provider_name: ModelProviderName) -> bool:
26
31
  if provider_name == ModelProviderName.ollama:
@@ -163,6 +168,10 @@ def kiln_model_provider_from(
163
168
  # For custom registry, get the provider name and model name from the model id
164
169
  if provider_name == ModelProviderName.kiln_custom_registry:
165
170
  provider_name, name = parse_custom_model_id(name)
171
+ else:
172
+ logger.warning(
173
+ f"Unexpected model/provider pair. Will treat as custom model but check your model settings. Provider: {provider_name}/{name}"
174
+ )
166
175
 
167
176
  # Custom/untested model. Set untested, and build a ModelProvider at runtime
168
177
  if provider_name is None:
@@ -177,12 +186,15 @@ def kiln_model_provider_from(
177
186
  supports_data_gen=False,
178
187
  untested_model=True,
179
188
  model_id=name,
189
+ # We don't know the structured output mode for custom models, so we default to json_instructions which is the only one that works everywhere.
190
+ structured_output_mode=StructuredOutputMode.json_instructions,
180
191
  )
181
192
 
182
193
 
183
- def lite_llm_config(
184
- model_id: str,
194
+ def lite_llm_config_for_openai_compatible(
195
+ run_config_properties: RunConfigProperties,
185
196
  ) -> LiteLlmConfig:
197
+ model_id = run_config_properties.model_name
186
198
  try:
187
199
  openai_provider_name, model_id = model_id.split("::")
188
200
  except Exception:
@@ -206,10 +218,16 @@ def lite_llm_config(
206
218
  f"OpenAI compatible provider {openai_provider_name} has no base URL"
207
219
  )
208
220
 
221
+ # Update a copy of the run config properties to use the openai compatible provider
222
+ updated_run_config_properties = run_config_properties.model_copy(deep=True)
223
+ updated_run_config_properties.model_provider_name = (
224
+ ModelProviderName.openai_compatible
225
+ )
226
+ updated_run_config_properties.model_name = model_id
227
+
209
228
  return LiteLlmConfig(
210
229
  # OpenAI compatible, with a custom base URL
211
- model_name=model_id,
212
- provider_name=ModelProviderName.openai_compatible,
230
+ run_config_properties=updated_run_config_properties,
213
231
  base_url=base_url,
214
232
  additional_body_options={
215
233
  "api_key": api_key,
@@ -259,9 +277,9 @@ def finetune_from_id(model_id: str) -> Finetune:
259
277
 
260
278
 
261
279
  def parser_from_data_strategy(
262
- data_strategy: FinetuneDataStrategy,
280
+ data_strategy: ChatStrategy,
263
281
  ) -> ModelParserID | None:
264
- if data_strategy == FinetuneDataStrategy.final_and_intermediate_r1_compatible:
282
+ if data_strategy == ChatStrategy.single_turn_r1_thinking:
265
283
  return ModelParserID.r1_thinking
266
284
  return None
267
285
 
@@ -279,10 +297,10 @@ def finetune_provider_model(
279
297
  reasoning_capable=(
280
298
  fine_tune.data_strategy
281
299
  in [
282
- FinetuneDataStrategy.final_and_intermediate,
283
- FinetuneDataStrategy.final_and_intermediate_r1_compatible,
300
+ ChatStrategy.single_turn_r1_thinking,
284
301
  ]
285
302
  ),
303
+ tuned_chat_strategy=fine_tune.data_strategy,
286
304
  )
287
305
 
288
306
  if provider == ModelProviderName.vertex and fine_tune.fine_tune_model_id:
@@ -0,0 +1,66 @@
1
+ import argparse
2
+ import json
3
+ import logging
4
+ import os
5
+ import threading
6
+ from pathlib import Path
7
+ from typing import List
8
+
9
+ import requests
10
+
11
+ from .ml_model_list import KilnModel, built_in_models
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ def serialize_config(models: List[KilnModel], path: str | Path) -> None:
17
+ data = {"model_list": [m.model_dump(mode="json") for m in models]}
18
+ Path(path).write_text(json.dumps(data, indent=2, sort_keys=True))
19
+
20
+
21
+ def deserialize_config(path: str | Path) -> List[KilnModel]:
22
+ raw = json.loads(Path(path).read_text())
23
+ model_data = raw.get("model_list", raw if isinstance(raw, list) else [])
24
+ return [KilnModel.model_validate(item) for item in model_data]
25
+
26
+
27
+ def load_from_url(url: str) -> List[KilnModel]:
28
+ response = requests.get(url, timeout=10)
29
+ response.raise_for_status()
30
+ data = response.json()
31
+ if isinstance(data, list):
32
+ model_data = data
33
+ else:
34
+ model_data = data.get("model_list", [])
35
+ return [KilnModel.model_validate(item) for item in model_data]
36
+
37
+
38
+ def dump_builtin_config(path: str | Path) -> None:
39
+ serialize_config(built_in_models, path)
40
+
41
+
42
+ def load_remote_models(url: str) -> None:
43
+ if os.environ.get("KILN_SKIP_REMOTE_MODEL_LIST") == "true":
44
+ return
45
+
46
+ def fetch_and_replace() -> None:
47
+ try:
48
+ models = load_from_url(url)
49
+ built_in_models[:] = models
50
+ except Exception as exc:
51
+ # Do not crash startup, but surface the issue
52
+ logger.warning("Failed to fetch remote model list from %s: %s", url, exc)
53
+
54
+ thread = threading.Thread(target=fetch_and_replace, daemon=True)
55
+ thread.start()
56
+
57
+
58
+ def main() -> None:
59
+ parser = argparse.ArgumentParser()
60
+ parser.add_argument("path", help="output path")
61
+ args = parser.parse_args()
62
+ dump_builtin_config(args.path)
63
+
64
+
65
+ if __name__ == "__main__":
66
+ main()
@@ -1,13 +1,8 @@
1
1
  import json
2
- from typing import Type
3
2
 
4
3
  from pydantic import BaseModel, Field
5
4
 
6
- from kiln_ai.adapters.prompt_builders import (
7
- BasePromptBuilder,
8
- SavedPromptBuilder,
9
- prompt_builder_from_id,
10
- )
5
+ from kiln_ai.adapters.prompt_builders import BasePromptBuilder, prompt_builder_from_id
11
6
  from kiln_ai.datamodel import Priority, Project, Task, TaskRequirement, TaskRun
12
7
 
13
8
 
@@ -21,6 +21,7 @@ from kiln_ai.datamodel import (
21
21
  TaskRequirement,
22
22
  TaskRun,
23
23
  )
24
+ from kiln_ai.datamodel.task import RunConfigProperties
24
25
 
25
26
  json_joke_schema = """{
26
27
  "type": "object",
@@ -189,7 +190,15 @@ async def test_live_run(sample_task, sample_task_run, sample_repair_data):
189
190
  repair_task_input = RepairTaskRun.build_repair_task_input(**sample_repair_data)
190
191
  assert isinstance(repair_task_input, RepairTaskInput)
191
192
 
192
- adapter = adapter_for_task(repair_task, model_name="llama_3_1_8b", provider="groq")
193
+ adapter = adapter_for_task(
194
+ repair_task,
195
+ RunConfigProperties(
196
+ model_name="llama_3_1_8b",
197
+ model_provider_name="groq",
198
+ prompt_id="simple_prompt_builder",
199
+ structured_output_mode="default",
200
+ ),
201
+ )
193
202
 
194
203
  run = await adapter.invoke(repair_task_input.model_dump())
195
204
  assert run is not None
@@ -198,10 +207,13 @@ async def test_live_run(sample_task, sample_task_run, sample_repair_data):
198
207
  assert "setup" in parsed_output
199
208
  assert "punchline" in parsed_output
200
209
  assert run.output.source.properties == {
201
- "adapter_name": "kiln_langchain_adapter",
210
+ "adapter_name": "kiln_openai_compatible_adapter",
202
211
  "model_name": "llama_3_1_8b",
203
212
  "model_provider": "groq",
204
213
  "prompt_id": "simple_prompt_builder",
214
+ "structured_output_mode": "default",
215
+ "temperature": 1.0,
216
+ "top_p": 1.0,
205
217
  }
206
218
 
207
219
 
@@ -224,7 +236,13 @@ async def test_mocked_repair_task_run(sample_task, sample_task_run, sample_repai
224
236
  )
225
237
 
226
238
  adapter = adapter_for_task(
227
- repair_task, model_name="llama_3_1_8b", provider="ollama"
239
+ repair_task,
240
+ RunConfigProperties(
241
+ model_name="llama_3_1_8b",
242
+ model_provider_name="ollama",
243
+ prompt_id="simple_prompt_builder",
244
+ structured_output_mode="json_schema",
245
+ ),
228
246
  )
229
247
 
230
248
  run = await adapter.invoke(repair_task_input.model_dump())
@@ -240,6 +258,9 @@ async def test_mocked_repair_task_run(sample_task, sample_task_run, sample_repai
240
258
  "model_name": "llama_3_1_8b",
241
259
  "model_provider": "ollama",
242
260
  "prompt_id": "simple_prompt_builder",
261
+ "structured_output_mode": "json_schema",
262
+ "temperature": 1.0,
263
+ "top_p": 1.0,
243
264
  }
244
265
  assert run.input_source.type == DataSourceType.human
245
266
  assert "created_by" in run.input_source.properties
@@ -7,8 +7,8 @@ from kiln_ai.adapters.adapter_registry import adapter_for_task
7
7
  from kiln_ai.adapters.ml_model_list import ModelProviderName
8
8
  from kiln_ai.adapters.model_adapters.base_adapter import AdapterConfig
9
9
  from kiln_ai.adapters.model_adapters.litellm_adapter import LiteLlmAdapter
10
- from kiln_ai.adapters.prompt_builders import BasePromptBuilder
11
10
  from kiln_ai.adapters.provider_tools import kiln_model_provider_from
11
+ from kiln_ai.datamodel.task import RunConfigProperties
12
12
 
13
13
 
14
14
  @pytest.fixture
@@ -35,18 +35,28 @@ def mock_finetune_from_id():
35
35
  with patch("kiln_ai.adapters.provider_tools.finetune_from_id") as mock:
36
36
  mock.return_value.provider = ModelProviderName.openai
37
37
  mock.return_value.fine_tune_model_id = "test-model"
38
+ mock.return_value.data_strategy = "final_only"
38
39
  yield mock
39
40
 
40
41
 
41
42
  def test_openai_adapter_creation(mock_config, basic_task):
42
43
  adapter = adapter_for_task(
43
- kiln_task=basic_task, model_name="gpt-4", provider=ModelProviderName.openai
44
+ kiln_task=basic_task,
45
+ run_config_properties=RunConfigProperties(
46
+ model_name="gpt-4",
47
+ model_provider_name=ModelProviderName.openai,
48
+ prompt_id="simple_prompt_builder",
49
+ structured_output_mode="json_schema",
50
+ ),
44
51
  )
45
52
 
46
53
  assert isinstance(adapter, LiteLlmAdapter)
47
- assert adapter.config.model_name == "gpt-4"
54
+ assert adapter.config.run_config_properties.model_name == "gpt-4"
48
55
  assert adapter.config.additional_body_options == {"api_key": "test-openai-key"}
49
- assert adapter.config.provider_name == ModelProviderName.openai
56
+ assert (
57
+ adapter.config.run_config_properties.model_provider_name
58
+ == ModelProviderName.openai
59
+ )
50
60
  assert adapter.config.base_url is None # OpenAI url is default
51
61
  assert adapter.config.default_headers is None
52
62
 
@@ -54,14 +64,21 @@ def test_openai_adapter_creation(mock_config, basic_task):
54
64
  def test_openrouter_adapter_creation(mock_config, basic_task):
55
65
  adapter = adapter_for_task(
56
66
  kiln_task=basic_task,
57
- model_name="anthropic/claude-3-opus",
58
- provider=ModelProviderName.openrouter,
67
+ run_config_properties=RunConfigProperties(
68
+ model_name="anthropic/claude-3-opus",
69
+ model_provider_name=ModelProviderName.openrouter,
70
+ prompt_id="simple_prompt_builder",
71
+ structured_output_mode="json_schema",
72
+ ),
59
73
  )
60
74
 
61
75
  assert isinstance(adapter, LiteLlmAdapter)
62
- assert adapter.config.model_name == "anthropic/claude-3-opus"
76
+ assert adapter.config.run_config_properties.model_name == "anthropic/claude-3-opus"
63
77
  assert adapter.config.additional_body_options == {"api_key": "test-openrouter-key"}
64
- assert adapter.config.provider_name == ModelProviderName.openrouter
78
+ assert (
79
+ adapter.config.run_config_properties.model_provider_name
80
+ == ModelProviderName.openrouter
81
+ )
65
82
  assert adapter.config.default_headers == {
66
83
  "HTTP-Referer": "https://getkiln.ai/openrouter",
67
84
  "X-Title": "KilnAI",
@@ -79,7 +96,13 @@ def test_openrouter_adapter_creation(mock_config, basic_task):
79
96
  )
80
97
  def test_openai_compatible_adapter_creation(mock_config, basic_task, provider):
81
98
  adapter = adapter_for_task(
82
- kiln_task=basic_task, model_name="test-model", provider=provider
99
+ kiln_task=basic_task,
100
+ run_config_properties=RunConfigProperties(
101
+ model_name="test-model",
102
+ model_provider_name=provider,
103
+ prompt_id="simple_prompt_builder",
104
+ structured_output_mode="json_schema",
105
+ ),
83
106
  )
84
107
 
85
108
  assert isinstance(adapter, LiteLlmAdapter)
@@ -90,9 +113,12 @@ def test_openai_compatible_adapter_creation(mock_config, basic_task, provider):
90
113
  def test_custom_prompt_builder(mock_config, basic_task):
91
114
  adapter = adapter_for_task(
92
115
  kiln_task=basic_task,
93
- model_name="gpt-4",
94
- provider=ModelProviderName.openai,
95
- prompt_id="simple_chain_of_thought_prompt_builder",
116
+ run_config_properties=RunConfigProperties(
117
+ model_name="gpt-4",
118
+ model_provider_name=ModelProviderName.openai,
119
+ prompt_id="simple_chain_of_thought_prompt_builder",
120
+ structured_output_mode="json_schema",
121
+ ),
96
122
  )
97
123
 
98
124
  assert adapter.run_config.prompt_id == "simple_chain_of_thought_prompt_builder"
@@ -103,8 +129,12 @@ def test_tags_passed_through(mock_config, basic_task):
103
129
  tags = ["test-tag-1", "test-tag-2"]
104
130
  adapter = adapter_for_task(
105
131
  kiln_task=basic_task,
106
- model_name="gpt-4",
107
- provider=ModelProviderName.openai,
132
+ run_config_properties=RunConfigProperties(
133
+ model_name="gpt-4",
134
+ model_provider_name=ModelProviderName.openai,
135
+ prompt_id="simple_prompt_builder",
136
+ structured_output_mode="json_schema",
137
+ ),
108
138
  base_adapter_config=AdapterConfig(
109
139
  default_tags=tags,
110
140
  ),
@@ -114,13 +144,19 @@ def test_tags_passed_through(mock_config, basic_task):
114
144
 
115
145
 
116
146
  def test_invalid_provider(mock_config, basic_task):
117
- with pytest.raises(ValueError, match="Unhandled enum value"):
147
+ with pytest.raises(ValueError, match="Input should be"):
118
148
  adapter_for_task(
119
- kiln_task=basic_task, model_name="test-model", provider="invalid"
149
+ kiln_task=basic_task,
150
+ run_config_properties=RunConfigProperties(
151
+ model_name="test-model",
152
+ model_provider_name="invalid",
153
+ prompt_id="simple_prompt_builder",
154
+ structured_output_mode="json_schema",
155
+ ),
120
156
  )
121
157
 
122
158
 
123
- @patch("kiln_ai.adapters.adapter_registry.lite_llm_config")
159
+ @patch("kiln_ai.adapters.adapter_registry.lite_llm_config_for_openai_compatible")
124
160
  def test_openai_compatible_adapter(mock_compatible_config, mock_config, basic_task):
125
161
  mock_compatible_config.return_value.model_name = "test-model"
126
162
  mock_compatible_config.return_value.additional_body_options = {
@@ -128,44 +164,68 @@ def test_openai_compatible_adapter(mock_compatible_config, mock_config, basic_ta
128
164
  }
129
165
  mock_compatible_config.return_value.base_url = "https://test.com/v1"
130
166
  mock_compatible_config.return_value.provider_name = "CustomProvider99"
167
+ mock_compatible_config.return_value.run_config_properties = RunConfigProperties(
168
+ model_name="provider::test-model",
169
+ model_provider_name=ModelProviderName.openai_compatible,
170
+ prompt_id="simple_prompt_builder",
171
+ structured_output_mode="json_schema",
172
+ )
131
173
 
132
174
  adapter = adapter_for_task(
133
175
  kiln_task=basic_task,
134
- model_name="provider::test-model",
135
- provider=ModelProviderName.openai_compatible,
176
+ run_config_properties=RunConfigProperties(
177
+ model_name="provider::test-model",
178
+ model_provider_name=ModelProviderName.openai_compatible,
179
+ prompt_id="simple_prompt_builder",
180
+ structured_output_mode="json_schema",
181
+ ),
136
182
  )
137
183
 
138
184
  assert isinstance(adapter, LiteLlmAdapter)
139
- mock_compatible_config.assert_called_once_with("provider::test-model")
185
+ mock_compatible_config.assert_called_once()
140
186
  assert adapter.config == mock_compatible_config.return_value
141
187
 
142
188
 
143
189
  def test_custom_openai_compatible_provider(mock_config, basic_task):
144
190
  adapter = adapter_for_task(
145
191
  kiln_task=basic_task,
146
- model_name="openai::test-model",
147
- provider=ModelProviderName.kiln_custom_registry,
192
+ run_config_properties=RunConfigProperties(
193
+ model_name="openai::test-model",
194
+ model_provider_name=ModelProviderName.kiln_custom_registry,
195
+ prompt_id="simple_prompt_builder",
196
+ structured_output_mode="json_schema",
197
+ ),
148
198
  )
149
199
 
150
200
  assert isinstance(adapter, LiteLlmAdapter)
151
- assert adapter.config.model_name == "openai::test-model"
201
+ assert adapter.config.run_config_properties.model_name == "openai::test-model"
152
202
  assert adapter.config.additional_body_options == {"api_key": "test-openai-key"}
153
203
  assert adapter.config.base_url is None # openai is none
154
- assert adapter.config.provider_name == ModelProviderName.kiln_custom_registry
204
+ assert (
205
+ adapter.config.run_config_properties.model_provider_name
206
+ == ModelProviderName.kiln_custom_registry
207
+ )
155
208
 
156
209
 
157
210
  async def test_fine_tune_provider(mock_config, basic_task, mock_finetune_from_id):
158
211
  adapter = adapter_for_task(
159
212
  kiln_task=basic_task,
160
- model_name="proj::task::tune",
161
- provider=ModelProviderName.kiln_fine_tune,
213
+ run_config_properties=RunConfigProperties(
214
+ model_name="proj::task::tune",
215
+ model_provider_name=ModelProviderName.kiln_fine_tune,
216
+ prompt_id="simple_prompt_builder",
217
+ structured_output_mode="json_schema",
218
+ ),
162
219
  )
163
220
 
164
221
  mock_finetune_from_id.assert_called_once_with("proj::task::tune")
165
222
  assert isinstance(adapter, LiteLlmAdapter)
166
- assert adapter.config.provider_name == ModelProviderName.kiln_fine_tune
223
+ assert (
224
+ adapter.config.run_config_properties.model_provider_name
225
+ == ModelProviderName.kiln_fine_tune
226
+ )
167
227
  # Kiln model name here, but the underlying openai model id below
168
- assert adapter.config.model_name == "proj::task::tune"
228
+ assert adapter.config.run_config_properties.model_name == "proj::task::tune"
169
229
 
170
230
  provider = kiln_model_provider_from(
171
231
  "proj::task::tune", provider_name=ModelProviderName.kiln_fine_tune