liger-kernel-nightly 0.5.3.dev20250220230230__py3-none-any.whl → 0.5.3.dev20250221003838__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,208 @@
1
+ from typing import Literal, Optional
2
+
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+ from liger_kernel.ops.utils import ensure_contiguous
8
+
9
+ MAX_FUSED_SIZE = 65536 // 4
10
+
11
+ REDUCTION_LITERAL = Literal["none", "sum", "mean", "batchmean"]
12
+
13
+ _REDUCTION_MODE_NONE = tl.constexpr(0)
14
+ _REDUCTION_MODE_SUM = tl.constexpr(1)
15
+ _REDUCTION_MODE_MEAN = tl.constexpr(2)
16
+ _REDUCTION_MODE_BATCHMEAN = tl.constexpr(3)
17
+
18
+ _str_to_reduction_mode = {
19
+ "none": _REDUCTION_MODE_NONE.value,
20
+ "sum": _REDUCTION_MODE_SUM.value,
21
+ "mean": _REDUCTION_MODE_MEAN.value,
22
+ "batchmean": _REDUCTION_MODE_BATCHMEAN.value,
23
+ }
24
+
25
+
26
+ def get_num_warps(BLOCK_SIZE):
27
+ num_warps = 4
28
+ if BLOCK_SIZE >= 32768:
29
+ num_warps = 32
30
+ elif BLOCK_SIZE >= 8192:
31
+ num_warps = 16
32
+ elif BLOCK_SIZE >= 2048:
33
+ num_warps = 8
34
+
35
+ return num_warps
36
+
37
+
38
+ @triton.jit
39
+ def _tv_distance_kernel(
40
+ p_ptr,
41
+ p_stride,
42
+ q_ptr,
43
+ q_stride,
44
+ loss_ptr,
45
+ loss_stride,
46
+ grads_ptr,
47
+ grads_stride,
48
+ label_ptr,
49
+ ignore_index: tl.constexpr,
50
+ n_cols,
51
+ BLOCK_SIZE: tl.constexpr,
52
+ HAS_LABEL: tl.constexpr,
53
+ reduction: tl.constexpr = _REDUCTION_MODE_BATCHMEAN,
54
+ ):
55
+ pid = tl.program_id(0).to(tl.int64)
56
+ p_ptr += pid * p_stride
57
+ q_ptr += pid * q_stride
58
+ loss_ptr += pid * loss_stride
59
+ grads_ptr += pid * grads_stride
60
+ label_ptr += pid
61
+
62
+ base_offsets = tl.arange(0, BLOCK_SIZE)
63
+
64
+ if HAS_LABEL:
65
+ label = tl.load(label_ptr)
66
+ if label == ignore_index:
67
+ for i in range(0, n_cols, BLOCK_SIZE):
68
+ offsets = i + base_offsets
69
+ mask = offsets < n_cols
70
+ tl.store(grads_ptr + offsets, 0.0, mask=mask)
71
+ if reduction == _REDUCTION_MODE_NONE:
72
+ tl.store(loss_ptr + offsets, 0.0, mask=mask)
73
+ return
74
+
75
+ loss_sum = 0.0
76
+ for i in range(0, n_cols, BLOCK_SIZE):
77
+ offsets = i + base_offsets
78
+ mask = offsets < n_cols
79
+
80
+ p = tl.load(p_ptr + offsets, mask=mask, other=0.0)
81
+ q = tl.load(q_ptr + offsets, mask=mask, other=0.0)
82
+
83
+ # TVD(P || Q) = 0.5 * |P - Q|
84
+ tv_loss = 0.5 * tl.abs(p - q)
85
+
86
+ grad_res = tl.where(p > q, 0.5, -0.5)
87
+
88
+ tl.store(grads_ptr + offsets, grad_res, mask=mask)
89
+
90
+ if reduction == _REDUCTION_MODE_NONE:
91
+ tl.store(loss_ptr + offsets, tv_loss, mask=mask)
92
+ else:
93
+ loss_sum += tl.sum(tv_loss, axis=0)
94
+
95
+ if reduction != _REDUCTION_MODE_NONE:
96
+ tl.store(loss_ptr, loss_sum)
97
+
98
+
99
+ def tv_distance_forward_triton(p, q, shift_labels, reduction, ignore_index, has_label):
100
+ BT, V = p.shape
101
+
102
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
103
+ num_warps = get_num_warps(BLOCK_SIZE)
104
+
105
+ grid = (BT,)
106
+
107
+ reduction = _str_to_reduction_mode[reduction]
108
+
109
+ out_size = (BT, V) if reduction == _REDUCTION_MODE_NONE.value else (BT,)
110
+ output_tensor = torch.zeros(out_size, device=p.device, dtype=torch.float32)
111
+ grads = torch.empty_like(p)
112
+
113
+ n_non_ignore = (shift_labels != ignore_index).sum().item() if has_label else BT
114
+
115
+ _tv_distance_kernel[grid](
116
+ p,
117
+ p.stride(0),
118
+ q,
119
+ q.stride(0),
120
+ output_tensor,
121
+ output_tensor.stride(0),
122
+ grads,
123
+ grads.stride(0),
124
+ shift_labels if has_label else torch.empty(1, device=p.device),
125
+ ignore_index,
126
+ V,
127
+ BLOCK_SIZE=BLOCK_SIZE,
128
+ HAS_LABEL=has_label,
129
+ num_warps=num_warps,
130
+ reduction=reduction,
131
+ )
132
+
133
+ if reduction == _REDUCTION_MODE_BATCHMEAN.value:
134
+ return output_tensor.sum() / n_non_ignore, grads / n_non_ignore
135
+ elif reduction == _REDUCTION_MODE_SUM.value:
136
+ return output_tensor.sum(dim=0), grads
137
+ elif reduction == _REDUCTION_MODE_MEAN.value:
138
+ return output_tensor.sum() / (n_non_ignore * V), grads / (n_non_ignore * V)
139
+ else:
140
+ return output_tensor, grads
141
+
142
+
143
+ def tvd_backward_triton(grad_output, grads):
144
+ # If cross entropy is the last layer, grad_output is 1.0. Skip the mul then.
145
+ if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
146
+ return grads
147
+
148
+ return grads * grad_output
149
+
150
+
151
+ class LigerTVDLossFunction(torch.autograd.Function):
152
+ """
153
+ Class implementing the forward and backward pass for the Total Variation Distance Loss using Triton.
154
+ """
155
+
156
+ @staticmethod
157
+ @ensure_contiguous
158
+ def forward(
159
+ ctx,
160
+ p: torch.Tensor,
161
+ q: torch.Tensor,
162
+ shift_labels: Optional[torch.Tensor] = None,
163
+ reduction: REDUCTION_LITERAL = "batchmean",
164
+ ignore_index: int = -100,
165
+ ) -> torch.Tensor:
166
+ """A forward pass for the Total Variation Distance Loss.
167
+
168
+ Args:
169
+ ctx: Torch autograd context
170
+ p (torch.Tensor): A tensor of shape (BT, V) containing the first distribution.
171
+ q (torch.Tensor): A tensor of shape (BT, V) containing the second distribution.
172
+ shift_labels (Optional[torch.Tensor]): A tensor of shape (BT,) containing the labels.
173
+ reduction (REDUCTION_LITERAL, optional): The reduction method to be applied. Defaults to "batchmean".
174
+ ignore_index (int, optional): The index to ignore during loss calculation. Defaults to -100.
175
+
176
+ Returns:
177
+ torch.Tensor: The computed Total Variation Distance Loss.
178
+ """
179
+ has_label = False
180
+ if shift_labels is not None:
181
+ assert shift_labels.shape == (
182
+ p.shape[0],
183
+ ), f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}"
184
+ shift_labels = shift_labels.contiguous()
185
+ has_label = True
186
+
187
+ loss, grads = tv_distance_forward_triton(
188
+ p, q, shift_labels, reduction, ignore_index, has_label
189
+ )
190
+ ctx.save_for_backward(grads)
191
+ return loss
192
+
193
+ @staticmethod
194
+ @ensure_contiguous
195
+ def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
196
+ """A backward pass for the Total Variation Distance Loss.
197
+
198
+ Args:
199
+ ctx: Torch autograd context
200
+ grad_output (torch.Tensor): The gradient of the loss with respect to the output.
201
+
202
+ Returns:
203
+ tuple[torch.Tensor, None, None, None, None]: The gradient of the loss with respect to the inputs.
204
+ """
205
+ (grads,) = ctx.saved_tensors
206
+ grads = tvd_backward_triton(grad_output, grads)
207
+
208
+ return grads, None, None, None, None
@@ -18,6 +18,7 @@ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2
18
18
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2_vl # noqa: F401
19
19
  from liger_kernel.transformers.rms_norm import LigerRMSNorm # noqa: F401
