kiln-ai 0.16.0__py3-none-any.whl → 0.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (54) hide show
  1. kiln_ai/adapters/__init__.py +2 -0
  2. kiln_ai/adapters/adapter_registry.py +22 -44
  3. kiln_ai/adapters/chat/__init__.py +8 -0
  4. kiln_ai/adapters/chat/chat_formatter.py +234 -0
  5. kiln_ai/adapters/chat/test_chat_formatter.py +131 -0
  6. kiln_ai/adapters/data_gen/test_data_gen_task.py +19 -6
  7. kiln_ai/adapters/eval/base_eval.py +8 -6
  8. kiln_ai/adapters/eval/eval_runner.py +4 -1
  9. kiln_ai/adapters/eval/g_eval.py +23 -5
  10. kiln_ai/adapters/eval/test_base_eval.py +166 -15
  11. kiln_ai/adapters/eval/test_eval_runner.py +3 -0
  12. kiln_ai/adapters/eval/test_g_eval.py +1 -0
  13. kiln_ai/adapters/fine_tune/base_finetune.py +2 -2
  14. kiln_ai/adapters/fine_tune/dataset_formatter.py +138 -272
  15. kiln_ai/adapters/fine_tune/test_base_finetune.py +10 -10
  16. kiln_ai/adapters/fine_tune/test_dataset_formatter.py +287 -353
  17. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +3 -3
  18. kiln_ai/adapters/fine_tune/test_openai_finetune.py +6 -6
  19. kiln_ai/adapters/fine_tune/test_together_finetune.py +1 -0
  20. kiln_ai/adapters/fine_tune/test_vertex_finetune.py +4 -4
  21. kiln_ai/adapters/fine_tune/together_finetune.py +12 -1
  22. kiln_ai/adapters/ml_model_list.py +80 -43
  23. kiln_ai/adapters/model_adapters/base_adapter.py +73 -26
  24. kiln_ai/adapters/model_adapters/litellm_adapter.py +79 -97
  25. kiln_ai/adapters/model_adapters/litellm_config.py +3 -2
  26. kiln_ai/adapters/model_adapters/test_base_adapter.py +235 -60
  27. kiln_ai/adapters/model_adapters/test_litellm_adapter.py +56 -21
  28. kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +41 -0
  29. kiln_ai/adapters/model_adapters/test_structured_output.py +44 -12
  30. kiln_ai/adapters/prompt_builders.py +0 -16
  31. kiln_ai/adapters/provider_tools.py +27 -9
  32. kiln_ai/adapters/repair/test_repair_task.py +24 -3
  33. kiln_ai/adapters/test_adapter_registry.py +88 -28
  34. kiln_ai/adapters/test_ml_model_list.py +158 -0
  35. kiln_ai/adapters/test_prompt_adaptors.py +17 -3
  36. kiln_ai/adapters/test_prompt_builders.py +3 -16
  37. kiln_ai/adapters/test_provider_tools.py +69 -20
  38. kiln_ai/datamodel/__init__.py +0 -2
  39. kiln_ai/datamodel/datamodel_enums.py +38 -13
  40. kiln_ai/datamodel/finetune.py +12 -7
  41. kiln_ai/datamodel/task.py +68 -7
  42. kiln_ai/datamodel/test_basemodel.py +2 -1
  43. kiln_ai/datamodel/test_dataset_split.py +0 -8
  44. kiln_ai/datamodel/test_models.py +33 -10
  45. kiln_ai/datamodel/test_task.py +168 -2
  46. kiln_ai/utils/config.py +3 -2
  47. kiln_ai/utils/dataset_import.py +1 -1
  48. kiln_ai/utils/logging.py +165 -0
  49. kiln_ai/utils/test_config.py +23 -0
  50. kiln_ai/utils/test_dataset_import.py +30 -0
  51. {kiln_ai-0.16.0.dist-info → kiln_ai-0.17.0.dist-info}/METADATA +1 -1
  52. {kiln_ai-0.16.0.dist-info → kiln_ai-0.17.0.dist-info}/RECORD +54 -49
  53. {kiln_ai-0.16.0.dist-info → kiln_ai-0.17.0.dist-info}/WHEEL +0 -0
  54. {kiln_ai-0.16.0.dist-info → kiln_ai-0.17.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -11,6 +11,7 @@ from kiln_ai.adapters.model_adapters.litellm_config import (
11
11
  LiteLlmConfig,
12
12
  )
13
13
  from kiln_ai.datamodel import Project, Task, Usage
14
+ from kiln_ai.datamodel.task import RunConfigProperties
14
15
 
15
16
 
16
17
  @pytest.fixture
@@ -41,8 +42,12 @@ def mock_task(tmp_path):
41
42
  def config():
42
43
  return LiteLlmConfig(
43
44
  base_url="https://api.test.com",
44
- model_name="test-model",
45
- provider_name="openrouter",
45
+ run_config_properties=RunConfigProperties(
46
+ model_name="test-model",
47
+ model_provider_name="openrouter",
48
+ prompt_id="simple_prompt_builder",
49
+ structured_output_mode="json_schema",
50
+ ),
46
51
  default_headers={"X-Test": "test"},
47
52
  additional_body_options={"api_key": "test_key"},
48
53
  )
@@ -52,7 +57,6 @@ def test_initialization(config, mock_task):
52
57
  adapter = LiteLlmAdapter(
53
58
  config=config,
54
59
  kiln_task=mock_task,
55
- prompt_id="simple_prompt_builder",
56
60
  base_adapter_config=AdapterConfig(default_tags=["test-tag"]),
57
61
  )
58
62
 
@@ -60,8 +64,11 @@ def test_initialization(config, mock_task):
60
64
  assert adapter.run_config.task == mock_task
61
65
  assert adapter.run_config.prompt_id == "simple_prompt_builder"
62
66
  assert adapter.base_adapter_config.default_tags == ["test-tag"]
63
- assert adapter.run_config.model_name == config.model_name
64
- assert adapter.run_config.model_provider_name == config.provider_name
67
+ assert adapter.run_config.model_name == config.run_config_properties.model_name
68
+ assert (
69
+ adapter.run_config.model_provider_name
70
+ == config.run_config_properties.model_provider_name
71
+ )
65
72
  assert adapter.config.additional_body_options["api_key"] == "test_key"
66
73
  assert adapter._api_base == config.base_url
67
74
  assert adapter._headers == config.default_headers
@@ -72,8 +79,11 @@ def test_adapter_info(config, mock_task):
72
79
 
73
80
  assert adapter.adapter_name() == "kiln_openai_compatible_adapter"
74
81
 
75
- assert adapter.run_config.model_name == config.model_name
76
- assert adapter.run_config.model_provider_name == config.provider_name
82
+ assert adapter.run_config.model_name == config.run_config_properties.model_name
83
+ assert (
84
+ adapter.run_config.model_provider_name
85
+ == config.run_config_properties.model_provider_name
86
+ )
77
87
  assert adapter.run_config.prompt_id == "simple_prompt_builder"
78
88
 
79
89
 
@@ -96,14 +106,12 @@ async def test_response_format_options_unstructured(config, mock_task):
96
106
  )
