liger-kernel-nightly 0.4.0.dev20241107052928__tar.gz → 0.4.0.dev20241107194223__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (53) hide show
  1. {liger_kernel_nightly-0.4.0.dev20241107052928/src/liger_kernel_nightly.egg-info → liger_kernel_nightly-0.4.0.dev20241107194223}/PKG-INFO +1 -1
  2. {liger_kernel_nightly-0.4.0.dev20241107052928 → liger_kernel_nightly-0.4.0.dev20241107194223}/pyproject.toml +1 -1
  3. {liger_kernel_nightly-0.4.0.dev20241107052928 → liger_kernel_nightly-0.4.0.dev20241107194223}/src/liger_kernel/transformers/monkey_patch.py +24 -51
  4. {liger_kernel_nightly-0.4.0.dev20241107052928 → liger_kernel_nightly-0.4.0.dev20241107194223/src/liger_kernel_nightly.egg-info}/PKG-INFO +1 -1
  5. {liger_kernel_nightly-0.4.0.dev20241107052928 → liger_kernel_nightly-0.4.0.dev20241107194223}/LICENSE +0 -0
  6. {liger_kernel_nightly-0.4.0.dev20241107052928 → liger_kernel_nightly-0.4.0.dev20241107194223}/NOTICE +0 -0
  7. {liger_kernel_nightly-0.4.0.dev20241107052928 → liger_kernel_nightly-0.4.0.dev20241107194223}/README.md +0 -0
  8. {liger_kernel_nightly-0.4.0.dev20241107052928 → liger_kernel_nightly-0.4.0.dev20241107194223}/setup.cfg +0 -0
  9. {liger_kernel_nightly-0.4.0.dev20241107052928 → liger_kernel_nightly-0.4.0.dev20241107194223}/src/liger_kernel/env_report.py +0 -0
  10. {liger_kernel_nightly-0.4.0.dev20241107052928 → liger_kernel_nightly-0.4.0.dev20241107194223}/src/liger_kernel/ops/__init__.py +0 -0
  11. {liger_kernel_nightly-0.4.0.dev20241107052928 → liger_kernel_nightly-0.4.0.dev20241107194223}/src/liger_kernel/ops/cross_entropy.py +0 -0
  12. {liger_kernel_nightly-0.4.0.dev20241107052928 → liger_kernel_nightly-0.4.0.dev20241107194223}/src/liger_kernel/ops/experimental/embedding.py +0 -0
  13. {liger_kernel_nightly-0.4.0.dev20241107052928 → liger_kernel_nightly-0.4.0.dev20241107194223}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
  14. {liger_kernel_nightly-0.4.0.dev20241107052928 → liger_kernel_nightly-0.4.0.dev20241107194223}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -0
  15. {liger_kernel_nightly-0.4.0.dev20241107052928 → liger_kernel_nightly-0.4.0.dev20241107194223}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
  16. {liger_kernel_nightly-0.4.0.dev20241107052928 → liger_kernel_nightly-0.4.0.dev20241107194223}/src/liger_kernel/ops/geglu.py +0 -0
  17. {liger_kernel_nightly-0.4.0.dev20241107052928 → liger_kernel_nightly-0.4.0.dev20241107194223}/src/liger_kernel/ops/jsd.py +0 -0
  18. {liger_kernel_nightly-0.4.0.dev20241107052928 → liger_kernel_nightly-0.4.0.dev20241107194223}/src/liger_kernel/ops/kl_div.py +0 -0
  19. {liger_kernel_nightly-0.4.0.dev20241107052928 → liger_kernel_nightly-0.4.0.dev20241107194223}/src/liger_kernel/ops/layer_norm.py +0 -0
  20. {liger_kernel_nightly-0.4.0.dev20241107052928 → liger_kernel_nightly-0.4.0.dev20241107194223}/src/liger_kernel/ops/rms_norm.py +0 -0
  21. {liger_kernel_nightly-0.4.0.dev20241107052928 → liger_kernel_nightly-0.4.0.dev20241107194223}/src/liger_kernel/ops/rope.py +0 -0
  22. {liger_kernel_nightly-0.4.0.dev20241107052928 → liger_kernel_nightly-0.4.0.dev20241107194223}/src/liger_kernel/ops/swiglu.py +0 -0
  23. {liger_kernel_nightly-0.4.0.dev20241107052928 → liger_kernel_nightly-0.4.0.dev20241107194223}/src/liger_kernel/ops/utils.py +0 -0
  24. {liger_kernel_nightly-0.4.0.dev20241107052928 → liger_kernel_nightly-0.4.0.dev20241107194223}/src/liger_kernel/transformers/__init__.py +0 -0
  25. {liger_kernel_nightly-0.4.0.dev20241107052928 → liger_kernel_nightly-0.4.0.dev20241107194223}/src/liger_kernel/transformers/auto_model.py +0 -0
  26. {liger_kernel_nightly-0.4.0.dev20241107052928 → liger_kernel_nightly-0.4.0.dev20241107194223}/src/liger_kernel/transformers/cross_entropy.py +0 -0
  27. {liger_kernel_nightly-0.4.0.dev20241107052928 → liger_kernel_nightly-0.4.0.dev20241107194223}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
  28. {liger_kernel_nightly-0.4.0.dev20241107052928 → liger_kernel_nightly-0.4.0.dev20241107194223}/src/liger_kernel/transformers/functional.py +0 -0
  29. {liger_kernel_nightly-0.4.0.dev20241107052928 → liger_kernel_nightly-0.4.0.dev20241107194223}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
  30. {liger_kernel_nightly-0.4.0.dev20241107052928 → liger_kernel_nightly-0.4.0.dev20241107194223}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
  31. {liger_kernel_nightly-0.4.0.dev20241107052928 → liger_kernel_nightly-0.4.0.dev20241107194223}/src/liger_kernel/transformers/geglu.py +0 -0
  32. {liger_kernel_nightly-0.4.0.dev20241107052928 → liger_kernel_nightly-0.4.0.dev20241107194223}/src/liger_kernel/transformers/jsd.py +0 -0
  33. {liger_kernel_nightly-0.4.0.dev20241107052928 → liger_kernel_nightly-0.4.0.dev20241107194223}/src/liger_kernel/transformers/kl_div.py +0 -0
  34. {liger_kernel_nightly-0.4.0.dev20241107052928 → liger_kernel_nightly-0.4.0.dev20241107194223}/src/liger_kernel/transformers/layer_norm.py +0 -0
  35. {liger_kernel_nightly-0.4.0.dev20241107052928 → liger_kernel_nightly-0.4.0.dev20241107194223}/src/liger_kernel/transformers/model/__init__.py +0 -0
  36. {liger_kernel_nightly-0.4.0.dev20241107052928 → liger_kernel_nightly-0.4.0.dev20241107194223}/src/liger_kernel/transformers/model/gemma.py +0 -0
  37. {liger_kernel_nightly-0.4.0.dev20241107052928 → liger_kernel_nightly-0.4.0.dev20241107194223}/src/liger_kernel/transformers/model/llama.py +0 -0
  38. {liger_kernel_nightly-0.4.0.dev20241107052928 → liger_kernel_nightly-0.4.0.dev20241107194223}/src/liger_kernel/transformers/model/mistral.py +0 -0
  39. {liger_kernel_nightly-0.4.0.dev20241107052928 → liger_kernel_nightly-0.4.0.dev20241107194223}/src/liger_kernel/transformers/model/mixtral.py +0 -0
  40. {liger_kernel_nightly-0.4.0.dev20241107052928 → liger_kernel_nightly-0.4.0.dev20241107194223}/src/liger_kernel/transformers/model/mllama.py +0 -0
  41. {liger_kernel_nightly-0.4.0.dev20241107052928 → liger_kernel_nightly-0.4.0.dev20241107194223}/src/liger_kernel/transformers/model/phi3.py +0 -0
  42. {liger_kernel_nightly-0.4.0.dev20241107052928 → liger_kernel_nightly-0.4.0.dev20241107194223}/src/liger_kernel/transformers/model/qwen2.py +0 -0
  43. {liger_kernel_nightly-0.4.0.dev20241107052928 → liger_kernel_nightly-0.4.0.dev20241107194223}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
  44. {liger_kernel_nightly-0.4.0.dev20241107052928 → liger_kernel_nightly-0.4.0.dev20241107194223}/src/liger_kernel/transformers/rms_norm.py +0 -0
  45. {liger_kernel_nightly-0.4.0.dev20241107052928 → liger_kernel_nightly-0.4.0.dev20241107194223}/src/liger_kernel/transformers/rope.py +0 -0
  46. {liger_kernel_nightly-0.4.0.dev20241107052928 → liger_kernel_nightly-0.4.0.dev20241107194223}/src/liger_kernel/transformers/swiglu.py +0 -0
  47. {liger_kernel_nightly-0.4.0.dev20241107052928 → liger_kernel_nightly-0.4.0.dev20241107194223}/src/liger_kernel/transformers/trainer_integration.py +0 -0
  48. {liger_kernel_nightly-0.4.0.dev20241107052928 → liger_kernel_nightly-0.4.0.dev20241107194223}/src/liger_kernel/triton/__init__.py +0 -0
  49. {liger_kernel_nightly-0.4.0.dev20241107052928 → liger_kernel_nightly-0.4.0.dev20241107194223}/src/liger_kernel/triton/monkey_patch.py +0 -0
  50. {liger_kernel_nightly-0.4.0.dev20241107052928 → liger_kernel_nightly-0.4.0.dev20241107194223}/src/liger_kernel_nightly.egg-info/SOURCES.txt +0 -0
  51. {liger_kernel_nightly-0.4.0.dev20241107052928 → liger_kernel_nightly-0.4.0.dev20241107194223}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
  52. {liger_kernel_nightly-0.4.0.dev20241107052928 → liger_kernel_nightly-0.4.0.dev20241107194223}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
  53. {liger_kernel_nightly-0.4.0.dev20241107052928 → liger_kernel_nightly-0.4.0.dev20241107194223}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.4.0.dev20241107052928
