liger-kernel 0.1.1__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,46 @@
1
+ import platform
2
+ import sys
3
+
4
+
5
+ def print_env_report():
6
+ """
7
+ Prints a report of the environment. Useful for debugging and reproducibility.
8
+ Usage:
9
+ ```
10
+ python -m liger_kernel.env_report
11
+ ```
12
+ """
13
+ print("Environment Report:")
14
+ print("-------------------")
15
+ print(f"Operating System: {platform.platform()}")
16
+ print(f"Python version: {sys.version.split()[0]}")
17
+
18
+ try:
19
+ import torch
20
+
21
+ print(f"PyTorch version: {torch.__version__}")
22
+ cuda_version = (
23
+ torch.version.cuda if torch.cuda.is_available() else "Not available"
24
+ )
25
+ print(f"CUDA version: {cuda_version}")
26
+ except ImportError:
27
+ print("PyTorch: Not installed")
28
+ print("CUDA version: Unable to query")
29
+
30
+ try:
31
+ import triton
32
+
33
+ print(f"Triton version: {triton.__version__}")
34
+ except ImportError:
35
+ print("Triton: Not installed")
36
+
37
+ try:
38
+ import transformers
39
+
40
+ print(f"Transformers version: {transformers.__version__}")
41
+ except ImportError:
42
+ print("Transformers: Not installed")
43
+
44
+
45
+ if __name__ == "__main__":
46
+ print_env_report()
@@ -56,7 +56,7 @@ def liger_cross_entropy_kernel(
56
56
  # Online softmax: 2 loads + 1 store (compared with 3 loads + 1 store for the safe softmax)
57
57
  # Refer to Algorithm 3 in the paper: https://arxiv.org/pdf/1805.02867
58
58
 
59
- # 3. [Oneline softmax] first pass: find max + sum
59
+ # 3. [Online softmax] first pass: find max + sum
60
60
  m = float("-inf") # m is the max value. use the notation from the paper
61
61
  d = 0.0 # d is the sum. use the notation from the paper
62
62
  ori_X_y = tl.load(
@@ -73,10 +73,10 @@ def liger_cross_entropy_kernel(
73
73
  d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new))
74
74
  m = m_new
75
75
 
76
- # 4. [Oneline softmax] second pass: calculate the gradients
76
+ # 4. [Online softmax] second pass: calculate the gradients
77
77
  # dx_y = (softmax(x_y) - 1) / N
78
78
  # dx_i = softmax(x_i) / N, i != y
79
- # N is the number of non ingored elements in the batch
79
+ # N is the number of non ignored elements in the batch
80
80
  for i in range(0, n_cols, BLOCK_SIZE):
81
81
  X_offsets = i + tl.arange(0, BLOCK_SIZE)
82
82
  X_block = tl.load(
@@ -86,7 +86,7 @@ def liger_cross_entropy_kernel(
86
86
  tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)
87
87
 
88
88
  # We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in
89
- # ttps://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34
89
+ # https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34
90
90
  tl.debug_barrier()
91
91
 
92
92
  # 5. Calculate the loss
@@ -196,7 +196,7 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
196
196
  ignore_index=ignore_index,
197
197
  BLOCK_SIZE=BLOCK_SIZE,
198
198
  # TODO: 32 seems to give the best performance
199
- # Performance is quite sentitive to num_warps
199
+ # Performance is quite sensitive to num_warps
200
200
  num_warps=32,
201
201
  )
202
202
 
@@ -11,7 +11,7 @@ MAX_FUSED_SIZE = 65536 // 2
11
11
 
12
12
  class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
13
13
  @staticmethod
14
- def forward(ctx, _input, linear, target, ignore_index):
14
+ def forward(ctx, _input, weight, target, bias=None, ignore_index=-100):
15
15
  """
16
16
  Fusing the last linear layer with cross-entropy loss
17
17
  Reference: https://github.com/mgmalek/efficient_cross_entropy
@@ -23,7 +23,8 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
23
23
 
24
24
  _input: (B*T, H) where B is batch size, T is sequence length, H is hidden dimension.
25
25
  target: (B*T) where each value is in [0, V-1]
26
- linear: linear projection matrix of shape V x H.
26
+ weight: (V, H) where V is the number of classes
27
+ bias: (V) where V is the number of classes
27
28
  ignore_index: the index to ignore in the target
28
29
  """
29
30
  dtype = (
@@ -36,12 +37,12 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
36
37
  # inputs have shape: BT x H
37
38
  # materialized activations will have shape: BT x V
38
39
  # the increase in memory = BT x V
39
- # reduction can be achieved by paritioning the number of tokens BT into smaller chunks.
40
+ # reduction can be achieved by partitioning the number of tokens BT into smaller chunks.
40
41
  # for ex: if we were to achieve the same memory consumption as BT x H, then the chunk size should be:
41
42
  # inc_factor = (V+H-1)//H, chunk_size = (BT + inc_factor - 1)//inc_factor
42
43
  # for ex: BT = 4096*4, V = 32000, H = 4096 ==> inc_factor = 8, chunk_size = 2048
43
44
  BT, H = _input.shape
44
- V = linear.shape[0]
45
+ V = weight.shape[0]
45
46
  BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
46
47
 
47
48
  inc_factor = triton.cdiv(V, H) # (V + H - 1) // H
@@ -50,9 +51,9 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
50
51
  ) # (BT + inc_factor - 1) // inc_factor
51
52
  num_chunks = triton.cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size
52
53
 
53
- grad_linear = torch.zeros_like(linear, device=device)
54
+ grad_weight = torch.zeros_like(weight, device=device)
54
55
  grad_input = torch.zeros_like(_input, device=device)
55
-
56
+ grad_bias = torch.zeros_like(bias, device=device) if bias is not None else None
56
57
  # we use fp32 for loss accumulator
57
58
  loss_1d = torch.zeros(BT, dtype=torch.float32, device=device)
58
59
 
@@ -64,7 +65,9 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
64
65
  _input_chunk = _input[start_idx:end_idx] # chunk_size x H
65
66
 
66
67
  # when doing matmul, use the original precision
67
- logits_chunk = _input_chunk @ linear.t() # chunk_size x V
68
+ logits_chunk = _input_chunk @ weight.t() # chunk_size x V
69
+ if bias is not None:
70
+ logits_chunk = logits_chunk + bias
68
71
  target_chunk = target[start_idx:end_idx] # chunk_size,
69
72
 
70
73
  n_rows = logits_chunk.shape[0]
@@ -95,39 +98,52 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
95
98
  num_warps=32,
96
99
  )