97
107
  @pytest.mark.asyncio
98
108
  async def test_response_format_options_json_mode(config, mock_task, mode):
109
+ config.run_config_properties.structured_output_mode = mode
99
110
  adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
100
111
 
101
112
  with (
102
113
  patch.object(adapter, "has_structured_output", return_value=True),
103
- patch.object(adapter, "model_provider") as mock_provider,
104
114
  ):
105
- mock_provider.return_value.structured_output_mode = mode
106
-
107
115
  options = await adapter.response_format_options()
108
116
  assert options == {"response_format": {"type": "json_object"}}
109
117
 
@@ -117,14 +125,12 @@ async def test_response_format_options_json_mode(config, mock_task, mode):
117
125
  )
118
126
  @pytest.mark.asyncio
119
127
  async def test_response_format_options_function_calling(config, mock_task, mode):
128
+ config.run_config_properties.structured_output_mode = mode
120
129
  adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
121
130
 
122
131
  with (
123
132
  patch.object(adapter, "has_structured_output", return_value=True),
124
- patch.object(adapter, "model_provider") as mock_provider,
125
133
  ):
126
- mock_provider.return_value.structured_output_mode = mode
127
-
128
134
  options = await adapter.response_format_options()
129
135
  assert "tools" in options
130
136
  # full tool structure validated below
