kiln-ai 0.16.0__py3-none-any.whl → 0.18.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kiln_ai/adapters/__init__.py +2 -0
- kiln_ai/adapters/adapter_registry.py +22 -44
- kiln_ai/adapters/chat/__init__.py +8 -0
- kiln_ai/adapters/chat/chat_formatter.py +233 -0
- kiln_ai/adapters/chat/test_chat_formatter.py +131 -0
- kiln_ai/adapters/data_gen/data_gen_prompts.py +121 -36
- kiln_ai/adapters/data_gen/data_gen_task.py +49 -36
- kiln_ai/adapters/data_gen/test_data_gen_task.py +330 -40
- kiln_ai/adapters/eval/base_eval.py +7 -6
- kiln_ai/adapters/eval/eval_runner.py +9 -2
- kiln_ai/adapters/eval/g_eval.py +40 -17
- kiln_ai/adapters/eval/test_base_eval.py +174 -17
- kiln_ai/adapters/eval/test_eval_runner.py +3 -0
- kiln_ai/adapters/eval/test_g_eval.py +116 -5
- kiln_ai/adapters/fine_tune/base_finetune.py +3 -8
- kiln_ai/adapters/fine_tune/dataset_formatter.py +135 -273
- kiln_ai/adapters/fine_tune/test_base_finetune.py +10 -10
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +287 -353
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +3 -3
- kiln_ai/adapters/fine_tune/test_openai_finetune.py +6 -6
- kiln_ai/adapters/fine_tune/test_together_finetune.py +1 -0
- kiln_ai/adapters/fine_tune/test_vertex_finetune.py +6 -11
- kiln_ai/adapters/fine_tune/together_finetune.py +13 -2
- kiln_ai/adapters/ml_model_list.py +370 -84
- kiln_ai/adapters/model_adapters/base_adapter.py +73 -26
- kiln_ai/adapters/model_adapters/litellm_adapter.py +88 -97
- kiln_ai/adapters/model_adapters/litellm_config.py +3 -2
- kiln_ai/adapters/model_adapters/test_base_adapter.py +235 -61
- kiln_ai/adapters/model_adapters/test_litellm_adapter.py +104 -21
- kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +41 -0
- kiln_ai/adapters/model_adapters/test_structured_output.py +44 -12
- kiln_ai/adapters/parsers/parser_registry.py +0 -2
- kiln_ai/adapters/parsers/r1_parser.py +0 -1
- kiln_ai/adapters/prompt_builders.py +0 -16
- kiln_ai/adapters/provider_tools.py +27 -9
- kiln_ai/adapters/remote_config.py +66 -0
- kiln_ai/adapters/repair/repair_task.py +1 -6
- kiln_ai/adapters/repair/test_repair_task.py +24 -3
- kiln_ai/adapters/test_adapter_registry.py +88 -28
- kiln_ai/adapters/test_ml_model_list.py +176 -0
- kiln_ai/adapters/test_prompt_adaptors.py +17 -7
- kiln_ai/adapters/test_prompt_builders.py +3 -16
- kiln_ai/adapters/test_provider_tools.py +69 -20
- kiln_ai/adapters/test_remote_config.py +100 -0
- kiln_ai/datamodel/__init__.py +0 -2
- kiln_ai/datamodel/datamodel_enums.py +38 -13
- kiln_ai/datamodel/eval.py +32 -0
- kiln_ai/datamodel/finetune.py +12 -8
- kiln_ai/datamodel/task.py +68 -7
- kiln_ai/datamodel/task_output.py +0 -2
- kiln_ai/datamodel/task_run.py +0 -2
- kiln_ai/datamodel/test_basemodel.py +2 -1
- kiln_ai/datamodel/test_dataset_split.py +0 -8
- kiln_ai/datamodel/test_eval_model.py +146 -4
- kiln_ai/datamodel/test_models.py +33 -10
- kiln_ai/datamodel/test_task.py +168 -2
- kiln_ai/utils/config.py +3 -2
- kiln_ai/utils/dataset_import.py +1 -1
- kiln_ai/utils/logging.py +166 -0
- kiln_ai/utils/test_config.py +23 -0
- kiln_ai/utils/test_dataset_import.py +30 -0
- {kiln_ai-0.16.0.dist-info → kiln_ai-0.18.0.dist-info}/METADATA +2 -2
- kiln_ai-0.18.0.dist-info/RECORD +115 -0
- kiln_ai-0.16.0.dist-info/RECORD +0 -108
- {kiln_ai-0.16.0.dist-info → kiln_ai-0.18.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.16.0.dist-info → kiln_ai-0.18.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -17,7 +17,7 @@ from kiln_ai.adapters.model_adapters.base_adapter import (
|
|
|
17
17
|
from kiln_ai.adapters.ollama_tools import ollama_online
|
|
18
18
|
from kiln_ai.adapters.test_prompt_adaptors import get_all_models_and_providers
|
|
19
19
|
from kiln_ai.datamodel import PromptId
|
|
20
|
-
from kiln_ai.datamodel.task import RunConfig
|
|
20
|
+
from kiln_ai.datamodel.task import RunConfig, RunConfigProperties
|
|
21
21
|
from kiln_ai.datamodel.test_json_schema import json_joke_schema, json_triangle_schema
|
|
22
22
|
|
|
23
23
|
|
|
@@ -51,6 +51,7 @@ class MockAdapter(BaseAdapter):
|
|
|
51
51
|
model_name="phi_3_5",
|
|
52
52
|
model_provider_name="ollama",
|
|
53
53
|
prompt_id="simple_chain_of_thought_prompt_builder",
|
|
54
|
+
structured_output_mode="json_schema",
|
|
54
55
|
),
|
|
55
56
|
)
|
|
56
57
|
self.response = response
|
|
@@ -146,7 +147,15 @@ def build_structured_output_test_task(tmp_path: Path):
|
|
|
146
147
|
|
|
147
148
|
async def run_structured_output_test(tmp_path: Path, model_name: str, provider: str):
|
|
148
149
|
task = build_structured_output_test_task(tmp_path)
|
|
149
|
-
a = adapter_for_task(
|
|
150
|
+
a = adapter_for_task(
|
|
151
|
+
task,
|
|
152
|
+
run_config_properties=RunConfigProperties(
|
|
153
|
+
model_name=model_name,
|
|
154
|
+
model_provider_name=provider,
|
|
155
|
+
prompt_id="simple_prompt_builder",
|
|
156
|
+
structured_output_mode="unknown",
|
|
157
|
+
),
|
|
158
|
+
)
|
|
150
159
|
try:
|
|
151
160
|
run = await a.invoke("Cows") # a joke about cows
|
|
152
161
|
parsed = json.loads(run.output.output)
|
|
@@ -197,10 +206,12 @@ def build_structured_input_test_task(tmp_path: Path):
|
|
|
197
206
|
return task
|
|
198
207
|
|
|
199
208
|
|
|
200
|
-
async def run_structured_input_test(
|
|
209
|
+
async def run_structured_input_test(
|
|
210
|
+
tmp_path: Path, model_name: str, provider: str, prompt_id: PromptId
|
|
211
|
+
):
|
|
201
212
|
task = build_structured_input_test_task(tmp_path)
|
|
202
213
|
try:
|
|
203
|
-
await run_structured_input_task(task, model_name, provider)
|
|
214
|
+
await run_structured_input_task(task, model_name, provider, prompt_id)
|
|
204
215
|
except ValueError as e:
|
|
205
216
|
if str(e) == "Failed to connect to Ollama. Ensure Ollama is running.":
|
|
206
217
|
pytest.skip(
|
|
@@ -209,17 +220,20 @@ async def run_structured_input_test(tmp_path: Path, model_name: str, provider: s
|
|
|
209
220
|
raise e
|
|
210
221
|
|
|
211
222
|
|
|
212
|
-
async def
|
|
223
|
+
async def run_structured_input_task_no_validation(
|
|
213
224
|
task: datamodel.Task,
|
|
214
225
|
model_name: str,
|
|
215
226
|
provider: str,
|
|
216
|
-
prompt_id: PromptId
|
|
227
|
+
prompt_id: PromptId,
|
|
217
228
|
):
|
|
218
229
|
a = adapter_for_task(
|
|
219
230
|
task,
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
231
|
+
run_config_properties=RunConfigProperties(
|
|
232
|
+
model_name=model_name,
|
|
233
|
+
model_provider_name=provider,
|
|
234
|
+
prompt_id=prompt_id,
|
|
235
|
+
structured_output_mode="unknown",
|
|
236
|
+
),
|
|
223
237
|
)
|
|
224
238
|
with pytest.raises(ValueError):
|
|
225
239
|
# not structured input in dictionary
|
|
@@ -231,18 +245,29 @@ async def run_structured_input_task(
|
|
|
231
245
|
try:
|
|
232
246
|
run = await a.invoke({"a": 2, "b": 2, "c": 2})
|
|
233
247
|
response = run.output.output
|
|
248
|
+
return response, a, run
|
|
234
249
|
except ValueError as e:
|
|
235
250
|
if str(e) == "Failed to connect to Ollama. Ensure Ollama is running.":
|
|
236
251
|
pytest.skip(
|
|
237
252
|
f"Skipping {model_name} {provider} because Ollama is not running"
|
|
238
253
|
)
|
|
239
254
|
raise e
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
async def run_structured_input_task(
|
|
258
|
+
task: datamodel.Task,
|
|
259
|
+
model_name: str,
|
|
260
|
+
provider: str,
|
|
261
|
+
prompt_id: PromptId,
|
|
262
|
+
):
|
|
263
|
+
response, a, run = await run_structured_input_task_no_validation(
|
|
264
|
+
task, model_name, provider, prompt_id
|
|
265
|
+
)
|
|
240
266
|
assert response is not None
|
|
241
267
|
if isinstance(response, str):
|
|
242
268
|
assert "[[equilateral]]" in response
|
|
243
269
|
else:
|
|
244
270
|
assert response["is_equilateral"] is True
|
|
245
|
-
|
|
246
271
|
expected_pb_name = "simple_prompt_builder"
|
|
247
272
|
if prompt_id is not None:
|
|
248
273
|
expected_pb_name = prompt_id
|
|
@@ -269,7 +294,9 @@ async def test_structured_input_gpt_4o_mini(tmp_path):
|
|
|
269
294
|
async def test_all_built_in_models_structured_input(
|
|
270
295
|
tmp_path, model_name, provider_name
|
|
271
296
|
):
|
|
272
|
-
await run_structured_input_test(
|
|
297
|
+
await run_structured_input_test(
|
|
298
|
+
tmp_path, model_name, provider_name, "simple_prompt_builder"
|
|
299
|
+
)
|
|
273
300
|
|
|
274
301
|
|
|
275
302
|
@pytest.mark.paid
|
|
@@ -323,6 +350,11 @@ When asked for a final result, this is the format (for an equilateral example):
|
|
|
323
350
|
"""
|
|
324
351
|
task.output_json_schema = json.dumps(triangle_schema)
|
|
325
352
|
task.save_to_file()
|
|
326
|
-
await
|
|
353
|
+
response, adapter, _ = await run_structured_input_task_no_validation(
|
|
327
354
|
task, model_name, provider_name, "simple_chain_of_thought_prompt_builder"
|
|
328
355
|
)
|
|
356
|
+
|
|
357
|
+
formatted_response = json.loads(response)
|
|
358
|
+
assert formatted_response["is_equilateral"] is True
|
|
359
|
+
assert formatted_response["is_scalene"] is False
|
|
360
|
+
assert formatted_response["is_obtuse"] is False
|
|
@@ -1,6 +1,4 @@
|
|
|
1
|
-
import json
|
|
2
1
|
from abc import ABCMeta, abstractmethod
|
|
3
|
-
from typing import Dict
|
|
4
2
|
|
|
5
3
|
from kiln_ai.datamodel import PromptGenerators, PromptId, Task, TaskRun
|
|
6
4
|
from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
|
|
@@ -53,20 +51,6 @@ class BasePromptBuilder(metaclass=ABCMeta):
|
|
|
53
51
|
"""
|
|
54
52
|
pass
|
|
55
53
|
|
|
56
|
-
def build_user_message(self, input: Dict | str) -> str:
|
|
57
|
-
"""Build a user message from the input.
|
|
58
|
-
|
|
59
|
-
Args:
|
|
60
|
-
input (Union[Dict, str]): The input to format into a message.
|
|
61
|
-
|
|
62
|
-
Returns:
|
|
63
|
-
str: The formatted user message.
|
|
64
|
-
"""
|
|
65
|
-
if isinstance(input, Dict):
|
|
66
|
-
return f"The input is:\n{json.dumps(input, indent=2, ensure_ascii=False)}"
|
|
67
|
-
|
|
68
|
-
return f"The input is:\n{input}"
|
|
69
|
-
|
|
70
54
|
def chain_of_thought_prompt(self) -> str | None:
|
|
71
55
|
"""Build and return the chain of thought prompt string.
|
|
72
56
|
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import logging
|
|
1
2
|
from dataclasses import dataclass
|
|
2
3
|
from typing import Dict, List
|
|
3
4
|
|
|
@@ -16,11 +17,15 @@ from kiln_ai.adapters.model_adapters.litellm_config import (
|
|
|
16
17
|
from kiln_ai.adapters.ollama_tools import (
|
|
17
18
|
get_ollama_connection,
|
|
18
19
|
)
|
|
19
|
-
from kiln_ai.datamodel import Finetune,
|
|
20
|
+
from kiln_ai.datamodel import Finetune, Task
|
|
21
|
+
from kiln_ai.datamodel.datamodel_enums import ChatStrategy
|
|
20
22
|
from kiln_ai.datamodel.registry import project_from_id
|
|
23
|
+
from kiln_ai.datamodel.task import RunConfigProperties
|
|
21
24
|
from kiln_ai.utils.config import Config
|
|
22
25
|
from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
|
|
23
26
|
|
|
27
|
+
logger = logging.getLogger(__name__)
|
|
28
|
+
|
|
24
29
|
|
|
25
30
|
async def provider_enabled(provider_name: ModelProviderName) -> bool:
|
|
26
31
|
if provider_name == ModelProviderName.ollama:
|
|
@@ -163,6 +168,10 @@ def kiln_model_provider_from(
|
|
|
163
168
|
# For custom registry, get the provider name and model name from the model id
|
|
164
169
|
if provider_name == ModelProviderName.kiln_custom_registry:
|
|
165
170
|
provider_name, name = parse_custom_model_id(name)
|
|
171
|
+
else:
|
|
172
|
+
logger.warning(
|
|
173
|
+
f"Unexpected model/provider pair. Will treat as custom model but check your model settings. Provider: {provider_name}/{name}"
|
|
174
|
+
)
|
|
166
175
|
|
|
167
176
|
# Custom/untested model. Set untested, and build a ModelProvider at runtime
|
|
168
177
|
if provider_name is None:
|
|
@@ -177,12 +186,15 @@ def kiln_model_provider_from(
|
|
|
177
186
|
supports_data_gen=False,
|
|
178
187
|
untested_model=True,
|
|
179
188
|
model_id=name,
|
|
189
|
+
# We don't know the structured output mode for custom models, so we default to json_instructions which is the only one that works everywhere.
|
|
190
|
+
structured_output_mode=StructuredOutputMode.json_instructions,
|
|
180
191
|
)
|
|
181
192
|
|
|
182
193
|
|
|
183
|
-
def
|
|
184
|
-
|
|
194
|
+
def lite_llm_config_for_openai_compatible(
|
|
195
|
+
run_config_properties: RunConfigProperties,
|
|
185
196
|
) -> LiteLlmConfig:
|
|
197
|
+
model_id = run_config_properties.model_name
|
|
186
198
|
try:
|
|
187
199
|
openai_provider_name, model_id = model_id.split("::")
|
|
188
200
|
except Exception:
|
|
@@ -206,10 +218,16 @@ def lite_llm_config(
|
|
|
206
218
|
f"OpenAI compatible provider {openai_provider_name} has no base URL"
|
|
207
219
|
)
|
|
208
220
|
|
|
221
|
+
# Update a copy of the run config properties to use the openai compatible provider
|
|
222
|
+
updated_run_config_properties = run_config_properties.model_copy(deep=True)
|
|
223
|
+
updated_run_config_properties.model_provider_name = (
|
|
224
|
+
ModelProviderName.openai_compatible
|
|
225
|
+
)
|
|
226
|
+
updated_run_config_properties.model_name = model_id
|
|
227
|
+
|
|
209
228
|
return LiteLlmConfig(
|
|
210
229
|
# OpenAI compatible, with a custom base URL
|
|
211
|
-
|
|
212
|
-
provider_name=ModelProviderName.openai_compatible,
|
|
230
|
+
run_config_properties=updated_run_config_properties,
|
|
213
231
|
base_url=base_url,
|
|
214
232
|
additional_body_options={
|
|
215
233
|
"api_key": api_key,
|
|
@@ -259,9 +277,9 @@ def finetune_from_id(model_id: str) -> Finetune:
|
|
|
259
277
|
|
|
260
278
|
|
|
261
279
|
def parser_from_data_strategy(
|
|
262
|
-
data_strategy:
|
|
280
|
+
data_strategy: ChatStrategy,
|
|
263
281
|
) -> ModelParserID | None:
|
|
264
|
-
if data_strategy ==
|
|
282
|
+
if data_strategy == ChatStrategy.single_turn_r1_thinking:
|
|
265
283
|
return ModelParserID.r1_thinking
|
|
266
284
|
return None
|
|
267
285
|
|
|
@@ -279,10 +297,10 @@ def finetune_provider_model(
|
|
|
279
297
|
reasoning_capable=(
|
|
280
298
|
fine_tune.data_strategy
|
|
281
299
|
in [
|
|
282
|
-
|
|
283
|
-
FinetuneDataStrategy.final_and_intermediate_r1_compatible,
|
|
300
|
+
ChatStrategy.single_turn_r1_thinking,
|
|
284
301
|
]
|
|
285
302
|
),
|
|
303
|
+
tuned_chat_strategy=fine_tune.data_strategy,
|
|
286
304
|
)
|
|
287
305
|
|
|
288
306
|
if provider == ModelProviderName.vertex and fine_tune.fine_tune_model_id:
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import json
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
import threading
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import List
|
|
8
|
+
|
|
9
|
+
import requests
|
|
10
|
+
|
|
11
|
+
from .ml_model_list import KilnModel, built_in_models
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def serialize_config(models: List[KilnModel], path: str | Path) -> None:
|
|
17
|
+
data = {"model_list": [m.model_dump(mode="json") for m in models]}
|
|
18
|
+
Path(path).write_text(json.dumps(data, indent=2, sort_keys=True))
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def deserialize_config(path: str | Path) -> List[KilnModel]:
|
|
22
|
+
raw = json.loads(Path(path).read_text())
|
|
23
|
+
model_data = raw.get("model_list", raw if isinstance(raw, list) else [])
|
|
24
|
+
return [KilnModel.model_validate(item) for item in model_data]
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def load_from_url(url: str) -> List[KilnModel]:
|
|
28
|
+
response = requests.get(url, timeout=10)
|
|
29
|
+
response.raise_for_status()
|
|
30
|
+
data = response.json()
|
|
31
|
+
if isinstance(data, list):
|
|
32
|
+
model_data = data
|
|
33
|
+
else:
|
|
34
|
+
model_data = data.get("model_list", [])
|
|
35
|
+
return [KilnModel.model_validate(item) for item in model_data]
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def dump_builtin_config(path: str | Path) -> None:
|
|
39
|
+
serialize_config(built_in_models, path)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def load_remote_models(url: str) -> None:
|
|
43
|
+
if os.environ.get("KILN_SKIP_REMOTE_MODEL_LIST") == "true":
|
|
44
|
+
return
|
|
45
|
+
|
|
46
|
+
def fetch_and_replace() -> None:
|
|
47
|
+
try:
|
|
48
|
+
models = load_from_url(url)
|
|
49
|
+
built_in_models[:] = models
|
|
50
|
+
except Exception as exc:
|
|
51
|
+
# Do not crash startup, but surface the issue
|
|
52
|
+
logger.warning("Failed to fetch remote model list from %s: %s", url, exc)
|
|
53
|
+
|
|
54
|
+
thread = threading.Thread(target=fetch_and_replace, daemon=True)
|
|
55
|
+
thread.start()
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def main() -> None:
|
|
59
|
+
parser = argparse.ArgumentParser()
|
|
60
|
+
parser.add_argument("path", help="output path")
|
|
61
|
+
args = parser.parse_args()
|
|
62
|
+
dump_builtin_config(args.path)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
if __name__ == "__main__":
|
|
66
|
+
main()
|
|
@@ -1,13 +1,8 @@
|
|
|
1
1
|
import json
|
|
2
|
-
from typing import Type
|
|
3
2
|
|
|
4
3
|
from pydantic import BaseModel, Field
|
|
5
4
|
|
|
6
|
-
from kiln_ai.adapters.prompt_builders import
|
|
7
|
-
BasePromptBuilder,
|
|
8
|
-
SavedPromptBuilder,
|
|
9
|
-
prompt_builder_from_id,
|
|
10
|
-
)
|
|
5
|
+
from kiln_ai.adapters.prompt_builders import BasePromptBuilder, prompt_builder_from_id
|
|
11
6
|
from kiln_ai.datamodel import Priority, Project, Task, TaskRequirement, TaskRun
|
|
12
7
|
|
|
13
8
|
|
|
@@ -21,6 +21,7 @@ from kiln_ai.datamodel import (
|
|
|
21
21
|
TaskRequirement,
|
|
22
22
|
TaskRun,
|
|
23
23
|
)
|
|
24
|
+
from kiln_ai.datamodel.task import RunConfigProperties
|
|
24
25
|
|
|
25
26
|
json_joke_schema = """{
|
|
26
27
|
"type": "object",
|
|
@@ -189,7 +190,15 @@ async def test_live_run(sample_task, sample_task_run, sample_repair_data):
|
|
|
189
190
|
repair_task_input = RepairTaskRun.build_repair_task_input(**sample_repair_data)
|
|
190
191
|
assert isinstance(repair_task_input, RepairTaskInput)
|
|
191
192
|
|
|
192
|
-
adapter = adapter_for_task(
|
|
193
|
+
adapter = adapter_for_task(
|
|
194
|
+
repair_task,
|
|
195
|
+
RunConfigProperties(
|
|
196
|
+
model_name="llama_3_1_8b",
|
|
197
|
+
model_provider_name="groq",
|
|
198
|
+
prompt_id="simple_prompt_builder",
|
|
199
|
+
structured_output_mode="default",
|
|
200
|
+
),
|
|
201
|
+
)
|
|
193
202
|
|
|
194
203
|
run = await adapter.invoke(repair_task_input.model_dump())
|
|
195
204
|
assert run is not None
|
|
@@ -198,10 +207,13 @@ async def test_live_run(sample_task, sample_task_run, sample_repair_data):
|
|
|
198
207
|
assert "setup" in parsed_output
|
|
199
208
|
assert "punchline" in parsed_output
|
|
200
209
|
assert run.output.source.properties == {
|
|
201
|
-
"adapter_name": "
|
|
210
|
+
"adapter_name": "kiln_openai_compatible_adapter",
|
|
202
211
|
"model_name": "llama_3_1_8b",
|
|
203
212
|
"model_provider": "groq",
|
|
204
213
|
"prompt_id": "simple_prompt_builder",
|
|
214
|
+
"structured_output_mode": "default",
|
|
215
|
+
"temperature": 1.0,
|
|
216
|
+
"top_p": 1.0,
|
|
205
217
|
}
|
|
206
218
|
|
|
207
219
|
|
|
@@ -224,7 +236,13 @@ async def test_mocked_repair_task_run(sample_task, sample_task_run, sample_repai
|
|
|
224
236
|
)
|
|
225
237
|
|
|
226
238
|
adapter = adapter_for_task(
|
|
227
|
-
repair_task,
|
|
239
|
+
repair_task,
|
|
240
|
+
RunConfigProperties(
|
|
241
|
+
model_name="llama_3_1_8b",
|
|
242
|
+
model_provider_name="ollama",
|
|
243
|
+
prompt_id="simple_prompt_builder",
|
|
244
|
+
structured_output_mode="json_schema",
|
|
245
|
+
),
|
|
228
246
|
)
|
|
229
247
|
|
|
230
248
|
run = await adapter.invoke(repair_task_input.model_dump())
|
|
@@ -240,6 +258,9 @@ async def test_mocked_repair_task_run(sample_task, sample_task_run, sample_repai
|
|
|
240
258
|
"model_name": "llama_3_1_8b",
|
|
241
259
|
"model_provider": "ollama",
|
|
242
260
|
"prompt_id": "simple_prompt_builder",
|
|
261
|
+
"structured_output_mode": "json_schema",
|
|
262
|
+
"temperature": 1.0,
|
|
263
|
+
"top_p": 1.0,
|
|
243
264
|
}
|
|
244
265
|
assert run.input_source.type == DataSourceType.human
|
|
245
266
|
assert "created_by" in run.input_source.properties
|
|
@@ -7,8 +7,8 @@ from kiln_ai.adapters.adapter_registry import adapter_for_task
|
|
|
7
7
|
from kiln_ai.adapters.ml_model_list import ModelProviderName
|
|
8
8
|
from kiln_ai.adapters.model_adapters.base_adapter import AdapterConfig
|
|
9
9
|
from kiln_ai.adapters.model_adapters.litellm_adapter import LiteLlmAdapter
|
|
10
|
-
from kiln_ai.adapters.prompt_builders import BasePromptBuilder
|
|
11
10
|
from kiln_ai.adapters.provider_tools import kiln_model_provider_from
|
|
11
|
+
from kiln_ai.datamodel.task import RunConfigProperties
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
@pytest.fixture
|
|
@@ -35,18 +35,28 @@ def mock_finetune_from_id():
|
|
|
35
35
|
with patch("kiln_ai.adapters.provider_tools.finetune_from_id") as mock:
|
|
36
36
|
mock.return_value.provider = ModelProviderName.openai
|
|
37
37
|
mock.return_value.fine_tune_model_id = "test-model"
|
|
38
|
+
mock.return_value.data_strategy = "final_only"
|
|
38
39
|
yield mock
|
|
39
40
|
|
|
40
41
|
|
|
41
42
|
def test_openai_adapter_creation(mock_config, basic_task):
|
|
42
43
|
adapter = adapter_for_task(
|
|
43
|
-
kiln_task=basic_task,
|
|
44
|
+
kiln_task=basic_task,
|
|
45
|
+
run_config_properties=RunConfigProperties(
|
|
46
|
+
model_name="gpt-4",
|
|
47
|
+
model_provider_name=ModelProviderName.openai,
|
|
48
|
+
prompt_id="simple_prompt_builder",
|
|
49
|
+
structured_output_mode="json_schema",
|
|
50
|
+
),
|
|
44
51
|
)
|
|
45
52
|
|
|
46
53
|
assert isinstance(adapter, LiteLlmAdapter)
|
|
47
|
-
assert adapter.config.model_name == "gpt-4"
|
|
54
|
+
assert adapter.config.run_config_properties.model_name == "gpt-4"
|
|
48
55
|
assert adapter.config.additional_body_options == {"api_key": "test-openai-key"}
|
|
49
|
-
assert
|
|
56
|
+
assert (
|
|
57
|
+
adapter.config.run_config_properties.model_provider_name
|
|
58
|
+
== ModelProviderName.openai
|
|
59
|
+
)
|
|
50
60
|
assert adapter.config.base_url is None # OpenAI url is default
|
|
51
61
|
assert adapter.config.default_headers is None
|
|
52
62
|
|
|
@@ -54,14 +64,21 @@ def test_openai_adapter_creation(mock_config, basic_task):
|
|
|
54
64
|
def test_openrouter_adapter_creation(mock_config, basic_task):
|
|
55
65
|
adapter = adapter_for_task(
|
|
56
66
|
kiln_task=basic_task,
|
|
57
|
-
|
|
58
|
-
|
|
67
|
+
run_config_properties=RunConfigProperties(
|
|
68
|
+
model_name="anthropic/claude-3-opus",
|
|
69
|
+
model_provider_name=ModelProviderName.openrouter,
|
|
70
|
+
prompt_id="simple_prompt_builder",
|
|
71
|
+
structured_output_mode="json_schema",
|
|
72
|
+
),
|
|
59
73
|
)
|
|
60
74
|
|
|
61
75
|
assert isinstance(adapter, LiteLlmAdapter)
|
|
62
|
-
assert adapter.config.model_name == "anthropic/claude-3-opus"
|
|
76
|
+
assert adapter.config.run_config_properties.model_name == "anthropic/claude-3-opus"
|
|
63
77
|
assert adapter.config.additional_body_options == {"api_key": "test-openrouter-key"}
|
|
64
|
-
assert
|
|
78
|
+
assert (
|
|
79
|
+
adapter.config.run_config_properties.model_provider_name
|
|
80
|
+
== ModelProviderName.openrouter
|
|
81
|
+
)
|
|
65
82
|
assert adapter.config.default_headers == {
|
|
66
83
|
"HTTP-Referer": "https://getkiln.ai/openrouter",
|
|
67
84
|
"X-Title": "KilnAI",
|
|
@@ -79,7 +96,13 @@ def test_openrouter_adapter_creation(mock_config, basic_task):
|
|
|
79
96
|
)
|
|
80
97
|
def test_openai_compatible_adapter_creation(mock_config, basic_task, provider):
|
|
81
98
|
adapter = adapter_for_task(
|
|
82
|
-
kiln_task=basic_task,
|
|
99
|
+
kiln_task=basic_task,
|
|
100
|
+
run_config_properties=RunConfigProperties(
|
|
101
|
+
model_name="test-model",
|
|
102
|
+
model_provider_name=provider,
|
|
103
|
+
prompt_id="simple_prompt_builder",
|
|
104
|
+
structured_output_mode="json_schema",
|
|
105
|
+
),
|
|
83
106
|
)
|
|
84
107
|
|
|
85
108
|
assert isinstance(adapter, LiteLlmAdapter)
|
|
@@ -90,9 +113,12 @@ def test_openai_compatible_adapter_creation(mock_config, basic_task, provider):
|
|
|
90
113
|
def test_custom_prompt_builder(mock_config, basic_task):
|
|
91
114
|
adapter = adapter_for_task(
|
|
92
115
|
kiln_task=basic_task,
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
116
|
+
run_config_properties=RunConfigProperties(
|
|
117
|
+
model_name="gpt-4",
|
|
118
|
+
model_provider_name=ModelProviderName.openai,
|
|
119
|
+
prompt_id="simple_chain_of_thought_prompt_builder",
|
|
120
|
+
structured_output_mode="json_schema",
|
|
121
|
+
),
|
|
96
122
|
)
|
|
97
123
|
|
|
98
124
|
assert adapter.run_config.prompt_id == "simple_chain_of_thought_prompt_builder"
|
|
@@ -103,8 +129,12 @@ def test_tags_passed_through(mock_config, basic_task):
|
|
|
103
129
|
tags = ["test-tag-1", "test-tag-2"]
|
|
104
130
|
adapter = adapter_for_task(
|
|
105
131
|
kiln_task=basic_task,
|
|
106
|
-
|
|
107
|
-
|
|
132
|
+
run_config_properties=RunConfigProperties(
|
|
133
|
+
model_name="gpt-4",
|
|
134
|
+
model_provider_name=ModelProviderName.openai,
|
|
135
|
+
prompt_id="simple_prompt_builder",
|
|
136
|
+
structured_output_mode="json_schema",
|
|
137
|
+
),
|
|
108
138
|
base_adapter_config=AdapterConfig(
|
|
109
139
|
default_tags=tags,
|
|
110
140
|
),
|
|
@@ -114,13 +144,19 @@ def test_tags_passed_through(mock_config, basic_task):
|
|
|
114
144
|
|
|
115
145
|
|
|
116
146
|
def test_invalid_provider(mock_config, basic_task):
|
|
117
|
-
with pytest.raises(ValueError, match="
|
|
147
|
+
with pytest.raises(ValueError, match="Input should be"):
|
|
118
148
|
adapter_for_task(
|
|
119
|
-
kiln_task=basic_task,
|
|
149
|
+
kiln_task=basic_task,
|
|
150
|
+
run_config_properties=RunConfigProperties(
|
|
151
|
+
model_name="test-model",
|
|
152
|
+
model_provider_name="invalid",
|
|
153
|
+
prompt_id="simple_prompt_builder",
|
|
154
|
+
structured_output_mode="json_schema",
|
|
155
|
+
),
|
|
120
156
|
)
|
|
121
157
|
|
|
122
158
|
|
|
123
|
-
@patch("kiln_ai.adapters.adapter_registry.
|
|
159
|
+
@patch("kiln_ai.adapters.adapter_registry.lite_llm_config_for_openai_compatible")
|
|
124
160
|
def test_openai_compatible_adapter(mock_compatible_config, mock_config, basic_task):
|
|
125
161
|
mock_compatible_config.return_value.model_name = "test-model"
|
|
126
162
|
mock_compatible_config.return_value.additional_body_options = {
|
|
@@ -128,44 +164,68 @@ def test_openai_compatible_adapter(mock_compatible_config, mock_config, basic_ta
|
|
|
128
164
|
}
|
|
129
165
|
mock_compatible_config.return_value.base_url = "https://test.com/v1"
|
|
130
166
|
mock_compatible_config.return_value.provider_name = "CustomProvider99"
|
|
167
|
+
mock_compatible_config.return_value.run_config_properties = RunConfigProperties(
|
|
168
|
+
model_name="provider::test-model",
|
|
169
|
+
model_provider_name=ModelProviderName.openai_compatible,
|
|
170
|
+
prompt_id="simple_prompt_builder",
|
|
171
|
+
structured_output_mode="json_schema",
|
|
172
|
+
)
|
|
131
173
|
|
|
132
174
|
adapter = adapter_for_task(
|
|
133
175
|
kiln_task=basic_task,
|
|
134
|
-
|
|
135
|
-
|
|
176
|
+
run_config_properties=RunConfigProperties(
|
|
177
|
+
model_name="provider::test-model",
|
|
178
|
+
model_provider_name=ModelProviderName.openai_compatible,
|
|
179
|
+
prompt_id="simple_prompt_builder",
|
|
180
|
+
structured_output_mode="json_schema",
|
|
181
|
+
),
|
|
136
182
|
)
|
|
137
183
|
|
|
138
184
|
assert isinstance(adapter, LiteLlmAdapter)
|
|
139
|
-
mock_compatible_config.
|
|
185
|
+
mock_compatible_config.assert_called_once()
|
|
140
186
|
assert adapter.config == mock_compatible_config.return_value
|
|
141
187
|
|
|
142
188
|
|
|
143
189
|
def test_custom_openai_compatible_provider(mock_config, basic_task):
|
|
144
190
|
adapter = adapter_for_task(
|
|
145
191
|
kiln_task=basic_task,
|
|
146
|
-
|
|
147
|
-
|
|
192
|
+
run_config_properties=RunConfigProperties(
|
|
193
|
+
model_name="openai::test-model",
|
|
194
|
+
model_provider_name=ModelProviderName.kiln_custom_registry,
|
|
195
|
+
prompt_id="simple_prompt_builder",
|
|
196
|
+
structured_output_mode="json_schema",
|
|
197
|
+
),
|
|
148
198
|
)
|
|
149
199
|
|
|
150
200
|
assert isinstance(adapter, LiteLlmAdapter)
|
|
151
|
-
assert adapter.config.model_name == "openai::test-model"
|
|
201
|
+
assert adapter.config.run_config_properties.model_name == "openai::test-model"
|
|
152
202
|
assert adapter.config.additional_body_options == {"api_key": "test-openai-key"}
|
|
153
203
|
assert adapter.config.base_url is None # openai is none
|
|
154
|
-
assert
|
|
204
|
+
assert (
|
|
205
|
+
adapter.config.run_config_properties.model_provider_name
|
|
206
|
+
== ModelProviderName.kiln_custom_registry
|
|
207
|
+
)
|
|
155
208
|
|
|
156
209
|
|
|
157
210
|
async def test_fine_tune_provider(mock_config, basic_task, mock_finetune_from_id):
|
|
158
211
|
adapter = adapter_for_task(
|
|
159
212
|
kiln_task=basic_task,
|
|
160
|
-
|
|
161
|
-
|
|
213
|
+
run_config_properties=RunConfigProperties(
|
|
214
|
+
model_name="proj::task::tune",
|
|
215
|
+
model_provider_name=ModelProviderName.kiln_fine_tune,
|
|
216
|
+
prompt_id="simple_prompt_builder",
|
|
217
|
+
structured_output_mode="json_schema",
|
|
218
|
+
),
|
|
162
219
|
)
|
|
163
220
|
|
|
164
221
|
mock_finetune_from_id.assert_called_once_with("proj::task::tune")
|
|
165
222
|
assert isinstance(adapter, LiteLlmAdapter)
|
|
166
|
-
assert
|
|
223
|
+
assert (
|
|
224
|
+
adapter.config.run_config_properties.model_provider_name
|
|
225
|
+
== ModelProviderName.kiln_fine_tune
|
|
226
|
+
)
|
|
167
227
|
# Kiln model name here, but the underlying openai model id below
|
|
168
|
-
assert adapter.config.model_name == "proj::task::tune"
|
|
228
|
+
assert adapter.config.run_config_properties.model_name == "proj::task::tune"
|
|
169
229
|
|
|
170
230
|
provider = kiln_model_provider_from(
|
|
171
231
|
"proj::task::tune", provider_name=ModelProviderName.kiln_fine_tune
|