liger-kernel-nightly 0.0.1.dev20240819184814__py3-none-any.whl → 0.6.4.dev20251212103629__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/__init__.py +0 -0
- liger_kernel/chunked_loss/README.md +25 -0
- liger_kernel/chunked_loss/__init__.py +8 -0
- liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
- liger_kernel/chunked_loss/cpo_loss.py +157 -0
- liger_kernel/chunked_loss/dpo_loss.py +229 -0
- liger_kernel/chunked_loss/functional.py +17 -0
- liger_kernel/chunked_loss/fused_linear_distillation.py +292 -0
- liger_kernel/chunked_loss/fused_linear_ppo.py +366 -0
- liger_kernel/chunked_loss/fused_linear_preference.py +433 -0
- liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +341 -0
- liger_kernel/chunked_loss/grpo_loss.py +307 -0
- liger_kernel/chunked_loss/jsd_loss.py +200 -0
- liger_kernel/chunked_loss/kto_loss.py +210 -0
- liger_kernel/chunked_loss/orpo_loss.py +144 -0
- liger_kernel/chunked_loss/simpo_loss.py +165 -0
- liger_kernel/env_report.py +63 -0
- liger_kernel/ops/__init__.py +141 -0
- liger_kernel/ops/backends/README.md +151 -0
- liger_kernel/ops/backends/__init__.py +13 -0
- liger_kernel/ops/backends/_ascend/__init__.py +5 -0
- liger_kernel/ops/backends/_ascend/ops/__init__.py +15 -0
- liger_kernel/ops/backends/registry.py +61 -0
- liger_kernel/ops/cross_entropy.py +383 -114
- liger_kernel/ops/dyt.py +160 -0
- liger_kernel/ops/experimental/embedding.py +141 -0
- liger_kernel/ops/experimental/mm_int8int2.py +349 -0
- liger_kernel/ops/fused_add_rms_norm.py +416 -0
- liger_kernel/ops/fused_linear_cross_entropy.py +346 -132
- liger_kernel/ops/fused_linear_jsd.py +228 -0
- liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
- liger_kernel/ops/geglu.py +66 -64
- liger_kernel/ops/group_norm.py +306 -0
- liger_kernel/ops/grpo_loss.py +312 -0
- liger_kernel/ops/jsd.py +201 -0
- liger_kernel/ops/kl_div.py +262 -0
- liger_kernel/ops/layer_norm.py +320 -0
- liger_kernel/ops/llama4_rope.py +225 -0
- liger_kernel/ops/multi_token_attention.py +207 -0
- liger_kernel/ops/poly_norm.py +390 -0
- liger_kernel/ops/qwen2vl_mrope.py +222 -0
- liger_kernel/ops/rms_norm.py +484 -88
- liger_kernel/ops/rope.py +122 -117
- liger_kernel/ops/softmax.py +201 -0
- liger_kernel/ops/sparsemax.py +179 -0
- liger_kernel/ops/swiglu.py +68 -65
- liger_kernel/ops/tiled_mlp.py +136 -0
- liger_kernel/ops/tvd.py +207 -0
- liger_kernel/ops/utils.py +82 -3
- liger_kernel/transformers/__init__.py +218 -6
- liger_kernel/transformers/auto_model.py +38 -0
- liger_kernel/transformers/cross_entropy.py +52 -7
- liger_kernel/transformers/dyt.py +22 -0
- liger_kernel/transformers/experimental/__init__.py +5 -0
- liger_kernel/transformers/experimental/embedding.py +26 -0
- liger_kernel/transformers/fsdp.py +55 -0
- liger_kernel/transformers/functional.py +301 -0
- liger_kernel/transformers/fused_add_rms_norm.py +39 -0
- liger_kernel/transformers/fused_linear_cross_entropy.py +59 -10
- liger_kernel/transformers/fused_linear_jsd.py +95 -0
- liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
- liger_kernel/transformers/geglu.py +6 -7
- liger_kernel/transformers/group_norm.py +50 -0
- liger_kernel/transformers/grpo_loss.py +153 -0
- liger_kernel/transformers/jsd.py +70 -0
- liger_kernel/transformers/kl_div.py +12 -0
- liger_kernel/transformers/layer_norm.py +24 -0
- liger_kernel/transformers/llama4_rope.py +93 -0
- liger_kernel/transformers/model/falcon_h1.py +122 -0
- liger_kernel/transformers/model/gemma.py +261 -0
- liger_kernel/transformers/model/gemma2.py +283 -0
- liger_kernel/transformers/model/gemma3.py +332 -0
- liger_kernel/transformers/model/glm4.py +141 -0
- liger_kernel/transformers/model/glm4v.py +163 -0
- liger_kernel/transformers/model/glm4v_moe.py +172 -0
- liger_kernel/transformers/model/gpt_oss.py +211 -0
- liger_kernel/transformers/model/hunyuan_v1.py +134 -0
- liger_kernel/transformers/model/internvl.py +157 -0
- liger_kernel/transformers/model/llama.py +221 -41
- liger_kernel/transformers/model/llama4.py +121 -0
- liger_kernel/transformers/model/llava.py +344 -0
- liger_kernel/transformers/model/loss_utils.py +95 -0
- liger_kernel/transformers/model/mistral.py +145 -0
- liger_kernel/transformers/model/mixtral.py +293 -0
- liger_kernel/transformers/model/mllama.py +269 -0
- liger_kernel/transformers/model/olmo2.py +141 -0
- liger_kernel/transformers/model/olmo3.py +142 -0
- liger_kernel/transformers/model/output_classes.py +147 -0
- liger_kernel/transformers/model/paligemma.py +433 -0
- liger_kernel/transformers/model/phi3.py +120 -0
- liger_kernel/transformers/model/qwen2.py +259 -0
- liger_kernel/transformers/model/qwen2_5_vl.py +163 -0
- liger_kernel/transformers/model/qwen2_vl.py +159 -0
- liger_kernel/transformers/model/qwen3.py +136 -0
- liger_kernel/transformers/model/qwen3_moe.py +152 -0
- liger_kernel/transformers/model/qwen3_next.py +146 -0
- liger_kernel/transformers/model/qwen3_vl.py +150 -0
- liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
- liger_kernel/transformers/model/smollm3.py +199 -0
- liger_kernel/transformers/model/smolvlm.py +158 -0
- liger_kernel/transformers/monkey_patch.py +2816 -21
- liger_kernel/transformers/multi_token_attention.py +64 -0
- liger_kernel/transformers/poly_norm.py +42 -0
- liger_kernel/transformers/qwen2vl_mrope.py +20 -0
- liger_kernel/transformers/rms_norm.py +75 -5
- liger_kernel/transformers/rope.py +47 -3
- liger_kernel/transformers/softmax.py +12 -0
- liger_kernel/transformers/sparsemax.py +16 -0
- liger_kernel/transformers/swiglu.py +62 -6
- liger_kernel/transformers/tiled_mlp.py +133 -0
- liger_kernel/transformers/trainer/__init__.py +4 -0
- liger_kernel/transformers/trainer/orpo_trainer.py +130 -0
- liger_kernel/transformers/trainer_integration.py +2 -45
- liger_kernel/transformers/tvd.py +13 -0
- liger_kernel/triton/__init__.py +1 -3
- liger_kernel/triton/monkey_patch.py +1 -5
- liger_kernel/utils.py +96 -0
- liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/METADATA +447 -0
- liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/NOTICE +58 -0
- liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/RECORD +124 -0
- {liger_kernel_nightly-0.0.1.dev20240819184814.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/WHEEL +1 -1
- liger_kernel_nightly-0.0.1.dev20240819184814.dist-info/METADATA +0 -21
- liger_kernel_nightly-0.0.1.dev20240819184814.dist-info/NOTICE +0 -4
- liger_kernel_nightly-0.0.1.dev20240819184814.dist-info/RECORD +0 -27
- {liger_kernel_nightly-0.0.1.dev20240819184814.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.0.1.dev20240819184814.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/top_level.txt +0 -0
liger_kernel/ops/geglu.py
CHANGED
|
@@ -4,23 +4,25 @@ import torch
|
|
|
4
4
|
import triton
|
|
5
5
|
import triton.language as tl
|
|
6
6
|
|
|
7
|
-
from liger_kernel.ops.utils import
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
7
|
+
from liger_kernel.ops.utils import calculate_settings
|
|
8
|
+
from liger_kernel.ops.utils import compare_version
|
|
9
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
|
10
|
+
from liger_kernel.utils import is_npu_available
|
|
11
|
+
|
|
12
|
+
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
|
|
13
|
+
try:
|
|
14
|
+
# typical import path with dispatch available
|
|
15
|
+
from triton.language.extra.libdevice import tanh
|
|
16
|
+
except ModuleNotFoundError:
|
|
17
|
+
# for working with NGC containers
|
|
18
|
+
from triton.language.extra.cuda.libdevice import tanh
|
|
15
19
|
else:
|
|
16
20
|
from triton.language.math import tanh
|
|
17
21
|
|
|
18
22
|
|
|
19
23
|
@triton.jit
|
|
20
|
-
def _geglu_tanh_forward_kernel(
|
|
21
|
-
|
|
22
|
-
):
|
|
23
|
-
program_id = tl.program_id(0)
|
|
24
|
+
def _geglu_tanh_forward_kernel(a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
|
|
25
|
+
program_id = tl.program_id(0).to(tl.int64)
|
|
24
26
|
|
|
25
27
|
# locate start index
|
|
26
28
|
a += program_id * stride
|
|
@@ -39,15 +41,13 @@ def _geglu_tanh_forward_kernel(
|
|
|
39
41
|
tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)
|
|
40
42
|
tanh_result = tanh(tanh_arg)
|
|
41
43
|
geglu_a = 0.5 * a_row * (1 + tanh_result)
|
|
42
|
-
c_row = geglu_a * b_row
|
|
44
|
+
c_row = geglu_a.cast(b_row.dtype) * b_row
|
|
43
45
|
tl.store(c + col_offsets, c_row, mask=mask)
|
|
44
46
|
|
|
45
47
|
|
|
46
48
|
@triton.jit
|
|
47
|
-
def _geglu_tanh_backward_kernel(
|
|
48
|
-
|
|
49
|
-
):
|
|
50
|
-
program_id = tl.program_id(0)
|
|
49
|
+
def _geglu_tanh_backward_kernel(dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
|
|
50
|
+
program_id = tl.program_id(0).to(tl.int64)
|
|
51
51
|
|
|
52
52
|
# locate start index
|
|
53
53
|
dc += program_id * stride
|
|
@@ -75,66 +75,68 @@ def _geglu_tanh_backward_kernel(
|
|
|
75
75
|
# where z = sqrt(2/pi) * (a + 0.044715 * a^3)
|
|
76
76
|
term1 = 0.5 * (1 + tanh_result)
|
|
77
77
|
tanh_sq = tanh_result * tanh_result
|
|
78
|
-
term2 = (
|
|
79
|
-
0.5
|
|
80
|
-
* a_row
|
|
81
|
-
* (1 - tanh_sq)
|
|
82
|
-
* (sqrt_2_over_pi * (1 + 3 * 0.044715 * a_row * a_row))
|
|
83
|
-
)
|
|
78
|
+
term2 = 0.5 * a_row * (1 - tanh_sq) * (sqrt_2_over_pi * (1 + 3 * 0.044715 * a_row * a_row))
|
|
84
79
|
da_row = dc_row * b_row * (term1 + term2)
|
|
85
80
|
|
|
86
81
|
tl.store(a + col_offsets, da_row, mask=mask)
|
|
87
82
|
tl.store(b + col_offsets, db_row, mask=mask)
|
|
88
83
|
|
|
89
84
|
|
|
85
|
+
def geglu_forward(a, b):
|
|
86
|
+
ori_shape = a.shape
|
|
87
|
+
|
|
88
|
+
n_cols = ori_shape[-1]
|
|
89
|
+
a = a.view(-1, n_cols)
|
|
90
|
+
b = b.view(-1, n_cols)
|
|
91
|
+
c = torch.empty_like(a)
|
|
92
|
+
n_rows = a.shape[0]
|
|
93
|
+
|
|
94
|
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
95
|
+
|
|
96
|
+
_geglu_tanh_forward_kernel[(n_rows,)](
|
|
97
|
+
a,
|
|
98
|
+
b,
|
|
99
|
+
c,
|
|
100
|
+
c.stride(-2),
|
|
101
|
+
n_cols=n_cols,
|
|
102
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
103
|
+
num_warps=num_warps,
|
|
104
|
+
)
|
|
105
|
+
return a, b, c.view(*ori_shape)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def geglu_backward(a, b, dc):
|
|
109
|
+
ori_shape = dc.shape
|
|
110
|
+
n_cols = ori_shape[-1]
|
|
111
|
+
dc = dc.view(-1, n_cols)
|
|
112
|
+
n_rows = dc.shape[0]
|
|
113
|
+
|
|
114
|
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
115
|
+
|
|
116
|
+
_geglu_tanh_backward_kernel[(n_rows,)](
|
|
117
|
+
dc,
|
|
118
|
+
a,
|
|
119
|
+
b,
|
|
120
|
+
dc.stride(-2),
|
|
121
|
+
n_cols=n_cols,
|
|
122
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
123
|
+
num_warps=num_warps,
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
return a.view(*ori_shape), b.view(*ori_shape)
|
|
127
|
+
|
|
128
|
+
|
|
90
129
|
class LigerGELUMulFunction(torch.autograd.Function):
|
|
91
130
|
@staticmethod
|
|
92
131
|
@ensure_contiguous
|
|
93
132
|
def forward(ctx, a, b):
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
n_cols = ori_shape[-1]
|
|
97
|
-
a = a.view(-1, n_cols)
|
|
98
|
-
b = b.view(-1, n_cols)
|
|
99
|
-
c = torch.zeros_like(a)
|
|
100
|
-
n_rows = a.shape[0]
|
|
101
|
-
|
|
102
|
-
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
103
|
-
|
|
104
|
-
_geglu_tanh_forward_kernel[(n_rows,)](
|
|
105
|
-
a,
|
|
106
|
-
b,
|
|
107
|
-
c,
|
|
108
|
-
c.stride(-2),
|
|
109
|
-
n_cols=n_cols,
|
|
110
|
-
BLOCK_SIZE=BLOCK_SIZE,
|
|
111
|
-
num_warps=num_warps,
|
|
112
|
-
)
|
|
113
|
-
|
|
133
|
+
a, b, c = geglu_forward(a, b)
|
|
114
134
|
ctx.save_for_backward(a, b)
|
|
115
|
-
|
|
116
|
-
return c.view(*ori_shape)
|
|
135
|
+
return c
|
|
117
136
|
|
|
118
137
|
@staticmethod
|
|
119
138
|
@ensure_contiguous
|
|
120
139
|
def backward(ctx, dc):
|
|
121
|
-
|
|
122
|
-
ori_shape = dc.shape
|
|
123
|
-
n_cols = ori_shape[-1]
|
|
124
|
-
dc = dc.view(-1, n_cols)
|
|
125
140
|
a, b = ctx.saved_tensors
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
129
|
-
|
|
130
|
-
_geglu_tanh_backward_kernel[(n_rows,)](
|
|
131
|
-
dc,
|
|
132
|
-
a,
|
|
133
|
-
b,
|
|
134
|
-
dc.stride(-2),
|
|
135
|
-
n_cols=n_cols,
|
|
136
|
-
BLOCK_SIZE=BLOCK_SIZE,
|
|
137
|
-
num_warps=num_warps,
|
|
138
|
-
)
|
|
139
|
-
|
|
140
|
-
return a.view(*ori_shape), b.view(*ori_shape)
|
|
141
|
+
a, b = geglu_backward(a, b, dc)
|
|
142
|
+
return a, b
|
|
@@ -0,0 +1,306 @@
|
|
|
1
|
+
import operator
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import triton
|
|
5
|
+
import triton.language as tl
|
|
6
|
+
|
|
7
|
+
from liger_kernel.ops.utils import compare_version
|
|
8
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
|
9
|
+
from liger_kernel.utils import is_npu_available
|
|
10
|
+
|
|
11
|
+
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
|
|
12
|
+
try:
|
|
13
|
+
# typical import path with dispatch available
|
|
14
|
+
from triton.language.extra.libdevice import rsqrt
|
|
15
|
+
except ModuleNotFoundError:
|
|
16
|
+
# for working with NGC containers
|
|
17
|
+
from triton.language.extra.cuda.libdevice import rsqrt
|
|
18
|
+
else:
|
|
19
|
+
from triton.language.math import rsqrt
|
|
20
|
+
|
|
21
|
+
MAX_FUSED_SIZE = 65536
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@triton.jit
|
|
25
|
+
def _group_norm_forward_kernel(
|
|
26
|
+
Y_ptr, # pointer to output, shape (n_rows, n_groups, hidden_size)
|
|
27
|
+
Y_row_stride, # stride of each row in output
|
|
28
|
+
Y_col_stride, # stride of each column in output
|
|
29
|
+
X_ptr, # pointer to input, shape (n_rows, n_groups, hidden_size)
|
|
30
|
+
X_row_stride, # stride of each row in input
|
|
31
|
+
X_col_stride, # stride of each column in input
|
|
32
|
+
Mean_ptr, # pointer to mean, shape (n_rows, n_groups)
|
|
33
|
+
Mean_row_stride, # stride of each row in mean
|
|
34
|
+
Mean_col_stride, # stride of each column in mean
|
|
35
|
+
RSTD_ptr, # pointer to rstd, shape (n_rows, n_groups)
|
|
36
|
+
RSTD_row_stride, # stride of each row in rstd
|
|
37
|
+
RSTD_col_stride, # stride of each column in rstd
|
|
38
|
+
W_ptr, # pointer to W
|
|
39
|
+
B_ptr, # pointer to B
|
|
40
|
+
hidden_size, # hidden size of X
|
|
41
|
+
channels_per_group, # the number of channels per group
|
|
42
|
+
eps,
|
|
43
|
+
BLOCK_SIZE: tl.constexpr,
|
|
44
|
+
):
|
|
45
|
+
"""
|
|
46
|
+
References:
|
|
47
|
+
https://nn.labml.ai/normalization/group_norm/index.html
|
|
48
|
+
"""
|
|
49
|
+
batch_idx = tl.program_id(0)
|
|
50
|
+
group_idx = tl.program_id(1)
|
|
51
|
+
|
|
52
|
+
X_ptr += batch_idx * X_row_stride + group_idx * X_col_stride
|
|
53
|
+
Y_ptr += batch_idx * Y_row_stride + group_idx * Y_col_stride
|
|
54
|
+
|
|
55
|
+
block_range = tl.arange(0, BLOCK_SIZE)
|
|
56
|
+
|
|
57
|
+
# Compute mean and variance using the online algorithm
|
|
58
|
+
s = 0.0
|
|
59
|
+
squared_sum = 0.0
|
|
60
|
+
for i in tl.range(0, hidden_size, BLOCK_SIZE):
|
|
61
|
+
hidden_size_offsets = i + block_range
|
|
62
|
+
mask = hidden_size_offsets < hidden_size
|
|
63
|
+
X = tl.load(X_ptr + hidden_size_offsets, mask=mask, other=0.0)
|
|
64
|
+
s += tl.sum(X)
|
|
65
|
+
# X**2
|
|
66
|
+
squared_sum += tl.sum(X * X)
|
|
67
|
+
|
|
68
|
+
m = s / hidden_size
|
|
69
|
+
|
|
70
|
+
# variance = E[X**2] - E[X]**2
|
|
71
|
+
variance = (squared_sum / hidden_size) - (m * m)
|
|
72
|
+
|
|
73
|
+
# 1/std
|
|
74
|
+
rstd = rsqrt(variance + eps)
|
|
75
|
+
|
|
76
|
+
# Normalize
|
|
77
|
+
hidden_size_per_channel = hidden_size // channels_per_group
|
|
78
|
+
for channel_idx in tl.range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group):
|
|
79
|
+
W = tl.load(W_ptr + channel_idx)
|
|
80
|
+
B = tl.load(B_ptr + channel_idx)
|
|
81
|
+
for i in range(0, hidden_size_per_channel, BLOCK_SIZE):
|
|
82
|
+
hidden_size_offsets = i + block_range
|
|
83
|
+
mask = hidden_size_offsets < hidden_size_per_channel
|
|
84
|
+
X = tl.load(X_ptr + hidden_size_offsets, mask=mask, other=m)
|
|
85
|
+
Y = (X - m) * rstd * W + B
|
|
86
|
+
tl.store(Y_ptr + hidden_size_offsets, Y, mask=mask)
|
|
87
|
+
|
|
88
|
+
X_ptr += hidden_size_per_channel
|
|
89
|
+
Y_ptr += hidden_size_per_channel
|
|
90
|
+
|
|
91
|
+
tl.store(Mean_ptr + batch_idx * Mean_row_stride + group_idx * Mean_col_stride, m)
|
|
92
|
+
tl.store(RSTD_ptr + batch_idx * RSTD_row_stride + group_idx * RSTD_col_stride, rstd)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
@triton.jit
|
|
96
|
+
def _group_norm_backward_kernel(
|
|
97
|
+
X_ptr, # pointer to input, shape (n_rows, n_channels, hidden_size)
|
|
98
|
+
X_row_stride, # stride of each row in input
|
|
99
|
+
X_col_stride, # stride of each column in input
|
|
100
|
+
W_ptr, # pointer to weights, shape (n_channels)
|
|
101
|
+
Mean_ptr, # pointer to mean, shape (n_rows, n_groups)
|
|
102
|
+
Mean_ptr_row_stride, # stride of each column in mean
|
|
103
|
+
Mean_ptr_col_stride, # stride of each column in mean
|
|
104
|
+
RSTD_ptr, # pointer to rstd, shape (n_rows, n_groups)
|
|
105
|
+
DX_ptr, # pointer to input grad, shape (n_rows, n_groups, hidden_size)
|
|
106
|
+
DW_ptr, # pointer to weights grad, shape (n_channels)
|
|
107
|
+
DB_ptr, # pointer to bias grad, shape (n_channels)
|
|
108
|
+
UPSTREAM_ptr, # pointer to output grad, shape (n_rows, n_channels, hidden_size)
|
|
109
|
+
hidden_size: tl.constexpr, # hidden size
|
|
110
|
+
channels_per_group: tl.constexpr, # number of groups in group norm
|
|
111
|
+
BLOCK_SIZE: tl.constexpr,
|
|
112
|
+
dtype: tl.constexpr,
|
|
113
|
+
):
|
|
114
|
+
"""
|
|
115
|
+
References:
|
|
116
|
+
https://nn.labml.ai/normalization/group_norm/index.html
|
|
117
|
+
https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
|
|
118
|
+
|
|
119
|
+
The backprop equations are the same for group_norm and layer_norm
|
|
120
|
+
the only difference here is that we load the Mean, Rstd corresponding to the
|
|
121
|
+
group we're computing gradients for and the mean and rstd are computed over n-channels
|
|
122
|
+
so the total number of elements we compute the mean over is num_channels_per_group * hidden_size
|
|
123
|
+
|
|
124
|
+
We also need to load the Weights corresponding to the current channel to compute the gradients.
|
|
125
|
+
"""
|
|
126
|
+
batch_idx = tl.program_id(0)
|
|
127
|
+
group_idx = tl.program_id(1)
|
|
128
|
+
|
|
129
|
+
# Move the pointers to the correct batch
|
|
130
|
+
X_ptr += batch_idx * X_row_stride
|
|
131
|
+
DX_ptr += batch_idx * X_row_stride
|
|
132
|
+
UPSTREAM_ptr += batch_idx * X_row_stride
|
|
133
|
+
|
|
134
|
+
# Mean and rstd are the same shape so have the same strides
|
|
135
|
+
mean = tl.load(Mean_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride)
|
|
136
|
+
rstd = tl.load(RSTD_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride)
|
|
137
|
+
|
|
138
|
+
c1 = 0.0
|
|
139
|
+
c2 = 0.0
|
|
140
|
+
block_range = tl.arange(0, BLOCK_SIZE)
|
|
141
|
+
|
|
142
|
+
# We need to compute the sum terms of the backprop equations across all channels in the group
|
|
143
|
+
for channel_idx in range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group):
|
|
144
|
+
dW = 0.0
|
|
145
|
+
dB = 0.0
|
|
146
|
+
# Move the pointers to the correct channel
|
|
147
|
+
W = tl.load(W_ptr + channel_idx)
|
|
148
|
+
for i in tl.range(0, hidden_size, BLOCK_SIZE):
|
|
149
|
+
hidden_size_offsets = i + block_range
|
|
150
|
+
mask = hidden_size_offsets < hidden_size
|
|
151
|
+
X = tl.load(
|
|
152
|
+
X_ptr + channel_idx * X_col_stride + hidden_size_offsets,
|
|
153
|
+
mask=mask,
|
|
154
|
+
other=0.0,
|
|
155
|
+
)
|
|
156
|
+
UPSTREAM_grad = tl.load(
|
|
157
|
+
UPSTREAM_ptr + channel_idx * X_col_stride + hidden_size_offsets,
|
|
158
|
+
mask=mask,
|
|
159
|
+
other=0.0,
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
x_hat = (X - mean) * rstd
|
|
163
|
+
dW += tl.sum(UPSTREAM_grad * x_hat)
|
|
164
|
+
dB += tl.sum(UPSTREAM_grad)
|
|
165
|
+
|
|
166
|
+
wdy = W * UPSTREAM_grad
|
|
167
|
+
c1 += tl.sum(x_hat * wdy)
|
|
168
|
+
c2 += tl.sum(wdy)
|
|
169
|
+
|
|
170
|
+
# Need to ensure additions to the same channel are atomic
|
|
171
|
+
tl.atomic_add(DW_ptr + channel_idx, dW.to(dtype))
|
|
172
|
+
tl.atomic_add(DB_ptr + channel_idx, dB.to(dtype))
|
|
173
|
+
|
|
174
|
+
N = hidden_size * channels_per_group
|
|
175
|
+
c1 = c1 / N
|
|
176
|
+
c2 = c2 / N
|
|
177
|
+
|
|
178
|
+
for channel_idx in tl.range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group):
|
|
179
|
+
# Move the pointers to the correct channel
|
|
180
|
+
W = tl.load(W_ptr + channel_idx)
|
|
181
|
+
for i in range(0, hidden_size, BLOCK_SIZE):
|
|
182
|
+
hidden_size_offsets = i + block_range
|
|
183
|
+
mask = hidden_size_offsets < hidden_size
|
|
184
|
+
X = tl.load(
|
|
185
|
+
X_ptr + channel_idx * X_col_stride + hidden_size_offsets,
|
|
186
|
+
mask=mask,
|
|
187
|
+
other=0.0,
|
|
188
|
+
)
|
|
189
|
+
UPSTREAM_grad = tl.load(
|
|
190
|
+
UPSTREAM_ptr + channel_idx * X_col_stride + hidden_size_offsets,
|
|
191
|
+
mask=mask,
|
|
192
|
+
other=0.0,
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
x_hat = (X - mean) * rstd
|
|
196
|
+
wdy = W * UPSTREAM_grad
|
|
197
|
+
dx = (wdy - (x_hat * c1 + c2)) * rstd
|
|
198
|
+
tl.store(DX_ptr + channel_idx * X_col_stride + hidden_size_offsets, dx, mask=mask)
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def group_norm_forward(X, num_channels, num_groups, W, B, eps):
|
|
202
|
+
shape = X.shape
|
|
203
|
+
batch_size = shape[0]
|
|
204
|
+
channels_per_group = num_channels // num_groups
|
|
205
|
+
# Reshape X so that the mean and std are computed across the groups
|
|
206
|
+
X = X.view(batch_size, num_groups, -1).contiguous()
|
|
207
|
+
hidden_size = X.shape[-1]
|
|
208
|
+
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(hidden_size))
|
|
209
|
+
Y = torch.empty((batch_size, num_groups, hidden_size), dtype=X.dtype, device=X.device)
|
|
210
|
+
Mean = torch.zeros((batch_size, num_groups), dtype=X.dtype, device=X.device)
|
|
211
|
+
RSTD = torch.zeros((batch_size, num_groups), dtype=X.dtype, device=X.device)
|
|
212
|
+
|
|
213
|
+
_group_norm_forward_kernel[(batch_size, num_groups)](
|
|
214
|
+
Y,
|
|
215
|
+
Y.stride(0),
|
|
216
|
+
Y.stride(1),
|
|
217
|
+
X,
|
|
218
|
+
X.stride(0),
|
|
219
|
+
X.stride(1),
|
|
220
|
+
Mean,
|
|
221
|
+
Mean.stride(0),
|
|
222
|
+
Mean.stride(1),
|
|
223
|
+
RSTD,
|
|
224
|
+
RSTD.stride(0),
|
|
225
|
+
RSTD.stride(1),
|
|
226
|
+
W,
|
|
227
|
+
B,
|
|
228
|
+
hidden_size,
|
|
229
|
+
channels_per_group,
|
|
230
|
+
eps,
|
|
231
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
232
|
+
)
|
|
233
|
+
# Return tensors in the original shape
|
|
234
|
+
return Y.view(*shape), X.view(*shape), Mean, RSTD, BLOCK_SIZE
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
def group_norm_backward(dY, X, W, B, Mean, RSTD, num_channels, num_groups):
|
|
238
|
+
shape = dY.shape
|
|
239
|
+
batch_size = shape[0]
|
|
240
|
+
hidden_size = dY.shape[-1]
|
|
241
|
+
channels_per_group = num_channels // num_groups
|
|
242
|
+
dY = dY.view(batch_size, num_groups, -1)
|
|
243
|
+
DX = torch.empty(
|
|
244
|
+
(batch_size, num_groups, hidden_size * channels_per_group),
|
|
245
|
+
dtype=X.dtype,
|
|
246
|
+
device=X.device,
|
|
247
|
+
)
|
|
248
|
+
DW = torch.zeros((num_channels), dtype=W.dtype, device=W.device)
|
|
249
|
+
DB = torch.zeros((num_channels), dtype=B.dtype, device=B.device)
|
|
250
|
+
triton_dtype = tl.float32 if X.dtype == torch.float32 else tl.bfloat16
|
|
251
|
+
|
|
252
|
+
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(hidden_size))
|
|
253
|
+
_group_norm_backward_kernel[(batch_size, num_groups)](
|
|
254
|
+
X,
|
|
255
|
+
X.stride(0),
|
|
256
|
+
X.stride(1),
|
|
257
|
+
W,
|
|
258
|
+
Mean,
|
|
259
|
+
Mean.stride(0),
|
|
260
|
+
Mean.stride(1),
|
|
261
|
+
RSTD,
|
|
262
|
+
DX,
|
|
263
|
+
DW,
|
|
264
|
+
DB,
|
|
265
|
+
dY,
|
|
266
|
+
hidden_size,
|
|
267
|
+
channels_per_group,
|
|
268
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
269
|
+
dtype=triton_dtype,
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
# Return tensors in the original shape
|
|
273
|
+
return DX.view(*shape), DW, DB
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
class LigerGroupNormFunction(torch.autograd.Function):
|
|
277
|
+
@staticmethod
|
|
278
|
+
@ensure_contiguous
|
|
279
|
+
def forward(
|
|
280
|
+
ctx,
|
|
281
|
+
X,
|
|
282
|
+
affine_scaling_weight,
|
|
283
|
+
affine_shifting_bias,
|
|
284
|
+
num_channels,
|
|
285
|
+
num_groups,
|
|
286
|
+
eps,
|
|
287
|
+
):
|
|
288
|
+
Y, X, Mean, RSTD, BLOCK_SIZE = group_norm_forward(
|
|
289
|
+
X,
|
|
290
|
+
num_channels,
|
|
291
|
+
num_groups,
|
|
292
|
+
affine_scaling_weight,
|
|
293
|
+
affine_shifting_bias,
|
|
294
|
+
eps,
|
|
295
|
+
)
|
|
296
|
+
ctx.num_channels = num_channels
|
|
297
|
+
ctx.num_groups = num_groups
|
|
298
|
+
ctx.save_for_backward(X, affine_scaling_weight, affine_shifting_bias, Mean, RSTD)
|
|
299
|
+
return Y
|
|
300
|
+
|
|
301
|
+
@staticmethod
|
|
302
|
+
@ensure_contiguous
|
|
303
|
+
def backward(ctx, dY):
|
|
304
|
+
X, W, B, Mean, RSTD = ctx.saved_tensors
|
|
305
|
+
DX, DW, DB = group_norm_backward(dY, X, W, B, Mean, RSTD, ctx.num_channels, ctx.num_groups)
|
|
306
|
+
return DX, DW, DB, None, None, None
|