rslearn 0.0.3__py3-none-any.whl → 0.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/arg_parser.py +59 -0
- rslearn/data_sources/copernicus.py +4 -4
- rslearn/data_sources/earthdaily.py +21 -1
- rslearn/data_sources/gcp_public_data.py +3 -3
- rslearn/data_sources/utils.py +1 -17
- rslearn/main.py +10 -1
- rslearn/models/trunk.py +0 -144
- rslearn/train/callbacks/adapters.py +53 -0
- rslearn/train/callbacks/freeze_unfreeze.py +319 -0
- rslearn/train/callbacks/gradients.py +54 -34
- rslearn/train/data_module.py +70 -41
- rslearn/train/dataset.py +232 -54
- rslearn/train/lightning_module.py +4 -0
- rslearn/train/prediction_writer.py +7 -0
- rslearn/train/scheduler.py +15 -0
- rslearn/train/tasks/per_pixel_regression.py +259 -0
- rslearn/train/tasks/regression.py +6 -4
- rslearn/train/tasks/segmentation.py +44 -14
- rslearn/train/transforms/mask.py +69 -0
- rslearn/utils/geometry.py +8 -8
- {rslearn-0.0.3.dist-info → rslearn-0.0.4.dist-info}/METADATA +3 -3
- {rslearn-0.0.3.dist-info → rslearn-0.0.4.dist-info}/RECORD +26 -24
- rslearn/models/moe/distributed.py +0 -262
- rslearn/models/moe/soft.py +0 -676
- {rslearn-0.0.3.dist-info → rslearn-0.0.4.dist-info}/WHEEL +0 -0
- {rslearn-0.0.3.dist-info → rslearn-0.0.4.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.3.dist-info → rslearn-0.0.4.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.3.dist-info → rslearn-0.0.4.dist-info}/top_level.txt +0 -0
rslearn/models/moe/soft.py
DELETED
|
@@ -1,676 +0,0 @@
|
|
|
1
|
-
"""Soft MoE (Mixture of Experts) implementation.
|
|
2
|
-
|
|
3
|
-
Mostly from
|
|
4
|
-
https://raw.githubusercontent.com/lucidrains/soft-moe-pytorch/refs/heads/main/soft_moe_pytorch/soft_moe.py.
|
|
5
|
-
"""
|
|
6
|
-
|
|
7
|
-
from collections.abc import Callable
|
|
8
|
-
from typing import Any
|
|
9
|
-
|
|
10
|
-
import torch
|
|
11
|
-
import torch.distributed as dist
|
|
12
|
-
import torch.nn.functional as F
|
|
13
|
-
from einops import pack, rearrange, unpack
|
|
14
|
-
from torch import Tensor, einsum, nn
|
|
15
|
-
from torch.nn import Module
|
|
16
|
-
|
|
17
|
-
from rslearn.models.moe.distributed import (
|
|
18
|
-
AllGather,
|
|
19
|
-
gather_sizes,
|
|
20
|
-
has_only_one_value,
|
|
21
|
-
split_by_rank,
|
|
22
|
-
)
|
|
23
|
-
|
|
24
|
-
# helper functions
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
def exists(val: Any) -> bool:
|
|
28
|
-
"""Check if a value exists (is not None).
|
|
29
|
-
|
|
30
|
-
Args:
|
|
31
|
-
val: The value to check.
|
|
32
|
-
|
|
33
|
-
Returns:
|
|
34
|
-
bool: True if the value is not None, False otherwise.
|
|
35
|
-
"""
|
|
36
|
-
return val is not None
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
def default(val: Any, d: Any) -> Any:
|
|
40
|
-
"""Return the value if it exists, otherwise return the default.
|
|
41
|
-
|
|
42
|
-
Args:
|
|
43
|
-
val: The value to check.
|
|
44
|
-
d: The default value to return if val is None.
|
|
45
|
-
|
|
46
|
-
Returns:
|
|
47
|
-
Any: The value if it exists, otherwise the default.
|
|
48
|
-
"""
|
|
49
|
-
return val if exists(val) else d
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
def divisible_by(num: int, den: int) -> bool:
|
|
53
|
-
"""Check if a number is divisible by another.
|
|
54
|
-
|
|
55
|
-
Args:
|
|
56
|
-
num: The numerator.
|
|
57
|
-
den: The denominator.
|
|
58
|
-
|
|
59
|
-
Returns:
|
|
60
|
-
bool: True if num is divisible by den, False otherwise.
|
|
61
|
-
"""
|
|
62
|
-
return (num % den) == 0
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
def chunk_num(num: int, chunks: int) -> list[int]:
|
|
66
|
-
"""Divide a number into approximately equal chunks.
|
|
67
|
-
|
|
68
|
-
Args:
|
|
69
|
-
num: The number to divide.
|
|
70
|
-
chunks: The number of chunks to create.
|
|
71
|
-
|
|
72
|
-
Returns:
|
|
73
|
-
List[int]: List of chunk sizes that sum to num.
|
|
74
|
-
"""
|
|
75
|
-
num_per_chunk, remainder = divmod(num, chunks)
|
|
76
|
-
|
|
77
|
-
out = []
|
|
78
|
-
for i in range(chunks):
|
|
79
|
-
n = num_per_chunk
|
|
80
|
-
out.append(n + int(i < remainder))
|
|
81
|
-
|
|
82
|
-
return out
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
def pack_one(t: Tensor, pattern: str) -> tuple[Tensor, tuple[int, ...]]:
|
|
86
|
-
"""Pack a single tensor using einops pattern.
|
|
87
|
-
|
|
88
|
-
Args:
|
|
89
|
-
t: The tensor to pack.
|
|
90
|
-
pattern: The einops pattern to use.
|
|
91
|
-
|
|
92
|
-
Returns:
|
|
93
|
-
Tuple[Tensor, Tuple[int, ...]]: Packed tensor and its shape.
|
|
94
|
-
"""
|
|
95
|
-
return pack([t], pattern)
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
def unpack_one(t: Tensor, ps: tuple[int, ...], pattern: str) -> Tensor:
|
|
99
|
-
"""Unpack a single tensor using einops pattern.
|
|
100
|
-
|
|
101
|
-
Args:
|
|
102
|
-
t: The tensor to unpack.
|
|
103
|
-
ps: The packed shape.
|
|
104
|
-
pattern: The einops pattern to use.
|
|
105
|
-
|
|
106
|
-
Returns:
|
|
107
|
-
Tensor: The unpacked tensor.
|
|
108
|
-
"""
|
|
109
|
-
return unpack(t, ps, pattern)[0]
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
def l2norm(t: Tensor) -> Tensor:
|
|
113
|
-
"""Apply L2 normalization to a tensor.
|
|
114
|
-
|
|
115
|
-
Args:
|
|
116
|
-
t: The tensor to normalize.
|
|
117
|
-
|
|
118
|
-
Returns:
|
|
119
|
-
Tensor: The L2 normalized tensor.
|
|
120
|
-
"""
|
|
121
|
-
return F.normalize(t, dim=-1)
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
def cumsum_exclusive(t: Tensor, dim: int = -3) -> Tensor:
|
|
125
|
-
"""Compute exclusive cumulative sum along a dimension.
|
|
126
|
-
|
|
127
|
-
Args:
|
|
128
|
-
t: The input tensor.
|
|
129
|
-
dim: The dimension along which to compute the cumulative sum.
|
|
130
|
-
|
|
131
|
-
Returns:
|
|
132
|
-
Tensor: The exclusive cumulative sum.
|
|
133
|
-
|
|
134
|
-
Raises:
|
|
135
|
-
AssertionError: If dim is not negative.
|
|
136
|
-
"""
|
|
137
|
-
assert dim < 0
|
|
138
|
-
num_pad_dims = -dim - 1
|
|
139
|
-
pre_padding = (0, 0) * num_pad_dims
|
|
140
|
-
return F.pad(t, (*pre_padding, 1, -1)).cumsum(dim=dim)
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
def log(t: Tensor, eps: float = 1e-20) -> Tensor:
|
|
144
|
-
"""Compute the natural logarithm with a minimum value.
|
|
145
|
-
|
|
146
|
-
Args:
|
|
147
|
-
t: The input tensor.
|
|
148
|
-
eps: The minimum value to clamp to.
|
|
149
|
-
|
|
150
|
-
Returns:
|
|
151
|
-
Tensor: The natural logarithm of the clamped tensor.
|
|
152
|
-
"""
|
|
153
|
-
return torch.log(t.clamp(min=eps))
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
def gumbel_noise(t: Tensor) -> Tensor:
|
|
157
|
-
"""Generate Gumbel noise for the given tensor.
|
|
158
|
-
|
|
159
|
-
Args:
|
|
160
|
-
t: The input tensor.
|
|
161
|
-
|
|
162
|
-
Returns:
|
|
163
|
-
Tensor: Gumbel noise with the same shape as t.
|
|
164
|
-
"""
|
|
165
|
-
noise = torch.zeros_like(t).uniform_(0, 1)
|
|
166
|
-
return -log(-log(noise))
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
# norm
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
class LayerNorm(nn.Module):
|
|
173
|
-
"""Layer normalization module with learnable parameters.
|
|
174
|
-
|
|
175
|
-
This module applies layer normalization with learnable gamma and beta parameters.
|
|
176
|
-
"""
|
|
177
|
-
|
|
178
|
-
def __init__(self, dim: int) -> None:
|
|
179
|
-
"""Initialize the LayerNorm module.
|
|
180
|
-
|
|
181
|
-
Args:
|
|
182
|
-
dim: The dimension to normalize over.
|
|
183
|
-
"""
|
|
184
|
-
super().__init__()
|
|
185
|
-
self.gamma = nn.Parameter(torch.ones(dim))
|
|
186
|
-
self.register_buffer("beta", torch.zeros(dim))
|
|
187
|
-
|
|
188
|
-
def forward(self, x: Tensor) -> Tensor:
|
|
189
|
-
"""Forward pass of the layer normalization.
|
|
190
|
-
|
|
191
|
-
Args:
|
|
192
|
-
x: Input tensor of shape (..., dim).
|
|
193
|
-
|
|
194
|
-
Returns:
|
|
195
|
-
Tensor: Normalized tensor with the same shape as x.
|
|
196
|
-
"""
|
|
197
|
-
return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
class RMSNorm(Module):
|
|
201
|
-
"""Root Mean Square normalization module.
|
|
202
|
-
|
|
203
|
-
This module applies RMS normalization with a learnable gamma parameter.
|
|
204
|
-
"""
|
|
205
|
-
|
|
206
|
-
def __init__(self, dim: int) -> None:
|
|
207
|
-
"""Initialize the RMSNorm module.
|
|
208
|
-
|
|
209
|
-
Args:
|
|
210
|
-
dim: The dimension to normalize over.
|
|
211
|
-
"""
|
|
212
|
-
super().__init__()
|
|
213
|
-
self.scale = dim**0.5
|
|
214
|
-
self.gamma = nn.Parameter(torch.ones(dim))
|
|
215
|
-
|
|
216
|
-
def forward(self, x: Tensor) -> Tensor:
|
|
217
|
-
"""Forward pass of the RMS normalization.
|
|
218
|
-
|
|
219
|
-
Args:
|
|
220
|
-
x: Input tensor of shape (..., dim).
|
|
221
|
-
|
|
222
|
-
Returns:
|
|
223
|
-
Tensor: Normalized tensor with the same shape as x.
|
|
224
|
-
"""
|
|
225
|
-
return l2norm(x) * self.scale * self.gamma
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
# expert
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
def FeedForward(dim: int, mult: int = 4, dropout: float = 0.0) -> nn.Sequential:
|
|
232
|
-
"""Create a feedforward neural network.
|
|
233
|
-
|
|
234
|
-
Args:
|
|
235
|
-
dim: The input and output dimension.
|
|
236
|
-
mult: The multiplier for the hidden dimension.
|
|
237
|
-
dropout: The dropout rate.
|
|
238
|
-
|
|
239
|
-
Returns:
|
|
240
|
-
nn.Sequential: A feedforward network with GELU activation.
|
|
241
|
-
"""
|
|
242
|
-
dim_hidden = int(dim * mult)
|
|
243
|
-
return nn.Sequential(
|
|
244
|
-
nn.Linear(dim, dim_hidden),
|
|
245
|
-
nn.GELU(),
|
|
246
|
-
nn.Dropout(dropout),
|
|
247
|
-
nn.Linear(dim_hidden, dim),
|
|
248
|
-
)
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
class GEGLU(Module):
|
|
252
|
-
"""Gated Linear Unit with GELU activation.
|
|
253
|
-
|
|
254
|
-
This module implements a gated linear unit where the gate uses GELU activation.
|
|
255
|
-
"""
|
|
256
|
-
|
|
257
|
-
def forward(self, x: Tensor) -> Tensor:
|
|
258
|
-
"""Forward pass of the GEGLU module.
|
|
259
|
-
|
|
260
|
-
Args:
|
|
261
|
-
x: Input tensor of shape (..., 2 * dim).
|
|
262
|
-
|
|
263
|
-
Returns:
|
|
264
|
-
Tensor: Output tensor of shape (..., dim).
|
|
265
|
-
"""
|
|
266
|
-
x, gate = x.chunk(2, dim=-1)
|
|
267
|
-
return x * F.gelu(gate)
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
def GLUFeedForward(dim: int, mult: int = 4, dropout: float = 0.0) -> nn.Sequential:
|
|
271
|
-
"""Create a feedforward neural network with GLU activation.
|
|
272
|
-
|
|
273
|
-
Args:
|
|
274
|
-
dim: The input and output dimension.
|
|
275
|
-
mult: The multiplier for the hidden dimension.
|
|
276
|
-
dropout: The dropout rate.
|
|
277
|
-
|
|
278
|
-
Returns:
|
|
279
|
-
nn.Sequential: A feedforward network with GLU activation.
|
|
280
|
-
"""
|
|
281
|
-
dim_hidden = int(dim * mult * 2 / 3)
|
|
282
|
-
|
|
283
|
-
return nn.Sequential(
|
|
284
|
-
nn.Linear(dim, dim_hidden * 2),
|
|
285
|
-
GEGLU(),
|
|
286
|
-
nn.Dropout(dropout),
|
|
287
|
-
nn.Linear(dim_hidden, dim),
|
|
288
|
-
)
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
# experts
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
class Experts(nn.Module):
|
|
295
|
-
"""A module that manages multiple expert networks for distributed training.
|
|
296
|
-
|
|
297
|
-
This module handles the distribution of experts across multiple devices
|
|
298
|
-
and manages the routing of inputs to the appropriate experts.
|
|
299
|
-
"""
|
|
300
|
-
|
|
301
|
-
def __init__(
|
|
302
|
-
self,
|
|
303
|
-
experts: list[nn.Module],
|
|
304
|
-
is_distributed: bool | None = None,
|
|
305
|
-
offload_unused_experts_to_cpu: bool = True,
|
|
306
|
-
) -> None:
|
|
307
|
-
"""Initialize the Experts module.
|
|
308
|
-
|
|
309
|
-
Args:
|
|
310
|
-
experts: List of expert modules.
|
|
311
|
-
is_distributed: Whether to use distributed training. If None,
|
|
312
|
-
automatically detected from torch.distributed.
|
|
313
|
-
offload_unused_experts_to_cpu: Whether to move unused experts to CPU
|
|
314
|
-
to save GPU memory.
|
|
315
|
-
"""
|
|
316
|
-
super().__init__()
|
|
317
|
-
self.num_experts = len(experts)
|
|
318
|
-
self.experts = nn.ModuleList(experts)
|
|
319
|
-
|
|
320
|
-
self.is_distributed = is_distributed
|
|
321
|
-
if not exists(self.is_distributed):
|
|
322
|
-
self.is_distributed = dist.is_initialized() and dist.get_world_size() > 1
|
|
323
|
-
|
|
324
|
-
# whether to offload unused experts to cpu, will require optimizer handles conversion of gradients to right device when accumulating
|
|
325
|
-
self.offload_unused_experts_to_cpu = offload_unused_experts_to_cpu
|
|
326
|
-
|
|
327
|
-
self.all_gather = AllGather()
|
|
328
|
-
self.register_buffer("dummy", torch.ones(1), persistent=False)
|
|
329
|
-
|
|
330
|
-
@property
|
|
331
|
-
def device(self) -> torch.device:
|
|
332
|
-
"""Get the device of the dummy buffer.
|
|
333
|
-
|
|
334
|
-
Returns:
|
|
335
|
-
torch.device: The device of the module.
|
|
336
|
-
"""
|
|
337
|
-
return self.dummy.device
|
|
338
|
-
|
|
339
|
-
def all_experts_to_cpu_besides(
|
|
340
|
-
self, selection: int | slice | list[nn.Module]
|
|
341
|
-
) -> None:
|
|
342
|
-
"""Move all experts to CPU except those in the selection.
|
|
343
|
-
|
|
344
|
-
Args:
|
|
345
|
-
selection: The experts to keep on the current device. Can be an int,
|
|
346
|
-
slice, or list of expert modules.
|
|
347
|
-
"""
|
|
348
|
-
if not self.offload_unused_experts_to_cpu:
|
|
349
|
-
return
|
|
350
|
-
|
|
351
|
-
if isinstance(selection, int):
|
|
352
|
-
experts = [self.experts[selection]]
|
|
353
|
-
elif isinstance(selection, slice):
|
|
354
|
-
experts = self.experts[selection]
|
|
355
|
-
else:
|
|
356
|
-
experts = selection
|
|
357
|
-
|
|
358
|
-
experts_set = set(experts)
|
|
359
|
-
|
|
360
|
-
for expert in self.experts:
|
|
361
|
-
device = self.device if expert in experts_set else "cpu"
|
|
362
|
-
expert.to(device)
|
|
363
|
-
|
|
364
|
-
def forward(self, x: Tensor, is_distributed: bool | None = None) -> Tensor:
|
|
365
|
-
"""Forward pass through the experts.
|
|
366
|
-
|
|
367
|
-
Args:
|
|
368
|
-
x: Input tensor of shape (batch, experts, seq_len, dim).
|
|
369
|
-
is_distributed: Whether to use distributed training. If None, uses
|
|
370
|
-
the default setting.
|
|
371
|
-
|
|
372
|
-
Returns:
|
|
373
|
-
Tensor: Output tensor with the same shape as the input.
|
|
374
|
-
|
|
375
|
-
Note:
|
|
376
|
-
einops notation:
|
|
377
|
-
b - batch
|
|
378
|
-
r - rank (device / machines)
|
|
379
|
-
e - experts
|
|
380
|
-
n - sequence (number of tokens per expert)
|
|
381
|
-
d - feature dimension
|
|
382
|
-
"""
|
|
383
|
-
is_distributed = default(is_distributed, self.is_distributed)
|
|
384
|
-
shape, num_experts = x.shape, self.num_experts
|
|
385
|
-
|
|
386
|
-
# for now naively all gather across batch dimension if distributed, optimize later
|
|
387
|
-
|
|
388
|
-
if is_distributed:
|
|
389
|
-
seq_sizes = gather_sizes(x, dim=-2)
|
|
390
|
-
assert has_only_one_value(seq_sizes), (
|
|
391
|
-
"number of tokens per expert must be the same"
|
|
392
|
-
)
|
|
393
|
-
|
|
394
|
-
x, batch_sizes = self.all_gather(x)
|
|
395
|
-
total_batch_size = x.shape[0]
|
|
396
|
-
|
|
397
|
-
world_size = dist.get_world_size()
|
|
398
|
-
rank = dist.get_rank()
|
|
399
|
-
else:
|
|
400
|
-
world_size = 1
|
|
401
|
-
rank = 0
|
|
402
|
-
|
|
403
|
-
# the experts in use on the rank
|
|
404
|
-
|
|
405
|
-
if is_distributed:
|
|
406
|
-
if world_size <= num_experts:
|
|
407
|
-
num_experts_across_ranks = chunk_num(num_experts, world_size)
|
|
408
|
-
start_indices = cumsum_exclusive(
|
|
409
|
-
torch.tensor(num_experts_across_ranks), dim=-1
|
|
410
|
-
)
|
|
411
|
-
|
|
412
|
-
num_experts_per_rank = num_experts_across_ranks[rank]
|
|
413
|
-
num_experts_batches_across_ranks = [
|
|
414
|
-
i * total_batch_size for i in num_experts_across_ranks
|
|
415
|
-
]
|
|
416
|
-
|
|
417
|
-
expert_start_index = start_indices[rank].item()
|
|
418
|
-
else:
|
|
419
|
-
num_batch_chunks = world_size // num_experts
|
|
420
|
-
total_ranks_in_use = num_batch_chunks * num_experts
|
|
421
|
-
|
|
422
|
-
expert_start_index = rank // num_batch_chunks
|
|
423
|
-
|
|
424
|
-
batch_splits = chunk_num(total_batch_size, num_batch_chunks)
|
|
425
|
-
num_experts_batches_across_ranks = list(batch_splits * num_experts)
|
|
426
|
-
|
|
427
|
-
# for now, remaining machines just process nothing
|
|
428
|
-
|
|
429
|
-
remain_ranks = world_size % num_experts
|
|
430
|
-
num_experts_batches_across_ranks += [0] * remain_ranks
|
|
431
|
-
|
|
432
|
-
num_experts_per_rank = int(rank < total_ranks_in_use)
|
|
433
|
-
|
|
434
|
-
assert len(num_experts_batches_across_ranks) == world_size
|
|
435
|
-
|
|
436
|
-
expert_slice = slice(
|
|
437
|
-
expert_start_index, expert_start_index + num_experts_per_rank
|
|
438
|
-
)
|
|
439
|
-
else:
|
|
440
|
-
num_experts_per_rank = num_experts
|
|
441
|
-
expert_slice = slice(0, num_experts)
|
|
442
|
-
|
|
443
|
-
# if distributed, each machine only handles subset of experts and batch
|
|
444
|
-
|
|
445
|
-
x = rearrange(x, "b e n d -> e b n d")
|
|
446
|
-
|
|
447
|
-
if is_distributed:
|
|
448
|
-
x, expert_batch_packed_shape = pack_one(x, "* n d")
|
|
449
|
-
x_split = x.split(num_experts_batches_across_ranks, dim=0)
|
|
450
|
-
x = split_by_rank(x_split)
|
|
451
|
-
|
|
452
|
-
if num_experts_per_rank > 0:
|
|
453
|
-
x = rearrange(x, "(e b) n d -> e b n d", e=num_experts_per_rank)
|
|
454
|
-
else:
|
|
455
|
-
x = x.reshape(num_experts, *x.shape)
|
|
456
|
-
|
|
457
|
-
# get the experts in use
|
|
458
|
-
|
|
459
|
-
self.all_experts_to_cpu_besides(expert_slice)
|
|
460
|
-
|
|
461
|
-
experts = self.experts[expert_slice]
|
|
462
|
-
|
|
463
|
-
# route tokens to appropriate experts
|
|
464
|
-
|
|
465
|
-
outs_list = []
|
|
466
|
-
for expert, expert_input in zip(experts, x):
|
|
467
|
-
out = expert(expert_input)
|
|
468
|
-
outs_list.append(out)
|
|
469
|
-
|
|
470
|
-
if len(outs_list) > 0:
|
|
471
|
-
outs = torch.stack(outs_list)
|
|
472
|
-
else:
|
|
473
|
-
outs = torch.empty_like(x).requires_grad_()
|
|
474
|
-
|
|
475
|
-
# all gather across merged expert batches dimensions
|
|
476
|
-
# then split the batch dimension back
|
|
477
|
-
|
|
478
|
-
if is_distributed:
|
|
479
|
-
outs = rearrange(outs, "e b n d -> (e b) n d")
|
|
480
|
-
outs, _ = self.all_gather(outs)
|
|
481
|
-
outs = unpack_one(outs, expert_batch_packed_shape, "* n d")
|
|
482
|
-
|
|
483
|
-
outs = rearrange(outs, "e b n d -> b e n d")
|
|
484
|
-
|
|
485
|
-
if is_distributed:
|
|
486
|
-
if batch_sizes is not None:
|
|
487
|
-
outs_split = outs.split(batch_sizes.tolist())
|
|
488
|
-
outs = split_by_rank(outs_split)
|
|
489
|
-
|
|
490
|
-
assert outs.shape == shape
|
|
491
|
-
return outs
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
# main class
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
class SoftMoE(Module):
|
|
498
|
-
"""Soft Mixture of Experts (MoE) module.
|
|
499
|
-
|
|
500
|
-
This module implements a soft mixture of experts where tokens are softly
|
|
501
|
-
assigned to experts using learned routing weights.
|
|
502
|
-
"""
|
|
503
|
-
|
|
504
|
-
def __init__(
|
|
505
|
-
self,
|
|
506
|
-
dim: int,
|
|
507
|
-
*,
|
|
508
|
-
seq_len: int | None = None,
|
|
509
|
-
num_experts: int = 4,
|
|
510
|
-
num_slots: int | None = None,
|
|
511
|
-
expert_mult: int = 4,
|
|
512
|
-
dropout: float = 0.0,
|
|
513
|
-
geglu: bool = False,
|
|
514
|
-
is_distributed: bool | None = None,
|
|
515
|
-
offload_unused_experts_to_cpu: bool = True,
|
|
516
|
-
use_layernorm: bool = False,
|
|
517
|
-
) -> None:
|
|
518
|
-
"""Initialize the SoftMoE module.
|
|
519
|
-
|
|
520
|
-
Args:
|
|
521
|
-
dim: The input and output dimension.
|
|
522
|
-
seq_len: The sequence length. Must be provided if num_slots is not.
|
|
523
|
-
num_experts: The number of experts.
|
|
524
|
-
num_slots: The number of slots per expert. Must be provided if seq_len is not.
|
|
525
|
-
expert_mult: The multiplier for expert hidden dimensions.
|
|
526
|
-
dropout: The dropout rate.
|
|
527
|
-
geglu: Whether to use GLU activation in experts.
|
|
528
|
-
is_distributed: Whether to use distributed training.
|
|
529
|
-
offload_unused_experts_to_cpu: Whether to move unused experts to CPU.
|
|
530
|
-
use_layernorm: Whether to use LayerNorm instead of RMSNorm.
|
|
531
|
-
|
|
532
|
-
Raises:
|
|
533
|
-
AssertionError: If neither seq_len nor num_slots is provided, or if both are provided.
|
|
534
|
-
"""
|
|
535
|
-
super().__init__()
|
|
536
|
-
assert exists(seq_len) ^ exists(num_slots), (
|
|
537
|
-
"either seq_len, or num_slots must be passed into SoftMoE"
|
|
538
|
-
)
|
|
539
|
-
|
|
540
|
-
if exists(seq_len):
|
|
541
|
-
if seq_len is not None:
|
|
542
|
-
num_slots = default(num_slots, seq_len // num_experts)
|
|
543
|
-
elif exists(num_slots):
|
|
544
|
-
if num_slots is not None:
|
|
545
|
-
seq_len = num_slots * num_experts
|
|
546
|
-
else:
|
|
547
|
-
raise ValueError("Either seq_len or num_slots must be provided")
|
|
548
|
-
|
|
549
|
-
norm_klass = LayerNorm if use_layernorm else RMSNorm
|
|
550
|
-
self.norm: Callable = norm_klass(dim) # type: ignore
|
|
551
|
-
|
|
552
|
-
self.slot_norm: Callable = norm_klass(dim) # type: ignore
|
|
553
|
-
self.slot_embeds = nn.Parameter(torch.randn(num_experts, num_slots, dim))
|
|
554
|
-
|
|
555
|
-
expert_klass = GLUFeedForward if geglu else FeedForward
|
|
556
|
-
|
|
557
|
-
self.experts = Experts(
|
|
558
|
-
experts=[
|
|
559
|
-
expert_klass(dim=dim, mult=expert_mult, dropout=dropout)
|
|
560
|
-
for _ in range(num_experts)
|
|
561
|
-
],
|
|
562
|
-
is_distributed=is_distributed,
|
|
563
|
-
offload_unused_experts_to_cpu=offload_unused_experts_to_cpu,
|
|
564
|
-
)
|
|
565
|
-
|
|
566
|
-
self.num_experts = num_experts
|
|
567
|
-
self.num_slots = num_slots
|
|
568
|
-
|
|
569
|
-
def forward(
|
|
570
|
-
self,
|
|
571
|
-
x: Tensor,
|
|
572
|
-
mask: Tensor | None = None,
|
|
573
|
-
add_noise: bool = False,
|
|
574
|
-
noise_mult: float = 1.0,
|
|
575
|
-
weight_key: Tensor | None = None,
|
|
576
|
-
return_load_balance_loss: bool = True,
|
|
577
|
-
return_dispatch_weights: bool = True,
|
|
578
|
-
return_combine_weights: bool = True,
|
|
579
|
-
) -> dict[str, Tensor]:
|
|
580
|
-
"""Forward pass of the SoftMoE module.
|
|
581
|
-
|
|
582
|
-
Args:
|
|
583
|
-
x: Input tensor of shape (batch, seq_len, dim) or (batch, dim) for single token.
|
|
584
|
-
mask: Optional mask tensor of shape (batch, seq_len) to mask out padding tokens.
|
|
585
|
-
add_noise: Whether to add Gumbel noise to the routing logits.
|
|
586
|
-
noise_mult: Multiplier for the Gumbel noise.
|
|
587
|
-
weight_key: Tensor of shape (batch, seq_len, dim) with which to compute the dispatch and combine
|
|
588
|
-
weights. If not specified, use the input tokens.
|
|
589
|
-
return_load_balance_loss: Whether to return the load balance loss.
|
|
590
|
-
return_dispatch_weights: Whether to return the dispatch weights along with the output.
|
|
591
|
-
return_combine_weights: Whether to return the combine weights along with the output.
|
|
592
|
-
|
|
593
|
-
Returns:
|
|
594
|
-
dict with key "outputs" (output tensor) and optionally "load_balance_loss",
|
|
595
|
-
"dispatch_weights", and "combine_weights".
|
|
596
|
-
|
|
597
|
-
Note:
|
|
598
|
-
einstein notation
|
|
599
|
-
b - batch
|
|
600
|
-
n - sequence length
|
|
601
|
-
e - number of experts
|
|
602
|
-
s - number of slots per expert
|
|
603
|
-
d - feature dimension
|
|
604
|
-
"""
|
|
605
|
-
is_single_token = x.ndim == 2
|
|
606
|
-
is_image = x.ndim == 4
|
|
607
|
-
|
|
608
|
-
if is_image:
|
|
609
|
-
x = rearrange(x, "b d h w -> b h w d")
|
|
610
|
-
x, ps = pack([x], "b * d") # type: ignore
|
|
611
|
-
elif is_single_token:
|
|
612
|
-
x = rearrange(x, "b d -> b 1 d")
|
|
613
|
-
|
|
614
|
-
# following Algorithm 1, with the normalization they proposed, but with scaling of both (the now popular rmsnorm + gamma)
|
|
615
|
-
x = self.norm(x)
|
|
616
|
-
slot_embeds = self.slot_norm(self.slot_embeds)
|
|
617
|
-
|
|
618
|
-
dispatch_logits = einsum("b n d, e s d -> b n e s", x, slot_embeds)
|
|
619
|
-
if weight_key is None:
|
|
620
|
-
combine_logits = dispatch_logits
|
|
621
|
-
else:
|
|
622
|
-
assert weight_key.shape == x.shape, (
|
|
623
|
-
"weight_key must be (batch_size, seq_len, dim)"
|
|
624
|
-
)
|
|
625
|
-
combine_logits = einsum("b n d, e s d -> b n e s", weight_key, slot_embeds)
|
|
626
|
-
|
|
627
|
-
# noised dispatch and combine gate logits, with annealing if needed
|
|
628
|
-
if add_noise:
|
|
629
|
-
dispatch_logits = (
|
|
630
|
-
dispatch_logits + gumbel_noise(dispatch_logits) * noise_mult
|
|
631
|
-
)
|
|
632
|
-
combine_logits = combine_logits + gumbel_noise(combine_logits) * noise_mult
|
|
633
|
-
|
|
634
|
-
# account for key padding mask
|
|
635
|
-
if exists(mask):
|
|
636
|
-
mask = rearrange(mask, "b n -> b n 1 1")
|
|
637
|
-
fill_value = -torch.finfo(dispatch_logits.dtype).max
|
|
638
|
-
dispatch_logits = dispatch_logits.masked_fill(~mask, fill_value)
|
|
639
|
-
combine_logits = combine_logits.masked_fill(~mask, fill_value)
|
|
640
|
-
|
|
641
|
-
# get dispatch and combine weights (softmax across right dimensions)
|
|
642
|
-
dispatch_weights = dispatch_logits.softmax(dim=1)
|
|
643
|
-
|
|
644
|
-
combine_weights = rearrange(combine_logits, "b n e s -> b n (e s)")
|
|
645
|
-
combine_weights = combine_weights.softmax(dim=-1)
|
|
646
|
-
|
|
647
|
-
# derive slots by weighted average of input tokens using the dispatch weights from above
|
|
648
|
-
slots = einsum("b n d, b n e s -> b e s d", x, dispatch_weights)
|
|
649
|
-
|
|
650
|
-
# route the slots per expert to each expert
|
|
651
|
-
out = self.experts(slots)
|
|
652
|
-
|
|
653
|
-
# combine back out
|
|
654
|
-
out = rearrange(out, " b e s d -> b (e s) d")
|
|
655
|
-
out = einsum("b s d, b n s -> b n d", out, combine_weights)
|
|
656
|
-
|
|
657
|
-
if is_image:
|
|
658
|
-
(out,) = unpack(out, ps, "b * d") # type: ignore
|
|
659
|
-
out = rearrange(out, "b h w d -> b d h w")
|
|
660
|
-
elif is_single_token:
|
|
661
|
-
out = rearrange(out, "b 1 d -> b d")
|
|
662
|
-
|
|
663
|
-
# compute the load balance loss per layer if requested
|
|
664
|
-
info = {"outputs": out}
|
|
665
|
-
if return_load_balance_loss:
|
|
666
|
-
# penalize negative entropy of the expert combine weights
|
|
667
|
-
# this is negative, so be careful when adding it to the total loss
|
|
668
|
-
sizes = (self.num_experts, self.num_slots)
|
|
669
|
-
unflat = combine_weights.unflatten(dim=-1, sizes=sizes).sum(dim=-1)
|
|
670
|
-
distr = torch.distributions.Categorical(probs=unflat)
|
|
671
|
-
info["load_balance_loss"] = -distr.entropy().mean()
|
|
672
|
-
if return_dispatch_weights:
|
|
673
|
-
info["dispatch_weights"] = dispatch_weights
|
|
674
|
-
if return_combine_weights:
|
|
675
|
-
info["combine_weights"] = combine_weights
|
|
676
|
-
return info
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|