liger-kernel-nightly 0.0.1.dev20240819184814__py3-none-any.whl → 0.6.4.dev20251212103629__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (126) hide show
  1. liger_kernel/__init__.py +0 -0
  2. liger_kernel/chunked_loss/README.md +25 -0
  3. liger_kernel/chunked_loss/__init__.py +8 -0
  4. liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
  5. liger_kernel/chunked_loss/cpo_loss.py +157 -0
  6. liger_kernel/chunked_loss/dpo_loss.py +229 -0
  7. liger_kernel/chunked_loss/functional.py +17 -0
  8. liger_kernel/chunked_loss/fused_linear_distillation.py +292 -0
  9. liger_kernel/chunked_loss/fused_linear_ppo.py +366 -0
  10. liger_kernel/chunked_loss/fused_linear_preference.py +433 -0
  11. liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +341 -0
  12. liger_kernel/chunked_loss/grpo_loss.py +307 -0
  13. liger_kernel/chunked_loss/jsd_loss.py +200 -0
  14. liger_kernel/chunked_loss/kto_loss.py +210 -0
  15. liger_kernel/chunked_loss/orpo_loss.py +144 -0
  16. liger_kernel/chunked_loss/simpo_loss.py +165 -0
  17. liger_kernel/env_report.py +63 -0
  18. liger_kernel/ops/__init__.py +141 -0
  19. liger_kernel/ops/backends/README.md +151 -0
  20. liger_kernel/ops/backends/__init__.py +13 -0
  21. liger_kernel/ops/backends/_ascend/__init__.py +5 -0
  22. liger_kernel/ops/backends/_ascend/ops/__init__.py +15 -0
  23. liger_kernel/ops/backends/registry.py +61 -0
  24. liger_kernel/ops/cross_entropy.py +383 -114
  25. liger_kernel/ops/dyt.py +160 -0
  26. liger_kernel/ops/experimental/embedding.py +141 -0
  27. liger_kernel/ops/experimental/mm_int8int2.py +349 -0
  28. liger_kernel/ops/fused_add_rms_norm.py +416 -0
  29. liger_kernel/ops/fused_linear_cross_entropy.py +346 -132
  30. liger_kernel/ops/fused_linear_jsd.py +228 -0
  31. liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
  32. liger_kernel/ops/geglu.py +66 -64
  33. liger_kernel/ops/group_norm.py +306 -0
  34. liger_kernel/ops/grpo_loss.py +312 -0
  35. liger_kernel/ops/jsd.py +201 -0
  36. liger_kernel/ops/kl_div.py +262 -0
  37. liger_kernel/ops/layer_norm.py +320 -0
  38. liger_kernel/ops/llama4_rope.py +225 -0
  39. liger_kernel/ops/multi_token_attention.py +207 -0
  40. liger_kernel/ops/poly_norm.py +390 -0
  41. liger_kernel/ops/qwen2vl_mrope.py +222 -0
  42. liger_kernel/ops/rms_norm.py +484 -88
  43. liger_kernel/ops/rope.py +122 -117
  44. liger_kernel/ops/softmax.py +201 -0
  45. liger_kernel/ops/sparsemax.py +179 -0
  46. liger_kernel/ops/swiglu.py +68 -65
  47. liger_kernel/ops/tiled_mlp.py +136 -0
  48. liger_kernel/ops/tvd.py +207 -0
  49. liger_kernel/ops/utils.py +82 -3
  50. liger_kernel/transformers/__init__.py +218 -6
  51. liger_kernel/transformers/auto_model.py +38 -0
  52. liger_kernel/transformers/cross_entropy.py +52 -7
  53. liger_kernel/transformers/dyt.py +22 -0
  54. liger_kernel/transformers/experimental/__init__.py +5 -0
  55. liger_kernel/transformers/experimental/embedding.py +26 -0
  56. liger_kernel/transformers/fsdp.py +55 -0
  57. liger_kernel/transformers/functional.py +301 -0
  58. liger_kernel/transformers/fused_add_rms_norm.py +39 -0
  59. liger_kernel/transformers/fused_linear_cross_entropy.py +59 -10
  60. liger_kernel/transformers/fused_linear_jsd.py +95 -0
  61. liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
  62. liger_kernel/transformers/geglu.py +6 -7
  63. liger_kernel/transformers/group_norm.py +50 -0
  64. liger_kernel/transformers/grpo_loss.py +153 -0
  65. liger_kernel/transformers/jsd.py +70 -0
  66. liger_kernel/transformers/kl_div.py +12 -0
  67. liger_kernel/transformers/layer_norm.py +24 -0
  68. liger_kernel/transformers/llama4_rope.py +93 -0
  69. liger_kernel/transformers/model/falcon_h1.py +122 -0
  70. liger_kernel/transformers/model/gemma.py +261 -0
  71. liger_kernel/transformers/model/gemma2.py +283 -0
  72. liger_kernel/transformers/model/gemma3.py +332 -0
  73. liger_kernel/transformers/model/glm4.py +141 -0
  74. liger_kernel/transformers/model/glm4v.py +163 -0
  75. liger_kernel/transformers/model/glm4v_moe.py +172 -0
  76. liger_kernel/transformers/model/gpt_oss.py +211 -0
  77. liger_kernel/transformers/model/hunyuan_v1.py +134 -0
  78. liger_kernel/transformers/model/internvl.py +157 -0
  79. liger_kernel/transformers/model/llama.py +221 -41
  80. liger_kernel/transformers/model/llama4.py +121 -0
  81. liger_kernel/transformers/model/llava.py +344 -0
  82. liger_kernel/transformers/model/loss_utils.py +95 -0
  83. liger_kernel/transformers/model/mistral.py +145 -0
  84. liger_kernel/transformers/model/mixtral.py +293 -0
  85. liger_kernel/transformers/model/mllama.py +269 -0
  86. liger_kernel/transformers/model/olmo2.py +141 -0
  87. liger_kernel/transformers/model/olmo3.py +142 -0
  88. liger_kernel/transformers/model/output_classes.py +147 -0
  89. liger_kernel/transformers/model/paligemma.py +433 -0
  90. liger_kernel/transformers/model/phi3.py +120 -0
  91. liger_kernel/transformers/model/qwen2.py +259 -0
  92. liger_kernel/transformers/model/qwen2_5_vl.py +163 -0
  93. liger_kernel/transformers/model/qwen2_vl.py +159 -0
  94. liger_kernel/transformers/model/qwen3.py +136 -0
  95. liger_kernel/transformers/model/qwen3_moe.py +152 -0
  96. liger_kernel/transformers/model/qwen3_next.py +146 -0
  97. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  98. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  99. liger_kernel/transformers/model/smollm3.py +199 -0
  100. liger_kernel/transformers/model/smolvlm.py +158 -0
  101. liger_kernel/transformers/monkey_patch.py +2816 -21
  102. liger_kernel/transformers/multi_token_attention.py +64 -0
  103. liger_kernel/transformers/poly_norm.py +42 -0
  104. liger_kernel/transformers/qwen2vl_mrope.py +20 -0
  105. liger_kernel/transformers/rms_norm.py +75 -5
  106. liger_kernel/transformers/rope.py +47 -3
  107. liger_kernel/transformers/softmax.py +12 -0
  108. liger_kernel/transformers/sparsemax.py +16 -0
  109. liger_kernel/transformers/swiglu.py +62 -6
  110. liger_kernel/transformers/tiled_mlp.py +133 -0
  111. liger_kernel/transformers/trainer/__init__.py +4 -0
  112. liger_kernel/transformers/trainer/orpo_trainer.py +130 -0
  113. liger_kernel/transformers/trainer_integration.py +2 -45
  114. liger_kernel/transformers/tvd.py +13 -0
  115. liger_kernel/triton/__init__.py +1 -3
  116. liger_kernel/triton/monkey_patch.py +1 -5
  117. liger_kernel/utils.py +96 -0
  118. liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/METADATA +447 -0
  119. liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/NOTICE +58 -0
  120. liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/RECORD +124 -0
  121. {liger_kernel_nightly-0.0.1.dev20240819184814.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/WHEEL +1 -1
  122. liger_kernel_nightly-0.0.1.dev20240819184814.dist-info/METADATA +0 -21
  123. liger_kernel_nightly-0.0.1.dev20240819184814.dist-info/NOTICE +0 -4
  124. liger_kernel_nightly-0.0.1.dev20240819184814.dist-info/RECORD +0 -27
  125. {liger_kernel_nightly-0.0.1.dev20240819184814.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/LICENSE +0 -0
  126. {liger_kernel_nightly-0.0.1.dev20240819184814.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/top_level.txt +0 -0
liger_kernel/ops/geglu.py CHANGED
@@ -4,23 +4,25 @@ import torch
4
4
  import triton
5
5
  import triton.language as tl
6
6
 
7
- from liger_kernel.ops.utils import (
8
- calculate_settings,
9
- compare_version,
10
- ensure_contiguous,
11
- )
12
-
13
- if compare_version("triton", operator.ge, "3.0.0"):
14
- from triton.language.extra.libdevice import tanh
7
+ from liger_kernel.ops.utils import calculate_settings
8
+ from liger_kernel.ops.utils import compare_version
9
+ from liger_kernel.ops.utils import ensure_contiguous
10
+ from liger_kernel.utils import is_npu_available
11
+
12
+ if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
13
+ try:
14
+ # typical import path with dispatch available
15
+ from triton.language.extra.libdevice import tanh
16
+ except ModuleNotFoundError:
17
+ # for working with NGC containers
18
+ from triton.language.extra.cuda.libdevice import tanh
15
19
  else:
16
20
  from triton.language.math import tanh
17
21
 
18
22
 
19
23
  @triton.jit
20
- def _geglu_tanh_forward_kernel(
21
- a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
22
- ):
23
- program_id = tl.program_id(0)
24
+ def _geglu_tanh_forward_kernel(a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
25
+ program_id = tl.program_id(0).to(tl.int64)
24
26
 
25
27
  # locate start index
26
28
  a += program_id * stride
@@ -39,15 +41,13 @@ def _geglu_tanh_forward_kernel(
39
41
  tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)
40
42
  tanh_result = tanh(tanh_arg)
41
43
  geglu_a = 0.5 * a_row * (1 + tanh_result)
42
- c_row = geglu_a * b_row
44
+ c_row = geglu_a.cast(b_row.dtype) * b_row
43
45
  tl.store(c + col_offsets, c_row, mask=mask)
44
46
 
45
47
 
46
48
  @triton.jit
47
- def _geglu_tanh_backward_kernel(
48
- dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
49
- ):
50
- program_id = tl.program_id(0)
49
+ def _geglu_tanh_backward_kernel(dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
50
+ program_id = tl.program_id(0).to(tl.int64)
51
51
 
52
52
  # locate start index
53
53
  dc += program_id * stride
@@ -75,66 +75,68 @@ def _geglu_tanh_backward_kernel(
75
75
  # where z = sqrt(2/pi) * (a + 0.044715 * a^3)
76
76
  term1 = 0.5 * (1 + tanh_result)
77
77
  tanh_sq = tanh_result * tanh_result
78
- term2 = (
79
- 0.5
80
- * a_row
81
- * (1 - tanh_sq)
82
- * (sqrt_2_over_pi * (1 + 3 * 0.044715 * a_row * a_row))
83
- )
78
+ term2 = 0.5 * a_row * (1 - tanh_sq) * (sqrt_2_over_pi * (1 + 3 * 0.044715 * a_row * a_row))
84
79
  da_row = dc_row * b_row * (term1 + term2)
85
80
 
86
81
  tl.store(a + col_offsets, da_row, mask=mask)
87
82
  tl.store(b + col_offsets, db_row, mask=mask)
88
83
 
89
84
 
85
+ def geglu_forward(a, b):
86
+ ori_shape = a.shape
87
+
88
+ n_cols = ori_shape[-1]
89
+ a = a.view(-1, n_cols)
90
+ b = b.view(-1, n_cols)
91
+ c = torch.empty_like(a)
92
+ n_rows = a.shape[0]
93
+
94
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
95
+
96
+ _geglu_tanh_forward_kernel[(n_rows,)](
97
+ a,
98
+ b,
99
+ c,
100
+ c.stride(-2),
101
+ n_cols=n_cols,
102
+ BLOCK_SIZE=BLOCK_SIZE,
103
+ num_warps=num_warps,
104
+ )
105
+ return a, b, c.view(*ori_shape)
106
+
107
+
108
+ def geglu_backward(a, b, dc):
109
+ ori_shape = dc.shape
110
+ n_cols = ori_shape[-1]
111
+ dc = dc.view(-1, n_cols)
112
+ n_rows = dc.shape[0]
113
+
114
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
115
+
116
+ _geglu_tanh_backward_kernel[(n_rows,)](
117
+ dc,
118
+ a,
119
+ b,
120
+ dc.stride(-2),
121
+ n_cols=n_cols,
122
+ BLOCK_SIZE=BLOCK_SIZE,
123
+ num_warps=num_warps,
124
+ )
125
+
126
+ return a.view(*ori_shape), b.view(*ori_shape)
127
+
128
+
90
129
  class LigerGELUMulFunction(torch.autograd.Function):
91
130
  @staticmethod
92
131
  @ensure_contiguous
93
132
  def forward(ctx, a, b):
94
- ori_shape = a.shape
95
-
96
- n_cols = ori_shape[-1]
97
- a = a.view(-1, n_cols)
98
- b = b.view(-1, n_cols)
99
- c = torch.zeros_like(a)
100
- n_rows = a.shape[0]
101
-
102
- BLOCK_SIZE, num_warps = calculate_settings(n_cols)
103
-
104
- _geglu_tanh_forward_kernel[(n_rows,)](
105
- a,
106
- b,
107
- c,
108
- c.stride(-2),
109
- n_cols=n_cols,
110
- BLOCK_SIZE=BLOCK_SIZE,
111
- num_warps=num_warps,
112
- )
113
-
133
+ a, b, c = geglu_forward(a, b)
114
134
  ctx.save_for_backward(a, b)
115
-
116
- return c.view(*ori_shape)
135
+ return c
117
136
 
118
137
  @staticmethod
119
138
  @ensure_contiguous
120
139
  def backward(ctx, dc):
121
-
122
- ori_shape = dc.shape
123
- n_cols = ori_shape[-1]
124
- dc = dc.view(-1, n_cols)
125
140
  a, b = ctx.saved_tensors
126
- n_rows = dc.shape[0]
127
-
128
- BLOCK_SIZE, num_warps = calculate_settings(n_cols)
129
-
130
- _geglu_tanh_backward_kernel[(n_rows,)](
131
- dc,
132
- a,
133
- b,
134
- dc.stride(-2),
135
- n_cols=n_cols,
136
- BLOCK_SIZE=BLOCK_SIZE,
137
- num_warps=num_warps,
138
- )
139
-
140
- return a.view(*ori_shape), b.view(*ori_shape)
141
+ a, b = geglu_backward(a, b, dc)
142
+ return a, b
@@ -0,0 +1,306 @@
1
+ import operator
2
+
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+ from liger_kernel.ops.utils import compare_version
8
+ from liger_kernel.ops.utils import ensure_contiguous
9
+ from liger_kernel.utils import is_npu_available
10
+
11
+ if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
12
+ try:
13
+ # typical import path with dispatch available
14
+ from triton.language.extra.libdevice import rsqrt
15
+ except ModuleNotFoundError:
16
+ # for working with NGC containers
17
+ from triton.language.extra.cuda.libdevice import rsqrt
18
+ else:
19
+ from triton.language.math import rsqrt
20
+
21
+ MAX_FUSED_SIZE = 65536
22
+
23
+
24
+ @triton.jit
25
+ def _group_norm_forward_kernel(
26
+ Y_ptr, # pointer to output, shape (n_rows, n_groups, hidden_size)
27
+ Y_row_stride, # stride of each row in output
28
+ Y_col_stride, # stride of each column in output
29
+ X_ptr, # pointer to input, shape (n_rows, n_groups, hidden_size)
30
+ X_row_stride, # stride of each row in input
31
+ X_col_stride, # stride of each column in input
32
+ Mean_ptr, # pointer to mean, shape (n_rows, n_groups)
33
+ Mean_row_stride, # stride of each row in mean
34
+ Mean_col_stride, # stride of each column in mean
35
+ RSTD_ptr, # pointer to rstd, shape (n_rows, n_groups)
36
+ RSTD_row_stride, # stride of each row in rstd
37
+ RSTD_col_stride, # stride of each column in rstd
38
+ W_ptr, # pointer to W
39
+ B_ptr, # pointer to B
40
+ hidden_size, # hidden size of X
41
+ channels_per_group, # the number of channels per group
42
+ eps,
43
+ BLOCK_SIZE: tl.constexpr,
44
+ ):
45
+ """
46
+ References:
47
+ https://nn.labml.ai/normalization/group_norm/index.html
48
+ """
49
+ batch_idx = tl.program_id(0)
50
+ group_idx = tl.program_id(1)
51
+
52
+ X_ptr += batch_idx * X_row_stride + group_idx * X_col_stride
53
+ Y_ptr += batch_idx * Y_row_stride + group_idx * Y_col_stride
54
+
55
+ block_range = tl.arange(0, BLOCK_SIZE)
56
+
57
+ # Compute mean and variance using the online algorithm
58
+ s = 0.0
59
+ squared_sum = 0.0
60
+ for i in tl.range(0, hidden_size, BLOCK_SIZE):
61
+ hidden_size_offsets = i + block_range
62
+ mask = hidden_size_offsets < hidden_size
63
+ X = tl.load(X_ptr + hidden_size_offsets, mask=mask, other=0.0)
64
+ s += tl.sum(X)
65
+ # X**2
66
+ squared_sum += tl.sum(X * X)
67
+
68
+ m = s / hidden_size
69
+
70
+ # variance = E[X**2] - E[X]**2
71
+ variance = (squared_sum / hidden_size) - (m * m)
72
+
73
+ # 1/std
74
+ rstd = rsqrt(variance + eps)
75
+
76
+ # Normalize
77
+ hidden_size_per_channel = hidden_size // channels_per_group
78
+ for channel_idx in tl.range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group):
79
+ W = tl.load(W_ptr + channel_idx)
80
+ B = tl.load(B_ptr + channel_idx)
81
+ for i in range(0, hidden_size_per_channel, BLOCK_SIZE):
82
+ hidden_size_offsets = i + block_range
83
+ mask = hidden_size_offsets < hidden_size_per_channel
84
+ X = tl.load(X_ptr + hidden_size_offsets, mask=mask, other=m)
85
+ Y = (X - m) * rstd * W + B
86
+ tl.store(Y_ptr + hidden_size_offsets, Y, mask=mask)
87
+
88
+ X_ptr += hidden_size_per_channel
89
+ Y_ptr += hidden_size_per_channel
90
+
91
+ tl.store(Mean_ptr + batch_idx * Mean_row_stride + group_idx * Mean_col_stride, m)
92
+ tl.store(RSTD_ptr + batch_idx * RSTD_row_stride + group_idx * RSTD_col_stride, rstd)
93
+
94
+
95
+ @triton.jit
96
+ def _group_norm_backward_kernel(
97
+ X_ptr, # pointer to input, shape (n_rows, n_channels, hidden_size)
98
+ X_row_stride, # stride of each row in input
99
+ X_col_stride, # stride of each column in input
100
+ W_ptr, # pointer to weights, shape (n_channels)
101
+ Mean_ptr, # pointer to mean, shape (n_rows, n_groups)
102
+ Mean_ptr_row_stride, # stride of each column in mean
103
+ Mean_ptr_col_stride, # stride of each column in mean
104
+ RSTD_ptr, # pointer to rstd, shape (n_rows, n_groups)
105
+ DX_ptr, # pointer to input grad, shape (n_rows, n_groups, hidden_size)
106
+ DW_ptr, # pointer to weights grad, shape (n_channels)
107
+ DB_ptr, # pointer to bias grad, shape (n_channels)
108
+ UPSTREAM_ptr, # pointer to output grad, shape (n_rows, n_channels, hidden_size)
109
+ hidden_size: tl.constexpr, # hidden size
110
+ channels_per_group: tl.constexpr, # number of groups in group norm
111
+ BLOCK_SIZE: tl.constexpr,
112
+ dtype: tl.constexpr,
113
+ ):
114
+ """
115
+ References:
116
+ https://nn.labml.ai/normalization/group_norm/index.html
117
+ https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
118
+
119
+ The backprop equations are the same for group_norm and layer_norm
120
+ the only difference here is that we load the Mean, Rstd corresponding to the
121
+ group we're computing gradients for and the mean and rstd are computed over n-channels
122
+ so the total number of elements we compute the mean over is num_channels_per_group * hidden_size
123
+
124
+ We also need to load the Weights corresponding to the current channel to compute the gradients.
125
+ """
126
+ batch_idx = tl.program_id(0)
127
+ group_idx = tl.program_id(1)
128
+
129
+ # Move the pointers to the correct batch
130
+ X_ptr += batch_idx * X_row_stride
131
+ DX_ptr += batch_idx * X_row_stride
132
+ UPSTREAM_ptr += batch_idx * X_row_stride
133
+
134
+ # Mean and rstd are the same shape so have the same strides
135
+ mean = tl.load(Mean_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride)
136
+ rstd = tl.load(RSTD_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride)
137
+
138
+ c1 = 0.0
139
+ c2 = 0.0
140
+ block_range = tl.arange(0, BLOCK_SIZE)
141
+
142
+ # We need to compute the sum terms of the backprop equations across all channels in the group
143
+ for channel_idx in range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group):
144
+ dW = 0.0
145
+ dB = 0.0
146
+ # Move the pointers to the correct channel
147
+ W = tl.load(W_ptr + channel_idx)
148
+ for i in tl.range(0, hidden_size, BLOCK_SIZE):
149
+ hidden_size_offsets = i + block_range
150
+ mask = hidden_size_offsets < hidden_size
151
+ X = tl.load(
152
+ X_ptr + channel_idx * X_col_stride + hidden_size_offsets,
153
+ mask=mask,
154
+ other=0.0,
155
+ )
156
+ UPSTREAM_grad = tl.load(
157
+ UPSTREAM_ptr + channel_idx * X_col_stride + hidden_size_offsets,
158
+ mask=mask,
159
+ other=0.0,
160
+ )
161
+
162
+ x_hat = (X - mean) * rstd
163
+ dW += tl.sum(UPSTREAM_grad * x_hat)
164
+ dB += tl.sum(UPSTREAM_grad)
165
+
166
+ wdy = W * UPSTREAM_grad
167
+ c1 += tl.sum(x_hat * wdy)
168
+ c2 += tl.sum(wdy)
169
+
170
+ # Need to ensure additions to the same channel are atomic
171
+ tl.atomic_add(DW_ptr + channel_idx, dW.to(dtype))
172
+ tl.atomic_add(DB_ptr + channel_idx, dB.to(dtype))
173
+
174
+ N = hidden_size * channels_per_group
175
+ c1 = c1 / N
176
+ c2 = c2 / N
177
+
178
+ for channel_idx in tl.range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group):
179
+ # Move the pointers to the correct channel
180
+ W = tl.load(W_ptr + channel_idx)
181
+ for i in range(0, hidden_size, BLOCK_SIZE):
182
+ hidden_size_offsets = i + block_range
183
+ mask = hidden_size_offsets < hidden_size
184
+ X = tl.load(
185
+ X_ptr + channel_idx * X_col_stride + hidden_size_offsets,
186
+ mask=mask,
187
+ other=0.0,
188
+ )
189
+ UPSTREAM_grad = tl.load(
190
+ UPSTREAM_ptr + channel_idx * X_col_stride + hidden_size_offsets,
191
+ mask=mask,
192
+ other=0.0,
193
+ )
194
+
195
+ x_hat = (X - mean) * rstd
196
+ wdy = W * UPSTREAM_grad
197
+ dx = (wdy - (x_hat * c1 + c2)) * rstd
198
+ tl.store(DX_ptr + channel_idx * X_col_stride + hidden_size_offsets, dx, mask=mask)
199
+
200
+
201
+ def group_norm_forward(X, num_channels, num_groups, W, B, eps):
202
+ shape = X.shape
203
+ batch_size = shape[0]
204
+ channels_per_group = num_channels // num_groups
205
+ # Reshape X so that the mean and std are computed across the groups
206
+ X = X.view(batch_size, num_groups, -1).contiguous()
207
+ hidden_size = X.shape[-1]
208
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(hidden_size))
209
+ Y = torch.empty((batch_size, num_groups, hidden_size), dtype=X.dtype, device=X.device)
210
+ Mean = torch.zeros((batch_size, num_groups), dtype=X.dtype, device=X.device)
211
+ RSTD = torch.zeros((batch_size, num_groups), dtype=X.dtype, device=X.device)
212
+
213
+ _group_norm_forward_kernel[(batch_size, num_groups)](
214
+ Y,
215
+ Y.stride(0),
216
+ Y.stride(1),
217
+ X,
218
+ X.stride(0),
219
+ X.stride(1),
220
+ Mean,
221
+ Mean.stride(0),
222
+ Mean.stride(1),
223
+ RSTD,
224
+ RSTD.stride(0),
225
+ RSTD.stride(1),
226
+ W,
227
+ B,
228
+ hidden_size,
229
+ channels_per_group,
230
+ eps,
231
+ BLOCK_SIZE=BLOCK_SIZE,
232
+ )
233
+ # Return tensors in the original shape
234
+ return Y.view(*shape), X.view(*shape), Mean, RSTD, BLOCK_SIZE
235
+
236
+
237
+ def group_norm_backward(dY, X, W, B, Mean, RSTD, num_channels, num_groups):
238
+ shape = dY.shape
239
+ batch_size = shape[0]
240
+ hidden_size = dY.shape[-1]
241
+ channels_per_group = num_channels // num_groups
242
+ dY = dY.view(batch_size, num_groups, -1)
243
+ DX = torch.empty(
244
+ (batch_size, num_groups, hidden_size * channels_per_group),
245
+ dtype=X.dtype,
246
+ device=X.device,
247
+ )
248
+ DW = torch.zeros((num_channels), dtype=W.dtype, device=W.device)
249
+ DB = torch.zeros((num_channels), dtype=B.dtype, device=B.device)
250
+ triton_dtype = tl.float32 if X.dtype == torch.float32 else tl.bfloat16
251
+
252
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(hidden_size))
253
+ _group_norm_backward_kernel[(batch_size, num_groups)](
254
+ X,
255
+ X.stride(0),
256
+ X.stride(1),
257
+ W,
258
+ Mean,
259
+ Mean.stride(0),
260
+ Mean.stride(1),
261
+ RSTD,
262
+ DX,
263
+ DW,
264
+ DB,
265
+ dY,
266
+ hidden_size,
267
+ channels_per_group,
268
+ BLOCK_SIZE=BLOCK_SIZE,
269
+ dtype=triton_dtype,
270
+ )
271
+
272
+ # Return tensors in the original shape
273
+ return DX.view(*shape), DW, DB
274
+
275
+
276
+ class LigerGroupNormFunction(torch.autograd.Function):
277
+ @staticmethod
278
+ @ensure_contiguous
279
+ def forward(
280
+ ctx,
281
+ X,
282
+ affine_scaling_weight,
283
+ affine_shifting_bias,
284
+ num_channels,
285
+ num_groups,
286
+ eps,
287
+ ):
288
+ Y, X, Mean, RSTD, BLOCK_SIZE = group_norm_forward(
289
+ X,
290
+ num_channels,
291
+ num_groups,
292
+ affine_scaling_weight,
293
+ affine_shifting_bias,
294
+ eps,
295
+ )
296
+ ctx.num_channels = num_channels
297
+ ctx.num_groups = num_groups
298
+ ctx.save_for_backward(X, affine_scaling_weight, affine_shifting_bias, Mean, RSTD)
299
+ return Y
300
+
301
+ @staticmethod
302
+ @ensure_contiguous
303
+ def backward(ctx, dY):
304
+ X, W, B, Mean, RSTD = ctx.saved_tensors
305
+ DX, DW, DB = group_norm_backward(dY, X, W, B, Mean, RSTD, ctx.num_channels, ctx.num_groups)
306
+ return DX, DW, DB, None, None, None