batchgenerators 0.25.2__tar.gz → 0.25.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. {batchgenerators-0.25.2 → batchgenerators-0.25.3}/PKG-INFO +1 -1
  2. {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/dataloading/multi_threaded_augmenter.py +139 -79
  3. {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/dataloading/nondet_multi_threaded_augmenter.py +131 -56
  4. {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators.egg-info/PKG-INFO +1 -1
  5. {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators.egg-info/SOURCES.txt +0 -0
  6. {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators.egg-info/dependency_links.txt +0 -0
  7. {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators.egg-info/requires.txt +0 -0
  8. {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators.egg-info/top_level.txt +0 -0
  9. {batchgenerators-0.25.2 → batchgenerators-0.25.3}/pyproject.toml +1 -1
  10. {batchgenerators-0.25.2 → batchgenerators-0.25.3}/tests/test_DataLoader.py +1 -1
  11. {batchgenerators-0.25.2 → batchgenerators-0.25.3}/tests/test_axis_mirroring.py +1 -2
  12. {batchgenerators-0.25.2 → batchgenerators-0.25.3}/tests/test_resample_augmentations.py +2 -2
  13. {batchgenerators-0.25.2 → batchgenerators-0.25.3}/LICENSE +0 -0
  14. {batchgenerators-0.25.2 → batchgenerators-0.25.3}/README.md +0 -0
  15. {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/__init__.py +0 -0
  16. {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/augmentations/__init__.py +0 -0
  17. {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/augmentations/color_augmentations.py +0 -0
  18. {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/augmentations/crop_and_pad_augmentations.py +0 -0
  19. {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/augmentations/noise_augmentations.py +0 -0
  20. {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/augmentations/normalizations.py +0 -0
  21. {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/augmentations/resample_augmentations.py +0 -0
  22. {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/augmentations/spatial_transformations.py +0 -0
  23. {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/augmentations/utils.py +0 -0
  24. {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/dataloading/__init__.py +0 -0
  25. {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/dataloading/data_loader.py +0 -0
  26. {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/dataloading/dataset.py +0 -0
  27. {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/dataloading/single_threaded_augmenter.py +0 -0
  28. {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/datasets/__init__.py +0 -0
  29. {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/datasets/cifar.py +0 -0
  30. {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/examples/__init__.py +0 -0
  31. {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/examples/brats2017/__init__.py +0 -0
  32. {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/examples/brats2017/brats2017_dataloader_2D.py +0 -0
  33. {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/examples/brats2017/brats2017_dataloader_3D.py +0 -0
  34. {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/examples/brats2017/brats2017_preprocessing.py +0 -0
  35. {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/examples/brats2017/config.py +0 -0
  36. {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/examples/cifar10.py +0 -0
  37. {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/examples/multithreaded_dataloading.py +0 -0
  38. {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/transforms/__init__.py +0 -0
  39. {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/transforms/abstract_transforms.py +0 -0
  40. {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/transforms/channel_selection_transforms.py +0 -0
  41. {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/transforms/color_transforms.py +0 -0
  42. {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/transforms/crop_and_pad_transforms.py +0 -0
  43. {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/transforms/local_transforms.py +0 -0
  44. {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/transforms/noise_transforms.py +0 -0
  45. {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/transforms/resample_transforms.py +0 -0
  46. {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/transforms/sample_normalization_transforms.py +0 -0
  47. {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/transforms/spatial_transforms.py +0 -0
  48. {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/transforms/utility_transforms.py +0 -0
  49. {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/utilities/__init__.py +0 -0
  50. {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/utilities/custom_types.py +0 -0
  51. {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/utilities/data_splitting.py +0 -0
  52. {batchgenerators-0.25.2 → batchgenerators-0.25.3}/batchgenerators/utilities/file_and_folder_operations.py +0 -0
  53. {batchgenerators-0.25.2 → batchgenerators-0.25.3}/setup.cfg +0 -0
  54. {batchgenerators-0.25.2 → batchgenerators-0.25.3}/setup.py +0 -0
  55. {batchgenerators-0.25.2 → batchgenerators-0.25.3}/tests/test_augment_zoom.py +0 -0
  56. {batchgenerators-0.25.2 → batchgenerators-0.25.3}/tests/test_color_augmentations.py +0 -0
  57. {batchgenerators-0.25.2 → batchgenerators-0.25.3}/tests/test_crop.py +0 -0
  58. {batchgenerators-0.25.2 → batchgenerators-0.25.3}/tests/test_multithreaded_augmenter.py +0 -0
  59. {batchgenerators-0.25.2 → batchgenerators-0.25.3}/tests/test_normalizations.py +0 -0
  60. {batchgenerators-0.25.2 → batchgenerators-0.25.3}/tests/test_random_crop.py +0 -0
  61. {batchgenerators-0.25.2 → batchgenerators-0.25.3}/tests/test_sanity.py +0 -0
  62. {batchgenerators-0.25.2 → batchgenerators-0.25.3}/tests/test_spatial_transformations.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: batchgenerators
3
- Version: 0.25.2
3
+ Version: 0.25.3
4
4
  Summary: Data augmentation toolkit
5
5
  Author-email: "Division of Medical Image Computing, German Cancer Research Center AND Applied Computer Vision Lab, Helmholtz Imaging Platform" <f.isensee@dkfz-heidelberg.de>
6
6
  License: Apache License
@@ -17,7 +17,7 @@ import traceback
17
17
  from typing import List, Union
18
18
  import threading
19
19
  from multiprocessing import Process, Queue
20
- from queue import Queue as thrQueue
20
+ from queue import Queue as thrQueue, Full, Empty
21
21
  import numpy as np
22
22
  import sys
23
23
  import logging
@@ -25,38 +25,42 @@ from multiprocessing import Event
25
25
  from time import sleep, time
26
26
  from threadpoolctl import threadpool_limits
27
27
 
28
+ from batchgenerators.dataloading.nondet_multi_threaded_augmenter import pin_memory_of_all_eligible_items_in_dict
29
+
28
30
  try:
29
31
  import torch
30
32
  except ImportError:
31
33
  torch = None
32
34
 
33
35
 
34
- def producer(queue, data_loader, transform, thread_id, seed, abort_event, wait_time: float = 0.02):
36
+ def producer(queue, data_loader, transform, thread_id, seed, abort_event,
37
+ pause_event=None, wait_time: float = 0.02):
35
38
  np.random.seed(seed)
36
39
  data_loader.set_thread_id(thread_id)
37
40
  item = None
38
41
 
39
42
  try:
40
- while True:
41
- # check if abort event was set
42
- if not abort_event.is_set():
43
- # print("worker %d event not set" % thread_id)
44
- if item is None:
45
- try:
46
- item = next(data_loader)
47
- if transform is not None:
48
- item = transform(**item)
49
- except StopIteration:
50
- item = "end"
51
-
52
- if not queue.full():
53
- queue.put(item)
54
- item = None
55
- else:
56
- sleep(wait_time)
57
- else:
58
- # print("worder %d event is now set, exiting" % thread_id)
59
- return
43
+ while not abort_event.is_set():
44
+ # When paused, hold any in-flight item and stop interacting with the
45
+ # queue. This guarantees no new bytes hit the pipe after pause is
46
+ # observed, which is what _finish() relies on for safe shutdown.
47
+ if pause_event is not None and pause_event.is_set():
48
+ sleep(wait_time)
49
+ continue
50
+
51
+ if item is None:
52
+ try:
53
+ item = next(data_loader)
54
+ if transform is not None:
55
+ item = transform(**item)
56
+ except StopIteration:
57
+ item = "end"
58
+
59
+ try:
60
+ queue.put(item, timeout=wait_time)
61
+ item = None
62
+ except Full:
63
+ pass # loop; abort/pause are re-checked at top
60
64
  except KeyboardInterrupt:
61
65
  abort_event.set()
62
66
  return
@@ -86,44 +90,41 @@ def results_loop(in_queues: List[Queue], out_queue: thrQueue, abort_event: Event
86
90
  if abort_event.is_set():
87
91
  return
88
92
 
89
- # check if all workers are still alive
90
- if not all([i.is_alive() for i in worker_list]):
93
+ # Check that all workers are still alive — but only when we haven't
94
+ # started a graceful shutdown. During shutdown workers exit cleanly,
95
+ # which would otherwise trip the RuntimeError below.
96
+ if not abort_event.is_set() and not all([i.is_alive() for i in worker_list]):
91
97
  abort_event.set()
92
98
  raise RuntimeError("One or more background workers are no longer alive. Exiting. Please check the print"
93
99
  " statements above for the actual error message")
94
100
 
95
- # if we don't have an item we need to fetch it first. If the queue we want to get it from it empty, try
96
- # again later
101
+ # if we don't have an item we need to fetch it first. Round-robin across the worker queues to keep the
102
+ # batch ordering deterministic; block on the current queue for up to wait_time so we don't busy-wait.
97
103
  if item is None:
98
104
  current_queue = in_queues[queue_ctr % len(in_queues)]
99
- if not current_queue.empty():
100
- # get the item
101
- item = current_queue.get()
102
- # if we do pin memory, do it now, otherwise skip this
103
- if do_pin_memory:
104
- if isinstance(item, dict):
105
- for k in item.keys():
106
- if isinstance(item[k], torch.Tensor):
107
- item[k] = item[k].pin_memory()
108
- queue_ctr += 1
109
-
110
- if isinstance(item, str) and item == 'end':
111
- end_ctr += 1
112
- if end_ctr == len(in_queues):
113
- end_ctr = 0
114
- queue_ctr = 0
115
-
116
- else:
117
- sleep(wait_time)
118
- continue
105
+ try:
106
+ item = current_queue.get(timeout=wait_time)
107
+ except Empty:
108
+ continue # retry the same queue; abort_event re-checked at top of loop
109
+ # if we do pin memory, do it now, otherwise skip this. The isinstance(dict) guard keeps the 'end'
110
+ # string sentinel from reaching the pinning logic.
111
+ if do_pin_memory:
112
+ if isinstance(item, dict):
113
+ item = pin_memory_of_all_eligible_items_in_dict(item)
114
+ queue_ctr += 1
115
+
116
+ if isinstance(item, str) and item == 'end':
117
+ end_ctr += 1
118
+ if end_ctr == len(in_queues):
119
+ end_ctr = 0
120
+ queue_ctr = 0
119
121
 
120
122
  # we only arrive here if item is not None. Now put item in to the out_queue
121
- if not out_queue.full():
122
- out_queue.put(item)
123
+ try:
124
+ out_queue.put(item, timeout=wait_time)
123
125
  item = None
124
- else:
125
- sleep(wait_time)
126
- continue
126
+ except Full:
127
+ continue # abort_event is re-checked at top of loop
127
128
  except KeyboardInterrupt:
128
129
  abort_event.set()
129
130
  raise KeyboardInterrupt
@@ -171,6 +172,7 @@ class MultiThreadedAugmenter(object):
171
172
  self.pin_memory_thread = None
172
173
  self.pin_memory_queue = None
173
174
  self.abort_event = Event()
175
+ self.pause_event = Event()
174
176
  self.wait_time = wait_time
175
177
  self.was_initialized = False
176
178
 
@@ -181,20 +183,16 @@ class MultiThreadedAugmenter(object):
181
183
  return self.__next__()
182
184
 
183
185
  def __get_next_item(self):
184
- item = None
185
-
186
- while item is None:
186
+ while True:
187
187
  if self.abort_event.is_set():
188
188
  self._finish()
189
189
  raise RuntimeError("One or more background workers are no longer alive. Exiting. Please check the "
190
190
  "print statements above for the actual error message")
191
191
 
192
- if not self.pin_memory_queue.empty():
193
- item = self.pin_memory_queue.get()
194
- else:
195
- sleep(self.wait_time)
196
-
197
- return item
192
+ try:
193
+ return self.pin_memory_queue.get(timeout=self.wait_time)
194
+ except Empty:
195
+ continue # abort_event re-checked at top of loop
198
196
 
199
197
  def __next__(self):
200
198
  if not self.was_initialized:
@@ -225,6 +223,7 @@ class MultiThreadedAugmenter(object):
225
223
  if not self.was_initialized:
226
224
  self._finish()
227
225
  self.abort_event.clear()
226
+ self.pause_event.clear()
228
227
 
229
228
  logging.debug("starting workers")
230
229
  self._queue_ctr = 0
@@ -237,7 +236,8 @@ class MultiThreadedAugmenter(object):
237
236
  for i in range(self.num_processes):
238
237
  self._queues.append(Queue(self.num_cached_per_queue))
239
238
  self._processes.append(Process(target=producer, args=(
240
- self._queues[i], self.generator, self.transform, i, self.seeds[i], self.abort_event)))
239
+ self._queues[i], self.generator, self.transform, i, self.seeds[i],
240
+ self.abort_event, self.pause_event, self.wait_time)))
241
241
  self._processes[-1].daemon = True
242
242
  self._processes[-1].start()
243
243
 
@@ -261,28 +261,86 @@ class MultiThreadedAugmenter(object):
261
261
  logging.debug("MultiThreadedGenerator Warning: start() has been called but it has already been "
262
262
  "initialized previously")
263
263
 
264
- def _finish(self, timeout=10):
265
- self.abort_event.set()
264
+ def _finish(self, timeout=10, force=False):
265
+ """Shut down workers and the pin-memory thread.
266
266
 
267
- start = time()
268
- while self.pin_memory_thread is not None and self.pin_memory_thread.is_alive() and start + timeout > time():
269
- sleep(0.2)
267
+ Graceful path (force=False):
268
+ pause producers -> abort everyone -> join pin_memory_thread ->
269
+ drain pin_memory_queue -> drain _queues while joining workers ->
270
+ terminate stragglers -> final drain -> close queues.
270
271
 
271
- if len(self._processes) != 0:
272
- logging.debug("MultiThreadedGenerator: shutting down workers...")
273
- [i.terminate() for i in self._processes]
272
+ Force path (force=True): skip the producer pause and rely on
273
+ abort + terminate fallback. Used by __del__ where the interpreter
274
+ may already be tearing down multiprocessing.
275
+ """
276
+ if not self.was_initialized and len(self._processes) == 0:
277
+ return
274
278
 
275
- for i, p in enumerate(self._processes):
276
- self._queues[i].close()
277
- self._queues[i].join_thread()
279
+ # 1. Stop producers from initiating new puts. Existing in-flight puts
280
+ # are still allowed to land; we drain them below.
281
+ if not force and self.pause_event is not None:
282
+ self.pause_event.set()
278
283
 
279
- self._queues = []
280
- self._processes = []
281
- self._queue = None
282
- self._end_ctr = 0
283
- self._queue_ctr = 0
284
+ # 2. Signal everyone to exit.
285
+ self.abort_event.set()
286
+
287
+ # 3. Join pin_memory_thread first so it stops refilling
288
+ # pin_memory_queue before we drain that queue.
289
+ if self.pin_memory_thread is not None:
290
+ self.pin_memory_thread.join(timeout=timeout)
291
+
292
+ # 4. Drain pin_memory_queue. Safe now: pin_memory_thread is dead and
293
+ # producers cannot reach it. Dropping items releases torch
294
+ # shared-memory refs via refcount.
295
+ if self.pin_memory_queue is not None:
296
+ while not self.pin_memory_queue.empty():
297
+ try:
298
+ self.pin_memory_queue.get_nowait()
299
+ except Exception:
300
+ break
301
+
302
+ # 5. Drain _queues in a loop while joining workers. Workers' Queue
303
+ # feeder threads need pipe space to flush before the worker
304
+ # process can fully exit; draining concurrently unblocks them.
305
+ if len(self._processes) > 0:
306
+ logging.debug("MultiThreadedGenerator: shutting down workers...")
307
+ deadline = time() + timeout
308
+ drain_tick = max(self.wait_time, 0.01)
309
+ while time() < deadline and any(p.is_alive() for p in self._processes):
310
+ for q in self._queues:
311
+ while not q.empty():
312
+ try:
313
+ q.get_nowait()
314
+ except Exception:
315
+ break
316
+ sleep(drain_tick)
317
+
318
+ # 6. Anyone still alive past the deadline gets terminated.
319
+ for p in self._processes:
320
+ if p.is_alive():
321
+ p.terminate()
322
+ p.join(timeout=1.0)
323
+
324
+ # 7. Final drain — catches residual bytes that feeder threads
325
+ # pushed onto the pipe during their own shutdown.
326
+ for q in self._queues:
327
+ while not q.empty():
328
+ try:
329
+ q.get_nowait()
330
+ except Exception:
331
+ break
284
332
 
285
- del self.pin_memory_queue
333
+ # 8. Close queues.
334
+ for q in self._queues:
335
+ q.close()
336
+ q.join_thread()
337
+
338
+ self._queues = []
339
+ self._processes = []
340
+ self.pin_memory_queue = None
341
+ self.pin_memory_thread = None
342
+ self._end_ctr = 0
343
+ self._queue_ctr = 0
286
344
  self.was_initialized = False
287
345
 
288
346
  def restart(self):
@@ -291,4 +349,6 @@ class MultiThreadedAugmenter(object):
291
349
 
292
350
  def __del__(self):
293
351
  logging.debug("MultiThreadedGenerator: destructor was called")
294
- self._finish()
352
+ # Interpreter shutdown may have already torn down parts of
353
+ # multiprocessing; take the fast path with a short timeout.
354
+ self._finish(timeout=2, force=True)
@@ -15,12 +15,12 @@
15
15
 
16
16
  import traceback
17
17
  from copy import deepcopy
18
- from typing import List, Union, Callable
18
+ from typing import List, Union
19
19
  import threading
20
20
  from builtins import range
21
21
  from multiprocessing import Process
22
22
  from multiprocessing import Queue
23
- from queue import Queue as thrQueue
23
+ from queue import Queue as thrQueue, Full, Empty
24
24
  import numpy as np
25
25
  import logging
26
26
  from multiprocessing import Event
@@ -36,7 +36,7 @@ except ImportError:
36
36
 
37
37
 
38
38
  def producer(queue: Queue, data_loader, transform, thread_id: int, seed,
39
- abort_event: Event, wait_time: float = 0.02):
39
+ abort_event: Event, pause_event=None, wait_time: float = 0.02):
40
40
  if torch is not None:
41
41
  torch.set_num_threads(1)
42
42
  if seed is not None:
@@ -50,24 +50,26 @@ def producer(queue: Queue, data_loader, transform, thread_id: int, seed,
50
50
  item = None
51
51
 
52
52
  try:
53
- while True:
53
+ while not abort_event.is_set():
54
+ # When paused, stop producing and stop putting. This is the
55
+ # handshake _finish() uses to drain queues safely.
56
+ if pause_event is not None and pause_event.is_set():
57
+ sleep(wait_time)
58
+ continue
59
+
60
+ if item is None:
61
+ item = next(data_loader)
62
+ if transform is not None:
63
+ item = transform(**item)
54
64
 
55
65
  if abort_event.is_set():
56
66
  return
57
- else:
58
- if item is None:
59
- item = next(data_loader)
60
- if transform is not None:
61
- item = transform(**item)
62
-
63
- if abort_event.is_set():
64
- return
65
67
 
66
- if not queue.full():
67
- queue.put(item)
68
- item = None
69
- else:
70
- sleep(wait_time)
68
+ try:
69
+ queue.put(item, timeout=wait_time)
70
+ item = None
71
+ except Full:
72
+ pass # loop; abort/pause re-checked at top
71
73
 
72
74
  except KeyboardInterrupt:
73
75
  abort_event.set()
@@ -80,10 +82,29 @@ def producer(queue: Queue, data_loader, transform, thread_id: int, seed,
80
82
  return
81
83
 
82
84
 
85
+ def _pin_memory_recursive(item):
86
+ # Pin every torch.Tensor reachable through nested dicts/lists/tuples. This matters for things like deep
87
+ # supervision targets, which are lists of tensors: without recursing into the list those tensors stay unpinned
88
+ # and the subsequent .to(device, non_blocking=True) silently falls back to a blocking copy.
89
+ if isinstance(item, torch.Tensor):
90
+ return item.pin_memory()
91
+ elif isinstance(item, dict):
92
+ for k in item.keys():
93
+ item[k] = _pin_memory_recursive(item[k])
94
+ return item
95
+ elif isinstance(item, list):
96
+ for i in range(len(item)):
97
+ item[i] = _pin_memory_recursive(item[i])
98
+ return item
99
+ elif isinstance(item, tuple):
100
+ return tuple(_pin_memory_recursive(i) for i in item)
101
+ else:
102
+ return item
103
+
104
+
83
105
  def pin_memory_of_all_eligible_items_in_dict(result_dict):
84
106
  for k in result_dict.keys():
85
- if isinstance(result_dict[k], torch.Tensor):
86
- result_dict[k] = result_dict[k].pin_memory()
107
+ result_dict[k] = _pin_memory_recursive(result_dict[k])
87
108
  return result_dict
88
109
 
89
110
 
@@ -103,28 +124,28 @@ def results_loop(in_queue: Queue, out_queue: thrQueue, abort_event: Event,
103
124
  if abort_event.is_set():
104
125
  return
105
126
 
106
- # check if all workers are still alive
107
- if not all([i.is_alive() for i in worker_list]):
127
+ # Check workers, but skip the check once a graceful shutdown is in
128
+ # progress workers exit cleanly then and would otherwise trip
129
+ # this RuntimeError.
130
+ if not abort_event.is_set() and not all([i.is_alive() for i in worker_list]):
108
131
  abort_event.set()
109
132
  raise RuntimeError("One or more background workers are no longer alive. Exiting. Please check the "
110
133
  "print statements above for the actual error message")
111
134
 
112
135
  if item is None:
113
- if not in_queue.empty():
114
- item = in_queue.get()
115
- if do_pin_memory:
116
- item = pin_memory_of_all_eligible_items_in_dict(item)
117
- else:
118
- sleep(wait_time)
119
- continue
136
+ try:
137
+ item = in_queue.get(timeout=wait_time)
138
+ except Empty:
139
+ continue # abort_event/worker liveness re-checked at top of loop
140
+ if do_pin_memory:
141
+ item = pin_memory_of_all_eligible_items_in_dict(item)
120
142
 
121
143
  # we only arrive here if item is not None. Now put item in to the out_queue
122
- if not out_queue.full():
123
- out_queue.put(item)
144
+ try:
145
+ out_queue.put(item, timeout=wait_time)
124
146
  item = None
125
- else:
126
- sleep(wait_time)
127
- continue
147
+ except Full:
148
+ continue # abort_event is re-checked at top of loop
128
149
 
129
150
  except Exception as e:
130
151
  abort_event.set()
@@ -145,7 +166,7 @@ class NonDetMultiThreadedAugmenter(object):
145
166
  """
146
167
 
147
168
  def __init__(self, data_loader, transform, num_processes, num_cached=2, seeds=None, pin_memory=False,
148
- wait_time=0.02, results_loop_fn: Callable = results_loop):
169
+ wait_time=0.02):
149
170
  self.pin_memory = pin_memory
150
171
  self.transform = transform
151
172
  self.num_cached = num_cached
@@ -157,10 +178,10 @@ class NonDetMultiThreadedAugmenter(object):
157
178
 
158
179
  self._queue = None
159
180
  self._processes = []
160
- self.results_loop_fn = results_loop
161
181
  self.results_loop_thread = None
162
182
  self.results_loop_queue = None
163
183
  self.abort_event = None
184
+ self.pause_event = None
164
185
  self.initialized = False
165
186
 
166
187
  self.wait_time = wait_time
@@ -178,23 +199,19 @@ class NonDetMultiThreadedAugmenter(object):
178
199
  return self.__next__()
179
200
 
180
201
  def __get_next_item(self):
181
- item = None
182
-
183
- while item is None:
184
- #
202
+ while True:
185
203
  if self.abort_event.is_set():
186
- # self.communication_thread handles checking for dead workers and will set the abort event if necessary
204
+ # the results loop checks for dead workers and will set the abort event if necessary
187
205
  self._finish()
188
206
  raise RuntimeError("One or more background workers are no longer alive. Exiting. Please check the "
189
207
  "print statements above for the actual error message")
190
208
 
191
- if not self.results_loop_queue.empty():
192
- item = self.results_loop_queue.get()
193
- self.results_loop_queue.task_done()
194
- else:
195
- sleep(self.wait_time)
196
-
197
- return item
209
+ try:
210
+ item = self.results_loop_queue.get(timeout=self.wait_time)
211
+ except Empty:
212
+ continue # abort_event re-checked at top of loop
213
+ self.results_loop_queue.task_done()
214
+ return item
198
215
 
199
216
  def __next__(self):
200
217
  if not self.initialized:
@@ -210,6 +227,7 @@ class NonDetMultiThreadedAugmenter(object):
210
227
  self._queue = Queue(self.num_cached)
211
228
  self.results_loop_queue = thrQueue(self.num_cached)
212
229
  self.abort_event = Event()
230
+ self.pause_event = Event()
213
231
 
214
232
  logging.debug("starting workers")
215
233
  if isinstance(self.generator, DataLoader):
@@ -218,7 +236,8 @@ class NonDetMultiThreadedAugmenter(object):
218
236
 
219
237
  for i in range(self.num_processes):
220
238
  self._processes.append(Process(target=producer, args=(
221
- self._queue, self.generator, self.transform, i, self.seeds[i], self.abort_event, self.wait_time
239
+ self._queue, self.generator, self.transform, i, self.seeds[i],
240
+ self.abort_event, self.pause_event, self.wait_time
222
241
  )))
223
242
  self._processes[-1].daemon = True
224
243
  _ = [i.start() for i in self._processes]
@@ -230,7 +249,7 @@ class NonDetMultiThreadedAugmenter(object):
230
249
 
231
250
  # in_queue: Queue, out_queue: thrQueue, abort_event: Event, pin_memory: bool, worker_list: List[Process],
232
251
  # gpu: Union[int, None] = None, wait_time: float = 0.02
233
- self.results_loop_thread = threading.Thread(target=self.results_loop_fn, args=(
252
+ self.results_loop_thread = threading.Thread(target=results_loop, args=(
234
253
  self._queue, self.results_loop_queue, self.abort_event, self.pin_memory, self._processes, gpu,
235
254
  self.wait_time)
236
255
  )
@@ -241,14 +260,70 @@ class NonDetMultiThreadedAugmenter(object):
241
260
  else:
242
261
  logging.debug("MultiThreadedGenerator Warning: start() has been called but workers are already running")
243
262
 
244
- def _finish(self):
245
- if self.initialized:
263
+ def _finish(self, timeout=10, force=False):
264
+ """Graceful shutdown — same pause-drain-exit handshake as MTA.
265
+
266
+ Single shared in-queue (self._queue), so the drain logic is simpler
267
+ than the per-worker variant.
268
+ """
269
+ if not self.initialized and len(self._processes) == 0:
270
+ return
271
+
272
+ # 1. Pause producers so no new bytes hit the pipe.
273
+ if not force and self.pause_event is not None:
274
+ self.pause_event.set()
275
+
276
+ # 2. Tell everyone to exit.
277
+ if self.abort_event is not None:
246
278
  self.abort_event.set()
247
- sleep(self.wait_time)
248
- [i.terminate() for i in self._processes if i.is_alive()]
249
279
 
250
- del self._queue, self.results_loop_queue, self.results_loop_thread, self.abort_event, self._processes
251
- self._queue, self.results_loop_queue, self.results_loop_thread, self.abort_event = None, None, None, None
280
+ # 3. Join results_loop_thread before draining its output queue.
281
+ if self.results_loop_thread is not None:
282
+ self.results_loop_thread.join(timeout=timeout)
283
+
284
+ # 4. Drain results_loop_queue (no writer now).
285
+ if self.results_loop_queue is not None:
286
+ while not self.results_loop_queue.empty():
287
+ try:
288
+ self.results_loop_queue.get_nowait()
289
+ except Exception:
290
+ break
291
+
292
+ # 5. Drain the shared mp.Queue while workers exit, so their feeder
293
+ # threads can flush and the worker processes can terminate.
294
+ if len(self._processes) > 0 and self._queue is not None:
295
+ deadline = time() + timeout
296
+ drain_tick = max(self.wait_time, 0.01)
297
+ while time() < deadline and any(p.is_alive() for p in self._processes):
298
+ while not self._queue.empty():
299
+ try:
300
+ self._queue.get_nowait()
301
+ except Exception:
302
+ break
303
+ sleep(drain_tick)
304
+
305
+ # 6. Terminate stragglers.
306
+ for p in self._processes:
307
+ if p.is_alive():
308
+ p.terminate()
309
+ p.join(timeout=1.0)
310
+
311
+ # 7. Final drain.
312
+ while not self._queue.empty():
313
+ try:
314
+ self._queue.get_nowait()
315
+ except Exception:
316
+ break
317
+
318
+ # 8. Close the queue.
319
+ self._queue.close()
320
+ self._queue.join_thread()
321
+
322
+ self._queue = None
323
+ self.results_loop_queue = None
324
+ self.results_loop_thread = None
325
+ self.abort_event = None
326
+ self.pause_event = None
252
327
  self._processes = []
253
328
  self.initialized = False
254
329
 
@@ -258,7 +333,7 @@ class NonDetMultiThreadedAugmenter(object):
258
333
 
259
334
  def __del__(self):
260
335
  logging.debug("MultiThreadedGenerator: destructor was called")
261
- self._finish()
336
+ self._finish(timeout=2, force=True)
262
337
 
263
338
 
264
339
  if __name__ == '__main__':
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: batchgenerators
3
- Version: 0.25.2
3
+ Version: 0.25.3
4
4
  Summary: Data augmentation toolkit
5
5
  Author-email: "Division of Medical Image Computing, German Cancer Research Center AND Applied Computer Vision Lab, Helmholtz Imaging Platform" <f.isensee@dkfz-heidelberg.de>
6
6
  License: Apache License
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "batchgenerators"
7
- version = "0.25.2"
7
+ version = "0.25.3"
8
8
  description = "Data augmentation toolkit"
9
9
  readme = "README.md"
10
10
  license = { file = "LICENSE" }
@@ -201,7 +201,7 @@ class TestDataLoader(unittest.TestCase):
201
201
 
202
202
  def test_thoroughly(self):
203
203
  data_list = [list(range(123)),
204
- list(range(1243)),
204
+ list(range(45)),
205
205
  list(range(1)),
206
206
  list(range(7)),
207
207
  ]
@@ -14,7 +14,6 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import unittest
17
- import unittest2
18
17
  import numpy as np
19
18
  from batchgenerators.dataloading.single_threaded_augmenter import SingleThreadedAugmenter
20
19
  from skimage import data
@@ -23,7 +22,7 @@ from tests.DataGenerators import BasicDataLoader
23
22
  from batchgenerators.transforms.spatial_transforms import MirrorTransform
24
23
 
25
24
 
26
- class TestMirrorAxis(unittest2.TestCase):
25
+ class TestMirrorAxis(unittest.TestCase):
27
26
  def setUp(self):
28
27
  self.seed = 1234
29
28
 
@@ -25,8 +25,8 @@ class TestAugmentResample(unittest.TestCase):
25
25
  self.data_3D = np.random.random((2, 64, 56, 48))
26
26
  self.data_2D = np.random.random((2, 64, 56))
27
27
 
28
- self.data_3D_unique = np.reshape(range(2 * 64 * 56 * 48), newshape=(2, 64, 56, 48))
29
- self.data_2D_unique = np.reshape(range(2 * 64 * 56), newshape=(2, 64, 56))
28
+ self.data_3D_unique = np.reshape(range(2 * 64 * 56 * 48), shape=(2, 64, 56, 48))
29
+ self.data_2D_unique = np.reshape(range(2 * 64 * 56), shape=(2, 64, 56))
30
30
 
31
31
  self.d_3D = augment_linear_downsampling_scipy(np.copy(self.data_3D), zoom_range=[0.5, 1.5], per_channel=False)
32
32
  self.d_2D = augment_linear_downsampling_scipy(np.copy(self.data_2D), zoom_range=[0.5, 1.5], per_channel=False)