kiln-ai 0.8.1__py3-none-any.whl → 0.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (88) hide show
  1. kiln_ai/adapters/__init__.py +7 -7
  2. kiln_ai/adapters/adapter_registry.py +81 -10
  3. kiln_ai/adapters/data_gen/data_gen_task.py +21 -3
  4. kiln_ai/adapters/data_gen/test_data_gen_task.py +23 -3
  5. kiln_ai/adapters/eval/base_eval.py +164 -0
  6. kiln_ai/adapters/eval/eval_runner.py +267 -0
  7. kiln_ai/adapters/eval/g_eval.py +367 -0
  8. kiln_ai/adapters/eval/registry.py +16 -0
  9. kiln_ai/adapters/eval/test_base_eval.py +324 -0
  10. kiln_ai/adapters/eval/test_eval_runner.py +640 -0
  11. kiln_ai/adapters/eval/test_g_eval.py +497 -0
  12. kiln_ai/adapters/eval/test_g_eval_data.py +4 -0
  13. kiln_ai/adapters/fine_tune/base_finetune.py +5 -1
  14. kiln_ai/adapters/fine_tune/dataset_formatter.py +310 -65
  15. kiln_ai/adapters/fine_tune/fireworks_finetune.py +47 -32
  16. kiln_ai/adapters/fine_tune/openai_finetune.py +12 -11
  17. kiln_ai/adapters/fine_tune/test_base_finetune.py +19 -0
  18. kiln_ai/adapters/fine_tune/test_dataset_formatter.py +472 -129
  19. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +114 -22
  20. kiln_ai/adapters/fine_tune/test_openai_finetune.py +125 -14
  21. kiln_ai/adapters/ml_model_list.py +434 -93
  22. kiln_ai/adapters/model_adapters/__init__.py +18 -0
  23. kiln_ai/adapters/model_adapters/base_adapter.py +250 -0
  24. kiln_ai/adapters/model_adapters/langchain_adapters.py +309 -0
  25. kiln_ai/adapters/model_adapters/openai_compatible_config.py +10 -0
  26. kiln_ai/adapters/model_adapters/openai_model_adapter.py +289 -0
  27. kiln_ai/adapters/model_adapters/test_base_adapter.py +199 -0
  28. kiln_ai/adapters/{test_langchain_adapter.py → model_adapters/test_langchain_adapter.py} +105 -97
  29. kiln_ai/adapters/model_adapters/test_openai_model_adapter.py +216 -0
  30. kiln_ai/adapters/{test_saving_adapter_results.py → model_adapters/test_saving_adapter_results.py} +80 -30
  31. kiln_ai/adapters/{test_structured_output.py → model_adapters/test_structured_output.py} +125 -46
  32. kiln_ai/adapters/ollama_tools.py +0 -1
  33. kiln_ai/adapters/parsers/__init__.py +10 -0
  34. kiln_ai/adapters/parsers/base_parser.py +12 -0
  35. kiln_ai/adapters/parsers/json_parser.py +37 -0
  36. kiln_ai/adapters/parsers/parser_registry.py +19 -0
  37. kiln_ai/adapters/parsers/r1_parser.py +69 -0
  38. kiln_ai/adapters/parsers/test_json_parser.py +81 -0
  39. kiln_ai/adapters/parsers/test_parser_registry.py +32 -0
  40. kiln_ai/adapters/parsers/test_r1_parser.py +144 -0
  41. kiln_ai/adapters/prompt_builders.py +193 -49
  42. kiln_ai/adapters/provider_tools.py +91 -36
  43. kiln_ai/adapters/repair/repair_task.py +18 -19
  44. kiln_ai/adapters/repair/test_repair_task.py +7 -7
  45. kiln_ai/adapters/run_output.py +11 -0
  46. kiln_ai/adapters/test_adapter_registry.py +177 -0
  47. kiln_ai/adapters/test_generate_docs.py +69 -0
  48. kiln_ai/adapters/test_ollama_tools.py +0 -1
  49. kiln_ai/adapters/test_prompt_adaptors.py +25 -18
  50. kiln_ai/adapters/test_prompt_builders.py +265 -44
  51. kiln_ai/adapters/test_provider_tools.py +268 -46
  52. kiln_ai/datamodel/__init__.py +51 -772
  53. kiln_ai/datamodel/basemodel.py +31 -11
  54. kiln_ai/datamodel/datamodel_enums.py +58 -0
  55. kiln_ai/datamodel/dataset_filters.py +114 -0
  56. kiln_ai/datamodel/dataset_split.py +170 -0
  57. kiln_ai/datamodel/eval.py +298 -0
  58. kiln_ai/datamodel/finetune.py +105 -0
  59. kiln_ai/datamodel/json_schema.py +14 -3
  60. kiln_ai/datamodel/model_cache.py +8 -3
  61. kiln_ai/datamodel/project.py +23 -0
  62. kiln_ai/datamodel/prompt.py +37 -0
  63. kiln_ai/datamodel/prompt_id.py +83 -0
  64. kiln_ai/datamodel/strict_mode.py +24 -0
  65. kiln_ai/datamodel/task.py +181 -0
  66. kiln_ai/datamodel/task_output.py +321 -0
  67. kiln_ai/datamodel/task_run.py +164 -0
  68. kiln_ai/datamodel/test_basemodel.py +80 -2
  69. kiln_ai/datamodel/test_dataset_filters.py +71 -0
  70. kiln_ai/datamodel/test_dataset_split.py +127 -6
  71. kiln_ai/datamodel/test_datasource.py +3 -2
  72. kiln_ai/datamodel/test_eval_model.py +635 -0
  73. kiln_ai/datamodel/test_example_models.py +34 -17
  74. kiln_ai/datamodel/test_json_schema.py +23 -0
  75. kiln_ai/datamodel/test_model_cache.py +24 -0
  76. kiln_ai/datamodel/test_model_perf.py +125 -0
  77. kiln_ai/datamodel/test_models.py +131 -2
  78. kiln_ai/datamodel/test_prompt_id.py +129 -0
  79. kiln_ai/datamodel/test_task.py +159 -0
  80. kiln_ai/utils/config.py +6 -1
  81. kiln_ai/utils/exhaustive_error.py +6 -0
  82. {kiln_ai-0.8.1.dist-info → kiln_ai-0.12.0.dist-info}/METADATA +45 -7
  83. kiln_ai-0.12.0.dist-info/RECORD +100 -0
  84. kiln_ai/adapters/base_adapter.py +0 -191
  85. kiln_ai/adapters/langchain_adapters.py +0 -256
  86. kiln_ai-0.8.1.dist-info/RECORD +0 -58
  87. {kiln_ai-0.8.1.dist-info → kiln_ai-0.12.0.dist-info}/WHEEL +0 -0
  88. {kiln_ai-0.8.1.dist-info → kiln_ai-0.12.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -1,20 +1,24 @@
1
1
  from dataclasses import dataclass
2
- from typing import Dict, List, NoReturn
2
+ from typing import Dict, List
3
3
 
4
4
  from kiln_ai.adapters.ml_model_list import (
5
5
  KilnModel,
6
6
  KilnModelProvider,
7
7
  ModelName,
8
8
  ModelProviderName,
9
+ StructuredOutputMode,
9
10
  built_in_models,
10
11
  )
12
+ from kiln_ai.adapters.model_adapters.openai_compatible_config import (
13
+ OpenAICompatibleConfig,
14
+ )
11
15
  from kiln_ai.adapters.ollama_tools import (
12
16
  get_ollama_connection,
13
17
  )
14
18
  from kiln_ai.datamodel import Finetune, Task
15
19
  from kiln_ai.datamodel.registry import project_from_id
16
-
17
- from ..utils.config import Config
20
+ from kiln_ai.utils.config import Config
21
+ from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
18
22
 
19
23
 
20
24
  async def provider_enabled(provider_name: ModelProviderName) -> bool:
@@ -61,7 +65,7 @@ def check_provider_warnings(provider_name: ModelProviderName):
61
65
  raise ValueError(warning_check.message)
62
66
 
63
67
 
64
- async def builtin_model_from(
68
+ def builtin_model_from(
65
69
  name: str, provider_name: str | None = None
66
70
  ) -> KilnModelProvider | None:
67
71
  """
@@ -102,7 +106,47 @@ async def builtin_model_from(
102
106
  return provider
103
107
 
104
108
 
105
- async def kiln_model_provider_from(
109
+ def core_provider(model_id: str, provider_name: ModelProviderName) -> ModelProviderName:
110
+ """
111
+ Get the provider that should be run.
112
+
113
+ Some provider IDs are wrappers (fine-tunes, custom models). This maps these to runnable providers (openai, ollama, etc)
114
+ """
115
+
116
+ # Custom models map to the underlying provider
117
+ if provider_name is ModelProviderName.kiln_custom_registry:
118
+ provider_name, _ = parse_custom_model_id(model_id)
119
+ return provider_name
120
+
121
+ # Fine-tune provider maps to an underlying provider
122
+ if provider_name is ModelProviderName.kiln_fine_tune:
123
+ finetune = finetune_from_id(model_id)
124
+ if finetune.provider not in ModelProviderName.__members__:
125
+ raise ValueError(
126
+ f"Finetune {model_id} has no underlying provider {finetune.provider}"
127
+ )
128
+ return ModelProviderName(finetune.provider)
129
+
130
+ return provider_name
131
+
132
+
133
+ def parse_custom_model_id(
134
+ model_id: str,
135
+ ) -> tuple[ModelProviderName, str]:
136
+ if "::" not in model_id:
137
+ raise ValueError(f"Invalid custom model ID: {model_id}")
138
+
139
+ # For custom registry, get the provider name and model name from the model id
140
+ provider_name = model_id.split("::", 1)[0]
141
+ model_name = model_id.split("::", 1)[1]
142
+
143
+ if provider_name not in ModelProviderName.__members__:
144
+ raise ValueError(f"Invalid provider name: {provider_name}")
145
+
146
+ return ModelProviderName(provider_name), model_name
147
+
148
+
149
+ def kiln_model_provider_from(
106
150
  name: str, provider_name: str | None = None
107
151
  ) -> KilnModelProvider:
108
152
  if provider_name == ModelProviderName.kiln_fine_tune:
@@ -111,14 +155,13 @@ async def kiln_model_provider_from(
111
155
  if provider_name == ModelProviderName.openai_compatible:
112
156
  return openai_compatible_provider_model(name)
113
157
 
114
- built_in_model = await builtin_model_from(name, provider_name)
158
+ built_in_model = builtin_model_from(name, provider_name)
115
159
  if built_in_model:
116
160
  return built_in_model
117
161
 
118
162
  # For custom registry, get the provider name and model name from the model id
119
163
  if provider_name == ModelProviderName.kiln_custom_registry:
120
- provider_name = name.split("::", 1)[0]
121
- name = name.split("::", 1)[1]
164
+ provider_name, name = parse_custom_model_id(name)
122
165
 
123
166
  # Custom/untested model. Set untested, and build a ModelProvider at runtime
124
167
  if provider_name is None:
@@ -136,12 +179,9 @@ async def kiln_model_provider_from(
136
179
  )
137
180
 
138
181
 
139
- finetune_cache: dict[str, KilnModelProvider] = {}
140
-
141
-
142
- def openai_compatible_provider_model(
182
+ def openai_compatible_config(
143
183
  model_id: str,
144
- ) -> KilnModelProvider:
184
+ ) -> OpenAICompatibleConfig:
145
185
  try:
146
186
  openai_provider_name, model_id = model_id.split("::")
147
187
  except Exception:
@@ -165,12 +205,21 @@ def openai_compatible_provider_model(
165
205
  f"OpenAI compatible provider {openai_provider_name} has no base URL"
166
206
  )
167
207
 
208
+ return OpenAICompatibleConfig(
209
+ api_key=api_key,
210
+ model_name=model_id,
211
+ provider_name=ModelProviderName.openai_compatible,
212
+ base_url=base_url,
213
+ )
214
+
215
+
216
+ def openai_compatible_provider_model(
217
+ model_id: str,
218
+ ) -> KilnModelProvider:
168
219
  return KilnModelProvider(
169
220
  name=ModelProviderName.openai_compatible,
170
221
  provider_options={
171
222
  "model": model_id,
172
- "api_key": api_key,
173
- "openai_api_base": base_url,
174
223
  },
175
224
  supports_structured_output=False,
176
225
  supports_data_gen=False,
@@ -178,9 +227,10 @@ def openai_compatible_provider_model(
178
227
  )
179
228
 
180
229
 
181
- def finetune_provider_model(
182
- model_id: str,
183
- ) -> KilnModelProvider:
230
+ finetune_cache: dict[str, Finetune] = {}
231
+
232
+
233
+ def finetune_from_id(model_id: str) -> Finetune:
184
234
  if model_id in finetune_cache:
185
235
  return finetune_cache[model_id]
186
236
 
@@ -202,6 +252,15 @@ def finetune_provider_model(
202
252
  f"Fine tune {fine_tune_id} not completed. Refresh it's status in the fine-tune tab."
203
253
  )
204
254
 
255
+ finetune_cache[model_id] = fine_tune
256
+ return fine_tune
257
+
258
+
259
+ def finetune_provider_model(
260
+ model_id: str,
261
+ ) -> KilnModelProvider:
262
+ fine_tune = finetune_from_id(model_id)
263
+
205
264
  provider = ModelProviderName[fine_tune.provider]
206
265
  model_provider = KilnModelProvider(
207
266
  name=provider,
@@ -210,18 +269,18 @@ def finetune_provider_model(
210
269
  },
211
270
  )
212
271
 
213
- # TODO: Don't love this abstraction/logic.
214
- if fine_tune.provider == ModelProviderName.fireworks_ai:
215
- # Fireworks finetunes are trained with json, not tool calling (which is LC default format)
216
- model_provider.adapter_options = {
217
- "langchain": {
218
- "with_structured_output_options": {
219
- "method": "json_mode",
220
- }
221
- }
222
- }
223
-
224
- finetune_cache[model_id] = model_provider
272
+ if fine_tune.structured_output_mode is not None:
273
+ # If we know the model was trained with specific output mode, set it
274
+ model_provider.structured_output_mode = fine_tune.structured_output_mode
275
+ else:
276
+ # Some early adopters won't have structured_output_mode set on their fine-tunes.
277
+ # We know that OpenAI uses json_schema, and Fireworks (only other provider) use json_mode.
278
+ # This can be removed in the future
279
+ if provider == ModelProviderName.openai:
280
+ model_provider.structured_output_mode = StructuredOutputMode.json_schema
281
+ else:
282
+ model_provider.structured_output_mode = StructuredOutputMode.json_mode
283
+
225
284
  return model_provider
226
285
 
227
286
 
@@ -274,7 +333,7 @@ def provider_name_from_id(id: str) -> str:
274
333
  return "OpenAI Compatible"
275
334
  case _:
276
335
  # triggers pyright warning if I miss a case
277
- raise_exhaustive_error(enum_id)
336
+ raise_exhaustive_enum_error(enum_id)
278
337
 
279
338
  return "Unknown provider: " + id
280
339
 
@@ -316,16 +375,12 @@ def provider_options_for_custom_model(
316
375
  )
317
376
  case _:
318
377
  # triggers pyright warning if I miss a case
319
- raise_exhaustive_error(enum_id)
378
+ raise_exhaustive_enum_error(enum_id)
320
379
 
321
380
  # Won't reach this, type checking will catch missed values
322
381
  return {"model": model_name}
323
382
 
324
383
 
325
- def raise_exhaustive_error(value: NoReturn) -> NoReturn:
326
- raise ValueError(f"Unhandled enum value: {value}")
327
-
328
-
329
384
  @dataclass
330
385
  class ModelProviderWarning:
331
386
  required_config_keys: List[str]
@@ -3,7 +3,11 @@ from typing import Type
3
3
 
4
4
  from pydantic import BaseModel, Field
5
5
 
6
- from kiln_ai.adapters.prompt_builders import BasePromptBuilder, prompt_builder_registry
6
+ from kiln_ai.adapters.prompt_builders import (
7
+ BasePromptBuilder,
8
+ SavedPromptBuilder,
9
+ prompt_builder_from_id,
10
+ )
7
11
  from kiln_ai.datamodel import Priority, Project, Task, TaskRequirement, TaskRun
8
12
 
9
13
 
@@ -42,24 +46,19 @@ feedback describing what should be improved. Your job is to understand the evalu
42
46
 
43
47
  @classmethod
44
48
  def _original_prompt(cls, run: TaskRun, task: Task) -> str:
45
- prompt_builder_class: Type[BasePromptBuilder] | None = None
46
- prompt_builder_name = (
47
- run.output.source.properties.get("prompt_builder_name", None)
48
- if run.output.source
49
- else None
50
- )
51
- if prompt_builder_name is not None and isinstance(prompt_builder_name, str):
52
- prompt_builder_class = prompt_builder_registry.get(
53
- prompt_builder_name, None
54
- )
55
- if prompt_builder_class is None:
56
- raise ValueError(f"No prompt builder found for name: {prompt_builder_name}")
57
- prompt_builder = prompt_builder_class(task=task)
58
- if not isinstance(prompt_builder, BasePromptBuilder):
59
- raise ValueError(
60
- f"Prompt builder {prompt_builder_name} is not a valid prompt builder"
61
- )
62
- return prompt_builder.build_prompt()
49
+ if run.output.source is None or run.output.source.properties is None:
50
+ raise ValueError("No source properties found")
51
+
52
+ # Get the prompt builder id. Need the second check because we used to store this in a prompt_builder_name field, so loading legacy runs will need this.
53
+ prompt_id = run.output.source.properties.get(
54
+ "prompt_id"
55
+ ) or run.output.source.properties.get("prompt_builder_name", None)
56
+ if prompt_id is not None and isinstance(prompt_id, str):
57
+ prompt_builder = prompt_builder_from_id(prompt_id, task)
58
+ if isinstance(prompt_builder, BasePromptBuilder):
59
+ return prompt_builder.build_prompt(include_json_instructions=False)
60
+
61
+ raise ValueError(f"Prompt builder '{prompt_id}' is not a valid prompt builder")
63
62
 
64
63
  @classmethod
65
64
  def build_repair_task_input(
@@ -6,8 +6,8 @@ import pytest
6
6
  from pydantic import ValidationError
7
7
 
8
8
  from kiln_ai.adapters.adapter_registry import adapter_for_task
9
- from kiln_ai.adapters.base_adapter import RunOutput
10
- from kiln_ai.adapters.langchain_adapters import LangchainAdapter
9
+ from kiln_ai.adapters.model_adapters.base_adapter import RunOutput
10
+ from kiln_ai.adapters.model_adapters.langchain_adapters import LangchainAdapter
11
11
  from kiln_ai.adapters.repair.repair_task import (
12
12
  RepairTaskInput,
13
13
  RepairTaskRun,
@@ -95,7 +95,7 @@ def sample_task_run(sample_task):
95
95
  "model_name": "gpt_4o",
96
96
  "model_provider": "openai",
97
97
  "adapter_name": "langchain_adapter",
98
- "prompt_builder_name": "simple_prompt_builder",
98
+ "prompt_id": "simple_prompt_builder",
99
99
  },
100
100
  ),
101
101
  ),
@@ -201,7 +201,7 @@ async def test_live_run(sample_task, sample_task_run, sample_repair_data):
201
201
  "adapter_name": "kiln_langchain_adapter",
202
202
  "model_name": "llama_3_1_8b",
203
203
  "model_provider": "groq",
204
- "prompt_builder_name": "simple_prompt_builder",
204
+ "prompt_id": "simple_prompt_builder",
205
205
  }
206
206
 
207
207
 
@@ -223,7 +223,7 @@ async def test_mocked_repair_task_run(sample_task, sample_task_run, sample_repai
223
223
  )
224
224
 
225
225
  adapter = adapter_for_task(
226
- repair_task, model_name="llama_3_1_8b", provider="groq"
226
+ repair_task, model_name="llama_3_1_8b", provider="ollama"
227
227
  )
228
228
 
229
229
  run = await adapter.invoke(repair_task_input.model_dump())
@@ -237,8 +237,8 @@ async def test_mocked_repair_task_run(sample_task, sample_task_run, sample_repai
237
237
  assert run.output.source.properties == {
238
238
  "adapter_name": "kiln_langchain_adapter",
239
239
  "model_name": "llama_3_1_8b",
240
- "model_provider": "groq",
241
- "prompt_builder_name": "simple_prompt_builder",
240
+ "model_provider": "ollama",
241
+ "prompt_id": "simple_prompt_builder",
242
242
  }
243
243
  assert run.input_source.type == DataSourceType.human
244
244
  assert "created_by" in run.input_source.properties
@@ -0,0 +1,11 @@
1
+ from dataclasses import dataclass
2
+ from typing import Dict
3
+
4
+ from openai.types.chat.chat_completion import ChoiceLogprobs
5
+
6
+
7
+ @dataclass
8
+ class RunOutput:
9
+ output: Dict | str
10
+ intermediate_outputs: Dict[str, str] | None
11
+ output_logprobs: ChoiceLogprobs | None = None
@@ -0,0 +1,177 @@
1
+ from unittest.mock import patch
2
+
3
+ import pytest
4
+
5
+ from kiln_ai import datamodel
6
+ from kiln_ai.adapters.adapter_registry import adapter_for_task
7
+ from kiln_ai.adapters.ml_model_list import ModelProviderName
8
+ from kiln_ai.adapters.model_adapters.base_adapter import AdapterConfig
9
+ from kiln_ai.adapters.model_adapters.langchain_adapters import LangchainAdapter
10
+ from kiln_ai.adapters.model_adapters.openai_model_adapter import OpenAICompatibleAdapter
11
+ from kiln_ai.adapters.prompt_builders import BasePromptBuilder
12
+ from kiln_ai.adapters.provider_tools import kiln_model_provider_from
13
+
14
+
15
+ @pytest.fixture
16
+ def mock_config():
17
+ with patch("kiln_ai.adapters.adapter_registry.Config") as mock:
18
+ mock.shared.return_value.open_ai_api_key = "test-openai-key"
19
+ mock.shared.return_value.open_router_api_key = "test-openrouter-key"
20
+ yield mock
21
+
22
+
23
+ @pytest.fixture
24
+ def basic_task():
25
+ return datamodel.Task(
26
+ task_id="test-task",
27
+ task_type="test",
28
+ input_text="test input",
29
+ name="test-task",
30
+ instruction="test-task",
31
+ )
32
+
33
+
34
+ @pytest.fixture
35
+ def mock_finetune_from_id():
36
+ with patch("kiln_ai.adapters.provider_tools.finetune_from_id") as mock:
37
+ mock.return_value.provider = ModelProviderName.openai
38
+ mock.return_value.fine_tune_model_id = "test-model"
39
+ yield mock
40
+
41
+
42
+ def test_openai_adapter_creation(mock_config, basic_task):
43
+ adapter = adapter_for_task(
44
+ kiln_task=basic_task, model_name="gpt-4", provider=ModelProviderName.openai
45
+ )
46
+
47
+ assert isinstance(adapter, OpenAICompatibleAdapter)
48
+ assert adapter.config.model_name == "gpt-4"
49
+ assert adapter.config.api_key == "test-openai-key"
50
+ assert adapter.config.provider_name == ModelProviderName.openai
51
+ assert adapter.config.base_url is None # OpenAI url is default
52
+ assert adapter.config.default_headers is None
53
+
54
+
55
+ def test_openrouter_adapter_creation(mock_config, basic_task):
56
+ adapter = adapter_for_task(
57
+ kiln_task=basic_task,
58
+ model_name="anthropic/claude-3-opus",
59
+ provider=ModelProviderName.openrouter,
60
+ )
61
+
62
+ assert isinstance(adapter, OpenAICompatibleAdapter)
63
+ assert adapter.config.model_name == "anthropic/claude-3-opus"
64
+ assert adapter.config.api_key == "test-openrouter-key"
65
+ assert adapter.config.provider_name == ModelProviderName.openrouter
66
+ assert adapter.config.base_url == "https://openrouter.ai/api/v1"
67
+ assert adapter.config.default_headers == {
68
+ "HTTP-Referer": "https://getkiln.ai/openrouter",
69
+ "X-Title": "KilnAI",
70
+ }
71
+
72
+
73
+ @pytest.mark.parametrize(
74
+ "provider",
75
+ [
76
+ ModelProviderName.groq,
77
+ ModelProviderName.amazon_bedrock,
78
+ ModelProviderName.ollama,
79
+ ModelProviderName.fireworks_ai,
80
+ ],
81
+ )
82
+ def test_langchain_adapter_creation(mock_config, basic_task, provider):
83
+ adapter = adapter_for_task(
84
+ kiln_task=basic_task, model_name="test-model", provider=provider
85
+ )
86
+
87
+ assert isinstance(adapter, LangchainAdapter)
88
+ assert adapter.run_config.model_name == "test-model"
89
+
90
+
91
+ # TODO should run for all cases
92
+ def test_custom_prompt_builder(mock_config, basic_task):
93
+ adapter = adapter_for_task(
94
+ kiln_task=basic_task,
95
+ model_name="gpt-4",
96
+ provider=ModelProviderName.openai,
97
+ prompt_id="simple_chain_of_thought_prompt_builder",
98
+ )
99
+
100
+ assert adapter.run_config.prompt_id == "simple_chain_of_thought_prompt_builder"
101
+
102
+
103
+ # TODO should run for all cases
104
+ def test_tags_passed_through(mock_config, basic_task):
105
+ tags = ["test-tag-1", "test-tag-2"]
106
+ adapter = adapter_for_task(
107
+ kiln_task=basic_task,
108
+ model_name="gpt-4",
109
+ provider=ModelProviderName.openai,
110
+ base_adapter_config=AdapterConfig(
111
+ default_tags=tags,
112
+ ),
113
+ )
114
+
115
+ assert adapter.base_adapter_config.default_tags == tags
116
+
117
+
118
+ def test_invalid_provider(mock_config, basic_task):
119
+ with pytest.raises(ValueError, match="Unhandled enum value"):
120
+ adapter_for_task(
121
+ kiln_task=basic_task, model_name="test-model", provider="invalid"
122
+ )
123
+
124
+
125
+ @patch("kiln_ai.adapters.adapter_registry.openai_compatible_config")
126
+ def test_openai_compatible_adapter(mock_compatible_config, mock_config, basic_task):
127
+ mock_compatible_config.return_value.model_name = "test-model"
128
+ mock_compatible_config.return_value.api_key = "test-key"
129
+ mock_compatible_config.return_value.base_url = "https://test.com/v1"
130
+ mock_compatible_config.return_value.provider_name = "CustomProvider99"
131
+
132
+ adapter = adapter_for_task(
133
+ kiln_task=basic_task,
134
+ model_name="provider::test-model",
135
+ provider=ModelProviderName.openai_compatible,
136
+ )
137
+
138
+ assert isinstance(adapter, OpenAICompatibleAdapter)
139
+ mock_compatible_config.assert_called_once_with("provider::test-model")
140
+ assert adapter.config.model_name == "test-model"
141
+ assert adapter.config.api_key == "test-key"
142
+ assert adapter.config.base_url == "https://test.com/v1"
143
+ assert adapter.config.provider_name == "CustomProvider99"
144
+
145
+
146
+ def test_custom_openai_compatible_provider(mock_config, basic_task):
147
+ adapter = adapter_for_task(
148
+ kiln_task=basic_task,
149
+ model_name="openai::test-model",
150
+ provider=ModelProviderName.kiln_custom_registry,
151
+ )
152
+
153
+ assert isinstance(adapter, OpenAICompatibleAdapter)
154
+ assert adapter.config.model_name == "openai::test-model"
155
+ assert adapter.config.api_key == "test-openai-key"
156
+ assert adapter.config.base_url is None # openai is none
157
+ assert adapter.config.provider_name == ModelProviderName.kiln_custom_registry
158
+
159
+
160
+ async def test_fine_tune_provider(mock_config, basic_task, mock_finetune_from_id):
161
+ adapter = adapter_for_task(
162
+ kiln_task=basic_task,
163
+ model_name="proj::task::tune",
164
+ provider=ModelProviderName.kiln_fine_tune,
165
+ )
166
+
167
+ mock_finetune_from_id.assert_called_once_with("proj::task::tune")
168
+ assert isinstance(adapter, OpenAICompatibleAdapter)
169
+ assert adapter.config.provider_name == ModelProviderName.kiln_fine_tune
170
+ # Kiln model name here, but the underlying openai model id below
171
+ assert adapter.config.model_name == "proj::task::tune"
172
+
173
+ provider = kiln_model_provider_from(
174
+ "proj::task::tune", provider_name=ModelProviderName.kiln_fine_tune
175
+ )
176
+ # The actual model name from the fine tune object
177
+ assert provider.provider_options["model"] == "test-model"
@@ -0,0 +1,69 @@
1
+ import logging
2
+ from typing import List
3
+
4
+ import pytest
5
+
6
+ from libs.core.kiln_ai.adapters.ml_model_list import KilnModelProvider, built_in_models
7
+ from libs.core.kiln_ai.adapters.provider_tools import provider_name_from_id
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ def _all_providers_support(providers: List[KilnModelProvider], attribute: str) -> bool:
13
+ """Check if all providers support a given feature"""
14
+ return all(getattr(provider, attribute) for provider in providers)
15
+
16
+
17
+ def _any_providers_support(providers: List[KilnModelProvider], attribute: str) -> bool:
18
+ """Check if any providers support a given feature"""
19
+ return any(getattr(provider, attribute) for provider in providers)
20
+
21
+
22
+ def _get_support_status(providers: List[KilnModelProvider], attribute: str) -> str:
23
+ """Get the support status for a feature"""
24
+ if _all_providers_support(providers, attribute):
25
+ return "✅︎"
26
+ elif _any_providers_support(providers, attribute):
27
+ return "✅︎ (some providers)"
28
+ return ""
29
+
30
+
31
+ def _has_finetune_support(providers: List[KilnModelProvider]) -> str:
32
+ """Check if any provider supports fine-tuning"""
33
+ return "✅︎" if any(p.provider_finetune_id for p in providers) else ""
34
+
35
+
36
+ @pytest.mark.paid(reason="Marking as paid so it isn't run by default")
37
+ def test_generate_model_table():
38
+ """Generate a markdown table of all models and their capabilities"""
39
+
40
+ # Table header
41
+ table = [
42
+ "| Model Name | Providers | Structured Output | Reasoning | Synthetic Data | API Fine-Tuneable |",
43
+ "|------------|-----------|-------------------|-----------|----------------|-------------------|",
44
+ ]
45
+
46
+ for model in built_in_models:
47
+ provider_names = ", ".join(
48
+ sorted(provider_name_from_id(p.name.value) for p in model.providers)
49
+ )
50
+ structured_output = _get_support_status(
51
+ model.providers, "supports_structured_output"
52
+ )
53
+ reasoning = _get_support_status(model.providers, "reasoning_capable")
54
+ data_gen = _get_support_status(model.providers, "supports_data_gen")
55
+ finetune = _has_finetune_support(model.providers)
56
+
57
+ row = f"| {model.friendly_name} | {provider_names} | {structured_output} | {reasoning} | {data_gen} | {finetune} |"
58
+ table.append(row)
59
+
60
+ # Print the table (useful for documentation)
61
+ logger.info("\nModel Capability Matrix:\n")
62
+ logger.info("\n".join(table))
63
+
64
+ # Basic assertions to ensure the table is well-formed
65
+ assert len(table) > 2, "Table should have header and at least one row"
66
+ assert all("|" in row for row in table), "All rows should be properly formatted"
67
+ assert len(table[0].split("|")) == len(table[1].split("|")), (
68
+ "Header and separator should have same number of columns"
69
+ )
@@ -10,7 +10,6 @@ from kiln_ai.adapters.ollama_tools import (
10
10
  def test_parse_ollama_tags_no_models():
11
11
  json_response = '{"models":[{"name":"scosman_net","model":"scosman_net:latest"},{"name":"phi3.5:latest","model":"phi3.5:latest","modified_at":"2024-10-02T12:04:35.191519822-04:00","size":2176178843,"digest":"61819fb370a3c1a9be6694869331e5f85f867a079e9271d66cb223acb81d04ba","details":{"parent_model":"","format":"gguf","family":"phi3","families":["phi3"],"parameter_size":"3.8B","quantization_level":"Q4_0"}},{"name":"gemma2:2b","model":"gemma2:2b","modified_at":"2024-09-09T16:46:38.64348929-04:00","size":1629518495,"digest":"8ccf136fdd5298f3ffe2d69862750ea7fb56555fa4d5b18c04e3fa4d82ee09d7","details":{"parent_model":"","format":"gguf","family":"gemma2","families":["gemma2"],"parameter_size":"2.6B","quantization_level":"Q4_0"}},{"name":"llama3.1:latest","model":"llama3.1:latest","modified_at":"2024-09-01T17:19:43.481523695-04:00","size":4661230720,"digest":"f66fc8dc39ea206e03ff6764fcc696b1b4dfb693f0b6ef751731dd4e6269046e","details":{"parent_model":"","format":"gguf","family":"llama","families":["llama"],"parameter_size":"8.0B","quantization_level":"Q4_0"}}]}'
12
12
  tags = json.loads(json_response)
13
- print(json.dumps(tags, indent=2))
14
13
  conn = parse_ollama_tags(tags)
15
14
  assert "phi3.5:latest" in conn.supported_models
16
15
  assert "gemma2:2b" in conn.supported_models