kiln-ai 0.15.0__py3-none-any.whl → 0.16.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. kiln_ai/adapters/eval/eval_runner.py +5 -64
  2. kiln_ai/adapters/eval/g_eval.py +3 -3
  3. kiln_ai/adapters/fine_tune/dataset_formatter.py +124 -34
  4. kiln_ai/adapters/fine_tune/test_dataset_formatter.py +264 -7
  5. kiln_ai/adapters/ml_model_list.py +478 -4
  6. kiln_ai/adapters/model_adapters/base_adapter.py +26 -8
  7. kiln_ai/adapters/model_adapters/litellm_adapter.py +41 -7
  8. kiln_ai/adapters/model_adapters/test_base_adapter.py +74 -2
  9. kiln_ai/adapters/model_adapters/test_litellm_adapter.py +65 -1
  10. kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +3 -2
  11. kiln_ai/adapters/model_adapters/test_structured_output.py +4 -6
  12. kiln_ai/adapters/parsers/base_parser.py +0 -3
  13. kiln_ai/adapters/parsers/parser_registry.py +5 -3
  14. kiln_ai/adapters/parsers/r1_parser.py +17 -2
  15. kiln_ai/adapters/parsers/request_formatters.py +40 -0
  16. kiln_ai/adapters/parsers/test_parser_registry.py +2 -2
  17. kiln_ai/adapters/parsers/test_r1_parser.py +44 -1
  18. kiln_ai/adapters/parsers/test_request_formatters.py +76 -0
  19. kiln_ai/adapters/prompt_builders.py +14 -1
  20. kiln_ai/adapters/provider_tools.py +18 -1
  21. kiln_ai/adapters/repair/test_repair_task.py +3 -2
  22. kiln_ai/adapters/test_prompt_builders.py +24 -3
  23. kiln_ai/adapters/test_provider_tools.py +70 -1
  24. kiln_ai/datamodel/__init__.py +2 -0
  25. kiln_ai/datamodel/datamodel_enums.py +14 -0
  26. kiln_ai/datamodel/dataset_filters.py +69 -1
  27. kiln_ai/datamodel/dataset_split.py +4 -0
  28. kiln_ai/datamodel/eval.py +8 -0
  29. kiln_ai/datamodel/finetune.py +1 -0
  30. kiln_ai/datamodel/prompt_id.py +1 -0
  31. kiln_ai/datamodel/task_output.py +1 -1
  32. kiln_ai/datamodel/task_run.py +39 -7
  33. kiln_ai/datamodel/test_basemodel.py +3 -7
  34. kiln_ai/datamodel/test_dataset_filters.py +82 -0
  35. kiln_ai/datamodel/test_dataset_split.py +2 -0
  36. kiln_ai/datamodel/test_example_models.py +54 -0
  37. kiln_ai/datamodel/test_models.py +50 -2
  38. kiln_ai/utils/async_job_runner.py +106 -0
  39. kiln_ai/utils/dataset_import.py +80 -18
  40. kiln_ai/utils/test_async_job_runner.py +199 -0
  41. kiln_ai/utils/test_dataset_import.py +242 -10
  42. {kiln_ai-0.15.0.dist-info → kiln_ai-0.16.0.dist-info}/METADATA +1 -1
  43. {kiln_ai-0.15.0.dist-info → kiln_ai-0.16.0.dist-info}/RECORD +45 -41
  44. {kiln_ai-0.15.0.dist-info → kiln_ai-0.16.0.dist-info}/WHEEL +0 -0
  45. {kiln_ai-0.15.0.dist-info → kiln_ai-0.16.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -547,20 +547,34 @@ def test_prompt_parent_task():
547
547
  False,
548
548
  None,
549
549
  ),
550
- # Test 3: Invalid case - thinking instructions with final_only
550
+ # Test 3: Valid case - no thinking instructions with final_and_intermediate_r1_compatible
551
+ (
552
+ None,
553
+ FinetuneDataStrategy.final_and_intermediate_r1_compatible,
554
+ False,
555
+ None,
556
+ ),
557
+ # Test 4: Invalid case - thinking instructions with final_only
551
558
  (
552
559
  "Think step by step",
553
560
  FinetuneDataStrategy.final_only,
554
561
  True,
555
562
  "Thinking instructions can only be used when data_strategy is final_and_intermediate",
556
563
  ),