97
100
 
98
- # gradient of logits_chunk is computed inplace by the above triton kernel.
101
+ # gradient of logits_chunk is computed in-place by the above triton kernel.
99
102
  # Following HuggingFace model source code, we do the forward and backward
100
103
  # w.r.t. logits in fp32 for numerical stability especially as the num classes (vocab size) os huge.
101
104
  # (reference: https://github.com/huggingface/transformers/blob/v4.42.4/src/transformers/models/llama/modeling_llama.py#L1194)
102
105
  # Propagating to lm_head's backward, we'll switch back to the original dtype.
103
106
  logits_chunk = logits_chunk.to(dtype)
104
107
 
105
- # gradient of logits_chunk is computed inplace by the above triton kernel and is of shape: chunk_size x V
108
+ # gradient of logits_chunk is computed in-place by the above triton kernel and is of shape: chunk_size x V
106
109
  # thus grad_input[start_idx: end_idx] should be of shape: chunk_size x H
107
110
  # additionally, since we are chunking the inputs, observe that the loss and gradients are calculated only
108
111
  # on `n_non_ignore` tokens. However, the gradient of the input should be calculated for all tokens.
109
112
  # Thus, we need an additional scaling factor of (n_non_ignore/total_n_non_ignore) to scale the gradients.
110
- grad_logits_chunk = logits_chunk * (n_non_ignore / total_n_non_ignore)
111
- grad_input[start_idx:end_idx] = grad_logits_chunk @ linear
112
-
113
+ grad_logits_chunk = logits_chunk * (
114
+ n_non_ignore / total_n_non_ignore
115
+ ) # chunk_size x V
116
+ grad_input[start_idx:end_idx] = grad_logits_chunk @ weight
113
117
  torch.addmm(
114
- input=grad_linear,
118
+ input=grad_weight,
115
119
  mat1=logits_chunk.t(),
116
120
  mat2=_input_chunk,
117
- out=grad_linear,
121
+ out=grad_weight,
118
122
  alpha=n_non_ignore / total_n_non_ignore,
119
123
  beta=1.0,
120
124
  )
