kiln-ai 0.19.0__py3-none-any.whl → 0.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (158) hide show
  1. kiln_ai/adapters/__init__.py +8 -2
  2. kiln_ai/adapters/adapter_registry.py +43 -208
  3. kiln_ai/adapters/chat/chat_formatter.py +8 -12
  4. kiln_ai/adapters/chat/test_chat_formatter.py +6 -2
  5. kiln_ai/adapters/chunkers/__init__.py +13 -0
  6. kiln_ai/adapters/chunkers/base_chunker.py +42 -0
  7. kiln_ai/adapters/chunkers/chunker_registry.py +16 -0
  8. kiln_ai/adapters/chunkers/fixed_window_chunker.py +39 -0
  9. kiln_ai/adapters/chunkers/helpers.py +23 -0
  10. kiln_ai/adapters/chunkers/test_base_chunker.py +63 -0
  11. kiln_ai/adapters/chunkers/test_chunker_registry.py +28 -0
  12. kiln_ai/adapters/chunkers/test_fixed_window_chunker.py +346 -0
  13. kiln_ai/adapters/chunkers/test_helpers.py +75 -0
  14. kiln_ai/adapters/data_gen/test_data_gen_task.py +9 -3
  15. kiln_ai/adapters/docker_model_runner_tools.py +119 -0
  16. kiln_ai/adapters/embedding/__init__.py +0 -0
  17. kiln_ai/adapters/embedding/base_embedding_adapter.py +44 -0
  18. kiln_ai/adapters/embedding/embedding_registry.py +32 -0
  19. kiln_ai/adapters/embedding/litellm_embedding_adapter.py +199 -0
  20. kiln_ai/adapters/embedding/test_base_embedding_adapter.py +283 -0
  21. kiln_ai/adapters/embedding/test_embedding_registry.py +166 -0
  22. kiln_ai/adapters/embedding/test_litellm_embedding_adapter.py +1149 -0
  23. kiln_ai/adapters/eval/base_eval.py +2 -2
  24. kiln_ai/adapters/eval/eval_runner.py +9 -3
  25. kiln_ai/adapters/eval/g_eval.py +2 -2
  26. kiln_ai/adapters/eval/test_base_eval.py +2 -4
  27. kiln_ai/adapters/eval/test_g_eval.py +4 -5
  28. kiln_ai/adapters/extractors/__init__.py +18 -0
  29. kiln_ai/adapters/extractors/base_extractor.py +72 -0
  30. kiln_ai/adapters/extractors/encoding.py +20 -0
  31. kiln_ai/adapters/extractors/extractor_registry.py +44 -0
  32. kiln_ai/adapters/extractors/extractor_runner.py +112 -0
  33. kiln_ai/adapters/extractors/litellm_extractor.py +386 -0
  34. kiln_ai/adapters/extractors/test_base_extractor.py +244 -0
  35. kiln_ai/adapters/extractors/test_encoding.py +54 -0
  36. kiln_ai/adapters/extractors/test_extractor_registry.py +181 -0
  37. kiln_ai/adapters/extractors/test_extractor_runner.py +181 -0
  38. kiln_ai/adapters/extractors/test_litellm_extractor.py +1192 -0
  39. kiln_ai/adapters/fine_tune/__init__.py +1 -1
  40. kiln_ai/adapters/fine_tune/openai_finetune.py +14 -4
  41. kiln_ai/adapters/fine_tune/test_dataset_formatter.py +2 -2
  42. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +2 -6
  43. kiln_ai/adapters/fine_tune/test_openai_finetune.py +108 -111
  44. kiln_ai/adapters/fine_tune/test_together_finetune.py +2 -6
  45. kiln_ai/adapters/ml_embedding_model_list.py +192 -0
  46. kiln_ai/adapters/ml_model_list.py +761 -37
  47. kiln_ai/adapters/model_adapters/base_adapter.py +51 -21
  48. kiln_ai/adapters/model_adapters/litellm_adapter.py +380 -138
  49. kiln_ai/adapters/model_adapters/test_base_adapter.py +193 -17
  50. kiln_ai/adapters/model_adapters/test_litellm_adapter.py +407 -2
  51. kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py +1103 -0
  52. kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +5 -5
  53. kiln_ai/adapters/model_adapters/test_structured_output.py +113 -5
  54. kiln_ai/adapters/ollama_tools.py +69 -12
  55. kiln_ai/adapters/parsers/__init__.py +1 -1
  56. kiln_ai/adapters/provider_tools.py +205 -47
  57. kiln_ai/adapters/rag/deduplication.py +49 -0
  58. kiln_ai/adapters/rag/progress.py +252 -0
  59. kiln_ai/adapters/rag/rag_runners.py +844 -0
  60. kiln_ai/adapters/rag/test_deduplication.py +195 -0
  61. kiln_ai/adapters/rag/test_progress.py +785 -0
  62. kiln_ai/adapters/rag/test_rag_runners.py +2376 -0
  63. kiln_ai/adapters/remote_config.py +80 -8
  64. kiln_ai/adapters/repair/test_repair_task.py +12 -9
  65. kiln_ai/adapters/run_output.py +3 -0
  66. kiln_ai/adapters/test_adapter_registry.py +657 -85
  67. kiln_ai/adapters/test_docker_model_runner_tools.py +305 -0
  68. kiln_ai/adapters/test_ml_embedding_model_list.py +429 -0
  69. kiln_ai/adapters/test_ml_model_list.py +251 -1
  70. kiln_ai/adapters/test_ollama_tools.py +340 -1
  71. kiln_ai/adapters/test_prompt_adaptors.py +13 -6
  72. kiln_ai/adapters/test_prompt_builders.py +1 -1
  73. kiln_ai/adapters/test_provider_tools.py +254 -8
  74. kiln_ai/adapters/test_remote_config.py +651 -58
  75. kiln_ai/adapters/vector_store/__init__.py +1 -0
  76. kiln_ai/adapters/vector_store/base_vector_store_adapter.py +83 -0
  77. kiln_ai/adapters/vector_store/lancedb_adapter.py +389 -0
  78. kiln_ai/adapters/vector_store/test_base_vector_store.py +160 -0
  79. kiln_ai/adapters/vector_store/test_lancedb_adapter.py +1841 -0
  80. kiln_ai/adapters/vector_store/test_vector_store_registry.py +199 -0
  81. kiln_ai/adapters/vector_store/vector_store_registry.py +33 -0
  82. kiln_ai/datamodel/__init__.py +39 -34
  83. kiln_ai/datamodel/basemodel.py +170 -1
  84. kiln_ai/datamodel/chunk.py +158 -0
  85. kiln_ai/datamodel/datamodel_enums.py +28 -0
  86. kiln_ai/datamodel/embedding.py +64 -0
  87. kiln_ai/datamodel/eval.py +1 -1
  88. kiln_ai/datamodel/external_tool_server.py +298 -0
  89. kiln_ai/datamodel/extraction.py +303 -0
  90. kiln_ai/datamodel/json_schema.py +25 -10
  91. kiln_ai/datamodel/project.py +40 -1
  92. kiln_ai/datamodel/rag.py +79 -0
  93. kiln_ai/datamodel/registry.py +0 -15
  94. kiln_ai/datamodel/run_config.py +62 -0
  95. kiln_ai/datamodel/task.py +2 -77
  96. kiln_ai/datamodel/task_output.py +6 -1
  97. kiln_ai/datamodel/task_run.py +41 -0
  98. kiln_ai/datamodel/test_attachment.py +649 -0
  99. kiln_ai/datamodel/test_basemodel.py +4 -4
  100. kiln_ai/datamodel/test_chunk_models.py +317 -0
  101. kiln_ai/datamodel/test_dataset_split.py +1 -1
  102. kiln_ai/datamodel/test_embedding_models.py +448 -0
  103. kiln_ai/datamodel/test_eval_model.py +6 -6
  104. kiln_ai/datamodel/test_example_models.py +175 -0
  105. kiln_ai/datamodel/test_external_tool_server.py +691 -0
  106. kiln_ai/datamodel/test_extraction_chunk.py +206 -0
  107. kiln_ai/datamodel/test_extraction_model.py +470 -0
  108. kiln_ai/datamodel/test_rag.py +641 -0
  109. kiln_ai/datamodel/test_registry.py +8 -3
  110. kiln_ai/datamodel/test_task.py +15 -47
  111. kiln_ai/datamodel/test_tool_id.py +320 -0
  112. kiln_ai/datamodel/test_vector_store.py +320 -0
  113. kiln_ai/datamodel/tool_id.py +105 -0
  114. kiln_ai/datamodel/vector_store.py +141 -0
  115. kiln_ai/tools/__init__.py +8 -0
  116. kiln_ai/tools/base_tool.py +82 -0
  117. kiln_ai/tools/built_in_tools/__init__.py +13 -0
  118. kiln_ai/tools/built_in_tools/math_tools.py +124 -0
  119. kiln_ai/tools/built_in_tools/test_math_tools.py +204 -0
  120. kiln_ai/tools/mcp_server_tool.py +95 -0
  121. kiln_ai/tools/mcp_session_manager.py +246 -0
  122. kiln_ai/tools/rag_tools.py +157 -0
  123. kiln_ai/tools/test_base_tools.py +199 -0
  124. kiln_ai/tools/test_mcp_server_tool.py +457 -0
  125. kiln_ai/tools/test_mcp_session_manager.py +1585 -0
  126. kiln_ai/tools/test_rag_tools.py +848 -0
  127. kiln_ai/tools/test_tool_registry.py +562 -0
  128. kiln_ai/tools/tool_registry.py +85 -0
  129. kiln_ai/utils/__init__.py +3 -0
  130. kiln_ai/utils/async_job_runner.py +62 -17
  131. kiln_ai/utils/config.py +24 -2
  132. kiln_ai/utils/env.py +15 -0
  133. kiln_ai/utils/filesystem.py +14 -0
  134. kiln_ai/utils/filesystem_cache.py +60 -0
  135. kiln_ai/utils/litellm.py +94 -0
  136. kiln_ai/utils/lock.py +100 -0
  137. kiln_ai/utils/mime_type.py +38 -0
  138. kiln_ai/utils/open_ai_types.py +94 -0
  139. kiln_ai/utils/pdf_utils.py +38 -0
  140. kiln_ai/utils/project_utils.py +17 -0
  141. kiln_ai/utils/test_async_job_runner.py +151 -35
  142. kiln_ai/utils/test_config.py +138 -1
  143. kiln_ai/utils/test_env.py +142 -0
  144. kiln_ai/utils/test_filesystem_cache.py +316 -0
  145. kiln_ai/utils/test_litellm.py +206 -0
  146. kiln_ai/utils/test_lock.py +185 -0
  147. kiln_ai/utils/test_mime_type.py +66 -0
  148. kiln_ai/utils/test_open_ai_types.py +131 -0
  149. kiln_ai/utils/test_pdf_utils.py +73 -0
  150. kiln_ai/utils/test_uuid.py +111 -0
  151. kiln_ai/utils/test_validation.py +524 -0
  152. kiln_ai/utils/uuid.py +9 -0
  153. kiln_ai/utils/validation.py +90 -0
  154. {kiln_ai-0.19.0.dist-info → kiln_ai-0.21.0.dist-info}/METADATA +12 -5
  155. kiln_ai-0.21.0.dist-info/RECORD +211 -0
  156. kiln_ai-0.19.0.dist-info/RECORD +0 -115
  157. {kiln_ai-0.19.0.dist-info → kiln_ai-0.21.0.dist-info}/WHEEL +0 -0
  158. {kiln_ai-0.19.0.dist-info → kiln_ai-0.21.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -3,30 +3,51 @@ from unittest.mock import AsyncMock, patch
3
3
 
4
4
  import pytest
5
5
 
6
- from kiln_ai.utils.async_job_runner import AsyncJobRunner, Progress
6
+ from kiln_ai.utils.async_job_runner import (
7
+ AsyncJobRunner,
8
+ AsyncJobRunnerObserver,
9
+ Progress,
10
+ )
11
+
12
+
13
+ @pytest.fixture
14
+ def mock_async_run_job_fn_success():
15
+ return AsyncMock(return_value=True)
16
+
17
+
18
+ @pytest.fixture
19
+ def mock_async_run_job_fn_failure():
20
+ return AsyncMock(return_value=False)
7
21
 
8
22
 
9
23
  @pytest.mark.parametrize("concurrency", [0, -1, -25])
10
- def test_invalid_concurrency_raises(concurrency):
24
+ def test_invalid_concurrency_raises(concurrency, mock_async_run_job_fn_success):
11
25
  with pytest.raises(ValueError):
12
- AsyncJobRunner(concurrency=concurrency)
26
+ AsyncJobRunner(
27
+ concurrency=concurrency,
28
+ jobs=[],
29
+ run_job_fn=mock_async_run_job_fn_success,
30
+ )
13
31
 
14
32
 
15
33
  # Test with and without concurrency
16
34
  @pytest.mark.parametrize("concurrency", [1, 25])
17
35
  @pytest.mark.asyncio
18
- async def test_async_job_runner_status_updates(concurrency):
36
+ async def test_async_job_runner_status_updates(
37
+ concurrency, mock_async_run_job_fn_success
38
+ ):
19
39
  job_count = 50
20
40
  jobs = [{"id": i} for i in range(job_count)]
21
41
 
22
- runner = AsyncJobRunner(concurrency=concurrency)
23
-
24
- # fake run_job that succeeds
25
- mock_run_job_success = AsyncMock(return_value=True)
42
+ runner = AsyncJobRunner(
43
+ concurrency=concurrency,
44
+ jobs=jobs,
45
+ run_job_fn=mock_async_run_job_fn_success,
46
+ )
26
47
 
27
48
  # Expect the status updates in order, and 1 for each job
28
49
  expected_completed_count = 0
29
- async for progress in runner.run(jobs, mock_run_job_success):
50
+ async for progress in runner.run():
30
51
  assert progress.complete == expected_completed_count
31
52
  expected_completed_count += 1
32
53
  assert progress.errors == 0
@@ -36,26 +57,29 @@ async def test_async_job_runner_status_updates(concurrency):
36
57
  assert expected_completed_count == job_count + 1
37
58
 
38
59
  # Verify run_job was called for each job
39
- assert mock_run_job_success.call_count == job_count
60
+ assert mock_async_run_job_fn_success.call_count == job_count
40
61
 
41
62
  # Verify run_job was called with the correct arguments
42
63
  for i in range(job_count):
43
- mock_run_job_success.assert_any_await(jobs[i])
64
+ mock_async_run_job_fn_success.assert_any_await(jobs[i])
44
65
 
45
66
 
46
67
  # Test with and without concurrency
47
68
  @pytest.mark.parametrize("concurrency", [1, 25])
48
69
  @pytest.mark.asyncio
49
- async def test_async_job_runner_status_updates_empty_job_list(concurrency):
70
+ async def test_async_job_runner_status_updates_empty_job_list(
71
+ concurrency, mock_async_run_job_fn_success
72
+ ):
50
73
  empty_job_list = []
51
74
 
52
- runner = AsyncJobRunner(concurrency=concurrency)
53
-
54
- # fake run_job that succeeds
55
- mock_run_job_success = AsyncMock(return_value=True)
75
+ runner = AsyncJobRunner(
76
+ concurrency=concurrency,
77
+ jobs=empty_job_list,
78
+ run_job_fn=mock_async_run_job_fn_success,
79
+ )
56
80
 
57
81
  updates: List[Progress] = []
58
- async for progress in runner.run(empty_job_list, mock_run_job_success):
82
+ async for progress in runner.run():
59
83
  updates.append(progress)
60
84
 
61
85
  # Verify last status update was complete
@@ -66,23 +90,26 @@ async def test_async_job_runner_status_updates_empty_job_list(concurrency):
66
90
  assert updates[0].total == 0
67
91
 
68
92
  # Verify run_job was called for each job
69
- assert mock_run_job_success.call_count == 0
93
+ assert mock_async_run_job_fn_success.call_count == 0
70
94
 
71
95
 
72
96
  @pytest.mark.parametrize("concurrency", [1, 25])
73
97
  @pytest.mark.asyncio
74
- async def test_async_job_runner_all_failures(concurrency):
98
+ async def test_async_job_runner_all_failures(
99
+ concurrency, mock_async_run_job_fn_failure
100
+ ):
75
101
  job_count = 50
76
102
  jobs = [{"id": i} for i in range(job_count)]
77
103
 
78
- runner = AsyncJobRunner(concurrency=concurrency)
79
-
80
- # fake run_job that fails
81
- mock_run_job_failure = AsyncMock(return_value=False)
104
+ runner = AsyncJobRunner(
105
+ concurrency=concurrency,
106
+ jobs=jobs,
107
+ run_job_fn=mock_async_run_job_fn_failure,
108
+ )
82
109
 
83
110
  # Expect the status updates in order, and 1 for each job
84
111
  expected_error_count = 0
85
- async for progress in runner.run(jobs, mock_run_job_failure):
112
+ async for progress in runner.run():
86
113
  assert progress.complete == 0
87
114
  assert progress.errors == expected_error_count
88
115
  expected_error_count += 1
@@ -92,11 +119,11 @@ async def test_async_job_runner_all_failures(concurrency):
92
119
  assert expected_error_count == job_count + 1
93
120
 
94
121
  # Verify run_job was called for each job
95
- assert mock_run_job_failure.call_count == job_count
122
+ assert mock_async_run_job_fn_failure.call_count == job_count
96
123
 
97
124
  # Verify run_job was called with the correct arguments
98
125
  for i in range(job_count):
99
- mock_run_job_failure.assert_any_await(jobs[i])
126
+ mock_async_run_job_fn_failure.assert_any_await(jobs[i])
100
127
 
101
128
 
102
129
  @pytest.mark.parametrize("concurrency", [1, 25])
@@ -108,16 +135,20 @@ async def test_async_job_runner_partial_failures(concurrency):
108
135
  # we want to fail on some jobs and succeed on others
109
136
  jobs_to_fail = set([0, 2, 4, 6, 8, 20, 25])
110
137
 
111
- runner = AsyncJobRunner(concurrency=concurrency)
112
-
113
138
  # fake run_job that fails
114
139
  mock_run_job_partial_success = AsyncMock(
115
140
  # return True for jobs that should succeed
116
141
  side_effect=lambda job: job["id"] not in jobs_to_fail
117
142
  )
118
143
 
144
+ runner = AsyncJobRunner(
145
+ concurrency=concurrency,
146
+ jobs=jobs,
147
+ run_job_fn=mock_run_job_partial_success,
148
+ )
149
+
119
150
  # Expect the status updates in order, and 1 for each job
120
- async for progress in runner.run(jobs, mock_run_job_partial_success):
151
+ async for progress in runner.run():
121
152
  assert progress.total == job_count
122
153
 
123
154
  # Verify last status update was complete
@@ -140,8 +171,6 @@ async def test_async_job_runner_partial_raises(concurrency):
140
171
  job_count = 50
141
172
  jobs = [{"id": i} for i in range(job_count)]
142
173
 
143
- runner = AsyncJobRunner(concurrency=concurrency)
144
-
145
174
  ids_to_fail = set([10, 25])
146
175
 
147
176
  def failure_fn(job):
@@ -152,6 +181,12 @@ async def test_async_job_runner_partial_raises(concurrency):
152
181
  # fake run_job that fails
153
182
  mock_run_job_partial_success = AsyncMock(side_effect=failure_fn)
154
183
 
184
+ runner = AsyncJobRunner(
185
+ concurrency=concurrency,
186
+ jobs=jobs,
187
+ run_job_fn=mock_run_job_partial_success,
188
+ )
189
+
155
190
  # generate all the values we expect to see in progress updates
156
191
  complete_values_expected = set([i for i in range(job_count - len(ids_to_fail) + 1)])
157
192
  errors_values_expected = set([i for i in range(len(ids_to_fail) + 1)])
@@ -164,7 +199,7 @@ async def test_async_job_runner_partial_raises(concurrency):
164
199
  errors_values_actual = set()
165
200
 
166
201
  # Expect the status updates in order, and 1 for each job
167
- async for progress in runner.run(jobs, mock_run_job_partial_success):
202
+ async for progress in runner.run():
168
203
  updates.append(progress)
169
204
  complete_values_actual.add(progress.complete)
170
205
  errors_values_actual.add(progress.errors)
@@ -184,9 +219,13 @@ async def test_async_job_runner_partial_raises(concurrency):
184
219
 
185
220
  @pytest.mark.parametrize("concurrency", [1, 25])
186
221
  @pytest.mark.asyncio
187
- async def test_async_job_runner_cancelled(concurrency):
188
- runner = AsyncJobRunner(concurrency=concurrency)
222
+ async def test_async_job_runner_cancelled(concurrency, mock_async_run_job_fn_success):
189
223
  jobs = [{"id": i} for i in range(10)]
224
+ runner = AsyncJobRunner(
225
+ concurrency=concurrency,
226
+ jobs=jobs,
227
+ run_job_fn=mock_async_run_job_fn_success,
228
+ )
190
229
 
191
230
  with patch.object(
192
231
  runner,
@@ -195,5 +234,82 @@ async def test_async_job_runner_cancelled(concurrency):
195
234
  ):
196
235
  # if an exception is raised in the task, we should see it bubble up
197
236
  with pytest.raises(Exception, match="run_worker raised an exception"):
198
- async for _ in runner.run(jobs, AsyncMock(return_value=True)):
237
+ async for _ in runner.run():
199
238
  pass
239
+
240
+
241
+ @pytest.mark.parametrize("concurrency", [1, 25])
242
+ @pytest.mark.asyncio
243
+ async def test_async_job_runner_observers(concurrency):
244
+ class MockAsyncJobRunnerObserver(AsyncJobRunnerObserver[dict[str, int]]):
245
+ def __init__(self):
246
+ self.on_error_calls = []
247
+ self.on_success_calls = []
248
+
249
+ async def on_error(self, job: dict[str, int], error: Exception):
250
+ self.on_error_calls.append((job, error))
251
+
252
+ async def on_success(self, job: dict[str, int]):
253
+ self.on_success_calls.append(job)
254
+
255
+ mock_observer_a = MockAsyncJobRunnerObserver()
256
+ mock_observer_b = MockAsyncJobRunnerObserver()
257
+
258
+ jobs = [{"id": i} for i in range(10)]
259
+
260
+ async def run_job_fn(job: dict[str, int]) -> bool:
261
+ # we simulate the job 5 and 6 crashing, which should trigger the observers on_error handlers
262
+ if job["id"] == 5 or job["id"] == 6:
263
+ raise ValueError(f"job failed unexpectedly {job['id']}")
264
+ return True
265
+
266
+ runner = AsyncJobRunner(
267
+ concurrency=concurrency,
268
+ jobs=jobs,
269
+ run_job_fn=run_job_fn,
270
+ observers=[mock_observer_a, mock_observer_b],
271
+ )
272
+
273
+ async for _ in runner.run():
274
+ pass
275
+
276
+ assert len(mock_observer_a.on_error_calls) == 2
277
+ assert len(mock_observer_b.on_error_calls) == 2
278
+
279
+ # not necessarily in order, but we should have seen both 5 and 6
280
+ assert len(mock_observer_a.on_success_calls) == 8
281
+ assert len(mock_observer_b.on_success_calls) == 8
282
+
283
+ # check that 5 and 6 are in the error calls
284
+ for job_idx in [5, 6]:
285
+ # check that 5 and 6 are in the error calls for both observers
286
+ assert any(call[0] == jobs[job_idx] for call in mock_observer_a.on_error_calls)
287
+ assert any(call[0] == jobs[job_idx] for call in mock_observer_b.on_error_calls)
288
+
289
+ # check that the error is the correct exception
290
+ assert (
291
+ str(mock_observer_a.on_error_calls[0][1]) == "job failed unexpectedly 5"
292
+ or str(mock_observer_a.on_error_calls[1][1]) == "job failed unexpectedly 6"
293
+ )
294
+ assert (
295
+ str(mock_observer_b.on_error_calls[0][1]) == "job failed unexpectedly 5"
296
+ or str(mock_observer_b.on_error_calls[1][1]) == "job failed unexpectedly 6"
297
+ )
298
+
299
+ # check that 5 and 6 are not in the success calls for both observers
300
+ assert not any(
301
+ call == jobs[job_idx] for call in mock_observer_a.on_success_calls
302
+ )
303
+ assert not any(
304
+ call == jobs[job_idx] for call in mock_observer_b.on_success_calls
305
+ )
306
+
307
+ # check that the other jobs are in the success calls for both observers
308
+ for job_idx in range(10):
309
+ if job_idx not in [5, 6]:
310
+ assert any(
311
+ call == jobs[job_idx] for call in mock_observer_a.on_success_calls
312
+ )
313
+ assert any(
314
+ call == jobs[job_idx] for call in mock_observer_b.on_success_calls
315
+ )
@@ -6,7 +6,7 @@ from unittest.mock import patch
6
6
  import pytest
7
7
  import yaml
8
8
 
9
- from kiln_ai.utils.config import Config, ConfigProperty, _get_user_id
9
+ from kiln_ai.utils.config import MCP_SECRETS_KEY, Config, ConfigProperty, _get_user_id
10
10
 
11
11
 
12
12
  @pytest.fixture
@@ -322,3 +322,140 @@ def test_update_settings_thread_safety(config_with_yaml):
322
322
 
323
323
  assert not exceptions
324
324
  assert config.int_property in range(5)
325
+
326
+
327
+ def test_mcp_secrets_property():
328
+ """Test mcp_secrets configuration property"""
329
+ config = Config.shared()
330
+
331
+ # Initially should be None/empty
332
+ assert config.mcp_secrets is None
333
+
334
+ # Set some secrets
335
+ secrets = {
336
+ "server1::Authorization": "Bearer token123",
337
+ "server1::X-API-Key": "api-key-456",
338
+ "server2::Token": "secret-token",
339
+ }
340
+ config.mcp_secrets = secrets
341
+
342
+ # Verify they are stored correctly
343
+ assert config.mcp_secrets == secrets
344
+ assert config.mcp_secrets["server1::Authorization"] == "Bearer token123"
345
+ assert config.mcp_secrets["server1::X-API-Key"] == "api-key-456"
346
+ assert config.mcp_secrets["server2::Token"] == "secret-token"
347
+
348
+
349
+ def test_mcp_secrets_sensitive_hiding():
350
+ """Test that mcp_secrets are hidden when hide_sensitive=True"""
351
+ config = Config.shared()
352
+
353
+ # Set some secrets
354
+ secrets = {
355
+ "server1::Authorization": "Bearer secret123",
356
+ "server2::X-API-Key": "secret-key",
357
+ }
358
+ config.mcp_secrets = secrets
359
+
360
+ # Test without hiding sensitive data
361
+ visible_settings = config.settings(hide_sensitive=False)
362
+ assert MCP_SECRETS_KEY in visible_settings
363
+ assert visible_settings[MCP_SECRETS_KEY] == secrets
364
+
365
+ # Test with hiding sensitive data
366
+ hidden_settings = config.settings(hide_sensitive=True)
367
+ assert MCP_SECRETS_KEY in hidden_settings
368
+ assert hidden_settings[MCP_SECRETS_KEY] == "[hidden]"
369
+
370
+
371
+ def test_mcp_secrets_persistence(mock_yaml_file):
372
+ """Test that mcp_secrets are persisted to YAML correctly"""
373
+ with patch(
374
+ "kiln_ai.utils.config.Config.settings_path",
375
+ return_value=mock_yaml_file,
376
+ ):
377
+ config = Config()
378
+
379
+ # Set some secrets
380
+ secrets = {
381
+ "server1::Authorization": "Bearer persist123",
382
+ "server2::Token": "persist-token",
383
+ }
384
+ config.mcp_secrets = secrets
385
+
386
+ # Check that the value was saved to the YAML file
387
+ with open(mock_yaml_file, "r") as f:
388
+ saved_settings = yaml.safe_load(f)
389
+ assert saved_settings[MCP_SECRETS_KEY] == secrets
390
+
391
+ # Create a new config instance to test loading from YAML
392
+ new_config = Config()
393
+
394
+ # Check that the value is loaded from YAML
395
+ assert new_config.mcp_secrets == secrets
396
+
397
+
398
+ def test_mcp_secrets_get_value():
399
+ """Test that mcp_secrets can be retrieved using get_value method"""
400
+ config = Config.shared()
401
+
402
+ # Initially should be None
403
+ assert config.get_value(MCP_SECRETS_KEY) is None
404
+
405
+ # Set some secrets
406
+ secrets = {"server::key": "value"}
407
+ config.mcp_secrets = secrets
408
+
409
+ # Should be retrievable via get_value
410
+ assert config.get_value(MCP_SECRETS_KEY) == secrets
411
+
412
+
413
+ def test_mcp_secrets_update_settings():
414
+ """Test updating mcp_secrets using update_settings method"""
415
+ config = Config.shared()
416
+
417
+ # Set initial secrets
418
+ initial_secrets = {"server1::key1": "value1"}
419
+ config.update_settings({MCP_SECRETS_KEY: initial_secrets})
420
+ assert config.mcp_secrets == initial_secrets
421
+
422
+ # Update with new secrets (should replace, not merge)
423
+ new_secrets = {
424
+ "server1::key1": "updated_value1",
425
+ "server2::key2": "value2",
426
+ }
427
+ config.update_settings({MCP_SECRETS_KEY: new_secrets})
428
+ assert config.mcp_secrets == new_secrets
429
+ assert config.mcp_secrets["server1::key1"] == "updated_value1"
430
+ assert config.mcp_secrets["server2::key2"] == "value2"
431
+
432
+
433
+ def test_mcp_secrets_empty_dict():
434
+ """Test mcp_secrets with empty dict"""
435
+ config = Config.shared()
436
+
437
+ # Set empty dict
438
+ config.mcp_secrets = {}
439
+ assert config.mcp_secrets == {}
440
+
441
+ # Should still be dict type, not None
442
+ assert isinstance(config.mcp_secrets, dict)
443
+
444
+
445
+ def test_mcp_secrets_type_validation():
446
+ """Test that mcp_secrets enforces dict[str, str] type"""
447
+ config = Config.shared()
448
+
449
+ # Valid dict[str, str]
450
+ valid_secrets = {"server::key": "value"}
451
+ config.mcp_secrets = valid_secrets
452
+ assert config.mcp_secrets == valid_secrets
453
+
454
+ # The config system applies type conversion when retrieving values
455
+ mixed_types = {"server::key": 123} # int value
456
+ config.mcp_secrets = mixed_types
457
+ # The type conversion happens when the value is retrieved, not when set
458
+ # So the underlying storage may preserve the original type
459
+ assert config.mcp_secrets == mixed_types or config.mcp_secrets == {
460
+ "server::key": "123"
461
+ }
@@ -0,0 +1,142 @@
1
+ import os
2
+
3
+ import pytest
4
+
5
+ from kiln_ai.utils.env import temporary_env
6
+
7
+
8
+ class TestTemporaryEnv:
9
+ def test_set_new_env_var(self):
10
+ """Test setting a new environment variable that doesn't exist."""
11
+ var_name = "TEST_NEW_VAR"
12
+ test_value = "test_value"
13
+
14
+ # Ensure the variable doesn't exist initially
15
+ assert var_name not in os.environ
16
+
17
+ with temporary_env(var_name, test_value):
18
+ assert os.environ[var_name] == test_value
19
+
20
+ # Verify it's removed after context
21
+ assert var_name not in os.environ
22
+
23
+ def test_modify_existing_env_var(self):
24
+ """Test modifying an existing environment variable."""
25
+ var_name = "TEST_EXISTING_VAR"
26
+ original_value = "original_value"
27
+ new_value = "new_value"
28
+
29
+ # Set up initial state
30
+ os.environ[var_name] = original_value
31
+
32
+ with temporary_env(var_name, new_value):
33
+ assert os.environ[var_name] == new_value
34
+
35
+ # Verify original value is restored
36
+ assert os.environ[var_name] == original_value
37
+
38
+ def test_restore_nonexistent_var(self):
39
+ """Test that a variable that didn't exist is properly removed."""
40
+ var_name = "TEST_NONEXISTENT_VAR"
41
+ test_value = "test_value"
42
+
43
+ # Ensure the variable doesn't exist initially
44
+ if var_name in os.environ:
45
+ del os.environ[var_name]
46
+
47
+ with temporary_env(var_name, test_value):
48
+ assert os.environ[var_name] == test_value
49
+
50
+ # Verify it's removed after context
51
+ assert var_name not in os.environ
52
+
53
+ def test_exception_handling(self):
54
+ """Test that environment is restored even when an exception occurs."""
55
+ var_name = "TEST_EXCEPTION_VAR"
56
+ original_value = "original_value"
57
+ new_value = "new_value"
58
+
59
+ # Set up initial state
60
+ os.environ[var_name] = original_value
61
+
62
+ with pytest.raises(ValueError):
63
+ with temporary_env(var_name, new_value):
64
+ assert os.environ[var_name] == new_value
65
+ raise ValueError("Test exception")
66
+
67
+ # Verify original value is restored even after exception
68
+ assert os.environ[var_name] == original_value
69
+
70
+ def test_exception_handling_new_var(self):
71
+ """Test that new variable is removed even when an exception occurs."""
72
+ var_name = "TEST_EXCEPTION_NEW_VAR"
73
+ test_value = "test_value"
74
+
75
+ # Ensure the variable doesn't exist initially
76
+ if var_name in os.environ:
77
+ del os.environ[var_name]
78
+
79
+ with pytest.raises(RuntimeError):
80
+ with temporary_env(var_name, test_value):
81
+ assert os.environ[var_name] == test_value
82
+ raise RuntimeError("Test exception")
83
+
84
+ # Verify variable is removed even after exception
85
+ assert var_name not in os.environ
86
+
87
+ def test_nested_context_managers(self):
88
+ """Test using multiple temporary_env context managers."""
89
+ var1 = "TEST_NESTED_VAR1"
90
+ var2 = "TEST_NESTED_VAR2"
91
+ value1 = "value1"
92
+ value2 = "value2"
93
+
94
+ # Set up initial state
95
+ os.environ[var1] = "original1"
96
+ if var2 in os.environ:
97
+ del os.environ[var2]
98
+
99
+ with temporary_env(var1, value1):
100
+ assert os.environ[var1] == value1
101
+
102
+ with temporary_env(var2, value2):
103
+ assert os.environ[var1] == value1
104
+ assert os.environ[var2] == value2
105
+
106
+ # Inner context should be cleaned up
107
+ assert var2 not in os.environ
108
+ assert os.environ[var1] == value1
109
+
110
+ # Both contexts should be cleaned up
111
+ assert os.environ[var1] == "original1"
112
+ assert var2 not in os.environ
113
+
114
+ def test_empty_string_value(self):
115
+ """Test setting an empty string value."""
116
+ var_name = "TEST_EMPTY_VAR"
117
+ test_value = ""
118
+
119
+ with temporary_env(var_name, test_value):
120
+ assert os.environ[var_name] == test_value
121
+
122
+ assert var_name not in os.environ
123
+
124
+ def test_none_value_handling(self):
125
+ """Test that None values are handled properly."""
126
+ var_name = "TEST_NONE_VAR"
127
+ test_value = "test_value"
128
+
129
+ with temporary_env(var_name, test_value):
130
+ assert os.environ[var_name] == test_value
131
+
132
+ assert var_name not in os.environ
133
+
134
+ def test_unicode_value(self):
135
+ """Test setting unicode values."""
136
+ var_name = "TEST_UNICODE_VAR"
137
+ test_value = "测试值 🚀"
138
+
139
+ with temporary_env(var_name, test_value):
140
+ assert os.environ[var_name] == test_value
141
+
142
+ assert var_name not in os.environ