liger-kernel-nightly 0.0.1.dev20240819184814__py3-none-any.whl → 0.6.4.dev20251212103629__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/__init__.py +0 -0
- liger_kernel/chunked_loss/README.md +25 -0
- liger_kernel/chunked_loss/__init__.py +8 -0
- liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
- liger_kernel/chunked_loss/cpo_loss.py +157 -0
- liger_kernel/chunked_loss/dpo_loss.py +229 -0
- liger_kernel/chunked_loss/functional.py +17 -0
- liger_kernel/chunked_loss/fused_linear_distillation.py +292 -0
- liger_kernel/chunked_loss/fused_linear_ppo.py +366 -0
- liger_kernel/chunked_loss/fused_linear_preference.py +433 -0
- liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +341 -0
- liger_kernel/chunked_loss/grpo_loss.py +307 -0
- liger_kernel/chunked_loss/jsd_loss.py +200 -0
- liger_kernel/chunked_loss/kto_loss.py +210 -0
- liger_kernel/chunked_loss/orpo_loss.py +144 -0
- liger_kernel/chunked_loss/simpo_loss.py +165 -0
- liger_kernel/env_report.py +63 -0
- liger_kernel/ops/__init__.py +141 -0
- liger_kernel/ops/backends/README.md +151 -0
- liger_kernel/ops/backends/__init__.py +13 -0
- liger_kernel/ops/backends/_ascend/__init__.py +5 -0
- liger_kernel/ops/backends/_ascend/ops/__init__.py +15 -0
- liger_kernel/ops/backends/registry.py +61 -0
- liger_kernel/ops/cross_entropy.py +383 -114
- liger_kernel/ops/dyt.py +160 -0
- liger_kernel/ops/experimental/embedding.py +141 -0
- liger_kernel/ops/experimental/mm_int8int2.py +349 -0
- liger_kernel/ops/fused_add_rms_norm.py +416 -0
- liger_kernel/ops/fused_linear_cross_entropy.py +346 -132
- liger_kernel/ops/fused_linear_jsd.py +228 -0
- liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
- liger_kernel/ops/geglu.py +66 -64
- liger_kernel/ops/group_norm.py +306 -0
- liger_kernel/ops/grpo_loss.py +312 -0
- liger_kernel/ops/jsd.py +201 -0
- liger_kernel/ops/kl_div.py +262 -0
- liger_kernel/ops/layer_norm.py +320 -0
- liger_kernel/ops/llama4_rope.py +225 -0
- liger_kernel/ops/multi_token_attention.py +207 -0
- liger_kernel/ops/poly_norm.py +390 -0
- liger_kernel/ops/qwen2vl_mrope.py +222 -0
- liger_kernel/ops/rms_norm.py +484 -88
- liger_kernel/ops/rope.py +122 -117
- liger_kernel/ops/softmax.py +201 -0
- liger_kernel/ops/sparsemax.py +179 -0
- liger_kernel/ops/swiglu.py +68 -65
- liger_kernel/ops/tiled_mlp.py +136 -0
- liger_kernel/ops/tvd.py +207 -0
- liger_kernel/ops/utils.py +82 -3
- liger_kernel/transformers/__init__.py +218 -6
- liger_kernel/transformers/auto_model.py +38 -0
- liger_kernel/transformers/cross_entropy.py +52 -7
- liger_kernel/transformers/dyt.py +22 -0
- liger_kernel/transformers/experimental/__init__.py +5 -0
- liger_kernel/transformers/experimental/embedding.py +26 -0
- liger_kernel/transformers/fsdp.py +55 -0
- liger_kernel/transformers/functional.py +301 -0
- liger_kernel/transformers/fused_add_rms_norm.py +39 -0
- liger_kernel/transformers/fused_linear_cross_entropy.py +59 -10
- liger_kernel/transformers/fused_linear_jsd.py +95 -0
- liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
- liger_kernel/transformers/geglu.py +6 -7
- liger_kernel/transformers/group_norm.py +50 -0
- liger_kernel/transformers/grpo_loss.py +153 -0
- liger_kernel/transformers/jsd.py +70 -0
- liger_kernel/transformers/kl_div.py +12 -0
- liger_kernel/transformers/layer_norm.py +24 -0
- liger_kernel/transformers/llama4_rope.py +93 -0
- liger_kernel/transformers/model/falcon_h1.py +122 -0
- liger_kernel/transformers/model/gemma.py +261 -0
- liger_kernel/transformers/model/gemma2.py +283 -0
- liger_kernel/transformers/model/gemma3.py +332 -0
- liger_kernel/transformers/model/glm4.py +141 -0
- liger_kernel/transformers/model/glm4v.py +163 -0
- liger_kernel/transformers/model/glm4v_moe.py +172 -0
- liger_kernel/transformers/model/gpt_oss.py +211 -0
- liger_kernel/transformers/model/hunyuan_v1.py +134 -0
- liger_kernel/transformers/model/internvl.py +157 -0
- liger_kernel/transformers/model/llama.py +221 -41
- liger_kernel/transformers/model/llama4.py +121 -0
- liger_kernel/transformers/model/llava.py +344 -0
- liger_kernel/transformers/model/loss_utils.py +95 -0
- liger_kernel/transformers/model/mistral.py +145 -0
- liger_kernel/transformers/model/mixtral.py +293 -0
- liger_kernel/transformers/model/mllama.py +269 -0
- liger_kernel/transformers/model/olmo2.py +141 -0
- liger_kernel/transformers/model/olmo3.py +142 -0
- liger_kernel/transformers/model/output_classes.py +147 -0
- liger_kernel/transformers/model/paligemma.py +433 -0
- liger_kernel/transformers/model/phi3.py +120 -0
- liger_kernel/transformers/model/qwen2.py +259 -0
- liger_kernel/transformers/model/qwen2_5_vl.py +163 -0
- liger_kernel/transformers/model/qwen2_vl.py +159 -0
- liger_kernel/transformers/model/qwen3.py +136 -0
- liger_kernel/transformers/model/qwen3_moe.py +152 -0
- liger_kernel/transformers/model/qwen3_next.py +146 -0
- liger_kernel/transformers/model/qwen3_vl.py +150 -0
- liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
- liger_kernel/transformers/model/smollm3.py +199 -0
- liger_kernel/transformers/model/smolvlm.py +158 -0
- liger_kernel/transformers/monkey_patch.py +2816 -21
- liger_kernel/transformers/multi_token_attention.py +64 -0
- liger_kernel/transformers/poly_norm.py +42 -0
- liger_kernel/transformers/qwen2vl_mrope.py +20 -0
- liger_kernel/transformers/rms_norm.py +75 -5
- liger_kernel/transformers/rope.py +47 -3
- liger_kernel/transformers/softmax.py +12 -0
- liger_kernel/transformers/sparsemax.py +16 -0
- liger_kernel/transformers/swiglu.py +62 -6
- liger_kernel/transformers/tiled_mlp.py +133 -0
- liger_kernel/transformers/trainer/__init__.py +4 -0
- liger_kernel/transformers/trainer/orpo_trainer.py +130 -0
- liger_kernel/transformers/trainer_integration.py +2 -45
- liger_kernel/transformers/tvd.py +13 -0
- liger_kernel/triton/__init__.py +1 -3
- liger_kernel/triton/monkey_patch.py +1 -5
- liger_kernel/utils.py +96 -0
- liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/METADATA +447 -0
- liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/NOTICE +58 -0
- liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/RECORD +124 -0
- {liger_kernel_nightly-0.0.1.dev20240819184814.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/WHEEL +1 -1
- liger_kernel_nightly-0.0.1.dev20240819184814.dist-info/METADATA +0 -21
- liger_kernel_nightly-0.0.1.dev20240819184814.dist-info/NOTICE +0 -4
- liger_kernel_nightly-0.0.1.dev20240819184814.dist-info/RECORD +0 -27
- {liger_kernel_nightly-0.0.1.dev20240819184814.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.0.1.dev20240819184814.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/top_level.txt +0 -0
liger_kernel/ops/swiglu.py
CHANGED
|
@@ -2,7 +2,8 @@ import torch
|
|
|
2
2
|
import triton
|
|
3
3
|
import triton.language as tl
|
|
4
4
|
|
|
5
|
-
from liger_kernel.ops.utils import calculate_settings
|
|
5
|
+
from liger_kernel.ops.utils import calculate_settings
|
|
6
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
|
6
7
|
|
|
7
8
|
|
|
8
9
|
@triton.jit
|
|
@@ -11,44 +12,40 @@ def silu(x):
|
|
|
11
12
|
|
|
12
13
|
|
|
13
14
|
@triton.jit
|
|
14
|
-
def _swiglu_forward_kernel(
|
|
15
|
-
|
|
16
|
-
):
|
|
17
|
-
program_id = tl.program_id(0)
|
|
15
|
+
def _swiglu_forward_kernel(a_ptr, b_ptr, c_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
|
|
16
|
+
program_id = tl.program_id(0).to(tl.int64)
|
|
18
17
|
|
|
19
18
|
# locate start index
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
19
|
+
a_ptr += program_id * stride
|
|
20
|
+
b_ptr += program_id * stride
|
|
21
|
+
c_ptr += program_id * stride
|
|
23
22
|
|
|
24
23
|
col_offsets = tl.arange(0, BLOCK_SIZE)
|
|
25
24
|
mask = col_offsets < n_cols
|
|
26
25
|
|
|
27
26
|
# sigmoid requires type float32
|
|
28
|
-
a_row = tl.load(
|
|
29
|
-
b_row = tl.load(
|
|
30
|
-
c_row = silu(a_row) * b_row
|
|
31
|
-
tl.store(
|
|
27
|
+
a_row = tl.load(a_ptr + col_offsets, mask=mask, other=0).to(tl.float32)
|
|
28
|
+
b_row = tl.load(b_ptr + col_offsets, mask=mask, other=0)
|
|
29
|
+
c_row = silu(a_row).cast(b_row.dtype) * b_row
|
|
30
|
+
tl.store(c_ptr + col_offsets, c_row, mask=mask)
|
|
32
31
|
|
|
33
32
|
|
|
34
33
|
@triton.jit
|
|
35
|
-
def _swiglu_backward_kernel(
|
|
36
|
-
|
|
37
|
-
):
|
|
38
|
-
program_id = tl.program_id(0)
|
|
34
|
+
def _swiglu_backward_kernel(dc_ptr, a_ptr, b_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
|
|
35
|
+
program_id = tl.program_id(0).to(tl.int64)
|
|
39
36
|
|
|
40
37
|
# locate start index
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
38
|
+
dc_ptr += program_id * stride
|
|
39
|
+
a_ptr += program_id * stride
|
|
40
|
+
b_ptr += program_id * stride
|
|
44
41
|
|
|
45
42
|
col_offsets = tl.arange(0, BLOCK_SIZE)
|
|
46
43
|
mask = col_offsets < n_cols
|
|
47
44
|
|
|
48
|
-
dc_row = tl.load(
|
|
45
|
+
dc_row = tl.load(dc_ptr + col_offsets, mask=mask, other=0)
|
|
49
46
|
# sigmoid requires type float32
|
|
50
|
-
a_row = tl.load(
|
|
51
|
-
b_row = tl.load(
|
|
47
|
+
a_row = tl.load(a_ptr + col_offsets, mask=mask, other=0).to(tl.float32)
|
|
48
|
+
b_row = tl.load(b_ptr + col_offsets, mask=mask, other=0)
|
|
52
49
|
|
|
53
50
|
# recomputation to save memory
|
|
54
51
|
sig_a = tl.sigmoid(a_row)
|
|
@@ -56,58 +53,64 @@ def _swiglu_backward_kernel(
|
|
|
56
53
|
db_row = dc_row * silu_a
|
|
57
54
|
da_row = dc_row * (silu_a * (1 - sig_a) + sig_a) * b_row
|
|
58
55
|
|
|
59
|
-
tl.store(
|
|
60
|
-
tl.store(
|
|
56
|
+
tl.store(a_ptr + col_offsets, da_row, mask=mask)
|
|
57
|
+
tl.store(b_ptr + col_offsets, db_row, mask=mask)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def swiglu_forward(a, b):
|
|
61
|
+
ori_shape = a.shape
|
|
62
|
+
|
|
63
|
+
n_cols = ori_shape[-1]
|
|
64
|
+
a = a.view(-1, n_cols)
|
|
65
|
+
b = b.view(-1, n_cols)
|
|
66
|
+
c = torch.empty_like(a)
|
|
67
|
+
n_rows = a.shape[0]
|
|
68
|
+
|
|
69
|
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
70
|
+
|
|
71
|
+
_swiglu_forward_kernel[(n_rows,)](
|
|
72
|
+
a,
|
|
73
|
+
b,
|
|
74
|
+
c,
|
|
75
|
+
c.stride(-2),
|
|
76
|
+
n_cols=n_cols,
|
|
77
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
78
|
+
num_warps=num_warps,
|
|
79
|
+
)
|
|
80
|
+
return a, b, c.view(*ori_shape)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def swiglu_backward(a, b, dc):
|
|
84
|
+
ori_shape = dc.shape
|
|
85
|
+
n_cols = ori_shape[-1]
|
|
86
|
+
dc = dc.view(-1, n_cols)
|
|
87
|
+
n_rows = dc.shape[0]
|
|
88
|
+
|
|
89
|
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
90
|
+
|
|
91
|
+
_swiglu_backward_kernel[(n_rows,)](
|
|
92
|
+
dc,
|
|
93
|
+
a,
|
|
94
|
+
b,
|
|
95
|
+
dc.stride(-2),
|
|
96
|
+
n_cols=n_cols,
|
|
97
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
98
|
+
num_warps=num_warps,
|
|
99
|
+
)
|
|
100
|
+
return a.view(*ori_shape), b.view(*ori_shape)
|
|
61
101
|
|
|
62
102
|
|
|
63
103
|
class LigerSiLUMulFunction(torch.autograd.Function):
|
|
64
104
|
@staticmethod
|
|
65
105
|
@ensure_contiguous
|
|
66
106
|
def forward(ctx, a, b):
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
n_cols = ori_shape[-1]
|
|
70
|
-
a = a.view(-1, n_cols)
|
|
71
|
-
b = b.view(-1, n_cols)
|
|
72
|
-
c = torch.zeros_like(a)
|
|
73
|
-
n_rows = a.shape[0]
|
|
74
|
-
|
|
75
|
-
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
76
|
-
|
|
77
|
-
_swiglu_forward_kernel[(n_rows,)](
|
|
78
|
-
a,
|
|
79
|
-
b,
|
|
80
|
-
c,
|
|
81
|
-
c.stride(-2),
|
|
82
|
-
n_cols=n_cols,
|
|
83
|
-
BLOCK_SIZE=BLOCK_SIZE,
|
|
84
|
-
num_warps=num_warps,
|
|
85
|
-
)
|
|
86
|
-
|
|
107
|
+
a, b, c = swiglu_forward(a, b)
|
|
87
108
|
ctx.save_for_backward(a, b)
|
|
88
|
-
|
|
89
|
-
return c.view(*ori_shape)
|
|
109
|
+
return c
|
|
90
110
|
|
|
91
111
|
@staticmethod
|
|
92
112
|
@ensure_contiguous
|
|
93
113
|
def backward(ctx, dc):
|
|
94
|
-
|
|
95
|
-
ori_shape = dc.shape
|
|
96
|
-
n_cols = ori_shape[-1]
|
|
97
|
-
dc = dc.view(-1, n_cols)
|
|
98
114
|
a, b = ctx.saved_tensors
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
102
|
-
|
|
103
|
-
_swiglu_backward_kernel[(n_rows,)](
|
|
104
|
-
dc,
|
|
105
|
-
a,
|
|
106
|
-
b,
|
|
107
|
-
dc.stride(-2),
|
|
108
|
-
n_cols=n_cols,
|
|
109
|
-
BLOCK_SIZE=BLOCK_SIZE,
|
|
110
|
-
num_warps=num_warps,
|
|
111
|
-
)
|
|
112
|
-
|
|
113
|
-
return a.view(*ori_shape), b.view(*ori_shape)
|
|
115
|
+
a, b = swiglu_backward(a, b, dc)
|
|
116
|
+
return a, b
|
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
import math
|
|
2
|
+
|
|
3
|
+
from typing import Callable
|
|
4
|
+
from typing import List
|
|
5
|
+
from typing import Optional
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class LigerTiledMLPFunction(torch.autograd.Function):
|
|
13
|
+
"""
|
|
14
|
+
Based on DeepSpeed's TiledMLP:
|
|
15
|
+
https://github.com/deepspeedai/DeepSpeed/blob/v0.18.2/deepspeed/runtime/sequence_parallel/ulysses_sp.py#L838
|
|
16
|
+
|
|
17
|
+
Perform a tiled MLP computation to massively reduce memory usage needed to compute MLP
|
|
18
|
+
when using very long sequence lengths.
|
|
19
|
+
|
|
20
|
+
This module re-computes `forward` in the `backward`. So the `forward` occurs twice each iteration.
|
|
21
|
+
And if you're using activation checkpointing it then occurs thrice.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
fn: the function to call on sharded inputs (e.g., mlp.forward)
|
|
25
|
+
mlp_module: the MLP nn.Module object
|
|
26
|
+
x: the input to MLP.forward (hidden_states)
|
|
27
|
+
shards: how many shards to use
|
|
28
|
+
compute_params: a list of weights engaged in the compute
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
the computed hidden_states
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
@staticmethod
|
|
35
|
+
@ensure_contiguous
|
|
36
|
+
def forward(
|
|
37
|
+
ctx,
|
|
38
|
+
fn: Callable,
|
|
39
|
+
mlp_module: torch.nn.Module,
|
|
40
|
+
x: torch.Tensor,
|
|
41
|
+
shards: int,
|
|
42
|
+
compute_params: Optional[List[torch.nn.Parameter]] = None,
|
|
43
|
+
) -> torch.Tensor:
|
|
44
|
+
ctx.fn = fn
|
|
45
|
+
ctx.mlp_module = mlp_module
|
|
46
|
+
ctx.shards = shards
|
|
47
|
+
ctx.save_for_backward(x)
|
|
48
|
+
|
|
49
|
+
# x.shape could be [bs, seqlen, hidden_size] or [seqlen, hidden_size] (moe experts)
|
|
50
|
+
x_shards = list(torch.chunk(x, chunks=shards, dim=-2))
|
|
51
|
+
with torch.no_grad():
|
|
52
|
+
output_shards = [fn(mlp_module, x_shard) for x_shard in x_shards]
|
|
53
|
+
output_unsharded = torch.cat(output_shards, dim=-2)
|
|
54
|
+
|
|
55
|
+
return output_unsharded
|
|
56
|
+
|
|
57
|
+
@staticmethod
|
|
58
|
+
@ensure_contiguous
|
|
59
|
+
def backward(ctx, *grads) -> tuple:
|
|
60
|
+
fn = ctx.fn
|
|
61
|
+
(x,) = ctx.saved_tensors
|
|
62
|
+
mlp_module = ctx.mlp_module
|
|
63
|
+
shards = ctx.shards
|
|
64
|
+
|
|
65
|
+
x_requires_grad = x.requires_grad
|
|
66
|
+
x = x.detach()
|
|
67
|
+
# detach() unsets x.requires_grad, so restore it
|
|
68
|
+
x.requires_grad_(x_requires_grad)
|
|
69
|
+
|
|
70
|
+
# x.shape could be [bs, seqlen, hidden_size] or [seqlen, hidden_size] (moe experts)
|
|
71
|
+
hidden_size = x.shape[-1]
|
|
72
|
+
x_shape_orig = x.shape
|
|
73
|
+
|
|
74
|
+
# flatten bs+seqlen to avoid having stride issues when narrowing into seqlen w/ bs>1
|
|
75
|
+
x = x.view(-1, hidden_size)
|
|
76
|
+
incoming_grad = grads[0].view(-1, hidden_size)
|
|
77
|
+
x_grad = torch.zeros_like(x)
|
|
78
|
+
|
|
79
|
+
x_shards = list(torch.chunk(x, chunks=shards, dim=0))
|
|
80
|
+
|
|
81
|
+
for i, x_shard in enumerate(x_shards):
|
|
82
|
+
x_shard.requires_grad_(x_requires_grad)
|
|
83
|
+
|
|
84
|
+
# if seqlen is not exactly divisible by shards the last step will be shorter than shard_step
|
|
85
|
+
shard_step = x_shards[i].shape[0]
|
|
86
|
+
shard_offset = i * x_shards[0].shape[0]
|
|
87
|
+
|
|
88
|
+
x_shard.grad = x_grad.narrow(0, shard_offset, shard_step).view_as(x_shard)
|
|
89
|
+
incoming_grad_shard = incoming_grad.narrow(0, shard_offset, shard_step).view_as(x_shard)
|
|
90
|
+
|
|
91
|
+
with torch.enable_grad():
|
|
92
|
+
output = fn(mlp_module, x_shard)
|
|
93
|
+
torch.autograd.backward(output, incoming_grad_shard)
|
|
94
|
+
|
|
95
|
+
# unflatten
|
|
96
|
+
x_grad = x_grad.view(x_shape_orig)
|
|
97
|
+
|
|
98
|
+
return (None, None, x_grad, None, None)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def apply_tiled_mlp(
|
|
102
|
+
fn: Callable,
|
|
103
|
+
mlp_module: torch.nn.Module,
|
|
104
|
+
x: torch.Tensor,
|
|
105
|
+
num_shards: Optional[int] = None,
|
|
106
|
+
compute_params: Optional[List[torch.nn.Parameter]] = None,
|
|
107
|
+
) -> torch.Tensor:
|
|
108
|
+
"""
|
|
109
|
+
Apply tiled MLP computation for memory efficiency.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
fn: the function to call on sharded inputs (e.g., lambda module, x: module(x))
|
|
113
|
+
mlp_module: the MLP nn.Module object
|
|
114
|
+
x: the input tensor with shape [bs, seqlen, hidden_size] or [seqlen, hidden_size]
|
|
115
|
+
num_shards: number of shards to use. If None, automatically calculated as ceil(seqlen / hidden_size)
|
|
116
|
+
compute_params: list of parameters for DeepSpeed ZeRO optimization
|
|
117
|
+
|
|
118
|
+
Returns:
|
|
119
|
+
output tensor with the same shape as input
|
|
120
|
+
"""
|
|
121
|
+
if num_shards is None:
|
|
122
|
+
# x.shape could be [bs, seqlen, hidden_size] or [seqlen, hidden_size]
|
|
123
|
+
hidden_size = x.shape[-1]
|
|
124
|
+
seqlen = x.shape[-2]
|
|
125
|
+
num_shards = math.ceil(seqlen / hidden_size)
|
|
126
|
+
|
|
127
|
+
# Ensure num_shards is at least 1
|
|
128
|
+
num_shards = max(1, num_shards)
|
|
129
|
+
|
|
130
|
+
return LigerTiledMLPFunction.apply(
|
|
131
|
+
fn,
|
|
132
|
+
mlp_module,
|
|
133
|
+
x,
|
|
134
|
+
num_shards,
|
|
135
|
+
compute_params,
|
|
136
|
+
)
|
liger_kernel/ops/tvd.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
1
|
+
from typing import Literal
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import triton
|
|
6
|
+
import triton.language as tl
|
|
7
|
+
|
|
8
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
|
9
|
+
|
|
10
|
+
MAX_FUSED_SIZE = 65536 // 4
|
|
11
|
+
|
|
12
|
+
REDUCTION_LITERAL = Literal["none", "sum", "mean", "batchmean"]
|
|
13
|
+
|
|
14
|
+
_REDUCTION_MODE_NONE = tl.constexpr(0)
|
|
15
|
+
_REDUCTION_MODE_SUM = tl.constexpr(1)
|
|
16
|
+
_REDUCTION_MODE_MEAN = tl.constexpr(2)
|
|
17
|
+
_REDUCTION_MODE_BATCHMEAN = tl.constexpr(3)
|
|
18
|
+
|
|
19
|
+
_str_to_reduction_mode = {
|
|
20
|
+
"none": _REDUCTION_MODE_NONE.value,
|
|
21
|
+
"sum": _REDUCTION_MODE_SUM.value,
|
|
22
|
+
"mean": _REDUCTION_MODE_MEAN.value,
|
|
23
|
+
"batchmean": _REDUCTION_MODE_BATCHMEAN.value,
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def get_num_warps(BLOCK_SIZE):
|
|
28
|
+
num_warps = 4
|
|
29
|
+
if BLOCK_SIZE >= 32768:
|
|
30
|
+
num_warps = 32
|
|
31
|
+
elif BLOCK_SIZE >= 8192:
|
|
32
|
+
num_warps = 16
|
|
33
|
+
elif BLOCK_SIZE >= 2048:
|
|
34
|
+
num_warps = 8
|
|
35
|
+
|
|
36
|
+
return num_warps
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@triton.jit
|
|
40
|
+
def _tv_distance_kernel(
|
|
41
|
+
p_ptr,
|
|
42
|
+
p_stride,
|
|
43
|
+
q_ptr,
|
|
44
|
+
q_stride,
|
|
45
|
+
loss_ptr,
|
|
46
|
+
loss_stride,
|
|
47
|
+
grads_ptr,
|
|
48
|
+
grads_stride,
|
|
49
|
+
label_ptr,
|
|
50
|
+
ignore_index: tl.constexpr,
|
|
51
|
+
n_cols,
|
|
52
|
+
BLOCK_SIZE: tl.constexpr,
|
|
53
|
+
HAS_LABEL: tl.constexpr,
|
|
54
|
+
reduction: tl.constexpr = _REDUCTION_MODE_BATCHMEAN,
|
|
55
|
+
):
|
|
56
|
+
pid = tl.program_id(0).to(tl.int64)
|
|
57
|
+
p_ptr += pid * p_stride
|
|
58
|
+
q_ptr += pid * q_stride
|
|
59
|
+
loss_ptr += pid * loss_stride
|
|
60
|
+
grads_ptr += pid * grads_stride
|
|
61
|
+
label_ptr += pid
|
|
62
|
+
|
|
63
|
+
base_offsets = tl.arange(0, BLOCK_SIZE)
|
|
64
|
+
|
|
65
|
+
if HAS_LABEL:
|
|
66
|
+
label = tl.load(label_ptr)
|
|
67
|
+
if label == ignore_index:
|
|
68
|
+
for i in range(0, n_cols, BLOCK_SIZE):
|
|
69
|
+
offsets = i + base_offsets
|
|
70
|
+
mask = offsets < n_cols
|
|
71
|
+
tl.store(grads_ptr + offsets, 0.0, mask=mask)
|
|
72
|
+
if reduction == _REDUCTION_MODE_NONE:
|
|
73
|
+
tl.store(loss_ptr + offsets, 0.0, mask=mask)
|
|
74
|
+
return
|
|
75
|
+
|
|
76
|
+
loss_sum = 0.0
|
|
77
|
+
for i in range(0, n_cols, BLOCK_SIZE):
|
|
78
|
+
offsets = i + base_offsets
|
|
79
|
+
mask = offsets < n_cols
|
|
80
|
+
|
|
81
|
+
p = tl.load(p_ptr + offsets, mask=mask, other=0.0)
|
|
82
|
+
q = tl.load(q_ptr + offsets, mask=mask, other=0.0)
|
|
83
|
+
|
|
84
|
+
# TVD(P || Q) = 0.5 * |P - Q|
|
|
85
|
+
tv_loss = 0.5 * tl.abs(p - q)
|
|
86
|
+
|
|
87
|
+
grad_res = tl.where(p > q, 0.5, -0.5)
|
|
88
|
+
|
|
89
|
+
tl.store(grads_ptr + offsets, grad_res, mask=mask)
|
|
90
|
+
|
|
91
|
+
if reduction == _REDUCTION_MODE_NONE:
|
|
92
|
+
tl.store(loss_ptr + offsets, tv_loss, mask=mask)
|
|
93
|
+
else:
|
|
94
|
+
loss_sum += tl.sum(tv_loss, axis=0)
|
|
95
|
+
|
|
96
|
+
if reduction != _REDUCTION_MODE_NONE:
|
|
97
|
+
tl.store(loss_ptr, loss_sum)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def tv_distance_forward_triton(p, q, shift_labels, reduction, ignore_index, has_label):
|
|
101
|
+
BT, V = p.shape
|
|
102
|
+
|
|
103
|
+
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
|
104
|
+
num_warps = get_num_warps(BLOCK_SIZE)
|
|
105
|
+
|
|
106
|
+
grid = (BT,)
|
|
107
|
+
|
|
108
|
+
reduction = _str_to_reduction_mode[reduction]
|
|
109
|
+
|
|
110
|
+
out_size = (BT, V) if reduction == _REDUCTION_MODE_NONE.value else (BT,)
|
|
111
|
+
output_tensor = torch.zeros(out_size, device=p.device, dtype=torch.float32)
|
|
112
|
+
grads = torch.empty_like(p)
|
|
113
|
+
|
|
114
|
+
n_non_ignore = (shift_labels != ignore_index).sum().item() if has_label else BT
|
|
115
|
+
|
|
116
|
+
_tv_distance_kernel[grid](
|
|
117
|
+
p,
|
|
118
|
+
p.stride(0),
|
|
119
|
+
q,
|
|
120
|
+
q.stride(0),
|
|
121
|
+
output_tensor,
|
|
122
|
+
output_tensor.stride(0),
|
|
123
|
+
grads,
|
|
124
|
+
grads.stride(0),
|
|
125
|
+
shift_labels if has_label else torch.empty(1, device=p.device),
|
|
126
|
+
ignore_index,
|
|
127
|
+
V,
|
|
128
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
129
|
+
HAS_LABEL=has_label,
|
|
130
|
+
num_warps=num_warps,
|
|
131
|
+
reduction=reduction,
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
if reduction == _REDUCTION_MODE_BATCHMEAN.value:
|
|
135
|
+
return output_tensor.sum() / n_non_ignore, grads / n_non_ignore
|
|
136
|
+
elif reduction == _REDUCTION_MODE_SUM.value:
|
|
137
|
+
return output_tensor.sum(dim=0), grads
|
|
138
|
+
elif reduction == _REDUCTION_MODE_MEAN.value:
|
|
139
|
+
return output_tensor.sum() / (n_non_ignore * V), grads / (n_non_ignore * V)
|
|
140
|
+
else:
|
|
141
|
+
return output_tensor, grads
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def tvd_backward_triton(grad_output, grads):
|
|
145
|
+
# If cross entropy is the last layer, grad_output is 1.0. Skip the mul then.
|
|
146
|
+
if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
|
|
147
|
+
return grads
|
|
148
|
+
|
|
149
|
+
return grads * grad_output
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
class LigerTVDLossFunction(torch.autograd.Function):
|
|
153
|
+
"""
|
|
154
|
+
Class implementing the forward and backward pass for the Total Variation Distance Loss using Triton.
|
|
155
|
+
"""
|
|
156
|
+
|
|
157
|
+
@staticmethod
|
|
158
|
+
@ensure_contiguous
|
|
159
|
+
def forward(
|
|
160
|
+
ctx,
|
|
161
|
+
p: torch.Tensor,
|
|
162
|
+
q: torch.Tensor,
|
|
163
|
+
shift_labels: Optional[torch.Tensor] = None,
|
|
164
|
+
reduction: REDUCTION_LITERAL = "batchmean",
|
|
165
|
+
ignore_index: int = -100,
|
|
166
|
+
) -> torch.Tensor:
|
|
167
|
+
"""A forward pass for the Total Variation Distance Loss.
|
|
168
|
+
|
|
169
|
+
Args:
|
|
170
|
+
ctx: Torch autograd context
|
|
171
|
+
p (torch.Tensor): A tensor of shape (BT, V) containing the first distribution.
|
|
172
|
+
q (torch.Tensor): A tensor of shape (BT, V) containing the second distribution.
|
|
173
|
+
shift_labels (Optional[torch.Tensor]): A tensor of shape (BT,) containing the labels.
|
|
174
|
+
reduction (REDUCTION_LITERAL, optional): The reduction method to be applied. Defaults to "batchmean".
|
|
175
|
+
ignore_index (int, optional): The index to ignore during loss calculation. Defaults to -100.
|
|
176
|
+
|
|
177
|
+
Returns:
|
|
178
|
+
torch.Tensor: The computed Total Variation Distance Loss.
|
|
179
|
+
"""
|
|
180
|
+
has_label = False
|
|
181
|
+
if shift_labels is not None:
|
|
182
|
+
assert shift_labels.shape == (p.shape[0],), (
|
|
183
|
+
f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}"
|
|
184
|
+
)
|
|
185
|
+
shift_labels = shift_labels.contiguous()
|
|
186
|
+
has_label = True
|
|
187
|
+
|
|
188
|
+
loss, grads = tv_distance_forward_triton(p, q, shift_labels, reduction, ignore_index, has_label)
|
|
189
|
+
ctx.save_for_backward(grads)
|
|
190
|
+
return loss
|
|
191
|
+
|
|
192
|
+
@staticmethod
|
|
193
|
+
@ensure_contiguous
|
|
194
|
+
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
|
|
195
|
+
"""A backward pass for the Total Variation Distance Loss.
|
|
196
|
+
|
|
197
|
+
Args:
|
|
198
|
+
ctx: Torch autograd context
|
|
199
|
+
grad_output (torch.Tensor): The gradient of the loss with respect to the output.
|
|
200
|
+
|
|
201
|
+
Returns:
|
|
202
|
+
tuple[torch.Tensor, None, None, None, None]: The gradient of the loss with respect to the inputs.
|
|
203
|
+
"""
|
|
204
|
+
(grads,) = ctx.saved_tensors
|
|
205
|
+
grads = tvd_backward_triton(grad_output, grads)
|
|
206
|
+
|
|
207
|
+
return grads, None, None, None, None
|
liger_kernel/ops/utils.py
CHANGED
|
@@ -1,11 +1,33 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This file incorporates code from Unsloth licensed under the Apache License, Version 2.0.
|
|
3
|
+
See the original Unsloth repository at https://github.com/unslothai/unsloth.
|
|
4
|
+
|
|
5
|
+
The following line
|
|
6
|
+
https://github.com/linkedin/Liger-Kernel/blob/7382a8761f9af679482b968f9348013d933947c7/src/liger_kernel/ops/utils.py#L23
|
|
7
|
+
is based on code from Unsloth, located at:
|
|
8
|
+
https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/utils.py#L43
|
|
9
|
+
|
|
10
|
+
Modifications made by Yanning Chen, 2024.
|
|
11
|
+
"""
|
|
12
|
+
|
|
1
13
|
import functools
|
|
2
14
|
import importlib
|
|
15
|
+
import operator
|
|
16
|
+
|
|
3
17
|
from typing import Callable
|
|
4
18
|
|
|
5
19
|
import torch
|
|
6
20
|
import triton
|
|
21
|
+
import triton.language as tl
|
|
22
|
+
|
|
7
23
|
from packaging.version import Version
|
|
8
24
|
|
|
25
|
+
from liger_kernel.utils import infer_device
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def is_hip() -> bool:
|
|
29
|
+
return torch.version.hip is not None
|
|
30
|
+
|
|
9
31
|
|
|
10
32
|
def ensure_contiguous(fn):
|
|
11
33
|
@functools.wraps(fn)
|
|
@@ -27,13 +49,12 @@ def calculate_settings(n):
|
|
|
27
49
|
BLOCK_SIZE = triton.next_power_of_2(n)
|
|
28
50
|
if BLOCK_SIZE > MAX_FUSED_SIZE:
|
|
29
51
|
raise RuntimeError(
|
|
30
|
-
f"Cannot launch Triton kernel since n = {n} exceeds "
|
|
31
|
-
f"the recommended Triton blocksize = {MAX_FUSED_SIZE}."
|
|
52
|
+
f"Cannot launch Triton kernel since n = {n} exceeds the recommended Triton blocksize = {MAX_FUSED_SIZE}."
|
|
32
53
|
)
|
|
33
54
|
|
|
34
55
|
num_warps = 4
|
|
35
56
|
if BLOCK_SIZE >= 32768:
|
|
36
|
-
num_warps = 32
|
|
57
|
+
num_warps = 32 if not is_hip() else 16
|
|
37
58
|
elif BLOCK_SIZE >= 8192:
|
|
38
59
|
num_warps = 16
|
|
39
60
|
elif BLOCK_SIZE >= 2048:
|
|
@@ -48,3 +69,61 @@ def compare_version(package: str, operator: Callable, target: str):
|
|
|
48
69
|
return False
|
|
49
70
|
pkg_version = Version(pkg.__version__)
|
|
50
71
|
return operator(pkg_version, Version(target))
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def get_amp_custom_fwd_bwd() -> Callable:
|
|
75
|
+
device = infer_device()
|
|
76
|
+
if compare_version("torch", operator.ge, "2.4.0"):
|
|
77
|
+
return (
|
|
78
|
+
functools.partial(torch.amp.custom_fwd, device_type=device),
|
|
79
|
+
functools.partial(torch.amp.custom_bwd, device_type=device),
|
|
80
|
+
)
|
|
81
|
+
if hasattr(torch, "npu") and getattr(torch.npu, "amp", None) is not None:
|
|
82
|
+
return torch.npu.amp.custom_fwd, torch.npu.amp.custom_bwd
|
|
83
|
+
return torch.cuda.amp.custom_fwd, torch.cuda.amp.custom_bwd
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
amp_custom_fwd, amp_custom_bwd = get_amp_custom_fwd_bwd()
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
torch_to_triton_dtype = {
|
|
90
|
+
torch.float32: tl.float32,
|
|
91
|
+
torch.float16: tl.float16,
|
|
92
|
+
torch.bfloat16: tl.bfloat16,
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
@triton.jit
|
|
97
|
+
def element_mul_kernel(
|
|
98
|
+
X_ptr,
|
|
99
|
+
X_stride,
|
|
100
|
+
grad_output_ptr,
|
|
101
|
+
n_cols,
|
|
102
|
+
BLOCK_SIZE: tl.constexpr,
|
|
103
|
+
):
|
|
104
|
+
"""
|
|
105
|
+
This function multiplies each element of the tensor pointed by X_ptr with the value pointed by grad_output_ptr.
|
|
106
|
+
The multiplication is performed in-place on the tensor pointed by X_ptr.
|
|
107
|
+
|
|
108
|
+
Parameters:
|
|
109
|
+
X_ptr: Pointer to the input tensor.
|
|
110
|
+
X_stride (int): The stride of the input tensor.
|
|
111
|
+
grad_output_ptr: Pointer to the gradient output value.
|
|
112
|
+
n_cols (int): The number of columns in the input tensor.
|
|
113
|
+
BLOCK_SIZE (int): The block size for Triton operations.
|
|
114
|
+
"""
|
|
115
|
+
|
|
116
|
+
# Get the program ID and convert it to int64 to avoid overflow
|
|
117
|
+
program_id = tl.program_id(0).to(tl.int64)
|
|
118
|
+
|
|
119
|
+
# Locate the start index
|
|
120
|
+
X_ptr += program_id * X_stride
|
|
121
|
+
|
|
122
|
+
# Load the gradient output value
|
|
123
|
+
grad_output = tl.load(grad_output_ptr)
|
|
124
|
+
|
|
125
|
+
# Perform the element-wise multiplication
|
|
126
|
+
for i in range(0, n_cols, BLOCK_SIZE):
|
|
127
|
+
X_offsets = i + tl.arange(0, BLOCK_SIZE)
|
|
128
|
+
X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols)
|
|
129
|
+
tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols)
|