rslearn 0.0.3__py3-none-any.whl → 0.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,676 +0,0 @@
1
- """Soft MoE (Mixture of Experts) implementation.
2
-
3
- Mostly from
4
- https://raw.githubusercontent.com/lucidrains/soft-moe-pytorch/refs/heads/main/soft_moe_pytorch/soft_moe.py.
5
- """
6
-
7
- from collections.abc import Callable
8
- from typing import Any
9
-
10
- import torch
11
- import torch.distributed as dist
12
- import torch.nn.functional as F
13
- from einops import pack, rearrange, unpack
14
- from torch import Tensor, einsum, nn
15
- from torch.nn import Module
16
-
17
- from rslearn.models.moe.distributed import (
18
- AllGather,
19
- gather_sizes,
20
- has_only_one_value,
21
- split_by_rank,
22
- )
23
-
24
- # helper functions
25
-
26
-
27
- def exists(val: Any) -> bool:
28
- """Check if a value exists (is not None).
29
-
30
- Args:
31
- val: The value to check.
32
-
33
- Returns:
34
- bool: True if the value is not None, False otherwise.
35
- """
36
- return val is not None
37
-
38
-
39
- def default(val: Any, d: Any) -> Any:
40
- """Return the value if it exists, otherwise return the default.
41
-
42
- Args:
43
- val: The value to check.
44
- d: The default value to return if val is None.
45
-
46
- Returns:
47
- Any: The value if it exists, otherwise the default.
48
- """
49
- return val if exists(val) else d
50
-
51
-
52
- def divisible_by(num: int, den: int) -> bool:
53
- """Check if a number is divisible by another.
54
-
55
- Args:
56
- num: The numerator.
57
- den: The denominator.
58
-
59
- Returns:
60
- bool: True if num is divisible by den, False otherwise.
61
- """
62
- return (num % den) == 0
63
-
64
-
65
- def chunk_num(num: int, chunks: int) -> list[int]:
66
- """Divide a number into approximately equal chunks.
67
-
68
- Args:
69
- num: The number to divide.
70
- chunks: The number of chunks to create.
71
-
72
- Returns:
73
- List[int]: List of chunk sizes that sum to num.
74
- """
75
- num_per_chunk, remainder = divmod(num, chunks)
76
-
77
- out = []
78
- for i in range(chunks):
79
- n = num_per_chunk
80
- out.append(n + int(i < remainder))
81
-
82
- return out
83
-
84
-
85
- def pack_one(t: Tensor, pattern: str) -> tuple[Tensor, tuple[int, ...]]:
86
- """Pack a single tensor using einops pattern.
87
-
88
- Args:
89
- t: The tensor to pack.
90
- pattern: The einops pattern to use.
91
-
92
- Returns:
93
- Tuple[Tensor, Tuple[int, ...]]: Packed tensor and its shape.
94
- """
95
- return pack([t], pattern)
96
-
97
-
98
- def unpack_one(t: Tensor, ps: tuple[int, ...], pattern: str) -> Tensor:
99
- """Unpack a single tensor using einops pattern.
100
-
101
- Args:
102
- t: The tensor to unpack.
103
- ps: The packed shape.
104
- pattern: The einops pattern to use.
105
-
106
- Returns:
107
- Tensor: The unpacked tensor.
108
- """
109
- return unpack(t, ps, pattern)[0]
110
-
111
-
112
- def l2norm(t: Tensor) -> Tensor:
113
- """Apply L2 normalization to a tensor.
114
-
115
- Args:
116
- t: The tensor to normalize.
117
-
118
- Returns:
119
- Tensor: The L2 normalized tensor.
120
- """
121
- return F.normalize(t, dim=-1)
122
-
123
-
124
- def cumsum_exclusive(t: Tensor, dim: int = -3) -> Tensor:
125
- """Compute exclusive cumulative sum along a dimension.
126
-
127
- Args:
128
- t: The input tensor.
129
- dim: The dimension along which to compute the cumulative sum.
130
-
131
- Returns:
132
- Tensor: The exclusive cumulative sum.
133
-
134
- Raises:
135
- AssertionError: If dim is not negative.
136
- """
137
- assert dim < 0
138
- num_pad_dims = -dim - 1
139
- pre_padding = (0, 0) * num_pad_dims
140
- return F.pad(t, (*pre_padding, 1, -1)).cumsum(dim=dim)
141
-
142
-
143
- def log(t: Tensor, eps: float = 1e-20) -> Tensor:
144
- """Compute the natural logarithm with a minimum value.
145
-
146
- Args:
147
- t: The input tensor.
148
- eps: The minimum value to clamp to.
149
-
150
- Returns:
151
- Tensor: The natural logarithm of the clamped tensor.
152
- """
153
- return torch.log(t.clamp(min=eps))
154
-
155
-
156
- def gumbel_noise(t: Tensor) -> Tensor:
157
- """Generate Gumbel noise for the given tensor.
158
-
159
- Args:
160
- t: The input tensor.
161
-
162
- Returns:
163
- Tensor: Gumbel noise with the same shape as t.
164
- """
165
- noise = torch.zeros_like(t).uniform_(0, 1)
166
- return -log(-log(noise))
167
-
168
-
169
- # norm
170
-
171
-
172
- class LayerNorm(nn.Module):
173
- """Layer normalization module with learnable parameters.
174
-
175
- This module applies layer normalization with learnable gamma and beta parameters.
176
- """
177
-
178
- def __init__(self, dim: int) -> None:
179
- """Initialize the LayerNorm module.
180
-
181
- Args:
182
- dim: The dimension to normalize over.
183
- """
184
- super().__init__()
185
- self.gamma = nn.Parameter(torch.ones(dim))
186
- self.register_buffer("beta", torch.zeros(dim))
187
-
188
- def forward(self, x: Tensor) -> Tensor:
189
- """Forward pass of the layer normalization.
190
-
191
- Args:
192
- x: Input tensor of shape (..., dim).
193
-
194
- Returns:
195
- Tensor: Normalized tensor with the same shape as x.
196
- """
197
- return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
198
-
199
-
200
- class RMSNorm(Module):
201
- """Root Mean Square normalization module.
202
-
203
- This module applies RMS normalization with a learnable gamma parameter.
204
- """
205
-
206
- def __init__(self, dim: int) -> None:
207
- """Initialize the RMSNorm module.
208
-
209
- Args:
210
- dim: The dimension to normalize over.
211
- """
212
- super().__init__()
213
- self.scale = dim**0.5
214
- self.gamma = nn.Parameter(torch.ones(dim))
215
-
216
- def forward(self, x: Tensor) -> Tensor:
217
- """Forward pass of the RMS normalization.
218
-
219
- Args:
220
- x: Input tensor of shape (..., dim).
221
-
222
- Returns:
223
- Tensor: Normalized tensor with the same shape as x.
224
- """
225
- return l2norm(x) * self.scale * self.gamma
226
-
227
-
228
- # expert
229
-
230
-
231
- def FeedForward(dim: int, mult: int = 4, dropout: float = 0.0) -> nn.Sequential:
232
- """Create a feedforward neural network.
233
-
234
- Args:
235
- dim: The input and output dimension.
236
- mult: The multiplier for the hidden dimension.
237
- dropout: The dropout rate.
238
-
239
- Returns:
240
- nn.Sequential: A feedforward network with GELU activation.
241
- """
242
- dim_hidden = int(dim * mult)
243
- return nn.Sequential(
244
- nn.Linear(dim, dim_hidden),
245
- nn.GELU(),
246
- nn.Dropout(dropout),
247
- nn.Linear(dim_hidden, dim),
248
- )
249
-
250
-
251
- class GEGLU(Module):
252
- """Gated Linear Unit with GELU activation.
253
-
254
- This module implements a gated linear unit where the gate uses GELU activation.
255
- """
256
-
257
- def forward(self, x: Tensor) -> Tensor:
258
- """Forward pass of the GEGLU module.
259
-
260
- Args:
261
- x: Input tensor of shape (..., 2 * dim).
262
-
263
- Returns:
264
- Tensor: Output tensor of shape (..., dim).
265
- """
266
- x, gate = x.chunk(2, dim=-1)
267
- return x * F.gelu(gate)
268
-
269
-
270
- def GLUFeedForward(dim: int, mult: int = 4, dropout: float = 0.0) -> nn.Sequential:
271
- """Create a feedforward neural network with GLU activation.
272
-
273
- Args:
274
- dim: The input and output dimension.
275
- mult: The multiplier for the hidden dimension.
276
- dropout: The dropout rate.
277
-
278
- Returns:
279
- nn.Sequential: A feedforward network with GLU activation.
280
- """
281
- dim_hidden = int(dim * mult * 2 / 3)
282
-
283
- return nn.Sequential(
284
- nn.Linear(dim, dim_hidden * 2),
285
- GEGLU(),
286
- nn.Dropout(dropout),
287
- nn.Linear(dim_hidden, dim),
288
- )
289
-
290
-
291
- # experts
292
-
293
-
294
- class Experts(nn.Module):
295
- """A module that manages multiple expert networks for distributed training.
296
-
297
- This module handles the distribution of experts across multiple devices
298
- and manages the routing of inputs to the appropriate experts.
299
- """
300
-
301
- def __init__(
302
- self,
303
- experts: list[nn.Module],
304
- is_distributed: bool | None = None,
305
- offload_unused_experts_to_cpu: bool = True,
306
- ) -> None:
307
- """Initialize the Experts module.
308
-
309
- Args:
310
- experts: List of expert modules.
311
- is_distributed: Whether to use distributed training. If None,
312
- automatically detected from torch.distributed.
313
- offload_unused_experts_to_cpu: Whether to move unused experts to CPU
314
- to save GPU memory.
315
- """
316
- super().__init__()
317
- self.num_experts = len(experts)
318
- self.experts = nn.ModuleList(experts)
319
-
320
- self.is_distributed = is_distributed
321
- if not exists(self.is_distributed):
322
- self.is_distributed = dist.is_initialized() and dist.get_world_size() > 1
323
-
324
- # whether to offload unused experts to cpu, will require optimizer handles conversion of gradients to right device when accumulating
325
- self.offload_unused_experts_to_cpu = offload_unused_experts_to_cpu
326
-
327
- self.all_gather = AllGather()
328
- self.register_buffer("dummy", torch.ones(1), persistent=False)
329
-
330
- @property
331
- def device(self) -> torch.device:
332
- """Get the device of the dummy buffer.
333
-
334
- Returns:
335
- torch.device: The device of the module.
336
- """
337
- return self.dummy.device
338
-
339
- def all_experts_to_cpu_besides(
340
- self, selection: int | slice | list[nn.Module]
341
- ) -> None:
342
- """Move all experts to CPU except those in the selection.
343
-
344
- Args:
345
- selection: The experts to keep on the current device. Can be an int,
346
- slice, or list of expert modules.
347
- """
348
- if not self.offload_unused_experts_to_cpu:
349
- return
350
-
351
- if isinstance(selection, int):
352
- experts = [self.experts[selection]]
353
- elif isinstance(selection, slice):
354
- experts = self.experts[selection]
355
- else:
356
- experts = selection
357
-
358
- experts_set = set(experts)
359
-
360
- for expert in self.experts:
361
- device = self.device if expert in experts_set else "cpu"
362
- expert.to(device)
363
-
364
- def forward(self, x: Tensor, is_distributed: bool | None = None) -> Tensor:
365
- """Forward pass through the experts.
366
-
367
- Args:
368
- x: Input tensor of shape (batch, experts, seq_len, dim).
369
- is_distributed: Whether to use distributed training. If None, uses
370
- the default setting.
371
-
372
- Returns:
373
- Tensor: Output tensor with the same shape as the input.
374
-
375
- Note:
376
- einops notation:
377
- b - batch
378
- r - rank (device / machines)
379
- e - experts
380
- n - sequence (number of tokens per expert)
381
- d - feature dimension
382
- """
383
- is_distributed = default(is_distributed, self.is_distributed)
384
- shape, num_experts = x.shape, self.num_experts
385
-
386
- # for now naively all gather across batch dimension if distributed, optimize later
387
-
388
- if is_distributed:
389
- seq_sizes = gather_sizes(x, dim=-2)
390
- assert has_only_one_value(seq_sizes), (
391
- "number of tokens per expert must be the same"
392
- )
393
-
394
- x, batch_sizes = self.all_gather(x)
395
- total_batch_size = x.shape[0]
396
-
397
- world_size = dist.get_world_size()
398
- rank = dist.get_rank()
399
- else:
400
- world_size = 1
401
- rank = 0
402
-
403
- # the experts in use on the rank
404
-
405
- if is_distributed:
406
- if world_size <= num_experts:
407
- num_experts_across_ranks = chunk_num(num_experts, world_size)
408
- start_indices = cumsum_exclusive(
409
- torch.tensor(num_experts_across_ranks), dim=-1
410
- )
411
-
412
- num_experts_per_rank = num_experts_across_ranks[rank]
413
- num_experts_batches_across_ranks = [
414
- i * total_batch_size for i in num_experts_across_ranks
415
- ]
416
-
417
- expert_start_index = start_indices[rank].item()
418
- else:
419
- num_batch_chunks = world_size // num_experts
420
- total_ranks_in_use = num_batch_chunks * num_experts
421
-
422
- expert_start_index = rank // num_batch_chunks
423
-
424
- batch_splits = chunk_num(total_batch_size, num_batch_chunks)
425
- num_experts_batches_across_ranks = list(batch_splits * num_experts)
426
-
427
- # for now, remaining machines just process nothing
428
-
429
- remain_ranks = world_size % num_experts
430
- num_experts_batches_across_ranks += [0] * remain_ranks
431
-
432
- num_experts_per_rank = int(rank < total_ranks_in_use)
433
-
434
- assert len(num_experts_batches_across_ranks) == world_size
435
-
436
- expert_slice = slice(
437
- expert_start_index, expert_start_index + num_experts_per_rank
438
- )
439
- else:
440
- num_experts_per_rank = num_experts
441
- expert_slice = slice(0, num_experts)
442
-
443
- # if distributed, each machine only handles subset of experts and batch
444
-
445
- x = rearrange(x, "b e n d -> e b n d")
446
-
447
- if is_distributed:
448
- x, expert_batch_packed_shape = pack_one(x, "* n d")
449
- x_split = x.split(num_experts_batches_across_ranks, dim=0)
450
- x = split_by_rank(x_split)
451
-
452
- if num_experts_per_rank > 0:
453
- x = rearrange(x, "(e b) n d -> e b n d", e=num_experts_per_rank)
454
- else:
455
- x = x.reshape(num_experts, *x.shape)
456
-
457
- # get the experts in use
458
-
459
- self.all_experts_to_cpu_besides(expert_slice)
460
-
461
- experts = self.experts[expert_slice]
462
-
463
- # route tokens to appropriate experts
464
-
465
- outs_list = []
466
- for expert, expert_input in zip(experts, x):
467
- out = expert(expert_input)
468
- outs_list.append(out)
469
-
470
- if len(outs_list) > 0:
471
- outs = torch.stack(outs_list)
472
- else:
473
- outs = torch.empty_like(x).requires_grad_()
474
-
475
- # all gather across merged expert batches dimensions
476
- # then split the batch dimension back
477
-
478
- if is_distributed:
479
- outs = rearrange(outs, "e b n d -> (e b) n d")
480
- outs, _ = self.all_gather(outs)
481
- outs = unpack_one(outs, expert_batch_packed_shape, "* n d")
482
-
483
- outs = rearrange(outs, "e b n d -> b e n d")
484
-
485
- if is_distributed:
486
- if batch_sizes is not None:
487
- outs_split = outs.split(batch_sizes.tolist())
488
- outs = split_by_rank(outs_split)
489
-
490
- assert outs.shape == shape
491
- return outs
492
-
493
-
494
- # main class
495
-
496
-
497
- class SoftMoE(Module):
498
- """Soft Mixture of Experts (MoE) module.
499
-
500
- This module implements a soft mixture of experts where tokens are softly
501
- assigned to experts using learned routing weights.
502
- """
503
-
504
- def __init__(
505
- self,
506
- dim: int,
507
- *,
508
- seq_len: int | None = None,
509
- num_experts: int = 4,
510
- num_slots: int | None = None,
511
- expert_mult: int = 4,
512
- dropout: float = 0.0,
513
- geglu: bool = False,
514
- is_distributed: bool | None = None,
515
- offload_unused_experts_to_cpu: bool = True,
516
- use_layernorm: bool = False,
517
- ) -> None:
518
- """Initialize the SoftMoE module.
519
-
520
- Args:
521
- dim: The input and output dimension.
522
- seq_len: The sequence length. Must be provided if num_slots is not.
523
- num_experts: The number of experts.
524
- num_slots: The number of slots per expert. Must be provided if seq_len is not.
525
- expert_mult: The multiplier for expert hidden dimensions.
526
- dropout: The dropout rate.
527
- geglu: Whether to use GLU activation in experts.
528
- is_distributed: Whether to use distributed training.
529
- offload_unused_experts_to_cpu: Whether to move unused experts to CPU.
530
- use_layernorm: Whether to use LayerNorm instead of RMSNorm.
531
-
532
- Raises:
533
- AssertionError: If neither seq_len nor num_slots is provided, or if both are provided.
534
- """
535
- super().__init__()
536
- assert exists(seq_len) ^ exists(num_slots), (
537
- "either seq_len, or num_slots must be passed into SoftMoE"
538
- )
539
-
540
- if exists(seq_len):
541
- if seq_len is not None:
542
- num_slots = default(num_slots, seq_len // num_experts)
543
- elif exists(num_slots):
544
- if num_slots is not None:
545
- seq_len = num_slots * num_experts
546
- else:
547
- raise ValueError("Either seq_len or num_slots must be provided")
548
-
549
- norm_klass = LayerNorm if use_layernorm else RMSNorm
550
- self.norm: Callable = norm_klass(dim) # type: ignore
551
-
552
- self.slot_norm: Callable = norm_klass(dim) # type: ignore
553
- self.slot_embeds = nn.Parameter(torch.randn(num_experts, num_slots, dim))
554
-
555
- expert_klass = GLUFeedForward if geglu else FeedForward
556
-
557
- self.experts = Experts(
558
- experts=[
559
- expert_klass(dim=dim, mult=expert_mult, dropout=dropout)
560
- for _ in range(num_experts)
561
- ],
562
- is_distributed=is_distributed,
563
- offload_unused_experts_to_cpu=offload_unused_experts_to_cpu,
564
- )
565
-
566
- self.num_experts = num_experts
567
- self.num_slots = num_slots
568
-
569
- def forward(
570
- self,
571
- x: Tensor,
572
- mask: Tensor | None = None,
573
- add_noise: bool = False,
574
- noise_mult: float = 1.0,
575
- weight_key: Tensor | None = None,
576
- return_load_balance_loss: bool = True,
577
- return_dispatch_weights: bool = True,
578
- return_combine_weights: bool = True,
579
- ) -> dict[str, Tensor]:
580
- """Forward pass of the SoftMoE module.
581
-
582
- Args:
583
- x: Input tensor of shape (batch, seq_len, dim) or (batch, dim) for single token.
584
- mask: Optional mask tensor of shape (batch, seq_len) to mask out padding tokens.
585
- add_noise: Whether to add Gumbel noise to the routing logits.
586
- noise_mult: Multiplier for the Gumbel noise.
587
- weight_key: Tensor of shape (batch, seq_len, dim) with which to compute the dispatch and combine
588
- weights. If not specified, use the input tokens.
589
- return_load_balance_loss: Whether to return the load balance loss.
590
- return_dispatch_weights: Whether to return the dispatch weights along with the output.
591
- return_combine_weights: Whether to return the combine weights along with the output.
592
-
593
- Returns:
594
- dict with key "outputs" (output tensor) and optionally "load_balance_loss",
595
- "dispatch_weights", and "combine_weights".
596
-
597
- Note:
598
- einstein notation
599
- b - batch
600
- n - sequence length
601
- e - number of experts
602
- s - number of slots per expert
603
- d - feature dimension
604
- """
605
- is_single_token = x.ndim == 2
606
- is_image = x.ndim == 4
607
-
608
- if is_image:
609
- x = rearrange(x, "b d h w -> b h w d")
610
- x, ps = pack([x], "b * d") # type: ignore
611
- elif is_single_token:
612
- x = rearrange(x, "b d -> b 1 d")
613
-
614
- # following Algorithm 1, with the normalization they proposed, but with scaling of both (the now popular rmsnorm + gamma)
615
- x = self.norm(x)
616
- slot_embeds = self.slot_norm(self.slot_embeds)
617
-
618
- dispatch_logits = einsum("b n d, e s d -> b n e s", x, slot_embeds)
619
- if weight_key is None:
620
- combine_logits = dispatch_logits
621
- else:
622
- assert weight_key.shape == x.shape, (
623
- "weight_key must be (batch_size, seq_len, dim)"
624
- )
625
- combine_logits = einsum("b n d, e s d -> b n e s", weight_key, slot_embeds)
626
-
627
- # noised dispatch and combine gate logits, with annealing if needed
628
- if add_noise:
629
- dispatch_logits = (
630
- dispatch_logits + gumbel_noise(dispatch_logits) * noise_mult
631
- )
632
- combine_logits = combine_logits + gumbel_noise(combine_logits) * noise_mult
633
-
634
- # account for key padding mask
635
- if exists(mask):
636
- mask = rearrange(mask, "b n -> b n 1 1")
637
- fill_value = -torch.finfo(dispatch_logits.dtype).max
638
- dispatch_logits = dispatch_logits.masked_fill(~mask, fill_value)
639
- combine_logits = combine_logits.masked_fill(~mask, fill_value)
640
-
641
- # get dispatch and combine weights (softmax across right dimensions)
642
- dispatch_weights = dispatch_logits.softmax(dim=1)
643
-
644
- combine_weights = rearrange(combine_logits, "b n e s -> b n (e s)")
645
- combine_weights = combine_weights.softmax(dim=-1)
646
-
647
- # derive slots by weighted average of input tokens using the dispatch weights from above
648
- slots = einsum("b n d, b n e s -> b e s d", x, dispatch_weights)
649
-
650
- # route the slots per expert to each expert
651
- out = self.experts(slots)
652
-
653
- # combine back out
654
- out = rearrange(out, " b e s d -> b (e s) d")
655
- out = einsum("b s d, b n s -> b n d", out, combine_weights)
656
-
657
- if is_image:
658
- (out,) = unpack(out, ps, "b * d") # type: ignore
659
- out = rearrange(out, "b h w d -> b d h w")
660
- elif is_single_token:
661
- out = rearrange(out, "b 1 d -> b d")
662
-
663
- # compute the load balance loss per layer if requested
664
- info = {"outputs": out}
665
- if return_load_balance_loss:
666
- # penalize negative entropy of the expert combine weights
667
- # this is negative, so be careful when adding it to the total loss
668
- sizes = (self.num_experts, self.num_slots)
669
- unflat = combine_weights.unflatten(dim=-1, sizes=sizes).sum(dim=-1)
670
- distr = torch.distributions.Categorical(probs=unflat)
671
- info["load_balance_loss"] = -distr.entropy().mean()
672
- if return_dispatch_weights:
673
- info["dispatch_weights"] = dispatch_weights
674
- if return_combine_weights:
675
- info["combine_weights"] = combine_weights
676
- return info