kiln-ai 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (42) hide show
  1. kiln_ai/adapters/__init__.py +11 -1
  2. kiln_ai/adapters/adapter_registry.py +19 -0
  3. kiln_ai/adapters/data_gen/__init__.py +11 -0
  4. kiln_ai/adapters/data_gen/data_gen_task.py +69 -1
  5. kiln_ai/adapters/data_gen/test_data_gen_task.py +30 -21
  6. kiln_ai/adapters/fine_tune/__init__.py +14 -0
  7. kiln_ai/adapters/fine_tune/base_finetune.py +186 -0
  8. kiln_ai/adapters/fine_tune/dataset_formatter.py +187 -0
  9. kiln_ai/adapters/fine_tune/finetune_registry.py +11 -0
  10. kiln_ai/adapters/fine_tune/fireworks_finetune.py +308 -0
  11. kiln_ai/adapters/fine_tune/openai_finetune.py +205 -0
  12. kiln_ai/adapters/fine_tune/test_base_finetune.py +290 -0
  13. kiln_ai/adapters/fine_tune/test_dataset_formatter.py +342 -0
  14. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +455 -0
  15. kiln_ai/adapters/fine_tune/test_openai_finetune.py +503 -0
  16. kiln_ai/adapters/langchain_adapters.py +103 -13
  17. kiln_ai/adapters/ml_model_list.py +218 -304
  18. kiln_ai/adapters/ollama_tools.py +114 -0
  19. kiln_ai/adapters/provider_tools.py +295 -0
  20. kiln_ai/adapters/repair/test_repair_task.py +6 -11
  21. kiln_ai/adapters/test_langchain_adapter.py +46 -18
  22. kiln_ai/adapters/test_ollama_tools.py +42 -0
  23. kiln_ai/adapters/test_prompt_adaptors.py +7 -5
  24. kiln_ai/adapters/test_provider_tools.py +312 -0
  25. kiln_ai/adapters/test_structured_output.py +22 -43
  26. kiln_ai/datamodel/__init__.py +235 -22
  27. kiln_ai/datamodel/basemodel.py +30 -0
  28. kiln_ai/datamodel/registry.py +31 -0
  29. kiln_ai/datamodel/test_basemodel.py +29 -1
  30. kiln_ai/datamodel/test_dataset_split.py +234 -0
  31. kiln_ai/datamodel/test_example_models.py +12 -0
  32. kiln_ai/datamodel/test_models.py +91 -1
  33. kiln_ai/datamodel/test_registry.py +96 -0
  34. kiln_ai/utils/config.py +9 -0
  35. kiln_ai/utils/name_generator.py +125 -0
  36. kiln_ai/utils/test_name_geneator.py +47 -0
  37. {kiln_ai-0.6.0.dist-info → kiln_ai-0.7.0.dist-info}/METADATA +4 -2
  38. kiln_ai-0.7.0.dist-info/RECORD +56 -0
  39. kiln_ai/adapters/test_ml_model_list.py +0 -181
  40. kiln_ai-0.6.0.dist-info/RECORD +0 -36
  41. {kiln_ai-0.6.0.dist-info → kiln_ai-0.7.0.dist-info}/WHEEL +0 -0
  42. {kiln_ai-0.6.0.dist-info → kiln_ai-0.7.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -12,7 +12,15 @@ The prompt_builders submodule contains classes that build prompts for use with t
12
12
  The repair submodule contains an adapter for the repair task.
13
13
  """
14
14
 
15
- from . import base_adapter, langchain_adapters, ml_model_list, prompt_builders, repair
15
+ from . import (
16
+ base_adapter,
17
+ data_gen,
18
+ fine_tune,
19
+ langchain_adapters,
20
+ ml_model_list,
21
+ prompt_builders,
22
+ repair,
23
+ )
16
24
 
17
25
  __all__ = [
18
26
  "base_adapter",
@@ -20,4 +28,6 @@ __all__ = [
20
28
  "ml_model_list",
21
29
  "prompt_builders",
22
30
  "repair",
31
+ "data_gen",
32
+ "fine_tune",
23
33
  ]
@@ -0,0 +1,19 @@
1
+ from kiln_ai import datamodel
2
+ from kiln_ai.adapters.base_adapter import BaseAdapter
3
+ from kiln_ai.adapters.langchain_adapters import LangchainAdapter
4
+ from kiln_ai.adapters.prompt_builders import BasePromptBuilder
5
+
6
+
7
+ def adapter_for_task(
8
+ kiln_task: datamodel.Task,
9
+ model_name: str | None = None,
10
+ provider: str | None = None,
11
+ prompt_builder: BasePromptBuilder | None = None,
12
+ ) -> BaseAdapter:
13
+ # We use langchain for everything right now, but can add any others here
14
+ return LangchainAdapter(
15
+ kiln_task,
16
+ model_name=model_name,
17
+ provider=provider,
18
+ prompt_builder=prompt_builder,
19
+ )
@@ -0,0 +1,11 @@
1
+ """
2
+ # Data Generation
3
+
4
+ A task to generate synthetic data for Kiln Tasks. This generates the inputs, which then can be run through the task.
5
+
6
+ Optional human guidance can be provided to guide the generation process.
7
+ """
8
+
9
+ from . import data_gen_task
10
+
11
+ __all__ = ["data_gen_task"]
@@ -1,8 +1,9 @@
1
1
  import json
2
2
 
3
+ from pydantic import BaseModel
4
+
3
5
  from kiln_ai.adapters.prompt_builders import SimplePromptBuilder
4
6
  from kiln_ai.datamodel import Project, Task
5
- from pydantic import BaseModel
6
7
 
7
8
  from .data_gen_prompts import (
8
9
  SAMPLE_GENERATION_PROMPT,
@@ -11,6 +12,16 @@ from .data_gen_prompts import (
11
12
 
12
13
 
13
14
  class DataGenCategoriesTaskInput(BaseModel):
15
+ """Input model for generating categories/subtopics.
16
+
17
+ Attributes:
18
+ node_path: List of strings representing the hierarchical path to current node
19
+ system_prompt: System prompt to guide the AI generation
20
+ num_subtopics: Number of subtopics to generate
21
+ human_guidance: Optional human guidance to influence generation
22
+ existing_topics: Optional list of existing topics to avoid duplication
23
+ """
24
+
14
25
  node_path: list[str]
15
26
  system_prompt: str
16
27
  num_subtopics: int
@@ -26,6 +37,18 @@ class DataGenCategoriesTaskInput(BaseModel):
26
37
  human_guidance: str | None = None,
27
38
  existing_topics: list[str] | None = None,
28
39
  ) -> "DataGenCategoriesTaskInput":
40
+ """Create a DataGenCategoriesTaskInput instance from a Task.
41
+
42
+ Args:
43
+ task: The source Task object
44
+ node_path: Path to current node in topic hierarchy
45
+ num_subtopics: Number of subtopics to generate
46
+ human_guidance: Optional guidance for generation
47
+ existing_topics: Optional list of existing topics
48
+
49
+ Returns:
50
+ A new DataGenCategoriesTaskInput instance
51
+ """
29
52
  prompt_builder = SimplePromptBuilder(task=task)
30
53
  return cls(
31
54
  node_path=node_path,
@@ -37,10 +60,22 @@ class DataGenCategoriesTaskInput(BaseModel):
37
60
 
38
61
 
39
62
  class DataGenCategoriesTaskOutput(BaseModel):
63
+ """Output model for generated categories/subtopics.
64
+
65
+ Attributes:
66
+ subtopics: List of generated subtopic strings
67
+ """
68
+
40
69
  subtopics: list[str]
41
70
 
42
71
 
43
72
  class DataGenCategoriesTask(Task, parent_of={}):
73
+ """Task for generating hierarchical categories/subtopics.
74
+
75
+ Generates synthetic data categories which can be used to generate
76
+ training data for model learning.
77
+ """
78
+
44
79
  def __init__(self):
45
80
  # Keep the typechecker happy. TODO: shouldn't need this or parent_of above.
46
81
  tmp_project = Project(name="DataGen")
@@ -59,6 +94,15 @@ class DataGenCategoriesTask(Task, parent_of={}):
59
94
 
60
95
 
61
96
  class DataGenSampleTaskInput(BaseModel):
97
+ """Input model for generating data samples for a kiln task.
98
+
99
+ Attributes:
100
+ topic: List of strings representing the topic path
101
+ system_prompt: System prompt to guide the AI generation
102
+ num_samples: Number of samples to generate
103
+ human_guidance: Optional human guidance to influence generation
104
+ """
105
+
62
106
  topic: list[str]
63
107
  system_prompt: str
64
108
  num_samples: int
@@ -72,6 +116,17 @@ class DataGenSampleTaskInput(BaseModel):
72
116
  num_samples: int = 8,
73
117
  human_guidance: str | None = None,
74
118
  ) -> "DataGenSampleTaskInput":
119
+ """Create a DataGenSampleTaskInput instance from a Task.
120
+
121
+ Args:
122
+ task: The source Task object
123
+ topic: Topic path for sample generation
124
+ num_samples: Number of samples to generate
125
+ human_guidance: Optional guidance for generation
126
+
127
+ Returns:
128
+ A new DataGenSampleTaskInput instance
129
+ """
75
130
  prompt_builder = SimplePromptBuilder(task=task)
76
131
  return cls(
77
132
  topic=topic,
@@ -82,6 +137,14 @@ class DataGenSampleTaskInput(BaseModel):
82
137
 
83
138
 
84
139
  def list_json_schema_for_task(task: Task) -> str:
140
+ """Generate a JSON schema for a list of task inputs (json schema)
141
+
142
+ Args:
143
+ task: Task object whose input schema will be used
144
+
145
+ Returns:
146
+ JSON string representing the schema for a list of task inputs
147
+ """
85
148
  if task.input_json_schema:
86
149
  items_schema = json.loads(task.input_json_schema)
87
150
  else:
@@ -104,6 +167,11 @@ def list_json_schema_for_task(task: Task) -> str:
104
167
 
105
168
 
106
169
  class DataGenSampleTask(Task, parent_of={}):
170
+ """Task for generating data samples for a given topic.
171
+
172
+ Generates synthetic data samples based on provided topics and subtopics.
173
+ """
174
+
107
175
  def __init__(self, target_task: Task, num_samples: int = 8):
108
176
  # Keep the typechecker happy. TODO: shouldn't need this or parent_of above.
109
177
  tmp_project = Project(name="DataGenSample")
@@ -1,6 +1,8 @@
1
1
  import json
2
2
 
3
3
  import pytest
4
+
5
+ from kiln_ai.adapters.adapter_registry import adapter_for_task
4
6
  from kiln_ai.adapters.data_gen.data_gen_task import (
5
7
  DataGenCategoriesTask,
6
8
  DataGenCategoriesTaskInput,
@@ -9,8 +11,7 @@ from kiln_ai.adapters.data_gen.data_gen_task import (
9
11
  DataGenSampleTaskInput,
10
12
  list_json_schema_for_task,
11
13
  )
12
- from kiln_ai.adapters.langchain_adapters import LangChainPromptAdapter
13
- from kiln_ai.adapters.ml_model_list import get_model_and_provider
14
+ from kiln_ai.adapters.provider_tools import get_model_and_provider
14
15
  from kiln_ai.adapters.test_prompt_adaptors import get_all_models_and_providers
15
16
  from kiln_ai.datamodel import Project, Task
16
17
 
@@ -107,7 +108,7 @@ async def test_data_gen_all_models_providers(
107
108
  data_gen_task = DataGenCategoriesTask()
108
109
  data_gen_input = DataGenCategoriesTaskInput.from_task(base_task, num_subtopics=6)
109
110
 
110
- adapter = LangChainPromptAdapter(
111
+ adapter = adapter_for_task(
111
112
  data_gen_task,
112
113
  model_name=model_name,
113
114
  provider=provider_name,
@@ -231,7 +232,7 @@ async def test_data_gen_sample_all_models_providers(
231
232
  base_task, topic=["riding horses"], num_samples=4
232
233
  )
233
234
 
234
- adapter = LangChainPromptAdapter(
235
+ adapter = adapter_for_task(
235
236
  data_gen_task,
236
237
  model_name=model_name,
237
238
  provider=provider_name,
@@ -250,17 +251,25 @@ async def test_data_gen_sample_all_models_providers(
250
251
  @pytest.mark.ollama
251
252
  @pytest.mark.parametrize("model_name,provider_name", get_all_models_and_providers())
252
253
  async def test_data_gen_sample_all_models_providers_with_structured_output(
253
- tmp_path, model_name, provider_name, base_task
254
+ tmp_path, model_name, provider_name
254
255
  ):
255
- base_task.output_json_schema = json.dumps(
256
- {
257
- "type": "object",
258
- "properties": {
259
- "opening": {"type": "string"},
260
- "closing": {"type": "string"},
261
- },
262
- "required": ["opening", "closing"],
263
- }
256
+ project = Project(name="TestProject")
257
+ task = Task(
258
+ name="Summarize",
259
+ parent=project,
260
+ description="Explain if the username matches the tweet",
261
+ instruction="Explain if the username matches the tweet",
262
+ requirements=[],
263
+ input_json_schema=json.dumps(
264
+ {
265
+ "type": "object",
266
+ "properties": {
267
+ "username": {"type": "string"},
268
+ "tweet": {"type": "string"},
269
+ },
270
+ "required": ["username", "tweet"],
271
+ }
272
+ ),
264
273
  )
265
274
 
266
275
  _, provider = get_model_and_provider(model_name, provider_name)
@@ -268,12 +277,12 @@ async def test_data_gen_sample_all_models_providers_with_structured_output(
268
277
  # pass if the model doesn't support data gen (testing the support flag is part of this)
269
278
  return
270
279
 
271
- data_gen_task = DataGenSampleTask(target_task=base_task)
280
+ data_gen_task = DataGenSampleTask(target_task=task)
272
281
  data_gen_input = DataGenSampleTaskInput.from_task(
273
- base_task, topic=["riding horses"], num_samples=4
282
+ task, topic=["Food"], num_samples=4
274
283
  )
275
284
 
276
- adapter = LangChainPromptAdapter(
285
+ adapter = adapter_for_task(
277
286
  data_gen_task,
278
287
  model_name=model_name,
279
288
  provider=provider_name,
@@ -286,7 +295,7 @@ async def test_data_gen_sample_all_models_providers_with_structured_output(
286
295
  assert len(samples) == 4
287
296
  for sample in samples:
288
297
  assert isinstance(sample, dict)
289
- assert "opening" in sample
290
- assert "closing" in sample
291
- assert isinstance(sample["opening"], str)
292
- assert isinstance(sample["closing"], str)
298
+ assert "username" in sample
299
+ assert "tweet" in sample
300
+ assert isinstance(sample["username"], str)
301
+ assert isinstance(sample["tweet"], str)
@@ -0,0 +1,14 @@
1
+ """
2
+ # Fine-Tuning
3
+
4
+ A set of classes for fine-tuning models.
5
+ """
6
+
7
+ from . import base_finetune, dataset_formatter, finetune_registry, openai_finetune
8
+
9
+ __all__ = [
10
+ "base_finetune",
11
+ "openai_finetune",
12
+ "dataset_formatter",
13
+ "finetune_registry",
14
+ ]
@@ -0,0 +1,186 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Literal
3
+
4
+ from pydantic import BaseModel
5
+
6
+ from kiln_ai.adapters.ml_model_list import built_in_models
7
+ from kiln_ai.datamodel import DatasetSplit, FineTuneStatusType
8
+ from kiln_ai.datamodel import Finetune as FinetuneModel
9
+ from kiln_ai.utils.name_generator import generate_memorable_name
10
+
11
+
12
+ class FineTuneStatus(BaseModel):
13
+ """
14
+ The status of a fine-tune, including a user friendly message.
15
+ """
16
+
17
+ status: FineTuneStatusType
18
+ message: str | None = None
19
+
20
+
21
+ class FineTuneParameter(BaseModel):
22
+ """
23
+ A parameter for a fine-tune. Hyperparameters, etc.
24
+ """
25
+
26
+ name: str
27
+ type: Literal["string", "int", "float", "bool"]
28
+ description: str
29
+ optional: bool = True
30
+
31
+
32
+ TYPE_MAP = {
33
+ "string": str,
34
+ "int": int,
35
+ "float": float,
36
+ "bool": bool,
37
+ }
38
+
39
+
40
+ class BaseFinetuneAdapter(ABC):
41
+ """
42
+ A base class for fine-tuning adapters.
43
+ """
44
+
45
+ def __init__(
46
+ self,
47
+ datamodel: FinetuneModel,
48
+ ):
49
+ self.datamodel = datamodel
50
+
51
+ @classmethod
52
+ async def create_and_start(
53
+ cls,
54
+ dataset: DatasetSplit,
55
+ provider_id: str,
56
+ provider_base_model_id: str,
57
+ train_split_name: str,
58
+ system_message: str,
59
+ parameters: dict[str, str | int | float | bool] = {},
60
+ name: str | None = None,
61
+ description: str | None = None,
62
+ validation_split_name: str | None = None,
63
+ ) -> tuple["BaseFinetuneAdapter", FinetuneModel]:
64
+ """
65
+ Create and start a fine-tune.
66
+ """
67
+
68
+ cls.check_valid_provider_model(provider_id, provider_base_model_id)
69
+
70
+ if not dataset.id:
71
+ raise ValueError("Dataset must have an id")
72
+
73
+ if train_split_name not in dataset.split_contents:
74
+ raise ValueError(f"Train split {train_split_name} not found in dataset")
75
+
76
+ if (
77
+ validation_split_name
78
+ and validation_split_name not in dataset.split_contents
79
+ ):
80
+ raise ValueError(
81
+ f"Validation split {validation_split_name} not found in dataset"
82
+ )
83
+
84
+ # Default name if not provided
85
+ if name is None:
86
+ name = generate_memorable_name()
87
+
88
+ cls.validate_parameters(parameters)
89
+ parent_task = dataset.parent_task()
90
+ if parent_task is None or not parent_task.path:
91
+ raise ValueError("Dataset must have a parent task with a path")
92
+
93
+ datamodel = FinetuneModel(
94
+ name=name,
95
+ description=description,
96
+ provider=provider_id,
97
+ base_model_id=provider_base_model_id,
98
+ dataset_split_id=dataset.id,
99
+ train_split_name=train_split_name,
100
+ validation_split_name=validation_split_name,
101
+ parameters=parameters,
102
+ system_message=system_message,
103
+ parent=parent_task,
104
+ )
105
+
106
+ adapter = cls(datamodel)
107
+ await adapter._start(dataset)
108
+
109
+ datamodel.save_to_file()
110
+
111
+ return adapter, datamodel
112
+
113
+ @abstractmethod
114
+ async def _start(self, dataset: DatasetSplit) -> None:
115
+ """
116
+ Start the fine-tune.
117
+ """
118
+ pass
119
+
120
+ @abstractmethod
121
+ async def status(self) -> FineTuneStatus:
122
+ """
123
+ Get the status of the fine-tune.
124
+ """
125
+ pass
126
+
127
+ @classmethod
128
+ def available_parameters(cls) -> list[FineTuneParameter]:
129
+ """
130
+ Returns a list of parameters that can be provided for this fine-tune. Includes hyperparameters, etc.
131
+ """
132
+ return []
133
+
134
+ @classmethod
135
+ def validate_parameters(
136
+ cls, parameters: dict[str, str | int | float | bool]
137
+ ) -> None:
138
+ """
139
+ Validate the parameters for this fine-tune.
140
+ """
141
+ # Check required parameters and parameter types
142
+ available_parameters = cls.available_parameters()
143
+ for parameter in available_parameters:
144
+ if not parameter.optional and parameter.name not in parameters:
145
+ raise ValueError(f"Parameter {parameter.name} is required")
146
+ elif parameter.name in parameters:
147
+ # check parameter is correct type
148
+ expected_type = TYPE_MAP[parameter.type]
149
+ value = parameters[parameter.name]
150
+
151
+ # Strict type checking for numeric types
152
+ if expected_type is float and not isinstance(value, float):
153
+ raise ValueError(
154
+ f"Parameter {parameter.name} must be a float, got {type(value)}"
155
+ )
156
+ elif expected_type is int and not isinstance(value, int):
157
+ raise ValueError(
158
+ f"Parameter {parameter.name} must be an integer, got {type(value)}"
159
+ )
160
+ elif not isinstance(value, expected_type):
161
+ raise ValueError(
162
+ f"Parameter {parameter.name} must be type {expected_type}, got {type(value)}"
163
+ )
164
+
165
+ allowed_parameters = [p.name for p in available_parameters]
166
+ for parameter_key in parameters:
167
+ if parameter_key not in allowed_parameters:
168
+ raise ValueError(f"Parameter {parameter_key} is not available")
169
+
170
+ @classmethod
171
+ def check_valid_provider_model(
172
+ cls, provider_id: str, provider_base_model_id: str
173
+ ) -> None:
174
+ """
175
+ Check if the provider and base model are valid.
176
+ """
177
+ for model in built_in_models:
178
+ for provider in model.providers:
179
+ if (
180
+ provider.name == provider_id
181
+ and provider.provider_finetune_id == provider_base_model_id
182
+ ):
183
+ return
184
+ raise ValueError(
185
+ f"Provider {provider_id} with base model {provider_base_model_id} is not available"
186
+ )
@@ -0,0 +1,187 @@
1
+ import json
2
+ import tempfile
3
+ from enum import Enum
4
+ from pathlib import Path
5
+ from typing import Any, Dict, Protocol
6
+ from uuid import uuid4
7
+
8
+ from kiln_ai.datamodel import DatasetSplit, TaskRun
9
+
10
+
11
+ class DatasetFormat(str, Enum):
12
+ """Formats for dataset generation. Both for file format (like JSONL), and internal structure (like chat/toolcall)"""
13
+
14
+ """OpenAI chat format with plaintext response"""
15
+ OPENAI_CHAT_JSONL = "openai_chat_jsonl"
16
+
17
+ """OpenAI chat format with tool call response"""
18
+ OPENAI_CHAT_TOOLCALL_JSONL = "openai_chat_toolcall_jsonl"
19
+
20
+ """HuggingFace chat template in JSONL"""
21
+ HUGGINGFACE_CHAT_TEMPLATE_JSONL = "huggingface_chat_template_jsonl"
22
+
23
+ """HuggingFace chat template with tool calls in JSONL"""
24
+ HUGGINGFACE_CHAT_TEMPLATE_TOOLCALL_JSONL = (
25
+ "huggingface_chat_template_toolcall_jsonl"
26
+ )
27
+
28
+
29
+ class FormatGenerator(Protocol):
30
+ """Protocol for format generators"""
31
+
32
+ def __call__(self, task_run: TaskRun, system_message: str) -> Dict[str, Any]: ...
33
+
34
+
35
+ def generate_chat_message_response(
36
+ task_run: TaskRun, system_message: str
37
+ ) -> Dict[str, Any]:
38
+ """Generate OpenAI chat format with plaintext response"""
39
+ return {
40
+ "messages": [
41
+ {"role": "system", "content": system_message},
42
+ {"role": "user", "content": task_run.input},
43
+ {"role": "assistant", "content": task_run.output.output},
44
+ ]
45
+ }
46
+
47
+
48
+ def generate_chat_message_toolcall(
49
+ task_run: TaskRun, system_message: str
50
+ ) -> Dict[str, Any]:
51
+ """Generate OpenAI chat format with tool call response"""
52
+ try:
53
+ arguments = json.loads(task_run.output.output)
54
+ except json.JSONDecodeError as e:
55
+ raise ValueError(f"Invalid JSON in for tool call: {e}") from e
56
+
57
+ return {
58
+ "messages": [
59
+ {"role": "system", "content": system_message},
60
+ {"role": "user", "content": task_run.input},
61
+ {
62
+ "role": "assistant",
63
+ "content": None,
64
+ "tool_calls": [
65
+ {
66
+ "id": "call_1",
67
+ "type": "function",
68
+ "function": {
69
+ "name": "task_response",
70
+ # Yes we parse then dump again. This ensures it's valid JSON, and ensures it goes to 1 line
71
+ "arguments": json.dumps(arguments),
72
+ },
73
+ }
74
+ ],
75
+ },
76
+ ]
77
+ }
78
+
79
+
80
+ def generate_huggingface_chat_template(
81
+ task_run: TaskRun, system_message: str
82
+ ) -> Dict[str, Any]:
83
+ """Generate HuggingFace chat template"""
84
+ return {
85
+ "conversations": [
86
+ {"role": "system", "content": system_message},
87
+ {"role": "user", "content": task_run.input},
88
+ {"role": "assistant", "content": task_run.output.output},
89
+ ]
90
+ }
91
+
92
+
93
+ def generate_huggingface_chat_template_toolcall(
94
+ task_run: TaskRun, system_message: str
95
+ ) -> Dict[str, Any]:
96
+ """Generate HuggingFace chat template with tool calls"""
97
+ try:
98
+ arguments = json.loads(task_run.output.output)
99
+ except json.JSONDecodeError as e:
100
+ raise ValueError(f"Invalid JSON in for tool call: {e}") from e
101
+
102
+ # See https://huggingface.co/docs/transformers/en/chat_templating
103
+ return {
104
+ "conversations": [
105
+ {"role": "system", "content": system_message},
106
+ {"role": "user", "content": task_run.input},
107
+ {
108
+ "role": "assistant",
109
+ "tool_calls": [
110
+ {
111
+ "type": "function",
112
+ "function": {
113
+ "name": "task_response",
114
+ "id": str(uuid4()).replace("-", "")[:9],
115
+ "arguments": arguments,
116
+ },
117
+ }
118
+ ],
119
+ },
120
+ ]
121
+ }
122
+
123
+
124
+ FORMAT_GENERATORS: Dict[DatasetFormat, FormatGenerator] = {
125
+ DatasetFormat.OPENAI_CHAT_JSONL: generate_chat_message_response,
126
+ DatasetFormat.OPENAI_CHAT_TOOLCALL_JSONL: generate_chat_message_toolcall,
127
+ DatasetFormat.HUGGINGFACE_CHAT_TEMPLATE_JSONL: generate_huggingface_chat_template,
128
+ DatasetFormat.HUGGINGFACE_CHAT_TEMPLATE_TOOLCALL_JSONL: generate_huggingface_chat_template_toolcall,
129
+ }
130
+
131
+
132
+ class DatasetFormatter:
133
+ """Handles formatting of datasets into various output formats"""
134
+
135
+ def __init__(self, dataset: DatasetSplit, system_message: str):
136
+ self.dataset = dataset
137
+ self.system_message = system_message
138
+
139
+ task = dataset.parent_task()
140
+ if task is None:
141
+ raise ValueError("Dataset has no parent task")
142
+ self.task = task
143
+
144
+ def dump_to_file(
145
+ self, split_name: str, format_type: DatasetFormat, path: Path | None = None
146
+ ) -> Path:
147
+ """
148
+ Format the dataset into the specified format.
149
+
150
+ Args:
151
+ split_name: Name of the split to dump
152
+ format_type: Format to generate the dataset in
153
+ path: Optional path to write to. If None, writes to temp directory
154
+
155
+ Returns:
156
+ Path to the generated file
157
+ """
158
+ if format_type not in FORMAT_GENERATORS:
159
+ raise ValueError(f"Unsupported format: {format_type}")
160
+ if split_name not in self.dataset.split_contents:
161
+ raise ValueError(f"Split {split_name} not found in dataset")
162
+
163
+ generator = FORMAT_GENERATORS[format_type]
164
+
165
+ # Write to a temp file if no path is provided
166
+ output_path = (
167
+ path
168
+ or Path(tempfile.gettempdir())
169
+ / f"{self.dataset.name}_{split_name}_{format_type}.jsonl"
170
+ )
171
+
172
+ runs = self.task.runs()
173
+ runs_by_id = {run.id: run for run in runs}
174
+
175
+ # Generate formatted output with UTF-8 encoding
176
+ with open(output_path, "w", encoding="utf-8") as f:
177
+ for run_id in self.dataset.split_contents[split_name]:
178
+ task_run = runs_by_id[run_id]
179
+ if task_run is None:
180
+ raise ValueError(
181
+ f"Task run {run_id} not found. This is required by this dataset."
182
+ )
183
+
184
+ example = generator(task_run, self.system_message)
185
+ f.write(json.dumps(example) + "\n")
186
+
187
+ return output_path
@@ -0,0 +1,11 @@
1
+ from typing import Type
2
+
3
+ from kiln_ai.adapters.fine_tune.base_finetune import BaseFinetuneAdapter
4
+ from kiln_ai.adapters.fine_tune.fireworks_finetune import FireworksFinetune
5
+ from kiln_ai.adapters.fine_tune.openai_finetune import OpenAIFinetune
6
+ from kiln_ai.adapters.ml_model_list import ModelProviderName
7
+
8
+ finetune_registry: dict[ModelProviderName, Type[BaseFinetuneAdapter]] = {
9
+ ModelProviderName.openai: OpenAIFinetune,
10
+ ModelProviderName.fireworks_ai: FireworksFinetune,
11
+ }