kiln-ai 0.15.0__py3-none-any.whl → 0.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (72) hide show
  1. kiln_ai/adapters/__init__.py +2 -0
  2. kiln_ai/adapters/adapter_registry.py +22 -44
  3. kiln_ai/adapters/chat/__init__.py +8 -0
  4. kiln_ai/adapters/chat/chat_formatter.py +234 -0
  5. kiln_ai/adapters/chat/test_chat_formatter.py +131 -0
  6. kiln_ai/adapters/data_gen/test_data_gen_task.py +19 -6
  7. kiln_ai/adapters/eval/base_eval.py +8 -6
  8. kiln_ai/adapters/eval/eval_runner.py +9 -65
  9. kiln_ai/adapters/eval/g_eval.py +26 -8
  10. kiln_ai/adapters/eval/test_base_eval.py +166 -15
  11. kiln_ai/adapters/eval/test_eval_runner.py +3 -0
  12. kiln_ai/adapters/eval/test_g_eval.py +1 -0
  13. kiln_ai/adapters/fine_tune/base_finetune.py +2 -2
  14. kiln_ai/adapters/fine_tune/dataset_formatter.py +153 -197
  15. kiln_ai/adapters/fine_tune/test_base_finetune.py +10 -10
  16. kiln_ai/adapters/fine_tune/test_dataset_formatter.py +402 -211
  17. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +3 -3
  18. kiln_ai/adapters/fine_tune/test_openai_finetune.py +6 -6
  19. kiln_ai/adapters/fine_tune/test_together_finetune.py +1 -0
  20. kiln_ai/adapters/fine_tune/test_vertex_finetune.py +4 -4
  21. kiln_ai/adapters/fine_tune/together_finetune.py +12 -1
  22. kiln_ai/adapters/ml_model_list.py +556 -45
  23. kiln_ai/adapters/model_adapters/base_adapter.py +100 -35
  24. kiln_ai/adapters/model_adapters/litellm_adapter.py +116 -100
  25. kiln_ai/adapters/model_adapters/litellm_config.py +3 -2
  26. kiln_ai/adapters/model_adapters/test_base_adapter.py +299 -52
  27. kiln_ai/adapters/model_adapters/test_litellm_adapter.py +121 -22
  28. kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +44 -2
  29. kiln_ai/adapters/model_adapters/test_structured_output.py +48 -18
  30. kiln_ai/adapters/parsers/base_parser.py +0 -3
  31. kiln_ai/adapters/parsers/parser_registry.py +5 -3
  32. kiln_ai/adapters/parsers/r1_parser.py +17 -2
  33. kiln_ai/adapters/parsers/request_formatters.py +40 -0
  34. kiln_ai/adapters/parsers/test_parser_registry.py +2 -2
  35. kiln_ai/adapters/parsers/test_r1_parser.py +44 -1
  36. kiln_ai/adapters/parsers/test_request_formatters.py +76 -0
  37. kiln_ai/adapters/prompt_builders.py +14 -17
  38. kiln_ai/adapters/provider_tools.py +39 -4
  39. kiln_ai/adapters/repair/test_repair_task.py +27 -5
  40. kiln_ai/adapters/test_adapter_registry.py +88 -28
  41. kiln_ai/adapters/test_ml_model_list.py +158 -0
  42. kiln_ai/adapters/test_prompt_adaptors.py +17 -3
  43. kiln_ai/adapters/test_prompt_builders.py +27 -19
  44. kiln_ai/adapters/test_provider_tools.py +130 -12
  45. kiln_ai/datamodel/__init__.py +2 -2
  46. kiln_ai/datamodel/datamodel_enums.py +43 -4
  47. kiln_ai/datamodel/dataset_filters.py +69 -1
  48. kiln_ai/datamodel/dataset_split.py +4 -0
  49. kiln_ai/datamodel/eval.py +8 -0
  50. kiln_ai/datamodel/finetune.py +13 -7
  51. kiln_ai/datamodel/prompt_id.py +1 -0
  52. kiln_ai/datamodel/task.py +68 -7
  53. kiln_ai/datamodel/task_output.py +1 -1
  54. kiln_ai/datamodel/task_run.py +39 -7
  55. kiln_ai/datamodel/test_basemodel.py +5 -8
  56. kiln_ai/datamodel/test_dataset_filters.py +82 -0
  57. kiln_ai/datamodel/test_dataset_split.py +2 -8
  58. kiln_ai/datamodel/test_example_models.py +54 -0
  59. kiln_ai/datamodel/test_models.py +80 -9
  60. kiln_ai/datamodel/test_task.py +168 -2
  61. kiln_ai/utils/async_job_runner.py +106 -0
  62. kiln_ai/utils/config.py +3 -2
  63. kiln_ai/utils/dataset_import.py +81 -19
  64. kiln_ai/utils/logging.py +165 -0
  65. kiln_ai/utils/test_async_job_runner.py +199 -0
  66. kiln_ai/utils/test_config.py +23 -0
  67. kiln_ai/utils/test_dataset_import.py +272 -10
  68. {kiln_ai-0.15.0.dist-info → kiln_ai-0.17.0.dist-info}/METADATA +1 -1
  69. kiln_ai-0.17.0.dist-info/RECORD +113 -0
  70. kiln_ai-0.15.0.dist-info/RECORD +0 -104
  71. {kiln_ai-0.15.0.dist-info → kiln_ai-0.17.0.dist-info}/WHEEL +0 -0
  72. {kiln_ai-0.15.0.dist-info → kiln_ai-0.17.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -3,16 +3,18 @@ from unittest.mock import MagicMock, patch
3
3
  import pytest
4
4
 
5
5
  from kiln_ai.adapters.ml_model_list import KilnModelProvider, StructuredOutputMode
6
- from kiln_ai.adapters.model_adapters.base_adapter import BaseAdapter
6
+ from kiln_ai.adapters.model_adapters.base_adapter import BaseAdapter, RunOutput
7
+ from kiln_ai.adapters.parsers.request_formatters import request_formatter_from_id
7
8
  from kiln_ai.datamodel import Task
8
- from kiln_ai.datamodel.task import RunConfig
9
+ from kiln_ai.datamodel.datamodel_enums import ChatStrategy
10
+ from kiln_ai.datamodel.task import RunConfig, RunConfigProperties
9
11
 
10
12
 
11
13
  class MockAdapter(BaseAdapter):
12
14
  """Concrete implementation of BaseAdapter for testing"""
13
15
 
14
16
  async def _run(self, input):
15
- return None
17
+ return None, None
16
18
 
17
19
  def adapter_name(self) -> str:
18
20
  return "test"
@@ -36,12 +38,29 @@ def adapter(base_task):
36
38
  run_config=RunConfig(
37
39
  task=base_task,
38
40
  model_name="test_model",
39
- model_provider_name="test_provider",
41
+ model_provider_name="openai",
40
42
  prompt_id="simple_prompt_builder",
43
+ structured_output_mode="json_schema",
41
44
  ),
42
45
  )
