liger-kernel 0.5.2__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. liger_kernel/chunked_loss/README.md +25 -0
  2. liger_kernel/chunked_loss/__init__.py +2 -0
  3. liger_kernel/chunked_loss/cpo_loss.py +18 -8
  4. liger_kernel/chunked_loss/dpo_loss.py +20 -10
  5. liger_kernel/chunked_loss/functional.py +4 -0
  6. liger_kernel/chunked_loss/fused_linear_distillation.py +58 -44
  7. liger_kernel/chunked_loss/fused_linear_preference.py +108 -60
  8. liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +246 -0
  9. liger_kernel/chunked_loss/jsd_loss.py +154 -0
  10. liger_kernel/chunked_loss/kto_loss.py +172 -0
  11. liger_kernel/chunked_loss/orpo_loss.py +8 -9
  12. liger_kernel/chunked_loss/simpo_loss.py +22 -8
  13. liger_kernel/env_report.py +5 -12
  14. liger_kernel/ops/cross_entropy.py +102 -51
  15. liger_kernel/ops/experimental/embedding.py +1 -3
  16. liger_kernel/ops/experimental/mm_int8int2.py +3 -9
  17. liger_kernel/ops/fused_linear_cross_entropy.py +89 -55
  18. liger_kernel/ops/fused_linear_jsd.py +11 -29
  19. liger_kernel/ops/geglu.py +6 -17
  20. liger_kernel/ops/group_norm.py +11 -28
  21. liger_kernel/ops/jsd.py +2 -6
  22. liger_kernel/ops/kl_div.py +8 -11
  23. liger_kernel/ops/layer_norm.py +3 -5
  24. liger_kernel/ops/qwen2vl_mrope.py +8 -25
  25. liger_kernel/ops/rms_norm.py +14 -32
  26. liger_kernel/ops/rope.py +31 -33
  27. liger_kernel/ops/swiglu.py +4 -8
  28. liger_kernel/ops/utils.py +2 -0
  29. liger_kernel/transformers/__init__.py +16 -24
  30. liger_kernel/transformers/auto_model.py +6 -13
  31. liger_kernel/transformers/cross_entropy.py +4 -6
  32. liger_kernel/transformers/experimental/embedding.py +1 -3
  33. liger_kernel/transformers/functional.py +11 -7
  34. liger_kernel/transformers/fused_linear_cross_entropy.py +12 -7
  35. liger_kernel/transformers/geglu.py +1 -4
  36. liger_kernel/transformers/group_norm.py +3 -9
  37. liger_kernel/transformers/jsd.py +1 -3
  38. liger_kernel/transformers/kl_div.py +1 -3
  39. liger_kernel/transformers/layer_norm.py +3 -9
  40. liger_kernel/transformers/model/gemma.py +18 -40
  41. liger_kernel/transformers/model/gemma2.py +19 -41
  42. liger_kernel/transformers/model/llama.py +22 -48
  43. liger_kernel/transformers/model/mistral.py +14 -26
  44. liger_kernel/transformers/model/mixtral.py +24 -54
  45. liger_kernel/transformers/model/mllama.py +16 -36
  46. liger_kernel/transformers/model/phi3.py +18 -40
  47. liger_kernel/transformers/model/qwen2.py +18 -40
  48. liger_kernel/transformers/model/qwen2_vl.py +36 -32
  49. liger_kernel/transformers/monkey_patch.py +43 -117
  50. liger_kernel/transformers/rms_norm.py +4 -4
  51. liger_kernel/transformers/rope.py +2 -2
  52. liger_kernel/transformers/swiglu.py +2 -8
  53. liger_kernel/transformers/trainer/__init__.py +1 -3
  54. liger_kernel/transformers/trainer/orpo_trainer.py +31 -18
  55. liger_kernel/triton/__init__.py +1 -3
  56. liger_kernel/triton/monkey_patch.py +1 -3
  57. {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.3.dist-info}/METADATA +38 -25
  58. liger_kernel-0.5.3.dist-info/RECORD +69 -0
  59. {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.3.dist-info}/WHEEL +1 -1
  60. liger_kernel-0.5.2.dist-info/RECORD +0 -65
  61. {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.3.dist-info}/LICENSE +0 -0
  62. {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.3.dist-info}/NOTICE +0 -0
  63. {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.3.dist-info}/top_level.txt +0 -0
