kiln-ai 0.8.1__py3-none-any.whl → 0.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +7 -7
- kiln_ai/adapters/adapter_registry.py +81 -10
- kiln_ai/adapters/data_gen/data_gen_task.py +21 -3
- kiln_ai/adapters/data_gen/test_data_gen_task.py +23 -3
- kiln_ai/adapters/eval/base_eval.py +164 -0
- kiln_ai/adapters/eval/eval_runner.py +267 -0
- kiln_ai/adapters/eval/g_eval.py +367 -0
- kiln_ai/adapters/eval/registry.py +16 -0
- kiln_ai/adapters/eval/test_base_eval.py +324 -0
- kiln_ai/adapters/eval/test_eval_runner.py +640 -0
- kiln_ai/adapters/eval/test_g_eval.py +497 -0
- kiln_ai/adapters/eval/test_g_eval_data.py +4 -0
- kiln_ai/adapters/fine_tune/base_finetune.py +5 -1
- kiln_ai/adapters/fine_tune/dataset_formatter.py +310 -65
- kiln_ai/adapters/fine_tune/fireworks_finetune.py +47 -32
- kiln_ai/adapters/fine_tune/openai_finetune.py +12 -11
- kiln_ai/adapters/fine_tune/test_base_finetune.py +19 -0
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +472 -129
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +114 -22
- kiln_ai/adapters/fine_tune/test_openai_finetune.py +125 -14
- kiln_ai/adapters/ml_model_list.py +434 -93
- kiln_ai/adapters/model_adapters/__init__.py +18 -0
- kiln_ai/adapters/model_adapters/base_adapter.py +250 -0
- kiln_ai/adapters/model_adapters/langchain_adapters.py +309 -0
- kiln_ai/adapters/model_adapters/openai_compatible_config.py +10 -0
- kiln_ai/adapters/model_adapters/openai_model_adapter.py +289 -0
- kiln_ai/adapters/model_adapters/test_base_adapter.py +199 -0
- kiln_ai/adapters/{test_langchain_adapter.py → model_adapters/test_langchain_adapter.py} +105 -97
- kiln_ai/adapters/model_adapters/test_openai_model_adapter.py +216 -0
- kiln_ai/adapters/{test_saving_adapter_results.py → model_adapters/test_saving_adapter_results.py} +80 -30
- kiln_ai/adapters/{test_structured_output.py → model_adapters/test_structured_output.py} +125 -46
- kiln_ai/adapters/ollama_tools.py +0 -1
- kiln_ai/adapters/parsers/__init__.py +10 -0
- kiln_ai/adapters/parsers/base_parser.py +12 -0
- kiln_ai/adapters/parsers/json_parser.py +37 -0
- kiln_ai/adapters/parsers/parser_registry.py +19 -0
- kiln_ai/adapters/parsers/r1_parser.py +69 -0
- kiln_ai/adapters/parsers/test_json_parser.py +81 -0
- kiln_ai/adapters/parsers/test_parser_registry.py +32 -0
- kiln_ai/adapters/parsers/test_r1_parser.py +144 -0
- kiln_ai/adapters/prompt_builders.py +193 -49
- kiln_ai/adapters/provider_tools.py +91 -36
- kiln_ai/adapters/repair/repair_task.py +18 -19
- kiln_ai/adapters/repair/test_repair_task.py +7 -7
- kiln_ai/adapters/run_output.py +11 -0
- kiln_ai/adapters/test_adapter_registry.py +177 -0
- kiln_ai/adapters/test_generate_docs.py +69 -0
- kiln_ai/adapters/test_ollama_tools.py +0 -1
- kiln_ai/adapters/test_prompt_adaptors.py +25 -18
- kiln_ai/adapters/test_prompt_builders.py +265 -44
- kiln_ai/adapters/test_provider_tools.py +268 -46
- kiln_ai/datamodel/__init__.py +51 -772
- kiln_ai/datamodel/basemodel.py +31 -11
- kiln_ai/datamodel/datamodel_enums.py +58 -0
- kiln_ai/datamodel/dataset_filters.py +114 -0
- kiln_ai/datamodel/dataset_split.py +170 -0
- kiln_ai/datamodel/eval.py +298 -0
- kiln_ai/datamodel/finetune.py +105 -0
- kiln_ai/datamodel/json_schema.py +14 -3
- kiln_ai/datamodel/model_cache.py +8 -3
- kiln_ai/datamodel/project.py +23 -0
- kiln_ai/datamodel/prompt.py +37 -0
- kiln_ai/datamodel/prompt_id.py +83 -0
- kiln_ai/datamodel/strict_mode.py +24 -0
- kiln_ai/datamodel/task.py +181 -0
- kiln_ai/datamodel/task_output.py +321 -0
- kiln_ai/datamodel/task_run.py +164 -0
- kiln_ai/datamodel/test_basemodel.py +80 -2
- kiln_ai/datamodel/test_dataset_filters.py +71 -0
- kiln_ai/datamodel/test_dataset_split.py +127 -6
- kiln_ai/datamodel/test_datasource.py +3 -2
- kiln_ai/datamodel/test_eval_model.py +635 -0
- kiln_ai/datamodel/test_example_models.py +34 -17
- kiln_ai/datamodel/test_json_schema.py +23 -0
- kiln_ai/datamodel/test_model_cache.py +24 -0
- kiln_ai/datamodel/test_model_perf.py +125 -0
- kiln_ai/datamodel/test_models.py +131 -2
- kiln_ai/datamodel/test_prompt_id.py +129 -0
- kiln_ai/datamodel/test_task.py +159 -0
- kiln_ai/utils/config.py +6 -1
- kiln_ai/utils/exhaustive_error.py +6 -0
- {kiln_ai-0.8.1.dist-info → kiln_ai-0.12.0.dist-info}/METADATA +45 -7
- kiln_ai-0.12.0.dist-info/RECORD +100 -0
- kiln_ai/adapters/base_adapter.py +0 -191
- kiln_ai/adapters/langchain_adapters.py +0 -256
- kiln_ai-0.8.1.dist-info/RECORD +0 -58
- {kiln_ai-0.8.1.dist-info → kiln_ai-0.12.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.8.1.dist-info → kiln_ai-0.12.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -1,20 +1,24 @@
|
|
|
1
1
|
from dataclasses import dataclass
|
|
2
|
-
from typing import Dict, List
|
|
2
|
+
from typing import Dict, List
|
|
3
3
|
|
|
4
4
|
from kiln_ai.adapters.ml_model_list import (
|
|
5
5
|
KilnModel,
|
|
6
6
|
KilnModelProvider,
|
|
7
7
|
ModelName,
|
|
8
8
|
ModelProviderName,
|
|
9
|
+
StructuredOutputMode,
|
|
9
10
|
built_in_models,
|
|
10
11
|
)
|
|
12
|
+
from kiln_ai.adapters.model_adapters.openai_compatible_config import (
|
|
13
|
+
OpenAICompatibleConfig,
|
|
14
|
+
)
|
|
11
15
|
from kiln_ai.adapters.ollama_tools import (
|
|
12
16
|
get_ollama_connection,
|
|
13
17
|
)
|
|
14
18
|
from kiln_ai.datamodel import Finetune, Task
|
|
15
19
|
from kiln_ai.datamodel.registry import project_from_id
|
|
16
|
-
|
|
17
|
-
from
|
|
20
|
+
from kiln_ai.utils.config import Config
|
|
21
|
+
from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
|
|
18
22
|
|
|
19
23
|
|
|
20
24
|
async def provider_enabled(provider_name: ModelProviderName) -> bool:
|
|
@@ -61,7 +65,7 @@ def check_provider_warnings(provider_name: ModelProviderName):
|
|
|
61
65
|
raise ValueError(warning_check.message)
|
|
62
66
|
|
|
63
67
|
|
|
64
|
-
|
|
68
|
+
def builtin_model_from(
|
|
65
69
|
name: str, provider_name: str | None = None
|
|
66
70
|
) -> KilnModelProvider | None:
|
|
67
71
|
"""
|
|
@@ -102,7 +106,47 @@ async def builtin_model_from(
|
|
|
102
106
|
return provider
|
|
103
107
|
|
|
104
108
|
|
|
105
|
-
|
|
109
|
+
def core_provider(model_id: str, provider_name: ModelProviderName) -> ModelProviderName:
|
|
110
|
+
"""
|
|
111
|
+
Get the provider that should be run.
|
|
112
|
+
|
|
113
|
+
Some provider IDs are wrappers (fine-tunes, custom models). This maps these to runnable providers (openai, ollama, etc)
|
|
114
|
+
"""
|
|
115
|
+
|
|
116
|
+
# Custom models map to the underlying provider
|
|
117
|
+
if provider_name is ModelProviderName.kiln_custom_registry:
|
|
118
|
+
provider_name, _ = parse_custom_model_id(model_id)
|
|
119
|
+
return provider_name
|
|
120
|
+
|
|
121
|
+
# Fine-tune provider maps to an underlying provider
|
|
122
|
+
if provider_name is ModelProviderName.kiln_fine_tune:
|
|
123
|
+
finetune = finetune_from_id(model_id)
|
|
124
|
+
if finetune.provider not in ModelProviderName.__members__:
|
|
125
|
+
raise ValueError(
|
|
126
|
+
f"Finetune {model_id} has no underlying provider {finetune.provider}"
|
|
127
|
+
)
|
|
128
|
+
return ModelProviderName(finetune.provider)
|
|
129
|
+
|
|
130
|
+
return provider_name
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def parse_custom_model_id(
|
|
134
|
+
model_id: str,
|
|
135
|
+
) -> tuple[ModelProviderName, str]:
|
|
136
|
+
if "::" not in model_id:
|
|
137
|
+
raise ValueError(f"Invalid custom model ID: {model_id}")
|
|
138
|
+
|
|
139
|
+
# For custom registry, get the provider name and model name from the model id
|
|
140
|
+
provider_name = model_id.split("::", 1)[0]
|
|
141
|
+
model_name = model_id.split("::", 1)[1]
|
|
142
|
+
|
|
143
|
+
if provider_name not in ModelProviderName.__members__:
|
|
144
|
+
raise ValueError(f"Invalid provider name: {provider_name}")
|
|
145
|
+
|
|
146
|
+
return ModelProviderName(provider_name), model_name
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def kiln_model_provider_from(
|
|
106
150
|
name: str, provider_name: str | None = None
|
|
107
151
|
) -> KilnModelProvider:
|
|
108
152
|
if provider_name == ModelProviderName.kiln_fine_tune:
|
|
@@ -111,14 +155,13 @@ async def kiln_model_provider_from(
|
|
|
111
155
|
if provider_name == ModelProviderName.openai_compatible:
|
|
112
156
|
return openai_compatible_provider_model(name)
|
|
113
157
|
|
|
114
|
-
built_in_model =
|
|
158
|
+
built_in_model = builtin_model_from(name, provider_name)
|
|
115
159
|
if built_in_model:
|
|
116
160
|
return built_in_model
|
|
117
161
|
|
|
118
162
|
# For custom registry, get the provider name and model name from the model id
|
|
119
163
|
if provider_name == ModelProviderName.kiln_custom_registry:
|
|
120
|
-
provider_name = name
|
|
121
|
-
name = name.split("::", 1)[1]
|
|
164
|
+
provider_name, name = parse_custom_model_id(name)
|
|
122
165
|
|
|
123
166
|
# Custom/untested model. Set untested, and build a ModelProvider at runtime
|
|
124
167
|
if provider_name is None:
|
|
@@ -136,12 +179,9 @@ async def kiln_model_provider_from(
|
|
|
136
179
|
)
|
|
137
180
|
|
|
138
181
|
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
def openai_compatible_provider_model(
|
|
182
|
+
def openai_compatible_config(
|
|
143
183
|
model_id: str,
|
|
144
|
-
) ->
|
|
184
|
+
) -> OpenAICompatibleConfig:
|
|
145
185
|
try:
|
|
146
186
|
openai_provider_name, model_id = model_id.split("::")
|
|
147
187
|
except Exception:
|
|
@@ -165,12 +205,21 @@ def openai_compatible_provider_model(
|
|
|
165
205
|
f"OpenAI compatible provider {openai_provider_name} has no base URL"
|
|
166
206
|
)
|
|
167
207
|
|
|
208
|
+
return OpenAICompatibleConfig(
|
|
209
|
+
api_key=api_key,
|
|
210
|
+
model_name=model_id,
|
|
211
|
+
provider_name=ModelProviderName.openai_compatible,
|
|
212
|
+
base_url=base_url,
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
def openai_compatible_provider_model(
|
|
217
|
+
model_id: str,
|
|
218
|
+
) -> KilnModelProvider:
|
|
168
219
|
return KilnModelProvider(
|
|
169
220
|
name=ModelProviderName.openai_compatible,
|
|
170
221
|
provider_options={
|
|
171
222
|
"model": model_id,
|
|
172
|
-
"api_key": api_key,
|
|
173
|
-
"openai_api_base": base_url,
|
|
174
223
|
},
|
|
175
224
|
supports_structured_output=False,
|
|
176
225
|
supports_data_gen=False,
|
|
@@ -178,9 +227,10 @@ def openai_compatible_provider_model(
|
|
|
178
227
|
)
|
|
179
228
|
|
|
180
229
|
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
230
|
+
finetune_cache: dict[str, Finetune] = {}
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
def finetune_from_id(model_id: str) -> Finetune:
|
|
184
234
|
if model_id in finetune_cache:
|
|
185
235
|
return finetune_cache[model_id]
|
|
186
236
|
|
|
@@ -202,6 +252,15 @@ def finetune_provider_model(
|
|
|
202
252
|
f"Fine tune {fine_tune_id} not completed. Refresh it's status in the fine-tune tab."
|
|
203
253
|
)
|
|
204
254
|
|
|
255
|
+
finetune_cache[model_id] = fine_tune
|
|
256
|
+
return fine_tune
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
def finetune_provider_model(
|
|
260
|
+
model_id: str,
|
|
261
|
+
) -> KilnModelProvider:
|
|
262
|
+
fine_tune = finetune_from_id(model_id)
|
|
263
|
+
|
|
205
264
|
provider = ModelProviderName[fine_tune.provider]
|
|
206
265
|
model_provider = KilnModelProvider(
|
|
207
266
|
name=provider,
|
|
@@ -210,18 +269,18 @@ def finetune_provider_model(
|
|
|
210
269
|
},
|
|
211
270
|
)
|
|
212
271
|
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
272
|
+
if fine_tune.structured_output_mode is not None:
|
|
273
|
+
# If we know the model was trained with specific output mode, set it
|
|
274
|
+
model_provider.structured_output_mode = fine_tune.structured_output_mode
|
|
275
|
+
else:
|
|
276
|
+
# Some early adopters won't have structured_output_mode set on their fine-tunes.
|
|
277
|
+
# We know that OpenAI uses json_schema, and Fireworks (only other provider) use json_mode.
|
|
278
|
+
# This can be removed in the future
|
|
279
|
+
if provider == ModelProviderName.openai:
|
|
280
|
+
model_provider.structured_output_mode = StructuredOutputMode.json_schema
|
|
281
|
+
else:
|
|
282
|
+
model_provider.structured_output_mode = StructuredOutputMode.json_mode
|
|
283
|
+
|
|
225
284
|
return model_provider
|
|
226
285
|
|
|
227
286
|
|
|
@@ -274,7 +333,7 @@ def provider_name_from_id(id: str) -> str:
|
|
|
274
333
|
return "OpenAI Compatible"
|
|
275
334
|
case _:
|
|
276
335
|
# triggers pyright warning if I miss a case
|
|
277
|
-
|
|
336
|
+
raise_exhaustive_enum_error(enum_id)
|
|
278
337
|
|
|
279
338
|
return "Unknown provider: " + id
|
|
280
339
|
|
|
@@ -316,16 +375,12 @@ def provider_options_for_custom_model(
|
|
|
316
375
|
)
|
|
317
376
|
case _:
|
|
318
377
|
# triggers pyright warning if I miss a case
|
|
319
|
-
|
|
378
|
+
raise_exhaustive_enum_error(enum_id)
|
|
320
379
|
|
|
321
380
|
# Won't reach this, type checking will catch missed values
|
|
322
381
|
return {"model": model_name}
|
|
323
382
|
|
|
324
383
|
|
|
325
|
-
def raise_exhaustive_error(value: NoReturn) -> NoReturn:
|
|
326
|
-
raise ValueError(f"Unhandled enum value: {value}")
|
|
327
|
-
|
|
328
|
-
|
|
329
384
|
@dataclass
|
|
330
385
|
class ModelProviderWarning:
|
|
331
386
|
required_config_keys: List[str]
|
|
@@ -3,7 +3,11 @@ from typing import Type
|
|
|
3
3
|
|
|
4
4
|
from pydantic import BaseModel, Field
|
|
5
5
|
|
|
6
|
-
from kiln_ai.adapters.prompt_builders import
|
|
6
|
+
from kiln_ai.adapters.prompt_builders import (
|
|
7
|
+
BasePromptBuilder,
|
|
8
|
+
SavedPromptBuilder,
|
|
9
|
+
prompt_builder_from_id,
|
|
10
|
+
)
|
|
7
11
|
from kiln_ai.datamodel import Priority, Project, Task, TaskRequirement, TaskRun
|
|
8
12
|
|
|
9
13
|
|
|
@@ -42,24 +46,19 @@ feedback describing what should be improved. Your job is to understand the evalu
|
|
|
42
46
|
|
|
43
47
|
@classmethod
|
|
44
48
|
def _original_prompt(cls, run: TaskRun, task: Task) -> str:
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
)
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
if not isinstance(prompt_builder, BasePromptBuilder):
|
|
59
|
-
raise ValueError(
|
|
60
|
-
f"Prompt builder {prompt_builder_name} is not a valid prompt builder"
|
|
61
|
-
)
|
|
62
|
-
return prompt_builder.build_prompt()
|
|
49
|
+
if run.output.source is None or run.output.source.properties is None:
|
|
50
|
+
raise ValueError("No source properties found")
|
|
51
|
+
|
|
52
|
+
# Get the prompt builder id. Need the second check because we used to store this in a prompt_builder_name field, so loading legacy runs will need this.
|
|
53
|
+
prompt_id = run.output.source.properties.get(
|
|
54
|
+
"prompt_id"
|
|
55
|
+
) or run.output.source.properties.get("prompt_builder_name", None)
|
|
56
|
+
if prompt_id is not None and isinstance(prompt_id, str):
|
|
57
|
+
prompt_builder = prompt_builder_from_id(prompt_id, task)
|
|
58
|
+
if isinstance(prompt_builder, BasePromptBuilder):
|
|
59
|
+
return prompt_builder.build_prompt(include_json_instructions=False)
|
|
60
|
+
|
|
61
|
+
raise ValueError(f"Prompt builder '{prompt_id}' is not a valid prompt builder")
|
|
63
62
|
|
|
64
63
|
@classmethod
|
|
65
64
|
def build_repair_task_input(
|
|
@@ -6,8 +6,8 @@ import pytest
|
|
|
6
6
|
from pydantic import ValidationError
|
|
7
7
|
|
|
8
8
|
from kiln_ai.adapters.adapter_registry import adapter_for_task
|
|
9
|
-
from kiln_ai.adapters.base_adapter import RunOutput
|
|
10
|
-
from kiln_ai.adapters.langchain_adapters import LangchainAdapter
|
|
9
|
+
from kiln_ai.adapters.model_adapters.base_adapter import RunOutput
|
|
10
|
+
from kiln_ai.adapters.model_adapters.langchain_adapters import LangchainAdapter
|
|
11
11
|
from kiln_ai.adapters.repair.repair_task import (
|
|
12
12
|
RepairTaskInput,
|
|
13
13
|
RepairTaskRun,
|
|
@@ -95,7 +95,7 @@ def sample_task_run(sample_task):
|
|
|
95
95
|
"model_name": "gpt_4o",
|
|
96
96
|
"model_provider": "openai",
|
|
97
97
|
"adapter_name": "langchain_adapter",
|
|
98
|
-
"
|
|
98
|
+
"prompt_id": "simple_prompt_builder",
|
|
99
99
|
},
|
|
100
100
|
),
|
|
101
101
|
),
|
|
@@ -201,7 +201,7 @@ async def test_live_run(sample_task, sample_task_run, sample_repair_data):
|
|
|
201
201
|
"adapter_name": "kiln_langchain_adapter",
|
|
202
202
|
"model_name": "llama_3_1_8b",
|
|
203
203
|
"model_provider": "groq",
|
|
204
|
-
"
|
|
204
|
+
"prompt_id": "simple_prompt_builder",
|
|
205
205
|
}
|
|
206
206
|
|
|
207
207
|
|
|
@@ -223,7 +223,7 @@ async def test_mocked_repair_task_run(sample_task, sample_task_run, sample_repai
|
|
|
223
223
|
)
|
|
224
224
|
|
|
225
225
|
adapter = adapter_for_task(
|
|
226
|
-
repair_task, model_name="llama_3_1_8b", provider="
|
|
226
|
+
repair_task, model_name="llama_3_1_8b", provider="ollama"
|
|
227
227
|
)
|
|
228
228
|
|
|
229
229
|
run = await adapter.invoke(repair_task_input.model_dump())
|
|
@@ -237,8 +237,8 @@ async def test_mocked_repair_task_run(sample_task, sample_task_run, sample_repai
|
|
|
237
237
|
assert run.output.source.properties == {
|
|
238
238
|
"adapter_name": "kiln_langchain_adapter",
|
|
239
239
|
"model_name": "llama_3_1_8b",
|
|
240
|
-
"model_provider": "
|
|
241
|
-
"
|
|
240
|
+
"model_provider": "ollama",
|
|
241
|
+
"prompt_id": "simple_prompt_builder",
|
|
242
242
|
}
|
|
243
243
|
assert run.input_source.type == DataSourceType.human
|
|
244
244
|
assert "created_by" in run.input_source.properties
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Dict
|
|
3
|
+
|
|
4
|
+
from openai.types.chat.chat_completion import ChoiceLogprobs
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@dataclass
|
|
8
|
+
class RunOutput:
|
|
9
|
+
output: Dict | str
|
|
10
|
+
intermediate_outputs: Dict[str, str] | None
|
|
11
|
+
output_logprobs: ChoiceLogprobs | None = None
|
|
@@ -0,0 +1,177 @@
|
|
|
1
|
+
from unittest.mock import patch
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
|
|
5
|
+
from kiln_ai import datamodel
|
|
6
|
+
from kiln_ai.adapters.adapter_registry import adapter_for_task
|
|
7
|
+
from kiln_ai.adapters.ml_model_list import ModelProviderName
|
|
8
|
+
from kiln_ai.adapters.model_adapters.base_adapter import AdapterConfig
|
|
9
|
+
from kiln_ai.adapters.model_adapters.langchain_adapters import LangchainAdapter
|
|
10
|
+
from kiln_ai.adapters.model_adapters.openai_model_adapter import OpenAICompatibleAdapter
|
|
11
|
+
from kiln_ai.adapters.prompt_builders import BasePromptBuilder
|
|
12
|
+
from kiln_ai.adapters.provider_tools import kiln_model_provider_from
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@pytest.fixture
|
|
16
|
+
def mock_config():
|
|
17
|
+
with patch("kiln_ai.adapters.adapter_registry.Config") as mock:
|
|
18
|
+
mock.shared.return_value.open_ai_api_key = "test-openai-key"
|
|
19
|
+
mock.shared.return_value.open_router_api_key = "test-openrouter-key"
|
|
20
|
+
yield mock
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@pytest.fixture
|
|
24
|
+
def basic_task():
|
|
25
|
+
return datamodel.Task(
|
|
26
|
+
task_id="test-task",
|
|
27
|
+
task_type="test",
|
|
28
|
+
input_text="test input",
|
|
29
|
+
name="test-task",
|
|
30
|
+
instruction="test-task",
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@pytest.fixture
|
|
35
|
+
def mock_finetune_from_id():
|
|
36
|
+
with patch("kiln_ai.adapters.provider_tools.finetune_from_id") as mock:
|
|
37
|
+
mock.return_value.provider = ModelProviderName.openai
|
|
38
|
+
mock.return_value.fine_tune_model_id = "test-model"
|
|
39
|
+
yield mock
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def test_openai_adapter_creation(mock_config, basic_task):
|
|
43
|
+
adapter = adapter_for_task(
|
|
44
|
+
kiln_task=basic_task, model_name="gpt-4", provider=ModelProviderName.openai
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
assert isinstance(adapter, OpenAICompatibleAdapter)
|
|
48
|
+
assert adapter.config.model_name == "gpt-4"
|
|
49
|
+
assert adapter.config.api_key == "test-openai-key"
|
|
50
|
+
assert adapter.config.provider_name == ModelProviderName.openai
|
|
51
|
+
assert adapter.config.base_url is None # OpenAI url is default
|
|
52
|
+
assert adapter.config.default_headers is None
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def test_openrouter_adapter_creation(mock_config, basic_task):
|
|
56
|
+
adapter = adapter_for_task(
|
|
57
|
+
kiln_task=basic_task,
|
|
58
|
+
model_name="anthropic/claude-3-opus",
|
|
59
|
+
provider=ModelProviderName.openrouter,
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
assert isinstance(adapter, OpenAICompatibleAdapter)
|
|
63
|
+
assert adapter.config.model_name == "anthropic/claude-3-opus"
|
|
64
|
+
assert adapter.config.api_key == "test-openrouter-key"
|
|
65
|
+
assert adapter.config.provider_name == ModelProviderName.openrouter
|
|
66
|
+
assert adapter.config.base_url == "https://openrouter.ai/api/v1"
|
|
67
|
+
assert adapter.config.default_headers == {
|
|
68
|
+
"HTTP-Referer": "https://getkiln.ai/openrouter",
|
|
69
|
+
"X-Title": "KilnAI",
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
@pytest.mark.parametrize(
|
|
74
|
+
"provider",
|
|
75
|
+
[
|
|
76
|
+
ModelProviderName.groq,
|
|
77
|
+
ModelProviderName.amazon_bedrock,
|
|
78
|
+
ModelProviderName.ollama,
|
|
79
|
+
ModelProviderName.fireworks_ai,
|
|
80
|
+
],
|
|
81
|
+
)
|
|
82
|
+
def test_langchain_adapter_creation(mock_config, basic_task, provider):
|
|
83
|
+
adapter = adapter_for_task(
|
|
84
|
+
kiln_task=basic_task, model_name="test-model", provider=provider
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
assert isinstance(adapter, LangchainAdapter)
|
|
88
|
+
assert adapter.run_config.model_name == "test-model"
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
# TODO should run for all cases
|
|
92
|
+
def test_custom_prompt_builder(mock_config, basic_task):
|
|
93
|
+
adapter = adapter_for_task(
|
|
94
|
+
kiln_task=basic_task,
|
|
95
|
+
model_name="gpt-4",
|
|
96
|
+
provider=ModelProviderName.openai,
|
|
97
|
+
prompt_id="simple_chain_of_thought_prompt_builder",
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
assert adapter.run_config.prompt_id == "simple_chain_of_thought_prompt_builder"
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
# TODO should run for all cases
|
|
104
|
+
def test_tags_passed_through(mock_config, basic_task):
|
|
105
|
+
tags = ["test-tag-1", "test-tag-2"]
|
|
106
|
+
adapter = adapter_for_task(
|
|
107
|
+
kiln_task=basic_task,
|
|
108
|
+
model_name="gpt-4",
|
|
109
|
+
provider=ModelProviderName.openai,
|
|
110
|
+
base_adapter_config=AdapterConfig(
|
|
111
|
+
default_tags=tags,
|
|
112
|
+
),
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
assert adapter.base_adapter_config.default_tags == tags
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def test_invalid_provider(mock_config, basic_task):
|
|
119
|
+
with pytest.raises(ValueError, match="Unhandled enum value"):
|
|
120
|
+
adapter_for_task(
|
|
121
|
+
kiln_task=basic_task, model_name="test-model", provider="invalid"
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
@patch("kiln_ai.adapters.adapter_registry.openai_compatible_config")
|
|
126
|
+
def test_openai_compatible_adapter(mock_compatible_config, mock_config, basic_task):
|
|
127
|
+
mock_compatible_config.return_value.model_name = "test-model"
|
|
128
|
+
mock_compatible_config.return_value.api_key = "test-key"
|
|
129
|
+
mock_compatible_config.return_value.base_url = "https://test.com/v1"
|
|
130
|
+
mock_compatible_config.return_value.provider_name = "CustomProvider99"
|
|
131
|
+
|
|
132
|
+
adapter = adapter_for_task(
|
|
133
|
+
kiln_task=basic_task,
|
|
134
|
+
model_name="provider::test-model",
|
|
135
|
+
provider=ModelProviderName.openai_compatible,
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
assert isinstance(adapter, OpenAICompatibleAdapter)
|
|
139
|
+
mock_compatible_config.assert_called_once_with("provider::test-model")
|
|
140
|
+
assert adapter.config.model_name == "test-model"
|
|
141
|
+
assert adapter.config.api_key == "test-key"
|
|
142
|
+
assert adapter.config.base_url == "https://test.com/v1"
|
|
143
|
+
assert adapter.config.provider_name == "CustomProvider99"
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def test_custom_openai_compatible_provider(mock_config, basic_task):
|
|
147
|
+
adapter = adapter_for_task(
|
|
148
|
+
kiln_task=basic_task,
|
|
149
|
+
model_name="openai::test-model",
|
|
150
|
+
provider=ModelProviderName.kiln_custom_registry,
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
assert isinstance(adapter, OpenAICompatibleAdapter)
|
|
154
|
+
assert adapter.config.model_name == "openai::test-model"
|
|
155
|
+
assert adapter.config.api_key == "test-openai-key"
|
|
156
|
+
assert adapter.config.base_url is None # openai is none
|
|
157
|
+
assert adapter.config.provider_name == ModelProviderName.kiln_custom_registry
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
async def test_fine_tune_provider(mock_config, basic_task, mock_finetune_from_id):
|
|
161
|
+
adapter = adapter_for_task(
|
|
162
|
+
kiln_task=basic_task,
|
|
163
|
+
model_name="proj::task::tune",
|
|
164
|
+
provider=ModelProviderName.kiln_fine_tune,
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
mock_finetune_from_id.assert_called_once_with("proj::task::tune")
|
|
168
|
+
assert isinstance(adapter, OpenAICompatibleAdapter)
|
|
169
|
+
assert adapter.config.provider_name == ModelProviderName.kiln_fine_tune
|
|
170
|
+
# Kiln model name here, but the underlying openai model id below
|
|
171
|
+
assert adapter.config.model_name == "proj::task::tune"
|
|
172
|
+
|
|
173
|
+
provider = kiln_model_provider_from(
|
|
174
|
+
"proj::task::tune", provider_name=ModelProviderName.kiln_fine_tune
|
|
175
|
+
)
|
|
176
|
+
# The actual model name from the fine tune object
|
|
177
|
+
assert provider.provider_options["model"] == "test-model"
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import List
|
|
3
|
+
|
|
4
|
+
import pytest
|
|
5
|
+
|
|
6
|
+
from libs.core.kiln_ai.adapters.ml_model_list import KilnModelProvider, built_in_models
|
|
7
|
+
from libs.core.kiln_ai.adapters.provider_tools import provider_name_from_id
|
|
8
|
+
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def _all_providers_support(providers: List[KilnModelProvider], attribute: str) -> bool:
|
|
13
|
+
"""Check if all providers support a given feature"""
|
|
14
|
+
return all(getattr(provider, attribute) for provider in providers)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _any_providers_support(providers: List[KilnModelProvider], attribute: str) -> bool:
|
|
18
|
+
"""Check if any providers support a given feature"""
|
|
19
|
+
return any(getattr(provider, attribute) for provider in providers)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _get_support_status(providers: List[KilnModelProvider], attribute: str) -> str:
|
|
23
|
+
"""Get the support status for a feature"""
|
|
24
|
+
if _all_providers_support(providers, attribute):
|
|
25
|
+
return "✅︎"
|
|
26
|
+
elif _any_providers_support(providers, attribute):
|
|
27
|
+
return "✅︎ (some providers)"
|
|
28
|
+
return ""
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def _has_finetune_support(providers: List[KilnModelProvider]) -> str:
|
|
32
|
+
"""Check if any provider supports fine-tuning"""
|
|
33
|
+
return "✅︎" if any(p.provider_finetune_id for p in providers) else ""
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@pytest.mark.paid(reason="Marking as paid so it isn't run by default")
|
|
37
|
+
def test_generate_model_table():
|
|
38
|
+
"""Generate a markdown table of all models and their capabilities"""
|
|
39
|
+
|
|
40
|
+
# Table header
|
|
41
|
+
table = [
|
|
42
|
+
"| Model Name | Providers | Structured Output | Reasoning | Synthetic Data | API Fine-Tuneable |",
|
|
43
|
+
"|------------|-----------|-------------------|-----------|----------------|-------------------|",
|
|
44
|
+
]
|
|
45
|
+
|
|
46
|
+
for model in built_in_models:
|
|
47
|
+
provider_names = ", ".join(
|
|
48
|
+
sorted(provider_name_from_id(p.name.value) for p in model.providers)
|
|
49
|
+
)
|
|
50
|
+
structured_output = _get_support_status(
|
|
51
|
+
model.providers, "supports_structured_output"
|
|
52
|
+
)
|
|
53
|
+
reasoning = _get_support_status(model.providers, "reasoning_capable")
|
|
54
|
+
data_gen = _get_support_status(model.providers, "supports_data_gen")
|
|
55
|
+
finetune = _has_finetune_support(model.providers)
|
|
56
|
+
|
|
57
|
+
row = f"| {model.friendly_name} | {provider_names} | {structured_output} | {reasoning} | {data_gen} | {finetune} |"
|
|
58
|
+
table.append(row)
|
|
59
|
+
|
|
60
|
+
# Print the table (useful for documentation)
|
|
61
|
+
logger.info("\nModel Capability Matrix:\n")
|
|
62
|
+
logger.info("\n".join(table))
|
|
63
|
+
|
|
64
|
+
# Basic assertions to ensure the table is well-formed
|
|
65
|
+
assert len(table) > 2, "Table should have header and at least one row"
|
|
66
|
+
assert all("|" in row for row in table), "All rows should be properly formatted"
|
|
67
|
+
assert len(table[0].split("|")) == len(table[1].split("|")), (
|
|
68
|
+
"Header and separator should have same number of columns"
|
|
69
|
+
)
|
|
@@ -10,7 +10,6 @@ from kiln_ai.adapters.ollama_tools import (
|
|
|
10
10
|
def test_parse_ollama_tags_no_models():
|
|
11
11
|
json_response = '{"models":[{"name":"scosman_net","model":"scosman_net:latest"},{"name":"phi3.5:latest","model":"phi3.5:latest","modified_at":"2024-10-02T12:04:35.191519822-04:00","size":2176178843,"digest":"61819fb370a3c1a9be6694869331e5f85f867a079e9271d66cb223acb81d04ba","details":{"parent_model":"","format":"gguf","family":"phi3","families":["phi3"],"parameter_size":"3.8B","quantization_level":"Q4_0"}},{"name":"gemma2:2b","model":"gemma2:2b","modified_at":"2024-09-09T16:46:38.64348929-04:00","size":1629518495,"digest":"8ccf136fdd5298f3ffe2d69862750ea7fb56555fa4d5b18c04e3fa4d82ee09d7","details":{"parent_model":"","format":"gguf","family":"gemma2","families":["gemma2"],"parameter_size":"2.6B","quantization_level":"Q4_0"}},{"name":"llama3.1:latest","model":"llama3.1:latest","modified_at":"2024-09-01T17:19:43.481523695-04:00","size":4661230720,"digest":"f66fc8dc39ea206e03ff6764fcc696b1b4dfb693f0b6ef751731dd4e6269046e","details":{"parent_model":"","format":"gguf","family":"llama","families":["llama"],"parameter_size":"8.0B","quantization_level":"Q4_0"}}]}'
|
|
12
12
|
tags = json.loads(json_response)
|
|
13
|
-
print(json.dumps(tags, indent=2))
|
|
14
13
|
conn = parse_ollama_tags(tags)
|
|
15
14
|
assert "phi3.5:latest" in conn.supported_models
|
|
16
15
|
assert "gemma2:2b" in conn.supported_models
|