43
46
 
44
47
 
48
+ @pytest.fixture
49
+ def mock_formatter():
50
+ formatter = MagicMock()
51
+ formatter.format_input.return_value = {"formatted": "input"}
52
+ return formatter
53
+
54
+
55
+ @pytest.fixture
56
+ def mock_parser():
57
+ parser = MagicMock()
58
+ parser.parse_output.return_value = RunOutput(
59
+ output="test output", intermediate_outputs={}
60
+ )
61
+ return parser
62
+
63
+
45
64
  async def test_model_provider_uses_cache(adapter, mock_provider):
46
65
  """Test that cached provider is returned if it exists"""
47
66
  # Set up cached provider
@@ -71,7 +90,7 @@ async def test_model_provider_loads_and_caches(adapter, mock_provider):
71
90
  # First call should load and cache
72
91
  provider1 = adapter.model_provider()
73
92
  assert provider1 == mock_provider
74
- mock_loader.assert_called_once_with("test_model", "test_provider")
93
+ mock_loader.assert_called_once_with("test_model", "openai")
75
94
 
76
95
  # Second call should use cache
77
96
  mock_loader.reset_mock()
@@ -80,29 +99,30 @@ async def test_model_provider_loads_and_caches(adapter, mock_provider):
80
99
  mock_loader.assert_not_called()
81
100
 
82
101
 
