kiln-ai 0.19.0__py3-none-any.whl → 0.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +8 -2
- kiln_ai/adapters/adapter_registry.py +43 -208
- kiln_ai/adapters/chat/chat_formatter.py +8 -12
- kiln_ai/adapters/chat/test_chat_formatter.py +6 -2
- kiln_ai/adapters/chunkers/__init__.py +13 -0
- kiln_ai/adapters/chunkers/base_chunker.py +42 -0
- kiln_ai/adapters/chunkers/chunker_registry.py +16 -0
- kiln_ai/adapters/chunkers/fixed_window_chunker.py +39 -0
- kiln_ai/adapters/chunkers/helpers.py +23 -0
- kiln_ai/adapters/chunkers/test_base_chunker.py +63 -0
- kiln_ai/adapters/chunkers/test_chunker_registry.py +28 -0
- kiln_ai/adapters/chunkers/test_fixed_window_chunker.py +346 -0
- kiln_ai/adapters/chunkers/test_helpers.py +75 -0
- kiln_ai/adapters/data_gen/test_data_gen_task.py +9 -3
- kiln_ai/adapters/docker_model_runner_tools.py +119 -0
- kiln_ai/adapters/embedding/__init__.py +0 -0
- kiln_ai/adapters/embedding/base_embedding_adapter.py +44 -0
- kiln_ai/adapters/embedding/embedding_registry.py +32 -0
- kiln_ai/adapters/embedding/litellm_embedding_adapter.py +199 -0
- kiln_ai/adapters/embedding/test_base_embedding_adapter.py +283 -0
- kiln_ai/adapters/embedding/test_embedding_registry.py +166 -0
- kiln_ai/adapters/embedding/test_litellm_embedding_adapter.py +1149 -0
- kiln_ai/adapters/eval/base_eval.py +2 -2
- kiln_ai/adapters/eval/eval_runner.py +9 -3
- kiln_ai/adapters/eval/g_eval.py +2 -2
- kiln_ai/adapters/eval/test_base_eval.py +2 -4
- kiln_ai/adapters/eval/test_g_eval.py +4 -5
- kiln_ai/adapters/extractors/__init__.py +18 -0
- kiln_ai/adapters/extractors/base_extractor.py +72 -0
- kiln_ai/adapters/extractors/encoding.py +20 -0
- kiln_ai/adapters/extractors/extractor_registry.py +44 -0
- kiln_ai/adapters/extractors/extractor_runner.py +112 -0
- kiln_ai/adapters/extractors/litellm_extractor.py +386 -0
- kiln_ai/adapters/extractors/test_base_extractor.py +244 -0
- kiln_ai/adapters/extractors/test_encoding.py +54 -0
- kiln_ai/adapters/extractors/test_extractor_registry.py +181 -0
- kiln_ai/adapters/extractors/test_extractor_runner.py +181 -0
- kiln_ai/adapters/extractors/test_litellm_extractor.py +1192 -0
- kiln_ai/adapters/fine_tune/__init__.py +1 -1
- kiln_ai/adapters/fine_tune/openai_finetune.py +14 -4
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +2 -2
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +2 -6
- kiln_ai/adapters/fine_tune/test_openai_finetune.py +108 -111
- kiln_ai/adapters/fine_tune/test_together_finetune.py +2 -6
- kiln_ai/adapters/ml_embedding_model_list.py +192 -0
- kiln_ai/adapters/ml_model_list.py +761 -37
- kiln_ai/adapters/model_adapters/base_adapter.py +51 -21
- kiln_ai/adapters/model_adapters/litellm_adapter.py +380 -138
- kiln_ai/adapters/model_adapters/test_base_adapter.py +193 -17
- kiln_ai/adapters/model_adapters/test_litellm_adapter.py +407 -2
- kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py +1103 -0
- kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +5 -5
- kiln_ai/adapters/model_adapters/test_structured_output.py +113 -5
- kiln_ai/adapters/ollama_tools.py +69 -12
- kiln_ai/adapters/parsers/__init__.py +1 -1
- kiln_ai/adapters/provider_tools.py +205 -47
- kiln_ai/adapters/rag/deduplication.py +49 -0
- kiln_ai/adapters/rag/progress.py +252 -0
- kiln_ai/adapters/rag/rag_runners.py +844 -0
- kiln_ai/adapters/rag/test_deduplication.py +195 -0
- kiln_ai/adapters/rag/test_progress.py +785 -0
- kiln_ai/adapters/rag/test_rag_runners.py +2376 -0
- kiln_ai/adapters/remote_config.py +80 -8
- kiln_ai/adapters/repair/test_repair_task.py +12 -9
- kiln_ai/adapters/run_output.py +3 -0
- kiln_ai/adapters/test_adapter_registry.py +657 -85
- kiln_ai/adapters/test_docker_model_runner_tools.py +305 -0
- kiln_ai/adapters/test_ml_embedding_model_list.py +429 -0
- kiln_ai/adapters/test_ml_model_list.py +251 -1
- kiln_ai/adapters/test_ollama_tools.py +340 -1
- kiln_ai/adapters/test_prompt_adaptors.py +13 -6
- kiln_ai/adapters/test_prompt_builders.py +1 -1
- kiln_ai/adapters/test_provider_tools.py +254 -8
- kiln_ai/adapters/test_remote_config.py +651 -58
- kiln_ai/adapters/vector_store/__init__.py +1 -0
- kiln_ai/adapters/vector_store/base_vector_store_adapter.py +83 -0
- kiln_ai/adapters/vector_store/lancedb_adapter.py +389 -0
- kiln_ai/adapters/vector_store/test_base_vector_store.py +160 -0
- kiln_ai/adapters/vector_store/test_lancedb_adapter.py +1841 -0
- kiln_ai/adapters/vector_store/test_vector_store_registry.py +199 -0
- kiln_ai/adapters/vector_store/vector_store_registry.py +33 -0
- kiln_ai/datamodel/__init__.py +39 -34
- kiln_ai/datamodel/basemodel.py +170 -1
- kiln_ai/datamodel/chunk.py +158 -0
- kiln_ai/datamodel/datamodel_enums.py +28 -0
- kiln_ai/datamodel/embedding.py +64 -0
- kiln_ai/datamodel/eval.py +1 -1
- kiln_ai/datamodel/external_tool_server.py +298 -0
- kiln_ai/datamodel/extraction.py +303 -0
- kiln_ai/datamodel/json_schema.py +25 -10
- kiln_ai/datamodel/project.py +40 -1
- kiln_ai/datamodel/rag.py +79 -0
- kiln_ai/datamodel/registry.py +0 -15
- kiln_ai/datamodel/run_config.py +62 -0
- kiln_ai/datamodel/task.py +2 -77
- kiln_ai/datamodel/task_output.py +6 -1
- kiln_ai/datamodel/task_run.py +41 -0
- kiln_ai/datamodel/test_attachment.py +649 -0
- kiln_ai/datamodel/test_basemodel.py +4 -4
- kiln_ai/datamodel/test_chunk_models.py +317 -0
- kiln_ai/datamodel/test_dataset_split.py +1 -1
- kiln_ai/datamodel/test_embedding_models.py +448 -0
- kiln_ai/datamodel/test_eval_model.py +6 -6
- kiln_ai/datamodel/test_example_models.py +175 -0
- kiln_ai/datamodel/test_external_tool_server.py +691 -0
- kiln_ai/datamodel/test_extraction_chunk.py +206 -0
- kiln_ai/datamodel/test_extraction_model.py +470 -0
- kiln_ai/datamodel/test_rag.py +641 -0
- kiln_ai/datamodel/test_registry.py +8 -3
- kiln_ai/datamodel/test_task.py +15 -47
- kiln_ai/datamodel/test_tool_id.py +320 -0
- kiln_ai/datamodel/test_vector_store.py +320 -0
- kiln_ai/datamodel/tool_id.py +105 -0
- kiln_ai/datamodel/vector_store.py +141 -0
- kiln_ai/tools/__init__.py +8 -0
- kiln_ai/tools/base_tool.py +82 -0
- kiln_ai/tools/built_in_tools/__init__.py +13 -0
- kiln_ai/tools/built_in_tools/math_tools.py +124 -0
- kiln_ai/tools/built_in_tools/test_math_tools.py +204 -0
- kiln_ai/tools/mcp_server_tool.py +95 -0
- kiln_ai/tools/mcp_session_manager.py +246 -0
- kiln_ai/tools/rag_tools.py +157 -0
- kiln_ai/tools/test_base_tools.py +199 -0
- kiln_ai/tools/test_mcp_server_tool.py +457 -0
- kiln_ai/tools/test_mcp_session_manager.py +1585 -0
- kiln_ai/tools/test_rag_tools.py +848 -0
- kiln_ai/tools/test_tool_registry.py +562 -0
- kiln_ai/tools/tool_registry.py +85 -0
- kiln_ai/utils/__init__.py +3 -0
- kiln_ai/utils/async_job_runner.py +62 -17
- kiln_ai/utils/config.py +24 -2
- kiln_ai/utils/env.py +15 -0
- kiln_ai/utils/filesystem.py +14 -0
- kiln_ai/utils/filesystem_cache.py +60 -0
- kiln_ai/utils/litellm.py +94 -0
- kiln_ai/utils/lock.py +100 -0
- kiln_ai/utils/mime_type.py +38 -0
- kiln_ai/utils/open_ai_types.py +94 -0
- kiln_ai/utils/pdf_utils.py +38 -0
- kiln_ai/utils/project_utils.py +17 -0
- kiln_ai/utils/test_async_job_runner.py +151 -35
- kiln_ai/utils/test_config.py +138 -1
- kiln_ai/utils/test_env.py +142 -0
- kiln_ai/utils/test_filesystem_cache.py +316 -0
- kiln_ai/utils/test_litellm.py +206 -0
- kiln_ai/utils/test_lock.py +185 -0
- kiln_ai/utils/test_mime_type.py +66 -0
- kiln_ai/utils/test_open_ai_types.py +131 -0
- kiln_ai/utils/test_pdf_utils.py +73 -0
- kiln_ai/utils/test_uuid.py +111 -0
- kiln_ai/utils/test_validation.py +524 -0
- kiln_ai/utils/uuid.py +9 -0
- kiln_ai/utils/validation.py +90 -0
- {kiln_ai-0.19.0.dist-info → kiln_ai-0.21.0.dist-info}/METADATA +12 -5
- kiln_ai-0.21.0.dist-info/RECORD +211 -0
- kiln_ai-0.19.0.dist-info/RECORD +0 -115
- {kiln_ai-0.19.0.dist-info → kiln_ai-0.21.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.19.0.dist-info → kiln_ai-0.21.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -3,30 +3,51 @@ from unittest.mock import AsyncMock, patch
|
|
|
3
3
|
|
|
4
4
|
import pytest
|
|
5
5
|
|
|
6
|
-
from kiln_ai.utils.async_job_runner import
|
|
6
|
+
from kiln_ai.utils.async_job_runner import (
|
|
7
|
+
AsyncJobRunner,
|
|
8
|
+
AsyncJobRunnerObserver,
|
|
9
|
+
Progress,
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@pytest.fixture
|
|
14
|
+
def mock_async_run_job_fn_success():
|
|
15
|
+
return AsyncMock(return_value=True)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@pytest.fixture
|
|
19
|
+
def mock_async_run_job_fn_failure():
|
|
20
|
+
return AsyncMock(return_value=False)
|
|
7
21
|
|
|
8
22
|
|
|
9
23
|
@pytest.mark.parametrize("concurrency", [0, -1, -25])
|
|
10
|
-
def test_invalid_concurrency_raises(concurrency):
|
|
24
|
+
def test_invalid_concurrency_raises(concurrency, mock_async_run_job_fn_success):
|
|
11
25
|
with pytest.raises(ValueError):
|
|
12
|
-
AsyncJobRunner(
|
|
26
|
+
AsyncJobRunner(
|
|
27
|
+
concurrency=concurrency,
|
|
28
|
+
jobs=[],
|
|
29
|
+
run_job_fn=mock_async_run_job_fn_success,
|
|
30
|
+
)
|
|
13
31
|
|
|
14
32
|
|
|
15
33
|
# Test with and without concurrency
|
|
16
34
|
@pytest.mark.parametrize("concurrency", [1, 25])
|
|
17
35
|
@pytest.mark.asyncio
|
|
18
|
-
async def test_async_job_runner_status_updates(
|
|
36
|
+
async def test_async_job_runner_status_updates(
|
|
37
|
+
concurrency, mock_async_run_job_fn_success
|
|
38
|
+
):
|
|
19
39
|
job_count = 50
|
|
20
40
|
jobs = [{"id": i} for i in range(job_count)]
|
|
21
41
|
|
|
22
|
-
runner = AsyncJobRunner(
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
42
|
+
runner = AsyncJobRunner(
|
|
43
|
+
concurrency=concurrency,
|
|
44
|
+
jobs=jobs,
|
|
45
|
+
run_job_fn=mock_async_run_job_fn_success,
|
|
46
|
+
)
|
|
26
47
|
|
|
27
48
|
# Expect the status updates in order, and 1 for each job
|
|
28
49
|
expected_completed_count = 0
|
|
29
|
-
async for progress in runner.run(
|
|
50
|
+
async for progress in runner.run():
|
|
30
51
|
assert progress.complete == expected_completed_count
|
|
31
52
|
expected_completed_count += 1
|
|
32
53
|
assert progress.errors == 0
|
|
@@ -36,26 +57,29 @@ async def test_async_job_runner_status_updates(concurrency):
|
|
|
36
57
|
assert expected_completed_count == job_count + 1
|
|
37
58
|
|
|
38
59
|
# Verify run_job was called for each job
|
|
39
|
-
assert
|
|
60
|
+
assert mock_async_run_job_fn_success.call_count == job_count
|
|
40
61
|
|
|
41
62
|
# Verify run_job was called with the correct arguments
|
|
42
63
|
for i in range(job_count):
|
|
43
|
-
|
|
64
|
+
mock_async_run_job_fn_success.assert_any_await(jobs[i])
|
|
44
65
|
|
|
45
66
|
|
|
46
67
|
# Test with and without concurrency
|
|
47
68
|
@pytest.mark.parametrize("concurrency", [1, 25])
|
|
48
69
|
@pytest.mark.asyncio
|
|
49
|
-
async def test_async_job_runner_status_updates_empty_job_list(
|
|
70
|
+
async def test_async_job_runner_status_updates_empty_job_list(
|
|
71
|
+
concurrency, mock_async_run_job_fn_success
|
|
72
|
+
):
|
|
50
73
|
empty_job_list = []
|
|
51
74
|
|
|
52
|
-
runner = AsyncJobRunner(
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
75
|
+
runner = AsyncJobRunner(
|
|
76
|
+
concurrency=concurrency,
|
|
77
|
+
jobs=empty_job_list,
|
|
78
|
+
run_job_fn=mock_async_run_job_fn_success,
|
|
79
|
+
)
|
|
56
80
|
|
|
57
81
|
updates: List[Progress] = []
|
|
58
|
-
async for progress in runner.run(
|
|
82
|
+
async for progress in runner.run():
|
|
59
83
|
updates.append(progress)
|
|
60
84
|
|
|
61
85
|
# Verify last status update was complete
|
|
@@ -66,23 +90,26 @@ async def test_async_job_runner_status_updates_empty_job_list(concurrency):
|
|
|
66
90
|
assert updates[0].total == 0
|
|
67
91
|
|
|
68
92
|
# Verify run_job was called for each job
|
|
69
|
-
assert
|
|
93
|
+
assert mock_async_run_job_fn_success.call_count == 0
|
|
70
94
|
|
|
71
95
|
|
|
72
96
|
@pytest.mark.parametrize("concurrency", [1, 25])
|
|
73
97
|
@pytest.mark.asyncio
|
|
74
|
-
async def test_async_job_runner_all_failures(
|
|
98
|
+
async def test_async_job_runner_all_failures(
|
|
99
|
+
concurrency, mock_async_run_job_fn_failure
|
|
100
|
+
):
|
|
75
101
|
job_count = 50
|
|
76
102
|
jobs = [{"id": i} for i in range(job_count)]
|
|
77
103
|
|
|
78
|
-
runner = AsyncJobRunner(
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
104
|
+
runner = AsyncJobRunner(
|
|
105
|
+
concurrency=concurrency,
|
|
106
|
+
jobs=jobs,
|
|
107
|
+
run_job_fn=mock_async_run_job_fn_failure,
|
|
108
|
+
)
|
|
82
109
|
|
|
83
110
|
# Expect the status updates in order, and 1 for each job
|
|
84
111
|
expected_error_count = 0
|
|
85
|
-
async for progress in runner.run(
|
|
112
|
+
async for progress in runner.run():
|
|
86
113
|
assert progress.complete == 0
|
|
87
114
|
assert progress.errors == expected_error_count
|
|
88
115
|
expected_error_count += 1
|
|
@@ -92,11 +119,11 @@ async def test_async_job_runner_all_failures(concurrency):
|
|
|
92
119
|
assert expected_error_count == job_count + 1
|
|
93
120
|
|
|
94
121
|
# Verify run_job was called for each job
|
|
95
|
-
assert
|
|
122
|
+
assert mock_async_run_job_fn_failure.call_count == job_count
|
|
96
123
|
|
|
97
124
|
# Verify run_job was called with the correct arguments
|
|
98
125
|
for i in range(job_count):
|
|
99
|
-
|
|
126
|
+
mock_async_run_job_fn_failure.assert_any_await(jobs[i])
|
|
100
127
|
|
|
101
128
|
|
|
102
129
|
@pytest.mark.parametrize("concurrency", [1, 25])
|
|
@@ -108,16 +135,20 @@ async def test_async_job_runner_partial_failures(concurrency):
|
|
|
108
135
|
# we want to fail on some jobs and succeed on others
|
|
109
136
|
jobs_to_fail = set([0, 2, 4, 6, 8, 20, 25])
|
|
110
137
|
|
|
111
|
-
runner = AsyncJobRunner(concurrency=concurrency)
|
|
112
|
-
|
|
113
138
|
# fake run_job that fails
|
|
114
139
|
mock_run_job_partial_success = AsyncMock(
|
|
115
140
|
# return True for jobs that should succeed
|
|
116
141
|
side_effect=lambda job: job["id"] not in jobs_to_fail
|
|
117
142
|
)
|
|
118
143
|
|
|
144
|
+
runner = AsyncJobRunner(
|
|
145
|
+
concurrency=concurrency,
|
|
146
|
+
jobs=jobs,
|
|
147
|
+
run_job_fn=mock_run_job_partial_success,
|
|
148
|
+
)
|
|
149
|
+
|
|
119
150
|
# Expect the status updates in order, and 1 for each job
|
|
120
|
-
async for progress in runner.run(
|
|
151
|
+
async for progress in runner.run():
|
|
121
152
|
assert progress.total == job_count
|
|
122
153
|
|
|
123
154
|
# Verify last status update was complete
|
|
@@ -140,8 +171,6 @@ async def test_async_job_runner_partial_raises(concurrency):
|
|
|
140
171
|
job_count = 50
|
|
141
172
|
jobs = [{"id": i} for i in range(job_count)]
|
|
142
173
|
|
|
143
|
-
runner = AsyncJobRunner(concurrency=concurrency)
|
|
144
|
-
|
|
145
174
|
ids_to_fail = set([10, 25])
|
|
146
175
|
|
|
147
176
|
def failure_fn(job):
|
|
@@ -152,6 +181,12 @@ async def test_async_job_runner_partial_raises(concurrency):
|
|
|
152
181
|
# fake run_job that fails
|
|
153
182
|
mock_run_job_partial_success = AsyncMock(side_effect=failure_fn)
|
|
154
183
|
|
|
184
|
+
runner = AsyncJobRunner(
|
|
185
|
+
concurrency=concurrency,
|
|
186
|
+
jobs=jobs,
|
|
187
|
+
run_job_fn=mock_run_job_partial_success,
|
|
188
|
+
)
|
|
189
|
+
|
|
155
190
|
# generate all the values we expect to see in progress updates
|
|
156
191
|
complete_values_expected = set([i for i in range(job_count - len(ids_to_fail) + 1)])
|
|
157
192
|
errors_values_expected = set([i for i in range(len(ids_to_fail) + 1)])
|
|
@@ -164,7 +199,7 @@ async def test_async_job_runner_partial_raises(concurrency):
|
|
|
164
199
|
errors_values_actual = set()
|
|
165
200
|
|
|
166
201
|
# Expect the status updates in order, and 1 for each job
|
|
167
|
-
async for progress in runner.run(
|
|
202
|
+
async for progress in runner.run():
|
|
168
203
|
updates.append(progress)
|
|
169
204
|
complete_values_actual.add(progress.complete)
|
|
170
205
|
errors_values_actual.add(progress.errors)
|
|
@@ -184,9 +219,13 @@ async def test_async_job_runner_partial_raises(concurrency):
|
|
|
184
219
|
|
|
185
220
|
@pytest.mark.parametrize("concurrency", [1, 25])
|
|
186
221
|
@pytest.mark.asyncio
|
|
187
|
-
async def test_async_job_runner_cancelled(concurrency):
|
|
188
|
-
runner = AsyncJobRunner(concurrency=concurrency)
|
|
222
|
+
async def test_async_job_runner_cancelled(concurrency, mock_async_run_job_fn_success):
|
|
189
223
|
jobs = [{"id": i} for i in range(10)]
|
|
224
|
+
runner = AsyncJobRunner(
|
|
225
|
+
concurrency=concurrency,
|
|
226
|
+
jobs=jobs,
|
|
227
|
+
run_job_fn=mock_async_run_job_fn_success,
|
|
228
|
+
)
|
|
190
229
|
|
|
191
230
|
with patch.object(
|
|
192
231
|
runner,
|
|
@@ -195,5 +234,82 @@ async def test_async_job_runner_cancelled(concurrency):
|
|
|
195
234
|
):
|
|
196
235
|
# if an exception is raised in the task, we should see it bubble up
|
|
197
236
|
with pytest.raises(Exception, match="run_worker raised an exception"):
|
|
198
|
-
async for _ in runner.run(
|
|
237
|
+
async for _ in runner.run():
|
|
199
238
|
pass
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
@pytest.mark.parametrize("concurrency", [1, 25])
|
|
242
|
+
@pytest.mark.asyncio
|
|
243
|
+
async def test_async_job_runner_observers(concurrency):
|
|
244
|
+
class MockAsyncJobRunnerObserver(AsyncJobRunnerObserver[dict[str, int]]):
|
|
245
|
+
def __init__(self):
|
|
246
|
+
self.on_error_calls = []
|
|
247
|
+
self.on_success_calls = []
|
|
248
|
+
|
|
249
|
+
async def on_error(self, job: dict[str, int], error: Exception):
|
|
250
|
+
self.on_error_calls.append((job, error))
|
|
251
|
+
|
|
252
|
+
async def on_success(self, job: dict[str, int]):
|
|
253
|
+
self.on_success_calls.append(job)
|
|
254
|
+
|
|
255
|
+
mock_observer_a = MockAsyncJobRunnerObserver()
|
|
256
|
+
mock_observer_b = MockAsyncJobRunnerObserver()
|
|
257
|
+
|
|
258
|
+
jobs = [{"id": i} for i in range(10)]
|
|
259
|
+
|
|
260
|
+
async def run_job_fn(job: dict[str, int]) -> bool:
|
|
261
|
+
# we simulate the job 5 and 6 crashing, which should trigger the observers on_error handlers
|
|
262
|
+
if job["id"] == 5 or job["id"] == 6:
|
|
263
|
+
raise ValueError(f"job failed unexpectedly {job['id']}")
|
|
264
|
+
return True
|
|
265
|
+
|
|
266
|
+
runner = AsyncJobRunner(
|
|
267
|
+
concurrency=concurrency,
|
|
268
|
+
jobs=jobs,
|
|
269
|
+
run_job_fn=run_job_fn,
|
|
270
|
+
observers=[mock_observer_a, mock_observer_b],
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
async for _ in runner.run():
|
|
274
|
+
pass
|
|
275
|
+
|
|
276
|
+
assert len(mock_observer_a.on_error_calls) == 2
|
|
277
|
+
assert len(mock_observer_b.on_error_calls) == 2
|
|
278
|
+
|
|
279
|
+
# not necessarily in order, but we should have seen both 5 and 6
|
|
280
|
+
assert len(mock_observer_a.on_success_calls) == 8
|
|
281
|
+
assert len(mock_observer_b.on_success_calls) == 8
|
|
282
|
+
|
|
283
|
+
# check that 5 and 6 are in the error calls
|
|
284
|
+
for job_idx in [5, 6]:
|
|
285
|
+
# check that 5 and 6 are in the error calls for both observers
|
|
286
|
+
assert any(call[0] == jobs[job_idx] for call in mock_observer_a.on_error_calls)
|
|
287
|
+
assert any(call[0] == jobs[job_idx] for call in mock_observer_b.on_error_calls)
|
|
288
|
+
|
|
289
|
+
# check that the error is the correct exception
|
|
290
|
+
assert (
|
|
291
|
+
str(mock_observer_a.on_error_calls[0][1]) == "job failed unexpectedly 5"
|
|
292
|
+
or str(mock_observer_a.on_error_calls[1][1]) == "job failed unexpectedly 6"
|
|
293
|
+
)
|
|
294
|
+
assert (
|
|
295
|
+
str(mock_observer_b.on_error_calls[0][1]) == "job failed unexpectedly 5"
|
|
296
|
+
or str(mock_observer_b.on_error_calls[1][1]) == "job failed unexpectedly 6"
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
# check that 5 and 6 are not in the success calls for both observers
|
|
300
|
+
assert not any(
|
|
301
|
+
call == jobs[job_idx] for call in mock_observer_a.on_success_calls
|
|
302
|
+
)
|
|
303
|
+
assert not any(
|
|
304
|
+
call == jobs[job_idx] for call in mock_observer_b.on_success_calls
|
|
305
|
+
)
|
|
306
|
+
|
|
307
|
+
# check that the other jobs are in the success calls for both observers
|
|
308
|
+
for job_idx in range(10):
|
|
309
|
+
if job_idx not in [5, 6]:
|
|
310
|
+
assert any(
|
|
311
|
+
call == jobs[job_idx] for call in mock_observer_a.on_success_calls
|
|
312
|
+
)
|
|
313
|
+
assert any(
|
|
314
|
+
call == jobs[job_idx] for call in mock_observer_b.on_success_calls
|
|
315
|
+
)
|
kiln_ai/utils/test_config.py
CHANGED
|
@@ -6,7 +6,7 @@ from unittest.mock import patch
|
|
|
6
6
|
import pytest
|
|
7
7
|
import yaml
|
|
8
8
|
|
|
9
|
-
from kiln_ai.utils.config import Config, ConfigProperty, _get_user_id
|
|
9
|
+
from kiln_ai.utils.config import MCP_SECRETS_KEY, Config, ConfigProperty, _get_user_id
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
@pytest.fixture
|
|
@@ -322,3 +322,140 @@ def test_update_settings_thread_safety(config_with_yaml):
|
|
|
322
322
|
|
|
323
323
|
assert not exceptions
|
|
324
324
|
assert config.int_property in range(5)
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
def test_mcp_secrets_property():
|
|
328
|
+
"""Test mcp_secrets configuration property"""
|
|
329
|
+
config = Config.shared()
|
|
330
|
+
|
|
331
|
+
# Initially should be None/empty
|
|
332
|
+
assert config.mcp_secrets is None
|
|
333
|
+
|
|
334
|
+
# Set some secrets
|
|
335
|
+
secrets = {
|
|
336
|
+
"server1::Authorization": "Bearer token123",
|
|
337
|
+
"server1::X-API-Key": "api-key-456",
|
|
338
|
+
"server2::Token": "secret-token",
|
|
339
|
+
}
|
|
340
|
+
config.mcp_secrets = secrets
|
|
341
|
+
|
|
342
|
+
# Verify they are stored correctly
|
|
343
|
+
assert config.mcp_secrets == secrets
|
|
344
|
+
assert config.mcp_secrets["server1::Authorization"] == "Bearer token123"
|
|
345
|
+
assert config.mcp_secrets["server1::X-API-Key"] == "api-key-456"
|
|
346
|
+
assert config.mcp_secrets["server2::Token"] == "secret-token"
|
|
347
|
+
|
|
348
|
+
|
|
349
|
+
def test_mcp_secrets_sensitive_hiding():
|
|
350
|
+
"""Test that mcp_secrets are hidden when hide_sensitive=True"""
|
|
351
|
+
config = Config.shared()
|
|
352
|
+
|
|
353
|
+
# Set some secrets
|
|
354
|
+
secrets = {
|
|
355
|
+
"server1::Authorization": "Bearer secret123",
|
|
356
|
+
"server2::X-API-Key": "secret-key",
|
|
357
|
+
}
|
|
358
|
+
config.mcp_secrets = secrets
|
|
359
|
+
|
|
360
|
+
# Test without hiding sensitive data
|
|
361
|
+
visible_settings = config.settings(hide_sensitive=False)
|
|
362
|
+
assert MCP_SECRETS_KEY in visible_settings
|
|
363
|
+
assert visible_settings[MCP_SECRETS_KEY] == secrets
|
|
364
|
+
|
|
365
|
+
# Test with hiding sensitive data
|
|
366
|
+
hidden_settings = config.settings(hide_sensitive=True)
|
|
367
|
+
assert MCP_SECRETS_KEY in hidden_settings
|
|
368
|
+
assert hidden_settings[MCP_SECRETS_KEY] == "[hidden]"
|
|
369
|
+
|
|
370
|
+
|
|
371
|
+
def test_mcp_secrets_persistence(mock_yaml_file):
|
|
372
|
+
"""Test that mcp_secrets are persisted to YAML correctly"""
|
|
373
|
+
with patch(
|
|
374
|
+
"kiln_ai.utils.config.Config.settings_path",
|
|
375
|
+
return_value=mock_yaml_file,
|
|
376
|
+
):
|
|
377
|
+
config = Config()
|
|
378
|
+
|
|
379
|
+
# Set some secrets
|
|
380
|
+
secrets = {
|
|
381
|
+
"server1::Authorization": "Bearer persist123",
|
|
382
|
+
"server2::Token": "persist-token",
|
|
383
|
+
}
|
|
384
|
+
config.mcp_secrets = secrets
|
|
385
|
+
|
|
386
|
+
# Check that the value was saved to the YAML file
|
|
387
|
+
with open(mock_yaml_file, "r") as f:
|
|
388
|
+
saved_settings = yaml.safe_load(f)
|
|
389
|
+
assert saved_settings[MCP_SECRETS_KEY] == secrets
|
|
390
|
+
|
|
391
|
+
# Create a new config instance to test loading from YAML
|
|
392
|
+
new_config = Config()
|
|
393
|
+
|
|
394
|
+
# Check that the value is loaded from YAML
|
|
395
|
+
assert new_config.mcp_secrets == secrets
|
|
396
|
+
|
|
397
|
+
|
|
398
|
+
def test_mcp_secrets_get_value():
|
|
399
|
+
"""Test that mcp_secrets can be retrieved using get_value method"""
|
|
400
|
+
config = Config.shared()
|
|
401
|
+
|
|
402
|
+
# Initially should be None
|
|
403
|
+
assert config.get_value(MCP_SECRETS_KEY) is None
|
|
404
|
+
|
|
405
|
+
# Set some secrets
|
|
406
|
+
secrets = {"server::key": "value"}
|
|
407
|
+
config.mcp_secrets = secrets
|
|
408
|
+
|
|
409
|
+
# Should be retrievable via get_value
|
|
410
|
+
assert config.get_value(MCP_SECRETS_KEY) == secrets
|
|
411
|
+
|
|
412
|
+
|
|
413
|
+
def test_mcp_secrets_update_settings():
|
|
414
|
+
"""Test updating mcp_secrets using update_settings method"""
|
|
415
|
+
config = Config.shared()
|
|
416
|
+
|
|
417
|
+
# Set initial secrets
|
|
418
|
+
initial_secrets = {"server1::key1": "value1"}
|
|
419
|
+
config.update_settings({MCP_SECRETS_KEY: initial_secrets})
|
|
420
|
+
assert config.mcp_secrets == initial_secrets
|
|
421
|
+
|
|
422
|
+
# Update with new secrets (should replace, not merge)
|
|
423
|
+
new_secrets = {
|
|
424
|
+
"server1::key1": "updated_value1",
|
|
425
|
+
"server2::key2": "value2",
|
|
426
|
+
}
|
|
427
|
+
config.update_settings({MCP_SECRETS_KEY: new_secrets})
|
|
428
|
+
assert config.mcp_secrets == new_secrets
|
|
429
|
+
assert config.mcp_secrets["server1::key1"] == "updated_value1"
|
|
430
|
+
assert config.mcp_secrets["server2::key2"] == "value2"
|
|
431
|
+
|
|
432
|
+
|
|
433
|
+
def test_mcp_secrets_empty_dict():
|
|
434
|
+
"""Test mcp_secrets with empty dict"""
|
|
435
|
+
config = Config.shared()
|
|
436
|
+
|
|
437
|
+
# Set empty dict
|
|
438
|
+
config.mcp_secrets = {}
|
|
439
|
+
assert config.mcp_secrets == {}
|
|
440
|
+
|
|
441
|
+
# Should still be dict type, not None
|
|
442
|
+
assert isinstance(config.mcp_secrets, dict)
|
|
443
|
+
|
|
444
|
+
|
|
445
|
+
def test_mcp_secrets_type_validation():
|
|
446
|
+
"""Test that mcp_secrets enforces dict[str, str] type"""
|
|
447
|
+
config = Config.shared()
|
|
448
|
+
|
|
449
|
+
# Valid dict[str, str]
|
|
450
|
+
valid_secrets = {"server::key": "value"}
|
|
451
|
+
config.mcp_secrets = valid_secrets
|
|
452
|
+
assert config.mcp_secrets == valid_secrets
|
|
453
|
+
|
|
454
|
+
# The config system applies type conversion when retrieving values
|
|
455
|
+
mixed_types = {"server::key": 123} # int value
|
|
456
|
+
config.mcp_secrets = mixed_types
|
|
457
|
+
# The type conversion happens when the value is retrieved, not when set
|
|
458
|
+
# So the underlying storage may preserve the original type
|
|
459
|
+
assert config.mcp_secrets == mixed_types or config.mcp_secrets == {
|
|
460
|
+
"server::key": "123"
|
|
461
|
+
}
|
|
@@ -0,0 +1,142 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
|
|
5
|
+
from kiln_ai.utils.env import temporary_env
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class TestTemporaryEnv:
|
|
9
|
+
def test_set_new_env_var(self):
|
|
10
|
+
"""Test setting a new environment variable that doesn't exist."""
|
|
11
|
+
var_name = "TEST_NEW_VAR"
|
|
12
|
+
test_value = "test_value"
|
|
13
|
+
|
|
14
|
+
# Ensure the variable doesn't exist initially
|
|
15
|
+
assert var_name not in os.environ
|
|
16
|
+
|
|
17
|
+
with temporary_env(var_name, test_value):
|
|
18
|
+
assert os.environ[var_name] == test_value
|
|
19
|
+
|
|
20
|
+
# Verify it's removed after context
|
|
21
|
+
assert var_name not in os.environ
|
|
22
|
+
|
|
23
|
+
def test_modify_existing_env_var(self):
|
|
24
|
+
"""Test modifying an existing environment variable."""
|
|
25
|
+
var_name = "TEST_EXISTING_VAR"
|
|
26
|
+
original_value = "original_value"
|
|
27
|
+
new_value = "new_value"
|
|
28
|
+
|
|
29
|
+
# Set up initial state
|
|
30
|
+
os.environ[var_name] = original_value
|
|
31
|
+
|
|
32
|
+
with temporary_env(var_name, new_value):
|
|
33
|
+
assert os.environ[var_name] == new_value
|
|
34
|
+
|
|
35
|
+
# Verify original value is restored
|
|
36
|
+
assert os.environ[var_name] == original_value
|
|
37
|
+
|
|
38
|
+
def test_restore_nonexistent_var(self):
|
|
39
|
+
"""Test that a variable that didn't exist is properly removed."""
|
|
40
|
+
var_name = "TEST_NONEXISTENT_VAR"
|
|
41
|
+
test_value = "test_value"
|
|
42
|
+
|
|
43
|
+
# Ensure the variable doesn't exist initially
|
|
44
|
+
if var_name in os.environ:
|
|
45
|
+
del os.environ[var_name]
|
|
46
|
+
|
|
47
|
+
with temporary_env(var_name, test_value):
|
|
48
|
+
assert os.environ[var_name] == test_value
|
|
49
|
+
|
|
50
|
+
# Verify it's removed after context
|
|
51
|
+
assert var_name not in os.environ
|
|
52
|
+
|
|
53
|
+
def test_exception_handling(self):
|
|
54
|
+
"""Test that environment is restored even when an exception occurs."""
|
|
55
|
+
var_name = "TEST_EXCEPTION_VAR"
|
|
56
|
+
original_value = "original_value"
|
|
57
|
+
new_value = "new_value"
|
|
58
|
+
|
|
59
|
+
# Set up initial state
|
|
60
|
+
os.environ[var_name] = original_value
|
|
61
|
+
|
|
62
|
+
with pytest.raises(ValueError):
|
|
63
|
+
with temporary_env(var_name, new_value):
|
|
64
|
+
assert os.environ[var_name] == new_value
|
|
65
|
+
raise ValueError("Test exception")
|
|
66
|
+
|
|
67
|
+
# Verify original value is restored even after exception
|
|
68
|
+
assert os.environ[var_name] == original_value
|
|
69
|
+
|
|
70
|
+
def test_exception_handling_new_var(self):
|
|
71
|
+
"""Test that new variable is removed even when an exception occurs."""
|
|
72
|
+
var_name = "TEST_EXCEPTION_NEW_VAR"
|
|
73
|
+
test_value = "test_value"
|
|
74
|
+
|
|
75
|
+
# Ensure the variable doesn't exist initially
|
|
76
|
+
if var_name in os.environ:
|
|
77
|
+
del os.environ[var_name]
|
|
78
|
+
|
|
79
|
+
with pytest.raises(RuntimeError):
|
|
80
|
+
with temporary_env(var_name, test_value):
|
|
81
|
+
assert os.environ[var_name] == test_value
|
|
82
|
+
raise RuntimeError("Test exception")
|
|
83
|
+
|
|
84
|
+
# Verify variable is removed even after exception
|
|
85
|
+
assert var_name not in os.environ
|
|
86
|
+
|
|
87
|
+
def test_nested_context_managers(self):
|
|
88
|
+
"""Test using multiple temporary_env context managers."""
|
|
89
|
+
var1 = "TEST_NESTED_VAR1"
|
|
90
|
+
var2 = "TEST_NESTED_VAR2"
|
|
91
|
+
value1 = "value1"
|
|
92
|
+
value2 = "value2"
|
|
93
|
+
|
|
94
|
+
# Set up initial state
|
|
95
|
+
os.environ[var1] = "original1"
|
|
96
|
+
if var2 in os.environ:
|
|
97
|
+
del os.environ[var2]
|
|
98
|
+
|
|
99
|
+
with temporary_env(var1, value1):
|
|
100
|
+
assert os.environ[var1] == value1
|
|
101
|
+
|
|
102
|
+
with temporary_env(var2, value2):
|
|
103
|
+
assert os.environ[var1] == value1
|
|
104
|
+
assert os.environ[var2] == value2
|
|
105
|
+
|
|
106
|
+
# Inner context should be cleaned up
|
|
107
|
+
assert var2 not in os.environ
|
|
108
|
+
assert os.environ[var1] == value1
|
|
109
|
+
|
|
110
|
+
# Both contexts should be cleaned up
|
|
111
|
+
assert os.environ[var1] == "original1"
|
|
112
|
+
assert var2 not in os.environ
|
|
113
|
+
|
|
114
|
+
def test_empty_string_value(self):
|
|
115
|
+
"""Test setting an empty string value."""
|
|
116
|
+
var_name = "TEST_EMPTY_VAR"
|
|
117
|
+
test_value = ""
|
|
118
|
+
|
|
119
|
+
with temporary_env(var_name, test_value):
|
|
120
|
+
assert os.environ[var_name] == test_value
|
|
121
|
+
|
|
122
|
+
assert var_name not in os.environ
|
|
123
|
+
|
|
124
|
+
def test_none_value_handling(self):
|
|
125
|
+
"""Test that None values are handled properly."""
|
|
126
|
+
var_name = "TEST_NONE_VAR"
|
|
127
|
+
test_value = "test_value"
|
|
128
|
+
|
|
129
|
+
with temporary_env(var_name, test_value):
|
|
130
|
+
assert os.environ[var_name] == test_value
|
|
131
|
+
|
|
132
|
+
assert var_name not in os.environ
|
|
133
|
+
|
|
134
|
+
def test_unicode_value(self):
|
|
135
|
+
"""Test setting unicode values."""
|
|
136
|
+
var_name = "TEST_UNICODE_VAR"
|
|
137
|
+
test_value = "测试值 🚀"
|
|
138
|
+
|
|
139
|
+
with temporary_env(var_name, test_value):
|
|
140
|
+
assert os.environ[var_name] == test_value
|
|
141
|
+
|
|
142
|
+
assert var_name not in os.environ
|