kiln-ai 0.12.0__py3-none-any.whl → 0.13.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +4 -0
- kiln_ai/adapters/adapter_registry.py +153 -28
- kiln_ai/adapters/eval/__init__.py +28 -0
- kiln_ai/adapters/eval/eval_runner.py +4 -1
- kiln_ai/adapters/eval/g_eval.py +2 -1
- kiln_ai/adapters/eval/test_base_eval.py +1 -0
- kiln_ai/adapters/eval/test_eval_runner.py +1 -0
- kiln_ai/adapters/eval/test_g_eval.py +1 -0
- kiln_ai/adapters/fine_tune/base_finetune.py +16 -2
- kiln_ai/adapters/fine_tune/finetune_registry.py +2 -0
- kiln_ai/adapters/fine_tune/test_together_finetune.py +531 -0
- kiln_ai/adapters/fine_tune/together_finetune.py +325 -0
- kiln_ai/adapters/ml_model_list.py +638 -155
- kiln_ai/adapters/model_adapters/__init__.py +2 -4
- kiln_ai/adapters/model_adapters/base_adapter.py +14 -11
- kiln_ai/adapters/model_adapters/litellm_adapter.py +391 -0
- kiln_ai/adapters/model_adapters/litellm_config.py +13 -0
- kiln_ai/adapters/model_adapters/test_litellm_adapter.py +407 -0
- kiln_ai/adapters/model_adapters/test_structured_output.py +23 -5
- kiln_ai/adapters/ollama_tools.py +3 -2
- kiln_ai/adapters/parsers/r1_parser.py +19 -14
- kiln_ai/adapters/parsers/test_r1_parser.py +17 -5
- kiln_ai/adapters/provider_tools.py +50 -58
- kiln_ai/adapters/repair/test_repair_task.py +3 -3
- kiln_ai/adapters/run_output.py +1 -1
- kiln_ai/adapters/test_adapter_registry.py +17 -20
- kiln_ai/adapters/test_generate_docs.py +2 -2
- kiln_ai/adapters/test_prompt_adaptors.py +30 -19
- kiln_ai/adapters/test_provider_tools.py +26 -81
- kiln_ai/datamodel/basemodel.py +2 -0
- kiln_ai/datamodel/datamodel_enums.py +2 -0
- kiln_ai/datamodel/json_schema.py +1 -1
- kiln_ai/datamodel/task_output.py +13 -6
- kiln_ai/datamodel/test_basemodel.py +9 -0
- kiln_ai/datamodel/test_datasource.py +19 -0
- kiln_ai/utils/config.py +37 -0
- kiln_ai/utils/dataset_import.py +232 -0
- kiln_ai/utils/test_dataset_import.py +596 -0
- {kiln_ai-0.12.0.dist-info → kiln_ai-0.13.0.dist-info}/METADATA +51 -7
- {kiln_ai-0.12.0.dist-info → kiln_ai-0.13.0.dist-info}/RECORD +42 -39
- kiln_ai/adapters/model_adapters/langchain_adapters.py +0 -309
- kiln_ai/adapters/model_adapters/openai_compatible_config.py +0 -10
- kiln_ai/adapters/model_adapters/openai_model_adapter.py +0 -289
- kiln_ai/adapters/model_adapters/test_langchain_adapter.py +0 -343
- kiln_ai/adapters/model_adapters/test_openai_model_adapter.py +0 -216
- {kiln_ai-0.12.0.dist-info → kiln_ai-0.13.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.12.0.dist-info → kiln_ai-0.13.0.dist-info}/licenses/LICENSE.txt +0 -0
kiln_ai/datamodel/json_schema.py
CHANGED
|
@@ -49,7 +49,7 @@ def validate_schema(instance: Dict, schema_str: str) -> None:
|
|
|
49
49
|
v.validate(instance)
|
|
50
50
|
except jsonschema.exceptions.ValidationError as e:
|
|
51
51
|
raise ValueError(
|
|
52
|
-
f"This task requires a specific output schema. While the model produced JSON, that JSON didn't meet the schema. Search 'Troubleshooting Structured Data Issues' in our docs for more information. The error from the schema check was: {e.message}"
|
|
52
|
+
f"This task requires a specific output schema. While the model produced JSON, that JSON didn't meet the schema. Search 'Troubleshooting Structured Data Issues' in our docs for more information. The error from the schema check was: {e.message}. The JSON was: \n```json\n{instance}\n```"
|
|
53
53
|
) from e
|
|
54
54
|
|
|
55
55
|
|
kiln_ai/datamodel/task_output.py
CHANGED
|
@@ -171,6 +171,7 @@ class DataSourceType(str, Enum):
|
|
|
171
171
|
|
|
172
172
|
human = "human"
|
|
173
173
|
synthetic = "synthetic"
|
|
174
|
+
file_import = "file_import"
|
|
174
175
|
|
|
175
176
|
|
|
176
177
|
class DataSourceProperty(BaseModel):
|
|
@@ -206,37 +207,43 @@ class DataSource(BaseModel):
|
|
|
206
207
|
name="created_by",
|
|
207
208
|
type=str,
|
|
208
209
|
required_for=[DataSourceType.human],
|
|
209
|
-
not_allowed_for=[DataSourceType.synthetic],
|
|
210
|
+
not_allowed_for=[DataSourceType.synthetic, DataSourceType.file_import],
|
|
210
211
|
),
|
|
211
212
|
DataSourceProperty(
|
|
212
213
|
name="model_name",
|
|
213
214
|
type=str,
|
|
214
215
|
required_for=[DataSourceType.synthetic],
|
|
215
|
-
not_allowed_for=[DataSourceType.human],
|
|
216
|
+
not_allowed_for=[DataSourceType.human, DataSourceType.file_import],
|
|
216
217
|
),
|
|
217
218
|
DataSourceProperty(
|
|
218
219
|
name="model_provider",
|
|
219
220
|
type=str,
|
|
220
221
|
required_for=[DataSourceType.synthetic],
|
|
221
|
-
not_allowed_for=[DataSourceType.human],
|
|
222
|
+
not_allowed_for=[DataSourceType.human, DataSourceType.file_import],
|
|
222
223
|
),
|
|
223
224
|
DataSourceProperty(
|
|
224
225
|
name="adapter_name",
|
|
225
226
|
type=str,
|
|
226
227
|
required_for=[DataSourceType.synthetic],
|
|
227
|
-
not_allowed_for=[DataSourceType.human],
|
|
228
|
+
not_allowed_for=[DataSourceType.human, DataSourceType.file_import],
|
|
228
229
|
),
|
|
229
230
|
DataSourceProperty(
|
|
230
231
|
# Legacy field -- allow loading from old runs, but we shouldn't be setting it.
|
|
231
232
|
name="prompt_builder_name",
|
|
232
233
|
type=str,
|
|
233
|
-
not_allowed_for=[DataSourceType.human],
|
|
234
|
+
not_allowed_for=[DataSourceType.human, DataSourceType.file_import],
|
|
234
235
|
),
|
|
235
236
|
DataSourceProperty(
|
|
236
237
|
# The PromptId of the prompt. Can be a saved prompt, fine-tune, generator name, etc. See PromptId type for more details.
|
|
237
238
|
name="prompt_id",
|
|
238
239
|
type=str,
|
|
239
|
-
not_allowed_for=[DataSourceType.human],
|
|
240
|
+
not_allowed_for=[DataSourceType.human, DataSourceType.file_import],
|
|
241
|
+
),
|
|
242
|
+
DataSourceProperty(
|
|
243
|
+
name="file_name",
|
|
244
|
+
type=str,
|
|
245
|
+
required_for=[DataSourceType.file_import],
|
|
246
|
+
not_allowed_for=[DataSourceType.human, DataSourceType.synthetic],
|
|
240
247
|
),
|
|
241
248
|
]
|
|
242
249
|
|
|
@@ -510,6 +510,7 @@ async def test_invoke_parsing_flow(adapter):
|
|
|
510
510
|
# Mock dependencies
|
|
511
511
|
mock_provider = MagicMock()
|
|
512
512
|
mock_provider.parser = "test_parser"
|
|
513
|
+
mock_provider.reasoning_capable = False
|
|
513
514
|
|
|
514
515
|
mock_parser = MagicMock()
|
|
515
516
|
mock_parser.parse_output.return_value = RunOutput(
|
|
@@ -547,3 +548,11 @@ async def test_invoke_parsing_flow(adapter):
|
|
|
547
548
|
assert result.output.output == "parsed test output"
|
|
548
549
|
assert result.intermediate_outputs == {"key": "value"}
|
|
549
550
|
assert result.input == "test input"
|
|
551
|
+
|
|
552
|
+
# Test with reasoning required, that we error if no reasoning is returned
|
|
553
|
+
mock_provider.reasoning_capable = True
|
|
554
|
+
with pytest.raises(
|
|
555
|
+
RuntimeError,
|
|
556
|
+
match="Reasoning is required for this model, but no reasoning was returned.",
|
|
557
|
+
):
|
|
558
|
+
await adapter.invoke("test input")
|
|
@@ -29,11 +29,30 @@ def test_valid_synthetic_data_source():
|
|
|
29
29
|
assert data_source.properties["adapter_name"] == "langchain"
|
|
30
30
|
|
|
31
31
|
|
|
32
|
+
def test_valid_file_import_data_source():
|
|
33
|
+
data_source = DataSource(
|
|
34
|
+
type=DataSourceType.file_import,
|
|
35
|
+
properties={"file_name": "test.txt"},
|
|
36
|
+
)
|
|
37
|
+
assert data_source.type == DataSourceType.file_import
|
|
38
|
+
assert data_source.properties["file_name"] == "test.txt"
|
|
39
|
+
|
|
40
|
+
|
|
32
41
|
def test_missing_required_property():
|
|
33
42
|
with pytest.raises(ValidationError, match="'created_by' is required for"):
|
|
34
43
|
DataSource(type=DataSourceType.human)
|
|
35
44
|
|
|
36
45
|
|
|
46
|
+
def test_missing_required_property_file_import():
|
|
47
|
+
with pytest.raises(ValidationError, match="'file_name' is required for"):
|
|
48
|
+
DataSource(type=DataSourceType.file_import)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def test_not_allowed_property_file_import():
|
|
52
|
+
with pytest.raises(ValidationError, match="'model_name' is not allowed for"):
|
|
53
|
+
DataSource(type=DataSourceType.file_import, properties={"model_name": "GPT-4"})
|
|
54
|
+
|
|
55
|
+
|
|
37
56
|
def test_wrong_property_type():
|
|
38
57
|
with pytest.raises(
|
|
39
58
|
ValidationError,
|
kiln_ai/utils/config.py
CHANGED
|
@@ -78,10 +78,47 @@ class Config:
|
|
|
78
78
|
str,
|
|
79
79
|
env_var="FIREWORKS_ACCOUNT_ID",
|
|
80
80
|
),
|
|
81
|
+
"anthropic_api_key": ConfigProperty(
|
|
82
|
+
str,
|
|
83
|
+
env_var="ANTHROPIC_API_KEY",
|
|
84
|
+
sensitive=True,
|
|
85
|
+
),
|
|
86
|
+
"gemini_api_key": ConfigProperty(
|
|
87
|
+
str,
|
|
88
|
+
env_var="GEMINI_API_KEY",
|
|
89
|
+
sensitive=True,
|
|
90
|
+
),
|
|
81
91
|
"projects": ConfigProperty(
|
|
82
92
|
list,
|
|
83
93
|
default_lambda=lambda: [],
|
|
84
94
|
),
|
|
95
|
+
"azure_openai_api_key": ConfigProperty(
|
|
96
|
+
str,
|
|
97
|
+
env_var="AZURE_OPENAI_API_KEY",
|
|
98
|
+
sensitive=True,
|
|
99
|
+
),
|
|
100
|
+
"azure_openai_endpoint": ConfigProperty(
|
|
101
|
+
str,
|
|
102
|
+
env_var="AZURE_OPENAI_ENDPOINT",
|
|
103
|
+
),
|
|
104
|
+
"huggingface_api_key": ConfigProperty(
|
|
105
|
+
str,
|
|
106
|
+
env_var="HUGGINGFACE_API_KEY",
|
|
107
|
+
sensitive=True,
|
|
108
|
+
),
|
|
109
|
+
"vertex_project_id": ConfigProperty(
|
|
110
|
+
str,
|
|
111
|
+
env_var="VERTEX_PROJECT_ID",
|
|
112
|
+
),
|
|
113
|
+
"vertex_location": ConfigProperty(
|
|
114
|
+
str,
|
|
115
|
+
env_var="VERTEX_LOCATION",
|
|
116
|
+
),
|
|
117
|
+
"together_api_key": ConfigProperty(
|
|
118
|
+
str,
|
|
119
|
+
env_var="TOGETHERAI_API_KEY",
|
|
120
|
+
sensitive=True,
|
|
121
|
+
),
|
|
85
122
|
"custom_models": ConfigProperty(
|
|
86
123
|
list,
|
|
87
124
|
default_lambda=lambda: [],
|
|
@@ -0,0 +1,232 @@
|
|
|
1
|
+
import csv
|
|
2
|
+
import logging
|
|
3
|
+
import time
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from enum import Enum
|
|
6
|
+
from typing import Dict, Protocol
|
|
7
|
+
|
|
8
|
+
from pydantic import BaseModel, Field, ValidationError, field_validator
|
|
9
|
+
|
|
10
|
+
from kiln_ai.datamodel import DataSource, DataSourceType, Task, TaskOutput, TaskRun
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class DatasetImportFormat(str, Enum):
|
|
16
|
+
"""
|
|
17
|
+
The format of the dataset to import.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
CSV = "csv"
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class Importer(Protocol):
|
|
24
|
+
"""Protocol for dataset importers"""
|
|
25
|
+
|
|
26
|
+
def __call__(
|
|
27
|
+
self,
|
|
28
|
+
task: Task,
|
|
29
|
+
dataset_path: str,
|
|
30
|
+
dataset_name: str,
|
|
31
|
+
) -> int: ...
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class CSVRowSchema(BaseModel):
|
|
35
|
+
"""Schema for validating rows in a CSV file."""
|
|
36
|
+
|
|
37
|
+
input: str = Field(description="The input to the model")
|
|
38
|
+
output: str = Field(description="The output of the model")
|
|
39
|
+
reasoning: str | None = Field(
|
|
40
|
+
description="The reasoning of the model (optional)",
|
|
41
|
+
default=None,
|
|
42
|
+
)
|
|
43
|
+
chain_of_thought: str | None = Field(
|
|
44
|
+
description="The chain of thought of the model (optional)",
|
|
45
|
+
default=None,
|
|
46
|
+
)
|
|
47
|
+
tags: list[str] = Field(
|
|
48
|
+
default_factory=list,
|
|
49
|
+
description="The tags of the run (optional)",
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def generate_import_tags(session_id: str) -> list[str]:
|
|
54
|
+
return [
|
|
55
|
+
"imported",
|
|
56
|
+
f"imported_{session_id}",
|
|
57
|
+
]
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class KilnInvalidImportFormat(Exception):
|
|
61
|
+
"""Raised when the import format is invalid"""
|
|
62
|
+
|
|
63
|
+
def __init__(self, message: str, row_number: int | None = None):
|
|
64
|
+
self.row_number = row_number
|
|
65
|
+
if row_number is not None:
|
|
66
|
+
message = f"Error in row {row_number}: {message}"
|
|
67
|
+
super().__init__(message)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def format_validation_error(e: ValidationError) -> str:
|
|
71
|
+
"""Convert a Pydantic validation error into a human-readable message."""
|
|
72
|
+
error_messages = []
|
|
73
|
+
for error in e.errors():
|
|
74
|
+
location = " -> ".join(str(loc) for loc in error["loc"])
|
|
75
|
+
message = error["msg"]
|
|
76
|
+
error_messages.append(f"- {location}: {message}")
|
|
77
|
+
|
|
78
|
+
return "Validation failed:\n" + "\n".join(error_messages)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def deserialize_tags(tags_serialized: str | None) -> list[str]:
|
|
82
|
+
"""Deserialize tags from a comma-separated string to a list of strings."""
|
|
83
|
+
if tags_serialized:
|
|
84
|
+
return [tag.strip() for tag in tags_serialized.split(",") if tag.strip()]
|
|
85
|
+
return []
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def without_none_values(d: dict) -> dict:
|
|
89
|
+
"""Return a copy of the dictionary with all None values removed."""
|
|
90
|
+
return {k: v for k, v in d.items() if v is not None}
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def create_task_run_from_csv_row(
|
|
94
|
+
task: Task,
|
|
95
|
+
row: dict[str, str],
|
|
96
|
+
dataset_name: str,
|
|
97
|
+
session_id: str,
|
|
98
|
+
) -> TaskRun:
|
|
99
|
+
"""Validate and create a TaskRun from a CSV row, without saving to file"""
|
|
100
|
+
|
|
101
|
+
# first we validate the row from the CSV file
|
|
102
|
+
validated_row = CSVRowSchema.model_validate(
|
|
103
|
+
{
|
|
104
|
+
**row,
|
|
105
|
+
"tags": deserialize_tags(row.get("tags")),
|
|
106
|
+
}
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
tags = generate_import_tags(session_id)
|
|
110
|
+
if validated_row.tags:
|
|
111
|
+
tags.extend(validated_row.tags)
|
|
112
|
+
|
|
113
|
+
# note that we don't persist the run yet, we just create and validate it
|
|
114
|
+
# this instantiation may raise pydantic validation errors
|
|
115
|
+
run = TaskRun(
|
|
116
|
+
parent=task,
|
|
117
|
+
input=validated_row.input,
|
|
118
|
+
input_source=DataSource(
|
|
119
|
+
type=DataSourceType.file_import,
|
|
120
|
+
properties={
|
|
121
|
+
"file_name": dataset_name,
|
|
122
|
+
},
|
|
123
|
+
),
|
|
124
|
+
output=TaskOutput(
|
|
125
|
+
output=validated_row.output,
|
|
126
|
+
source=DataSource(
|
|
127
|
+
type=DataSourceType.file_import,
|
|
128
|
+
properties={
|
|
129
|
+
"file_name": dataset_name,
|
|
130
|
+
},
|
|
131
|
+
),
|
|
132
|
+
),
|
|
133
|
+
intermediate_outputs=without_none_values(
|
|
134
|
+
{
|
|
135
|
+
"reasoning": validated_row.reasoning,
|
|
136
|
+
"chain_of_thought": validated_row.chain_of_thought,
|
|
137
|
+
}
|
|
138
|
+
)
|
|
139
|
+
or None,
|
|
140
|
+
tags=tags,
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
return run
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def import_csv(task: Task, dataset_path: str, dataset_name: str) -> int:
|
|
147
|
+
"""Import a CSV dataset.
|
|
148
|
+
|
|
149
|
+
All rows are validated before any are persisted to files to avoid partial imports."""
|
|
150
|
+
|
|
151
|
+
session_id = str(int(time.time()))
|
|
152
|
+
|
|
153
|
+
required_headers = {"input", "output"} # minimum required headers
|
|
154
|
+
optional_headers = {"reasoning", "tags", "chain_of_thought"} # optional headers
|
|
155
|
+
|
|
156
|
+
rows: list[TaskRun] = []
|
|
157
|
+
with open(dataset_path, "r", newline="") as csvfile:
|
|
158
|
+
reader = csv.DictReader(csvfile)
|
|
159
|
+
|
|
160
|
+
# Check if we have headers
|
|
161
|
+
if not reader.fieldnames:
|
|
162
|
+
raise KilnInvalidImportFormat(
|
|
163
|
+
"CSV file appears to be empty or missing headers"
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
# Check for required headers
|
|
167
|
+
actual_headers = set(reader.fieldnames)
|
|
168
|
+
missing_headers = required_headers - actual_headers
|
|
169
|
+
if missing_headers:
|
|
170
|
+
raise KilnInvalidImportFormat(
|
|
171
|
+
f"Missing required headers: {', '.join(missing_headers)}. "
|
|
172
|
+
f"Required headers are: {', '.join(required_headers)}"
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
# Warn about unknown headers (not required or optional)
|
|
176
|
+
unknown_headers = actual_headers - (required_headers | optional_headers)
|
|
177
|
+
if unknown_headers:
|
|
178
|
+
logger.warning(
|
|
179
|
+
f"Unknown headers in CSV file will be ignored: {', '.join(unknown_headers)}"
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
# enumeration starts at 2 because row 1 is headers
|
|
183
|
+
for row_number, row in enumerate(reader, start=2):
|
|
184
|
+
try:
|
|
185
|
+
run = create_task_run_from_csv_row(
|
|
186
|
+
task=task,
|
|
187
|
+
row=row,
|
|
188
|
+
dataset_name=dataset_name,
|
|
189
|
+
session_id=session_id,
|
|
190
|
+
)
|
|
191
|
+
except ValidationError as e:
|
|
192
|
+
logger.warning(f"Invalid row {row_number}: {row}", exc_info=True)
|
|
193
|
+
human_readable = format_validation_error(e)
|
|
194
|
+
raise KilnInvalidImportFormat(
|
|
195
|
+
human_readable,
|
|
196
|
+
row_number=row_number,
|
|
197
|
+
) from e
|
|
198
|
+
rows.append(run)
|
|
199
|
+
|
|
200
|
+
# now that we know all rows are valid, we can save them
|
|
201
|
+
for run in rows:
|
|
202
|
+
run.save_to_file()
|
|
203
|
+
|
|
204
|
+
return len(rows)
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
DATASET_IMPORTERS: Dict[DatasetImportFormat, Importer] = {
|
|
208
|
+
DatasetImportFormat.CSV: import_csv,
|
|
209
|
+
}
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
@dataclass
|
|
213
|
+
class ImportConfig:
|
|
214
|
+
"""Configuration for importing a dataset"""
|
|
215
|
+
|
|
216
|
+
dataset_type: DatasetImportFormat
|
|
217
|
+
dataset_path: str
|
|
218
|
+
dataset_name: str
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
class DatasetFileImporter:
|
|
222
|
+
"""Import a dataset from a file"""
|
|
223
|
+
|
|
224
|
+
def __init__(self, task: Task, config: ImportConfig):
|
|
225
|
+
self.task = task
|
|
226
|
+
self.dataset_type = config.dataset_type
|
|
227
|
+
self.dataset_path = config.dataset_path
|
|
228
|
+
self.dataset_name = config.dataset_name
|
|
229
|
+
|
|
230
|
+
def create_runs_from_file(self) -> int:
|
|
231
|
+
fn = DATASET_IMPORTERS[self.dataset_type]
|
|
232
|
+
return fn(self.task, self.dataset_path, self.dataset_name)
|