83
- async def test_model_provider_missing_names(base_task):
102
+ async def test_model_provider_invalid_provider_model_name(base_task):
103
+ """Test error when model or provider name is missing"""
104
+ # Test with missing model name
105
+ with pytest.raises(ValueError, match="Input should be"):
106
+ adapter = MockAdapter(
107
+ run_config=RunConfig(
108
+ task=base_task,
109
+ model_name="test_model",
110
+ model_provider_name="invalid",
111
+ prompt_id="simple_prompt_builder",
112
+ ),
113
+ )
114
+
115
+
116
+ async def test_model_provider_missing_model_names(base_task):
84
117
  """Test error when model or provider name is missing"""
85
118
  # Test with missing model name
86
119
  adapter = MockAdapter(
87
120
  run_config=RunConfig(
88
121
  task=base_task,
89
122
  model_name="",
90
- model_provider_name="",
91
- prompt_id="simple_prompt_builder",
92
- ),
93
- )
94
- with pytest.raises(
95
- ValueError, match="model_name and model_provider_name must be provided"
96
- ):
97
- await adapter.model_provider()
98
-
99
- # Test with missing provider name
100
- adapter = MockAdapter(
101
- run_config=RunConfig(
102
- task=base_task,
103
- model_name="test_model",
104
- model_provider_name="",
123
+ model_provider_name="openai",
105
124
  prompt_id="simple_prompt_builder",
125
+ structured_output_mode="json_schema",
106
126
  ),
107
127
  )
