liger-kernel-nightly 0.4.1.dev20241115012952__py3-none-any.whl → 0.4.1.dev20241115210858__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/ops/fused_linear_cross_entropy.py +1 -0
- liger_kernel/transformers/model/qwen2_vl.py +43 -17
- {liger_kernel_nightly-0.4.1.dev20241115012952.dist-info → liger_kernel_nightly-0.4.1.dev20241115210858.dist-info}/METADATA +1 -1
- {liger_kernel_nightly-0.4.1.dev20241115012952.dist-info → liger_kernel_nightly-0.4.1.dev20241115210858.dist-info}/RECORD +8 -8
- {liger_kernel_nightly-0.4.1.dev20241115012952.dist-info → liger_kernel_nightly-0.4.1.dev20241115210858.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.4.1.dev20241115012952.dist-info → liger_kernel_nightly-0.4.1.dev20241115210858.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.4.1.dev20241115012952.dist-info → liger_kernel_nightly-0.4.1.dev20241115210858.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.4.1.dev20241115012952.dist-info → liger_kernel_nightly-0.4.1.dev20241115210858.dist-info}/top_level.txt +0 -0
@@ -229,6 +229,7 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
229
229
|
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
|
230
230
|
reduction: reduction to apply
|
231
231
|
"""
|
232
|
+
|
232
233
|
loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
|
233
234
|
_input,
|
234
235
|
weight,
|
@@ -1,7 +1,9 @@
|
|
1
1
|
from typing import List, Optional, Tuple, Union
|
2
2
|
|
3
3
|
import torch
|
4
|
+
from packaging import version
|
4
5
|
from torch.nn import CrossEntropyLoss
|
6
|
+
from transformers import __version__ as transformers_version
|
5
7
|
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
|
6
8
|
_CONFIG_FOR_DOC,
|
7
9
|
QWEN2_VL_INPUTS_DOCSTRING,
|
@@ -80,8 +82,6 @@ def lce_forward(
|
|
80
82
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
81
83
|
"The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
|
82
84
|
```"""
|
83
|
-
# FIXME: The code is outdated and not compatible with transformer >= 4.46.1
|
84
|
-
|
85
85
|
output_attentions = (
|
86
86
|
output_attentions
|
87
87
|
if output_attentions is not None
|
@@ -100,27 +100,53 @@ def lce_forward(
|
|
100
100
|
inputs_embeds = self.model.embed_tokens(input_ids)
|
101
101
|
if pixel_values is not None:
|
102
102
|
pixel_values = pixel_values.type(self.visual.get_dtype())
|
103
|
-
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
104
|
-
|
103
|
+
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
104
|
+
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
|
105
|
+
n_image_features = image_embeds.shape[0]
|
106
|
+
if n_image_tokens != n_image_features:
|
107
|
+
raise ValueError(
|
108
|
+
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
109
|
+
)
|
110
|
+
image_mask = (
|
111
|
+
(input_ids == self.config.image_token_id)
|
112
|
+
.unsqueeze(-1)
|
113
|
+
.expand_as(inputs_embeds)
|
114
|
+
.to(inputs_embeds.device)
|
105
115
|
)
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
inputs_embeds[image_mask] = image_embeds
|
116
|
+
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
117
|
+
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
118
|
+
|
110
119
|
if pixel_values_videos is not None:
|
111
120
|
pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype())
|
112
|
-
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
|
113
|
-
|
121
|
+
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
|
122
|
+
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
|
123
|
+
n_video_features = video_embeds.shape[0]
|
124
|
+
if n_video_tokens != n_video_features:
|
125
|
+
raise ValueError(
|
126
|
+
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
|
127
|
+
)
|
128
|
+
video_mask = (
|
129
|
+
(input_ids == self.config.video_token_id)
|
130
|
+
.unsqueeze(-1)
|
131
|
+
.expand_as(inputs_embeds)
|
132
|
+
.to(inputs_embeds.device)
|
114
133
|
)
|
115
|
-
|
116
|
-
inputs_embeds
|
134
|
+
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
135
|
+
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
136
|
+
|
117
137
|
if attention_mask is not None:
|
118
138
|
attention_mask = attention_mask.to(inputs_embeds.device)
|
119
|
-
|
120
|
-
if
|
121
|
-
|
122
|
-
|
123
|
-
|
139
|
+
|
140
|
+
if version.parse(transformers_version) > version.parse("4.46.2"):
|
141
|
+
# NOTE: this bug fix for qwen2-vl is not applied until transformers 4.47.0
|
142
|
+
# https://github.com/huggingface/transformers/issues/33401
|
143
|
+
# While correct, this breaks equivalence with past versions of Qwen2-VL from
|
144
|
+
# transformers and leads to failed tests or users noticing differences in results.
|
145
|
+
# TODO: remove above conditional when liger drops support for transformers<4.47.0
|
146
|
+
if position_ids is None and input_ids is not None:
|
147
|
+
position_ids, _ = self.get_rope_index(
|
148
|
+
input_ids, image_grid_thw, video_grid_thw, attention_mask
|
149
|
+
)
|
124
150
|
|
125
151
|
outputs = self.model(
|
126
152
|
input_ids=None,
|
@@ -5,7 +5,7 @@ liger_kernel/chunked_loss/fused_linear_preference.py,sha256=ayx-dmAx1TW9sThHJ_wU
|
|
5
5
|
liger_kernel/chunked_loss/orpo_loss.py,sha256=DNifPpzGV_t3dfOPlPy2XKDM6M1Qne0kCbIPztvFY9U,2179
|
6
6
|
liger_kernel/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
7
7
|
liger_kernel/ops/cross_entropy.py,sha256=sfUb7-jIZp0EKXjg1DYy2Wdzw_Mg-mHmGoR5bpdm4tw,15526
|
8
|
-
liger_kernel/ops/fused_linear_cross_entropy.py,sha256=
|
8
|
+
liger_kernel/ops/fused_linear_cross_entropy.py,sha256=ib7M3AjJE164yMfuS9R39k-5qnDgYOXptIT146lqYbg,9964
|
9
9
|
liger_kernel/ops/fused_linear_jsd.py,sha256=5D_obamh08lGGTMyh85kBJD_aNjPhOYf4-TmCZ6m4s4,9626
|
10
10
|
liger_kernel/ops/geglu.py,sha256=MQL4zyzneZqZYUGPvb1QjI_EYT9_pKfSDgR25WD9jrI,4127
|
11
11
|
liger_kernel/ops/group_norm.py,sha256=VaRErVJGR4JqgXXvuIjNGTn3E2egjLtU1y3ymwIf4d8,10961
|
@@ -44,12 +44,12 @@ liger_kernel/transformers/model/mixtral.py,sha256=nyDS1dBpsOXYC2DuW59Hgu7ZrGftrH
|
|
44
44
|
liger_kernel/transformers/model/mllama.py,sha256=mesNCgj0Ea1O-fqRD4LVxDJ1CR2abY_zAzK_bfVzkiU,11222
|
45
45
|
liger_kernel/transformers/model/phi3.py,sha256=xUZPlaPKwknLjHc3uUW3EPodm1h0vD3G7Qnhh51v-Io,10332
|
46
46
|
liger_kernel/transformers/model/qwen2.py,sha256=EyhSSzQOskGjSnCsKMZpd1s5IAIlHd5PBO3q0MoCs00,9619
|
47
|
-
liger_kernel/transformers/model/qwen2_vl.py,sha256=
|
47
|
+
liger_kernel/transformers/model/qwen2_vl.py,sha256=bIQe2bWiY--G84FhCD29Gdi64_qHP6vbcGsK6vKysQE,8547
|
48
48
|
liger_kernel/triton/__init__.py,sha256=yfRe0zMb47QnqjecZWG7LnanfCTzeku7SgWRAwNVmzU,101
|
49
49
|
liger_kernel/triton/monkey_patch.py,sha256=5BcGKTtdqeYchypBIBopGIWPx1-cFALz7sOKoEsqXJ0,1584
|
50
|
-
liger_kernel_nightly-0.4.1.
|
51
|
-
liger_kernel_nightly-0.4.1.
|
52
|
-
liger_kernel_nightly-0.4.1.
|
53
|
-
liger_kernel_nightly-0.4.1.
|
54
|
-
liger_kernel_nightly-0.4.1.
|
55
|
-
liger_kernel_nightly-0.4.1.
|
50
|
+
liger_kernel_nightly-0.4.1.dev20241115210858.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
|
51
|
+
liger_kernel_nightly-0.4.1.dev20241115210858.dist-info/METADATA,sha256=VsDMgGO6VdbcC6qFTtPSALLozMM_bwcOl-MgZTzZKLY,21556
|
52
|
+
liger_kernel_nightly-0.4.1.dev20241115210858.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
|
53
|
+
liger_kernel_nightly-0.4.1.dev20241115210858.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
54
|
+
liger_kernel_nightly-0.4.1.dev20241115210858.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
55
|
+
liger_kernel_nightly-0.4.1.dev20241115210858.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|