megatron-core 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of megatron-core might be problematic. Click here for more details.
- megatron/core/__init__.py +12 -0
- megatron/core/enums.py +7 -0
- megatron/core/package_info.py +23 -0
- megatron/core/parallel_state.py +570 -0
- megatron/core/pipeline_parallel/__init__.py +1 -0
- megatron/core/pipeline_parallel/p2p_communication.py +456 -0
- megatron/core/pipeline_parallel/schedules.py +1050 -0
- megatron/core/tensor_parallel/__init__.py +65 -0
- megatron/core/tensor_parallel/cross_entropy.py +143 -0
- megatron/core/tensor_parallel/data.py +105 -0
- megatron/core/tensor_parallel/layers.py +716 -0
- megatron/core/tensor_parallel/mappings.py +279 -0
- megatron/core/tensor_parallel/random.py +253 -0
- megatron/core/tensor_parallel/utils.py +108 -0
- megatron/core/utils.py +137 -0
- megatron_core-0.1.0.dist-info/LICENSE +376 -0
- megatron_core-0.1.0.dist-info/METADATA +35 -0
- megatron_core-0.1.0.dist-info/RECORD +20 -0
- megatron_core-0.1.0.dist-info/WHEEL +5 -0
- megatron_core-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,456 @@
|
|
|
1
|
+
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
|
2
|
+
|
|
3
|
+
from functools import reduce
|
|
4
|
+
import operator
|
|
5
|
+
from typing import Optional, List, Union, Callable, Tuple
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
from megatron import core
|
|
10
|
+
from megatron.core.parallel_state import (
|
|
11
|
+
get_pipeline_model_parallel_group,
|
|
12
|
+
get_pipeline_model_parallel_prev_rank,
|
|
13
|
+
get_pipeline_model_parallel_next_rank,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
# Types
|
|
17
|
+
Shape = Union[List[int], torch.Size]
|
|
18
|
+
|
|
19
|
+
def _communicate_shapes(tensor_send_next, tensor_send_prev,
|
|
20
|
+
recv_prev, recv_next,
|
|
21
|
+
use_ring_exchange_p2p):
|
|
22
|
+
"""Communicate tensor shapes between stages. Used to communicate
|
|
23
|
+
tensor shapes before the actual tensor communication happens.
|
|
24
|
+
This is required when the sequence lengths across micro batches
|
|
25
|
+
are not uniform.
|
|
26
|
+
|
|
27
|
+
Takes the following arguments:
|
|
28
|
+
tensor_send_next: tensor to send to next rank (no tensor sent if
|
|
29
|
+
set to None).
|
|
30
|
+
tensor_send_prev: tensor to send to prev rank (no tensor sent if
|
|
31
|
+
set to None).
|
|
32
|
+
recv_prev: boolean for whether tensor should be received from
|
|
33
|
+
previous rank.
|
|
34
|
+
recv_next: boolean for whether tensor should be received from
|
|
35
|
+
next rank.
|
|
36
|
+
Returns:
|
|
37
|
+
(recv_prev_shape, recv_next_shape)
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
recv_prev_shape_tensor = None
|
|
41
|
+
recv_next_shape_tensor = None
|
|
42
|
+
send_prev_shape_tensor = None
|
|
43
|
+
send_next_shape_tensor = None
|
|
44
|
+
if recv_prev:
|
|
45
|
+
recv_prev_shape_tensor = torch.empty((3),
|
|
46
|
+
device=torch.cuda.current_device(),
|
|
47
|
+
dtype=torch.int64)
|
|
48
|
+
if recv_next:
|
|
49
|
+
recv_next_shape_tensor = torch.empty((3),
|
|
50
|
+
device=torch.cuda.current_device(),
|
|
51
|
+
dtype=torch.int64)
|
|
52
|
+
if tensor_send_prev is not None:
|
|
53
|
+
send_prev_shape_tensor = torch.tensor(tensor_send_prev.size(),
|
|
54
|
+
device=torch.cuda.current_device(),
|
|
55
|
+
dtype=torch.int64)
|
|
56
|
+
if tensor_send_next is not None:
|
|
57
|
+
send_next_shape_tensor = torch.tensor(tensor_send_next.size(),
|
|
58
|
+
device=torch.cuda.current_device(),
|
|
59
|
+
dtype=torch.int64)
|
|
60
|
+
|
|
61
|
+
if use_ring_exchange_p2p:
|
|
62
|
+
torch.distributed.ring_exchange(tensor_send_prev=send_prev_shape_tensor,
|
|
63
|
+
tensor_recv_prev=recv_prev_shape_tensor,
|
|
64
|
+
tensor_send_next=send_next_shape_tensor,
|
|
65
|
+
tensor_recv_next=recv_next_shape_tensor,
|
|
66
|
+
group=mpu.get_pipeline_model_parallel_group())
|
|
67
|
+
else:
|
|
68
|
+
ops = []
|
|
69
|
+
if send_prev_shape_tensor is not None:
|
|
70
|
+
send_prev_op = torch.distributed.P2POp(
|
|
71
|
+
torch.distributed.isend, send_prev_shape_tensor,
|
|
72
|
+
mpu.get_pipeline_model_parallel_prev_rank())
|
|
73
|
+
ops.append(send_prev_op)
|
|
74
|
+
if recv_prev_shape_tensor is not None:
|
|
75
|
+
recv_prev_op = torch.distributed.P2POp(
|
|
76
|
+
torch.distributed.irecv, recv_prev_shape_tensor,
|
|
77
|
+
mpu.get_pipeline_model_parallel_prev_rank())
|
|
78
|
+
ops.append(recv_prev_op)
|
|
79
|
+
if send_next_shape_tensor is not None:
|
|
80
|
+
send_next_op = torch.distributed.P2POp(
|
|
81
|
+
torch.distributed.isend, send_next_shape_tensor,
|
|
82
|
+
mpu.get_pipeline_model_parallel_next_rank())
|
|
83
|
+
ops.append(send_next_op)
|
|
84
|
+
if recv_next_shape_tensor is not None:
|
|
85
|
+
recv_next_op = torch.distributed.P2POp(
|
|
86
|
+
torch.distributed.irecv, recv_next_shape_tensor,
|
|
87
|
+
mpu.get_pipeline_model_parallel_next_rank())
|
|
88
|
+
ops.append(recv_next_op)
|
|
89
|
+
if len(ops) > 0:
|
|
90
|
+
reqs = torch.distributed.batch_isend_irecv(ops)
|
|
91
|
+
for req in reqs:
|
|
92
|
+
req.wait()
|
|
93
|
+
|
|
94
|
+
# To protect against race condition when using batch_isend_irecv().
|
|
95
|
+
# should take this out once the bug with batch_isend_irecv is resolved.
|
|
96
|
+
torch.cuda.synchronize()
|
|
97
|
+
|
|
98
|
+
recv_prev_shape = [0, 0, 0]
|
|
99
|
+
if recv_prev_shape_tensor is not None:
|
|
100
|
+
recv_prev_shape = recv_prev_shape_tensor.tolist()
|
|
101
|
+
|
|
102
|
+
recv_next_shape = [0, 0, 0]
|
|
103
|
+
if recv_next_shape_tensor is not None:
|
|
104
|
+
recv_next_shape = recv_next_shape_tensor.tolist()
|
|
105
|
+
|
|
106
|
+
return recv_prev_shape, recv_next_shape
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def _communicate(*, tensor_send_next: Optional[torch.Tensor],
|
|
110
|
+
tensor_send_prev: Optional[torch.Tensor],
|
|
111
|
+
recv_prev: bool,
|
|
112
|
+
recv_next: bool,
|
|
113
|
+
tensor_shape: Shape,
|
|
114
|
+
dtype: Optional[torch.dtype],
|
|
115
|
+
variable_seq_lengths: bool = False,
|
|
116
|
+
use_ring_exchange_p2p: bool = False,
|
|
117
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
118
|
+
"""Communicate tensors between stages. Used as helper method in other
|
|
119
|
+
communication methods that are used in megatron/schedules.py.
|
|
120
|
+
|
|
121
|
+
Arguments:
|
|
122
|
+
tensor_send_next (torch.Tensor, optional):
|
|
123
|
+
Tensor to send to next rank (no tensor sent if None)
|
|
124
|
+
|
|
125
|
+
tensor_send_prev (torch.Tensor, optional):
|
|
126
|
+
Tensor to send to prev rank (no tensor sent if None)
|
|
127
|
+
|
|
128
|
+
recv_prev (boolean, required):
|
|
129
|
+
whether tensor should be received from previous rank.
|
|
130
|
+
|
|
131
|
+
recv_next (boolean, required):
|
|
132
|
+
whether tensor should be received from next rank.
|
|
133
|
+
|
|
134
|
+
tensor_shape (List[int] or torch.Size, required):
|
|
135
|
+
shape of tensor to receive (this method assumes that all
|
|
136
|
+
tensors sent and received in a single function call are
|
|
137
|
+
the same shape).
|
|
138
|
+
|
|
139
|
+
dtype (torch.dtype, required if either recv_{prev,next} is True):
|
|
140
|
+
this must be the type of the tensors that will be
|
|
141
|
+
received, will typically be params_dtype, but in the case
|
|
142
|
+
of fp32 residual connections might be torch.float.
|
|
143
|
+
|
|
144
|
+
variable_seq_lengths (bool, optional, default=False):
|
|
145
|
+
Support for variable sequence lengths across
|
|
146
|
+
microbatches. Setting this communicates the size of
|
|
147
|
+
tensors during pipeline parallelism communication, because
|
|
148
|
+
of this extra overhead it should only be set if the
|
|
149
|
+
sequence length is not constant during training.
|
|
150
|
+
|
|
151
|
+
use_ring_exchange_p2p (bool, optional, default = False):
|
|
152
|
+
Use custom ring_exchange kernel instead of
|
|
153
|
+
torch.distributed.batch_isend_irecv(). Requires custom
|
|
154
|
+
built torch with torch.distributed.ring_exchange.
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
Returns:
|
|
158
|
+
tuple containing
|
|
159
|
+
|
|
160
|
+
- tensor_recv_prev: torch.Tensor if recv_prev is True, None otherwise.
|
|
161
|
+
- tensor_recv_next: torch.Tensor if recv_next is True, None otherwise.
|
|
162
|
+
|
|
163
|
+
"""
|
|
164
|
+
|
|
165
|
+
# Create placeholder tensors for receive in forward and backward directions
|
|
166
|
+
# if needed.
|
|
167
|
+
tensor_recv_prev = None
|
|
168
|
+
tensor_recv_next = None
|
|
169
|
+
|
|
170
|
+
if not variable_seq_lengths:
|
|
171
|
+
recv_prev_shape = tensor_shape
|
|
172
|
+
recv_next_shape = tensor_shape
|
|
173
|
+
else:
|
|
174
|
+
recv_prev_shape, recv_next_shape = \
|
|
175
|
+
_communicate_shapes(tensor_send_next,
|
|
176
|
+
tensor_send_prev,
|
|
177
|
+
recv_prev,
|
|
178
|
+
recv_next)
|
|
179
|
+
|
|
180
|
+
if recv_prev:
|
|
181
|
+
if dtype is None:
|
|
182
|
+
raise RuntimeError("dtype must be provided if recv_prev is True")
|
|
183
|
+
if tensor_shape is None:
|
|
184
|
+
raise RuntimeError(
|
|
185
|
+
"tensor_shape must be specified if recv_prev is True. "
|
|
186
|
+
"Common tensor_shape is (seq_length, micro_batch_size, hidden_size)"
|
|
187
|
+
)
|
|
188
|
+
tensor_recv_prev = torch.empty(recv_prev_shape,
|
|
189
|
+
requires_grad=True,
|
|
190
|
+
device=torch.cuda.current_device(),
|
|
191
|
+
dtype=dtype)
|
|
192
|
+
if recv_next:
|
|
193
|
+
if dtype is None:
|
|
194
|
+
raise RuntimeError("dtype must be provided if recv_next is True")
|
|
195
|
+
if tensor_shape is None:
|
|
196
|
+
raise RuntimeError(
|
|
197
|
+
"tensor_shape must be specified if recv_next is True. "
|
|
198
|
+
"Common tensor_shape is (seq_length, micro_batch_size, hidden_size)"
|
|
199
|
+
)
|
|
200
|
+
tensor_recv_next = torch.empty(recv_next_shape,
|
|
201
|
+
requires_grad=True,
|
|
202
|
+
device=torch.cuda.current_device(),
|
|
203
|
+
dtype=dtype)
|
|
204
|
+
|
|
205
|
+
# Send tensors in both the forward and backward directions as appropriate.
|
|
206
|
+
if use_ring_exchange_p2p:
|
|
207
|
+
torch.distributed.ring_exchange(tensor_send_prev=tensor_send_prev,
|
|
208
|
+
tensor_recv_prev=tensor_recv_prev,
|
|
209
|
+
tensor_send_next=tensor_send_next,
|
|
210
|
+
tensor_recv_next=tensor_recv_next,
|
|
211
|
+
group=get_pipeline_model_parallel_group())
|
|
212
|
+
else:
|
|
213
|
+
ops = []
|
|
214
|
+
if tensor_send_prev is not None:
|
|
215
|
+
send_prev_op = torch.distributed.P2POp(
|
|
216
|
+
torch.distributed.isend, tensor_send_prev,
|
|
217
|
+
get_pipeline_model_parallel_prev_rank())
|
|
218
|
+
ops.append(send_prev_op)
|
|
219
|
+
if tensor_recv_prev is not None:
|
|
220
|
+
recv_prev_op = torch.distributed.P2POp(
|
|
221
|
+
torch.distributed.irecv, tensor_recv_prev,
|
|
222
|
+
get_pipeline_model_parallel_prev_rank())
|
|
223
|
+
ops.append(recv_prev_op)
|
|
224
|
+
if tensor_send_next is not None:
|
|
225
|
+
send_next_op = torch.distributed.P2POp(
|
|
226
|
+
torch.distributed.isend, tensor_send_next,
|
|
227
|
+
get_pipeline_model_parallel_next_rank())
|
|
228
|
+
ops.append(send_next_op)
|
|
229
|
+
if tensor_recv_next is not None:
|
|
230
|
+
recv_next_op = torch.distributed.P2POp(
|
|
231
|
+
torch.distributed.irecv, tensor_recv_next,
|
|
232
|
+
get_pipeline_model_parallel_next_rank())
|
|
233
|
+
ops.append(recv_next_op)
|
|
234
|
+
if len(ops) > 0:
|
|
235
|
+
reqs = torch.distributed.batch_isend_irecv(ops)
|
|
236
|
+
for req in reqs:
|
|
237
|
+
req.wait()
|
|
238
|
+
# To protect against race condition when using batch_isend_irecv().
|
|
239
|
+
# User should assert that we have a modern enough PyTorch to not need this
|
|
240
|
+
torch.cuda.synchronize()
|
|
241
|
+
|
|
242
|
+
return tensor_recv_prev, tensor_recv_next
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
def recv_forward(tensor_shape: Shape,
|
|
246
|
+
dtype: torch.dtype,
|
|
247
|
+
timers: Callable = None) -> torch.Tensor:
|
|
248
|
+
""" Receive tensor from previous rank in pipeline (forward receive).
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
See _communicate for argument details.
|
|
252
|
+
"""
|
|
253
|
+
|
|
254
|
+
if core.parallel_state.is_pipeline_first_stage():
|
|
255
|
+
input_tensor = None
|
|
256
|
+
else:
|
|
257
|
+
if timers is not None:
|
|
258
|
+
timers('forward-recv', log_level=2).start()
|
|
259
|
+
input_tensor, _ = _communicate(
|
|
260
|
+
tensor_send_next=None,
|
|
261
|
+
tensor_send_prev=None,
|
|
262
|
+
recv_prev=True,
|
|
263
|
+
recv_next=False,
|
|
264
|
+
tensor_shape=tensor_shape,
|
|
265
|
+
dtype=dtype)
|
|
266
|
+
if timers is not None:
|
|
267
|
+
timers('forward-recv').stop()
|
|
268
|
+
return input_tensor
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
def recv_backward(tensor_shape: Shape,
|
|
272
|
+
dtype: torch.dtype,
|
|
273
|
+
timers: Callable = None) -> torch.Tensor:
|
|
274
|
+
"""Receive tensor from next rank in pipeline (backward receive).
|
|
275
|
+
|
|
276
|
+
See _communicate for argument details.
|
|
277
|
+
"""
|
|
278
|
+
if core.parallel_state.is_pipeline_last_stage():
|
|
279
|
+
output_tensor_grad = None
|
|
280
|
+
else:
|
|
281
|
+
if timers is not None:
|
|
282
|
+
timers('backward-recv', log_level=2).start()
|
|
283
|
+
_, output_tensor_grad = _communicate(
|
|
284
|
+
tensor_send_next=None,
|
|
285
|
+
tensor_send_prev=None,
|
|
286
|
+
recv_prev=False,
|
|
287
|
+
recv_next=True,
|
|
288
|
+
tensor_shape=tensor_shape,
|
|
289
|
+
dtype=dtype)
|
|
290
|
+
if timers is not None:
|
|
291
|
+
timers('backward-recv').stop()
|
|
292
|
+
return output_tensor_grad
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
def send_forward(output_tensor: torch.Tensor,
|
|
296
|
+
timers: Callable = None) -> None:
|
|
297
|
+
"""Send tensor to next rank in pipeline (forward send).
|
|
298
|
+
|
|
299
|
+
See _communicate for argument details.
|
|
300
|
+
"""
|
|
301
|
+
|
|
302
|
+
if not core.parallel_state.is_pipeline_last_stage():
|
|
303
|
+
if timers is not None:
|
|
304
|
+
timers('forward-send', log_level=2).start()
|
|
305
|
+
_communicate(
|
|
306
|
+
tensor_send_next=output_tensor,
|
|
307
|
+
tensor_send_prev=None,
|
|
308
|
+
recv_prev=False,
|
|
309
|
+
recv_next=False,
|
|
310
|
+
tensor_shape=None,
|
|
311
|
+
dtype=None)
|
|
312
|
+
if timers is not None:
|
|
313
|
+
timers('forward-send').stop()
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
def send_backward(input_tensor_grad: torch.Tensor,
|
|
317
|
+
timers: Callable = None) -> None:
|
|
318
|
+
"""Send tensor to previous rank in pipeline (backward send).
|
|
319
|
+
|
|
320
|
+
See _communicate for argument details.
|
|
321
|
+
"""
|
|
322
|
+
if not core.parallel_state.is_pipeline_first_stage():
|
|
323
|
+
if timers is not None:
|
|
324
|
+
timers('backward-send', log_level=2).start()
|
|
325
|
+
_communicate(
|
|
326
|
+
tensor_send_next=None,
|
|
327
|
+
tensor_send_prev=input_tensor_grad,
|
|
328
|
+
recv_prev=False,
|
|
329
|
+
recv_next=False,
|
|
330
|
+
tensor_shape=None,
|
|
331
|
+
dtype=None)
|
|
332
|
+
if timers is not None:
|
|
333
|
+
timers('backward-send').stop()
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
def send_forward_recv_backward(output_tensor: torch.Tensor,
|
|
337
|
+
tensor_shape: Shape,
|
|
338
|
+
dtype: torch.dtype,
|
|
339
|
+
timers: Callable = None) -> torch.Tensor:
|
|
340
|
+
"""Batched send and recv with next rank in pipeline.
|
|
341
|
+
|
|
342
|
+
See _communicate for argument details.
|
|
343
|
+
"""
|
|
344
|
+
if core.parallel_state.is_pipeline_last_stage():
|
|
345
|
+
output_tensor_grad = None
|
|
346
|
+
else:
|
|
347
|
+
if timers is not None:
|
|
348
|
+
timers('forward-send-backward-recv', log_level=2).start()
|
|
349
|
+
_, output_tensor_grad = _communicate(
|
|
350
|
+
tensor_send_next=output_tensor,
|
|
351
|
+
tensor_send_prev=None,
|
|
352
|
+
recv_prev=False,
|
|
353
|
+
recv_next=True,
|
|
354
|
+
tensor_shape=tensor_shape,
|
|
355
|
+
dtype=dtype)
|
|
356
|
+
if timers is not None:
|
|
357
|
+
timers('forward-send-backward-recv').stop()
|
|
358
|
+
return output_tensor_grad
|
|
359
|
+
|
|
360
|
+
|
|
361
|
+
def send_backward_recv_forward(input_tensor_grad: torch.Tensor,
|
|
362
|
+
tensor_shape: Shape,
|
|
363
|
+
dtype: torch.dtype,
|
|
364
|
+
timers: Callable = None) -> torch.Tensor:
|
|
365
|
+
"""Batched send and recv with previous rank in pipeline.
|
|
366
|
+
|
|
367
|
+
See _communicate for argument details.
|
|
368
|
+
"""
|
|
369
|
+
if core.parallel_state.is_pipeline_first_stage():
|
|
370
|
+
input_tensor = None
|
|
371
|
+
else:
|
|
372
|
+
if timers is not None:
|
|
373
|
+
timers('backward-send-forward-recv', log_level=2).start()
|
|
374
|
+
input_tensor, _ = _communicate(
|
|
375
|
+
tensor_send_next=None,
|
|
376
|
+
tensor_send_prev=input_tensor_grad,
|
|
377
|
+
recv_prev=True,
|
|
378
|
+
recv_next=False,
|
|
379
|
+
tensor_shape=tensor_shape,
|
|
380
|
+
dtype=dtype)
|
|
381
|
+
if timers is not None:
|
|
382
|
+
timers('backward-send-forward-recv').stop()
|
|
383
|
+
return input_tensor
|
|
384
|
+
|
|
385
|
+
|
|
386
|
+
def send_forward_recv_forward(output_tensor: torch.Tensor,
|
|
387
|
+
recv_prev: bool,
|
|
388
|
+
tensor_shape: Shape,
|
|
389
|
+
dtype: torch.dtype,
|
|
390
|
+
timers: Callable = None) -> torch.Tensor:
|
|
391
|
+
"""Batched recv from previous rank and send to next rank in pipeline.
|
|
392
|
+
|
|
393
|
+
See _communicate for argument details.
|
|
394
|
+
"""
|
|
395
|
+
if timers is not None:
|
|
396
|
+
timers('forward-send-forward-recv', log_level=2).start()
|
|
397
|
+
input_tensor, _ = _communicate(
|
|
398
|
+
tensor_send_next=output_tensor,
|
|
399
|
+
tensor_send_prev=None,
|
|
400
|
+
recv_prev=recv_prev,
|
|
401
|
+
recv_next=False,
|
|
402
|
+
tensor_shape=tensor_shape,
|
|
403
|
+
dtype=dtype)
|
|
404
|
+
if timers is not None:
|
|
405
|
+
timers('forward-send-forward-recv').stop()
|
|
406
|
+
return input_tensor
|
|
407
|
+
|
|
408
|
+
|
|
409
|
+
def send_backward_recv_backward(input_tensor_grad: torch.Tensor,
|
|
410
|
+
recv_next: bool,
|
|
411
|
+
tensor_shape: Shape,
|
|
412
|
+
dtype: torch.dtype,
|
|
413
|
+
timers: Callable = None) -> torch.Tensor:
|
|
414
|
+
"""Batched recv from next rank and send to previous rank in pipeline.
|
|
415
|
+
|
|
416
|
+
See _communicate for argument details.
|
|
417
|
+
"""
|
|
418
|
+
if timers is not None:
|
|
419
|
+
timers('backward-send-backward-recv', log_level=2).start()
|
|
420
|
+
_, output_tensor_grad = _communicate(
|
|
421
|
+
tensor_send_next=None,
|
|
422
|
+
tensor_send_prev=input_tensor_grad,
|
|
423
|
+
recv_prev=False,
|
|
424
|
+
recv_next=recv_next,
|
|
425
|
+
tensor_shape=tensor_shape,
|
|
426
|
+
dtype=dtype)
|
|
427
|
+
if timers is not None:
|
|
428
|
+
timers('backward-send-backward-recv').stop()
|
|
429
|
+
return output_tensor_grad
|
|
430
|
+
|
|
431
|
+
|
|
432
|
+
def send_forward_backward_recv_forward_backward(
|
|
433
|
+
output_tensor: torch.Tensor,
|
|
434
|
+
input_tensor_grad: torch.Tensor,
|
|
435
|
+
recv_prev: bool,
|
|
436
|
+
recv_next: bool,
|
|
437
|
+
tensor_shape: Shape,
|
|
438
|
+
dtype: torch.dtype,
|
|
439
|
+
timers: Callable = None) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
440
|
+
"""Batched send and recv with previous and next ranks in pipeline.
|
|
441
|
+
|
|
442
|
+
See _communicate for argument details.
|
|
443
|
+
"""
|
|
444
|
+
if timers is not None:
|
|
445
|
+
timers('forward-backward-send-forward-backward-recv',
|
|
446
|
+
log_level=2).start()
|
|
447
|
+
input_tensor, output_tensor_grad = _communicate(
|
|
448
|
+
tensor_send_next=output_tensor,
|
|
449
|
+
tensor_send_prev=input_tensor_grad,
|
|
450
|
+
recv_prev=recv_prev,
|
|
451
|
+
recv_next=recv_next,
|
|
452
|
+
tensor_shape=tensor_shape,
|
|
453
|
+
dtype=dtype)
|
|
454
|
+
if timers is not None:
|
|
455
|
+
timers('forward-backward-send-forward-backward-recv').stop()
|
|
456
|
+
return input_tensor, output_tensor_grad
|