121
125
 
126
+ if bias is not None:
127
+ torch.add(
128
+ input=grad_bias,
129
+ other=logits_chunk.sum(dim=0),
130
+ out=grad_bias,
131
+ alpha=n_non_ignore / total_n_non_ignore,
132
+ )
133
+
122
134
  loss = torch.sum(loss_1d) / total_n_non_ignore
123
135
 
124
136
  # downcast to dtype and store for backward
125
- ctx.save_for_backward(grad_input.detach(), grad_linear.detach())
137
+ ctx.save_for_backward(
138
+ grad_input.detach(),
139
+ grad_weight.detach(),
140
+ grad_bias.detach() if bias is not None else None,
141
+ )
126
142
  return loss
127
143
 
128
144
  @staticmethod
129
145
  def backward(ctx, grad_output):
130
- (grad_input, grad_linear) = ctx.saved_tensors
146
+ (grad_input, grad_weight, grad_bias) = ctx.saved_tensors
131
147
  # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
132
148
  if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)):
133
149
  # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
@@ -145,17 +161,30 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
145
161
  num_warps=32,
146
162
  )
147
163
 
148
- # handle grad_linear
149
- V, H = grad_linear.shape
164
+ # handle grad_weight
165
+ V, H = grad_weight.shape
150
166
  n_rows = V
151
167
 
152
168
  element_mul[(n_rows,)](
153
- grad_linear,
154
- grad_linear.stride(-2),
169
+ grad_weight,
170
+ grad_weight.stride(-2),
155
171
  grad_output,
156
172
  H,
157
173
  BLOCK_SIZE=BLOCK_SIZE,
158
174
  num_warps=32,
159
175
  )
160
176
 
161
- return (grad_input, grad_linear, None, None)
177
+ if grad_bias is not None:
178
+ V = grad_bias.shape[0]
179
+ n_rows = V
180
+
181
+ element_mul[(n_rows,)](
182
+ grad_bias,
183
+ grad_bias.stride(-1),
184
+ grad_output,
185
+ 1,
186
+ BLOCK_SIZE=BLOCK_SIZE,
187
+ num_warps=32,
188
+ )
189
+
190
+ return (grad_input, grad_weight, None, grad_bias, None)
liger_kernel/ops/geglu.py CHANGED
@@ -11,7 +11,12 @@ from liger_kernel.ops.utils import (
11
11
  )
12
12
 
13
13
  if compare_version("triton", operator.ge, "3.0.0"):
14
- from triton.language.extra.libdevice import tanh
14
+ try:
15
+ # typical import path with dispatch available
16
+ from triton.language.extra.libdevice import tanh
17
+ except ModuleNotFoundError:
18
+ # for working with NGC containers
19
+ from triton.language.extra.cuda.libdevice import tanh
15
20
  else:
16
21
  from triton.language.math import tanh
17
22
 
@@ -1,8 +1,29 @@
1
+ import operator
2
+
1
3
  import torch
2
4
  import triton
3
5
  import triton.language as tl
4
6
 
