kiln-ai 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +11 -1
- kiln_ai/adapters/adapter_registry.py +19 -0
- kiln_ai/adapters/data_gen/__init__.py +11 -0
- kiln_ai/adapters/data_gen/data_gen_task.py +69 -1
- kiln_ai/adapters/data_gen/test_data_gen_task.py +30 -21
- kiln_ai/adapters/fine_tune/__init__.py +14 -0
- kiln_ai/adapters/fine_tune/base_finetune.py +186 -0
- kiln_ai/adapters/fine_tune/dataset_formatter.py +187 -0
- kiln_ai/adapters/fine_tune/finetune_registry.py +11 -0
- kiln_ai/adapters/fine_tune/fireworks_finetune.py +308 -0
- kiln_ai/adapters/fine_tune/openai_finetune.py +205 -0
- kiln_ai/adapters/fine_tune/test_base_finetune.py +290 -0
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +342 -0
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +455 -0
- kiln_ai/adapters/fine_tune/test_openai_finetune.py +503 -0
- kiln_ai/adapters/langchain_adapters.py +103 -13
- kiln_ai/adapters/ml_model_list.py +218 -304
- kiln_ai/adapters/ollama_tools.py +114 -0
- kiln_ai/adapters/provider_tools.py +295 -0
- kiln_ai/adapters/repair/test_repair_task.py +6 -11
- kiln_ai/adapters/test_langchain_adapter.py +46 -18
- kiln_ai/adapters/test_ollama_tools.py +42 -0
- kiln_ai/adapters/test_prompt_adaptors.py +7 -5
- kiln_ai/adapters/test_provider_tools.py +312 -0
- kiln_ai/adapters/test_structured_output.py +22 -43
- kiln_ai/datamodel/__init__.py +235 -22
- kiln_ai/datamodel/basemodel.py +30 -0
- kiln_ai/datamodel/registry.py +31 -0
- kiln_ai/datamodel/test_basemodel.py +29 -1
- kiln_ai/datamodel/test_dataset_split.py +234 -0
- kiln_ai/datamodel/test_example_models.py +12 -0
- kiln_ai/datamodel/test_models.py +91 -1
- kiln_ai/datamodel/test_registry.py +96 -0
- kiln_ai/utils/config.py +9 -0
- kiln_ai/utils/name_generator.py +125 -0
- kiln_ai/utils/test_name_geneator.py +47 -0
- {kiln_ai-0.6.0.dist-info → kiln_ai-0.7.0.dist-info}/METADATA +4 -2
- kiln_ai-0.7.0.dist-info/RECORD +56 -0
- kiln_ai/adapters/test_ml_model_list.py +0 -181
- kiln_ai-0.6.0.dist-info/RECORD +0 -36
- {kiln_ai-0.6.0.dist-info → kiln_ai-0.7.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.6.0.dist-info → kiln_ai-0.7.0.dist-info}/licenses/LICENSE.txt +0 -0
kiln_ai/adapters/__init__.py
CHANGED
|
@@ -12,7 +12,15 @@ The prompt_builders submodule contains classes that build prompts for use with t
|
|
|
12
12
|
The repair submodule contains an adapter for the repair task.
|
|
13
13
|
"""
|
|
14
14
|
|
|
15
|
-
from . import
|
|
15
|
+
from . import (
|
|
16
|
+
base_adapter,
|
|
17
|
+
data_gen,
|
|
18
|
+
fine_tune,
|
|
19
|
+
langchain_adapters,
|
|
20
|
+
ml_model_list,
|
|
21
|
+
prompt_builders,
|
|
22
|
+
repair,
|
|
23
|
+
)
|
|
16
24
|
|
|
17
25
|
__all__ = [
|
|
18
26
|
"base_adapter",
|
|
@@ -20,4 +28,6 @@ __all__ = [
|
|
|
20
28
|
"ml_model_list",
|
|
21
29
|
"prompt_builders",
|
|
22
30
|
"repair",
|
|
31
|
+
"data_gen",
|
|
32
|
+
"fine_tune",
|
|
23
33
|
]
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
from kiln_ai import datamodel
|
|
2
|
+
from kiln_ai.adapters.base_adapter import BaseAdapter
|
|
3
|
+
from kiln_ai.adapters.langchain_adapters import LangchainAdapter
|
|
4
|
+
from kiln_ai.adapters.prompt_builders import BasePromptBuilder
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def adapter_for_task(
|
|
8
|
+
kiln_task: datamodel.Task,
|
|
9
|
+
model_name: str | None = None,
|
|
10
|
+
provider: str | None = None,
|
|
11
|
+
prompt_builder: BasePromptBuilder | None = None,
|
|
12
|
+
) -> BaseAdapter:
|
|
13
|
+
# We use langchain for everything right now, but can add any others here
|
|
14
|
+
return LangchainAdapter(
|
|
15
|
+
kiln_task,
|
|
16
|
+
model_name=model_name,
|
|
17
|
+
provider=provider,
|
|
18
|
+
prompt_builder=prompt_builder,
|
|
19
|
+
)
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
"""
|
|
2
|
+
# Data Generation
|
|
3
|
+
|
|
4
|
+
A task to generate synthetic data for Kiln Tasks. This generates the inputs, which then can be run through the task.
|
|
5
|
+
|
|
6
|
+
Optional human guidance can be provided to guide the generation process.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from . import data_gen_task
|
|
10
|
+
|
|
11
|
+
__all__ = ["data_gen_task"]
|
|
@@ -1,8 +1,9 @@
|
|
|
1
1
|
import json
|
|
2
2
|
|
|
3
|
+
from pydantic import BaseModel
|
|
4
|
+
|
|
3
5
|
from kiln_ai.adapters.prompt_builders import SimplePromptBuilder
|
|
4
6
|
from kiln_ai.datamodel import Project, Task
|
|
5
|
-
from pydantic import BaseModel
|
|
6
7
|
|
|
7
8
|
from .data_gen_prompts import (
|
|
8
9
|
SAMPLE_GENERATION_PROMPT,
|
|
@@ -11,6 +12,16 @@ from .data_gen_prompts import (
|
|
|
11
12
|
|
|
12
13
|
|
|
13
14
|
class DataGenCategoriesTaskInput(BaseModel):
|
|
15
|
+
"""Input model for generating categories/subtopics.
|
|
16
|
+
|
|
17
|
+
Attributes:
|
|
18
|
+
node_path: List of strings representing the hierarchical path to current node
|
|
19
|
+
system_prompt: System prompt to guide the AI generation
|
|
20
|
+
num_subtopics: Number of subtopics to generate
|
|
21
|
+
human_guidance: Optional human guidance to influence generation
|
|
22
|
+
existing_topics: Optional list of existing topics to avoid duplication
|
|
23
|
+
"""
|
|
24
|
+
|
|
14
25
|
node_path: list[str]
|
|
15
26
|
system_prompt: str
|
|
16
27
|
num_subtopics: int
|
|
@@ -26,6 +37,18 @@ class DataGenCategoriesTaskInput(BaseModel):
|
|
|
26
37
|
human_guidance: str | None = None,
|
|
27
38
|
existing_topics: list[str] | None = None,
|
|
28
39
|
) -> "DataGenCategoriesTaskInput":
|
|
40
|
+
"""Create a DataGenCategoriesTaskInput instance from a Task.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
task: The source Task object
|
|
44
|
+
node_path: Path to current node in topic hierarchy
|
|
45
|
+
num_subtopics: Number of subtopics to generate
|
|
46
|
+
human_guidance: Optional guidance for generation
|
|
47
|
+
existing_topics: Optional list of existing topics
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
A new DataGenCategoriesTaskInput instance
|
|
51
|
+
"""
|
|
29
52
|
prompt_builder = SimplePromptBuilder(task=task)
|
|
30
53
|
return cls(
|
|
31
54
|
node_path=node_path,
|
|
@@ -37,10 +60,22 @@ class DataGenCategoriesTaskInput(BaseModel):
|
|
|
37
60
|
|
|
38
61
|
|
|
39
62
|
class DataGenCategoriesTaskOutput(BaseModel):
|
|
63
|
+
"""Output model for generated categories/subtopics.
|
|
64
|
+
|
|
65
|
+
Attributes:
|
|
66
|
+
subtopics: List of generated subtopic strings
|
|
67
|
+
"""
|
|
68
|
+
|
|
40
69
|
subtopics: list[str]
|
|
41
70
|
|
|
42
71
|
|
|
43
72
|
class DataGenCategoriesTask(Task, parent_of={}):
|
|
73
|
+
"""Task for generating hierarchical categories/subtopics.
|
|
74
|
+
|
|
75
|
+
Generates synthetic data categories which can be used to generate
|
|
76
|
+
training data for model learning.
|
|
77
|
+
"""
|
|
78
|
+
|
|
44
79
|
def __init__(self):
|
|
45
80
|
# Keep the typechecker happy. TODO: shouldn't need this or parent_of above.
|
|
46
81
|
tmp_project = Project(name="DataGen")
|
|
@@ -59,6 +94,15 @@ class DataGenCategoriesTask(Task, parent_of={}):
|
|
|
59
94
|
|
|
60
95
|
|
|
61
96
|
class DataGenSampleTaskInput(BaseModel):
|
|
97
|
+
"""Input model for generating data samples for a kiln task.
|
|
98
|
+
|
|
99
|
+
Attributes:
|
|
100
|
+
topic: List of strings representing the topic path
|
|
101
|
+
system_prompt: System prompt to guide the AI generation
|
|
102
|
+
num_samples: Number of samples to generate
|
|
103
|
+
human_guidance: Optional human guidance to influence generation
|
|
104
|
+
"""
|
|
105
|
+
|
|
62
106
|
topic: list[str]
|
|
63
107
|
system_prompt: str
|
|
64
108
|
num_samples: int
|
|
@@ -72,6 +116,17 @@ class DataGenSampleTaskInput(BaseModel):
|
|
|
72
116
|
num_samples: int = 8,
|
|
73
117
|
human_guidance: str | None = None,
|
|
74
118
|
) -> "DataGenSampleTaskInput":
|
|
119
|
+
"""Create a DataGenSampleTaskInput instance from a Task.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
task: The source Task object
|
|
123
|
+
topic: Topic path for sample generation
|
|
124
|
+
num_samples: Number of samples to generate
|
|
125
|
+
human_guidance: Optional guidance for generation
|
|
126
|
+
|
|
127
|
+
Returns:
|
|
128
|
+
A new DataGenSampleTaskInput instance
|
|
129
|
+
"""
|
|
75
130
|
prompt_builder = SimplePromptBuilder(task=task)
|
|
76
131
|
return cls(
|
|
77
132
|
topic=topic,
|
|
@@ -82,6 +137,14 @@ class DataGenSampleTaskInput(BaseModel):
|
|
|
82
137
|
|
|
83
138
|
|
|
84
139
|
def list_json_schema_for_task(task: Task) -> str:
|
|
140
|
+
"""Generate a JSON schema for a list of task inputs (json schema)
|
|
141
|
+
|
|
142
|
+
Args:
|
|
143
|
+
task: Task object whose input schema will be used
|
|
144
|
+
|
|
145
|
+
Returns:
|
|
146
|
+
JSON string representing the schema for a list of task inputs
|
|
147
|
+
"""
|
|
85
148
|
if task.input_json_schema:
|
|
86
149
|
items_schema = json.loads(task.input_json_schema)
|
|
87
150
|
else:
|
|
@@ -104,6 +167,11 @@ def list_json_schema_for_task(task: Task) -> str:
|
|
|
104
167
|
|
|
105
168
|
|
|
106
169
|
class DataGenSampleTask(Task, parent_of={}):
|
|
170
|
+
"""Task for generating data samples for a given topic.
|
|
171
|
+
|
|
172
|
+
Generates synthetic data samples based on provided topics and subtopics.
|
|
173
|
+
"""
|
|
174
|
+
|
|
107
175
|
def __init__(self, target_task: Task, num_samples: int = 8):
|
|
108
176
|
# Keep the typechecker happy. TODO: shouldn't need this or parent_of above.
|
|
109
177
|
tmp_project = Project(name="DataGenSample")
|
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
import json
|
|
2
2
|
|
|
3
3
|
import pytest
|
|
4
|
+
|
|
5
|
+
from kiln_ai.adapters.adapter_registry import adapter_for_task
|
|
4
6
|
from kiln_ai.adapters.data_gen.data_gen_task import (
|
|
5
7
|
DataGenCategoriesTask,
|
|
6
8
|
DataGenCategoriesTaskInput,
|
|
@@ -9,8 +11,7 @@ from kiln_ai.adapters.data_gen.data_gen_task import (
|
|
|
9
11
|
DataGenSampleTaskInput,
|
|
10
12
|
list_json_schema_for_task,
|
|
11
13
|
)
|
|
12
|
-
from kiln_ai.adapters.
|
|
13
|
-
from kiln_ai.adapters.ml_model_list import get_model_and_provider
|
|
14
|
+
from kiln_ai.adapters.provider_tools import get_model_and_provider
|
|
14
15
|
from kiln_ai.adapters.test_prompt_adaptors import get_all_models_and_providers
|
|
15
16
|
from kiln_ai.datamodel import Project, Task
|
|
16
17
|
|
|
@@ -107,7 +108,7 @@ async def test_data_gen_all_models_providers(
|
|
|
107
108
|
data_gen_task = DataGenCategoriesTask()
|
|
108
109
|
data_gen_input = DataGenCategoriesTaskInput.from_task(base_task, num_subtopics=6)
|
|
109
110
|
|
|
110
|
-
adapter =
|
|
111
|
+
adapter = adapter_for_task(
|
|
111
112
|
data_gen_task,
|
|
112
113
|
model_name=model_name,
|
|
113
114
|
provider=provider_name,
|
|
@@ -231,7 +232,7 @@ async def test_data_gen_sample_all_models_providers(
|
|
|
231
232
|
base_task, topic=["riding horses"], num_samples=4
|
|
232
233
|
)
|
|
233
234
|
|
|
234
|
-
adapter =
|
|
235
|
+
adapter = adapter_for_task(
|
|
235
236
|
data_gen_task,
|
|
236
237
|
model_name=model_name,
|
|
237
238
|
provider=provider_name,
|
|
@@ -250,17 +251,25 @@ async def test_data_gen_sample_all_models_providers(
|
|
|
250
251
|
@pytest.mark.ollama
|
|
251
252
|
@pytest.mark.parametrize("model_name,provider_name", get_all_models_and_providers())
|
|
252
253
|
async def test_data_gen_sample_all_models_providers_with_structured_output(
|
|
253
|
-
tmp_path, model_name, provider_name
|
|
254
|
+
tmp_path, model_name, provider_name
|
|
254
255
|
):
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
256
|
+
project = Project(name="TestProject")
|
|
257
|
+
task = Task(
|
|
258
|
+
name="Summarize",
|
|
259
|
+
parent=project,
|
|
260
|
+
description="Explain if the username matches the tweet",
|
|
261
|
+
instruction="Explain if the username matches the tweet",
|
|
262
|
+
requirements=[],
|
|
263
|
+
input_json_schema=json.dumps(
|
|
264
|
+
{
|
|
265
|
+
"type": "object",
|
|
266
|
+
"properties": {
|
|
267
|
+
"username": {"type": "string"},
|
|
268
|
+
"tweet": {"type": "string"},
|
|
269
|
+
},
|
|
270
|
+
"required": ["username", "tweet"],
|
|
271
|
+
}
|
|
272
|
+
),
|
|
264
273
|
)
|
|
265
274
|
|
|
266
275
|
_, provider = get_model_and_provider(model_name, provider_name)
|
|
@@ -268,12 +277,12 @@ async def test_data_gen_sample_all_models_providers_with_structured_output(
|
|
|
268
277
|
# pass if the model doesn't support data gen (testing the support flag is part of this)
|
|
269
278
|
return
|
|
270
279
|
|
|
271
|
-
data_gen_task = DataGenSampleTask(target_task=
|
|
280
|
+
data_gen_task = DataGenSampleTask(target_task=task)
|
|
272
281
|
data_gen_input = DataGenSampleTaskInput.from_task(
|
|
273
|
-
|
|
282
|
+
task, topic=["Food"], num_samples=4
|
|
274
283
|
)
|
|
275
284
|
|
|
276
|
-
adapter =
|
|
285
|
+
adapter = adapter_for_task(
|
|
277
286
|
data_gen_task,
|
|
278
287
|
model_name=model_name,
|
|
279
288
|
provider=provider_name,
|
|
@@ -286,7 +295,7 @@ async def test_data_gen_sample_all_models_providers_with_structured_output(
|
|
|
286
295
|
assert len(samples) == 4
|
|
287
296
|
for sample in samples:
|
|
288
297
|
assert isinstance(sample, dict)
|
|
289
|
-
assert "
|
|
290
|
-
assert "
|
|
291
|
-
assert isinstance(sample["
|
|
292
|
-
assert isinstance(sample["
|
|
298
|
+
assert "username" in sample
|
|
299
|
+
assert "tweet" in sample
|
|
300
|
+
assert isinstance(sample["username"], str)
|
|
301
|
+
assert isinstance(sample["tweet"], str)
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
"""
|
|
2
|
+
# Fine-Tuning
|
|
3
|
+
|
|
4
|
+
A set of classes for fine-tuning models.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from . import base_finetune, dataset_formatter, finetune_registry, openai_finetune
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
"base_finetune",
|
|
11
|
+
"openai_finetune",
|
|
12
|
+
"dataset_formatter",
|
|
13
|
+
"finetune_registry",
|
|
14
|
+
]
|
|
@@ -0,0 +1,186 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import Literal
|
|
3
|
+
|
|
4
|
+
from pydantic import BaseModel
|
|
5
|
+
|
|
6
|
+
from kiln_ai.adapters.ml_model_list import built_in_models
|
|
7
|
+
from kiln_ai.datamodel import DatasetSplit, FineTuneStatusType
|
|
8
|
+
from kiln_ai.datamodel import Finetune as FinetuneModel
|
|
9
|
+
from kiln_ai.utils.name_generator import generate_memorable_name
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class FineTuneStatus(BaseModel):
|
|
13
|
+
"""
|
|
14
|
+
The status of a fine-tune, including a user friendly message.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
status: FineTuneStatusType
|
|
18
|
+
message: str | None = None
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class FineTuneParameter(BaseModel):
|
|
22
|
+
"""
|
|
23
|
+
A parameter for a fine-tune. Hyperparameters, etc.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
name: str
|
|
27
|
+
type: Literal["string", "int", "float", "bool"]
|
|
28
|
+
description: str
|
|
29
|
+
optional: bool = True
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
TYPE_MAP = {
|
|
33
|
+
"string": str,
|
|
34
|
+
"int": int,
|
|
35
|
+
"float": float,
|
|
36
|
+
"bool": bool,
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class BaseFinetuneAdapter(ABC):
|
|
41
|
+
"""
|
|
42
|
+
A base class for fine-tuning adapters.
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
def __init__(
|
|
46
|
+
self,
|
|
47
|
+
datamodel: FinetuneModel,
|
|
48
|
+
):
|
|
49
|
+
self.datamodel = datamodel
|
|
50
|
+
|
|
51
|
+
@classmethod
|
|
52
|
+
async def create_and_start(
|
|
53
|
+
cls,
|
|
54
|
+
dataset: DatasetSplit,
|
|
55
|
+
provider_id: str,
|
|
56
|
+
provider_base_model_id: str,
|
|
57
|
+
train_split_name: str,
|
|
58
|
+
system_message: str,
|
|
59
|
+
parameters: dict[str, str | int | float | bool] = {},
|
|
60
|
+
name: str | None = None,
|
|
61
|
+
description: str | None = None,
|
|
62
|
+
validation_split_name: str | None = None,
|
|
63
|
+
) -> tuple["BaseFinetuneAdapter", FinetuneModel]:
|
|
64
|
+
"""
|
|
65
|
+
Create and start a fine-tune.
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
cls.check_valid_provider_model(provider_id, provider_base_model_id)
|
|
69
|
+
|
|
70
|
+
if not dataset.id:
|
|
71
|
+
raise ValueError("Dataset must have an id")
|
|
72
|
+
|
|
73
|
+
if train_split_name not in dataset.split_contents:
|
|
74
|
+
raise ValueError(f"Train split {train_split_name} not found in dataset")
|
|
75
|
+
|
|
76
|
+
if (
|
|
77
|
+
validation_split_name
|
|
78
|
+
and validation_split_name not in dataset.split_contents
|
|
79
|
+
):
|
|
80
|
+
raise ValueError(
|
|
81
|
+
f"Validation split {validation_split_name} not found in dataset"
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
# Default name if not provided
|
|
85
|
+
if name is None:
|
|
86
|
+
name = generate_memorable_name()
|
|
87
|
+
|
|
88
|
+
cls.validate_parameters(parameters)
|
|
89
|
+
parent_task = dataset.parent_task()
|
|
90
|
+
if parent_task is None or not parent_task.path:
|
|
91
|
+
raise ValueError("Dataset must have a parent task with a path")
|
|
92
|
+
|
|
93
|
+
datamodel = FinetuneModel(
|
|
94
|
+
name=name,
|
|
95
|
+
description=description,
|
|
96
|
+
provider=provider_id,
|
|
97
|
+
base_model_id=provider_base_model_id,
|
|
98
|
+
dataset_split_id=dataset.id,
|
|
99
|
+
train_split_name=train_split_name,
|
|
100
|
+
validation_split_name=validation_split_name,
|
|
101
|
+
parameters=parameters,
|
|
102
|
+
system_message=system_message,
|
|
103
|
+
parent=parent_task,
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
adapter = cls(datamodel)
|
|
107
|
+
await adapter._start(dataset)
|
|
108
|
+
|
|
109
|
+
datamodel.save_to_file()
|
|
110
|
+
|
|
111
|
+
return adapter, datamodel
|
|
112
|
+
|
|
113
|
+
@abstractmethod
|
|
114
|
+
async def _start(self, dataset: DatasetSplit) -> None:
|
|
115
|
+
"""
|
|
116
|
+
Start the fine-tune.
|
|
117
|
+
"""
|
|
118
|
+
pass
|
|
119
|
+
|
|
120
|
+
@abstractmethod
|
|
121
|
+
async def status(self) -> FineTuneStatus:
|
|
122
|
+
"""
|
|
123
|
+
Get the status of the fine-tune.
|
|
124
|
+
"""
|
|
125
|
+
pass
|
|
126
|
+
|
|
127
|
+
@classmethod
|
|
128
|
+
def available_parameters(cls) -> list[FineTuneParameter]:
|
|
129
|
+
"""
|
|
130
|
+
Returns a list of parameters that can be provided for this fine-tune. Includes hyperparameters, etc.
|
|
131
|
+
"""
|
|
132
|
+
return []
|
|
133
|
+
|
|
134
|
+
@classmethod
|
|
135
|
+
def validate_parameters(
|
|
136
|
+
cls, parameters: dict[str, str | int | float | bool]
|
|
137
|
+
) -> None:
|
|
138
|
+
"""
|
|
139
|
+
Validate the parameters for this fine-tune.
|
|
140
|
+
"""
|
|
141
|
+
# Check required parameters and parameter types
|
|
142
|
+
available_parameters = cls.available_parameters()
|
|
143
|
+
for parameter in available_parameters:
|
|
144
|
+
if not parameter.optional and parameter.name not in parameters:
|
|
145
|
+
raise ValueError(f"Parameter {parameter.name} is required")
|
|
146
|
+
elif parameter.name in parameters:
|
|
147
|
+
# check parameter is correct type
|
|
148
|
+
expected_type = TYPE_MAP[parameter.type]
|
|
149
|
+
value = parameters[parameter.name]
|
|
150
|
+
|
|
151
|
+
# Strict type checking for numeric types
|
|
152
|
+
if expected_type is float and not isinstance(value, float):
|
|
153
|
+
raise ValueError(
|
|
154
|
+
f"Parameter {parameter.name} must be a float, got {type(value)}"
|
|
155
|
+
)
|
|
156
|
+
elif expected_type is int and not isinstance(value, int):
|
|
157
|
+
raise ValueError(
|
|
158
|
+
f"Parameter {parameter.name} must be an integer, got {type(value)}"
|
|
159
|
+
)
|
|
160
|
+
elif not isinstance(value, expected_type):
|
|
161
|
+
raise ValueError(
|
|
162
|
+
f"Parameter {parameter.name} must be type {expected_type}, got {type(value)}"
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
allowed_parameters = [p.name for p in available_parameters]
|
|
166
|
+
for parameter_key in parameters:
|
|
167
|
+
if parameter_key not in allowed_parameters:
|
|
168
|
+
raise ValueError(f"Parameter {parameter_key} is not available")
|
|
169
|
+
|
|
170
|
+
@classmethod
|
|
171
|
+
def check_valid_provider_model(
|
|
172
|
+
cls, provider_id: str, provider_base_model_id: str
|
|
173
|
+
) -> None:
|
|
174
|
+
"""
|
|
175
|
+
Check if the provider and base model are valid.
|
|
176
|
+
"""
|
|
177
|
+
for model in built_in_models:
|
|
178
|
+
for provider in model.providers:
|
|
179
|
+
if (
|
|
180
|
+
provider.name == provider_id
|
|
181
|
+
and provider.provider_finetune_id == provider_base_model_id
|
|
182
|
+
):
|
|
183
|
+
return
|
|
184
|
+
raise ValueError(
|
|
185
|
+
f"Provider {provider_id} with base model {provider_base_model_id} is not available"
|
|
186
|
+
)
|
|
@@ -0,0 +1,187 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import tempfile
|
|
3
|
+
from enum import Enum
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any, Dict, Protocol
|
|
6
|
+
from uuid import uuid4
|
|
7
|
+
|
|
8
|
+
from kiln_ai.datamodel import DatasetSplit, TaskRun
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class DatasetFormat(str, Enum):
|
|
12
|
+
"""Formats for dataset generation. Both for file format (like JSONL), and internal structure (like chat/toolcall)"""
|
|
13
|
+
|
|
14
|
+
"""OpenAI chat format with plaintext response"""
|
|
15
|
+
OPENAI_CHAT_JSONL = "openai_chat_jsonl"
|
|
16
|
+
|
|
17
|
+
"""OpenAI chat format with tool call response"""
|
|
18
|
+
OPENAI_CHAT_TOOLCALL_JSONL = "openai_chat_toolcall_jsonl"
|
|
19
|
+
|
|
20
|
+
"""HuggingFace chat template in JSONL"""
|
|
21
|
+
HUGGINGFACE_CHAT_TEMPLATE_JSONL = "huggingface_chat_template_jsonl"
|
|
22
|
+
|
|
23
|
+
"""HuggingFace chat template with tool calls in JSONL"""
|
|
24
|
+
HUGGINGFACE_CHAT_TEMPLATE_TOOLCALL_JSONL = (
|
|
25
|
+
"huggingface_chat_template_toolcall_jsonl"
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class FormatGenerator(Protocol):
|
|
30
|
+
"""Protocol for format generators"""
|
|
31
|
+
|
|
32
|
+
def __call__(self, task_run: TaskRun, system_message: str) -> Dict[str, Any]: ...
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def generate_chat_message_response(
|
|
36
|
+
task_run: TaskRun, system_message: str
|
|
37
|
+
) -> Dict[str, Any]:
|
|
38
|
+
"""Generate OpenAI chat format with plaintext response"""
|
|
39
|
+
return {
|
|
40
|
+
"messages": [
|
|
41
|
+
{"role": "system", "content": system_message},
|
|
42
|
+
{"role": "user", "content": task_run.input},
|
|
43
|
+
{"role": "assistant", "content": task_run.output.output},
|
|
44
|
+
]
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def generate_chat_message_toolcall(
|
|
49
|
+
task_run: TaskRun, system_message: str
|
|
50
|
+
) -> Dict[str, Any]:
|
|
51
|
+
"""Generate OpenAI chat format with tool call response"""
|
|
52
|
+
try:
|
|
53
|
+
arguments = json.loads(task_run.output.output)
|
|
54
|
+
except json.JSONDecodeError as e:
|
|
55
|
+
raise ValueError(f"Invalid JSON in for tool call: {e}") from e
|
|
56
|
+
|
|
57
|
+
return {
|
|
58
|
+
"messages": [
|
|
59
|
+
{"role": "system", "content": system_message},
|
|
60
|
+
{"role": "user", "content": task_run.input},
|
|
61
|
+
{
|
|
62
|
+
"role": "assistant",
|
|
63
|
+
"content": None,
|
|
64
|
+
"tool_calls": [
|
|
65
|
+
{
|
|
66
|
+
"id": "call_1",
|
|
67
|
+
"type": "function",
|
|
68
|
+
"function": {
|
|
69
|
+
"name": "task_response",
|
|
70
|
+
# Yes we parse then dump again. This ensures it's valid JSON, and ensures it goes to 1 line
|
|
71
|
+
"arguments": json.dumps(arguments),
|
|
72
|
+
},
|
|
73
|
+
}
|
|
74
|
+
],
|
|
75
|
+
},
|
|
76
|
+
]
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def generate_huggingface_chat_template(
|
|
81
|
+
task_run: TaskRun, system_message: str
|
|
82
|
+
) -> Dict[str, Any]:
|
|
83
|
+
"""Generate HuggingFace chat template"""
|
|
84
|
+
return {
|
|
85
|
+
"conversations": [
|
|
86
|
+
{"role": "system", "content": system_message},
|
|
87
|
+
{"role": "user", "content": task_run.input},
|
|
88
|
+
{"role": "assistant", "content": task_run.output.output},
|
|
89
|
+
]
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def generate_huggingface_chat_template_toolcall(
|
|
94
|
+
task_run: TaskRun, system_message: str
|
|
95
|
+
) -> Dict[str, Any]:
|
|
96
|
+
"""Generate HuggingFace chat template with tool calls"""
|
|
97
|
+
try:
|
|
98
|
+
arguments = json.loads(task_run.output.output)
|
|
99
|
+
except json.JSONDecodeError as e:
|
|
100
|
+
raise ValueError(f"Invalid JSON in for tool call: {e}") from e
|
|
101
|
+
|
|
102
|
+
# See https://huggingface.co/docs/transformers/en/chat_templating
|
|
103
|
+
return {
|
|
104
|
+
"conversations": [
|
|
105
|
+
{"role": "system", "content": system_message},
|
|
106
|
+
{"role": "user", "content": task_run.input},
|
|
107
|
+
{
|
|
108
|
+
"role": "assistant",
|
|
109
|
+
"tool_calls": [
|
|
110
|
+
{
|
|
111
|
+
"type": "function",
|
|
112
|
+
"function": {
|
|
113
|
+
"name": "task_response",
|
|
114
|
+
"id": str(uuid4()).replace("-", "")[:9],
|
|
115
|
+
"arguments": arguments,
|
|
116
|
+
},
|
|
117
|
+
}
|
|
118
|
+
],
|
|
119
|
+
},
|
|
120
|
+
]
|
|
121
|
+
}
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
FORMAT_GENERATORS: Dict[DatasetFormat, FormatGenerator] = {
|
|
125
|
+
DatasetFormat.OPENAI_CHAT_JSONL: generate_chat_message_response,
|
|
126
|
+
DatasetFormat.OPENAI_CHAT_TOOLCALL_JSONL: generate_chat_message_toolcall,
|
|
127
|
+
DatasetFormat.HUGGINGFACE_CHAT_TEMPLATE_JSONL: generate_huggingface_chat_template,
|
|
128
|
+
DatasetFormat.HUGGINGFACE_CHAT_TEMPLATE_TOOLCALL_JSONL: generate_huggingface_chat_template_toolcall,
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
class DatasetFormatter:
|
|
133
|
+
"""Handles formatting of datasets into various output formats"""
|
|
134
|
+
|
|
135
|
+
def __init__(self, dataset: DatasetSplit, system_message: str):
|
|
136
|
+
self.dataset = dataset
|
|
137
|
+
self.system_message = system_message
|
|
138
|
+
|
|
139
|
+
task = dataset.parent_task()
|
|
140
|
+
if task is None:
|
|
141
|
+
raise ValueError("Dataset has no parent task")
|
|
142
|
+
self.task = task
|
|
143
|
+
|
|
144
|
+
def dump_to_file(
|
|
145
|
+
self, split_name: str, format_type: DatasetFormat, path: Path | None = None
|
|
146
|
+
) -> Path:
|
|
147
|
+
"""
|
|
148
|
+
Format the dataset into the specified format.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
split_name: Name of the split to dump
|
|
152
|
+
format_type: Format to generate the dataset in
|
|
153
|
+
path: Optional path to write to. If None, writes to temp directory
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
Path to the generated file
|
|
157
|
+
"""
|
|
158
|
+
if format_type not in FORMAT_GENERATORS:
|
|
159
|
+
raise ValueError(f"Unsupported format: {format_type}")
|
|
160
|
+
if split_name not in self.dataset.split_contents:
|
|
161
|
+
raise ValueError(f"Split {split_name} not found in dataset")
|
|
162
|
+
|
|
163
|
+
generator = FORMAT_GENERATORS[format_type]
|
|
164
|
+
|
|
165
|
+
# Write to a temp file if no path is provided
|
|
166
|
+
output_path = (
|
|
167
|
+
path
|
|
168
|
+
or Path(tempfile.gettempdir())
|
|
169
|
+
/ f"{self.dataset.name}_{split_name}_{format_type}.jsonl"
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
runs = self.task.runs()
|
|
173
|
+
runs_by_id = {run.id: run for run in runs}
|
|
174
|
+
|
|
175
|
+
# Generate formatted output with UTF-8 encoding
|
|
176
|
+
with open(output_path, "w", encoding="utf-8") as f:
|
|
177
|
+
for run_id in self.dataset.split_contents[split_name]:
|
|
178
|
+
task_run = runs_by_id[run_id]
|
|
179
|
+
if task_run is None:
|
|
180
|
+
raise ValueError(
|
|
181
|
+
f"Task run {run_id} not found. This is required by this dataset."
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
example = generator(task_run, self.system_message)
|
|
185
|
+
f.write(json.dumps(example) + "\n")
|
|
186
|
+
|
|
187
|
+
return output_path
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
from typing import Type
|
|
2
|
+
|
|
3
|
+
from kiln_ai.adapters.fine_tune.base_finetune import BaseFinetuneAdapter
|
|
4
|
+
from kiln_ai.adapters.fine_tune.fireworks_finetune import FireworksFinetune
|
|
5
|
+
from kiln_ai.adapters.fine_tune.openai_finetune import OpenAIFinetune
|
|
6
|
+
from kiln_ai.adapters.ml_model_list import ModelProviderName
|
|
7
|
+
|
|
8
|
+
finetune_registry: dict[ModelProviderName, Type[BaseFinetuneAdapter]] = {
|
|
9
|
+
ModelProviderName.openai: OpenAIFinetune,
|
|
10
|
+
ModelProviderName.fireworks_ai: FireworksFinetune,
|
|
11
|
+
}
|