liger-kernel 0.4.1__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. liger_kernel/__init__.py +0 -0
  2. liger_kernel/chunked_loss/__init__.py +4 -0
  3. liger_kernel/chunked_loss/cpo_loss.py +107 -0
  4. liger_kernel/chunked_loss/dpo_loss.py +135 -0
  5. liger_kernel/chunked_loss/functional.py +9 -0
  6. liger_kernel/chunked_loss/fused_linear_distillation.py +252 -0
  7. liger_kernel/chunked_loss/fused_linear_preference.py +386 -0
  8. liger_kernel/chunked_loss/orpo_loss.py +113 -0
  9. liger_kernel/chunked_loss/simpo_loss.py +115 -0
  10. liger_kernel/env_report.py +22 -0
  11. liger_kernel/ops/cross_entropy.py +17 -10
  12. liger_kernel/ops/fused_linear_cross_entropy.py +1 -11
  13. liger_kernel/ops/fused_linear_jsd.py +1 -1
  14. liger_kernel/ops/jsd.py +19 -10
  15. liger_kernel/ops/layer_norm.py +6 -1
  16. liger_kernel/ops/qwen2vl_mrope.py +238 -0
  17. liger_kernel/ops/rms_norm.py +6 -1
  18. liger_kernel/ops/utils.py +5 -2
  19. liger_kernel/transformers/__init__.py +1 -0
  20. liger_kernel/transformers/functional.py +128 -11
  21. liger_kernel/transformers/fused_linear_jsd.py +1 -4
  22. liger_kernel/transformers/jsd.py +1 -4
  23. liger_kernel/transformers/model/qwen2_vl.py +43 -17
  24. liger_kernel/transformers/monkey_patch.py +11 -6
  25. liger_kernel/transformers/orpo_trainer.py +171 -0
  26. liger_kernel/transformers/qwen2vl_mrope.py +20 -0
  27. liger_kernel/utils.py +13 -0
  28. {liger_kernel-0.4.1.dist-info → liger_kernel-0.5.0.dist-info}/METADATA +80 -123
  29. {liger_kernel-0.4.1.dist-info → liger_kernel-0.5.0.dist-info}/RECORD +33 -20
  30. {liger_kernel-0.4.1.dist-info → liger_kernel-0.5.0.dist-info}/WHEEL +1 -1
  31. {liger_kernel-0.4.1.dist-info → liger_kernel-0.5.0.dist-info}/LICENSE +0 -0
  32. {liger_kernel-0.4.1.dist-info → liger_kernel-0.5.0.dist-info}/NOTICE +0 -0
  33. {liger_kernel-0.4.1.dist-info → liger_kernel-0.5.0.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,9 @@
1
1
  from typing import List, Optional, Tuple, Union
2
2
 
3
3
  import torch
4
+ from packaging import version
4
5
  from torch.nn import CrossEntropyLoss
6
+ from transformers import __version__ as transformers_version
5
7
  from transformers.models.qwen2_vl.modeling_qwen2_vl import (
6
8
  _CONFIG_FOR_DOC,
7
9
  QWEN2_VL_INPUTS_DOCSTRING,
@@ -80,8 +82,6 @@ def lce_forward(
80
82
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
81
83
  "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
82
84
  ```"""
83
- # FIXME: The code is outdated and not compatible with transformer >= 4.46.1
84
-
85
85
  output_attentions = (
86
86
  output_attentions
87
87
  if output_attentions is not None
@@ -100,27 +100,53 @@ def lce_forward(
100
100
  inputs_embeds = self.model.embed_tokens(input_ids)
101
101
  if pixel_values is not None:
102
102
  pixel_values = pixel_values.type(self.visual.get_dtype())
103
- image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw).to(
104
- inputs_embeds.device
103
+ image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
104
+ n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
105
+ n_image_features = image_embeds.shape[0]
106
+ if n_image_tokens != n_image_features:
107
+ raise ValueError(
108
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
109
+ )
110
+ image_mask = (
111
+ (input_ids == self.config.image_token_id)
112
+ .unsqueeze(-1)
113
+ .expand_as(inputs_embeds)
114
+ .to(inputs_embeds.device)
105
115
  )
106
- image_mask = input_ids == self.config.image_token_id
107
- if self.training:
108
- inputs_embeds = inputs_embeds.clone()
109
- inputs_embeds[image_mask] = image_embeds
116
+ image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
117
+ inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
118
+
110
119
  if pixel_values_videos is not None:
111
120
  pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype())
112
- video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw).to(
113
- inputs_embeds.device
121
+ video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
122
+ n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
123
+ n_video_features = video_embeds.shape[0]
124
+ if n_video_tokens != n_video_features:
125
+ raise ValueError(
126
+ f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
127
+ )
128
+ video_mask = (
129
+ (input_ids == self.config.video_token_id)
130
+ .unsqueeze(-1)
131
+ .expand_as(inputs_embeds)
132
+ .to(inputs_embeds.device)
114
133
  )
115
- video_mask = input_ids == self.config.video_token_id
116
- inputs_embeds[video_mask] = video_embeds
134
+ video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
135
+ inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
136
+
117
137
  if attention_mask is not None:
118
138
  attention_mask = attention_mask.to(inputs_embeds.device)
119
- # The code is copied from https://github.com/huggingface/transformers/pull/33487
120
- if position_ids is None and input_ids is not None:
121
- position_ids, _ = self.get_rope_index(
122
- input_ids, image_grid_thw, video_grid_thw, attention_mask
123
- )
139
+
140
+ if version.parse(transformers_version) > version.parse("4.46.2"):
141
+ # NOTE: this bug fix for qwen2-vl is not applied until transformers 4.47.0
142
+ # https://github.com/huggingface/transformers/issues/33401
143
+ # While correct, this breaks equivalence with past versions of Qwen2-VL from
144
+ # transformers and leads to failed tests or users noticing differences in results.
145
+ # TODO: remove above conditional when liger drops support for transformers<4.47.0
146
+ if position_ids is None and input_ids is not None:
147
+ position_ids, _ = self.get_rope_index(
148
+ input_ids, image_grid_thw, video_grid_thw, attention_mask
149
+ )
124
150
 
125
151
  outputs = self.model(
126
152
  input_ids=None,
@@ -36,6 +36,7 @@ from liger_kernel.transformers.model.qwen2 import lce_forward as qwen2_lce_forwa
36
36
  from liger_kernel.transformers.model.qwen2 import (
37
37
  lce_forward_deprecated as qwen2_lce_forward_deprecated,
38
38
  )
39
+ from liger_kernel.transformers.qwen2vl_mrope import liger_multimodal_rotary_pos_emb
39
40
  from liger_kernel.transformers.rms_norm import LigerRMSNorm
40
41
  from liger_kernel.transformers.rope import liger_rotary_pos_emb
41
42
  from liger_kernel.transformers.swiglu import (
@@ -56,12 +57,15 @@ def _bind_method_to_module(module, method_name: str, new_method: Callable):
56
57
  module.__dict__[method_name] = new_method.__get__(module, module.__class__)
57
58
 
58
59
 
59
- def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama"):
60
+ def _patch_rms_norm_module(
61
+ module, offset=0.0, eps=1e-6, casting_mode="llama", in_place=True
62
+ ):
60
63
  module.offset = offset
61
64
  module.casting_mode = casting_mode
62
65
  module.variance_epsilon = (
63
66
  getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
64
67
  )
68
+ module.in_place = in_place
65
69
  _bind_method_to_module(module, "forward", LigerRMSNorm.forward)
66
70
  _bind_method_to_module(module, "extra_repr", LigerRMSNorm.extra_repr)
67
71
 
@@ -510,7 +514,7 @@ def apply_liger_kernel_to_gemma2(
510
514
  LigerRMSNorm, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False
511
515
  )
512
516
  _patch_rms_norm_module_for_gemma2 = partial(
513
- _patch_rms_norm_module, offset=1.0, casting_mode="gemma"
517
+ _patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
514
518
  )
515
519
 
516
520
  if rope:
@@ -607,9 +611,7 @@ def apply_liger_kernel_to_qwen2(
607
611
  logger.warning(TRANSFORMER_DEPRECATION_WARNING)
608
612
  modeling_qwen2.CrossEntropyLoss = LigerCrossEntropyLoss
609
613
 
610
- # import pdb; pdb.set_trace()
611
614
  if fused_linear_cross_entropy:
612
-
613
615
  if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
614
616
  modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
615
617
  else: # if version < 4.46.1
@@ -641,6 +643,7 @@ def apply_liger_kernel_to_qwen2(
641
643
 
642
644
 
643
645
  def apply_liger_kernel_to_qwen2_vl(
646
+ rope: bool = True,
644
647
  cross_entropy: bool = False,
645
648
  fused_linear_cross_entropy: bool = True,
646
649
  rms_norm: bool = True,
@@ -675,8 +678,10 @@ def apply_liger_kernel_to_qwen2_vl(
675
678
  lce_forward as qwen2_vl_lce_forward,
676
679
  )
677
680
 
678
- # TODO: Support Qwen2-VL's multimodal RoPE implementation
679
-
681
+ if rope:
682
+ modeling_qwen2_vl.apply_multimodal_rotary_pos_emb = (
683
+ liger_multimodal_rotary_pos_emb
684
+ )
680
685
  if rms_norm:
681
686
  # https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L439
682
687
  modeling_qwen2_vl.Qwen2RMSNorm = LigerRMSNorm
@@ -0,0 +1,171 @@
1
+ from typing import Any, Callable, Dict, List, Literal, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.distributed.fsdp import FullyShardedDataParallel
6
+ from trl.trainer import ORPOTrainer
7
+
8
+ from liger_kernel.chunked_loss import LigerFusedLinearORPOLoss
9
+
10
+
11
+ class _FSDPForwardRedirection:
12
+ """
13
+ Modified based on
14
+ https://github.com/Lightning-AI/pytorch-lightning/blob/d3f9c83d6efa4f1def36aa6c199600946cdb9117/src/lightning/pytorch/strategies/strategy.py#L601-L648
15
+ Redirect a method call through FullyShardedDataParallel.forward so that the FSDP module's root pre-forward and
16
+ post-forward can be properly executed around the method call.
17
+ This is needed in cases where we call a submodule of a FSDP module. For instance, when we want to call only
18
+ the `LlamaModel` part out of a FSDP-wrapped `LlamaForCausalLM` to get the hidden states without involving
19
+ GPU-memory-heavy `lm_head` and cross entropy computation, doing this directly (i.e. `model.model.forward()`)
20
+ will not work because the first `nn.Emebedding` layer is not independently wrapped as a FSDP module (because of
21
+ the transformer-based wrapping policy), and not calling it through FSDP root module forward will not all-gather
22
+ its parameter, thus resulting in "RuntimeError: 'weight' must be 2-D" error. Similarly, if we want to call just
23
+ the `lm_head` part of a model, we need this trick too to properly get its params all-gathered.
24
+ """
25
+
26
+ def __call__(
27
+ self,
28
+ wrapper_module: FullyShardedDataParallel,
29
+ method: Callable,
30
+ *args: Any,
31
+ **kwargs: Any,
32
+ ):
33
+ """Reroutes a method call through the `wrapper_module`'s `forward` method.
34
+ Args:
35
+ wrapper_module: The module that has `original_module` wrapped.
36
+ original_module: The module that was wrapped inside `wrapper_module`.
37
+ method_name: The name of the method that should be called on the `original_module` after inputs get
38
+ redirected through the `wrapper_module`'s `forward` method.
39
+ *args: The positional arguments to the method `method_name`. They will get passed to a patched
40
+ `forward` method instead.
41
+ **kwargs: The keyword arguments to the method `method_name`. They will get passed to a patched
42
+ `forward` method instead.
43
+ """
44
+ assert isinstance(wrapper_module, FullyShardedDataParallel)
45
+ original_module = wrapper_module._fsdp_wrapped_module
46
+ original_forward = original_module.forward
47
+
48
+ def wrapped_forward(*_args: Any, **_kwargs: Any) -> Any:
49
+ # Unpatch ourselves immediately before calling the method `method_name`
50
+ # because itself may want to call the real `forward`
51
+ original_module.forward = original_forward # type: ignore[method-assign]
52
+ # Call the actual method e.g. `.training_step(...)`
53
+ out = method(*_args, **_kwargs)
54
+ return out
55
+
56
+ # Patch the original_module's forward so we can redirect the arguments back to the real method
57
+ original_module.forward = wrapped_forward # type: ignore[method-assign]
58
+ wrapper_output = wrapper_module(*args, **kwargs)
59
+ return wrapper_output
60
+
61
+
62
+ class LigerORPOTrainer(ORPOTrainer):
63
+ def concatenated_forward(
64
+ self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
65
+ ) -> Tuple[
66
+ torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor
67
+ ]:
68
+ """
69
+ Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
70
+ We do this to avoid doing two forward passes, because it's faster for FSDP.
71
+ """
72
+ concatenated_batch = self.concatenated_inputs(
73
+ batch,
74
+ is_encoder_decoder=self.is_encoder_decoder,
75
+ label_pad_token_id=self.label_pad_token_id,
76
+ padding_value=self.padding_value,
77
+ device=self.accelerator.device,
78
+ )
79
+ # if self.accelerator.is_main_process:
80
+ # import pdb; pdb.set_trace()
81
+ # torch.distributed.barrier()
82
+ model_kwargs = (
83
+ {
84
+ "decoder_input_ids": self._shift_right(
85
+ concatenated_batch["concatenated_labels"]
86
+ ),
87
+ }
88
+ if self.is_encoder_decoder
89
+ else {}
90
+ )
91
+
92
+ if self.aux_loss_enabled:
93
+ model_kwargs["output_router_logits"] = True
94
+
95
+ if isinstance(model, FullyShardedDataParallel):
96
+ outputs = _FSDPForwardRedirection()(
97
+ model,
98
+ model._fsdp_wrapped_module.model,
99
+ concatenated_batch["concatenated_input_ids"],
100
+ attention_mask=concatenated_batch["concatenated_attention_mask"],
101
+ use_cache=False,
102
+ **model_kwargs,
103
+ )
104
+ else:
105
+ if isinstance(model, torch.nn.DataParallel):
106
+ model = model.module
107
+ outputs = model.model(
108
+ concatenated_batch["concatenated_input_ids"],
109
+ attention_mask=concatenated_batch["concatenated_attention_mask"],
110
+ use_cache=False,
111
+ **model_kwargs,
112
+ )
113
+
114
+ orpo_loss_fn = LigerFusedLinearORPOLoss(
115
+ ignore_index=self.label_pad_token_id, beta=self.beta
116
+ )
117
+
118
+ def orpo_partial(lm_head, last_hidden_state, concatenated_labels):
119
+ return orpo_loss_fn(
120
+ lm_head.weight, last_hidden_state, concatenated_labels, lm_head.bias
121
+ )
122
+
123
+ orpo_loss, aux_outputs = _FSDPForwardRedirection()(
124
+ model,
125
+ orpo_partial,
126
+ model.lm_head,
127
+ outputs.last_hidden_state,
128
+ concatenated_batch["concatenated_labels"],
129
+ )
130
+ return orpo_loss, aux_outputs
131
+
132
+ def get_batch_loss_metrics(
133
+ self,
134
+ model,
135
+ batch: Dict[str, Union[List, torch.LongTensor]],
136
+ train_eval: Literal["train", "eval"] = "train",
137
+ ):
138
+ """Compute the ORPO loss and other metrics for the given batch of inputs for train or test."""
139
+ metrics = {}
140
+ loss, aux_outputs = self.concatenated_forward(model, batch)
141
+ (
142
+ policy_chosen_logps,
143
+ policy_rejected_logps,
144
+ policy_chosen_logits,
145
+ policy_rejected_logits,
146
+ policy_nll_loss,
147
+ ) = aux_outputs[:5]
148
+
149
+ # return loss, metrics
150
+ chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = aux_outputs[
151
+ 5:
152
+ ]
153
+
154
+ reward_accuracies = (chosen_rewards > rejected_rewards).float()
155
+
156
+ prefix = "eval_" if train_eval == "eval" else ""
157
+ metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean()
158
+ metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean()
159
+ metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean()
160
+ metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean()
161
+ metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean()
162
+ metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean()
163
+ metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean()
164
+ metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean()
165
+ metrics[f"{prefix}nll_loss"] = policy_nll_loss.detach().mean()
166
+ metrics[f"{prefix}log_odds_ratio"] = log_odds_ratio
167
+ metrics[f"{prefix}log_odds_chosen"] = log_odds_chosen
168
+ for k, v in metrics.items():
169
+ metrics[k] = v.item()
170
+
171
+ return loss, metrics
@@ -0,0 +1,20 @@
1
+ from liger_kernel.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction
2
+
3
+
4
+ def liger_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
5
+ """
6
+ Applies Multimodal Rotary Positional Embedding (M-RoPE) operation to query and key states.
7
+
8
+ Args:
9
+ q (torch.Tensor): The query tensor of shape (bsz, n_q_head, seq_len, head_dim).
10
+ k (torch.Tensor): The key tensor of shape (bsz, n_kv_head, seq_len, head_dim).
11
+ cos (torch.Tensor): The cosine tensor of shape (3, 1, seq_len, head_dim).
12
+ sin (torch.Tensor): The sine tensor of shape (3, 1, seq_len, head_dim).
13
+ mrope_section (List[int]): The multimodal rope section for channel dimension of temporal, height and width in rope calculation.
14
+ unsqueeze_dim (int, optional): The dimension to unsqueeze. Defaults to 1.
15
+
16
+ Returns:
17
+ Tuple[torch.Tensor, torch.Tensor]: The query and key tensors after applying the M-RoPE operation.
18
+ """
19
+
20
+ return LigerQwen2VLMRopeFunction.apply(q, k, cos, sin, mrope_section, unsqueeze_dim)
liger_kernel/utils.py ADDED
@@ -0,0 +1,13 @@
1
+ import torch
2
+
3
+
4
+ def infer_device():
5
+ """
6
+ Get current device name based on available devices
7
+ """
8
+ if torch.cuda.is_available():
9
+ return "cuda"
10
+ elif torch.xpu.is_available():
11
+ return "xpu"
12
+ else:
13
+ return "cpu"