liger-kernel-nightly 0.4.0.dev20241107194223__tar.gz → 0.4.0.dev20241108173850__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of liger-kernel-nightly might be problematic. Click here for more details.
- {liger_kernel_nightly-0.4.0.dev20241107194223/src/liger_kernel_nightly.egg-info → liger_kernel_nightly-0.4.0.dev20241108173850}/PKG-INFO +2 -2
- {liger_kernel_nightly-0.4.0.dev20241107194223 → liger_kernel_nightly-0.4.0.dev20241108173850}/README.md +1 -1
- {liger_kernel_nightly-0.4.0.dev20241107194223 → liger_kernel_nightly-0.4.0.dev20241108173850}/pyproject.toml +1 -1
- liger_kernel_nightly-0.4.0.dev20241108173850/src/liger_kernel/ops/group_norm.py +322 -0
- {liger_kernel_nightly-0.4.0.dev20241107194223 → liger_kernel_nightly-0.4.0.dev20241108173850}/src/liger_kernel/transformers/functional.py +2 -0
- liger_kernel_nightly-0.4.0.dev20241108173850/src/liger_kernel/transformers/group_norm.py +56 -0
- {liger_kernel_nightly-0.4.0.dev20241107194223 → liger_kernel_nightly-0.4.0.dev20241108173850/src/liger_kernel_nightly.egg-info}/PKG-INFO +2 -2
- {liger_kernel_nightly-0.4.0.dev20241107194223 → liger_kernel_nightly-0.4.0.dev20241108173850}/src/liger_kernel_nightly.egg-info/SOURCES.txt +2 -0
- {liger_kernel_nightly-0.4.0.dev20241107194223 → liger_kernel_nightly-0.4.0.dev20241108173850}/LICENSE +0 -0
- {liger_kernel_nightly-0.4.0.dev20241107194223 → liger_kernel_nightly-0.4.0.dev20241108173850}/NOTICE +0 -0
- {liger_kernel_nightly-0.4.0.dev20241107194223 → liger_kernel_nightly-0.4.0.dev20241108173850}/setup.cfg +0 -0
- {liger_kernel_nightly-0.4.0.dev20241107194223 → liger_kernel_nightly-0.4.0.dev20241108173850}/src/liger_kernel/env_report.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241107194223 → liger_kernel_nightly-0.4.0.dev20241108173850}/src/liger_kernel/ops/__init__.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241107194223 → liger_kernel_nightly-0.4.0.dev20241108173850}/src/liger_kernel/ops/cross_entropy.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241107194223 → liger_kernel_nightly-0.4.0.dev20241108173850}/src/liger_kernel/ops/experimental/embedding.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241107194223 → liger_kernel_nightly-0.4.0.dev20241108173850}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241107194223 → liger_kernel_nightly-0.4.0.dev20241108173850}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241107194223 → liger_kernel_nightly-0.4.0.dev20241108173850}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241107194223 → liger_kernel_nightly-0.4.0.dev20241108173850}/src/liger_kernel/ops/geglu.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241107194223 → liger_kernel_nightly-0.4.0.dev20241108173850}/src/liger_kernel/ops/jsd.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241107194223 → liger_kernel_nightly-0.4.0.dev20241108173850}/src/liger_kernel/ops/kl_div.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241107194223 → liger_kernel_nightly-0.4.0.dev20241108173850}/src/liger_kernel/ops/layer_norm.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241107194223 → liger_kernel_nightly-0.4.0.dev20241108173850}/src/liger_kernel/ops/rms_norm.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241107194223 → liger_kernel_nightly-0.4.0.dev20241108173850}/src/liger_kernel/ops/rope.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241107194223 → liger_kernel_nightly-0.4.0.dev20241108173850}/src/liger_kernel/ops/swiglu.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241107194223 → liger_kernel_nightly-0.4.0.dev20241108173850}/src/liger_kernel/ops/utils.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241107194223 → liger_kernel_nightly-0.4.0.dev20241108173850}/src/liger_kernel/transformers/__init__.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241107194223 → liger_kernel_nightly-0.4.0.dev20241108173850}/src/liger_kernel/transformers/auto_model.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241107194223 → liger_kernel_nightly-0.4.0.dev20241108173850}/src/liger_kernel/transformers/cross_entropy.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241107194223 → liger_kernel_nightly-0.4.0.dev20241108173850}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241107194223 → liger_kernel_nightly-0.4.0.dev20241108173850}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241107194223 → liger_kernel_nightly-0.4.0.dev20241108173850}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241107194223 → liger_kernel_nightly-0.4.0.dev20241108173850}/src/liger_kernel/transformers/geglu.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241107194223 → liger_kernel_nightly-0.4.0.dev20241108173850}/src/liger_kernel/transformers/jsd.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241107194223 → liger_kernel_nightly-0.4.0.dev20241108173850}/src/liger_kernel/transformers/kl_div.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241107194223 → liger_kernel_nightly-0.4.0.dev20241108173850}/src/liger_kernel/transformers/layer_norm.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241107194223 → liger_kernel_nightly-0.4.0.dev20241108173850}/src/liger_kernel/transformers/model/__init__.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241107194223 → liger_kernel_nightly-0.4.0.dev20241108173850}/src/liger_kernel/transformers/model/gemma.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241107194223 → liger_kernel_nightly-0.4.0.dev20241108173850}/src/liger_kernel/transformers/model/llama.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241107194223 → liger_kernel_nightly-0.4.0.dev20241108173850}/src/liger_kernel/transformers/model/mistral.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241107194223 → liger_kernel_nightly-0.4.0.dev20241108173850}/src/liger_kernel/transformers/model/mixtral.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241107194223 → liger_kernel_nightly-0.4.0.dev20241108173850}/src/liger_kernel/transformers/model/mllama.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241107194223 → liger_kernel_nightly-0.4.0.dev20241108173850}/src/liger_kernel/transformers/model/phi3.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241107194223 → liger_kernel_nightly-0.4.0.dev20241108173850}/src/liger_kernel/transformers/model/qwen2.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241107194223 → liger_kernel_nightly-0.4.0.dev20241108173850}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241107194223 → liger_kernel_nightly-0.4.0.dev20241108173850}/src/liger_kernel/transformers/monkey_patch.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241107194223 → liger_kernel_nightly-0.4.0.dev20241108173850}/src/liger_kernel/transformers/rms_norm.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241107194223 → liger_kernel_nightly-0.4.0.dev20241108173850}/src/liger_kernel/transformers/rope.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241107194223 → liger_kernel_nightly-0.4.0.dev20241108173850}/src/liger_kernel/transformers/swiglu.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241107194223 → liger_kernel_nightly-0.4.0.dev20241108173850}/src/liger_kernel/transformers/trainer_integration.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241107194223 → liger_kernel_nightly-0.4.0.dev20241108173850}/src/liger_kernel/triton/__init__.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241107194223 → liger_kernel_nightly-0.4.0.dev20241108173850}/src/liger_kernel/triton/monkey_patch.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241107194223 → liger_kernel_nightly-0.4.0.dev20241108173850}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
- {liger_kernel_nightly-0.4.0.dev20241107194223 → liger_kernel_nightly-0.4.0.dev20241108173850}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
- {liger_kernel_nightly-0.4.0.dev20241107194223 → liger_kernel_nightly-0.4.0.dev20241108173850}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: liger_kernel_nightly
|
|
3
|
-
Version: 0.4.0.
|
|
3
|
+
Version: 0.4.0.dev20241108173850
|
|
4
4
|
Summary: Efficient Triton kernels for LLM Training
|
|
5
5
|
License: BSD 2-CLAUSE LICENSE
|
|
6
6
|
Copyright 2024 LinkedIn Corporation
|
|
@@ -332,7 +332,7 @@ $$\text{GeGLU}(x)=\text{GELU}(xW+b)\otimes(xV+c)$$
|
|
|
332
332
|
- **FusedLinearCrossEntropy**: Peak memory usage of cross entropy loss is further improved by fusing the model head with the CE loss and chunking the input for block-wise loss and gradient calculation, a technique inspired by [Efficient Cross Entropy](https://github.com/mgmalek/efficient_cross_entropy). It achieves >4X memory reduction for 128k vocab size. **This is highly effective for large batch size, large sequence length, and large vocabulary sizes.** Please refer to the [Medusa example](https://github.com/linkedin/Liger-Kernel/tree/main/examples/medusa) for individual kernel usage.
|
|
333
333
|
- **KLDivergence**: [KL Divergence](https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html) is implemented by fusing the forward into a single triton kernel, with reduction done outside the kernel. It achieves ~1.5X speed and ~15% memory reduction for 128K vocab size.
|
|
334
334
|
- **JSD**: [Generalized JSD](https://arxiv.org/pdf/2306.13649) (Jensen-Shannon divergence), is implemented by computing both the loss and gradient in the forward pass. It achieves ~1.5X speed and ~54% memory reduction for 128k vocab size.
|
|
335
|
-
- **FusedLinearJSD**: Peak memory usage of JSD loss is further improved by fusing the model head with the
|
|
335
|
+
- **FusedLinearJSD**: Peak memory usage of JSD loss is further improved by fusing the model head with the JSD and chunking the input for block-wise loss and gradient calculation. It achieves ~85% memory reduction for 128k vocab size where batch size $\times$ sequence length is 8192.
|
|
336
336
|
|
|
337
337
|
|
|
338
338
|
### Experimental Kernels
|
|
@@ -285,7 +285,7 @@ $$\text{GeGLU}(x)=\text{GELU}(xW+b)\otimes(xV+c)$$
|
|
|
285
285
|
- **FusedLinearCrossEntropy**: Peak memory usage of cross entropy loss is further improved by fusing the model head with the CE loss and chunking the input for block-wise loss and gradient calculation, a technique inspired by [Efficient Cross Entropy](https://github.com/mgmalek/efficient_cross_entropy). It achieves >4X memory reduction for 128k vocab size. **This is highly effective for large batch size, large sequence length, and large vocabulary sizes.** Please refer to the [Medusa example](https://github.com/linkedin/Liger-Kernel/tree/main/examples/medusa) for individual kernel usage.
|
|
286
286
|
- **KLDivergence**: [KL Divergence](https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html) is implemented by fusing the forward into a single triton kernel, with reduction done outside the kernel. It achieves ~1.5X speed and ~15% memory reduction for 128K vocab size.
|
|
287
287
|
- **JSD**: [Generalized JSD](https://arxiv.org/pdf/2306.13649) (Jensen-Shannon divergence), is implemented by computing both the loss and gradient in the forward pass. It achieves ~1.5X speed and ~54% memory reduction for 128k vocab size.
|
|
288
|
-
- **FusedLinearJSD**: Peak memory usage of JSD loss is further improved by fusing the model head with the
|
|
288
|
+
- **FusedLinearJSD**: Peak memory usage of JSD loss is further improved by fusing the model head with the JSD and chunking the input for block-wise loss and gradient calculation. It achieves ~85% memory reduction for 128k vocab size where batch size $\times$ sequence length is 8192.
|
|
289
289
|
|
|
290
290
|
|
|
291
291
|
### Experimental Kernels
|
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "liger_kernel_nightly"
|
|
7
|
-
version = "0.4.0.
|
|
7
|
+
version = "0.4.0.dev20241108173850"
|
|
8
8
|
description = "Efficient Triton kernels for LLM Training"
|
|
9
9
|
urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
|
|
10
10
|
readme = { file = "README.md", content-type = "text/markdown" }
|
|
@@ -0,0 +1,322 @@
|
|
|
1
|
+
import operator
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import triton
|
|
5
|
+
import triton.language as tl
|
|
6
|
+
|
|
7
|
+
from liger_kernel.ops.utils import compare_version, ensure_contiguous
|
|
8
|
+
|
|
9
|
+
if compare_version("triton", operator.ge, "3.0.0"):
|
|
10
|
+
try:
|
|
11
|
+
# typical import path with dispatch available
|
|
12
|
+
from triton.language.extra.libdevice import rsqrt
|
|
13
|
+
except ModuleNotFoundError:
|
|
14
|
+
# for working with NGC containers
|
|
15
|
+
from triton.language.extra.cuda.libdevice import rsqrt
|
|
16
|
+
else:
|
|
17
|
+
from triton.language.math import rsqrt
|
|
18
|
+
|
|
19
|
+
MAX_FUSED_SIZE = 65536
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@triton.jit
|
|
23
|
+
def _group_norm_forward_kernel(
|
|
24
|
+
Y_ptr, # pointer to output, shape (n_rows, n_groups, hidden_size)
|
|
25
|
+
Y_row_stride, # stride of each row in output
|
|
26
|
+
Y_col_stride, # stride of each column in output
|
|
27
|
+
X_ptr, # pointer to input, shape (n_rows, n_groups, hidden_size)
|
|
28
|
+
X_row_stride, # stride of each row in input
|
|
29
|
+
X_col_stride, # stride of each column in input
|
|
30
|
+
Mean_ptr, # pointer to mean, shape (n_rows, n_groups)
|
|
31
|
+
Mean_row_stride, # stride of each row in mean
|
|
32
|
+
Mean_col_stride, # stride of each column in mean
|
|
33
|
+
RSTD_ptr, # pointer to rstd, shape (n_rows, n_groups)
|
|
34
|
+
RSTD_row_stride, # stride of each row in rstd
|
|
35
|
+
RSTD_col_stride, # stride of each column in rstd
|
|
36
|
+
W_ptr, # pointer to W
|
|
37
|
+
B_ptr, # pointer to B
|
|
38
|
+
hidden_size, # hidden size of X
|
|
39
|
+
channels_per_group, # the number of channels per group
|
|
40
|
+
eps,
|
|
41
|
+
BLOCK_SIZE: tl.constexpr,
|
|
42
|
+
):
|
|
43
|
+
"""
|
|
44
|
+
References:
|
|
45
|
+
https://nn.labml.ai/normalization/group_norm/index.html
|
|
46
|
+
"""
|
|
47
|
+
batch_idx = tl.program_id(0)
|
|
48
|
+
group_idx = tl.program_id(1)
|
|
49
|
+
|
|
50
|
+
X_ptr += batch_idx * X_row_stride + group_idx * X_col_stride
|
|
51
|
+
Y_ptr += batch_idx * Y_row_stride + group_idx * Y_col_stride
|
|
52
|
+
|
|
53
|
+
block_range = tl.arange(0, BLOCK_SIZE)
|
|
54
|
+
|
|
55
|
+
# Compute mean and variance using the online algorithm
|
|
56
|
+
s = 0.0
|
|
57
|
+
squared_sum = 0.0
|
|
58
|
+
for i in tl.range(0, hidden_size, BLOCK_SIZE):
|
|
59
|
+
hidden_size_offsets = i + block_range
|
|
60
|
+
mask = hidden_size_offsets < hidden_size
|
|
61
|
+
X = tl.load(X_ptr + hidden_size_offsets, mask=mask, other=0.0)
|
|
62
|
+
s += tl.sum(X)
|
|
63
|
+
# X**2
|
|
64
|
+
squared_sum += tl.sum(X * X)
|
|
65
|
+
|
|
66
|
+
m = s / hidden_size
|
|
67
|
+
|
|
68
|
+
# variance = E[X**2] - E[X]**2
|
|
69
|
+
variance = (squared_sum / hidden_size) - (m * m)
|
|
70
|
+
|
|
71
|
+
# 1/std
|
|
72
|
+
rstd = rsqrt(variance + eps)
|
|
73
|
+
|
|
74
|
+
# Normalize
|
|
75
|
+
hidden_size_per_channel = hidden_size // channels_per_group
|
|
76
|
+
for channel_idx in tl.range(
|
|
77
|
+
group_idx * channels_per_group, (group_idx + 1) * channels_per_group
|
|
78
|
+
):
|
|
79
|
+
W = tl.load(W_ptr + channel_idx)
|
|
80
|
+
B = tl.load(B_ptr + channel_idx)
|
|
81
|
+
for i in range(0, hidden_size_per_channel, BLOCK_SIZE):
|
|
82
|
+
hidden_size_offsets = i + block_range
|
|
83
|
+
mask = hidden_size_offsets < hidden_size_per_channel
|
|
84
|
+
X = tl.load(X_ptr + hidden_size_offsets, mask=mask, other=m)
|
|
85
|
+
Y = (X - m) * rstd * W + B
|
|
86
|
+
tl.store(Y_ptr + hidden_size_offsets, Y, mask=mask)
|
|
87
|
+
|
|
88
|
+
X_ptr += hidden_size_per_channel
|
|
89
|
+
Y_ptr += hidden_size_per_channel
|
|
90
|
+
|
|
91
|
+
tl.store(Mean_ptr + batch_idx * Mean_row_stride + group_idx * Mean_col_stride, m)
|
|
92
|
+
tl.store(RSTD_ptr + batch_idx * RSTD_row_stride + group_idx * RSTD_col_stride, rstd)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
@triton.jit
|
|
96
|
+
def _group_norm_backward_kernel(
|
|
97
|
+
X_ptr, # pointer to input, shape (n_rows, n_channels, hidden_size)
|
|
98
|
+
X_row_stride, # stride of each row in input
|
|
99
|
+
X_col_stride, # stride of each column in input
|
|
100
|
+
W_ptr, # pointer to weights, shape (n_channels)
|
|
101
|
+
Mean_ptr, # pointer to mean, shape (n_rows, n_groups)
|
|
102
|
+
Mean_ptr_row_stride, # stride of each column in mean
|
|
103
|
+
Mean_ptr_col_stride, # stride of each column in mean
|
|
104
|
+
RSTD_ptr, # pointer to rstd, shape (n_rows, n_groups)
|
|
105
|
+
DX_ptr, # pointer to input grad, shape (n_rows, n_groups, hidden_size)
|
|
106
|
+
DW_ptr, # pointer to weights grad, shape (n_channels)
|
|
107
|
+
DB_ptr, # pointer to bias grad, shape (n_channels)
|
|
108
|
+
UPSTREAM_ptr, # pointer to output grad, shape (n_rows, n_channels, hidden_size)
|
|
109
|
+
hidden_size: tl.constexpr, # hidden size
|
|
110
|
+
channels_per_group: tl.constexpr, # number of groups in group norm
|
|
111
|
+
BLOCK_SIZE: tl.constexpr,
|
|
112
|
+
dtype: tl.constexpr,
|
|
113
|
+
):
|
|
114
|
+
"""
|
|
115
|
+
References:
|
|
116
|
+
https://nn.labml.ai/normalization/group_norm/index.html
|
|
117
|
+
https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
|
|
118
|
+
|
|
119
|
+
The backprop equations are the same for group_norm and layer_norm
|
|
120
|
+
the only difference here is that we load the Mean, Rstd corresponding to the
|
|
121
|
+
group we're computing gradients for and the mean and rstd are computed over n-channels
|
|
122
|
+
so the total number of elements we compute the mean over is num_channels_per_group * hidden_size
|
|
123
|
+
|
|
124
|
+
We also need to load the Weights corresponding to the current channel to compute the gradients.
|
|
125
|
+
"""
|
|
126
|
+
batch_idx = tl.program_id(0)
|
|
127
|
+
group_idx = tl.program_id(1)
|
|
128
|
+
|
|
129
|
+
# Move the pointers to the correct batch
|
|
130
|
+
X_ptr += batch_idx * X_row_stride
|
|
131
|
+
DX_ptr += batch_idx * X_row_stride
|
|
132
|
+
UPSTREAM_ptr += batch_idx * X_row_stride
|
|
133
|
+
|
|
134
|
+
# Mean and rstd are the same shape so have the same strides
|
|
135
|
+
mean = tl.load(
|
|
136
|
+
Mean_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride
|
|
137
|
+
)
|
|
138
|
+
rstd = tl.load(
|
|
139
|
+
RSTD_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
c1 = 0.0
|
|
143
|
+
c2 = 0.0
|
|
144
|
+
block_range = tl.arange(0, BLOCK_SIZE)
|
|
145
|
+
|
|
146
|
+
# We need to compute the sum terms of the backprop equations across all channels in the group
|
|
147
|
+
for channel_idx in range(
|
|
148
|
+
group_idx * channels_per_group, (group_idx + 1) * channels_per_group
|
|
149
|
+
):
|
|
150
|
+
dW = 0.0
|
|
151
|
+
dB = 0.0
|
|
152
|
+
# Move the pointers to the correct channel
|
|
153
|
+
W = tl.load(W_ptr + channel_idx)
|
|
154
|
+
for i in tl.range(0, hidden_size, BLOCK_SIZE):
|
|
155
|
+
hidden_size_offsets = i + block_range
|
|
156
|
+
mask = hidden_size_offsets < hidden_size
|
|
157
|
+
X = tl.load(
|
|
158
|
+
X_ptr + channel_idx * X_col_stride + hidden_size_offsets,
|
|
159
|
+
mask=mask,
|
|
160
|
+
other=0.0,
|
|
161
|
+
)
|
|
162
|
+
UPSTREAM_grad = tl.load(
|
|
163
|
+
UPSTREAM_ptr + channel_idx * X_col_stride + hidden_size_offsets,
|
|
164
|
+
mask=mask,
|
|
165
|
+
other=0.0,
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
x_hat = (X - mean) * rstd
|
|
169
|
+
dW += tl.sum(UPSTREAM_grad * x_hat)
|
|
170
|
+
dB += tl.sum(UPSTREAM_grad)
|
|
171
|
+
|
|
172
|
+
wdy = W * UPSTREAM_grad
|
|
173
|
+
c1 += tl.sum(x_hat * wdy)
|
|
174
|
+
c2 += tl.sum(wdy)
|
|
175
|
+
|
|
176
|
+
# Need to ensure additions to the same channel are atomic
|
|
177
|
+
tl.atomic_add(DW_ptr + channel_idx, dW.to(dtype))
|
|
178
|
+
tl.atomic_add(DB_ptr + channel_idx, dB.to(dtype))
|
|
179
|
+
|
|
180
|
+
N = hidden_size * channels_per_group
|
|
181
|
+
c1 = c1 / N
|
|
182
|
+
c2 = c2 / N
|
|
183
|
+
|
|
184
|
+
for channel_idx in tl.range(
|
|
185
|
+
group_idx * channels_per_group, (group_idx + 1) * channels_per_group
|
|
186
|
+
):
|
|
187
|
+
# Move the pointers to the correct channel
|
|
188
|
+
W = tl.load(W_ptr + channel_idx)
|
|
189
|
+
for i in range(0, hidden_size, BLOCK_SIZE):
|
|
190
|
+
hidden_size_offsets = i + block_range
|
|
191
|
+
mask = hidden_size_offsets < hidden_size
|
|
192
|
+
X = tl.load(
|
|
193
|
+
X_ptr + channel_idx * X_col_stride + hidden_size_offsets,
|
|
194
|
+
mask=mask,
|
|
195
|
+
other=0.0,
|
|
196
|
+
)
|
|
197
|
+
UPSTREAM_grad = tl.load(
|
|
198
|
+
UPSTREAM_ptr + channel_idx * X_col_stride + hidden_size_offsets,
|
|
199
|
+
mask=mask,
|
|
200
|
+
other=0.0,
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
x_hat = (X - mean) * rstd
|
|
204
|
+
wdy = W * UPSTREAM_grad
|
|
205
|
+
dx = (wdy - (x_hat * c1 + c2)) * rstd
|
|
206
|
+
tl.store(
|
|
207
|
+
DX_ptr + channel_idx * X_col_stride + hidden_size_offsets, dx, mask=mask
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def group_norm_forward(X, num_channels, num_groups, W, B, eps):
|
|
212
|
+
shape = X.shape
|
|
213
|
+
batch_size = shape[0]
|
|
214
|
+
channels_per_group = num_channels // num_groups
|
|
215
|
+
# Reshape X so that the mean and std are computed across the groups
|
|
216
|
+
X = X.view(batch_size, num_groups, -1).contiguous()
|
|
217
|
+
hidden_size = X.shape[-1]
|
|
218
|
+
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(hidden_size))
|
|
219
|
+
Y = torch.empty(
|
|
220
|
+
(batch_size, num_groups, hidden_size), dtype=X.dtype, device=X.device
|
|
221
|
+
)
|
|
222
|
+
Mean = torch.zeros((batch_size, num_groups), dtype=X.dtype, device=X.device)
|
|
223
|
+
RSTD = torch.zeros((batch_size, num_groups), dtype=X.dtype, device=X.device)
|
|
224
|
+
|
|
225
|
+
_group_norm_forward_kernel[(batch_size, num_groups)](
|
|
226
|
+
Y,
|
|
227
|
+
Y.stride(0),
|
|
228
|
+
Y.stride(1),
|
|
229
|
+
X,
|
|
230
|
+
X.stride(0),
|
|
231
|
+
X.stride(1),
|
|
232
|
+
Mean,
|
|
233
|
+
Mean.stride(0),
|
|
234
|
+
Mean.stride(1),
|
|
235
|
+
RSTD,
|
|
236
|
+
RSTD.stride(0),
|
|
237
|
+
RSTD.stride(1),
|
|
238
|
+
W,
|
|
239
|
+
B,
|
|
240
|
+
hidden_size,
|
|
241
|
+
channels_per_group,
|
|
242
|
+
eps,
|
|
243
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
244
|
+
)
|
|
245
|
+
# Return tensors in the original shape
|
|
246
|
+
return Y.view(*shape), X.view(*shape), Mean, RSTD, BLOCK_SIZE
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
def group_norm_backward(dY, X, W, B, Mean, RSTD, num_channels, num_groups):
|
|
250
|
+
shape = dY.shape
|
|
251
|
+
batch_size = shape[0]
|
|
252
|
+
hidden_size = dY.shape[-1]
|
|
253
|
+
channels_per_group = num_channels // num_groups
|
|
254
|
+
dY = dY.view(batch_size, num_groups, -1)
|
|
255
|
+
DX = torch.empty(
|
|
256
|
+
(batch_size, num_groups, hidden_size * channels_per_group),
|
|
257
|
+
dtype=X.dtype,
|
|
258
|
+
device=X.device,
|
|
259
|
+
)
|
|
260
|
+
DW = torch.zeros((num_channels), dtype=W.dtype, device=W.device)
|
|
261
|
+
DB = torch.zeros((num_channels), dtype=B.dtype, device=B.device)
|
|
262
|
+
triton_dtype = tl.float32 if X.dtype == torch.float32 else tl.bfloat16
|
|
263
|
+
|
|
264
|
+
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(hidden_size))
|
|
265
|
+
_group_norm_backward_kernel[(batch_size, num_groups)](
|
|
266
|
+
X,
|
|
267
|
+
X.stride(0),
|
|
268
|
+
X.stride(1),
|
|
269
|
+
W,
|
|
270
|
+
Mean,
|
|
271
|
+
Mean.stride(0),
|
|
272
|
+
Mean.stride(1),
|
|
273
|
+
RSTD,
|
|
274
|
+
DX,
|
|
275
|
+
DW,
|
|
276
|
+
DB,
|
|
277
|
+
dY,
|
|
278
|
+
hidden_size,
|
|
279
|
+
channels_per_group,
|
|
280
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
281
|
+
dtype=triton_dtype,
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
# Return tensors in the original shape
|
|
285
|
+
return DX.view(*shape), DW, DB
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
class LigerGroupNormFunction(torch.autograd.Function):
|
|
289
|
+
@staticmethod
|
|
290
|
+
@ensure_contiguous
|
|
291
|
+
def forward(
|
|
292
|
+
ctx,
|
|
293
|
+
X,
|
|
294
|
+
affine_scaling_weight,
|
|
295
|
+
affine_shifting_bias,
|
|
296
|
+
num_channels,
|
|
297
|
+
num_groups,
|
|
298
|
+
eps,
|
|
299
|
+
):
|
|
300
|
+
Y, X, Mean, RSTD, BLOCK_SIZE = group_norm_forward(
|
|
301
|
+
X,
|
|
302
|
+
num_channels,
|
|
303
|
+
num_groups,
|
|
304
|
+
affine_scaling_weight,
|
|
305
|
+
affine_shifting_bias,
|
|
306
|
+
eps,
|
|
307
|
+
)
|
|
308
|
+
ctx.num_channels = num_channels
|
|
309
|
+
ctx.num_groups = num_groups
|
|
310
|
+
ctx.save_for_backward(
|
|
311
|
+
X, affine_scaling_weight, affine_shifting_bias, Mean, RSTD
|
|
312
|
+
)
|
|
313
|
+
return Y
|
|
314
|
+
|
|
315
|
+
@staticmethod
|
|
316
|
+
@ensure_contiguous
|
|
317
|
+
def backward(ctx, dY):
|
|
318
|
+
X, W, B, Mean, RSTD = ctx.saved_tensors
|
|
319
|
+
DX, DW, DB = group_norm_backward(
|
|
320
|
+
dY, X, W, B, Mean, RSTD, ctx.num_channels, ctx.num_groups
|
|
321
|
+
)
|
|
322
|
+
return DX, DW, DB, None, None, None
|
|
@@ -4,6 +4,7 @@ from liger_kernel.ops.fused_linear_cross_entropy import (
|
|
|
4
4
|
)
|
|
5
5
|
from liger_kernel.ops.fused_linear_jsd import LigerFusedLinearJSDFunction
|
|
6
6
|
from liger_kernel.ops.geglu import LigerGELUMulFunction
|
|
7
|
+
from liger_kernel.ops.group_norm import LigerGroupNormFunction
|
|
7
8
|
from liger_kernel.ops.jsd import LigerJSDFunction
|
|
8
9
|
from liger_kernel.ops.kl_div import LigerKLDivLossFunction
|
|
9
10
|
from liger_kernel.ops.layer_norm import LigerLayerNormFunction
|
|
@@ -21,3 +22,4 @@ liger_layer_norm = LigerLayerNormFunction.apply
|
|
|
21
22
|
liger_kl_div = LigerKLDivLossFunction.apply
|
|
22
23
|
liger_jsd = LigerJSDFunction.apply
|
|
23
24
|
liger_fused_linear_jsd = LigerFusedLinearJSDFunction.apply
|
|
25
|
+
liger_group_norm = LigerGroupNormFunction.apply
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
|
|
4
|
+
from liger_kernel.ops.group_norm import LigerGroupNormFunction
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class LigerGroupNorm(nn.Module):
|
|
8
|
+
def __init__(self, num_channels, num_groups, eps=1e-6, bias=False, init_fn="ones"):
|
|
9
|
+
"""
|
|
10
|
+
A Group Normalization layer.
|
|
11
|
+
Args:
|
|
12
|
+
num_channels (int): Number of channels in the input tensor.
|
|
13
|
+
num_groups (int): Number of groups to divide the channels into.
|
|
14
|
+
eps (float, optional): A value added to the denominator for numerical stability. Default: 1e-6.
|
|
15
|
+
bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``False``.
|
|
16
|
+
init_fn (str, optional): Initialization function for the learnable parameters. Default: "ones".
|
|
17
|
+
"""
|
|
18
|
+
super().__init__()
|
|
19
|
+
assert init_fn in [
|
|
20
|
+
"ones",
|
|
21
|
+
"zeros",
|
|
22
|
+
], f"init_fn must be either 'ones' or 'zeros', got {init_fn}"
|
|
23
|
+
|
|
24
|
+
assert (
|
|
25
|
+
num_channels % num_groups == 0
|
|
26
|
+
), f"Number of channels {num_channels} must be divisible by num_groups {num_groups}"
|
|
27
|
+
self.num_channels = num_channels
|
|
28
|
+
self.num_groups = num_groups
|
|
29
|
+
self.eps = eps
|
|
30
|
+
self.weight = nn.Parameter(
|
|
31
|
+
torch.ones(num_channels) if init_fn == "ones" else torch.zeros(num_channels)
|
|
32
|
+
)
|
|
33
|
+
self.bias = nn.Parameter(
|
|
34
|
+
torch.randn(num_channels) if bias else torch.zeros(num_channels)
|
|
35
|
+
)
|
|
36
|
+
self.variance_epsilon = eps
|
|
37
|
+
|
|
38
|
+
def forward(self, hidden_states):
|
|
39
|
+
# hidden_states: (batch_size, num_channels, *)
|
|
40
|
+
assert (
|
|
41
|
+
hidden_states.dim() >= 3
|
|
42
|
+
), f"Input must have atleast 3 dimensions, got {hidden_states.dim()}"
|
|
43
|
+
assert (
|
|
44
|
+
hidden_states.size(1) == self.num_channels
|
|
45
|
+
), f"Input tensor must have {self.num_channels} channels, got {hidden_states.size(1)}"
|
|
46
|
+
return LigerGroupNormFunction.apply(
|
|
47
|
+
hidden_states,
|
|
48
|
+
self.weight,
|
|
49
|
+
self.bias,
|
|
50
|
+
self.num_channels,
|
|
51
|
+
self.num_groups,
|
|
52
|
+
self.variance_epsilon,
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
def extra_repr(self):
|
|
56
|
+
return f"{self.hidden_size}, num_channels={self.num_channels}, num_groups={self.num_groups}, eps={self.eps}"
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: liger_kernel_nightly
|
|
3
|
-
Version: 0.4.0.
|
|
3
|
+
Version: 0.4.0.dev20241108173850
|
|
4
4
|
Summary: Efficient Triton kernels for LLM Training
|
|
5
5
|
License: BSD 2-CLAUSE LICENSE
|
|
6
6
|
Copyright 2024 LinkedIn Corporation
|
|
@@ -332,7 +332,7 @@ $$\text{GeGLU}(x)=\text{GELU}(xW+b)\otimes(xV+c)$$
|
|
|
332
332
|
- **FusedLinearCrossEntropy**: Peak memory usage of cross entropy loss is further improved by fusing the model head with the CE loss and chunking the input for block-wise loss and gradient calculation, a technique inspired by [Efficient Cross Entropy](https://github.com/mgmalek/efficient_cross_entropy). It achieves >4X memory reduction for 128k vocab size. **This is highly effective for large batch size, large sequence length, and large vocabulary sizes.** Please refer to the [Medusa example](https://github.com/linkedin/Liger-Kernel/tree/main/examples/medusa) for individual kernel usage.
|
|
333
333
|
- **KLDivergence**: [KL Divergence](https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html) is implemented by fusing the forward into a single triton kernel, with reduction done outside the kernel. It achieves ~1.5X speed and ~15% memory reduction for 128K vocab size.
|
|
334
334
|
- **JSD**: [Generalized JSD](https://arxiv.org/pdf/2306.13649) (Jensen-Shannon divergence), is implemented by computing both the loss and gradient in the forward pass. It achieves ~1.5X speed and ~54% memory reduction for 128k vocab size.
|
|
335
|
-
- **FusedLinearJSD**: Peak memory usage of JSD loss is further improved by fusing the model head with the
|
|
335
|
+
- **FusedLinearJSD**: Peak memory usage of JSD loss is further improved by fusing the model head with the JSD and chunking the input for block-wise loss and gradient calculation. It achieves ~85% memory reduction for 128k vocab size where batch size $\times$ sequence length is 8192.
|
|
336
336
|
|
|
337
337
|
|
|
338
338
|
### Experimental Kernels
|
|
@@ -8,6 +8,7 @@ src/liger_kernel/ops/cross_entropy.py
|
|
|
8
8
|
src/liger_kernel/ops/fused_linear_cross_entropy.py
|
|
9
9
|
src/liger_kernel/ops/fused_linear_jsd.py
|
|
10
10
|
src/liger_kernel/ops/geglu.py
|
|
11
|
+
src/liger_kernel/ops/group_norm.py
|
|
11
12
|
src/liger_kernel/ops/jsd.py
|
|
12
13
|
src/liger_kernel/ops/kl_div.py
|
|
13
14
|
src/liger_kernel/ops/layer_norm.py
|
|
@@ -24,6 +25,7 @@ src/liger_kernel/transformers/functional.py
|
|
|
24
25
|
src/liger_kernel/transformers/fused_linear_cross_entropy.py
|
|
25
26
|
src/liger_kernel/transformers/fused_linear_jsd.py
|
|
26
27
|
src/liger_kernel/transformers/geglu.py
|
|
28
|
+
src/liger_kernel/transformers/group_norm.py
|
|
27
29
|
src/liger_kernel/transformers/jsd.py
|
|
28
30
|
src/liger_kernel/transformers/kl_div.py
|
|
29
31
|
src/liger_kernel/transformers/layer_norm.py
|
|
File without changes
|
{liger_kernel_nightly-0.4.0.dev20241107194223 → liger_kernel_nightly-0.4.0.dev20241108173850}/NOTICE
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|