liger-kernel 0.3.0__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. liger_kernel/ops/cross_entropy.py +5 -39
  2. liger_kernel/ops/experimental/mm_int8int2.py +355 -0
  3. liger_kernel/ops/fused_linear_cross_entropy.py +13 -10
  4. liger_kernel/ops/fused_linear_jsd.py +245 -0
  5. liger_kernel/ops/geglu.py +2 -2
  6. liger_kernel/ops/jsd.py +176 -0
  7. liger_kernel/ops/kl_div.py +45 -34
  8. liger_kernel/ops/rms_norm.py +67 -42
  9. liger_kernel/ops/swiglu.py +2 -2
  10. liger_kernel/ops/utils.py +62 -1
  11. liger_kernel/transformers/__init__.py +3 -0
  12. liger_kernel/transformers/auto_model.py +18 -6
  13. liger_kernel/transformers/functional.py +4 -0
  14. liger_kernel/transformers/fused_linear_jsd.py +98 -0
  15. liger_kernel/transformers/jsd.py +75 -0
  16. liger_kernel/transformers/kl_div.py +3 -2
  17. liger_kernel/transformers/model/gemma.py +124 -1
  18. liger_kernel/transformers/model/llama.py +135 -4
  19. liger_kernel/transformers/model/mistral.py +3 -0
  20. liger_kernel/transformers/model/mixtral.py +153 -2
  21. liger_kernel/transformers/model/mllama.py +274 -0
  22. liger_kernel/transformers/model/phi3.py +140 -2
  23. liger_kernel/transformers/model/qwen2.py +123 -2
  24. liger_kernel/transformers/model/qwen2_vl.py +8 -1
  25. liger_kernel/transformers/monkey_patch.py +254 -129
  26. {liger_kernel-0.3.0.dist-info → liger_kernel-0.4.0.dist-info}/METADATA +74 -35
  27. liger_kernel-0.4.0.dist-info/NOTICE +58 -0
  28. liger_kernel-0.4.0.dist-info/RECORD +48 -0
  29. {liger_kernel-0.3.0.dist-info → liger_kernel-0.4.0.dist-info}/WHEEL +1 -1
  30. liger_kernel-0.3.0.dist-info/NOTICE +0 -4
  31. liger_kernel-0.3.0.dist-info/RECORD +0 -42
  32. {liger_kernel-0.3.0.dist-info → liger_kernel-0.4.0.dist-info}/LICENSE +0 -0
  33. {liger_kernel-0.3.0.dist-info → liger_kernel-0.4.0.dist-info}/top_level.txt +0 -0
@@ -10,6 +10,7 @@ https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddec
10
10
  Modifications made by Yanning Chen, 2024.
11
11
  """
12
12
 
13
+ import math
13
14
  import operator
14
15
 
15
16
  import torch
@@ -20,6 +21,7 @@ from liger_kernel.ops.utils import (
20
21
  calculate_settings,
21
22
  compare_version,
22
23
  ensure_contiguous,
24
+ torch_to_triton_dtype,
23
25
  )
24
26
 
25
27
  if compare_version("triton", operator.ge, "3.0.0"):
@@ -84,6 +86,10 @@ def _rms_norm_forward_kernel(
84
86
  W_row = W_row.to(tl.float32)
85
87
  X_row = X_row.to(tl.float32)
86
88
 
89
+ if casting_mode == _CASTING_MODE_NONE:
90
+ eps = eps.to(X_row_dtype)
91
+ offset = offset.to(X_row_dtype)
92
+
87
93
  mean_square = tl.sum(X_row * X_row, axis=0) / n_cols
88
94
  rstd = rsqrt(mean_square + eps)
89
95
 
@@ -100,6 +106,9 @@ def _rms_norm_forward_kernel(
100
106
 
101
107
  Y_row = X_row * (offset + W_row)
102
108
 
109
+ if casting_mode == _CASTING_MODE_GEMMA:
110
+ Y_row = Y_row.to(X_row_dtype)
111
+
103
112
  tl.store(Y_ptr + col_offsets, Y_row, mask=mask)
104
113
 
105
114
 
@@ -109,14 +118,17 @@ def _rms_norm_backward_kernel(
109
118
  dY_row_stride,
110
119
  X_ptr,
111
120
  X_row_stride,
121
+ X_dtype: tl.constexpr,
112
122
  W_ptr,
113
123
  W_row_stride,
114
124
  RSTD_ptr,
115
125
  RSTD_row_stride,
116
126
  dW_ptr,
117
127
  dW_row_stride,
128
+ n_rows,
118
129
  n_cols,
119
130
  offset,
131
+ rows_per_program: tl.constexpr,
120
132
  casting_mode: tl.constexpr,
121
133
  BLOCK_SIZE: tl.constexpr,
122
134
  ):
@@ -125,54 +137,60 @@ def _rms_norm_backward_kernel(
125
137
  dw = sum(dy * (x / RMS)). summation over BxT dimension
126
138
  """
