liger-kernel 0.4.1__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/__init__.py +0 -0
- liger_kernel/chunked_loss/__init__.py +4 -0
- liger_kernel/chunked_loss/cpo_loss.py +107 -0
- liger_kernel/chunked_loss/dpo_loss.py +135 -0
- liger_kernel/chunked_loss/functional.py +9 -0
- liger_kernel/chunked_loss/fused_linear_distillation.py +252 -0
- liger_kernel/chunked_loss/fused_linear_preference.py +386 -0
- liger_kernel/chunked_loss/orpo_loss.py +113 -0
- liger_kernel/chunked_loss/simpo_loss.py +115 -0
- liger_kernel/env_report.py +22 -0
- liger_kernel/ops/cross_entropy.py +17 -10
- liger_kernel/ops/fused_linear_cross_entropy.py +1 -11
- liger_kernel/ops/fused_linear_jsd.py +1 -1
- liger_kernel/ops/jsd.py +19 -10
- liger_kernel/ops/layer_norm.py +6 -1
- liger_kernel/ops/qwen2vl_mrope.py +238 -0
- liger_kernel/ops/rms_norm.py +6 -1
- liger_kernel/ops/utils.py +5 -2
- liger_kernel/transformers/__init__.py +1 -0
- liger_kernel/transformers/functional.py +128 -11
- liger_kernel/transformers/fused_linear_jsd.py +1 -4
- liger_kernel/transformers/jsd.py +1 -4
- liger_kernel/transformers/model/qwen2_vl.py +43 -17
- liger_kernel/transformers/monkey_patch.py +11 -6
- liger_kernel/transformers/orpo_trainer.py +171 -0
- liger_kernel/transformers/qwen2vl_mrope.py +20 -0
- liger_kernel/utils.py +13 -0
- {liger_kernel-0.4.1.dist-info → liger_kernel-0.5.0.dist-info}/METADATA +80 -123
- {liger_kernel-0.4.1.dist-info → liger_kernel-0.5.0.dist-info}/RECORD +33 -20
- {liger_kernel-0.4.1.dist-info → liger_kernel-0.5.0.dist-info}/WHEEL +1 -1
- {liger_kernel-0.4.1.dist-info → liger_kernel-0.5.0.dist-info}/LICENSE +0 -0
- {liger_kernel-0.4.1.dist-info → liger_kernel-0.5.0.dist-info}/NOTICE +0 -0
- {liger_kernel-0.4.1.dist-info → liger_kernel-0.5.0.dist-info}/top_level.txt +0 -0
|
@@ -1,7 +1,9 @@
|
|
|
1
1
|
from typing import List, Optional, Tuple, Union
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
|
+
from packaging import version
|
|
4
5
|
from torch.nn import CrossEntropyLoss
|
|
6
|
+
from transformers import __version__ as transformers_version
|
|
5
7
|
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
|
|
6
8
|
_CONFIG_FOR_DOC,
|
|
7
9
|
QWEN2_VL_INPUTS_DOCSTRING,
|
|
@@ -80,8 +82,6 @@ def lce_forward(
|
|
|
80
82
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
81
83
|
"The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
|
|
82
84
|
```"""
|
|
83
|
-
# FIXME: The code is outdated and not compatible with transformer >= 4.46.1
|
|
84
|
-
|
|
85
85
|
output_attentions = (
|
|
86
86
|
output_attentions
|
|
87
87
|
if output_attentions is not None
|
|
@@ -100,27 +100,53 @@ def lce_forward(
|
|
|
100
100
|
inputs_embeds = self.model.embed_tokens(input_ids)
|
|
101
101
|
if pixel_values is not None:
|
|
102
102
|
pixel_values = pixel_values.type(self.visual.get_dtype())
|
|
103
|
-
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
|
104
|
-
|
|
103
|
+
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
|
104
|
+
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
|
|
105
|
+
n_image_features = image_embeds.shape[0]
|
|
106
|
+
if n_image_tokens != n_image_features:
|
|
107
|
+
raise ValueError(
|
|
108
|
+
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
|
109
|
+
)
|
|
110
|
+
image_mask = (
|
|
111
|
+
(input_ids == self.config.image_token_id)
|
|
112
|
+
.unsqueeze(-1)
|
|
113
|
+
.expand_as(inputs_embeds)
|
|
114
|
+
.to(inputs_embeds.device)
|
|
105
115
|
)
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
inputs_embeds[image_mask] = image_embeds
|
|
116
|
+
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
|
117
|
+
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
|
118
|
+
|
|
110
119
|
if pixel_values_videos is not None:
|
|
111
120
|
pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype())
|
|
112
|
-
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
|
|
113
|
-
|
|
121
|
+
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
|
|
122
|
+
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
|
|
123
|
+
n_video_features = video_embeds.shape[0]
|
|
124
|
+
if n_video_tokens != n_video_features:
|
|
125
|
+
raise ValueError(
|
|
126
|
+
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
|
|
127
|
+
)
|
|
128
|
+
video_mask = (
|
|
129
|
+
(input_ids == self.config.video_token_id)
|
|
130
|
+
.unsqueeze(-1)
|
|
131
|
+
.expand_as(inputs_embeds)
|
|
132
|
+
.to(inputs_embeds.device)
|
|
114
133
|
)
|
|
115
|
-
|
|
116
|
-
inputs_embeds
|
|
134
|
+
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
|
135
|
+
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
|
136
|
+
|
|
117
137
|
if attention_mask is not None:
|
|
118
138
|
attention_mask = attention_mask.to(inputs_embeds.device)
|
|
119
|
-
|
|
120
|
-
if
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
139
|
+
|
|
140
|
+
if version.parse(transformers_version) > version.parse("4.46.2"):
|
|
141
|
+
# NOTE: this bug fix for qwen2-vl is not applied until transformers 4.47.0
|
|
142
|
+
# https://github.com/huggingface/transformers/issues/33401
|
|
143
|
+
# While correct, this breaks equivalence with past versions of Qwen2-VL from
|
|
144
|
+
# transformers and leads to failed tests or users noticing differences in results.
|
|
145
|
+
# TODO: remove above conditional when liger drops support for transformers<4.47.0
|
|
146
|
+
if position_ids is None and input_ids is not None:
|
|
147
|
+
position_ids, _ = self.get_rope_index(
|
|
148
|
+
input_ids, image_grid_thw, video_grid_thw, attention_mask
|
|
149
|
+
)
|
|
124
150
|
|
|
125
151
|
outputs = self.model(
|
|
126
152
|
input_ids=None,
|
|
@@ -36,6 +36,7 @@ from liger_kernel.transformers.model.qwen2 import lce_forward as qwen2_lce_forwa
|
|
|
36
36
|
from liger_kernel.transformers.model.qwen2 import (
|
|
37
37
|
lce_forward_deprecated as qwen2_lce_forward_deprecated,
|
|
38
38
|
)
|
|
39
|
+
from liger_kernel.transformers.qwen2vl_mrope import liger_multimodal_rotary_pos_emb
|
|
39
40
|
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
|
40
41
|
from liger_kernel.transformers.rope import liger_rotary_pos_emb
|
|
41
42
|
from liger_kernel.transformers.swiglu import (
|
|
@@ -56,12 +57,15 @@ def _bind_method_to_module(module, method_name: str, new_method: Callable):
|
|
|
56
57
|
module.__dict__[method_name] = new_method.__get__(module, module.__class__)
|
|
57
58
|
|
|
58
59
|
|
|
59
|
-
def _patch_rms_norm_module(
|
|
60
|
+
def _patch_rms_norm_module(
|
|
61
|
+
module, offset=0.0, eps=1e-6, casting_mode="llama", in_place=True
|
|
62
|
+
):
|
|
60
63
|
module.offset = offset
|
|
61
64
|
module.casting_mode = casting_mode
|
|
62
65
|
module.variance_epsilon = (
|
|
63
66
|
getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
|
64
67
|
)
|
|
68
|
+
module.in_place = in_place
|
|
65
69
|
_bind_method_to_module(module, "forward", LigerRMSNorm.forward)
|
|
66
70
|
_bind_method_to_module(module, "extra_repr", LigerRMSNorm.extra_repr)
|
|
67
71
|
|
|
@@ -510,7 +514,7 @@ def apply_liger_kernel_to_gemma2(
|
|
|
510
514
|
LigerRMSNorm, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False
|
|
511
515
|
)
|
|
512
516
|
_patch_rms_norm_module_for_gemma2 = partial(
|
|
513
|
-
_patch_rms_norm_module, offset=1.0, casting_mode="gemma"
|
|
517
|
+
_patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
|
|
514
518
|
)
|
|
515
519
|
|
|
516
520
|
if rope:
|
|
@@ -607,9 +611,7 @@ def apply_liger_kernel_to_qwen2(
|
|
|
607
611
|
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
608
612
|
modeling_qwen2.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
609
613
|
|
|
610
|
-
# import pdb; pdb.set_trace()
|
|
611
614
|
if fused_linear_cross_entropy:
|
|
612
|
-
|
|
613
615
|
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
614
616
|
modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
|
|
615
617
|
else: # if version < 4.46.1
|
|
@@ -641,6 +643,7 @@ def apply_liger_kernel_to_qwen2(
|
|
|
641
643
|
|
|
642
644
|
|
|
643
645
|
def apply_liger_kernel_to_qwen2_vl(
|
|
646
|
+
rope: bool = True,
|
|
644
647
|
cross_entropy: bool = False,
|
|
645
648
|
fused_linear_cross_entropy: bool = True,
|
|
646
649
|
rms_norm: bool = True,
|
|
@@ -675,8 +678,10 @@ def apply_liger_kernel_to_qwen2_vl(
|
|
|
675
678
|
lce_forward as qwen2_vl_lce_forward,
|
|
676
679
|
)
|
|
677
680
|
|
|
678
|
-
|
|
679
|
-
|
|
681
|
+
if rope:
|
|
682
|
+
modeling_qwen2_vl.apply_multimodal_rotary_pos_emb = (
|
|
683
|
+
liger_multimodal_rotary_pos_emb
|
|
684
|
+
)
|
|
680
685
|
if rms_norm:
|
|
681
686
|
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L439
|
|
682
687
|
modeling_qwen2_vl.Qwen2RMSNorm = LigerRMSNorm
|
|
@@ -0,0 +1,171 @@
|
|
|
1
|
+
from typing import Any, Callable, Dict, List, Literal, Tuple, Union
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn as nn
|
|
5
|
+
from torch.distributed.fsdp import FullyShardedDataParallel
|
|
6
|
+
from trl.trainer import ORPOTrainer
|
|
7
|
+
|
|
8
|
+
from liger_kernel.chunked_loss import LigerFusedLinearORPOLoss
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class _FSDPForwardRedirection:
|
|
12
|
+
"""
|
|
13
|
+
Modified based on
|
|
14
|
+
https://github.com/Lightning-AI/pytorch-lightning/blob/d3f9c83d6efa4f1def36aa6c199600946cdb9117/src/lightning/pytorch/strategies/strategy.py#L601-L648
|
|
15
|
+
Redirect a method call through FullyShardedDataParallel.forward so that the FSDP module's root pre-forward and
|
|
16
|
+
post-forward can be properly executed around the method call.
|
|
17
|
+
This is needed in cases where we call a submodule of a FSDP module. For instance, when we want to call only
|
|
18
|
+
the `LlamaModel` part out of a FSDP-wrapped `LlamaForCausalLM` to get the hidden states without involving
|
|
19
|
+
GPU-memory-heavy `lm_head` and cross entropy computation, doing this directly (i.e. `model.model.forward()`)
|
|
20
|
+
will not work because the first `nn.Emebedding` layer is not independently wrapped as a FSDP module (because of
|
|
21
|
+
the transformer-based wrapping policy), and not calling it through FSDP root module forward will not all-gather
|
|
22
|
+
its parameter, thus resulting in "RuntimeError: 'weight' must be 2-D" error. Similarly, if we want to call just
|
|
23
|
+
the `lm_head` part of a model, we need this trick too to properly get its params all-gathered.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __call__(
|
|
27
|
+
self,
|
|
28
|
+
wrapper_module: FullyShardedDataParallel,
|
|
29
|
+
method: Callable,
|
|
30
|
+
*args: Any,
|
|
31
|
+
**kwargs: Any,
|
|
32
|
+
):
|
|
33
|
+
"""Reroutes a method call through the `wrapper_module`'s `forward` method.
|
|
34
|
+
Args:
|
|
35
|
+
wrapper_module: The module that has `original_module` wrapped.
|
|
36
|
+
original_module: The module that was wrapped inside `wrapper_module`.
|
|
37
|
+
method_name: The name of the method that should be called on the `original_module` after inputs get
|
|
38
|
+
redirected through the `wrapper_module`'s `forward` method.
|
|
39
|
+
*args: The positional arguments to the method `method_name`. They will get passed to a patched
|
|
40
|
+
`forward` method instead.
|
|
41
|
+
**kwargs: The keyword arguments to the method `method_name`. They will get passed to a patched
|
|
42
|
+
`forward` method instead.
|
|
43
|
+
"""
|
|
44
|
+
assert isinstance(wrapper_module, FullyShardedDataParallel)
|
|
45
|
+
original_module = wrapper_module._fsdp_wrapped_module
|
|
46
|
+
original_forward = original_module.forward
|
|
47
|
+
|
|
48
|
+
def wrapped_forward(*_args: Any, **_kwargs: Any) -> Any:
|
|
49
|
+
# Unpatch ourselves immediately before calling the method `method_name`
|
|
50
|
+
# because itself may want to call the real `forward`
|
|
51
|
+
original_module.forward = original_forward # type: ignore[method-assign]
|
|
52
|
+
# Call the actual method e.g. `.training_step(...)`
|
|
53
|
+
out = method(*_args, **_kwargs)
|
|
54
|
+
return out
|
|
55
|
+
|
|
56
|
+
# Patch the original_module's forward so we can redirect the arguments back to the real method
|
|
57
|
+
original_module.forward = wrapped_forward # type: ignore[method-assign]
|
|
58
|
+
wrapper_output = wrapper_module(*args, **kwargs)
|
|
59
|
+
return wrapper_output
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class LigerORPOTrainer(ORPOTrainer):
|
|
63
|
+
def concatenated_forward(
|
|
64
|
+
self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
|
|
65
|
+
) -> Tuple[
|
|
66
|
+
torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor
|
|
67
|
+
]:
|
|
68
|
+
"""
|
|
69
|
+
Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
|
|
70
|
+
We do this to avoid doing two forward passes, because it's faster for FSDP.
|
|
71
|
+
"""
|
|
72
|
+
concatenated_batch = self.concatenated_inputs(
|
|
73
|
+
batch,
|
|
74
|
+
is_encoder_decoder=self.is_encoder_decoder,
|
|
75
|
+
label_pad_token_id=self.label_pad_token_id,
|
|
76
|
+
padding_value=self.padding_value,
|
|
77
|
+
device=self.accelerator.device,
|
|
78
|
+
)
|
|
79
|
+
# if self.accelerator.is_main_process:
|
|
80
|
+
# import pdb; pdb.set_trace()
|
|
81
|
+
# torch.distributed.barrier()
|
|
82
|
+
model_kwargs = (
|
|
83
|
+
{
|
|
84
|
+
"decoder_input_ids": self._shift_right(
|
|
85
|
+
concatenated_batch["concatenated_labels"]
|
|
86
|
+
),
|
|
87
|
+
}
|
|
88
|
+
if self.is_encoder_decoder
|
|
89
|
+
else {}
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
if self.aux_loss_enabled:
|
|
93
|
+
model_kwargs["output_router_logits"] = True
|
|
94
|
+
|
|
95
|
+
if isinstance(model, FullyShardedDataParallel):
|
|
96
|
+
outputs = _FSDPForwardRedirection()(
|
|
97
|
+
model,
|
|
98
|
+
model._fsdp_wrapped_module.model,
|
|
99
|
+
concatenated_batch["concatenated_input_ids"],
|
|
100
|
+
attention_mask=concatenated_batch["concatenated_attention_mask"],
|
|
101
|
+
use_cache=False,
|
|
102
|
+
**model_kwargs,
|
|
103
|
+
)
|
|
104
|
+
else:
|
|
105
|
+
if isinstance(model, torch.nn.DataParallel):
|
|
106
|
+
model = model.module
|
|
107
|
+
outputs = model.model(
|
|
108
|
+
concatenated_batch["concatenated_input_ids"],
|
|
109
|
+
attention_mask=concatenated_batch["concatenated_attention_mask"],
|
|
110
|
+
use_cache=False,
|
|
111
|
+
**model_kwargs,
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
orpo_loss_fn = LigerFusedLinearORPOLoss(
|
|
115
|
+
ignore_index=self.label_pad_token_id, beta=self.beta
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
def orpo_partial(lm_head, last_hidden_state, concatenated_labels):
|
|
119
|
+
return orpo_loss_fn(
|
|
120
|
+
lm_head.weight, last_hidden_state, concatenated_labels, lm_head.bias
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
orpo_loss, aux_outputs = _FSDPForwardRedirection()(
|
|
124
|
+
model,
|
|
125
|
+
orpo_partial,
|
|
126
|
+
model.lm_head,
|
|
127
|
+
outputs.last_hidden_state,
|
|
128
|
+
concatenated_batch["concatenated_labels"],
|
|
129
|
+
)
|
|
130
|
+
return orpo_loss, aux_outputs
|
|
131
|
+
|
|
132
|
+
def get_batch_loss_metrics(
|
|
133
|
+
self,
|
|
134
|
+
model,
|
|
135
|
+
batch: Dict[str, Union[List, torch.LongTensor]],
|
|
136
|
+
train_eval: Literal["train", "eval"] = "train",
|
|
137
|
+
):
|
|
138
|
+
"""Compute the ORPO loss and other metrics for the given batch of inputs for train or test."""
|
|
139
|
+
metrics = {}
|
|
140
|
+
loss, aux_outputs = self.concatenated_forward(model, batch)
|
|
141
|
+
(
|
|
142
|
+
policy_chosen_logps,
|
|
143
|
+
policy_rejected_logps,
|
|
144
|
+
policy_chosen_logits,
|
|
145
|
+
policy_rejected_logits,
|
|
146
|
+
policy_nll_loss,
|
|
147
|
+
) = aux_outputs[:5]
|
|
148
|
+
|
|
149
|
+
# return loss, metrics
|
|
150
|
+
chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = aux_outputs[
|
|
151
|
+
5:
|
|
152
|
+
]
|
|
153
|
+
|
|
154
|
+
reward_accuracies = (chosen_rewards > rejected_rewards).float()
|
|
155
|
+
|
|
156
|
+
prefix = "eval_" if train_eval == "eval" else ""
|
|
157
|
+
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean()
|
|
158
|
+
metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean()
|
|
159
|
+
metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean()
|
|
160
|
+
metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean()
|
|
161
|
+
metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean()
|
|
162
|
+
metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean()
|
|
163
|
+
metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean()
|
|
164
|
+
metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean()
|
|
165
|
+
metrics[f"{prefix}nll_loss"] = policy_nll_loss.detach().mean()
|
|
166
|
+
metrics[f"{prefix}log_odds_ratio"] = log_odds_ratio
|
|
167
|
+
metrics[f"{prefix}log_odds_chosen"] = log_odds_chosen
|
|
168
|
+
for k, v in metrics.items():
|
|
169
|
+
metrics[k] = v.item()
|
|
170
|
+
|
|
171
|
+
return loss, metrics
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
from liger_kernel.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def liger_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
|
|
5
|
+
"""
|
|
6
|
+
Applies Multimodal Rotary Positional Embedding (M-RoPE) operation to query and key states.
|
|
7
|
+
|
|
8
|
+
Args:
|
|
9
|
+
q (torch.Tensor): The query tensor of shape (bsz, n_q_head, seq_len, head_dim).
|
|
10
|
+
k (torch.Tensor): The key tensor of shape (bsz, n_kv_head, seq_len, head_dim).
|
|
11
|
+
cos (torch.Tensor): The cosine tensor of shape (3, 1, seq_len, head_dim).
|
|
12
|
+
sin (torch.Tensor): The sine tensor of shape (3, 1, seq_len, head_dim).
|
|
13
|
+
mrope_section (List[int]): The multimodal rope section for channel dimension of temporal, height and width in rope calculation.
|
|
14
|
+
unsqueeze_dim (int, optional): The dimension to unsqueeze. Defaults to 1.
|
|
15
|
+
|
|
16
|
+
Returns:
|
|
17
|
+
Tuple[torch.Tensor, torch.Tensor]: The query and key tensors after applying the M-RoPE operation.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
return LigerQwen2VLMRopeFunction.apply(q, k, cos, sin, mrope_section, unsqueeze_dim)
|