3
+ Version: 0.4.0.dev20241107194223
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "liger_kernel_nightly"
7
- version = "0.4.0.dev20241107052928"
7
+ version = "0.4.0.dev20241107194223"
8
8
  description = "Efficient Triton kernels for LLM Training"
9
9
  urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
10
10
  readme = { file = "README.md", content-type = "text/markdown" }
@@ -99,6 +99,7 @@ def apply_liger_kernel_to_llama(
99
99
  ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
100
100
 
101
101
  from transformers.models.llama import modeling_llama
102
+ from transformers.models.llama.modeling_llama import LlamaModel
102
103
 
103
104
  if rope:
104
105
  modeling_llama.apply_rotary_pos_emb = liger_rotary_pos_emb
@@ -119,15 +120,8 @@ def apply_liger_kernel_to_llama(
119
120
  # The model instance already exists, so we need to additionally patch the
120
121
  # instance variables that reference already-instantiated modules (e.g. LlamaRMSNorm or LlamaMLP)
121
122
 
122
- if hasattr(model, "model"):
123
- # The case for LlamaForCausalLM or LlamaForSequenceClassification, for example
124
- base_model = model.model
125
- elif hasattr(model, "transformer"):
126
- # LlamaForQuestionAnswering uses "transformer" instead of "model"
127
- base_model = model.transformer
128
- else:
129
- # Direct LlamaModel
130
- base_model = model
123
+ # get the base model from the model instance
124
+ base_model: LlamaModel = getattr(model, model.base_model_prefix, model)
131
125
 
132
126
  if rms_norm:
133
127
  _patch_rms_norm_module(base_model.norm)
@@ -275,6 +269,7 @@ def apply_liger_kernel_to_mistral(
275
269
  ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
276
270
 
277
271
  from transformers.models.mistral import modeling_mistral
272
+ from transformers.models.mistral.modeling_mistral import MistralModel
278
273
 
279
274
  if rope:
280
275
  modeling_mistral.apply_rotary_pos_emb = liger_rotary_pos_emb
@@ -291,12 +286,8 @@ def apply_liger_kernel_to_mistral(
291
286
  # The model instance already exists, so we need to additionally patch the
292
287
  # instance variables that reference already-instantiated modules
293
288
 
294
- if hasattr(model, "model"):
295
- # The case for MistralForCausalLM, MistralForTokenClassification for example
296
- base_model = model.model
297
- else:
298
- # Direct MistralModel
299
- base_model = model
289
+ # get the base model from the model instance
290
+ base_model: MistralModel = getattr(model, model.base_model_prefix, model)
300
291
 
301
292
  if rms_norm:
302
293
  _patch_rms_norm_module(base_model.norm)
@@ -340,6 +331,7 @@ def apply_liger_kernel_to_mixtral(
340
331
  ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
341
332
 
342
333
  from transformers.models.mixtral import modeling_mixtral
334
+ from transformers.models.mixtral.modeling_mixtral import MixtralModel
343
335
 
344
336
  if rope:
345
337
  modeling_mixtral.apply_rotary_pos_emb = liger_rotary_pos_emb
@@ -360,12 +352,8 @@ def apply_liger_kernel_to_mixtral(
360
352
  # The model instance already exists, so we need to additionally patch the
361
353
  # instance variables that reference already-instantiated modules
362
354
 
363
- if hasattr(model, "model"):
364
- # The case for MixtralForCausalLM, MixtralForTokenClassification for example
365
- base_model = model.model
366
- else:
367
- # Direct MixtralModel
368
- base_model = model
355
+ # get the base model from the model instance
356
+ base_model: MixtralModel = getattr(model, model.base_model_prefix, model)
369
357
 
370
358
  if rms_norm:
371
359
  _patch_rms_norm_module(base_model.norm)
@@ -410,6 +398,7 @@ def apply_liger_kernel_to_gemma(
410
398
  ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
411
399
 
412
400
  from transformers.models.gemma import modeling_gemma
401
+ from transformers.models.gemma.modeling_gemma import GemmaModel
413
402
 
414
403
  # https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/gemma/modeling_gemma.py#L109
415
404
  LigerRMSNormForGemma = partial(
@@ -438,12 +427,8 @@ def apply_liger_kernel_to_gemma(
438
427
  # The model instance already exists, so we need to additionally patch the
439
428
  # instance variables that reference already-instantiated modules
440
429
 
441
- if hasattr(model, "model"):
442
- # The case for GemmaForCausalLM, GemmaForTokenClassification for example
443
- base_model = model.model
444
- else:
445
- # Direct GemmaModel
446
- base_model = model
430
+ # get the base model from the model instance
431
+ base_model: GemmaModel = getattr(model, model.base_model_prefix, model)
447
432
 
448
433
  if rms_norm:
449
434
  _patch_rms_norm_module_for_gemma(base_model.norm)
@@ -478,6 +463,7 @@ def apply_liger_kernel_to_gemma2(
478
463
  loaded. Default is None.
479
464
  """
480
465
  from transformers.models.gemma2 import modeling_gemma2
466
+ from transformers.models.gemma2.modeling_gemma2 import Gemma2Model
481
467
 
482
468
  LigerRMSNormForGemma2 = partial(
483
469
  LigerRMSNorm, offset=1.0, casting_mode="gemma", init_fn="zeros"
@@ -500,12 +486,8 @@ def apply_liger_kernel_to_gemma2(
500
486
  # The model instance already exists, so we need to additionally patch the
501
487
  # instance variables that reference already-instantiated modules
502
488
 
503
- if hasattr(model, "model"):
504
- # The case for Gemma2ForCausalLM, Gemma2ForTokenClassification for example
505
- base_model = model.model
506
- else:
507
- # Direct Gemma2Model
508
- base_model = model
489
+ # get the base model from the model instance
490
+ base_model: Gemma2Model = getattr(model, model.base_model_prefix, model)
509
491
 
510
492
  if rms_norm:
511
493
  _patch_rms_norm_module_for_gemma2(base_model.norm)
@@ -556,6 +538,7 @@ def apply_liger_kernel_to_qwen2(
556
538
  ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
557
539
 
558
540
  from transformers.models.qwen2 import modeling_qwen2
541
+ from transformers.models.qwen2.modeling_qwen2 import Qwen2Model
559
542
 
560
543
  if rope:
561
544
  modeling_qwen2.apply_rotary_pos_emb = liger_rotary_pos_emb
@@ -580,12 +563,8 @@ def apply_liger_kernel_to_qwen2(
580
563
  # The model instance already exists, so we need to additionally patch the
581
564
  # instance variables that reference already-instantiated modules
582
565
 
583
- if hasattr(model, "model"):
584
- # The case for Qwen2ForCausalLM, Qwen2ForTokenClassification for example
585
- base_model = model.model
586
- else:
587
- # Direct Qwen2Model
588
- base_model = model
566
+ # get the base model from the model instance
567
+ base_model: Qwen2Model = getattr(model, model.base_model_prefix, model)
589
568
 
590
569
  if rms_norm:
591
570
  _patch_rms_norm_module(base_model.norm)
@@ -630,6 +609,7 @@ def apply_liger_kernel_to_qwen2_vl(
630
609
  ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
631
610
 
632
611
  from transformers.models.qwen2_vl import modeling_qwen2_vl
612
+ from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLModel
633
613
 
634
614
  from liger_kernel.transformers.model.qwen2_vl import (
635
615
  lce_forward as qwen2_vl_lce_forward,
@@ -653,12 +633,8 @@ def apply_liger_kernel_to_qwen2_vl(
653
633
  # The model instance already exists, so we need to additionally patch the
654
634
  # instance variables that reference already-instantiated modules
655
635
 
656
- if hasattr(model, "model"):
657
- # The case for Qwen2VLForConditionalGeneration.
658
- base_model = model.model
659
- else:
660
- # Direct Qwen2VLModel
661
- base_model = model
636
+ # get the base model from the model instance
637
+ base_model: Qwen2VLModel = getattr(model, model.base_model_prefix, model)
662
638
 
663
639
  if hasattr(model, "visual"):
664
640
  # Patch Qwen2VisionTransformerPretrainedModel
@@ -707,6 +683,7 @@ def apply_liger_kernel_to_phi3(
707
683
  ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
708
684
 
709
685
  from transformers.models.phi3 import modeling_phi3
686
+ from transformers.models.phi3.modeling_phi3 import Phi3Model
710
687
 
711
688
  if rope:
712
689
  modeling_phi3.apply_rotary_pos_emb = liger_rotary_pos_emb # Same as Gemma
@@ -727,12 +704,8 @@ def apply_liger_kernel_to_phi3(
727
704
  # The model instance already exists, so we need to additionally patch the
728
705
  # instance variables that reference already-instantiated modules
729
706
 
730
- if hasattr(model, "model"):
731
- # The case for Phi3ForCausalLM, Phi3ForTokenClassification for example
732
- base_model = model.model
733
- else:
734
- # Direct Phi3Model
735
- base_model = model
707
+ # get the base model from the model instance
708
+ base_model: Phi3Model = getattr(model, model.base_model_prefix, model)
736
709
 
737
710
  if rms_norm:
738
711
  _patch_rms_norm_module(base_model.norm)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.4.0.dev20241107052928
3
+ Version: 0.4.0.dev20241107194223
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation