liger-kernel-nightly 0.5.3.dev20250220230230__py3-none-any.whl → 0.5.3.dev20250221003838__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,208 @@
1
+ from typing import Literal, Optional
2
+
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+ from liger_kernel.ops.utils import ensure_contiguous
8
+
9
+ MAX_FUSED_SIZE = 65536 // 4
10
+
11
+ REDUCTION_LITERAL = Literal["none", "sum", "mean", "batchmean"]
12
+
13
+ _REDUCTION_MODE_NONE = tl.constexpr(0)
14
+ _REDUCTION_MODE_SUM = tl.constexpr(1)
15
+ _REDUCTION_MODE_MEAN = tl.constexpr(2)
16
+ _REDUCTION_MODE_BATCHMEAN = tl.constexpr(3)
17
+
18
+ _str_to_reduction_mode = {
19
+ "none": _REDUCTION_MODE_NONE.value,
20
+ "sum": _REDUCTION_MODE_SUM.value,
21
+ "mean": _REDUCTION_MODE_MEAN.value,
22
+ "batchmean": _REDUCTION_MODE_BATCHMEAN.value,
23
+ }
24
+
25
+
26
+ def get_num_warps(BLOCK_SIZE):
27
+ num_warps = 4
28
+ if BLOCK_SIZE >= 32768:
29
+ num_warps = 32
30
+ elif BLOCK_SIZE >= 8192:
31
+ num_warps = 16
32
+ elif BLOCK_SIZE >= 2048:
33
+ num_warps = 8
34
+
35
+ return num_warps
36
+
37
+
38
+ @triton.jit
39
+ def _tv_distance_kernel(
40
+ p_ptr,
41
+ p_stride,
42
+ q_ptr,
43
+ q_stride,
44
+ loss_ptr,
45
+ loss_stride,
46
+ grads_ptr,
47
+ grads_stride,
48
+ label_ptr,
49
+ ignore_index: tl.constexpr,
50
+ n_cols,
51
+ BLOCK_SIZE: tl.constexpr,
52
+ HAS_LABEL: tl.constexpr,
53
+ reduction: tl.constexpr = _REDUCTION_MODE_BATCHMEAN,
54
+ ):
55
+ pid = tl.program_id(0).to(tl.int64)
56
+ p_ptr += pid * p_stride
57
+ q_ptr += pid * q_stride
58
+ loss_ptr += pid * loss_stride
59
+ grads_ptr += pid * grads_stride
60
+ label_ptr += pid
61
+
62
+ base_offsets = tl.arange(0, BLOCK_SIZE)
63
+
64
+ if HAS_LABEL:
65
+ label = tl.load(label_ptr)
66
+ if label == ignore_index:
67
+ for i in range(0, n_cols, BLOCK_SIZE):
68
+ offsets = i + base_offsets
69
+ mask = offsets < n_cols
70
+ tl.store(grads_ptr + offsets, 0.0, mask=mask)
71
+ if reduction == _REDUCTION_MODE_NONE:
72
+ tl.store(loss_ptr + offsets, 0.0, mask=mask)
73
+ return
74
+
75
+ loss_sum = 0.0
76
+ for i in range(0, n_cols, BLOCK_SIZE):
77
+ offsets = i + base_offsets
78
+ mask = offsets < n_cols
79
+
80
+ p = tl.load(p_ptr + offsets, mask=mask, other=0.0)
81
+ q = tl.load(q_ptr + offsets, mask=mask, other=0.0)
82
+
83
+ # TVD(P || Q) = 0.5 * |P - Q|
84
+ tv_loss = 0.5 * tl.abs(p - q)
85
+
86
+ grad_res = tl.where(p > q, 0.5, -0.5)
87
+
88
+ tl.store(grads_ptr + offsets, grad_res, mask=mask)
89
+
90
+ if reduction == _REDUCTION_MODE_NONE:
91
+ tl.store(loss_ptr + offsets, tv_loss, mask=mask)
92
+ else:
93
+ loss_sum += tl.sum(tv_loss, axis=0)
94
+
95
+ if reduction != _REDUCTION_MODE_NONE:
96
+ tl.store(loss_ptr, loss_sum)
97
+
98
+
99
+ def tv_distance_forward_triton(p, q, shift_labels, reduction, ignore_index, has_label):
100
+ BT, V = p.shape
101
+
102
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
103
+ num_warps = get_num_warps(BLOCK_SIZE)
104
+
105
+ grid = (BT,)
106
+
107
+ reduction = _str_to_reduction_mode[reduction]
108
+
109
+ out_size = (BT, V) if reduction == _REDUCTION_MODE_NONE.value else (BT,)
110
+ output_tensor = torch.zeros(out_size, device=p.device, dtype=torch.float32)
111
+ grads = torch.empty_like(p)
112
+
113
+ n_non_ignore = (shift_labels != ignore_index).sum().item() if has_label else BT
114
+
115
+ _tv_distance_kernel[grid](
116
+ p,
117
+ p.stride(0),
118
+ q,
119
+ q.stride(0),
120
+ output_tensor,
121
+ output_tensor.stride(0),
122
+ grads,
123
+ grads.stride(0),
124
+ shift_labels if has_label else torch.empty(1, device=p.device),
125
+ ignore_index,
126
+ V,
127
+ BLOCK_SIZE=BLOCK_SIZE,
128
+ HAS_LABEL=has_label,
129
+ num_warps=num_warps,
130
+ reduction=reduction,
131
+ )
132
+
133
+ if reduction == _REDUCTION_MODE_BATCHMEAN.value:
134
+ return output_tensor.sum() / n_non_ignore, grads / n_non_ignore
135
+ elif reduction == _REDUCTION_MODE_SUM.value:
136
+ return output_tensor.sum(dim=0), grads
137
+ elif reduction == _REDUCTION_MODE_MEAN.value:
138
+ return output_tensor.sum() / (n_non_ignore * V), grads / (n_non_ignore * V)
139
+ else:
140
+ return output_tensor, grads
141
+
142
+
143
+ def tvd_backward_triton(grad_output, grads):
144
+ # If cross entropy is the last layer, grad_output is 1.0. Skip the mul then.
145
+ if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
146
+ return grads
147
+
148
+ return grads * grad_output
149
+
150
+
151
+ class LigerTVDLossFunction(torch.autograd.Function):
152
+ """
153
+ Class implementing the forward and backward pass for the Total Variation Distance Loss using Triton.
154
+ """
155
+
156
+ @staticmethod
157
+ @ensure_contiguous
158
+ def forward(
159
+ ctx,
160
+ p: torch.Tensor,
161
+ q: torch.Tensor,
162
+ shift_labels: Optional[torch.Tensor] = None,
163
+ reduction: REDUCTION_LITERAL = "batchmean",
164
+ ignore_index: int = -100,
165
+ ) -> torch.Tensor:
166
+ """A forward pass for the Total Variation Distance Loss.
167
+
168
+ Args:
169
+ ctx: Torch autograd context
170
+ p (torch.Tensor): A tensor of shape (BT, V) containing the first distribution.
171
+ q (torch.Tensor): A tensor of shape (BT, V) containing the second distribution.
172
+ shift_labels (Optional[torch.Tensor]): A tensor of shape (BT,) containing the labels.
173
+ reduction (REDUCTION_LITERAL, optional): The reduction method to be applied. Defaults to "batchmean".
174
+ ignore_index (int, optional): The index to ignore during loss calculation. Defaults to -100.
175
+
176
+ Returns:
177
+ torch.Tensor: The computed Total Variation Distance Loss.
178
+ """
179
+ has_label = False
180
+ if shift_labels is not None:
181
+ assert shift_labels.shape == (
182
+ p.shape[0],
183
+ ), f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}"
184
+ shift_labels = shift_labels.contiguous()
185
+ has_label = True
186
+
187
+ loss, grads = tv_distance_forward_triton(
188
+ p, q, shift_labels, reduction, ignore_index, has_label
189
+ )
190
+ ctx.save_for_backward(grads)
191
+ return loss
192
+
193
+ @staticmethod
194
+ @ensure_contiguous
195
+ def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
196
+ """A backward pass for the Total Variation Distance Loss.
197
+
198
+ Args:
199
+ ctx: Torch autograd context
200
+ grad_output (torch.Tensor): The gradient of the loss with respect to the output.
201
+
202
+ Returns:
203
+ tuple[torch.Tensor, None, None, None, None]: The gradient of the loss with respect to the inputs.
204
+ """
205
+ (grads,) = ctx.saved_tensors
206
+ grads = tvd_backward_triton(grad_output, grads)
207
+
208
+ return grads, None, None, None, None
@@ -18,6 +18,7 @@ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2
18
18
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2_vl # noqa: F401
19
19
  from liger_kernel.transformers.rms_norm import LigerRMSNorm # noqa: F401