@@ -139,30 +145,26 @@ async def test_response_format_options_function_calling(config, mock_task, mode)
139
145
  )
140
146
  @pytest.mark.asyncio
141
147
  async def test_response_format_options_json_instructions(config, mock_task, mode):
148
+ config.run_config_properties.structured_output_mode = mode
142
149
  adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
143
150
 
144
151
  with (
145
152
  patch.object(adapter, "has_structured_output", return_value=True),
146
- patch.object(adapter, "model_provider") as mock_provider,
147
153
  ):
148
- mock_provider.return_value.structured_output_mode = (
149
- StructuredOutputMode.json_instructions
150
- )
151
154
  options = await adapter.response_format_options()
152
155
  assert options == {}
153
156
 
154
157
 
155
158
  @pytest.mark.asyncio
156
159
  async def test_response_format_options_json_schema(config, mock_task):
160
+ config.run_config_properties.structured_output_mode = (
161
+ StructuredOutputMode.json_schema
162
+ )
157
163
  adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
158
164
 
159
165
  with (
160
166
  patch.object(adapter, "has_structured_output", return_value=True),
161
- patch.object(adapter, "model_provider") as mock_provider,
162
167
  ):
163
- mock_provider.return_value.structured_output_mode = (
164
- StructuredOutputMode.json_schema
165
- )
166
168
  options = await adapter.response_format_options()
