nnInteractive 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. nnInteractive/__init__.py +3 -0
  2. nnInteractive/inference/__init__.py +0 -0
  3. nnInteractive/inference/cvpr2025_challenge_baseline/__init__.py +0 -0
  4. nnInteractive/inference/cvpr2025_challenge_baseline/predict.py +173 -0
  5. nnInteractive/inference/inference_session.py +1400 -0
  6. nnInteractive/interaction/__init__.py +0 -0
  7. nnInteractive/interaction/point.py +166 -0
  8. nnInteractive/supervoxel/setup.py +4 -0
  9. nnInteractive/supervoxel/src/metadata.py +118 -0
  10. nnInteractive/supervoxel/src/reader.py +175 -0
  11. nnInteractive/supervoxel/src/run.py +136 -0
  12. nnInteractive/supervoxel/src/sam2/__init__.py +2 -0
  13. nnInteractive/supervoxel/src/sam2/sam2/__init__.py +11 -0
  14. nnInteractive/supervoxel/src/sam2/sam2/automatic_mask_generator.py +434 -0
  15. nnInteractive/supervoxel/src/sam2/sam2/benchmark.py +86 -0
  16. nnInteractive/supervoxel/src/sam2/sam2/build_sam.py +172 -0
  17. nnInteractive/supervoxel/src/sam2/sam2/modeling/__init__.py +5 -0
  18. nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/__init__.py +5 -0
  19. nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/hieradet.py +305 -0
  20. nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/image_encoder.py +132 -0
  21. nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/utils.py +89 -0
  22. nnInteractive/supervoxel/src/sam2/sam2/modeling/memory_attention.py +167 -0
  23. nnInteractive/supervoxel/src/sam2/sam2/modeling/memory_encoder.py +179 -0
  24. nnInteractive/supervoxel/src/sam2/sam2/modeling/position_encoding.py +217 -0
  25. nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/__init__.py +5 -0
  26. nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/mask_decoder.py +274 -0
  27. nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/prompt_encoder.py +194 -0
  28. nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/transformer.py +293 -0
  29. nnInteractive/supervoxel/src/sam2/sam2/modeling/sam2_base.py +879 -0
  30. nnInteractive/supervoxel/src/sam2/sam2/modeling/sam2_utils.py +315 -0
  31. nnInteractive/supervoxel/src/sam2/sam2/sam2_image_predictor.py +433 -0
  32. nnInteractive/supervoxel/src/sam2/sam2/sam2_video_predictor.py +1171 -0
  33. nnInteractive/supervoxel/src/sam2/sam2/sam2_video_predictor_legacy.py +1125 -0
  34. nnInteractive/supervoxel/src/sam2/sam2/utils/__init__.py +5 -0
  35. nnInteractive/supervoxel/src/sam2/sam2/utils/amg.py +332 -0
  36. nnInteractive/supervoxel/src/sam2/sam2/utils/misc.py +488 -0
  37. nnInteractive/supervoxel/src/sam2/sam2/utils/transforms.py +108 -0
  38. nnInteractive/supervoxel/src/sam2/setup.py +174 -0
  39. nnInteractive/supervoxel/src/sam2/training/__init__.py +5 -0
  40. nnInteractive/supervoxel/src/sam2/training/dataset/__init__.py +5 -0
  41. nnInteractive/supervoxel/src/sam2/training/dataset/sam2_datasets.py +176 -0
  42. nnInteractive/supervoxel/src/sam2/training/dataset/transforms.py +481 -0
  43. nnInteractive/supervoxel/src/sam2/training/dataset/utils.py +102 -0
  44. nnInteractive/supervoxel/src/sam2/training/dataset/vos_dataset.py +154 -0
  45. nnInteractive/supervoxel/src/sam2/training/dataset/vos_raw_dataset.py +290 -0
  46. nnInteractive/supervoxel/src/sam2/training/dataset/vos_sampler.py +103 -0
  47. nnInteractive/supervoxel/src/sam2/training/dataset/vos_segment_loader.py +289 -0
  48. nnInteractive/supervoxel/src/sam2/training/loss_fns.py +290 -0
  49. nnInteractive/supervoxel/src/sam2/training/model/__init__.py +5 -0
  50. nnInteractive/supervoxel/src/sam2/training/model/sam2.py +515 -0
  51. nnInteractive/supervoxel/src/sam2/training/optimizer.py +462 -0
  52. nnInteractive/supervoxel/src/sam2/training/scripts/sav_frame_extraction_submitit.py +157 -0
  53. nnInteractive/supervoxel/src/sam2/training/train.py +232 -0
  54. nnInteractive/supervoxel/src/sam2/training/trainer.py +1051 -0
  55. nnInteractive/supervoxel/src/sam2/training/utils/__init__.py +5 -0
  56. nnInteractive/supervoxel/src/sam2/training/utils/checkpoint_utils.py +328 -0
  57. nnInteractive/supervoxel/src/sam2/training/utils/data_utils.py +166 -0
  58. nnInteractive/supervoxel/src/sam2/training/utils/distributed.py +560 -0
  59. nnInteractive/supervoxel/src/sam2/training/utils/logger.py +236 -0
  60. nnInteractive/supervoxel/src/sam2/training/utils/train_utils.py +275 -0
  61. nnInteractive/supervoxel/src/supervoxel.py +198 -0
  62. nnInteractive/trainer/__init__.py +0 -0
  63. nnInteractive/trainer/nnInteractiveTrainer.py +24 -0
  64. nnInteractive/utils/__init__.py +0 -0
  65. nnInteractive/utils/bboxes.py +217 -0
  66. nnInteractive/utils/checkpoint_cleansing.py +9 -0
  67. nnInteractive/utils/crop.py +268 -0
  68. nnInteractive/utils/erosion_dilation.py +48 -0
  69. nnInteractive/utils/inference_helpers.py +45 -0
  70. nnInteractive/utils/os_shennanigans.py +16 -0
  71. nnInteractive/utils/rounding.py +13 -0
  72. nninteractive-2.0.0.dist-info/METADATA +511 -0
  73. nninteractive-2.0.0.dist-info/RECORD +76 -0
  74. nninteractive-2.0.0.dist-info/WHEEL +5 -0
  75. nninteractive-2.0.0.dist-info/licenses/LICENSE +201 -0
  76. nninteractive-2.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,560 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import datetime
