kiln-ai 0.15.0__py3-none-any.whl → 0.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (72) hide show
  1. kiln_ai/adapters/__init__.py +2 -0
  2. kiln_ai/adapters/adapter_registry.py +22 -44
  3. kiln_ai/adapters/chat/__init__.py +8 -0
  4. kiln_ai/adapters/chat/chat_formatter.py +234 -0
  5. kiln_ai/adapters/chat/test_chat_formatter.py +131 -0
  6. kiln_ai/adapters/data_gen/test_data_gen_task.py +19 -6
  7. kiln_ai/adapters/eval/base_eval.py +8 -6
  8. kiln_ai/adapters/eval/eval_runner.py +9 -65
  9. kiln_ai/adapters/eval/g_eval.py +26 -8
  10. kiln_ai/adapters/eval/test_base_eval.py +166 -15
  11. kiln_ai/adapters/eval/test_eval_runner.py +3 -0
  12. kiln_ai/adapters/eval/test_g_eval.py +1 -0
  13. kiln_ai/adapters/fine_tune/base_finetune.py +2 -2
  14. kiln_ai/adapters/fine_tune/dataset_formatter.py +153 -197
  15. kiln_ai/adapters/fine_tune/test_base_finetune.py +10 -10
  16. kiln_ai/adapters/fine_tune/test_dataset_formatter.py +402 -211
  17. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +3 -3
  18. kiln_ai/adapters/fine_tune/test_openai_finetune.py +6 -6
  19. kiln_ai/adapters/fine_tune/test_together_finetune.py +1 -0
  20. kiln_ai/adapters/fine_tune/test_vertex_finetune.py +4 -4
  21. kiln_ai/adapters/fine_tune/together_finetune.py +12 -1
  22. kiln_ai/adapters/ml_model_list.py +556 -45
  23. kiln_ai/adapters/model_adapters/base_adapter.py +100 -35
  24. kiln_ai/adapters/model_adapters/litellm_adapter.py +116 -100
  25. kiln_ai/adapters/model_adapters/litellm_config.py +3 -2
  26. kiln_ai/adapters/model_adapters/test_base_adapter.py +299 -52
  27. kiln_ai/adapters/model_adapters/test_litellm_adapter.py +121 -22
  28. kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +44 -2
  29. kiln_ai/adapters/model_adapters/test_structured_output.py +48 -18
  30. kiln_ai/adapters/parsers/base_parser.py +0 -3
  31. kiln_ai/adapters/parsers/parser_registry.py +5 -3
  32. kiln_ai/adapters/parsers/r1_parser.py +17 -2
  33. kiln_ai/adapters/parsers/request_formatters.py +40 -0
  34. kiln_ai/adapters/parsers/test_parser_registry.py +2 -2
  35. kiln_ai/adapters/parsers/test_r1_parser.py +44 -1
  36. kiln_ai/adapters/parsers/test_request_formatters.py +76 -0
  37. kiln_ai/adapters/prompt_builders.py +14 -17
  38. kiln_ai/adapters/provider_tools.py +39 -4
  39. kiln_ai/adapters/repair/test_repair_task.py +27 -5
  40. kiln_ai/adapters/test_adapter_registry.py +88 -28
  41. kiln_ai/adapters/test_ml_model_list.py +158 -0
  42. kiln_ai/adapters/test_prompt_adaptors.py +17 -3
  43. kiln_ai/adapters/test_prompt_builders.py +27 -19
  44. kiln_ai/adapters/test_provider_tools.py +130 -12
  45. kiln_ai/datamodel/__init__.py +2 -2
  46. kiln_ai/datamodel/datamodel_enums.py +43 -4
  47. kiln_ai/datamodel/dataset_filters.py +69 -1
  48. kiln_ai/datamodel/dataset_split.py +4 -0
  49. kiln_ai/datamodel/eval.py +8 -0
  50. kiln_ai/datamodel/finetune.py +13 -7
  51. kiln_ai/datamodel/prompt_id.py +1 -0
  52. kiln_ai/datamodel/task.py +68 -7
  53. kiln_ai/datamodel/task_output.py +1 -1
  54. kiln_ai/datamodel/task_run.py +39 -7
  55. kiln_ai/datamodel/test_basemodel.py +5 -8
  56. kiln_ai/datamodel/test_dataset_filters.py +82 -0
  57. kiln_ai/datamodel/test_dataset_split.py +2 -8
  58. kiln_ai/datamodel/test_example_models.py +54 -0
  59. kiln_ai/datamodel/test_models.py +80 -9
  60. kiln_ai/datamodel/test_task.py +168 -2
  61. kiln_ai/utils/async_job_runner.py +106 -0
  62. kiln_ai/utils/config.py +3 -2
  63. kiln_ai/utils/dataset_import.py +81 -19
  64. kiln_ai/utils/logging.py +165 -0
  65. kiln_ai/utils/test_async_job_runner.py +199 -0
  66. kiln_ai/utils/test_config.py +23 -0
  67. kiln_ai/utils/test_dataset_import.py +272 -10
  68. {kiln_ai-0.15.0.dist-info → kiln_ai-0.17.0.dist-info}/METADATA +1 -1
  69. kiln_ai-0.17.0.dist-info/RECORD +113 -0
  70. kiln_ai-0.15.0.dist-info/RECORD +0 -104
  71. {kiln_ai-0.15.0.dist-info → kiln_ai-0.17.0.dist-info}/WHEEL +0 -0
  72. {kiln_ai-0.15.0.dist-info → kiln_ai-0.17.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -0,0 +1,106 @@
1
+ import asyncio
2
+ import logging
3
+ from dataclasses import dataclass
4
+ from typing import AsyncGenerator, Awaitable, Callable, List, TypeVar
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+ T = TypeVar("T")
9
+
10
+
11
+ @dataclass
12
+ class Progress:
13
+ complete: int
14
+ total: int
15
+ errors: int
16
+
17
+
18
+ class AsyncJobRunner:
19
+ def __init__(self, concurrency: int = 1):
20
+ if concurrency < 1:
21
+ raise ValueError("concurrency must be ≥ 1")
22
+ self.concurrency = concurrency
23
+
24
+ async def run(
25
+ self,
26
+ jobs: List[T],
27
+ run_job: Callable[[T], Awaitable[bool]],
28
+ ) -> AsyncGenerator[Progress, None]:
29
+ """
30
+ Runs the jobs with parallel workers and yields progress updates.
31
+ """
32
+ complete = 0
33
+ errors = 0
34
+ total = len(jobs)
35
+
36
+ # Send initial status
37
+ yield Progress(complete=complete, total=total, errors=errors)
38
+
39
+ worker_queue: asyncio.Queue[T] = asyncio.Queue()
40
+ for job in jobs:
41
+ worker_queue.put_nowait(job)
42
+
43
+ # simple status queue to return progress. True=success, False=error
44
+ status_queue: asyncio.Queue[bool] = asyncio.Queue()
45
+
46
+ workers = []
47
+ for _ in range(self.concurrency):
48
+ task = asyncio.create_task(
49
+ self._run_worker(worker_queue, status_queue, run_job),
50
+ )
51
+ workers.append(task)
52
+
53
+ try:
54
+ # Send status updates until workers are done, and they are all sent
55
+ while not status_queue.empty() or not all(
56
+ worker.done() for worker in workers
57
+ ):
58
+ try:
59
+ # Use timeout to prevent hanging if all workers complete
60
+ # between our while condition check and get()
61
+ success = await asyncio.wait_for(status_queue.get(), timeout=0.1)
62
+ if success:
63
+ complete += 1
64
+ else:
65
+ errors += 1
66
+
67
+ yield Progress(complete=complete, total=total, errors=errors)
68
+ except asyncio.TimeoutError:
69
+ # Timeout is expected, just continue to recheck worker status
70
+ # Don't love this but beats sentinels for reliability
71
+ continue
72
+ finally:
73
+ # Cancel outstanding workers on early exit or error
74
+ for w in workers:
75
+ w.cancel()
76
+
77
+ # These are redundant, but keeping them will catch async errors
78
+ await asyncio.gather(*workers)
79
+ await worker_queue.join()
80
+
81
+ async def _run_worker(
82
+ self,
83
+ worker_queue: asyncio.Queue[T],
84
+ status_queue: asyncio.Queue[bool],
85
+ run_job: Callable[[T], Awaitable[bool]],
86
+ ):
87
+ while True:
88
+ try:
89
+ job = worker_queue.get_nowait()
90
+ except asyncio.QueueEmpty:
91
+ # worker can end when the queue is empty
92
+ break
93
+
94
+ try:
95
+ success = await run_job(job)
96
+ except Exception:
97
+ logger.error("Job failed to complete", exc_info=True)
98
+ success = False
99
+
100
+ try:
101
+ await status_queue.put(success)
102
+ except Exception:
103
+ logger.error("Failed to enqueue status for job", exc_info=True)
104
+ finally:
105
+ # Always mark the dequeued task as done, even on exceptions
106
+ worker_queue.task_done()
kiln_ai/utils/config.py CHANGED
@@ -138,6 +138,7 @@ class Config:
138
138
  sensitive_keys=["api_key"],
139
139
  ),
