kiln-ai 0.6.1__py3-none-any.whl → 0.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (44) hide show
  1. kiln_ai/adapters/__init__.py +2 -0
  2. kiln_ai/adapters/adapter_registry.py +19 -0
  3. kiln_ai/adapters/data_gen/test_data_gen_task.py +29 -21
  4. kiln_ai/adapters/fine_tune/__init__.py +14 -0
  5. kiln_ai/adapters/fine_tune/base_finetune.py +186 -0
  6. kiln_ai/adapters/fine_tune/dataset_formatter.py +187 -0
  7. kiln_ai/adapters/fine_tune/finetune_registry.py +11 -0
  8. kiln_ai/adapters/fine_tune/fireworks_finetune.py +308 -0
  9. kiln_ai/adapters/fine_tune/openai_finetune.py +205 -0
  10. kiln_ai/adapters/fine_tune/test_base_finetune.py +290 -0
  11. kiln_ai/adapters/fine_tune/test_dataset_formatter.py +342 -0
  12. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +455 -0
  13. kiln_ai/adapters/fine_tune/test_openai_finetune.py +503 -0
  14. kiln_ai/adapters/langchain_adapters.py +103 -13
  15. kiln_ai/adapters/ml_model_list.py +239 -303
  16. kiln_ai/adapters/ollama_tools.py +115 -0
  17. kiln_ai/adapters/provider_tools.py +308 -0
  18. kiln_ai/adapters/repair/repair_task.py +4 -2
  19. kiln_ai/adapters/repair/test_repair_task.py +6 -11
  20. kiln_ai/adapters/test_langchain_adapter.py +229 -18
  21. kiln_ai/adapters/test_ollama_tools.py +42 -0
  22. kiln_ai/adapters/test_prompt_adaptors.py +7 -5
  23. kiln_ai/adapters/test_provider_tools.py +531 -0
  24. kiln_ai/adapters/test_structured_output.py +22 -43
  25. kiln_ai/datamodel/__init__.py +287 -24
  26. kiln_ai/datamodel/basemodel.py +122 -38
  27. kiln_ai/datamodel/model_cache.py +116 -0
  28. kiln_ai/datamodel/registry.py +31 -0
  29. kiln_ai/datamodel/test_basemodel.py +167 -4
  30. kiln_ai/datamodel/test_dataset_split.py +234 -0
  31. kiln_ai/datamodel/test_example_models.py +12 -0
  32. kiln_ai/datamodel/test_model_cache.py +244 -0
  33. kiln_ai/datamodel/test_models.py +215 -1
  34. kiln_ai/datamodel/test_registry.py +96 -0
  35. kiln_ai/utils/config.py +14 -1
  36. kiln_ai/utils/name_generator.py +125 -0
  37. kiln_ai/utils/test_name_geneator.py +47 -0
  38. kiln_ai-0.7.1.dist-info/METADATA +237 -0
  39. kiln_ai-0.7.1.dist-info/RECORD +58 -0
  40. {kiln_ai-0.6.1.dist-info → kiln_ai-0.7.1.dist-info}/WHEEL +1 -1
  41. kiln_ai/adapters/test_ml_model_list.py +0 -181
  42. kiln_ai-0.6.1.dist-info/METADATA +0 -88
  43. kiln_ai-0.6.1.dist-info/RECORD +0 -37
  44. {kiln_ai-0.6.1.dist-info → kiln_ai-0.7.1.dist-info}/licenses/LICENSE.txt +0 -0
