liger-kernel 0.5.2__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. liger_kernel/chunked_loss/README.md +25 -0
  2. liger_kernel/chunked_loss/__init__.py +3 -0
  3. liger_kernel/chunked_loss/cpo_loss.py +18 -8
  4. liger_kernel/chunked_loss/dpo_loss.py +20 -10
  5. liger_kernel/chunked_loss/functional.py +4 -0
  6. liger_kernel/chunked_loss/fused_linear_distillation.py +58 -44
  7. liger_kernel/chunked_loss/fused_linear_preference.py +108 -60
  8. liger_kernel/chunked_loss/fused_linear_rlhf.py +213 -0
  9. liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +246 -0
  10. liger_kernel/chunked_loss/grpo_loss.py +160 -0
  11. liger_kernel/chunked_loss/jsd_loss.py +154 -0
  12. liger_kernel/chunked_loss/kto_loss.py +172 -0
  13. liger_kernel/chunked_loss/orpo_loss.py +8 -9
  14. liger_kernel/chunked_loss/simpo_loss.py +22 -8
  15. liger_kernel/env_report.py +5 -12
  16. liger_kernel/ops/cross_entropy.py +102 -51
  17. liger_kernel/ops/experimental/embedding.py +1 -3
  18. liger_kernel/ops/experimental/mm_int8int2.py +3 -9
  19. liger_kernel/ops/fused_linear_cross_entropy.py +89 -55
  20. liger_kernel/ops/fused_linear_jsd.py +14 -32
  21. liger_kernel/ops/geglu.py +6 -17
  22. liger_kernel/ops/group_norm.py +11 -28
  23. liger_kernel/ops/jsd.py +5 -9
  24. liger_kernel/ops/kl_div.py +8 -11
  25. liger_kernel/ops/layer_norm.py +23 -12
  26. liger_kernel/ops/qwen2vl_mrope.py +8 -25
  27. liger_kernel/ops/rms_norm.py +14 -32
  28. liger_kernel/ops/rope.py +31 -33
  29. liger_kernel/ops/swiglu.py +4 -8
  30. liger_kernel/ops/tvd.py +207 -0
  31. liger_kernel/ops/utils.py +3 -2
  32. liger_kernel/transformers/__init__.py +19 -24
  33. liger_kernel/transformers/auto_model.py +6 -13
  34. liger_kernel/transformers/cross_entropy.py +7 -9
  35. liger_kernel/transformers/experimental/embedding.py +1 -3
  36. liger_kernel/transformers/functional.py +28 -7
  37. liger_kernel/transformers/fused_linear_cross_entropy.py +15 -10
  38. liger_kernel/transformers/geglu.py +1 -4
  39. liger_kernel/transformers/group_norm.py +9 -15
  40. liger_kernel/transformers/jsd.py +1 -3
  41. liger_kernel/transformers/kl_div.py +1 -3
  42. liger_kernel/transformers/layer_norm.py +3 -9
  43. liger_kernel/transformers/model/gemma.py +18 -40
  44. liger_kernel/transformers/model/gemma2.py +19 -41
  45. liger_kernel/transformers/model/llama.py +22 -48
  46. liger_kernel/transformers/model/mistral.py +14 -26
  47. liger_kernel/transformers/model/mixtral.py +24 -54
  48. liger_kernel/transformers/model/mllama.py +16 -36
  49. liger_kernel/transformers/model/olmo2.py +124 -0
  50. liger_kernel/transformers/model/phi3.py +18 -40
  51. liger_kernel/transformers/model/qwen2.py +18 -40
  52. liger_kernel/transformers/model/qwen2_vl.py +36 -32
  53. liger_kernel/transformers/monkey_patch.py +214 -144
  54. liger_kernel/transformers/rms_norm.py +4 -4
  55. liger_kernel/transformers/rope.py +2 -2
  56. liger_kernel/transformers/swiglu.py +2 -8
  57. liger_kernel/transformers/trainer/__init__.py +1 -3
  58. liger_kernel/transformers/trainer/orpo_trainer.py +31 -18
  59. liger_kernel/transformers/tvd.py +13 -0
  60. liger_kernel/triton/__init__.py +1 -3
  61. liger_kernel/triton/monkey_patch.py +1 -3
  62. liger_kernel/utils.py +49 -0
  63. {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.4.dist-info}/METADATA +53 -26
  64. liger_kernel-0.5.4.dist-info/RECORD +74 -0
  65. {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.4.dist-info}/WHEEL +1 -1
  66. liger_kernel-0.5.2.dist-info/RECORD +0 -65
  67. {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.4.dist-info}/LICENSE +0 -0
  68. {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.4.dist-info}/NOTICE +0 -0
  69. {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.4.dist-info}/top_level.txt +0 -0