8
+ import functools
9
+ import io
10
+ import logging
11
+ import os
12
+ import random
13
+ import tempfile
14
+ import time
15
+ from typing import Any, Callable, List, Tuple
16
+
17
+ import torch
18
+ import torch.autograd as autograd
19
+ import torch.distributed as dist
20
+
21
+ # Default to GPU 0
22
+ _cuda_device_index: int = 0
23
+
24
+ # Setting _cuda_device_index to -1 internally implies that we should use CPU
25
+ _CPU_DEVICE_INDEX = -1
26
+ _PRIMARY_RANK = 0
27
+
28
+
29
+ @functools.lru_cache()
30
+ def _get_global_gloo_group():
31
+ """
32
+ Return a process group based on gloo backend, containing all the ranks
33
+ The result is cached.
34
+ """
35
+
36
+ if dist.get_backend() == "nccl":
37
+ # Increase timeout from 1800 sec to 43200 sec (12 hr) to avoid some processes
38
+ # being much slower than others causing a timeout (which can happen in relation
39
+ # or LVIS class mAP evaluation).
40
+ timeout = 43200
41
+ return dist.new_group(
42
+ backend="gloo",
43
+ timeout=datetime.timedelta(seconds=timeout),
44
+ )
45
+
46
+ return dist.group.WORLD
47
+
48
+
49
+ def is_main_process():
50
+ """Return true if the current process is the main one"""
51
+ return get_rank() == 0
52
+
53
+
54
+ def all_gather_via_filesys(data, filesys_save_dir=None, gather_to_rank_0_only=False):
55
+ """
56
+ Run all_gather on arbitrary picklable data (not necessarily tensors), similar to
57
+ `all_gather` above, but using filesystem instead of collective ops.
58
+
59
+ If gather_to_rank_0_only is True, only rank 0 will load the gathered object list
60
+ (and other ranks will have an empty list).
61
+ """
62
+ world_size = get_world_size()
63
+ if world_size == 1:
64
+ return [data]
65
+
66
+ print("gathering via files")
67
+ cpu_group = _get_global_gloo_group()
68
+
69
+ # if unspecified, we will save to the current python file dir
70
+ if filesys_save_dir is not None:
71
+ save_dir = filesys_save_dir
72
+ elif "EXP_DIR" in os.environ:
73
+ save_dir = os.environ["EXP_DIR"]
74
+ else:
75
+ # try the same directory where the code is stored
76
+ save_dir = filesys_save_dir or os.path.dirname(__file__)
77
+ save_dir = os.path.join(save_dir, "all_gather_via_filesys")
78
+ if is_main_process():
79
+ os.makedirs(save_dir, exist_ok=True)
80
+
81
+ # use a timestamp and salt to distinguish different all_gather
82
+ timestamp = int(time.time()) if is_main_process() else 0
83
+ salt = random.randint(0, 2**31 - 1) if is_main_process() else 0
84
+ # broadcast the timestamp and salt across ranks
85
+ # (all-reduce will do the broadcasting since only rank 0 is non-zero)
86
+ timestamp_and_salt = torch.tensor([timestamp, salt], dtype=torch.long)
87
+ dist.all_reduce(timestamp_and_salt, group=cpu_group)
88
+ timestamp, salt = timestamp_and_salt.tolist()
89
+
90
+ # save the data to a file on the disk
91
+ rank_save = get_rank()
92
+ save_data_filename = f"data_to_gather_{timestamp}_{salt}_{rank_save}.pkl"
93
+ save_data_path = os.path.join(save_dir, save_data_filename)
94
+ assert not os.path.exists(save_data_path), f"{save_data_path} already exists"
95
+ torch.save(data, save_data_path)
96
+ dist.barrier(group=cpu_group)
97
+
98
+ # read the data from the files
99
+ data_list = []
100
+ if rank_save == 0 or not gather_to_rank_0_only:
101
+ for rank_load in range(world_size):
102
+ load_data_filename = f"data_to_gather_{timestamp}_{salt}_{rank_load}.pkl"
103
+ load_data_path = os.path.join(save_dir, load_data_filename)
104
+ assert os.path.exists(load_data_path), f"cannot read {save_data_path}"
105
+ data_list.append(torch.load(load_data_path))
106
+ dist.barrier(group=cpu_group)
107
+
108
+ # delete the saved file
109
+ os.remove(save_data_path)
110
+ return data_list
111
+
112
+
113
+ def all_gather(data, force_cpu=False, force_filesys=False, filesys_save_dir=None):
114
+ """
115
+ Run all_gather on arbitrary picklable data (not necessarily tensors)
116
+ Args:
117
+ data: any picklable object
118
+ Returns:
119
+ list[data]: list of data gathered from each rank
120
+ """
121
+
122
+ world_size = get_world_size()
123
+ if world_size == 1:
124
+ return [data]
125
+
126
+ if os.getenv("MDETR_FILESYS_REDUCE_RANK_0_ONLY") == "1":
127
+ return all_gather_via_filesys(data, filesys_save_dir, gather_to_rank_0_only=True)
128
+
129
+ if os.getenv("MDETR_FILESYS_REDUCE") == "1" or force_filesys:
130
+ return all_gather_via_filesys(data, filesys_save_dir)
131
+
132
+ cpu_group = None
133
+ if os.getenv("MDETR_CPU_REDUCE") == "1" or force_cpu:
134
+ cpu_group = _get_global_gloo_group()
135
+
136
+ buffer = io.BytesIO()
137
+ torch.save(data, buffer)
138
+ data_view = buffer.getbuffer()
139
+ device = "cuda" if cpu_group is None else "cpu"
140
+ tensor = torch.ByteTensor(data_view).to(device)
141
+
142
+ # obtain Tensor size of each rank
143
+ local_size = torch.tensor([tensor.numel()], device=device, dtype=torch.long)
144
+ size_list = [torch.tensor([0], device=device, dtype=torch.long) for _ in range(world_size)]
145
+ if cpu_group is None:
146
+ dist.all_gather(size_list, local_size)
147
+ else:
148
+ print("gathering on cpu")
149
+ dist.all_gather(size_list, local_size, group=cpu_group)
150
+ size_list = [int(size.item()) for size in size_list]
151
+ max_size = max(size_list)
152
+ assert isinstance(local_size.item(), int)
153
+ local_size = int(local_size.item())
154
+
155
+ # receiving Tensor from all ranks
156
+ # we pad the tensor because torch all_gather does not support
157
+ # gathering tensors of different shapes
158
+ tensor_list = []
159
+ for _ in size_list:
160
+ tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device=device))
161
+ if local_size != max_size:
162
+ padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device=device)
163
+ tensor = torch.cat((tensor, padding), dim=0)
164
+ if cpu_group is None:
165
+ dist.all_gather(tensor_list, tensor)
166
+ else:
167
+ dist.all_gather(tensor_list, tensor, group=cpu_group)
168
+
169
+ data_list = []
170
+ for size, tensor in zip(size_list, tensor_list):
171
+ tensor = torch.split(tensor, [size, max_size - size], dim=0)[0]
172
+ buffer = io.BytesIO(tensor.cpu().numpy())
173
+ obj = torch.load(buffer)
174
+ data_list.append(obj)
175
+
176
+ return data_list
177
+
178
+
179
+ def convert_to_distributed_tensor(tensor: torch.Tensor) -> Tuple[torch.Tensor, str]:
180
+ """
181
+ For some backends, such as NCCL, communication only works if the
182
+ tensor is on the GPU. This helper function converts to the correct
183
+ device and returns the tensor + original device.
184
+ """
185
+ orig_device = "cpu" if not tensor.is_cuda else "gpu"
186
+ if (
187
+ torch.distributed.is_available()
188
+ and torch.distributed.get_backend() == torch.distributed.Backend.NCCL
189
+ and not tensor.is_cuda
190
+ ):
191
+ tensor = tensor.cuda()
192
+ return (tensor, orig_device)
193
+
194
+
195
+ def convert_to_normal_tensor(tensor: torch.Tensor, orig_device: str) -> torch.Tensor:
196
+ """
197
+ For some backends, such as NCCL, communication only works if the
198
+ tensor is on the GPU. This converts the tensor back to original device.
199
+ """
200
+ if tensor.is_cuda and orig_device == "cpu":
201
+ tensor = tensor.cpu()
202
+ return tensor
203
+
204
+
205
+ def is_distributed_training_run() -> bool:
206
+ return (
207
+ torch.distributed.is_available()
208
+ and torch.distributed.is_initialized()
209
+ and (torch.distributed.get_world_size() > 1)
210
+ )
211
+
212
+
213
+ def is_primary() -> bool:
214
+ """
215
+ Returns True if this is rank 0 of a distributed training job OR if it is
216
+ a single trainer job. Otherwise False.
217
+ """
218
+ return get_rank() == _PRIMARY_RANK
219
+
220
+
221
+ def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
222
+ """
223
+ Wrapper over torch.distributed.all_reduce for performing mean reduction
224
+ of tensor over all processes.
225
+ """
226
+ return all_reduce_op(
227
+ tensor,
228
+ torch.distributed.ReduceOp.SUM,
229
+ lambda t: t / torch.distributed.get_world_size(),
230
+ )
231
+
232
+
233
+ def all_reduce_sum(tensor: torch.Tensor) -> torch.Tensor:
234
+ """
235
+ Wrapper over torch.distributed.all_reduce for performing sum
236
+ reduction of tensor over all processes in both distributed /
237
+ non-distributed scenarios.
238
+ """
239
+ return all_reduce_op(tensor, torch.distributed.ReduceOp.SUM)
240
+
241
+
242
+ def all_reduce_min(tensor: torch.Tensor) -> torch.Tensor:
243
+ """
244
+ Wrapper over torch.distributed.all_reduce for performing min
245
+ reduction of tensor over all processes in both distributed /
246
+ non-distributed scenarios.
247
+ """
248
+ return all_reduce_op(tensor, torch.distributed.ReduceOp.MIN)
249
+
250
+
251
+ def all_reduce_max(tensor: torch.Tensor) -> torch.Tensor:
252
+ """
253
+ Wrapper over torch.distributed.all_reduce for performing min
254
+ reduction of tensor over all processes in both distributed /
255
+ non-distributed scenarios.
256
+ """
257
+ return all_reduce_op(tensor, torch.distributed.ReduceOp.MAX)
258
+
259
+
260
+ def all_reduce_op(
261
+ tensor: torch.Tensor,
262
+ op: torch.distributed.ReduceOp,
263
+ after_op_func: Callable[[torch.Tensor], torch.Tensor] = None,
264
+ ) -> torch.Tensor:
265
+ """
266
+ Wrapper over torch.distributed.all_reduce for performing
267
+ reduction of tensor over all processes in both distributed /
268
+ non-distributed scenarios.
269
+ """
270
+ if is_distributed_training_run():
271
+ tensor, orig_device = convert_to_distributed_tensor(tensor)
272
+ torch.distributed.all_reduce(tensor, op)
273
+ if after_op_func is not None:
274
+ tensor = after_op_func(tensor)
275
+ tensor = convert_to_normal_tensor(tensor, orig_device)
276
+ return tensor
277
+
278
+
279
+ def gather_tensors_from_all(tensor: torch.Tensor) -> List[torch.Tensor]:
280
+ """
281
+ Wrapper over torch.distributed.all_gather for performing
282
+ 'gather' of 'tensor' over all processes in both distributed /
283
+ non-distributed scenarios.
284
+ """
285
+ if tensor.ndim == 0:
286
+ # 0 dim tensors cannot be gathered. so unsqueeze
287
+ tensor = tensor.unsqueeze(0)
288
+
289
+ if is_distributed_training_run():
290
+ tensor, orig_device = convert_to_distributed_tensor(tensor)
291
+ gathered_tensors = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())]
292
+ torch.distributed.all_gather(gathered_tensors, tensor)
293
+ gathered_tensors = [convert_to_normal_tensor(_tensor, orig_device) for _tensor in gathered_tensors]
294
+ else:
295
+ gathered_tensors = [tensor]
296
+
297
+ return gathered_tensors
298
+
299
+
300
+ def gather_from_all(tensor: torch.Tensor) -> torch.Tensor:
301
+ gathered_tensors = gather_tensors_from_all(tensor)
302
+ gathered_tensor = torch.cat(gathered_tensors, 0)
303
+ return gathered_tensor
304
+
305
+
306
+ def broadcast(tensor: torch.Tensor, src: int = 0) -> torch.Tensor:
307
+ """
308
+ Wrapper over torch.distributed.broadcast for broadcasting a tensor from the source
309
+ to all processes in both distributed / non-distributed scenarios.
310
+ """
311
+ if is_distributed_training_run():
312
+ tensor, orig_device = convert_to_distributed_tensor(tensor)
313
+ torch.distributed.broadcast(tensor, src)
314
+ tensor = convert_to_normal_tensor(tensor, orig_device)
315
+ return tensor
316
+
317
+
318
+ def barrier() -> None:
319
+ """
320
+ Wrapper over torch.distributed.barrier, returns without waiting
321
+ if the distributed process group is not initialized instead of throwing error.
322
+ """
323
+ if not torch.distributed.is_available() or not torch.distributed.is_initialized():
324
+ return
325
+ torch.distributed.barrier()
326
+
327
+
328
+ def get_world_size() -> int:
329
+ """
330
+ Simple wrapper for correctly getting worldsize in both distributed
331
+ / non-distributed settings
332
+ """
333
+ return (
334
+ torch.distributed.get_world_size()
335
+ if torch.distributed.is_available() and torch.distributed.is_initialized()
336
+ else 1
337
+ )
338
+
339
+
340
+ def get_rank() -> int:
341
+ """
342
+ Simple wrapper for correctly getting rank in both distributed
343
+ / non-distributed settings
344
+ """
345
+ return (
346
+ torch.distributed.get_rank() if torch.distributed.is_available() and torch.distributed.is_initialized() else 0
347
+ )
348
+
349
+
350
+ def get_primary_rank() -> int:
351
+ return _PRIMARY_RANK
352
+
353
+
354
+ def set_cuda_device_index(idx: int) -> None:
355
+ global _cuda_device_index
356
+ _cuda_device_index = idx
357
+ torch.cuda.set_device(_cuda_device_index)
358
+
359
+
360
+ def set_cpu_device() -> None:
361
+ global _cuda_device_index
362
+ _cuda_device_index = _CPU_DEVICE_INDEX
363
+
364
+
365
+ def get_cuda_device_index() -> int:
366
+ return _cuda_device_index
367
+
368
+
369
+ def init_distributed_data_parallel_model(
370
+ model: torch.nn.Module,
371
+ broadcast_buffers: bool = False,
372
+ find_unused_parameters: bool = True,
373
+ bucket_cap_mb: int = 25,
374
+ ) -> torch.nn.parallel.DistributedDataParallel:
375
+ global _cuda_device_index
376
+
377
+ if _cuda_device_index == _CPU_DEVICE_INDEX:
378
+ # CPU-only model, don't specify device
379
+ return torch.nn.parallel.DistributedDataParallel(
380
+ model,
381
+ broadcast_buffers=broadcast_buffers,
382
+ find_unused_parameters=find_unused_parameters,
383
+ bucket_cap_mb=bucket_cap_mb,
384
+ )
385
+ else:
386
+ # GPU model
387
+ return torch.nn.parallel.DistributedDataParallel(
388
+ model,
389
+ device_ids=[_cuda_device_index],
390
+ output_device=_cuda_device_index,
391
+ broadcast_buffers=broadcast_buffers,
392
+ find_unused_parameters=find_unused_parameters,
393
+ bucket_cap_mb=bucket_cap_mb,
394
+ )
395
+
396
+
397
+ def broadcast_object(obj: Any, src: int = _PRIMARY_RANK, use_disk: bool = True) -> Any:
398
+ """Broadcast an object from a source to all workers.
399
+
400
+ Args:
401
+ obj: Object to broadcast, must be serializable
402
+ src: Source rank for broadcast (default is primary)
403
+ use_disk: If enabled, removes redundant CPU memory copies by writing to
404
+ disk
405
+ """
406
+ # Either broadcast from primary to the fleet (default),
407
+ # or use the src setting as the original rank
408
+ if get_rank() == src:
409
+ # Emit data
410
+ buffer = io.BytesIO()
411
+ torch.save(obj, buffer)
412
+ data_view = buffer.getbuffer()
413
+ length_tensor = torch.LongTensor([len(data_view)])
414
+ length_tensor = broadcast(length_tensor, src=src)
415
+ data_tensor = torch.ByteTensor(data_view)
416
+ data_tensor = broadcast(data_tensor, src=src)
417
+ else:
418
+ # Fetch from the source
419
+ length_tensor = torch.LongTensor([0])
420
+ length_tensor = broadcast(length_tensor, src=src)
421
+ data_tensor = torch.empty([length_tensor.item()], dtype=torch.uint8)
422
+ data_tensor = broadcast(data_tensor, src=src)
423
+ if use_disk:
424
+ with tempfile.TemporaryFile("r+b") as f:
425
+ f.write(data_tensor.numpy())
426
+ # remove reference to the data tensor and hope that Python garbage
427
+ # collects it
428
+ del data_tensor
429
+ f.seek(0)
430
+ obj = torch.load(f)
431
+ else:
432
+ buffer = io.BytesIO(data_tensor.numpy())
433
+ obj = torch.load(buffer)
434
+ return obj
435
+
436
+
437
+ def all_gather_tensor(tensor: torch.Tensor, world_size=None):
438
+ if world_size is None:
439
+ world_size = get_world_size()
440
+ # make contiguous because NCCL won't gather the tensor otherwise
441
+ assert tensor.is_contiguous(), f"{tensor.shape} is not contiguous!"
442
+ tensor, orig_device = convert_to_distributed_tensor(tensor)
443
+ tensor_all = [torch.ones_like(tensor) for _ in range(world_size)]
444
+ dist.all_gather(tensor_all, tensor, async_op=False) # performance opt
445
+ tensor_all = [convert_to_normal_tensor(tensor, orig_device) for tensor in tensor_all]
446
+ return tensor_all
447
+
448
+
449
+ def all_gather_batch(tensors: List[torch.Tensor]):
450
+ """
451
+ Performs all_gather operation on the provided tensors.
452
+ """
453
+ # Queue the gathered tensors
454
+ world_size = get_world_size()
455
+ # There is no need for reduction in the single-proc case
456
+ if world_size == 1:
457
+ return tensors
458
+ tensor_list = []
459
+ output_tensor = []
460
+ for tensor in tensors:
461
+ tensor_all = all_gather_tensor(tensor, world_size)
462
+ tensor_list.append(tensor_all)
463
+
464
+ for tensor_all in tensor_list:
465
+ output_tensor.append(torch.cat(tensor_all, dim=0))
466
+ return output_tensor
467
+
468
+
469
+ class GatherLayer(autograd.Function):
470
+ """
471
+ Gather tensors from all workers with support for backward propagation:
472
+ This implementation does not cut the gradients as torch.distributed.all_gather does.
473
+ """
474
+
475
+ @staticmethod
476
+ def forward(ctx, x):
477
+ output = [torch.zeros_like(x) for _ in range(dist.get_world_size())]
478
+ dist.all_gather(output, x)
479
+ return tuple(output)
480
+
481
+ @staticmethod
482
+ def backward(ctx, *grads):
483
+ all_gradients = torch.stack(grads)
484
+ dist.all_reduce(all_gradients)
485
+ return all_gradients[dist.get_rank()]
486
+
487
+
488
+ def all_gather_batch_with_grad(tensors):
489
+ """
490
+ Performs all_gather operation on the provided tensors.
491
+ Graph remains connected for backward grad computation.
492
+ """
493
+ # Queue the gathered tensors
494
+ world_size = get_world_size()
495
+ # There is no need for reduction in the single-proc case
496
+ if world_size == 1:
497
+ return tensors
498
+ tensor_list = []
499
+ output_tensor = []
500
+
501
+ for tensor in tensors:
502
+ tensor_all = GatherLayer.apply(tensor)
503
+ tensor_list.append(tensor_all)
504
+
505
+ for tensor_all in tensor_list:
506
+ output_tensor.append(torch.cat(tensor_all, dim=0))
507
+ return output_tensor
508
+
509
+
510
+ def unwrap_ddp_if_wrapped(model):
511
+ if isinstance(model, torch.nn.parallel.DistributedDataParallel):
512
+ return model.module
513
+ return model
514
+
515
+
516
+ def create_new_process_group(group_size):
517
+ """
518
+ Creates process groups of a gives `group_size` and returns
519
+ process group that current GPU participates in.
520
+
521
+ `group_size` must divide the total number of GPUs (world_size).
522
+
523
+ Modified from
524
+ https://github.com/NVIDIA/apex/blob/4e1ae43f7f7ac69113ef426dd15f37123f0a2ed3/apex/parallel/__init__.py#L60
525
+
526
+ Args:
527
+ group_size (int): number of GPU's to collaborate for sync bn
528
+ """
529
+
530
+ assert group_size > 0
531
+
532
+ world_size = torch.distributed.get_world_size()
533
+ if world_size <= 8:
534
+ if group_size > world_size:
535
+ logging.warning(
536
+ f"Requested group size [{group_size}] > world size [{world_size}]. "
537
+ "Assuming local debug run and capping it to world size."
538
+ )
539
+ group_size = world_size
540
+ assert world_size >= group_size
541
+ assert world_size % group_size == 0
542
+
543
+ group = None
544
+ for group_num in range(world_size // group_size):
545
+ group_ids = range(group_num * group_size, (group_num + 1) * group_size)
546
+ cur_group = torch.distributed.new_group(ranks=group_ids)
547
+ if torch.distributed.get_rank() // group_size == group_num:
548
+ group = cur_group
549
+ # can not drop out and return here, every process must go through creation of all subgroups
550
+
551
+ assert group is not None
552
+ return group
553
+
554
+
555
+ def is_dist_avail_and_initialized():
556
+ if not dist.is_available():
557
+ return False
558
+ if not dist.is_initialized():
559
+ return False
560
+ return True