20
20
  from liger_kernel.transformers.rope import liger_rotary_pos_emb # noqa: F401
21
+ from liger_kernel.transformers.tvd import LigerTVDLoss # noqa: F401
21
22
  from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP # noqa: F401
22
23
  from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP # noqa: F401
23
24
  from liger_kernel.transformers.swiglu import LigerSwiGLUMLP # noqa: F401
@@ -12,7 +12,7 @@ from liger_kernel.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction
12
12
  from liger_kernel.ops.rms_norm import LigerRMSNormFunction
13
13
  from liger_kernel.ops.rope import LigerRopeFunction
14
14
  from liger_kernel.ops.swiglu import LigerSiLUMulFunction
15
-
15
+ from liger_kernel.ops.tvd import LigerTVDLossFunction
16
16
 
17
17
  # conform to the function signature in https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html
18
18
  # `weight` and `size_average` are placeholders and not implemented yet
@@ -156,6 +156,20 @@ def liger_kl_div(
156
156
  eps,
157
157
  )
158
158
 
159
+ def liger_tvd(
160
+ input,
161
+ target,
162
+ shift_labels=None,
163
+ reduction: str = "mean",
164
+ ignore_index: int = -100,
165
+ ):
166
+ return LigerTVDLossFunction.apply(
167
+ input,
168
+ target,
169
+ shift_labels,
170
+ reduction,
171
+ ignore_index,
172
+ )
159
173
 