127
139
 
128
- row_idx = tl.program_id(0)
140
+ row_block_id = tl.program_id(0)
141
+ row_start = row_block_id * rows_per_program
142
+ row_end = min((row_block_id + 1) * rows_per_program, n_rows)
129
143
  col_offsets = tl.arange(0, BLOCK_SIZE)
130
144
  mask = col_offsets < n_cols
131
145
 
132
- dY_ptr += row_idx * dY_row_stride
133
- X_ptr += row_idx * X_row_stride
134
- RSTD_ptr += row_idx * RSTD_row_stride
135
- dW_ptr += row_idx * dW_row_stride
146
+ dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
136
147
 
137
- dY_row = tl.load(dY_ptr + col_offsets, mask=mask, other=0)
138
- X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0)
139
- W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
140
- original_x_dtype = X_row.dtype
141
-
142
- # Get cached rms
143
- rstd_row = tl.load(RSTD_ptr)
148
+ dY_ptr += row_start * dY_row_stride
149
+ X_ptr += row_start * X_row_stride
150
+ RSTD_ptr += row_start
144
151
 
152
+ W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
145
153
  W_row = W_row + offset
146
154
 
147
- X_row = X_row.to(tl.float32)
155
+ for _ in range(row_start, row_end):
156
+ dY_row = tl.load(dY_ptr + col_offsets, mask=mask, other=0.0)
157
+ X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0.0)
148
158
 
149
- # Different bacward graphs for different casting modes
150
- if casting_mode == _CASTING_MODE_LLAMA:
151
- m = (dY_row * W_row).to(tl.float32)
159
+ # Get cached rms
160
+ rstd_row = tl.load(RSTD_ptr)
152
161
 
153
- elif casting_mode == _CASTING_MODE_GEMMA:
154
- dY_row, W_row = (
155
- dY_row.to(tl.float32),
156
- W_row.to(tl.float32),
157
- )
162
+ X_row = X_row.to(tl.float32)
158
163
 
159
- m = dY_row * W_row
164
+ # Different bacward graphs for different casting modes
165
+ if casting_mode == _CASTING_MODE_LLAMA:
166
+ m = (dY_row * W_row).to(tl.float32)
160
167
 
161
- dX_row = rstd_row * m
168
+ elif casting_mode == _CASTING_MODE_GEMMA:
169
+ dY_row = dY_row.to(tl.float32)
170
+ m = dY_row * W_row
171
+ else:
172
+ m = dY_row * W_row
162
173
 
