kiln-ai 0.11.1__py3-none-any.whl → 0.13.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (80) hide show
  1. kiln_ai/adapters/__init__.py +4 -0
  2. kiln_ai/adapters/adapter_registry.py +163 -39
  3. kiln_ai/adapters/data_gen/data_gen_task.py +18 -0
  4. kiln_ai/adapters/eval/__init__.py +28 -0
  5. kiln_ai/adapters/eval/base_eval.py +164 -0
  6. kiln_ai/adapters/eval/eval_runner.py +270 -0
  7. kiln_ai/adapters/eval/g_eval.py +368 -0
  8. kiln_ai/adapters/eval/registry.py +16 -0
  9. kiln_ai/adapters/eval/test_base_eval.py +325 -0
  10. kiln_ai/adapters/eval/test_eval_runner.py +641 -0
  11. kiln_ai/adapters/eval/test_g_eval.py +498 -0
  12. kiln_ai/adapters/eval/test_g_eval_data.py +4 -0
  13. kiln_ai/adapters/fine_tune/base_finetune.py +16 -2
  14. kiln_ai/adapters/fine_tune/finetune_registry.py +2 -0
  15. kiln_ai/adapters/fine_tune/test_dataset_formatter.py +4 -1
  16. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +1 -1
  17. kiln_ai/adapters/fine_tune/test_openai_finetune.py +1 -1
  18. kiln_ai/adapters/fine_tune/test_together_finetune.py +531 -0
  19. kiln_ai/adapters/fine_tune/together_finetune.py +325 -0
  20. kiln_ai/adapters/ml_model_list.py +758 -163
  21. kiln_ai/adapters/model_adapters/__init__.py +2 -4
  22. kiln_ai/adapters/model_adapters/base_adapter.py +61 -43
  23. kiln_ai/adapters/model_adapters/litellm_adapter.py +391 -0
  24. kiln_ai/adapters/model_adapters/litellm_config.py +13 -0
  25. kiln_ai/adapters/model_adapters/test_base_adapter.py +22 -13
  26. kiln_ai/adapters/model_adapters/test_litellm_adapter.py +407 -0
  27. kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +41 -19
  28. kiln_ai/adapters/model_adapters/test_structured_output.py +59 -35
  29. kiln_ai/adapters/ollama_tools.py +3 -3
  30. kiln_ai/adapters/parsers/r1_parser.py +19 -14
  31. kiln_ai/adapters/parsers/test_r1_parser.py +17 -5
  32. kiln_ai/adapters/prompt_builders.py +80 -42
  33. kiln_ai/adapters/provider_tools.py +50 -58
  34. kiln_ai/adapters/repair/repair_task.py +9 -21
  35. kiln_ai/adapters/repair/test_repair_task.py +6 -6
  36. kiln_ai/adapters/run_output.py +3 -0
  37. kiln_ai/adapters/test_adapter_registry.py +26 -29
  38. kiln_ai/adapters/test_generate_docs.py +4 -4
  39. kiln_ai/adapters/test_ollama_tools.py +0 -1
  40. kiln_ai/adapters/test_prompt_adaptors.py +47 -33
  41. kiln_ai/adapters/test_prompt_builders.py +91 -31
  42. kiln_ai/adapters/test_provider_tools.py +26 -81
  43. kiln_ai/datamodel/__init__.py +50 -952
  44. kiln_ai/datamodel/basemodel.py +2 -0
  45. kiln_ai/datamodel/datamodel_enums.py +60 -0
  46. kiln_ai/datamodel/dataset_filters.py +114 -0
  47. kiln_ai/datamodel/dataset_split.py +170 -0
  48. kiln_ai/datamodel/eval.py +298 -0
  49. kiln_ai/datamodel/finetune.py +105 -0
  50. kiln_ai/datamodel/json_schema.py +7 -1
  51. kiln_ai/datamodel/project.py +23 -0
  52. kiln_ai/datamodel/prompt.py +37 -0
  53. kiln_ai/datamodel/prompt_id.py +83 -0
  54. kiln_ai/datamodel/strict_mode.py +24 -0
  55. kiln_ai/datamodel/task.py +181 -0
  56. kiln_ai/datamodel/task_output.py +328 -0
  57. kiln_ai/datamodel/task_run.py +164 -0
  58. kiln_ai/datamodel/test_basemodel.py +19 -11
  59. kiln_ai/datamodel/test_dataset_filters.py +71 -0
  60. kiln_ai/datamodel/test_dataset_split.py +32 -8
  61. kiln_ai/datamodel/test_datasource.py +22 -2
  62. kiln_ai/datamodel/test_eval_model.py +635 -0
  63. kiln_ai/datamodel/test_example_models.py +9 -13
  64. kiln_ai/datamodel/test_json_schema.py +23 -0
  65. kiln_ai/datamodel/test_models.py +2 -2
  66. kiln_ai/datamodel/test_prompt_id.py +129 -0
  67. kiln_ai/datamodel/test_task.py +159 -0
  68. kiln_ai/utils/config.py +43 -1
  69. kiln_ai/utils/dataset_import.py +232 -0
  70. kiln_ai/utils/test_dataset_import.py +596 -0
  71. {kiln_ai-0.11.1.dist-info → kiln_ai-0.13.0.dist-info}/METADATA +86 -6
  72. kiln_ai-0.13.0.dist-info/RECORD +103 -0
  73. kiln_ai/adapters/model_adapters/langchain_adapters.py +0 -302
  74. kiln_ai/adapters/model_adapters/openai_compatible_config.py +0 -11
  75. kiln_ai/adapters/model_adapters/openai_model_adapter.py +0 -246
  76. kiln_ai/adapters/model_adapters/test_langchain_adapter.py +0 -350
  77. kiln_ai/adapters/model_adapters/test_openai_model_adapter.py +0 -225
  78. kiln_ai-0.11.1.dist-info/RECORD +0 -76
  79. {kiln_ai-0.11.1.dist-info → kiln_ai-0.13.0.dist-info}/WHEEL +0 -0
  80. {kiln_ai-0.11.1.dist-info → kiln_ai-0.13.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -0,0 +1,232 @@
1
+ import csv
2
+ import logging
3
+ import time
4
+ from dataclasses import dataclass
5
+ from enum import Enum
6
+ from typing import Dict, Protocol
7
+
8
+ from pydantic import BaseModel, Field, ValidationError, field_validator
9
+
10
+ from kiln_ai.datamodel import DataSource, DataSourceType, Task, TaskOutput, TaskRun
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class DatasetImportFormat(str, Enum):
16
+ """
17
+ The format of the dataset to import.
18
+ """
19
+
20
+ CSV = "csv"
21
+
22
+
23
+ class Importer(Protocol):
24
+ """Protocol for dataset importers"""
25
+
26
+ def __call__(
27
+ self,
28
+ task: Task,
29
+ dataset_path: str,
30
+ dataset_name: str,
31
+ ) -> int: ...
32
+
33
+
34
+ class CSVRowSchema(BaseModel):
35
+ """Schema for validating rows in a CSV file."""
36
+
37
+ input: str = Field(description="The input to the model")
38
+ output: str = Field(description="The output of the model")
39
+ reasoning: str | None = Field(
40
+ description="The reasoning of the model (optional)",
41
+ default=None,
42
+ )
43
+ chain_of_thought: str | None = Field(
44
+ description="The chain of thought of the model (optional)",
45
+ default=None,
46
+ )
47
+ tags: list[str] = Field(
48
+ default_factory=list,
49
+ description="The tags of the run (optional)",
50
+ )
51
+
52
+
53
+ def generate_import_tags(session_id: str) -> list[str]:
54
+ return [
55
+ "imported",
56
+ f"imported_{session_id}",
57
+ ]
58
+
59
+
60
+ class KilnInvalidImportFormat(Exception):
61
+ """Raised when the import format is invalid"""
62
+
63
+ def __init__(self, message: str, row_number: int | None = None):
64
+ self.row_number = row_number
65
+ if row_number is not None:
66
+ message = f"Error in row {row_number}: {message}"
67
+ super().__init__(message)
68
+
69
+
70
+ def format_validation_error(e: ValidationError) -> str:
71
+ """Convert a Pydantic validation error into a human-readable message."""
72
+ error_messages = []
73
+ for error in e.errors():
74
+ location = " -> ".join(str(loc) for loc in error["loc"])
75
+ message = error["msg"]
76
+ error_messages.append(f"- {location}: {message}")
77
+
78
+ return "Validation failed:\n" + "\n".join(error_messages)
79
+
80
+
81
+ def deserialize_tags(tags_serialized: str | None) -> list[str]:
82
+ """Deserialize tags from a comma-separated string to a list of strings."""
83
+ if tags_serialized:
84
+ return [tag.strip() for tag in tags_serialized.split(",") if tag.strip()]
85
+ return []
86
+
87
+
88
+ def without_none_values(d: dict) -> dict:
89
+ """Return a copy of the dictionary with all None values removed."""
90
+ return {k: v for k, v in d.items() if v is not None}
91
+
92
+
93
+ def create_task_run_from_csv_row(
94
+ task: Task,
95
+ row: dict[str, str],
96
+ dataset_name: str,
97
+ session_id: str,
98
+ ) -> TaskRun:
99
+ """Validate and create a TaskRun from a CSV row, without saving to file"""
100
+
101
+ # first we validate the row from the CSV file
102
+ validated_row = CSVRowSchema.model_validate(
103
+ {
104
+ **row,
105
+ "tags": deserialize_tags(row.get("tags")),
106
+ }
107
+ )
108
+
109
+ tags = generate_import_tags(session_id)
110
+ if validated_row.tags:
111
+ tags.extend(validated_row.tags)
112
+
113
+ # note that we don't persist the run yet, we just create and validate it
114
+ # this instantiation may raise pydantic validation errors
115
+ run = TaskRun(
116
+ parent=task,
117
+ input=validated_row.input,
118
+ input_source=DataSource(
119
+ type=DataSourceType.file_import,
120
+ properties={
121
+ "file_name": dataset_name,
122
+ },
123
+ ),
124
+ output=TaskOutput(
125
+ output=validated_row.output,
126
+ source=DataSource(
127
+ type=DataSourceType.file_import,
128
+ properties={
129
+ "file_name": dataset_name,
130
+ },
131
+ ),
132
+ ),
133
+ intermediate_outputs=without_none_values(
134
+ {
135
+ "reasoning": validated_row.reasoning,
136
+ "chain_of_thought": validated_row.chain_of_thought,
137
+ }
138
+ )
139
+ or None,
140
+ tags=tags,
141
+ )
142
+
143
+ return run
144
+
145
+
146
+ def import_csv(task: Task, dataset_path: str, dataset_name: str) -> int:
147
+ """Import a CSV dataset.
148
+
149
+ All rows are validated before any are persisted to files to avoid partial imports."""
150
+
151
+ session_id = str(int(time.time()))
152
+
153
+ required_headers = {"input", "output"} # minimum required headers
154
+ optional_headers = {"reasoning", "tags", "chain_of_thought"} # optional headers
155
+
156
+ rows: list[TaskRun] = []
157
+ with open(dataset_path, "r", newline="") as csvfile:
158
+ reader = csv.DictReader(csvfile)
159
+
160
+ # Check if we have headers
161
+ if not reader.fieldnames:
162
+ raise KilnInvalidImportFormat(
163
+ "CSV file appears to be empty or missing headers"
164
+ )
165
+
166
+ # Check for required headers
167
+ actual_headers = set(reader.fieldnames)
168
+ missing_headers = required_headers - actual_headers
169
+ if missing_headers:
170
+ raise KilnInvalidImportFormat(
171
+ f"Missing required headers: {', '.join(missing_headers)}. "
172
+ f"Required headers are: {', '.join(required_headers)}"
173
+ )
174
+
175
+ # Warn about unknown headers (not required or optional)
176
+ unknown_headers = actual_headers - (required_headers | optional_headers)
177
+ if unknown_headers:
178
+ logger.warning(
179
+ f"Unknown headers in CSV file will be ignored: {', '.join(unknown_headers)}"
180
+ )
181
+
182
+ # enumeration starts at 2 because row 1 is headers
183
+ for row_number, row in enumerate(reader, start=2):
184
+ try:
185
+ run = create_task_run_from_csv_row(
186
+ task=task,
187
+ row=row,
188
+ dataset_name=dataset_name,
189
+ session_id=session_id,
190
+ )
191
+ except ValidationError as e:
192
+ logger.warning(f"Invalid row {row_number}: {row}", exc_info=True)
193
+ human_readable = format_validation_error(e)
194
+ raise KilnInvalidImportFormat(
195
+ human_readable,
196
+ row_number=row_number,
197
+ ) from e
198
+ rows.append(run)
199
+
200
+ # now that we know all rows are valid, we can save them
201
+ for run in rows:
202
+ run.save_to_file()
203
+
204
+ return len(rows)
205
+
206
+
207
+ DATASET_IMPORTERS: Dict[DatasetImportFormat, Importer] = {
208
+ DatasetImportFormat.CSV: import_csv,
209
+ }
210
+
211
+
212
+ @dataclass
213
+ class ImportConfig:
214
+ """Configuration for importing a dataset"""
215
+
216
+ dataset_type: DatasetImportFormat
217
+ dataset_path: str
218
+ dataset_name: str
219
+
220
+
221
+ class DatasetFileImporter:
222
+ """Import a dataset from a file"""
223
+
224
+ def __init__(self, task: Task, config: ImportConfig):
225
+ self.task = task
226
+ self.dataset_type = config.dataset_type
227
+ self.dataset_path = config.dataset_path
228
+ self.dataset_name = config.dataset_name
229
+
230
+ def create_runs_from_file(self) -> int:
231
+ fn = DATASET_IMPORTERS[self.dataset_type]
232
+ return fn(self.task, self.dataset_path, self.dataset_name)