batchgenerators 0.25.2__tar.gz → 0.25.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {batchgenerators-0.25.2 → batchgenerators-0.25.3}/PKG-INFO +1 -1
- {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/dataloading/multi_threaded_augmenter.py +139 -79
- {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/dataloading/nondet_multi_threaded_augmenter.py +131 -56
- {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators.egg-info/PKG-INFO +1 -1
- {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators.egg-info/SOURCES.txt +0 -0
- {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators.egg-info/dependency_links.txt +0 -0
- {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators.egg-info/requires.txt +0 -0
- {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators.egg-info/top_level.txt +0 -0
- {batchgenerators-0.25.2 → batchgenerators-0.25.3}/pyproject.toml +1 -1
- {batchgenerators-0.25.2 → batchgenerators-0.25.3}/tests/test_DataLoader.py +1 -1
- {batchgenerators-0.25.2 → batchgenerators-0.25.3}/tests/test_axis_mirroring.py +1 -2
- {batchgenerators-0.25.2 → batchgenerators-0.25.3}/tests/test_resample_augmentations.py +2 -2
- {batchgenerators-0.25.2 → batchgenerators-0.25.3}/LICENSE +0 -0
- {batchgenerators-0.25.2 → batchgenerators-0.25.3}/README.md +0 -0
- {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/__init__.py +0 -0
- {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/augmentations/__init__.py +0 -0
- {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/augmentations/color_augmentations.py +0 -0
- {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/augmentations/crop_and_pad_augmentations.py +0 -0
- {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/augmentations/noise_augmentations.py +0 -0
- {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/augmentations/normalizations.py +0 -0
- {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/augmentations/resample_augmentations.py +0 -0
- {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/augmentations/spatial_transformations.py +0 -0
- {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/augmentations/utils.py +0 -0
- {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/dataloading/__init__.py +0 -0
- {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/dataloading/data_loader.py +0 -0
- {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/dataloading/dataset.py +0 -0
- {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/dataloading/single_threaded_augmenter.py +0 -0
- {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/datasets/__init__.py +0 -0
- {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/datasets/cifar.py +0 -0
- {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/examples/__init__.py +0 -0
- {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/examples/brats2017/__init__.py +0 -0
- {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/examples/brats2017/brats2017_dataloader_2D.py +0 -0
- {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/examples/brats2017/brats2017_dataloader_3D.py +0 -0
- {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/examples/brats2017/brats2017_preprocessing.py +0 -0
- {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/examples/brats2017/config.py +0 -0
- {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/examples/cifar10.py +0 -0
- {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/examples/multithreaded_dataloading.py +0 -0
- {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/transforms/__init__.py +0 -0
- {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/transforms/abstract_transforms.py +0 -0
- {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/transforms/channel_selection_transforms.py +0 -0
- {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/transforms/color_transforms.py +0 -0
- {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/transforms/crop_and_pad_transforms.py +0 -0
- {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/transforms/local_transforms.py +0 -0
- {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/transforms/noise_transforms.py +0 -0
- {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/transforms/resample_transforms.py +0 -0
- {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/transforms/sample_normalization_transforms.py +0 -0
- {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/transforms/spatial_transforms.py +0 -0
- {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/transforms/utility_transforms.py +0 -0
- {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/utilities/__init__.py +0 -0
- {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/utilities/custom_types.py +0 -0
- {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/utilities/data_splitting.py +0 -0
- {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/utilities/file_and_folder_operations.py +0 -0
- {batchgenerators-0.25.2 → batchgenerators-0.25.3}/setup.cfg +0 -0
- {batchgenerators-0.25.2 → batchgenerators-0.25.3}/setup.py +0 -0
- {batchgenerators-0.25.2 → batchgenerators-0.25.3}/tests/test_augment_zoom.py +0 -0
- {batchgenerators-0.25.2 → batchgenerators-0.25.3}/tests/test_color_augmentations.py +0 -0
- {batchgenerators-0.25.2 → batchgenerators-0.25.3}/tests/test_crop.py +0 -0
- {batchgenerators-0.25.2 → batchgenerators-0.25.3}/tests/test_multithreaded_augmenter.py +0 -0
- {batchgenerators-0.25.2 → batchgenerators-0.25.3}/tests/test_normalizations.py +0 -0
- {batchgenerators-0.25.2 → batchgenerators-0.25.3}/tests/test_random_crop.py +0 -0
- {batchgenerators-0.25.2 → batchgenerators-0.25.3}/tests/test_sanity.py +0 -0
- {batchgenerators-0.25.2 → batchgenerators-0.25.3}/tests/test_spatial_transformations.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: batchgenerators
|
|
3
|
-
Version: 0.25.
|
|
3
|
+
Version: 0.25.3
|
|
4
4
|
Summary: Data augmentation toolkit
|
|
5
5
|
Author-email: "Division of Medical Image Computing, German Cancer Research Center AND Applied Computer Vision Lab, Helmholtz Imaging Platform" <f.isensee@dkfz-heidelberg.de>
|
|
6
6
|
License: Apache License
|
|
@@ -17,7 +17,7 @@ import traceback
|
|
|
17
17
|
from typing import List, Union
|
|
18
18
|
import threading
|
|
19
19
|
from multiprocessing import Process, Queue
|
|
20
|
-
from queue import Queue as thrQueue
|
|
20
|
+
from queue import Queue as thrQueue, Full, Empty
|
|
21
21
|
import numpy as np
|
|
22
22
|
import sys
|
|
23
23
|
import logging
|
|
@@ -25,38 +25,42 @@ from multiprocessing import Event
|
|
|
25
25
|
from time import sleep, time
|
|
26
26
|
from threadpoolctl import threadpool_limits
|
|
27
27
|
|
|
28
|
+
from batchgenerators.dataloading.nondet_multi_threaded_augmenter import pin_memory_of_all_eligible_items_in_dict
|
|
29
|
+
|
|
28
30
|
try:
|
|
29
31
|
import torch
|
|
30
32
|
except ImportError:
|
|
31
33
|
torch = None
|
|
32
34
|
|
|
33
35
|
|
|
34
|
-
def producer(queue, data_loader, transform, thread_id, seed, abort_event,
|
|
36
|
+
def producer(queue, data_loader, transform, thread_id, seed, abort_event,
|
|
37
|
+
pause_event=None, wait_time: float = 0.02):
|
|
35
38
|
np.random.seed(seed)
|
|
36
39
|
data_loader.set_thread_id(thread_id)
|
|
37
40
|
item = None
|
|
38
41
|
|
|
39
42
|
try:
|
|
40
|
-
while
|
|
41
|
-
#
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
item =
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
43
|
+
while not abort_event.is_set():
|
|
44
|
+
# When paused, hold any in-flight item and stop interacting with the
|
|
45
|
+
# queue. This guarantees no new bytes hit the pipe after pause is
|
|
46
|
+
# observed, which is what _finish() relies on for safe shutdown.
|
|
47
|
+
if pause_event is not None and pause_event.is_set():
|
|
48
|
+
sleep(wait_time)
|
|
49
|
+
continue
|
|
50
|
+
|
|
51
|
+
if item is None:
|
|
52
|
+
try:
|
|
53
|
+
item = next(data_loader)
|
|
54
|
+
if transform is not None:
|
|
55
|
+
item = transform(**item)
|
|
56
|
+
except StopIteration:
|
|
57
|
+
item = "end"
|
|
58
|
+
|
|
59
|
+
try:
|
|
60
|
+
queue.put(item, timeout=wait_time)
|
|
61
|
+
item = None
|
|
62
|
+
except Full:
|
|
63
|
+
pass # loop; abort/pause are re-checked at top
|
|
60
64
|
except KeyboardInterrupt:
|
|
61
65
|
abort_event.set()
|
|
62
66
|
return
|
|
@@ -86,44 +90,41 @@ def results_loop(in_queues: List[Queue], out_queue: thrQueue, abort_event: Event
|
|
|
86
90
|
if abort_event.is_set():
|
|
87
91
|
return
|
|
88
92
|
|
|
89
|
-
#
|
|
90
|
-
|
|
93
|
+
# Check that all workers are still alive — but only when we haven't
|
|
94
|
+
# started a graceful shutdown. During shutdown workers exit cleanly,
|
|
95
|
+
# which would otherwise trip the RuntimeError below.
|
|
96
|
+
if not abort_event.is_set() and not all([i.is_alive() for i in worker_list]):
|
|
91
97
|
abort_event.set()
|
|
92
98
|
raise RuntimeError("One or more background workers are no longer alive. Exiting. Please check the print"
|
|
93
99
|
" statements above for the actual error message")
|
|
94
100
|
|
|
95
|
-
# if we don't have an item we need to fetch it first.
|
|
96
|
-
#
|
|
101
|
+
# if we don't have an item we need to fetch it first. Round-robin across the worker queues to keep the
|
|
102
|
+
# batch ordering deterministic; block on the current queue for up to wait_time so we don't busy-wait.
|
|
97
103
|
if item is None:
|
|
98
104
|
current_queue = in_queues[queue_ctr % len(in_queues)]
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
#
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
else:
|
|
117
|
-
sleep(wait_time)
|
|
118
|
-
continue
|
|
105
|
+
try:
|
|
106
|
+
item = current_queue.get(timeout=wait_time)
|
|
107
|
+
except Empty:
|
|
108
|
+
continue # retry the same queue; abort_event re-checked at top of loop
|
|
109
|
+
# if we do pin memory, do it now, otherwise skip this. The isinstance(dict) guard keeps the 'end'
|
|
110
|
+
# string sentinel from reaching the pinning logic.
|
|
111
|
+
if do_pin_memory:
|
|
112
|
+
if isinstance(item, dict):
|
|
113
|
+
item = pin_memory_of_all_eligible_items_in_dict(item)
|
|
114
|
+
queue_ctr += 1
|
|
115
|
+
|
|
116
|
+
if isinstance(item, str) and item == 'end':
|
|
117
|
+
end_ctr += 1
|
|
118
|
+
if end_ctr == len(in_queues):
|
|
119
|
+
end_ctr = 0
|
|
120
|
+
queue_ctr = 0
|
|
119
121
|
|
|
120
122
|
# we only arrive here if item is not None. Now put item in to the out_queue
|
|
121
|
-
|
|
122
|
-
out_queue.put(item)
|
|
123
|
+
try:
|
|
124
|
+
out_queue.put(item, timeout=wait_time)
|
|
123
125
|
item = None
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
continue
|
|
126
|
+
except Full:
|
|
127
|
+
continue # abort_event is re-checked at top of loop
|
|
127
128
|
except KeyboardInterrupt:
|
|
128
129
|
abort_event.set()
|
|
129
130
|
raise KeyboardInterrupt
|
|
@@ -171,6 +172,7 @@ class MultiThreadedAugmenter(object):
|
|
|
171
172
|
self.pin_memory_thread = None
|
|
172
173
|
self.pin_memory_queue = None
|
|
173
174
|
self.abort_event = Event()
|
|
175
|
+
self.pause_event = Event()
|
|
174
176
|
self.wait_time = wait_time
|
|
175
177
|
self.was_initialized = False
|
|
176
178
|
|
|
@@ -181,20 +183,16 @@ class MultiThreadedAugmenter(object):
|
|
|
181
183
|
return self.__next__()
|
|
182
184
|
|
|
183
185
|
def __get_next_item(self):
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
while item is None:
|
|
186
|
+
while True:
|
|
187
187
|
if self.abort_event.is_set():
|
|
188
188
|
self._finish()
|
|
189
189
|
raise RuntimeError("One or more background workers are no longer alive. Exiting. Please check the "
|
|
190
190
|
"print statements above for the actual error message")
|
|
191
191
|
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
return item
|
|
192
|
+
try:
|
|
193
|
+
return self.pin_memory_queue.get(timeout=self.wait_time)
|
|
194
|
+
except Empty:
|
|
195
|
+
continue # abort_event re-checked at top of loop
|
|
198
196
|
|
|
199
197
|
def __next__(self):
|
|
200
198
|
if not self.was_initialized:
|
|
@@ -225,6 +223,7 @@ class MultiThreadedAugmenter(object):
|
|
|
225
223
|
if not self.was_initialized:
|
|
226
224
|
self._finish()
|
|
227
225
|
self.abort_event.clear()
|
|
226
|
+
self.pause_event.clear()
|
|
228
227
|
|
|
229
228
|
logging.debug("starting workers")
|
|
230
229
|
self._queue_ctr = 0
|
|
@@ -237,7 +236,8 @@ class MultiThreadedAugmenter(object):
|
|
|
237
236
|
for i in range(self.num_processes):
|
|
238
237
|
self._queues.append(Queue(self.num_cached_per_queue))
|
|
239
238
|
self._processes.append(Process(target=producer, args=(
|
|
240
|
-
self._queues[i], self.generator, self.transform, i, self.seeds[i],
|
|
239
|
+
self._queues[i], self.generator, self.transform, i, self.seeds[i],
|
|
240
|
+
self.abort_event, self.pause_event, self.wait_time)))
|
|
241
241
|
self._processes[-1].daemon = True
|
|
242
242
|
self._processes[-1].start()
|
|
243
243
|
|
|
@@ -261,28 +261,86 @@ class MultiThreadedAugmenter(object):
|
|
|
261
261
|
logging.debug("MultiThreadedGenerator Warning: start() has been called but it has already been "
|
|
262
262
|
"initialized previously")
|
|
263
263
|
|
|
264
|
-
def _finish(self, timeout=10):
|
|
265
|
-
|
|
264
|
+
def _finish(self, timeout=10, force=False):
|
|
265
|
+
"""Shut down workers and the pin-memory thread.
|
|
266
266
|
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
267
|
+
Graceful path (force=False):
|
|
268
|
+
pause producers -> abort everyone -> join pin_memory_thread ->
|
|
269
|
+
drain pin_memory_queue -> drain _queues while joining workers ->
|
|
270
|
+
terminate stragglers -> final drain -> close queues.
|
|
270
271
|
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
272
|
+
Force path (force=True): skip the producer pause and rely on
|
|
273
|
+
abort + terminate fallback. Used by __del__ where the interpreter
|
|
274
|
+
may already be tearing down multiprocessing.
|
|
275
|
+
"""
|
|
276
|
+
if not self.was_initialized and len(self._processes) == 0:
|
|
277
|
+
return
|
|
274
278
|
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
279
|
+
# 1. Stop producers from initiating new puts. Existing in-flight puts
|
|
280
|
+
# are still allowed to land; we drain them below.
|
|
281
|
+
if not force and self.pause_event is not None:
|
|
282
|
+
self.pause_event.set()
|
|
278
283
|
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
+
# 2. Signal everyone to exit.
|
|
285
|
+
self.abort_event.set()
|
|
286
|
+
|
|
287
|
+
# 3. Join pin_memory_thread first so it stops refilling
|
|
288
|
+
# pin_memory_queue before we drain that queue.
|
|
289
|
+
if self.pin_memory_thread is not None:
|
|
290
|
+
self.pin_memory_thread.join(timeout=timeout)
|
|
291
|
+
|
|
292
|
+
# 4. Drain pin_memory_queue. Safe now: pin_memory_thread is dead and
|
|
293
|
+
# producers cannot reach it. Dropping items releases torch
|
|
294
|
+
# shared-memory refs via refcount.
|
|
295
|
+
if self.pin_memory_queue is not None:
|
|
296
|
+
while not self.pin_memory_queue.empty():
|
|
297
|
+
try:
|
|
298
|
+
self.pin_memory_queue.get_nowait()
|
|
299
|
+
except Exception:
|
|
300
|
+
break
|
|
301
|
+
|
|
302
|
+
# 5. Drain _queues in a loop while joining workers. Workers' Queue
|
|
303
|
+
# feeder threads need pipe space to flush before the worker
|
|
304
|
+
# process can fully exit; draining concurrently unblocks them.
|
|
305
|
+
if len(self._processes) > 0:
|
|
306
|
+
logging.debug("MultiThreadedGenerator: shutting down workers...")
|
|
307
|
+
deadline = time() + timeout
|
|
308
|
+
drain_tick = max(self.wait_time, 0.01)
|
|
309
|
+
while time() < deadline and any(p.is_alive() for p in self._processes):
|
|
310
|
+
for q in self._queues:
|
|
311
|
+
while not q.empty():
|
|
312
|
+
try:
|
|
313
|
+
q.get_nowait()
|
|
314
|
+
except Exception:
|
|
315
|
+
break
|
|
316
|
+
sleep(drain_tick)
|
|
317
|
+
|
|
318
|
+
# 6. Anyone still alive past the deadline gets terminated.
|
|
319
|
+
for p in self._processes:
|
|
320
|
+
if p.is_alive():
|
|
321
|
+
p.terminate()
|
|
322
|
+
p.join(timeout=1.0)
|
|
323
|
+
|
|
324
|
+
# 7. Final drain — catches residual bytes that feeder threads
|
|
325
|
+
# pushed onto the pipe during their own shutdown.
|
|
326
|
+
for q in self._queues:
|
|
327
|
+
while not q.empty():
|
|
328
|
+
try:
|
|
329
|
+
q.get_nowait()
|
|
330
|
+
except Exception:
|
|
331
|
+
break
|
|
284
332
|
|
|
285
|
-
|
|
333
|
+
# 8. Close queues.
|
|
334
|
+
for q in self._queues:
|
|
335
|
+
q.close()
|
|
336
|
+
q.join_thread()
|
|
337
|
+
|
|
338
|
+
self._queues = []
|
|
339
|
+
self._processes = []
|
|
340
|
+
self.pin_memory_queue = None
|
|
341
|
+
self.pin_memory_thread = None
|
|
342
|
+
self._end_ctr = 0
|
|
343
|
+
self._queue_ctr = 0
|
|
286
344
|
self.was_initialized = False
|
|
287
345
|
|
|
288
346
|
def restart(self):
|
|
@@ -291,4 +349,6 @@ class MultiThreadedAugmenter(object):
|
|
|
291
349
|
|
|
292
350
|
def __del__(self):
|
|
293
351
|
logging.debug("MultiThreadedGenerator: destructor was called")
|
|
294
|
-
|
|
352
|
+
# Interpreter shutdown may have already torn down parts of
|
|
353
|
+
# multiprocessing; take the fast path with a short timeout.
|
|
354
|
+
self._finish(timeout=2, force=True)
|
|
@@ -15,12 +15,12 @@
|
|
|
15
15
|
|
|
16
16
|
import traceback
|
|
17
17
|
from copy import deepcopy
|
|
18
|
-
from typing import List, Union
|
|
18
|
+
from typing import List, Union
|
|
19
19
|
import threading
|
|
20
20
|
from builtins import range
|
|
21
21
|
from multiprocessing import Process
|
|
22
22
|
from multiprocessing import Queue
|
|
23
|
-
from queue import Queue as thrQueue
|
|
23
|
+
from queue import Queue as thrQueue, Full, Empty
|
|
24
24
|
import numpy as np
|
|
25
25
|
import logging
|
|
26
26
|
from multiprocessing import Event
|
|
@@ -36,7 +36,7 @@ except ImportError:
|
|
|
36
36
|
|
|
37
37
|
|
|
38
38
|
def producer(queue: Queue, data_loader, transform, thread_id: int, seed,
|
|
39
|
-
abort_event: Event, wait_time: float = 0.02):
|
|
39
|
+
abort_event: Event, pause_event=None, wait_time: float = 0.02):
|
|
40
40
|
if torch is not None:
|
|
41
41
|
torch.set_num_threads(1)
|
|
42
42
|
if seed is not None:
|
|
@@ -50,24 +50,26 @@ def producer(queue: Queue, data_loader, transform, thread_id: int, seed,
|
|
|
50
50
|
item = None
|
|
51
51
|
|
|
52
52
|
try:
|
|
53
|
-
while
|
|
53
|
+
while not abort_event.is_set():
|
|
54
|
+
# When paused, stop producing and stop putting. This is the
|
|
55
|
+
# handshake _finish() uses to drain queues safely.
|
|
56
|
+
if pause_event is not None and pause_event.is_set():
|
|
57
|
+
sleep(wait_time)
|
|
58
|
+
continue
|
|
59
|
+
|
|
60
|
+
if item is None:
|
|
61
|
+
item = next(data_loader)
|
|
62
|
+
if transform is not None:
|
|
63
|
+
item = transform(**item)
|
|
54
64
|
|
|
55
65
|
if abort_event.is_set():
|
|
56
66
|
return
|
|
57
|
-
else:
|
|
58
|
-
if item is None:
|
|
59
|
-
item = next(data_loader)
|
|
60
|
-
if transform is not None:
|
|
61
|
-
item = transform(**item)
|
|
62
|
-
|
|
63
|
-
if abort_event.is_set():
|
|
64
|
-
return
|
|
65
67
|
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
68
|
+
try:
|
|
69
|
+
queue.put(item, timeout=wait_time)
|
|
70
|
+
item = None
|
|
71
|
+
except Full:
|
|
72
|
+
pass # loop; abort/pause re-checked at top
|
|
71
73
|
|
|
72
74
|
except KeyboardInterrupt:
|
|
73
75
|
abort_event.set()
|
|
@@ -80,10 +82,29 @@ def producer(queue: Queue, data_loader, transform, thread_id: int, seed,
|
|
|
80
82
|
return
|
|
81
83
|
|
|
82
84
|
|
|
85
|
+
def _pin_memory_recursive(item):
|
|
86
|
+
# Pin every torch.Tensor reachable through nested dicts/lists/tuples. This matters for things like deep
|
|
87
|
+
# supervision targets, which are lists of tensors: without recursing into the list those tensors stay unpinned
|
|
88
|
+
# and the subsequent .to(device, non_blocking=True) silently falls back to a blocking copy.
|
|
89
|
+
if isinstance(item, torch.Tensor):
|
|
90
|
+
return item.pin_memory()
|
|
91
|
+
elif isinstance(item, dict):
|
|
92
|
+
for k in item.keys():
|
|
93
|
+
item[k] = _pin_memory_recursive(item[k])
|
|
94
|
+
return item
|
|
95
|
+
elif isinstance(item, list):
|
|
96
|
+
for i in range(len(item)):
|
|
97
|
+
item[i] = _pin_memory_recursive(item[i])
|
|
98
|
+
return item
|
|
99
|
+
elif isinstance(item, tuple):
|
|
100
|
+
return tuple(_pin_memory_recursive(i) for i in item)
|
|
101
|
+
else:
|
|
102
|
+
return item
|
|
103
|
+
|
|
104
|
+
|
|
83
105
|
def pin_memory_of_all_eligible_items_in_dict(result_dict):
|
|
84
106
|
for k in result_dict.keys():
|
|
85
|
-
|
|
86
|
-
result_dict[k] = result_dict[k].pin_memory()
|
|
107
|
+
result_dict[k] = _pin_memory_recursive(result_dict[k])
|
|
87
108
|
return result_dict
|
|
88
109
|
|
|
89
110
|
|
|
@@ -103,28 +124,28 @@ def results_loop(in_queue: Queue, out_queue: thrQueue, abort_event: Event,
|
|
|
103
124
|
if abort_event.is_set():
|
|
104
125
|
return
|
|
105
126
|
|
|
106
|
-
# check
|
|
107
|
-
|
|
127
|
+
# Check workers, but skip the check once a graceful shutdown is in
|
|
128
|
+
# progress — workers exit cleanly then and would otherwise trip
|
|
129
|
+
# this RuntimeError.
|
|
130
|
+
if not abort_event.is_set() and not all([i.is_alive() for i in worker_list]):
|
|
108
131
|
abort_event.set()
|
|
109
132
|
raise RuntimeError("One or more background workers are no longer alive. Exiting. Please check the "
|
|
110
133
|
"print statements above for the actual error message")
|
|
111
134
|
|
|
112
135
|
if item is None:
|
|
113
|
-
|
|
114
|
-
item = in_queue.get()
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
continue
|
|
136
|
+
try:
|
|
137
|
+
item = in_queue.get(timeout=wait_time)
|
|
138
|
+
except Empty:
|
|
139
|
+
continue # abort_event/worker liveness re-checked at top of loop
|
|
140
|
+
if do_pin_memory:
|
|
141
|
+
item = pin_memory_of_all_eligible_items_in_dict(item)
|
|
120
142
|
|
|
121
143
|
# we only arrive here if item is not None. Now put item in to the out_queue
|
|
122
|
-
|
|
123
|
-
out_queue.put(item)
|
|
144
|
+
try:
|
|
145
|
+
out_queue.put(item, timeout=wait_time)
|
|
124
146
|
item = None
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
continue
|
|
147
|
+
except Full:
|
|
148
|
+
continue # abort_event is re-checked at top of loop
|
|
128
149
|
|
|
129
150
|
except Exception as e:
|
|
130
151
|
abort_event.set()
|
|
@@ -145,7 +166,7 @@ class NonDetMultiThreadedAugmenter(object):
|
|
|
145
166
|
"""
|
|
146
167
|
|
|
147
168
|
def __init__(self, data_loader, transform, num_processes, num_cached=2, seeds=None, pin_memory=False,
|
|
148
|
-
wait_time=0.02
|
|
169
|
+
wait_time=0.02):
|
|
149
170
|
self.pin_memory = pin_memory
|
|
150
171
|
self.transform = transform
|
|
151
172
|
self.num_cached = num_cached
|
|
@@ -157,10 +178,10 @@ class NonDetMultiThreadedAugmenter(object):
|
|
|
157
178
|
|
|
158
179
|
self._queue = None
|
|
159
180
|
self._processes = []
|
|
160
|
-
self.results_loop_fn = results_loop
|
|
161
181
|
self.results_loop_thread = None
|
|
162
182
|
self.results_loop_queue = None
|
|
163
183
|
self.abort_event = None
|
|
184
|
+
self.pause_event = None
|
|
164
185
|
self.initialized = False
|
|
165
186
|
|
|
166
187
|
self.wait_time = wait_time
|
|
@@ -178,23 +199,19 @@ class NonDetMultiThreadedAugmenter(object):
|
|
|
178
199
|
return self.__next__()
|
|
179
200
|
|
|
180
201
|
def __get_next_item(self):
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
while item is None:
|
|
184
|
-
#
|
|
202
|
+
while True:
|
|
185
203
|
if self.abort_event.is_set():
|
|
186
|
-
#
|
|
204
|
+
# the results loop checks for dead workers and will set the abort event if necessary
|
|
187
205
|
self._finish()
|
|
188
206
|
raise RuntimeError("One or more background workers are no longer alive. Exiting. Please check the "
|
|
189
207
|
"print statements above for the actual error message")
|
|
190
208
|
|
|
191
|
-
|
|
192
|
-
item = self.results_loop_queue.get()
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
return item
|
|
209
|
+
try:
|
|
210
|
+
item = self.results_loop_queue.get(timeout=self.wait_time)
|
|
211
|
+
except Empty:
|
|
212
|
+
continue # abort_event re-checked at top of loop
|
|
213
|
+
self.results_loop_queue.task_done()
|
|
214
|
+
return item
|
|
198
215
|
|
|
199
216
|
def __next__(self):
|
|
200
217
|
if not self.initialized:
|
|
@@ -210,6 +227,7 @@ class NonDetMultiThreadedAugmenter(object):
|
|
|
210
227
|
self._queue = Queue(self.num_cached)
|
|
211
228
|
self.results_loop_queue = thrQueue(self.num_cached)
|
|
212
229
|
self.abort_event = Event()
|
|
230
|
+
self.pause_event = Event()
|
|
213
231
|
|
|
214
232
|
logging.debug("starting workers")
|
|
215
233
|
if isinstance(self.generator, DataLoader):
|
|
@@ -218,7 +236,8 @@ class NonDetMultiThreadedAugmenter(object):
|
|
|
218
236
|
|
|
219
237
|
for i in range(self.num_processes):
|
|
220
238
|
self._processes.append(Process(target=producer, args=(
|
|
221
|
-
self._queue, self.generator, self.transform, i, self.seeds[i],
|
|
239
|
+
self._queue, self.generator, self.transform, i, self.seeds[i],
|
|
240
|
+
self.abort_event, self.pause_event, self.wait_time
|
|
222
241
|
)))
|
|
223
242
|
self._processes[-1].daemon = True
|
|
224
243
|
_ = [i.start() for i in self._processes]
|
|
@@ -230,7 +249,7 @@ class NonDetMultiThreadedAugmenter(object):
|
|
|
230
249
|
|
|
231
250
|
# in_queue: Queue, out_queue: thrQueue, abort_event: Event, pin_memory: bool, worker_list: List[Process],
|
|
232
251
|
# gpu: Union[int, None] = None, wait_time: float = 0.02
|
|
233
|
-
self.results_loop_thread = threading.Thread(target=
|
|
252
|
+
self.results_loop_thread = threading.Thread(target=results_loop, args=(
|
|
234
253
|
self._queue, self.results_loop_queue, self.abort_event, self.pin_memory, self._processes, gpu,
|
|
235
254
|
self.wait_time)
|
|
236
255
|
)
|
|
@@ -241,14 +260,70 @@ class NonDetMultiThreadedAugmenter(object):
|
|
|
241
260
|
else:
|
|
242
261
|
logging.debug("MultiThreadedGenerator Warning: start() has been called but workers are already running")
|
|
243
262
|
|
|
244
|
-
def _finish(self):
|
|
245
|
-
|
|
263
|
+
def _finish(self, timeout=10, force=False):
|
|
264
|
+
"""Graceful shutdown — same pause-drain-exit handshake as MTA.
|
|
265
|
+
|
|
266
|
+
Single shared in-queue (self._queue), so the drain logic is simpler
|
|
267
|
+
than the per-worker variant.
|
|
268
|
+
"""
|
|
269
|
+
if not self.initialized and len(self._processes) == 0:
|
|
270
|
+
return
|
|
271
|
+
|
|
272
|
+
# 1. Pause producers so no new bytes hit the pipe.
|
|
273
|
+
if not force and self.pause_event is not None:
|
|
274
|
+
self.pause_event.set()
|
|
275
|
+
|
|
276
|
+
# 2. Tell everyone to exit.
|
|
277
|
+
if self.abort_event is not None:
|
|
246
278
|
self.abort_event.set()
|
|
247
|
-
sleep(self.wait_time)
|
|
248
|
-
[i.terminate() for i in self._processes if i.is_alive()]
|
|
249
279
|
|
|
250
|
-
|
|
251
|
-
|
|
280
|
+
# 3. Join results_loop_thread before draining its output queue.
|
|
281
|
+
if self.results_loop_thread is not None:
|
|
282
|
+
self.results_loop_thread.join(timeout=timeout)
|
|
283
|
+
|
|
284
|
+
# 4. Drain results_loop_queue (no writer now).
|
|
285
|
+
if self.results_loop_queue is not None:
|
|
286
|
+
while not self.results_loop_queue.empty():
|
|
287
|
+
try:
|
|
288
|
+
self.results_loop_queue.get_nowait()
|
|
289
|
+
except Exception:
|
|
290
|
+
break
|
|
291
|
+
|
|
292
|
+
# 5. Drain the shared mp.Queue while workers exit, so their feeder
|
|
293
|
+
# threads can flush and the worker processes can terminate.
|
|
294
|
+
if len(self._processes) > 0 and self._queue is not None:
|
|
295
|
+
deadline = time() + timeout
|
|
296
|
+
drain_tick = max(self.wait_time, 0.01)
|
|
297
|
+
while time() < deadline and any(p.is_alive() for p in self._processes):
|
|
298
|
+
while not self._queue.empty():
|
|
299
|
+
try:
|
|
300
|
+
self._queue.get_nowait()
|
|
301
|
+
except Exception:
|
|
302
|
+
break
|
|
303
|
+
sleep(drain_tick)
|
|
304
|
+
|
|
305
|
+
# 6. Terminate stragglers.
|
|
306
|
+
for p in self._processes:
|
|
307
|
+
if p.is_alive():
|
|
308
|
+
p.terminate()
|
|
309
|
+
p.join(timeout=1.0)
|
|
310
|
+
|
|
311
|
+
# 7. Final drain.
|
|
312
|
+
while not self._queue.empty():
|
|
313
|
+
try:
|
|
314
|
+
self._queue.get_nowait()
|
|
315
|
+
except Exception:
|
|
316
|
+
break
|
|
317
|
+
|
|
318
|
+
# 8. Close the queue.
|
|
319
|
+
self._queue.close()
|
|
320
|
+
self._queue.join_thread()
|
|
321
|
+
|
|
322
|
+
self._queue = None
|
|
323
|
+
self.results_loop_queue = None
|
|
324
|
+
self.results_loop_thread = None
|
|
325
|
+
self.abort_event = None
|
|
326
|
+
self.pause_event = None
|
|
252
327
|
self._processes = []
|
|
253
328
|
self.initialized = False
|
|
254
329
|
|
|
@@ -258,7 +333,7 @@ class NonDetMultiThreadedAugmenter(object):
|
|
|
258
333
|
|
|
259
334
|
def __del__(self):
|
|
260
335
|
logging.debug("MultiThreadedGenerator: destructor was called")
|
|
261
|
-
self._finish()
|
|
336
|
+
self._finish(timeout=2, force=True)
|
|
262
337
|
|
|
263
338
|
|
|
264
339
|
if __name__ == '__main__':
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: batchgenerators
|
|
3
|
-
Version: 0.25.
|
|
3
|
+
Version: 0.25.3
|
|
4
4
|
Summary: Data augmentation toolkit
|
|
5
5
|
Author-email: "Division of Medical Image Computing, German Cancer Research Center AND Applied Computer Vision Lab, Helmholtz Imaging Platform" <f.isensee@dkfz-heidelberg.de>
|
|
6
6
|
License: Apache License
|
|
File without changes
|
{batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators.egg-info/dependency_links.txt
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
@@ -14,7 +14,6 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import unittest
|
|
17
|
-
import unittest2
|
|
18
17
|
import numpy as np
|
|
19
18
|
from batchgenerators.dataloading.single_threaded_augmenter import SingleThreadedAugmenter
|
|
20
19
|
from skimage import data
|
|
@@ -23,7 +22,7 @@ from tests.DataGenerators import BasicDataLoader
|
|
|
23
22
|
from batchgenerators.transforms.spatial_transforms import MirrorTransform
|
|
24
23
|
|
|
25
24
|
|
|
26
|
-
class TestMirrorAxis(
|
|
25
|
+
class TestMirrorAxis(unittest.TestCase):
|
|
27
26
|
def setUp(self):
|
|
28
27
|
self.seed = 1234
|
|
29
28
|
|
|
@@ -25,8 +25,8 @@ class TestAugmentResample(unittest.TestCase):
|
|
|
25
25
|
self.data_3D = np.random.random((2, 64, 56, 48))
|
|
26
26
|
self.data_2D = np.random.random((2, 64, 56))
|
|
27
27
|
|
|
28
|
-
self.data_3D_unique = np.reshape(range(2 * 64 * 56 * 48),
|
|
29
|
-
self.data_2D_unique = np.reshape(range(2 * 64 * 56),
|
|
28
|
+
self.data_3D_unique = np.reshape(range(2 * 64 * 56 * 48), shape=(2, 64, 56, 48))
|
|
29
|
+
self.data_2D_unique = np.reshape(range(2 * 64 * 56), shape=(2, 64, 56))
|
|
30
30
|
|
|
31
31
|
self.d_3D = augment_linear_downsampling_scipy(np.copy(self.data_3D), zoom_range=[0.5, 1.5], per_channel=False)
|
|
32
32
|
self.d_2D = augment_linear_downsampling_scipy(np.copy(self.data_2D), zoom_range=[0.5, 1.5], per_channel=False)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/augmentations/normalizations.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/dataloading/data_loader.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/examples/brats2017/__init__.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/examples/brats2017/config.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/transforms/abstract_transforms.py
RENAMED
|
File without changes
|
|
File without changes
|
{batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/transforms/color_transforms.py
RENAMED
|
File without changes
|
|
File without changes
|
{batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/transforms/local_transforms.py
RENAMED
|
File without changes
|
{batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/transforms/noise_transforms.py
RENAMED
|
File without changes
|
{batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/transforms/resample_transforms.py
RENAMED
|
File without changes
|
|
File without changes
|
{batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/transforms/spatial_transforms.py
RENAMED
|
File without changes
|
{batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/transforms/utility_transforms.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
{batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/utilities/data_splitting.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|