140
140
  }
141
+ self._lock = threading.Lock()
141
142
  self._settings = self.load_settings()
142
143
 
143
144
  @classmethod
@@ -180,7 +181,7 @@ class Config:
180
181
  return None if value is None else property_config.type(value)
181
182
 
182
183
  def __setattr__(self, name, value):
183
- if name in ("_properties", "_settings"):
184
+ if name in ("_properties", "_settings", "_lock"):
184
185
  super().__setattr__(name, value)
185
186
  elif name in self._properties:
186
187
  self.update_settings({name: value})
@@ -234,7 +235,7 @@ class Config:
234
235
 
235
236
  def update_settings(self, new_settings: Dict[str, Any]):
236
237
  # Lock to prevent race conditions in multi-threaded scenarios
237
- with threading.Lock():
238
+ with self._lock:
238
239
  # Fresh load to avoid clobbering changes from other instances
239
240
  current_settings = self.load_settings()
240
241
  current_settings.update(new_settings)
@@ -1,11 +1,12 @@
1
1
  import csv
2
2
  import logging
3
+ import random
3
4
  import time
4
5
  from dataclasses import dataclass
5
6
  from enum import Enum
6
7
  from typing import Dict, Protocol
7
8
 
8
- from pydantic import BaseModel, Field, ValidationError, field_validator
9
+ from pydantic import BaseModel, Field, ValidationError
9
10
 