20
20
  from liger_kernel.transformers.rope import liger_rotary_pos_emb # noqa: F401
21
+ from liger_kernel.transformers.tvd import LigerTVDLoss # noqa: F401
21
22
  from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP # noqa: F401
22
23
  from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP # noqa: F401
23
24
  from liger_kernel.transformers.swiglu import LigerSwiGLUMLP # noqa: F401
@@ -12,7 +12,7 @@ from liger_kernel.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction
12
12
  from liger_kernel.ops.rms_norm import LigerRMSNormFunction
13
13
  from liger_kernel.ops.rope import LigerRopeFunction
14
14
  from liger_kernel.ops.swiglu import LigerSiLUMulFunction
15
-
15
+ from liger_kernel.ops.tvd import LigerTVDLossFunction
16
16
 
17
17
  # conform to the function signature in https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html
18
18
  # `weight` and `size_average` are placeholders and not implemented yet
@@ -156,6 +156,20 @@ def liger_kl_div(
156
156
  eps,
157
157
  )
158
158
 
159
+ def liger_tvd(
160
+ input,
161
+ target,
162
+ shift_labels=None,
163
+ reduction: str = "mean",
164
+ ignore_index: int = -100,
165
+ ):
166
+ return LigerTVDLossFunction.apply(
167
+ input,
168
+ target,
169
+ shift_labels,
170
+ reduction,
171
+ ignore_index,
172
+ )
159
173
 
