liger-kernel-nightly 0.5.2.dev20250119025401__py3-none-any.whl → 0.5.2.dev20250121233718__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- liger_kernel/transformers/model/qwen2_vl.py +21 -3
- {liger_kernel_nightly-0.5.2.dev20250119025401.dist-info → liger_kernel_nightly-0.5.2.dev20250121233718.dist-info}/METADATA +1 -1
- {liger_kernel_nightly-0.5.2.dev20250119025401.dist-info → liger_kernel_nightly-0.5.2.dev20250121233718.dist-info}/RECORD +7 -7
- {liger_kernel_nightly-0.5.2.dev20250119025401.dist-info → liger_kernel_nightly-0.5.2.dev20250121233718.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.2.dev20250119025401.dist-info → liger_kernel_nightly-0.5.2.dev20250121233718.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.2.dev20250119025401.dist-info → liger_kernel_nightly-0.5.2.dev20250121233718.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.2.dev20250119025401.dist-info → liger_kernel_nightly-0.5.2.dev20250121233718.dist-info}/top_level.txt +0 -0
@@ -36,6 +36,7 @@ def lce_forward(
|
|
36
36
|
image_grid_thw: Optional[torch.LongTensor] = None,
|
37
37
|
video_grid_thw: Optional[torch.LongTensor] = None,
|
38
38
|
rope_deltas: Optional[torch.LongTensor] = None,
|
39
|
+
cache_position: Optional[torch.LongTensor] = None,
|
39
40
|
) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]:
|
40
41
|
r"""
|
41
42
|
Copy paste Qwen2VL's forward but replace torch cross entropy with liger fused linear cross entropy
|
@@ -125,14 +126,30 @@ def lce_forward(
|
|
125
126
|
if attention_mask is not None:
|
126
127
|
attention_mask = attention_mask.to(inputs_embeds.device)
|
127
128
|
|
128
|
-
if version.parse(transformers_version) > version.parse("4.46.
|
129
|
+
if version.parse(transformers_version) > version.parse("4.46.3"):
|
129
130
|
# NOTE: this bug fix for qwen2-vl is not applied until transformers 4.47.0
|
130
131
|
# https://github.com/huggingface/transformers/issues/33401
|
131
132
|
# While correct, this breaks equivalence with past versions of Qwen2-VL from
|
132
133
|
# transformers and leads to failed tests or users noticing differences in results.
|
133
134
|
# TODO: remove above conditional when liger drops support for transformers<4.47.0
|
134
|
-
if
|
135
|
-
|
135
|
+
# if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
|
136
|
+
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
|
137
|
+
# calculate RoPE index once per generation in the pre-fill stage only
|
138
|
+
if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None:
|
139
|
+
position_ids, rope_deltas = self.get_rope_index(
|
140
|
+
input_ids, image_grid_thw, video_grid_thw, attention_mask
|
141
|
+
)
|
142
|
+
self.rope_deltas = rope_deltas
|
143
|
+
# then use the prev pre-calculated rope-deltas to get the correct position ids
|
144
|
+
else:
|
145
|
+
batch_size, seq_length, _ = inputs_embeds.shape
|
146
|
+
delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0
|
147
|
+
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
|
148
|
+
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
|
149
|
+
if cache_position is not None: # otherwise `deltas` is an int `0`
|
150
|
+
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
|
151
|
+
position_ids = position_ids.add(delta)
|
152
|
+
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
|
136
153
|
|
137
154
|
outputs = self.model(
|
138
155
|
input_ids=None,
|
@@ -144,6 +161,7 @@ def lce_forward(
|
|
144
161
|
output_attentions=output_attentions,
|
145
162
|
output_hidden_states=output_hidden_states,
|
146
163
|
return_dict=return_dict,
|
164
|
+
cache_position=cache_position,
|
147
165
|
)
|
148
166
|
|
149
167
|
hidden_states = outputs[0]
|
@@ -53,14 +53,14 @@ liger_kernel/transformers/model/mixtral.py,sha256=jpZJkpl625Q-JHWarj2MqT5mRaSsiC
|
|
53
53
|
liger_kernel/transformers/model/mllama.py,sha256=qWexBdskuN3gPJvPUwt4J0nU675tGD6W7wxgRZ9Bifg,11145
|
54
54
|
liger_kernel/transformers/model/phi3.py,sha256=biRa8fph9qdnQmkD9I21t5XIjpIt1i6UKU4uk8Up8pU,10292
|
55
55
|
liger_kernel/transformers/model/qwen2.py,sha256=14UuPjxB-tjqWn85Tn4fqBFvVhVsth5iPEt8kJSMiew,9581
|
56
|
-
liger_kernel/transformers/model/qwen2_vl.py,sha256=
|
56
|
+
liger_kernel/transformers/model/qwen2_vl.py,sha256=yMLqsfSYcvhClUpTUjGoADiOxfLB2B8240VdrPP0c8s,9851
|
57
57
|
liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7HHWHwku25A-GYL0WU,193
|
58
58
|
liger_kernel/transformers/trainer/orpo_trainer.py,sha256=pdekW7l6Qg_aqa5SYKYlSWUF8m3lkOFvFLcIMEHrz9s,8338
|
59
59
|
liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
|
60
60
|
liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
|
61
|
-
liger_kernel_nightly-0.5.2.
|
62
|
-
liger_kernel_nightly-0.5.2.
|
63
|
-
liger_kernel_nightly-0.5.2.
|
64
|
-
liger_kernel_nightly-0.5.2.
|
65
|
-
liger_kernel_nightly-0.5.2.
|
66
|
-
liger_kernel_nightly-0.5.2.
|
61
|
+
liger_kernel_nightly-0.5.2.dev20250121233718.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
|
62
|
+
liger_kernel_nightly-0.5.2.dev20250121233718.dist-info/METADATA,sha256=QUXjV3q15U4bHHBeStGdZVlcf9xzck0d-aOHsLdr9nE,21055
|
63
|
+
liger_kernel_nightly-0.5.2.dev20250121233718.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
|
64
|
+
liger_kernel_nightly-0.5.2.dev20250121233718.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
65
|
+
liger_kernel_nightly-0.5.2.dev20250121233718.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
66
|
+
liger_kernel_nightly-0.5.2.dev20250121233718.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|