d9d 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- d9d/__init__.py +0 -0
- d9d/core/__init__.py +0 -0
- d9d/core/autograd/__init__.py +7 -0
- d9d/core/autograd/grad_context.py +85 -0
- d9d/core/dist_context/__init__.py +19 -0
- d9d/core/dist_context/configured.py +215 -0
- d9d/core/dist_context/device_mesh_domains.py +185 -0
- d9d/core/dist_context/log.py +30 -0
- d9d/core/dist_context/params.py +113 -0
- d9d/core/dist_ops/__init__.py +16 -0
- d9d/core/dist_ops/object.py +68 -0
- d9d/core/dist_ops/tensor.py +192 -0
- d9d/core/protocol/__init__.py +8 -0
- d9d/core/protocol/training.py +38 -0
- d9d/core/sharding/__init__.py +15 -0
- d9d/core/sharding/auto_spec.py +66 -0
- d9d/core/sharding/shard.py +154 -0
- d9d/core/sharding/spec.py +28 -0
- d9d/core/sharding/unshard.py +117 -0
- d9d/core/types/__init__.py +12 -0
- d9d/core/types/data.py +14 -0
- d9d/core/types/pytree.py +26 -0
- d9d/dataset/__init__.py +17 -0
- d9d/dataset/buffer_sorted.py +143 -0
- d9d/dataset/padding.py +79 -0
- d9d/dataset/sharded.py +195 -0
- d9d/internals/__init__.py +0 -0
- d9d/internals/determinism/__init__.py +10 -0
- d9d/internals/determinism/seed.py +63 -0
- d9d/internals/grad_norm/__init__.py +8 -0
- d9d/internals/grad_norm/group.py +87 -0
- d9d/internals/grad_norm/norm.py +169 -0
- d9d/internals/grad_sync/__init__.py +14 -0
- d9d/internals/grad_sync/bucket.py +317 -0
- d9d/internals/grad_sync/placement_helper.py +23 -0
- d9d/internals/grad_sync/synchronizer.py +257 -0
- d9d/internals/pipeline_state/__init__.py +14 -0
- d9d/internals/pipeline_state/api.py +45 -0
- d9d/internals/pipeline_state/handler.py +111 -0
- d9d/internals/pipeline_state/storage.py +236 -0
- d9d/internals/profiling/__init__.py +7 -0
- d9d/internals/profiling/profile.py +112 -0
- d9d/internals/state/__init__.py +6 -0
- d9d/internals/state/main_process.py +44 -0
- d9d/kernel/__init__.py +0 -0
- d9d/kernel/cce/__init__.py +5 -0
- d9d/kernel/cce/cce.py +298 -0
- d9d/kernel/cce/main.py +282 -0
- d9d/kernel/general/__init__.py +5 -0
- d9d/kernel/general/get_int_dtype.py +7 -0
- d9d/kernel/gmm/__init__.py +5 -0
- d9d/kernel/gmm/function.py +78 -0
- d9d/kernel/moe/__init__.py +8 -0
- d9d/kernel/moe/indices_to_multihot.py +268 -0
- d9d/kernel/moe/permute_with_probs.py +1035 -0
- d9d/kernel/stochastic/__init__.py +11 -0
- d9d/kernel/stochastic/adamw_step.py +204 -0
- d9d/kernel/stochastic/copy.py +104 -0
- d9d/kernel/stochastic/ops/__init__.py +5 -0
- d9d/kernel/stochastic/ops/round.py +22 -0
- d9d/kernel/swiglu/__init__.py +5 -0
- d9d/kernel/swiglu/function.py +36 -0
- d9d/kernel/swiglu/op.py +167 -0
- d9d/loop/__init__.py +0 -0
- d9d/loop/auto/__init__.py +9 -0
- d9d/loop/auto/auto_lr_scheduler.py +46 -0
- d9d/loop/auto/auto_optimizer.py +196 -0
- d9d/loop/component/__init__.py +35 -0
- d9d/loop/component/batch_maths.py +106 -0
- d9d/loop/component/checkpointer.py +172 -0
- d9d/loop/component/data_loader_factory.py +258 -0
- d9d/loop/component/garbage_collector.py +94 -0
- d9d/loop/component/gradient_clipper.py +89 -0
- d9d/loop/component/gradient_manager.py +149 -0
- d9d/loop/component/job_logger.py +146 -0
- d9d/loop/component/job_profiler.py +62 -0
- d9d/loop/component/loss_computer.py +86 -0
- d9d/loop/component/model_stage_exporter.py +37 -0
- d9d/loop/component/model_stage_factory.py +261 -0
- d9d/loop/component/optimizer_factory.py +88 -0
- d9d/loop/component/stepper.py +52 -0
- d9d/loop/component/timeout_manager.py +54 -0
- d9d/loop/component/train_task_operator.py +152 -0
- d9d/loop/config/__init__.py +36 -0
- d9d/loop/config/config.py +225 -0
- d9d/loop/config/types.py +24 -0
- d9d/loop/control/__init__.py +61 -0
- d9d/loop/control/dataset_provider.py +58 -0
- d9d/loop/control/lr_scheduler_provider.py +47 -0
- d9d/loop/control/model_provider.py +162 -0
- d9d/loop/control/optimizer_provider.py +45 -0
- d9d/loop/control/task.py +304 -0
- d9d/loop/run/__init__.py +6 -0
- d9d/loop/run/train.py +355 -0
- d9d/loop/state.py +143 -0
- d9d/lr_scheduler/__init__.py +9 -0
- d9d/lr_scheduler/piecewise/__init__.py +18 -0
- d9d/lr_scheduler/piecewise/builder.py +152 -0
- d9d/lr_scheduler/piecewise/config.py +176 -0
- d9d/lr_scheduler/piecewise/curves.py +75 -0
- d9d/lr_scheduler/piecewise/engine.py +76 -0
- d9d/lr_scheduler/visualizer.py +74 -0
- d9d/metric/__init__.py +10 -0
- d9d/metric/abc.py +79 -0
- d9d/metric/impl/__init__.py +7 -0
- d9d/metric/impl/compose.py +54 -0
- d9d/metric/impl/mean.py +94 -0
- d9d/model_state/__init__.py +0 -0
- d9d/model_state/io/__init__.py +21 -0
- d9d/model_state/io/dto.py +30 -0
- d9d/model_state/io/module_reader.py +75 -0
- d9d/model_state/io/module_writer.py +123 -0
- d9d/model_state/io/reader.py +125 -0
- d9d/model_state/io/writer.py +309 -0
- d9d/model_state/mapper/__init__.py +10 -0
- d9d/model_state/mapper/abc.py +70 -0
- d9d/model_state/mapper/adapters/__init__.py +12 -0
- d9d/model_state/mapper/adapters/mapper.py +27 -0
- d9d/model_state/mapper/adapters/module.py +22 -0
- d9d/model_state/mapper/compose/__init__.py +17 -0
- d9d/model_state/mapper/compose/helper.py +22 -0
- d9d/model_state/mapper/compose/parallel.py +58 -0
- d9d/model_state/mapper/compose/sequential.py +131 -0
- d9d/model_state/mapper/compose/shard.py +36 -0
- d9d/model_state/mapper/leaf/__init__.py +18 -0
- d9d/model_state/mapper/leaf/dtensor.py +56 -0
- d9d/model_state/mapper/leaf/identity.py +23 -0
- d9d/model_state/mapper/leaf/rename.py +26 -0
- d9d/model_state/mapper/leaf/select_child.py +37 -0
- d9d/model_state/mapper/leaf/stack.py +29 -0
- d9d/module/__init__.py +0 -0
- d9d/module/base/__init__.py +7 -0
- d9d/module/base/late_init.py +10 -0
- d9d/module/block/__init__.py +0 -0
- d9d/module/block/attention/__init__.py +7 -0
- d9d/module/block/attention/grouped_query.py +139 -0
- d9d/module/block/attention/sdpa/__init__.py +5 -0
- d9d/module/block/attention/sdpa/flash.py +52 -0
- d9d/module/block/embedding/__init__.py +7 -0
- d9d/module/block/embedding/shard_token_embedding.py +103 -0
- d9d/module/block/ffn/__init__.py +5 -0
- d9d/module/block/ffn/swiglu.py +60 -0
- d9d/module/block/head/__init__.py +6 -0
- d9d/module/block/head/language_modelling.py +87 -0
- d9d/module/block/hidden_states_aggregator/__init__.py +12 -0
- d9d/module/block/hidden_states_aggregator/base.py +35 -0
- d9d/module/block/hidden_states_aggregator/factory.py +48 -0
- d9d/module/block/hidden_states_aggregator/mean.py +61 -0
- d9d/module/block/hidden_states_aggregator/noop.py +27 -0
- d9d/module/block/moe/__init__.py +13 -0
- d9d/module/block/moe/communications/__init__.py +11 -0
- d9d/module/block/moe/communications/base.py +58 -0
- d9d/module/block/moe/communications/deepep.py +300 -0
- d9d/module/block/moe/communications/naive.py +68 -0
- d9d/module/block/moe/grouped_experts.py +81 -0
- d9d/module/block/moe/grouped_linear.py +78 -0
- d9d/module/block/moe/layer.py +122 -0
- d9d/module/block/moe/router.py +103 -0
- d9d/module/block/positional/__init__.py +8 -0
- d9d/module/block/positional/rope.py +150 -0
- d9d/module/model/__init__.py +0 -0
- d9d/module/model/qwen3_moe/__init__.py +16 -0
- d9d/module/model/qwen3_moe/decoder_layer.py +110 -0
- d9d/module/model/qwen3_moe/model.py +373 -0
- d9d/module/model/qwen3_moe/params.py +69 -0
- d9d/module/parallelism/__init__.py +0 -0
- d9d/module/parallelism/api/__init__.py +18 -0
- d9d/module/parallelism/api/expert_parallel.py +36 -0
- d9d/module/parallelism/api/fully_sharded.py +43 -0
- d9d/module/parallelism/api/hybrid_sharded.py +49 -0
- d9d/module/parallelism/api/replicate_parallel.py +33 -0
- d9d/module/parallelism/model/__init__.py +0 -0
- d9d/module/parallelism/model/qwen3_moe.py +99 -0
- d9d/module/parallelism/style/__init__.py +7 -0
- d9d/module/parallelism/style/shard_experts.py +60 -0
- d9d/module/parallelism/style/to_local.py +86 -0
- d9d/optim/__init__.py +0 -0
- d9d/optim/stochastic/__init__.py +5 -0
- d9d/optim/stochastic/adamw.py +158 -0
- d9d/peft/__init__.py +13 -0
- d9d/peft/all/__init__.py +12 -0
- d9d/peft/all/config.py +31 -0
- d9d/peft/all/method.py +76 -0
- d9d/peft/applicator.py +47 -0
- d9d/peft/base.py +70 -0
- d9d/peft/full_tune/__init__.py +11 -0
- d9d/peft/full_tune/config.py +20 -0
- d9d/peft/full_tune/method.py +46 -0
- d9d/peft/lora/__init__.py +15 -0
- d9d/peft/lora/config.py +35 -0
- d9d/peft/lora/layer.py +177 -0
- d9d/peft/lora/method.py +132 -0
- d9d/pipelining/__init__.py +0 -0
- d9d/pipelining/api/__init__.py +19 -0
- d9d/pipelining/api/module.py +149 -0
- d9d/pipelining/api/schedule.py +50 -0
- d9d/pipelining/api/sharding.py +9 -0
- d9d/pipelining/factory/__init__.py +21 -0
- d9d/pipelining/factory/config.py +89 -0
- d9d/pipelining/factory/factory.py +114 -0
- d9d/pipelining/factory/registry.py +82 -0
- d9d/pipelining/infra/__init__.py +0 -0
- d9d/pipelining/infra/schedule/__init__.py +0 -0
- d9d/pipelining/infra/schedule/component/__init__.py +0 -0
- d9d/pipelining/infra/schedule/component/program/__init__.py +22 -0
- d9d/pipelining/infra/schedule/component/program/base.py +35 -0
- d9d/pipelining/infra/schedule/component/program/communications.py +203 -0
- d9d/pipelining/infra/schedule/component/program/topology.py +78 -0
- d9d/pipelining/infra/schedule/component/runtime/__init__.py +29 -0
- d9d/pipelining/infra/schedule/component/runtime/action.py +361 -0
- d9d/pipelining/infra/schedule/component/runtime/communications.py +101 -0
- d9d/pipelining/infra/schedule/component/runtime/executor.py +113 -0
- d9d/pipelining/infra/schedule/component/runtime/loss.py +55 -0
- d9d/pipelining/infra/schedule/program/__init__.py +15 -0
- d9d/pipelining/infra/schedule/program/bfs.py +86 -0
- d9d/pipelining/infra/schedule/program/dualpipev.py +234 -0
- d9d/pipelining/infra/schedule/program/interleaved.py +240 -0
- d9d/pipelining/infra/schedule/program/zerobubblev.py +227 -0
- d9d/pipelining/infra/stage/__init__.py +5 -0
- d9d/pipelining/infra/stage/communications.py +274 -0
- d9d/pipelining/infra/stage/computations.py +317 -0
- d9d/pipelining/infra/stage/splitgrad.py +377 -0
- d9d/pipelining/infra/stage/stage.py +321 -0
- d9d/pipelining/infra/stage/struct_helper.py +46 -0
- d9d/pipelining/training/__init__.py +7 -0
- d9d/pipelining/training/optimizer.py +41 -0
- d9d/pipelining/training/scheduler.py +34 -0
- d9d/tracker/__init__.py +14 -0
- d9d/tracker/base.py +124 -0
- d9d/tracker/factory.py +57 -0
- d9d/tracker/provider/__init__.py +0 -0
- d9d/tracker/provider/aim/__init__.py +0 -0
- d9d/tracker/provider/aim/config.py +23 -0
- d9d/tracker/provider/aim/tracker.py +114 -0
- d9d/tracker/provider/null.py +61 -0
- d9d-0.1.0.dist-info/METADATA +90 -0
- d9d-0.1.0.dist-info/RECORD +238 -0
- d9d-0.1.0.dist-info/WHEEL +4 -0
d9d/kernel/cce/main.py
ADDED
|
@@ -0,0 +1,282 @@
|
|
|
1
|
+
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
|
2
|
+
# TODO: currently this implementation diverges only in out_grad contiguity fix
|
|
3
|
+
# TODO: proposed in cce.py (grep FIX) - we should contribute this to main repo
|
|
4
|
+
import platform
|
|
5
|
+
import warnings
|
|
6
|
+
from typing import TYPE_CHECKING, Literal, overload
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
import torch.nn as nn
|
|
10
|
+
|
|
11
|
+
from cut_cross_entropy.cce_utils import CCEPreset, CCEPresets, LinearCrossEntropyImpl
|
|
12
|
+
from cut_cross_entropy.constants import IGNORE_INDEX
|
|
13
|
+
from cut_cross_entropy.doc import (
|
|
14
|
+
CCE_OPTS_DOC,
|
|
15
|
+
DTENSOR_NOTE,
|
|
16
|
+
IMPL_DOC,
|
|
17
|
+
LINEAR_CROSS_ENTROPY_DOC,
|
|
18
|
+
add_doc_end,
|
|
19
|
+
add_doc_start,
|
|
20
|
+
)
|
|
21
|
+
from cut_cross_entropy.torch_compile import torch_compile_linear_cross_entropy
|
|
22
|
+
from cut_cross_entropy.utils import (
|
|
23
|
+
CCEWarning,
|
|
24
|
+
is_torch_greater_or_equal_2_5,
|
|
25
|
+
is_triton_3_2,
|
|
26
|
+
maybe_type_as,
|
|
27
|
+
to_full_tensor,
|
|
28
|
+
)
|
|
29
|
+
from cut_cross_entropy.vocab_parallel import VocabParallelOptions
|
|
30
|
+
|
|
31
|
+
warnings.filterwarnings("once", category=CCEWarning, module="cut_cross_entropy")
|
|
32
|
+
|
|
33
|
+
PLATFORM_SYSTEM = platform.system()
|
|
34
|
+
|
|
35
|
+
if TYPE_CHECKING or PLATFORM_SYSTEM != "Darwin":
|
|
36
|
+
from .cce import cce_linear_cross_entropy
|
|
37
|
+
|
|
38
|
+
LCE_IMPL_DEFAULT = LinearCrossEntropyImpl.CCE
|
|
39
|
+
else:
|
|
40
|
+
cce_linear_cross_entropy = None
|
|
41
|
+
LCE_IMPL_DEFAULT = LinearCrossEntropyImpl.TORCH_COMPILE
|
|
42
|
+
|
|
43
|
+
if TYPE_CHECKING or is_torch_greater_or_equal_2_5():
|
|
44
|
+
import torch.distributed.tensor
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
is_d_tensor_error_message = (
|
|
48
|
+
"Received {name} as a torch.distributed.tensor.DTensor. This is not supported. "
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@overload
|
|
53
|
+
def linear_cross_entropy(
|
|
54
|
+
e: torch.Tensor,
|
|
55
|
+
c: torch.Tensor,
|
|
56
|
+
targets: torch.Tensor,
|
|
57
|
+
bias: torch.Tensor | None = None,
|
|
58
|
+
ignore_index: int = IGNORE_INDEX,
|
|
59
|
+
softcap: float | None = None,
|
|
60
|
+
reduction: str = "mean",
|
|
61
|
+
shift: bool | int = 0,
|
|
62
|
+
return_lse: Literal[False] = False,
|
|
63
|
+
filter_eps: float | str | None = "auto",
|
|
64
|
+
accum_e_fp32: bool = False,
|
|
65
|
+
accum_c_fp32: bool = False,
|
|
66
|
+
filter_e_grad: bool = True,
|
|
67
|
+
filter_c_grad: bool = True,
|
|
68
|
+
impl: str | LinearCrossEntropyImpl = LCE_IMPL_DEFAULT,
|
|
69
|
+
vocab_parallel_options: VocabParallelOptions | None = None,
|
|
70
|
+
) -> torch.Tensor: ...
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
@overload
|
|
74
|
+
def linear_cross_entropy(
|
|
75
|
+
e: torch.Tensor,
|
|
76
|
+
c: torch.Tensor,
|
|
77
|
+
targets: torch.Tensor,
|
|
78
|
+
bias: torch.Tensor | None = None,
|
|
79
|
+
ignore_index: int = IGNORE_INDEX,
|
|
80
|
+
softcap: float | None = None,
|
|
81
|
+
reduction: str = "mean",
|
|
82
|
+
shift: bool | int = 0,
|
|
83
|
+
return_lse: Literal[True] = True,
|
|
84
|
+
filter_eps: float | str | None = "auto",
|
|
85
|
+
accum_e_fp32: bool = False,
|
|
86
|
+
accum_c_fp32: bool = False,
|
|
87
|
+
filter_e_grad: bool = True,
|
|
88
|
+
filter_c_grad: bool = True,
|
|
89
|
+
impl: str | LinearCrossEntropyImpl = LCE_IMPL_DEFAULT,
|
|
90
|
+
vocab_parallel_options: VocabParallelOptions | None = None,
|
|
91
|
+
) -> tuple[torch.Tensor, torch.Tensor]: ...
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
@overload
|
|
95
|
+
def linear_cross_entropy(
|
|
96
|
+
e: torch.Tensor,
|
|
97
|
+
c: torch.Tensor,
|
|
98
|
+
targets: torch.Tensor,
|
|
99
|
+
bias: torch.Tensor | None = None,
|
|
100
|
+
ignore_index: int = IGNORE_INDEX,
|
|
101
|
+
softcap: float | None = None,
|
|
102
|
+
reduction: str = "mean",
|
|
103
|
+
shift: bool | int = 0,
|
|
104
|
+
return_lse: bool = False,
|
|
105
|
+
filter_eps: float | str | None = "auto",
|
|
106
|
+
accum_e_fp32: bool = False,
|
|
107
|
+
accum_c_fp32: bool = False,
|
|
108
|
+
filter_e_grad: bool = True,
|
|
109
|
+
filter_c_grad: bool = True,
|
|
110
|
+
impl: str | LinearCrossEntropyImpl = LCE_IMPL_DEFAULT,
|
|
111
|
+
vocab_parallel_options: VocabParallelOptions | None = None,
|
|
112
|
+
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ...
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
@add_doc_start(LINEAR_CROSS_ENTROPY_DOC)
|
|
116
|
+
@add_doc_start(*(doc_str + " Only valid for the cce implementation." for doc_str in CCE_OPTS_DOC))
|
|
117
|
+
@add_doc_start(IMPL_DOC)
|
|
118
|
+
@add_doc_end(DTENSOR_NOTE)
|
|
119
|
+
def linear_cross_entropy(
|
|
120
|
+
e: torch.Tensor,
|
|
121
|
+
c: torch.Tensor,
|
|
122
|
+
targets: torch.Tensor,
|
|
123
|
+
bias: torch.Tensor | None = None,
|
|
124
|
+
ignore_index: int = IGNORE_INDEX,
|
|
125
|
+
softcap: float | None = None,
|
|
126
|
+
reduction: str = "mean",
|
|
127
|
+
shift: bool | int = 0,
|
|
128
|
+
return_lse: bool = False,
|
|
129
|
+
filter_eps: float | str | None = "auto",
|
|
130
|
+
accum_e_fp32: bool = False,
|
|
131
|
+
accum_c_fp32: bool = False,
|
|
132
|
+
filter_e_grad: bool = True,
|
|
133
|
+
filter_c_grad: bool = True,
|
|
134
|
+
impl: str | LinearCrossEntropyImpl = LCE_IMPL_DEFAULT,
|
|
135
|
+
vocab_parallel_options: VocabParallelOptions | None = None,
|
|
136
|
+
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
|
137
|
+
"""
|
|
138
|
+
:param vocab_parallel_options: Used to enable vocab parallelism."""
|
|
139
|
+
|
|
140
|
+
if is_torch_greater_or_equal_2_5():
|
|
141
|
+
maybe_tensor_inputs = dict(e=e, targets=targets)
|
|
142
|
+
for k, v in maybe_tensor_inputs.items():
|
|
143
|
+
if isinstance(v, torch.distributed.tensor.DTensor):
|
|
144
|
+
raise ValueError(is_d_tensor_error_message.format(name=k))
|
|
145
|
+
|
|
146
|
+
c = maybe_type_as(to_full_tensor(c), e)
|
|
147
|
+
bias = maybe_type_as(to_full_tensor(bias), e)
|
|
148
|
+
|
|
149
|
+
if isinstance(impl, LinearCrossEntropyImpl):
|
|
150
|
+
impl = impl.name.lower()
|
|
151
|
+
|
|
152
|
+
if isinstance(shift, int) and (shift < 0 or shift >= targets.size(-1)):
|
|
153
|
+
raise ValueError(f"Shift must be in the range [0, {targets.size(-1)}). Got {shift}.")
|
|
154
|
+
|
|
155
|
+
if vocab_parallel_options is not None:
|
|
156
|
+
expected_v_dim_size = vocab_parallel_options.stop - vocab_parallel_options.start
|
|
157
|
+
if c.size(0) != expected_v_dim_size:
|
|
158
|
+
raise ValueError(f"Expected c.size(0) to be {expected_v_dim_size}, got {c.size(0)}.")
|
|
159
|
+
|
|
160
|
+
if bias is not None and bias.size(0) != c.size(0):
|
|
161
|
+
raise ValueError(
|
|
162
|
+
f"Bias has a different number of elements than c. {bias.size(0)} vs. {c.size(0)}."
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
if impl in CCEPresets.names:
|
|
166
|
+
if platform.system() == "Darwin":
|
|
167
|
+
raise RuntimeError(
|
|
168
|
+
"CCE does not support MacOS. Please use torch_compile when running on MacOS instead."
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
if is_triton_3_2():
|
|
172
|
+
warnings.warn(
|
|
173
|
+
"There is a known issue with CCE and Triton 3.2 (the version that ships with PyTorch 2.6)"
|
|
174
|
+
" that can result in incorrect gradients. If possible, please verify that you"
|
|
175
|
+
" are not impacted by this bug by trying a newer triton version (i.e. by installing PyTorch>2.6).",
|
|
176
|
+
CCEWarning,
|
|
177
|
+
stacklevel=2,
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
cce_opts = CCEPresets.build_for_impl(
|
|
181
|
+
impl,
|
|
182
|
+
CCEPreset(
|
|
183
|
+
filter_eps=filter_eps,
|
|
184
|
+
accum_e_fp32=accum_e_fp32,
|
|
185
|
+
accum_c_fp32=accum_c_fp32,
|
|
186
|
+
filter_e_grad=filter_e_grad,
|
|
187
|
+
filter_c_grad=filter_c_grad,
|
|
188
|
+
),
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
assert cce_linear_cross_entropy is not None
|
|
192
|
+
loss, lse = cce_linear_cross_entropy(
|
|
193
|
+
e,
|
|
194
|
+
c,
|
|
195
|
+
targets,
|
|
196
|
+
bias,
|
|
197
|
+
ignore_index,
|
|
198
|
+
softcap,
|
|
199
|
+
reduction,
|
|
200
|
+
shift,
|
|
201
|
+
**cce_opts,
|
|
202
|
+
vocab_parallel_options=vocab_parallel_options,
|
|
203
|
+
return_lse=return_lse,
|
|
204
|
+
)
|
|
205
|
+
elif impl == "torch_compile":
|
|
206
|
+
loss, lse = torch_compile_linear_cross_entropy(
|
|
207
|
+
e,
|
|
208
|
+
c,
|
|
209
|
+
targets,
|
|
210
|
+
bias,
|
|
211
|
+
ignore_index,
|
|
212
|
+
softcap,
|
|
213
|
+
reduction,
|
|
214
|
+
shift,
|
|
215
|
+
vocab_parallel_options=vocab_parallel_options,
|
|
216
|
+
return_lse=return_lse,
|
|
217
|
+
)
|
|
218
|
+
else:
|
|
219
|
+
raise NotImplementedError(f"{impl} is not implemented.")
|
|
220
|
+
|
|
221
|
+
if return_lse:
|
|
222
|
+
assert lse is not None
|
|
223
|
+
return loss, lse
|
|
224
|
+
else:
|
|
225
|
+
return loss
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
class LinearCrossEntropy(nn.Module):
|
|
229
|
+
def __init__(
|
|
230
|
+
self,
|
|
231
|
+
ignore_index: int = IGNORE_INDEX,
|
|
232
|
+
softcap: float | None = None,
|
|
233
|
+
reduction: str = "mean",
|
|
234
|
+
shift: bool | int = 0,
|
|
235
|
+
filter_eps: float | str | None = "auto",
|
|
236
|
+
accum_e_fp32: bool = False,
|
|
237
|
+
accum_c_fp32: bool = False,
|
|
238
|
+
filter_e_grad: bool = True,
|
|
239
|
+
filter_c_grad: bool = True,
|
|
240
|
+
impl: str | LinearCrossEntropyImpl = LCE_IMPL_DEFAULT,
|
|
241
|
+
return_lse: bool = False,
|
|
242
|
+
):
|
|
243
|
+
super().__init__()
|
|
244
|
+
self.ignore_index = ignore_index
|
|
245
|
+
self.softcap = softcap
|
|
246
|
+
self.reduction = reduction
|
|
247
|
+
self.filter_eps = filter_eps
|
|
248
|
+
self.shift = shift
|
|
249
|
+
|
|
250
|
+
self.accum_e_fp32 = accum_e_fp32
|
|
251
|
+
self.accum_c_fp32 = accum_c_fp32
|
|
252
|
+
|
|
253
|
+
self.filter_e_grad = filter_e_grad
|
|
254
|
+
self.filter_c_grad = filter_c_grad
|
|
255
|
+
|
|
256
|
+
self.impl = impl
|
|
257
|
+
self.return_lse = return_lse
|
|
258
|
+
|
|
259
|
+
def forward(
|
|
260
|
+
self,
|
|
261
|
+
e: torch.Tensor,
|
|
262
|
+
c: torch.Tensor,
|
|
263
|
+
targets: torch.Tensor,
|
|
264
|
+
bias: torch.Tensor | None = None,
|
|
265
|
+
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
|
266
|
+
return linear_cross_entropy(
|
|
267
|
+
e,
|
|
268
|
+
c,
|
|
269
|
+
targets,
|
|
270
|
+
bias=bias,
|
|
271
|
+
ignore_index=self.ignore_index,
|
|
272
|
+
softcap=self.softcap,
|
|
273
|
+
reduction=self.reduction,
|
|
274
|
+
shift=self.shift,
|
|
275
|
+
filter_eps=self.filter_eps,
|
|
276
|
+
accum_e_fp32=self.accum_e_fp32,
|
|
277
|
+
accum_c_fp32=self.accum_c_fp32,
|
|
278
|
+
filter_e_grad=self.filter_e_grad,
|
|
279
|
+
filter_c_grad=self.filter_c_grad,
|
|
280
|
+
impl=self.impl,
|
|
281
|
+
return_lse=self.return_lse,
|
|
282
|
+
)
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from grouped_gemm import backend
|
|
5
|
+
from torch.autograd import Function
|
|
6
|
+
|
|
7
|
+
from d9d.core.autograd import GLOBAL_GRAD_CONTEXT, GradDirection
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class GroupedGemm(Function):
|
|
11
|
+
"""
|
|
12
|
+
Autograd function for Grouped GEMM (Generalized Matrix Multiplication) with explicit gradient control.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
@staticmethod
|
|
16
|
+
def forward(
|
|
17
|
+
ctx: Any,
|
|
18
|
+
a: torch.Tensor,
|
|
19
|
+
b: torch.Tensor,
|
|
20
|
+
batch_sizes: torch.Tensor,
|
|
21
|
+
a_grad_direction: GradDirection | None,
|
|
22
|
+
b_grad_direction: GradDirection | None,
|
|
23
|
+
trans_b: bool
|
|
24
|
+
) -> torch.Tensor:
|
|
25
|
+
ctx.save_for_backward(a, b, batch_sizes)
|
|
26
|
+
ctx.a_grad_direction = a_grad_direction
|
|
27
|
+
ctx.b_grad_direction = b_grad_direction
|
|
28
|
+
ctx.trans_b = trans_b
|
|
29
|
+
return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b)
|
|
30
|
+
|
|
31
|
+
@staticmethod
|
|
32
|
+
def backward(
|
|
33
|
+
ctx: Any, grad: torch.Tensor
|
|
34
|
+
) -> tuple[torch.Tensor | None, torch.Tensor | None, None, None, None, None]:
|
|
35
|
+
grad = grad.contiguous()
|
|
36
|
+
a, b, batch_sizes = ctx.saved_tensors
|
|
37
|
+
trans_b = ctx.trans_b
|
|
38
|
+
|
|
39
|
+
compute_a = GLOBAL_GRAD_CONTEXT.check_direction(ctx.a_grad_direction)
|
|
40
|
+
compute_b = GLOBAL_GRAD_CONTEXT.check_direction(ctx.b_grad_direction)
|
|
41
|
+
|
|
42
|
+
a_grad = None
|
|
43
|
+
if ctx.needs_input_grad[0] and compute_a:
|
|
44
|
+
a_grad = backend.gmm(
|
|
45
|
+
grad, b, batch_sizes, trans_a=False, trans_b=not trans_b)
|
|
46
|
+
|
|
47
|
+
b_grad = None
|
|
48
|
+
if ctx.needs_input_grad[1] and compute_b:
|
|
49
|
+
lhs, rhs = (grad, a) if trans_b else (a, grad)
|
|
50
|
+
b_grad = backend.gmm(
|
|
51
|
+
lhs, rhs, batch_sizes, trans_a=True, trans_b=False)
|
|
52
|
+
return a_grad, b_grad, None, None, None, None
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def gmm(
|
|
56
|
+
a: torch.Tensor,
|
|
57
|
+
b: torch.Tensor,
|
|
58
|
+
batch_sizes: torch.Tensor,
|
|
59
|
+
a_grad_direction: GradDirection | None,
|
|
60
|
+
b_grad_direction: GradDirection | None,
|
|
61
|
+
trans_b: bool = False
|
|
62
|
+
) -> torch.Tensor:
|
|
63
|
+
"""
|
|
64
|
+
The Grouped GEMM (Generalized Matrix Multiplication) function with explicit gradient control.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
a: Left-hand side tensor.
|
|
68
|
+
b: Right-hand side tensor.
|
|
69
|
+
batch_sizes: Sizes of batches/groups.
|
|
70
|
+
a_grad_direction: Gradient category for `a` (e.g., `GradDirection.inputs`).
|
|
71
|
+
b_grad_direction: Gradient category for `b` (e.g., `GradDirection.weight`).
|
|
72
|
+
trans_b: Whether to transpose `b`.
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
Result of matrix multiplication.
|
|
76
|
+
"""
|
|
77
|
+
|
|
78
|
+
return GroupedGemm.apply(a, b, batch_sizes, a_grad_direction, b_grad_direction, trans_b)
|
|
@@ -0,0 +1,268 @@
|
|
|
1
|
+
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/fusions/fused_indices_converter.py
|
|
2
|
+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
|
3
|
+
|
|
4
|
+
import math
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import triton
|
|
8
|
+
import triton.language as tl
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
# Assign a block to a row([1,topk]), generate a local routing map([1,num_of_local_experts])
|
|
12
|
+
@triton.jit
|
|
13
|
+
def _indices_to_multihot_kernel(
|
|
14
|
+
indices_ptr,
|
|
15
|
+
probs_in_indices_ptr,
|
|
16
|
+
multihot_indices_ptr, # bool
|
|
17
|
+
probs_in_multihot_ptr,
|
|
18
|
+
position_map_ptr,
|
|
19
|
+
num_of_local_experts: tl.constexpr,
|
|
20
|
+
num_of_local_experts_next_power_of_2: tl.constexpr,
|
|
21
|
+
topk: tl.constexpr,
|
|
22
|
+
topk_next_power_of_2: tl.constexpr,
|
|
23
|
+
BLOCK_SIZE: tl.constexpr,
|
|
24
|
+
):
|
|
25
|
+
'''
|
|
26
|
+
Triton kernel for converting indices to multihot representation.
|
|
27
|
+
|
|
28
|
+
Input:
|
|
29
|
+
indices: [num_of_tokens, topk]
|
|
30
|
+
probs_in_indices: [num_of_tokens, topk]
|
|
31
|
+
Output:
|
|
32
|
+
multihot_indices: [num_of_tokens, num_of_local_experts]
|
|
33
|
+
probs_in_multihot: [num_of_tokens, num_of_local_experts]
|
|
34
|
+
|
|
35
|
+
Assume that topk = 2 , num_of_local_experts = 4, num_of_tokens = 2,
|
|
36
|
+
then the kernel can process the following conversion:
|
|
37
|
+
|
|
38
|
+
Input Example:
|
|
39
|
+
indices = [
|
|
40
|
+
[0, 1],
|
|
41
|
+
[1, 2]
|
|
42
|
+
]
|
|
43
|
+
probs_in_indices = [
|
|
44
|
+
[0.1, 0.2],
|
|
45
|
+
[0.3, 0.4]
|
|
46
|
+
]
|
|
47
|
+
Output Example:
|
|
48
|
+
multihot_indices = [
|
|
49
|
+
[1, 1, -1, -1],
|
|
50
|
+
[-1, 1, 1, -1]
|
|
51
|
+
]
|
|
52
|
+
probs_in_multihot = [
|
|
53
|
+
[0.1, 0.2, 0.0, 0.0],
|
|
54
|
+
[0.0, 0.3, 0.4, 0.0]
|
|
55
|
+
]
|
|
56
|
+
'''
|
|
57
|
+
# Prepare the [0, topk) row
|
|
58
|
+
topk_row = tl.arange(0, topk_next_power_of_2)
|
|
59
|
+
topk_row = tl.where(topk_row < topk, topk_row, -1)
|
|
60
|
+
topk_row_mask = topk_row != -1
|
|
61
|
+
# Prepare the [0, num_of_local_experts) row
|
|
62
|
+
num_exp_row = tl.arange(0, num_of_local_experts_next_power_of_2)
|
|
63
|
+
num_exp_row = tl.where(num_exp_row < num_of_local_experts, num_exp_row, -1)
|
|
64
|
+
num_exp_row_mask = num_exp_row != -1
|
|
65
|
+
|
|
66
|
+
# Load a [1, topk] row from the indices buffer
|
|
67
|
+
row_idx = tl.program_id(0)
|
|
68
|
+
indices_row = tl.load(indices_ptr + row_idx * topk + topk_row, mask=topk_row_mask)
|
|
69
|
+
indices_row = tl.where(topk_row_mask, indices_row, -1)
|
|
70
|
+
probs_row = tl.load(probs_in_indices_ptr + row_idx * topk + topk_row, mask=topk_row_mask)
|
|
71
|
+
|
|
72
|
+
# Get the position of the each index in the indices_row, which is saved for backwards
|
|
73
|
+
position_row = tl.where(indices_row != -1, topk_row, -1)
|
|
74
|
+
# Mask of the valid indices
|
|
75
|
+
mask = (indices_row != -1) & (indices_row < num_of_local_experts)
|
|
76
|
+
|
|
77
|
+
row_idx_offset = row_idx * num_of_local_experts
|
|
78
|
+
# Store to initialize
|
|
79
|
+
tl.store(multihot_indices_ptr + row_idx_offset + num_exp_row, 0, mask=num_exp_row_mask)
|
|
80
|
+
tl.store(probs_in_multihot_ptr + row_idx_offset + num_exp_row, 0, mask=num_exp_row_mask)
|
|
81
|
+
tl.store(position_map_ptr + row_idx_offset + num_exp_row, -1, mask=num_exp_row_mask)
|
|
82
|
+
# Use barrier to make sure the initialization is done
|
|
83
|
+
tl.debug_barrier()
|
|
84
|
+
# Store the indices and probs_in_indices
|
|
85
|
+
tl.store(multihot_indices_ptr + row_idx_offset + indices_row, 1, mask)
|
|
86
|
+
tl.store(probs_in_multihot_ptr + row_idx_offset + indices_row, probs_row, mask)
|
|
87
|
+
# Store the position of the position_row for backwards
|
|
88
|
+
tl.store(position_map_ptr + row_idx_offset + indices_row, position_row, mask)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
# Assign a block to a row([1,topk]), generate a probs_indices([1,topk])
|
|
92
|
+
@triton.jit
|
|
93
|
+
def _multihot_to_indices_kernel(
|
|
94
|
+
probs_in_multihot_ptr,
|
|
95
|
+
position_map_ptr,
|
|
96
|
+
probs_indices_ptr,
|
|
97
|
+
num_of_local_experts: tl.constexpr,
|
|
98
|
+
num_of_local_experts_next_power_of_2: tl.constexpr,
|
|
99
|
+
topk: tl.constexpr,
|
|
100
|
+
topk_next_power_of_2: tl.constexpr,
|
|
101
|
+
BLOCK_SIZE: tl.constexpr,
|
|
102
|
+
):
|
|
103
|
+
'''
|
|
104
|
+
Triton kernel for converting multihot representation to indices.
|
|
105
|
+
|
|
106
|
+
Input:
|
|
107
|
+
probs_in_multihot: [num_of_tokens, num_of_local_experts]
|
|
108
|
+
position_map: [num_of_tokens, num_of_local_experts]
|
|
109
|
+
Output:
|
|
110
|
+
probs_indices: [num_of_tokens, topk]
|
|
111
|
+
|
|
112
|
+
Assume that topk = 2 , num_of_local_experts = 4, num_of_tokens = 2,
|
|
113
|
+
then the kernel can process the following conversion:
|
|
114
|
+
|
|
115
|
+
Input Example:
|
|
116
|
+
probs_in_multihot = [
|
|
117
|
+
[0.7, 0.8, 0.0, 0.0],
|
|
118
|
+
[0.0, 0.1, 0.9, 0.0]
|
|
119
|
+
]
|
|
120
|
+
position_map = [
|
|
121
|
+
[1, 1, -1, -1],
|
|
122
|
+
[-1, 1, 1, -1]
|
|
123
|
+
]
|
|
124
|
+
Output Example:
|
|
125
|
+
probs_indices = [
|
|
126
|
+
[0.7, 0.8],
|
|
127
|
+
[0.1, 0.9]
|
|
128
|
+
]
|
|
129
|
+
'''
|
|
130
|
+
# Prepare the [0, topk) row
|
|
131
|
+
topk_row = tl.arange(0, topk_next_power_of_2)
|
|
132
|
+
topk_row = tl.where(topk_row < topk, topk_row, -1)
|
|
133
|
+
topk_row_mask = topk_row != -1
|
|
134
|
+
# Prepare the [0, num_of_local_experts) row
|
|
135
|
+
num_exp_row = tl.arange(0, num_of_local_experts_next_power_of_2)
|
|
136
|
+
num_exp_row = tl.where(num_exp_row < num_of_local_experts, num_exp_row, -1)
|
|
137
|
+
num_exp_row_mask = num_exp_row != -1
|
|
138
|
+
|
|
139
|
+
# Load a [1, num_of_local_experts] row from the local routing map
|
|
140
|
+
row_idx = tl.program_id(0)
|
|
141
|
+
ptr_offset = row_idx * num_of_local_experts + num_exp_row
|
|
142
|
+
probs_in_multihot_row = tl.load(probs_in_multihot_ptr + ptr_offset, mask=num_exp_row_mask)
|
|
143
|
+
|
|
144
|
+
# Get the original position of the valid value in the the indices
|
|
145
|
+
position_map_row = tl.load(position_map_ptr + ptr_offset, mask=num_exp_row_mask)
|
|
146
|
+
position_map_row = tl.where(num_exp_row_mask, position_map_row, -1)
|
|
147
|
+
mask = position_map_row != -1
|
|
148
|
+
|
|
149
|
+
# Store to initialize
|
|
150
|
+
tl.store(probs_indices_ptr + row_idx * topk + topk_row, 0, mask=topk_row_mask)
|
|
151
|
+
# Use barrier to make sure the initialization is done
|
|
152
|
+
tl.debug_barrier()
|
|
153
|
+
# Restore the indices and probs_indices
|
|
154
|
+
tl.store(probs_indices_ptr + row_idx * topk + position_map_row, probs_in_multihot_row, mask)
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
class IndicesToMultihot(torch.autograd.Function):
|
|
158
|
+
"""Convert moe topk indices to multihot representation.
|
|
159
|
+
|
|
160
|
+
This class implements a custom forward and backward propagation
|
|
161
|
+
operation for efficiently converting indices to multihot
|
|
162
|
+
representation.
|
|
163
|
+
It is an experimental feature and may change in future versions.
|
|
164
|
+
"""
|
|
165
|
+
|
|
166
|
+
@staticmethod
|
|
167
|
+
def forward(ctx, indices, probs_indices, num_of_local_experts):
|
|
168
|
+
'''Forward function for IndicesToMultihot
|
|
169
|
+
|
|
170
|
+
Convert indices to multihot representation.
|
|
171
|
+
|
|
172
|
+
Args:
|
|
173
|
+
indices: [num_of_tokens, topk]
|
|
174
|
+
probs_indices: [num_of_tokens, topk]
|
|
175
|
+
num_of_local_experts: int
|
|
176
|
+
|
|
177
|
+
Returns:
|
|
178
|
+
multihot_indices: [num_of_tokens, num_of_local_experts]
|
|
179
|
+
probs_in_multihot: [num_of_tokens, num_of_local_experts]
|
|
180
|
+
'''
|
|
181
|
+
num_of_tokens = indices.shape[0]
|
|
182
|
+
assert (
|
|
183
|
+
indices.shape == probs_indices.shape
|
|
184
|
+
), "indices and probs_indices must have the same shape"
|
|
185
|
+
topk = indices.shape[1]
|
|
186
|
+
multihot_indices = torch.empty(
|
|
187
|
+
(num_of_tokens, num_of_local_experts), dtype=torch.bool, device="cuda"
|
|
188
|
+
)
|
|
189
|
+
probs_in_multihot = torch.empty(
|
|
190
|
+
(num_of_tokens, num_of_local_experts), dtype=probs_indices.dtype, device="cuda"
|
|
191
|
+
)
|
|
192
|
+
position_map = torch.empty(
|
|
193
|
+
(num_of_tokens, num_of_local_experts), dtype=torch.int32, device="cuda"
|
|
194
|
+
)
|
|
195
|
+
# Compute the next power of 2 for the topk and num_of_local_experts
|
|
196
|
+
topk_next_power_of_2 = 2 ** int(math.ceil(math.log2(topk)))
|
|
197
|
+
num_of_local_experts_next_power_of_2 = 2 ** int(math.ceil(math.log2(num_of_local_experts)))
|
|
198
|
+
grid = (num_of_tokens,)
|
|
199
|
+
_indices_to_multihot_kernel[grid](
|
|
200
|
+
indices,
|
|
201
|
+
probs_indices,
|
|
202
|
+
multihot_indices,
|
|
203
|
+
probs_in_multihot,
|
|
204
|
+
position_map,
|
|
205
|
+
num_of_local_experts,
|
|
206
|
+
num_of_local_experts_next_power_of_2,
|
|
207
|
+
topk,
|
|
208
|
+
topk_next_power_of_2,
|
|
209
|
+
BLOCK_SIZE=32, # use only 1 warp per block
|
|
210
|
+
num_warps=1,
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
ctx.save_for_backward(position_map)
|
|
214
|
+
ctx.num_of_tokens = num_of_tokens
|
|
215
|
+
ctx.num_of_local_experts = num_of_local_experts
|
|
216
|
+
ctx.topk = topk
|
|
217
|
+
return multihot_indices, probs_in_multihot
|
|
218
|
+
|
|
219
|
+
@staticmethod
|
|
220
|
+
def backward(ctx, grad_multihot_indices, grad_probs_in_multihot):
|
|
221
|
+
'''Backward function for IndicesToMultihot
|
|
222
|
+
|
|
223
|
+
Convert multihot probs representation to indices.
|
|
224
|
+
indices is ignored in the backward function.
|
|
225
|
+
|
|
226
|
+
Args:
|
|
227
|
+
grad_multihot_indices: [num_of_tokens, num_of_local_experts]
|
|
228
|
+
grad_probs_in_multihot: [num_of_tokens, num_of_local_experts]
|
|
229
|
+
|
|
230
|
+
Returns:
|
|
231
|
+
grad_probs_indices: [num_of_tokens, topk]
|
|
232
|
+
'''
|
|
233
|
+
position_map = ctx.saved_tensors[0]
|
|
234
|
+
num_of_tokens = ctx.num_of_tokens
|
|
235
|
+
num_of_local_experts = ctx.num_of_local_experts
|
|
236
|
+
topk = ctx.topk
|
|
237
|
+
|
|
238
|
+
# Initialize the gradient of the indices and probs_indices
|
|
239
|
+
grad_probs_indices = torch.empty(
|
|
240
|
+
(num_of_tokens, topk), dtype=grad_probs_in_multihot.dtype, device="cuda"
|
|
241
|
+
)
|
|
242
|
+
# Compute the next power of 2 for the topk and num_of_local_experts
|
|
243
|
+
topk_next_power_of_2 = 2 ** int(math.ceil(math.log2(topk)))
|
|
244
|
+
num_of_local_experts_next_power_of_2 = 2 ** int(math.ceil(math.log2(num_of_local_experts)))
|
|
245
|
+
|
|
246
|
+
grid = (num_of_tokens,)
|
|
247
|
+
_multihot_to_indices_kernel[grid](
|
|
248
|
+
# if the grad_probs_in_multihot is all-one/all-zero,
|
|
249
|
+
# overlapping stride will cause error without contiguous()
|
|
250
|
+
grad_probs_in_multihot.contiguous(),
|
|
251
|
+
position_map,
|
|
252
|
+
grad_probs_indices,
|
|
253
|
+
num_of_local_experts,
|
|
254
|
+
num_of_local_experts_next_power_of_2,
|
|
255
|
+
topk,
|
|
256
|
+
topk_next_power_of_2,
|
|
257
|
+
BLOCK_SIZE=32, # use only 1 warp per block
|
|
258
|
+
num_warps=1,
|
|
259
|
+
)
|
|
260
|
+
return None, grad_probs_indices, None, None
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
def fused_indices_to_multihot(indices, probs_indices, num_of_local_experts):
|
|
264
|
+
"""Convert moe topk indices to multihot representation.
|
|
265
|
+
|
|
266
|
+
This function is an experimental feature and may change in future versions.
|
|
267
|
+
"""
|
|
268
|
+
return IndicesToMultihot.apply(indices, probs_indices, num_of_local_experts)
|