kiln-ai 0.15.0__py3-none-any.whl → 0.17.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +2 -0
- kiln_ai/adapters/adapter_registry.py +22 -44
- kiln_ai/adapters/chat/__init__.py +8 -0
- kiln_ai/adapters/chat/chat_formatter.py +234 -0
- kiln_ai/adapters/chat/test_chat_formatter.py +131 -0
- kiln_ai/adapters/data_gen/test_data_gen_task.py +19 -6
- kiln_ai/adapters/eval/base_eval.py +8 -6
- kiln_ai/adapters/eval/eval_runner.py +9 -65
- kiln_ai/adapters/eval/g_eval.py +26 -8
- kiln_ai/adapters/eval/test_base_eval.py +166 -15
- kiln_ai/adapters/eval/test_eval_runner.py +3 -0
- kiln_ai/adapters/eval/test_g_eval.py +1 -0
- kiln_ai/adapters/fine_tune/base_finetune.py +2 -2
- kiln_ai/adapters/fine_tune/dataset_formatter.py +153 -197
- kiln_ai/adapters/fine_tune/test_base_finetune.py +10 -10
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +402 -211
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +3 -3
- kiln_ai/adapters/fine_tune/test_openai_finetune.py +6 -6
- kiln_ai/adapters/fine_tune/test_together_finetune.py +1 -0
- kiln_ai/adapters/fine_tune/test_vertex_finetune.py +4 -4
- kiln_ai/adapters/fine_tune/together_finetune.py +12 -1
- kiln_ai/adapters/ml_model_list.py +556 -45
- kiln_ai/adapters/model_adapters/base_adapter.py +100 -35
- kiln_ai/adapters/model_adapters/litellm_adapter.py +116 -100
- kiln_ai/adapters/model_adapters/litellm_config.py +3 -2
- kiln_ai/adapters/model_adapters/test_base_adapter.py +299 -52
- kiln_ai/adapters/model_adapters/test_litellm_adapter.py +121 -22
- kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +44 -2
- kiln_ai/adapters/model_adapters/test_structured_output.py +48 -18
- kiln_ai/adapters/parsers/base_parser.py +0 -3
- kiln_ai/adapters/parsers/parser_registry.py +5 -3
- kiln_ai/adapters/parsers/r1_parser.py +17 -2
- kiln_ai/adapters/parsers/request_formatters.py +40 -0
- kiln_ai/adapters/parsers/test_parser_registry.py +2 -2
- kiln_ai/adapters/parsers/test_r1_parser.py +44 -1
- kiln_ai/adapters/parsers/test_request_formatters.py +76 -0
- kiln_ai/adapters/prompt_builders.py +14 -17
- kiln_ai/adapters/provider_tools.py +39 -4
- kiln_ai/adapters/repair/test_repair_task.py +27 -5
- kiln_ai/adapters/test_adapter_registry.py +88 -28
- kiln_ai/adapters/test_ml_model_list.py +158 -0
- kiln_ai/adapters/test_prompt_adaptors.py +17 -3
- kiln_ai/adapters/test_prompt_builders.py +27 -19
- kiln_ai/adapters/test_provider_tools.py +130 -12
- kiln_ai/datamodel/__init__.py +2 -2
- kiln_ai/datamodel/datamodel_enums.py +43 -4
- kiln_ai/datamodel/dataset_filters.py +69 -1
- kiln_ai/datamodel/dataset_split.py +4 -0
- kiln_ai/datamodel/eval.py +8 -0
- kiln_ai/datamodel/finetune.py +13 -7
- kiln_ai/datamodel/prompt_id.py +1 -0
- kiln_ai/datamodel/task.py +68 -7
- kiln_ai/datamodel/task_output.py +1 -1
- kiln_ai/datamodel/task_run.py +39 -7
- kiln_ai/datamodel/test_basemodel.py +5 -8
- kiln_ai/datamodel/test_dataset_filters.py +82 -0
- kiln_ai/datamodel/test_dataset_split.py +2 -8
- kiln_ai/datamodel/test_example_models.py +54 -0
- kiln_ai/datamodel/test_models.py +80 -9
- kiln_ai/datamodel/test_task.py +168 -2
- kiln_ai/utils/async_job_runner.py +106 -0
- kiln_ai/utils/config.py +3 -2
- kiln_ai/utils/dataset_import.py +81 -19
- kiln_ai/utils/logging.py +165 -0
- kiln_ai/utils/test_async_job_runner.py +199 -0
- kiln_ai/utils/test_config.py +23 -0
- kiln_ai/utils/test_dataset_import.py +272 -10
- {kiln_ai-0.15.0.dist-info → kiln_ai-0.17.0.dist-info}/METADATA +1 -1
- kiln_ai-0.17.0.dist-info/RECORD +113 -0
- kiln_ai-0.15.0.dist-info/RECORD +0 -104
- {kiln_ai-0.15.0.dist-info → kiln_ai-0.17.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.15.0.dist-info → kiln_ai-0.17.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import logging
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import AsyncGenerator, Awaitable, Callable, List, TypeVar
|
|
5
|
+
|
|
6
|
+
logger = logging.getLogger(__name__)
|
|
7
|
+
|
|
8
|
+
T = TypeVar("T")
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class Progress:
|
|
13
|
+
complete: int
|
|
14
|
+
total: int
|
|
15
|
+
errors: int
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class AsyncJobRunner:
|
|
19
|
+
def __init__(self, concurrency: int = 1):
|
|
20
|
+
if concurrency < 1:
|
|
21
|
+
raise ValueError("concurrency must be ≥ 1")
|
|
22
|
+
self.concurrency = concurrency
|
|
23
|
+
|
|
24
|
+
async def run(
|
|
25
|
+
self,
|
|
26
|
+
jobs: List[T],
|
|
27
|
+
run_job: Callable[[T], Awaitable[bool]],
|
|
28
|
+
) -> AsyncGenerator[Progress, None]:
|
|
29
|
+
"""
|
|
30
|
+
Runs the jobs with parallel workers and yields progress updates.
|
|
31
|
+
"""
|
|
32
|
+
complete = 0
|
|
33
|
+
errors = 0
|
|
34
|
+
total = len(jobs)
|
|
35
|
+
|
|
36
|
+
# Send initial status
|
|
37
|
+
yield Progress(complete=complete, total=total, errors=errors)
|
|
38
|
+
|
|
39
|
+
worker_queue: asyncio.Queue[T] = asyncio.Queue()
|
|
40
|
+
for job in jobs:
|
|
41
|
+
worker_queue.put_nowait(job)
|
|
42
|
+
|
|
43
|
+
# simple status queue to return progress. True=success, False=error
|
|
44
|
+
status_queue: asyncio.Queue[bool] = asyncio.Queue()
|
|
45
|
+
|
|
46
|
+
workers = []
|
|
47
|
+
for _ in range(self.concurrency):
|
|
48
|
+
task = asyncio.create_task(
|
|
49
|
+
self._run_worker(worker_queue, status_queue, run_job),
|
|
50
|
+
)
|
|
51
|
+
workers.append(task)
|
|
52
|
+
|
|
53
|
+
try:
|
|
54
|
+
# Send status updates until workers are done, and they are all sent
|
|
55
|
+
while not status_queue.empty() or not all(
|
|
56
|
+
worker.done() for worker in workers
|
|
57
|
+
):
|
|
58
|
+
try:
|
|
59
|
+
# Use timeout to prevent hanging if all workers complete
|
|
60
|
+
# between our while condition check and get()
|
|
61
|
+
success = await asyncio.wait_for(status_queue.get(), timeout=0.1)
|
|
62
|
+
if success:
|
|
63
|
+
complete += 1
|
|
64
|
+
else:
|
|
65
|
+
errors += 1
|
|
66
|
+
|
|
67
|
+
yield Progress(complete=complete, total=total, errors=errors)
|
|
68
|
+
except asyncio.TimeoutError:
|
|
69
|
+
# Timeout is expected, just continue to recheck worker status
|
|
70
|
+
# Don't love this but beats sentinels for reliability
|
|
71
|
+
continue
|
|
72
|
+
finally:
|
|
73
|
+
# Cancel outstanding workers on early exit or error
|
|
74
|
+
for w in workers:
|
|
75
|
+
w.cancel()
|
|
76
|
+
|
|
77
|
+
# These are redundant, but keeping them will catch async errors
|
|
78
|
+
await asyncio.gather(*workers)
|
|
79
|
+
await worker_queue.join()
|
|
80
|
+
|
|
81
|
+
async def _run_worker(
|
|
82
|
+
self,
|
|
83
|
+
worker_queue: asyncio.Queue[T],
|
|
84
|
+
status_queue: asyncio.Queue[bool],
|
|
85
|
+
run_job: Callable[[T], Awaitable[bool]],
|
|
86
|
+
):
|
|
87
|
+
while True:
|
|
88
|
+
try:
|
|
89
|
+
job = worker_queue.get_nowait()
|
|
90
|
+
except asyncio.QueueEmpty:
|
|
91
|
+
# worker can end when the queue is empty
|
|
92
|
+
break
|
|
93
|
+
|
|
94
|
+
try:
|
|
95
|
+
success = await run_job(job)
|
|
96
|
+
except Exception:
|
|
97
|
+
logger.error("Job failed to complete", exc_info=True)
|
|
98
|
+
success = False
|
|
99
|
+
|
|
100
|
+
try:
|
|
101
|
+
await status_queue.put(success)
|
|
102
|
+
except Exception:
|
|
103
|
+
logger.error("Failed to enqueue status for job", exc_info=True)
|
|
104
|
+
finally:
|
|
105
|
+
# Always mark the dequeued task as done, even on exceptions
|
|
106
|
+
worker_queue.task_done()
|
kiln_ai/utils/config.py
CHANGED
|
@@ -138,6 +138,7 @@ class Config:
|
|
|
138
138
|
sensitive_keys=["api_key"],
|
|
139
139
|
),
|
|
140
140
|
}
|
|
141
|
+
self._lock = threading.Lock()
|
|
141
142
|
self._settings = self.load_settings()
|
|
142
143
|
|
|
143
144
|
@classmethod
|
|
@@ -180,7 +181,7 @@ class Config:
|
|
|
180
181
|
return None if value is None else property_config.type(value)
|
|
181
182
|
|
|
182
183
|
def __setattr__(self, name, value):
|
|
183
|
-
if name in ("_properties", "_settings"):
|
|
184
|
+
if name in ("_properties", "_settings", "_lock"):
|
|
184
185
|
super().__setattr__(name, value)
|
|
185
186
|
elif name in self._properties:
|
|
186
187
|
self.update_settings({name: value})
|
|
@@ -234,7 +235,7 @@ class Config:
|
|
|
234
235
|
|
|
235
236
|
def update_settings(self, new_settings: Dict[str, Any]):
|
|
236
237
|
# Lock to prevent race conditions in multi-threaded scenarios
|
|
237
|
-
with
|
|
238
|
+
with self._lock:
|
|
238
239
|
# Fresh load to avoid clobbering changes from other instances
|
|
239
240
|
current_settings = self.load_settings()
|
|
240
241
|
current_settings.update(new_settings)
|
kiln_ai/utils/dataset_import.py
CHANGED
|
@@ -1,11 +1,12 @@
|
|
|
1
1
|
import csv
|
|
2
2
|
import logging
|
|
3
|
+
import random
|
|
3
4
|
import time
|
|
4
5
|
from dataclasses import dataclass
|
|
5
6
|
from enum import Enum
|
|
6
7
|
from typing import Dict, Protocol
|
|
7
8
|
|
|
8
|
-
from pydantic import BaseModel, Field, ValidationError
|
|
9
|
+
from pydantic import BaseModel, Field, ValidationError
|
|
9
10
|
|
|
10
11
|
from kiln_ai.datamodel import DataSource, DataSourceType, Task, TaskOutput, TaskRun
|
|
11
12
|
|
|
@@ -20,14 +21,36 @@ class DatasetImportFormat(str, Enum):
|
|
|
20
21
|
CSV = "csv"
|
|
21
22
|
|
|
22
23
|
|
|
24
|
+
@dataclass
|
|
25
|
+
class ImportConfig:
|
|
26
|
+
"""Configuration for importing a dataset"""
|
|
27
|
+
|
|
28
|
+
dataset_type: DatasetImportFormat
|
|
29
|
+
dataset_path: str
|
|
30
|
+
dataset_name: str
|
|
31
|
+
"""
|
|
32
|
+
A set of splits to assign to the import (as dataset tags).
|
|
33
|
+
The keys are the names of the splits (tag name), and the values are the proportions of the dataset to include in each split (should sum to 1).
|
|
34
|
+
"""
|
|
35
|
+
tag_splits: Dict[str, float] | None = None
|
|
36
|
+
|
|
37
|
+
def validate_tag_splits(self) -> None:
|
|
38
|
+
if self.tag_splits:
|
|
39
|
+
EPSILON = 0.001 # Allow for small floating point errors
|
|
40
|
+
if abs(sum(self.tag_splits.values()) - 1) > EPSILON:
|
|
41
|
+
raise ValueError(
|
|
42
|
+
"Splits must sum to 1. The following splits do not: "
|
|
43
|
+
+ ", ".join(f"{k}: {v}" for k, v in self.tag_splits.items())
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
|
|
23
47
|
class Importer(Protocol):
|
|
24
48
|
"""Protocol for dataset importers"""
|
|
25
49
|
|
|
26
50
|
def __call__(
|
|
27
51
|
self,
|
|
28
52
|
task: Task,
|
|
29
|
-
|
|
30
|
-
dataset_name: str,
|
|
53
|
+
config: ImportConfig,
|
|
31
54
|
) -> int: ...
|
|
32
55
|
|
|
33
56
|
|
|
@@ -90,6 +113,44 @@ def without_none_values(d: dict) -> dict:
|
|
|
90
113
|
return {k: v for k, v in d.items() if v is not None}
|
|
91
114
|
|
|
92
115
|
|
|
116
|
+
def add_tag_splits(runs: list[TaskRun], tag_splits: Dict[str, float] | None) -> None:
|
|
117
|
+
"""Assign split tags to runs according to configured proportions.
|
|
118
|
+
|
|
119
|
+
Args:
|
|
120
|
+
runs: List of TaskRun objects to assign tags to
|
|
121
|
+
tag_splits: Dictionary mapping tag names to their desired proportions
|
|
122
|
+
|
|
123
|
+
The assignment is random but ensures the proportions match the configured splits
|
|
124
|
+
as closely as possible given the number of runs.
|
|
125
|
+
"""
|
|
126
|
+
if not tag_splits:
|
|
127
|
+
return
|
|
128
|
+
|
|
129
|
+
# Calculate exact number of runs for each split
|
|
130
|
+
total_runs = len(runs)
|
|
131
|
+
split_counts = {
|
|
132
|
+
tag: int(proportion * total_runs) for tag, proportion in tag_splits.items()
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
# Handle rounding errors by adjusting the largest split
|
|
136
|
+
remaining = total_runs - sum(split_counts.values())
|
|
137
|
+
if remaining != 0:
|
|
138
|
+
largest_split = max(split_counts.items(), key=lambda x: x[1])
|
|
139
|
+
split_counts[largest_split[0]] += remaining
|
|
140
|
+
|
|
141
|
+
# Create a list of tags with the correct counts
|
|
142
|
+
tags_to_assign = []
|
|
143
|
+
for tag, count in split_counts.items():
|
|
144
|
+
tags_to_assign.extend([tag] * count)
|
|
145
|
+
|
|
146
|
+
# Shuffle the tags to randomize assignment
|
|
147
|
+
random.shuffle(tags_to_assign)
|
|
148
|
+
|
|
149
|
+
# Assign tags to runs
|
|
150
|
+
for run, tag in zip(runs, tags_to_assign):
|
|
151
|
+
run.tags.append(tag)
|
|
152
|
+
|
|
153
|
+
|
|
93
154
|
def create_task_run_from_csv_row(
|
|
94
155
|
task: Task,
|
|
95
156
|
row: dict[str, str],
|
|
@@ -143,18 +204,24 @@ def create_task_run_from_csv_row(
|
|
|
143
204
|
return run
|
|
144
205
|
|
|
145
206
|
|
|
146
|
-
def import_csv(
|
|
207
|
+
def import_csv(
|
|
208
|
+
task: Task,
|
|
209
|
+
config: ImportConfig,
|
|
210
|
+
) -> int:
|
|
147
211
|
"""Import a CSV dataset.
|
|
148
212
|
|
|
149
213
|
All rows are validated before any are persisted to files to avoid partial imports."""
|
|
150
214
|
|
|
151
215
|
session_id = str(int(time.time()))
|
|
216
|
+
dataset_path = config.dataset_path
|
|
217
|
+
dataset_name = config.dataset_name
|
|
218
|
+
tag_splits = config.tag_splits
|
|
152
219
|
|
|
153
220
|
required_headers = {"input", "output"} # minimum required headers
|
|
154
221
|
optional_headers = {"reasoning", "tags", "chain_of_thought"} # optional headers
|
|
155
222
|
|
|
156
223
|
rows: list[TaskRun] = []
|
|
157
|
-
with open(dataset_path, "r", newline="") as csvfile:
|
|
224
|
+
with open(dataset_path, "r", newline="", encoding="utf-8") as csvfile:
|
|
158
225
|
reader = csv.DictReader(csvfile)
|
|
159
226
|
|
|
160
227
|
# Check if we have headers
|
|
@@ -197,6 +264,8 @@ def import_csv(task: Task, dataset_path: str, dataset_name: str) -> int:
|
|
|
197
264
|
) from e
|
|
198
265
|
rows.append(run)
|
|
199
266
|
|
|
267
|
+
add_tag_splits(rows, tag_splits)
|
|
268
|
+
|
|
200
269
|
# now that we know all rows are valid, we can save them
|
|
201
270
|
for run in rows:
|
|
202
271
|
run.save_to_file()
|
|
@@ -209,24 +278,17 @@ DATASET_IMPORTERS: Dict[DatasetImportFormat, Importer] = {
|
|
|
209
278
|
}
|
|
210
279
|
|
|
211
280
|
|
|
212
|
-
@dataclass
|
|
213
|
-
class ImportConfig:
|
|
214
|
-
"""Configuration for importing a dataset"""
|
|
215
|
-
|
|
216
|
-
dataset_type: DatasetImportFormat
|
|
217
|
-
dataset_path: str
|
|
218
|
-
dataset_name: str
|
|
219
|
-
|
|
220
|
-
|
|
221
281
|
class DatasetFileImporter:
|
|
222
282
|
"""Import a dataset from a file"""
|
|
223
283
|
|
|
224
284
|
def __init__(self, task: Task, config: ImportConfig):
|
|
225
285
|
self.task = task
|
|
226
|
-
|
|
227
|
-
self.
|
|
228
|
-
self.dataset_name = config.dataset_name
|
|
286
|
+
config.validate_tag_splits()
|
|
287
|
+
self.config = config
|
|
229
288
|
|
|
230
289
|
def create_runs_from_file(self) -> int:
|
|
231
|
-
fn = DATASET_IMPORTERS[self.dataset_type]
|
|
232
|
-
return fn(
|
|
290
|
+
fn = DATASET_IMPORTERS[self.config.dataset_type]
|
|
291
|
+
return fn(
|
|
292
|
+
self.task,
|
|
293
|
+
self.config,
|
|
294
|
+
)
|
kiln_ai/utils/logging.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
import datetime
|
|
2
|
+
import json
|
|
3
|
+
import logging
|
|
4
|
+
import logging.handlers
|
|
5
|
+
import os
|
|
6
|
+
|
|
7
|
+
import litellm
|
|
8
|
+
from litellm.integrations.custom_logger import CustomLogger
|
|
9
|
+
from litellm.litellm_core_utils.litellm_logging import Logging
|
|
10
|
+
|
|
11
|
+
from kiln_ai.utils.config import Config
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def get_default_formatter() -> str:
|
|
15
|
+
return "%(asctime)s.%(msecs)03d - %(levelname)s - %(name)s - %(message)s"
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def get_log_file_path(filename: str) -> str:
|
|
19
|
+
"""Get the path to the log file, using environment override if specified.
|
|
20
|
+
|
|
21
|
+
Returns:
|
|
22
|
+
str: The path to the log file
|
|
23
|
+
"""
|
|
24
|
+
log_path_default = os.path.join(Config.settings_dir(), "logs", filename)
|
|
25
|
+
log_path = os.getenv("KILN_LOG_FILE", log_path_default)
|
|
26
|
+
|
|
27
|
+
# Ensure the log directory exists
|
|
28
|
+
os.makedirs(os.path.dirname(log_path), exist_ok=True)
|
|
29
|
+
return log_path
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class CustomLiteLLMLogger(CustomLogger):
|
|
33
|
+
def __init__(self, logger: logging.Logger):
|
|
34
|
+
self.logger = logger
|
|
35
|
+
|
|
36
|
+
def log_pre_api_call(self, model, messages, kwargs):
|
|
37
|
+
api_base = kwargs.get("litellm_params", {}).get("api_base", "")
|
|
38
|
+
headers = kwargs.get("additional_args", {}).get("headers", {})
|
|
39
|
+
data = kwargs.get("additional_args", {}).get("complete_input_dict", {})
|
|
40
|
+
|
|
41
|
+
try:
|
|
42
|
+
# Print the curl command for the request
|
|
43
|
+
logger = Logging(
|
|
44
|
+
model=model,
|
|
45
|
+
messages=messages,
|
|
46
|
+
stream=False,
|
|
47
|
+
call_type="completion",
|
|
48
|
+
start_time=datetime.datetime.now(),
|
|
49
|
+
litellm_call_id="",
|
|
50
|
+
function_id="na",
|
|
51
|
+
kwargs=kwargs,
|
|
52
|
+
)
|
|
53
|
+
curl_command = logger._get_request_curl_command(
|
|
54
|
+
api_base=api_base,
|
|
55
|
+
headers=headers,
|
|
56
|
+
additional_args=kwargs,
|
|
57
|
+
data=data,
|
|
58
|
+
)
|
|
59
|
+
self.logger.info(f"{curl_command}")
|
|
60
|
+
except Exception as e:
|
|
61
|
+
self.logger.info(f"Curl Command: Could not print {e}")
|
|
62
|
+
|
|
63
|
+
# Print the formatted input data for the request in API format, pretty print
|
|
64
|
+
try:
|
|
65
|
+
self.logger.info(
|
|
66
|
+
f"Formatted Input Data (API):\n{json.dumps(data, indent=2)}"
|
|
67
|
+
)
|
|
68
|
+
except Exception as e:
|
|
69
|
+
self.logger.info(f"Formatted Input Data (API): Could not print {e}")
|
|
70
|
+
|
|
71
|
+
# Print the messages for the request in LiteLLM Message list, pretty print
|
|
72
|
+
try:
|
|
73
|
+
json_messages = json.dumps(messages, indent=2)
|
|
74
|
+
self.logger.info(f"Messages:\n{json_messages}")
|
|
75
|
+
except Exception as e:
|
|
76
|
+
self.logger.info(f"Messages: Could not print {e}")
|
|
77
|
+
|
|
78
|
+
def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
|
|
79
|
+
# No op
|
|
80
|
+
pass
|
|
81
|
+
|
|
82
|
+
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
|
83
|
+
litellm_logger = logging.getLogger("LiteLLM")
|
|
84
|
+
litellm_logger.error(
|
|
85
|
+
"Used a sync call in Litellm. Kiln should use async calls."
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
|
89
|
+
# This logging method is supposed to be called by Litellm in synchronous error cases (Kiln should use async calls instead)
|
|
90
|
+
# but it appears to also be getting called in async calls that fail early (e.g. UnsupportedParamsError).
|
|
91
|
+
litellm_logger = logging.getLogger("LiteLLM")
|
|
92
|
+
litellm_logger.error(
|
|
93
|
+
"LiteLLM logged a synchronous failure event. This may result from a sync call, or from an async call failing early (e.g. invalid parameters). Make sure you are using async calls.",
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
#### ASYNC #### - for acompletion/aembeddings
|
|
97
|
+
|
|
98
|
+
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
|
99
|
+
try:
|
|
100
|
+
if len(response_obj.choices) == 1:
|
|
101
|
+
if response_obj.choices[0].message.tool_calls:
|
|
102
|
+
for tool_call in response_obj.choices[0].message.tool_calls:
|
|
103
|
+
try:
|
|
104
|
+
args = tool_call.function.arguments
|
|
105
|
+
function_name = tool_call.function.name
|
|
106
|
+
self.logger.info(
|
|
107
|
+
f"Model Response Tool Call Arguments [{function_name}]:\n{args}"
|
|
108
|
+
)
|
|
109
|
+
except Exception:
|
|
110
|
+
self.logger.info(f"Model Response Tool Call:\n{tool_call}")
|
|
111
|
+
|
|
112
|
+
content = response_obj.choices[0].message.content
|
|
113
|
+
if content:
|
|
114
|
+
try:
|
|
115
|
+
# JSON format logs if possible
|
|
116
|
+
json_content = json.loads(content)
|
|
117
|
+
self.logger.info(
|
|
118
|
+
f"Model Response Content:\n{json.dumps(json_content, indent=2)}"
|
|
119
|
+
)
|
|
120
|
+
except Exception:
|
|
121
|
+
self.logger.info(f"Model Response Content:\n{content}")
|
|
122
|
+
elif len(response_obj.choices) > 1:
|
|
123
|
+
self.logger.info(
|
|
124
|
+
f"Model Response (multiple choices):\n{response_obj.choices}"
|
|
125
|
+
)
|
|
126
|
+
else:
|
|
127
|
+
self.logger.info("Model Response: No choices returned")
|
|
128
|
+
|
|
129
|
+
except Exception as e:
|
|
130
|
+
self.logger.info(f"Model Response: Could not print {e}")
|
|
131
|
+
|
|
132
|
+
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
|
133
|
+
self.logger.info(f"LiteLLM Failure: {response_obj}")
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def setup_litellm_logging(filename: str = "model_calls.log"):
|
|
137
|
+
# Check if we already have a custom litellm logger
|
|
138
|
+
for callback in litellm.callbacks or []:
|
|
139
|
+
if isinstance(callback, CustomLiteLLMLogger):
|
|
140
|
+
return # We already have a custom litellm logger
|
|
141
|
+
|
|
142
|
+
# If we don't have a custom litellm logger, create one
|
|
143
|
+
# Disable the default litellm logger except for errors. It's ugly, hard to use, and we don't want it to mix with kiln logs.
|
|
144
|
+
litellm_logger = logging.getLogger("LiteLLM")
|
|
145
|
+
litellm_logger.setLevel(logging.ERROR)
|
|
146
|
+
|
|
147
|
+
# Create a logger that logs to files, with a max size of 5MB and 3 backup files
|
|
148
|
+
handler = logging.handlers.RotatingFileHandler(
|
|
149
|
+
get_log_file_path(filename),
|
|
150
|
+
maxBytes=5 * 1024 * 1024, # 5MB
|
|
151
|
+
backupCount=3,
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
# Set formatter to match the default formatting
|
|
155
|
+
formatter = logging.Formatter(get_default_formatter())
|
|
156
|
+
handler.setFormatter(formatter)
|
|
157
|
+
|
|
158
|
+
# Create a new logger for model calls
|
|
159
|
+
model_calls_logger = logging.getLogger("ModelCalls")
|
|
160
|
+
model_calls_logger.setLevel(logging.INFO)
|
|
161
|
+
model_calls_logger.propagate = False # Only log to file
|
|
162
|
+
model_calls_logger.addHandler(handler)
|
|
163
|
+
|
|
164
|
+
# Tell litellm to use our custom logger
|
|
165
|
+
litellm.callbacks = [CustomLiteLLMLogger(model_calls_logger)]
|
|
@@ -0,0 +1,199 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
from unittest.mock import AsyncMock, patch
|
|
3
|
+
|
|
4
|
+
import pytest
|
|
5
|
+
|
|
6
|
+
from kiln_ai.utils.async_job_runner import AsyncJobRunner, Progress
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@pytest.mark.parametrize("concurrency", [0, -1, -25])
|
|
10
|
+
def test_invalid_concurrency_raises(concurrency):
|
|
11
|
+
with pytest.raises(ValueError):
|
|
12
|
+
AsyncJobRunner(concurrency=concurrency)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
# Test with and without concurrency
|
|
16
|
+
@pytest.mark.parametrize("concurrency", [1, 25])
|
|
17
|
+
@pytest.mark.asyncio
|
|
18
|
+
async def test_async_job_runner_status_updates(concurrency):
|
|
19
|
+
job_count = 50
|
|
20
|
+
jobs = [{"id": i} for i in range(job_count)]
|
|
21
|
+
|
|
22
|
+
runner = AsyncJobRunner(concurrency=concurrency)
|
|
23
|
+
|
|
24
|
+
# fake run_job that succeeds
|
|
25
|
+
mock_run_job_success = AsyncMock(return_value=True)
|
|
26
|
+
|
|
27
|
+
# Expect the status updates in order, and 1 for each job
|
|
28
|
+
expected_completed_count = 0
|
|
29
|
+
async for progress in runner.run(jobs, mock_run_job_success):
|
|
30
|
+
assert progress.complete == expected_completed_count
|
|
31
|
+
expected_completed_count += 1
|
|
32
|
+
assert progress.errors == 0
|
|
33
|
+
assert progress.total == job_count
|
|
34
|
+
|
|
35
|
+
# Verify last status update was complete
|
|
36
|
+
assert expected_completed_count == job_count + 1
|
|
37
|
+
|
|
38
|
+
# Verify run_job was called for each job
|
|
39
|
+
assert mock_run_job_success.call_count == job_count
|
|
40
|
+
|
|
41
|
+
# Verify run_job was called with the correct arguments
|
|
42
|
+
for i in range(job_count):
|
|
43
|
+
mock_run_job_success.assert_any_await(jobs[i])
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
# Test with and without concurrency
|
|
47
|
+
@pytest.mark.parametrize("concurrency", [1, 25])
|
|
48
|
+
@pytest.mark.asyncio
|
|
49
|
+
async def test_async_job_runner_status_updates_empty_job_list(concurrency):
|
|
50
|
+
empty_job_list = []
|
|
51
|
+
|
|
52
|
+
runner = AsyncJobRunner(concurrency=concurrency)
|
|
53
|
+
|
|
54
|
+
# fake run_job that succeeds
|
|
55
|
+
mock_run_job_success = AsyncMock(return_value=True)
|
|
56
|
+
|
|
57
|
+
updates: List[Progress] = []
|
|
58
|
+
async for progress in runner.run(empty_job_list, mock_run_job_success):
|
|
59
|
+
updates.append(progress)
|
|
60
|
+
|
|
61
|
+
# Verify last status update was complete
|
|
62
|
+
assert len(updates) == 1
|
|
63
|
+
|
|
64
|
+
assert updates[0].complete == 0
|
|
65
|
+
assert updates[0].errors == 0
|
|
66
|
+
assert updates[0].total == 0
|
|
67
|
+
|
|
68
|
+
# Verify run_job was called for each job
|
|
69
|
+
assert mock_run_job_success.call_count == 0
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
@pytest.mark.parametrize("concurrency", [1, 25])
|
|
73
|
+
@pytest.mark.asyncio
|
|
74
|
+
async def test_async_job_runner_all_failures(concurrency):
|
|
75
|
+
job_count = 50
|
|
76
|
+
jobs = [{"id": i} for i in range(job_count)]
|
|
77
|
+
|
|
78
|
+
runner = AsyncJobRunner(concurrency=concurrency)
|
|
79
|
+
|
|
80
|
+
# fake run_job that fails
|
|
81
|
+
mock_run_job_failure = AsyncMock(return_value=False)
|
|
82
|
+
|
|
83
|
+
# Expect the status updates in order, and 1 for each job
|
|
84
|
+
expected_error_count = 0
|
|
85
|
+
async for progress in runner.run(jobs, mock_run_job_failure):
|
|
86
|
+
assert progress.complete == 0
|
|
87
|
+
assert progress.errors == expected_error_count
|
|
88
|
+
expected_error_count += 1
|
|
89
|
+
assert progress.total == job_count
|
|
90
|
+
|
|
91
|
+
# Verify last status update was complete
|
|
92
|
+
assert expected_error_count == job_count + 1
|
|
93
|
+
|
|
94
|
+
# Verify run_job was called for each job
|
|
95
|
+
assert mock_run_job_failure.call_count == job_count
|
|
96
|
+
|
|
97
|
+
# Verify run_job was called with the correct arguments
|
|
98
|
+
for i in range(job_count):
|
|
99
|
+
mock_run_job_failure.assert_any_await(jobs[i])
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
@pytest.mark.parametrize("concurrency", [1, 25])
|
|
103
|
+
@pytest.mark.asyncio
|
|
104
|
+
async def test_async_job_runner_partial_failures(concurrency):
|
|
105
|
+
job_count = 50
|
|
106
|
+
jobs = [{"id": i} for i in range(job_count)]
|
|
107
|
+
|
|
108
|
+
# we want to fail on some jobs and succeed on others
|
|
109
|
+
jobs_to_fail = set([0, 2, 4, 6, 8, 20, 25])
|
|
110
|
+
|
|
111
|
+
runner = AsyncJobRunner(concurrency=concurrency)
|
|
112
|
+
|
|
113
|
+
# fake run_job that fails
|
|
114
|
+
mock_run_job_partial_success = AsyncMock(
|
|
115
|
+
# return True for jobs that should succeed
|
|
116
|
+
side_effect=lambda job: job["id"] not in jobs_to_fail
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
# Expect the status updates in order, and 1 for each job
|
|
120
|
+
async for progress in runner.run(jobs, mock_run_job_partial_success):
|
|
121
|
+
assert progress.total == job_count
|
|
122
|
+
|
|
123
|
+
# Verify last status update was complete
|
|
124
|
+
expected_error_count = len(jobs_to_fail)
|
|
125
|
+
expected_success_count = len(jobs) - expected_error_count
|
|
126
|
+
assert progress.errors == expected_error_count
|
|
127
|
+
assert progress.complete == expected_success_count
|
|
128
|
+
|
|
129
|
+
# Verify run_job was called for each job
|
|
130
|
+
assert mock_run_job_partial_success.call_count == job_count
|
|
131
|
+
|
|
132
|
+
# Verify run_job was called with the correct arguments
|
|
133
|
+
for i in range(job_count):
|
|
134
|
+
mock_run_job_partial_success.assert_any_await(jobs[i])
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
@pytest.mark.parametrize("concurrency", [1, 25])
|
|
138
|
+
@pytest.mark.asyncio
|
|
139
|
+
async def test_async_job_runner_partial_raises(concurrency):
|
|
140
|
+
job_count = 50
|
|
141
|
+
jobs = [{"id": i} for i in range(job_count)]
|
|
142
|
+
|
|
143
|
+
runner = AsyncJobRunner(concurrency=concurrency)
|
|
144
|
+
|
|
145
|
+
ids_to_fail = set([10, 25])
|
|
146
|
+
|
|
147
|
+
def failure_fn(job):
|
|
148
|
+
if job["id"] in ids_to_fail:
|
|
149
|
+
raise Exception("job failed unexpectedly")
|
|
150
|
+
return True
|
|
151
|
+
|
|
152
|
+
# fake run_job that fails
|
|
153
|
+
mock_run_job_partial_success = AsyncMock(side_effect=failure_fn)
|
|
154
|
+
|
|
155
|
+
# generate all the values we expect to see in progress updates
|
|
156
|
+
complete_values_expected = set([i for i in range(job_count - len(ids_to_fail) + 1)])
|
|
157
|
+
errors_values_expected = set([i for i in range(len(ids_to_fail) + 1)])
|
|
158
|
+
|
|
159
|
+
# keep track of all the updates we see
|
|
160
|
+
updates: List[Progress] = []
|
|
161
|
+
|
|
162
|
+
# we keep track of the progress values we have actually seen
|
|
163
|
+
complete_values_actual = set()
|
|
164
|
+
errors_values_actual = set()
|
|
165
|
+
|
|
166
|
+
# Expect the status updates in order, and 1 for each job
|
|
167
|
+
async for progress in runner.run(jobs, mock_run_job_partial_success):
|
|
168
|
+
updates.append(progress)
|
|
169
|
+
complete_values_actual.add(progress.complete)
|
|
170
|
+
errors_values_actual.add(progress.errors)
|
|
171
|
+
|
|
172
|
+
assert progress.total == job_count
|
|
173
|
+
|
|
174
|
+
# complete values should be all the jobs, except for the ones that failed
|
|
175
|
+
assert progress.complete == job_count - len(ids_to_fail)
|
|
176
|
+
|
|
177
|
+
# check that the actual updates and expected updates are equivalent sets
|
|
178
|
+
assert complete_values_actual == complete_values_expected
|
|
179
|
+
assert errors_values_actual == errors_values_expected
|
|
180
|
+
|
|
181
|
+
# we should have seen one update for each job, plus one for the initial status update
|
|
182
|
+
assert len(updates) == job_count + 1
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
@pytest.mark.parametrize("concurrency", [1, 25])
|
|
186
|
+
@pytest.mark.asyncio
|
|
187
|
+
async def test_async_job_runner_cancelled(concurrency):
|
|
188
|
+
runner = AsyncJobRunner(concurrency=concurrency)
|
|
189
|
+
jobs = [{"id": i} for i in range(10)]
|
|
190
|
+
|
|
191
|
+
with patch.object(
|
|
192
|
+
runner,
|
|
193
|
+
"_run_worker",
|
|
194
|
+
side_effect=Exception("run_worker raised an exception"),
|
|
195
|
+
):
|
|
196
|
+
# if an exception is raised in the task, we should see it bubble up
|
|
197
|
+
with pytest.raises(Exception, match="run_worker raised an exception"):
|
|
198
|
+
async for _ in runner.run(jobs, AsyncMock(return_value=True)):
|
|
199
|
+
pass
|
kiln_ai/utils/test_config.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import getpass
|
|
2
2
|
import os
|
|
3
|
+
import threading
|
|
3
4
|
from unittest.mock import patch
|
|
4
5
|
|
|
5
6
|
import pytest
|
|
@@ -299,3 +300,25 @@ def test_yaml_persistence_structured_data(config_with_yaml, mock_yaml_file):
|
|
|
299
300
|
with open(mock_yaml_file, "r") as f:
|
|
300
301
|
saved_settings = yaml.safe_load(f)
|
|
301
302
|
assert saved_settings["list_of_objects"] == new_settings
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
def test_update_settings_thread_safety(config_with_yaml):
|
|
306
|
+
config = config_with_yaml
|
|
307
|
+
|
|
308
|
+
exceptions = []
|
|
309
|
+
|
|
310
|
+
def update(val):
|
|
311
|
+
try:
|
|
312
|
+
config.update_settings({"int_property": val})
|
|
313
|
+
except Exception as e:
|
|
314
|
+
exceptions.append(e)
|
|
315
|
+
|
|
316
|
+
threads = [threading.Thread(target=update, args=(i,)) for i in range(5)]
|
|
317
|
+
|
|
318
|
+
for t in threads:
|
|
319
|
+
t.start()
|
|
320
|
+
for t in threads:
|
|
321
|
+
t.join()
|
|
322
|
+
|
|
323
|
+
assert not exceptions
|
|
324
|
+
assert config.int_property in range(5)
|