108
128
  with pytest.raises(
@@ -121,7 +141,7 @@ async def test_model_provider_not_found(adapter):
121
141
 
122
142
  with pytest.raises(
123
143
  ValueError,
124
- match="model_provider_name test_provider not found for model test_model",
144
+ match="not found for model test_model",
125
145
  ):
126
146
  await adapter.model_provider()
127
147
 
@@ -151,11 +171,7 @@ async def test_prompt_builder_json_instructions(
151
171
  adapter.prompt_builder = mock_prompt_builder
152
172
  adapter.model_provider_name = "openai"
153
173
  adapter.has_structured_output = MagicMock(return_value=output_schema)
154
-
155
- # provider mock
156
- provider = MagicMock()
157
- provider.structured_output_mode = structured_output_mode
158
- adapter.model_provider = MagicMock(return_value=provider)
174
+ adapter.run_config.structured_output_mode = structured_output_mode
159
175
 
160
176
  # Test
161
177
  adapter.build_prompt()
@@ -164,36 +180,267 @@ async def test_prompt_builder_json_instructions(
164
180
  )
165
181
 
166
182
 
183
+ @pytest.mark.asyncio
167
184
  @pytest.mark.parametrize(
168
- "cot_prompt,has_structured_output,reasoning_capable,expected",
185
+ "formatter_id,expected_input,expected_calls",
169
186
  [
170
- # COT and normal LLM
171
- ("think carefully", False, False, ("cot_two_call", "think carefully")),
172
- # Structured output with thinking-capable LLM
173
- ("think carefully", True, True, ("cot_as_message", "think carefully")),
174
- # Structured output with normal LLM
175
- ("think carefully", True, False, ("cot_two_call", "think carefully")),
176
- # Basic cases - no COT
177
- (None, True, True, ("basic", None)),
178
- (None, False, False, ("basic", None)),
179
- (None, True, False, ("basic", None)),
180
- (None, False, True, ("basic", None)),
181
- # Edge case - COT prompt exists but structured output is False and reasoning_capable is True
182
- ("think carefully", False, True, ("cot_as_message", "think carefully")),
187
+ (None, {"original": "input"}, 0), # No formatter
188
+ ("test_formatter", {"formatted": "input"}, 1), # With formatter
183
189
  ],
184
190
  )
185
- async def test_run_strategy(
186
- adapter, cot_prompt, has_structured_output, reasoning_capable, expected
191
+ async def test_input_formatting(
192
+ adapter, mock_formatter, mock_parser, formatter_id, expected_input, expected_calls
187
193
  ):
188
- """Test that run_strategy returns correct strategy based on conditions"""
189
- # Mock dependencies
190
- adapter.prompt_builder.chain_of_thought_prompt = MagicMock(return_value=cot_prompt)
191
- adapter.has_structured_output = MagicMock(return_value=has_structured_output)
192
-
194
+ """Test that input formatting is handled correctly based on formatter configuration"""
195
+ # Mock the model provider to return our formatter ID and parser
193
196
  provider = MagicMock()
194
- provider.reasoning_capable = reasoning_capable
197
+ provider.formatter = formatter_id
198
+ provider.parser = "test_parser"
199
+ provider.reasoning_capable = False
195
200
  adapter.model_provider = MagicMock(return_value=provider)
196
201
 
197
- # Test
198
- result = adapter.run_strategy()
199
- assert result == expected
202
+ # Mock the formatter factory and parser factory
203
+ with (
204
+ patch(
205
+ "kiln_ai.adapters.model_adapters.base_adapter.request_formatter_from_id"
206
+ ) as mock_factory,
207
+ patch(
208
+ "kiln_ai.adapters.model_adapters.base_adapter.model_parser_from_id"
209
+ ) as mock_parser_factory,
210
+ ):
211
+ mock_factory.return_value = mock_formatter
212
+ mock_parser_factory.return_value = mock_parser
213
+
214
+ # Mock the _run method to capture the input
215
+ captured_input = None
216
+
217
+ async def mock_run(input):
218
+ nonlocal captured_input
219
+ captured_input = input
220
+ return RunOutput(output="test output", intermediate_outputs={}), None
221
+
222
+ adapter._run = mock_run
223
+
224
+ # Run the adapter
225
+ original_input = {"original": "input"}
226
+ await adapter.invoke_returning_run_output(original_input)
227
+
228
+ # Verify formatter was called correctly
229
+ assert captured_input == expected_input
230
+ assert mock_factory.call_count == (1 if formatter_id else 0)
231
+ assert mock_formatter.format_input.call_count == expected_calls
232
+
233
+ # Verify original input was preserved in the run
234
+ if formatter_id:
235
+ mock_formatter.format_input.assert_called_once_with(original_input)
236
+
237
+
238
+ async def test_properties_for_task_output_includes_all_run_config_properties(adapter):
239
+ """Test that all properties from RunConfigProperties are saved in task output properties"""
240
+ # Get all field names from RunConfigProperties
241
+ run_config_properties_fields = set(RunConfigProperties.model_fields.keys())
242
+
243
+ # Get the properties saved by the adapter
244
+ saved_properties = adapter._properties_for_task_output()
245
+ saved_property_keys = set(saved_properties.keys())
246
+
247
+ # Check which RunConfigProperties fields are missing from saved properties
248
+ # Note: model_provider_name becomes model_provider in saved properties
249
+ expected_mappings = {
250
+ "model_name": "model_name",
251
+ "model_provider_name": "model_provider",
252
+ "prompt_id": "prompt_id",
253
+ "temperature": "temperature",
254
+ "top_p": "top_p",
255
+ "structured_output_mode": "structured_output_mode",
256
+ }
257
+
258
+ missing_properties = []
259
+ for field_name in run_config_properties_fields:
260
+ expected_key = expected_mappings.get(field_name, field_name)
261
+ if expected_key not in saved_property_keys:
262
+ missing_properties.append(
263
+ f"RunConfigProperties.{field_name} -> {expected_key}"
264
+ )
265
+
266
+ assert not missing_properties, (
267
+ f"The following RunConfigProperties fields are not saved by _properties_for_task_output: {missing_properties}. Please update the method to include them."
268
+ )
269
+
270
+
271
+ async def test_properties_for_task_output_catches_missing_new_property(adapter):
272
+ """Test that demonstrates our test will catch when new properties are added to RunConfigProperties but not to _properties_for_task_output"""
273
+ # Simulate what happens if a new property was added to RunConfigProperties
274
+ # We'll mock the model_fields to include a fake new property
275
+ original_fields = RunConfigProperties.model_fields.copy()
276
+
277
+ # Create a mock field to simulate a new property being added
278
+ from pydantic.fields import FieldInfo
279
+
280
+ mock_field = FieldInfo(annotation=str, default="default_value")
281
+
282
+ try:
283
+ # Add a fake new field to simulate someone adding a property
284
+ RunConfigProperties.model_fields["new_fake_property"] = mock_field
285
+
286
+ # Get all field names from RunConfigProperties (now includes our fake property)
287
+ run_config_properties_fields = set(RunConfigProperties.model_fields.keys())
288
+
289
+ # Get the properties saved by the adapter (won't include our fake property)
290
+ saved_properties = adapter._properties_for_task_output()
291
+ saved_property_keys = set(saved_properties.keys())
292
+
293
+ # The mappings don't include our fake property
294
+ expected_mappings = {
295
+ "model_name": "model_name",
296
+ "model_provider_name": "model_provider",
297
+ "prompt_id": "prompt_id",
298
+ "temperature": "temperature",
299
+ "top_p": "top_p",
300
+ "structured_output_mode": "structured_output_mode",
301
+ }
302
+
303
+ missing_properties = []
304
+ for field_name in run_config_properties_fields:
305
+ expected_key = expected_mappings.get(field_name, field_name)
306
+ if expected_key not in saved_property_keys:
307
+ missing_properties.append(
308
+ f"RunConfigProperties.{field_name} -> {expected_key}"
309
+ )
310
+
311
+ # This should find our missing fake property
312
+ assert missing_properties == [
313
+ "RunConfigProperties.new_fake_property -> new_fake_property"
314
+ ], f"Expected to find missing fake property, but got: {missing_properties}"
315
+
316
+ finally:
317
+ # Restore the original fields
318
+ RunConfigProperties.model_fields.clear()
319
+ RunConfigProperties.model_fields.update(original_fields)
320
+
321
+
322
+ @pytest.mark.parametrize(
323
+ "cot_prompt,tuned_strategy,reasoning_capable,expected_formatter_class",
324
+ [
325
+ # No COT prompt -> always single turn
326
+ (None, None, False, "SingleTurnFormatter"),
327
+ (None, ChatStrategy.two_message_cot, False, "SingleTurnFormatter"),
328
+ (None, ChatStrategy.single_turn_r1_thinking, True, "SingleTurnFormatter"),
329
+ # With COT prompt:
330
+ # - Tuned strategy takes precedence (except single turn)
331
+ (
332
+ "think step by step",
333
+ ChatStrategy.two_message_cot,
334
+ False,
335
+ "TwoMessageCotFormatter",
336
+ ),
337
+ (
338
+ "think step by step",
339
+ ChatStrategy.single_turn_r1_thinking,
340
+ False,
341
+ "SingleTurnR1ThinkingFormatter",
342
+ ),
343
+ # - Tuned single turn is ignored when COT exists
344
+ (
345
+ "think step by step",
346
+ ChatStrategy.single_turn,
347
+ True,
348
+ "SingleTurnR1ThinkingFormatter",
349
+ ),
350
+ # - Reasoning capable -> single turn R1 thinking
351
+ ("think step by step", None, True, "SingleTurnR1ThinkingFormatter"),
352
+ # - Not reasoning capable -> two message COT
353
+ ("think step by step", None, False, "TwoMessageCotFormatter"),
354
+ ],
355
+ )
356
+ def test_build_chat_formatter(
357
+ adapter,
358
+ cot_prompt,
359
+ tuned_strategy,
360
+ reasoning_capable,
361
+ expected_formatter_class,
362
+ ):
363
+ """Test chat formatter strategy selection based on COT prompt, tuned strategy, and model capabilities"""
364
+ # Mock the prompt builder
365
+ mock_prompt_builder = MagicMock()
366
+ mock_prompt_builder.chain_of_thought_prompt.return_value = cot_prompt
367
+ mock_prompt_builder.build_prompt.return_value = "system message"
368
+ adapter.prompt_builder = mock_prompt_builder
369
+
370
+ # Mock the model provider
371
+ mock_provider = MagicMock()
372
+ mock_provider.tuned_chat_strategy = tuned_strategy
373
+ mock_provider.reasoning_capable = reasoning_capable
374
+ adapter.model_provider = MagicMock(return_value=mock_provider)
375
+
376
+ # Get the formatter
377
+ formatter = adapter.build_chat_formatter("test input")
378
+
379
+ # Verify the formatter type
380
+ assert formatter.__class__.__name__ == expected_formatter_class
381
+
382
+ # Verify the formatter was created with correct parameters
383
+ assert formatter.system_message == "system message"
384
+ assert formatter.user_input == "test input"
385
+ # Only check thinking_instructions for formatters that use it
386
+ if expected_formatter_class == "TwoMessageCotFormatter":
387
+ if cot_prompt:
388
+ assert formatter.thinking_instructions == cot_prompt
389
+ else:
390
+ assert formatter.thinking_instructions is None
391
+ # For other formatters, don't assert thinking_instructions
392
+
393
+ # Verify prompt builder was called correctly
394
+ mock_prompt_builder.build_prompt.assert_called_once()
395
+ mock_prompt_builder.chain_of_thought_prompt.assert_called_once()
396
+
397
+
398
+ @pytest.mark.parametrize(
399
+ "initial_mode,expected_mode",
400
+ [
401
+ (
402
+ StructuredOutputMode.json_schema,
403
+ StructuredOutputMode.json_schema,
404
+ ), # Should not change
405
+ (
406
+ StructuredOutputMode.unknown,
407
+ StructuredOutputMode.json_mode,
408
+ ), # Should update to default
409
+ ],
410
+ )
411
+ async def test_update_run_config_unknown_structured_output_mode(
412
+ base_task, initial_mode, expected_mode
413
+ ):
414
+ """Test that unknown structured output mode is updated to the default for the model provider"""
415
+ # Create a run config with the initial mode
416
+ run_config = RunConfig(
417
+ task=base_task,
418
+ model_name="test_model",
419
+ model_provider_name="openai",
420
+ prompt_id="simple_prompt_builder",
421
+ structured_output_mode=initial_mode,
422
+ temperature=0.7, # Add some other properties to verify they're preserved
423
+ top_p=0.9,
424
+ )
425
+
426
+ # Mock the default mode lookup
427
+ with patch(
428
+ "kiln_ai.adapters.model_adapters.base_adapter.default_structured_output_mode_for_model_provider"
429
+ ) as mock_default:
430
+ mock_default.return_value = StructuredOutputMode.json_mode
431
+
432
+ # Create the adapter
433
+ adapter = MockAdapter(run_config=run_config)
434
+
435
+ # Verify the mode was updated correctly
436
+ assert adapter.run_config.structured_output_mode == expected_mode
437
+
438
+ # Verify other properties were preserved
439
+ assert adapter.run_config.temperature == 0.7
440
+ assert adapter.run_config.top_p == 0.9
441
+
442
+ # Verify the default mode lookup was only called when needed
443
+ if initial_mode == StructuredOutputMode.unknown:
444
+ mock_default.assert_called_once_with("test_model", "openai")
445
+ else:
446
+ mock_default.assert_not_called()
@@ -1,6 +1,7 @@
1
1
  import json
2
2
  from unittest.mock import Mock, patch
3
3
 
4
+ import litellm
4
5
  import pytest
5
6
 
6
7
  from kiln_ai.adapters.ml_model_list import ModelProviderName, StructuredOutputMode
@@ -9,7 +10,8 @@ from kiln_ai.adapters.model_adapters.litellm_adapter import LiteLlmAdapter
9
10
  from kiln_ai.adapters.model_adapters.litellm_config import (
10
11
  LiteLlmConfig,
11
12
  )
12
- from kiln_ai.datamodel import Project, Task
13
+ from kiln_ai.datamodel import Project, Task, Usage
14
+ from kiln_ai.datamodel.task import RunConfigProperties
13
15
 
14
16
 
15
17
  @pytest.fixture
@@ -40,8 +42,12 @@ def mock_task(tmp_path):
40
42
  def config():
41
43
  return LiteLlmConfig(
42
44
  base_url="https://api.test.com",
43
- model_name="test-model",
44
- provider_name="openrouter",
45
+ run_config_properties=RunConfigProperties(
46
+ model_name="test-model",
47
+ model_provider_name="openrouter",
48
+ prompt_id="simple_prompt_builder",
49
+ structured_output_mode="json_schema",
50
+ ),
45
51
  default_headers={"X-Test": "test"},
46
52
  additional_body_options={"api_key": "test_key"},
47
53
  )
@@ -51,7 +57,6 @@ def test_initialization(config, mock_task):
51
57
  adapter = LiteLlmAdapter(
52
58
  config=config,
53
59
  kiln_task=mock_task,
54
- prompt_id="simple_prompt_builder",
55
60
  base_adapter_config=AdapterConfig(default_tags=["test-tag"]),
56
61
  )
57
62
 
@@ -59,8 +64,11 @@ def test_initialization(config, mock_task):
59
64
  assert adapter.run_config.task == mock_task
60
65
  assert adapter.run_config.prompt_id == "simple_prompt_builder"
61
66
  assert adapter.base_adapter_config.default_tags == ["test-tag"]
62
- assert adapter.run_config.model_name == config.model_name
63
- assert adapter.run_config.model_provider_name == config.provider_name
67
+ assert adapter.run_config.model_name == config.run_config_properties.model_name
68
+ assert (
69
+ adapter.run_config.model_provider_name
70
+ == config.run_config_properties.model_provider_name
71
+ )
64
72
  assert adapter.config.additional_body_options["api_key"] == "test_key"
65
73
  assert adapter._api_base == config.base_url
66
74
  assert adapter._headers == config.default_headers
@@ -71,8 +79,11 @@ def test_adapter_info(config, mock_task):
71
79
 
72
80
  assert adapter.adapter_name() == "kiln_openai_compatible_adapter"
73
81
 
74
- assert adapter.run_config.model_name == config.model_name
75
- assert adapter.run_config.model_provider_name == config.provider_name
82
+ assert adapter.run_config.model_name == config.run_config_properties.model_name
83
+ assert (
84
+ adapter.run_config.model_provider_name
85
+ == config.run_config_properties.model_provider_name
86
+ )
76
87
  assert adapter.run_config.prompt_id == "simple_prompt_builder"
77
88
 
78
89
 
@@ -95,14 +106,12 @@ async def test_response_format_options_unstructured(config, mock_task):
95
106
  )
96
107
  @pytest.mark.asyncio
97
108
  async def test_response_format_options_json_mode(config, mock_task, mode):
109
+ config.run_config_properties.structured_output_mode = mode
98
110
  adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
99
111
 
100
112
  with (
101
113
  patch.object(adapter, "has_structured_output", return_value=True),
102
- patch.object(adapter, "model_provider") as mock_provider,
103
114
  ):
104
- mock_provider.return_value.structured_output_mode = mode
105
-
106
115
  options = await adapter.response_format_options()
107
116
  assert options == {"response_format": {"type": "json_object"}}
108
117
 
@@ -116,14 +125,12 @@ async def test_response_format_options_json_mode(config, mock_task, mode):
116
125
  )
117
126
  @pytest.mark.asyncio
118
127
  async def test_response_format_options_function_calling(config, mock_task, mode):
128
+ config.run_config_properties.structured_output_mode = mode
119
129
  adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
120
130
 
121
131
  with (
122
132
  patch.object(adapter, "has_structured_output", return_value=True),
123
- patch.object(adapter, "model_provider") as mock_provider,
124
133
  ):
125
- mock_provider.return_value.structured_output_mode = mode
126
-
127
134
  options = await adapter.response_format_options()
128
135
  assert "tools" in options
129
136
  # full tool structure validated below
@@ -138,30 +145,26 @@ async def test_response_format_options_function_calling(config, mock_task, mode)
138
145
  )
