liger-kernel-nightly 0.4.1.dev20241115191733__tar.gz → 0.4.1.dev20241117192031__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. {liger_kernel_nightly-0.4.1.dev20241115191733/src/liger_kernel_nightly.egg-info → liger_kernel_nightly-0.4.1.dev20241117192031}/PKG-INFO +1 -1
  2. {liger_kernel_nightly-0.4.1.dev20241115191733 → liger_kernel_nightly-0.4.1.dev20241117192031}/pyproject.toml +1 -1
  3. {liger_kernel_nightly-0.4.1.dev20241115191733 → liger_kernel_nightly-0.4.1.dev20241117192031}/src/liger_kernel/transformers/model/qwen2_vl.py +43 -17
  4. {liger_kernel_nightly-0.4.1.dev20241115191733 → liger_kernel_nightly-0.4.1.dev20241117192031}/src/liger_kernel/transformers/monkey_patch.py +5 -2
  5. {liger_kernel_nightly-0.4.1.dev20241115191733 → liger_kernel_nightly-0.4.1.dev20241117192031/src/liger_kernel_nightly.egg-info}/PKG-INFO +1 -1
  6. {liger_kernel_nightly-0.4.1.dev20241115191733 → liger_kernel_nightly-0.4.1.dev20241117192031}/LICENSE +0 -0
  7. {liger_kernel_nightly-0.4.1.dev20241115191733 → liger_kernel_nightly-0.4.1.dev20241117192031}/NOTICE +0 -0
  8. {liger_kernel_nightly-0.4.1.dev20241115191733 → liger_kernel_nightly-0.4.1.dev20241117192031}/README.md +0 -0
  9. {liger_kernel_nightly-0.4.1.dev20241115191733 → liger_kernel_nightly-0.4.1.dev20241117192031}/setup.cfg +0 -0
  10. {liger_kernel_nightly-0.4.1.dev20241115191733 → liger_kernel_nightly-0.4.1.dev20241117192031}/src/liger_kernel/chunked_loss/__init__.py +0 -0
  11. {liger_kernel_nightly-0.4.1.dev20241115191733 → liger_kernel_nightly-0.4.1.dev20241117192031}/src/liger_kernel/chunked_loss/dpo_loss.py +0 -0
  12. {liger_kernel_nightly-0.4.1.dev20241115191733 → liger_kernel_nightly-0.4.1.dev20241117192031}/src/liger_kernel/chunked_loss/fused_linear_preference.py +0 -0
  13. {liger_kernel_nightly-0.4.1.dev20241115191733 → liger_kernel_nightly-0.4.1.dev20241117192031}/src/liger_kernel/chunked_loss/orpo_loss.py +0 -0
  14. {liger_kernel_nightly-0.4.1.dev20241115191733 → liger_kernel_nightly-0.4.1.dev20241117192031}/src/liger_kernel/env_report.py +0 -0
  15. {liger_kernel_nightly-0.4.1.dev20241115191733 → liger_kernel_nightly-0.4.1.dev20241117192031}/src/liger_kernel/ops/__init__.py +0 -0
  16. {liger_kernel_nightly-0.4.1.dev20241115191733 → liger_kernel_nightly-0.4.1.dev20241117192031}/src/liger_kernel/ops/cross_entropy.py +0 -0
  17. {liger_kernel_nightly-0.4.1.dev20241115191733 → liger_kernel_nightly-0.4.1.dev20241117192031}/src/liger_kernel/ops/experimental/embedding.py +0 -0
  18. {liger_kernel_nightly-0.4.1.dev20241115191733 → liger_kernel_nightly-0.4.1.dev20241117192031}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
  19. {liger_kernel_nightly-0.4.1.dev20241115191733 → liger_kernel_nightly-0.4.1.dev20241117192031}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -0
  20. {liger_kernel_nightly-0.4.1.dev20241115191733 → liger_kernel_nightly-0.4.1.dev20241117192031}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
  21. {liger_kernel_nightly-0.4.1.dev20241115191733 → liger_kernel_nightly-0.4.1.dev20241117192031}/src/liger_kernel/ops/geglu.py +0 -0
  22. {liger_kernel_nightly-0.4.1.dev20241115191733 → liger_kernel_nightly-0.4.1.dev20241117192031}/src/liger_kernel/ops/group_norm.py +0 -0
  23. {liger_kernel_nightly-0.4.1.dev20241115191733 → liger_kernel_nightly-0.4.1.dev20241117192031}/src/liger_kernel/ops/jsd.py +0 -0
  24. {liger_kernel_nightly-0.4.1.dev20241115191733 → liger_kernel_nightly-0.4.1.dev20241117192031}/src/liger_kernel/ops/kl_div.py +0 -0
  25. {liger_kernel_nightly-0.4.1.dev20241115191733 → liger_kernel_nightly-0.4.1.dev20241117192031}/src/liger_kernel/ops/layer_norm.py +0 -0
  26. {liger_kernel_nightly-0.4.1.dev20241115191733 → liger_kernel_nightly-0.4.1.dev20241117192031}/src/liger_kernel/ops/rms_norm.py +0 -0
  27. {liger_kernel_nightly-0.4.1.dev20241115191733 → liger_kernel_nightly-0.4.1.dev20241117192031}/src/liger_kernel/ops/rope.py +0 -0
  28. {liger_kernel_nightly-0.4.1.dev20241115191733 → liger_kernel_nightly-0.4.1.dev20241117192031}/src/liger_kernel/ops/swiglu.py +0 -0
  29. {liger_kernel_nightly-0.4.1.dev20241115191733 → liger_kernel_nightly-0.4.1.dev20241117192031}/src/liger_kernel/ops/utils.py +0 -0
  30. {liger_kernel_nightly-0.4.1.dev20241115191733 → liger_kernel_nightly-0.4.1.dev20241117192031}/src/liger_kernel/transformers/__init__.py +0 -0
  31. {liger_kernel_nightly-0.4.1.dev20241115191733 → liger_kernel_nightly-0.4.1.dev20241117192031}/src/liger_kernel/transformers/auto_model.py +0 -0
  32. {liger_kernel_nightly-0.4.1.dev20241115191733 → liger_kernel_nightly-0.4.1.dev20241117192031}/src/liger_kernel/transformers/cross_entropy.py +0 -0
  33. {liger_kernel_nightly-0.4.1.dev20241115191733 → liger_kernel_nightly-0.4.1.dev20241117192031}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
  34. {liger_kernel_nightly-0.4.1.dev20241115191733 → liger_kernel_nightly-0.4.1.dev20241117192031}/src/liger_kernel/transformers/functional.py +0 -0
  35. {liger_kernel_nightly-0.4.1.dev20241115191733 → liger_kernel_nightly-0.4.1.dev20241117192031}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
  36. {liger_kernel_nightly-0.4.1.dev20241115191733 → liger_kernel_nightly-0.4.1.dev20241117192031}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
  37. {liger_kernel_nightly-0.4.1.dev20241115191733 → liger_kernel_nightly-0.4.1.dev20241117192031}/src/liger_kernel/transformers/geglu.py +0 -0
  38. {liger_kernel_nightly-0.4.1.dev20241115191733 → liger_kernel_nightly-0.4.1.dev20241117192031}/src/liger_kernel/transformers/group_norm.py +0 -0
  39. {liger_kernel_nightly-0.4.1.dev20241115191733 → liger_kernel_nightly-0.4.1.dev20241117192031}/src/liger_kernel/transformers/jsd.py +0 -0
  40. {liger_kernel_nightly-0.4.1.dev20241115191733 → liger_kernel_nightly-0.4.1.dev20241117192031}/src/liger_kernel/transformers/kl_div.py +0 -0
  41. {liger_kernel_nightly-0.4.1.dev20241115191733 → liger_kernel_nightly-0.4.1.dev20241117192031}/src/liger_kernel/transformers/layer_norm.py +0 -0
  42. {liger_kernel_nightly-0.4.1.dev20241115191733 → liger_kernel_nightly-0.4.1.dev20241117192031}/src/liger_kernel/transformers/model/__init__.py +0 -0
  43. {liger_kernel_nightly-0.4.1.dev20241115191733 → liger_kernel_nightly-0.4.1.dev20241117192031}/src/liger_kernel/transformers/model/gemma.py +0 -0
  44. {liger_kernel_nightly-0.4.1.dev20241115191733 → liger_kernel_nightly-0.4.1.dev20241117192031}/src/liger_kernel/transformers/model/gemma2.py +0 -0
  45. {liger_kernel_nightly-0.4.1.dev20241115191733 → liger_kernel_nightly-0.4.1.dev20241117192031}/src/liger_kernel/transformers/model/llama.py +0 -0
  46. {liger_kernel_nightly-0.4.1.dev20241115191733 → liger_kernel_nightly-0.4.1.dev20241117192031}/src/liger_kernel/transformers/model/mistral.py +0 -0
  47. {liger_kernel_nightly-0.4.1.dev20241115191733 → liger_kernel_nightly-0.4.1.dev20241117192031}/src/liger_kernel/transformers/model/mixtral.py +0 -0
  48. {liger_kernel_nightly-0.4.1.dev20241115191733 → liger_kernel_nightly-0.4.1.dev20241117192031}/src/liger_kernel/transformers/model/mllama.py +0 -0
  49. {liger_kernel_nightly-0.4.1.dev20241115191733 → liger_kernel_nightly-0.4.1.dev20241117192031}/src/liger_kernel/transformers/model/phi3.py +0 -0
  50. {liger_kernel_nightly-0.4.1.dev20241115191733 → liger_kernel_nightly-0.4.1.dev20241117192031}/src/liger_kernel/transformers/model/qwen2.py +0 -0
  51. {liger_kernel_nightly-0.4.1.dev20241115191733 → liger_kernel_nightly-0.4.1.dev20241117192031}/src/liger_kernel/transformers/rms_norm.py +0 -0
  52. {liger_kernel_nightly-0.4.1.dev20241115191733 → liger_kernel_nightly-0.4.1.dev20241117192031}/src/liger_kernel/transformers/rope.py +0 -0
  53. {liger_kernel_nightly-0.4.1.dev20241115191733 → liger_kernel_nightly-0.4.1.dev20241117192031}/src/liger_kernel/transformers/swiglu.py +0 -0
  54. {liger_kernel_nightly-0.4.1.dev20241115191733 → liger_kernel_nightly-0.4.1.dev20241117192031}/src/liger_kernel/transformers/trainer_integration.py +0 -0
  55. {liger_kernel_nightly-0.4.1.dev20241115191733 → liger_kernel_nightly-0.4.1.dev20241117192031}/src/liger_kernel/triton/__init__.py +0 -0
  56. {liger_kernel_nightly-0.4.1.dev20241115191733 → liger_kernel_nightly-0.4.1.dev20241117192031}/src/liger_kernel/triton/monkey_patch.py +0 -0
  57. {liger_kernel_nightly-0.4.1.dev20241115191733 → liger_kernel_nightly-0.4.1.dev20241117192031}/src/liger_kernel_nightly.egg-info/SOURCES.txt +0 -0
  58. {liger_kernel_nightly-0.4.1.dev20241115191733 → liger_kernel_nightly-0.4.1.dev20241117192031}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
  59. {liger_kernel_nightly-0.4.1.dev20241115191733 → liger_kernel_nightly-0.4.1.dev20241117192031}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
  60. {liger_kernel_nightly-0.4.1.dev20241115191733 → liger_kernel_nightly-0.4.1.dev20241117192031}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.4.1.dev20241115191733
