liger-kernel-nightly 0.5.2.dev20250120024510__py3-none-any.whl → 0.5.2.dev20250121233718__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -36,6 +36,7 @@ def lce_forward(
36
36
  image_grid_thw: Optional[torch.LongTensor] = None,
37
37
  video_grid_thw: Optional[torch.LongTensor] = None,
38
38
  rope_deltas: Optional[torch.LongTensor] = None,
39
+ cache_position: Optional[torch.LongTensor] = None,
39
40
  ) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]:
40
41
  r"""
41
42
  Copy paste Qwen2VL's forward but replace torch cross entropy with liger fused linear cross entropy
@@ -125,14 +126,30 @@ def lce_forward(
125
126
  if attention_mask is not None:
126
127
  attention_mask = attention_mask.to(inputs_embeds.device)
127
128
 
128
- if version.parse(transformers_version) > version.parse("4.46.2"):
129
+ if version.parse(transformers_version) > version.parse("4.46.3"):
129
130
  # NOTE: this bug fix for qwen2-vl is not applied until transformers 4.47.0
130
131
  # https://github.com/huggingface/transformers/issues/33401
131
132
  # While correct, this breaks equivalence with past versions of Qwen2-VL from
132
133
  # transformers and leads to failed tests or users noticing differences in results.
133
134
  # TODO: remove above conditional when liger drops support for transformers<4.47.0
134
- if position_ids is None and input_ids is not None:
135
- position_ids, _ = self.get_rope_index(input_ids, image_grid_thw, video_grid_thw, attention_mask)
135
+ # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
136
+ if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
137
+ # calculate RoPE index once per generation in the pre-fill stage only
138
+ if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None:
139
+ position_ids, rope_deltas = self.get_rope_index(
140
+ input_ids, image_grid_thw, video_grid_thw, attention_mask
141
+ )
142
+ self.rope_deltas = rope_deltas
143
+ # then use the prev pre-calculated rope-deltas to get the correct position ids
144
+ else:
145
+ batch_size, seq_length, _ = inputs_embeds.shape
146
+ delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0
147
+ position_ids = torch.arange(seq_length, device=inputs_embeds.device)
148
+ position_ids = position_ids.view(1, -1).expand(batch_size, -1)
149
+ if cache_position is not None: # otherwise `deltas` is an int `0`
150
+ delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
151
+ position_ids = position_ids.add(delta)
152
+ position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
136
153
 
137
154
  outputs = self.model(
138
155
  input_ids=None,
@@ -144,6 +161,7 @@ def lce_forward(
144
161
  output_attentions=output_attentions,
145
162
  output_hidden_states=output_hidden_states,
146
163
  return_dict=return_dict,
164
+ cache_position=cache_position,
147
165
  )
148
166
 
149
167
  hidden_states = outputs[0]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.2.dev20250120024510
3
+ Version: 0.5.2.dev20250121233718
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -53,14 +53,14 @@ liger_kernel/transformers/model/mixtral.py,sha256=jpZJkpl625Q-JHWarj2MqT5mRaSsiC
53
53
  liger_kernel/transformers/model/mllama.py,sha256=qWexBdskuN3gPJvPUwt4J0nU675tGD6W7wxgRZ9Bifg,11145
54
54
  liger_kernel/transformers/model/phi3.py,sha256=biRa8fph9qdnQmkD9I21t5XIjpIt1i6UKU4uk8Up8pU,10292
55
55
  liger_kernel/transformers/model/qwen2.py,sha256=14UuPjxB-tjqWn85Tn4fqBFvVhVsth5iPEt8kJSMiew,9581
56
- liger_kernel/transformers/model/qwen2_vl.py,sha256=rZg3nU3YgF6wkB1UJ0a9IACSIlVOSCyLltyqw951MQQ,8609
56
+ liger_kernel/transformers/model/qwen2_vl.py,sha256=yMLqsfSYcvhClUpTUjGoADiOxfLB2B8240VdrPP0c8s,9851
57
57
  liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7HHWHwku25A-GYL0WU,193
58
58
  liger_kernel/transformers/trainer/orpo_trainer.py,sha256=pdekW7l6Qg_aqa5SYKYlSWUF8m3lkOFvFLcIMEHrz9s,8338
59
59
  liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
60
60
  liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
61
- liger_kernel_nightly-0.5.2.dev20250120024510.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
62
- liger_kernel_nightly-0.5.2.dev20250120024510.dist-info/METADATA,sha256=6rr1Qq6PM7sdCXXvN9tkrqqEhjzfwy6Ac2mfUlpc5n4,21055
63
- liger_kernel_nightly-0.5.2.dev20250120024510.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
64
- liger_kernel_nightly-0.5.2.dev20250120024510.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
65
- liger_kernel_nightly-0.5.2.dev20250120024510.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
66
- liger_kernel_nightly-0.5.2.dev20250120024510.dist-info/RECORD,,
61
+ liger_kernel_nightly-0.5.2.dev20250121233718.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
62
+ liger_kernel_nightly-0.5.2.dev20250121233718.dist-info/METADATA,sha256=QUXjV3q15U4bHHBeStGdZVlcf9xzck0d-aOHsLdr9nE,21055
63
+ liger_kernel_nightly-0.5.2.dev20250121233718.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
64
+ liger_kernel_nightly-0.5.2.dev20250121233718.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
65
+ liger_kernel_nightly-0.5.2.dev20250121233718.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
66
+ liger_kernel_nightly-0.5.2.dev20250121233718.dist-info/RECORD,,