@@ -1,31 +1,23 @@
1
- from liger_kernel.transformers.auto_model import ( # noqa: F401
2
- AutoLigerKernelForCausalLM,
3
- )
1
+ from liger_kernel.transformers.auto_model import AutoLigerKernelForCausalLM # noqa: F401
4
2
  from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss # noqa: F401
5
- from liger_kernel.transformers.fused_linear_cross_entropy import ( # noqa: F401
6
- LigerFusedLinearCrossEntropyLoss,
7
- )
3
+ from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss # noqa: F401
8
4
  from liger_kernel.transformers.fused_linear_jsd import LigerFusedLinearJSD # noqa: F401
9
5
  from liger_kernel.transformers.geglu import LigerGEGLUMLP # noqa: F401
10
6
  from liger_kernel.transformers.jsd import LigerJSD # noqa: F401
11
7
  from liger_kernel.transformers.layer_norm import LigerLayerNorm # noqa: F401
12
- from liger_kernel.transformers.monkey_patch import ( # noqa: F401
13
- _apply_liger_kernel,
14
- _apply_liger_kernel_to_instance,
15
- apply_liger_kernel_to_gemma,
16
- apply_liger_kernel_to_gemma2,
17
- apply_liger_kernel_to_llama,
18
- apply_liger_kernel_to_mistral,
19
- apply_liger_kernel_to_mixtral,
20
- apply_liger_kernel_to_mllama,
21
- apply_liger_kernel_to_phi3,
22
- apply_liger_kernel_to_qwen2,
23
- apply_liger_kernel_to_qwen2_vl,
24
- )
8
+ from liger_kernel.transformers.monkey_patch import _apply_liger_kernel # noqa: F401
9
+ from liger_kernel.transformers.monkey_patch import _apply_liger_kernel_to_instance # noqa: F401
10
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma # noqa: F401
11
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma2 # noqa: F401
12
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama # noqa: F401
13
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mistral # noqa: F401
14
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mixtral # noqa: F401
15
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mllama # noqa: F401
16
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_phi3 # noqa: F401
17
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2 # noqa: F401
18
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2_vl # noqa: F401
25
19
  from liger_kernel.transformers.rms_norm import LigerRMSNorm # noqa: F401
26
20
  from liger_kernel.transformers.rope import liger_rotary_pos_emb # noqa: F401
27
- from liger_kernel.transformers.swiglu import ( # noqa: F401
28
- LigerBlockSparseTop2MLP,
29
- LigerPhi3SwiGLUMLP,
30
- LigerSwiGLUMLP,
31
- )
21
+ from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP # noqa: F401
22
+ from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP # noqa: F401
23
+ from liger_kernel.transformers.swiglu import LigerSwiGLUMLP # noqa: F401
@@ -1,11 +1,10 @@
1
1
  import inspect
2
2
 
3
- from transformers import AutoConfig, AutoModelForCausalLM
3
+ from transformers import AutoConfig
4
+ from transformers import AutoModelForCausalLM
4
5
 
5
- from liger_kernel.transformers.monkey_patch import (
6
- MODEL_TYPE_TO_APPLY_LIGER_FN,
7
- _apply_liger_kernel,
8
- )
6
+ from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
7
+ from liger_kernel.transformers.monkey_patch import _apply_liger_kernel
9
8
 
10
9
 
11
10
  def _get_model_config(model_dir, **model_init_kwargs):
@@ -34,12 +33,6 @@ class AutoLigerKernelForCausalLM(AutoModelForCausalLM):
34
33
  apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
35
34
  apply_fn_signature = inspect.signature(apply_fn)
36
35
 
37
- applicable_kwargs = {
38
- key: value
39
- for key, value in kwargs.items()
40
- if key not in apply_fn_signature.parameters
41
- }
36
+ applicable_kwargs = {key: value for key, value in kwargs.items() if key not in apply_fn_signature.parameters}
42
37
 
43
- return super().from_pretrained(
44
- pretrained_model_name_or_path, *model_args, **applicable_kwargs
45
- )
38
+ return super().from_pretrained(pretrained_model_name_or_path, *model_args, **applicable_kwargs)
@@ -8,6 +8,7 @@ from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
8
8
  class LigerCrossEntropyLoss(torch.nn.Module):