167
169
  assert options == {
168
170
  "response_format": {
@@ -350,6 +352,32 @@ def test_litellm_model_id_unknown_provider(config, mock_task):
350
352
  adapter.litellm_model_id()
351
353
 
352
354
 
355
+ @pytest.mark.asyncio
356
+ async def test_build_completion_kwargs_custom_temperature_top_p(config, mock_task):
357
+ """Test build_completion_kwargs with custom temperature and top_p values"""
358
+ # Create config with custom temperature and top_p
359
+ config.run_config_properties.temperature = 0.7
360
+ config.run_config_properties.top_p = 0.9
361
+
362
+ adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
363
+ mock_provider = Mock()
364
+ messages = [{"role": "user", "content": "Hello"}]
365
+
366
+ with (
367
+ patch.object(adapter, "model_provider", return_value=mock_provider),
368
+ patch.object(adapter, "litellm_model_id", return_value="openai/test-model"),
369
+ patch.object(adapter, "build_extra_body", return_value={}),
370
+ patch.object(adapter, "response_format_options", return_value={}),
371
+ ):
372
+ kwargs = await adapter.build_completion_kwargs(mock_provider, messages, None)
373
+
374
+ # Verify custom temperature and top_p are passed through
375
+ assert kwargs["temperature"] == 0.7
376
+ assert kwargs["top_p"] == 0.9
377
+ # Verify drop_params is set correctly
378
+ assert kwargs["drop_params"] is True
379
+
380
+
353
381
  @pytest.mark.asyncio
354
382
  @pytest.mark.parametrize(
355
383
  "top_logprobs,response_format,extra_body",
@@ -391,6 +419,13 @@ async def test_build_completion_kwargs(
391
419
  assert kwargs["messages"] == messages
392
420
  assert kwargs["api_base"] == config.base_url
393
421
 
422
+ # Verify temperature and top_p are included with default values
423
+ assert kwargs["temperature"] == 1.0 # Default from RunConfigProperties
424
+ assert kwargs["top_p"] == 1.0 # Default from RunConfigProperties
425
+
426
+ # Verify drop_params is set correctly
427
+ assert kwargs["drop_params"] is True
428
+
394
429
  # Verify optional parameters
395
430
  if top_logprobs is not None:
396
431
  assert kwargs["logprobs"] is True
@@ -46,6 +46,7 @@ def adapter(test_task):
46
46
  model_name="phi_3_5",
47
47
  model_provider_name="ollama",
48
48
  prompt_id="simple_chain_of_thought_prompt_builder",
49
+ structured_output_mode="json_schema",
49
50
  ),
50
51
  )
51
52
 
@@ -102,6 +103,9 @@ def test_save_run_isolation(test_task, adapter):
102
103
  reloaded_output.source.properties["prompt_id"]
103
104
  == "simple_chain_of_thought_prompt_builder"
104
105
  )
106
+ assert reloaded_output.source.properties["structured_output_mode"] == "json_schema"
107
+ assert reloaded_output.source.properties["temperature"] == 1.0
108
+ assert reloaded_output.source.properties["top_p"] == 1.0
105
109
  # Run again, with same input and different output. Should create a new TaskRun.
106
110
  different_run_output = RunOutput(
107
111
  output="Different output", intermediate_outputs=None
@@ -228,3 +232,40 @@ async def test_autosave_true(test_task, adapter):
228
232
  output.source.properties["prompt_id"]
229
233
  == "simple_chain_of_thought_prompt_builder"
230
234
  )
235
+ assert output.source.properties["structured_output_mode"] == "json_schema"
236
+ assert output.source.properties["temperature"] == 1.0
237
+ assert output.source.properties["top_p"] == 1.0
238
+
239
+
240
+ def test_properties_for_task_output_custom_values(test_task):
241
+ """Test that _properties_for_task_output includes custom temperature, top_p, and structured_output_mode"""
242
+ adapter = MockAdapter(
243
+ run_config=RunConfig(
244
+ task=test_task,
245
+ model_name="gpt-4",
246
+ model_provider_name="openai",
247
+ prompt_id="simple_prompt_builder",
248
+ temperature=0.7,
249
+ top_p=0.9,
250
+ structured_output_mode="json_schema",
251
+ ),
252
+ )
253
+
254
+ input_data = "Test input"
255
+ output_data = "Test output"
256
+ run_output = RunOutput(output=output_data, intermediate_outputs=None)
257
+
258
+ task_run = adapter.generate_run(
259
+ input=input_data, input_source=None, run_output=run_output
260
+ )
261
+ task_run.save_to_file()
262
+
263
+ # Verify custom values are preserved in properties
264
+ output = task_run.output
265
+ assert output.source.properties["adapter_name"] == "mock_adapter"
266
+ assert output.source.properties["model_name"] == "gpt-4"
267
+ assert output.source.properties["model_provider"] == "openai"
268
+ assert output.source.properties["prompt_id"] == "simple_prompt_builder"
269
+ assert output.source.properties["structured_output_mode"] == "json_schema"
270
+ assert output.source.properties["temperature"] == 0.7
271
+ assert output.source.properties["top_p"] == 0.9
@@ -17,7 +17,7 @@ from kiln_ai.adapters.model_adapters.base_adapter import (
17
17
  from kiln_ai.adapters.ollama_tools import ollama_online
18
18
  from kiln_ai.adapters.test_prompt_adaptors import get_all_models_and_providers
19
19
  from kiln_ai.datamodel import PromptId
20
- from kiln_ai.datamodel.task import RunConfig
20
+ from kiln_ai.datamodel.task import RunConfig, RunConfigProperties
21
21
  from kiln_ai.datamodel.test_json_schema import json_joke_schema, json_triangle_schema
22
22
 
23
23
 
@@ -51,6 +51,7 @@ class MockAdapter(BaseAdapter):
51
51
  model_name="phi_3_5",
52
52
  model_provider_name="ollama",
53
53
  prompt_id="simple_chain_of_thought_prompt_builder",
54
+ structured_output_mode="json_schema",
54
55
  ),
55
56
  )
56
57
  self.response = response
@@ -146,7 +147,15 @@ def build_structured_output_test_task(tmp_path: Path):
146
147
 
147
148
  async def run_structured_output_test(tmp_path: Path, model_name: str, provider: str):
148
149
  task = build_structured_output_test_task(tmp_path)
149
- a = adapter_for_task(task, model_name=model_name, provider=provider)
150
+ a = adapter_for_task(
151
+ task,
152
+ run_config_properties=RunConfigProperties(
153
+ model_name=model_name,
154
+ model_provider_name=provider,
155
+ prompt_id="simple_prompt_builder",
156
+ structured_output_mode="unknown",
157
+ ),
158
+ )
150
159
  try:
151
160
  run = await a.invoke("Cows") # a joke about cows
152
161
  parsed = json.loads(run.output.output)
