kiln-ai 0.6.1__py3-none-any.whl → 0.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +2 -0
- kiln_ai/adapters/adapter_registry.py +19 -0
- kiln_ai/adapters/data_gen/test_data_gen_task.py +29 -21
- kiln_ai/adapters/fine_tune/__init__.py +14 -0
- kiln_ai/adapters/fine_tune/base_finetune.py +186 -0
- kiln_ai/adapters/fine_tune/dataset_formatter.py +187 -0
- kiln_ai/adapters/fine_tune/finetune_registry.py +11 -0
- kiln_ai/adapters/fine_tune/fireworks_finetune.py +308 -0
- kiln_ai/adapters/fine_tune/openai_finetune.py +205 -0
- kiln_ai/adapters/fine_tune/test_base_finetune.py +290 -0
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +342 -0
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +455 -0
- kiln_ai/adapters/fine_tune/test_openai_finetune.py +503 -0
- kiln_ai/adapters/langchain_adapters.py +103 -13
- kiln_ai/adapters/ml_model_list.py +239 -303
- kiln_ai/adapters/ollama_tools.py +115 -0
- kiln_ai/adapters/provider_tools.py +308 -0
- kiln_ai/adapters/repair/repair_task.py +4 -2
- kiln_ai/adapters/repair/test_repair_task.py +6 -11
- kiln_ai/adapters/test_langchain_adapter.py +229 -18
- kiln_ai/adapters/test_ollama_tools.py +42 -0
- kiln_ai/adapters/test_prompt_adaptors.py +7 -5
- kiln_ai/adapters/test_provider_tools.py +531 -0
- kiln_ai/adapters/test_structured_output.py +22 -43
- kiln_ai/datamodel/__init__.py +287 -24
- kiln_ai/datamodel/basemodel.py +122 -38
- kiln_ai/datamodel/model_cache.py +116 -0
- kiln_ai/datamodel/registry.py +31 -0
- kiln_ai/datamodel/test_basemodel.py +167 -4
- kiln_ai/datamodel/test_dataset_split.py +234 -0
- kiln_ai/datamodel/test_example_models.py +12 -0
- kiln_ai/datamodel/test_model_cache.py +244 -0
- kiln_ai/datamodel/test_models.py +215 -1
- kiln_ai/datamodel/test_registry.py +96 -0
- kiln_ai/utils/config.py +14 -1
- kiln_ai/utils/name_generator.py +125 -0
- kiln_ai/utils/test_name_geneator.py +47 -0
- kiln_ai-0.7.1.dist-info/METADATA +237 -0
- kiln_ai-0.7.1.dist-info/RECORD +58 -0
- {kiln_ai-0.6.1.dist-info → kiln_ai-0.7.1.dist-info}/WHEEL +1 -1
- kiln_ai/adapters/test_ml_model_list.py +0 -181
- kiln_ai-0.6.1.dist-info/METADATA +0 -88
- kiln_ai-0.6.1.dist-info/RECORD +0 -37
- {kiln_ai-0.6.1.dist-info → kiln_ai-0.7.1.dist-info}/licenses/LICENSE.txt +0 -0
kiln_ai/adapters/__init__.py
CHANGED
|
@@ -15,6 +15,7 @@ The repair submodule contains an adapter for the repair task.
|
|
|
15
15
|
from . import (
|
|
16
16
|
base_adapter,
|
|
17
17
|
data_gen,
|
|
18
|
+
fine_tune,
|
|
18
19
|
langchain_adapters,
|
|
19
20
|
ml_model_list,
|
|
20
21
|
prompt_builders,
|
|
@@ -28,4 +29,5 @@ __all__ = [
|
|
|
28
29
|
"prompt_builders",
|
|
29
30
|
"repair",
|
|
30
31
|
"data_gen",
|
|
32
|
+
"fine_tune",
|
|
31
33
|
]
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
from kiln_ai import datamodel
|
|
2
|
+
from kiln_ai.adapters.base_adapter import BaseAdapter
|
|
3
|
+
from kiln_ai.adapters.langchain_adapters import LangchainAdapter
|
|
4
|
+
from kiln_ai.adapters.prompt_builders import BasePromptBuilder
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def adapter_for_task(
|
|
8
|
+
kiln_task: datamodel.Task,
|
|
9
|
+
model_name: str | None = None,
|
|
10
|
+
provider: str | None = None,
|
|
11
|
+
prompt_builder: BasePromptBuilder | None = None,
|
|
12
|
+
) -> BaseAdapter:
|
|
13
|
+
# We use langchain for everything right now, but can add any others here
|
|
14
|
+
return LangchainAdapter(
|
|
15
|
+
kiln_task,
|
|
16
|
+
model_name=model_name,
|
|
17
|
+
provider=provider,
|
|
18
|
+
prompt_builder=prompt_builder,
|
|
19
|
+
)
|
|
@@ -2,6 +2,7 @@ import json
|
|
|
2
2
|
|
|
3
3
|
import pytest
|
|
4
4
|
|
|
5
|
+
from kiln_ai.adapters.adapter_registry import adapter_for_task
|
|
5
6
|
from kiln_ai.adapters.data_gen.data_gen_task import (
|
|
6
7
|
DataGenCategoriesTask,
|
|
7
8
|
DataGenCategoriesTaskInput,
|
|
@@ -10,8 +11,7 @@ from kiln_ai.adapters.data_gen.data_gen_task import (
|
|
|
10
11
|
DataGenSampleTaskInput,
|
|
11
12
|
list_json_schema_for_task,
|
|
12
13
|
)
|
|
13
|
-
from kiln_ai.adapters.
|
|
14
|
-
from kiln_ai.adapters.ml_model_list import get_model_and_provider
|
|
14
|
+
from kiln_ai.adapters.provider_tools import get_model_and_provider
|
|
15
15
|
from kiln_ai.adapters.test_prompt_adaptors import get_all_models_and_providers
|
|
16
16
|
from kiln_ai.datamodel import Project, Task
|
|
17
17
|
|
|
@@ -108,7 +108,7 @@ async def test_data_gen_all_models_providers(
|
|
|
108
108
|
data_gen_task = DataGenCategoriesTask()
|
|
109
109
|
data_gen_input = DataGenCategoriesTaskInput.from_task(base_task, num_subtopics=6)
|
|
110
110
|
|
|
111
|
-
adapter =
|
|
111
|
+
adapter = adapter_for_task(
|
|
112
112
|
data_gen_task,
|
|
113
113
|
model_name=model_name,
|
|
114
114
|
provider=provider_name,
|
|
@@ -232,7 +232,7 @@ async def test_data_gen_sample_all_models_providers(
|
|
|
232
232
|
base_task, topic=["riding horses"], num_samples=4
|
|
233
233
|
)
|
|
234
234
|
|
|
235
|
-
adapter =
|
|
235
|
+
adapter = adapter_for_task(
|
|
236
236
|
data_gen_task,
|
|
237
237
|
model_name=model_name,
|
|
238
238
|
provider=provider_name,
|
|
@@ -251,17 +251,25 @@ async def test_data_gen_sample_all_models_providers(
|
|
|
251
251
|
@pytest.mark.ollama
|
|
252
252
|
@pytest.mark.parametrize("model_name,provider_name", get_all_models_and_providers())
|
|
253
253
|
async def test_data_gen_sample_all_models_providers_with_structured_output(
|
|
254
|
-
tmp_path, model_name, provider_name
|
|
254
|
+
tmp_path, model_name, provider_name
|
|
255
255
|
):
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
256
|
+
project = Project(name="TestProject")
|
|
257
|
+
task = Task(
|
|
258
|
+
name="Summarize",
|
|
259
|
+
parent=project,
|
|
260
|
+
description="Explain if the username matches the tweet",
|
|
261
|
+
instruction="Explain if the username matches the tweet",
|
|
262
|
+
requirements=[],
|
|
263
|
+
input_json_schema=json.dumps(
|
|
264
|
+
{
|
|
265
|
+
"type": "object",
|
|
266
|
+
"properties": {
|
|
267
|
+
"username": {"type": "string"},
|
|
268
|
+
"tweet": {"type": "string"},
|
|
269
|
+
},
|
|
270
|
+
"required": ["username", "tweet"],
|
|
271
|
+
}
|
|
272
|
+
),
|
|
265
273
|
)
|
|
266
274
|
|
|
267
275
|
_, provider = get_model_and_provider(model_name, provider_name)
|
|
@@ -269,12 +277,12 @@ async def test_data_gen_sample_all_models_providers_with_structured_output(
|
|
|
269
277
|
# pass if the model doesn't support data gen (testing the support flag is part of this)
|
|
270
278
|
return
|
|
271
279
|
|
|
272
|
-
data_gen_task = DataGenSampleTask(target_task=
|
|
280
|
+
data_gen_task = DataGenSampleTask(target_task=task)
|
|
273
281
|
data_gen_input = DataGenSampleTaskInput.from_task(
|
|
274
|
-
|
|
282
|
+
task, topic=["Food"], num_samples=4
|
|
275
283
|
)
|
|
276
284
|
|
|
277
|
-
adapter =
|
|
285
|
+
adapter = adapter_for_task(
|
|
278
286
|
data_gen_task,
|
|
279
287
|
model_name=model_name,
|
|
280
288
|
provider=provider_name,
|
|
@@ -287,7 +295,7 @@ async def test_data_gen_sample_all_models_providers_with_structured_output(
|
|
|
287
295
|
assert len(samples) == 4
|
|
288
296
|
for sample in samples:
|
|
289
297
|
assert isinstance(sample, dict)
|
|
290
|
-
assert "
|
|
291
|
-
assert "
|
|
292
|
-
assert isinstance(sample["
|
|
293
|
-
assert isinstance(sample["
|
|
298
|
+
assert "username" in sample
|
|
299
|
+
assert "tweet" in sample
|
|
300
|
+
assert isinstance(sample["username"], str)
|
|
301
|
+
assert isinstance(sample["tweet"], str)
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
"""
|
|
2
|
+
# Fine-Tuning
|
|
3
|
+
|
|
4
|
+
A set of classes for fine-tuning models.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from . import base_finetune, dataset_formatter, finetune_registry, openai_finetune
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
"base_finetune",
|
|
11
|
+
"openai_finetune",
|
|
12
|
+
"dataset_formatter",
|
|
13
|
+
"finetune_registry",
|
|
14
|
+
]
|
|
@@ -0,0 +1,186 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import Literal
|
|
3
|
+
|
|
4
|
+
from pydantic import BaseModel
|
|
5
|
+
|
|
6
|
+
from kiln_ai.adapters.ml_model_list import built_in_models
|
|
7
|
+
from kiln_ai.datamodel import DatasetSplit, FineTuneStatusType
|
|
8
|
+
from kiln_ai.datamodel import Finetune as FinetuneModel
|
|
9
|
+
from kiln_ai.utils.name_generator import generate_memorable_name
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class FineTuneStatus(BaseModel):
|
|
13
|
+
"""
|
|
14
|
+
The status of a fine-tune, including a user friendly message.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
status: FineTuneStatusType
|
|
18
|
+
message: str | None = None
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class FineTuneParameter(BaseModel):
|
|
22
|
+
"""
|
|
23
|
+
A parameter for a fine-tune. Hyperparameters, etc.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
name: str
|
|
27
|
+
type: Literal["string", "int", "float", "bool"]
|
|
28
|
+
description: str
|
|
29
|
+
optional: bool = True
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
TYPE_MAP = {
|
|
33
|
+
"string": str,
|
|
34
|
+
"int": int,
|
|
35
|
+
"float": float,
|
|
36
|
+
"bool": bool,
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class BaseFinetuneAdapter(ABC):
|
|
41
|
+
"""
|
|
42
|
+
A base class for fine-tuning adapters.
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
def __init__(
|
|
46
|
+
self,
|
|
47
|
+
datamodel: FinetuneModel,
|
|
48
|
+
):
|
|
49
|
+
self.datamodel = datamodel
|
|
50
|
+
|
|
51
|
+
@classmethod
|
|
52
|
+
async def create_and_start(
|
|
53
|
+
cls,
|
|
54
|
+
dataset: DatasetSplit,
|
|
55
|
+
provider_id: str,
|
|
56
|
+
provider_base_model_id: str,
|
|
57
|
+
train_split_name: str,
|
|
58
|
+
system_message: str,
|
|
59
|
+
parameters: dict[str, str | int | float | bool] = {},
|
|
60
|
+
name: str | None = None,
|
|
61
|
+
description: str | None = None,
|
|
62
|
+
validation_split_name: str | None = None,
|
|
63
|
+
) -> tuple["BaseFinetuneAdapter", FinetuneModel]:
|
|
64
|
+
"""
|
|
65
|
+
Create and start a fine-tune.
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
cls.check_valid_provider_model(provider_id, provider_base_model_id)
|
|
69
|
+
|
|
70
|
+
if not dataset.id:
|
|
71
|
+
raise ValueError("Dataset must have an id")
|
|
72
|
+
|
|
73
|
+
if train_split_name not in dataset.split_contents:
|
|
74
|
+
raise ValueError(f"Train split {train_split_name} not found in dataset")
|
|
75
|
+
|
|
76
|
+
if (
|
|
77
|
+
validation_split_name
|
|
78
|
+
and validation_split_name not in dataset.split_contents
|
|
79
|
+
):
|
|
80
|
+
raise ValueError(
|
|
81
|
+
f"Validation split {validation_split_name} not found in dataset"
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
# Default name if not provided
|
|
85
|
+
if name is None:
|
|
86
|
+
name = generate_memorable_name()
|
|
87
|
+
|
|
88
|
+
cls.validate_parameters(parameters)
|
|
89
|
+
parent_task = dataset.parent_task()
|
|
90
|
+
if parent_task is None or not parent_task.path:
|
|
91
|
+
raise ValueError("Dataset must have a parent task with a path")
|
|
92
|
+
|
|
93
|
+
datamodel = FinetuneModel(
|
|
94
|
+
name=name,
|
|
95
|
+
description=description,
|
|
96
|
+
provider=provider_id,
|
|
97
|
+
base_model_id=provider_base_model_id,
|
|
98
|
+
dataset_split_id=dataset.id,
|
|
99
|
+
train_split_name=train_split_name,
|
|
100
|
+
validation_split_name=validation_split_name,
|
|
101
|
+
parameters=parameters,
|
|
102
|
+
system_message=system_message,
|
|
103
|
+
parent=parent_task,
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
adapter = cls(datamodel)
|
|
107
|
+
await adapter._start(dataset)
|
|
108
|
+
|
|
109
|
+
datamodel.save_to_file()
|
|
110
|
+
|
|
111
|
+
return adapter, datamodel
|
|
112
|
+
|
|
113
|
+
@abstractmethod
|
|
114
|
+
async def _start(self, dataset: DatasetSplit) -> None:
|
|
115
|
+
"""
|
|
116
|
+
Start the fine-tune.
|
|
117
|
+
"""
|
|
118
|
+
pass
|
|
119
|
+
|
|
120
|
+
@abstractmethod
|
|
121
|
+
async def status(self) -> FineTuneStatus:
|
|
122
|
+
"""
|
|
123
|
+
Get the status of the fine-tune.
|
|
124
|
+
"""
|
|
125
|
+
pass
|
|
126
|
+
|
|
127
|
+
@classmethod
|
|
128
|
+
def available_parameters(cls) -> list[FineTuneParameter]:
|
|
129
|
+
"""
|
|
130
|
+
Returns a list of parameters that can be provided for this fine-tune. Includes hyperparameters, etc.
|
|
131
|
+
"""
|
|
132
|
+
return []
|
|
133
|
+
|
|
134
|
+
@classmethod
|
|
135
|
+
def validate_parameters(
|
|
136
|
+
cls, parameters: dict[str, str | int | float | bool]
|
|
137
|
+
) -> None:
|
|
138
|
+
"""
|
|
139
|
+
Validate the parameters for this fine-tune.
|
|
140
|
+
"""
|
|
141
|
+
# Check required parameters and parameter types
|
|
142
|
+
available_parameters = cls.available_parameters()
|
|
143
|
+
for parameter in available_parameters:
|
|
144
|
+
if not parameter.optional and parameter.name not in parameters:
|
|
145
|
+
raise ValueError(f"Parameter {parameter.name} is required")
|
|
146
|
+
elif parameter.name in parameters:
|
|
147
|
+
# check parameter is correct type
|
|
148
|
+
expected_type = TYPE_MAP[parameter.type]
|
|
149
|
+
value = parameters[parameter.name]
|
|
150
|
+
|
|
151
|
+
# Strict type checking for numeric types
|
|
152
|
+
if expected_type is float and not isinstance(value, float):
|
|
153
|
+
raise ValueError(
|
|
154
|
+
f"Parameter {parameter.name} must be a float, got {type(value)}"
|
|
155
|
+
)
|
|
156
|
+
elif expected_type is int and not isinstance(value, int):
|
|
157
|
+
raise ValueError(
|
|
158
|
+
f"Parameter {parameter.name} must be an integer, got {type(value)}"
|
|
159
|
+
)
|
|
160
|
+
elif not isinstance(value, expected_type):
|
|
161
|
+
raise ValueError(
|
|
162
|
+
f"Parameter {parameter.name} must be type {expected_type}, got {type(value)}"
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
allowed_parameters = [p.name for p in available_parameters]
|
|
166
|
+
for parameter_key in parameters:
|
|
167
|
+
if parameter_key not in allowed_parameters:
|
|
168
|
+
raise ValueError(f"Parameter {parameter_key} is not available")
|
|
169
|
+
|
|
170
|
+
@classmethod
|
|
171
|
+
def check_valid_provider_model(
|
|
172
|
+
cls, provider_id: str, provider_base_model_id: str
|
|
173
|
+
) -> None:
|
|
174
|
+
"""
|
|
175
|
+
Check if the provider and base model are valid.
|
|
176
|
+
"""
|
|
177
|
+
for model in built_in_models:
|
|
178
|
+
for provider in model.providers:
|
|
179
|
+
if (
|
|
180
|
+
provider.name == provider_id
|
|
181
|
+
and provider.provider_finetune_id == provider_base_model_id
|
|
182
|
+
):
|
|
183
|
+
return
|
|
184
|
+
raise ValueError(
|
|
185
|
+
f"Provider {provider_id} with base model {provider_base_model_id} is not available"
|
|
186
|
+
)
|
|
@@ -0,0 +1,187 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import tempfile
|
|
3
|
+
from enum import Enum
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any, Dict, Protocol
|
|
6
|
+
from uuid import uuid4
|
|
7
|
+
|
|
8
|
+
from kiln_ai.datamodel import DatasetSplit, TaskRun
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class DatasetFormat(str, Enum):
|
|
12
|
+
"""Formats for dataset generation. Both for file format (like JSONL), and internal structure (like chat/toolcall)"""
|
|
13
|
+
|
|
14
|
+
"""OpenAI chat format with plaintext response"""
|
|
15
|
+
OPENAI_CHAT_JSONL = "openai_chat_jsonl"
|
|
16
|
+
|
|
17
|
+
"""OpenAI chat format with tool call response"""
|
|
18
|
+
OPENAI_CHAT_TOOLCALL_JSONL = "openai_chat_toolcall_jsonl"
|
|
19
|
+
|
|
20
|
+
"""HuggingFace chat template in JSONL"""
|
|
21
|
+
HUGGINGFACE_CHAT_TEMPLATE_JSONL = "huggingface_chat_template_jsonl"
|
|
22
|
+
|
|
23
|
+
"""HuggingFace chat template with tool calls in JSONL"""
|
|
24
|
+
HUGGINGFACE_CHAT_TEMPLATE_TOOLCALL_JSONL = (
|
|
25
|
+
"huggingface_chat_template_toolcall_jsonl"
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class FormatGenerator(Protocol):
|
|
30
|
+
"""Protocol for format generators"""
|
|
31
|
+
|
|
32
|
+
def __call__(self, task_run: TaskRun, system_message: str) -> Dict[str, Any]: ...
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def generate_chat_message_response(
|
|
36
|
+
task_run: TaskRun, system_message: str
|
|
37
|
+
) -> Dict[str, Any]:
|
|
38
|
+
"""Generate OpenAI chat format with plaintext response"""
|
|
39
|
+
return {
|
|
40
|
+
"messages": [
|
|
41
|
+
{"role": "system", "content": system_message},
|
|
42
|
+
{"role": "user", "content": task_run.input},
|
|
43
|
+
{"role": "assistant", "content": task_run.output.output},
|
|
44
|
+
]
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def generate_chat_message_toolcall(
|
|
49
|
+
task_run: TaskRun, system_message: str
|
|
50
|
+
) -> Dict[str, Any]:
|
|
51
|
+
"""Generate OpenAI chat format with tool call response"""
|
|
52
|
+
try:
|
|
53
|
+
arguments = json.loads(task_run.output.output)
|
|
54
|
+
except json.JSONDecodeError as e:
|
|
55
|
+
raise ValueError(f"Invalid JSON in for tool call: {e}") from e
|
|
56
|
+
|
|
57
|
+
return {
|
|
58
|
+
"messages": [
|
|
59
|
+
{"role": "system", "content": system_message},
|
|
60
|
+
{"role": "user", "content": task_run.input},
|
|
61
|
+
{
|
|
62
|
+
"role": "assistant",
|
|
63
|
+
"content": None,
|
|
64
|
+
"tool_calls": [
|
|
65
|
+
{
|
|
66
|
+
"id": "call_1",
|
|
67
|
+
"type": "function",
|
|
68
|
+
"function": {
|
|
69
|
+
"name": "task_response",
|
|
70
|
+
# Yes we parse then dump again. This ensures it's valid JSON, and ensures it goes to 1 line
|
|
71
|
+
"arguments": json.dumps(arguments),
|
|
72
|
+
},
|
|
73
|
+
}
|
|
74
|
+
],
|
|
75
|
+
},
|
|
76
|
+
]
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def generate_huggingface_chat_template(
|
|
81
|
+
task_run: TaskRun, system_message: str
|
|
82
|
+
) -> Dict[str, Any]:
|
|
83
|
+
"""Generate HuggingFace chat template"""
|
|
84
|
+
return {
|
|
85
|
+
"conversations": [
|
|
86
|
+
{"role": "system", "content": system_message},
|
|
87
|
+
{"role": "user", "content": task_run.input},
|
|
88
|
+
{"role": "assistant", "content": task_run.output.output},
|
|
89
|
+
]
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def generate_huggingface_chat_template_toolcall(
|
|
94
|
+
task_run: TaskRun, system_message: str
|
|
95
|
+
) -> Dict[str, Any]:
|
|
96
|
+
"""Generate HuggingFace chat template with tool calls"""
|
|
97
|
+
try:
|
|
98
|
+
arguments = json.loads(task_run.output.output)
|
|
99
|
+
except json.JSONDecodeError as e:
|
|
100
|
+
raise ValueError(f"Invalid JSON in for tool call: {e}") from e
|
|
101
|
+
|
|
102
|
+
# See https://huggingface.co/docs/transformers/en/chat_templating
|
|
103
|
+
return {
|
|
104
|
+
"conversations": [
|
|
105
|
+
{"role": "system", "content": system_message},
|
|
106
|
+
{"role": "user", "content": task_run.input},
|
|
107
|
+
{
|
|
108
|
+
"role": "assistant",
|
|
109
|
+
"tool_calls": [
|
|
110
|
+
{
|
|
111
|
+
"type": "function",
|
|
112
|
+
"function": {
|
|
113
|
+
"name": "task_response",
|
|
114
|
+
"id": str(uuid4()).replace("-", "")[:9],
|
|
115
|
+
"arguments": arguments,
|
|
116
|
+
},
|
|
117
|
+
}
|
|
118
|
+
],
|
|
119
|
+
},
|
|
120
|
+
]
|
|
121
|
+
}
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
FORMAT_GENERATORS: Dict[DatasetFormat, FormatGenerator] = {
|
|
125
|
+
DatasetFormat.OPENAI_CHAT_JSONL: generate_chat_message_response,
|
|
126
|
+
DatasetFormat.OPENAI_CHAT_TOOLCALL_JSONL: generate_chat_message_toolcall,
|
|
127
|
+
DatasetFormat.HUGGINGFACE_CHAT_TEMPLATE_JSONL: generate_huggingface_chat_template,
|
|
128
|
+
DatasetFormat.HUGGINGFACE_CHAT_TEMPLATE_TOOLCALL_JSONL: generate_huggingface_chat_template_toolcall,
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
class DatasetFormatter:
|
|
133
|
+
"""Handles formatting of datasets into various output formats"""
|
|
134
|
+
|
|
135
|
+
def __init__(self, dataset: DatasetSplit, system_message: str):
|
|
136
|
+
self.dataset = dataset
|
|
137
|
+
self.system_message = system_message
|
|
138
|
+
|
|
139
|
+
task = dataset.parent_task()
|
|
140
|
+
if task is None:
|
|
141
|
+
raise ValueError("Dataset has no parent task")
|
|
142
|
+
self.task = task
|
|
143
|
+
|
|
144
|
+
def dump_to_file(
|
|
145
|
+
self, split_name: str, format_type: DatasetFormat, path: Path | None = None
|
|
146
|
+
) -> Path:
|
|
147
|
+
"""
|
|
148
|
+
Format the dataset into the specified format.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
split_name: Name of the split to dump
|
|
152
|
+
format_type: Format to generate the dataset in
|
|
153
|
+
path: Optional path to write to. If None, writes to temp directory
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
Path to the generated file
|
|
157
|
+
"""
|
|
158
|
+
if format_type not in FORMAT_GENERATORS:
|
|
159
|
+
raise ValueError(f"Unsupported format: {format_type}")
|
|
160
|
+
if split_name not in self.dataset.split_contents:
|
|
161
|
+
raise ValueError(f"Split {split_name} not found in dataset")
|
|
162
|
+
|
|
163
|
+
generator = FORMAT_GENERATORS[format_type]
|
|
164
|
+
|
|
165
|
+
# Write to a temp file if no path is provided
|
|
166
|
+
output_path = (
|
|
167
|
+
path
|
|
168
|
+
or Path(tempfile.gettempdir())
|
|
169
|
+
/ f"{self.dataset.name}_{split_name}_{format_type}.jsonl"
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
runs = self.task.runs()
|
|
173
|
+
runs_by_id = {run.id: run for run in runs}
|
|
174
|
+
|
|
175
|
+
# Generate formatted output with UTF-8 encoding
|
|
176
|
+
with open(output_path, "w", encoding="utf-8") as f:
|
|
177
|
+
for run_id in self.dataset.split_contents[split_name]:
|
|
178
|
+
task_run = runs_by_id[run_id]
|
|
179
|
+
if task_run is None:
|
|
180
|
+
raise ValueError(
|
|
181
|
+
f"Task run {run_id} not found. This is required by this dataset."
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
example = generator(task_run, self.system_message)
|
|
185
|
+
f.write(json.dumps(example) + "\n")
|
|
186
|
+
|
|
187
|
+
return output_path
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
from typing import Type
|
|
2
|
+
|
|
3
|
+
from kiln_ai.adapters.fine_tune.base_finetune import BaseFinetuneAdapter
|
|
4
|
+
from kiln_ai.adapters.fine_tune.fireworks_finetune import FireworksFinetune
|
|
5
|
+
from kiln_ai.adapters.fine_tune.openai_finetune import OpenAIFinetune
|
|
6
|
+
from kiln_ai.adapters.ml_model_list import ModelProviderName
|
|
7
|
+
|
|
8
|
+
finetune_registry: dict[ModelProviderName, Type[BaseFinetuneAdapter]] = {
|
|
9
|
+
ModelProviderName.openai: OpenAIFinetune,
|
|
10
|
+
ModelProviderName.fireworks_ai: FireworksFinetune,
|
|
11
|
+
}
|