liger_kernel/ops/rope.py CHANGED
@@ -15,6 +15,7 @@ def _triton_rope(
15
15
  sin_row_stride,
16
16
  sl,
17
17
  bs: tl.constexpr,
18
+ cos_bs: tl.constexpr,
18
19
  n_qh: tl.constexpr,
19
20
  n_kh: tl.constexpr,
20
21
  hd: tl.constexpr,
@@ -29,7 +30,7 @@ def _triton_rope(
29
30
  # k size: (bsz, seq_len, num_kv_heads, head_dim)
30
31
  # k stride: (seq_len * num_kv_heads * head_dim, num_kv_heads * head_dim, head_dim, 1)
31
32
 
32
- # cos size: (1, seq_len, head_dim)
33
+ # cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
33
34
  # stride: (seq_len * head_dim, head_dim, 1)
34
35
  pid = tl.program_id(0)
35
36
 
@@ -48,9 +49,19 @@ def _triton_rope(
48
49
  # and pid % sl to get the sequence index.
49
50
  # 2. We only need the left half of cos and sin matrix because the right half is just
50
51
  # a clone of the left half.
51
- cos_row_idx = pid % (sl)
52
- cos = cos + cos_row_idx * cos_row_stride
53
- sin = sin + cos_row_idx * sin_row_stride
52
+ batch_idx = pid // sl
53
+ cos_row_idx = pid % sl
54
+ cos = cos + tl.where(
55
+ cos_bs == 1,
56
+ cos_row_idx * cos_row_stride,
57
+ batch_idx * (sl * cos_row_stride) + cos_row_idx * cos_row_stride,
58
+ )
59
+ sin = sin + tl.where(
60
+ cos_bs == 1,
61
+ cos_row_idx * sin_row_stride,
62
+ batch_idx * (sl * sin_row_stride) + cos_row_idx * sin_row_stride,
63
+ )
64
+
54
65
  cos_offsets = tl.arange(0, pad_hd // 2)
55
66
  cos_mask = cos_offsets < hd // 2
56
67
  cos_row = tl.load(cos + cos_offsets, mask=cos_mask, other=0)
@@ -61,36 +72,20 @@ def _triton_rope(
61
72
  # program instance (i.e. for the current token) separately
62
73
  # ####################################################################
63
74
  # left half of the head
64
- first_half_q_offsets = (
65
- tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
66
- )
67
- first_half_k_offsets = (
68
- tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
69
- )
70
- first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (
71
- tl.arange(0, pad_hd // 2)[None, :] < hd // 2
72
- )
73
- first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (
74
- tl.arange(0, pad_hd // 2)[None, :] < hd // 2
75
- )
76
- q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(
77
- sin_row.dtype
78
- )
79
- k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(
80
- sin_row.dtype
81
- )
75
+ first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
76
+ first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
77
+ first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
78
+ first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
79
+ q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(sin_row.dtype)
80
+ k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(sin_row.dtype)
82
81
 
83
82
  # right half of the head
84
83
  second_half_q_offsets = first_half_q_offsets + (hd // 2)
85
84
  second_half_k_offsets = first_half_k_offsets + (hd // 2)
86
85
  second_q_mask = first_q_mask
87
86
  second_k_mask = first_k_mask
88
- q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(
89
- sin_row.dtype
90
- )
91
- k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(
92
- sin_row.dtype
93
- )
87
+ q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(sin_row.dtype)
88
+ k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(sin_row.dtype)
94
89
 
95
90
  if not BACKWARD_PASS:
96
91
  # y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
@@ -118,7 +113,6 @@ def _triton_rope(
118
113
 
119
114
 
120
115
  def rope_forward(q, k, cos, sin):
121
-
122
116
  # transpose it back to the physical shape because Triton looks at the physical storage
123
117
  # note: q and k are incontiguous before the transformation and will become contiguous after transpose
124
118
  q = q.transpose(1, 2)
@@ -138,6 +132,7 @@ def rope_forward(q, k, cos, sin):
138
132
  k = k.contiguous()
139
133
  cos = cos.contiguous()
140
134
  sin = sin.contiguous()
135
+ cos_batch_size = cos.shape[0]
141
136
 
142
137
  _triton_rope[(n_row,)](
143
138
  q,
@@ -150,6 +145,7 @@ def rope_forward(q, k, cos, sin):
150
145
  sin.stride(-2),
151
146
  seq_len,
152
147
  batch_size,
148
+ cos_batch_size,
153
149
  n_q_head,
154
150
  n_kv_head,
155
151
  head_dim,
@@ -167,6 +163,7 @@ def rope_backward(dq, dk, cos, sin):
167
163
  dk = dk.transpose(1, 2)
168
164
 
169
165
  batch_size, seq_len, n_q_head, head_dim = dq.shape
166
+ cos_batch_size = cos.shape[0]
170
167
  n_kv_head = dk.shape[2]
171
168
  pad_hd = triton.next_power_of_2(head_dim)
172
169
  pad_n_q_head = triton.next_power_of_2(n_q_head)
@@ -191,6 +188,7 @@ def rope_backward(dq, dk, cos, sin):
191
188
  sin.stride(-2),
192
189
  seq_len,
193
190
  batch_size,
191
+ cos_batch_size,
194
192
  n_q_head,
195
193
  n_kv_head,
196
194
  head_dim,
@@ -221,8 +219,8 @@ class LigerRopeFunction(torch.autograd.Function):
221
219
  """
222
220
  q size: (bsz, n_q_head, seq_len, head_dim)
223
221
  k size: (bsz, n_kv_head, seq_len, head_dim)
224
- cos size: (1, seq_len, head_dim)
225
- sin size: (1, seq_len, head_dim)
222
+ cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
223
+ sin size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
226
224
  """
227
225
  q, k, cos, sin = rope_forward(q, k, cos, sin)
228
226
  ctx.save_for_backward(cos, sin)
@@ -232,8 +230,8 @@ class LigerRopeFunction(torch.autograd.Function):
232
230
  """
233
231
  dq size: (bsz, n_q_head, seq_len, head_dim)
234
232
  dk size: (bsz, n_kv_head, seq_len, head_dim)
235
- cos size: (1, seq_len, head_dim)
236
- sin size: (1, seq_len, head_dim)
233
+ cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
234
+ sin size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
237
235
  """
238
236
 
239
237
  cos, sin = ctx.saved_tensors
@@ -2,7 +2,8 @@ import torch
2
2
  import triton
3
3
  import triton.language as tl
4
4
 
5
- from liger_kernel.ops.utils import calculate_settings, ensure_contiguous
5
+ from liger_kernel.ops.utils import calculate_settings
6
+ from liger_kernel.ops.utils import ensure_contiguous
6
7
 
7
8
 
8
9
  @triton.jit
@@ -11,9 +12,7 @@ def silu(x):
11
12
 
12
13
 
13
14
  @triton.jit
14
- def _swiglu_forward_kernel(
15
- a_ptr, b_ptr, c_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
16
- ):
15
+ def _swiglu_forward_kernel(a_ptr, b_ptr, c_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
17
16
  program_id = tl.program_id(0).to(tl.int64)
18
17
 
19
18
  # locate start index
@@ -32,9 +31,7 @@ def _swiglu_forward_kernel(
32
31
 
33
32
 
34
33
  @triton.jit
35
- def _swiglu_backward_kernel(
36
- dc_ptr, a_ptr, b_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
37
- ):
34
+ def _swiglu_backward_kernel(dc_ptr, a_ptr, b_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
38
35
  program_id = tl.program_id(0).to(tl.int64)
39
36
 
40
37
  # locate start index
@@ -84,7 +81,6 @@ def swiglu_forward(a, b):
84
81
 
85
82
 
86
83
  def swiglu_backward(a, b, dc):
87
-
88
84
  ori_shape = dc.shape
89
85
  n_cols = ori_shape[-1]
90
86
  dc = dc.view(-1, n_cols)
@@ -0,0 +1,207 @@
1
+ from typing import Literal
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import triton
6
+ import triton.language as tl
7
+
8
+ from liger_kernel.ops.utils import ensure_contiguous
9
+
10
+ MAX_FUSED_SIZE = 65536 // 4
11
+
12
+ REDUCTION_LITERAL = Literal["none", "sum", "mean", "batchmean"]
13
+
14
+ _REDUCTION_MODE_NONE = tl.constexpr(0)
15
+ _REDUCTION_MODE_SUM = tl.constexpr(1)
16
+ _REDUCTION_MODE_MEAN = tl.constexpr(2)
17
+ _REDUCTION_MODE_BATCHMEAN = tl.constexpr(3)
18
+
19
+ _str_to_reduction_mode = {
20
+ "none": _REDUCTION_MODE_NONE.value,
21
+ "sum": _REDUCTION_MODE_SUM.value,
22
+ "mean": _REDUCTION_MODE_MEAN.value,
23
+ "batchmean": _REDUCTION_MODE_BATCHMEAN.value,
24
+ }
25
+
26
+
27
+ def get_num_warps(BLOCK_SIZE):
28
+ num_warps = 4
29
+ if BLOCK_SIZE >= 32768:
30
+ num_warps = 32
31
+ elif BLOCK_SIZE >= 8192:
32
+ num_warps = 16
33
+ elif BLOCK_SIZE >= 2048:
34
+ num_warps = 8
35
+
36
+ return num_warps
37
+
38
+
39
+ @triton.jit
40
+ def _tv_distance_kernel(
41
+ p_ptr,
42
+ p_stride,
43
+ q_ptr,
44
+ q_stride,
45
+ loss_ptr,
46
+ loss_stride,
47
+ grads_ptr,
48
+ grads_stride,
49
+ label_ptr,
50
+ ignore_index: tl.constexpr,
51
+ n_cols,
52
+ BLOCK_SIZE: tl.constexpr,
53
+ HAS_LABEL: tl.constexpr,
54
+ reduction: tl.constexpr = _REDUCTION_MODE_BATCHMEAN,
55
+ ):
56
+ pid = tl.program_id(0).to(tl.int64)
57
+ p_ptr += pid * p_stride
58
+ q_ptr += pid * q_stride
59
+ loss_ptr += pid * loss_stride
60
+ grads_ptr += pid * grads_stride
61
+ label_ptr += pid
62
+
63
+ base_offsets = tl.arange(0, BLOCK_SIZE)
64
+
65
+ if HAS_LABEL:
66
+ label = tl.load(label_ptr)
67
+ if label == ignore_index:
68
+ for i in range(0, n_cols, BLOCK_SIZE):
69
+ offsets = i + base_offsets
70
+ mask = offsets < n_cols
71
+ tl.store(grads_ptr + offsets, 0.0, mask=mask)
72
+ if reduction == _REDUCTION_MODE_NONE:
73
+ tl.store(loss_ptr + offsets, 0.0, mask=mask)
74
+ return
75
+
76
+ loss_sum = 0.0
77
+ for i in range(0, n_cols, BLOCK_SIZE):
78
+ offsets = i + base_offsets
79
+ mask = offsets < n_cols
80
+
81
+ p = tl.load(p_ptr + offsets, mask=mask, other=0.0)
82
+ q = tl.load(q_ptr + offsets, mask=mask, other=0.0)
83
+
84
+ # TVD(P || Q) = 0.5 * |P - Q|
85
+ tv_loss = 0.5 * tl.abs(p - q)
86
+
87
+ grad_res = tl.where(p > q, 0.5, -0.5)
88
+
89
+ tl.store(grads_ptr + offsets, grad_res, mask=mask)
90
+
91
+ if reduction == _REDUCTION_MODE_NONE:
92
+ tl.store(loss_ptr + offsets, tv_loss, mask=mask)
93
+ else:
94
+ loss_sum += tl.sum(tv_loss, axis=0)
95
+
96
+ if reduction != _REDUCTION_MODE_NONE:
97
+ tl.store(loss_ptr, loss_sum)
98
+
99
+
100
+ def tv_distance_forward_triton(p, q, shift_labels, reduction, ignore_index, has_label):
101
+ BT, V = p.shape
102
+
103
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
104
+ num_warps = get_num_warps(BLOCK_SIZE)
105
+
106
+ grid = (BT,)
107
+
108
+ reduction = _str_to_reduction_mode[reduction]
109
+
110
+ out_size = (BT, V) if reduction == _REDUCTION_MODE_NONE.value else (BT,)
111
+ output_tensor = torch.zeros(out_size, device=p.device, dtype=torch.float32)
112
+ grads = torch.empty_like(p)
113
+
114
+ n_non_ignore = (shift_labels != ignore_index).sum().item() if has_label else BT
115
+
116
+ _tv_distance_kernel[grid](
117
+ p,
118
+ p.stride(0),
119
+ q,
120
+ q.stride(0),
121
+ output_tensor,
122
+ output_tensor.stride(0),
123
+ grads,
124
+ grads.stride(0),
125
+ shift_labels if has_label else torch.empty(1, device=p.device),
126
+ ignore_index,
127
+ V,
128
+ BLOCK_SIZE=BLOCK_SIZE,
129
+ HAS_LABEL=has_label,
130
+ num_warps=num_warps,
131
+ reduction=reduction,
132
+ )
133
+
134
+ if reduction == _REDUCTION_MODE_BATCHMEAN.value:
135
+ return output_tensor.sum() / n_non_ignore, grads / n_non_ignore
136
+ elif reduction == _REDUCTION_MODE_SUM.value:
137
+ return output_tensor.sum(dim=0), grads
138
+ elif reduction == _REDUCTION_MODE_MEAN.value:
139
+ return output_tensor.sum() / (n_non_ignore * V), grads / (n_non_ignore * V)
140
+ else:
141
+ return output_tensor, grads
142
+
143
+
144
+ def tvd_backward_triton(grad_output, grads):
145
+ # If cross entropy is the last layer, grad_output is 1.0. Skip the mul then.
146
+ if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
147
+ return grads
148
+
149
+ return grads * grad_output
150
+
151
+
152
+ class LigerTVDLossFunction(torch.autograd.Function):
153
+ """
154
+ Class implementing the forward and backward pass for the Total Variation Distance Loss using Triton.
155
+ """
156
+
157
+ @staticmethod
158
+ @ensure_contiguous
159
+ def forward(
160
+ ctx,
161
+ p: torch.Tensor,
162
+ q: torch.Tensor,
163
+ shift_labels: Optional[torch.Tensor] = None,
164
+ reduction: REDUCTION_LITERAL = "batchmean",
165
+ ignore_index: int = -100,
166
+ ) -> torch.Tensor:
167
+ """A forward pass for the Total Variation Distance Loss.
168
+
169
+ Args:
170
+ ctx: Torch autograd context
171
+ p (torch.Tensor): A tensor of shape (BT, V) containing the first distribution.
172
+ q (torch.Tensor): A tensor of shape (BT, V) containing the second distribution.
173
+ shift_labels (Optional[torch.Tensor]): A tensor of shape (BT,) containing the labels.
174
+ reduction (REDUCTION_LITERAL, optional): The reduction method to be applied. Defaults to "batchmean".
175
+ ignore_index (int, optional): The index to ignore during loss calculation. Defaults to -100.
176
+
177
+ Returns:
178
+ torch.Tensor: The computed Total Variation Distance Loss.
179
+ """
180
+ has_label = False
181
+ if shift_labels is not None:
182
+ assert shift_labels.shape == (p.shape[0],), (
183
+ f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}"
184
+ )
185
+ shift_labels = shift_labels.contiguous()
186
+ has_label = True
187
+
188
+ loss, grads = tv_distance_forward_triton(p, q, shift_labels, reduction, ignore_index, has_label)
189
+ ctx.save_for_backward(grads)
190
+ return loss
191
+
192
+ @staticmethod
193
+ @ensure_contiguous
194
+ def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
195
+ """A backward pass for the Total Variation Distance Loss.
196
+
197
+ Args:
198
+ ctx: Torch autograd context
199
+ grad_output (torch.Tensor): The gradient of the loss with respect to the output.
200
+
201
+ Returns:
202
+ tuple[torch.Tensor, None, None, None, None]: The gradient of the loss with respect to the inputs.
203
+ """
204
+ (grads,) = ctx.saved_tensors
205
+ grads = tvd_backward_triton(grad_output, grads)
206
+
207
+ return grads, None, None, None, None
liger_kernel/ops/utils.py CHANGED
@@ -13,11 +13,13 @@ Modifications made by Yanning Chen, 2024.
13
13
  import functools
14
14
  import importlib
15
15
  import operator
16
+
16
17
  from typing import Callable
17
18
 
18
19
  import torch
19
20
  import triton
20
21
  import triton.language as tl
22
+
21
23
  from packaging.version import Version
22
24
 
23
25
  from liger_kernel.utils import infer_device
@@ -47,8 +49,7 @@ def calculate_settings(n):
47
49
  BLOCK_SIZE = triton.next_power_of_2(n)
48
50
  if BLOCK_SIZE > MAX_FUSED_SIZE:
49
51
  raise RuntimeError(
50
- f"Cannot launch Triton kernel since n = {n} exceeds "
51
- f"the recommended Triton blocksize = {MAX_FUSED_SIZE}."
52
+ f"Cannot launch Triton kernel since n = {n} exceeds the recommended Triton blocksize = {MAX_FUSED_SIZE}."
52
53
  )
53
54
 
54
55
  num_warps = 4
@@ -1,31 +1,26 @@
1
- from liger_kernel.transformers.auto_model import ( # noqa: F401
2
- AutoLigerKernelForCausalLM,
3
- )
1
+ from liger_kernel.transformers.auto_model import AutoLigerKernelForCausalLM # noqa: F401
4
2
  from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss # noqa: F401
5
- from liger_kernel.transformers.fused_linear_cross_entropy import ( # noqa: F401
6
- LigerFusedLinearCrossEntropyLoss,
7
- )
3
+ from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss # noqa: F401
8
4
  from liger_kernel.transformers.fused_linear_jsd import LigerFusedLinearJSD # noqa: F401
9
5
  from liger_kernel.transformers.geglu import LigerGEGLUMLP # noqa: F401
10
6
  from liger_kernel.transformers.jsd import LigerJSD # noqa: F401
11
7
  from liger_kernel.transformers.layer_norm import LigerLayerNorm # noqa: F401
12
- from liger_kernel.transformers.monkey_patch import ( # noqa: F401
13
- _apply_liger_kernel,
14
- _apply_liger_kernel_to_instance,
15
- apply_liger_kernel_to_gemma,
16
- apply_liger_kernel_to_gemma2,
17
- apply_liger_kernel_to_llama,
18
- apply_liger_kernel_to_mistral,
19
- apply_liger_kernel_to_mixtral,
20
- apply_liger_kernel_to_mllama,
21
- apply_liger_kernel_to_phi3,
22
- apply_liger_kernel_to_qwen2,
23
- apply_liger_kernel_to_qwen2_vl,
24
- )
8
+ from liger_kernel.transformers.monkey_patch import _apply_liger_kernel # noqa: F401
9
+ from liger_kernel.transformers.monkey_patch import _apply_liger_kernel_to_instance # noqa: F401
10
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma # noqa: F401
11
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma2 # noqa: F401
12
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_granite # noqa: F401
13
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama # noqa: F401
14
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mistral # noqa: F401
15
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mixtral # noqa: F401
16
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mllama # noqa: F401
17
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_olmo2 # noqa: F401
18
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_phi3 # noqa: F401
19
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2 # noqa: F401
20
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2_vl # noqa: F401
25
21
  from liger_kernel.transformers.rms_norm import LigerRMSNorm # noqa: F401
26
22
  from liger_kernel.transformers.rope import liger_rotary_pos_emb # noqa: F401
27
- from liger_kernel.transformers.swiglu import ( # noqa: F401
28
- LigerBlockSparseTop2MLP,
29
- LigerPhi3SwiGLUMLP,
30
- LigerSwiGLUMLP,
31
- )
23
+ from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP # noqa: F401
24
+ from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP # noqa: F401
25
+ from liger_kernel.transformers.swiglu import LigerSwiGLUMLP # noqa: F401
26
+ from liger_kernel.transformers.tvd import LigerTVDLoss # noqa: F401
@@ -1,11 +1,10 @@
1
1
  import inspect
2
2
 
3
- from transformers import AutoConfig, AutoModelForCausalLM
3
+ from transformers import AutoConfig
4
+ from transformers import AutoModelForCausalLM
4
5
 
5
- from liger_kernel.transformers.monkey_patch import (
6
- MODEL_TYPE_TO_APPLY_LIGER_FN,
7
- _apply_liger_kernel,
8
- )
6
+ from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
7
+ from liger_kernel.transformers.monkey_patch import _apply_liger_kernel
9
8
 
10
9
 
11
10
  def _get_model_config(model_dir, **model_init_kwargs):
@@ -34,12 +33,6 @@ class AutoLigerKernelForCausalLM(AutoModelForCausalLM):
34
33
  apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
35
34
  apply_fn_signature = inspect.signature(apply_fn)
36
35
 
37
- applicable_kwargs = {
38
- key: value
39
- for key, value in kwargs.items()
40
- if key not in apply_fn_signature.parameters
41
- }
36
+ applicable_kwargs = {key: value for key, value in kwargs.items() if key not in apply_fn_signature.parameters}
42
37
 
43
- return super().from_pretrained(
44
- pretrained_model_name_or_path, *model_args, **applicable_kwargs
45
- )
38
+ return super().from_pretrained(pretrained_model_name_or_path, *model_args, **applicable_kwargs)
@@ -8,6 +8,7 @@ from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
8
8
  class LigerCrossEntropyLoss(torch.nn.Module):
9
9
  def __init__(
10
10
  self,
11
+ weight: Optional[torch.FloatTensor] = None,
11
12
  ignore_index: int = -100,
12
13
  lse_square_scale: float = 0.0,
13
14
  label_smoothing: float = 0.0,
@@ -16,20 +17,16 @@ class LigerCrossEntropyLoss(torch.nn.Module):
16
17
  return_z_loss: bool = False,
17
18
  ):
18
19
  super().__init__()
19
- assert (label_smoothing >= 0) and (
20
- label_smoothing <= 1
21
- ), f"label_smoothing must be between 0.0 and 1.0. Got: {label_smoothing}"
22
- assert (label_smoothing >= 0) and (
23
- label_smoothing <= 1
24
- ), f"label_smoothing must be between 0.0 and 1.0. Got: {label_smoothing}"
20
+ assert (label_smoothing >= 0) and (label_smoothing <= 1), (
21
+ f"label_smoothing must be between 0.0 and 1.0. Got: {label_smoothing}"
22
+ )
25
23
  assert reduction in {
26
24
  "mean",
27
25
  "sum",
28
26
  "none",
29
27
  }, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {reduction}"
30
- assert (
31
- softcap is None or softcap > 0
32
- ), f"softcap must greater than 0.0 or None. Got: {softcap}"
28
+ assert softcap is None or softcap > 0, f"softcap must greater than 0.0 or None. Got: {softcap}"
29
+ self.weight = weight
33
30
  self.ignore_index = ignore_index
34
31
  self.lse_square_scale = lse_square_scale
35
32
  self.label_smoothing = label_smoothing
@@ -41,6 +38,7 @@ class LigerCrossEntropyLoss(torch.nn.Module):
41
38
  loss, z_loss = LigerCrossEntropyFunction.apply(
42
39
  _input,
43
40
  target,
41
+ self.weight,
44
42
  self.ignore_index,
45
43
  self.lse_square_scale,
46
44
  self.label_smoothing,
@@ -7,9 +7,7 @@ from liger_kernel.ops.experimental.embedding import LigerEmbeddingFunction
7
7
 
8
8
 
9
9
  class LigerEmbedding(nn.Module):
10
- def __init__(
11
- self, num_embeddings, embedding_dim, padding_idx: Optional[int] = None
12
- ):
10
+ def __init__(self, num_embeddings, embedding_dim, padding_idx: Optional[int] = None):
13
11
  super().__init__()
14
12
  self.num_embeddings = num_embeddings
15
13
  self.embedding_dim = embedding_dim
@@ -1,9 +1,7 @@
1
1
  from typing import Optional
2
2
 
3
3
  from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
4
- from liger_kernel.ops.fused_linear_cross_entropy import (
5
- LigerFusedLinearCrossEntropyFunction,
6
- )
4
+ from liger_kernel.ops.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyFunction
7
5
  from liger_kernel.ops.fused_linear_jsd import LigerFusedLinearJSDFunction
8
6
  from liger_kernel.ops.geglu import LigerGELUMulFunction
9
7
  from liger_kernel.ops.group_norm import LigerGroupNormFunction
@@ -14,6 +12,7 @@ from liger_kernel.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction
14
12
  from liger_kernel.ops.rms_norm import LigerRMSNormFunction
15
13
  from liger_kernel.ops.rope import LigerRopeFunction
16
14
  from liger_kernel.ops.swiglu import LigerSiLUMulFunction
15
+ from liger_kernel.ops.tvd import LigerTVDLossFunction
17
16
 
18
17
 
19
18
  # conform to the function signature in https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html
@@ -34,6 +33,7 @@ def liger_cross_entropy(
34
33
  loss, z_loss = LigerCrossEntropyFunction.apply(
35
34
  input,
36
35
  target,
36
+ weight,
37
37
  ignore_index,
38
38
  lse_square_scale,
39
39
  label_smoothing,
@@ -51,23 +51,30 @@ def liger_fused_linear_cross_entropy(
51
51
  weight,
52
52
  target,
53
53
  bias=None,
54
+ ce_weight=None,
54
55
  ignore_index: int = -100,
55
56
  lse_square_scale: float = 0.0,
56
57
  label_smoothing: float = 0.0,
57
58
  reduction: str = "mean",
58
59
  softcap: Optional[float] = None,
60
+ return_z_loss: bool = False,
59
61
  ):
60
- return LigerFusedLinearCrossEntropyFunction.apply(
62
+ loss, z_loss = LigerFusedLinearCrossEntropyFunction.apply(
61
63
  input,
62
64
  weight,
63
65
  target,
64
66
  bias,
67
+ ce_weight,
65
68
  ignore_index,
66
69
  lse_square_scale,
67
70
  label_smoothing,
68
71
  reduction,
69
72
  softcap,
73
+ return_z_loss,
70
74
  )
75
+ if not return_z_loss:
76
+ return loss
77
+ return loss, z_loss
71
78
 
72
79
 
73
80
  def liger_fused_linear_jsd(
@@ -151,6 +158,22 @@ def liger_kl_div(
151
158
  )
152
159
 
153
160
 
161
+ def liger_tvd(
162
+ input,
163
+ target,
164
+ shift_labels=None,
165
+ reduction: str = "mean",
166
+ ignore_index: int = -100,
167
+ ):
168
+ return LigerTVDLossFunction.apply(
169
+ input,
170
+ target,
171
+ shift_labels,
172
+ reduction,
173
+ ignore_index,
174
+ )
175
+
176
+
154
177
  def liger_layer_norm(X, W, B, eps):
155
178
  return LigerLayerNormFunction.apply(X, W, B, eps)
156
179
 
@@ -159,9 +182,7 @@ def liger_qwen2vl_mrope(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
159
182
  return LigerQwen2VLMRopeFunction.apply(q, k, cos, sin, mrope_section, unsqueeze_dim)
160
183
 
161
184
 
162
- def liger_rms_norm(
163
- X, W, eps, offset: float = 0.0, casting_mode: str = "llama", in_place: bool = True
164
- ):
185
+ def liger_rms_norm(X, W, eps, offset: float = 0.0, casting_mode: str = "llama", in_place: bool = True):
165
186
  return LigerRMSNormFunction.apply(X, W, eps, offset, casting_mode, in_place)
166
187
 
167
188