kiln-ai 0.11.1__py3-none-any.whl → 0.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (63) hide show
  1. kiln_ai/adapters/adapter_registry.py +12 -13
  2. kiln_ai/adapters/data_gen/data_gen_task.py +18 -0
  3. kiln_ai/adapters/eval/base_eval.py +164 -0
  4. kiln_ai/adapters/eval/eval_runner.py +267 -0
  5. kiln_ai/adapters/eval/g_eval.py +367 -0
  6. kiln_ai/adapters/eval/registry.py +16 -0
  7. kiln_ai/adapters/eval/test_base_eval.py +324 -0
  8. kiln_ai/adapters/eval/test_eval_runner.py +640 -0
  9. kiln_ai/adapters/eval/test_g_eval.py +497 -0
  10. kiln_ai/adapters/eval/test_g_eval_data.py +4 -0
  11. kiln_ai/adapters/fine_tune/test_dataset_formatter.py +4 -1
  12. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +1 -1
  13. kiln_ai/adapters/fine_tune/test_openai_finetune.py +1 -1
  14. kiln_ai/adapters/ml_model_list.py +141 -29
  15. kiln_ai/adapters/model_adapters/base_adapter.py +50 -35
  16. kiln_ai/adapters/model_adapters/langchain_adapters.py +27 -20
  17. kiln_ai/adapters/model_adapters/openai_compatible_config.py +0 -1
  18. kiln_ai/adapters/model_adapters/openai_model_adapter.py +93 -50
  19. kiln_ai/adapters/model_adapters/test_base_adapter.py +22 -13
  20. kiln_ai/adapters/model_adapters/test_langchain_adapter.py +7 -14
  21. kiln_ai/adapters/model_adapters/test_openai_model_adapter.py +55 -64
  22. kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +41 -19
  23. kiln_ai/adapters/model_adapters/test_structured_output.py +36 -30
  24. kiln_ai/adapters/ollama_tools.py +0 -1
  25. kiln_ai/adapters/prompt_builders.py +80 -42
  26. kiln_ai/adapters/repair/repair_task.py +9 -21
  27. kiln_ai/adapters/repair/test_repair_task.py +3 -3
  28. kiln_ai/adapters/run_output.py +3 -0
  29. kiln_ai/adapters/test_adapter_registry.py +10 -10
  30. kiln_ai/adapters/test_generate_docs.py +6 -6
  31. kiln_ai/adapters/test_ollama_tools.py +0 -1
  32. kiln_ai/adapters/test_prompt_adaptors.py +17 -14
  33. kiln_ai/adapters/test_prompt_builders.py +91 -31
  34. kiln_ai/datamodel/__init__.py +50 -952
  35. kiln_ai/datamodel/datamodel_enums.py +58 -0
  36. kiln_ai/datamodel/dataset_filters.py +114 -0
  37. kiln_ai/datamodel/dataset_split.py +170 -0
  38. kiln_ai/datamodel/eval.py +298 -0
  39. kiln_ai/datamodel/finetune.py +105 -0
  40. kiln_ai/datamodel/json_schema.py +6 -0
  41. kiln_ai/datamodel/project.py +23 -0
  42. kiln_ai/datamodel/prompt.py +37 -0
  43. kiln_ai/datamodel/prompt_id.py +83 -0
  44. kiln_ai/datamodel/strict_mode.py +24 -0
  45. kiln_ai/datamodel/task.py +181 -0
  46. kiln_ai/datamodel/task_output.py +321 -0
  47. kiln_ai/datamodel/task_run.py +164 -0
  48. kiln_ai/datamodel/test_basemodel.py +10 -11
  49. kiln_ai/datamodel/test_dataset_filters.py +71 -0
  50. kiln_ai/datamodel/test_dataset_split.py +32 -8
  51. kiln_ai/datamodel/test_datasource.py +3 -2
  52. kiln_ai/datamodel/test_eval_model.py +635 -0
  53. kiln_ai/datamodel/test_example_models.py +9 -13
  54. kiln_ai/datamodel/test_json_schema.py +23 -0
  55. kiln_ai/datamodel/test_models.py +2 -2
  56. kiln_ai/datamodel/test_prompt_id.py +129 -0
  57. kiln_ai/datamodel/test_task.py +159 -0
  58. kiln_ai/utils/config.py +6 -1
  59. {kiln_ai-0.11.1.dist-info → kiln_ai-0.12.0.dist-info}/METADATA +37 -1
  60. kiln_ai-0.12.0.dist-info/RECORD +100 -0
  61. kiln_ai-0.11.1.dist-info/RECORD +0 -76
  62. {kiln_ai-0.11.1.dist-info → kiln_ai-0.12.0.dist-info}/WHEEL +0 -0
  63. {kiln_ai-0.11.1.dist-info → kiln_ai-0.12.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -2,8 +2,6 @@ import json
2
2
  from pathlib import Path
3
3
  from typing import Dict
4
4
 
5
- import jsonschema
6
- import jsonschema.exceptions
7
5
  import pytest
8
6
 
9
7
  import kiln_ai.datamodel as datamodel
@@ -12,16 +10,13 @@ from kiln_ai.adapters.ml_model_list import (
12
10
  built_in_models,
13
11
  )
14
12
  from kiln_ai.adapters.model_adapters.base_adapter import (
15
- AdapterInfo,
16
13
  BaseAdapter,
17
14
  RunOutput,
18
15
  )
19
16
  from kiln_ai.adapters.ollama_tools import ollama_online
20
- from kiln_ai.adapters.prompt_builders import (
21
- BasePromptBuilder,
22
- SimpleChainOfThoughtPromptBuilder,
23
- )
24
17
  from kiln_ai.adapters.test_prompt_adaptors import get_all_models_and_providers
18
+ from kiln_ai.datamodel import PromptId
19
+ from kiln_ai.datamodel.task import RunConfig
25
20
  from kiln_ai.datamodel.test_json_schema import json_joke_schema, json_triangle_schema
26
21
 
27
22
 
@@ -39,9 +34,9 @@ async def test_structured_output_gpt_4o_mini(tmp_path):
39
34
  await run_structured_output_test(tmp_path, "gpt_4o_mini", "openai")
40
35
 
41
36
 
42
- @pytest.mark.parametrize("model_name", ["llama_3_1_8b"])
37
+ @pytest.mark.parametrize("model_name", ["llama_3_1_8b", "gemma_2_2b"])
43
38
  @pytest.mark.ollama
44
- async def test_structured_output_ollama_llama(tmp_path, model_name):
39
+ async def test_structured_output_ollama(tmp_path, model_name):
45
40
  if not await ollama_online():
46
41
  pytest.skip("Ollama API not running. Expect it running on localhost:11434")
47
42
  await run_structured_output_test(tmp_path, model_name, "ollama")
@@ -49,19 +44,21 @@ async def test_structured_output_ollama_llama(tmp_path, model_name):
49
44
 
50
45
  class MockAdapter(BaseAdapter):
51
46
  def __init__(self, kiln_task: datamodel.Task, response: Dict | str | None):
52
- super().__init__(kiln_task, model_name="phi_3_5", model_provider_name="ollama")
47
+ super().__init__(
48
+ run_config=RunConfig(
49
+ task=kiln_task,
50
+ model_name="phi_3_5",
51
+ model_provider_name="ollama",
52
+ prompt_id="simple_chain_of_thought_prompt_builder",
53
+ ),
54
+ )
53
55
  self.response = response
54
56
 
55
57
  async def _run(self, input: str) -> RunOutput:
56
58
  return RunOutput(output=self.response, intermediate_outputs=None)
57
59
 
58
- def adapter_info(self) -> AdapterInfo:
59
- return AdapterInfo(
60
- adapter_name="mock_adapter",
61
- model_name="mock_model",
62
- model_provider="mock_provider",
63
- prompt_builder_name="mock_prompt_builder",
64
- )
60
+ def adapter_name(self) -> str:
61
+ return "mock_adapter"
65
62
 
66
63
 
67
64
  async def test_mock_unstructred_response(tmp_path):
@@ -204,15 +201,21 @@ async def run_structured_input_task(
204
201
  task: datamodel.Task,
205
202
  model_name: str,
206
203
  provider: str,
207
- pb: BasePromptBuilder | None = None,
204
+ prompt_id: PromptId | None = None,
208
205
  ):
209
206
  a = adapter_for_task(
210
- task, model_name=model_name, provider=provider, prompt_builder=pb
207
+ task,
208
+ model_name=model_name,
209
+ provider=provider,
210
+ prompt_id=prompt_id,
211
211
  )
212
212
  with pytest.raises(ValueError):
213
213
  # not structured input in dictionary
214
214
  await a.invoke("a=1, b=2, c=3")
215
- with pytest.raises(jsonschema.exceptions.ValidationError):
215
+ with pytest.raises(
216
+ ValueError,
217
+ match="This task requires a specific output schema. While the model produced JSON, that JSON didn't meet the schema.",
218
+ ):
216
219
  # invalid structured input
217
220
  await a.invoke({"a": 1, "b": 2, "d": 3})
218
221
 
@@ -229,13 +232,14 @@ async def run_structured_input_task(
229
232
  assert "[[equilateral]]" in response
230
233
  else:
231
234
  assert response["is_equilateral"] is True
232
- adapter_info = a.adapter_info()
235
+
233
236
  expected_pb_name = "simple_prompt_builder"
234
- if pb is not None:
235
- expected_pb_name = pb.__class__.prompt_builder_name()
236
- assert adapter_info.prompt_builder_name == expected_pb_name
237
- assert adapter_info.model_name == model_name
238
- assert adapter_info.model_provider == provider
237
+ if prompt_id is not None:
238
+ expected_pb_name = prompt_id
239
+ assert a.run_config.prompt_id == expected_pb_name
240
+
241
+ assert a.run_config.model_name == model_name
242
+ assert a.run_config.model_provider_name == provider
239
243
 
240
244
 
241
245
  @pytest.mark.paid
@@ -257,8 +261,9 @@ async def test_all_built_in_models_structured_input(
257
261
  @pytest.mark.parametrize("model_name,provider_name", get_all_models_and_providers())
258
262
  async def test_structured_input_cot_prompt_builder(tmp_path, model_name, provider_name):
259
263
  task = build_structured_input_test_task(tmp_path)
260
- pb = SimpleChainOfThoughtPromptBuilder(task)
261
- await run_structured_input_task(task, model_name, provider_name, pb)
264
+ await run_structured_input_task(
265
+ task, model_name, provider_name, "simple_chain_of_thought_prompt_builder"
266
+ )
262
267
 
263
268
 
264
269
  @pytest.mark.paid
@@ -302,5 +307,6 @@ When asked for a final result, this is the format (for an equilateral example):
302
307
  """
303
308
  task.output_json_schema = json.dumps(triangle_schema)
304
309
  task.save_to_file()
305
- pb = SimpleChainOfThoughtPromptBuilder(task)
306
- await run_structured_input_task(task, model_name, provider_name, pb)
310
+ await run_structured_input_task(
311
+ task, model_name, provider_name, "simple_chain_of_thought_prompt_builder"
312
+ )
@@ -1,4 +1,3 @@
1
- import os
2
1
  from typing import Any, List
3
2
 
4
3
  import httpx
@@ -2,8 +2,8 @@ import json
2
2
  from abc import ABCMeta, abstractmethod
3
3
  from typing import Dict
4
4
 
5
- from kiln_ai.datamodel import Task, TaskRun
6
- from kiln_ai.utils.formatting import snake_case
5
+ from kiln_ai.datamodel import PromptGenerators, PromptId, Task, TaskRun
6
+ from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
7
7
 
8
8
 
9
9
  class BasePromptBuilder(metaclass=ABCMeta):
@@ -53,17 +53,6 @@ class BasePromptBuilder(metaclass=ABCMeta):
53
53
  """
54
54
  pass
55
55
 
56
- @classmethod
57
- def prompt_builder_name(cls) -> str:
58
- """Returns the name of the prompt builder, to be used for persisting into the datastore.
59
-
60
- Default implementation gets the name of the prompt builder in snake case. If you change the class name, you should override this so prior saved data is compatible.
61
-
62
- Returns:
63
- str: The prompt builder name in snake_case format.
64
- """
65
- return snake_case(cls.__name__)
66
-
67
56
  def build_user_message(self, input: Dict | str) -> str:
68
57
  """Build a user message from the input.
69
58
 
@@ -300,6 +289,57 @@ class SavedPromptBuilder(BasePromptBuilder):
300
289
  return self.prompt_model.chain_of_thought_instructions
301
290
 
302
291
 
292
+ class TaskRunConfigPromptBuilder(BasePromptBuilder):
293
+ """A prompt builder that looks up a static prompt in a task run config."""
294
+
295
+ def __init__(self, task: Task, run_config_prompt_id: str):
296
+ parts = run_config_prompt_id.split("::")
297
+ if len(parts) != 4:
298
+ raise ValueError(
299
+ f"Invalid task run config prompt ID: {run_config_prompt_id}. Expected format: 'task_run_config::[project_id]::[task_id]::[run_config_id]'."
300
+ )
301
+
302
+ task_id = parts[2]
303
+ if task_id != task.id:
304
+ raise ValueError(
305
+ f"Task run config prompt ID: {run_config_prompt_id}. Task ID mismatch. Expected: {task.id}, got: {task_id}."
306
+ )
307
+
308
+ run_config_id = parts[3]
309
+ run_config = next(
310
+ (
311
+ run_config
312
+ for run_config in task.run_configs(readonly=True)
313
+ if run_config.id == run_config_id
314
+ ),
315
+ None,
316
+ )
317
+ if not run_config:
318
+ raise ValueError(
319
+ f"Task run config ID not found: {run_config_id} for prompt id {run_config_prompt_id}"
320
+ )
321
+ if run_config.prompt is None:
322
+ raise ValueError(
323
+ f"Task run config ID {run_config_id} does not have a stored prompt. Used as prompt id {run_config_prompt_id}"
324
+ )
325
+
326
+ # Load the prompt from the model
327
+ self.prompt = run_config.prompt.prompt
328
+ self.cot_prompt = run_config.prompt.chain_of_thought_instructions
329
+ self.id = run_config_prompt_id
330
+
331
+ super().__init__(task)
332
+
333
+ def prompt_id(self) -> str | None:
334
+ return self.id
335
+
336
+ def build_base_prompt(self) -> str:
337
+ return self.prompt
338
+
339
+ def chain_of_thought_prompt(self) -> str | None:
340
+ return self.cot_prompt
341
+
342
+
303
343
  class FineTunePromptBuilder(BasePromptBuilder):
304
344
  """A prompt builder that looks up a fine-tune prompt."""
305
345
 
@@ -337,25 +377,12 @@ class FineTunePromptBuilder(BasePromptBuilder):
337
377
  return self.fine_tune_model.thinking_instructions
338
378
 
339
379
 
340
- # TODO P2: we end up with 2 IDs for these: the keys here (ui_name) and the prompt_builder_name from the class
341
- # We end up maintaining this in _prompt_generators as well.
342
- prompt_builder_registry = {
343
- "simple_prompt_builder": SimplePromptBuilder,
344
- "multi_shot_prompt_builder": MultiShotPromptBuilder,
345
- "few_shot_prompt_builder": FewShotPromptBuilder,
346
- "repairs_prompt_builder": RepairsPromptBuilder,
347
- "simple_chain_of_thought_prompt_builder": SimpleChainOfThoughtPromptBuilder,
348
- "few_shot_chain_of_thought_prompt_builder": FewShotChainOfThoughtPromptBuilder,
349
- "multi_shot_chain_of_thought_prompt_builder": MultiShotChainOfThoughtPromptBuilder,
350
- }
351
-
352
-
353
380
  # Our UI has some names that are not the same as the class names, which also hint parameters.
354
- def prompt_builder_from_ui_name(ui_name: str, task: Task) -> BasePromptBuilder:
381
+ def prompt_builder_from_id(prompt_id: PromptId, task: Task) -> BasePromptBuilder:
355
382
  """Convert a name used in the UI to the corresponding prompt builder class.
356
383
 
357
384
  Args:
358
- ui_name (str): The UI name for the prompt builder type.
385
+ prompt_id (PromptId): The prompt ID.
359
386
 
360
387
  Returns:
361
388
  type[BasePromptBuilder]: The corresponding prompt builder class.
@@ -365,29 +392,40 @@ def prompt_builder_from_ui_name(ui_name: str, task: Task) -> BasePromptBuilder:
365
392
  """
366
393
 
367
394
  # Saved prompts are prefixed with "id::"
368
- if ui_name.startswith("id::"):
369
- prompt_id = ui_name[4:]
395
+ if prompt_id.startswith("id::"):
396
+ prompt_id = prompt_id[4:]
370
397
  return SavedPromptBuilder(task, prompt_id)
371
398
 
399
+ # Task run config prompts are prefixed with "task_run_config::"
400
+ # task_run_config::[project_id]::[task_id]::[run_config_id]
401
+ if prompt_id.startswith("task_run_config::"):
402
+ return TaskRunConfigPromptBuilder(task, prompt_id)
403
+
372
404
  # Fine-tune prompts are prefixed with "fine_tune_prompt::"
373
- if ui_name.startswith("fine_tune_prompt::"):
374
- fine_tune_id = ui_name[18:]
375
- return FineTunePromptBuilder(task, fine_tune_id)
405
+ if prompt_id.startswith("fine_tune_prompt::"):
406
+ prompt_id = prompt_id[18:]
407
+ return FineTunePromptBuilder(task, prompt_id)
408
+
409
+ # Check if the prompt_id matches any enum value
410
+ if prompt_id not in [member.value for member in PromptGenerators]:
411
+ raise ValueError(f"Unknown prompt generator: {prompt_id}")
412
+ typed_prompt_generator = PromptGenerators(prompt_id)
376
413
 
377
- match ui_name:
378
- case "basic":
414
+ match typed_prompt_generator:
415
+ case PromptGenerators.SIMPLE:
379
416
  return SimplePromptBuilder(task)
380
- case "few_shot":
417
+ case PromptGenerators.FEW_SHOT:
381
418
  return FewShotPromptBuilder(task)
382
- case "many_shot":
419
+ case PromptGenerators.MULTI_SHOT:
383
420
  return MultiShotPromptBuilder(task)
384
- case "repairs":
421
+ case PromptGenerators.REPAIRS:
385
422
  return RepairsPromptBuilder(task)
386
- case "simple_chain_of_thought":
423
+ case PromptGenerators.SIMPLE_CHAIN_OF_THOUGHT:
387
424
  return SimpleChainOfThoughtPromptBuilder(task)
388
- case "few_shot_chain_of_thought":
425
+ case PromptGenerators.FEW_SHOT_CHAIN_OF_THOUGHT:
389
426
  return FewShotChainOfThoughtPromptBuilder(task)
390
- case "multi_shot_chain_of_thought":
427
+ case PromptGenerators.MULTI_SHOT_CHAIN_OF_THOUGHT:
391
428
  return MultiShotChainOfThoughtPromptBuilder(task)
392
429
  case _:
393
- raise ValueError(f"Unknown prompt builder: {ui_name}")
430
+ # Type checking will find missing cases
431
+ raise_exhaustive_enum_error(typed_prompt_generator)
@@ -6,7 +6,7 @@ from pydantic import BaseModel, Field
6
6
  from kiln_ai.adapters.prompt_builders import (
7
7
  BasePromptBuilder,
8
8
  SavedPromptBuilder,
9
- prompt_builder_registry,
9
+ prompt_builder_from_id,
10
10
  )
11
11
  from kiln_ai.datamodel import Priority, Project, Task, TaskRequirement, TaskRun
12
12
 
@@ -49,28 +49,16 @@ feedback describing what should be improved. Your job is to understand the evalu
49
49
  if run.output.source is None or run.output.source.properties is None:
50
50
  raise ValueError("No source properties found")
51
51
 
52
- # Try ID first, then builder name
53
- prompt_id = run.output.source.properties.get("prompt_id", None)
52
+ # Get the prompt builder id. Need the second check because we used to store this in a prompt_builder_name field, so loading legacy runs will need this.
53
+ prompt_id = run.output.source.properties.get(
54
+ "prompt_id"
55
+ ) or run.output.source.properties.get("prompt_builder_name", None)
54
56
  if prompt_id is not None and isinstance(prompt_id, str):
55
- static_prompt_builder = SavedPromptBuilder(task, prompt_id)
56
- return static_prompt_builder.build_prompt(include_json_instructions=False)
57
+ prompt_builder = prompt_builder_from_id(prompt_id, task)
58
+ if isinstance(prompt_builder, BasePromptBuilder):
59
+ return prompt_builder.build_prompt(include_json_instructions=False)
57
60
 
58
- prompt_builder_class: Type[BasePromptBuilder] | None = None
59
- prompt_builder_name = run.output.source.properties.get(
60
- "prompt_builder_name", None
61
- )
62
- if prompt_builder_name is not None and isinstance(prompt_builder_name, str):
63
- prompt_builder_class = prompt_builder_registry.get(
64
- prompt_builder_name, None
65
- )
66
- if prompt_builder_class is None:
67
- raise ValueError(f"No prompt builder found for name: {prompt_builder_name}")
68
- prompt_builder = prompt_builder_class(task=task)
69
- if not isinstance(prompt_builder, BasePromptBuilder):
70
- raise ValueError(
71
- f"Prompt builder {prompt_builder_name} is not a valid prompt builder"
72
- )
73
- return prompt_builder.build_prompt(include_json_instructions=False)
61
+ raise ValueError(f"Prompt builder '{prompt_id}' is not a valid prompt builder")
74
62
 
75
63
  @classmethod
76
64
  def build_repair_task_input(
@@ -95,7 +95,7 @@ def sample_task_run(sample_task):
95
95
  "model_name": "gpt_4o",
96
96
  "model_provider": "openai",
97
97
  "adapter_name": "langchain_adapter",
98
- "prompt_builder_name": "simple_prompt_builder",
98
+ "prompt_id": "simple_prompt_builder",
99
99
  },
100
100
  ),
101
101
  ),
@@ -201,7 +201,7 @@ async def test_live_run(sample_task, sample_task_run, sample_repair_data):
201
201
  "adapter_name": "kiln_langchain_adapter",
202
202
  "model_name": "llama_3_1_8b",
203
203
  "model_provider": "groq",
204
- "prompt_builder_name": "simple_prompt_builder",
204
+ "prompt_id": "simple_prompt_builder",
205
205
  }
206
206
 
207
207
 
@@ -238,7 +238,7 @@ async def test_mocked_repair_task_run(sample_task, sample_task_run, sample_repai
238
238
  "adapter_name": "kiln_langchain_adapter",
239
239
  "model_name": "llama_3_1_8b",
240
240
  "model_provider": "ollama",
241
- "prompt_builder_name": "simple_prompt_builder",
241
+ "prompt_id": "simple_prompt_builder",
242
242
  }
243
243
  assert run.input_source.type == DataSourceType.human
244
244
  assert "created_by" in run.input_source.properties
@@ -1,8 +1,11 @@
1
1
  from dataclasses import dataclass
2
2
  from typing import Dict
3
3
 
4
+ from openai.types.chat.chat_completion import ChoiceLogprobs
5
+
4
6
 
5
7
  @dataclass
6
8
  class RunOutput:
7
9
  output: Dict | str
8
10
  intermediate_outputs: Dict[str, str] | None
11
+ output_logprobs: ChoiceLogprobs | None = None
@@ -5,6 +5,7 @@ import pytest
5
5
  from kiln_ai import datamodel
6
6
  from kiln_ai.adapters.adapter_registry import adapter_for_task
7
7
  from kiln_ai.adapters.ml_model_list import ModelProviderName
8
+ from kiln_ai.adapters.model_adapters.base_adapter import AdapterConfig
8
9
  from kiln_ai.adapters.model_adapters.langchain_adapters import LangchainAdapter
9
10
  from kiln_ai.adapters.model_adapters.openai_model_adapter import OpenAICompatibleAdapter
10
11
  from kiln_ai.adapters.prompt_builders import BasePromptBuilder
@@ -84,24 +85,19 @@ def test_langchain_adapter_creation(mock_config, basic_task, provider):
84
85
  )
85
86
 
86
87
  assert isinstance(adapter, LangchainAdapter)
87
- assert adapter.model_name == "test-model"
88
+ assert adapter.run_config.model_name == "test-model"
88
89
 
89
90
 
90
91
  # TODO should run for all cases
91
92
  def test_custom_prompt_builder(mock_config, basic_task):
92
- class TestPromptBuilder(BasePromptBuilder):
93
- def build_base_prompt(self, kiln_task) -> str:
94
- return "test-prompt"
95
-
96
- prompt_builder = TestPromptBuilder(basic_task)
97
93
  adapter = adapter_for_task(
98
94
  kiln_task=basic_task,
99
95
  model_name="gpt-4",
100
96
  provider=ModelProviderName.openai,
101
- prompt_builder=prompt_builder,
97
+ prompt_id="simple_chain_of_thought_prompt_builder",
102
98
  )
103
99
 
104
- assert adapter.prompt_builder == prompt_builder
100
+ assert adapter.run_config.prompt_id == "simple_chain_of_thought_prompt_builder"
105
101
 
106
102
 
107
103
  # TODO should run for all cases
@@ -111,10 +107,12 @@ def test_tags_passed_through(mock_config, basic_task):
111
107
  kiln_task=basic_task,
112
108
  model_name="gpt-4",
113
109
  provider=ModelProviderName.openai,
114
- tags=tags,
110
+ base_adapter_config=AdapterConfig(
111
+ default_tags=tags,
112
+ ),
115
113
  )
116
114
 
117
- assert adapter.default_tags == tags
115
+ assert adapter.base_adapter_config.default_tags == tags
118
116
 
119
117
 
120
118
  def test_invalid_provider(mock_config, basic_task):
@@ -129,6 +127,7 @@ def test_openai_compatible_adapter(mock_compatible_config, mock_config, basic_ta
129
127
  mock_compatible_config.return_value.model_name = "test-model"
130
128
  mock_compatible_config.return_value.api_key = "test-key"
131
129
  mock_compatible_config.return_value.base_url = "https://test.com/v1"
130
+ mock_compatible_config.return_value.provider_name = "CustomProvider99"
132
131
 
133
132
  adapter = adapter_for_task(
134
133
  kiln_task=basic_task,
@@ -141,6 +140,7 @@ def test_openai_compatible_adapter(mock_compatible_config, mock_config, basic_ta
141
140
  assert adapter.config.model_name == "test-model"
142
141
  assert adapter.config.api_key == "test-key"
143
142
  assert adapter.config.base_url == "https://test.com/v1"
143
+ assert adapter.config.provider_name == "CustomProvider99"
144
144
 
145
145
 
146
146
  def test_custom_openai_compatible_provider(mock_config, basic_task):
@@ -1,13 +1,13 @@
1
+ import logging
1
2
  from typing import List
2
3
 
3
4
  import pytest
4
5
 
5
- from libs.core.kiln_ai.adapters.ml_model_list import (
6
- KilnModelProvider,
7
- built_in_models,
8
- )
6
+ from libs.core.kiln_ai.adapters.ml_model_list import KilnModelProvider, built_in_models
9
7
  from libs.core.kiln_ai.adapters.provider_tools import provider_name_from_id
10
8
 
9
+ logger = logging.getLogger(__name__)
10
+
11
11
 
12
12
  def _all_providers_support(providers: List[KilnModelProvider], attribute: str) -> bool:
13
13
  """Check if all providers support a given feature"""
@@ -58,8 +58,8 @@ def test_generate_model_table():
58
58
  table.append(row)
59
59
 
60
60
  # Print the table (useful for documentation)
61
- print("\nModel Capability Matrix:\n")
62
- print("\n".join(table))
61
+ logger.info("\nModel Capability Matrix:\n")
62
+ logger.info("\n".join(table))
63
63
 
64
64
  # Basic assertions to ensure the table is well-formed
65
65
  assert len(table) > 2, "Table should have header and at least one row"
@@ -10,7 +10,6 @@ from kiln_ai.adapters.ollama_tools import (
10
10
  def test_parse_ollama_tags_no_models():
11
11
  json_response = '{"models":[{"name":"scosman_net","model":"scosman_net:latest"},{"name":"phi3.5:latest","model":"phi3.5:latest","modified_at":"2024-10-02T12:04:35.191519822-04:00","size":2176178843,"digest":"61819fb370a3c1a9be6694869331e5f85f867a079e9271d66cb223acb81d04ba","details":{"parent_model":"","format":"gguf","family":"phi3","families":["phi3"],"parameter_size":"3.8B","quantization_level":"Q4_0"}},{"name":"gemma2:2b","model":"gemma2:2b","modified_at":"2024-09-09T16:46:38.64348929-04:00","size":1629518495,"digest":"8ccf136fdd5298f3ffe2d69862750ea7fb56555fa4d5b18c04e3fa4d82ee09d7","details":{"parent_model":"","format":"gguf","family":"gemma2","families":["gemma2"],"parameter_size":"2.6B","quantization_level":"Q4_0"}},{"name":"llama3.1:latest","model":"llama3.1:latest","modified_at":"2024-09-01T17:19:43.481523695-04:00","size":4661230720,"digest":"f66fc8dc39ea206e03ff6764fcc696b1b4dfb693f0b6ef751731dd4e6269046e","details":{"parent_model":"","format":"gguf","family":"llama","families":["llama"],"parameter_size":"8.0B","quantization_level":"Q4_0"}}]}'
12
12
  tags = json.loads(json_response)
13
- print(json.dumps(tags, indent=2))
14
13
  conn = parse_ollama_tags(tags)
15
14
  assert "phi3.5:latest" in conn.supported_models
16
15
  assert "gemma2:2b" in conn.supported_models
@@ -13,6 +13,7 @@ from kiln_ai.adapters.prompt_builders import (
13
13
  BasePromptBuilder,
14
14
  SimpleChainOfThoughtPromptBuilder,
15
15
  )
16
+ from kiln_ai.datamodel import PromptId
16
17
 
17
18
 
18
19
  def get_all_models_and_providers():
@@ -132,7 +133,7 @@ async def test_mock_returning_run(tmp_path):
132
133
  "adapter_name": "kiln_langchain_adapter",
133
134
  "model_name": "custom.langchain:unknown_model",
134
135
  "model_provider": "ollama",
135
- "prompt_builder_name": "simple_prompt_builder",
136
+ "prompt_id": "simple_prompt_builder",
136
137
  }
137
138
 
138
139
 
@@ -149,8 +150,9 @@ async def test_all_models_providers_plaintext(tmp_path, model_name, provider_nam
149
150
  @pytest.mark.parametrize("model_name,provider_name", get_all_models_and_providers())
150
151
  async def test_cot_prompt_builder(tmp_path, model_name, provider_name):
151
152
  task = build_test_task(tmp_path)
152
- pb = SimpleChainOfThoughtPromptBuilder(task)
153
- await run_simple_task(task, model_name, provider_name, pb)
153
+ await run_simple_task(
154
+ task, model_name, provider_name, "simple_chain_of_thought_prompt_builder"
155
+ )
154
156
 
155
157
 
156
158
  def build_test_task(tmp_path: Path):
@@ -186,20 +188,20 @@ async def run_simple_test(
186
188
  tmp_path: Path,
187
189
  model_name: str,
188
190
  provider: str | None = None,
189
- prompt_builder: BasePromptBuilder | None = None,
191
+ prompt_id: PromptId | None = None,
190
192
  ):
191
193
  task = build_test_task(tmp_path)
192
- return await run_simple_task(task, model_name, provider, prompt_builder)
194
+ return await run_simple_task(task, model_name, provider, prompt_id)
193
195
 
194
196
 
195
197
  async def run_simple_task(
196
198
  task: datamodel.Task,
197
199
  model_name: str,
198
200
  provider: str,
199
- prompt_builder: BasePromptBuilder | None = None,
201
+ prompt_id: PromptId | None = None,
200
202
  ) -> datamodel.TaskRun:
201
203
  adapter = adapter_for_task(
202
- task, model_name=model_name, provider=provider, prompt_builder=prompt_builder
204
+ task, model_name=model_name, provider=provider, prompt_id=prompt_id
203
205
  )
204
206
 
205
207
  run = await adapter.invoke(
@@ -212,13 +214,14 @@ async def run_simple_task(
212
214
  )
213
215
  assert "64" in run.output.output
214
216
  source_props = run.output.source.properties
215
- assert source_props["adapter_name"] == "kiln_langchain_adapter"
217
+ assert source_props["adapter_name"] in [
218
+ "kiln_langchain_adapter",
219
+ "kiln_openai_compatible_adapter",
220
+ ]
216
221
  assert source_props["model_name"] == model_name
217
222
  assert source_props["model_provider"] == provider
218
- expected_prompt_builder_name = (
219
- prompt_builder.__class__.prompt_builder_name()
220
- if prompt_builder
221
- else "simple_prompt_builder"
222
- )
223
- assert source_props["prompt_builder_name"] == expected_prompt_builder_name
223
+ if prompt_id is None:
224
+ assert source_props["prompt_id"] == "simple_prompt_builder"
225
+ else:
226
+ assert source_props["prompt_id"] == prompt_id
224
227
  return run