160
174
  def liger_layer_norm(X, W, B, eps):
161
175
  return LigerLayerNormFunction.apply(X, W, B, eps)
@@ -0,0 +1,15 @@
1
+ import torch.nn as nn
2
+
3
+ from liger_kernel.ops.tvd import LigerTVDLossFunction
4
+
5
+
6
+ class LigerTVDLoss(nn.Module):
7
+ def __init__(self, reduction="batchmean", ignore_index: int = -100):
8
+ super(LigerTVDLoss, self).__init__()
9
+ self.reduction = reduction
10
+ self.ignore_index = ignore_index
11
+
12
+ def forward(self, p, q, shift_labels=None):
13
+ return LigerTVDLossFunction.apply(
14
+ p, q, shift_labels, self.reduction, self.ignore_index
15
+ )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.3.dev20250220230230
3
+ Version: 0.5.3.dev20250221003838
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -126,7 +126,7 @@ Requires-Dist: mkdocs-material; extra == "dev"
126
126
 
127
127
  **Liger Kernel** is a collection of Triton kernels designed specifically for LLM training. It can effectively increase multi-GPU **training throughput by 20%** and reduces **memory usage by 60%**. We have implemented **Hugging Face Compatible** `RMSNorm`, `RoPE`, `SwiGLU`, `CrossEntropy`, `FusedLinearCrossEntropy`, and more to come. The kernel works out of the box with [Flash Attention](https://github.com/Dao-AILab/flash-attention), [PyTorch FSDP](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html), and [Microsoft DeepSpeed](https://github.com/microsoft/DeepSpeed). We welcome contributions from the community to gather the best kernels for LLM training.
128
128
 
129
- We've also added optimized Post-Training kernels that deliver **up to 80% memory savings** for alignment and distillation tasks. We support losses like DPO, CPO, ORPO, SimPO, JSD, and many more. Check out [how we optimize the memory](https://x.com/hsu_byron/status/1866577403918917655).
129
+ We've also added optimized Post-Training kernels that deliver **up to 80% memory savings** for alignment and distillation tasks. We support losses like DPO, CPO, ORPO, SimPO, KTO, JSD, and many more. Check out [how we optimize the memory](https://x.com/hsu_byron/status/1866577403918917655).
130
130
 
131
131
  ## Supercharge Your Model with Liger Kernel
132
132
 
@@ -341,6 +341,7 @@ loss.backward()
341
341
  | Fused Linear DPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearDPOLoss` |
342
342
  | Fused Linear ORPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearORPOLoss` |
343
343
  | Fused Linear SimPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearSimPOLoss` |
344
+ | Fused Linear KTO Loss | `liger_kernel.chunked_loss.LigerFusedLinearKTOLoss` |
344
345
 
345
346
  ### Distillation Kernels
346
347
 
@@ -349,6 +350,7 @@ loss.backward()
349
350
  | KLDivergence | `liger_kernel.transformers.LigerKLDIVLoss` |
350
351
  | JSD | `liger_kernel.transformers.LigerJSD` |
351
352
  | Fused Linear JSD | `liger_kernel.transformers.LigerFusedLinearJSD` |
353
+ | TVD | `liger_kernel.transformers.LigerTVDLoss` |
352
354
 
353
355
  ### Experimental Kernels
354
356
 
@@ -28,13 +28,14 @@ liger_kernel/ops/qwen2vl_mrope.py,sha256=3GExhYpLgB4VUtyZyjRk8XjEur3W4EWF6HQ67ML
28
28
  liger_kernel/ops/rms_norm.py,sha256=PWLJcdIKU5e-8BuYFHd9Cqlq6wmr6fUXKi9zQD4LetU,11727
29
29
  liger_kernel/ops/rope.py,sha256=ofmBOkUpZZO-Q8Z5B_LOFYYLD-YT-8WnJ4vGOrDYouI,8943
30
30
  liger_kernel/ops/swiglu.py,sha256=KmgMjaJQnbLLgZn2nEpbwHU_xpnYRweCyrLQSVvM1vA,3015
31
+ liger_kernel/ops/tvd.py,sha256=9wVCijj2vBtgiLeUHhl7hy_LAiJ3liPIYOGMSU3P1ro,6407
31
32
  liger_kernel/ops/utils.py,sha256=uoFKQqo-34N2TWQNvXMFywqGiOMMXNEVBxVojzlUAa0,3836
32
33
  liger_kernel/ops/experimental/embedding.py,sha256=tolj3tItkzpSb30zWqDN2_yX4ectflaQ8HMyKyFIQc8,4172
33
34
  liger_kernel/ops/experimental/mm_int8int2.py,sha256=TrS9lpwekrik_w5qE7AhMJD1bcq-OidjtbsW80oZ6IM,13314
34
- liger_kernel/transformers/__init__.py,sha256=QPmYkL6hosBPpPqCUGqvIvAtD9XzLgvZqZxUyYMZeVk,2008
35
+ liger_kernel/transformers/__init__.py,sha256=VZI9hiCvvA371jsfkJmSt1CNXlBztIvlVGDExyKeqBM,2077
35
36
  liger_kernel/transformers/auto_model.py,sha256=0qCTRZt280Bj_LcFdzo9hlaR-BWNazawXOGgoCZjgEg,1545
36
37
  liger_kernel/transformers/cross_entropy.py,sha256=z3KTWQnFxr_IZaVjtYt0ZNEWQdDdYThN35xWkHlDGH0,1683
37
- liger_kernel/transformers/functional.py,sha256=lDOjch622dJIc78K3ePFK_H1DX00GC5kKjodjcbEgbM,4624
38
+ liger_kernel/transformers/functional.py,sha256=zahXVCjA2NxcVFpAgajILIRN0GO6mrbfLPgONUkTrY8,4940
38
39
  liger_kernel/transformers/fused_linear_cross_entropy.py,sha256=09Rt7FZzLH42VOcIbQ4dlQd0o3Rlb4vk6fqiOQ7WTD8,1778
39
40
  liger_kernel/transformers/fused_linear_jsd.py,sha256=bZ4otCvWBuOnA5XdQL-FzZVItJlDt-ht9e_pG7PG93E,3999
40
41
  liger_kernel/transformers/geglu.py,sha256=mrgqzIUVd6lN7fkDKLkw5YaESDxDtFgbot430WwPVOQ,1107
@@ -48,6 +49,7 @@ liger_kernel/transformers/rms_norm.py,sha256=GqCEJuGt0YdqqlMcToE0Wp4A8YFquDa4UUS
48
49
  liger_kernel/transformers/rope.py,sha256=ZTrTORSAyfcFIKjk6XEeYmk4ROH7xXED9L4g2NFntlE,999
49
50
  liger_kernel/transformers/swiglu.py,sha256=i9WTqcNRqReU4XJs391IPbl-I5X0wG4T72D4pqGFfJg,2422
50
51
  liger_kernel/transformers/trainer_integration.py,sha256=W3ON51O5GkyzNJsItz0y5rKx-uy2f2cFfveZpqbUdhw,123
52
+ liger_kernel/transformers/tvd.py,sha256=lt730XDR4IYEkn-HeWS7WU6AGssg90ubg8pqPAX2lbE,472
51
53
  liger_kernel/transformers/experimental/embedding.py,sha256=2P0QYdlFyFrG5OqTzTa1wcRgDSyjBMv5i1a7BrDPDQw,881
52
54
  liger_kernel/transformers/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
53
55
  liger_kernel/transformers/model/gemma.py,sha256=ky89b3aWPaeTGRMC-745KgixtQIRXzNAiCORAMLn9yo,9654
@@ -63,9 +65,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
63
65
  liger_kernel/transformers/trainer/orpo_trainer.py,sha256=pdekW7l6Qg_aqa5SYKYlSWUF8m3lkOFvFLcIMEHrz9s,8338
64
66
  liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
65
67
  liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
66
- liger_kernel_nightly-0.5.3.dev20250220230230.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
67
- liger_kernel_nightly-0.5.3.dev20250220230230.dist-info/METADATA,sha256=xtathj_pY7bV0Pkw0qNpzJ-cDVUXRy3AsSemRtaTRYY,21766
68
- liger_kernel_nightly-0.5.3.dev20250220230230.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
69
- liger_kernel_nightly-0.5.3.dev20250220230230.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
70
- liger_kernel_nightly-0.5.3.dev20250220230230.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
71
- liger_kernel_nightly-0.5.3.dev20250220230230.dist-info/RECORD,,
68
+ liger_kernel_nightly-0.5.3.dev20250221003838.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
69
+ liger_kernel_nightly-0.5.3.dev20250221003838.dist-info/METADATA,sha256=zJp1YMQDbbOzeNRKFf7AN7hYqePdT49OEMQkN_buKl8,21963
70
+ liger_kernel_nightly-0.5.3.dev20250221003838.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
71
+ liger_kernel_nightly-0.5.3.dev20250221003838.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
72
+ liger_kernel_nightly-0.5.3.dev20250221003838.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
73
+ liger_kernel_nightly-0.5.3.dev20250221003838.dist-info/RECORD,,