liger-kernel-nightly 0.0.1.dev20240819184814__py3-none-any.whl → 0.6.4.dev20251212103629__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (126) hide show
  1. liger_kernel/__init__.py +0 -0
  2. liger_kernel/chunked_loss/README.md +25 -0
  3. liger_kernel/chunked_loss/__init__.py +8 -0
  4. liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
  5. liger_kernel/chunked_loss/cpo_loss.py +157 -0
  6. liger_kernel/chunked_loss/dpo_loss.py +229 -0
  7. liger_kernel/chunked_loss/functional.py +17 -0
  8. liger_kernel/chunked_loss/fused_linear_distillation.py +292 -0
  9. liger_kernel/chunked_loss/fused_linear_ppo.py +366 -0
  10. liger_kernel/chunked_loss/fused_linear_preference.py +433 -0
  11. liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +341 -0
  12. liger_kernel/chunked_loss/grpo_loss.py +307 -0
  13. liger_kernel/chunked_loss/jsd_loss.py +200 -0
  14. liger_kernel/chunked_loss/kto_loss.py +210 -0
  15. liger_kernel/chunked_loss/orpo_loss.py +144 -0
  16. liger_kernel/chunked_loss/simpo_loss.py +165 -0
  17. liger_kernel/env_report.py +63 -0
  18. liger_kernel/ops/__init__.py +141 -0
  19. liger_kernel/ops/backends/README.md +151 -0
  20. liger_kernel/ops/backends/__init__.py +13 -0
  21. liger_kernel/ops/backends/_ascend/__init__.py +5 -0
  22. liger_kernel/ops/backends/_ascend/ops/__init__.py +15 -0
  23. liger_kernel/ops/backends/registry.py +61 -0
  24. liger_kernel/ops/cross_entropy.py +383 -114
  25. liger_kernel/ops/dyt.py +160 -0
  26. liger_kernel/ops/experimental/embedding.py +141 -0
  27. liger_kernel/ops/experimental/mm_int8int2.py +349 -0
  28. liger_kernel/ops/fused_add_rms_norm.py +416 -0
  29. liger_kernel/ops/fused_linear_cross_entropy.py +346 -132
  30. liger_kernel/ops/fused_linear_jsd.py +228 -0
  31. liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
  32. liger_kernel/ops/geglu.py +66 -64
  33. liger_kernel/ops/group_norm.py +306 -0
  34. liger_kernel/ops/grpo_loss.py +312 -0
  35. liger_kernel/ops/jsd.py +201 -0
  36. liger_kernel/ops/kl_div.py +262 -0
  37. liger_kernel/ops/layer_norm.py +320 -0
  38. liger_kernel/ops/llama4_rope.py +225 -0
  39. liger_kernel/ops/multi_token_attention.py +207 -0
  40. liger_kernel/ops/poly_norm.py +390 -0
  41. liger_kernel/ops/qwen2vl_mrope.py +222 -0
  42. liger_kernel/ops/rms_norm.py +484 -88
  43. liger_kernel/ops/rope.py +122 -117
  44. liger_kernel/ops/softmax.py +201 -0
  45. liger_kernel/ops/sparsemax.py +179 -0
  46. liger_kernel/ops/swiglu.py +68 -65
  47. liger_kernel/ops/tiled_mlp.py +136 -0
  48. liger_kernel/ops/tvd.py +207 -0
  49. liger_kernel/ops/utils.py +82 -3
  50. liger_kernel/transformers/__init__.py +218 -6
  51. liger_kernel/transformers/auto_model.py +38 -0
  52. liger_kernel/transformers/cross_entropy.py +52 -7
  53. liger_kernel/transformers/dyt.py +22 -0
  54. liger_kernel/transformers/experimental/__init__.py +5 -0
  55. liger_kernel/transformers/experimental/embedding.py +26 -0
  56. liger_kernel/transformers/fsdp.py +55 -0
  57. liger_kernel/transformers/functional.py +301 -0
  58. liger_kernel/transformers/fused_add_rms_norm.py +39 -0
  59. liger_kernel/transformers/fused_linear_cross_entropy.py +59 -10
  60. liger_kernel/transformers/fused_linear_jsd.py +95 -0
  61. liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
  62. liger_kernel/transformers/geglu.py +6 -7
  63. liger_kernel/transformers/group_norm.py +50 -0
  64. liger_kernel/transformers/grpo_loss.py +153 -0
  65. liger_kernel/transformers/jsd.py +70 -0
  66. liger_kernel/transformers/kl_div.py +12 -0
  67. liger_kernel/transformers/layer_norm.py +24 -0
  68. liger_kernel/transformers/llama4_rope.py +93 -0
  69. liger_kernel/transformers/model/falcon_h1.py +122 -0
  70. liger_kernel/transformers/model/gemma.py +261 -0
  71. liger_kernel/transformers/model/gemma2.py +283 -0
  72. liger_kernel/transformers/model/gemma3.py +332 -0
  73. liger_kernel/transformers/model/glm4.py +141 -0
  74. liger_kernel/transformers/model/glm4v.py +163 -0
  75. liger_kernel/transformers/model/glm4v_moe.py +172 -0
  76. liger_kernel/transformers/model/gpt_oss.py +211 -0
  77. liger_kernel/transformers/model/hunyuan_v1.py +134 -0
  78. liger_kernel/transformers/model/internvl.py +157 -0
  79. liger_kernel/transformers/model/llama.py +221 -41
  80. liger_kernel/transformers/model/llama4.py +121 -0
  81. liger_kernel/transformers/model/llava.py +344 -0
  82. liger_kernel/transformers/model/loss_utils.py +95 -0
  83. liger_kernel/transformers/model/mistral.py +145 -0
  84. liger_kernel/transformers/model/mixtral.py +293 -0
  85. liger_kernel/transformers/model/mllama.py +269 -0
  86. liger_kernel/transformers/model/olmo2.py +141 -0
  87. liger_kernel/transformers/model/olmo3.py +142 -0
  88. liger_kernel/transformers/model/output_classes.py +147 -0
  89. liger_kernel/transformers/model/paligemma.py +433 -0
  90. liger_kernel/transformers/model/phi3.py +120 -0
  91. liger_kernel/transformers/model/qwen2.py +259 -0
  92. liger_kernel/transformers/model/qwen2_5_vl.py +163 -0
  93. liger_kernel/transformers/model/qwen2_vl.py +159 -0
  94. liger_kernel/transformers/model/qwen3.py +136 -0
  95. liger_kernel/transformers/model/qwen3_moe.py +152 -0
  96. liger_kernel/transformers/model/qwen3_next.py +146 -0
  97. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  98. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  99. liger_kernel/transformers/model/smollm3.py +199 -0
  100. liger_kernel/transformers/model/smolvlm.py +158 -0
  101. liger_kernel/transformers/monkey_patch.py +2816 -21
  102. liger_kernel/transformers/multi_token_attention.py +64 -0
  103. liger_kernel/transformers/poly_norm.py +42 -0
  104. liger_kernel/transformers/qwen2vl_mrope.py +20 -0
  105. liger_kernel/transformers/rms_norm.py +75 -5
  106. liger_kernel/transformers/rope.py +47 -3
  107. liger_kernel/transformers/softmax.py +12 -0
  108. liger_kernel/transformers/sparsemax.py +16 -0
  109. liger_kernel/transformers/swiglu.py +62 -6
  110. liger_kernel/transformers/tiled_mlp.py +133 -0
  111. liger_kernel/transformers/trainer/__init__.py +4 -0
  112. liger_kernel/transformers/trainer/orpo_trainer.py +130 -0
  113. liger_kernel/transformers/trainer_integration.py +2 -45
  114. liger_kernel/transformers/tvd.py +13 -0
  115. liger_kernel/triton/__init__.py +1 -3
  116. liger_kernel/triton/monkey_patch.py +1 -5
  117. liger_kernel/utils.py +96 -0
  118. liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/METADATA +447 -0
  119. liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/NOTICE +58 -0
  120. liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/RECORD +124 -0
  121. {liger_kernel_nightly-0.0.1.dev20240819184814.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/WHEEL +1 -1
  122. liger_kernel_nightly-0.0.1.dev20240819184814.dist-info/METADATA +0 -21
  123. liger_kernel_nightly-0.0.1.dev20240819184814.dist-info/NOTICE +0 -4
  124. liger_kernel_nightly-0.0.1.dev20240819184814.dist-info/RECORD +0 -27
  125. {liger_kernel_nightly-0.0.1.dev20240819184814.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/LICENSE +0 -0
  126. {liger_kernel_nightly-0.0.1.dev20240819184814.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/top_level.txt +0 -0
@@ -2,7 +2,8 @@ import torch
2
2
  import triton
3
3
  import triton.language as tl
4
4
 
5
- from liger_kernel.ops.utils import calculate_settings, ensure_contiguous
5
+ from liger_kernel.ops.utils import calculate_settings
6
+ from liger_kernel.ops.utils import ensure_contiguous
6
7
 
7
8
 
8
9
  @triton.jit
@@ -11,44 +12,40 @@ def silu(x):
11
12
 
12
13
 
13
14
  @triton.jit
14
- def _swiglu_forward_kernel(
15
- a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
16
- ):
17
- program_id = tl.program_id(0)
15
+ def _swiglu_forward_kernel(a_ptr, b_ptr, c_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
16
+ program_id = tl.program_id(0).to(tl.int64)
18
17
 
19
18
  # locate start index
20
- a += program_id * stride
21
- b += program_id * stride
22
- c += program_id * stride
19
+ a_ptr += program_id * stride
20
+ b_ptr += program_id * stride
21
+ c_ptr += program_id * stride
23
22
 
24
23
  col_offsets = tl.arange(0, BLOCK_SIZE)
25
24
  mask = col_offsets < n_cols
26
25
 
27
26
  # sigmoid requires type float32
28
- a_row = tl.load(a + col_offsets, mask=mask, other=0).to(tl.float32)
29
- b_row = tl.load(b + col_offsets, mask=mask, other=0)
30
- c_row = silu(a_row) * b_row
31
- tl.store(c + col_offsets, c_row, mask=mask)
27
+ a_row = tl.load(a_ptr + col_offsets, mask=mask, other=0).to(tl.float32)
28
+ b_row = tl.load(b_ptr + col_offsets, mask=mask, other=0)
29
+ c_row = silu(a_row).cast(b_row.dtype) * b_row
30
+ tl.store(c_ptr + col_offsets, c_row, mask=mask)
32
31
 
33
32
 
34
33
  @triton.jit
35
- def _swiglu_backward_kernel(
36
- dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
37
- ):
38
- program_id = tl.program_id(0)
34
+ def _swiglu_backward_kernel(dc_ptr, a_ptr, b_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
35
+ program_id = tl.program_id(0).to(tl.int64)
39
36
 
40
37
  # locate start index
41
- dc += program_id * stride
42
- a += program_id * stride
43
- b += program_id * stride
38
+ dc_ptr += program_id * stride
39
+ a_ptr += program_id * stride
40
+ b_ptr += program_id * stride
44
41
 
45
42
  col_offsets = tl.arange(0, BLOCK_SIZE)
46
43
  mask = col_offsets < n_cols
47
44
 
48
- dc_row = tl.load(dc + col_offsets, mask=mask, other=0)
45
+ dc_row = tl.load(dc_ptr + col_offsets, mask=mask, other=0)
49
46
  # sigmoid requires type float32
50
- a_row = tl.load(a + col_offsets, mask=mask, other=0).to(tl.float32)
51
- b_row = tl.load(b + col_offsets, mask=mask, other=0)
47
+ a_row = tl.load(a_ptr + col_offsets, mask=mask, other=0).to(tl.float32)
48
+ b_row = tl.load(b_ptr + col_offsets, mask=mask, other=0)
52
49
 
53
50
  # recomputation to save memory
54
51
  sig_a = tl.sigmoid(a_row)
@@ -56,58 +53,64 @@ def _swiglu_backward_kernel(
56
53
  db_row = dc_row * silu_a
57
54
  da_row = dc_row * (silu_a * (1 - sig_a) + sig_a) * b_row
58
55
 
59
- tl.store(a + col_offsets, da_row, mask=mask)
60
- tl.store(b + col_offsets, db_row, mask=mask)
56
+ tl.store(a_ptr + col_offsets, da_row, mask=mask)
57
+ tl.store(b_ptr + col_offsets, db_row, mask=mask)
58
+
59
+
60
+ def swiglu_forward(a, b):
61
+ ori_shape = a.shape
62
+
63
+ n_cols = ori_shape[-1]
64
+ a = a.view(-1, n_cols)
65
+ b = b.view(-1, n_cols)
66
+ c = torch.empty_like(a)
67
+ n_rows = a.shape[0]
68
+
69
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
70
+
71
+ _swiglu_forward_kernel[(n_rows,)](
72
+ a,
73
+ b,
74
+ c,
75
+ c.stride(-2),
76
+ n_cols=n_cols,
77
+ BLOCK_SIZE=BLOCK_SIZE,
78
+ num_warps=num_warps,
79
+ )
80
+ return a, b, c.view(*ori_shape)
81
+
82
+
83
+ def swiglu_backward(a, b, dc):
84
+ ori_shape = dc.shape
85
+ n_cols = ori_shape[-1]
86
+ dc = dc.view(-1, n_cols)
87
+ n_rows = dc.shape[0]
88
+
89
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
90
+
91
+ _swiglu_backward_kernel[(n_rows,)](
92
+ dc,
93
+ a,
94
+ b,
95
+ dc.stride(-2),
96
+ n_cols=n_cols,
97
+ BLOCK_SIZE=BLOCK_SIZE,
98
+ num_warps=num_warps,
99
+ )
100
+ return a.view(*ori_shape), b.view(*ori_shape)
61
101
 
62
102
 
63
103
  class LigerSiLUMulFunction(torch.autograd.Function):
64
104
  @staticmethod
65
105
  @ensure_contiguous
66
106
  def forward(ctx, a, b):
67
- ori_shape = a.shape
68
-
69
- n_cols = ori_shape[-1]
70
- a = a.view(-1, n_cols)
71
- b = b.view(-1, n_cols)
72
- c = torch.zeros_like(a)
73
- n_rows = a.shape[0]
74
-
75
- BLOCK_SIZE, num_warps = calculate_settings(n_cols)
76
-
77
- _swiglu_forward_kernel[(n_rows,)](
78
- a,
79
- b,
80
- c,
81
- c.stride(-2),
82
- n_cols=n_cols,
83
- BLOCK_SIZE=BLOCK_SIZE,
84
- num_warps=num_warps,
85
- )
86
-
107
+ a, b, c = swiglu_forward(a, b)
87
108
  ctx.save_for_backward(a, b)
88
-
89
- return c.view(*ori_shape)
109
+ return c
90
110
 
91
111
  @staticmethod
92
112
  @ensure_contiguous
93
113
  def backward(ctx, dc):
94
-
95
- ori_shape = dc.shape
96
- n_cols = ori_shape[-1]
97
- dc = dc.view(-1, n_cols)
98
114
  a, b = ctx.saved_tensors
99
- n_rows = dc.shape[0]
100
-
101
- BLOCK_SIZE, num_warps = calculate_settings(n_cols)
102
-
103
- _swiglu_backward_kernel[(n_rows,)](
104
- dc,
105
- a,
106
- b,
107
- dc.stride(-2),
108
- n_cols=n_cols,
109
- BLOCK_SIZE=BLOCK_SIZE,
110
- num_warps=num_warps,
111
- )
112
-
113
- return a.view(*ori_shape), b.view(*ori_shape)
115
+ a, b = swiglu_backward(a, b, dc)
116
+ return a, b
@@ -0,0 +1,136 @@
1
+ import math
2
+
3
+ from typing import Callable
4
+ from typing import List
5
+ from typing import Optional
6
+
7
+ import torch
8
+
9
+ from liger_kernel.ops.utils import ensure_contiguous
10
+
11
+
12
+ class LigerTiledMLPFunction(torch.autograd.Function):
13
+ """
14
+ Based on DeepSpeed's TiledMLP:
15
+ https://github.com/deepspeedai/DeepSpeed/blob/v0.18.2/deepspeed/runtime/sequence_parallel/ulysses_sp.py#L838
16
+
17
+ Perform a tiled MLP computation to massively reduce memory usage needed to compute MLP
18
+ when using very long sequence lengths.
19
+
20
+ This module re-computes `forward` in the `backward`. So the `forward` occurs twice each iteration.
21
+ And if you're using activation checkpointing it then occurs thrice.
22
+
23
+ Args:
24
+ fn: the function to call on sharded inputs (e.g., mlp.forward)
25
+ mlp_module: the MLP nn.Module object
26
+ x: the input to MLP.forward (hidden_states)
27
+ shards: how many shards to use
28
+ compute_params: a list of weights engaged in the compute
29
+
30
+ Returns:
31
+ the computed hidden_states
32
+ """
33
+
34
+ @staticmethod
35
+ @ensure_contiguous
36
+ def forward(
37
+ ctx,
38
+ fn: Callable,
39
+ mlp_module: torch.nn.Module,
40
+ x: torch.Tensor,
41
+ shards: int,
42
+ compute_params: Optional[List[torch.nn.Parameter]] = None,
43
+ ) -> torch.Tensor:
44
+ ctx.fn = fn
45
+ ctx.mlp_module = mlp_module
46
+ ctx.shards = shards
47
+ ctx.save_for_backward(x)
48
+
49
+ # x.shape could be [bs, seqlen, hidden_size] or [seqlen, hidden_size] (moe experts)
50
+ x_shards = list(torch.chunk(x, chunks=shards, dim=-2))
51
+ with torch.no_grad():
52
+ output_shards = [fn(mlp_module, x_shard) for x_shard in x_shards]
53
+ output_unsharded = torch.cat(output_shards, dim=-2)
54
+
55
+ return output_unsharded
56
+
57
+ @staticmethod
58
+ @ensure_contiguous
59
+ def backward(ctx, *grads) -> tuple:
60
+ fn = ctx.fn
61
+ (x,) = ctx.saved_tensors
62
+ mlp_module = ctx.mlp_module
63
+ shards = ctx.shards
64
+
65
+ x_requires_grad = x.requires_grad
66
+ x = x.detach()
67
+ # detach() unsets x.requires_grad, so restore it
68
+ x.requires_grad_(x_requires_grad)
69
+
70
+ # x.shape could be [bs, seqlen, hidden_size] or [seqlen, hidden_size] (moe experts)
71
+ hidden_size = x.shape[-1]
72
+ x_shape_orig = x.shape
73
+
74
+ # flatten bs+seqlen to avoid having stride issues when narrowing into seqlen w/ bs>1
75
+ x = x.view(-1, hidden_size)
76
+ incoming_grad = grads[0].view(-1, hidden_size)
77
+ x_grad = torch.zeros_like(x)
78
+
79
+ x_shards = list(torch.chunk(x, chunks=shards, dim=0))
80
+
81
+ for i, x_shard in enumerate(x_shards):
82
+ x_shard.requires_grad_(x_requires_grad)
83
+
84
+ # if seqlen is not exactly divisible by shards the last step will be shorter than shard_step
85
+ shard_step = x_shards[i].shape[0]
86
+ shard_offset = i * x_shards[0].shape[0]
87
+
88
+ x_shard.grad = x_grad.narrow(0, shard_offset, shard_step).view_as(x_shard)
89
+ incoming_grad_shard = incoming_grad.narrow(0, shard_offset, shard_step).view_as(x_shard)
90
+
91
+ with torch.enable_grad():
92
+ output = fn(mlp_module, x_shard)
93
+ torch.autograd.backward(output, incoming_grad_shard)
94
+
95
+ # unflatten
96
+ x_grad = x_grad.view(x_shape_orig)
97
+
98
+ return (None, None, x_grad, None, None)
99
+
100
+
101
+ def apply_tiled_mlp(
102
+ fn: Callable,
103
+ mlp_module: torch.nn.Module,
104
+ x: torch.Tensor,
105
+ num_shards: Optional[int] = None,
106
+ compute_params: Optional[List[torch.nn.Parameter]] = None,
107
+ ) -> torch.Tensor:
108
+ """
109
+ Apply tiled MLP computation for memory efficiency.
110
+
111
+ Args:
112
+ fn: the function to call on sharded inputs (e.g., lambda module, x: module(x))
113
+ mlp_module: the MLP nn.Module object
114
+ x: the input tensor with shape [bs, seqlen, hidden_size] or [seqlen, hidden_size]
115
+ num_shards: number of shards to use. If None, automatically calculated as ceil(seqlen / hidden_size)
116
+ compute_params: list of parameters for DeepSpeed ZeRO optimization
117
+
118
+ Returns:
119
+ output tensor with the same shape as input
120
+ """
121
+ if num_shards is None:
122
+ # x.shape could be [bs, seqlen, hidden_size] or [seqlen, hidden_size]
123
+ hidden_size = x.shape[-1]
124
+ seqlen = x.shape[-2]
125
+ num_shards = math.ceil(seqlen / hidden_size)
126
+
127
+ # Ensure num_shards is at least 1
128
+ num_shards = max(1, num_shards)
129
+
130
+ return LigerTiledMLPFunction.apply(
131
+ fn,
132
+ mlp_module,
133
+ x,
134
+ num_shards,
135
+ compute_params,
136
+ )
@@ -0,0 +1,207 @@
1
+ from typing import Literal
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import triton
6
+ import triton.language as tl
7
+
8
+ from liger_kernel.ops.utils import ensure_contiguous
9
+
10
+ MAX_FUSED_SIZE = 65536 // 4
11
+
12
+ REDUCTION_LITERAL = Literal["none", "sum", "mean", "batchmean"]
13
+
14
+ _REDUCTION_MODE_NONE = tl.constexpr(0)
15
+ _REDUCTION_MODE_SUM = tl.constexpr(1)
16
+ _REDUCTION_MODE_MEAN = tl.constexpr(2)
17
+ _REDUCTION_MODE_BATCHMEAN = tl.constexpr(3)
18
+
19
+ _str_to_reduction_mode = {
20
+ "none": _REDUCTION_MODE_NONE.value,
21
+ "sum": _REDUCTION_MODE_SUM.value,
22
+ "mean": _REDUCTION_MODE_MEAN.value,
23
+ "batchmean": _REDUCTION_MODE_BATCHMEAN.value,
24
+ }
25
+
26
+
27
+ def get_num_warps(BLOCK_SIZE):
28
+ num_warps = 4
29
+ if BLOCK_SIZE >= 32768:
30
+ num_warps = 32
31
+ elif BLOCK_SIZE >= 8192:
32
+ num_warps = 16
33
+ elif BLOCK_SIZE >= 2048:
34
+ num_warps = 8
35
+
36
+ return num_warps
37
+
38
+
39
+ @triton.jit
40
+ def _tv_distance_kernel(
41
+ p_ptr,
42
+ p_stride,
43
+ q_ptr,
44
+ q_stride,
45
+ loss_ptr,
46
+ loss_stride,
47
+ grads_ptr,
48
+ grads_stride,
49
+ label_ptr,
50
+ ignore_index: tl.constexpr,
51
+ n_cols,
52
+ BLOCK_SIZE: tl.constexpr,
53
+ HAS_LABEL: tl.constexpr,
54
+ reduction: tl.constexpr = _REDUCTION_MODE_BATCHMEAN,
55
+ ):
56
+ pid = tl.program_id(0).to(tl.int64)
57
+ p_ptr += pid * p_stride
58
+ q_ptr += pid * q_stride
59
+ loss_ptr += pid * loss_stride
60
+ grads_ptr += pid * grads_stride
61
+ label_ptr += pid
62
+
63
+ base_offsets = tl.arange(0, BLOCK_SIZE)
64
+
65
+ if HAS_LABEL:
66
+ label = tl.load(label_ptr)
67
+ if label == ignore_index:
68
+ for i in range(0, n_cols, BLOCK_SIZE):
69
+ offsets = i + base_offsets
70
+ mask = offsets < n_cols
71
+ tl.store(grads_ptr + offsets, 0.0, mask=mask)
72
+ if reduction == _REDUCTION_MODE_NONE:
73
+ tl.store(loss_ptr + offsets, 0.0, mask=mask)
74
+ return
75
+
76
+ loss_sum = 0.0
77
+ for i in range(0, n_cols, BLOCK_SIZE):
78
+ offsets = i + base_offsets
79
+ mask = offsets < n_cols
80
+
81
+ p = tl.load(p_ptr + offsets, mask=mask, other=0.0)
82
+ q = tl.load(q_ptr + offsets, mask=mask, other=0.0)
83
+
84
+ # TVD(P || Q) = 0.5 * |P - Q|
85
+ tv_loss = 0.5 * tl.abs(p - q)
86
+
87
+ grad_res = tl.where(p > q, 0.5, -0.5)
88
+
89
+ tl.store(grads_ptr + offsets, grad_res, mask=mask)
90
+
91
+ if reduction == _REDUCTION_MODE_NONE:
92
+ tl.store(loss_ptr + offsets, tv_loss, mask=mask)
93
+ else:
94
+ loss_sum += tl.sum(tv_loss, axis=0)
95
+
96
+ if reduction != _REDUCTION_MODE_NONE:
97
+ tl.store(loss_ptr, loss_sum)
98
+
99
+
100
+ def tv_distance_forward_triton(p, q, shift_labels, reduction, ignore_index, has_label):
101
+ BT, V = p.shape
102
+
103
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
104
+ num_warps = get_num_warps(BLOCK_SIZE)
105
+
106
+ grid = (BT,)
107
+
108
+ reduction = _str_to_reduction_mode[reduction]
109
+
110
+ out_size = (BT, V) if reduction == _REDUCTION_MODE_NONE.value else (BT,)
111
+ output_tensor = torch.zeros(out_size, device=p.device, dtype=torch.float32)
112
+ grads = torch.empty_like(p)
113
+
114
+ n_non_ignore = (shift_labels != ignore_index).sum().item() if has_label else BT
115
+
116
+ _tv_distance_kernel[grid](
117
+ p,
118
+ p.stride(0),
119
+ q,
120
+ q.stride(0),
121
+ output_tensor,
122
+ output_tensor.stride(0),
123
+ grads,
124
+ grads.stride(0),
125
+ shift_labels if has_label else torch.empty(1, device=p.device),
126
+ ignore_index,
127
+ V,
128
+ BLOCK_SIZE=BLOCK_SIZE,
129
+ HAS_LABEL=has_label,
130
+ num_warps=num_warps,
131
+ reduction=reduction,
132
+ )
133
+
134
+ if reduction == _REDUCTION_MODE_BATCHMEAN.value:
135
+ return output_tensor.sum() / n_non_ignore, grads / n_non_ignore
136
+ elif reduction == _REDUCTION_MODE_SUM.value:
137
+ return output_tensor.sum(dim=0), grads
138
+ elif reduction == _REDUCTION_MODE_MEAN.value:
139
+ return output_tensor.sum() / (n_non_ignore * V), grads / (n_non_ignore * V)
140
+ else:
141
+ return output_tensor, grads
142
+
143
+
144
+ def tvd_backward_triton(grad_output, grads):
145
+ # If cross entropy is the last layer, grad_output is 1.0. Skip the mul then.
146
+ if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
147
+ return grads
148
+
149
+ return grads * grad_output
150
+
151
+
152
+ class LigerTVDLossFunction(torch.autograd.Function):
153
+ """
154
+ Class implementing the forward and backward pass for the Total Variation Distance Loss using Triton.
155
+ """
156
+
157
+ @staticmethod
158
+ @ensure_contiguous
159
+ def forward(
160
+ ctx,
161
+ p: torch.Tensor,
162
+ q: torch.Tensor,
163
+ shift_labels: Optional[torch.Tensor] = None,
164
+ reduction: REDUCTION_LITERAL = "batchmean",
165
+ ignore_index: int = -100,
166
+ ) -> torch.Tensor:
167
+ """A forward pass for the Total Variation Distance Loss.
168
+
169
+ Args:
170
+ ctx: Torch autograd context
171
+ p (torch.Tensor): A tensor of shape (BT, V) containing the first distribution.
172
+ q (torch.Tensor): A tensor of shape (BT, V) containing the second distribution.
173
+ shift_labels (Optional[torch.Tensor]): A tensor of shape (BT,) containing the labels.
174
+ reduction (REDUCTION_LITERAL, optional): The reduction method to be applied. Defaults to "batchmean".
175
+ ignore_index (int, optional): The index to ignore during loss calculation. Defaults to -100.
176
+
177
+ Returns:
178
+ torch.Tensor: The computed Total Variation Distance Loss.
179
+ """
180
+ has_label = False
181
+ if shift_labels is not None:
182
+ assert shift_labels.shape == (p.shape[0],), (
183
+ f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}"
184
+ )
185
+ shift_labels = shift_labels.contiguous()
186
+ has_label = True
187
+
188
+ loss, grads = tv_distance_forward_triton(p, q, shift_labels, reduction, ignore_index, has_label)
189
+ ctx.save_for_backward(grads)
190
+ return loss
191
+
192
+ @staticmethod
193
+ @ensure_contiguous
194
+ def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
195
+ """A backward pass for the Total Variation Distance Loss.
196
+
197
+ Args:
198
+ ctx: Torch autograd context
199
+ grad_output (torch.Tensor): The gradient of the loss with respect to the output.
200
+
201
+ Returns:
202
+ tuple[torch.Tensor, None, None, None, None]: The gradient of the loss with respect to the inputs.
203
+ """
204
+ (grads,) = ctx.saved_tensors
205
+ grads = tvd_backward_triton(grad_output, grads)
206
+
207
+ return grads, None, None, None, None
liger_kernel/ops/utils.py CHANGED
@@ -1,11 +1,33 @@
1
+ """
2
+ This file incorporates code from Unsloth licensed under the Apache License, Version 2.0.
3
+ See the original Unsloth repository at https://github.com/unslothai/unsloth.
4
+
5
+ The following line
6
+ https://github.com/linkedin/Liger-Kernel/blob/7382a8761f9af679482b968f9348013d933947c7/src/liger_kernel/ops/utils.py#L23
7
+ is based on code from Unsloth, located at:
8
+ https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/utils.py#L43
9
+
10
+ Modifications made by Yanning Chen, 2024.
11
+ """
12
+
1
13
  import functools
2
14
  import importlib
15
+ import operator
16
+
3
17
  from typing import Callable
4
18
 
5
19
  import torch
6
20
  import triton
21
+ import triton.language as tl
22
+
7
23
  from packaging.version import Version
8
24
 
25
+ from liger_kernel.utils import infer_device
26
+
27
+
28
+ def is_hip() -> bool:
29
+ return torch.version.hip is not None
30
+
9
31
 
10
32
  def ensure_contiguous(fn):
11
33
  @functools.wraps(fn)
@@ -27,13 +49,12 @@ def calculate_settings(n):
27
49
  BLOCK_SIZE = triton.next_power_of_2(n)
28
50
  if BLOCK_SIZE > MAX_FUSED_SIZE:
29
51
  raise RuntimeError(
30
- f"Cannot launch Triton kernel since n = {n} exceeds "
31
- f"the recommended Triton blocksize = {MAX_FUSED_SIZE}."
52
+ f"Cannot launch Triton kernel since n = {n} exceeds the recommended Triton blocksize = {MAX_FUSED_SIZE}."
32
53
  )
33
54
 
34
55
  num_warps = 4
35
56
  if BLOCK_SIZE >= 32768:
36
- num_warps = 32
57
+ num_warps = 32 if not is_hip() else 16
37
58
  elif BLOCK_SIZE >= 8192:
38
59
  num_warps = 16
39
60
  elif BLOCK_SIZE >= 2048:
@@ -48,3 +69,61 @@ def compare_version(package: str, operator: Callable, target: str):
48
69
  return False
49
70
  pkg_version = Version(pkg.__version__)
50
71
  return operator(pkg_version, Version(target))
72
+
73
+
74
+ def get_amp_custom_fwd_bwd() -> Callable:
75
+ device = infer_device()
76
+ if compare_version("torch", operator.ge, "2.4.0"):
77
+ return (
78
+ functools.partial(torch.amp.custom_fwd, device_type=device),
79
+ functools.partial(torch.amp.custom_bwd, device_type=device),
80
+ )
81
+ if hasattr(torch, "npu") and getattr(torch.npu, "amp", None) is not None:
82
+ return torch.npu.amp.custom_fwd, torch.npu.amp.custom_bwd
83
+ return torch.cuda.amp.custom_fwd, torch.cuda.amp.custom_bwd
84
+
85
+
86
+ amp_custom_fwd, amp_custom_bwd = get_amp_custom_fwd_bwd()
87
+
88
+
89
+ torch_to_triton_dtype = {
90
+ torch.float32: tl.float32,
91
+ torch.float16: tl.float16,
92
+ torch.bfloat16: tl.bfloat16,
93
+ }
94
+
95
+
96
+ @triton.jit
97
+ def element_mul_kernel(
98
+ X_ptr,
99
+ X_stride,
100
+ grad_output_ptr,
101
+ n_cols,
102
+ BLOCK_SIZE: tl.constexpr,
103
+ ):
104
+ """
105
+ This function multiplies each element of the tensor pointed by X_ptr with the value pointed by grad_output_ptr.
106
+ The multiplication is performed in-place on the tensor pointed by X_ptr.
107
+
108
+ Parameters:
109
+ X_ptr: Pointer to the input tensor.
110
+ X_stride (int): The stride of the input tensor.
111
+ grad_output_ptr: Pointer to the gradient output value.
112
+ n_cols (int): The number of columns in the input tensor.
113
+ BLOCK_SIZE (int): The block size for Triton operations.
114
+ """
115
+
116
+ # Get the program ID and convert it to int64 to avoid overflow
117
+ program_id = tl.program_id(0).to(tl.int64)
118
+
119
+ # Locate the start index
120
+ X_ptr += program_id * X_stride
121
+
122
+ # Load the gradient output value
123
+ grad_output = tl.load(grad_output_ptr)
124
+
125
+ # Perform the element-wise multiplication
126
+ for i in range(0, n_cols, BLOCK_SIZE):
127
+ X_offsets = i + tl.arange(0, BLOCK_SIZE)
128
+ X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols)
129
+ tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols)