liger-kernel 0.5.9__py3-none-any.whl → 0.5.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/dpo_loss.py +1 -1
- liger_kernel/chunked_loss/fused_linear_preference.py +0 -1
- liger_kernel/chunked_loss/jsd_loss.py +2 -2
- liger_kernel/ops/dyt.py +113 -179
- liger_kernel/ops/grpo_loss.py +310 -0
- liger_kernel/ops/sparsemax.py +167 -0
- liger_kernel/transformers/__init__.py +5 -0
- liger_kernel/transformers/dyt.py +5 -3
- liger_kernel/transformers/fsdp.py +55 -0
- liger_kernel/transformers/functional.py +8 -0
- liger_kernel/transformers/grpo_loss.py +98 -0
- liger_kernel/transformers/model/gemma.py +0 -8
- liger_kernel/transformers/model/gemma2.py +0 -6
- liger_kernel/transformers/model/gemma3.py +0 -8
- liger_kernel/transformers/model/glm4.py +0 -6
- liger_kernel/transformers/model/llama.py +56 -11
- liger_kernel/transformers/model/llava.py +0 -8
- liger_kernel/transformers/model/mistral.py +0 -6
- liger_kernel/transformers/model/mixtral.py +0 -8
- liger_kernel/transformers/model/mllama.py +0 -7
- liger_kernel/transformers/model/olmo2.py +0 -6
- liger_kernel/transformers/model/paligemma.py +0 -8
- liger_kernel/transformers/model/phi3.py +0 -8
- liger_kernel/transformers/model/qwen2.py +0 -8
- liger_kernel/transformers/model/qwen2_5_vl.py +0 -6
- liger_kernel/transformers/model/qwen2_vl.py +0 -6
- liger_kernel/transformers/model/qwen3.py +0 -6
- liger_kernel/transformers/model/qwen3_moe.py +128 -0
- liger_kernel/transformers/monkey_patch.py +122 -13
- liger_kernel/transformers/sparsemax.py +16 -0
- liger_kernel/transformers/swiglu.py +21 -0
- liger_kernel/transformers/trainer/orpo_trainer.py +1 -53
- liger_kernel/utils.py +11 -0
- {liger_kernel-0.5.9.dist-info → liger_kernel-0.5.10.dist-info}/METADATA +34 -20
- {liger_kernel-0.5.9.dist-info → liger_kernel-0.5.10.dist-info}/RECORD +39 -33
- {liger_kernel-0.5.9.dist-info → liger_kernel-0.5.10.dist-info}/WHEEL +1 -1
- {liger_kernel-0.5.9.dist-info → liger_kernel-0.5.10.dist-info}/licenses/LICENSE +0 -0
- {liger_kernel-0.5.9.dist-info → liger_kernel-0.5.10.dist-info}/licenses/NOTICE +0 -0
- {liger_kernel-0.5.9.dist-info → liger_kernel-0.5.10.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,167 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import triton
|
|
3
|
+
import triton.language as tl
|
|
4
|
+
|
|
5
|
+
from liger_kernel.ops.utils import calculate_settings
|
|
6
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@triton.jit
|
|
10
|
+
def _sparsemax_forward_kernel(
|
|
11
|
+
x_ptr,
|
|
12
|
+
x_stride_row,
|
|
13
|
+
sorted_x_ptr,
|
|
14
|
+
sorted_x_stride_row,
|
|
15
|
+
o_ptr,
|
|
16
|
+
o_stride_row,
|
|
17
|
+
n_cols,
|
|
18
|
+
BLOCK_SIZE: tl.constexpr,
|
|
19
|
+
num_warps: tl.constexpr,
|
|
20
|
+
):
|
|
21
|
+
pid_row = tl.program_id(0)
|
|
22
|
+
ptr_x_data_row = x_ptr + pid_row * x_stride_row
|
|
23
|
+
ptr_sorted_x_data_row = sorted_x_ptr + pid_row * sorted_x_stride_row
|
|
24
|
+
ptr_output_row = o_ptr + pid_row * o_stride_row
|
|
25
|
+
|
|
26
|
+
offs = tl.arange(0, BLOCK_SIZE)
|
|
27
|
+
mask = offs < n_cols
|
|
28
|
+
|
|
29
|
+
z_sorted_block = tl.load(
|
|
30
|
+
ptr_sorted_x_data_row + offs,
|
|
31
|
+
mask=mask,
|
|
32
|
+
other=-float("inf"),
|
|
33
|
+
cache_modifier=".ca",
|
|
34
|
+
).to(tl.float32)
|
|
35
|
+
|
|
36
|
+
z_valid = tl.where(mask, z_sorted_block, 0.0)
|
|
37
|
+
cssv = tl.cumsum(z_valid, 0)
|
|
38
|
+
|
|
39
|
+
r = (offs + 1).to(tl.float32)
|
|
40
|
+
safe_r = tl.where(mask, r, 1.0)
|
|
41
|
+
|
|
42
|
+
t_vec = (cssv - 1.0) / safe_r
|
|
43
|
+
|
|
44
|
+
support = (z_sorted_block > t_vec) & mask
|
|
45
|
+
|
|
46
|
+
k_int = tl.sum(support.to(tl.int32), 0)
|
|
47
|
+
k_clamped_int = tl.maximum(k_int, 1)
|
|
48
|
+
k = k_clamped_int.to(tl.float32)
|
|
49
|
+
|
|
50
|
+
s = tl.sum(tl.where(support, z_sorted_block, 0.0), 0)
|
|
51
|
+
|
|
52
|
+
tau = (s - 1.0) / k
|
|
53
|
+
|
|
54
|
+
x_block = tl.load(
|
|
55
|
+
ptr_x_data_row + offs,
|
|
56
|
+
mask=mask,
|
|
57
|
+
other=0.0,
|
|
58
|
+
cache_modifier=".ca",
|
|
59
|
+
).to(tl.float32)
|
|
60
|
+
|
|
61
|
+
y = tl.maximum(x_block - tau, 0.0)
|
|
62
|
+
|
|
63
|
+
tl.store(
|
|
64
|
+
ptr_output_row + offs,
|
|
65
|
+
y.to(ptr_output_row.dtype.element_ty),
|
|
66
|
+
mask=mask,
|
|
67
|
+
cache_modifier=".cs",
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
@triton.jit
|
|
72
|
+
def _sparsemax_backward_kernel(
|
|
73
|
+
o_ptr, go_ptr, gi_ptr, stride, n_cols, BLOCK_SIZE: tl.constexpr, num_warps: tl.constexpr
|
|
74
|
+
):
|
|
75
|
+
row = tl.program_id(0)
|
|
76
|
+
o_row = o_ptr + row * stride
|
|
77
|
+
go_row = go_ptr + row * stride
|
|
78
|
+
gi_row = gi_ptr + row * stride
|
|
79
|
+
|
|
80
|
+
offs = tl.arange(0, BLOCK_SIZE)
|
|
81
|
+
|
|
82
|
+
supp_cnt = tl.zeros((), tl.float32)
|
|
83
|
+
go_sum = tl.zeros((), tl.float32)
|
|
84
|
+
|
|
85
|
+
for i in tl.range(0, tl.cdiv(n_cols, BLOCK_SIZE)):
|
|
86
|
+
offs_iter = i * BLOCK_SIZE + offs
|
|
87
|
+
mask_iter = offs_iter < n_cols
|
|
88
|
+
o_val = tl.load(o_row + offs_iter, mask=mask_iter, other=0.0, cache_modifier=".ca").to(tl.float32)
|
|
89
|
+
go_val = tl.load(go_row + offs_iter, mask=mask_iter, other=0.0).to(tl.float32)
|
|
90
|
+
supp = o_val > 0.0
|
|
91
|
+
go_sum += tl.sum(tl.where(supp, go_val, 0.0))
|
|
92
|
+
supp_cnt += tl.sum(supp.to(tl.float32))
|
|
93
|
+
|
|
94
|
+
for i in tl.range(0, tl.cdiv(n_cols, BLOCK_SIZE)):
|
|
95
|
+
offs_iter = i * BLOCK_SIZE + offs
|
|
96
|
+
mask_iter = offs_iter < n_cols
|
|
97
|
+
o_val = tl.load(o_row + offs_iter, mask=mask_iter, other=0.0, cache_modifier=".ca").to(tl.float32)
|
|
98
|
+
go_val = tl.load(go_row + offs_iter, mask=mask_iter, other=0.0).to(tl.float32)
|
|
99
|
+
supp = o_val > 0.0
|
|
100
|
+
gi_val = tl.where(
|
|
101
|
+
supp,
|
|
102
|
+
go_val - tl.cast(go_sum / tl.maximum(supp_cnt, 1e-6), gi_row.dtype.element_ty).to(tl.float32),
|
|
103
|
+
0.0,
|
|
104
|
+
)
|
|
105
|
+
tl.store(gi_row + offs_iter, gi_val.to(gi_row.dtype.element_ty), mask=mask_iter, cache_modifier=".wb")
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
class LigerSparsemaxFunction(torch.autograd.Function):
|
|
109
|
+
@staticmethod
|
|
110
|
+
@ensure_contiguous
|
|
111
|
+
def forward(ctx, x: torch.Tensor, dim: int):
|
|
112
|
+
if dim < 0:
|
|
113
|
+
dim += x.dim()
|
|
114
|
+
ctx.dim = dim
|
|
115
|
+
|
|
116
|
+
x_sw = x.transpose(dim, -1).contiguous()
|
|
117
|
+
n_cols = x_sw.size(-1)
|
|
118
|
+
n_rows = x_sw.numel() // n_cols
|
|
119
|
+
x_flat = x_sw.view(n_rows, n_cols)
|
|
120
|
+
|
|
121
|
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
122
|
+
out_flat = torch.empty_like(x_flat)
|
|
123
|
+
grid = (n_rows,)
|
|
124
|
+
|
|
125
|
+
x_sorted_flat = torch.sort(x_flat.float(), dim=-1, descending=True).values
|
|
126
|
+
|
|
127
|
+
_sparsemax_forward_kernel[grid](
|
|
128
|
+
x_flat,
|
|
129
|
+
x_flat.stride(0),
|
|
130
|
+
x_sorted_flat,
|
|
131
|
+
x_sorted_flat.stride(0),
|
|
132
|
+
out_flat,
|
|
133
|
+
out_flat.stride(0),
|
|
134
|
+
n_cols,
|
|
135
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
136
|
+
num_warps=num_warps,
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
ctx.save_for_backward(out_flat)
|
|
140
|
+
return out_flat.view_as(x_sw).transpose(dim, -1)
|
|
141
|
+
|
|
142
|
+
@staticmethod
|
|
143
|
+
@ensure_contiguous
|
|
144
|
+
def backward(ctx, grad_out: torch.Tensor):
|
|
145
|
+
(out_flat,) = ctx.saved_tensors
|
|
146
|
+
dim = ctx.dim
|
|
147
|
+
|
|
148
|
+
go_sw = grad_out.transpose(dim, -1).contiguous()
|
|
149
|
+
n_cols = go_sw.size(-1)
|
|
150
|
+
n_rows = go_sw.numel() // n_cols
|
|
151
|
+
go_flat = go_sw.view(n_rows, n_cols)
|
|
152
|
+
|
|
153
|
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
154
|
+
gi_flat = torch.empty_like(go_flat)
|
|
155
|
+
grid = (n_rows,)
|
|
156
|
+
|
|
157
|
+
_sparsemax_backward_kernel[grid](
|
|
158
|
+
out_flat,
|
|
159
|
+
go_flat,
|
|
160
|
+
gi_flat,
|
|
161
|
+
out_flat.stride(0),
|
|
162
|
+
n_cols,
|
|
163
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
164
|
+
num_warps=num_warps,
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
return gi_flat.view_as(go_sw).transpose(dim, -1), None
|
|
@@ -14,6 +14,7 @@ from liger_kernel.transformers.rms_norm import LigerRMSNorm # noqa: F401
|
|
|
14
14
|
from liger_kernel.transformers.rope import liger_rotary_pos_emb # noqa: F401
|
|
15
15
|
from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP # noqa: F401
|
|
16
16
|
from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP # noqa: F401
|
|
17
|
+
from liger_kernel.transformers.swiglu import LigerQwen3MoeSwiGLUMLP # noqa: F401
|
|
17
18
|
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP # noqa: F401
|
|
18
19
|
from liger_kernel.transformers.tvd import LigerTVDLoss # noqa: F401
|
|
19
20
|
|
|
@@ -40,6 +41,7 @@ if TYPE_CHECKING:
|
|
|
40
41
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2_5_vl # noqa: F401
|
|
41
42
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2_vl # noqa: F401
|
|
42
43
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3 # noqa: F401
|
|
44
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_moe # noqa: F401
|
|
43
45
|
|
|
44
46
|
|
|
45
47
|
# Check if 'transformers' is installed
|
|
@@ -95,6 +97,7 @@ def __getattr__(name: str):
|
|
|
95
97
|
"apply_liger_kernel_to_qwen2_5_vl",
|
|
96
98
|
"apply_liger_kernel_to_qwen2_vl",
|
|
97
99
|
"apply_liger_kernel_to_qwen3",
|
|
100
|
+
"apply_liger_kernel_to_qwen3_moe",
|
|
98
101
|
}
|
|
99
102
|
|
|
100
103
|
if name in monkey_patch_symbols:
|
|
@@ -118,6 +121,7 @@ __all__ = [
|
|
|
118
121
|
"liger_rotary_pos_emb",
|
|
119
122
|
"LigerBlockSparseTop2MLP",
|
|
120
123
|
"LigerPhi3SwiGLUMLP",
|
|
124
|
+
"LigerQwen3MoeSwiGLUMLP",
|
|
121
125
|
"LigerSwiGLUMLP",
|
|
122
126
|
"LigerTVDLoss",
|
|
123
127
|
]
|
|
@@ -147,5 +151,6 @@ if _TRANSFORMERS_AVAILABLE:
|
|
|
147
151
|
"apply_liger_kernel_to_qwen2_5_vl",
|
|
148
152
|
"apply_liger_kernel_to_qwen2_vl",
|
|
149
153
|
"apply_liger_kernel_to_qwen3",
|
|
154
|
+
"apply_liger_kernel_to_qwen3_moe",
|
|
150
155
|
]
|
|
151
156
|
)
|
liger_kernel/transformers/dyt.py
CHANGED
|
@@ -5,16 +5,18 @@ from liger_kernel.ops.dyt import LigerDyTFunction
|
|
|
5
5
|
|
|
6
6
|
|
|
7
7
|
class LigerDyT(nn.Module):
|
|
8
|
-
def __init__(self, hidden_size, init_alpha=0.5):
|
|
8
|
+
def __init__(self, hidden_size, beta=True, init_alpha=0.5):
|
|
9
9
|
super().__init__()
|
|
10
10
|
self.hidden_size = hidden_size
|
|
11
11
|
self.init_alpha = init_alpha
|
|
12
12
|
self.alpha = nn.Parameter(torch.ones(1) * init_alpha)
|
|
13
13
|
self.gamma = nn.Parameter(torch.ones(hidden_size))
|
|
14
|
-
self.beta =
|
|
14
|
+
self.beta = None
|
|
15
|
+
if beta:
|
|
16
|
+
self.beta = nn.Parameter(torch.zeros(hidden_size))
|
|
15
17
|
|
|
16
18
|
def forward(self, x):
|
|
17
19
|
return LigerDyTFunction.apply(x, self.alpha, self.gamma, self.beta)
|
|
18
20
|
|
|
19
21
|
def extra_repr(self):
|
|
20
|
-
return f"{self.hidden_size}, init_alpha={self.init_alpha}"
|
|
22
|
+
return f"{self.hidden_size}, init_alpha={self.init_alpha}, beta={self.beta}"
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
from typing import Callable
|
|
3
|
+
|
|
4
|
+
from torch.distributed.fsdp import FullyShardedDataParallel
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class _FSDPForwardRedirection:
|
|
8
|
+
"""
|
|
9
|
+
Modified based on
|
|
10
|
+
https://github.com/Lightning-AI/pytorch-lightning/blob/d3f9c83d6efa4f1def36aa6c199600946cdb9117/src/lightning/pytorch/strategies/strategy.py#L601-L648
|
|
11
|
+
Redirect a method call through FullyShardedDataParallel.forward so that the FSDP module's root pre-forward and
|
|
12
|
+
post-forward can be properly executed around the method call.
|
|
13
|
+
This is needed in cases where we call a submodule of a FSDP module. For instance, when we want to call only
|
|
14
|
+
the `LlamaModel` part out of a FSDP-wrapped `LlamaForCausalLM` to get the hidden states without involving
|
|
15
|
+
GPU-memory-heavy `lm_head` and cross entropy computation, doing this directly (i.e. `model.model.forward()`)
|
|
16
|
+
will not work because the first `nn.Embedding` layer is not independently wrapped as a FSDP module (because of
|
|
17
|
+
the transformer-based wrapping policy), and not calling it through FSDP root module forward will not all-gather
|
|
18
|
+
its parameter, thus resulting in "RuntimeError: 'weight' must be 2-D" error. Similarly, if we want to call just
|
|
19
|
+
the `lm_head` part of a model, we need this trick too to properly get its params all-gathered.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __call__(
|
|
23
|
+
self,
|
|
24
|
+
wrapper_module: FullyShardedDataParallel,
|
|
25
|
+
method: Callable,
|
|
26
|
+
*args: Any,
|
|
27
|
+
**kwargs: Any,
|
|
28
|
+
):
|
|
29
|
+
"""Reroutes a method call through the `wrapper_module`'s `forward` method.
|
|
30
|
+
Args:
|
|
31
|
+
wrapper_module: The module that has `original_module` wrapped.
|
|
32
|
+
original_module: The module that was wrapped inside `wrapper_module`.
|
|
33
|
+
method_name: The name of the method that should be called on the `original_module` after inputs get
|
|
34
|
+
redirected through the `wrapper_module`'s `forward` method.
|
|
35
|
+
*args: The positional arguments to the method `method_name`. They will get passed to a patched
|
|
36
|
+
`forward` method instead.
|
|
37
|
+
**kwargs: The keyword arguments to the method `method_name`. They will get passed to a patched
|
|
38
|
+
`forward` method instead.
|
|
39
|
+
"""
|
|
40
|
+
assert isinstance(wrapper_module, FullyShardedDataParallel)
|
|
41
|
+
original_module = wrapper_module._fsdp_wrapped_module
|
|
42
|
+
original_forward = original_module.forward
|
|
43
|
+
|
|
44
|
+
def wrapped_forward(*_args: Any, **_kwargs: Any) -> Any:
|
|
45
|
+
# Unpatch ourselves immediately before calling the method `method_name`
|
|
46
|
+
# because itself may want to call the real `forward`
|
|
47
|
+
original_module.forward = original_forward # type: ignore[method-assign]
|
|
48
|
+
# Call the actual method e.g. `.training_step(...)`
|
|
49
|
+
out = method(*_args, **_kwargs)
|
|
50
|
+
return out
|
|
51
|
+
|
|
52
|
+
# Patch the original_module's forward so we can redirect the arguments back to the real method
|
|
53
|
+
original_module.forward = wrapped_forward # type: ignore[method-assign]
|
|
54
|
+
wrapper_output = wrapper_module(*args, **kwargs)
|
|
55
|
+
return wrapper_output
|
|
@@ -12,6 +12,7 @@ from liger_kernel.ops.layer_norm import LigerLayerNormFunction
|
|
|
12
12
|
from liger_kernel.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction
|
|
13
13
|
from liger_kernel.ops.rms_norm import LigerRMSNormFunction
|
|
14
14
|
from liger_kernel.ops.rope import LigerRopeFunction
|
|
15
|
+
from liger_kernel.ops.sparsemax import LigerSparsemaxFunction
|
|
15
16
|
from liger_kernel.ops.swiglu import LigerSiLUMulFunction
|
|
16
17
|
from liger_kernel.ops.tvd import LigerTVDLossFunction
|
|
17
18
|
|
|
@@ -159,6 +160,13 @@ def liger_kl_div(
|
|
|
159
160
|
)
|
|
160
161
|
|
|
161
162
|
|
|
163
|
+
def liger_sparsemax(
|
|
164
|
+
input,
|
|
165
|
+
dim: int = -1,
|
|
166
|
+
):
|
|
167
|
+
return LigerSparsemaxFunction.apply(input, dim)
|
|
168
|
+
|
|
169
|
+
|
|
162
170
|
def liger_tvd(
|
|
163
171
|
input,
|
|
164
172
|
target,
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
from liger_kernel.ops.grpo_loss import GrpoLossFunction
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def triton_grpo_loss(
|
|
5
|
+
logits,
|
|
6
|
+
old_logp,
|
|
7
|
+
ref_logp,
|
|
8
|
+
completion_ids,
|
|
9
|
+
advantages,
|
|
10
|
+
completion_mask=None,
|
|
11
|
+
temperature=0.9,
|
|
12
|
+
beta=0.04,
|
|
13
|
+
eps_low=0.2,
|
|
14
|
+
eps_high=0.4,
|
|
15
|
+
inplace=True,
|
|
16
|
+
):
|
|
17
|
+
assert logits is not None and completion_ids is not None and advantages is not None, (
|
|
18
|
+
"must provide logits、completion_ids and advantages"
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
return GrpoLossFunction.apply(
|
|
22
|
+
logits,
|
|
23
|
+
old_logp,
|
|
24
|
+
ref_logp,
|
|
25
|
+
completion_ids,
|
|
26
|
+
advantages,
|
|
27
|
+
completion_mask,
|
|
28
|
+
temperature,
|
|
29
|
+
beta,
|
|
30
|
+
eps_low,
|
|
31
|
+
eps_high,
|
|
32
|
+
inplace,
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
# This is a demo how to use grpo_loss in GRPOTrainer. The Trl version must be 0.16
|
|
37
|
+
"""
|
|
38
|
+
import torch
|
|
39
|
+
import trl
|
|
40
|
+
assert trl.__version__.startswith("0.16"), "please pip install trl==0.16"
|
|
41
|
+
from trl.extras.profiling import profiling_decorator
|
|
42
|
+
|
|
43
|
+
@profiling_decorator
|
|
44
|
+
def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep):
|
|
45
|
+
# We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
|
|
46
|
+
logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits
|
|
47
|
+
return fused_selective_log_softmax(logits, input_ids, self.temperature, mask=attention_mask)
|
|
48
|
+
|
|
49
|
+
@profiling_decorator
|
|
50
|
+
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
|
|
51
|
+
if return_outputs:
|
|
52
|
+
raise ValueError("The GRPOTrainer does not support returning outputs")
|
|
53
|
+
# Compute the per-token log probabilities for the model
|
|
54
|
+
|
|
55
|
+
prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
|
|
56
|
+
completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
|
|
57
|
+
input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
|
|
58
|
+
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
|
|
59
|
+
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
|
|
60
|
+
logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits
|
|
61
|
+
|
|
62
|
+
ref_per_token_logps = inputs["ref_per_token_logps"]
|
|
63
|
+
advantages = inputs["advantages"]
|
|
64
|
+
old_per_token_logps = inputs["old_per_token_logps"]
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
per_token_loss, per_token_kl, is_clipped = triton_grpo_loss(logits,
|
|
68
|
+
old_per_token_logps,
|
|
69
|
+
ref_per_token_logps,
|
|
70
|
+
completion_ids,
|
|
71
|
+
advantages,
|
|
72
|
+
completion_mask,
|
|
73
|
+
self.temperature,
|
|
74
|
+
self.beta,
|
|
75
|
+
self.epsilon_low,
|
|
76
|
+
self.epsilon_high,)
|
|
77
|
+
loss = (per_token_loss * completion_mask).sum() / completion_mask.sum()
|
|
78
|
+
|
|
79
|
+
# Log the metrics
|
|
80
|
+
mode = "eval" if self.control.should_evaluate else "train"
|
|
81
|
+
|
|
82
|
+
if self.beta != 0.0:
|
|
83
|
+
mean_kl = (per_token_kl * completion_mask).sum() / completion_mask.sum()
|
|
84
|
+
self._metrics[mode]["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
|
|
85
|
+
|
|
86
|
+
clip_ratio = (is_clipped * completion_mask).sum() / completion_mask.sum()
|
|
87
|
+
self._metrics[mode]["clip_ratio"].append(self.accelerator.gather_for_metrics(clip_ratio).mean().item())
|
|
88
|
+
return loss
|
|
89
|
+
|
|
90
|
+
trl.GRPOTrainer._get_per_token_logps = _get_per_token_logps
|
|
91
|
+
trl.GRPOTrainer.compute_loss = compute_loss
|
|
92
|
+
trigger = None
|
|
93
|
+
"""
|
|
94
|
+
|
|
95
|
+
# add this line at the first line of grpo.py in open-r1
|
|
96
|
+
"""
|
|
97
|
+
from liger_kernel.transformers.grpo_loss import trigger
|
|
98
|
+
"""
|
|
@@ -8,18 +8,12 @@ import torch
|
|
|
8
8
|
from torch.nn import CrossEntropyLoss
|
|
9
9
|
from transformers.cache_utils import Cache
|
|
10
10
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
11
|
-
from transformers.models.gemma.modeling_gemma import _CONFIG_FOR_DOC
|
|
12
|
-
from transformers.models.gemma.modeling_gemma import GEMMA_INPUTS_DOCSTRING
|
|
13
|
-
from transformers.utils import add_start_docstrings_to_model_forward
|
|
14
|
-
from transformers.utils import replace_return_docstrings
|
|
15
11
|
from transformers.utils.deprecation import deprecate_kwarg
|
|
16
12
|
|
|
17
13
|
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
|
18
14
|
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
19
15
|
|
|
20
16
|
|
|
21
|
-
@add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
|
|
22
|
-
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
23
17
|
def lce_forward_deprecated(
|
|
24
18
|
self,
|
|
25
19
|
input_ids: torch.LongTensor = None,
|
|
@@ -129,8 +123,6 @@ def lce_forward_deprecated(
|
|
|
129
123
|
|
|
130
124
|
|
|
131
125
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
132
|
-
@add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
|
|
133
|
-
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
134
126
|
def lce_forward(
|
|
135
127
|
self,
|
|
136
128
|
input_ids: torch.LongTensor = None,
|
|
@@ -9,10 +9,6 @@ import torch
|
|
|
9
9
|
from torch.nn import CrossEntropyLoss
|
|
10
10
|
from transformers.cache_utils import HybridCache
|
|
11
11
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
12
|
-
from transformers.models.gemma2.modeling_gemma2 import _CONFIG_FOR_DOC
|
|
13
|
-
from transformers.models.gemma2.modeling_gemma2 import GEMMA2_INPUTS_DOCSTRING
|
|
14
|
-
from transformers.utils import add_start_docstrings_to_model_forward
|
|
15
|
-
from transformers.utils import replace_return_docstrings
|
|
16
12
|
from transformers.utils.deprecation import deprecate_kwarg
|
|
17
13
|
|
|
18
14
|
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
|
@@ -136,8 +132,6 @@ def lce_forward_deprecated(
|
|
|
136
132
|
|
|
137
133
|
|
|
138
134
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
139
|
-
@add_start_docstrings_to_model_forward(GEMMA2_INPUTS_DOCSTRING)
|
|
140
|
-
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
141
135
|
def lce_forward(
|
|
142
136
|
self,
|
|
143
137
|
input_ids: torch.LongTensor = None,
|
|
@@ -9,13 +9,9 @@ import torch.nn as nn
|
|
|
9
9
|
from transformers.cache_utils import Cache
|
|
10
10
|
from transformers.cache_utils import HybridCache
|
|
11
11
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
12
|
-
from transformers.models.gemma3.modeling_gemma3 import _CONFIG_FOR_DOC
|
|
13
|
-
from transformers.models.gemma3.modeling_gemma3 import GEMMA3_INPUTS_DOCSTRING
|
|
14
12
|
from transformers.models.gemma3.modeling_gemma3 import Gemma3CausalLMOutputWithPast
|
|
15
|
-
from transformers.utils import add_start_docstrings_to_model_forward
|
|
16
13
|
from transformers.utils import is_torchdynamo_compiling
|
|
17
14
|
from transformers.utils import logging
|
|
18
|
-
from transformers.utils import replace_return_docstrings
|
|
19
15
|
from transformers.utils.deprecation import deprecate_kwarg
|
|
20
16
|
|
|
21
17
|
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
|
@@ -25,8 +21,6 @@ logger = logging.get_logger(__name__)
|
|
|
25
21
|
|
|
26
22
|
|
|
27
23
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
28
|
-
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
|
|
29
|
-
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
30
24
|
def causal_forward(
|
|
31
25
|
self,
|
|
32
26
|
input_ids: torch.LongTensor = None,
|
|
@@ -141,8 +135,6 @@ def causal_forward(
|
|
|
141
135
|
|
|
142
136
|
|
|
143
137
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
144
|
-
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
|
|
145
|
-
@replace_return_docstrings(output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
146
138
|
def multimodal_forward(
|
|
147
139
|
self,
|
|
148
140
|
input_ids: torch.LongTensor = None,
|
|
@@ -6,18 +6,12 @@ from typing import Union
|
|
|
6
6
|
import torch
|
|
7
7
|
|
|
8
8
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
9
|
-
from transformers.models.glm4.modeling_glm4 import _CONFIG_FOR_DOC
|
|
10
|
-
from transformers.models.glm4.modeling_glm4 import GLM4_INPUTS_DOCSTRING
|
|
11
|
-
from transformers.utils import add_start_docstrings_to_model_forward
|
|
12
|
-
from transformers.utils import replace_return_docstrings
|
|
13
9
|
from transformers.utils.deprecation import deprecate_kwarg
|
|
14
10
|
|
|
15
11
|
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
16
12
|
|
|
17
13
|
|
|
18
14
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
19
|
-
@add_start_docstrings_to_model_forward(GLM4_INPUTS_DOCSTRING)
|
|
20
|
-
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
21
15
|
def lce_forward(
|
|
22
16
|
self,
|
|
23
17
|
input_ids: torch.LongTensor = None,
|
|
@@ -7,23 +7,23 @@ from typing import Union
|
|
|
7
7
|
import torch
|
|
8
8
|
import torch.nn.functional as F
|
|
9
9
|
|
|
10
|
+
from torch.distributed.fsdp import FullyShardedDataParallel
|
|
10
11
|
from torch.nn import CrossEntropyLoss
|
|
11
12
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
12
|
-
from transformers.models.llama.modeling_llama import _CONFIG_FOR_DOC
|
|
13
|
-
from transformers.models.llama.modeling_llama import LLAMA_INPUTS_DOCSTRING
|
|
14
|
-
from transformers.utils import add_start_docstrings_to_model_forward
|
|
15
|
-
from transformers.utils import replace_return_docstrings
|
|
16
13
|
from transformers.utils.deprecation import deprecate_kwarg
|
|
17
14
|
|
|
15
|
+
from liger_kernel.transformers.fsdp import _FSDPForwardRedirection
|
|
18
16
|
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
|
19
17
|
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
18
|
+
from liger_kernel.utils import PEFT_AVAILABLE
|
|
20
19
|
|
|
21
20
|
if TYPE_CHECKING:
|
|
22
21
|
from transformers.cache_utils import Cache
|
|
23
22
|
|
|
23
|
+
if PEFT_AVAILABLE:
|
|
24
|
+
from peft.utils.other import ModulesToSaveWrapper
|
|
25
|
+
|
|
24
26
|
|
|
25
|
-
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
|
26
|
-
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
27
27
|
def lce_forward_deprecated(
|
|
28
28
|
self,
|
|
29
29
|
input_ids: torch.LongTensor = None,
|
|
@@ -137,8 +137,6 @@ def lce_forward_deprecated(
|
|
|
137
137
|
|
|
138
138
|
|
|
139
139
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
140
|
-
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
|
141
|
-
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
142
140
|
def lce_forward(
|
|
143
141
|
self,
|
|
144
142
|
input_ids: torch.LongTensor = None,
|
|
@@ -221,12 +219,12 @@ def lce_forward(
|
|
|
221
219
|
loss = None
|
|
222
220
|
# if in training mode, don't materialize logits
|
|
223
221
|
if self.training and (labels is not None or shift_labels is not None):
|
|
224
|
-
loss =
|
|
222
|
+
loss = lce_maybe_trainable_lm_head(
|
|
223
|
+
self,
|
|
225
224
|
hidden_states=kept_hidden_states,
|
|
226
|
-
|
|
225
|
+
hidden_size=self.config.hidden_size,
|
|
227
226
|
labels=labels,
|
|
228
227
|
shift_labels=shift_labels,
|
|
229
|
-
hidden_size=self.config.hidden_size,
|
|
230
228
|
**loss_kwargs,
|
|
231
229
|
)
|
|
232
230
|
|
|
@@ -251,3 +249,50 @@ def lce_forward(
|
|
|
251
249
|
hidden_states=outputs.hidden_states,
|
|
252
250
|
attentions=outputs.attentions,
|
|
253
251
|
)
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
def lce_maybe_trainable_lm_head(self, hidden_states, hidden_size, labels, shift_labels, **loss_kwargs):
|
|
255
|
+
lm_head = self.lm_head
|
|
256
|
+
|
|
257
|
+
# Unwrap the module if lm_head has been added as trainable module in PEFT LoRA configuration,
|
|
258
|
+
# i.e. listed in the modules_to_save field of LoraConfig, so the lm_head weights are read
|
|
259
|
+
# from the unwrapped module.
|
|
260
|
+
# See https://huggingface.co/docs/peft/package_reference/lora for reference.
|
|
261
|
+
if PEFT_AVAILABLE and isinstance(lm_head, ModulesToSaveWrapper):
|
|
262
|
+
lm_head = lm_head.modules_to_save.default
|
|
263
|
+
|
|
264
|
+
# If FSDP is used and lm_head is trainable, e.g., during full fine-tuning or with LoRA,
|
|
265
|
+
# reading the lm_head module weights and calling the kernel must be done within FSDP forward pass
|
|
266
|
+
# so the module entire parameters are summoned and kept in memory during the kernel execution.
|
|
267
|
+
if isinstance(lm_head, FullyShardedDataParallel):
|
|
268
|
+
return _FSDPForwardRedirection()(
|
|
269
|
+
lm_head,
|
|
270
|
+
_liger_for_causal_lm_loss,
|
|
271
|
+
lm_head.module,
|
|
272
|
+
hidden_states,
|
|
273
|
+
hidden_size,
|
|
274
|
+
labels,
|
|
275
|
+
shift_labels,
|
|
276
|
+
**loss_kwargs,
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
# FSDP is not used so we can read the lm_head weights and call the kernel directly
|
|
280
|
+
return _liger_for_causal_lm_loss(
|
|
281
|
+
lm_head=self.lm_head,
|
|
282
|
+
hidden_states=hidden_states,
|
|
283
|
+
hidden_size=hidden_size,
|
|
284
|
+
labels=labels,
|
|
285
|
+
shift_labels=shift_labels,
|
|
286
|
+
**loss_kwargs,
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
def _liger_for_causal_lm_loss(lm_head, hidden_states, hidden_size, labels, shift_labels, **loss_kwargs):
|
|
291
|
+
return LigerForCausalLMLoss(
|
|
292
|
+
hidden_states=hidden_states,
|
|
293
|
+
lm_head_weight=lm_head.weight,
|
|
294
|
+
labels=labels,
|
|
295
|
+
hidden_size=hidden_size,
|
|
296
|
+
shift_labels=shift_labels,
|
|
297
|
+
**loss_kwargs,
|
|
298
|
+
)
|
|
@@ -5,19 +5,13 @@ from typing import Union
|
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
|
|
8
|
-
from transformers.models.llava.modeling_llava import _CONFIG_FOR_DOC
|
|
9
|
-
from transformers.models.llava.modeling_llava import LLAVA_INPUTS_DOCSTRING
|
|
10
8
|
from transformers.models.llava.modeling_llava import LlavaCausalLMOutputWithPast
|
|
11
|
-
from transformers.utils import add_start_docstrings_to_model_forward
|
|
12
9
|
from transformers.utils import is_torchdynamo_compiling
|
|
13
|
-
from transformers.utils import replace_return_docstrings
|
|
14
10
|
from transformers.utils.deprecation import deprecate_kwarg
|
|
15
11
|
|
|
16
12
|
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
|
17
13
|
|
|
18
14
|
|
|
19
|
-
@add_start_docstrings_to_model_forward(LLAVA_INPUTS_DOCSTRING)
|
|
20
|
-
@replace_return_docstrings(output_type=LlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
21
15
|
def lce_forward_deprecated(
|
|
22
16
|
self,
|
|
23
17
|
input_ids: torch.LongTensor = None,
|
|
@@ -210,9 +204,7 @@ def lce_forward_deprecated(
|
|
|
210
204
|
)
|
|
211
205
|
|
|
212
206
|
|
|
213
|
-
@add_start_docstrings_to_model_forward(LLAVA_INPUTS_DOCSTRING)
|
|
214
207
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
215
|
-
@replace_return_docstrings(output_type=LlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
216
208
|
def lce_forward(
|
|
217
209
|
self,
|
|
218
210
|
input_ids: torch.LongTensor = None,
|
|
@@ -7,18 +7,12 @@ import torch
|
|
|
7
7
|
|
|
8
8
|
from transformers.cache_utils import Cache
|
|
9
9
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
10
|
-
from transformers.models.mistral.modeling_mistral import _CONFIG_FOR_DOC
|
|
11
|
-
from transformers.models.mistral.modeling_mistral import MISTRAL_INPUTS_DOCSTRING
|
|
12
|
-
from transformers.utils import add_start_docstrings_to_model_forward
|
|
13
|
-
from transformers.utils import replace_return_docstrings
|
|
14
10
|
from transformers.utils.deprecation import deprecate_kwarg
|
|
15
11
|
|
|
16
12
|
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
17
13
|
|
|
18
14
|
|
|
19
15
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
20
|
-
@add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
|
|
21
|
-
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
22
16
|
def lce_forward(
|
|
23
17
|
self,
|
|
24
18
|
input_ids: torch.LongTensor = None,
|