@@ -197,10 +206,12 @@ def build_structured_input_test_task(tmp_path: Path):
197
206
  return task
198
207
 
199
208
 
200
- async def run_structured_input_test(tmp_path: Path, model_name: str, provider: str):
209
+ async def run_structured_input_test(
210
+ tmp_path: Path, model_name: str, provider: str, prompt_id: PromptId
211
+ ):
201
212
  task = build_structured_input_test_task(tmp_path)
202
213
  try:
203
- await run_structured_input_task(task, model_name, provider)
214
+ await run_structured_input_task(task, model_name, provider, prompt_id)
204
215
  except ValueError as e:
205
216
  if str(e) == "Failed to connect to Ollama. Ensure Ollama is running.":
206
217
  pytest.skip(
@@ -209,17 +220,20 @@ async def run_structured_input_test(tmp_path: Path, model_name: str, provider: s
209
220
  raise e
210
221
 
211
222
 
212
- async def run_structured_input_task(
223
+ async def run_structured_input_task_no_validation(
213
224
  task: datamodel.Task,
214
225
  model_name: str,
215
226
  provider: str,
216
- prompt_id: PromptId | None = None,
227
+ prompt_id: PromptId,
217
228
  ):
218
229
  a = adapter_for_task(
219
230
  task,
220
- model_name=model_name,
221
- provider=provider,
222
- prompt_id=prompt_id,
231
+ run_config_properties=RunConfigProperties(
232
+ model_name=model_name,
233
+ model_provider_name=provider,
234
+ prompt_id=prompt_id,
235
+ structured_output_mode="unknown",
236
+ ),
223
237
  )
224
238
  with pytest.raises(ValueError):
225
239
  # not structured input in dictionary
@@ -231,18 +245,29 @@ async def run_structured_input_task(
231
245
  try:
232
246
  run = await a.invoke({"a": 2, "b": 2, "c": 2})
233
247
  response = run.output.output
248
+ return response, a
234
249
  except ValueError as e:
235
250
  if str(e) == "Failed to connect to Ollama. Ensure Ollama is running.":
236
251
  pytest.skip(
237
252
  f"Skipping {model_name} {provider} because Ollama is not running"
238
253
  )
239
254
  raise e
255
+
256
+
257
+ async def run_structured_input_task(
258
+ task: datamodel.Task,
259
+ model_name: str,
260
+ provider: str,
261
+ prompt_id: PromptId,
262
+ ):
263
+ response, a = await run_structured_input_task_no_validation(
264
+ task, model_name, provider, prompt_id
265
+ )
240
266
  assert response is not None
241
267
  if isinstance(response, str):
242
268
  assert "[[equilateral]]" in response
243
269
  else:
244
270
  assert response["is_equilateral"] is True
245
-
246
271
  expected_pb_name = "simple_prompt_builder"
247
272
  if prompt_id is not None:
248
273
  expected_pb_name = prompt_id
@@ -269,7 +294,9 @@ async def test_structured_input_gpt_4o_mini(tmp_path):
269
294
  async def test_all_built_in_models_structured_input(
270
295
  tmp_path, model_name, provider_name
271
296
  ):
272
- await run_structured_input_test(tmp_path, model_name, provider_name)
297
+ await run_structured_input_test(
298
+ tmp_path, model_name, provider_name, "simple_prompt_builder"
299
+ )
273
300
 
274
301
 
275
302
  @pytest.mark.paid
@@ -323,6 +350,11 @@ When asked for a final result, this is the format (for an equilateral example):
323
350
  """
324
351
  task.output_json_schema = json.dumps(triangle_schema)
325
352
  task.save_to_file()
326
- await run_structured_input_task(
353
+ response, adapter = await run_structured_input_task_no_validation(
327
354
  task, model_name, provider_name, "simple_chain_of_thought_prompt_builder"
328
355
  )
356
+
357
+ formatted_response = json.loads(response)
358
+ assert formatted_response["is_equilateral"] is True
359
+ assert formatted_response["is_scalene"] is False
360
+ assert formatted_response["is_obtuse"] is False
@@ -1,6 +1,4 @@
1
- import json
2
1
  from abc import ABCMeta, abstractmethod
3
- from typing import Dict
4
2
 
5
3
  from kiln_ai.datamodel import PromptGenerators, PromptId, Task, TaskRun
6
4
  from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
@@ -53,20 +51,6 @@ class BasePromptBuilder(metaclass=ABCMeta):
53
51
  """
54
52
  pass
55
53
 
56
- def build_user_message(self, input: Dict | str) -> str:
57
- """Build a user message from the input.
58
-
59
- Args:
60
- input (Union[Dict, str]): The input to format into a message.
61
-
62
- Returns:
63
- str: The formatted user message.
64
- """
65
- if isinstance(input, Dict):
66
- return f"The input is:\n{json.dumps(input, indent=2, ensure_ascii=False)}"
67
-
68
- return f"The input is:\n{input}"
69
-
70
54
  def chain_of_thought_prompt(self) -> str | None:
71
55
  """Build and return the chain of thought prompt string.
72
56
 
@@ -1,3 +1,4 @@
1
+ import logging
1
2
  from dataclasses import dataclass
2
3
  from typing import Dict, List
3
4
 
@@ -16,11 +17,15 @@ from kiln_ai.adapters.model_adapters.litellm_config import (
16
17
  from kiln_ai.adapters.ollama_tools import (
17
18
  get_ollama_connection,
18
19
  )
19
- from kiln_ai.datamodel import Finetune, FinetuneDataStrategy, Task
20
+ from kiln_ai.datamodel import Finetune, Task
21
+ from kiln_ai.datamodel.datamodel_enums import ChatStrategy
20
22
  from kiln_ai.datamodel.registry import project_from_id
23
+ from kiln_ai.datamodel.task import RunConfigProperties
21
24
  from kiln_ai.utils.config import Config
22
25
  from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
23
26
 
27
+ logger = logging.getLogger(__name__)
28
+
24
29
 
25
30
  async def provider_enabled(provider_name: ModelProviderName) -> bool:
26
31
  if provider_name == ModelProviderName.ollama:
@@ -163,6 +168,10 @@ def kiln_model_provider_from(
163
168
  # For custom registry, get the provider name and model name from the model id
164
169
  if provider_name == ModelProviderName.kiln_custom_registry:
165
170
  provider_name, name = parse_custom_model_id(name)
171
+ else:
172
+ logger.warning(
173
+ f"Unexpected model/provider pair. Will treat as custom model but check your model settings. Provider: {provider_name}/{name}"
174
+ )
166
175
 
167
176
  # Custom/untested model. Set untested, and build a ModelProvider at runtime
168
177
  if provider_name is None:
@@ -177,12 +186,15 @@ def kiln_model_provider_from(
177
186
  supports_data_gen=False,
178
187
  untested_model=True,
179
188
  model_id=name,
189
+ # We don't know the structured output mode for custom models, so we default to json_instructions which is the only one that works everywhere.
190
+ structured_output_mode=StructuredOutputMode.json_instructions,
180
191
  )
181
192
 
182
193
 
183
- def lite_llm_config(
184
- model_id: str,
194
+ def lite_llm_config_for_openai_compatible(
195
+ run_config_properties: RunConfigProperties,
185
196
  ) -> LiteLlmConfig:
197
+ model_id = run_config_properties.model_name
186
198
  try:
187
199
  openai_provider_name, model_id = model_id.split("::")
188
200
  except Exception:
@@ -206,10 +218,16 @@ def lite_llm_config(
206
218
  f"OpenAI compatible provider {openai_provider_name} has no base URL"
207
219
  )
208
220
 
221
+ # Update a copy of the run config properties to use the openai compatible provider
222
+ updated_run_config_properties = run_config_properties.model_copy(deep=True)
223
+ updated_run_config_properties.model_provider_name = (
224
+ ModelProviderName.openai_compatible
225
+ )
226
+ updated_run_config_properties.model_name = model_id
227
+
209
228
  return LiteLlmConfig(
210
229
  # OpenAI compatible, with a custom base URL
211
- model_name=model_id,
212
- provider_name=ModelProviderName.openai_compatible,
230
+ run_config_properties=updated_run_config_properties,
213
231
  base_url=base_url,
214
232
  additional_body_options={
215
233
  "api_key": api_key,
@@ -259,9 +277,9 @@ def finetune_from_id(model_id: str) -> Finetune:
259
277
 
260
278
 
261
279
  def parser_from_data_strategy(
262
- data_strategy: FinetuneDataStrategy,
280
+ data_strategy: ChatStrategy,
263
281
  ) -> ModelParserID | None:
264
- if data_strategy == FinetuneDataStrategy.final_and_intermediate_r1_compatible:
282
+ if data_strategy == ChatStrategy.single_turn_r1_thinking:
265
283
  return ModelParserID.r1_thinking
266
284
  return None
267
285
 
@@ -279,10 +297,10 @@ def finetune_provider_model(
279
297
  reasoning_capable=(
280
298
  fine_tune.data_strategy
281
299
  in [
282
- FinetuneDataStrategy.final_and_intermediate,
283
- FinetuneDataStrategy.final_and_intermediate_r1_compatible,
300
+ ChatStrategy.single_turn_r1_thinking,
284
301
  ]
285
302
  ),
303
+ tuned_chat_strategy=fine_tune.data_strategy,
286
304
  )
287
305
 
288
306
  if provider == ModelProviderName.vertex and fine_tune.fine_tune_model_id:
@@ -21,6 +21,7 @@ from kiln_ai.datamodel import (
21
21
  TaskRequirement,
22
22
  TaskRun,
23
23
  )
24
+ from kiln_ai.datamodel.task import RunConfigProperties
24
25
 
25
26
  json_joke_schema = """{
26
27
  "type": "object",
@@ -189,7 +190,15 @@ async def test_live_run(sample_task, sample_task_run, sample_repair_data):
189
190
  repair_task_input = RepairTaskRun.build_repair_task_input(**sample_repair_data)
190
191
  assert isinstance(repair_task_input, RepairTaskInput)
191
192
 
192
- adapter = adapter_for_task(repair_task, model_name="llama_3_1_8b", provider="groq")
193
+ adapter = adapter_for_task(
194
+ repair_task,
195
+ RunConfigProperties(
196
+ model_name="llama_3_1_8b",
197
+ model_provider_name="groq",
198
+ prompt_id="simple_prompt_builder",
199
+ structured_output_mode="default",
200
+ ),
201
+ )
193
202
 
194
203
  run = await adapter.invoke(repair_task_input.model_dump())
195
204
  assert run is not None
@@ -198,10 +207,13 @@ async def test_live_run(sample_task, sample_task_run, sample_repair_data):
198
207
  assert "setup" in parsed_output
199
208
  assert "punchline" in parsed_output
200
209
  assert run.output.source.properties == {
201
- "adapter_name": "kiln_langchain_adapter",
210
+ "adapter_name": "kiln_openai_compatible_adapter",
202
211
  "model_name": "llama_3_1_8b",
203
212
  "model_provider": "groq",
204
213
  "prompt_id": "simple_prompt_builder",
214
+ "structured_output_mode": "default",
215
+ "temperature": 1.0,
216
+ "top_p": 1.0,
205
217
  }
206
218
 
207
219
 
@@ -224,7 +236,13 @@ async def test_mocked_repair_task_run(sample_task, sample_task_run, sample_repai
224
236
  )
225
237
 
226
238
  adapter = adapter_for_task(
227
- repair_task, model_name="llama_3_1_8b", provider="ollama"
239
+ repair_task,
240
+ RunConfigProperties(
241
+ model_name="llama_3_1_8b",
242
+ model_provider_name="ollama",
243
+ prompt_id="simple_prompt_builder",
244
+ structured_output_mode="json_schema",
245
+ ),
228
246
  )
229
247
 
230
248
  run = await adapter.invoke(repair_task_input.model_dump())
@@ -240,6 +258,9 @@ async def test_mocked_repair_task_run(sample_task, sample_task_run, sample_repai
240
258
  "model_name": "llama_3_1_8b",
241
259
  "model_provider": "ollama",
242
260
  "prompt_id": "simple_prompt_builder",
261
+ "structured_output_mode": "json_schema",
262
+ "temperature": 1.0,
263
+ "top_p": 1.0,
243
264
  }
244
265
  assert run.input_source.type == DataSourceType.human
245
266
  assert "created_by" in run.input_source.properties