163
- dX_row += (rstd_row) * (
164
- -(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row
165
- )
174
+ dX_row = rstd_row * m
166
175
 
167
- # calculate the gradient of W
168
- if casting_mode == _CASTING_MODE_LLAMA:
169
- dW_row = dY_row * (X_row * rstd_row).to(original_x_dtype)
170
- else:
171
- # here X_row is already in fp32 (see previous if block)
172
- dW_row = dY_row * (X_row * rstd_row)
176
+ dX_row += (rstd_row) * (
177
+ -(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row
178
+ )
179
+
180
+ # calculate the gradient of W
181
+ if casting_mode == _CASTING_MODE_LLAMA:
182
+ dW_row += dY_row * (X_row * rstd_row).to(X_dtype)
183
+ else:
184
+ # here X_row is already in fp32 (see previous if block)
185
+ dW_row += dY_row * (X_row * rstd_row)
173
186
 
174
- tl.store(dY_ptr + col_offsets, dX_row, mask=mask)
175
- tl.store(dW_ptr + col_offsets, dW_row, mask=mask)
187
+ tl.store(dY_ptr + col_offsets, dX_row.to(X_dtype), mask=mask)
188
+
189
+ dY_ptr += dY_row_stride
190
+ X_ptr += X_row_stride
191
+ RSTD_ptr += RSTD_row_stride
192
+
193
+ tl.store(dW_ptr + row_block_id * dW_row_stride + col_offsets, dW_row, mask=mask)
176
194
 
177
195
 
178
196
  _str_to_casting_mode = {
@@ -238,31 +256,38 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
238
256
  dim = shape[-1]
239
257
  dY = dY.view(-1, dim)
240
258
  n_rows, n_cols = dY.shape
241
- dW = torch.empty_like(
242
- X,
243
- dtype=(torch.float32 if casting_mode == _CASTING_MODE_GEMMA.value else W.dtype),
244
- )
245
259
 
260
+ sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
261
+ # fp32 for numerical stability especially.
262
+ _dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
263
+
264
+ if n_cols > BLOCK_SIZE:
265
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
266
+ rows_per_program = math.ceil(n_rows / sm_count)
267
+ grid = (sm_count,)
246
268
  # Here we use dY to store the value of dX to save memory
247
- _rms_norm_backward_kernel[(n_rows,)](
269
+ _rms_norm_backward_kernel[grid](
248
270
  dY,
249
271
  dY.stride(0),
250
272
  X,
251
273
  X.stride(0),
274
+ torch_to_triton_dtype[X.dtype],
252
275
  W,
253
276
  W.stride(0),
254
277
  RSTD,
255
278
  RSTD.stride(0),
256
- dW,
257
- dW.stride(0),
279
+ _dW,
280
+ _dW.stride(0),
281
+ n_rows,
258
282
  n_cols,
259
283
  offset,
284
+ rows_per_program,
260
285
  casting_mode,
261
286
  BLOCK_SIZE=BLOCK_SIZE,
262
287
  num_warps=num_warps,
263
288
  )
264
289
  dX = dY.view(*shape)
265
- dW = torch.sum(dW, dim=0).to(W.dtype)
290
+ dW = _dW.sum(dim=0).to(W.dtype)
266
291
  return dX, dW
267
292
 
268
293
 
@@ -14,7 +14,7 @@ def silu(x):
14
14
  def _swiglu_forward_kernel(
15
15
  a_ptr, b_ptr, c_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
16
16
  ):
17
- program_id = tl.program_id(0)
17
+ program_id = tl.program_id(0).to(tl.int64)
18
18
 
19
19
  # locate start index
20
20
  a_ptr += program_id * stride
@@ -35,7 +35,7 @@ def _swiglu_forward_kernel(
35
35
  def _swiglu_backward_kernel(
36
36
  dc_ptr, a_ptr, b_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
37
37
  ):
38
- program_id = tl.program_id(0)
38
+ program_id = tl.program_id(0).to(tl.int64)
39
39
 
40
40
  # locate start index
41
41
  dc_ptr += program_id * stride
liger_kernel/ops/utils.py CHANGED
@@ -12,13 +12,19 @@ Modifications made by Yanning Chen, 2024.
12
12
 
13
13
  import functools
14
14
  import importlib
15
+ import operator
15
16
  from typing import Callable
16
17
 
17
18
  import torch
18
19
  import triton
20
+ import triton.language as tl
19
21
  from packaging.version import Version
20
22
 
21
23
 
24
+ def is_hip() -> bool:
25
+ return torch.version.hip is not None
26
+
27
+
22
28
  def ensure_contiguous(fn):
23
29
  @functools.wraps(fn)
24
30
  def wrapper(ctx, *args, **kwargs):
@@ -45,7 +51,7 @@ def calculate_settings(n):
45
51
 
46
52
  num_warps = 4
47
53
  if BLOCK_SIZE >= 32768:
48
- num_warps = 32
54
+ num_warps = 32 if not is_hip() else 16
49
55
  elif BLOCK_SIZE >= 8192:
50
56
  num_warps = 16
51
57
  elif BLOCK_SIZE >= 2048:
@@ -60,3 +66,58 @@ def compare_version(package: str, operator: Callable, target: str):
60
66
  return False
61
67
  pkg_version = Version(pkg.__version__)
62
68
  return operator(pkg_version, Version(target))
69
+
70
+
71
+ def get_amp_custom_fwd_bwd() -> Callable:
72
+ if compare_version("torch", operator.ge, "2.4.0"):
73
+ return (
74
+ functools.partial(torch.amp.custom_fwd, device_type="cuda"),
75
+ functools.partial(torch.amp.custom_bwd, device_type="cuda"),
76
+ )
77
+ return torch.cuda.amp.custom_fwd, torch.cuda.amp.custom_bwd
78
+
79
+
80
+ amp_custom_fwd, amp_custom_bwd = get_amp_custom_fwd_bwd()
81
+
82
+
83
+ torch_to_triton_dtype = {
84
+ torch.float32: tl.float32,
85
+ torch.float16: tl.float16,
86
+ torch.bfloat16: tl.bfloat16,
87
+ }
88
+
89
+
90
+ @triton.jit
91
+ def element_mul_kernel(
92
+ X_ptr,
93
+ X_stride,
94
+ grad_output_ptr,
95
+ n_cols,
96
+ BLOCK_SIZE: tl.constexpr,
97
+ ):
98
+ """
99
+ This function multiplies each element of the tensor pointed by X_ptr with the value pointed by grad_output_ptr.
100
+ The multiplication is performed in-place on the tensor pointed by X_ptr.
101
+
102
+ Parameters:
103
+ X_ptr: Pointer to the input tensor.
104
+ X_stride (int): The stride of the input tensor.
105
+ grad_output_ptr: Pointer to the gradient output value.
106
+ n_cols (int): The number of columns in the input tensor.
107
+ BLOCK_SIZE (int): The block size for Triton operations.
108
+ """
109
+
110
+ # Get the program ID and convert it to int64 to avoid overflow
111
+ program_id = tl.program_id(0).to(tl.int64)
112
+
113
+ # Locate the start index
114
+ X_ptr += program_id * X_stride
115
+
116
+ # Load the gradient output value
117
+ grad_output = tl.load(grad_output_ptr)
118
+
119
+ # Perform the element-wise multiplication
120
+ for i in range(0, n_cols, BLOCK_SIZE):
121
+ X_offsets = i + tl.arange(0, BLOCK_SIZE)
122
+ X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols)
123
+ tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols)
@@ -5,7 +5,9 @@ from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss # noq
5
5
  from liger_kernel.transformers.fused_linear_cross_entropy import ( # noqa: F401
6
6
  LigerFusedLinearCrossEntropyLoss,
7
7
  )
8
+ from liger_kernel.transformers.fused_linear_jsd import LigerFusedLinearJSD # noqa: F401
8
9
  from liger_kernel.transformers.geglu import LigerGEGLUMLP # noqa: F401
10
+ from liger_kernel.transformers.jsd import LigerJSD # noqa: F401
9
11
  from liger_kernel.transformers.layer_norm import LigerLayerNorm # noqa: F401
10
12
  from liger_kernel.transformers.monkey_patch import ( # noqa: F401
11
13
  _apply_liger_kernel,
@@ -15,6 +17,7 @@ from liger_kernel.transformers.monkey_patch import ( # noqa: F401
15
17
  apply_liger_kernel_to_llama,
16
18
  apply_liger_kernel_to_mistral,
17
19
  apply_liger_kernel_to_mixtral,
20
+ apply_liger_kernel_to_mllama,
18
21
  apply_liger_kernel_to_phi3,
19
22
  apply_liger_kernel_to_qwen2,
20
23
  apply_liger_kernel_to_qwen2_vl,
@@ -1,6 +1,11 @@
1
+ import inspect
2
+
1
3
  from transformers import AutoConfig, AutoModelForCausalLM
2
4
 
3
- from liger_kernel.transformers.monkey_patch import _apply_liger_kernel
5
+ from liger_kernel.transformers.monkey_patch import (
6
+ MODEL_TYPE_TO_APPLY_LIGER_FN,
7
+ _apply_liger_kernel,
8
+ )
4
9
 
5
10
 
6
11
  def _get_model_config(model_dir, **model_init_kwargs):
@@ -21,13 +26,20 @@ class AutoLigerKernelForCausalLM(AutoModelForCausalLM):
21
26
  # Determine the model type and apply the Liger Kernel if applicable
22
27
  # Note: _apply_liger_kernel will only pass relevant kwargs to the apply_liger_kernel_to_* function
23
28
  model_type = model_config.model_type
29
+
24
30
  _apply_liger_kernel(model_type, **kwargs)
25
31
 
26
- # Retain only the keyword args present in the model configuration
27
- for k in list(kwargs.keys()):
28
- if k not in model_config.__dict__:
29
- del kwargs[k]
32
+ # Filter out kwargs that were passed to the apply_liger_* function, which will cause
33
+ # model initialization errors otherwise
34
+ apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
35
+ apply_fn_signature = inspect.signature(apply_fn)
36
+
37
+ applicable_kwargs = {
38
+ key: value
39
+ for key, value in kwargs.items()
40
+ if key not in apply_fn_signature.parameters
41
+ }
30
42
 
31
43
  return super().from_pretrained(
32
- pretrained_model_name_or_path, *model_args, **kwargs
44
+ pretrained_model_name_or_path, *model_args, **applicable_kwargs
33
45
  )
@@ -2,7 +2,9 @@ from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
2
2
  from liger_kernel.ops.fused_linear_cross_entropy import (
3
3
  LigerFusedLinearCrossEntropyFunction,
4
4
  )
5
+ from liger_kernel.ops.fused_linear_jsd import LigerFusedLinearJSDFunction
5
6
  from liger_kernel.ops.geglu import LigerGELUMulFunction
7
+ from liger_kernel.ops.jsd import LigerJSDFunction
6
8
  from liger_kernel.ops.kl_div import LigerKLDivLossFunction
7
9
  from liger_kernel.ops.layer_norm import LigerLayerNormFunction
8
10
  from liger_kernel.ops.rms_norm import LigerRMSNormFunction
@@ -17,3 +19,5 @@ liger_rms_norm = LigerRMSNormFunction.apply
17
19
  liger_rope = LigerRopeFunction.apply
18
20
  liger_layer_norm = LigerLayerNormFunction.apply
19
21
  liger_kl_div = LigerKLDivLossFunction.apply
22
+ liger_jsd = LigerJSDFunction.apply
23
+ liger_fused_linear_jsd = LigerFusedLinearJSDFunction.apply
@@ -0,0 +1,98 @@
1
+ from typing import Optional
2
+
3
+ import torch
4
+
5
+ from liger_kernel.ops.fused_linear_jsd import LigerFusedLinearJSDFunction
6
+
7
+
8
+ class LigerFusedLinearJSD(torch.nn.Module):
9
+ r"""Fusing the last linear layer with generalized JSD
10
+
11
+ Handle the forward and backward pass of the final linear layer via JSD by avoiding
12
+ the materialization of the large logits tensor.
13
+
14
+ Args:
15
+ jsd_beta (float): coefficient beta of generalized JSD in the open interval (0, 1). Default: `0.5`
16
+ ignore_index (int): The index to ignore in the target. Default: `-100`
17
+ temperature (float): temperature in softmax function to control the output probability distribution. Default: `1.0`
18
+
19
+ Shape:
20
+ - student_input: :math:`(BT, H)`, where B is batch size, T is sequence length, H is hidden dimension.
21
+ - student_weight: :math:`(V, H)`, where V is vocab size.
22
+ - teacher_input: :math:`(BT, H')`, where H' is hidden dimension of the teacher model.
23
+ - teacher_weight: :math:`(V, H')`, where hidden size H and H' can be different.
24
+ - shift_labels: :math:`(BT,)`
25
+ - Output: a scalar.
26
+
27
+ Examples:
28
+ ```python
29
+ >>> (B, T, H_s, H_t, V) = (2, 2, 3, 5, 10)
30
+ >>> fused_jsd = LigerFusedLinearJSD(jsd_beta=0.1, temperature=2.0)
31
+ >>> # generate inputs and weights
32
+ >>> student_input = torch.rand(B * T, H_s, device="cuda", requires_grad=True)
33
+ >>> student_lin = torch.nn.Linear(H_s, V, bias=False, device="cuda")
34
+ >>> # teacher input doesn't require grad, hidden_dim can be different from student's
35
+ >>> teacher_input = torch.rand(B * T, H_t, device="cuda")
36
+ >>> teacher_lin = torch.nn.Linear(H_t, V, bias=False, device="cuda")
37
+ >>> output = fused_jsd(student_input, student_lin.weight, teacher_input, teacher_lin.weight)
38
+ >>> output.backward()
39
+ >>>
40
+ >>> # Example with labels for supervised fine-tuning (SFT) context:
41
+ >>>
42
+ >>> # Assume hidden_states, lm_heads and corresponding labels are given
43
+ >>> student_lm_head = torch.nn.Linear(H_s, V, bias=False)
44
+ >>> student_hidden_states = torch.randn(B * T, H_s, requires_grad=True).log_softmax(dim=-1)
45
+ >>> teacher_lm_head = torch.nn.Linear(H_t, V, bias=False)
46
+ >>> teacher_hidden_states = torch.randn(B * T, H_t).log_softmax(dim=-1)
47
+ >>> labels = torch.randint(0, V, (B * T,), torch.long)
48
+ >>>
49
+ >>> # Shift so that tokens < n predict n
50
+ >>> shift_student_hidden_states = student_hidden_states[..., :-1, :].contiguous()
51
+ >>> shift_teacher_hidden_states = teacher_hidden_states[..., :-1, :].contiguous()
52
+ >>> shift_labels = labels[..., 1:].contiguous()
53
+ >>>
54
+ >>> # Flatten tokens
55
+ >>> shift_student_hidden_states = shift_student_hidden_states.view(-1, V)
56
+ >>> shift_teacher_hidden_states = shift_teacher_hidden_states.view(-1, V)
57
+ >>> shift_labels = shift_labels.view(-1)
58
+ >>>
59
+ >>> # Calculate loss
60
+ >>> loss_fct = LigerJSD(beta=0.1)
61
+ >>> loss = loss_fct(
62
+ >>> shift_studetn_hidden_states,
63
+ >>> student_lm_head.weight,
64
+ >>> shift_teacher_hidden_states,
65
+ >>> teacher_lm_head.weight,
66
+ >>> shift_labels
67
+ >>> )
68
+ ```
69
+ """
70
+
71
+ def __init__(self, jsd_beta=0.5, ignore_index=-100, temperature=1.0):
72
+ super().__init__()
73
+ assert (
74
+ jsd_beta > 0 and jsd_beta < 1
75
+ ), f"beta must be greater than 0 and less than 1. Got: {jsd_beta}"
76
+ assert temperature != 0, "temperature cannot be 0."
77
+ self.jsd_beta = jsd_beta
78
+ self.temperature = temperature
79
+ self.ignore_index = ignore_index
80
+
81
+ def forward(
82
+ self,
83
+ student_input: torch.Tensor,
84
+ student_weight: torch.Tensor,
85
+ teacher_input: torch.Tensor,
86
+ teacher_weight: torch.Tensor,
87
+ shift_labels: Optional[torch.LongTensor],
88
+ ):
89
+ return LigerFusedLinearJSDFunction.apply(
90
+ student_input,
91
+ student_weight,
92
+ teacher_input,
93
+ teacher_weight,
94
+ shift_labels,
95
+ self.jsd_beta,
96
+ self.ignore_index,
97
+ self.temperature,
98
+ )
@@ -0,0 +1,75 @@
1
+ from typing import Optional
2
+
3
+ import torch
4
+
5
+ from liger_kernel.ops.jsd import LigerJSDFunction
6
+
7
+
8
+ class LigerJSD(torch.nn.Module):
9
+ r"""The generalized Jensen-Shannon Divergence.
10
+ .. math::
11
+ JSD(\beta)(P || Q)
12
+ = \beta * KLDiv(P || (\beta * P + (1 - \beta) * Q)) + (1 - \beta) * KLDiv(Q || (\beta * P + (1 - \beta) * Q))
13
+ .. note::
14
+ As all the other losses in PyTorch, this function expects the first argument,
15
+ :attr:`log_q`, to be the predictions, the output of the student model in log-space,
16
+ and the second, :attr:`log_p`, to be the observations, the output of the teacher model in log-space.
17
+ This differs from the standard mathematical notation :math:`JSD(P || Q)` where
18
+ :math:`P` denotes the teacher model and :math:`Q` denotes the student model.
19
+
20
+ Args:
21
+ beta (float): coefficient beta of generalized JSD in the open interval (0, 1). Default: `0.5`
22
+ ignore_index (int): The index to ignore in the target. Default: `-100`
23
+
24
+ Shape:
25
+ - Input: :math:`(BT, V)`, where B is batch size, T is sequence length, V is vocab size.
26
+ - Target: :math:`(BT, V)`, same shape as the input.
27
+ - shift_labels (Optional): :math:`(BT,)`
28
+ - Output: a scalar.
29
+
30
+ Examples:
31
+ ```python
32
+ >>> (B, T, V) = (2, 2, 5)
33
+ >>> jsd = LigerJSD(beta=0.1)
34
+ >>> # input should be a distribution in the log space
35
+ >>> input = torch.randn(B * T, V, requires_grad=True).log_softmax(dim=-1)
36
+ >>> target = torch.randn(B * T, V).log_softmax(dim=-1)
37
+ >>> output = jsd(input, target)
38
+ >>>
39
+ >>> # Example with labels for supervised fine-tuning (SFT) context
40
+ >>> # Assume logits and corresponding labels are given
41
+ >>> student_logits = torch.randn(B * T, V, requires_grad=True).log_softmax(dim=-1)
42
+ >>> teacher_logits = torch.randn(B * T, V).log_softmax(dim=-1)
43
+ >>> labels = torch.randint(0, V, (B * T,), torch.long)
44
+ >>> # Shift so that tokens < n predict n
45
+ >>> shift_student_logits = student_logits[..., :-1, :].contiguous()
46
+ >>> shift_teacher_logits = teacher_logits[..., :-1, :].contiguous()
47
+ >>> shift_labels = labels[..., 1:].contiguous()
48
+ >>> # Flatten tokens
49
+ >>> shift_student_logits = shift_student_logits.view(-1, V)
50
+ >>> shift_teacher_logits = shift_teacher_logits.view(-1, V)
51
+ >>> shift_labels = shift_labels.view(-1)
52
+ >>> # Calculate loss
53
+ >>> loss_fct = LigerJSD(beta=0.1)
54
+ >>> loss = loss_fct(shift_studetn_logits, shift_teacher_logits, shift_labels)
55
+
56
+ ```
57
+ """
58
+
59
+ def __init__(self, beta: float = 0.5, ignore_index: int = -100):
60
+ super().__init__()
61
+ assert (
62
+ beta > 0 and beta < 1
63
+ ), f"beta must be greater than 0 and less than 1. Got: {beta}"
64
+ self.beta = beta
65
+ self.ignore_index = ignore_index
66
+
67
+ def forward(
68
+ self,
69
+ log_q: torch.Tensor,
70
+ log_p: torch.Tensor,
71
+ shift_labels: Optional[torch.LongTensor] = None,
72
+ ):
73
+ return LigerJSDFunction.apply(
74
+ log_q, log_p, shift_labels, self.beta, self.ignore_index
75
+ )
@@ -4,10 +4,11 @@ from liger_kernel.ops.kl_div import LigerKLDivLossFunction
4
4
 
5
5
 
6
6
  class LigerKLDIVLoss(nn.KLDivLoss):
7
- def __init__(self, *args, **kwargs):
7
+ def __init__(self, eps: float = 1e-10, *args, **kwargs):
8
8
  super(LigerKLDIVLoss, self).__init__(*args, **kwargs)
9
+ self.eps = eps
9
10
 
10
11
  def forward(self, y_pred, y_true):
11
12
  return LigerKLDivLossFunction.apply(
12
- y_pred, y_true, self.reduction, self.log_target
13
+ y_pred, y_true, self.reduction, self.log_target, self.eps
13
14
  )