5
- from liger_kernel.ops.utils import calculate_settings, ensure_contiguous
7
+ from liger_kernel.ops.utils import (
8
+ calculate_settings,
9
+ compare_version,
10
+ ensure_contiguous,
11
+ )
12
+
13
+ if compare_version("triton", operator.ge, "3.0.0"):
14
+ try:
15
+ # typical import path with dispatch available
16
+ from triton.language.extra.libdevice import rsqrt
17
+ except ModuleNotFoundError:
18
+ # for working with NGC containers
19
+ from triton.language.extra.cuda.libdevice import rsqrt
20
+ else:
21
+ from triton.language.math import rsqrt
22
+
23
+
24
+ _CASTING_MODE_NONE = tl.constexpr(-1)
25
+ _CASTING_MODE_LLAMA = tl.constexpr(0)
26
+ _CASTING_MODE_GEMMA = tl.constexpr(1)
6
27
 
7
28
 
8
29
  @triton.jit
@@ -17,10 +38,12 @@ def _rms_norm_forward(
17
38
  r_row_stride,
18
39
  n_cols,
19
40
  eps,
41
+ offset,
42
+ casting_mode: tl.constexpr, # constexpr so the `if` blocks can be optimized out
20
43
  BLOCK_SIZE: tl.constexpr,
21
44
  ):
22
45
  """
23
- y_i = (x_i / (RMS)) * wi, RMS = sqrt(sum(x_i^2) / N)
46
+ y_i = (x_i / (RMS)) * (offset + wi), RMS = sqrt(sum(x_i^2) / N)
24
47
 
25
48
  Reference:
26
49
  1. https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
@@ -37,17 +60,33 @@ def _rms_norm_forward(
37
60
  r_ptr += row_idx * r_row_stride
38
61
 
39
62
  X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0)
63
+ X_row_dtype = X_row.dtype
40
64
  W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
41
65
 
66
+ # On Llama, only inv_rms is computed on fp32
67
+ if casting_mode == _CASTING_MODE_LLAMA:
68
+ X_row = X_row.to(tl.float32)
69
+
70
+ # Gemma computes everything on fp32, and then casts back the output to the original dtype
71
+ if casting_mode == _CASTING_MODE_GEMMA:
72
+ W_row = W_row.to(tl.float32)
73
+ X_row = X_row.to(tl.float32)
74
+
42
75
  mean_square = tl.sum(X_row * X_row, axis=0) / n_cols
43
- inv_rms = tl.math.rsqrt(mean_square + eps)
76
+ inv_rms = rsqrt(mean_square + eps)
44
77
 
45
78
  # We can save time by caching rms with minimal memory overhead
46
79
  # because rms is much smaller compared to X_row, as rms is for each row.
47
80
  # However, on the computation side, it can save 4 operations (*, sum, /, sqrt).
48
81
  tl.store(r_ptr, inv_rms)
49
82
 
50
- Y_row = X_row * inv_rms * W_row
83
+ X_row = X_row * inv_rms
84
+
85
+ # On Llama, the multiplication with the weight is done on the original dtype
86
+ if casting_mode == _CASTING_MODE_LLAMA:
87
+ X_row = X_row.to(X_row_dtype)
88
+
89
+ Y_row = X_row * (offset + W_row)
51
90
 
52
91
  tl.store(Y_ptr + col_offsets, Y_row, mask=mask)
53
92
 
@@ -66,10 +105,12 @@ def _rms_norm_backward(
66
105
  dW_row_stride,
67
106
  n_cols,
68
107
  eps,
108
+ offset,
109
+ casting_mode: tl.constexpr,
69
110
  BLOCK_SIZE: tl.constexpr,
70
111
  ):
71
112
  """
