megatron-core 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of megatron-core might be problematic. Click here for more details.

@@ -0,0 +1,456 @@
1
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
+
3
+ from functools import reduce
4
+ import operator
5
+ from typing import Optional, List, Union, Callable, Tuple
6
+
7
+ import torch
8
+
9
+ from megatron import core
10
+ from megatron.core.parallel_state import (
11
+ get_pipeline_model_parallel_group,
12
+ get_pipeline_model_parallel_prev_rank,
13
+ get_pipeline_model_parallel_next_rank,
14
+ )
15
+
16
+ # Types
17
+ Shape = Union[List[int], torch.Size]
18
+
19
+ def _communicate_shapes(tensor_send_next, tensor_send_prev,
20
+ recv_prev, recv_next,
21
+ use_ring_exchange_p2p):
22
+ """Communicate tensor shapes between stages. Used to communicate
23
+ tensor shapes before the actual tensor communication happens.
24
+ This is required when the sequence lengths across micro batches
25
+ are not uniform.
26
+
27
+ Takes the following arguments:
28
+ tensor_send_next: tensor to send to next rank (no tensor sent if
29
+ set to None).
30
+ tensor_send_prev: tensor to send to prev rank (no tensor sent if
31
+ set to None).
32
+ recv_prev: boolean for whether tensor should be received from
33
+ previous rank.
34
+ recv_next: boolean for whether tensor should be received from
35
+ next rank.
36
+ Returns:
37
+ (recv_prev_shape, recv_next_shape)
38
+ """
39
+
40
+ recv_prev_shape_tensor = None
41
+ recv_next_shape_tensor = None
42
+ send_prev_shape_tensor = None
43
+ send_next_shape_tensor = None
44
+ if recv_prev:
45
+ recv_prev_shape_tensor = torch.empty((3),
46
+ device=torch.cuda.current_device(),
47
+ dtype=torch.int64)
48
+ if recv_next:
49
+ recv_next_shape_tensor = torch.empty((3),
50
+ device=torch.cuda.current_device(),
51
+ dtype=torch.int64)
52
+ if tensor_send_prev is not None:
53
+ send_prev_shape_tensor = torch.tensor(tensor_send_prev.size(),
54
+ device=torch.cuda.current_device(),
55
+ dtype=torch.int64)
56
+ if tensor_send_next is not None:
57
+ send_next_shape_tensor = torch.tensor(tensor_send_next.size(),
58
+ device=torch.cuda.current_device(),
59
+ dtype=torch.int64)
60
+
61
+ if use_ring_exchange_p2p:
62
+ torch.distributed.ring_exchange(tensor_send_prev=send_prev_shape_tensor,
63
+ tensor_recv_prev=recv_prev_shape_tensor,
64
+ tensor_send_next=send_next_shape_tensor,
65
+ tensor_recv_next=recv_next_shape_tensor,
66
+ group=mpu.get_pipeline_model_parallel_group())
67
+ else:
68
+ ops = []
69
+ if send_prev_shape_tensor is not None:
70
+ send_prev_op = torch.distributed.P2POp(
71
+ torch.distributed.isend, send_prev_shape_tensor,
72
+ mpu.get_pipeline_model_parallel_prev_rank())
73
+ ops.append(send_prev_op)
74
+ if recv_prev_shape_tensor is not None:
75
+ recv_prev_op = torch.distributed.P2POp(
76
+ torch.distributed.irecv, recv_prev_shape_tensor,
77
+ mpu.get_pipeline_model_parallel_prev_rank())
78
+ ops.append(recv_prev_op)
79
+ if send_next_shape_tensor is not None:
80
+ send_next_op = torch.distributed.P2POp(
81
+ torch.distributed.isend, send_next_shape_tensor,
82
+ mpu.get_pipeline_model_parallel_next_rank())
83
+ ops.append(send_next_op)
84
+ if recv_next_shape_tensor is not None:
85
+ recv_next_op = torch.distributed.P2POp(
86
+ torch.distributed.irecv, recv_next_shape_tensor,
87
+ mpu.get_pipeline_model_parallel_next_rank())
88
+ ops.append(recv_next_op)
89
+ if len(ops) > 0:
90
+ reqs = torch.distributed.batch_isend_irecv(ops)
91
+ for req in reqs:
92
+ req.wait()
93
+
94
+ # To protect against race condition when using batch_isend_irecv().
95
+ # should take this out once the bug with batch_isend_irecv is resolved.
96
+ torch.cuda.synchronize()
97
+
98
+ recv_prev_shape = [0, 0, 0]
99
+ if recv_prev_shape_tensor is not None:
100
+ recv_prev_shape = recv_prev_shape_tensor.tolist()
101
+
102
+ recv_next_shape = [0, 0, 0]
103
+ if recv_next_shape_tensor is not None:
104
+ recv_next_shape = recv_next_shape_tensor.tolist()
105
+
106
+ return recv_prev_shape, recv_next_shape
107
+
108
+
109
+ def _communicate(*, tensor_send_next: Optional[torch.Tensor],
110
+ tensor_send_prev: Optional[torch.Tensor],
111
+ recv_prev: bool,
112
+ recv_next: bool,
113
+ tensor_shape: Shape,
114
+ dtype: Optional[torch.dtype],
115
+ variable_seq_lengths: bool = False,
116
+ use_ring_exchange_p2p: bool = False,
117
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
118
+ """Communicate tensors between stages. Used as helper method in other
119
+ communication methods that are used in megatron/schedules.py.
120
+
121
+ Arguments:
122
+ tensor_send_next (torch.Tensor, optional):
123
+ Tensor to send to next rank (no tensor sent if None)
124
+
125
+ tensor_send_prev (torch.Tensor, optional):
126
+ Tensor to send to prev rank (no tensor sent if None)
127
+
128
+ recv_prev (boolean, required):
129
+ whether tensor should be received from previous rank.
130
+
131
+ recv_next (boolean, required):
132
+ whether tensor should be received from next rank.
133
+
134
+ tensor_shape (List[int] or torch.Size, required):
135
+ shape of tensor to receive (this method assumes that all
136
+ tensors sent and received in a single function call are
137
+ the same shape).
138
+
139
+ dtype (torch.dtype, required if either recv_{prev,next} is True):
140
+ this must be the type of the tensors that will be
141
+ received, will typically be params_dtype, but in the case
142
+ of fp32 residual connections might be torch.float.
143
+
144
+ variable_seq_lengths (bool, optional, default=False):
145
+ Support for variable sequence lengths across
146
+ microbatches. Setting this communicates the size of
147
+ tensors during pipeline parallelism communication, because
148
+ of this extra overhead it should only be set if the
149
+ sequence length is not constant during training.
150
+
151
+ use_ring_exchange_p2p (bool, optional, default = False):
152
+ Use custom ring_exchange kernel instead of
153
+ torch.distributed.batch_isend_irecv(). Requires custom
154
+ built torch with torch.distributed.ring_exchange.
155
+
156
+
157
+ Returns:
158
+ tuple containing
159
+
160
+ - tensor_recv_prev: torch.Tensor if recv_prev is True, None otherwise.
161
+ - tensor_recv_next: torch.Tensor if recv_next is True, None otherwise.
162
+
163
+ """
164
+
165
+ # Create placeholder tensors for receive in forward and backward directions
166
+ # if needed.
167
+ tensor_recv_prev = None
168
+ tensor_recv_next = None
169
+
170
+ if not variable_seq_lengths:
171
+ recv_prev_shape = tensor_shape
172
+ recv_next_shape = tensor_shape
173
+ else:
174
+ recv_prev_shape, recv_next_shape = \
175
+ _communicate_shapes(tensor_send_next,
176
+ tensor_send_prev,
177
+ recv_prev,
178
+ recv_next)
179
+
180
+ if recv_prev:
181
+ if dtype is None:
182
+ raise RuntimeError("dtype must be provided if recv_prev is True")
183
+ if tensor_shape is None:
184
+ raise RuntimeError(
185
+ "tensor_shape must be specified if recv_prev is True. "
186
+ "Common tensor_shape is (seq_length, micro_batch_size, hidden_size)"
187
+ )
188
+ tensor_recv_prev = torch.empty(recv_prev_shape,
189
+ requires_grad=True,
190
+ device=torch.cuda.current_device(),
191
+ dtype=dtype)
192
+ if recv_next:
193
+ if dtype is None:
194
+ raise RuntimeError("dtype must be provided if recv_next is True")
195
+ if tensor_shape is None:
196
+ raise RuntimeError(
197
+ "tensor_shape must be specified if recv_next is True. "
198
+ "Common tensor_shape is (seq_length, micro_batch_size, hidden_size)"
199
+ )
200
+ tensor_recv_next = torch.empty(recv_next_shape,
201
+ requires_grad=True,
202
+ device=torch.cuda.current_device(),
203
+ dtype=dtype)
204
+
205
+ # Send tensors in both the forward and backward directions as appropriate.
206
+ if use_ring_exchange_p2p:
207
+ torch.distributed.ring_exchange(tensor_send_prev=tensor_send_prev,
208
+ tensor_recv_prev=tensor_recv_prev,
209
+ tensor_send_next=tensor_send_next,
210
+ tensor_recv_next=tensor_recv_next,
211
+ group=get_pipeline_model_parallel_group())
212
+ else:
213
+ ops = []
214
+ if tensor_send_prev is not None:
215
+ send_prev_op = torch.distributed.P2POp(
216
+ torch.distributed.isend, tensor_send_prev,
217
+ get_pipeline_model_parallel_prev_rank())
218
+ ops.append(send_prev_op)
219
+ if tensor_recv_prev is not None:
220
+ recv_prev_op = torch.distributed.P2POp(
221
+ torch.distributed.irecv, tensor_recv_prev,
222
+ get_pipeline_model_parallel_prev_rank())
223
+ ops.append(recv_prev_op)
224
+ if tensor_send_next is not None:
225
+ send_next_op = torch.distributed.P2POp(
226
+ torch.distributed.isend, tensor_send_next,
227
+ get_pipeline_model_parallel_next_rank())
228
+ ops.append(send_next_op)
229
+ if tensor_recv_next is not None:
230
+ recv_next_op = torch.distributed.P2POp(
231
+ torch.distributed.irecv, tensor_recv_next,
232
+ get_pipeline_model_parallel_next_rank())
233
+ ops.append(recv_next_op)
234
+ if len(ops) > 0:
235
+ reqs = torch.distributed.batch_isend_irecv(ops)
236
+ for req in reqs:
237
+ req.wait()
238
+ # To protect against race condition when using batch_isend_irecv().
239
+ # User should assert that we have a modern enough PyTorch to not need this
240
+ torch.cuda.synchronize()
241
+
242
+ return tensor_recv_prev, tensor_recv_next
243
+
244
+
245
+ def recv_forward(tensor_shape: Shape,
246
+ dtype: torch.dtype,
247
+ timers: Callable = None) -> torch.Tensor:
248
+ """ Receive tensor from previous rank in pipeline (forward receive).
249
+
250
+
251
+ See _communicate for argument details.
252
+ """
253
+
254
+ if core.parallel_state.is_pipeline_first_stage():
255
+ input_tensor = None
256
+ else:
257
+ if timers is not None:
258
+ timers('forward-recv', log_level=2).start()
259
+ input_tensor, _ = _communicate(
260
+ tensor_send_next=None,
261
+ tensor_send_prev=None,
262
+ recv_prev=True,
263
+ recv_next=False,
264
+ tensor_shape=tensor_shape,
265
+ dtype=dtype)
266
+ if timers is not None:
267
+ timers('forward-recv').stop()
268
+ return input_tensor
269
+
270
+
271
+ def recv_backward(tensor_shape: Shape,
272
+ dtype: torch.dtype,
273
+ timers: Callable = None) -> torch.Tensor:
274
+ """Receive tensor from next rank in pipeline (backward receive).
275
+
276
+ See _communicate for argument details.
277
+ """
278
+ if core.parallel_state.is_pipeline_last_stage():
279
+ output_tensor_grad = None
280
+ else:
281
+ if timers is not None:
282
+ timers('backward-recv', log_level=2).start()
283
+ _, output_tensor_grad = _communicate(
284
+ tensor_send_next=None,
285
+ tensor_send_prev=None,
286
+ recv_prev=False,
287
+ recv_next=True,
288
+ tensor_shape=tensor_shape,
289
+ dtype=dtype)
290
+ if timers is not None:
291
+ timers('backward-recv').stop()
292
+ return output_tensor_grad
293
+
294
+
295
+ def send_forward(output_tensor: torch.Tensor,
296
+ timers: Callable = None) -> None:
297
+ """Send tensor to next rank in pipeline (forward send).
298
+
299
+ See _communicate for argument details.
300
+ """
301
+
302
+ if not core.parallel_state.is_pipeline_last_stage():
303
+ if timers is not None:
304
+ timers('forward-send', log_level=2).start()
305
+ _communicate(
306
+ tensor_send_next=output_tensor,
307
+ tensor_send_prev=None,
308
+ recv_prev=False,
309
+ recv_next=False,
310
+ tensor_shape=None,
311
+ dtype=None)
312
+ if timers is not None:
313
+ timers('forward-send').stop()
314
+
315
+
316
+ def send_backward(input_tensor_grad: torch.Tensor,
317
+ timers: Callable = None) -> None:
318
+ """Send tensor to previous rank in pipeline (backward send).
319
+
320
+ See _communicate for argument details.
321
+ """
322
+ if not core.parallel_state.is_pipeline_first_stage():
323
+ if timers is not None:
324
+ timers('backward-send', log_level=2).start()
325
+ _communicate(
326
+ tensor_send_next=None,
327
+ tensor_send_prev=input_tensor_grad,
328
+ recv_prev=False,
329
+ recv_next=False,
330
+ tensor_shape=None,
331
+ dtype=None)
332
+ if timers is not None:
333
+ timers('backward-send').stop()
334
+
335
+
336
+ def send_forward_recv_backward(output_tensor: torch.Tensor,
337
+ tensor_shape: Shape,
338
+ dtype: torch.dtype,
339
+ timers: Callable = None) -> torch.Tensor:
340
+ """Batched send and recv with next rank in pipeline.
341
+
342
+ See _communicate for argument details.
343
+ """
344
+ if core.parallel_state.is_pipeline_last_stage():
345
+ output_tensor_grad = None
346
+ else:
347
+ if timers is not None:
348
+ timers('forward-send-backward-recv', log_level=2).start()
349
+ _, output_tensor_grad = _communicate(
350
+ tensor_send_next=output_tensor,
351
+ tensor_send_prev=None,
352
+ recv_prev=False,
353
+ recv_next=True,
354
+ tensor_shape=tensor_shape,
355
+ dtype=dtype)
356
+ if timers is not None:
357
+ timers('forward-send-backward-recv').stop()
358
+ return output_tensor_grad
359
+
360
+
361
+ def send_backward_recv_forward(input_tensor_grad: torch.Tensor,
362
+ tensor_shape: Shape,
363
+ dtype: torch.dtype,
364
+ timers: Callable = None) -> torch.Tensor:
365
+ """Batched send and recv with previous rank in pipeline.
366
+
367
+ See _communicate for argument details.
368
+ """
369
+ if core.parallel_state.is_pipeline_first_stage():
370
+ input_tensor = None
371
+ else:
372
+ if timers is not None:
373
+ timers('backward-send-forward-recv', log_level=2).start()
374
+ input_tensor, _ = _communicate(
375
+ tensor_send_next=None,
376
+ tensor_send_prev=input_tensor_grad,
377
+ recv_prev=True,
378
+ recv_next=False,
379
+ tensor_shape=tensor_shape,
380
+ dtype=dtype)
381
+ if timers is not None:
382
+ timers('backward-send-forward-recv').stop()
383
+ return input_tensor
384
+
385
+
386
+ def send_forward_recv_forward(output_tensor: torch.Tensor,
387
+ recv_prev: bool,
388
+ tensor_shape: Shape,
389
+ dtype: torch.dtype,
390
+ timers: Callable = None) -> torch.Tensor:
391
+ """Batched recv from previous rank and send to next rank in pipeline.
392
+
393
+ See _communicate for argument details.
394
+ """
395
+ if timers is not None:
396
+ timers('forward-send-forward-recv', log_level=2).start()
397
+ input_tensor, _ = _communicate(
398
+ tensor_send_next=output_tensor,
399
+ tensor_send_prev=None,
400
+ recv_prev=recv_prev,
401
+ recv_next=False,
402
+ tensor_shape=tensor_shape,
403
+ dtype=dtype)
404
+ if timers is not None:
405
+ timers('forward-send-forward-recv').stop()
406
+ return input_tensor
407
+
408
+
409
+ def send_backward_recv_backward(input_tensor_grad: torch.Tensor,
410
+ recv_next: bool,
411
+ tensor_shape: Shape,
412
+ dtype: torch.dtype,
413
+ timers: Callable = None) -> torch.Tensor:
414
+ """Batched recv from next rank and send to previous rank in pipeline.
415
+
416
+ See _communicate for argument details.
417
+ """
418
+ if timers is not None:
419
+ timers('backward-send-backward-recv', log_level=2).start()
420
+ _, output_tensor_grad = _communicate(
421
+ tensor_send_next=None,
422
+ tensor_send_prev=input_tensor_grad,
423
+ recv_prev=False,
424
+ recv_next=recv_next,
425
+ tensor_shape=tensor_shape,
426
+ dtype=dtype)
427
+ if timers is not None:
428
+ timers('backward-send-backward-recv').stop()
429
+ return output_tensor_grad
430
+
431
+
432
+ def send_forward_backward_recv_forward_backward(
433
+ output_tensor: torch.Tensor,
434
+ input_tensor_grad: torch.Tensor,
435
+ recv_prev: bool,
436
+ recv_next: bool,
437
+ tensor_shape: Shape,
438
+ dtype: torch.dtype,
439
+ timers: Callable = None) -> Tuple[torch.Tensor, torch.Tensor]:
440
+ """Batched send and recv with previous and next ranks in pipeline.
441
+
442
+ See _communicate for argument details.
443
+ """
444
+ if timers is not None:
445
+ timers('forward-backward-send-forward-backward-recv',
446
+ log_level=2).start()
447
+ input_tensor, output_tensor_grad = _communicate(
448
+ tensor_send_next=output_tensor,
449
+ tensor_send_prev=input_tensor_grad,
450
+ recv_prev=recv_prev,
451
+ recv_next=recv_next,
452
+ tensor_shape=tensor_shape,
453
+ dtype=dtype)
454
+ if timers is not None:
455
+ timers('forward-backward-send-forward-backward-recv').stop()
456
+ return input_tensor, output_tensor_grad