kiln-ai 0.12.0__py3-none-any.whl → 0.13.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (47) hide show
  1. kiln_ai/adapters/__init__.py +4 -0
  2. kiln_ai/adapters/adapter_registry.py +153 -28
  3. kiln_ai/adapters/eval/__init__.py +28 -0
  4. kiln_ai/adapters/eval/eval_runner.py +4 -1
  5. kiln_ai/adapters/eval/g_eval.py +2 -1
  6. kiln_ai/adapters/eval/test_base_eval.py +1 -0
  7. kiln_ai/adapters/eval/test_eval_runner.py +1 -0
  8. kiln_ai/adapters/eval/test_g_eval.py +1 -0
  9. kiln_ai/adapters/fine_tune/base_finetune.py +16 -2
  10. kiln_ai/adapters/fine_tune/finetune_registry.py +2 -0
  11. kiln_ai/adapters/fine_tune/test_together_finetune.py +531 -0
  12. kiln_ai/adapters/fine_tune/together_finetune.py +325 -0
  13. kiln_ai/adapters/ml_model_list.py +638 -155
  14. kiln_ai/adapters/model_adapters/__init__.py +2 -4
  15. kiln_ai/adapters/model_adapters/base_adapter.py +14 -11
  16. kiln_ai/adapters/model_adapters/litellm_adapter.py +391 -0
  17. kiln_ai/adapters/model_adapters/litellm_config.py +13 -0
  18. kiln_ai/adapters/model_adapters/test_litellm_adapter.py +407 -0
  19. kiln_ai/adapters/model_adapters/test_structured_output.py +23 -5
  20. kiln_ai/adapters/ollama_tools.py +3 -2
  21. kiln_ai/adapters/parsers/r1_parser.py +19 -14
  22. kiln_ai/adapters/parsers/test_r1_parser.py +17 -5
  23. kiln_ai/adapters/provider_tools.py +50 -58
  24. kiln_ai/adapters/repair/test_repair_task.py +3 -3
  25. kiln_ai/adapters/run_output.py +1 -1
  26. kiln_ai/adapters/test_adapter_registry.py +17 -20
  27. kiln_ai/adapters/test_generate_docs.py +2 -2
  28. kiln_ai/adapters/test_prompt_adaptors.py +30 -19
  29. kiln_ai/adapters/test_provider_tools.py +26 -81
  30. kiln_ai/datamodel/basemodel.py +2 -0
  31. kiln_ai/datamodel/datamodel_enums.py +2 -0
  32. kiln_ai/datamodel/json_schema.py +1 -1
  33. kiln_ai/datamodel/task_output.py +13 -6
  34. kiln_ai/datamodel/test_basemodel.py +9 -0
  35. kiln_ai/datamodel/test_datasource.py +19 -0
  36. kiln_ai/utils/config.py +37 -0
  37. kiln_ai/utils/dataset_import.py +232 -0
  38. kiln_ai/utils/test_dataset_import.py +596 -0
  39. {kiln_ai-0.12.0.dist-info → kiln_ai-0.13.0.dist-info}/METADATA +51 -7
  40. {kiln_ai-0.12.0.dist-info → kiln_ai-0.13.0.dist-info}/RECORD +42 -39
  41. kiln_ai/adapters/model_adapters/langchain_adapters.py +0 -309
  42. kiln_ai/adapters/model_adapters/openai_compatible_config.py +0 -10
  43. kiln_ai/adapters/model_adapters/openai_model_adapter.py +0 -289
  44. kiln_ai/adapters/model_adapters/test_langchain_adapter.py +0 -343
  45. kiln_ai/adapters/model_adapters/test_openai_model_adapter.py +0 -216
  46. {kiln_ai-0.12.0.dist-info → kiln_ai-0.13.0.dist-info}/WHEEL +0 -0
  47. {kiln_ai-0.12.0.dist-info → kiln_ai-0.13.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -49,7 +49,7 @@ def validate_schema(instance: Dict, schema_str: str) -> None:
49
49
  v.validate(instance)
50
50
  except jsonschema.exceptions.ValidationError as e:
51
51
  raise ValueError(
52
- f"This task requires a specific output schema. While the model produced JSON, that JSON didn't meet the schema. Search 'Troubleshooting Structured Data Issues' in our docs for more information. The error from the schema check was: {e.message}"
52
+ f"This task requires a specific output schema. While the model produced JSON, that JSON didn't meet the schema. Search 'Troubleshooting Structured Data Issues' in our docs for more information. The error from the schema check was: {e.message}. The JSON was: \n```json\n{instance}\n```"
53
53
  ) from e
54
54
 
55
55
 
@@ -171,6 +171,7 @@ class DataSourceType(str, Enum):
171
171
 
172
172
  human = "human"
173
173
  synthetic = "synthetic"
174
+ file_import = "file_import"
174
175
 
175
176
 
176
177
  class DataSourceProperty(BaseModel):
@@ -206,37 +207,43 @@ class DataSource(BaseModel):
206
207
  name="created_by",
207
208
  type=str,
208
209
  required_for=[DataSourceType.human],
209
- not_allowed_for=[DataSourceType.synthetic],
210
+ not_allowed_for=[DataSourceType.synthetic, DataSourceType.file_import],
210
211
  ),