72
- dx = (1 / RMS) * [dy * w - (1 / N) * (1 / RMS^2) * ((dy * w) dot x) * x]. * means element-wise multiplication, whileas dot means dot product
113
+ dx = (1 / RMS) * [dy * (w + offset - (1 / N) * (1 / RMS^2) * ((dy * (w + offset)) dot x) * x]. * means element-wise multiplication, whileas dot means dot product
73
114
  dw = sum(dy * (x / RMS)). summation over BxT dimension
74
115
  """
75
116
 
@@ -85,33 +126,95 @@ def _rms_norm_backward(
85
126
  dY_row = tl.load(dY_ptr + col_offsets, mask=mask, other=0)
86
127
  X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0)
87
128
  W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
129
+ original_x_dtype = X_row.dtype
88
130
 
89
131
  # Get cached rms
90
132
  inv_rms_row = tl.load(r_ptr)
91
133
 
92
- dX_row = (inv_rms_row) * (
93
- dY_row * W_row
94
- - (1 / n_cols)
95
- * inv_rms_row
96
- * inv_rms_row
97
- * tl.sum(dY_row * W_row * X_row, axis=0)
98
- * X_row
99
- )
100
- tl.store(dY_ptr + col_offsets, dX_row, mask=mask)
134
+ W_row = W_row + offset
135
+
136
+ # Different bacward graphs for different casting modes
137
+ if casting_mode == _CASTING_MODE_LLAMA:
138
+ X_row = X_row.to(tl.float32)
139
+ m = (dY_row * W_row).to(tl.float32)
140
+ dX_row = inv_rms_row * m
141
+
142
+ dX_row += (inv_rms_row) * (
143
+ -(1 / n_cols)
144
+ * inv_rms_row
145
+ * inv_rms_row
146
+ * tl.sum(m * X_row, axis=0)
147
+ * X_row
148
+ )
149
+
150
+ if casting_mode == _CASTING_MODE_GEMMA:
151
+ dY_row, W_row, X_row = (
152
+ dY_row.to(tl.float32),
153
+ W_row.to(tl.float32),
154
+ X_row.to(tl.float32),
155
+ )
156
+ dX_row = inv_rms_row * dY_row * W_row
157
+
158
+ dX_row += (inv_rms_row) * (
159
+ -(1 / n_cols)
160
+ * inv_rms_row
161
+ * inv_rms_row
162
+ * tl.sum(dY_row * W_row * X_row, axis=0)
163
+ * X_row
164
+ )
101
165
 
102
166
  # calculate the gradient of W
103
- dW_row = dY_row * X_row * inv_rms_row
167
+ if casting_mode == _CASTING_MODE_LLAMA:
168
+ dW_row = dY_row * (X_row * inv_rms_row).to(original_x_dtype)
169
+ else:
170
+ # here X_row is already in fp32 (see previous if block)
171
+ dW_row = dY_row * (X_row * inv_rms_row)
172
+
173
+ tl.store(dY_ptr + col_offsets, dX_row, mask=mask)
104
174
  tl.store(dW_ptr + col_offsets, dW_row, mask=mask)
105
175
 
106
176
 
177
+ _str_to_casting_mode = {
178
+ "llama": _CASTING_MODE_LLAMA.value,
179
+ "gemma": _CASTING_MODE_GEMMA.value,
180
+ "none": _CASTING_MODE_NONE.value,
181
+ }
182
+
183
+
107
184
  class LigerRMSNormFunction(torch.autograd.Function):
185
+ """
186
+ Performs RMSNorm (Root Mean Square Normalization), which normalizes the input tensor `X` using the
187
+ weight tensor `W`, with an optional offset and casting mode.
188
+
189
+ Some models use an 'offset' to shift the weight tensor `W` by a constant value. For example, Gemma
190
+ uses an offset of 1.0, so the computation becomes `(X / RMS(X)) * (W + 1.0)` instead of the usual
191
+ `(X / RMS(X)) * W`. You can pass the offset value as an argument to the forward function.
192
+
193
+ In addition, different models cast their inputs at different places during RMSNorm computation. For
194
+ example, Gemma casts everything to fp32 nefore starting the computation, while Llama casts only the
195
+ inverse RMS to fp32. You can specify the casting mode using the `casting_mode` argument. We currently
196
+ support the following casting modes (they match HuggingFace Transformers' implementations):
197
+ - 'llama': matches the Llama implementation, where only the inverse RMS is computed on fp32.
198
+ - 'gemma': matches the Gemma implementation, where everything is cast to fp32, then computed, then cast back to the original dtype.
199
+ - 'none': no casting is done. The computation is done in the original dtype. This saves memory and is slightly faster, but has more error w.r.t. the original implementation.
200
+ """
201
+
108
202
  @staticmethod
109
203
  @ensure_contiguous
110
- def forward(ctx, X, W, eps):
204
+ def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama"):
111
205
  """
112
206
  X: (B, T, H) or (BxT, H)
113
207
  W: (H,)
114
208
  """
209
+ if not isinstance(casting_mode, int):
210
+ assert (
211
+ casting_mode in _str_to_casting_mode
212
+ ), f"Invalid casting mode: {casting_mode}"
213
+ casting_mode = _str_to_casting_mode[casting_mode]
214
+ else:
215
+ assert (
216
+ casting_mode in _str_to_casting_mode.values()
217
+ ), f"Invalid casting mode: {casting_mode}"
115
218
 
116
219
  shape = X.shape
117
220
  dim = shape[-1]
@@ -121,7 +224,13 @@ class LigerRMSNormFunction(torch.autograd.Function):
121
224
 
122
225
  Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
123
226
  # r is to cache (1/rms) for each row
124
- r = torch.empty(n_rows, dtype=X.dtype, device=X.device)
227
+ # r is always computed/stored in fp32 if we are using Llama or Gemma casting mode
228
+ r_dtype = (
229
+ torch.float32
230
+ if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value)
231
+ else X.dtype
232
+ )
233
+ r = torch.empty(n_rows, dtype=r_dtype, device=X.device)
125
234
 
126
235
  # Check constraints.
127
236
  assert (
@@ -139,10 +248,14 @@ class LigerRMSNormFunction(torch.autograd.Function):
139
248
  r.stride(0),
140
249
  n_cols,
141
250
  eps,
251
+ offset,
252
+ casting_mode,
142
253
  BLOCK_SIZE=BLOCK_SIZE,
143
254
  num_warps=num_warps,
144
255
  )
145
256
  ctx.eps = eps
257
+ ctx.offset = offset
258
+ ctx.casting_mode = casting_mode
146
259
  ctx.BLOCK_SIZE = BLOCK_SIZE
147
260
  ctx.num_warps = num_warps
148
261
 
@@ -161,7 +274,14 @@ class LigerRMSNormFunction(torch.autograd.Function):
161
274
  dY = dY.view(-1, dim)
162
275
  X, W, r = ctx.saved_tensors
163
276
  n_rows, n_cols = dY.shape
164
- dW = torch.zeros_like(X)
277
+ dW = torch.empty_like(
278
+ X,
279
+ dtype=(
280
+ torch.float32
281
+ if ctx.casting_mode == _CASTING_MODE_GEMMA.value
282
+ else W.dtype
283
+ ),
284
+ )
165
285
 
166
286
  # Here we use dY to store the value of dX to save memory
167
287
  _rms_norm_backward[(n_rows,)](
@@ -177,9 +297,11 @@ class LigerRMSNormFunction(torch.autograd.Function):
177
297
  dW.stride(0),
178
298
  n_cols,
179
299
  ctx.eps,
300
+ ctx.offset,
301
+ ctx.casting_mode,
180
302
  BLOCK_SIZE=ctx.BLOCK_SIZE,
181
303
  num_warps=ctx.num_warps,
182
304
  )
183
305
  dX = dY.view(*shape)
184
- dW = torch.sum(dW, dim=0)
185
- return dX, dW, None
306
+ dW = torch.sum(dW, dim=0).to(W.dtype)
307
+ return dX, dW, None, None, None
liger_kernel/ops/rope.py CHANGED
@@ -13,8 +13,8 @@ def _triton_rope(
13
13
  cos_row_stride,
14
14
  sin,
15
15
  sin_row_stride,
16
+ sl,
16
17
  bs: tl.constexpr,
17
- sl: tl.constexpr,
18
18
  n_qh: tl.constexpr,
19
19
  n_kh: tl.constexpr,
20
20
  hd: tl.constexpr,
@@ -168,8 +168,8 @@ class LigerRopeFunction(torch.autograd.Function):
168
168
  cos.stride(-2),
169
169
  sin,
170
170
  sin.stride(-2),
171
- batch_size,
172
171
  seq_len,
172
+ batch_size,
173
173
  n_q_head,
174
174
  n_kv_head,
175
175
  head_dim,
@@ -219,8 +219,8 @@ class LigerRopeFunction(torch.autograd.Function):
219
219
  cos.stride(-2),
220
220
  sin,
221
221
  sin.stride(-2),
222
- batch_size,
223
222
  seq_len,
223
+ batch_size,
224
224
  n_q_head,
225
225
  n_kv_head,
226
226
  head_dim,
@@ -1,6 +1,12 @@
1
+ from liger_kernel.transformers.auto_model import ( # noqa: F401
2
+ AutoLigerKernelForCausalLM,
3
+ )
1
4
  from liger_kernel.transformers.monkey_patch import ( # noqa: F401
2
5
  apply_liger_kernel_to_gemma,
6
+ apply_liger_kernel_to_gemma2,
3
7
  apply_liger_kernel_to_llama,
4
8
  apply_liger_kernel_to_mistral,
5
9
  apply_liger_kernel_to_mixtral,
10
+ apply_liger_kernel_to_phi3,
11
+ apply_liger_kernel_to_qwen2,
6
12
  )
@@ -0,0 +1,33 @@
1
+ from transformers import AutoConfig, AutoModelForCausalLM
2
+
3
+ from liger_kernel.transformers.monkey_patch import _apply_liger_kernel
4
+
5
+
6
+ def _get_model_config(model_dir, **model_init_kwargs):
7
+ config = AutoConfig.from_pretrained(model_dir, **model_init_kwargs)
8
+ return config
9
+
10
+
11
+ class AutoLigerKernelForCausalLM(AutoModelForCausalLM):
12
+ """
13
+ This class is a drop-in replacement for AutoModelForCausalLM that applies the Liger Kernel to the model
14
+ if applicable.
15
+ """
16
+
17
+ @classmethod
18
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
19
+ model_config = _get_model_config(pretrained_model_name_or_path, **kwargs)
20
+
21
+ # Determine the model type and apply the Liger Kernel if applicable
22
+ # Note: _apply_liger_kernel will only pass relevant kwargs to the apply_liger_kernel_to_* function
23
+ model_type = model_config.model_type
24
+ _apply_liger_kernel(model_type, **kwargs)
25
+
26
+ # Retain only the keyword args present in the model configuration
27
+ for k in list(kwargs.keys()):
28
+ if k not in model_config.__dict__:
29
+ del kwargs[k]
30
+
31
+ return super().from_pretrained(
32
+ pretrained_model_name_or_path, *model_args, **kwargs
33
+ )
@@ -9,7 +9,7 @@ class LigerFusedLinearCrossEntropyLoss(CrossEntropyLoss):
9
9
  def __init__(self, *args, **kwargs):
10
10
  super(LigerFusedLinearCrossEntropyLoss, self).__init__(*args, **kwargs)
11
11
 
12
- def forward(self, lin_weight, _input, target):
12
+ def forward(self, lin_weight, _input, target, bias=None):
13
13
  return LigerFusedLinearCrossEntropyFunction.apply(
14
- _input, lin_weight, target, self.ignore_index
14
+ _input, lin_weight, target, bias, self.ignore_index
15
15
  )
@@ -13,8 +13,10 @@ class LigerGEGLUMLP(nn.Module):
13
13
  self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
14
14
  self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
15
15
  # TODO: support exact GELU
16
- if config.hidden_act not in ["gelu_pytorch_tanh"]:
17
- raise ValueError(f"Activation function {config.hidden_act} not supported.")
16
+ # Right now Gemma 1, 1.1 and 2 models are all using `gelu_pytorch_tanh`
17
+ # https://github.com/huggingface/transformers/blob/v4.40.1/src/transformers/models/gemma/modeling_gemma.py#L175
18
+ # https://github.com/huggingface/transformers/blob/v4.40.1/src/transformers/activations.py#L46
19
+ # So we can safely assume we use tanh approximation form all the time
18
20
 
19
21
  def forward(self, x):
20
22