kiln-ai 0.11.1__py3-none-any.whl → 0.13.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (80) hide show
  1. kiln_ai/adapters/__init__.py +4 -0
  2. kiln_ai/adapters/adapter_registry.py +163 -39
  3. kiln_ai/adapters/data_gen/data_gen_task.py +18 -0
  4. kiln_ai/adapters/eval/__init__.py +28 -0
  5. kiln_ai/adapters/eval/base_eval.py +164 -0
  6. kiln_ai/adapters/eval/eval_runner.py +270 -0
  7. kiln_ai/adapters/eval/g_eval.py +368 -0
  8. kiln_ai/adapters/eval/registry.py +16 -0
  9. kiln_ai/adapters/eval/test_base_eval.py +325 -0
  10. kiln_ai/adapters/eval/test_eval_runner.py +641 -0
  11. kiln_ai/adapters/eval/test_g_eval.py +498 -0
  12. kiln_ai/adapters/eval/test_g_eval_data.py +4 -0
  13. kiln_ai/adapters/fine_tune/base_finetune.py +16 -2
  14. kiln_ai/adapters/fine_tune/finetune_registry.py +2 -0
  15. kiln_ai/adapters/fine_tune/test_dataset_formatter.py +4 -1
  16. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +1 -1
  17. kiln_ai/adapters/fine_tune/test_openai_finetune.py +1 -1
  18. kiln_ai/adapters/fine_tune/test_together_finetune.py +531 -0
  19. kiln_ai/adapters/fine_tune/together_finetune.py +325 -0
  20. kiln_ai/adapters/ml_model_list.py +758 -163
  21. kiln_ai/adapters/model_adapters/__init__.py +2 -4
  22. kiln_ai/adapters/model_adapters/base_adapter.py +61 -43
  23. kiln_ai/adapters/model_adapters/litellm_adapter.py +391 -0
  24. kiln_ai/adapters/model_adapters/litellm_config.py +13 -0
  25. kiln_ai/adapters/model_adapters/test_base_adapter.py +22 -13
  26. kiln_ai/adapters/model_adapters/test_litellm_adapter.py +407 -0
  27. kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +41 -19
  28. kiln_ai/adapters/model_adapters/test_structured_output.py +59 -35
  29. kiln_ai/adapters/ollama_tools.py +3 -3
  30. kiln_ai/adapters/parsers/r1_parser.py +19 -14
  31. kiln_ai/adapters/parsers/test_r1_parser.py +17 -5
  32. kiln_ai/adapters/prompt_builders.py +80 -42
  33. kiln_ai/adapters/provider_tools.py +50 -58
  34. kiln_ai/adapters/repair/repair_task.py +9 -21
  35. kiln_ai/adapters/repair/test_repair_task.py +6 -6
  36. kiln_ai/adapters/run_output.py +3 -0
  37. kiln_ai/adapters/test_adapter_registry.py +26 -29
  38. kiln_ai/adapters/test_generate_docs.py +4 -4
  39. kiln_ai/adapters/test_ollama_tools.py +0 -1
  40. kiln_ai/adapters/test_prompt_adaptors.py +47 -33
  41. kiln_ai/adapters/test_prompt_builders.py +91 -31
  42. kiln_ai/adapters/test_provider_tools.py +26 -81
  43. kiln_ai/datamodel/__init__.py +50 -952
  44. kiln_ai/datamodel/basemodel.py +2 -0
  45. kiln_ai/datamodel/datamodel_enums.py +60 -0
  46. kiln_ai/datamodel/dataset_filters.py +114 -0
  47. kiln_ai/datamodel/dataset_split.py +170 -0
  48. kiln_ai/datamodel/eval.py +298 -0
  49. kiln_ai/datamodel/finetune.py +105 -0
  50. kiln_ai/datamodel/json_schema.py +7 -1
  51. kiln_ai/datamodel/project.py +23 -0
  52. kiln_ai/datamodel/prompt.py +37 -0
  53. kiln_ai/datamodel/prompt_id.py +83 -0
  54. kiln_ai/datamodel/strict_mode.py +24 -0
  55. kiln_ai/datamodel/task.py +181 -0
  56. kiln_ai/datamodel/task_output.py +328 -0
  57. kiln_ai/datamodel/task_run.py +164 -0
  58. kiln_ai/datamodel/test_basemodel.py +19 -11
  59. kiln_ai/datamodel/test_dataset_filters.py +71 -0
  60. kiln_ai/datamodel/test_dataset_split.py +32 -8
  61. kiln_ai/datamodel/test_datasource.py +22 -2
  62. kiln_ai/datamodel/test_eval_model.py +635 -0
  63. kiln_ai/datamodel/test_example_models.py +9 -13
  64. kiln_ai/datamodel/test_json_schema.py +23 -0
  65. kiln_ai/datamodel/test_models.py +2 -2
  66. kiln_ai/datamodel/test_prompt_id.py +129 -0
  67. kiln_ai/datamodel/test_task.py +159 -0
  68. kiln_ai/utils/config.py +43 -1
  69. kiln_ai/utils/dataset_import.py +232 -0
  70. kiln_ai/utils/test_dataset_import.py +596 -0
  71. {kiln_ai-0.11.1.dist-info → kiln_ai-0.13.0.dist-info}/METADATA +86 -6
  72. kiln_ai-0.13.0.dist-info/RECORD +103 -0
  73. kiln_ai/adapters/model_adapters/langchain_adapters.py +0 -302
  74. kiln_ai/adapters/model_adapters/openai_compatible_config.py +0 -11
  75. kiln_ai/adapters/model_adapters/openai_model_adapter.py +0 -246
  76. kiln_ai/adapters/model_adapters/test_langchain_adapter.py +0 -350
  77. kiln_ai/adapters/model_adapters/test_openai_model_adapter.py +0 -225
  78. kiln_ai-0.11.1.dist-info/RECORD +0 -76
  79. {kiln_ai-0.11.1.dist-info → kiln_ai-0.13.0.dist-info}/WHEEL +0 -0
  80. {kiln_ai-0.11.1.dist-info → kiln_ai-0.13.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -2,8 +2,6 @@ import json
2
2
  from pathlib import Path
3
3
  from typing import Dict
4
4
 
5
- import jsonschema
6
- import jsonschema.exceptions
7
5
  import pytest
8
6
 
9
7
  import kiln_ai.datamodel as datamodel
@@ -12,16 +10,13 @@ from kiln_ai.adapters.ml_model_list import (
12
10
  built_in_models,
13
11
  )
14
12
  from kiln_ai.adapters.model_adapters.base_adapter import (
15
- AdapterInfo,
16
13
  BaseAdapter,
17
14
  RunOutput,
18
15
  )
19
16
  from kiln_ai.adapters.ollama_tools import ollama_online
20
- from kiln_ai.adapters.prompt_builders import (
21
- BasePromptBuilder,
22
- SimpleChainOfThoughtPromptBuilder,
23
- )
24
17
  from kiln_ai.adapters.test_prompt_adaptors import get_all_models_and_providers
18
+ from kiln_ai.datamodel import PromptId
19
+ from kiln_ai.datamodel.task import RunConfig
25
20
  from kiln_ai.datamodel.test_json_schema import json_joke_schema, json_triangle_schema
26
21
 
27
22
 
@@ -39,9 +34,9 @@ async def test_structured_output_gpt_4o_mini(tmp_path):
39
34
  await run_structured_output_test(tmp_path, "gpt_4o_mini", "openai")
40
35
 
41
36
 
42
- @pytest.mark.parametrize("model_name", ["llama_3_1_8b"])
37
+ @pytest.mark.parametrize("model_name", ["llama_3_1_8b", "gemma_2_2b"])
43
38
  @pytest.mark.ollama
44
- async def test_structured_output_ollama_llama(tmp_path, model_name):
39
+ async def test_structured_output_ollama(tmp_path, model_name):
45
40
  if not await ollama_online():
46
41
  pytest.skip("Ollama API not running. Expect it running on localhost:11434")
47
42
  await run_structured_output_test(tmp_path, model_name, "ollama")
@@ -49,19 +44,21 @@ async def test_structured_output_ollama_llama(tmp_path, model_name):
49
44
 
50
45
  class MockAdapter(BaseAdapter):
51
46
  def __init__(self, kiln_task: datamodel.Task, response: Dict | str | None):
52
- super().__init__(kiln_task, model_name="phi_3_5", model_provider_name="ollama")
47
+ super().__init__(
48
+ run_config=RunConfig(
49
+ task=kiln_task,
50
+ model_name="phi_3_5",
51
+ model_provider_name="ollama",
52
+ prompt_id="simple_chain_of_thought_prompt_builder",
53
+ ),
54
+ )
53
55
  self.response = response
54
56
 
55
57
  async def _run(self, input: str) -> RunOutput:
56
58
  return RunOutput(output=self.response, intermediate_outputs=None)
57
59
 
58
- def adapter_info(self) -> AdapterInfo:
59
- return AdapterInfo(
60
- adapter_name="mock_adapter",
61
- model_name="mock_model",
62
- model_provider="mock_provider",
63
- prompt_builder_name="mock_prompt_builder",
64
- )
60
+ def adapter_name(self) -> str:
61
+ return "mock_adapter"
65
62
 
66
63
 
67
64
  async def test_mock_unstructred_response(tmp_path):
@@ -69,7 +66,8 @@ async def test_mock_unstructred_response(tmp_path):
69
66
 
70
67
  # don't error on valid response
71
68
  adapter = MockAdapter(task, response={"setup": "asdf", "punchline": "asdf"})
72
- answer = await adapter.invoke_returning_raw("You are a mock, send me the response!")
69
+ run = await adapter.invoke("You are a mock, send me the response!")
70
+ answer = json.loads(run.output.output)
73
71
  assert answer["setup"] == "asdf"
74
72
  assert answer["punchline"] == "asdf"
75
73
 
@@ -79,9 +77,12 @@ async def test_mock_unstructred_response(tmp_path):
79
77
  answer = await adapter.invoke("You are a mock, send me the response!")
80
78
 
81
79
  adapter = MockAdapter(task, response="string instead of dict")
82
- with pytest.raises(RuntimeError):
80
+ with pytest.raises(
81
+ ValueError,
82
+ match="This task requires JSON output but the model didn't return valid JSON",
83
+ ):
83
84
  # Not a structed response so should error
84
- answer = await adapter.invoke("You are a mock, send me the response!")
85
+ run = await adapter.invoke("You are a mock, send me the response!")
85
86
 
86
87
  # Should error, expecting a string, not a dict
87
88
  project = datamodel.Project(name="test", path=tmp_path / "test.kiln")
@@ -146,7 +147,8 @@ async def run_structured_output_test(tmp_path: Path, model_name: str, provider:
146
147
  task = build_structured_output_test_task(tmp_path)
147
148
  a = adapter_for_task(task, model_name=model_name, provider=provider)
148
149
  try:
149
- parsed = await a.invoke_returning_raw("Cows") # a joke about cows
150
+ run = await a.invoke("Cows") # a joke about cows
151
+ parsed = json.loads(run.output.output)
150
152
  except ValueError as e:
151
153
  if str(e) == "Failed to connect to Ollama. Ensure Ollama is running.":
152
154
  pytest.skip(
@@ -165,6 +167,12 @@ async def run_structured_output_test(tmp_path: Path, model_name: str, provider:
165
167
  assert rating >= 0
166
168
  assert rating <= 10
167
169
 
170
+ # Check reasoning models
171
+ assert a._model_provider is not None
172
+ if a._model_provider.reasoning_capable:
173
+ assert "reasoning" in run.intermediate_outputs
174
+ assert isinstance(run.intermediate_outputs["reasoning"], str)
175
+
168
176
 
169
177
  def build_structured_input_test_task(tmp_path: Path):
170
178
  project = datamodel.Project(name="test", path=tmp_path / "test.kiln")
@@ -204,20 +212,27 @@ async def run_structured_input_task(
204
212
  task: datamodel.Task,
205
213
  model_name: str,
206
214
  provider: str,
207
- pb: BasePromptBuilder | None = None,
215
+ prompt_id: PromptId | None = None,
208
216
  ):
209
217
  a = adapter_for_task(
210
- task, model_name=model_name, provider=provider, prompt_builder=pb
218
+ task,
219
+ model_name=model_name,
220
+ provider=provider,
221
+ prompt_id=prompt_id,
211
222
  )
212
223
  with pytest.raises(ValueError):
213
224
  # not structured input in dictionary
214
225
  await a.invoke("a=1, b=2, c=3")
215
- with pytest.raises(jsonschema.exceptions.ValidationError):
226
+ with pytest.raises(
227
+ ValueError,
228
+ match="This task requires a specific output schema. While the model produced JSON, that JSON didn't meet the schema.",
229
+ ):
216
230
  # invalid structured input
217
231
  await a.invoke({"a": 1, "b": 2, "d": 3})
218
232
 
219
233
  try:
220
- response = await a.invoke_returning_raw({"a": 2, "b": 2, "c": 2})
234
+ run = await a.invoke({"a": 2, "b": 2, "c": 2})
235
+ response = run.output.output
221
236
  except ValueError as e:
222
237
  if str(e) == "Failed to connect to Ollama. Ensure Ollama is running.":
223
238
  pytest.skip(
@@ -229,13 +244,20 @@ async def run_structured_input_task(
229
244
  assert "[[equilateral]]" in response
230
245
  else:
231
246
  assert response["is_equilateral"] is True
232
- adapter_info = a.adapter_info()
247
+
233
248
  expected_pb_name = "simple_prompt_builder"
234
- if pb is not None:
235
- expected_pb_name = pb.__class__.prompt_builder_name()
236
- assert adapter_info.prompt_builder_name == expected_pb_name
237
- assert adapter_info.model_name == model_name
238
- assert adapter_info.model_provider == provider
249
+ if prompt_id is not None:
250
+ expected_pb_name = prompt_id
251
+ assert a.run_config.prompt_id == expected_pb_name
252
+
253
+ assert a.run_config.model_name == model_name
254
+ assert a.run_config.model_provider_name == provider
255
+
256
+ # Check reasoning models
257
+ assert a._model_provider is not None
258
+ if a._model_provider.reasoning_capable:
259
+ assert "reasoning" in run.intermediate_outputs
260
+ assert isinstance(run.intermediate_outputs["reasoning"], str)
239
261
 
240
262
 
241
263
  @pytest.mark.paid
@@ -257,8 +279,9 @@ async def test_all_built_in_models_structured_input(
257
279
  @pytest.mark.parametrize("model_name,provider_name", get_all_models_and_providers())
258
280
  async def test_structured_input_cot_prompt_builder(tmp_path, model_name, provider_name):
259
281
  task = build_structured_input_test_task(tmp_path)
260
- pb = SimpleChainOfThoughtPromptBuilder(task)
261
- await run_structured_input_task(task, model_name, provider_name, pb)
282
+ await run_structured_input_task(
283
+ task, model_name, provider_name, "simple_chain_of_thought_prompt_builder"
284
+ )
262
285
 
263
286
 
264
287
  @pytest.mark.paid
@@ -302,5 +325,6 @@ When asked for a final result, this is the format (for an equilateral example):
302
325
  """
303
326
  task.output_json_schema = json.dumps(triangle_schema)
304
327
  task.save_to_file()
305
- pb = SimpleChainOfThoughtPromptBuilder(task)
306
- await run_structured_input_task(task, model_name, provider_name, pb)
328
+ await run_structured_input_task(
329
+ task, model_name, provider_name, "simple_chain_of_thought_prompt_builder"
330
+ )
@@ -1,4 +1,3 @@
1
- import os
2
1
  from typing import Any, List
3
2
 
4
3
  import httpx
@@ -39,6 +38,7 @@ async def ollama_online() -> bool:
39
38
 
40
39
  class OllamaConnection(BaseModel):
41
40
  message: str
41
+ version: str | None = None
42
42
  supported_models: List[str]
43
43
  untested_models: List[str] = Field(default_factory=list)
44
44
 
@@ -50,7 +50,7 @@ class OllamaConnection(BaseModel):
50
50
  def parse_ollama_tags(tags: Any) -> OllamaConnection | None:
51
51
  # Build a list of models we support for Ollama from the built-in model list
52
52
  supported_ollama_models = [
53
- provider.provider_options["model"]
53
+ provider.model_id
54
54
  for model in built_in_models
55
55
  for provider in model.providers
56
56
  if provider.name == ModelProviderName.ollama
@@ -61,7 +61,7 @@ def parse_ollama_tags(tags: Any) -> OllamaConnection | None:
61
61
  alias
62
62
  for model in built_in_models
63
63
  for provider in model.providers
64
- for alias in provider.provider_options.get("model_aliases", [])
64
+ for alias in provider.ollama_model_aliases or []
65
65
  ]
66
66
  )
67
67
 
@@ -20,21 +20,33 @@ class R1ThinkingParser(BaseParser):
20
20
  Raises:
21
21
  ValueError: If response format is invalid (missing tags, multiple tags, or no content after closing tag)
22
22
  """
23
+
24
+ # The upstream providers (litellm, openrouter, fireworks) all keep changing their response formats, sometimes adding reasoning parsing where it didn't previously exist.
25
+ # If they do it already, great just return. If not we parse it ourselves. Not ideal, but better than upstream changes breaking the app.
26
+ if (
27
+ original_output.intermediate_outputs is not None
28
+ and "reasoning" in original_output.intermediate_outputs
29
+ ):
30
+ return original_output
31
+
23
32
  # This parser only works for strings
24
33
  if not isinstance(original_output.output, str):
25
34
  raise ValueError("Response must be a string for R1 parser")
26
35
 
27
36
  # Strip whitespace and validate basic structure
28
37
  cleaned_response = original_output.output.strip()
29
- if not cleaned_response.startswith(self.START_TAG):
30
- raise ValueError("Response must start with <think> tag")
31
38
 
32
39
  # Find the thinking tags
33
- think_start = cleaned_response.find(self.START_TAG)
34
40
  think_end = cleaned_response.find(self.END_TAG)
41
+ if think_end == -1:
42
+ raise ValueError("Missing </think> tag")
35
43
 
36
- if think_start == -1 or think_end == -1:
37
- raise ValueError("Missing thinking tags")
44
+ think_tag_start = cleaned_response.find(self.START_TAG)
45
+ if think_tag_start == -1:
46
+ # We allow no start <think>, thinking starts on first char. QwQ does this.
47
+ think_start = 0
48
+ else:
49
+ think_start = think_tag_start + len(self.START_TAG)
38
50
 
39
51
  # Check for multiple tags
40
52
  if (
@@ -44,9 +56,7 @@ class R1ThinkingParser(BaseParser):
44
56
  raise ValueError("Multiple thinking tags found")
45
57
 
46
58
  # Extract thinking content
47
- thinking_content = cleaned_response[
48
- think_start + len(self.START_TAG) : think_end
49
- ].strip()
59
+ thinking_content = cleaned_response[think_start:think_end].strip()
50
60
 
51
61
  # Extract result (everything after </think>)
52
62
  result = cleaned_response[think_end + len(self.END_TAG) :].strip()
@@ -54,16 +64,11 @@ class R1ThinkingParser(BaseParser):
54
64
  if not result or len(result) == 0:
55
65
  raise ValueError("No content found after </think> tag")
56
66
 
57
- # Parse JSON if needed
58
- output = result
59
- if self.structured_output:
60
- output = parse_json_string(result)
61
-
62
67
  # Add thinking content to intermediate outputs if it exists
63
68
  intermediate_outputs = original_output.intermediate_outputs or {}
64
69
  intermediate_outputs["reasoning"] = thinking_content
65
70
 
66
71
  return RunOutput(
67
- output=output,
72
+ output=result,
68
73
  intermediate_outputs=intermediate_outputs,
69
74
  )
@@ -19,6 +19,16 @@ def test_valid_response(parser):
19
19
  assert parsed.output == "This is the result"
20
20
 
21
21
 
22
+ def test_already_parsed_response(parser):
23
+ response = RunOutput(
24
+ output="This is the result",
25
+ intermediate_outputs={"reasoning": "This is thinking content"},
26
+ )
27
+ parsed = parser.parse_output(response)
28
+ assert parsed.intermediate_outputs["reasoning"] == "This is thinking content"
29
+ assert parsed.output == "This is the result"
30
+
31
+
22
32
  def test_response_with_whitespace(parser):
23
33
  response = RunOutput(
24
34
  output="""
@@ -37,14 +47,16 @@ def test_response_with_whitespace(parser):
37
47
 
38
48
 
39
49
  def test_missing_start_tag(parser):
40
- with pytest.raises(ValueError, match="Response must start with <think> tag"):
41
- parser.parse_output(
42
- RunOutput(output="Some content</think>result", intermediate_outputs=None)
43
- )
50
+ parsed = parser.parse_output(
51
+ RunOutput(output="Some content</think>result", intermediate_outputs=None)
52
+ )
53
+
54
+ assert parsed.intermediate_outputs["reasoning"] == "Some content"
55
+ assert parsed.output == "result"
44
56
 
45
57
 
46
58
  def test_missing_end_tag(parser):
47
- with pytest.raises(ValueError, match="Missing thinking tags"):
59
+ with pytest.raises(ValueError, match="Missing </think> tag"):
48
60
  parser.parse_output(
49
61
  RunOutput(output="<think>Some content", intermediate_outputs=None)
50
62
  )
@@ -2,8 +2,8 @@ import json
2
2
  from abc import ABCMeta, abstractmethod
3
3
  from typing import Dict
4
4
 
5
- from kiln_ai.datamodel import Task, TaskRun
6
- from kiln_ai.utils.formatting import snake_case
5
+ from kiln_ai.datamodel import PromptGenerators, PromptId, Task, TaskRun
6
+ from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
7
7
 
8
8
 
9
9
  class BasePromptBuilder(metaclass=ABCMeta):
@@ -53,17 +53,6 @@ class BasePromptBuilder(metaclass=ABCMeta):
53
53
  """
54
54
  pass
55
55
 
56
- @classmethod
57
- def prompt_builder_name(cls) -> str:
58
- """Returns the name of the prompt builder, to be used for persisting into the datastore.
59
-
60
- Default implementation gets the name of the prompt builder in snake case. If you change the class name, you should override this so prior saved data is compatible.
61
-
62
- Returns:
63
- str: The prompt builder name in snake_case format.
64
- """
65
- return snake_case(cls.__name__)
66
-
67
56
  def build_user_message(self, input: Dict | str) -> str:
68
57
  """Build a user message from the input.
69
58
 
@@ -300,6 +289,57 @@ class SavedPromptBuilder(BasePromptBuilder):
300
289
  return self.prompt_model.chain_of_thought_instructions
301
290
 
302
291
 
292
+ class TaskRunConfigPromptBuilder(BasePromptBuilder):
293
+ """A prompt builder that looks up a static prompt in a task run config."""
294
+
295
+ def __init__(self, task: Task, run_config_prompt_id: str):
296
+ parts = run_config_prompt_id.split("::")
297
+ if len(parts) != 4:
298
+ raise ValueError(
299
+ f"Invalid task run config prompt ID: {run_config_prompt_id}. Expected format: 'task_run_config::[project_id]::[task_id]::[run_config_id]'."
300
+ )
301
+
302
+ task_id = parts[2]
303
+ if task_id != task.id:
304
+ raise ValueError(
305
+ f"Task run config prompt ID: {run_config_prompt_id}. Task ID mismatch. Expected: {task.id}, got: {task_id}."
306
+ )
307
+
308
+ run_config_id = parts[3]
309
+ run_config = next(
310
+ (
311
+ run_config
312
+ for run_config in task.run_configs(readonly=True)
313
+ if run_config.id == run_config_id
314
+ ),
315
+ None,
316
+ )
317
+ if not run_config:
318
+ raise ValueError(
319
+ f"Task run config ID not found: {run_config_id} for prompt id {run_config_prompt_id}"
320
+ )
321
+ if run_config.prompt is None:
322
+ raise ValueError(
323
+ f"Task run config ID {run_config_id} does not have a stored prompt. Used as prompt id {run_config_prompt_id}"
324
+ )
325
+
326
+ # Load the prompt from the model
327
+ self.prompt = run_config.prompt.prompt
328
+ self.cot_prompt = run_config.prompt.chain_of_thought_instructions
329
+ self.id = run_config_prompt_id
330
+
331
+ super().__init__(task)
332
+
333
+ def prompt_id(self) -> str | None:
334
+ return self.id
335
+
336
+ def build_base_prompt(self) -> str:
337
+ return self.prompt
338
+
339
+ def chain_of_thought_prompt(self) -> str | None:
340
+ return self.cot_prompt
341
+
342
+
303
343
  class FineTunePromptBuilder(BasePromptBuilder):
304
344
  """A prompt builder that looks up a fine-tune prompt."""
305
345
 
@@ -337,25 +377,12 @@ class FineTunePromptBuilder(BasePromptBuilder):
337
377
  return self.fine_tune_model.thinking_instructions
338
378
 
339
379
 
340
- # TODO P2: we end up with 2 IDs for these: the keys here (ui_name) and the prompt_builder_name from the class
341
- # We end up maintaining this in _prompt_generators as well.
342
- prompt_builder_registry = {
343
- "simple_prompt_builder": SimplePromptBuilder,
344
- "multi_shot_prompt_builder": MultiShotPromptBuilder,
345
- "few_shot_prompt_builder": FewShotPromptBuilder,
346
- "repairs_prompt_builder": RepairsPromptBuilder,
347
- "simple_chain_of_thought_prompt_builder": SimpleChainOfThoughtPromptBuilder,
348
- "few_shot_chain_of_thought_prompt_builder": FewShotChainOfThoughtPromptBuilder,
349
- "multi_shot_chain_of_thought_prompt_builder": MultiShotChainOfThoughtPromptBuilder,
350
- }
351
-
352
-
353
380
  # Our UI has some names that are not the same as the class names, which also hint parameters.
354
- def prompt_builder_from_ui_name(ui_name: str, task: Task) -> BasePromptBuilder:
381
+ def prompt_builder_from_id(prompt_id: PromptId, task: Task) -> BasePromptBuilder:
355
382
  """Convert a name used in the UI to the corresponding prompt builder class.
356
383
 
357
384
  Args:
358
- ui_name (str): The UI name for the prompt builder type.
385
+ prompt_id (PromptId): The prompt ID.
359
386
 
360
387
  Returns:
361
388
  type[BasePromptBuilder]: The corresponding prompt builder class.
@@ -365,29 +392,40 @@ def prompt_builder_from_ui_name(ui_name: str, task: Task) -> BasePromptBuilder:
365
392
  """
366
393
 
367
394
  # Saved prompts are prefixed with "id::"
368
- if ui_name.startswith("id::"):
369
- prompt_id = ui_name[4:]
395
+ if prompt_id.startswith("id::"):
396
+ prompt_id = prompt_id[4:]
370
397
  return SavedPromptBuilder(task, prompt_id)
371
398
 
399
+ # Task run config prompts are prefixed with "task_run_config::"
400
+ # task_run_config::[project_id]::[task_id]::[run_config_id]
401
+ if prompt_id.startswith("task_run_config::"):
402
+ return TaskRunConfigPromptBuilder(task, prompt_id)
403
+
372
404
  # Fine-tune prompts are prefixed with "fine_tune_prompt::"
373
- if ui_name.startswith("fine_tune_prompt::"):
374
- fine_tune_id = ui_name[18:]
375
- return FineTunePromptBuilder(task, fine_tune_id)
405
+ if prompt_id.startswith("fine_tune_prompt::"):
406
+ prompt_id = prompt_id[18:]
407
+ return FineTunePromptBuilder(task, prompt_id)
408
+
409
+ # Check if the prompt_id matches any enum value
410
+ if prompt_id not in [member.value for member in PromptGenerators]:
411
+ raise ValueError(f"Unknown prompt generator: {prompt_id}")
412
+ typed_prompt_generator = PromptGenerators(prompt_id)
376
413
 
377
- match ui_name:
378
- case "basic":
414
+ match typed_prompt_generator:
415
+ case PromptGenerators.SIMPLE:
379
416
  return SimplePromptBuilder(task)
380
- case "few_shot":
417
+ case PromptGenerators.FEW_SHOT:
381
418
  return FewShotPromptBuilder(task)
382
- case "many_shot":
419
+ case PromptGenerators.MULTI_SHOT:
383
420
  return MultiShotPromptBuilder(task)
384
- case "repairs":
421
+ case PromptGenerators.REPAIRS:
385
422
  return RepairsPromptBuilder(task)
386
- case "simple_chain_of_thought":
423
+ case PromptGenerators.SIMPLE_CHAIN_OF_THOUGHT:
387
424
  return SimpleChainOfThoughtPromptBuilder(task)
388
- case "few_shot_chain_of_thought":
425
+ case PromptGenerators.FEW_SHOT_CHAIN_OF_THOUGHT:
389
426
  return FewShotChainOfThoughtPromptBuilder(task)
390
- case "multi_shot_chain_of_thought":
427
+ case PromptGenerators.MULTI_SHOT_CHAIN_OF_THOUGHT:
391
428
  return MultiShotChainOfThoughtPromptBuilder(task)
392
429
  case _:
393
- raise ValueError(f"Unknown prompt builder: {ui_name}")
430
+ # Type checking will find missing cases
431
+ raise_exhaustive_enum_error(typed_prompt_generator)
@@ -9,8 +9,8 @@ from kiln_ai.adapters.ml_model_list import (
9
9
  StructuredOutputMode,
10
10
  built_in_models,
11
11
  )
12
- from kiln_ai.adapters.model_adapters.openai_compatible_config import (
13
- OpenAICompatibleConfig,
12
+ from kiln_ai.adapters.model_adapters.litellm_config import (
13
+ LiteLlmConfig,
14
14
  )
15
15
  from kiln_ai.adapters.ollama_tools import (
16
16
  get_ollama_connection,
@@ -153,7 +153,7 @@ def kiln_model_provider_from(
153
153
  return finetune_provider_model(name)
154
154
 
155
155
  if provider_name == ModelProviderName.openai_compatible:
156
- return openai_compatible_provider_model(name)
156
+ return lite_llm_provider_model(name)
157
157
 
158
158
  built_in_model = builtin_model_from(name, provider_name)
159
159
  if built_in_model:
@@ -175,13 +175,13 @@ def kiln_model_provider_from(
175
175
  supports_structured_output=False,
176
176
  supports_data_gen=False,
177
177
  untested_model=True,
178
- provider_options=provider_options_for_custom_model(name, provider_name),
178
+ model_id=name,
179
179
  )
180
180
 
181
181
 
182
- def openai_compatible_config(
182
+ def lite_llm_config(
183
183
  model_id: str,
184
- ) -> OpenAICompatibleConfig:
184
+ ) -> LiteLlmConfig:
185
185
  try:
186
186
  openai_provider_name, model_id = model_id.split("::")
187
187
  except Exception:
@@ -205,22 +205,23 @@ def openai_compatible_config(
205
205
  f"OpenAI compatible provider {openai_provider_name} has no base URL"
206
206
  )
207
207
 
208
- return OpenAICompatibleConfig(
209
- api_key=api_key,
208
+ return LiteLlmConfig(
209
+ # OpenAI compatible, with a custom base URL
210
210
  model_name=model_id,
211
211
  provider_name=ModelProviderName.openai_compatible,
212
212
  base_url=base_url,
213
+ additional_body_options={
214
+ "api_key": api_key,
215
+ },
213
216
  )
214
217
 
215
218
 
216
- def openai_compatible_provider_model(
219
+ def lite_llm_provider_model(
217
220
  model_id: str,
218
221
  ) -> KilnModelProvider:
219
222
  return KilnModelProvider(
220
223
  name=ModelProviderName.openai_compatible,
221
- provider_options={
222
- "model": model_id,
223
- },
224
+ model_id=model_id,
224
225
  supports_structured_output=False,
225
226
  supports_data_gen=False,
226
227
  untested_model=True,
@@ -264,9 +265,7 @@ def finetune_provider_model(
264
265
  provider = ModelProviderName[fine_tune.provider]
265
266
  model_provider = KilnModelProvider(
266
267
  name=provider,
267
- provider_options={
268
- "model": fine_tune.fine_tune_model_id,
269
- },
268
+ model_id=fine_tune.fine_tune_model_id,
270
269
  )
271
270
 
272
271
  if fine_tune.structured_output_mode is not None:
@@ -331,6 +330,18 @@ def provider_name_from_id(id: str) -> str:
331
330
  return "Custom Models"
332
331
  case ModelProviderName.openai_compatible:
333
332
  return "OpenAI Compatible"
333
+ case ModelProviderName.azure_openai:
334
+ return "Azure OpenAI"
335
+ case ModelProviderName.gemini_api:
336
+ return "Gemini API"
337
+ case ModelProviderName.anthropic:
338
+ return "Anthropic"
339
+ case ModelProviderName.huggingface:
340
+ return "Hugging Face"
341
+ case ModelProviderName.vertex:
342
+ return "Google Vertex AI"
343
+ case ModelProviderName.together_ai:
344
+ return "Together AI"
334
345
  case _:
335
346
  # triggers pyright warning if I miss a case
336
347
  raise_exhaustive_enum_error(enum_id)
@@ -338,49 +349,6 @@ def provider_name_from_id(id: str) -> str:
338
349
  return "Unknown provider: " + id
339
350
 
340
351
 
341
- def provider_options_for_custom_model(
342
- model_name: str, provider_name: str
343
- ) -> Dict[str, str]:
344
- """
345
- Generated model provider options for a custom model. Each has their own format/options.
346
- """
347
-
348
- if provider_name not in ModelProviderName.__members__:
349
- raise ValueError(f"Invalid provider name: {provider_name}")
350
-
351
- enum_id = ModelProviderName(provider_name)
352
- match enum_id:
353
- case ModelProviderName.amazon_bedrock:
354
- # us-west-2 is the only region consistently supported by Bedrock
355
- return {"model": model_name, "region_name": "us-west-2"}
356
- case (
357
- ModelProviderName.openai
358
- | ModelProviderName.ollama
359
- | ModelProviderName.fireworks_ai
360
- | ModelProviderName.openrouter
361
- | ModelProviderName.groq
362
- ):
363
- return {"model": model_name}
364
- case ModelProviderName.kiln_custom_registry:
365
- raise ValueError(
366
- "Custom models from registry should be parsed into provider/model before calling this."
367
- )
368
- case ModelProviderName.kiln_fine_tune:
369
- raise ValueError(
370
- "Fine tuned models should populate provider options via another path"
371
- )
372
- case ModelProviderName.openai_compatible:
373
- raise ValueError(
374
- "OpenAI compatible models should populate provider options via another path"
375
- )
376
- case _:
377
- # triggers pyright warning if I miss a case
378
- raise_exhaustive_enum_error(enum_id)
379
-
380
- # Won't reach this, type checking will catch missed values
381
- return {"model": model_name}
382
-
383
-
384
352
  @dataclass
385
353
  class ModelProviderWarning:
386
354
  required_config_keys: List[str]
@@ -408,4 +376,28 @@ provider_warnings: Dict[ModelProviderName, ModelProviderWarning] = {
408
376
  required_config_keys=["fireworks_api_key", "fireworks_account_id"],
409
377
  message="Attempted to use Fireworks without an API key and account ID set. \nGet your API key from https://fireworks.ai/account/api-keys and your account ID from https://fireworks.ai/account/profile",
410
378
  ),
379
+ ModelProviderName.anthropic: ModelProviderWarning(
380
+ required_config_keys=["anthropic_api_key"],
381
+ message="Attempted to use Anthropic without an API key set. \nGet your API key from https://console.anthropic.com/settings/keys",
382
+ ),
383
+ ModelProviderName.gemini_api: ModelProviderWarning(
384
+ required_config_keys=["gemini_api_key"],
385
+ message="Attempted to use Gemini without an API key set. \nGet your API key from https://aistudio.google.com/app/apikey",
386
+ ),
387
+ ModelProviderName.azure_openai: ModelProviderWarning(
388
+ required_config_keys=["azure_openai_api_key", "azure_openai_endpoint"],
389
+ message="Attempted to use Azure OpenAI without an API key and endpoint set. Configure these in settings.",
390
+ ),
391
+ ModelProviderName.huggingface: ModelProviderWarning(
392
+ required_config_keys=["huggingface_api_key"],
393
+ message="Attempted to use Hugging Face without an API key set. \nGet your API key from https://huggingface.co/settings/tokens",
394
+ ),
395
+ ModelProviderName.vertex: ModelProviderWarning(
396
+ required_config_keys=["vertex_project_id"],
397
+ message="Attempted to use Vertex without a project ID set. \nGet your project ID from the Vertex AI console.",
398
+ ),
399
+ ModelProviderName.together_ai: ModelProviderWarning(
400
+ required_config_keys=["together_api_key"],
401
+ message="Attempted to use Together without an API key set. \nGet your API key from https://together.ai/settings/keys",
402
+ ),
411
403
  }