kiln-ai 0.11.1__py3-none-any.whl → 0.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/adapter_registry.py +12 -13
- kiln_ai/adapters/data_gen/data_gen_task.py +18 -0
- kiln_ai/adapters/eval/base_eval.py +164 -0
- kiln_ai/adapters/eval/eval_runner.py +267 -0
- kiln_ai/adapters/eval/g_eval.py +367 -0
- kiln_ai/adapters/eval/registry.py +16 -0
- kiln_ai/adapters/eval/test_base_eval.py +324 -0
- kiln_ai/adapters/eval/test_eval_runner.py +640 -0
- kiln_ai/adapters/eval/test_g_eval.py +497 -0
- kiln_ai/adapters/eval/test_g_eval_data.py +4 -0
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +4 -1
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +1 -1
- kiln_ai/adapters/fine_tune/test_openai_finetune.py +1 -1
- kiln_ai/adapters/ml_model_list.py +141 -29
- kiln_ai/adapters/model_adapters/base_adapter.py +50 -35
- kiln_ai/adapters/model_adapters/langchain_adapters.py +27 -20
- kiln_ai/adapters/model_adapters/openai_compatible_config.py +0 -1
- kiln_ai/adapters/model_adapters/openai_model_adapter.py +93 -50
- kiln_ai/adapters/model_adapters/test_base_adapter.py +22 -13
- kiln_ai/adapters/model_adapters/test_langchain_adapter.py +7 -14
- kiln_ai/adapters/model_adapters/test_openai_model_adapter.py +55 -64
- kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +41 -19
- kiln_ai/adapters/model_adapters/test_structured_output.py +36 -30
- kiln_ai/adapters/ollama_tools.py +0 -1
- kiln_ai/adapters/prompt_builders.py +80 -42
- kiln_ai/adapters/repair/repair_task.py +9 -21
- kiln_ai/adapters/repair/test_repair_task.py +3 -3
- kiln_ai/adapters/run_output.py +3 -0
- kiln_ai/adapters/test_adapter_registry.py +10 -10
- kiln_ai/adapters/test_generate_docs.py +6 -6
- kiln_ai/adapters/test_ollama_tools.py +0 -1
- kiln_ai/adapters/test_prompt_adaptors.py +17 -14
- kiln_ai/adapters/test_prompt_builders.py +91 -31
- kiln_ai/datamodel/__init__.py +50 -952
- kiln_ai/datamodel/datamodel_enums.py +58 -0
- kiln_ai/datamodel/dataset_filters.py +114 -0
- kiln_ai/datamodel/dataset_split.py +170 -0
- kiln_ai/datamodel/eval.py +298 -0
- kiln_ai/datamodel/finetune.py +105 -0
- kiln_ai/datamodel/json_schema.py +6 -0
- kiln_ai/datamodel/project.py +23 -0
- kiln_ai/datamodel/prompt.py +37 -0
- kiln_ai/datamodel/prompt_id.py +83 -0
- kiln_ai/datamodel/strict_mode.py +24 -0
- kiln_ai/datamodel/task.py +181 -0
- kiln_ai/datamodel/task_output.py +321 -0
- kiln_ai/datamodel/task_run.py +164 -0
- kiln_ai/datamodel/test_basemodel.py +10 -11
- kiln_ai/datamodel/test_dataset_filters.py +71 -0
- kiln_ai/datamodel/test_dataset_split.py +32 -8
- kiln_ai/datamodel/test_datasource.py +3 -2
- kiln_ai/datamodel/test_eval_model.py +635 -0
- kiln_ai/datamodel/test_example_models.py +9 -13
- kiln_ai/datamodel/test_json_schema.py +23 -0
- kiln_ai/datamodel/test_models.py +2 -2
- kiln_ai/datamodel/test_prompt_id.py +129 -0
- kiln_ai/datamodel/test_task.py +159 -0
- kiln_ai/utils/config.py +6 -1
- {kiln_ai-0.11.1.dist-info → kiln_ai-0.12.0.dist-info}/METADATA +37 -1
- kiln_ai-0.12.0.dist-info/RECORD +100 -0
- kiln_ai-0.11.1.dist-info/RECORD +0 -76
- {kiln_ai-0.11.1.dist-info → kiln_ai-0.12.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.11.1.dist-info → kiln_ai-0.12.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -2,8 +2,6 @@ import json
|
|
|
2
2
|
from pathlib import Path
|
|
3
3
|
from typing import Dict
|
|
4
4
|
|
|
5
|
-
import jsonschema
|
|
6
|
-
import jsonschema.exceptions
|
|
7
5
|
import pytest
|
|
8
6
|
|
|
9
7
|
import kiln_ai.datamodel as datamodel
|
|
@@ -12,16 +10,13 @@ from kiln_ai.adapters.ml_model_list import (
|
|
|
12
10
|
built_in_models,
|
|
13
11
|
)
|
|
14
12
|
from kiln_ai.adapters.model_adapters.base_adapter import (
|
|
15
|
-
AdapterInfo,
|
|
16
13
|
BaseAdapter,
|
|
17
14
|
RunOutput,
|
|
18
15
|
)
|
|
19
16
|
from kiln_ai.adapters.ollama_tools import ollama_online
|
|
20
|
-
from kiln_ai.adapters.prompt_builders import (
|
|
21
|
-
BasePromptBuilder,
|
|
22
|
-
SimpleChainOfThoughtPromptBuilder,
|
|
23
|
-
)
|
|
24
17
|
from kiln_ai.adapters.test_prompt_adaptors import get_all_models_and_providers
|
|
18
|
+
from kiln_ai.datamodel import PromptId
|
|
19
|
+
from kiln_ai.datamodel.task import RunConfig
|
|
25
20
|
from kiln_ai.datamodel.test_json_schema import json_joke_schema, json_triangle_schema
|
|
26
21
|
|
|
27
22
|
|
|
@@ -39,9 +34,9 @@ async def test_structured_output_gpt_4o_mini(tmp_path):
|
|
|
39
34
|
await run_structured_output_test(tmp_path, "gpt_4o_mini", "openai")
|
|
40
35
|
|
|
41
36
|
|
|
42
|
-
@pytest.mark.parametrize("model_name", ["llama_3_1_8b"])
|
|
37
|
+
@pytest.mark.parametrize("model_name", ["llama_3_1_8b", "gemma_2_2b"])
|
|
43
38
|
@pytest.mark.ollama
|
|
44
|
-
async def
|
|
39
|
+
async def test_structured_output_ollama(tmp_path, model_name):
|
|
45
40
|
if not await ollama_online():
|
|
46
41
|
pytest.skip("Ollama API not running. Expect it running on localhost:11434")
|
|
47
42
|
await run_structured_output_test(tmp_path, model_name, "ollama")
|
|
@@ -49,19 +44,21 @@ async def test_structured_output_ollama_llama(tmp_path, model_name):
|
|
|
49
44
|
|
|
50
45
|
class MockAdapter(BaseAdapter):
|
|
51
46
|
def __init__(self, kiln_task: datamodel.Task, response: Dict | str | None):
|
|
52
|
-
super().__init__(
|
|
47
|
+
super().__init__(
|
|
48
|
+
run_config=RunConfig(
|
|
49
|
+
task=kiln_task,
|
|
50
|
+
model_name="phi_3_5",
|
|
51
|
+
model_provider_name="ollama",
|
|
52
|
+
prompt_id="simple_chain_of_thought_prompt_builder",
|
|
53
|
+
),
|
|
54
|
+
)
|
|
53
55
|
self.response = response
|
|
54
56
|
|
|
55
57
|
async def _run(self, input: str) -> RunOutput:
|
|
56
58
|
return RunOutput(output=self.response, intermediate_outputs=None)
|
|
57
59
|
|
|
58
|
-
def
|
|
59
|
-
return
|
|
60
|
-
adapter_name="mock_adapter",
|
|
61
|
-
model_name="mock_model",
|
|
62
|
-
model_provider="mock_provider",
|
|
63
|
-
prompt_builder_name="mock_prompt_builder",
|
|
64
|
-
)
|
|
60
|
+
def adapter_name(self) -> str:
|
|
61
|
+
return "mock_adapter"
|
|
65
62
|
|
|
66
63
|
|
|
67
64
|
async def test_mock_unstructred_response(tmp_path):
|
|
@@ -204,15 +201,21 @@ async def run_structured_input_task(
|
|
|
204
201
|
task: datamodel.Task,
|
|
205
202
|
model_name: str,
|
|
206
203
|
provider: str,
|
|
207
|
-
|
|
204
|
+
prompt_id: PromptId | None = None,
|
|
208
205
|
):
|
|
209
206
|
a = adapter_for_task(
|
|
210
|
-
task,
|
|
207
|
+
task,
|
|
208
|
+
model_name=model_name,
|
|
209
|
+
provider=provider,
|
|
210
|
+
prompt_id=prompt_id,
|
|
211
211
|
)
|
|
212
212
|
with pytest.raises(ValueError):
|
|
213
213
|
# not structured input in dictionary
|
|
214
214
|
await a.invoke("a=1, b=2, c=3")
|
|
215
|
-
with pytest.raises(
|
|
215
|
+
with pytest.raises(
|
|
216
|
+
ValueError,
|
|
217
|
+
match="This task requires a specific output schema. While the model produced JSON, that JSON didn't meet the schema.",
|
|
218
|
+
):
|
|
216
219
|
# invalid structured input
|
|
217
220
|
await a.invoke({"a": 1, "b": 2, "d": 3})
|
|
218
221
|
|
|
@@ -229,13 +232,14 @@ async def run_structured_input_task(
|
|
|
229
232
|
assert "[[equilateral]]" in response
|
|
230
233
|
else:
|
|
231
234
|
assert response["is_equilateral"] is True
|
|
232
|
-
|
|
235
|
+
|
|
233
236
|
expected_pb_name = "simple_prompt_builder"
|
|
234
|
-
if
|
|
235
|
-
expected_pb_name =
|
|
236
|
-
assert
|
|
237
|
-
|
|
238
|
-
assert
|
|
237
|
+
if prompt_id is not None:
|
|
238
|
+
expected_pb_name = prompt_id
|
|
239
|
+
assert a.run_config.prompt_id == expected_pb_name
|
|
240
|
+
|
|
241
|
+
assert a.run_config.model_name == model_name
|
|
242
|
+
assert a.run_config.model_provider_name == provider
|
|
239
243
|
|
|
240
244
|
|
|
241
245
|
@pytest.mark.paid
|
|
@@ -257,8 +261,9 @@ async def test_all_built_in_models_structured_input(
|
|
|
257
261
|
@pytest.mark.parametrize("model_name,provider_name", get_all_models_and_providers())
|
|
258
262
|
async def test_structured_input_cot_prompt_builder(tmp_path, model_name, provider_name):
|
|
259
263
|
task = build_structured_input_test_task(tmp_path)
|
|
260
|
-
|
|
261
|
-
|
|
264
|
+
await run_structured_input_task(
|
|
265
|
+
task, model_name, provider_name, "simple_chain_of_thought_prompt_builder"
|
|
266
|
+
)
|
|
262
267
|
|
|
263
268
|
|
|
264
269
|
@pytest.mark.paid
|
|
@@ -302,5 +307,6 @@ When asked for a final result, this is the format (for an equilateral example):
|
|
|
302
307
|
"""
|
|
303
308
|
task.output_json_schema = json.dumps(triangle_schema)
|
|
304
309
|
task.save_to_file()
|
|
305
|
-
|
|
306
|
-
|
|
310
|
+
await run_structured_input_task(
|
|
311
|
+
task, model_name, provider_name, "simple_chain_of_thought_prompt_builder"
|
|
312
|
+
)
|
kiln_ai/adapters/ollama_tools.py
CHANGED
|
@@ -2,8 +2,8 @@ import json
|
|
|
2
2
|
from abc import ABCMeta, abstractmethod
|
|
3
3
|
from typing import Dict
|
|
4
4
|
|
|
5
|
-
from kiln_ai.datamodel import Task, TaskRun
|
|
6
|
-
from kiln_ai.utils.
|
|
5
|
+
from kiln_ai.datamodel import PromptGenerators, PromptId, Task, TaskRun
|
|
6
|
+
from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
|
|
7
7
|
|
|
8
8
|
|
|
9
9
|
class BasePromptBuilder(metaclass=ABCMeta):
|
|
@@ -53,17 +53,6 @@ class BasePromptBuilder(metaclass=ABCMeta):
|
|
|
53
53
|
"""
|
|
54
54
|
pass
|
|
55
55
|
|
|
56
|
-
@classmethod
|
|
57
|
-
def prompt_builder_name(cls) -> str:
|
|
58
|
-
"""Returns the name of the prompt builder, to be used for persisting into the datastore.
|
|
59
|
-
|
|
60
|
-
Default implementation gets the name of the prompt builder in snake case. If you change the class name, you should override this so prior saved data is compatible.
|
|
61
|
-
|
|
62
|
-
Returns:
|
|
63
|
-
str: The prompt builder name in snake_case format.
|
|
64
|
-
"""
|
|
65
|
-
return snake_case(cls.__name__)
|
|
66
|
-
|
|
67
56
|
def build_user_message(self, input: Dict | str) -> str:
|
|
68
57
|
"""Build a user message from the input.
|
|
69
58
|
|
|
@@ -300,6 +289,57 @@ class SavedPromptBuilder(BasePromptBuilder):
|
|
|
300
289
|
return self.prompt_model.chain_of_thought_instructions
|
|
301
290
|
|
|
302
291
|
|
|
292
|
+
class TaskRunConfigPromptBuilder(BasePromptBuilder):
|
|
293
|
+
"""A prompt builder that looks up a static prompt in a task run config."""
|
|
294
|
+
|
|
295
|
+
def __init__(self, task: Task, run_config_prompt_id: str):
|
|
296
|
+
parts = run_config_prompt_id.split("::")
|
|
297
|
+
if len(parts) != 4:
|
|
298
|
+
raise ValueError(
|
|
299
|
+
f"Invalid task run config prompt ID: {run_config_prompt_id}. Expected format: 'task_run_config::[project_id]::[task_id]::[run_config_id]'."
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
task_id = parts[2]
|
|
303
|
+
if task_id != task.id:
|
|
304
|
+
raise ValueError(
|
|
305
|
+
f"Task run config prompt ID: {run_config_prompt_id}. Task ID mismatch. Expected: {task.id}, got: {task_id}."
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
run_config_id = parts[3]
|
|
309
|
+
run_config = next(
|
|
310
|
+
(
|
|
311
|
+
run_config
|
|
312
|
+
for run_config in task.run_configs(readonly=True)
|
|
313
|
+
if run_config.id == run_config_id
|
|
314
|
+
),
|
|
315
|
+
None,
|
|
316
|
+
)
|
|
317
|
+
if not run_config:
|
|
318
|
+
raise ValueError(
|
|
319
|
+
f"Task run config ID not found: {run_config_id} for prompt id {run_config_prompt_id}"
|
|
320
|
+
)
|
|
321
|
+
if run_config.prompt is None:
|
|
322
|
+
raise ValueError(
|
|
323
|
+
f"Task run config ID {run_config_id} does not have a stored prompt. Used as prompt id {run_config_prompt_id}"
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
# Load the prompt from the model
|
|
327
|
+
self.prompt = run_config.prompt.prompt
|
|
328
|
+
self.cot_prompt = run_config.prompt.chain_of_thought_instructions
|
|
329
|
+
self.id = run_config_prompt_id
|
|
330
|
+
|
|
331
|
+
super().__init__(task)
|
|
332
|
+
|
|
333
|
+
def prompt_id(self) -> str | None:
|
|
334
|
+
return self.id
|
|
335
|
+
|
|
336
|
+
def build_base_prompt(self) -> str:
|
|
337
|
+
return self.prompt
|
|
338
|
+
|
|
339
|
+
def chain_of_thought_prompt(self) -> str | None:
|
|
340
|
+
return self.cot_prompt
|
|
341
|
+
|
|
342
|
+
|
|
303
343
|
class FineTunePromptBuilder(BasePromptBuilder):
|
|
304
344
|
"""A prompt builder that looks up a fine-tune prompt."""
|
|
305
345
|
|
|
@@ -337,25 +377,12 @@ class FineTunePromptBuilder(BasePromptBuilder):
|
|
|
337
377
|
return self.fine_tune_model.thinking_instructions
|
|
338
378
|
|
|
339
379
|
|
|
340
|
-
# TODO P2: we end up with 2 IDs for these: the keys here (ui_name) and the prompt_builder_name from the class
|
|
341
|
-
# We end up maintaining this in _prompt_generators as well.
|
|
342
|
-
prompt_builder_registry = {
|
|
343
|
-
"simple_prompt_builder": SimplePromptBuilder,
|
|
344
|
-
"multi_shot_prompt_builder": MultiShotPromptBuilder,
|
|
345
|
-
"few_shot_prompt_builder": FewShotPromptBuilder,
|
|
346
|
-
"repairs_prompt_builder": RepairsPromptBuilder,
|
|
347
|
-
"simple_chain_of_thought_prompt_builder": SimpleChainOfThoughtPromptBuilder,
|
|
348
|
-
"few_shot_chain_of_thought_prompt_builder": FewShotChainOfThoughtPromptBuilder,
|
|
349
|
-
"multi_shot_chain_of_thought_prompt_builder": MultiShotChainOfThoughtPromptBuilder,
|
|
350
|
-
}
|
|
351
|
-
|
|
352
|
-
|
|
353
380
|
# Our UI has some names that are not the same as the class names, which also hint parameters.
|
|
354
|
-
def
|
|
381
|
+
def prompt_builder_from_id(prompt_id: PromptId, task: Task) -> BasePromptBuilder:
|
|
355
382
|
"""Convert a name used in the UI to the corresponding prompt builder class.
|
|
356
383
|
|
|
357
384
|
Args:
|
|
358
|
-
|
|
385
|
+
prompt_id (PromptId): The prompt ID.
|
|
359
386
|
|
|
360
387
|
Returns:
|
|
361
388
|
type[BasePromptBuilder]: The corresponding prompt builder class.
|
|
@@ -365,29 +392,40 @@ def prompt_builder_from_ui_name(ui_name: str, task: Task) -> BasePromptBuilder:
|
|
|
365
392
|
"""
|
|
366
393
|
|
|
367
394
|
# Saved prompts are prefixed with "id::"
|
|
368
|
-
if
|
|
369
|
-
prompt_id =
|
|
395
|
+
if prompt_id.startswith("id::"):
|
|
396
|
+
prompt_id = prompt_id[4:]
|
|
370
397
|
return SavedPromptBuilder(task, prompt_id)
|
|
371
398
|
|
|
399
|
+
# Task run config prompts are prefixed with "task_run_config::"
|
|
400
|
+
# task_run_config::[project_id]::[task_id]::[run_config_id]
|
|
401
|
+
if prompt_id.startswith("task_run_config::"):
|
|
402
|
+
return TaskRunConfigPromptBuilder(task, prompt_id)
|
|
403
|
+
|
|
372
404
|
# Fine-tune prompts are prefixed with "fine_tune_prompt::"
|
|
373
|
-
if
|
|
374
|
-
|
|
375
|
-
return FineTunePromptBuilder(task,
|
|
405
|
+
if prompt_id.startswith("fine_tune_prompt::"):
|
|
406
|
+
prompt_id = prompt_id[18:]
|
|
407
|
+
return FineTunePromptBuilder(task, prompt_id)
|
|
408
|
+
|
|
409
|
+
# Check if the prompt_id matches any enum value
|
|
410
|
+
if prompt_id not in [member.value for member in PromptGenerators]:
|
|
411
|
+
raise ValueError(f"Unknown prompt generator: {prompt_id}")
|
|
412
|
+
typed_prompt_generator = PromptGenerators(prompt_id)
|
|
376
413
|
|
|
377
|
-
match
|
|
378
|
-
case
|
|
414
|
+
match typed_prompt_generator:
|
|
415
|
+
case PromptGenerators.SIMPLE:
|
|
379
416
|
return SimplePromptBuilder(task)
|
|
380
|
-
case
|
|
417
|
+
case PromptGenerators.FEW_SHOT:
|
|
381
418
|
return FewShotPromptBuilder(task)
|
|
382
|
-
case
|
|
419
|
+
case PromptGenerators.MULTI_SHOT:
|
|
383
420
|
return MultiShotPromptBuilder(task)
|
|
384
|
-
case
|
|
421
|
+
case PromptGenerators.REPAIRS:
|
|
385
422
|
return RepairsPromptBuilder(task)
|
|
386
|
-
case
|
|
423
|
+
case PromptGenerators.SIMPLE_CHAIN_OF_THOUGHT:
|
|
387
424
|
return SimpleChainOfThoughtPromptBuilder(task)
|
|
388
|
-
case
|
|
425
|
+
case PromptGenerators.FEW_SHOT_CHAIN_OF_THOUGHT:
|
|
389
426
|
return FewShotChainOfThoughtPromptBuilder(task)
|
|
390
|
-
case
|
|
427
|
+
case PromptGenerators.MULTI_SHOT_CHAIN_OF_THOUGHT:
|
|
391
428
|
return MultiShotChainOfThoughtPromptBuilder(task)
|
|
392
429
|
case _:
|
|
393
|
-
|
|
430
|
+
# Type checking will find missing cases
|
|
431
|
+
raise_exhaustive_enum_error(typed_prompt_generator)
|
|
@@ -6,7 +6,7 @@ from pydantic import BaseModel, Field
|
|
|
6
6
|
from kiln_ai.adapters.prompt_builders import (
|
|
7
7
|
BasePromptBuilder,
|
|
8
8
|
SavedPromptBuilder,
|
|
9
|
-
|
|
9
|
+
prompt_builder_from_id,
|
|
10
10
|
)
|
|
11
11
|
from kiln_ai.datamodel import Priority, Project, Task, TaskRequirement, TaskRun
|
|
12
12
|
|
|
@@ -49,28 +49,16 @@ feedback describing what should be improved. Your job is to understand the evalu
|
|
|
49
49
|
if run.output.source is None or run.output.source.properties is None:
|
|
50
50
|
raise ValueError("No source properties found")
|
|
51
51
|
|
|
52
|
-
#
|
|
53
|
-
prompt_id = run.output.source.properties.get(
|
|
52
|
+
# Get the prompt builder id. Need the second check because we used to store this in a prompt_builder_name field, so loading legacy runs will need this.
|
|
53
|
+
prompt_id = run.output.source.properties.get(
|
|
54
|
+
"prompt_id"
|
|
55
|
+
) or run.output.source.properties.get("prompt_builder_name", None)
|
|
54
56
|
if prompt_id is not None and isinstance(prompt_id, str):
|
|
55
|
-
|
|
56
|
-
|
|
57
|
+
prompt_builder = prompt_builder_from_id(prompt_id, task)
|
|
58
|
+
if isinstance(prompt_builder, BasePromptBuilder):
|
|
59
|
+
return prompt_builder.build_prompt(include_json_instructions=False)
|
|
57
60
|
|
|
58
|
-
|
|
59
|
-
prompt_builder_name = run.output.source.properties.get(
|
|
60
|
-
"prompt_builder_name", None
|
|
61
|
-
)
|
|
62
|
-
if prompt_builder_name is not None and isinstance(prompt_builder_name, str):
|
|
63
|
-
prompt_builder_class = prompt_builder_registry.get(
|
|
64
|
-
prompt_builder_name, None
|
|
65
|
-
)
|
|
66
|
-
if prompt_builder_class is None:
|
|
67
|
-
raise ValueError(f"No prompt builder found for name: {prompt_builder_name}")
|
|
68
|
-
prompt_builder = prompt_builder_class(task=task)
|
|
69
|
-
if not isinstance(prompt_builder, BasePromptBuilder):
|
|
70
|
-
raise ValueError(
|
|
71
|
-
f"Prompt builder {prompt_builder_name} is not a valid prompt builder"
|
|
72
|
-
)
|
|
73
|
-
return prompt_builder.build_prompt(include_json_instructions=False)
|
|
61
|
+
raise ValueError(f"Prompt builder '{prompt_id}' is not a valid prompt builder")
|
|
74
62
|
|
|
75
63
|
@classmethod
|
|
76
64
|
def build_repair_task_input(
|
|
@@ -95,7 +95,7 @@ def sample_task_run(sample_task):
|
|
|
95
95
|
"model_name": "gpt_4o",
|
|
96
96
|
"model_provider": "openai",
|
|
97
97
|
"adapter_name": "langchain_adapter",
|
|
98
|
-
"
|
|
98
|
+
"prompt_id": "simple_prompt_builder",
|
|
99
99
|
},
|
|
100
100
|
),
|
|
101
101
|
),
|
|
@@ -201,7 +201,7 @@ async def test_live_run(sample_task, sample_task_run, sample_repair_data):
|
|
|
201
201
|
"adapter_name": "kiln_langchain_adapter",
|
|
202
202
|
"model_name": "llama_3_1_8b",
|
|
203
203
|
"model_provider": "groq",
|
|
204
|
-
"
|
|
204
|
+
"prompt_id": "simple_prompt_builder",
|
|
205
205
|
}
|
|
206
206
|
|
|
207
207
|
|
|
@@ -238,7 +238,7 @@ async def test_mocked_repair_task_run(sample_task, sample_task_run, sample_repai
|
|
|
238
238
|
"adapter_name": "kiln_langchain_adapter",
|
|
239
239
|
"model_name": "llama_3_1_8b",
|
|
240
240
|
"model_provider": "ollama",
|
|
241
|
-
"
|
|
241
|
+
"prompt_id": "simple_prompt_builder",
|
|
242
242
|
}
|
|
243
243
|
assert run.input_source.type == DataSourceType.human
|
|
244
244
|
assert "created_by" in run.input_source.properties
|
kiln_ai/adapters/run_output.py
CHANGED
|
@@ -1,8 +1,11 @@
|
|
|
1
1
|
from dataclasses import dataclass
|
|
2
2
|
from typing import Dict
|
|
3
3
|
|
|
4
|
+
from openai.types.chat.chat_completion import ChoiceLogprobs
|
|
5
|
+
|
|
4
6
|
|
|
5
7
|
@dataclass
|
|
6
8
|
class RunOutput:
|
|
7
9
|
output: Dict | str
|
|
8
10
|
intermediate_outputs: Dict[str, str] | None
|
|
11
|
+
output_logprobs: ChoiceLogprobs | None = None
|
|
@@ -5,6 +5,7 @@ import pytest
|
|
|
5
5
|
from kiln_ai import datamodel
|
|
6
6
|
from kiln_ai.adapters.adapter_registry import adapter_for_task
|
|
7
7
|
from kiln_ai.adapters.ml_model_list import ModelProviderName
|
|
8
|
+
from kiln_ai.adapters.model_adapters.base_adapter import AdapterConfig
|
|
8
9
|
from kiln_ai.adapters.model_adapters.langchain_adapters import LangchainAdapter
|
|
9
10
|
from kiln_ai.adapters.model_adapters.openai_model_adapter import OpenAICompatibleAdapter
|
|
10
11
|
from kiln_ai.adapters.prompt_builders import BasePromptBuilder
|
|
@@ -84,24 +85,19 @@ def test_langchain_adapter_creation(mock_config, basic_task, provider):
|
|
|
84
85
|
)
|
|
85
86
|
|
|
86
87
|
assert isinstance(adapter, LangchainAdapter)
|
|
87
|
-
assert adapter.model_name == "test-model"
|
|
88
|
+
assert adapter.run_config.model_name == "test-model"
|
|
88
89
|
|
|
89
90
|
|
|
90
91
|
# TODO should run for all cases
|
|
91
92
|
def test_custom_prompt_builder(mock_config, basic_task):
|
|
92
|
-
class TestPromptBuilder(BasePromptBuilder):
|
|
93
|
-
def build_base_prompt(self, kiln_task) -> str:
|
|
94
|
-
return "test-prompt"
|
|
95
|
-
|
|
96
|
-
prompt_builder = TestPromptBuilder(basic_task)
|
|
97
93
|
adapter = adapter_for_task(
|
|
98
94
|
kiln_task=basic_task,
|
|
99
95
|
model_name="gpt-4",
|
|
100
96
|
provider=ModelProviderName.openai,
|
|
101
|
-
|
|
97
|
+
prompt_id="simple_chain_of_thought_prompt_builder",
|
|
102
98
|
)
|
|
103
99
|
|
|
104
|
-
assert adapter.
|
|
100
|
+
assert adapter.run_config.prompt_id == "simple_chain_of_thought_prompt_builder"
|
|
105
101
|
|
|
106
102
|
|
|
107
103
|
# TODO should run for all cases
|
|
@@ -111,10 +107,12 @@ def test_tags_passed_through(mock_config, basic_task):
|
|
|
111
107
|
kiln_task=basic_task,
|
|
112
108
|
model_name="gpt-4",
|
|
113
109
|
provider=ModelProviderName.openai,
|
|
114
|
-
|
|
110
|
+
base_adapter_config=AdapterConfig(
|
|
111
|
+
default_tags=tags,
|
|
112
|
+
),
|
|
115
113
|
)
|
|
116
114
|
|
|
117
|
-
assert adapter.default_tags == tags
|
|
115
|
+
assert adapter.base_adapter_config.default_tags == tags
|
|
118
116
|
|
|
119
117
|
|
|
120
118
|
def test_invalid_provider(mock_config, basic_task):
|
|
@@ -129,6 +127,7 @@ def test_openai_compatible_adapter(mock_compatible_config, mock_config, basic_ta
|
|
|
129
127
|
mock_compatible_config.return_value.model_name = "test-model"
|
|
130
128
|
mock_compatible_config.return_value.api_key = "test-key"
|
|
131
129
|
mock_compatible_config.return_value.base_url = "https://test.com/v1"
|
|
130
|
+
mock_compatible_config.return_value.provider_name = "CustomProvider99"
|
|
132
131
|
|
|
133
132
|
adapter = adapter_for_task(
|
|
134
133
|
kiln_task=basic_task,
|
|
@@ -141,6 +140,7 @@ def test_openai_compatible_adapter(mock_compatible_config, mock_config, basic_ta
|
|
|
141
140
|
assert adapter.config.model_name == "test-model"
|
|
142
141
|
assert adapter.config.api_key == "test-key"
|
|
143
142
|
assert adapter.config.base_url == "https://test.com/v1"
|
|
143
|
+
assert adapter.config.provider_name == "CustomProvider99"
|
|
144
144
|
|
|
145
145
|
|
|
146
146
|
def test_custom_openai_compatible_provider(mock_config, basic_task):
|
|
@@ -1,13 +1,13 @@
|
|
|
1
|
+
import logging
|
|
1
2
|
from typing import List
|
|
2
3
|
|
|
3
4
|
import pytest
|
|
4
5
|
|
|
5
|
-
from libs.core.kiln_ai.adapters.ml_model_list import
|
|
6
|
-
KilnModelProvider,
|
|
7
|
-
built_in_models,
|
|
8
|
-
)
|
|
6
|
+
from libs.core.kiln_ai.adapters.ml_model_list import KilnModelProvider, built_in_models
|
|
9
7
|
from libs.core.kiln_ai.adapters.provider_tools import provider_name_from_id
|
|
10
8
|
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
11
|
|
|
12
12
|
def _all_providers_support(providers: List[KilnModelProvider], attribute: str) -> bool:
|
|
13
13
|
"""Check if all providers support a given feature"""
|
|
@@ -58,8 +58,8 @@ def test_generate_model_table():
|
|
|
58
58
|
table.append(row)
|
|
59
59
|
|
|
60
60
|
# Print the table (useful for documentation)
|
|
61
|
-
|
|
62
|
-
|
|
61
|
+
logger.info("\nModel Capability Matrix:\n")
|
|
62
|
+
logger.info("\n".join(table))
|
|
63
63
|
|
|
64
64
|
# Basic assertions to ensure the table is well-formed
|
|
65
65
|
assert len(table) > 2, "Table should have header and at least one row"
|
|
@@ -10,7 +10,6 @@ from kiln_ai.adapters.ollama_tools import (
|
|
|
10
10
|
def test_parse_ollama_tags_no_models():
|
|
11
11
|
json_response = '{"models":[{"name":"scosman_net","model":"scosman_net:latest"},{"name":"phi3.5:latest","model":"phi3.5:latest","modified_at":"2024-10-02T12:04:35.191519822-04:00","size":2176178843,"digest":"61819fb370a3c1a9be6694869331e5f85f867a079e9271d66cb223acb81d04ba","details":{"parent_model":"","format":"gguf","family":"phi3","families":["phi3"],"parameter_size":"3.8B","quantization_level":"Q4_0"}},{"name":"gemma2:2b","model":"gemma2:2b","modified_at":"2024-09-09T16:46:38.64348929-04:00","size":1629518495,"digest":"8ccf136fdd5298f3ffe2d69862750ea7fb56555fa4d5b18c04e3fa4d82ee09d7","details":{"parent_model":"","format":"gguf","family":"gemma2","families":["gemma2"],"parameter_size":"2.6B","quantization_level":"Q4_0"}},{"name":"llama3.1:latest","model":"llama3.1:latest","modified_at":"2024-09-01T17:19:43.481523695-04:00","size":4661230720,"digest":"f66fc8dc39ea206e03ff6764fcc696b1b4dfb693f0b6ef751731dd4e6269046e","details":{"parent_model":"","format":"gguf","family":"llama","families":["llama"],"parameter_size":"8.0B","quantization_level":"Q4_0"}}]}'
|
|
12
12
|
tags = json.loads(json_response)
|
|
13
|
-
print(json.dumps(tags, indent=2))
|
|
14
13
|
conn = parse_ollama_tags(tags)
|
|
15
14
|
assert "phi3.5:latest" in conn.supported_models
|
|
16
15
|
assert "gemma2:2b" in conn.supported_models
|
|
@@ -13,6 +13,7 @@ from kiln_ai.adapters.prompt_builders import (
|
|
|
13
13
|
BasePromptBuilder,
|
|
14
14
|
SimpleChainOfThoughtPromptBuilder,
|
|
15
15
|
)
|
|
16
|
+
from kiln_ai.datamodel import PromptId
|
|
16
17
|
|
|
17
18
|
|
|
18
19
|
def get_all_models_and_providers():
|
|
@@ -132,7 +133,7 @@ async def test_mock_returning_run(tmp_path):
|
|
|
132
133
|
"adapter_name": "kiln_langchain_adapter",
|
|
133
134
|
"model_name": "custom.langchain:unknown_model",
|
|
134
135
|
"model_provider": "ollama",
|
|
135
|
-
"
|
|
136
|
+
"prompt_id": "simple_prompt_builder",
|
|
136
137
|
}
|
|
137
138
|
|
|
138
139
|
|
|
@@ -149,8 +150,9 @@ async def test_all_models_providers_plaintext(tmp_path, model_name, provider_nam
|
|
|
149
150
|
@pytest.mark.parametrize("model_name,provider_name", get_all_models_and_providers())
|
|
150
151
|
async def test_cot_prompt_builder(tmp_path, model_name, provider_name):
|
|
151
152
|
task = build_test_task(tmp_path)
|
|
152
|
-
|
|
153
|
-
|
|
153
|
+
await run_simple_task(
|
|
154
|
+
task, model_name, provider_name, "simple_chain_of_thought_prompt_builder"
|
|
155
|
+
)
|
|
154
156
|
|
|
155
157
|
|
|
156
158
|
def build_test_task(tmp_path: Path):
|
|
@@ -186,20 +188,20 @@ async def run_simple_test(
|
|
|
186
188
|
tmp_path: Path,
|
|
187
189
|
model_name: str,
|
|
188
190
|
provider: str | None = None,
|
|
189
|
-
|
|
191
|
+
prompt_id: PromptId | None = None,
|
|
190
192
|
):
|
|
191
193
|
task = build_test_task(tmp_path)
|
|
192
|
-
return await run_simple_task(task, model_name, provider,
|
|
194
|
+
return await run_simple_task(task, model_name, provider, prompt_id)
|
|
193
195
|
|
|
194
196
|
|
|
195
197
|
async def run_simple_task(
|
|
196
198
|
task: datamodel.Task,
|
|
197
199
|
model_name: str,
|
|
198
200
|
provider: str,
|
|
199
|
-
|
|
201
|
+
prompt_id: PromptId | None = None,
|
|
200
202
|
) -> datamodel.TaskRun:
|
|
201
203
|
adapter = adapter_for_task(
|
|
202
|
-
task, model_name=model_name, provider=provider,
|
|
204
|
+
task, model_name=model_name, provider=provider, prompt_id=prompt_id
|
|
203
205
|
)
|
|
204
206
|
|
|
205
207
|
run = await adapter.invoke(
|
|
@@ -212,13 +214,14 @@ async def run_simple_task(
|
|
|
212
214
|
)
|
|
213
215
|
assert "64" in run.output.output
|
|
214
216
|
source_props = run.output.source.properties
|
|
215
|
-
assert source_props["adapter_name"]
|
|
217
|
+
assert source_props["adapter_name"] in [
|
|
218
|
+
"kiln_langchain_adapter",
|
|
219
|
+
"kiln_openai_compatible_adapter",
|
|
220
|
+
]
|
|
216
221
|
assert source_props["model_name"] == model_name
|
|
217
222
|
assert source_props["model_provider"] == provider
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
)
|
|
223
|
-
assert source_props["prompt_builder_name"] == expected_prompt_builder_name
|
|
223
|
+
if prompt_id is None:
|
|
224
|
+
assert source_props["prompt_id"] == "simple_prompt_builder"
|
|
225
|
+
else:
|
|
226
|
+
assert source_props["prompt_id"] == prompt_id
|
|
224
227
|
return run
|