139
146
  @pytest.mark.asyncio
140
147
  async def test_response_format_options_json_instructions(config, mock_task, mode):
148
+ config.run_config_properties.structured_output_mode = mode
141
149
  adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
142
150
 
143
151
  with (
144
152
  patch.object(adapter, "has_structured_output", return_value=True),
145
- patch.object(adapter, "model_provider") as mock_provider,
146
153
  ):
147
- mock_provider.return_value.structured_output_mode = (
148
- StructuredOutputMode.json_instructions
149
- )
150
154
  options = await adapter.response_format_options()
151
155
  assert options == {}
152
156
 
153
157
 
154
158
  @pytest.mark.asyncio
155
159
  async def test_response_format_options_json_schema(config, mock_task):
160
+ config.run_config_properties.structured_output_mode = (
161
+ StructuredOutputMode.json_schema
162
+ )
156
163
  adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
157
164
 
158
165
  with (
159
166
  patch.object(adapter, "has_structured_output", return_value=True),
160
- patch.object(adapter, "model_provider") as mock_provider,
161
167
  ):
162
- mock_provider.return_value.structured_output_mode = (
163
- StructuredOutputMode.json_schema
164
- )
165
168
  options = await adapter.response_format_options()
166
169
  assert options == {
167
170
  "response_format": {
@@ -349,6 +352,32 @@ def test_litellm_model_id_unknown_provider(config, mock_task):
349
352
  adapter.litellm_model_id()
350
353
 
351
354
 
355
+ @pytest.mark.asyncio
356
+ async def test_build_completion_kwargs_custom_temperature_top_p(config, mock_task):
357
+ """Test build_completion_kwargs with custom temperature and top_p values"""
358
+ # Create config with custom temperature and top_p
359
+ config.run_config_properties.temperature = 0.7
360
+ config.run_config_properties.top_p = 0.9
361
+
362
+ adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
363
+ mock_provider = Mock()
364
+ messages = [{"role": "user", "content": "Hello"}]
365
+
366
+ with (
367
+ patch.object(adapter, "model_provider", return_value=mock_provider),
368
+ patch.object(adapter, "litellm_model_id", return_value="openai/test-model"),
369
+ patch.object(adapter, "build_extra_body", return_value={}),
370
+ patch.object(adapter, "response_format_options", return_value={}),
371
+ ):
372
+ kwargs = await adapter.build_completion_kwargs(mock_provider, messages, None)
373
+
374
+ # Verify custom temperature and top_p are passed through
375
+ assert kwargs["temperature"] == 0.7
376
+ assert kwargs["top_p"] == 0.9
377
+ # Verify drop_params is set correctly
378
+ assert kwargs["drop_params"] is True
379
+
380
+
352
381
  @pytest.mark.asyncio
353
382
  @pytest.mark.parametrize(
354
383
  "top_logprobs,response_format,extra_body",
@@ -390,6 +419,13 @@ async def test_build_completion_kwargs(
390
419
  assert kwargs["messages"] == messages
391
420
  assert kwargs["api_base"] == config.base_url
392
421
 
422
+ # Verify temperature and top_p are included with default values
423
+ assert kwargs["temperature"] == 1.0 # Default from RunConfigProperties
424
+ assert kwargs["top_p"] == 1.0 # Default from RunConfigProperties
425
+
426
+ # Verify drop_params is set correctly
427
+ assert kwargs["drop_params"] is True
428
+
393
429
  # Verify optional parameters
394
430
  if top_logprobs is not None:
395
431
  assert kwargs["logprobs"] is True
@@ -405,3 +441,66 @@ async def test_build_completion_kwargs(
405
441
  # Verify extra body is included
406
442
  for key, value in extra_body.items():
407
443
  assert kwargs[key] == value
444
+
445
+
446
+ @pytest.mark.parametrize(
447
+ "litellm_usage,cost,expected_usage",
448
+ [
449
+ # No usage data
450
+ (None, None, None),
451
+ # Only cost
452
+ (None, 0.5, Usage(cost=0.5)),
453
+ # Only token counts
454
+ (
455
+ litellm.types.utils.Usage(
456
+ prompt_tokens=10,
457
+ completion_tokens=20,
458
+ total_tokens=30,
459
+ ),
460
+ None,
461
+ Usage(input_tokens=10, output_tokens=20, total_tokens=30),
462
+ ),
463
+ # Both cost and token counts
464
+ (
465
+ litellm.types.utils.Usage(
466
+ prompt_tokens=10,
467
+ completion_tokens=20,
468
+ total_tokens=30,
469
+ ),
470
+ 0.5,
471
+ Usage(input_tokens=10, output_tokens=20, total_tokens=30, cost=0.5),
472
+ ),
473
+ # Invalid usage type (should be ignored)
474
+ ({"prompt_tokens": 10}, None, None),
475
+ # Invalid cost type (should be ignored)
476
+ (None, "0.5", None),
477
+ ],
478
+ )
479
+ def test_usage_from_response(config, mock_task, litellm_usage, cost, expected_usage):
480
+ """Test usage_from_response with various combinations of usage data and cost"""
481
+ adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
482
+
483
+ # Create a mock response
484
+ response = Mock(spec=litellm.types.utils.ModelResponse)
485
+ response.get.return_value = litellm_usage
486
+ response._hidden_params = {"response_cost": cost}
487
+
488
+ # Call the method
489
+ result = adapter.usage_from_response(response)
490
+
491
+ # Verify the result
492
+ if expected_usage is None:
493
+ if result is not None:
494
+ assert result.input_tokens is None
495
+ assert result.output_tokens is None
496
+ assert result.total_tokens is None
497
+ assert result.cost is None
498
+ else:
499
+ assert result is not None
500
+ assert result.input_tokens == expected_usage.input_tokens
501
+ assert result.output_tokens == expected_usage.output_tokens
502
+ assert result.total_tokens == expected_usage.total_tokens
503
+ assert result.cost == expected_usage.cost
504
+
505
+ # Verify the response was queried correctly
506
+ response.get.assert_called_once_with("usage", None)