557
- # Test 4: Invalid case - no thinking instructions with final_and_intermediate
564
+ # Test 5: Invalid case - no thinking instructions with final_and_intermediate
558
565
  (
559
566
  None,
560
567
  FinetuneDataStrategy.final_and_intermediate,
561
568
  True,
562
569
  "Thinking instructions are required when data_strategy is final_and_intermediate",
563
570
  ),
571
+ # Test 6: Invalid case - thinking instructions with final_and_intermediate_r1_compatible
572
+ (
573
+ "Think step by step",
574
+ FinetuneDataStrategy.final_and_intermediate_r1_compatible,
575
+ True,
576
+ "Thinking instructions can only be used when data_strategy is final_and_intermediate",
577
+ ),
564
578
  ],
565
579
  )
566
580
  def test_finetune_thinking_instructions_validation(
@@ -617,3 +631,37 @@ def test_task_run_has_thinking_training_data(intermediate_outputs, expected):
617
631
  intermediate_outputs=intermediate_outputs,
618
632
  )
619
633
  assert task_run.has_thinking_training_data() == expected
634
+
635
+
636
+ @pytest.mark.parametrize(
637
+ "intermediate_outputs,expected",
638
+ [
639
+ # No intermediate outputs
640
+ (None, None),
641
+ # Empty intermediate outputs
642
+ ({}, None),
643
+ # Only chain_of_thought
644
+ ({"chain_of_thought": "thinking process"}, "thinking process"),
645
+ # Only reasoning
646
+ ({"reasoning": "reasoning process"}, "reasoning process"),
647
+ # Both chain_of_thought and reasoning (should return reasoning as it's checked first)
648
+ (
649
+ {"chain_of_thought": "thinking process", "reasoning": "reasoning process"},
650
+ "reasoning process",
651
+ ),
652
+ # Other intermediate outputs but no thinking data
653
+ ({"other_output": "some data"}, None),
654
+ # Mixed other outputs with thinking data
655
+ (
656
+ {"chain_of_thought": "thinking process", "other_output": "some data"},
657
+ "thinking process",
658
+ ),
659
+ ],
660
+ )
661
+ def test_task_run_thinking_training_data(intermediate_outputs, expected):
662
+ task_run = TaskRun(
663
+ input="test input",
664
+ output=TaskOutput(output="test output"),
665
+ intermediate_outputs=intermediate_outputs,
666
+ )
667
+ assert task_run.thinking_training_data() == expected
@@ -0,0 +1,106 @@
1
+ import asyncio
2
+ import logging
3
+ from dataclasses import dataclass
4
+ from typing import AsyncGenerator, Awaitable, Callable, List, TypeVar
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+ T = TypeVar("T")
9
+
10
+
11
+ @dataclass
12
+ class Progress:
13
+ complete: int
14
+ total: int
15
+ errors: int
16
+
17
+
18
+ class AsyncJobRunner:
19
+ def __init__(self, concurrency: int = 1):
20
+ if concurrency < 1:
21
+ raise ValueError("concurrency must be ≥ 1")
22
+ self.concurrency = concurrency
23
+
24
+ async def run(
25
+ self,
26
+ jobs: List[T],
27
+ run_job: Callable[[T], Awaitable[bool]],
28
+ ) -> AsyncGenerator[Progress, None]:
29
+ """
30
+ Runs the jobs with parallel workers and yields progress updates.
31
+ """
32
+ complete = 0
33
+ errors = 0
34
+ total = len(jobs)
35
+
36
+ # Send initial status
37
+ yield Progress(complete=complete, total=total, errors=errors)
38
+
39
+ worker_queue: asyncio.Queue[T] = asyncio.Queue()
40
+ for job in jobs:
41
+ worker_queue.put_nowait(job)
42
+
43
+ # simple status queue to return progress. True=success, False=error
44
+ status_queue: asyncio.Queue[bool] = asyncio.Queue()
45
+
46
+ workers = []
47
+ for _ in range(self.concurrency):
48
+ task = asyncio.create_task(
49
+ self._run_worker(worker_queue, status_queue, run_job),
50
+ )
51
+ workers.append(task)
52
+
53
+ try:
54
+ # Send status updates until workers are done, and they are all sent
55
+ while not status_queue.empty() or not all(
56
+ worker.done() for worker in workers
57
+ ):
58
+ try:
59
+ # Use timeout to prevent hanging if all workers complete
60
+ # between our while condition check and get()
61
+ success = await asyncio.wait_for(status_queue.get(), timeout=0.1)
62
+ if success:
63
+ complete += 1
64
+ else:
65
+ errors += 1
66
+
67
+ yield Progress(complete=complete, total=total, errors=errors)
68
+ except asyncio.TimeoutError:
69
+ # Timeout is expected, just continue to recheck worker status
70
+ # Don't love this but beats sentinels for reliability
71
+ continue
72
+ finally:
73
+ # Cancel outstanding workers on early exit or error
74
+ for w in workers:
75
+ w.cancel()
76
+
77
+ # These are redundant, but keeping them will catch async errors
78
+ await asyncio.gather(*workers)
79
+ await worker_queue.join()
80
+
81
+ async def _run_worker(
82
+ self,
83
+ worker_queue: asyncio.Queue[T],
84
+ status_queue: asyncio.Queue[bool],
85
+ run_job: Callable[[T], Awaitable[bool]],
86
+ ):
87
+ while True:
88
+ try:
89
+ job = worker_queue.get_nowait()
90
+ except asyncio.QueueEmpty:
91
+ # worker can end when the queue is empty
92
+ break
93
+
94
+ try:
95
+ success = await run_job(job)
96
+ except Exception:
97
+ logger.error("Job failed to complete", exc_info=True)
98
+ success = False
99
+
100
+ try:
101
+ await status_queue.put(success)
102
+ except Exception:
103
+ logger.error("Failed to enqueue status for job", exc_info=True)
104
+ finally:
105
+ # Always mark the dequeued task as done, even on exceptions
106
+ worker_queue.task_done()
@@ -1,11 +1,12 @@
1
1
  import csv