211
212
  DataSourceProperty(
212
213
  name="model_name",
213
214
  type=str,
214
215
  required_for=[DataSourceType.synthetic],
215
- not_allowed_for=[DataSourceType.human],
216
+ not_allowed_for=[DataSourceType.human, DataSourceType.file_import],
216
217
  ),
217
218
  DataSourceProperty(
218
219
  name="model_provider",
219
220
  type=str,
220
221
  required_for=[DataSourceType.synthetic],
221
- not_allowed_for=[DataSourceType.human],
222
+ not_allowed_for=[DataSourceType.human, DataSourceType.file_import],
222
223
  ),
223
224
  DataSourceProperty(
224
225
  name="adapter_name",
225
226
  type=str,
226
227
  required_for=[DataSourceType.synthetic],
227
- not_allowed_for=[DataSourceType.human],
228
+ not_allowed_for=[DataSourceType.human, DataSourceType.file_import],
228
229
  ),
229
230
  DataSourceProperty(
230
231
  # Legacy field -- allow loading from old runs, but we shouldn't be setting it.
231
232
  name="prompt_builder_name",
232
233
  type=str,
233
- not_allowed_for=[DataSourceType.human],
234
+ not_allowed_for=[DataSourceType.human, DataSourceType.file_import],
234
235
  ),
235
236
  DataSourceProperty(
236
237
  # The PromptId of the prompt. Can be a saved prompt, fine-tune, generator name, etc. See PromptId type for more details.
237
238
  name="prompt_id",
238
239
  type=str,
239
- not_allowed_for=[DataSourceType.human],
240
+ not_allowed_for=[DataSourceType.human, DataSourceType.file_import],
241
+ ),
242
+ DataSourceProperty(
243
+ name="file_name",
244
+ type=str,
245
+ required_for=[DataSourceType.file_import],
246
+ not_allowed_for=[DataSourceType.human, DataSourceType.synthetic],
240
247
  ),
241
248
  ]
242
249
 
@@ -510,6 +510,7 @@ async def test_invoke_parsing_flow(adapter):
510
510
  # Mock dependencies
511
511
  mock_provider = MagicMock()
512
512
  mock_provider.parser = "test_parser"
513
+ mock_provider.reasoning_capable = False
513
514
 
514
515
  mock_parser = MagicMock()
