kiln-ai 0.14.0__py3-none-any.whl → 0.16.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. kiln_ai/adapters/eval/base_eval.py +7 -2
  2. kiln_ai/adapters/eval/eval_runner.py +5 -64
  3. kiln_ai/adapters/eval/g_eval.py +3 -3
  4. kiln_ai/adapters/fine_tune/base_finetune.py +6 -3
  5. kiln_ai/adapters/fine_tune/dataset_formatter.py +128 -38
  6. kiln_ai/adapters/fine_tune/finetune_registry.py +2 -0
  7. kiln_ai/adapters/fine_tune/fireworks_finetune.py +2 -1
  8. kiln_ai/adapters/fine_tune/test_base_finetune.py +7 -0
  9. kiln_ai/adapters/fine_tune/test_dataset_formatter.py +267 -10
  10. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +1 -1
  11. kiln_ai/adapters/fine_tune/test_vertex_finetune.py +586 -0
  12. kiln_ai/adapters/fine_tune/vertex_finetune.py +217 -0
  13. kiln_ai/adapters/ml_model_list.py +817 -62
  14. kiln_ai/adapters/model_adapters/base_adapter.py +33 -10
  15. kiln_ai/adapters/model_adapters/litellm_adapter.py +51 -12
  16. kiln_ai/adapters/model_adapters/test_base_adapter.py +74 -2
  17. kiln_ai/adapters/model_adapters/test_litellm_adapter.py +65 -1
  18. kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +3 -2
  19. kiln_ai/adapters/model_adapters/test_structured_output.py +4 -6
  20. kiln_ai/adapters/parsers/base_parser.py +0 -3
  21. kiln_ai/adapters/parsers/parser_registry.py +5 -3
  22. kiln_ai/adapters/parsers/r1_parser.py +17 -2
  23. kiln_ai/adapters/parsers/request_formatters.py +40 -0
  24. kiln_ai/adapters/parsers/test_parser_registry.py +2 -2
  25. kiln_ai/adapters/parsers/test_r1_parser.py +44 -1
  26. kiln_ai/adapters/parsers/test_request_formatters.py +76 -0
  27. kiln_ai/adapters/prompt_builders.py +14 -1
  28. kiln_ai/adapters/provider_tools.py +25 -1
  29. kiln_ai/adapters/repair/test_repair_task.py +3 -2
  30. kiln_ai/adapters/test_prompt_builders.py +24 -3
  31. kiln_ai/adapters/test_provider_tools.py +86 -1
  32. kiln_ai/datamodel/__init__.py +2 -0
  33. kiln_ai/datamodel/datamodel_enums.py +14 -0
  34. kiln_ai/datamodel/dataset_filters.py +69 -1
  35. kiln_ai/datamodel/dataset_split.py +4 -0
  36. kiln_ai/datamodel/eval.py +8 -0
  37. kiln_ai/datamodel/finetune.py +1 -0
  38. kiln_ai/datamodel/json_schema.py +24 -7
  39. kiln_ai/datamodel/prompt_id.py +1 -0
  40. kiln_ai/datamodel/task_output.py +10 -6
  41. kiln_ai/datamodel/task_run.py +68 -12
  42. kiln_ai/datamodel/test_basemodel.py +3 -7
  43. kiln_ai/datamodel/test_dataset_filters.py +82 -0
  44. kiln_ai/datamodel/test_dataset_split.py +2 -0
  45. kiln_ai/datamodel/test_example_models.py +158 -3
  46. kiln_ai/datamodel/test_json_schema.py +22 -3
  47. kiln_ai/datamodel/test_model_perf.py +3 -2
  48. kiln_ai/datamodel/test_models.py +50 -2
  49. kiln_ai/utils/async_job_runner.py +106 -0
  50. kiln_ai/utils/dataset_import.py +80 -18
  51. kiln_ai/utils/test_async_job_runner.py +199 -0
  52. kiln_ai/utils/test_dataset_import.py +242 -10
  53. {kiln_ai-0.14.0.dist-info → kiln_ai-0.16.0.dist-info}/METADATA +3 -2
  54. kiln_ai-0.16.0.dist-info/RECORD +108 -0
  55. kiln_ai/adapters/test_generate_docs.py +0 -69
  56. kiln_ai-0.14.0.dist-info/RECORD +0 -103
  57. {kiln_ai-0.14.0.dist-info → kiln_ai-0.16.0.dist-info}/WHEEL +0 -0
  58. {kiln_ai-0.14.0.dist-info → kiln_ai-0.16.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -0,0 +1,106 @@
1
+ import asyncio
2
+ import logging
3
+ from dataclasses import dataclass
4
+ from typing import AsyncGenerator, Awaitable, Callable, List, TypeVar
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+ T = TypeVar("T")
9
+
10
+
11
+ @dataclass
12
+ class Progress:
13
+ complete: int
14
+ total: int
15
+ errors: int
16
+
17
+
18
+ class AsyncJobRunner:
19
+ def __init__(self, concurrency: int = 1):
20
+ if concurrency < 1:
21
+ raise ValueError("concurrency must be ≥ 1")
22
+ self.concurrency = concurrency
23
+
24
+ async def run(
25
+ self,
26
+ jobs: List[T],
27
+ run_job: Callable[[T], Awaitable[bool]],
28
+ ) -> AsyncGenerator[Progress, None]:
29
+ """
30
+ Runs the jobs with parallel workers and yields progress updates.
31
+ """
32
+ complete = 0
33
+ errors = 0
34
+ total = len(jobs)
35
+
36
+ # Send initial status
37
+ yield Progress(complete=complete, total=total, errors=errors)
38
+
39
+ worker_queue: asyncio.Queue[T] = asyncio.Queue()
40
+ for job in jobs:
41
+ worker_queue.put_nowait(job)
42
+
43
+ # simple status queue to return progress. True=success, False=error
44
+ status_queue: asyncio.Queue[bool] = asyncio.Queue()
45
+
46
+ workers = []
47
+ for _ in range(self.concurrency):
48
+ task = asyncio.create_task(
49
+ self._run_worker(worker_queue, status_queue, run_job),
50
+ )
51
+ workers.append(task)
52
+
53
+ try:
54
+ # Send status updates until workers are done, and they are all sent
55
+ while not status_queue.empty() or not all(
56
+ worker.done() for worker in workers
57
+ ):
58
+ try:
59
+ # Use timeout to prevent hanging if all workers complete
60
+ # between our while condition check and get()
61
+ success = await asyncio.wait_for(status_queue.get(), timeout=0.1)
62
+ if success:
63
+ complete += 1
64
+ else:
65
+ errors += 1
66
+
67
+ yield Progress(complete=complete, total=total, errors=errors)
68
+ except asyncio.TimeoutError:
69
+ # Timeout is expected, just continue to recheck worker status
70
+ # Don't love this but beats sentinels for reliability
71
+ continue
72
+ finally:
73
+ # Cancel outstanding workers on early exit or error
74
+ for w in workers:
75
+ w.cancel()
76
+
77
+ # These are redundant, but keeping them will catch async errors
78
+ await asyncio.gather(*workers)
79
+ await worker_queue.join()
80
+
81
+ async def _run_worker(
82
+ self,
83
+ worker_queue: asyncio.Queue[T],
84
+ status_queue: asyncio.Queue[bool],
85
+ run_job: Callable[[T], Awaitable[bool]],
86
+ ):
87
+ while True:
88
+ try:
89
+ job = worker_queue.get_nowait()
90
+ except asyncio.QueueEmpty:
91
+ # worker can end when the queue is empty
92
+ break
93
+
94
+ try:
95
+ success = await run_job(job)
96
+ except Exception:
97
+ logger.error("Job failed to complete", exc_info=True)
98
+ success = False
99
+
100
+ try:
101
+ await status_queue.put(success)
102
+ except Exception:
103
+ logger.error("Failed to enqueue status for job", exc_info=True)
104
+ finally:
105
+ # Always mark the dequeued task as done, even on exceptions
106
+ worker_queue.task_done()
@@ -1,11 +1,12 @@
1
1
  import csv
2
2
  import logging
3
+ import random
3
4
  import time
4
5
  from dataclasses import dataclass
5
6
  from enum import Enum
6
7
  from typing import Dict, Protocol
7
8
 
8
- from pydantic import BaseModel, Field, ValidationError, field_validator
9
+ from pydantic import BaseModel, Field, ValidationError
9
10
 
10
11
  from kiln_ai.datamodel import DataSource, DataSourceType, Task, TaskOutput, TaskRun
11
12
 
@@ -20,14 +21,36 @@ class DatasetImportFormat(str, Enum):
20
21
  CSV = "csv"
21
22
 
22
23
 
24
+ @dataclass
25
+ class ImportConfig:
26
+ """Configuration for importing a dataset"""
27
+
28
+ dataset_type: DatasetImportFormat
29
+ dataset_path: str
30
+ dataset_name: str
31
+ """
32
+ A set of splits to assign to the import (as dataset tags).
33
+ The keys are the names of the splits (tag name), and the values are the proportions of the dataset to include in each split (should sum to 1).
34
+ """
35
+ tag_splits: Dict[str, float] | None = None
36
+
37
+ def validate_tag_splits(self) -> None:
38
+ if self.tag_splits:
39
+ EPSILON = 0.001 # Allow for small floating point errors
40
+ if abs(sum(self.tag_splits.values()) - 1) > EPSILON:
41
+ raise ValueError(
42
+ "Splits must sum to 1. The following splits do not: "
43
+ + ", ".join(f"{k}: {v}" for k, v in self.tag_splits.items())
44
+ )
45
+
46
+
23
47
  class Importer(Protocol):
24
48
  """Protocol for dataset importers"""
25
49
 
26
50
  def __call__(
27
51
  self,
28
52
  task: Task,
29
- dataset_path: str,
30
- dataset_name: str,
53
+ config: ImportConfig,
31
54
  ) -> int: ...
32
55
 
33
56
 
@@ -90,6 +113,44 @@ def without_none_values(d: dict) -> dict:
90
113
  return {k: v for k, v in d.items() if v is not None}
91
114
 
92
115
 
116
+ def add_tag_splits(runs: list[TaskRun], tag_splits: Dict[str, float] | None) -> None:
117
+ """Assign split tags to runs according to configured proportions.
118
+
119
+ Args:
120
+ runs: List of TaskRun objects to assign tags to
121
+ tag_splits: Dictionary mapping tag names to their desired proportions
122
+
123
+ The assignment is random but ensures the proportions match the configured splits
124
+ as closely as possible given the number of runs.
125
+ """
126
+ if not tag_splits:
127
+ return
128
+
129
+ # Calculate exact number of runs for each split
130
+ total_runs = len(runs)
131
+ split_counts = {
132
+ tag: int(proportion * total_runs) for tag, proportion in tag_splits.items()
133
+ }
134
+
135
+ # Handle rounding errors by adjusting the largest split
136
+ remaining = total_runs - sum(split_counts.values())
137
+ if remaining != 0:
138
+ largest_split = max(split_counts.items(), key=lambda x: x[1])
139
+ split_counts[largest_split[0]] += remaining
140
+
141
+ # Create a list of tags with the correct counts
142
+ tags_to_assign = []
143
+ for tag, count in split_counts.items():
144
+ tags_to_assign.extend([tag] * count)
145
+
146
+ # Shuffle the tags to randomize assignment
147
+ random.shuffle(tags_to_assign)
148
+
149
+ # Assign tags to runs
150
+ for run, tag in zip(runs, tags_to_assign):
151
+ run.tags.append(tag)
152
+
153
+
93
154
  def create_task_run_from_csv_row(
94
155
  task: Task,
95
156
  row: dict[str, str],
@@ -143,12 +204,18 @@ def create_task_run_from_csv_row(
143
204
  return run
144
205
 
145
206
 
146
- def import_csv(task: Task, dataset_path: str, dataset_name: str) -> int:
207
+ def import_csv(
208
+ task: Task,
209
+ config: ImportConfig,
210
+ ) -> int:
147
211
  """Import a CSV dataset.
148
212
 
149
213
  All rows are validated before any are persisted to files to avoid partial imports."""
150
214
 
151
215
  session_id = str(int(time.time()))
216
+ dataset_path = config.dataset_path
217
+ dataset_name = config.dataset_name
218
+ tag_splits = config.tag_splits
152
219
 
153
220
  required_headers = {"input", "output"} # minimum required headers
154
221
  optional_headers = {"reasoning", "tags", "chain_of_thought"} # optional headers
@@ -197,6 +264,8 @@ def import_csv(task: Task, dataset_path: str, dataset_name: str) -> int:
197
264
  ) from e
198
265
  rows.append(run)
199
266
 
267
+ add_tag_splits(rows, tag_splits)
268
+
200
269
  # now that we know all rows are valid, we can save them
201
270
  for run in rows:
202
271
  run.save_to_file()
@@ -209,24 +278,17 @@ DATASET_IMPORTERS: Dict[DatasetImportFormat, Importer] = {
209
278
  }
210
279
 
211
280
 
212
- @dataclass
213
- class ImportConfig:
214
- """Configuration for importing a dataset"""
215
-
216
- dataset_type: DatasetImportFormat
217
- dataset_path: str
218
- dataset_name: str
219
-
220
-
221
281
  class DatasetFileImporter:
222
282
  """Import a dataset from a file"""
223
283
 
224
284
  def __init__(self, task: Task, config: ImportConfig):
225
285
  self.task = task
226
- self.dataset_type = config.dataset_type
227
- self.dataset_path = config.dataset_path
228
- self.dataset_name = config.dataset_name
286
+ config.validate_tag_splits()
287
+ self.config = config
229
288
 
230
289
  def create_runs_from_file(self) -> int:
231
- fn = DATASET_IMPORTERS[self.dataset_type]
232
- return fn(self.task, self.dataset_path, self.dataset_name)
290
+ fn = DATASET_IMPORTERS[self.config.dataset_type]
291
+ return fn(
292
+ self.task,
293
+ self.config,
294
+ )
@@ -0,0 +1,199 @@
1
+ from typing import List
2
+ from unittest.mock import AsyncMock, patch
3
+
4
+ import pytest
5
+
6
+ from kiln_ai.utils.async_job_runner import AsyncJobRunner, Progress
7
+
8
+
9
+ @pytest.mark.parametrize("concurrency", [0, -1, -25])
10
+ def test_invalid_concurrency_raises(concurrency):
11
+ with pytest.raises(ValueError):
12
+ AsyncJobRunner(concurrency=concurrency)
13
+
14
+
15
+ # Test with and without concurrency
16
+ @pytest.mark.parametrize("concurrency", [1, 25])
17
+ @pytest.mark.asyncio
18
+ async def test_async_job_runner_status_updates(concurrency):
19
+ job_count = 50
20
+ jobs = [{"id": i} for i in range(job_count)]
21
+
22
+ runner = AsyncJobRunner(concurrency=concurrency)
23
+
24
+ # fake run_job that succeeds
25
+ mock_run_job_success = AsyncMock(return_value=True)
26
+
27
+ # Expect the status updates in order, and 1 for each job
28
+ expected_completed_count = 0
29
+ async for progress in runner.run(jobs, mock_run_job_success):
30
+ assert progress.complete == expected_completed_count
31
+ expected_completed_count += 1
32
+ assert progress.errors == 0
33
+ assert progress.total == job_count
34
+
35
+ # Verify last status update was complete
36
+ assert expected_completed_count == job_count + 1
37
+
38
+ # Verify run_job was called for each job
39
+ assert mock_run_job_success.call_count == job_count
40
+
41
+ # Verify run_job was called with the correct arguments
42
+ for i in range(job_count):
43
+ mock_run_job_success.assert_any_await(jobs[i])
44
+
45
+
46
+ # Test with and without concurrency
47
+ @pytest.mark.parametrize("concurrency", [1, 25])
48
+ @pytest.mark.asyncio
49
+ async def test_async_job_runner_status_updates_empty_job_list(concurrency):
50
+ empty_job_list = []
51
+
52
+ runner = AsyncJobRunner(concurrency=concurrency)
53
+
54
+ # fake run_job that succeeds
55
+ mock_run_job_success = AsyncMock(return_value=True)
56
+
57
+ updates: List[Progress] = []
58
+ async for progress in runner.run(empty_job_list, mock_run_job_success):
59
+ updates.append(progress)
60
+
61
+ # Verify last status update was complete
62
+ assert len(updates) == 1
63
+
64
+ assert updates[0].complete == 0
65
+ assert updates[0].errors == 0
66
+ assert updates[0].total == 0
67
+
68
+ # Verify run_job was called for each job
69
+ assert mock_run_job_success.call_count == 0
70
+
71
+
72
+ @pytest.mark.parametrize("concurrency", [1, 25])
73
+ @pytest.mark.asyncio
74
+ async def test_async_job_runner_all_failures(concurrency):
75
+ job_count = 50
76
+ jobs = [{"id": i} for i in range(job_count)]
77
+
78
+ runner = AsyncJobRunner(concurrency=concurrency)
79
+
80
+ # fake run_job that fails
81
+ mock_run_job_failure = AsyncMock(return_value=False)
82
+
83
+ # Expect the status updates in order, and 1 for each job
84
+ expected_error_count = 0
85
+ async for progress in runner.run(jobs, mock_run_job_failure):
86
+ assert progress.complete == 0
87
+ assert progress.errors == expected_error_count
88
+ expected_error_count += 1
89
+ assert progress.total == job_count
90
+
91
+ # Verify last status update was complete
92
+ assert expected_error_count == job_count + 1
93
+
94
+ # Verify run_job was called for each job
95
+ assert mock_run_job_failure.call_count == job_count
96
+
97
+ # Verify run_job was called with the correct arguments
98
+ for i in range(job_count):
99
+ mock_run_job_failure.assert_any_await(jobs[i])
100
+
101
+
102
+ @pytest.mark.parametrize("concurrency", [1, 25])
103
+ @pytest.mark.asyncio
104
+ async def test_async_job_runner_partial_failures(concurrency):
105
+ job_count = 50
106
+ jobs = [{"id": i} for i in range(job_count)]
107
+
108
+ # we want to fail on some jobs and succeed on others
109
+ jobs_to_fail = set([0, 2, 4, 6, 8, 20, 25])
110
+
111
+ runner = AsyncJobRunner(concurrency=concurrency)
112
+
113
+ # fake run_job that fails
114
+ mock_run_job_partial_success = AsyncMock(
115
+ # return True for jobs that should succeed
116
+ side_effect=lambda job: job["id"] not in jobs_to_fail
117
+ )
118
+
119
+ # Expect the status updates in order, and 1 for each job
120
+ async for progress in runner.run(jobs, mock_run_job_partial_success):
121
+ assert progress.total == job_count
122
+
123
+ # Verify last status update was complete
124
+ expected_error_count = len(jobs_to_fail)
125
+ expected_success_count = len(jobs) - expected_error_count
126
+ assert progress.errors == expected_error_count
127
+ assert progress.complete == expected_success_count
128
+
129
+ # Verify run_job was called for each job
130
+ assert mock_run_job_partial_success.call_count == job_count
131
+
132
+ # Verify run_job was called with the correct arguments
133
+ for i in range(job_count):
134
+ mock_run_job_partial_success.assert_any_await(jobs[i])
135
+
136
+
137
+ @pytest.mark.parametrize("concurrency", [1, 25])
138
+ @pytest.mark.asyncio
139
+ async def test_async_job_runner_partial_raises(concurrency):
140
+ job_count = 50
141
+ jobs = [{"id": i} for i in range(job_count)]
142
+
143
+ runner = AsyncJobRunner(concurrency=concurrency)
144
+
145
+ ids_to_fail = set([10, 25])
146
+
147
+ def failure_fn(job):
148
+ if job["id"] in ids_to_fail:
149
+ raise Exception("job failed unexpectedly")
150
+ return True
151
+
152
+ # fake run_job that fails
153
+ mock_run_job_partial_success = AsyncMock(side_effect=failure_fn)
154
+
155
+ # generate all the values we expect to see in progress updates
156
+ complete_values_expected = set([i for i in range(job_count - len(ids_to_fail) + 1)])
157
+ errors_values_expected = set([i for i in range(len(ids_to_fail) + 1)])
158
+
159
+ # keep track of all the updates we see
160
+ updates: List[Progress] = []
161
+
162
+ # we keep track of the progress values we have actually seen
163
+ complete_values_actual = set()
164
+ errors_values_actual = set()
165
+
166
+ # Expect the status updates in order, and 1 for each job
167
+ async for progress in runner.run(jobs, mock_run_job_partial_success):
168
+ updates.append(progress)
169
+ complete_values_actual.add(progress.complete)
170
+ errors_values_actual.add(progress.errors)
171
+
172
+ assert progress.total == job_count
173
+
174
+ # complete values should be all the jobs, except for the ones that failed
175
+ assert progress.complete == job_count - len(ids_to_fail)
176
+
177
+ # check that the actual updates and expected updates are equivalent sets
178
+ assert complete_values_actual == complete_values_expected
179
+ assert errors_values_actual == errors_values_expected
180
+
181
+ # we should have seen one update for each job, plus one for the initial status update
182
+ assert len(updates) == job_count + 1
183
+
184
+
185
+ @pytest.mark.parametrize("concurrency", [1, 25])
186
+ @pytest.mark.asyncio
187
+ async def test_async_job_runner_cancelled(concurrency):
188
+ runner = AsyncJobRunner(concurrency=concurrency)
189
+ jobs = [{"id": i} for i in range(10)]
190
+
191
+ with patch.object(
192
+ runner,
193
+ "_run_worker",
194
+ side_effect=Exception("run_worker raised an exception"),
195
+ ):
196
+ # if an exception is raised in the task, we should see it bubble up
197
+ with pytest.raises(Exception, match="run_worker raised an exception"):
198
+ async for _ in runner.run(jobs, AsyncMock(return_value=True)):
199
+ pass