@@ -15,6 +15,7 @@ The repair submodule contains an adapter for the repair task.
15
15
  from . import (
16
16
  base_adapter,
17
17
  data_gen,
18
+ fine_tune,
18
19
  langchain_adapters,
19
20
  ml_model_list,
20
21
  prompt_builders,
@@ -28,4 +29,5 @@ __all__ = [
28
29
  "prompt_builders",
29
30
  "repair",
30
31
  "data_gen",
32
+ "fine_tune",
31
33
  ]
@@ -0,0 +1,19 @@
1
+ from kiln_ai import datamodel
2
+ from kiln_ai.adapters.base_adapter import BaseAdapter
3
+ from kiln_ai.adapters.langchain_adapters import LangchainAdapter
4
+ from kiln_ai.adapters.prompt_builders import BasePromptBuilder
5
+
6
+
7
+ def adapter_for_task(
8
+ kiln_task: datamodel.Task,
9
+ model_name: str | None = None,
10
+ provider: str | None = None,
11
+ prompt_builder: BasePromptBuilder | None = None,
12
+ ) -> BaseAdapter:
13
+ # We use langchain for everything right now, but can add any others here
14
+ return LangchainAdapter(
15
+ kiln_task,
16
+ model_name=model_name,
17
+ provider=provider,
18
+ prompt_builder=prompt_builder,
19
+ )
@@ -2,6 +2,7 @@ import json
2
2
 
3
3
  import pytest
4
4
 
5
+ from kiln_ai.adapters.adapter_registry import adapter_for_task
5
6
  from kiln_ai.adapters.data_gen.data_gen_task import (
6
7
  DataGenCategoriesTask,
7
8
  DataGenCategoriesTaskInput,
@@ -10,8 +11,7 @@ from kiln_ai.adapters.data_gen.data_gen_task import (
10
11
  DataGenSampleTaskInput,
11
12
  list_json_schema_for_task,
12
13
  )
13
- from kiln_ai.adapters.langchain_adapters import LangChainPromptAdapter
14
- from kiln_ai.adapters.ml_model_list import get_model_and_provider
14
+ from kiln_ai.adapters.provider_tools import get_model_and_provider
15
15
  from kiln_ai.adapters.test_prompt_adaptors import get_all_models_and_providers
16
16
  from kiln_ai.datamodel import Project, Task
17
17
 
@@ -108,7 +108,7 @@ async def test_data_gen_all_models_providers(
108
108
  data_gen_task = DataGenCategoriesTask()
109
109
  data_gen_input = DataGenCategoriesTaskInput.from_task(base_task, num_subtopics=6)
110
110
 
111
- adapter = LangChainPromptAdapter(
111
+ adapter = adapter_for_task(
112
112
  data_gen_task,
113
113
  model_name=model_name,
114
114
  provider=provider_name,
@@ -232,7 +232,7 @@ async def test_data_gen_sample_all_models_providers(
232
232
  base_task, topic=["riding horses"], num_samples=4
233
233
  )
234
234
 
235
- adapter = LangChainPromptAdapter(
235
+ adapter = adapter_for_task(
236
236
  data_gen_task,
237
237
  model_name=model_name,
238
238
  provider=provider_name,
@@ -251,17 +251,25 @@ async def test_data_gen_sample_all_models_providers(
251
251
  @pytest.mark.ollama
252
252
  @pytest.mark.parametrize("model_name,provider_name", get_all_models_and_providers())
253
253
  async def test_data_gen_sample_all_models_providers_with_structured_output(
254
- tmp_path, model_name, provider_name, base_task
254
+ tmp_path, model_name, provider_name
255
255
  ):
256
- base_task.output_json_schema = json.dumps(
257
- {
258
- "type": "object",
259
- "properties": {
260
- "opening": {"type": "string"},
261
- "closing": {"type": "string"},
262
- },
263
- "required": ["opening", "closing"],
264
- }
256
+ project = Project(name="TestProject")
257
+ task = Task(
258
+ name="Summarize",
259
+ parent=project,
260
+ description="Explain if the username matches the tweet",
261
+ instruction="Explain if the username matches the tweet",
262
+ requirements=[],
263
+ input_json_schema=json.dumps(
264
+ {
265
+ "type": "object",
266
+ "properties": {
267
+ "username": {"type": "string"},
268
+ "tweet": {"type": "string"},
269
+ },
270
+ "required": ["username", "tweet"],
271
+ }
272
+ ),
265
273
  )
266
274
 
267
275
  _, provider = get_model_and_provider(model_name, provider_name)
@@ -269,12 +277,12 @@ async def test_data_gen_sample_all_models_providers_with_structured_output(
269
277
  # pass if the model doesn't support data gen (testing the support flag is part of this)
270
278
  return
271
279
 
272
- data_gen_task = DataGenSampleTask(target_task=base_task)
280
+ data_gen_task = DataGenSampleTask(target_task=task)
273
281
  data_gen_input = DataGenSampleTaskInput.from_task(
274
- base_task, topic=["riding horses"], num_samples=4
282
+ task, topic=["Food"], num_samples=4
275
283
  )
276
284
 
277
- adapter = LangChainPromptAdapter(
285
+ adapter = adapter_for_task(
278
286
  data_gen_task,
279
287
  model_name=model_name,
280
288
  provider=provider_name,
@@ -287,7 +295,7 @@ async def test_data_gen_sample_all_models_providers_with_structured_output(
287
295
  assert len(samples) == 4
288
296
  for sample in samples:
289
297
  assert isinstance(sample, dict)
290
- assert "opening" in sample
291
- assert "closing" in sample
292
- assert isinstance(sample["opening"], str)
293
- assert isinstance(sample["closing"], str)
298
+ assert "username" in sample
299
+ assert "tweet" in sample
300
+ assert isinstance(sample["username"], str)
301
+ assert isinstance(sample["tweet"], str)
@@ -0,0 +1,14 @@
1
+ """
2
+ # Fine-Tuning
3
+
4
+ A set of classes for fine-tuning models.
5
+ """
6
+
7
+ from . import base_finetune, dataset_formatter, finetune_registry, openai_finetune
8
+
9
+ __all__ = [
10
+ "base_finetune",
11
+ "openai_finetune",
12
+ "dataset_formatter",
13
+ "finetune_registry",
14
+ ]
@@ -0,0 +1,186 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Literal
3
+
4
+ from pydantic import BaseModel
5
+
6
+ from kiln_ai.adapters.ml_model_list import built_in_models
7
+ from kiln_ai.datamodel import DatasetSplit, FineTuneStatusType
8
+ from kiln_ai.datamodel import Finetune as FinetuneModel
9
+ from kiln_ai.utils.name_generator import generate_memorable_name
10
+
11
+
12
+ class FineTuneStatus(BaseModel):
13
+ """
14
+ The status of a fine-tune, including a user friendly message.
15
+ """
16
+
17
+ status: FineTuneStatusType
18
+ message: str | None = None
19
+
20
+
21
+ class FineTuneParameter(BaseModel):
22
+ """
23
+ A parameter for a fine-tune. Hyperparameters, etc.
24
+ """
25
+
26
+ name: str
27
+ type: Literal["string", "int", "float", "bool"]
28
+ description: str
29
+ optional: bool = True
30
+
31
+
32
+ TYPE_MAP = {
33
+ "string": str,
34
+ "int": int,
35
+ "float": float,
36
+ "bool": bool,
37
+ }
38
+
39
+
40
+ class BaseFinetuneAdapter(ABC):
41
+ """
42
+ A base class for fine-tuning adapters.
43
+ """
44
+
45
+ def __init__(
46
+ self,
47
+ datamodel: FinetuneModel,
48
+ ):
49
+ self.datamodel = datamodel
50
+
51
+ @classmethod
52
+ async def create_and_start(
53
+ cls,
54
+ dataset: DatasetSplit,
55
+ provider_id: str,
56
+ provider_base_model_id: str,
57
+ train_split_name: str,
58
+ system_message: str,
59
+ parameters: dict[str, str | int | float | bool] = {},
60
+ name: str | None = None,
61
+ description: str | None = None,
62
+ validation_split_name: str | None = None,
63
+ ) -> tuple["BaseFinetuneAdapter", FinetuneModel]:
64
+ """
65
+ Create and start a fine-tune.
66
+ """
67
+
68
+ cls.check_valid_provider_model(provider_id, provider_base_model_id)
69
+
70
+ if not dataset.id:
71
+ raise ValueError("Dataset must have an id")
72
+
73
+ if train_split_name not in dataset.split_contents:
74
+ raise ValueError(f"Train split {train_split_name} not found in dataset")
75
+
76
+ if (
77
+ validation_split_name
78
+ and validation_split_name not in dataset.split_contents
79
+ ):
80
+ raise ValueError(
81
+ f"Validation split {validation_split_name} not found in dataset"
82
+ )
83
+
84
+ # Default name if not provided
85
+ if name is None:
86
+ name = generate_memorable_name()
87
+
88
+ cls.validate_parameters(parameters)
89
+ parent_task = dataset.parent_task()
90
+ if parent_task is None or not parent_task.path:
91
+ raise ValueError("Dataset must have a parent task with a path")
92
+
93
+ datamodel = FinetuneModel(
94
+ name=name,
95
+ description=description,
96
+ provider=provider_id,
97
+ base_model_id=provider_base_model_id,
98
+ dataset_split_id=dataset.id,
99
+ train_split_name=train_split_name,
100
+ validation_split_name=validation_split_name,
101
+ parameters=parameters,
102
+ system_message=system_message,
103
+ parent=parent_task,
104
+ )
105
+
106
+ adapter = cls(datamodel)
107
+ await adapter._start(dataset)
108
+
109
+ datamodel.save_to_file()
110
+
111
+ return adapter, datamodel
112
+
113
+ @abstractmethod
114
+ async def _start(self, dataset: DatasetSplit) -> None:
115
+ """
116
+ Start the fine-tune.
117
+ """
118
+ pass
119
+
120
+ @abstractmethod
121
+ async def status(self) -> FineTuneStatus:
122
+ """
123
+ Get the status of the fine-tune.
124
+ """
125
+ pass
126
+
127
+ @classmethod
128
+ def available_parameters(cls) -> list[FineTuneParameter]:
129
+ """
130
+ Returns a list of parameters that can be provided for this fine-tune. Includes hyperparameters, etc.
131
+ """
132
+ return []
133
+
134
+ @classmethod
135
+ def validate_parameters(
136
+ cls, parameters: dict[str, str | int | float | bool]
137
+ ) -> None:
138
+ """
139
+ Validate the parameters for this fine-tune.
140
+ """
141
+ # Check required parameters and parameter types
142
+ available_parameters = cls.available_parameters()
143
+ for parameter in available_parameters:
144
+ if not parameter.optional and parameter.name not in parameters:
145
+ raise ValueError(f"Parameter {parameter.name} is required")
146
+ elif parameter.name in parameters:
147
+ # check parameter is correct type
148
+ expected_type = TYPE_MAP[parameter.type]
149
+ value = parameters[parameter.name]
150
+
151
+ # Strict type checking for numeric types
152
+ if expected_type is float and not isinstance(value, float):
153
+ raise ValueError(
154
+ f"Parameter {parameter.name} must be a float, got {type(value)}"
155
+ )
156
+ elif expected_type is int and not isinstance(value, int):
157
+ raise ValueError(
158
+ f"Parameter {parameter.name} must be an integer, got {type(value)}"
159
+ )
160
+ elif not isinstance(value, expected_type):
161
+ raise ValueError(
162
+ f"Parameter {parameter.name} must be type {expected_type}, got {type(value)}"
163
+ )
164
+
165
+ allowed_parameters = [p.name for p in available_parameters]
166
+ for parameter_key in parameters:
167
+ if parameter_key not in allowed_parameters:
168
+ raise ValueError(f"Parameter {parameter_key} is not available")
169
+
170
+ @classmethod
171
+ def check_valid_provider_model(
172
+ cls, provider_id: str, provider_base_model_id: str
173
+ ) -> None:
174
+ """
175
+ Check if the provider and base model are valid.
176
+ """
177
+ for model in built_in_models:
178
+ for provider in model.providers:
179
+ if (
180
+ provider.name == provider_id
181
+ and provider.provider_finetune_id == provider_base_model_id
182
+ ):
183
+ return
184
+ raise ValueError(
185
+ f"Provider {provider_id} with base model {provider_base_model_id} is not available"
186
+ )
@@ -0,0 +1,187 @@
1
+ import json
2
+ import tempfile
3
+ from enum import Enum
4
+ from pathlib import Path
5
+ from typing import Any, Dict, Protocol
6
+ from uuid import uuid4
7
+
8
+ from kiln_ai.datamodel import DatasetSplit, TaskRun
9
+
10
+
11
+ class DatasetFormat(str, Enum):
12
+ """Formats for dataset generation. Both for file format (like JSONL), and internal structure (like chat/toolcall)"""
13
+
14
+ """OpenAI chat format with plaintext response"""
15
+ OPENAI_CHAT_JSONL = "openai_chat_jsonl"
16
+
17
+ """OpenAI chat format with tool call response"""
18
+ OPENAI_CHAT_TOOLCALL_JSONL = "openai_chat_toolcall_jsonl"
19
+
20
+ """HuggingFace chat template in JSONL"""
21
+ HUGGINGFACE_CHAT_TEMPLATE_JSONL = "huggingface_chat_template_jsonl"
22
+
23
+ """HuggingFace chat template with tool calls in JSONL"""
24
+ HUGGINGFACE_CHAT_TEMPLATE_TOOLCALL_JSONL = (
25
+ "huggingface_chat_template_toolcall_jsonl"
26
+ )
27
+
28
+
29
+ class FormatGenerator(Protocol):
30
+ """Protocol for format generators"""
31
+
32
+ def __call__(self, task_run: TaskRun, system_message: str) -> Dict[str, Any]: ...
33
+
34
+
35
+ def generate_chat_message_response(
36
+ task_run: TaskRun, system_message: str
37
+ ) -> Dict[str, Any]:
38
+ """Generate OpenAI chat format with plaintext response"""
39
+ return {
40
+ "messages": [
41
+ {"role": "system", "content": system_message},
42
+ {"role": "user", "content": task_run.input},
43
+ {"role": "assistant", "content": task_run.output.output},
44
+ ]
45
+ }
46
+
47
+
48
+ def generate_chat_message_toolcall(
49
+ task_run: TaskRun, system_message: str
50
+ ) -> Dict[str, Any]:
51
+ """Generate OpenAI chat format with tool call response"""
52
+ try:
53
+ arguments = json.loads(task_run.output.output)
54
+ except json.JSONDecodeError as e:
55
+ raise ValueError(f"Invalid JSON in for tool call: {e}") from e
56
+
57
+ return {
58
+ "messages": [
59
+ {"role": "system", "content": system_message},
60
+ {"role": "user", "content": task_run.input},
61
+ {
62
+ "role": "assistant",
63
+ "content": None,
64
+ "tool_calls": [
65
+ {
66
+ "id": "call_1",
67
+ "type": "function",
68
+ "function": {
69
+ "name": "task_response",
70
+ # Yes we parse then dump again. This ensures it's valid JSON, and ensures it goes to 1 line
71
+ "arguments": json.dumps(arguments),
72
+ },
73
+ }
74
+ ],
75
+ },
76
+ ]
77
+ }
78
+
79
+
80
+ def generate_huggingface_chat_template(
81
+ task_run: TaskRun, system_message: str
82
+ ) -> Dict[str, Any]:
83
+ """Generate HuggingFace chat template"""
84
+ return {
85
+ "conversations": [
86
+ {"role": "system", "content": system_message},
87
+ {"role": "user", "content": task_run.input},
88
+ {"role": "assistant", "content": task_run.output.output},
89
+ ]
90
+ }
91
+
92
+
93
+ def generate_huggingface_chat_template_toolcall(
94
+ task_run: TaskRun, system_message: str
95
+ ) -> Dict[str, Any]:
96
+ """Generate HuggingFace chat template with tool calls"""
97
+ try:
98
+ arguments = json.loads(task_run.output.output)
99
+ except json.JSONDecodeError as e:
100
+ raise ValueError(f"Invalid JSON in for tool call: {e}") from e
101
+
102
+ # See https://huggingface.co/docs/transformers/en/chat_templating
103
+ return {
104
+ "conversations": [
105
+ {"role": "system", "content": system_message},
106
+ {"role": "user", "content": task_run.input},
107
+ {
108
+ "role": "assistant",
109
+ "tool_calls": [
110
+ {
111
+ "type": "function",
112
+ "function": {
113
+ "name": "task_response",
114
+ "id": str(uuid4()).replace("-", "")[:9],
115
+ "arguments": arguments,
116
+ },
117
+ }
118
+ ],
119
+ },
120
+ ]
121
+ }
122
+
123
+
124
+ FORMAT_GENERATORS: Dict[DatasetFormat, FormatGenerator] = {
125
+ DatasetFormat.OPENAI_CHAT_JSONL: generate_chat_message_response,
126
+ DatasetFormat.OPENAI_CHAT_TOOLCALL_JSONL: generate_chat_message_toolcall,
127
+ DatasetFormat.HUGGINGFACE_CHAT_TEMPLATE_JSONL: generate_huggingface_chat_template,
128
+ DatasetFormat.HUGGINGFACE_CHAT_TEMPLATE_TOOLCALL_JSONL: generate_huggingface_chat_template_toolcall,
129
+ }
130
+
131
+
132
+ class DatasetFormatter:
133
+ """Handles formatting of datasets into various output formats"""
134
+
135
+ def __init__(self, dataset: DatasetSplit, system_message: str):
136
+ self.dataset = dataset
137
+ self.system_message = system_message
138
+
139
+ task = dataset.parent_task()
140
+ if task is None:
141
+ raise ValueError("Dataset has no parent task")
142
+ self.task = task
143
+
144
+ def dump_to_file(
145
+ self, split_name: str, format_type: DatasetFormat, path: Path | None = None
146
+ ) -> Path:
147
+ """
148
+ Format the dataset into the specified format.
149
+
150
+ Args:
151
+ split_name: Name of the split to dump
152
+ format_type: Format to generate the dataset in
153
+ path: Optional path to write to. If None, writes to temp directory
154
+
155
+ Returns:
156
+ Path to the generated file
157
+ """
158
+ if format_type not in FORMAT_GENERATORS:
159
+ raise ValueError(f"Unsupported format: {format_type}")
160
+ if split_name not in self.dataset.split_contents:
161
+ raise ValueError(f"Split {split_name} not found in dataset")
162
+
163
+ generator = FORMAT_GENERATORS[format_type]
164
+
165
+ # Write to a temp file if no path is provided
166
+ output_path = (
167
+ path
168
+ or Path(tempfile.gettempdir())
169
+ / f"{self.dataset.name}_{split_name}_{format_type}.jsonl"
170
+ )
171
+
172
+ runs = self.task.runs()
173
+ runs_by_id = {run.id: run for run in runs}
174
+
175
+ # Generate formatted output with UTF-8 encoding
176
+ with open(output_path, "w", encoding="utf-8") as f:
177
+ for run_id in self.dataset.split_contents[split_name]:
178
+ task_run = runs_by_id[run_id]
179
+ if task_run is None:
180
+ raise ValueError(
181
+ f"Task run {run_id} not found. This is required by this dataset."
182
+ )
183
+
184
+ example = generator(task_run, self.system_message)
185
+ f.write(json.dumps(example) + "\n")
186
+
187
+ return output_path
@@ -0,0 +1,11 @@
1
+ from typing import Type
2
+
3
+ from kiln_ai.adapters.fine_tune.base_finetune import BaseFinetuneAdapter
4
+ from kiln_ai.adapters.fine_tune.fireworks_finetune import FireworksFinetune
5
+ from kiln_ai.adapters.fine_tune.openai_finetune import OpenAIFinetune
6
+ from kiln_ai.adapters.ml_model_list import ModelProviderName
7
+
8
+ finetune_registry: dict[ModelProviderName, Type[BaseFinetuneAdapter]] = {
9
+ ModelProviderName.openai: OpenAIFinetune,
10
+ ModelProviderName.fireworks_ai: FireworksFinetune,
11
+ }