3
+ Version: 0.4.1.dev20241117192031
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "liger_kernel_nightly"
7
- version = "0.4.1.dev20241115191733"
7
+ version = "0.4.1.dev20241117192031"
8
8
  description = "Efficient Triton kernels for LLM Training"
9
9
  urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
10
10
  readme = { file = "README.md", content-type = "text/markdown" }
@@ -1,7 +1,9 @@
1
1
  from typing import List, Optional, Tuple, Union
2
2
 
3
3
  import torch
4
+ from packaging import version
4
5
  from torch.nn import CrossEntropyLoss
6
+ from transformers import __version__ as transformers_version
5
7
  from transformers.models.qwen2_vl.modeling_qwen2_vl import (
6
8
  _CONFIG_FOR_DOC,
7
9
  QWEN2_VL_INPUTS_DOCSTRING,
@@ -80,8 +82,6 @@ def lce_forward(
80
82
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
81
83
  "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
82
84
  ```"""
83
- # FIXME: The code is outdated and not compatible with transformer >= 4.46.1
84
-
85
85
  output_attentions = (
86
86
  output_attentions
87
87
  if output_attentions is not None
@@ -100,27 +100,53 @@ def lce_forward(
100
100
  inputs_embeds = self.model.embed_tokens(input_ids)
101
101
  if pixel_values is not None:
102
102
  pixel_values = pixel_values.type(self.visual.get_dtype())
103
- image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw).to(
104
- inputs_embeds.device
103
+ image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
104
+ n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
105
+ n_image_features = image_embeds.shape[0]
106
+ if n_image_tokens != n_image_features:
107
+ raise ValueError(
108
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
109
+ )
110
+ image_mask = (
111
+ (input_ids == self.config.image_token_id)
112
+ .unsqueeze(-1)
113
+ .expand_as(inputs_embeds)
114
+ .to(inputs_embeds.device)
105
115
  )
106
- image_mask = input_ids == self.config.image_token_id
107
- if self.training:
108
- inputs_embeds = inputs_embeds.clone()
109
- inputs_embeds[image_mask] = image_embeds
116
+ image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
117
+ inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
118
+
110
119
  if pixel_values_videos is not None:
111
120
  pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype())
112
- video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw).to(
113
- inputs_embeds.device
121
+ video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
122
+ n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
123
+ n_video_features = video_embeds.shape[0]
124
+ if n_video_tokens != n_video_features:
125
+ raise ValueError(
126
+ f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
127
+ )
128
+ video_mask = (
129
+ (input_ids == self.config.video_token_id)
130
+ .unsqueeze(-1)
131
+ .expand_as(inputs_embeds)
132
+ .to(inputs_embeds.device)
114
133
  )
115
- video_mask = input_ids == self.config.video_token_id
116
- inputs_embeds[video_mask] = video_embeds
134
+ video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
135
+ inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
136
+
117
137
  if attention_mask is not None:
118
138
  attention_mask = attention_mask.to(inputs_embeds.device)
119
- # The code is copied from https://github.com/huggingface/transformers/pull/33487
120
- if position_ids is None and input_ids is not None:
121
- position_ids, _ = self.get_rope_index(
122
- input_ids, image_grid_thw, video_grid_thw, attention_mask
123
- )
139
+
140
+ if version.parse(transformers_version) > version.parse("4.46.2"):
141
+ # NOTE: this bug fix for qwen2-vl is not applied until transformers 4.47.0
142
+ # https://github.com/huggingface/transformers/issues/33401
143
+ # While correct, this breaks equivalence with past versions of Qwen2-VL from
144
+ # transformers and leads to failed tests or users noticing differences in results.
145
+ # TODO: remove above conditional when liger drops support for transformers<4.47.0
146
+ if position_ids is None and input_ids is not None:
147
+ position_ids, _ = self.get_rope_index(
148
+ input_ids, image_grid_thw, video_grid_thw, attention_mask
149
+ )
124
150
 
125
151
  outputs = self.model(
126
152
  input_ids=None,
@@ -56,12 +56,15 @@ def _bind_method_to_module(module, method_name: str, new_method: Callable):
56
56
  module.__dict__[method_name] = new_method.__get__(module, module.__class__)
57
57
 
58
58
 
59
- def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama"):
59
+ def _patch_rms_norm_module(
60
+ module, offset=0.0, eps=1e-6, casting_mode="llama", in_place=True
61
+ ):
60
62
  module.offset = offset
61
63
  module.casting_mode = casting_mode
62
64
  module.variance_epsilon = (
63
65
  getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
64
66
  )
67
+ module.in_place = in_place
65
68
  _bind_method_to_module(module, "forward", LigerRMSNorm.forward)
66
69
  _bind_method_to_module(module, "extra_repr", LigerRMSNorm.extra_repr)
67
70
 
@@ -510,7 +513,7 @@ def apply_liger_kernel_to_gemma2(
510
513
  LigerRMSNorm, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False
511
514
  )
512
515
  _patch_rms_norm_module_for_gemma2 = partial(
513
- _patch_rms_norm_module, offset=1.0, casting_mode="gemma"
516
+ _patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
514
517
  )
515
518
 
516
519
  if rope:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.4.1.dev20241115191733
3
+ Version: 0.4.1.dev20241117192031
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation