kiln-ai 0.14.0__py3-none-any.whl → 0.16.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kiln_ai/adapters/eval/base_eval.py +7 -2
- kiln_ai/adapters/eval/eval_runner.py +5 -64
- kiln_ai/adapters/eval/g_eval.py +3 -3
- kiln_ai/adapters/fine_tune/base_finetune.py +6 -3
- kiln_ai/adapters/fine_tune/dataset_formatter.py +128 -38
- kiln_ai/adapters/fine_tune/finetune_registry.py +2 -0
- kiln_ai/adapters/fine_tune/fireworks_finetune.py +2 -1
- kiln_ai/adapters/fine_tune/test_base_finetune.py +7 -0
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +267 -10
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +1 -1
- kiln_ai/adapters/fine_tune/test_vertex_finetune.py +586 -0
- kiln_ai/adapters/fine_tune/vertex_finetune.py +217 -0
- kiln_ai/adapters/ml_model_list.py +817 -62
- kiln_ai/adapters/model_adapters/base_adapter.py +33 -10
- kiln_ai/adapters/model_adapters/litellm_adapter.py +51 -12
- kiln_ai/adapters/model_adapters/test_base_adapter.py +74 -2
- kiln_ai/adapters/model_adapters/test_litellm_adapter.py +65 -1
- kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +3 -2
- kiln_ai/adapters/model_adapters/test_structured_output.py +4 -6
- kiln_ai/adapters/parsers/base_parser.py +0 -3
- kiln_ai/adapters/parsers/parser_registry.py +5 -3
- kiln_ai/adapters/parsers/r1_parser.py +17 -2
- kiln_ai/adapters/parsers/request_formatters.py +40 -0
- kiln_ai/adapters/parsers/test_parser_registry.py +2 -2
- kiln_ai/adapters/parsers/test_r1_parser.py +44 -1
- kiln_ai/adapters/parsers/test_request_formatters.py +76 -0
- kiln_ai/adapters/prompt_builders.py +14 -1
- kiln_ai/adapters/provider_tools.py +25 -1
- kiln_ai/adapters/repair/test_repair_task.py +3 -2
- kiln_ai/adapters/test_prompt_builders.py +24 -3
- kiln_ai/adapters/test_provider_tools.py +86 -1
- kiln_ai/datamodel/__init__.py +2 -0
- kiln_ai/datamodel/datamodel_enums.py +14 -0
- kiln_ai/datamodel/dataset_filters.py +69 -1
- kiln_ai/datamodel/dataset_split.py +4 -0
- kiln_ai/datamodel/eval.py +8 -0
- kiln_ai/datamodel/finetune.py +1 -0
- kiln_ai/datamodel/json_schema.py +24 -7
- kiln_ai/datamodel/prompt_id.py +1 -0
- kiln_ai/datamodel/task_output.py +10 -6
- kiln_ai/datamodel/task_run.py +68 -12
- kiln_ai/datamodel/test_basemodel.py +3 -7
- kiln_ai/datamodel/test_dataset_filters.py +82 -0
- kiln_ai/datamodel/test_dataset_split.py +2 -0
- kiln_ai/datamodel/test_example_models.py +158 -3
- kiln_ai/datamodel/test_json_schema.py +22 -3
- kiln_ai/datamodel/test_model_perf.py +3 -2
- kiln_ai/datamodel/test_models.py +50 -2
- kiln_ai/utils/async_job_runner.py +106 -0
- kiln_ai/utils/dataset_import.py +80 -18
- kiln_ai/utils/test_async_job_runner.py +199 -0
- kiln_ai/utils/test_dataset_import.py +242 -10
- {kiln_ai-0.14.0.dist-info → kiln_ai-0.16.0.dist-info}/METADATA +3 -2
- kiln_ai-0.16.0.dist-info/RECORD +108 -0
- kiln_ai/adapters/test_generate_docs.py +0 -69
- kiln_ai-0.14.0.dist-info/RECORD +0 -103
- {kiln_ai-0.14.0.dist-info → kiln_ai-0.16.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.14.0.dist-info → kiln_ai-0.16.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import logging
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import AsyncGenerator, Awaitable, Callable, List, TypeVar
|
|
5
|
+
|
|
6
|
+
logger = logging.getLogger(__name__)
|
|
7
|
+
|
|
8
|
+
T = TypeVar("T")
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class Progress:
|
|
13
|
+
complete: int
|
|
14
|
+
total: int
|
|
15
|
+
errors: int
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class AsyncJobRunner:
|
|
19
|
+
def __init__(self, concurrency: int = 1):
|
|
20
|
+
if concurrency < 1:
|
|
21
|
+
raise ValueError("concurrency must be ≥ 1")
|
|
22
|
+
self.concurrency = concurrency
|
|
23
|
+
|
|
24
|
+
async def run(
|
|
25
|
+
self,
|
|
26
|
+
jobs: List[T],
|
|
27
|
+
run_job: Callable[[T], Awaitable[bool]],
|
|
28
|
+
) -> AsyncGenerator[Progress, None]:
|
|
29
|
+
"""
|
|
30
|
+
Runs the jobs with parallel workers and yields progress updates.
|
|
31
|
+
"""
|
|
32
|
+
complete = 0
|
|
33
|
+
errors = 0
|
|
34
|
+
total = len(jobs)
|
|
35
|
+
|
|
36
|
+
# Send initial status
|
|
37
|
+
yield Progress(complete=complete, total=total, errors=errors)
|
|
38
|
+
|
|
39
|
+
worker_queue: asyncio.Queue[T] = asyncio.Queue()
|
|
40
|
+
for job in jobs:
|
|
41
|
+
worker_queue.put_nowait(job)
|
|
42
|
+
|
|
43
|
+
# simple status queue to return progress. True=success, False=error
|
|
44
|
+
status_queue: asyncio.Queue[bool] = asyncio.Queue()
|
|
45
|
+
|
|
46
|
+
workers = []
|
|
47
|
+
for _ in range(self.concurrency):
|
|
48
|
+
task = asyncio.create_task(
|
|
49
|
+
self._run_worker(worker_queue, status_queue, run_job),
|
|
50
|
+
)
|
|
51
|
+
workers.append(task)
|
|
52
|
+
|
|
53
|
+
try:
|
|
54
|
+
# Send status updates until workers are done, and they are all sent
|
|
55
|
+
while not status_queue.empty() or not all(
|
|
56
|
+
worker.done() for worker in workers
|
|
57
|
+
):
|
|
58
|
+
try:
|
|
59
|
+
# Use timeout to prevent hanging if all workers complete
|
|
60
|
+
# between our while condition check and get()
|
|
61
|
+
success = await asyncio.wait_for(status_queue.get(), timeout=0.1)
|
|
62
|
+
if success:
|
|
63
|
+
complete += 1
|
|
64
|
+
else:
|
|
65
|
+
errors += 1
|
|
66
|
+
|
|
67
|
+
yield Progress(complete=complete, total=total, errors=errors)
|
|
68
|
+
except asyncio.TimeoutError:
|
|
69
|
+
# Timeout is expected, just continue to recheck worker status
|
|
70
|
+
# Don't love this but beats sentinels for reliability
|
|
71
|
+
continue
|
|
72
|
+
finally:
|
|
73
|
+
# Cancel outstanding workers on early exit or error
|
|
74
|
+
for w in workers:
|
|
75
|
+
w.cancel()
|
|
76
|
+
|
|
77
|
+
# These are redundant, but keeping them will catch async errors
|
|
78
|
+
await asyncio.gather(*workers)
|
|
79
|
+
await worker_queue.join()
|
|
80
|
+
|
|
81
|
+
async def _run_worker(
|
|
82
|
+
self,
|
|
83
|
+
worker_queue: asyncio.Queue[T],
|
|
84
|
+
status_queue: asyncio.Queue[bool],
|
|
85
|
+
run_job: Callable[[T], Awaitable[bool]],
|
|
86
|
+
):
|
|
87
|
+
while True:
|
|
88
|
+
try:
|
|
89
|
+
job = worker_queue.get_nowait()
|
|
90
|
+
except asyncio.QueueEmpty:
|
|
91
|
+
# worker can end when the queue is empty
|
|
92
|
+
break
|
|
93
|
+
|
|
94
|
+
try:
|
|
95
|
+
success = await run_job(job)
|
|
96
|
+
except Exception:
|
|
97
|
+
logger.error("Job failed to complete", exc_info=True)
|
|
98
|
+
success = False
|
|
99
|
+
|
|
100
|
+
try:
|
|
101
|
+
await status_queue.put(success)
|
|
102
|
+
except Exception:
|
|
103
|
+
logger.error("Failed to enqueue status for job", exc_info=True)
|
|
104
|
+
finally:
|
|
105
|
+
# Always mark the dequeued task as done, even on exceptions
|
|
106
|
+
worker_queue.task_done()
|
kiln_ai/utils/dataset_import.py
CHANGED
|
@@ -1,11 +1,12 @@
|
|
|
1
1
|
import csv
|
|
2
2
|
import logging
|
|
3
|
+
import random
|
|
3
4
|
import time
|
|
4
5
|
from dataclasses import dataclass
|
|
5
6
|
from enum import Enum
|
|
6
7
|
from typing import Dict, Protocol
|
|
7
8
|
|
|
8
|
-
from pydantic import BaseModel, Field, ValidationError
|
|
9
|
+
from pydantic import BaseModel, Field, ValidationError
|
|
9
10
|
|
|
10
11
|
from kiln_ai.datamodel import DataSource, DataSourceType, Task, TaskOutput, TaskRun
|
|
11
12
|
|
|
@@ -20,14 +21,36 @@ class DatasetImportFormat(str, Enum):
|
|
|
20
21
|
CSV = "csv"
|
|
21
22
|
|
|
22
23
|
|
|
24
|
+
@dataclass
|
|
25
|
+
class ImportConfig:
|
|
26
|
+
"""Configuration for importing a dataset"""
|
|
27
|
+
|
|
28
|
+
dataset_type: DatasetImportFormat
|
|
29
|
+
dataset_path: str
|
|
30
|
+
dataset_name: str
|
|
31
|
+
"""
|
|
32
|
+
A set of splits to assign to the import (as dataset tags).
|
|
33
|
+
The keys are the names of the splits (tag name), and the values are the proportions of the dataset to include in each split (should sum to 1).
|
|
34
|
+
"""
|
|
35
|
+
tag_splits: Dict[str, float] | None = None
|
|
36
|
+
|
|
37
|
+
def validate_tag_splits(self) -> None:
|
|
38
|
+
if self.tag_splits:
|
|
39
|
+
EPSILON = 0.001 # Allow for small floating point errors
|
|
40
|
+
if abs(sum(self.tag_splits.values()) - 1) > EPSILON:
|
|
41
|
+
raise ValueError(
|
|
42
|
+
"Splits must sum to 1. The following splits do not: "
|
|
43
|
+
+ ", ".join(f"{k}: {v}" for k, v in self.tag_splits.items())
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
|
|
23
47
|
class Importer(Protocol):
|
|
24
48
|
"""Protocol for dataset importers"""
|
|
25
49
|
|
|
26
50
|
def __call__(
|
|
27
51
|
self,
|
|
28
52
|
task: Task,
|
|
29
|
-
|
|
30
|
-
dataset_name: str,
|
|
53
|
+
config: ImportConfig,
|
|
31
54
|
) -> int: ...
|
|
32
55
|
|
|
33
56
|
|
|
@@ -90,6 +113,44 @@ def without_none_values(d: dict) -> dict:
|
|
|
90
113
|
return {k: v for k, v in d.items() if v is not None}
|
|
91
114
|
|
|
92
115
|
|
|
116
|
+
def add_tag_splits(runs: list[TaskRun], tag_splits: Dict[str, float] | None) -> None:
|
|
117
|
+
"""Assign split tags to runs according to configured proportions.
|
|
118
|
+
|
|
119
|
+
Args:
|
|
120
|
+
runs: List of TaskRun objects to assign tags to
|
|
121
|
+
tag_splits: Dictionary mapping tag names to their desired proportions
|
|
122
|
+
|
|
123
|
+
The assignment is random but ensures the proportions match the configured splits
|
|
124
|
+
as closely as possible given the number of runs.
|
|
125
|
+
"""
|
|
126
|
+
if not tag_splits:
|
|
127
|
+
return
|
|
128
|
+
|
|
129
|
+
# Calculate exact number of runs for each split
|
|
130
|
+
total_runs = len(runs)
|
|
131
|
+
split_counts = {
|
|
132
|
+
tag: int(proportion * total_runs) for tag, proportion in tag_splits.items()
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
# Handle rounding errors by adjusting the largest split
|
|
136
|
+
remaining = total_runs - sum(split_counts.values())
|
|
137
|
+
if remaining != 0:
|
|
138
|
+
largest_split = max(split_counts.items(), key=lambda x: x[1])
|
|
139
|
+
split_counts[largest_split[0]] += remaining
|
|
140
|
+
|
|
141
|
+
# Create a list of tags with the correct counts
|
|
142
|
+
tags_to_assign = []
|
|
143
|
+
for tag, count in split_counts.items():
|
|
144
|
+
tags_to_assign.extend([tag] * count)
|
|
145
|
+
|
|
146
|
+
# Shuffle the tags to randomize assignment
|
|
147
|
+
random.shuffle(tags_to_assign)
|
|
148
|
+
|
|
149
|
+
# Assign tags to runs
|
|
150
|
+
for run, tag in zip(runs, tags_to_assign):
|
|
151
|
+
run.tags.append(tag)
|
|
152
|
+
|
|
153
|
+
|
|
93
154
|
def create_task_run_from_csv_row(
|
|
94
155
|
task: Task,
|
|
95
156
|
row: dict[str, str],
|
|
@@ -143,12 +204,18 @@ def create_task_run_from_csv_row(
|
|
|
143
204
|
return run
|
|
144
205
|
|
|
145
206
|
|
|
146
|
-
def import_csv(
|
|
207
|
+
def import_csv(
|
|
208
|
+
task: Task,
|
|
209
|
+
config: ImportConfig,
|
|
210
|
+
) -> int:
|
|
147
211
|
"""Import a CSV dataset.
|
|
148
212
|
|
|
149
213
|
All rows are validated before any are persisted to files to avoid partial imports."""
|
|
150
214
|
|
|
151
215
|
session_id = str(int(time.time()))
|
|
216
|
+
dataset_path = config.dataset_path
|
|
217
|
+
dataset_name = config.dataset_name
|
|
218
|
+
tag_splits = config.tag_splits
|
|
152
219
|
|
|
153
220
|
required_headers = {"input", "output"} # minimum required headers
|
|
154
221
|
optional_headers = {"reasoning", "tags", "chain_of_thought"} # optional headers
|
|
@@ -197,6 +264,8 @@ def import_csv(task: Task, dataset_path: str, dataset_name: str) -> int:
|
|
|
197
264
|
) from e
|
|
198
265
|
rows.append(run)
|
|
199
266
|
|
|
267
|
+
add_tag_splits(rows, tag_splits)
|
|
268
|
+
|
|
200
269
|
# now that we know all rows are valid, we can save them
|
|
201
270
|
for run in rows:
|
|
202
271
|
run.save_to_file()
|
|
@@ -209,24 +278,17 @@ DATASET_IMPORTERS: Dict[DatasetImportFormat, Importer] = {
|
|
|
209
278
|
}
|
|
210
279
|
|
|
211
280
|
|
|
212
|
-
@dataclass
|
|
213
|
-
class ImportConfig:
|
|
214
|
-
"""Configuration for importing a dataset"""
|
|
215
|
-
|
|
216
|
-
dataset_type: DatasetImportFormat
|
|
217
|
-
dataset_path: str
|
|
218
|
-
dataset_name: str
|
|
219
|
-
|
|
220
|
-
|
|
221
281
|
class DatasetFileImporter:
|
|
222
282
|
"""Import a dataset from a file"""
|
|
223
283
|
|
|
224
284
|
def __init__(self, task: Task, config: ImportConfig):
|
|
225
285
|
self.task = task
|
|
226
|
-
|
|
227
|
-
self.
|
|
228
|
-
self.dataset_name = config.dataset_name
|
|
286
|
+
config.validate_tag_splits()
|
|
287
|
+
self.config = config
|
|
229
288
|
|
|
230
289
|
def create_runs_from_file(self) -> int:
|
|
231
|
-
fn = DATASET_IMPORTERS[self.dataset_type]
|
|
232
|
-
return fn(
|
|
290
|
+
fn = DATASET_IMPORTERS[self.config.dataset_type]
|
|
291
|
+
return fn(
|
|
292
|
+
self.task,
|
|
293
|
+
self.config,
|
|
294
|
+
)
|
|
@@ -0,0 +1,199 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
from unittest.mock import AsyncMock, patch
|
|
3
|
+
|
|
4
|
+
import pytest
|
|
5
|
+
|
|
6
|
+
from kiln_ai.utils.async_job_runner import AsyncJobRunner, Progress
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@pytest.mark.parametrize("concurrency", [0, -1, -25])
|
|
10
|
+
def test_invalid_concurrency_raises(concurrency):
|
|
11
|
+
with pytest.raises(ValueError):
|
|
12
|
+
AsyncJobRunner(concurrency=concurrency)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
# Test with and without concurrency
|
|
16
|
+
@pytest.mark.parametrize("concurrency", [1, 25])
|
|
17
|
+
@pytest.mark.asyncio
|
|
18
|
+
async def test_async_job_runner_status_updates(concurrency):
|
|
19
|
+
job_count = 50
|
|
20
|
+
jobs = [{"id": i} for i in range(job_count)]
|
|
21
|
+
|
|
22
|
+
runner = AsyncJobRunner(concurrency=concurrency)
|
|
23
|
+
|
|
24
|
+
# fake run_job that succeeds
|
|
25
|
+
mock_run_job_success = AsyncMock(return_value=True)
|
|
26
|
+
|
|
27
|
+
# Expect the status updates in order, and 1 for each job
|
|
28
|
+
expected_completed_count = 0
|
|
29
|
+
async for progress in runner.run(jobs, mock_run_job_success):
|
|
30
|
+
assert progress.complete == expected_completed_count
|
|
31
|
+
expected_completed_count += 1
|
|
32
|
+
assert progress.errors == 0
|
|
33
|
+
assert progress.total == job_count
|
|
34
|
+
|
|
35
|
+
# Verify last status update was complete
|
|
36
|
+
assert expected_completed_count == job_count + 1
|
|
37
|
+
|
|
38
|
+
# Verify run_job was called for each job
|
|
39
|
+
assert mock_run_job_success.call_count == job_count
|
|
40
|
+
|
|
41
|
+
# Verify run_job was called with the correct arguments
|
|
42
|
+
for i in range(job_count):
|
|
43
|
+
mock_run_job_success.assert_any_await(jobs[i])
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
# Test with and without concurrency
|
|
47
|
+
@pytest.mark.parametrize("concurrency", [1, 25])
|
|
48
|
+
@pytest.mark.asyncio
|
|
49
|
+
async def test_async_job_runner_status_updates_empty_job_list(concurrency):
|
|
50
|
+
empty_job_list = []
|
|
51
|
+
|
|
52
|
+
runner = AsyncJobRunner(concurrency=concurrency)
|
|
53
|
+
|
|
54
|
+
# fake run_job that succeeds
|
|
55
|
+
mock_run_job_success = AsyncMock(return_value=True)
|
|
56
|
+
|
|
57
|
+
updates: List[Progress] = []
|
|
58
|
+
async for progress in runner.run(empty_job_list, mock_run_job_success):
|
|
59
|
+
updates.append(progress)
|
|
60
|
+
|
|
61
|
+
# Verify last status update was complete
|
|
62
|
+
assert len(updates) == 1
|
|
63
|
+
|
|
64
|
+
assert updates[0].complete == 0
|
|
65
|
+
assert updates[0].errors == 0
|
|
66
|
+
assert updates[0].total == 0
|
|
67
|
+
|
|
68
|
+
# Verify run_job was called for each job
|
|
69
|
+
assert mock_run_job_success.call_count == 0
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
@pytest.mark.parametrize("concurrency", [1, 25])
|
|
73
|
+
@pytest.mark.asyncio
|
|
74
|
+
async def test_async_job_runner_all_failures(concurrency):
|
|
75
|
+
job_count = 50
|
|
76
|
+
jobs = [{"id": i} for i in range(job_count)]
|
|
77
|
+
|
|
78
|
+
runner = AsyncJobRunner(concurrency=concurrency)
|
|
79
|
+
|
|
80
|
+
# fake run_job that fails
|
|
81
|
+
mock_run_job_failure = AsyncMock(return_value=False)
|
|
82
|
+
|
|
83
|
+
# Expect the status updates in order, and 1 for each job
|
|
84
|
+
expected_error_count = 0
|
|
85
|
+
async for progress in runner.run(jobs, mock_run_job_failure):
|
|
86
|
+
assert progress.complete == 0
|
|
87
|
+
assert progress.errors == expected_error_count
|
|
88
|
+
expected_error_count += 1
|
|
89
|
+
assert progress.total == job_count
|
|
90
|
+
|
|
91
|
+
# Verify last status update was complete
|
|
92
|
+
assert expected_error_count == job_count + 1
|
|
93
|
+
|
|
94
|
+
# Verify run_job was called for each job
|
|
95
|
+
assert mock_run_job_failure.call_count == job_count
|
|
96
|
+
|
|
97
|
+
# Verify run_job was called with the correct arguments
|
|
98
|
+
for i in range(job_count):
|
|
99
|
+
mock_run_job_failure.assert_any_await(jobs[i])
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
@pytest.mark.parametrize("concurrency", [1, 25])
|
|
103
|
+
@pytest.mark.asyncio
|
|
104
|
+
async def test_async_job_runner_partial_failures(concurrency):
|
|
105
|
+
job_count = 50
|
|
106
|
+
jobs = [{"id": i} for i in range(job_count)]
|
|
107
|
+
|
|
108
|
+
# we want to fail on some jobs and succeed on others
|
|
109
|
+
jobs_to_fail = set([0, 2, 4, 6, 8, 20, 25])
|
|
110
|
+
|
|
111
|
+
runner = AsyncJobRunner(concurrency=concurrency)
|
|
112
|
+
|
|
113
|
+
# fake run_job that fails
|
|
114
|
+
mock_run_job_partial_success = AsyncMock(
|
|
115
|
+
# return True for jobs that should succeed
|
|
116
|
+
side_effect=lambda job: job["id"] not in jobs_to_fail
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
# Expect the status updates in order, and 1 for each job
|
|
120
|
+
async for progress in runner.run(jobs, mock_run_job_partial_success):
|
|
121
|
+
assert progress.total == job_count
|
|
122
|
+
|
|
123
|
+
# Verify last status update was complete
|
|
124
|
+
expected_error_count = len(jobs_to_fail)
|
|
125
|
+
expected_success_count = len(jobs) - expected_error_count
|
|
126
|
+
assert progress.errors == expected_error_count
|
|
127
|
+
assert progress.complete == expected_success_count
|
|
128
|
+
|
|
129
|
+
# Verify run_job was called for each job
|
|
130
|
+
assert mock_run_job_partial_success.call_count == job_count
|
|
131
|
+
|
|
132
|
+
# Verify run_job was called with the correct arguments
|
|
133
|
+
for i in range(job_count):
|
|
134
|
+
mock_run_job_partial_success.assert_any_await(jobs[i])
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
@pytest.mark.parametrize("concurrency", [1, 25])
|
|
138
|
+
@pytest.mark.asyncio
|
|
139
|
+
async def test_async_job_runner_partial_raises(concurrency):
|
|
140
|
+
job_count = 50
|
|
141
|
+
jobs = [{"id": i} for i in range(job_count)]
|
|
142
|
+
|
|
143
|
+
runner = AsyncJobRunner(concurrency=concurrency)
|
|
144
|
+
|
|
145
|
+
ids_to_fail = set([10, 25])
|
|
146
|
+
|
|
147
|
+
def failure_fn(job):
|
|
148
|
+
if job["id"] in ids_to_fail:
|
|
149
|
+
raise Exception("job failed unexpectedly")
|
|
150
|
+
return True
|
|
151
|
+
|
|
152
|
+
# fake run_job that fails
|
|
153
|
+
mock_run_job_partial_success = AsyncMock(side_effect=failure_fn)
|
|
154
|
+
|
|
155
|
+
# generate all the values we expect to see in progress updates
|
|
156
|
+
complete_values_expected = set([i for i in range(job_count - len(ids_to_fail) + 1)])
|
|
157
|
+
errors_values_expected = set([i for i in range(len(ids_to_fail) + 1)])
|
|
158
|
+
|
|
159
|
+
# keep track of all the updates we see
|
|
160
|
+
updates: List[Progress] = []
|
|
161
|
+
|
|
162
|
+
# we keep track of the progress values we have actually seen
|
|
163
|
+
complete_values_actual = set()
|
|
164
|
+
errors_values_actual = set()
|
|
165
|
+
|
|
166
|
+
# Expect the status updates in order, and 1 for each job
|
|
167
|
+
async for progress in runner.run(jobs, mock_run_job_partial_success):
|
|
168
|
+
updates.append(progress)
|
|
169
|
+
complete_values_actual.add(progress.complete)
|
|
170
|
+
errors_values_actual.add(progress.errors)
|
|
171
|
+
|
|
172
|
+
assert progress.total == job_count
|
|
173
|
+
|
|
174
|
+
# complete values should be all the jobs, except for the ones that failed
|
|
175
|
+
assert progress.complete == job_count - len(ids_to_fail)
|
|
176
|
+
|
|
177
|
+
# check that the actual updates and expected updates are equivalent sets
|
|
178
|
+
assert complete_values_actual == complete_values_expected
|
|
179
|
+
assert errors_values_actual == errors_values_expected
|
|
180
|
+
|
|
181
|
+
# we should have seen one update for each job, plus one for the initial status update
|
|
182
|
+
assert len(updates) == job_count + 1
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
@pytest.mark.parametrize("concurrency", [1, 25])
|
|
186
|
+
@pytest.mark.asyncio
|
|
187
|
+
async def test_async_job_runner_cancelled(concurrency):
|
|
188
|
+
runner = AsyncJobRunner(concurrency=concurrency)
|
|
189
|
+
jobs = [{"id": i} for i in range(10)]
|
|
190
|
+
|
|
191
|
+
with patch.object(
|
|
192
|
+
runner,
|
|
193
|
+
"_run_worker",
|
|
194
|
+
side_effect=Exception("run_worker raised an exception"),
|
|
195
|
+
):
|
|
196
|
+
# if an exception is raised in the task, we should see it bubble up
|
|
197
|
+
with pytest.raises(Exception, match="run_worker raised an exception"):
|
|
198
|
+
async for _ in runner.run(jobs, AsyncMock(return_value=True)):
|
|
199
|
+
pass
|