10
11
  from kiln_ai.datamodel import DataSource, DataSourceType, Task, TaskOutput, TaskRun
11
12
 
@@ -20,14 +21,36 @@ class DatasetImportFormat(str, Enum):
20
21
  CSV = "csv"
21
22
 
22
23
 
24
+ @dataclass
25
+ class ImportConfig:
26
+ """Configuration for importing a dataset"""
27
+
28
+ dataset_type: DatasetImportFormat
29
+ dataset_path: str
30
+ dataset_name: str
31
+ """
32
+ A set of splits to assign to the import (as dataset tags).
33
+ The keys are the names of the splits (tag name), and the values are the proportions of the dataset to include in each split (should sum to 1).
34
+ """
35
+ tag_splits: Dict[str, float] | None = None
36
+
37
+ def validate_tag_splits(self) -> None:
38
+ if self.tag_splits:
39
+ EPSILON = 0.001 # Allow for small floating point errors
40
+ if abs(sum(self.tag_splits.values()) - 1) > EPSILON:
41
+ raise ValueError(
42
+ "Splits must sum to 1. The following splits do not: "
43
+ + ", ".join(f"{k}: {v}" for k, v in self.tag_splits.items())
44
+ )
45
+
46
+
23
47
  class Importer(Protocol):
24
48
  """Protocol for dataset importers"""
25
49
 
26
50
  def __call__(
27
51
  self,
28
52
  task: Task,
29
- dataset_path: str,
30
- dataset_name: str,
53
+ config: ImportConfig,
31
54
  ) -> int: ...
32
55
 
33
56
 
@@ -90,6 +113,44 @@ def without_none_values(d: dict) -> dict:
90
113
  return {k: v for k, v in d.items() if v is not None}
