kiln-ai 0.20.1__py3-none-any.whl → 0.22.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (133) hide show
  1. kiln_ai/adapters/__init__.py +6 -0
  2. kiln_ai/adapters/adapter_registry.py +43 -226
  3. kiln_ai/adapters/chunkers/__init__.py +13 -0
  4. kiln_ai/adapters/chunkers/base_chunker.py +42 -0
  5. kiln_ai/adapters/chunkers/chunker_registry.py +16 -0
  6. kiln_ai/adapters/chunkers/fixed_window_chunker.py +39 -0
  7. kiln_ai/adapters/chunkers/helpers.py +23 -0
  8. kiln_ai/adapters/chunkers/test_base_chunker.py +63 -0
  9. kiln_ai/adapters/chunkers/test_chunker_registry.py +28 -0
  10. kiln_ai/adapters/chunkers/test_fixed_window_chunker.py +346 -0
  11. kiln_ai/adapters/chunkers/test_helpers.py +75 -0
  12. kiln_ai/adapters/data_gen/test_data_gen_task.py +9 -3
  13. kiln_ai/adapters/embedding/__init__.py +0 -0
  14. kiln_ai/adapters/embedding/base_embedding_adapter.py +44 -0
  15. kiln_ai/adapters/embedding/embedding_registry.py +32 -0
  16. kiln_ai/adapters/embedding/litellm_embedding_adapter.py +199 -0
  17. kiln_ai/adapters/embedding/test_base_embedding_adapter.py +283 -0
  18. kiln_ai/adapters/embedding/test_embedding_registry.py +166 -0
  19. kiln_ai/adapters/embedding/test_litellm_embedding_adapter.py +1149 -0
  20. kiln_ai/adapters/eval/eval_runner.py +6 -2
  21. kiln_ai/adapters/eval/test_base_eval.py +1 -3
  22. kiln_ai/adapters/eval/test_g_eval.py +1 -1
  23. kiln_ai/adapters/extractors/__init__.py +18 -0
  24. kiln_ai/adapters/extractors/base_extractor.py +72 -0
  25. kiln_ai/adapters/extractors/encoding.py +20 -0
  26. kiln_ai/adapters/extractors/extractor_registry.py +44 -0
  27. kiln_ai/adapters/extractors/extractor_runner.py +112 -0
  28. kiln_ai/adapters/extractors/litellm_extractor.py +406 -0
  29. kiln_ai/adapters/extractors/test_base_extractor.py +244 -0
  30. kiln_ai/adapters/extractors/test_encoding.py +54 -0
  31. kiln_ai/adapters/extractors/test_extractor_registry.py +181 -0
  32. kiln_ai/adapters/extractors/test_extractor_runner.py +181 -0
  33. kiln_ai/adapters/extractors/test_litellm_extractor.py +1290 -0
  34. kiln_ai/adapters/fine_tune/test_dataset_formatter.py +2 -2
  35. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +2 -6
  36. kiln_ai/adapters/fine_tune/test_together_finetune.py +2 -6
  37. kiln_ai/adapters/ml_embedding_model_list.py +494 -0
  38. kiln_ai/adapters/ml_model_list.py +876 -18
  39. kiln_ai/adapters/model_adapters/litellm_adapter.py +40 -75
  40. kiln_ai/adapters/model_adapters/test_litellm_adapter.py +79 -1
  41. kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py +119 -5
  42. kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +9 -3
  43. kiln_ai/adapters/model_adapters/test_structured_output.py +9 -10
  44. kiln_ai/adapters/ollama_tools.py +69 -12
  45. kiln_ai/adapters/provider_tools.py +190 -46
  46. kiln_ai/adapters/rag/deduplication.py +49 -0
  47. kiln_ai/adapters/rag/progress.py +252 -0
  48. kiln_ai/adapters/rag/rag_runners.py +844 -0
  49. kiln_ai/adapters/rag/test_deduplication.py +195 -0
  50. kiln_ai/adapters/rag/test_progress.py +785 -0
  51. kiln_ai/adapters/rag/test_rag_runners.py +2376 -0
  52. kiln_ai/adapters/remote_config.py +80 -8
  53. kiln_ai/adapters/test_adapter_registry.py +579 -86
  54. kiln_ai/adapters/test_ml_embedding_model_list.py +239 -0
  55. kiln_ai/adapters/test_ml_model_list.py +202 -0
  56. kiln_ai/adapters/test_ollama_tools.py +340 -1
  57. kiln_ai/adapters/test_prompt_builders.py +1 -1
  58. kiln_ai/adapters/test_provider_tools.py +199 -8
  59. kiln_ai/adapters/test_remote_config.py +551 -56
  60. kiln_ai/adapters/vector_store/__init__.py +1 -0
  61. kiln_ai/adapters/vector_store/base_vector_store_adapter.py +83 -0
  62. kiln_ai/adapters/vector_store/lancedb_adapter.py +389 -0
  63. kiln_ai/adapters/vector_store/test_base_vector_store.py +160 -0
  64. kiln_ai/adapters/vector_store/test_lancedb_adapter.py +1841 -0
  65. kiln_ai/adapters/vector_store/test_vector_store_registry.py +199 -0
  66. kiln_ai/adapters/vector_store/vector_store_registry.py +33 -0
  67. kiln_ai/datamodel/__init__.py +16 -13
  68. kiln_ai/datamodel/basemodel.py +201 -4
  69. kiln_ai/datamodel/chunk.py +158 -0
  70. kiln_ai/datamodel/datamodel_enums.py +27 -0
  71. kiln_ai/datamodel/embedding.py +64 -0
  72. kiln_ai/datamodel/external_tool_server.py +206 -54
  73. kiln_ai/datamodel/extraction.py +317 -0
  74. kiln_ai/datamodel/project.py +33 -1
  75. kiln_ai/datamodel/rag.py +79 -0
  76. kiln_ai/datamodel/task.py +5 -0
  77. kiln_ai/datamodel/task_output.py +41 -11
  78. kiln_ai/datamodel/test_attachment.py +649 -0
  79. kiln_ai/datamodel/test_basemodel.py +270 -14
  80. kiln_ai/datamodel/test_chunk_models.py +317 -0
  81. kiln_ai/datamodel/test_dataset_split.py +1 -1
  82. kiln_ai/datamodel/test_datasource.py +50 -0
  83. kiln_ai/datamodel/test_embedding_models.py +448 -0
  84. kiln_ai/datamodel/test_eval_model.py +6 -6
  85. kiln_ai/datamodel/test_external_tool_server.py +534 -152
  86. kiln_ai/datamodel/test_extraction_chunk.py +206 -0
  87. kiln_ai/datamodel/test_extraction_model.py +501 -0
  88. kiln_ai/datamodel/test_rag.py +641 -0
  89. kiln_ai/datamodel/test_task.py +35 -1
  90. kiln_ai/datamodel/test_tool_id.py +187 -1
  91. kiln_ai/datamodel/test_vector_store.py +320 -0
  92. kiln_ai/datamodel/tool_id.py +58 -0
  93. kiln_ai/datamodel/vector_store.py +141 -0
  94. kiln_ai/tools/base_tool.py +12 -3
  95. kiln_ai/tools/built_in_tools/math_tools.py +12 -4
  96. kiln_ai/tools/kiln_task_tool.py +158 -0
  97. kiln_ai/tools/mcp_server_tool.py +2 -2
  98. kiln_ai/tools/mcp_session_manager.py +51 -22
  99. kiln_ai/tools/rag_tools.py +164 -0
  100. kiln_ai/tools/test_kiln_task_tool.py +527 -0
  101. kiln_ai/tools/test_mcp_server_tool.py +4 -15
  102. kiln_ai/tools/test_mcp_session_manager.py +187 -227
  103. kiln_ai/tools/test_rag_tools.py +929 -0
  104. kiln_ai/tools/test_tool_registry.py +290 -7
  105. kiln_ai/tools/tool_registry.py +69 -16
  106. kiln_ai/utils/__init__.py +3 -0
  107. kiln_ai/utils/async_job_runner.py +62 -17
  108. kiln_ai/utils/config.py +2 -2
  109. kiln_ai/utils/env.py +15 -0
  110. kiln_ai/utils/filesystem.py +14 -0
  111. kiln_ai/utils/filesystem_cache.py +60 -0
  112. kiln_ai/utils/litellm.py +94 -0
  113. kiln_ai/utils/lock.py +100 -0
  114. kiln_ai/utils/mime_type.py +38 -0
  115. kiln_ai/utils/open_ai_types.py +19 -2
  116. kiln_ai/utils/pdf_utils.py +59 -0
  117. kiln_ai/utils/test_async_job_runner.py +151 -35
  118. kiln_ai/utils/test_env.py +142 -0
  119. kiln_ai/utils/test_filesystem_cache.py +316 -0
  120. kiln_ai/utils/test_litellm.py +206 -0
  121. kiln_ai/utils/test_lock.py +185 -0
  122. kiln_ai/utils/test_mime_type.py +66 -0
  123. kiln_ai/utils/test_open_ai_types.py +88 -12
  124. kiln_ai/utils/test_pdf_utils.py +86 -0
  125. kiln_ai/utils/test_uuid.py +111 -0
  126. kiln_ai/utils/test_validation.py +524 -0
  127. kiln_ai/utils/uuid.py +9 -0
  128. kiln_ai/utils/validation.py +90 -0
  129. {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/METADATA +9 -1
  130. kiln_ai-0.22.0.dist-info/RECORD +213 -0
  131. kiln_ai-0.20.1.dist-info/RECORD +0 -138
  132. {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/WHEEL +0 -0
  133. {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -3,30 +3,51 @@ from unittest.mock import AsyncMock, patch
3
3
 
4
4
  import pytest
5
5
 
6
- from kiln_ai.utils.async_job_runner import AsyncJobRunner, Progress
6
+ from kiln_ai.utils.async_job_runner import (
7
+ AsyncJobRunner,
8
+ AsyncJobRunnerObserver,
9
+ Progress,
10
+ )
11
+
12
+
13
+ @pytest.fixture
14
+ def mock_async_run_job_fn_success():
15
+ return AsyncMock(return_value=True)
16
+
17
+
18
+ @pytest.fixture
19
+ def mock_async_run_job_fn_failure():
20
+ return AsyncMock(return_value=False)
7
21
 
8
22
 
9
23
  @pytest.mark.parametrize("concurrency", [0, -1, -25])
10
- def test_invalid_concurrency_raises(concurrency):
24
+ def test_invalid_concurrency_raises(concurrency, mock_async_run_job_fn_success):
11
25
  with pytest.raises(ValueError):
12
- AsyncJobRunner(concurrency=concurrency)
26
+ AsyncJobRunner(
27
+ concurrency=concurrency,
28
+ jobs=[],
29
+ run_job_fn=mock_async_run_job_fn_success,
30
+ )
13
31
 
14
32
 
15
33
  # Test with and without concurrency
16
34
  @pytest.mark.parametrize("concurrency", [1, 25])
17
35
  @pytest.mark.asyncio
18
- async def test_async_job_runner_status_updates(concurrency):
36
+ async def test_async_job_runner_status_updates(
37
+ concurrency, mock_async_run_job_fn_success
38
+ ):
19
39
  job_count = 50
20
40
  jobs = [{"id": i} for i in range(job_count)]
21
41
 
22
- runner = AsyncJobRunner(concurrency=concurrency)
23
-
24
- # fake run_job that succeeds
25
- mock_run_job_success = AsyncMock(return_value=True)
42
+ runner = AsyncJobRunner(
43
+ concurrency=concurrency,
44
+ jobs=jobs,
45
+ run_job_fn=mock_async_run_job_fn_success,
46
+ )
26
47
 
27
48
  # Expect the status updates in order, and 1 for each job
28
49
  expected_completed_count = 0
29
- async for progress in runner.run(jobs, mock_run_job_success):
50
+ async for progress in runner.run():
30
51
  assert progress.complete == expected_completed_count
31
52
  expected_completed_count += 1
32
53
  assert progress.errors == 0
@@ -36,26 +57,29 @@ async def test_async_job_runner_status_updates(concurrency):
36
57
  assert expected_completed_count == job_count + 1
37
58
 
38
59
  # Verify run_job was called for each job
39
- assert mock_run_job_success.call_count == job_count
60
+ assert mock_async_run_job_fn_success.call_count == job_count
40
61
 
41
62
  # Verify run_job was called with the correct arguments
42
63
  for i in range(job_count):
43
- mock_run_job_success.assert_any_await(jobs[i])
64
+ mock_async_run_job_fn_success.assert_any_await(jobs[i])
44
65
 
45
66
 
46
67
  # Test with and without concurrency
47
68
  @pytest.mark.parametrize("concurrency", [1, 25])
48
69
  @pytest.mark.asyncio
49
- async def test_async_job_runner_status_updates_empty_job_list(concurrency):
70
+ async def test_async_job_runner_status_updates_empty_job_list(
71
+ concurrency, mock_async_run_job_fn_success
72
+ ):
50
73
  empty_job_list = []
51
74
 
52
- runner = AsyncJobRunner(concurrency=concurrency)
53
-
54
- # fake run_job that succeeds
55
- mock_run_job_success = AsyncMock(return_value=True)
75
+ runner = AsyncJobRunner(
76
+ concurrency=concurrency,
77
+ jobs=empty_job_list,
78
+ run_job_fn=mock_async_run_job_fn_success,
79
+ )
56
80
 
57
81
  updates: List[Progress] = []
58
- async for progress in runner.run(empty_job_list, mock_run_job_success):
82
+ async for progress in runner.run():
59
83
  updates.append(progress)
60
84
 
61
85
  # Verify last status update was complete
@@ -66,23 +90,26 @@ async def test_async_job_runner_status_updates_empty_job_list(concurrency):
66
90
  assert updates[0].total == 0
67
91
 
68
92
  # Verify run_job was called for each job
69
- assert mock_run_job_success.call_count == 0
93
+ assert mock_async_run_job_fn_success.call_count == 0
70
94
 
71
95
 
72
96
  @pytest.mark.parametrize("concurrency", [1, 25])
73
97
  @pytest.mark.asyncio
74
- async def test_async_job_runner_all_failures(concurrency):
98
+ async def test_async_job_runner_all_failures(
99
+ concurrency, mock_async_run_job_fn_failure
100
+ ):
75
101
  job_count = 50
76
102
  jobs = [{"id": i} for i in range(job_count)]
77
103
 
78
- runner = AsyncJobRunner(concurrency=concurrency)
79
-
80
- # fake run_job that fails
81
- mock_run_job_failure = AsyncMock(return_value=False)
104
+ runner = AsyncJobRunner(
105
+ concurrency=concurrency,
106
+ jobs=jobs,
107
+ run_job_fn=mock_async_run_job_fn_failure,
108
+ )
82
109
 
83
110
  # Expect the status updates in order, and 1 for each job
84
111
  expected_error_count = 0
85
- async for progress in runner.run(jobs, mock_run_job_failure):
112
+ async for progress in runner.run():
86
113
  assert progress.complete == 0
87
114
  assert progress.errors == expected_error_count
88
115
  expected_error_count += 1
@@ -92,11 +119,11 @@ async def test_async_job_runner_all_failures(concurrency):
92
119
  assert expected_error_count == job_count + 1
93
120
 
94
121
  # Verify run_job was called for each job
95
- assert mock_run_job_failure.call_count == job_count
122
+ assert mock_async_run_job_fn_failure.call_count == job_count
96
123
 
97
124
  # Verify run_job was called with the correct arguments
98
125
  for i in range(job_count):
99
- mock_run_job_failure.assert_any_await(jobs[i])
126
+ mock_async_run_job_fn_failure.assert_any_await(jobs[i])
100
127
 
101
128
 
102
129
  @pytest.mark.parametrize("concurrency", [1, 25])
@@ -108,16 +135,20 @@ async def test_async_job_runner_partial_failures(concurrency):
108
135
  # we want to fail on some jobs and succeed on others
109
136
  jobs_to_fail = set([0, 2, 4, 6, 8, 20, 25])
110
137
 
111
- runner = AsyncJobRunner(concurrency=concurrency)
112
-
113
138
  # fake run_job that fails
114
139
  mock_run_job_partial_success = AsyncMock(
115
140
  # return True for jobs that should succeed
116
141
  side_effect=lambda job: job["id"] not in jobs_to_fail
117
142
  )
118
143
 
144
+ runner = AsyncJobRunner(
145
+ concurrency=concurrency,
146
+ jobs=jobs,
147
+ run_job_fn=mock_run_job_partial_success,
148
+ )
149
+
119
150
  # Expect the status updates in order, and 1 for each job
120
- async for progress in runner.run(jobs, mock_run_job_partial_success):
151
+ async for progress in runner.run():
121
152
  assert progress.total == job_count
122
153
 
123
154
  # Verify last status update was complete
@@ -140,8 +171,6 @@ async def test_async_job_runner_partial_raises(concurrency):
140
171
  job_count = 50
141
172
  jobs = [{"id": i} for i in range(job_count)]
142
173
 
143
- runner = AsyncJobRunner(concurrency=concurrency)
144
-
145
174
  ids_to_fail = set([10, 25])
146
175
 
147
176
  def failure_fn(job):
@@ -152,6 +181,12 @@ async def test_async_job_runner_partial_raises(concurrency):
152
181
  # fake run_job that fails
153
182
  mock_run_job_partial_success = AsyncMock(side_effect=failure_fn)
154
183
 
184
+ runner = AsyncJobRunner(
185
+ concurrency=concurrency,
186
+ jobs=jobs,
187
+ run_job_fn=mock_run_job_partial_success,
188
+ )
189
+
155
190
  # generate all the values we expect to see in progress updates
156
191
  complete_values_expected = set([i for i in range(job_count - len(ids_to_fail) + 1)])
157
192
  errors_values_expected = set([i for i in range(len(ids_to_fail) + 1)])
@@ -164,7 +199,7 @@ async def test_async_job_runner_partial_raises(concurrency):
164
199
  errors_values_actual = set()
165
200
 
166
201
  # Expect the status updates in order, and 1 for each job
167
- async for progress in runner.run(jobs, mock_run_job_partial_success):
202
+ async for progress in runner.run():
168
203
  updates.append(progress)
169
204
  complete_values_actual.add(progress.complete)
170
205
  errors_values_actual.add(progress.errors)
@@ -184,9 +219,13 @@ async def test_async_job_runner_partial_raises(concurrency):
184
219
 
185
220
  @pytest.mark.parametrize("concurrency", [1, 25])
186
221
  @pytest.mark.asyncio
187
- async def test_async_job_runner_cancelled(concurrency):
188
- runner = AsyncJobRunner(concurrency=concurrency)
222
+ async def test_async_job_runner_cancelled(concurrency, mock_async_run_job_fn_success):
189
223
  jobs = [{"id": i} for i in range(10)]
224
+ runner = AsyncJobRunner(
225
+ concurrency=concurrency,
226
+ jobs=jobs,
227
+ run_job_fn=mock_async_run_job_fn_success,
228
+ )
190
229
 
191
230
  with patch.object(
192
231
  runner,
@@ -195,5 +234,82 @@ async def test_async_job_runner_cancelled(concurrency):
195
234
  ):
196
235
  # if an exception is raised in the task, we should see it bubble up
197
236
  with pytest.raises(Exception, match="run_worker raised an exception"):
198
- async for _ in runner.run(jobs, AsyncMock(return_value=True)):
237
+ async for _ in runner.run():
199
238
  pass
239
+
240
+
241
+ @pytest.mark.parametrize("concurrency", [1, 25])
242
+ @pytest.mark.asyncio
243
+ async def test_async_job_runner_observers(concurrency):
244
+ class MockAsyncJobRunnerObserver(AsyncJobRunnerObserver[dict[str, int]]):
245
+ def __init__(self):
246
+ self.on_error_calls = []
247
+ self.on_success_calls = []
248
+
249
+ async def on_error(self, job: dict[str, int], error: Exception):
250
+ self.on_error_calls.append((job, error))
251
+
252
+ async def on_success(self, job: dict[str, int]):
253
+ self.on_success_calls.append(job)
254
+
255
+ mock_observer_a = MockAsyncJobRunnerObserver()
256
+ mock_observer_b = MockAsyncJobRunnerObserver()
257
+
258
+ jobs = [{"id": i} for i in range(10)]
259
+
260
+ async def run_job_fn(job: dict[str, int]) -> bool:
261
+ # we simulate the job 5 and 6 crashing, which should trigger the observers on_error handlers
262
+ if job["id"] == 5 or job["id"] == 6:
263
+ raise ValueError(f"job failed unexpectedly {job['id']}")
264
+ return True
265
+
266
+ runner = AsyncJobRunner(
267
+ concurrency=concurrency,
268
+ jobs=jobs,
269
+ run_job_fn=run_job_fn,
270
+ observers=[mock_observer_a, mock_observer_b],
271
+ )
272
+
273
+ async for _ in runner.run():
274
+ pass
275
+
276
+ assert len(mock_observer_a.on_error_calls) == 2
277
+ assert len(mock_observer_b.on_error_calls) == 2
278
+
279
+ # not necessarily in order, but we should have seen both 5 and 6
280
+ assert len(mock_observer_a.on_success_calls) == 8
281
+ assert len(mock_observer_b.on_success_calls) == 8
282
+
283
+ # check that 5 and 6 are in the error calls
284
+ for job_idx in [5, 6]:
285
+ # check that 5 and 6 are in the error calls for both observers
286
+ assert any(call[0] == jobs[job_idx] for call in mock_observer_a.on_error_calls)
287
+ assert any(call[0] == jobs[job_idx] for call in mock_observer_b.on_error_calls)
288
+
289
+ # check that the error is the correct exception
290
+ assert (
291
+ str(mock_observer_a.on_error_calls[0][1]) == "job failed unexpectedly 5"
292
+ or str(mock_observer_a.on_error_calls[1][1]) == "job failed unexpectedly 6"
293
+ )
294
+ assert (
295
+ str(mock_observer_b.on_error_calls[0][1]) == "job failed unexpectedly 5"
296
+ or str(mock_observer_b.on_error_calls[1][1]) == "job failed unexpectedly 6"
297
+ )
298
+
299
+ # check that 5 and 6 are not in the success calls for both observers
300
+ assert not any(
301
+ call == jobs[job_idx] for call in mock_observer_a.on_success_calls
302
+ )
303
+ assert not any(
304
+ call == jobs[job_idx] for call in mock_observer_b.on_success_calls
305
+ )
306
+
307
+ # check that the other jobs are in the success calls for both observers
308
+ for job_idx in range(10):
309
+ if job_idx not in [5, 6]:
310
+ assert any(
311
+ call == jobs[job_idx] for call in mock_observer_a.on_success_calls
312
+ )
313
+ assert any(
314
+ call == jobs[job_idx] for call in mock_observer_b.on_success_calls
315
+ )
@@ -0,0 +1,142 @@
1
+ import os
2
+
3
+ import pytest
4
+
5
+ from kiln_ai.utils.env import temporary_env
6
+
7
+
8
+ class TestTemporaryEnv:
9
+ def test_set_new_env_var(self):
10
+ """Test setting a new environment variable that doesn't exist."""
11
+ var_name = "TEST_NEW_VAR"
12
+ test_value = "test_value"
13
+
14
+ # Ensure the variable doesn't exist initially
15
+ assert var_name not in os.environ
16
+
17
+ with temporary_env(var_name, test_value):
18
+ assert os.environ[var_name] == test_value
19
+
20
+ # Verify it's removed after context
21
+ assert var_name not in os.environ
22
+
23
+ def test_modify_existing_env_var(self):
24
+ """Test modifying an existing environment variable."""
25
+ var_name = "TEST_EXISTING_VAR"
26
+ original_value = "original_value"
27
+ new_value = "new_value"
28
+
29
+ # Set up initial state
30
+ os.environ[var_name] = original_value
31
+
32
+ with temporary_env(var_name, new_value):
33
+ assert os.environ[var_name] == new_value
34
+
35
+ # Verify original value is restored
36
+ assert os.environ[var_name] == original_value
37
+
38
+ def test_restore_nonexistent_var(self):
39
+ """Test that a variable that didn't exist is properly removed."""
40
+ var_name = "TEST_NONEXISTENT_VAR"
41
+ test_value = "test_value"
42
+
43
+ # Ensure the variable doesn't exist initially
44
+ if var_name in os.environ:
45
+ del os.environ[var_name]
46
+
47
+ with temporary_env(var_name, test_value):
48
+ assert os.environ[var_name] == test_value
49
+
50
+ # Verify it's removed after context
51
+ assert var_name not in os.environ
52
+
53
+ def test_exception_handling(self):
54
+ """Test that environment is restored even when an exception occurs."""
55
+ var_name = "TEST_EXCEPTION_VAR"
56
+ original_value = "original_value"
57
+ new_value = "new_value"
58
+
59
+ # Set up initial state
60
+ os.environ[var_name] = original_value
61
+
62
+ with pytest.raises(ValueError):
63
+ with temporary_env(var_name, new_value):
64
+ assert os.environ[var_name] == new_value
65
+ raise ValueError("Test exception")
66
+
67
+ # Verify original value is restored even after exception
68
+ assert os.environ[var_name] == original_value
69
+
70
+ def test_exception_handling_new_var(self):
71
+ """Test that new variable is removed even when an exception occurs."""
72
+ var_name = "TEST_EXCEPTION_NEW_VAR"
73
+ test_value = "test_value"
74
+
75
+ # Ensure the variable doesn't exist initially
76
+ if var_name in os.environ:
77
+ del os.environ[var_name]
78
+
79
+ with pytest.raises(RuntimeError):
80
+ with temporary_env(var_name, test_value):
81
+ assert os.environ[var_name] == test_value
82
+ raise RuntimeError("Test exception")
83
+
84
+ # Verify variable is removed even after exception
85
+ assert var_name not in os.environ
86
+
87
+ def test_nested_context_managers(self):
88
+ """Test using multiple temporary_env context managers."""
89
+ var1 = "TEST_NESTED_VAR1"
90
+ var2 = "TEST_NESTED_VAR2"
91
+ value1 = "value1"
92
+ value2 = "value2"
93
+
94
+ # Set up initial state
95
+ os.environ[var1] = "original1"
96
+ if var2 in os.environ:
97
+ del os.environ[var2]
98
+
99
+ with temporary_env(var1, value1):
100
+ assert os.environ[var1] == value1
101
+
102
+ with temporary_env(var2, value2):
103
+ assert os.environ[var1] == value1
104
+ assert os.environ[var2] == value2
105
+
106
+ # Inner context should be cleaned up
107
+ assert var2 not in os.environ
108
+ assert os.environ[var1] == value1
109
+
110
+ # Both contexts should be cleaned up
111
+ assert os.environ[var1] == "original1"
112
+ assert var2 not in os.environ
113
+
114
+ def test_empty_string_value(self):
115
+ """Test setting an empty string value."""
116
+ var_name = "TEST_EMPTY_VAR"
117
+ test_value = ""
118
+
119
+ with temporary_env(var_name, test_value):
120
+ assert os.environ[var_name] == test_value
121
+
122
+ assert var_name not in os.environ
123
+
124
+ def test_none_value_handling(self):
125
+ """Test that None values are handled properly."""
126
+ var_name = "TEST_NONE_VAR"
127
+ test_value = "test_value"
128
+
129
+ with temporary_env(var_name, test_value):
130
+ assert os.environ[var_name] == test_value
131
+
132
+ assert var_name not in os.environ
133
+
134
+ def test_unicode_value(self):
135
+ """Test setting unicode values."""
136
+ var_name = "TEST_UNICODE_VAR"
137
+ test_value = "测试值 🚀"
138
+
139
+ with temporary_env(var_name, test_value):
140
+ assert os.environ[var_name] == test_value
141
+
142
+ assert var_name not in os.environ