515
516
  mock_parser.parse_output.return_value = RunOutput(
@@ -547,3 +548,11 @@ async def test_invoke_parsing_flow(adapter):
547
548
  assert result.output.output == "parsed test output"
548
549
  assert result.intermediate_outputs == {"key": "value"}
549
550
  assert result.input == "test input"
551
+
552
+ # Test with reasoning required, that we error if no reasoning is returned
553
+ mock_provider.reasoning_capable = True
554
+ with pytest.raises(
555
+ RuntimeError,
556
+ match="Reasoning is required for this model, but no reasoning was returned.",
557
+ ):
558
+ await adapter.invoke("test input")
@@ -29,11 +29,30 @@ def test_valid_synthetic_data_source():
29
29
  assert data_source.properties["adapter_name"] == "langchain"
30
30
 
31
31
 
32
+ def test_valid_file_import_data_source():
33
+ data_source = DataSource(
34
+ type=DataSourceType.file_import,
35
+ properties={"file_name": "test.txt"},
36
+ )
37
+ assert data_source.type == DataSourceType.file_import
38
+ assert data_source.properties["file_name"] == "test.txt"
39
+
40
+
32
41
  def test_missing_required_property():
33
42
  with pytest.raises(ValidationError, match="'created_by' is required for"):
34
43
  DataSource(type=DataSourceType.human)
35
44
 
36
45
 
46
+ def test_missing_required_property_file_import():
47
+ with pytest.raises(ValidationError, match="'file_name' is required for"):
48
+ DataSource(type=DataSourceType.file_import)
49
+
50
+
51
+ def test_not_allowed_property_file_import():
52
+ with pytest.raises(ValidationError, match="'model_name' is not allowed for"):
53
+ DataSource(type=DataSourceType.file_import, properties={"model_name": "GPT-4"})
54
+
55
+
37
56
  def test_wrong_property_type():
38
57
  with pytest.raises(
39
58
  ValidationError,
kiln_ai/utils/config.py CHANGED
@@ -78,10 +78,47 @@ class Config:
78
78
  str,
79
79
  env_var="FIREWORKS_ACCOUNT_ID",
80
80
  ),
81
+ "anthropic_api_key": ConfigProperty(
82
+ str,
83
+ env_var="ANTHROPIC_API_KEY",
84
+ sensitive=True,
85
+ ),
86
+ "gemini_api_key": ConfigProperty(
87
+ str,
88
+ env_var="GEMINI_API_KEY",
89
+ sensitive=True,
90
+ ),
81
91
  "projects": ConfigProperty(
82
92
  list,
83
93
  default_lambda=lambda: [],
84
94
  ),
95
+ "azure_openai_api_key": ConfigProperty(
96
+ str,
97
+ env_var="AZURE_OPENAI_API_KEY",
98
+ sensitive=True,
99
+ ),
100
+ "azure_openai_endpoint": ConfigProperty(
101
+ str,
102
+ env_var="AZURE_OPENAI_ENDPOINT",
103
+ ),
104
+ "huggingface_api_key": ConfigProperty(
105
+ str,
106
+ env_var="HUGGINGFACE_API_KEY",
107
+ sensitive=True,
108
+ ),
109
+ "vertex_project_id": ConfigProperty(
110
+ str,
111
+ env_var="VERTEX_PROJECT_ID",
112
+ ),
113
+ "vertex_location": ConfigProperty(
114
+ str,
115
+ env_var="VERTEX_LOCATION",
116
+ ),
117
+ "together_api_key": ConfigProperty(
118
+ str,
119
+ env_var="TOGETHERAI_API_KEY",
120
+ sensitive=True,
121
+ ),
85
122
  "custom_models": ConfigProperty(
86
123
  list,
87
124
  default_lambda=lambda: [],
@@ -0,0 +1,232 @@
1
+ import csv
2
+ import logging
3
+ import time
4
+ from dataclasses import dataclass
5
+ from enum import Enum
6
+ from typing import Dict, Protocol
7
+
8
+ from pydantic import BaseModel, Field, ValidationError, field_validator
9
+
10
+ from kiln_ai.datamodel import DataSource, DataSourceType, Task, TaskOutput, TaskRun
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class DatasetImportFormat(str, Enum):
16
+ """
17
+ The format of the dataset to import.
18
+ """
19
+
20
+ CSV = "csv"
21
+
22
+
23
+ class Importer(Protocol):
24
+ """Protocol for dataset importers"""
25
+
26
+ def __call__(
27
+ self,
28
+ task: Task,
29
+ dataset_path: str,
30
+ dataset_name: str,
31
+ ) -> int: ...
32
+
33
+
34
+ class CSVRowSchema(BaseModel):
35
+ """Schema for validating rows in a CSV file."""
36
+
37
+ input: str = Field(description="The input to the model")
38
+ output: str = Field(description="The output of the model")
39
+ reasoning: str | None = Field(
40
+ description="The reasoning of the model (optional)",
41
+ default=None,
42
+ )
43
+ chain_of_thought: str | None = Field(
44
+ description="The chain of thought of the model (optional)",
45
+ default=None,
46
+ )
47
+ tags: list[str] = Field(
48
+ default_factory=list,
49
+ description="The tags of the run (optional)",
50
+ )
51
+
52
+
53
+ def generate_import_tags(session_id: str) -> list[str]:
54
+ return [
55
+ "imported",
56
+ f"imported_{session_id}",
57
+ ]
58
+
59
+
60
+ class KilnInvalidImportFormat(Exception):
61
+ """Raised when the import format is invalid"""
62
+
63
+ def __init__(self, message: str, row_number: int | None = None):
64
+ self.row_number = row_number
65
+ if row_number is not None:
66
+ message = f"Error in row {row_number}: {message}"
67
+ super().__init__(message)
68
+
69
+
70
+ def format_validation_error(e: ValidationError) -> str:
71
+ """Convert a Pydantic validation error into a human-readable message."""
72
+ error_messages = []
73
+ for error in e.errors():
74
+ location = " -> ".join(str(loc) for loc in error["loc"])
75
+ message = error["msg"]
76
+ error_messages.append(f"- {location}: {message}")
77
+
78
+ return "Validation failed:\n" + "\n".join(error_messages)
79
+
80
+
81
+ def deserialize_tags(tags_serialized: str | None) -> list[str]:
82
+ """Deserialize tags from a comma-separated string to a list of strings."""
83
+ if tags_serialized:
84
+ return [tag.strip() for tag in tags_serialized.split(",") if tag.strip()]
85
+ return []
86
+
87
+
88
+ def without_none_values(d: dict) -> dict:
89
+ """Return a copy of the dictionary with all None values removed."""
90
+ return {k: v for k, v in d.items() if v is not None}
91
+
92
+
93
+ def create_task_run_from_csv_row(
94
+ task: Task,
95
+ row: dict[str, str],
96
+ dataset_name: str,
97
+ session_id: str,
98
+ ) -> TaskRun:
99
+ """Validate and create a TaskRun from a CSV row, without saving to file"""
100
+
101
+ # first we validate the row from the CSV file
102
+ validated_row = CSVRowSchema.model_validate(
103
+ {
104
+ **row,
105
+ "tags": deserialize_tags(row.get("tags")),
106
+ }
107
+ )
108
+
109
+ tags = generate_import_tags(session_id)
110
+ if validated_row.tags:
111
+ tags.extend(validated_row.tags)
112
+
113
+ # note that we don't persist the run yet, we just create and validate it
114
+ # this instantiation may raise pydantic validation errors
115
+ run = TaskRun(
116
+ parent=task,
117
+ input=validated_row.input,
118
+ input_source=DataSource(
119
+ type=DataSourceType.file_import,
120
+ properties={
121
+ "file_name": dataset_name,
122
+ },
123
+ ),
124
+ output=TaskOutput(
125
+ output=validated_row.output,
126
+ source=DataSource(
127
+ type=DataSourceType.file_import,
128
+ properties={
129
+ "file_name": dataset_name,
130
+ },
131
+ ),
132
+ ),
133
+ intermediate_outputs=without_none_values(
134
+ {
135
+ "reasoning": validated_row.reasoning,
136
+ "chain_of_thought": validated_row.chain_of_thought,
137
+ }
138
+ )
139
+ or None,
140
+ tags=tags,
141
+ )
142
+
143
+ return run
144
+
145
+
146
+ def import_csv(task: Task, dataset_path: str, dataset_name: str) -> int:
147
+ """Import a CSV dataset.
148
+
149
+ All rows are validated before any are persisted to files to avoid partial imports."""
150
+
151
+ session_id = str(int(time.time()))
152
+
153
+ required_headers = {"input", "output"} # minimum required headers
154
+ optional_headers = {"reasoning", "tags", "chain_of_thought"} # optional headers
155
+
156
+ rows: list[TaskRun] = []
157
+ with open(dataset_path, "r", newline="") as csvfile:
158
+ reader = csv.DictReader(csvfile)
159
+
160
+ # Check if we have headers
161
+ if not reader.fieldnames:
162
+ raise KilnInvalidImportFormat(
163
+ "CSV file appears to be empty or missing headers"
164
+ )
165
+
166
+ # Check for required headers
167
+ actual_headers = set(reader.fieldnames)
168
+ missing_headers = required_headers - actual_headers
169
+ if missing_headers:
170
+ raise KilnInvalidImportFormat(
171
+ f"Missing required headers: {', '.join(missing_headers)}. "
172
+ f"Required headers are: {', '.join(required_headers)}"
173
+ )
174
+
175
+ # Warn about unknown headers (not required or optional)
176
+ unknown_headers = actual_headers - (required_headers | optional_headers)
177
+ if unknown_headers:
178
+ logger.warning(
179
+ f"Unknown headers in CSV file will be ignored: {', '.join(unknown_headers)}"
180
+ )
181
+
182
+ # enumeration starts at 2 because row 1 is headers
183
+ for row_number, row in enumerate(reader, start=2):
184
+ try:
185
+ run = create_task_run_from_csv_row(
186
+ task=task,
187
+ row=row,
188
+ dataset_name=dataset_name,
189
+ session_id=session_id,
190
+ )
191
+ except ValidationError as e:
192
+ logger.warning(f"Invalid row {row_number}: {row}", exc_info=True)
193
+ human_readable = format_validation_error(e)
194
+ raise KilnInvalidImportFormat(
195
+ human_readable,
196
+ row_number=row_number,
197
+ ) from e
198
+ rows.append(run)
199
+
200
+ # now that we know all rows are valid, we can save them
201
+ for run in rows:
202
+ run.save_to_file()
203
+
204
+ return len(rows)
205
+
206
+
207
+ DATASET_IMPORTERS: Dict[DatasetImportFormat, Importer] = {
208
+ DatasetImportFormat.CSV: import_csv,
209
+ }
210
+
211
+
212
+ @dataclass
213
+ class ImportConfig:
214
+ """Configuration for importing a dataset"""
215
+
216
+ dataset_type: DatasetImportFormat
217
+ dataset_path: str
218
+ dataset_name: str
219
+
220
+
221
+ class DatasetFileImporter:
222
+ """Import a dataset from a file"""
223
+
224
+ def __init__(self, task: Task, config: ImportConfig):
225
+ self.task = task
226
+ self.dataset_type = config.dataset_type
227
+ self.dataset_path = config.dataset_path
228
+ self.dataset_name = config.dataset_name
229
+
230
+ def create_runs_from_file(self) -> int:
231
+ fn = DATASET_IMPORTERS[self.dataset_type]
232
+ return fn(self.task, self.dataset_path, self.dataset_name)