2
2
  import logging
3
+ import random
3
4
  import time
4
5
  from dataclasses import dataclass
5
6
  from enum import Enum
6
7
  from typing import Dict, Protocol
7
8
 
8
- from pydantic import BaseModel, Field, ValidationError, field_validator
9
+ from pydantic import BaseModel, Field, ValidationError
9
10
 
10
11
  from kiln_ai.datamodel import DataSource, DataSourceType, Task, TaskOutput, TaskRun
11
12
 
@@ -20,14 +21,36 @@ class DatasetImportFormat(str, Enum):
20
21
  CSV = "csv"
21
22
 
22
23
 
24
+ @dataclass
25
+ class ImportConfig:
26
+ """Configuration for importing a dataset"""
27
+
28
+ dataset_type: DatasetImportFormat
29
+ dataset_path: str
30
+ dataset_name: str
31
+ """
32
+ A set of splits to assign to the import (as dataset tags).
33
+ The keys are the names of the splits (tag name), and the values are the proportions of the dataset to include in each split (should sum to 1).
34
+ """
35
+ tag_splits: Dict[str, float] | None = None
36
+
37
+ def validate_tag_splits(self) -> None:
38
+ if self.tag_splits:
39
+ EPSILON = 0.001 # Allow for small floating point errors
40
+ if abs(sum(self.tag_splits.values()) - 1) > EPSILON:
41
+ raise ValueError(
42
+ "Splits must sum to 1. The following splits do not: "
43
+ + ", ".join(f"{k}: {v}" for k, v in self.tag_splits.items())
44
+ )
45
+
46
+
23
47
  class Importer(Protocol):
24
48
  """Protocol for dataset importers"""
25
49
 
26
50
  def __call__(
27
51
  self,
28
52
  task: Task,
29
- dataset_path: str,
30
- dataset_name: str,
53
+ config: ImportConfig,
31
54
  ) -> int: ...
32
55
 
33
56
 