9
9
  def __init__(
10
10
  self,
11
+ weight: Optional[torch.FloatTensor] = None,
11
12
  ignore_index: int = -100,
12
13
  lse_square_scale: float = 0.0,
13
14
  label_smoothing: float = 0.0,
@@ -19,17 +20,13 @@ class LigerCrossEntropyLoss(torch.nn.Module):
19
20
  assert (label_smoothing >= 0) and (
20
21
  label_smoothing <= 1
21
22
  ), f"label_smoothing must be between 0.0 and 1.0. Got: {label_smoothing}"
22
- assert (label_smoothing >= 0) and (
23
- label_smoothing <= 1
24
- ), f"label_smoothing must be between 0.0 and 1.0. Got: {label_smoothing}"
25
23
  assert reduction in {
26
24
  "mean",
27
25
  "sum",
28
26
  "none",
29
27
  }, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {reduction}"
30
- assert (
31
- softcap is None or softcap > 0
32
- ), f"softcap must greater than 0.0 or None. Got: {softcap}"
28
+ assert softcap is None or softcap > 0, f"softcap must greater than 0.0 or None. Got: {softcap}"
29
+ self.weight = weight
33
30
  self.ignore_index = ignore_index
34
31
  self.lse_square_scale = lse_square_scale
35
32
  self.label_smoothing = label_smoothing
@@ -41,6 +38,7 @@ class LigerCrossEntropyLoss(torch.nn.Module):
41
38
  loss, z_loss = LigerCrossEntropyFunction.apply(
42
39
  _input,
43
40
  target,
41
+ self.weight,
44
42
  self.ignore_index,
45
43
  self.lse_square_scale,
46
44
  self.label_smoothing,
@@ -7,9 +7,7 @@ from liger_kernel.ops.experimental.embedding import LigerEmbeddingFunction
7
7
 
8
8
 
9
9
  class LigerEmbedding(nn.Module):
10
- def __init__(
11
- self, num_embeddings, embedding_dim, padding_idx: Optional[int] = None
12
- ):
10
+ def __init__(self, num_embeddings, embedding_dim, padding_idx: Optional[int] = None):
13
11
  super().__init__()
14
12
  self.num_embeddings = num_embeddings
15
13
  self.embedding_dim = embedding_dim
@@ -1,9 +1,7 @@
1
1
  from typing import Optional
2
2
 
3
3
  from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
4
- from liger_kernel.ops.fused_linear_cross_entropy import (
5
- LigerFusedLinearCrossEntropyFunction,
6
- )
4
+ from liger_kernel.ops.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyFunction
7
5
  from liger_kernel.ops.fused_linear_jsd import LigerFusedLinearJSDFunction
8
6
  from liger_kernel.ops.geglu import LigerGELUMulFunction
9
7
  from liger_kernel.ops.group_norm import LigerGroupNormFunction
@@ -34,6 +32,7 @@ def liger_cross_entropy(
34
32
  loss, z_loss = LigerCrossEntropyFunction.apply(
35
33
  input,
36
34
  target,
35
+ weight,
37
36
  ignore_index,
38
37
  lse_square_scale,
39
38
  label_smoothing,
@@ -51,23 +50,30 @@ def liger_fused_linear_cross_entropy(
51
50
  weight,
52
51
  target,
53
52
  bias=None,
53
+ ce_weight=None,
54
54
  ignore_index: int = -100,
55
55
  lse_square_scale: float = 0.0,
56
56
  label_smoothing: float = 0.0,
57
57
  reduction: str = "mean",
58
58
  softcap: Optional[float] = None,
59
+ return_z_loss: bool = False,
59
60
  ):
60
- return LigerFusedLinearCrossEntropyFunction.apply(
61
+ loss, z_loss = LigerFusedLinearCrossEntropyFunction.apply(
61
62
  input,
62
63
  weight,
63
64
  target,
64
65
  bias,
66
+ ce_weight,
65
67
  ignore_index,
66
68
  lse_square_scale,
67
69
  label_smoothing,
68
70
  reduction,
69
71
  softcap,
72
+ return_z_loss,
70
73
  )
74
+ if not return_z_loss:
75
+ return loss
76
+ return loss, z_loss
71
77
 
72
78
 
73
79
  def liger_fused_linear_jsd(
@@ -159,9 +165,7 @@ def liger_qwen2vl_mrope(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
159
165
  return LigerQwen2VLMRopeFunction.apply(q, k, cos, sin, mrope_section, unsqueeze_dim)
160
166
 
161
167
 
162
- def liger_rms_norm(
163
- X, W, eps, offset: float = 0.0, casting_mode: str = "llama", in_place: bool = True
164
- ):
168
+ def liger_rms_norm(X, W, eps, offset: float = 0.0, casting_mode: str = "llama", in_place: bool = True):
165
169
  return LigerRMSNormFunction.apply(X, W, eps, offset, casting_mode, in_place)
166
170
 
167
171
 
@@ -2,19 +2,19 @@ from typing import Optional
2
2
 
3
3
  import torch
4
4
 
5
- from liger_kernel.ops.fused_linear_cross_entropy import (
6
- LigerFusedLinearCrossEntropyFunction,
7
- )
5
+ from liger_kernel.ops.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyFunction
8
6
 
9
7
 
10
8
  class LigerFusedLinearCrossEntropyLoss(torch.nn.Module):
11
9
  def __init__(
12
10
  self,
11
+ ce_weight: Optional[torch.FloatTensor] = None,
13
12
  ignore_index: int = -100,
14
13
  lse_square_scale: float = 0.0,
15
14
  label_smoothing: float = 0.0,
16
15
  reduction: str = "mean",
17
16
  softcap: Optional[float] = None,
17
+ return_z_loss: bool = False,
18
18
  ):
19
19
  super().__init__()
20
20
  assert (label_smoothing >= 0) and (
@@ -25,24 +25,29 @@ class LigerFusedLinearCrossEntropyLoss(torch.nn.Module):
25
25
  "sum",
26
26
  "none",
27
27
  }, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {reduction}"
28
- assert (
29
- softcap is None or softcap > 0
30
- ), f"softcap must greater than 0.0 or None. Got: {softcap}"
28
+ assert softcap is None or softcap > 0, f"softcap must greater than 0.0 or None. Got: {softcap}"
29
+ self.ce_weight = ce_weight
31
30
  self.ignore_index = ignore_index
32
31
  self.lse_square_scale = lse_square_scale
33
32
  self.label_smoothing = label_smoothing
34
33
  self.reduction = reduction
35
34
  self.softcap = softcap
35
+ self.return_z_loss = return_z_loss
36
36
 
37
37
  def forward(self, lin_weight, _input, target, bias=None):
38
- return LigerFusedLinearCrossEntropyFunction.apply(
38
+ loss, z_loss = LigerFusedLinearCrossEntropyFunction.apply(
39
39
  _input,
40
40
  lin_weight,
41
41
  target,
42
42
  bias,
43
+ self.ce_weight,
43
44
  self.ignore_index,
44
45
  self.lse_square_scale,
45
46
  self.label_smoothing,
46
47
  self.reduction,
47
48
  self.softcap,
49
+ self.return_z_loss,
48
50
  )
51
+ if not self.return_z_loss:
52
+ return loss
53
+ return loss, z_loss
@@ -19,7 +19,4 @@ class LigerGEGLUMLP(nn.Module):
19
19
  # So we can safely assume we use tanh approximation form all the time
20
20
 
21
21
  def forward(self, x):
22
-
23
- return self.down_proj(
24
- LigerGELUMulFunction.apply(self.gate_proj(x), self.up_proj(x))
25
- )
22
+ return self.down_proj(LigerGELUMulFunction.apply(self.gate_proj(x), self.up_proj(x)))
@@ -27,19 +27,13 @@ class LigerGroupNorm(nn.Module):
27
27
  self.num_channels = num_channels
28
28
  self.num_groups = num_groups
29
29
  self.eps = eps
30
- self.weight = nn.Parameter(
31
- torch.ones(num_channels) if init_fn == "ones" else torch.zeros(num_channels)
32
- )
33
- self.bias = nn.Parameter(
34
- torch.randn(num_channels) if bias else torch.zeros(num_channels)
35
- )
30
+ self.weight = nn.Parameter(torch.ones(num_channels) if init_fn == "ones" else torch.zeros(num_channels))
31
+ self.bias = nn.Parameter(torch.randn(num_channels) if bias else torch.zeros(num_channels))
36
32
  self.variance_epsilon = eps
37
33
 
38
34
  def forward(self, hidden_states):
39
35
  # hidden_states: (batch_size, num_channels, *)
40
- assert (
41
- hidden_states.dim() >= 3
42
- ), f"Input must have atleast 3 dimensions, got {hidden_states.dim()}"
36
+ assert hidden_states.dim() >= 3, f"Input must have atleast 3 dimensions, got {hidden_states.dim()}"
43
37
  assert (
44
38
  hidden_states.size(1) == self.num_channels
45
39
  ), f"Input tensor must have {self.num_channels} channels, got {hidden_states.size(1)}"
@@ -67,6 +67,4 @@ class LigerJSD(torch.nn.Module):
67
67
  log_p: torch.Tensor,
68
68
  shift_labels: Optional[torch.LongTensor] = None,
69
69
  ):
70
- return LigerJSDFunction.apply(
71
- log_q, log_p, shift_labels, self.beta, self.ignore_index
72
- )
70
+ return LigerJSDFunction.apply(log_q, log_p, shift_labels, self.beta, self.ignore_index)
@@ -9,6 +9,4 @@ class LigerKLDIVLoss(nn.KLDivLoss):
9
9
  self.eps = eps
10
10
 
11
11
  def forward(self, y_pred, y_true):
12
- return LigerKLDivLossFunction.apply(
13
- y_pred, y_true, self.reduction, self.log_target, self.eps
14
- )
12
+ return LigerKLDivLossFunction.apply(y_pred, y_true, self.reduction, self.log_target, self.eps)
@@ -13,18 +13,12 @@ class LigerLayerNorm(nn.Module):
13
13
  ], f"init_fn must be either 'ones' or 'zeros', got {init_fn}"
14
14
  self.hidden_size = hidden_size
15
15
  self.eps = eps
16
- self.weight = nn.Parameter(
17
- torch.ones(hidden_size) if init_fn == "ones" else torch.zeros(hidden_size)
18
- )
19
- self.bias = nn.Parameter(
20
- torch.randn(hidden_size) if bias else torch.zeros(hidden_size)
21
- )
16
+ self.weight = nn.Parameter(torch.ones(hidden_size) if init_fn == "ones" else torch.zeros(hidden_size))
17
+ self.bias = nn.Parameter(torch.randn(hidden_size) if bias else torch.zeros(hidden_size))
22
18
  self.variance_epsilon = eps
23
19
 
24
20
  def forward(self, hidden_states):
25
- return LigerLayerNormFunction.apply(
26
- hidden_states, self.weight, self.bias, self.variance_epsilon
27
- )
21
+ return LigerLayerNormFunction.apply(hidden_states, self.weight, self.bias, self.variance_epsilon)
28
22
 
29
23
  def extra_repr(self):
30
24
  return f"{self.hidden_size}, eps={self.eps}"
@@ -1,27 +1,23 @@
1
- from typing import List, Optional, Tuple, Union
1
+ from typing import List
2
+ from typing import Optional
3
+ from typing import Tuple
4
+ from typing import Union
2
5
 
3
6
  import torch
7
+
4
8
  from torch.nn import CrossEntropyLoss
5
9
  from transformers.cache_utils import Cache
6
10
  from transformers.modeling_outputs import CausalLMOutputWithPast
7
- from transformers.models.gemma.modeling_gemma import (
8
- _CONFIG_FOR_DOC,
9
- GEMMA_INPUTS_DOCSTRING,
10
- )
11
- from transformers.utils import (
12
- add_start_docstrings_to_model_forward,
13
- replace_return_docstrings,
14
- )
11
+ from transformers.models.gemma.modeling_gemma import _CONFIG_FOR_DOC
12
+ from transformers.models.gemma.modeling_gemma import GEMMA_INPUTS_DOCSTRING
13
+ from transformers.utils import add_start_docstrings_to_model_forward
14
+ from transformers.utils import replace_return_docstrings
15
15
 
16
- from liger_kernel.transformers.fused_linear_cross_entropy import (
17
- LigerFusedLinearCrossEntropyLoss,
18
- )
16
+ from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
19
17
 
20
18
 
21
19
  @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
22
- @replace_return_docstrings(
23
- output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
24
- )
20
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
25
21
  def lce_forward_deprecated(
26
22
  self,
27
23
  input_ids: torch.LongTensor = None,
@@ -64,19 +60,11 @@ def lce_forward_deprecated(
64
60
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
65
61
  "What is your favorite condiment?"
66
62
  ```"""
67
- output_attentions = (
68
- output_attentions
69
- if output_attentions is not None
70
- else self.config.output_attentions
71
- )
63
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
72
64
  output_hidden_states = (
73
- output_hidden_states
74
- if output_hidden_states is not None
75
- else self.config.output_hidden_states
76
- )
77
- return_dict = (
78
- return_dict if return_dict is not None else self.config.use_return_dict
65
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
79
66
  )
67
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
80
68
 
81
69
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
82
70
  outputs = self.model(
@@ -139,9 +127,7 @@ def lce_forward_deprecated(
139
127
 
140
128
 
141
129
  @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
142
- @replace_return_docstrings(
143
- output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
144
- )
130
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
145
131
  def lce_forward(
146
132
  self,
147
133
  input_ids: torch.LongTensor = None,
@@ -188,19 +174,11 @@ def lce_forward(
188
174
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
189
175
  "What is your favorite condiment?"
190
176
  ```"""
191
- output_attentions = (
192
- output_attentions
193
- if output_attentions is not None
194
- else self.config.output_attentions
195
- )
177
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
196
178
  output_hidden_states = (
197
- output_hidden_states
198
- if output_hidden_states is not None
199
- else self.config.output_hidden_states
200
- )
201
- return_dict = (
202
- return_dict if return_dict is not None else self.config.use_return_dict
179
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
203
180
  )
181
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
204
182
 
205
183
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
206
184
  outputs = self.model(
@@ -1,22 +1,20 @@
1
1
  import logging
2
- from typing import Optional, Tuple, Union
2
+
3
+ from typing import Optional
4
+ from typing import Tuple
5
+ from typing import Union
3
6
 
4
7
  import torch
8
+
5
9
  from torch.nn import CrossEntropyLoss
6
10
  from transformers.cache_utils import HybridCache
7
11
  from transformers.modeling_outputs import CausalLMOutputWithPast
8
- from transformers.models.gemma2.modeling_gemma2 import (
9
- _CONFIG_FOR_DOC,
10
- GEMMA2_INPUTS_DOCSTRING,
11
- )
12
- from transformers.utils import (
13
- add_start_docstrings_to_model_forward,
14
- replace_return_docstrings,
15
- )
16
-
17
- from liger_kernel.transformers.fused_linear_cross_entropy import (
18
- LigerFusedLinearCrossEntropyLoss,
19
- )
12
+ from transformers.models.gemma2.modeling_gemma2 import _CONFIG_FOR_DOC
13
+ from transformers.models.gemma2.modeling_gemma2 import GEMMA2_INPUTS_DOCSTRING
14
+ from transformers.utils import add_start_docstrings_to_model_forward
15
+ from transformers.utils import replace_return_docstrings
16
+
17
+ from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
20
18
 
21
19
  logger = logging.getLogger(__name__)
22
20
 
@@ -63,19 +61,11 @@ def lce_forward_deprecated(
63
61
  "It is strongly recommended to train Gemma2 models with the `eager` attention implementation "
64
62
  f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`."
65
63
  )
66
- output_attentions = (
67
- output_attentions
68
- if output_attentions is not None
69
- else self.config.output_attentions
70
- )
64
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
71
65
  output_hidden_states = (
72
- output_hidden_states
73
- if output_hidden_states is not None
74
- else self.config.output_hidden_states
75
- )
76
- return_dict = (
77
- return_dict if return_dict is not None else self.config.use_return_dict
66
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
78
67
  )
68
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
79
69
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
80
70
  outputs = self.model(
81
71
  input_ids=input_ids,
@@ -104,9 +94,7 @@ def lce_forward_deprecated(
104
94
  shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
105
95
  shift_labels = shift_labels.view(-1)
106
96
 
107
- lce = LigerFusedLinearCrossEntropyLoss(
108
- softcap=self.config.final_logit_softcapping
109
- )
97
+ lce = LigerFusedLinearCrossEntropyLoss(softcap=self.config.final_logit_softcapping)
110
98
  loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
111
99
 
112
100
  else:
@@ -146,9 +134,7 @@ def lce_forward_deprecated(
146
134
 
147
135
 
148
136
  @add_start_docstrings_to_model_forward(GEMMA2_INPUTS_DOCSTRING)
149
- @replace_return_docstrings(
150
- output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
151
- )
137
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
152
138
  def lce_forward(
153
139
  self,
154
140
  input_ids: torch.LongTensor = None,
@@ -201,19 +187,11 @@ def lce_forward(
201
187
  "It is strongly recommended to train Gemma2 models with the `eager` attention implementation "
202
188
  f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`."
203
189
  )
204
- output_attentions = (
205
- output_attentions
206
- if output_attentions is not None
207
- else self.config.output_attentions
208
- )
190
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
209
191
  output_hidden_states = (
210
- output_hidden_states
211
- if output_hidden_states is not None
212
- else self.config.output_hidden_states
213
- )
214
- return_dict = (
215
- return_dict if return_dict is not None else self.config.use_return_dict
192
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
216
193
  )
194
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
217
195
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
218
196
  outputs = self.model(
219
197
  input_ids=input_ids,
@@ -1,30 +1,27 @@
1
- from typing import TYPE_CHECKING, List, Optional, Tuple, Union
1
+ from typing import TYPE_CHECKING
2
+ from typing import List
3
+ from typing import Optional
4
+ from typing import Tuple
5
+ from typing import Union
2
6
 
3
7
  import torch
4
8
  import torch.nn.functional as F
9
+
5
10
  from torch.nn import CrossEntropyLoss
6
11
  from transformers.modeling_outputs import CausalLMOutputWithPast
7
- from transformers.models.llama.modeling_llama import (
8
- _CONFIG_FOR_DOC,
9
- LLAMA_INPUTS_DOCSTRING,
10
- )
11
- from transformers.utils import (
12
- add_start_docstrings_to_model_forward,
13
- replace_return_docstrings,
14
- )
15
-
16
- from liger_kernel.transformers.fused_linear_cross_entropy import (
17
- LigerFusedLinearCrossEntropyLoss,
18
- )
12
+ from transformers.models.llama.modeling_llama import _CONFIG_FOR_DOC
13
+ from transformers.models.llama.modeling_llama import LLAMA_INPUTS_DOCSTRING
14
+ from transformers.utils import add_start_docstrings_to_model_forward
15
+ from transformers.utils import replace_return_docstrings
16
+
17
+ from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
19
18
 
20
19
  if TYPE_CHECKING:
21
20
  from transformers.cache_utils import Cache
22
21
 
23
22
 
24
23
  @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
25
- @replace_return_docstrings(
26
- output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
27
- )
24
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
28
25
  def lce_forward_deprecated(
29
26
  self,
30
27
  input_ids: torch.LongTensor = None,
@@ -67,19 +64,11 @@ def lce_forward_deprecated(
67
64
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
68
65
  "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
69
66
  ```"""
70
- output_attentions = (
71
- output_attentions
72
- if output_attentions is not None
73
- else self.config.output_attentions
74
- )
67
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
75
68
  output_hidden_states = (
76
- output_hidden_states
77
- if output_hidden_states is not None
78
- else self.config.output_hidden_states
79
- )
80
- return_dict = (
81
- return_dict if return_dict is not None else self.config.use_return_dict
69
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
82
70
  )
71
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
83
72
 
84
73
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
85
74
  outputs = self.model(
@@ -113,13 +102,8 @@ def lce_forward_deprecated(
113
102
 
114
103
  else:
115
104
  if self.config.pretraining_tp > 1:
116
- lm_head_slices = self.lm_head.weight.split(
117
- self.vocab_size // self.config.pretraining_tp, dim=0
118
- )
119
- logits = [
120
- F.linear(hidden_states, lm_head_slices[i])
121
- for i in range(self.config.pretraining_tp)
122
- ]
105
+ lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
106
+ logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
123
107
  logits = torch.cat(logits, dim=-1)
124
108
  else:
125
109
  logits = self.lm_head(hidden_states)
@@ -151,9 +135,7 @@ def lce_forward_deprecated(
151
135
 
152
136
 
153
137
  @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
154
- @replace_return_docstrings(
155
- output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
156
- )
138
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
157
139
  def lce_forward(
158
140
  self,
159
141
  input_ids: torch.LongTensor = None,
@@ -201,19 +183,11 @@ def lce_forward(
201
183
  "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
202
184
  ```"""
203
185
 
204
- output_attentions = (
205
- output_attentions
206
- if output_attentions is not None
207
- else self.config.output_attentions
208
- )
186
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
209
187
  output_hidden_states = (
210
- output_hidden_states
211
- if output_hidden_states is not None
212
- else self.config.output_hidden_states
213
- )
214
- return_dict = (
215
- return_dict if return_dict is not None else self.config.use_return_dict
188
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
216
189
  )
190
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
217
191
 
218
192
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
219
193
  outputs = self.model(