liger-kernel-nightly 0.4.1.dev20241115191733__py3-none-any.whl → 0.4.1.dev20241117192031__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/transformers/model/qwen2_vl.py +43 -17
- liger_kernel/transformers/monkey_patch.py +5 -2
- {liger_kernel_nightly-0.4.1.dev20241115191733.dist-info → liger_kernel_nightly-0.4.1.dev20241117192031.dist-info}/METADATA +1 -1
- {liger_kernel_nightly-0.4.1.dev20241115191733.dist-info → liger_kernel_nightly-0.4.1.dev20241117192031.dist-info}/RECORD +8 -8
- {liger_kernel_nightly-0.4.1.dev20241115191733.dist-info → liger_kernel_nightly-0.4.1.dev20241117192031.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.4.1.dev20241115191733.dist-info → liger_kernel_nightly-0.4.1.dev20241117192031.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.4.1.dev20241115191733.dist-info → liger_kernel_nightly-0.4.1.dev20241117192031.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.4.1.dev20241115191733.dist-info → liger_kernel_nightly-0.4.1.dev20241117192031.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,9 @@
|
|
1
1
|
from typing import List, Optional, Tuple, Union
|
2
2
|
|
3
3
|
import torch
|
4
|
+
from packaging import version
|
4
5
|
from torch.nn import CrossEntropyLoss
|
6
|
+
from transformers import __version__ as transformers_version
|
5
7
|
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
|
6
8
|
_CONFIG_FOR_DOC,
|
7
9
|
QWEN2_VL_INPUTS_DOCSTRING,
|
@@ -80,8 +82,6 @@ def lce_forward(
|
|
80
82
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
81
83
|
"The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
|
82
84
|
```"""
|
83
|
-
# FIXME: The code is outdated and not compatible with transformer >= 4.46.1
|
84
|
-
|
85
85
|
output_attentions = (
|
86
86
|
output_attentions
|
87
87
|
if output_attentions is not None
|
@@ -100,27 +100,53 @@ def lce_forward(
|
|
100
100
|
inputs_embeds = self.model.embed_tokens(input_ids)
|
101
101
|
if pixel_values is not None:
|
102
102
|
pixel_values = pixel_values.type(self.visual.get_dtype())
|
103
|
-
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
104
|
-
|
103
|
+
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
104
|
+
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
|
105
|
+
n_image_features = image_embeds.shape[0]
|
106
|
+
if n_image_tokens != n_image_features:
|
107
|
+
raise ValueError(
|
108
|
+
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
109
|
+
)
|
110
|
+
image_mask = (
|
111
|
+
(input_ids == self.config.image_token_id)
|
112
|
+
.unsqueeze(-1)
|
113
|
+
.expand_as(inputs_embeds)
|
114
|
+
.to(inputs_embeds.device)
|
105
115
|
)
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
inputs_embeds[image_mask] = image_embeds
|
116
|
+
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
117
|
+
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
118
|
+
|
110
119
|
if pixel_values_videos is not None:
|
111
120
|
pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype())
|
112
|
-
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
|
113
|
-
|
121
|
+
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
|
122
|
+
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
|
123
|
+
n_video_features = video_embeds.shape[0]
|
124
|
+
if n_video_tokens != n_video_features:
|
125
|
+
raise ValueError(
|
126
|
+
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
|
127
|
+
)
|
128
|
+
video_mask = (
|
129
|
+
(input_ids == self.config.video_token_id)
|
130
|
+
.unsqueeze(-1)
|
131
|
+
.expand_as(inputs_embeds)
|
132
|
+
.to(inputs_embeds.device)
|
114
133
|
)
|
115
|
-
|
116
|
-
inputs_embeds
|
134
|
+
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
135
|
+
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
136
|
+
|
117
137
|
if attention_mask is not None:
|
118
138
|
attention_mask = attention_mask.to(inputs_embeds.device)
|
119
|
-
|
120
|
-
if
|
121
|
-
|
122
|
-
|
123
|
-
|
139
|
+
|
140
|
+
if version.parse(transformers_version) > version.parse("4.46.2"):
|
141
|
+
# NOTE: this bug fix for qwen2-vl is not applied until transformers 4.47.0
|
142
|
+
# https://github.com/huggingface/transformers/issues/33401
|
143
|
+
# While correct, this breaks equivalence with past versions of Qwen2-VL from
|
144
|
+
# transformers and leads to failed tests or users noticing differences in results.
|
145
|
+
# TODO: remove above conditional when liger drops support for transformers<4.47.0
|
146
|
+
if position_ids is None and input_ids is not None:
|
147
|
+
position_ids, _ = self.get_rope_index(
|
148
|
+
input_ids, image_grid_thw, video_grid_thw, attention_mask
|
149
|
+
)
|
124
150
|
|
125
151
|
outputs = self.model(
|
126
152
|
input_ids=None,
|
@@ -56,12 +56,15 @@ def _bind_method_to_module(module, method_name: str, new_method: Callable):
|
|
56
56
|
module.__dict__[method_name] = new_method.__get__(module, module.__class__)
|
57
57
|
|
58
58
|
|
59
|
-
def _patch_rms_norm_module(
|
59
|
+
def _patch_rms_norm_module(
|
60
|
+
module, offset=0.0, eps=1e-6, casting_mode="llama", in_place=True
|
61
|
+
):
|
60
62
|
module.offset = offset
|
61
63
|
module.casting_mode = casting_mode
|
62
64
|
module.variance_epsilon = (
|
63
65
|
getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
64
66
|
)
|
67
|
+
module.in_place = in_place
|
65
68
|
_bind_method_to_module(module, "forward", LigerRMSNorm.forward)
|
66
69
|
_bind_method_to_module(module, "extra_repr", LigerRMSNorm.extra_repr)
|
67
70
|
|
@@ -510,7 +513,7 @@ def apply_liger_kernel_to_gemma2(
|
|
510
513
|
LigerRMSNorm, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False
|
511
514
|
)
|
512
515
|
_patch_rms_norm_module_for_gemma2 = partial(
|
513
|
-
_patch_rms_norm_module, offset=1.0, casting_mode="gemma"
|
516
|
+
_patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
|
514
517
|
)
|
515
518
|
|
516
519
|
if rope:
|
@@ -29,7 +29,7 @@ liger_kernel/transformers/group_norm.py,sha256=FJ9R7mS9G1wO-GRIQ6QKSmIhnZ6nQ6GIk
|
|
29
29
|
liger_kernel/transformers/jsd.py,sha256=W-5CypO2mx4-bUWOxq1KScfCdoXlLoYbtt5xBnRzMs4,3056
|
30
30
|
liger_kernel/transformers/kl_div.py,sha256=qVhjBg6tjRyue5iZ3NFxo8uySY4JuIFJyv0IM_50F24,431
|
31
31
|
liger_kernel/transformers/layer_norm.py,sha256=fd6o4kSHJWolQMWxh-l1qObfgL08ruNbUoBiANKX1ow,972
|
32
|
-
liger_kernel/transformers/monkey_patch.py,sha256=
|
32
|
+
liger_kernel/transformers/monkey_patch.py,sha256=L1IuGmFMWYgf-u3OXCg43BUxbZKTpd7ATjjDjYoFkEM,38268
|
33
33
|
liger_kernel/transformers/rms_norm.py,sha256=AHstklNIO1PLHjjCBU-TPuUD-Fl_pycJUTLlJNojbV8,1189
|
34
34
|
liger_kernel/transformers/rope.py,sha256=m-ah8vZBYW8tfplTXCiAPMHJWlB1tdp_JPXJeWE-Boo,943
|
35
35
|
liger_kernel/transformers/swiglu.py,sha256=0-tVJ8xEYfhxnduc16PflXFj8sZPxdx9sHUn3hfwCI4,2468
|
@@ -44,12 +44,12 @@ liger_kernel/transformers/model/mixtral.py,sha256=nyDS1dBpsOXYC2DuW59Hgu7ZrGftrH
|
|
44
44
|
liger_kernel/transformers/model/mllama.py,sha256=mesNCgj0Ea1O-fqRD4LVxDJ1CR2abY_zAzK_bfVzkiU,11222
|
45
45
|
liger_kernel/transformers/model/phi3.py,sha256=xUZPlaPKwknLjHc3uUW3EPodm1h0vD3G7Qnhh51v-Io,10332
|
46
46
|
liger_kernel/transformers/model/qwen2.py,sha256=EyhSSzQOskGjSnCsKMZpd1s5IAIlHd5PBO3q0MoCs00,9619
|
47
|
-
liger_kernel/transformers/model/qwen2_vl.py,sha256=
|
47
|
+
liger_kernel/transformers/model/qwen2_vl.py,sha256=bIQe2bWiY--G84FhCD29Gdi64_qHP6vbcGsK6vKysQE,8547
|
48
48
|
liger_kernel/triton/__init__.py,sha256=yfRe0zMb47QnqjecZWG7LnanfCTzeku7SgWRAwNVmzU,101
|
49
49
|
liger_kernel/triton/monkey_patch.py,sha256=5BcGKTtdqeYchypBIBopGIWPx1-cFALz7sOKoEsqXJ0,1584
|
50
|
-
liger_kernel_nightly-0.4.1.
|
51
|
-
liger_kernel_nightly-0.4.1.
|
52
|
-
liger_kernel_nightly-0.4.1.
|
53
|
-
liger_kernel_nightly-0.4.1.
|
54
|
-
liger_kernel_nightly-0.4.1.
|
55
|
-
liger_kernel_nightly-0.4.1.
|
50
|
+
liger_kernel_nightly-0.4.1.dev20241117192031.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
|
51
|
+
liger_kernel_nightly-0.4.1.dev20241117192031.dist-info/METADATA,sha256=HE97eoTT33apEKjxw39NI2lolbsj49okNZImxATruEo,21556
|
52
|
+
liger_kernel_nightly-0.4.1.dev20241117192031.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
|
53
|
+
liger_kernel_nightly-0.4.1.dev20241117192031.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
54
|
+
liger_kernel_nightly-0.4.1.dev20241117192031.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
55
|
+
liger_kernel_nightly-0.4.1.dev20241117192031.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|