kiln-ai 0.20.1__py3-none-any.whl → 0.22.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +6 -0
- kiln_ai/adapters/adapter_registry.py +43 -226
- kiln_ai/adapters/chunkers/__init__.py +13 -0
- kiln_ai/adapters/chunkers/base_chunker.py +42 -0
- kiln_ai/adapters/chunkers/chunker_registry.py +16 -0
- kiln_ai/adapters/chunkers/fixed_window_chunker.py +39 -0
- kiln_ai/adapters/chunkers/helpers.py +23 -0
- kiln_ai/adapters/chunkers/test_base_chunker.py +63 -0
- kiln_ai/adapters/chunkers/test_chunker_registry.py +28 -0
- kiln_ai/adapters/chunkers/test_fixed_window_chunker.py +346 -0
- kiln_ai/adapters/chunkers/test_helpers.py +75 -0
- kiln_ai/adapters/data_gen/test_data_gen_task.py +9 -3
- kiln_ai/adapters/embedding/__init__.py +0 -0
- kiln_ai/adapters/embedding/base_embedding_adapter.py +44 -0
- kiln_ai/adapters/embedding/embedding_registry.py +32 -0
- kiln_ai/adapters/embedding/litellm_embedding_adapter.py +199 -0
- kiln_ai/adapters/embedding/test_base_embedding_adapter.py +283 -0
- kiln_ai/adapters/embedding/test_embedding_registry.py +166 -0
- kiln_ai/adapters/embedding/test_litellm_embedding_adapter.py +1149 -0
- kiln_ai/adapters/eval/eval_runner.py +6 -2
- kiln_ai/adapters/eval/test_base_eval.py +1 -3
- kiln_ai/adapters/eval/test_g_eval.py +1 -1
- kiln_ai/adapters/extractors/__init__.py +18 -0
- kiln_ai/adapters/extractors/base_extractor.py +72 -0
- kiln_ai/adapters/extractors/encoding.py +20 -0
- kiln_ai/adapters/extractors/extractor_registry.py +44 -0
- kiln_ai/adapters/extractors/extractor_runner.py +112 -0
- kiln_ai/adapters/extractors/litellm_extractor.py +406 -0
- kiln_ai/adapters/extractors/test_base_extractor.py +244 -0
- kiln_ai/adapters/extractors/test_encoding.py +54 -0
- kiln_ai/adapters/extractors/test_extractor_registry.py +181 -0
- kiln_ai/adapters/extractors/test_extractor_runner.py +181 -0
- kiln_ai/adapters/extractors/test_litellm_extractor.py +1290 -0
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +2 -2
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +2 -6
- kiln_ai/adapters/fine_tune/test_together_finetune.py +2 -6
- kiln_ai/adapters/ml_embedding_model_list.py +494 -0
- kiln_ai/adapters/ml_model_list.py +876 -18
- kiln_ai/adapters/model_adapters/litellm_adapter.py +40 -75
- kiln_ai/adapters/model_adapters/test_litellm_adapter.py +79 -1
- kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py +119 -5
- kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +9 -3
- kiln_ai/adapters/model_adapters/test_structured_output.py +9 -10
- kiln_ai/adapters/ollama_tools.py +69 -12
- kiln_ai/adapters/provider_tools.py +190 -46
- kiln_ai/adapters/rag/deduplication.py +49 -0
- kiln_ai/adapters/rag/progress.py +252 -0
- kiln_ai/adapters/rag/rag_runners.py +844 -0
- kiln_ai/adapters/rag/test_deduplication.py +195 -0
- kiln_ai/adapters/rag/test_progress.py +785 -0
- kiln_ai/adapters/rag/test_rag_runners.py +2376 -0
- kiln_ai/adapters/remote_config.py +80 -8
- kiln_ai/adapters/test_adapter_registry.py +579 -86
- kiln_ai/adapters/test_ml_embedding_model_list.py +239 -0
- kiln_ai/adapters/test_ml_model_list.py +202 -0
- kiln_ai/adapters/test_ollama_tools.py +340 -1
- kiln_ai/adapters/test_prompt_builders.py +1 -1
- kiln_ai/adapters/test_provider_tools.py +199 -8
- kiln_ai/adapters/test_remote_config.py +551 -56
- kiln_ai/adapters/vector_store/__init__.py +1 -0
- kiln_ai/adapters/vector_store/base_vector_store_adapter.py +83 -0
- kiln_ai/adapters/vector_store/lancedb_adapter.py +389 -0
- kiln_ai/adapters/vector_store/test_base_vector_store.py +160 -0
- kiln_ai/adapters/vector_store/test_lancedb_adapter.py +1841 -0
- kiln_ai/adapters/vector_store/test_vector_store_registry.py +199 -0
- kiln_ai/adapters/vector_store/vector_store_registry.py +33 -0
- kiln_ai/datamodel/__init__.py +16 -13
- kiln_ai/datamodel/basemodel.py +201 -4
- kiln_ai/datamodel/chunk.py +158 -0
- kiln_ai/datamodel/datamodel_enums.py +27 -0
- kiln_ai/datamodel/embedding.py +64 -0
- kiln_ai/datamodel/external_tool_server.py +206 -54
- kiln_ai/datamodel/extraction.py +317 -0
- kiln_ai/datamodel/project.py +33 -1
- kiln_ai/datamodel/rag.py +79 -0
- kiln_ai/datamodel/task.py +5 -0
- kiln_ai/datamodel/task_output.py +41 -11
- kiln_ai/datamodel/test_attachment.py +649 -0
- kiln_ai/datamodel/test_basemodel.py +270 -14
- kiln_ai/datamodel/test_chunk_models.py +317 -0
- kiln_ai/datamodel/test_dataset_split.py +1 -1
- kiln_ai/datamodel/test_datasource.py +50 -0
- kiln_ai/datamodel/test_embedding_models.py +448 -0
- kiln_ai/datamodel/test_eval_model.py +6 -6
- kiln_ai/datamodel/test_external_tool_server.py +534 -152
- kiln_ai/datamodel/test_extraction_chunk.py +206 -0
- kiln_ai/datamodel/test_extraction_model.py +501 -0
- kiln_ai/datamodel/test_rag.py +641 -0
- kiln_ai/datamodel/test_task.py +35 -1
- kiln_ai/datamodel/test_tool_id.py +187 -1
- kiln_ai/datamodel/test_vector_store.py +320 -0
- kiln_ai/datamodel/tool_id.py +58 -0
- kiln_ai/datamodel/vector_store.py +141 -0
- kiln_ai/tools/base_tool.py +12 -3
- kiln_ai/tools/built_in_tools/math_tools.py +12 -4
- kiln_ai/tools/kiln_task_tool.py +158 -0
- kiln_ai/tools/mcp_server_tool.py +2 -2
- kiln_ai/tools/mcp_session_manager.py +51 -22
- kiln_ai/tools/rag_tools.py +164 -0
- kiln_ai/tools/test_kiln_task_tool.py +527 -0
- kiln_ai/tools/test_mcp_server_tool.py +4 -15
- kiln_ai/tools/test_mcp_session_manager.py +187 -227
- kiln_ai/tools/test_rag_tools.py +929 -0
- kiln_ai/tools/test_tool_registry.py +290 -7
- kiln_ai/tools/tool_registry.py +69 -16
- kiln_ai/utils/__init__.py +3 -0
- kiln_ai/utils/async_job_runner.py +62 -17
- kiln_ai/utils/config.py +2 -2
- kiln_ai/utils/env.py +15 -0
- kiln_ai/utils/filesystem.py +14 -0
- kiln_ai/utils/filesystem_cache.py +60 -0
- kiln_ai/utils/litellm.py +94 -0
- kiln_ai/utils/lock.py +100 -0
- kiln_ai/utils/mime_type.py +38 -0
- kiln_ai/utils/open_ai_types.py +19 -2
- kiln_ai/utils/pdf_utils.py +59 -0
- kiln_ai/utils/test_async_job_runner.py +151 -35
- kiln_ai/utils/test_env.py +142 -0
- kiln_ai/utils/test_filesystem_cache.py +316 -0
- kiln_ai/utils/test_litellm.py +206 -0
- kiln_ai/utils/test_lock.py +185 -0
- kiln_ai/utils/test_mime_type.py +66 -0
- kiln_ai/utils/test_open_ai_types.py +88 -12
- kiln_ai/utils/test_pdf_utils.py +86 -0
- kiln_ai/utils/test_uuid.py +111 -0
- kiln_ai/utils/test_validation.py +524 -0
- kiln_ai/utils/uuid.py +9 -0
- kiln_ai/utils/validation.py +90 -0
- {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/METADATA +9 -1
- kiln_ai-0.22.0.dist-info/RECORD +213 -0
- kiln_ai-0.20.1.dist-info/RECORD +0 -138
- {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -3,30 +3,51 @@ from unittest.mock import AsyncMock, patch
|
|
|
3
3
|
|
|
4
4
|
import pytest
|
|
5
5
|
|
|
6
|
-
from kiln_ai.utils.async_job_runner import
|
|
6
|
+
from kiln_ai.utils.async_job_runner import (
|
|
7
|
+
AsyncJobRunner,
|
|
8
|
+
AsyncJobRunnerObserver,
|
|
9
|
+
Progress,
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@pytest.fixture
|
|
14
|
+
def mock_async_run_job_fn_success():
|
|
15
|
+
return AsyncMock(return_value=True)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@pytest.fixture
|
|
19
|
+
def mock_async_run_job_fn_failure():
|
|
20
|
+
return AsyncMock(return_value=False)
|
|
7
21
|
|
|
8
22
|
|
|
9
23
|
@pytest.mark.parametrize("concurrency", [0, -1, -25])
|
|
10
|
-
def test_invalid_concurrency_raises(concurrency):
|
|
24
|
+
def test_invalid_concurrency_raises(concurrency, mock_async_run_job_fn_success):
|
|
11
25
|
with pytest.raises(ValueError):
|
|
12
|
-
AsyncJobRunner(
|
|
26
|
+
AsyncJobRunner(
|
|
27
|
+
concurrency=concurrency,
|
|
28
|
+
jobs=[],
|
|
29
|
+
run_job_fn=mock_async_run_job_fn_success,
|
|
30
|
+
)
|
|
13
31
|
|
|
14
32
|
|
|
15
33
|
# Test with and without concurrency
|
|
16
34
|
@pytest.mark.parametrize("concurrency", [1, 25])
|
|
17
35
|
@pytest.mark.asyncio
|
|
18
|
-
async def test_async_job_runner_status_updates(
|
|
36
|
+
async def test_async_job_runner_status_updates(
|
|
37
|
+
concurrency, mock_async_run_job_fn_success
|
|
38
|
+
):
|
|
19
39
|
job_count = 50
|
|
20
40
|
jobs = [{"id": i} for i in range(job_count)]
|
|
21
41
|
|
|
22
|
-
runner = AsyncJobRunner(
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
42
|
+
runner = AsyncJobRunner(
|
|
43
|
+
concurrency=concurrency,
|
|
44
|
+
jobs=jobs,
|
|
45
|
+
run_job_fn=mock_async_run_job_fn_success,
|
|
46
|
+
)
|
|
26
47
|
|
|
27
48
|
# Expect the status updates in order, and 1 for each job
|
|
28
49
|
expected_completed_count = 0
|
|
29
|
-
async for progress in runner.run(
|
|
50
|
+
async for progress in runner.run():
|
|
30
51
|
assert progress.complete == expected_completed_count
|
|
31
52
|
expected_completed_count += 1
|
|
32
53
|
assert progress.errors == 0
|
|
@@ -36,26 +57,29 @@ async def test_async_job_runner_status_updates(concurrency):
|
|
|
36
57
|
assert expected_completed_count == job_count + 1
|
|
37
58
|
|
|
38
59
|
# Verify run_job was called for each job
|
|
39
|
-
assert
|
|
60
|
+
assert mock_async_run_job_fn_success.call_count == job_count
|
|
40
61
|
|
|
41
62
|
# Verify run_job was called with the correct arguments
|
|
42
63
|
for i in range(job_count):
|
|
43
|
-
|
|
64
|
+
mock_async_run_job_fn_success.assert_any_await(jobs[i])
|
|
44
65
|
|
|
45
66
|
|
|
46
67
|
# Test with and without concurrency
|
|
47
68
|
@pytest.mark.parametrize("concurrency", [1, 25])
|
|
48
69
|
@pytest.mark.asyncio
|
|
49
|
-
async def test_async_job_runner_status_updates_empty_job_list(
|
|
70
|
+
async def test_async_job_runner_status_updates_empty_job_list(
|
|
71
|
+
concurrency, mock_async_run_job_fn_success
|
|
72
|
+
):
|
|
50
73
|
empty_job_list = []
|
|
51
74
|
|
|
52
|
-
runner = AsyncJobRunner(
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
75
|
+
runner = AsyncJobRunner(
|
|
76
|
+
concurrency=concurrency,
|
|
77
|
+
jobs=empty_job_list,
|
|
78
|
+
run_job_fn=mock_async_run_job_fn_success,
|
|
79
|
+
)
|
|
56
80
|
|
|
57
81
|
updates: List[Progress] = []
|
|
58
|
-
async for progress in runner.run(
|
|
82
|
+
async for progress in runner.run():
|
|
59
83
|
updates.append(progress)
|
|
60
84
|
|
|
61
85
|
# Verify last status update was complete
|
|
@@ -66,23 +90,26 @@ async def test_async_job_runner_status_updates_empty_job_list(concurrency):
|
|
|
66
90
|
assert updates[0].total == 0
|
|
67
91
|
|
|
68
92
|
# Verify run_job was called for each job
|
|
69
|
-
assert
|
|
93
|
+
assert mock_async_run_job_fn_success.call_count == 0
|
|
70
94
|
|
|
71
95
|
|
|
72
96
|
@pytest.mark.parametrize("concurrency", [1, 25])
|
|
73
97
|
@pytest.mark.asyncio
|
|
74
|
-
async def test_async_job_runner_all_failures(
|
|
98
|
+
async def test_async_job_runner_all_failures(
|
|
99
|
+
concurrency, mock_async_run_job_fn_failure
|
|
100
|
+
):
|
|
75
101
|
job_count = 50
|
|
76
102
|
jobs = [{"id": i} for i in range(job_count)]
|
|
77
103
|
|
|
78
|
-
runner = AsyncJobRunner(
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
104
|
+
runner = AsyncJobRunner(
|
|
105
|
+
concurrency=concurrency,
|
|
106
|
+
jobs=jobs,
|
|
107
|
+
run_job_fn=mock_async_run_job_fn_failure,
|
|
108
|
+
)
|
|
82
109
|
|
|
83
110
|
# Expect the status updates in order, and 1 for each job
|
|
84
111
|
expected_error_count = 0
|
|
85
|
-
async for progress in runner.run(
|
|
112
|
+
async for progress in runner.run():
|
|
86
113
|
assert progress.complete == 0
|
|
87
114
|
assert progress.errors == expected_error_count
|
|
88
115
|
expected_error_count += 1
|
|
@@ -92,11 +119,11 @@ async def test_async_job_runner_all_failures(concurrency):
|
|
|
92
119
|
assert expected_error_count == job_count + 1
|
|
93
120
|
|
|
94
121
|
# Verify run_job was called for each job
|
|
95
|
-
assert
|
|
122
|
+
assert mock_async_run_job_fn_failure.call_count == job_count
|
|
96
123
|
|
|
97
124
|
# Verify run_job was called with the correct arguments
|
|
98
125
|
for i in range(job_count):
|
|
99
|
-
|
|
126
|
+
mock_async_run_job_fn_failure.assert_any_await(jobs[i])
|
|
100
127
|
|
|
101
128
|
|
|
102
129
|
@pytest.mark.parametrize("concurrency", [1, 25])
|
|
@@ -108,16 +135,20 @@ async def test_async_job_runner_partial_failures(concurrency):
|
|
|
108
135
|
# we want to fail on some jobs and succeed on others
|
|
109
136
|
jobs_to_fail = set([0, 2, 4, 6, 8, 20, 25])
|
|
110
137
|
|
|
111
|
-
runner = AsyncJobRunner(concurrency=concurrency)
|
|
112
|
-
|
|
113
138
|
# fake run_job that fails
|
|
114
139
|
mock_run_job_partial_success = AsyncMock(
|
|
115
140
|
# return True for jobs that should succeed
|
|
116
141
|
side_effect=lambda job: job["id"] not in jobs_to_fail
|
|
117
142
|
)
|
|
118
143
|
|
|
144
|
+
runner = AsyncJobRunner(
|
|
145
|
+
concurrency=concurrency,
|
|
146
|
+
jobs=jobs,
|
|
147
|
+
run_job_fn=mock_run_job_partial_success,
|
|
148
|
+
)
|
|
149
|
+
|
|
119
150
|
# Expect the status updates in order, and 1 for each job
|
|
120
|
-
async for progress in runner.run(
|
|
151
|
+
async for progress in runner.run():
|
|
121
152
|
assert progress.total == job_count
|
|
122
153
|
|
|
123
154
|
# Verify last status update was complete
|
|
@@ -140,8 +171,6 @@ async def test_async_job_runner_partial_raises(concurrency):
|
|
|
140
171
|
job_count = 50
|
|
141
172
|
jobs = [{"id": i} for i in range(job_count)]
|
|
142
173
|
|
|
143
|
-
runner = AsyncJobRunner(concurrency=concurrency)
|
|
144
|
-
|
|
145
174
|
ids_to_fail = set([10, 25])
|
|
146
175
|
|
|
147
176
|
def failure_fn(job):
|
|
@@ -152,6 +181,12 @@ async def test_async_job_runner_partial_raises(concurrency):
|
|
|
152
181
|
# fake run_job that fails
|
|
153
182
|
mock_run_job_partial_success = AsyncMock(side_effect=failure_fn)
|
|
154
183
|
|
|
184
|
+
runner = AsyncJobRunner(
|
|
185
|
+
concurrency=concurrency,
|
|
186
|
+
jobs=jobs,
|
|
187
|
+
run_job_fn=mock_run_job_partial_success,
|
|
188
|
+
)
|
|
189
|
+
|
|
155
190
|
# generate all the values we expect to see in progress updates
|
|
156
191
|
complete_values_expected = set([i for i in range(job_count - len(ids_to_fail) + 1)])
|
|
157
192
|
errors_values_expected = set([i for i in range(len(ids_to_fail) + 1)])
|
|
@@ -164,7 +199,7 @@ async def test_async_job_runner_partial_raises(concurrency):
|
|
|
164
199
|
errors_values_actual = set()
|
|
165
200
|
|
|
166
201
|
# Expect the status updates in order, and 1 for each job
|
|
167
|
-
async for progress in runner.run(
|
|
202
|
+
async for progress in runner.run():
|
|
168
203
|
updates.append(progress)
|
|
169
204
|
complete_values_actual.add(progress.complete)
|
|
170
205
|
errors_values_actual.add(progress.errors)
|
|
@@ -184,9 +219,13 @@ async def test_async_job_runner_partial_raises(concurrency):
|
|
|
184
219
|
|
|
185
220
|
@pytest.mark.parametrize("concurrency", [1, 25])
|
|
186
221
|
@pytest.mark.asyncio
|
|
187
|
-
async def test_async_job_runner_cancelled(concurrency):
|
|
188
|
-
runner = AsyncJobRunner(concurrency=concurrency)
|
|
222
|
+
async def test_async_job_runner_cancelled(concurrency, mock_async_run_job_fn_success):
|
|
189
223
|
jobs = [{"id": i} for i in range(10)]
|
|
224
|
+
runner = AsyncJobRunner(
|
|
225
|
+
concurrency=concurrency,
|
|
226
|
+
jobs=jobs,
|
|
227
|
+
run_job_fn=mock_async_run_job_fn_success,
|
|
228
|
+
)
|
|
190
229
|
|
|
191
230
|
with patch.object(
|
|
192
231
|
runner,
|
|
@@ -195,5 +234,82 @@ async def test_async_job_runner_cancelled(concurrency):
|
|
|
195
234
|
):
|
|
196
235
|
# if an exception is raised in the task, we should see it bubble up
|
|
197
236
|
with pytest.raises(Exception, match="run_worker raised an exception"):
|
|
198
|
-
async for _ in runner.run(
|
|
237
|
+
async for _ in runner.run():
|
|
199
238
|
pass
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
@pytest.mark.parametrize("concurrency", [1, 25])
|
|
242
|
+
@pytest.mark.asyncio
|
|
243
|
+
async def test_async_job_runner_observers(concurrency):
|
|
244
|
+
class MockAsyncJobRunnerObserver(AsyncJobRunnerObserver[dict[str, int]]):
|
|
245
|
+
def __init__(self):
|
|
246
|
+
self.on_error_calls = []
|
|
247
|
+
self.on_success_calls = []
|
|
248
|
+
|
|
249
|
+
async def on_error(self, job: dict[str, int], error: Exception):
|
|
250
|
+
self.on_error_calls.append((job, error))
|
|
251
|
+
|
|
252
|
+
async def on_success(self, job: dict[str, int]):
|
|
253
|
+
self.on_success_calls.append(job)
|
|
254
|
+
|
|
255
|
+
mock_observer_a = MockAsyncJobRunnerObserver()
|
|
256
|
+
mock_observer_b = MockAsyncJobRunnerObserver()
|
|
257
|
+
|
|
258
|
+
jobs = [{"id": i} for i in range(10)]
|
|
259
|
+
|
|
260
|
+
async def run_job_fn(job: dict[str, int]) -> bool:
|
|
261
|
+
# we simulate the job 5 and 6 crashing, which should trigger the observers on_error handlers
|
|
262
|
+
if job["id"] == 5 or job["id"] == 6:
|
|
263
|
+
raise ValueError(f"job failed unexpectedly {job['id']}")
|
|
264
|
+
return True
|
|
265
|
+
|
|
266
|
+
runner = AsyncJobRunner(
|
|
267
|
+
concurrency=concurrency,
|
|
268
|
+
jobs=jobs,
|
|
269
|
+
run_job_fn=run_job_fn,
|
|
270
|
+
observers=[mock_observer_a, mock_observer_b],
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
async for _ in runner.run():
|
|
274
|
+
pass
|
|
275
|
+
|
|
276
|
+
assert len(mock_observer_a.on_error_calls) == 2
|
|
277
|
+
assert len(mock_observer_b.on_error_calls) == 2
|
|
278
|
+
|
|
279
|
+
# not necessarily in order, but we should have seen both 5 and 6
|
|
280
|
+
assert len(mock_observer_a.on_success_calls) == 8
|
|
281
|
+
assert len(mock_observer_b.on_success_calls) == 8
|
|
282
|
+
|
|
283
|
+
# check that 5 and 6 are in the error calls
|
|
284
|
+
for job_idx in [5, 6]:
|
|
285
|
+
# check that 5 and 6 are in the error calls for both observers
|
|
286
|
+
assert any(call[0] == jobs[job_idx] for call in mock_observer_a.on_error_calls)
|
|
287
|
+
assert any(call[0] == jobs[job_idx] for call in mock_observer_b.on_error_calls)
|
|
288
|
+
|
|
289
|
+
# check that the error is the correct exception
|
|
290
|
+
assert (
|
|
291
|
+
str(mock_observer_a.on_error_calls[0][1]) == "job failed unexpectedly 5"
|
|
292
|
+
or str(mock_observer_a.on_error_calls[1][1]) == "job failed unexpectedly 6"
|
|
293
|
+
)
|
|
294
|
+
assert (
|
|
295
|
+
str(mock_observer_b.on_error_calls[0][1]) == "job failed unexpectedly 5"
|
|
296
|
+
or str(mock_observer_b.on_error_calls[1][1]) == "job failed unexpectedly 6"
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
# check that 5 and 6 are not in the success calls for both observers
|
|
300
|
+
assert not any(
|
|
301
|
+
call == jobs[job_idx] for call in mock_observer_a.on_success_calls
|
|
302
|
+
)
|
|
303
|
+
assert not any(
|
|
304
|
+
call == jobs[job_idx] for call in mock_observer_b.on_success_calls
|
|
305
|
+
)
|
|
306
|
+
|
|
307
|
+
# check that the other jobs are in the success calls for both observers
|
|
308
|
+
for job_idx in range(10):
|
|
309
|
+
if job_idx not in [5, 6]:
|
|
310
|
+
assert any(
|
|
311
|
+
call == jobs[job_idx] for call in mock_observer_a.on_success_calls
|
|
312
|
+
)
|
|
313
|
+
assert any(
|
|
314
|
+
call == jobs[job_idx] for call in mock_observer_b.on_success_calls
|
|
315
|
+
)
|
|
@@ -0,0 +1,142 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
|
|
5
|
+
from kiln_ai.utils.env import temporary_env
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class TestTemporaryEnv:
|
|
9
|
+
def test_set_new_env_var(self):
|
|
10
|
+
"""Test setting a new environment variable that doesn't exist."""
|
|
11
|
+
var_name = "TEST_NEW_VAR"
|
|
12
|
+
test_value = "test_value"
|
|
13
|
+
|
|
14
|
+
# Ensure the variable doesn't exist initially
|
|
15
|
+
assert var_name not in os.environ
|
|
16
|
+
|
|
17
|
+
with temporary_env(var_name, test_value):
|
|
18
|
+
assert os.environ[var_name] == test_value
|
|
19
|
+
|
|
20
|
+
# Verify it's removed after context
|
|
21
|
+
assert var_name not in os.environ
|
|
22
|
+
|
|
23
|
+
def test_modify_existing_env_var(self):
|
|
24
|
+
"""Test modifying an existing environment variable."""
|
|
25
|
+
var_name = "TEST_EXISTING_VAR"
|
|
26
|
+
original_value = "original_value"
|
|
27
|
+
new_value = "new_value"
|
|
28
|
+
|
|
29
|
+
# Set up initial state
|
|
30
|
+
os.environ[var_name] = original_value
|
|
31
|
+
|
|
32
|
+
with temporary_env(var_name, new_value):
|
|
33
|
+
assert os.environ[var_name] == new_value
|
|
34
|
+
|
|
35
|
+
# Verify original value is restored
|
|
36
|
+
assert os.environ[var_name] == original_value
|
|
37
|
+
|
|
38
|
+
def test_restore_nonexistent_var(self):
|
|
39
|
+
"""Test that a variable that didn't exist is properly removed."""
|
|
40
|
+
var_name = "TEST_NONEXISTENT_VAR"
|
|
41
|
+
test_value = "test_value"
|
|
42
|
+
|
|
43
|
+
# Ensure the variable doesn't exist initially
|
|
44
|
+
if var_name in os.environ:
|
|
45
|
+
del os.environ[var_name]
|
|
46
|
+
|
|
47
|
+
with temporary_env(var_name, test_value):
|
|
48
|
+
assert os.environ[var_name] == test_value
|
|
49
|
+
|
|
50
|
+
# Verify it's removed after context
|
|
51
|
+
assert var_name not in os.environ
|
|
52
|
+
|
|
53
|
+
def test_exception_handling(self):
|
|
54
|
+
"""Test that environment is restored even when an exception occurs."""
|
|
55
|
+
var_name = "TEST_EXCEPTION_VAR"
|
|
56
|
+
original_value = "original_value"
|
|
57
|
+
new_value = "new_value"
|
|
58
|
+
|
|
59
|
+
# Set up initial state
|
|
60
|
+
os.environ[var_name] = original_value
|
|
61
|
+
|
|
62
|
+
with pytest.raises(ValueError):
|
|
63
|
+
with temporary_env(var_name, new_value):
|
|
64
|
+
assert os.environ[var_name] == new_value
|
|
65
|
+
raise ValueError("Test exception")
|
|
66
|
+
|
|
67
|
+
# Verify original value is restored even after exception
|
|
68
|
+
assert os.environ[var_name] == original_value
|
|
69
|
+
|
|
70
|
+
def test_exception_handling_new_var(self):
|
|
71
|
+
"""Test that new variable is removed even when an exception occurs."""
|
|
72
|
+
var_name = "TEST_EXCEPTION_NEW_VAR"
|
|
73
|
+
test_value = "test_value"
|
|
74
|
+
|
|
75
|
+
# Ensure the variable doesn't exist initially
|
|
76
|
+
if var_name in os.environ:
|
|
77
|
+
del os.environ[var_name]
|
|
78
|
+
|
|
79
|
+
with pytest.raises(RuntimeError):
|
|
80
|
+
with temporary_env(var_name, test_value):
|
|
81
|
+
assert os.environ[var_name] == test_value
|
|
82
|
+
raise RuntimeError("Test exception")
|
|
83
|
+
|
|
84
|
+
# Verify variable is removed even after exception
|
|
85
|
+
assert var_name not in os.environ
|
|
86
|
+
|
|
87
|
+
def test_nested_context_managers(self):
|
|
88
|
+
"""Test using multiple temporary_env context managers."""
|
|
89
|
+
var1 = "TEST_NESTED_VAR1"
|
|
90
|
+
var2 = "TEST_NESTED_VAR2"
|
|
91
|
+
value1 = "value1"
|
|
92
|
+
value2 = "value2"
|
|
93
|
+
|
|
94
|
+
# Set up initial state
|
|
95
|
+
os.environ[var1] = "original1"
|
|
96
|
+
if var2 in os.environ:
|
|
97
|
+
del os.environ[var2]
|
|
98
|
+
|
|
99
|
+
with temporary_env(var1, value1):
|
|
100
|
+
assert os.environ[var1] == value1
|
|
101
|
+
|
|
102
|
+
with temporary_env(var2, value2):
|
|
103
|
+
assert os.environ[var1] == value1
|
|
104
|
+
assert os.environ[var2] == value2
|
|
105
|
+
|
|
106
|
+
# Inner context should be cleaned up
|
|
107
|
+
assert var2 not in os.environ
|
|
108
|
+
assert os.environ[var1] == value1
|
|
109
|
+
|
|
110
|
+
# Both contexts should be cleaned up
|
|
111
|
+
assert os.environ[var1] == "original1"
|
|
112
|
+
assert var2 not in os.environ
|
|
113
|
+
|
|
114
|
+
def test_empty_string_value(self):
|
|
115
|
+
"""Test setting an empty string value."""
|
|
116
|
+
var_name = "TEST_EMPTY_VAR"
|
|
117
|
+
test_value = ""
|
|
118
|
+
|
|
119
|
+
with temporary_env(var_name, test_value):
|
|
120
|
+
assert os.environ[var_name] == test_value
|
|
121
|
+
|
|
122
|
+
assert var_name not in os.environ
|
|
123
|
+
|
|
124
|
+
def test_none_value_handling(self):
|
|
125
|
+
"""Test that None values are handled properly."""
|
|
126
|
+
var_name = "TEST_NONE_VAR"
|
|
127
|
+
test_value = "test_value"
|
|
128
|
+
|
|
129
|
+
with temporary_env(var_name, test_value):
|
|
130
|
+
assert os.environ[var_name] == test_value
|
|
131
|
+
|
|
132
|
+
assert var_name not in os.environ
|
|
133
|
+
|
|
134
|
+
def test_unicode_value(self):
|
|
135
|
+
"""Test setting unicode values."""
|
|
136
|
+
var_name = "TEST_UNICODE_VAR"
|
|
137
|
+
test_value = "测试值 🚀"
|
|
138
|
+
|
|
139
|
+
with temporary_env(var_name, test_value):
|
|
140
|
+
assert os.environ[var_name] == test_value
|
|
141
|
+
|
|
142
|
+
assert var_name not in os.environ
|