liger-kernel 0.5.2__py3-none-any.whl → 0.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/README.md +25 -0
- liger_kernel/chunked_loss/__init__.py +3 -0
- liger_kernel/chunked_loss/cpo_loss.py +18 -8
- liger_kernel/chunked_loss/dpo_loss.py +20 -10
- liger_kernel/chunked_loss/functional.py +4 -0
- liger_kernel/chunked_loss/fused_linear_distillation.py +58 -44
- liger_kernel/chunked_loss/fused_linear_preference.py +108 -60
- liger_kernel/chunked_loss/fused_linear_rlhf.py +213 -0
- liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +246 -0
- liger_kernel/chunked_loss/grpo_loss.py +160 -0
- liger_kernel/chunked_loss/jsd_loss.py +154 -0
- liger_kernel/chunked_loss/kto_loss.py +172 -0
- liger_kernel/chunked_loss/orpo_loss.py +8 -9
- liger_kernel/chunked_loss/simpo_loss.py +22 -8
- liger_kernel/env_report.py +5 -12
- liger_kernel/ops/cross_entropy.py +102 -51
- liger_kernel/ops/experimental/embedding.py +1 -3
- liger_kernel/ops/experimental/mm_int8int2.py +3 -9
- liger_kernel/ops/fused_linear_cross_entropy.py +89 -55
- liger_kernel/ops/fused_linear_jsd.py +14 -32
- liger_kernel/ops/geglu.py +6 -17
- liger_kernel/ops/group_norm.py +11 -28
- liger_kernel/ops/jsd.py +5 -9
- liger_kernel/ops/kl_div.py +8 -11
- liger_kernel/ops/layer_norm.py +23 -12
- liger_kernel/ops/qwen2vl_mrope.py +8 -25
- liger_kernel/ops/rms_norm.py +14 -32
- liger_kernel/ops/rope.py +31 -33
- liger_kernel/ops/swiglu.py +4 -8
- liger_kernel/ops/tvd.py +207 -0
- liger_kernel/ops/utils.py +3 -2
- liger_kernel/transformers/__init__.py +19 -24
- liger_kernel/transformers/auto_model.py +6 -13
- liger_kernel/transformers/cross_entropy.py +7 -9
- liger_kernel/transformers/experimental/embedding.py +1 -3
- liger_kernel/transformers/functional.py +28 -7
- liger_kernel/transformers/fused_linear_cross_entropy.py +15 -10
- liger_kernel/transformers/geglu.py +1 -4
- liger_kernel/transformers/group_norm.py +9 -15
- liger_kernel/transformers/jsd.py +1 -3
- liger_kernel/transformers/kl_div.py +1 -3
- liger_kernel/transformers/layer_norm.py +3 -9
- liger_kernel/transformers/model/gemma.py +18 -40
- liger_kernel/transformers/model/gemma2.py +19 -41
- liger_kernel/transformers/model/llama.py +22 -48
- liger_kernel/transformers/model/mistral.py +14 -26
- liger_kernel/transformers/model/mixtral.py +24 -54
- liger_kernel/transformers/model/mllama.py +16 -36
- liger_kernel/transformers/model/olmo2.py +124 -0
- liger_kernel/transformers/model/phi3.py +18 -40
- liger_kernel/transformers/model/qwen2.py +18 -40
- liger_kernel/transformers/model/qwen2_vl.py +36 -32
- liger_kernel/transformers/monkey_patch.py +214 -144
- liger_kernel/transformers/rms_norm.py +4 -4
- liger_kernel/transformers/rope.py +2 -2
- liger_kernel/transformers/swiglu.py +2 -8
- liger_kernel/transformers/trainer/__init__.py +1 -3
- liger_kernel/transformers/trainer/orpo_trainer.py +31 -18
- liger_kernel/transformers/tvd.py +13 -0
- liger_kernel/triton/__init__.py +1 -3
- liger_kernel/triton/monkey_patch.py +1 -3
- liger_kernel/utils.py +49 -0
- {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.4.dist-info}/METADATA +53 -26
- liger_kernel-0.5.4.dist-info/RECORD +74 -0
- {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.4.dist-info}/WHEEL +1 -1
- liger_kernel-0.5.2.dist-info/RECORD +0 -65
- {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.4.dist-info}/LICENSE +0 -0
- {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.4.dist-info}/NOTICE +0 -0
- {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.4.dist-info}/top_level.txt +0 -0
liger_kernel/ops/rope.py
CHANGED
|
@@ -15,6 +15,7 @@ def _triton_rope(
|
|
|
15
15
|
sin_row_stride,
|
|
16
16
|
sl,
|
|
17
17
|
bs: tl.constexpr,
|
|
18
|
+
cos_bs: tl.constexpr,
|
|
18
19
|
n_qh: tl.constexpr,
|
|
19
20
|
n_kh: tl.constexpr,
|
|
20
21
|
hd: tl.constexpr,
|
|
@@ -29,7 +30,7 @@ def _triton_rope(
|
|
|
29
30
|
# k size: (bsz, seq_len, num_kv_heads, head_dim)
|
|
30
31
|
# k stride: (seq_len * num_kv_heads * head_dim, num_kv_heads * head_dim, head_dim, 1)
|
|
31
32
|
|
|
32
|
-
# cos size: (1, seq_len, head_dim)
|
|
33
|
+
# cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
|
|
33
34
|
# stride: (seq_len * head_dim, head_dim, 1)
|
|
34
35
|
pid = tl.program_id(0)
|
|
35
36
|
|
|
@@ -48,9 +49,19 @@ def _triton_rope(
|
|
|
48
49
|
# and pid % sl to get the sequence index.
|
|
49
50
|
# 2. We only need the left half of cos and sin matrix because the right half is just
|
|
50
51
|
# a clone of the left half.
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
52
|
+
batch_idx = pid // sl
|
|
53
|
+
cos_row_idx = pid % sl
|
|
54
|
+
cos = cos + tl.where(
|
|
55
|
+
cos_bs == 1,
|
|
56
|
+
cos_row_idx * cos_row_stride,
|
|
57
|
+
batch_idx * (sl * cos_row_stride) + cos_row_idx * cos_row_stride,
|
|
58
|
+
)
|
|
59
|
+
sin = sin + tl.where(
|
|
60
|
+
cos_bs == 1,
|
|
61
|
+
cos_row_idx * sin_row_stride,
|
|
62
|
+
batch_idx * (sl * sin_row_stride) + cos_row_idx * sin_row_stride,
|
|
63
|
+
)
|
|
64
|
+
|
|
54
65
|
cos_offsets = tl.arange(0, pad_hd // 2)
|
|
55
66
|
cos_mask = cos_offsets < hd // 2
|
|
56
67
|
cos_row = tl.load(cos + cos_offsets, mask=cos_mask, other=0)
|
|
@@ -61,36 +72,20 @@ def _triton_rope(
|
|
|
61
72
|
# program instance (i.e. for the current token) separately
|
|
62
73
|
# ####################################################################
|
|
63
74
|
# left half of the head
|
|
64
|
-
first_half_q_offsets = (
|
|
65
|
-
|
|
66
|
-
)
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
)
|
|
70
|
-
first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (
|
|
71
|
-
tl.arange(0, pad_hd // 2)[None, :] < hd // 2
|
|
72
|
-
)
|
|
73
|
-
first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (
|
|
74
|
-
tl.arange(0, pad_hd // 2)[None, :] < hd // 2
|
|
75
|
-
)
|
|
76
|
-
q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(
|
|
77
|
-
sin_row.dtype
|
|
78
|
-
)
|
|
79
|
-
k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(
|
|
80
|
-
sin_row.dtype
|
|
81
|
-
)
|
|
75
|
+
first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
|
|
76
|
+
first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
|
|
77
|
+
first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
|
|
78
|
+
first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
|
|
79
|
+
q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(sin_row.dtype)
|
|
80
|
+
k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(sin_row.dtype)
|
|
82
81
|
|
|
83
82
|
# right half of the head
|
|
84
83
|
second_half_q_offsets = first_half_q_offsets + (hd // 2)
|
|
85
84
|
second_half_k_offsets = first_half_k_offsets + (hd // 2)
|
|
86
85
|
second_q_mask = first_q_mask
|
|
87
86
|
second_k_mask = first_k_mask
|
|
88
|
-
q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(
|
|
89
|
-
|
|
90
|
-
)
|
|
91
|
-
k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(
|
|
92
|
-
sin_row.dtype
|
|
93
|
-
)
|
|
87
|
+
q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(sin_row.dtype)
|
|
88
|
+
k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(sin_row.dtype)
|
|
94
89
|
|
|
95
90
|
if not BACKWARD_PASS:
|
|
96
91
|
# y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
|
|
@@ -118,7 +113,6 @@ def _triton_rope(
|
|
|
118
113
|
|
|
119
114
|
|
|
120
115
|
def rope_forward(q, k, cos, sin):
|
|
121
|
-
|
|
122
116
|
# transpose it back to the physical shape because Triton looks at the physical storage
|
|
123
117
|
# note: q and k are incontiguous before the transformation and will become contiguous after transpose
|
|
124
118
|
q = q.transpose(1, 2)
|
|
@@ -138,6 +132,7 @@ def rope_forward(q, k, cos, sin):
|
|
|
138
132
|
k = k.contiguous()
|
|
139
133
|
cos = cos.contiguous()
|
|
140
134
|
sin = sin.contiguous()
|
|
135
|
+
cos_batch_size = cos.shape[0]
|
|
141
136
|
|
|
142
137
|
_triton_rope[(n_row,)](
|
|
143
138
|
q,
|
|
@@ -150,6 +145,7 @@ def rope_forward(q, k, cos, sin):
|
|
|
150
145
|
sin.stride(-2),
|
|
151
146
|
seq_len,
|
|
152
147
|
batch_size,
|
|
148
|
+
cos_batch_size,
|
|
153
149
|
n_q_head,
|
|
154
150
|
n_kv_head,
|
|
155
151
|
head_dim,
|
|
@@ -167,6 +163,7 @@ def rope_backward(dq, dk, cos, sin):
|
|
|
167
163
|
dk = dk.transpose(1, 2)
|
|
168
164
|
|
|
169
165
|
batch_size, seq_len, n_q_head, head_dim = dq.shape
|
|
166
|
+
cos_batch_size = cos.shape[0]
|
|
170
167
|
n_kv_head = dk.shape[2]
|
|
171
168
|
pad_hd = triton.next_power_of_2(head_dim)
|
|
172
169
|
pad_n_q_head = triton.next_power_of_2(n_q_head)
|
|
@@ -191,6 +188,7 @@ def rope_backward(dq, dk, cos, sin):
|
|
|
191
188
|
sin.stride(-2),
|
|
192
189
|
seq_len,
|
|
193
190
|
batch_size,
|
|
191
|
+
cos_batch_size,
|
|
194
192
|
n_q_head,
|
|
195
193
|
n_kv_head,
|
|
196
194
|
head_dim,
|
|
@@ -221,8 +219,8 @@ class LigerRopeFunction(torch.autograd.Function):
|
|
|
221
219
|
"""
|
|
222
220
|
q size: (bsz, n_q_head, seq_len, head_dim)
|
|
223
221
|
k size: (bsz, n_kv_head, seq_len, head_dim)
|
|
224
|
-
cos size: (1, seq_len, head_dim)
|
|
225
|
-
sin size: (1, seq_len, head_dim)
|
|
222
|
+
cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
|
|
223
|
+
sin size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
|
|
226
224
|
"""
|
|
227
225
|
q, k, cos, sin = rope_forward(q, k, cos, sin)
|
|
228
226
|
ctx.save_for_backward(cos, sin)
|
|
@@ -232,8 +230,8 @@ class LigerRopeFunction(torch.autograd.Function):
|
|
|
232
230
|
"""
|
|
233
231
|
dq size: (bsz, n_q_head, seq_len, head_dim)
|
|
234
232
|
dk size: (bsz, n_kv_head, seq_len, head_dim)
|
|
235
|
-
cos size: (1, seq_len, head_dim)
|
|
236
|
-
sin size: (1, seq_len, head_dim)
|
|
233
|
+
cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
|
|
234
|
+
sin size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
|
|
237
235
|
"""
|
|
238
236
|
|
|
239
237
|
cos, sin = ctx.saved_tensors
|
liger_kernel/ops/swiglu.py
CHANGED
|
@@ -2,7 +2,8 @@ import torch
|
|
|
2
2
|
import triton
|
|
3
3
|
import triton.language as tl
|
|
4
4
|
|
|
5
|
-
from liger_kernel.ops.utils import calculate_settings
|
|
5
|
+
from liger_kernel.ops.utils import calculate_settings
|
|
6
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
|
6
7
|
|
|
7
8
|
|
|
8
9
|
@triton.jit
|
|
@@ -11,9 +12,7 @@ def silu(x):
|
|
|
11
12
|
|
|
12
13
|
|
|
13
14
|
@triton.jit
|
|
14
|
-
def _swiglu_forward_kernel(
|
|
15
|
-
a_ptr, b_ptr, c_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
|
|
16
|
-
):
|
|
15
|
+
def _swiglu_forward_kernel(a_ptr, b_ptr, c_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
|
|
17
16
|
program_id = tl.program_id(0).to(tl.int64)
|
|
18
17
|
|
|
19
18
|
# locate start index
|
|
@@ -32,9 +31,7 @@ def _swiglu_forward_kernel(
|
|
|
32
31
|
|
|
33
32
|
|
|
34
33
|
@triton.jit
|
|
35
|
-
def _swiglu_backward_kernel(
|
|
36
|
-
dc_ptr, a_ptr, b_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
|
|
37
|
-
):
|
|
34
|
+
def _swiglu_backward_kernel(dc_ptr, a_ptr, b_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
|
|
38
35
|
program_id = tl.program_id(0).to(tl.int64)
|
|
39
36
|
|
|
40
37
|
# locate start index
|
|
@@ -84,7 +81,6 @@ def swiglu_forward(a, b):
|
|
|
84
81
|
|
|
85
82
|
|
|
86
83
|
def swiglu_backward(a, b, dc):
|
|
87
|
-
|
|
88
84
|
ori_shape = dc.shape
|
|
89
85
|
n_cols = ori_shape[-1]
|
|
90
86
|
dc = dc.view(-1, n_cols)
|
liger_kernel/ops/tvd.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
1
|
+
from typing import Literal
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import triton
|
|
6
|
+
import triton.language as tl
|
|
7
|
+
|
|
8
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
|
9
|
+
|
|
10
|
+
MAX_FUSED_SIZE = 65536 // 4
|
|
11
|
+
|
|
12
|
+
REDUCTION_LITERAL = Literal["none", "sum", "mean", "batchmean"]
|
|
13
|
+
|
|
14
|
+
_REDUCTION_MODE_NONE = tl.constexpr(0)
|
|
15
|
+
_REDUCTION_MODE_SUM = tl.constexpr(1)
|
|
16
|
+
_REDUCTION_MODE_MEAN = tl.constexpr(2)
|
|
17
|
+
_REDUCTION_MODE_BATCHMEAN = tl.constexpr(3)
|
|
18
|
+
|
|
19
|
+
_str_to_reduction_mode = {
|
|
20
|
+
"none": _REDUCTION_MODE_NONE.value,
|
|
21
|
+
"sum": _REDUCTION_MODE_SUM.value,
|
|
22
|
+
"mean": _REDUCTION_MODE_MEAN.value,
|
|
23
|
+
"batchmean": _REDUCTION_MODE_BATCHMEAN.value,
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def get_num_warps(BLOCK_SIZE):
|
|
28
|
+
num_warps = 4
|
|
29
|
+
if BLOCK_SIZE >= 32768:
|
|
30
|
+
num_warps = 32
|
|
31
|
+
elif BLOCK_SIZE >= 8192:
|
|
32
|
+
num_warps = 16
|
|
33
|
+
elif BLOCK_SIZE >= 2048:
|
|
34
|
+
num_warps = 8
|
|
35
|
+
|
|
36
|
+
return num_warps
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@triton.jit
|
|
40
|
+
def _tv_distance_kernel(
|
|
41
|
+
p_ptr,
|
|
42
|
+
p_stride,
|
|
43
|
+
q_ptr,
|
|
44
|
+
q_stride,
|
|
45
|
+
loss_ptr,
|
|
46
|
+
loss_stride,
|
|
47
|
+
grads_ptr,
|
|
48
|
+
grads_stride,
|
|
49
|
+
label_ptr,
|
|
50
|
+
ignore_index: tl.constexpr,
|
|
51
|
+
n_cols,
|
|
52
|
+
BLOCK_SIZE: tl.constexpr,
|
|
53
|
+
HAS_LABEL: tl.constexpr,
|
|
54
|
+
reduction: tl.constexpr = _REDUCTION_MODE_BATCHMEAN,
|
|
55
|
+
):
|
|
56
|
+
pid = tl.program_id(0).to(tl.int64)
|
|
57
|
+
p_ptr += pid * p_stride
|
|
58
|
+
q_ptr += pid * q_stride
|
|
59
|
+
loss_ptr += pid * loss_stride
|
|
60
|
+
grads_ptr += pid * grads_stride
|
|
61
|
+
label_ptr += pid
|
|
62
|
+
|
|
63
|
+
base_offsets = tl.arange(0, BLOCK_SIZE)
|
|
64
|
+
|
|
65
|
+
if HAS_LABEL:
|
|
66
|
+
label = tl.load(label_ptr)
|
|
67
|
+
if label == ignore_index:
|
|
68
|
+
for i in range(0, n_cols, BLOCK_SIZE):
|
|
69
|
+
offsets = i + base_offsets
|
|
70
|
+
mask = offsets < n_cols
|
|
71
|
+
tl.store(grads_ptr + offsets, 0.0, mask=mask)
|
|
72
|
+
if reduction == _REDUCTION_MODE_NONE:
|
|
73
|
+
tl.store(loss_ptr + offsets, 0.0, mask=mask)
|
|
74
|
+
return
|
|
75
|
+
|
|
76
|
+
loss_sum = 0.0
|
|
77
|
+
for i in range(0, n_cols, BLOCK_SIZE):
|
|
78
|
+
offsets = i + base_offsets
|
|
79
|
+
mask = offsets < n_cols
|
|
80
|
+
|
|
81
|
+
p = tl.load(p_ptr + offsets, mask=mask, other=0.0)
|
|
82
|
+
q = tl.load(q_ptr + offsets, mask=mask, other=0.0)
|
|
83
|
+
|
|
84
|
+
# TVD(P || Q) = 0.5 * |P - Q|
|
|
85
|
+
tv_loss = 0.5 * tl.abs(p - q)
|
|
86
|
+
|
|
87
|
+
grad_res = tl.where(p > q, 0.5, -0.5)
|
|
88
|
+
|
|
89
|
+
tl.store(grads_ptr + offsets, grad_res, mask=mask)
|
|
90
|
+
|
|
91
|
+
if reduction == _REDUCTION_MODE_NONE:
|
|
92
|
+
tl.store(loss_ptr + offsets, tv_loss, mask=mask)
|
|
93
|
+
else:
|
|
94
|
+
loss_sum += tl.sum(tv_loss, axis=0)
|
|
95
|
+
|
|
96
|
+
if reduction != _REDUCTION_MODE_NONE:
|
|
97
|
+
tl.store(loss_ptr, loss_sum)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def tv_distance_forward_triton(p, q, shift_labels, reduction, ignore_index, has_label):
|
|
101
|
+
BT, V = p.shape
|
|
102
|
+
|
|
103
|
+
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
|
104
|
+
num_warps = get_num_warps(BLOCK_SIZE)
|
|
105
|
+
|
|
106
|
+
grid = (BT,)
|
|
107
|
+
|
|
108
|
+
reduction = _str_to_reduction_mode[reduction]
|
|
109
|
+
|
|
110
|
+
out_size = (BT, V) if reduction == _REDUCTION_MODE_NONE.value else (BT,)
|
|
111
|
+
output_tensor = torch.zeros(out_size, device=p.device, dtype=torch.float32)
|
|
112
|
+
grads = torch.empty_like(p)
|
|
113
|
+
|
|
114
|
+
n_non_ignore = (shift_labels != ignore_index).sum().item() if has_label else BT
|
|
115
|
+
|
|
116
|
+
_tv_distance_kernel[grid](
|
|
117
|
+
p,
|
|
118
|
+
p.stride(0),
|
|
119
|
+
q,
|
|
120
|
+
q.stride(0),
|
|
121
|
+
output_tensor,
|
|
122
|
+
output_tensor.stride(0),
|
|
123
|
+
grads,
|
|
124
|
+
grads.stride(0),
|
|
125
|
+
shift_labels if has_label else torch.empty(1, device=p.device),
|
|
126
|
+
ignore_index,
|
|
127
|
+
V,
|
|
128
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
129
|
+
HAS_LABEL=has_label,
|
|
130
|
+
num_warps=num_warps,
|
|
131
|
+
reduction=reduction,
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
if reduction == _REDUCTION_MODE_BATCHMEAN.value:
|
|
135
|
+
return output_tensor.sum() / n_non_ignore, grads / n_non_ignore
|
|
136
|
+
elif reduction == _REDUCTION_MODE_SUM.value:
|
|
137
|
+
return output_tensor.sum(dim=0), grads
|
|
138
|
+
elif reduction == _REDUCTION_MODE_MEAN.value:
|
|
139
|
+
return output_tensor.sum() / (n_non_ignore * V), grads / (n_non_ignore * V)
|
|
140
|
+
else:
|
|
141
|
+
return output_tensor, grads
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def tvd_backward_triton(grad_output, grads):
|
|
145
|
+
# If cross entropy is the last layer, grad_output is 1.0. Skip the mul then.
|
|
146
|
+
if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
|
|
147
|
+
return grads
|
|
148
|
+
|
|
149
|
+
return grads * grad_output
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
class LigerTVDLossFunction(torch.autograd.Function):
|
|
153
|
+
"""
|
|
154
|
+
Class implementing the forward and backward pass for the Total Variation Distance Loss using Triton.
|
|
155
|
+
"""
|
|
156
|
+
|
|
157
|
+
@staticmethod
|
|
158
|
+
@ensure_contiguous
|
|
159
|
+
def forward(
|
|
160
|
+
ctx,
|
|
161
|
+
p: torch.Tensor,
|
|
162
|
+
q: torch.Tensor,
|
|
163
|
+
shift_labels: Optional[torch.Tensor] = None,
|
|
164
|
+
reduction: REDUCTION_LITERAL = "batchmean",
|
|
165
|
+
ignore_index: int = -100,
|
|
166
|
+
) -> torch.Tensor:
|
|
167
|
+
"""A forward pass for the Total Variation Distance Loss.
|
|
168
|
+
|
|
169
|
+
Args:
|
|
170
|
+
ctx: Torch autograd context
|
|
171
|
+
p (torch.Tensor): A tensor of shape (BT, V) containing the first distribution.
|
|
172
|
+
q (torch.Tensor): A tensor of shape (BT, V) containing the second distribution.
|
|
173
|
+
shift_labels (Optional[torch.Tensor]): A tensor of shape (BT,) containing the labels.
|
|
174
|
+
reduction (REDUCTION_LITERAL, optional): The reduction method to be applied. Defaults to "batchmean".
|
|
175
|
+
ignore_index (int, optional): The index to ignore during loss calculation. Defaults to -100.
|
|
176
|
+
|
|
177
|
+
Returns:
|
|
178
|
+
torch.Tensor: The computed Total Variation Distance Loss.
|
|
179
|
+
"""
|
|
180
|
+
has_label = False
|
|
181
|
+
if shift_labels is not None:
|
|
182
|
+
assert shift_labels.shape == (p.shape[0],), (
|
|
183
|
+
f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}"
|
|
184
|
+
)
|
|
185
|
+
shift_labels = shift_labels.contiguous()
|
|
186
|
+
has_label = True
|
|
187
|
+
|
|
188
|
+
loss, grads = tv_distance_forward_triton(p, q, shift_labels, reduction, ignore_index, has_label)
|
|
189
|
+
ctx.save_for_backward(grads)
|
|
190
|
+
return loss
|
|
191
|
+
|
|
192
|
+
@staticmethod
|
|
193
|
+
@ensure_contiguous
|
|
194
|
+
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
|
|
195
|
+
"""A backward pass for the Total Variation Distance Loss.
|
|
196
|
+
|
|
197
|
+
Args:
|
|
198
|
+
ctx: Torch autograd context
|
|
199
|
+
grad_output (torch.Tensor): The gradient of the loss with respect to the output.
|
|
200
|
+
|
|
201
|
+
Returns:
|
|
202
|
+
tuple[torch.Tensor, None, None, None, None]: The gradient of the loss with respect to the inputs.
|
|
203
|
+
"""
|
|
204
|
+
(grads,) = ctx.saved_tensors
|
|
205
|
+
grads = tvd_backward_triton(grad_output, grads)
|
|
206
|
+
|
|
207
|
+
return grads, None, None, None, None
|
liger_kernel/ops/utils.py
CHANGED
|
@@ -13,11 +13,13 @@ Modifications made by Yanning Chen, 2024.
|
|
|
13
13
|
import functools
|
|
14
14
|
import importlib
|
|
15
15
|
import operator
|
|
16
|
+
|
|
16
17
|
from typing import Callable
|
|
17
18
|
|
|
18
19
|
import torch
|
|
19
20
|
import triton
|
|
20
21
|
import triton.language as tl
|
|
22
|
+
|
|
21
23
|
from packaging.version import Version
|
|
22
24
|
|
|
23
25
|
from liger_kernel.utils import infer_device
|
|
@@ -47,8 +49,7 @@ def calculate_settings(n):
|
|
|
47
49
|
BLOCK_SIZE = triton.next_power_of_2(n)
|
|
48
50
|
if BLOCK_SIZE > MAX_FUSED_SIZE:
|
|
49
51
|
raise RuntimeError(
|
|
50
|
-
f"Cannot launch Triton kernel since n = {n} exceeds "
|
|
51
|
-
f"the recommended Triton blocksize = {MAX_FUSED_SIZE}."
|
|
52
|
+
f"Cannot launch Triton kernel since n = {n} exceeds the recommended Triton blocksize = {MAX_FUSED_SIZE}."
|
|
52
53
|
)
|
|
53
54
|
|
|
54
55
|
num_warps = 4
|
|
@@ -1,31 +1,26 @@
|
|
|
1
|
-
from liger_kernel.transformers.auto_model import
|
|
2
|
-
AutoLigerKernelForCausalLM,
|
|
3
|
-
)
|
|
1
|
+
from liger_kernel.transformers.auto_model import AutoLigerKernelForCausalLM # noqa: F401
|
|
4
2
|
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss # noqa: F401
|
|
5
|
-
from liger_kernel.transformers.fused_linear_cross_entropy import
|
|
6
|
-
LigerFusedLinearCrossEntropyLoss,
|
|
7
|
-
)
|
|
3
|
+
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss # noqa: F401
|
|
8
4
|
from liger_kernel.transformers.fused_linear_jsd import LigerFusedLinearJSD # noqa: F401
|
|
9
5
|
from liger_kernel.transformers.geglu import LigerGEGLUMLP # noqa: F401
|
|
10
6
|
from liger_kernel.transformers.jsd import LigerJSD # noqa: F401
|
|
11
7
|
from liger_kernel.transformers.layer_norm import LigerLayerNorm # noqa: F401
|
|
12
|
-
from liger_kernel.transformers.monkey_patch import
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
8
|
+
from liger_kernel.transformers.monkey_patch import _apply_liger_kernel # noqa: F401
|
|
9
|
+
from liger_kernel.transformers.monkey_patch import _apply_liger_kernel_to_instance # noqa: F401
|
|
10
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma # noqa: F401
|
|
11
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma2 # noqa: F401
|
|
12
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_granite # noqa: F401
|
|
13
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama # noqa: F401
|
|
14
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mistral # noqa: F401
|
|
15
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mixtral # noqa: F401
|
|
16
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mllama # noqa: F401
|
|
17
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_olmo2 # noqa: F401
|
|
18
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_phi3 # noqa: F401
|
|
19
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2 # noqa: F401
|
|
20
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2_vl # noqa: F401
|
|
25
21
|
from liger_kernel.transformers.rms_norm import LigerRMSNorm # noqa: F401
|
|
26
22
|
from liger_kernel.transformers.rope import liger_rotary_pos_emb # noqa: F401
|
|
27
|
-
from liger_kernel.transformers.swiglu import
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
)
|
|
23
|
+
from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP # noqa: F401
|
|
24
|
+
from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP # noqa: F401
|
|
25
|
+
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP # noqa: F401
|
|
26
|
+
from liger_kernel.transformers.tvd import LigerTVDLoss # noqa: F401
|
|
@@ -1,11 +1,10 @@
|
|
|
1
1
|
import inspect
|
|
2
2
|
|
|
3
|
-
from transformers import AutoConfig
|
|
3
|
+
from transformers import AutoConfig
|
|
4
|
+
from transformers import AutoModelForCausalLM
|
|
4
5
|
|
|
5
|
-
from liger_kernel.transformers.monkey_patch import
|
|
6
|
-
|
|
7
|
-
_apply_liger_kernel,
|
|
8
|
-
)
|
|
6
|
+
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
|
|
7
|
+
from liger_kernel.transformers.monkey_patch import _apply_liger_kernel
|
|
9
8
|
|
|
10
9
|
|
|
11
10
|
def _get_model_config(model_dir, **model_init_kwargs):
|
|
@@ -34,12 +33,6 @@ class AutoLigerKernelForCausalLM(AutoModelForCausalLM):
|
|
|
34
33
|
apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
|
|
35
34
|
apply_fn_signature = inspect.signature(apply_fn)
|
|
36
35
|
|
|
37
|
-
applicable_kwargs = {
|
|
38
|
-
key: value
|
|
39
|
-
for key, value in kwargs.items()
|
|
40
|
-
if key not in apply_fn_signature.parameters
|
|
41
|
-
}
|
|
36
|
+
applicable_kwargs = {key: value for key, value in kwargs.items() if key not in apply_fn_signature.parameters}
|
|
42
37
|
|
|
43
|
-
return super().from_pretrained(
|
|
44
|
-
pretrained_model_name_or_path, *model_args, **applicable_kwargs
|
|
45
|
-
)
|
|
38
|
+
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **applicable_kwargs)
|
|
@@ -8,6 +8,7 @@ from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
|
|
|
8
8
|
class LigerCrossEntropyLoss(torch.nn.Module):
|
|
9
9
|
def __init__(
|
|
10
10
|
self,
|
|
11
|
+
weight: Optional[torch.FloatTensor] = None,
|
|
11
12
|
ignore_index: int = -100,
|
|
12
13
|
lse_square_scale: float = 0.0,
|
|
13
14
|
label_smoothing: float = 0.0,
|
|
@@ -16,20 +17,16 @@ class LigerCrossEntropyLoss(torch.nn.Module):
|
|
|
16
17
|
return_z_loss: bool = False,
|
|
17
18
|
):
|
|
18
19
|
super().__init__()
|
|
19
|
-
assert (label_smoothing >= 0) and (
|
|
20
|
-
label_smoothing
|
|
21
|
-
)
|
|
22
|
-
assert (label_smoothing >= 0) and (
|
|
23
|
-
label_smoothing <= 1
|
|
24
|
-
), f"label_smoothing must be between 0.0 and 1.0. Got: {label_smoothing}"
|
|
20
|
+
assert (label_smoothing >= 0) and (label_smoothing <= 1), (
|
|
21
|
+
f"label_smoothing must be between 0.0 and 1.0. Got: {label_smoothing}"
|
|
22
|
+
)
|
|
25
23
|
assert reduction in {
|
|
26
24
|
"mean",
|
|
27
25
|
"sum",
|
|
28
26
|
"none",
|
|
29
27
|
}, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {reduction}"
|
|
30
|
-
assert
|
|
31
|
-
|
|
32
|
-
), f"softcap must greater than 0.0 or None. Got: {softcap}"
|
|
28
|
+
assert softcap is None or softcap > 0, f"softcap must greater than 0.0 or None. Got: {softcap}"
|
|
29
|
+
self.weight = weight
|
|
33
30
|
self.ignore_index = ignore_index
|
|
34
31
|
self.lse_square_scale = lse_square_scale
|
|
35
32
|
self.label_smoothing = label_smoothing
|
|
@@ -41,6 +38,7 @@ class LigerCrossEntropyLoss(torch.nn.Module):
|
|
|
41
38
|
loss, z_loss = LigerCrossEntropyFunction.apply(
|
|
42
39
|
_input,
|
|
43
40
|
target,
|
|
41
|
+
self.weight,
|
|
44
42
|
self.ignore_index,
|
|
45
43
|
self.lse_square_scale,
|
|
46
44
|
self.label_smoothing,
|
|
@@ -7,9 +7,7 @@ from liger_kernel.ops.experimental.embedding import LigerEmbeddingFunction
|
|
|
7
7
|
|
|
8
8
|
|
|
9
9
|
class LigerEmbedding(nn.Module):
|
|
10
|
-
def __init__(
|
|
11
|
-
self, num_embeddings, embedding_dim, padding_idx: Optional[int] = None
|
|
12
|
-
):
|
|
10
|
+
def __init__(self, num_embeddings, embedding_dim, padding_idx: Optional[int] = None):
|
|
13
11
|
super().__init__()
|
|
14
12
|
self.num_embeddings = num_embeddings
|
|
15
13
|
self.embedding_dim = embedding_dim
|
|
@@ -1,9 +1,7 @@
|
|
|
1
1
|
from typing import Optional
|
|
2
2
|
|
|
3
3
|
from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
|
|
4
|
-
from liger_kernel.ops.fused_linear_cross_entropy import
|
|
5
|
-
LigerFusedLinearCrossEntropyFunction,
|
|
6
|
-
)
|
|
4
|
+
from liger_kernel.ops.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyFunction
|
|
7
5
|
from liger_kernel.ops.fused_linear_jsd import LigerFusedLinearJSDFunction
|
|
8
6
|
from liger_kernel.ops.geglu import LigerGELUMulFunction
|
|
9
7
|
from liger_kernel.ops.group_norm import LigerGroupNormFunction
|
|
@@ -14,6 +12,7 @@ from liger_kernel.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction
|
|
|
14
12
|
from liger_kernel.ops.rms_norm import LigerRMSNormFunction
|
|
15
13
|
from liger_kernel.ops.rope import LigerRopeFunction
|
|
16
14
|
from liger_kernel.ops.swiglu import LigerSiLUMulFunction
|
|
15
|
+
from liger_kernel.ops.tvd import LigerTVDLossFunction
|
|
17
16
|
|
|
18
17
|
|
|
19
18
|
# conform to the function signature in https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html
|
|
@@ -34,6 +33,7 @@ def liger_cross_entropy(
|
|
|
34
33
|
loss, z_loss = LigerCrossEntropyFunction.apply(
|
|
35
34
|
input,
|
|
36
35
|
target,
|
|
36
|
+
weight,
|
|
37
37
|
ignore_index,
|
|
38
38
|
lse_square_scale,
|
|
39
39
|
label_smoothing,
|
|
@@ -51,23 +51,30 @@ def liger_fused_linear_cross_entropy(
|
|
|
51
51
|
weight,
|
|
52
52
|
target,
|
|
53
53
|
bias=None,
|
|
54
|
+
ce_weight=None,
|
|
54
55
|
ignore_index: int = -100,
|
|
55
56
|
lse_square_scale: float = 0.0,
|
|
56
57
|
label_smoothing: float = 0.0,
|
|
57
58
|
reduction: str = "mean",
|
|
58
59
|
softcap: Optional[float] = None,
|
|
60
|
+
return_z_loss: bool = False,
|
|
59
61
|
):
|
|
60
|
-
|
|
62
|
+
loss, z_loss = LigerFusedLinearCrossEntropyFunction.apply(
|
|
61
63
|
input,
|
|
62
64
|
weight,
|
|
63
65
|
target,
|
|
64
66
|
bias,
|
|
67
|
+
ce_weight,
|
|
65
68
|
ignore_index,
|
|
66
69
|
lse_square_scale,
|
|
67
70
|
label_smoothing,
|
|
68
71
|
reduction,
|
|
69
72
|
softcap,
|
|
73
|
+
return_z_loss,
|
|
70
74
|
)
|
|
75
|
+
if not return_z_loss:
|
|
76
|
+
return loss
|
|
77
|
+
return loss, z_loss
|
|
71
78
|
|
|
72
79
|
|
|
73
80
|
def liger_fused_linear_jsd(
|
|
@@ -151,6 +158,22 @@ def liger_kl_div(
|
|
|
151
158
|
)
|
|
152
159
|
|
|
153
160
|
|
|
161
|
+
def liger_tvd(
|
|
162
|
+
input,
|
|
163
|
+
target,
|
|
164
|
+
shift_labels=None,
|
|
165
|
+
reduction: str = "mean",
|
|
166
|
+
ignore_index: int = -100,
|
|
167
|
+
):
|
|
168
|
+
return LigerTVDLossFunction.apply(
|
|
169
|
+
input,
|
|
170
|
+
target,
|
|
171
|
+
shift_labels,
|
|
172
|
+
reduction,
|
|
173
|
+
ignore_index,
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
|
|
154
177
|
def liger_layer_norm(X, W, B, eps):
|
|
155
178
|
return LigerLayerNormFunction.apply(X, W, B, eps)
|
|
156
179
|
|
|
@@ -159,9 +182,7 @@ def liger_qwen2vl_mrope(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
|
|
|
159
182
|
return LigerQwen2VLMRopeFunction.apply(q, k, cos, sin, mrope_section, unsqueeze_dim)
|
|
160
183
|
|
|
161
184
|
|
|
162
|
-
def liger_rms_norm(
|
|
163
|
-
X, W, eps, offset: float = 0.0, casting_mode: str = "llama", in_place: bool = True
|
|
164
|
-
):
|
|
185
|
+
def liger_rms_norm(X, W, eps, offset: float = 0.0, casting_mode: str = "llama", in_place: bool = True):
|
|
165
186
|
return LigerRMSNormFunction.apply(X, W, eps, offset, casting_mode, in_place)
|
|
166
187
|
|
|
167
188
|
|