160
174
  def liger_layer_norm(X, W, B, eps):
161
175
  return LigerLayerNormFunction.apply(X, W, B, eps)
@@ -0,0 +1,15 @@
1
+ import torch.nn as nn
2
+
3
+ from liger_kernel.ops.tvd import LigerTVDLossFunction
4
+
5
+
6
+ class LigerTVDLoss(nn.Module):
7
+ def __init__(self, reduction="batchmean", ignore_index: int = -100):
8
+ super(LigerTVDLoss, self).__init__()
9
+ self.reduction = reduction
10
+ self.ignore_index = ignore_index
11
+
12
+ def forward(self, p, q, shift_labels=None):
13
+ return LigerTVDLossFunction.apply(
14
+ p, q, shift_labels, self.reduction, self.ignore_index
15
+ )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.3.dev20250220230230
3
+ Version: 0.5.3.dev20250221003838
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -126,7 +126,7 @@ Requires-Dist: mkdocs-material; extra == "dev"
126
126
 
127
127
  **Liger Kernel** is a collection of Triton kernels designed specifically for LLM training. It can effectively increase multi-GPU **training throughput by 20%** and reduces **memory usage by 60%**. We have implemented **Hugging Face Compatible** `RMSNorm`, `RoPE`, `SwiGLU`, `CrossEntropy`, `FusedLinearCrossEntropy`, and more to come. The kernel works out of the box with [Flash Attention](https://github.com/Dao-AILab/flash-attention), [PyTorch FSDP](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html), and [Microsoft DeepSpeed](https://github.com/microsoft/DeepSpeed). We welcome contributions from the community to gather the best kernels for LLM training.
128
128
 
129
- We've also added optimized Post-Training kernels that deliver **up to 80% memory savings** for alignment and distillation tasks. We support losses like DPO, CPO, ORPO, SimPO, JSD, and many more. Check out [how we optimize the memory](https://x.com/hsu_byron/status/1866577403918917655).
129
+ We've also added optimized Post-Training kernels that deliver **up to 80% memory savings** for alignment and distillation tasks. We support losses like DPO, CPO, ORPO, SimPO, KTO, JSD, and many more. Check out [how we optimize the memory](https://x.com/hsu_byron/status/1866577403918917655).
130
130
 
131
131
  ## Supercharge Your Model with Liger Kernel
132
132
 
@@ -341,6 +341,7 @@ loss.backward()
341
341
  | Fused Linear DPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearDPOLoss` |
342
342
  | Fused Linear ORPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearORPOLoss` |
343
343
  | Fused Linear SimPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearSimPOLoss` |
344
+ | Fused Linear KTO Loss | `liger_kernel.chunked_loss.LigerFusedLinearKTOLoss` |
344
345
 
345
346
  ### Distillation Kernels
346
347
 
@@ -349,6 +350,7 @@ loss.backward()
349
350
  | KLDivergence | `liger_kernel.transformers.LigerKLDIVLoss` |
350
351
  | JSD | `liger_kernel.transformers.LigerJSD` |
351
352
  | Fused Linear JSD | `liger_kernel.transformers.LigerFusedLinearJSD` |
353
+ | TVD | `liger_kernel.transformers.LigerTVDLoss` |
352
354
 
353
355
  ### Experimental Kernels
354
356
 
@@ -28,13 +28,14 @@ liger_kernel/ops/qwen2vl_mrope.py,sha256=3GExhYpLgB4VUtyZyjRk8XjEur3W4EWF6HQ67ML
28
28
  liger_kernel/ops/rms_norm.py,sha256=PWLJcdIKU5e-8BuYFHd9Cqlq6wmr6fUXKi9zQD4LetU,11727
29
29
  liger_kernel/ops/rope.py,sha256=ofmBOkUpZZO-Q8Z5B_LOFYYLD-YT-8WnJ4vGOrDYouI,8943
30
30
  liger_kernel/ops/swiglu.py,sha256=KmgMjaJQnbLLgZn2nEpbwHU_xpnYRweCyrLQSVvM1vA,3015
31
+ liger_kernel/ops/tvd.py,sha256=9wVCijj2vBtgiLeUHhl7hy_LAiJ3liPIYOGMSU3P1ro,6407
31
32
  liger_kernel/ops/utils.py,sha256=uoFKQqo-34N2TWQNvXMFywqGiOMMXNEVBxVojzlUAa0,3836
32
33
  liger_kernel/ops/experimental/embedding.py,sha256=tolj3tItkzpSb30zWqDN2_yX4ectflaQ8HMyKyFIQc8,4172
33
34
  liger_kernel/ops/experimental/mm_int8int2.py,sha256=TrS9lpwekrik_w5qE7AhMJD1bcq-OidjtbsW80oZ6IM,13314
34
- liger_kernel/transformers/__init__.py,sha256=QPmYkL6hosBPpPqCUGqvIvAtD9XzLgvZqZxUyYMZeVk,2008
35
+ liger_kernel/transformers/__init__.py,sha256=VZI9hiCvvA371jsfkJmSt1CNXlBztIvlVGDExyKeqBM,2077
35
36
  liger_kernel/transformers/auto_model.py,sha256=0qCTRZt280Bj_LcFdzo9hlaR-BWNazawXOGgoCZjgEg,1545
36
37
  liger_kernel/transformers/cross_entropy.py,sha256=z3KTWQnFxr_IZaVjtYt0ZNEWQdDdYThN35xWkHlDGH0,1683
37
- liger_kernel/transformers/functional.py,sha256=lDOjch622dJIc78K3ePFK_H1DX00GC5kKjodjcbEgbM,4624
38
+ liger_kernel/transformers/functional.py,sha256=zahXVCjA2NxcVFpAgajILIRN0GO6mrbfLPgONUkTrY8,4940
38
39
  liger_kernel/transformers/fused_linear_cross_entropy.py,sha256=09Rt7FZzLH42VOcIbQ4dlQd0o3Rlb4vk6fqiOQ7WTD8,1778
39
40
  liger_kernel/transformers/fused_linear_jsd.py,sha256=bZ4otCvWBuOnA5XdQL-FzZVItJlDt-ht9e_pG7PG93E,3999
40
41
  liger_kernel/transformers/geglu.py,sha256=mrgqzIUVd6lN7fkDKLkw5YaESDxDtFgbot430WwPVOQ,1107
@@ -48,6 +49,7 @@ liger_kernel/transformers/rms_norm.py,sha256=GqCEJuGt0YdqqlMcToE0Wp4A8YFquDa4UUS
48
49
  liger_kernel/transformers/rope.py,sha256=ZTrTORSAyfcFIKjk6XEeYmk4ROH7xXED9L4g2NFntlE,999
49
50
  liger_kernel/transformers/swiglu.py,sha256=i9WTqcNRqReU4XJs391IPbl-I5X0wG4T72D4pqGFfJg,2422
50
51
  liger_kernel/transformers/trainer_integration.py,sha256=W3ON51O5GkyzNJsItz0y5rKx-uy2f2cFfveZpqbUdhw,123
52
+ liger_kernel/transformers/tvd.py,sha256=lt730XDR4IYEkn-HeWS7WU6AGssg90ubg8pqPAX2lbE,472
51
53
  liger_kernel/transformers/experimental/embedding.py,sha256=2P0QYdlFyFrG5OqTzTa1wcRgDSyjBMv5i1a7BrDPDQw,881
52
54
  liger_kernel/transformers/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
53
55
  liger_kernel/transformers/model/gemma.py,sha256=ky89b3aWPaeTGRMC-745KgixtQIRXzNAiCORAMLn9yo,9654
@@ -63,9 +65,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
63
65
  liger_kernel/transformers/trainer/orpo_trainer.py,sha256=pdekW7l6Qg_aqa5SYKYlSWUF8m3lkOFvFLcIMEHrz9s,8338
64
66
  liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
65
67
  liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
66
- liger_kernel_nightly-0.5.3.dev20250220230230.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
67
- liger_kernel_nightly-0.5.3.dev20250220230230.dist-info/METADATA,sha256=xtathj_pY7bV0Pkw0qNpzJ-cDVUXRy3AsSemRtaTRYY,21766
68
- liger_kernel_nightly-0.5.3.dev20250220230230.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
69
- liger_kernel_nightly-0.5.3.dev20250220230230.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
70
- liger_kernel_nightly-0.5.3.dev20250220230230.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
71
- liger_kernel_nightly-0.5.3.dev20250220230230.dist-info/RECORD,,
68
+ liger_kernel_nightly-0.5.3.dev20250221003838.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
69
+ liger_kernel_nightly-0.5.3.dev20250221003838.dist-info/METADATA,sha256=zJp1YMQDbbOzeNRKFf7AN7hYqePdT49OEMQkN_buKl8,21963
70
+ liger_kernel_nightly-0.5.3.dev20250221003838.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
71
+ liger_kernel_nightly-0.5.3.dev20250221003838.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
72
+ liger_kernel_nightly-0.5.3.dev20250221003838.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
73
+ liger_kernel_nightly-0.5.3.dev20250221003838.dist-info/RECORD,,