@@ -90,6 +113,44 @@ def without_none_values(d: dict) -> dict:
90
113
  return {k: v for k, v in d.items() if v is not None}
91
114
 
92
115
 
116
+ def add_tag_splits(runs: list[TaskRun], tag_splits: Dict[str, float] | None) -> None:
117
+ """Assign split tags to runs according to configured proportions.
118
+
119
+ Args:
120
+ runs: List of TaskRun objects to assign tags to
121
+ tag_splits: Dictionary mapping tag names to their desired proportions
122
+
123
+ The assignment is random but ensures the proportions match the configured splits
124
+ as closely as possible given the number of runs.
125
+ """
126
+ if not tag_splits:
127
+ return
128
+
129
+ # Calculate exact number of runs for each split
130
+ total_runs = len(runs)
131
+ split_counts = {
132
+ tag: int(proportion * total_runs) for tag, proportion in tag_splits.items()
133
+ }
134
+
135
+ # Handle rounding errors by adjusting the largest split
136
+ remaining = total_runs - sum(split_counts.values())
137
+ if remaining != 0:
138
+ largest_split = max(split_counts.items(), key=lambda x: x[1])
139
+ split_counts[largest_split[0]] += remaining
140
+
141
+ # Create a list of tags with the correct counts
142
+ tags_to_assign = []
143
+ for tag, count in split_counts.items():
144
+ tags_to_assign.extend([tag] * count)
145
+
146
+ # Shuffle the tags to randomize assignment
147
+ random.shuffle(tags_to_assign)
148
+
149
+ # Assign tags to runs
150
+ for run, tag in zip(runs, tags_to_assign):
151
+ run.tags.append(tag)
152
+
153
+
93
154
  def create_task_run_from_csv_row(
94
155
  task: Task,
95
156
  row: dict[str, str],
@@ -143,12 +204,18 @@ def create_task_run_from_csv_row(
143
204
  return run
144
205
 
145
206
 
146
- def import_csv(task: Task, dataset_path: str, dataset_name: str) -> int:
207
+ def import_csv(
208
+ task: Task,
209
+ config: ImportConfig,
210
+ ) -> int:
147
211
  """Import a CSV dataset.
148
212
 
149
213
  All rows are validated before any are persisted to files to avoid partial imports."""
150
214
 
151
215
  session_id = str(int(time.time()))
216
+ dataset_path = config.dataset_path
217
+ dataset_name = config.dataset_name
218
+ tag_splits = config.tag_splits
152
219
 
153
220
  required_headers = {"input", "output"} # minimum required headers
154
221
  optional_headers = {"reasoning", "tags", "chain_of_thought"} # optional headers
@@ -197,6 +264,8 @@ def import_csv(task: Task, dataset_path: str, dataset_name: str) -> int:
197
264
  ) from e
198
265
  rows.append(run)
199
266
 
267
+ add_tag_splits(rows, tag_splits)
268
+
200
269
  # now that we know all rows are valid, we can save them
201
270
  for run in rows:
202
271
  run.save_to_file()
@@ -209,24 +278,17 @@ DATASET_IMPORTERS: Dict[DatasetImportFormat, Importer] = {
209
278
  }
210
279
 
211
280
 
212
- @dataclass
213
- class ImportConfig:
214
- """Configuration for importing a dataset"""
215
-
216
- dataset_type: DatasetImportFormat
217
- dataset_path: str
218
- dataset_name: str
219
-
220
-
221
281
  class DatasetFileImporter:
222
282
  """Import a dataset from a file"""
223
283
 
224
284
  def __init__(self, task: Task, config: ImportConfig):
225
285
  self.task = task
226
- self.dataset_type = config.dataset_type
227
- self.dataset_path = config.dataset_path
228
- self.dataset_name = config.dataset_name
286
+ config.validate_tag_splits()
287
+ self.config = config
229
288
 
230
289
  def create_runs_from_file(self) -> int:
231
- fn = DATASET_IMPORTERS[self.dataset_type]
232
- return fn(self.task, self.dataset_path, self.dataset_name)
290
+ fn = DATASET_IMPORTERS[self.config.dataset_type]
291
+ return fn(
292
+ self.task,
293
+ self.config,
294
+ )
@@ -0,0 +1,199 @@
1
+ from typing import List
2
+ from unittest.mock import AsyncMock, patch
3
+
4
+ import pytest
5
+
6
+ from kiln_ai.utils.async_job_runner import AsyncJobRunner, Progress
7
+
8
+
9
+ @pytest.mark.parametrize("concurrency", [0, -1, -25])
10
+ def test_invalid_concurrency_raises(concurrency):
11
+ with pytest.raises(ValueError):
12
+ AsyncJobRunner(concurrency=concurrency)
13
+
14
+
15
+ # Test with and without concurrency
16
+ @pytest.mark.parametrize("concurrency", [1, 25])
17
+ @pytest.mark.asyncio
18
+ async def test_async_job_runner_status_updates(concurrency):
19
+ job_count = 50
20
+ jobs = [{"id": i} for i in range(job_count)]
21
+
22
+ runner = AsyncJobRunner(concurrency=concurrency)
23
+
24
+ # fake run_job that succeeds
25
+ mock_run_job_success = AsyncMock(return_value=True)
26
+
27
+ # Expect the status updates in order, and 1 for each job
28
+ expected_completed_count = 0
29
+ async for progress in runner.run(jobs, mock_run_job_success):
30
+ assert progress.complete == expected_completed_count
31
+ expected_completed_count += 1
32
+ assert progress.errors == 0
33
+ assert progress.total == job_count
34
+
35
+ # Verify last status update was complete
36
+ assert expected_completed_count == job_count + 1
37
+
38
+ # Verify run_job was called for each job
39
+ assert mock_run_job_success.call_count == job_count
40
+
41
+ # Verify run_job was called with the correct arguments
42
+ for i in range(job_count):
43
+ mock_run_job_success.assert_any_await(jobs[i])
44
+
45
+
46
+ # Test with and without concurrency
47
+ @pytest.mark.parametrize("concurrency", [1, 25])
48
+ @pytest.mark.asyncio
49
+ async def test_async_job_runner_status_updates_empty_job_list(concurrency):
50
+ empty_job_list = []
51
+
52
+ runner = AsyncJobRunner(concurrency=concurrency)
53
+
54
+ # fake run_job that succeeds
55
+ mock_run_job_success = AsyncMock(return_value=True)
56
+
57
+ updates: List[Progress] = []
58
+ async for progress in runner.run(empty_job_list, mock_run_job_success):
59
+ updates.append(progress)
60
+
61
+ # Verify last status update was complete
62
+ assert len(updates) == 1
63
+
64
+ assert updates[0].complete == 0
65
+ assert updates[0].errors == 0
66
+ assert updates[0].total == 0
67
+
68
+ # Verify run_job was called for each job
69
+ assert mock_run_job_success.call_count == 0
70
+
71
+
72
+ @pytest.mark.parametrize("concurrency", [1, 25])
73
+ @pytest.mark.asyncio
74
+ async def test_async_job_runner_all_failures(concurrency):
75
+ job_count = 50
76
+ jobs = [{"id": i} for i in range(job_count)]
77
+
78
+ runner = AsyncJobRunner(concurrency=concurrency)
79
+
80
+ # fake run_job that fails
81
+ mock_run_job_failure = AsyncMock(return_value=False)
82
+
83
+ # Expect the status updates in order, and 1 for each job
84
+ expected_error_count = 0
85
+ async for progress in runner.run(jobs, mock_run_job_failure):
86
+ assert progress.complete == 0
87
+ assert progress.errors == expected_error_count
88
+ expected_error_count += 1
89
+ assert progress.total == job_count
90
+
91
+ # Verify last status update was complete
92
+ assert expected_error_count == job_count + 1
93
+
94
+ # Verify run_job was called for each job
95
+ assert mock_run_job_failure.call_count == job_count
96
+
97
+ # Verify run_job was called with the correct arguments
98
+ for i in range(job_count):
99
+ mock_run_job_failure.assert_any_await(jobs[i])
100
+
101
+
102
+ @pytest.mark.parametrize("concurrency", [1, 25])
103
+ @pytest.mark.asyncio
104
+ async def test_async_job_runner_partial_failures(concurrency):
105
+ job_count = 50
106
+ jobs = [{"id": i} for i in range(job_count)]
107
+
108
+ # we want to fail on some jobs and succeed on others
109
+ jobs_to_fail = set([0, 2, 4, 6, 8, 20, 25])
110
+
111
+ runner = AsyncJobRunner(concurrency=concurrency)
112
+
113
+ # fake run_job that fails
114
+ mock_run_job_partial_success = AsyncMock(
115
+ # return True for jobs that should succeed
116
+ side_effect=lambda job: job["id"] not in jobs_to_fail
117
+ )
118
+
119
+ # Expect the status updates in order, and 1 for each job
120
+ async for progress in runner.run(jobs, mock_run_job_partial_success):
121
+ assert progress.total == job_count
122
+
123
+ # Verify last status update was complete
124
+ expected_error_count = len(jobs_to_fail)
125
+ expected_success_count = len(jobs) - expected_error_count
126
+ assert progress.errors == expected_error_count
127
+ assert progress.complete == expected_success_count
128
+
129
+ # Verify run_job was called for each job
130
+ assert mock_run_job_partial_success.call_count == job_count
131
+
132
+ # Verify run_job was called with the correct arguments
133
+ for i in range(job_count):
134
+ mock_run_job_partial_success.assert_any_await(jobs[i])
135
+
136
+
137
+ @pytest.mark.parametrize("concurrency", [1, 25])
138
+ @pytest.mark.asyncio
139
+ async def test_async_job_runner_partial_raises(concurrency):
140
+ job_count = 50
141
+ jobs = [{"id": i} for i in range(job_count)]
142
+
143
+ runner = AsyncJobRunner(concurrency=concurrency)
144
+
145
+ ids_to_fail = set([10, 25])
146
+
147
+ def failure_fn(job):
148
+ if job["id"] in ids_to_fail:
149
+ raise Exception("job failed unexpectedly")
150
+ return True
151
+
152
+ # fake run_job that fails
153
+ mock_run_job_partial_success = AsyncMock(side_effect=failure_fn)
154
+
155
+ # generate all the values we expect to see in progress updates
156
+ complete_values_expected = set([i for i in range(job_count - len(ids_to_fail) + 1)])
157
+ errors_values_expected = set([i for i in range(len(ids_to_fail) + 1)])
158
+
159
+ # keep track of all the updates we see
160
+ updates: List[Progress] = []
161
+
162
+ # we keep track of the progress values we have actually seen
163
+ complete_values_actual = set()
164
+ errors_values_actual = set()
165
+
166
+ # Expect the status updates in order, and 1 for each job
167
+ async for progress in runner.run(jobs, mock_run_job_partial_success):
168
+ updates.append(progress)
169
+ complete_values_actual.add(progress.complete)
170
+ errors_values_actual.add(progress.errors)
171
+
172
+ assert progress.total == job_count
173
+
174
+ # complete values should be all the jobs, except for the ones that failed
175
+ assert progress.complete == job_count - len(ids_to_fail)
176
+
177
+ # check that the actual updates and expected updates are equivalent sets
178
+ assert complete_values_actual == complete_values_expected
179
+ assert errors_values_actual == errors_values_expected
180
+
181
+ # we should have seen one update for each job, plus one for the initial status update
182
+ assert len(updates) == job_count + 1
183
+
184
+
185
+ @pytest.mark.parametrize("concurrency", [1, 25])
186
+ @pytest.mark.asyncio
187
+ async def test_async_job_runner_cancelled(concurrency):
188
+ runner = AsyncJobRunner(concurrency=concurrency)
189
+ jobs = [{"id": i} for i in range(10)]
190
+
191
+ with patch.object(
192
+ runner,
193
+ "_run_worker",
194
+ side_effect=Exception("run_worker raised an exception"),
195
+ ):
196
+ # if an exception is raised in the task, we should see it bubble up
197
+ with pytest.raises(Exception, match="run_worker raised an exception"):
198
+ async for _ in runner.run(jobs, AsyncMock(return_value=True)):
199
+ pass