91
114
 
92
115
 
116
+ def add_tag_splits(runs: list[TaskRun], tag_splits: Dict[str, float] | None) -> None:
117
+ """Assign split tags to runs according to configured proportions.
118
+
119
+ Args:
120
+ runs: List of TaskRun objects to assign tags to
121
+ tag_splits: Dictionary mapping tag names to their desired proportions
122
+
123
+ The assignment is random but ensures the proportions match the configured splits
124
+ as closely as possible given the number of runs.
125
+ """
126
+ if not tag_splits:
127
+ return
128
+
129
+ # Calculate exact number of runs for each split
130
+ total_runs = len(runs)
131
+ split_counts = {
132
+ tag: int(proportion * total_runs) for tag, proportion in tag_splits.items()
133
+ }
134
+
135
+ # Handle rounding errors by adjusting the largest split
136
+ remaining = total_runs - sum(split_counts.values())
137
+ if remaining != 0:
138
+ largest_split = max(split_counts.items(), key=lambda x: x[1])
139
+ split_counts[largest_split[0]] += remaining
140
+
141
+ # Create a list of tags with the correct counts
142
+ tags_to_assign = []
143
+ for tag, count in split_counts.items():
144
+ tags_to_assign.extend([tag] * count)
145
+
146
+ # Shuffle the tags to randomize assignment
147
+ random.shuffle(tags_to_assign)
148
+
149
+ # Assign tags to runs
150
+ for run, tag in zip(runs, tags_to_assign):
151
+ run.tags.append(tag)
152
+
153
+
93
154
  def create_task_run_from_csv_row(
94
155
  task: Task,
95
156
  row: dict[str, str],
@@ -143,18 +204,24 @@ def create_task_run_from_csv_row(
143
204
  return run
144
205
 
145
206
 
146
- def import_csv(task: Task, dataset_path: str, dataset_name: str) -> int:
207
+ def import_csv(
208
+ task: Task,
209
+ config: ImportConfig,
210
+ ) -> int:
147
211
  """Import a CSV dataset.
148
212
 
149
213
  All rows are validated before any are persisted to files to avoid partial imports."""
150
214
 
151
215
  session_id = str(int(time.time()))
216
+ dataset_path = config.dataset_path
217
+ dataset_name = config.dataset_name
218
+ tag_splits = config.tag_splits
152
219
 
153
220
  required_headers = {"input", "output"} # minimum required headers
154
221
  optional_headers = {"reasoning", "tags", "chain_of_thought"} # optional headers
155
222
 
156
223
  rows: list[TaskRun] = []
157
- with open(dataset_path, "r", newline="") as csvfile:
224
+ with open(dataset_path, "r", newline="", encoding="utf-8") as csvfile:
158
225
  reader = csv.DictReader(csvfile)
159
226
 
160
227
  # Check if we have headers
@@ -197,6 +264,8 @@ def import_csv(task: Task, dataset_path: str, dataset_name: str) -> int:
197
264
  ) from e
198
265
  rows.append(run)
199
266
 
267
+ add_tag_splits(rows, tag_splits)
268
+
200
269
  # now that we know all rows are valid, we can save them
201
270
  for run in rows:
202
271
  run.save_to_file()
@@ -209,24 +278,17 @@ DATASET_IMPORTERS: Dict[DatasetImportFormat, Importer] = {
209
278
  }
210
279
 
211
280
 
212
- @dataclass
213
- class ImportConfig:
214
- """Configuration for importing a dataset"""
215
-
216
- dataset_type: DatasetImportFormat
217
- dataset_path: str
218
- dataset_name: str
219
-
220
-
221
281
  class DatasetFileImporter:
222
282
  """Import a dataset from a file"""
223
283
 
224
284
  def __init__(self, task: Task, config: ImportConfig):
225
285
  self.task = task
226
- self.dataset_type = config.dataset_type
227
- self.dataset_path = config.dataset_path
228
- self.dataset_name = config.dataset_name
286
+ config.validate_tag_splits()
287
+ self.config = config
229
288
 
230
289
  def create_runs_from_file(self) -> int:
231
- fn = DATASET_IMPORTERS[self.dataset_type]
232
- return fn(self.task, self.dataset_path, self.dataset_name)
290
+ fn = DATASET_IMPORTERS[self.config.dataset_type]
291
+ return fn(
292
+ self.task,
293
+ self.config,
294
+ )
@@ -0,0 +1,165 @@
1
+ import datetime
2
+ import json
3
+ import logging
4
+ import logging.handlers
5
+ import os
6
+
7
+ import litellm
8
+ from litellm.integrations.custom_logger import CustomLogger
9
+ from litellm.litellm_core_utils.litellm_logging import Logging
10
+
11
+ from kiln_ai.utils.config import Config
12
+
13
+
14
+ def get_default_formatter() -> str:
15
+ return "%(asctime)s.%(msecs)03d - %(levelname)s - %(name)s - %(message)s"
16
+
17
+
18
+ def get_log_file_path(filename: str) -> str:
19
+ """Get the path to the log file, using environment override if specified.
20
+
21
+ Returns:
22
+ str: The path to the log file
23
+ """
24
+ log_path_default = os.path.join(Config.settings_dir(), "logs", filename)
25
+ log_path = os.getenv("KILN_LOG_FILE", log_path_default)
26
+
27
+ # Ensure the log directory exists
28
+ os.makedirs(os.path.dirname(log_path), exist_ok=True)
29
+ return log_path
30
+
31
+
32
+ class CustomLiteLLMLogger(CustomLogger):
33
+ def __init__(self, logger: logging.Logger):
34
+ self.logger = logger
35
+
36
+ def log_pre_api_call(self, model, messages, kwargs):
37
+ api_base = kwargs.get("litellm_params", {}).get("api_base", "")
38
+ headers = kwargs.get("additional_args", {}).get("headers", {})
39
+ data = kwargs.get("additional_args", {}).get("complete_input_dict", {})
40
+
41
+ try:
42
+ # Print the curl command for the request
43
+ logger = Logging(
44
+ model=model,
45
+ messages=messages,
46
+ stream=False,
47
+ call_type="completion",
48
+ start_time=datetime.datetime.now(),
49
+ litellm_call_id="",
50
+ function_id="na",
51
+ kwargs=kwargs,
52
+ )
53
+ curl_command = logger._get_request_curl_command(
54
+ api_base=api_base,
55
+ headers=headers,
56
+ additional_args=kwargs,
57
+ data=data,
58
+ )
59
+ self.logger.info(f"{curl_command}")
60
+ except Exception as e:
61
+ self.logger.info(f"Curl Command: Could not print {e}")
62
+
63
+ # Print the formatted input data for the request in API format, pretty print
64
+ try:
65
+ self.logger.info(
66
+ f"Formatted Input Data (API):\n{json.dumps(data, indent=2)}"
67
+ )
68
+ except Exception as e:
69
+ self.logger.info(f"Formatted Input Data (API): Could not print {e}")
70
+
71
+ # Print the messages for the request in LiteLLM Message list, pretty print
72
+ try:
73
+ json_messages = json.dumps(messages, indent=2)
74
+ self.logger.info(f"Messages:\n{json_messages}")
75
+ except Exception as e:
76
+ self.logger.info(f"Messages: Could not print {e}")
77
+
78
+ def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
79
+ # No op
80
+ pass
81
+
82
+ def log_success_event(self, kwargs, response_obj, start_time, end_time):
83
+ litellm_logger = logging.getLogger("LiteLLM")
84
+ litellm_logger.error(
85
+ "Used a sync call in Litellm. Kiln should use async calls."
86
+ )
87
+
88
+ def log_failure_event(self, kwargs, response_obj, start_time, end_time):
89
+ # This logging method is supposed to be called by Litellm in synchronous error cases (Kiln should use async calls instead)
90
+ # but it appears to also be getting called in async calls that fail early (e.g. UnsupportedParamsError).
91
+ litellm_logger = logging.getLogger("LiteLLM")
92
+ litellm_logger.error(
93
+ "LiteLLM logged a synchronous failure event. This may result from a sync call, or from an async call failing early (e.g. invalid parameters). Make sure you are using async calls.",
94
+ )
95
+
96
+ #### ASYNC #### - for acompletion/aembeddings
97
+
98
+ async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
99
+ try:
100
+ if len(response_obj.choices) == 1:
101
+ if response_obj.choices[0].message.tool_calls:
102
+ for tool_call in response_obj.choices[0].message.tool_calls:
103
+ try:
104
+ args = tool_call.function.arguments
105
+ function_name = tool_call.function.name
106
+ self.logger.info(
107
+ f"Model Response Tool Call Arguments [{function_name}]:\n{args}"
108
+ )
109
+ except Exception:
110
+ self.logger.info(f"Model Response Tool Call:\n{tool_call}")
111
+
112
+ content = response_obj.choices[0].message.content
113
+ if content:
114
+ try:
115
+ # JSON format logs if possible
116
+ json_content = json.loads(content)
117
+ self.logger.info(
118
+ f"Model Response Content:\n{json.dumps(json_content, indent=2)}"
119
+ )
120
+ except Exception:
121
+ self.logger.info(f"Model Response Content:\n{content}")
122
+ elif len(response_obj.choices) > 1:
123
+ self.logger.info(
124
+ f"Model Response (multiple choices):\n{response_obj.choices}"
125
+ )
126
+ else:
127
+ self.logger.info("Model Response: No choices returned")
128
+
129
+ except Exception as e:
130
+ self.logger.info(f"Model Response: Could not print {e}")
131
+
132
+ async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
133
+ self.logger.info(f"LiteLLM Failure: {response_obj}")
134
+
135
+
136
+ def setup_litellm_logging(filename: str = "model_calls.log"):
137
+ # Check if we already have a custom litellm logger
138
+ for callback in litellm.callbacks or []:
139
+ if isinstance(callback, CustomLiteLLMLogger):
140
+ return # We already have a custom litellm logger
141
+
142
+ # If we don't have a custom litellm logger, create one
143
+ # Disable the default litellm logger except for errors. It's ugly, hard to use, and we don't want it to mix with kiln logs.
144
+ litellm_logger = logging.getLogger("LiteLLM")
145
+ litellm_logger.setLevel(logging.ERROR)
146
+
147
+ # Create a logger that logs to files, with a max size of 5MB and 3 backup files
148
+ handler = logging.handlers.RotatingFileHandler(
149
+ get_log_file_path(filename),
150
+ maxBytes=5 * 1024 * 1024, # 5MB
151
+ backupCount=3,
152
+ )
153
+
154
+ # Set formatter to match the default formatting
155
+ formatter = logging.Formatter(get_default_formatter())
156
+ handler.setFormatter(formatter)
157
+
158
+ # Create a new logger for model calls
159
+ model_calls_logger = logging.getLogger("ModelCalls")
160
+ model_calls_logger.setLevel(logging.INFO)
161
+ model_calls_logger.propagate = False # Only log to file
162
+ model_calls_logger.addHandler(handler)
163
+
164
+ # Tell litellm to use our custom logger
165
+ litellm.callbacks = [CustomLiteLLMLogger(model_calls_logger)]
@@ -0,0 +1,199 @@
1
+ from typing import List
2
+ from unittest.mock import AsyncMock, patch
3
+
4
+ import pytest
5
+
6
+ from kiln_ai.utils.async_job_runner import AsyncJobRunner, Progress
7
+
8
+
9
+ @pytest.mark.parametrize("concurrency", [0, -1, -25])
10
+ def test_invalid_concurrency_raises(concurrency):
11
+ with pytest.raises(ValueError):
12
+ AsyncJobRunner(concurrency=concurrency)
13
+
14
+
15
+ # Test with and without concurrency
16
+ @pytest.mark.parametrize("concurrency", [1, 25])
17
+ @pytest.mark.asyncio
18
+ async def test_async_job_runner_status_updates(concurrency):
19
+ job_count = 50
20
+ jobs = [{"id": i} for i in range(job_count)]
21
+
22
+ runner = AsyncJobRunner(concurrency=concurrency)
23
+
24
+ # fake run_job that succeeds
25
+ mock_run_job_success = AsyncMock(return_value=True)
26
+
27
+ # Expect the status updates in order, and 1 for each job
28
+ expected_completed_count = 0
29
+ async for progress in runner.run(jobs, mock_run_job_success):
30
+ assert progress.complete == expected_completed_count
31
+ expected_completed_count += 1
32
+ assert progress.errors == 0
33
+ assert progress.total == job_count
34
+
35
+ # Verify last status update was complete
36
+ assert expected_completed_count == job_count + 1
37
+
38
+ # Verify run_job was called for each job
39
+ assert mock_run_job_success.call_count == job_count
40
+
41
+ # Verify run_job was called with the correct arguments
42
+ for i in range(job_count):
43
+ mock_run_job_success.assert_any_await(jobs[i])
44
+
45
+
46
+ # Test with and without concurrency
47
+ @pytest.mark.parametrize("concurrency", [1, 25])
48
+ @pytest.mark.asyncio
49
+ async def test_async_job_runner_status_updates_empty_job_list(concurrency):
50
+ empty_job_list = []
51
+
52
+ runner = AsyncJobRunner(concurrency=concurrency)
53
+
54
+ # fake run_job that succeeds
55
+ mock_run_job_success = AsyncMock(return_value=True)
56
+
57
+ updates: List[Progress] = []
58
+ async for progress in runner.run(empty_job_list, mock_run_job_success):
59
+ updates.append(progress)
60
+
61
+ # Verify last status update was complete
62
+ assert len(updates) == 1
63
+
64
+ assert updates[0].complete == 0
65
+ assert updates[0].errors == 0
66
+ assert updates[0].total == 0
67
+
68
+ # Verify run_job was called for each job
69
+ assert mock_run_job_success.call_count == 0
70
+
71
+
72
+ @pytest.mark.parametrize("concurrency", [1, 25])
73
+ @pytest.mark.asyncio
74
+ async def test_async_job_runner_all_failures(concurrency):
75
+ job_count = 50
76
+ jobs = [{"id": i} for i in range(job_count)]
77
+
78
+ runner = AsyncJobRunner(concurrency=concurrency)
79
+
80
+ # fake run_job that fails
81
+ mock_run_job_failure = AsyncMock(return_value=False)
82
+
83
+ # Expect the status updates in order, and 1 for each job
84
+ expected_error_count = 0
85
+ async for progress in runner.run(jobs, mock_run_job_failure):
86
+ assert progress.complete == 0
87
+ assert progress.errors == expected_error_count
88
+ expected_error_count += 1
89
+ assert progress.total == job_count
90
+
91
+ # Verify last status update was complete
92
+ assert expected_error_count == job_count + 1
93
+
94
+ # Verify run_job was called for each job
95
+ assert mock_run_job_failure.call_count == job_count
96
+
97
+ # Verify run_job was called with the correct arguments
98
+ for i in range(job_count):
99
+ mock_run_job_failure.assert_any_await(jobs[i])
100
+
101
+
102
+ @pytest.mark.parametrize("concurrency", [1, 25])
103
+ @pytest.mark.asyncio
104
+ async def test_async_job_runner_partial_failures(concurrency):
105
+ job_count = 50
106
+ jobs = [{"id": i} for i in range(job_count)]
107
+
108
+ # we want to fail on some jobs and succeed on others
109
+ jobs_to_fail = set([0, 2, 4, 6, 8, 20, 25])
110
+
111
+ runner = AsyncJobRunner(concurrency=concurrency)
112
+
113
+ # fake run_job that fails
114
+ mock_run_job_partial_success = AsyncMock(
115
+ # return True for jobs that should succeed
116
+ side_effect=lambda job: job["id"] not in jobs_to_fail
117
+ )
118
+
119
+ # Expect the status updates in order, and 1 for each job
120
+ async for progress in runner.run(jobs, mock_run_job_partial_success):
121
+ assert progress.total == job_count
122
+
123
+ # Verify last status update was complete
124
+ expected_error_count = len(jobs_to_fail)
125
+ expected_success_count = len(jobs) - expected_error_count
126
+ assert progress.errors == expected_error_count
127
+ assert progress.complete == expected_success_count
128
+
129
+ # Verify run_job was called for each job
130
+ assert mock_run_job_partial_success.call_count == job_count
131
+
132
+ # Verify run_job was called with the correct arguments
133
+ for i in range(job_count):
134
+ mock_run_job_partial_success.assert_any_await(jobs[i])
135
+
136
+
137
+ @pytest.mark.parametrize("concurrency", [1, 25])
138
+ @pytest.mark.asyncio
139
+ async def test_async_job_runner_partial_raises(concurrency):
140
+ job_count = 50
141
+ jobs = [{"id": i} for i in range(job_count)]
142
+
143
+ runner = AsyncJobRunner(concurrency=concurrency)
144
+
145
+ ids_to_fail = set([10, 25])
146
+
147
+ def failure_fn(job):
148
+ if job["id"] in ids_to_fail:
149
+ raise Exception("job failed unexpectedly")
150
+ return True
151
+
152
+ # fake run_job that fails
153
+ mock_run_job_partial_success = AsyncMock(side_effect=failure_fn)
154
+
155
+ # generate all the values we expect to see in progress updates
156
+ complete_values_expected = set([i for i in range(job_count - len(ids_to_fail) + 1)])
157
+ errors_values_expected = set([i for i in range(len(ids_to_fail) + 1)])
158
+
159
+ # keep track of all the updates we see
160
+ updates: List[Progress] = []
161
+
162
+ # we keep track of the progress values we have actually seen
163
+ complete_values_actual = set()
164
+ errors_values_actual = set()
165
+
166
+ # Expect the status updates in order, and 1 for each job
167
+ async for progress in runner.run(jobs, mock_run_job_partial_success):
168
+ updates.append(progress)
169
+ complete_values_actual.add(progress.complete)
170
+ errors_values_actual.add(progress.errors)
171
+
172
+ assert progress.total == job_count
173
+
174
+ # complete values should be all the jobs, except for the ones that failed
175
+ assert progress.complete == job_count - len(ids_to_fail)
176
+
177
+ # check that the actual updates and expected updates are equivalent sets
178
+ assert complete_values_actual == complete_values_expected
179
+ assert errors_values_actual == errors_values_expected
180
+
181
+ # we should have seen one update for each job, plus one for the initial status update
182
+ assert len(updates) == job_count + 1
183
+
184
+
185
+ @pytest.mark.parametrize("concurrency", [1, 25])
186
+ @pytest.mark.asyncio
187
+ async def test_async_job_runner_cancelled(concurrency):
188
+ runner = AsyncJobRunner(concurrency=concurrency)
189
+ jobs = [{"id": i} for i in range(10)]
190
+
191
+ with patch.object(
192
+ runner,
193
+ "_run_worker",
194
+ side_effect=Exception("run_worker raised an exception"),
195
+ ):
196
+ # if an exception is raised in the task, we should see it bubble up
197
+ with pytest.raises(Exception, match="run_worker raised an exception"):
198
+ async for _ in runner.run(jobs, AsyncMock(return_value=True)):
199
+ pass
@@ -1,5 +1,6 @@
1
1
  import getpass
2
2
  import os
3
+ import threading
3
4
  from unittest.mock import patch
4
5
 
5
6
  import pytest
@@ -299,3 +300,25 @@ def test_yaml_persistence_structured_data(config_with_yaml, mock_yaml_file):
299
300
  with open(mock_yaml_file, "r") as f:
300
301
  saved_settings = yaml.safe_load(f)
301
302
  assert saved_settings["list_of_objects"] == new_settings
303
+
304
+
305
+ def test_update_settings_thread_safety(config_with_yaml):
306
+ config = config_with_yaml
307
+
308
+ exceptions = []
309
+
310
+ def update(val):
311
+ try:
312
+ config.update_settings({"int_property": val})
313
+ except Exception as e:
314
+ exceptions.append(e)
315
+
316
+ threads = [threading.Thread(target=update, args=(i,)) for i in range(5)]
317
+
318
+ for t in threads:
319
+ t.start()
320
+ for t in threads:
321
+ t.join()
322
+
323
+ assert not exceptions
324
+ assert config.int_property in range(5)