liger-kernel-nightly 0.5.2.dev20241211231633__tar.gz → 0.5.2.dev20241212000548__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (199) hide show
  1. liger_kernel_nightly-0.5.2.dev20241212000548/.flake8 +10 -0
  2. liger_kernel_nightly-0.5.2.dev20241212000548/.github/ISSUE_TEMPLATE/bug_report.yaml +48 -0
  3. liger_kernel_nightly-0.5.2.dev20241212000548/.github/ISSUE_TEMPLATE/feature_request.yaml +25 -0
  4. liger_kernel_nightly-0.5.2.dev20241212000548/.github/pull_request_template.md +22 -0
  5. liger_kernel_nightly-0.5.2.dev20241212000548/.github/workflows/amd-ci.yml +71 -0
  6. liger_kernel_nightly-0.5.2.dev20241212000548/.github/workflows/nvi-ci.yml +96 -0
  7. liger_kernel_nightly-0.5.2.dev20241212000548/.github/workflows/publish-nightly.yml +49 -0
  8. liger_kernel_nightly-0.5.2.dev20241212000548/.github/workflows/publish-release.yml +38 -0
  9. liger_kernel_nightly-0.5.2.dev20241212000548/.gitignore +20 -0
  10. liger_kernel_nightly-0.5.2.dev20241212000548/.isort.cfg +2 -0
  11. liger_kernel_nightly-0.5.2.dev20241212000548/Makefile +42 -0
  12. {liger_kernel_nightly-0.5.2.dev20241211231633/src/liger_kernel_nightly.egg-info → liger_kernel_nightly-0.5.2.dev20241212000548}/PKG-INFO +11 -15
  13. {liger_kernel_nightly-0.5.2.dev20241211231633 → liger_kernel_nightly-0.5.2.dev20241212000548}/README.md +6 -4
  14. liger_kernel_nightly-0.5.2.dev20241212000548/benchmark/benchmarks_visualizer.py +169 -0
  15. liger_kernel_nightly-0.5.2.dev20241212000548/benchmark/data/all_benchmark_data.csv +717 -0
  16. liger_kernel_nightly-0.5.2.dev20241212000548/benchmark/scripts/benchmark_cpo_loss.py +191 -0
  17. liger_kernel_nightly-0.5.2.dev20241212000548/benchmark/scripts/benchmark_cross_entropy.py +126 -0
  18. liger_kernel_nightly-0.5.2.dev20241212000548/benchmark/scripts/benchmark_dpo_loss.py +227 -0
  19. liger_kernel_nightly-0.5.2.dev20241212000548/benchmark/scripts/benchmark_embedding.py +129 -0
  20. liger_kernel_nightly-0.5.2.dev20241212000548/benchmark/scripts/benchmark_fused_linear_cross_entropy.py +183 -0
  21. liger_kernel_nightly-0.5.2.dev20241212000548/benchmark/scripts/benchmark_fused_linear_jsd.py +273 -0
  22. liger_kernel_nightly-0.5.2.dev20241212000548/benchmark/scripts/benchmark_geglu.py +180 -0
  23. liger_kernel_nightly-0.5.2.dev20241212000548/benchmark/scripts/benchmark_group_norm.py +150 -0
  24. liger_kernel_nightly-0.5.2.dev20241212000548/benchmark/scripts/benchmark_jsd.py +157 -0
  25. liger_kernel_nightly-0.5.2.dev20241212000548/benchmark/scripts/benchmark_kl_div.py +124 -0
  26. liger_kernel_nightly-0.5.2.dev20241212000548/benchmark/scripts/benchmark_layer_norm.py +130 -0
  27. liger_kernel_nightly-0.5.2.dev20241212000548/benchmark/scripts/benchmark_orpo_loss.py +191 -0
  28. liger_kernel_nightly-0.5.2.dev20241212000548/benchmark/scripts/benchmark_qwen2vl_mrope.py +252 -0
  29. liger_kernel_nightly-0.5.2.dev20241212000548/benchmark/scripts/benchmark_rms_norm.py +163 -0
  30. liger_kernel_nightly-0.5.2.dev20241212000548/benchmark/scripts/benchmark_rope.py +230 -0
  31. liger_kernel_nightly-0.5.2.dev20241212000548/benchmark/scripts/benchmark_simpo_loss.py +191 -0
  32. liger_kernel_nightly-0.5.2.dev20241212000548/benchmark/scripts/benchmark_swiglu.py +178 -0
  33. liger_kernel_nightly-0.5.2.dev20241212000548/benchmark/scripts/utils.py +356 -0
  34. liger_kernel_nightly-0.5.2.dev20241212000548/dev/modal/tests.py +30 -0
  35. liger_kernel_nightly-0.5.2.dev20241212000548/dev/modal/tests_bwd.py +35 -0
  36. liger_kernel_nightly-0.5.2.dev20241212000548/docs/Acknowledgement.md +27 -0
  37. liger_kernel_nightly-0.5.2.dev20241212000548/docs/CONTRIBUTING.md +107 -0
  38. liger_kernel_nightly-0.5.2.dev20241212000548/docs/License.md +8 -0
  39. liger_kernel_nightly-0.5.2.dev20241212000548/docs/images/banner.GIF +0 -0
  40. liger_kernel_nightly-0.5.2.dev20241212000548/docs/images/compose.gif +0 -0
  41. liger_kernel_nightly-0.5.2.dev20241212000548/docs/images/e2e-memory.png +0 -0
  42. liger_kernel_nightly-0.5.2.dev20241212000548/docs/images/e2e-tps.png +0 -0
  43. liger_kernel_nightly-0.5.2.dev20241212000548/docs/images/logo-banner.png +0 -0
  44. liger_kernel_nightly-0.5.2.dev20241212000548/docs/images/patch.gif +0 -0
  45. liger_kernel_nightly-0.5.2.dev20241212000548/examples/alignment/accelerate_config.yaml +26 -0
  46. liger_kernel_nightly-0.5.2.dev20241212000548/examples/alignment/run_orpo.py +35 -0
  47. liger_kernel_nightly-0.5.2.dev20241212000548/examples/huggingface/README.md +55 -0
  48. liger_kernel_nightly-0.5.2.dev20241212000548/examples/huggingface/callback.py +275 -0
  49. liger_kernel_nightly-0.5.2.dev20241212000548/examples/huggingface/config/fsdp_config.json +5 -0
  50. liger_kernel_nightly-0.5.2.dev20241212000548/examples/huggingface/img/gemma_7b_mem.png +0 -0
  51. liger_kernel_nightly-0.5.2.dev20241212000548/examples/huggingface/img/gemma_7b_tp.png +0 -0
  52. liger_kernel_nightly-0.5.2.dev20241212000548/examples/huggingface/img/llama_mem_alloc.png +0 -0
  53. liger_kernel_nightly-0.5.2.dev20241212000548/examples/huggingface/img/llama_tps.png +0 -0
  54. liger_kernel_nightly-0.5.2.dev20241212000548/examples/huggingface/img/qwen_mem_alloc.png +0 -0
  55. liger_kernel_nightly-0.5.2.dev20241212000548/examples/huggingface/img/qwen_tps.png +0 -0
  56. liger_kernel_nightly-0.5.2.dev20241212000548/examples/huggingface/launch_on_modal.py +72 -0
  57. liger_kernel_nightly-0.5.2.dev20241212000548/examples/huggingface/requirements.txt +6 -0
  58. liger_kernel_nightly-0.5.2.dev20241212000548/examples/huggingface/run_benchmarks.sh +52 -0
  59. liger_kernel_nightly-0.5.2.dev20241212000548/examples/huggingface/run_gemma.sh +22 -0
  60. liger_kernel_nightly-0.5.2.dev20241212000548/examples/huggingface/run_llama.sh +21 -0
  61. liger_kernel_nightly-0.5.2.dev20241212000548/examples/huggingface/run_qwen.sh +22 -0
  62. liger_kernel_nightly-0.5.2.dev20241212000548/examples/huggingface/run_qwen2_vl.sh +22 -0
  63. liger_kernel_nightly-0.5.2.dev20241212000548/examples/huggingface/training.py +81 -0
  64. liger_kernel_nightly-0.5.2.dev20241212000548/examples/huggingface/training_multimodal.py +171 -0
  65. liger_kernel_nightly-0.5.2.dev20241212000548/examples/lightning/README.md +21 -0
  66. liger_kernel_nightly-0.5.2.dev20241212000548/examples/lightning/requirements.txt +8 -0
  67. liger_kernel_nightly-0.5.2.dev20241212000548/examples/lightning/training.py +295 -0
  68. liger_kernel_nightly-0.5.2.dev20241212000548/examples/medusa/README.md +72 -0
  69. liger_kernel_nightly-0.5.2.dev20241212000548/examples/medusa/callback.py +410 -0
  70. liger_kernel_nightly-0.5.2.dev20241212000548/examples/medusa/docs/images/Memory_Stage1_num_head_3.png +0 -0
  71. liger_kernel_nightly-0.5.2.dev20241212000548/examples/medusa/docs/images/Memory_Stage1_num_head_5.png +0 -0
  72. liger_kernel_nightly-0.5.2.dev20241212000548/examples/medusa/docs/images/Memory_Stage2_num_head_3.png +0 -0
  73. liger_kernel_nightly-0.5.2.dev20241212000548/examples/medusa/docs/images/Memory_Stage2_num_head_5.png +0 -0
  74. liger_kernel_nightly-0.5.2.dev20241212000548/examples/medusa/docs/images/Throughput_Stage1_num_head_3.png +0 -0
  75. liger_kernel_nightly-0.5.2.dev20241212000548/examples/medusa/docs/images/Throughput_Stage1_num_head_5.png +0 -0
  76. liger_kernel_nightly-0.5.2.dev20241212000548/examples/medusa/docs/images/Throughput_Stage2_num_head_3.png +0 -0
  77. liger_kernel_nightly-0.5.2.dev20241212000548/examples/medusa/docs/images/Throughput_Stage2_num_head_5.png +0 -0
  78. liger_kernel_nightly-0.5.2.dev20241212000548/examples/medusa/fsdp/acc-fsdp.conf +24 -0
  79. liger_kernel_nightly-0.5.2.dev20241212000548/examples/medusa/medusa_util.py +289 -0
  80. liger_kernel_nightly-0.5.2.dev20241212000548/examples/medusa/requirements.txt +3 -0
  81. liger_kernel_nightly-0.5.2.dev20241212000548/examples/medusa/scripts/llama3_8b_medusa.sh +56 -0
  82. liger_kernel_nightly-0.5.2.dev20241212000548/examples/medusa/train.py +403 -0
  83. liger_kernel_nightly-0.5.2.dev20241212000548/licenses/LICENSE-Apache-2.0 +201 -0
  84. liger_kernel_nightly-0.5.2.dev20241212000548/licenses/LICENSE-MIT-AutoAWQ +21 -0
  85. liger_kernel_nightly-0.5.2.dev20241212000548/licenses/LICENSE-MIT-Efficient-Cross-Entropy +21 -0
  86. liger_kernel_nightly-0.5.2.dev20241212000548/licenses/LICENSE-MIT-llmc +21 -0
  87. liger_kernel_nightly-0.5.2.dev20241212000548/licenses/LICENSE-MIT-triton +23 -0
  88. liger_kernel_nightly-0.5.2.dev20241212000548/pyproject.toml +26 -0
  89. liger_kernel_nightly-0.5.2.dev20241212000548/setup.py +71 -0
  90. liger_kernel_nightly-0.5.2.dev20241212000548/src/liger_kernel/chunked_loss/README.md +25 -0
  91. liger_kernel_nightly-0.5.2.dev20241212000548/src/liger_kernel/ops/__init__.py +0 -0
  92. liger_kernel_nightly-0.5.2.dev20241212000548/src/liger_kernel/transformers/model/__init__.py +0 -0
  93. {liger_kernel_nightly-0.5.2.dev20241211231633 → liger_kernel_nightly-0.5.2.dev20241212000548/src/liger_kernel_nightly.egg-info}/PKG-INFO +11 -15
  94. liger_kernel_nightly-0.5.2.dev20241212000548/src/liger_kernel_nightly.egg-info/SOURCES.txt +195 -0
  95. {liger_kernel_nightly-0.5.2.dev20241211231633 → liger_kernel_nightly-0.5.2.dev20241212000548}/src/liger_kernel_nightly.egg-info/requires.txt +4 -12
  96. liger_kernel_nightly-0.5.2.dev20241212000548/test/__init__.py +0 -0
  97. liger_kernel_nightly-0.5.2.dev20241212000548/test/chunked_loss/__init__.py +0 -0
  98. liger_kernel_nightly-0.5.2.dev20241212000548/test/chunked_loss/test_cpo_loss.py +288 -0
  99. liger_kernel_nightly-0.5.2.dev20241212000548/test/chunked_loss/test_dpo_loss.py +287 -0
  100. liger_kernel_nightly-0.5.2.dev20241212000548/test/chunked_loss/test_orpo_loss.py +269 -0
  101. liger_kernel_nightly-0.5.2.dev20241212000548/test/chunked_loss/test_simpo_loss.py +203 -0
  102. liger_kernel_nightly-0.5.2.dev20241212000548/test/conftest.py +8 -0
  103. liger_kernel_nightly-0.5.2.dev20241212000548/test/convergence/__init__.py +0 -0
  104. liger_kernel_nightly-0.5.2.dev20241212000548/test/convergence/test_mini_models.py +716 -0
  105. liger_kernel_nightly-0.5.2.dev20241212000548/test/convergence/test_mini_models_multimodal.py +490 -0
  106. liger_kernel_nightly-0.5.2.dev20241212000548/test/convergence/test_mini_models_with_logits.py +715 -0
  107. liger_kernel_nightly-0.5.2.dev20241212000548/test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json +63 -0
  108. liger_kernel_nightly-0.5.2.dev20241212000548/test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json +31 -0
  109. liger_kernel_nightly-0.5.2.dev20241212000548/test/resources/scripts/generate_tokenized_dataset.py +79 -0
  110. liger_kernel_nightly-0.5.2.dev20241212000548/test/resources/tiny_shakespeare.txt +40000 -0
  111. liger_kernel_nightly-0.5.2.dev20241212000548/test/resources/tiny_shakespeare_tokenized/data-00000-of-00001.arrow +0 -0
  112. liger_kernel_nightly-0.5.2.dev20241212000548/test/resources/tiny_shakespeare_tokenized/dataset_info.json +48 -0
  113. liger_kernel_nightly-0.5.2.dev20241212000548/test/resources/tiny_shakespeare_tokenized/state.json +13 -0
  114. liger_kernel_nightly-0.5.2.dev20241212000548/test/transformers/test_auto_model.py +57 -0
  115. liger_kernel_nightly-0.5.2.dev20241212000548/test/transformers/test_cross_entropy.py +795 -0
  116. liger_kernel_nightly-0.5.2.dev20241212000548/test/transformers/test_embedding.py +74 -0
  117. liger_kernel_nightly-0.5.2.dev20241212000548/test/transformers/test_fused_linear_cross_entropy.py +326 -0
  118. liger_kernel_nightly-0.5.2.dev20241212000548/test/transformers/test_fused_linear_jsd.py +475 -0
  119. liger_kernel_nightly-0.5.2.dev20241212000548/test/transformers/test_geglu.py +147 -0
  120. liger_kernel_nightly-0.5.2.dev20241212000548/test/transformers/test_group_norm.py +70 -0
  121. liger_kernel_nightly-0.5.2.dev20241212000548/test/transformers/test_jsd.py +345 -0
  122. liger_kernel_nightly-0.5.2.dev20241212000548/test/transformers/test_kl_div.py +110 -0
  123. liger_kernel_nightly-0.5.2.dev20241212000548/test/transformers/test_layer_norm.py +101 -0
  124. liger_kernel_nightly-0.5.2.dev20241212000548/test/transformers/test_mm_int8int2.py +109 -0
  125. liger_kernel_nightly-0.5.2.dev20241212000548/test/transformers/test_monkey_patch.py +875 -0
  126. liger_kernel_nightly-0.5.2.dev20241212000548/test/transformers/test_qwen2vl_mrope.py +166 -0
  127. liger_kernel_nightly-0.5.2.dev20241212000548/test/transformers/test_rms_norm.py +194 -0
  128. liger_kernel_nightly-0.5.2.dev20241212000548/test/transformers/test_rope.py +157 -0
  129. liger_kernel_nightly-0.5.2.dev20241212000548/test/transformers/test_swiglu.py +221 -0
  130. liger_kernel_nightly-0.5.2.dev20241212000548/test/transformers/test_trainer_integration.py +10 -0
  131. liger_kernel_nightly-0.5.2.dev20241212000548/test/transformers/test_transformers.py +18 -0
  132. liger_kernel_nightly-0.5.2.dev20241212000548/test/triton/test_triton_monkey_patch.py +22 -0
  133. liger_kernel_nightly-0.5.2.dev20241212000548/test/utils.py +633 -0
  134. liger_kernel_nightly-0.5.2.dev20241211231633/pyproject.toml +0 -61
  135. liger_kernel_nightly-0.5.2.dev20241211231633/src/liger_kernel_nightly.egg-info/SOURCES.txt +0 -68
  136. {liger_kernel_nightly-0.5.2.dev20241211231633 → liger_kernel_nightly-0.5.2.dev20241212000548}/LICENSE +0 -0
  137. {liger_kernel_nightly-0.5.2.dev20241211231633 → liger_kernel_nightly-0.5.2.dev20241212000548}/NOTICE +0 -0
  138. {liger_kernel_nightly-0.5.2.dev20241211231633/src/liger_kernel → liger_kernel_nightly-0.5.2.dev20241212000548/benchmark}/__init__.py +0 -0
  139. {liger_kernel_nightly-0.5.2.dev20241211231633/src/liger_kernel/ops → liger_kernel_nightly-0.5.2.dev20241212000548/benchmark/scripts}/__init__.py +0 -0
  140. {liger_kernel_nightly-0.5.2.dev20241211231633 → liger_kernel_nightly-0.5.2.dev20241212000548}/setup.cfg +0 -0
  141. {liger_kernel_nightly-0.5.2.dev20241211231633/src/liger_kernel/transformers/model → liger_kernel_nightly-0.5.2.dev20241212000548/src/liger_kernel}/__init__.py +0 -0
  142. {liger_kernel_nightly-0.5.2.dev20241211231633 → liger_kernel_nightly-0.5.2.dev20241212000548}/src/liger_kernel/chunked_loss/__init__.py +0 -0
  143. {liger_kernel_nightly-0.5.2.dev20241211231633 → liger_kernel_nightly-0.5.2.dev20241212000548}/src/liger_kernel/chunked_loss/cpo_loss.py +0 -0
  144. {liger_kernel_nightly-0.5.2.dev20241211231633 → liger_kernel_nightly-0.5.2.dev20241212000548}/src/liger_kernel/chunked_loss/dpo_loss.py +0 -0
  145. {liger_kernel_nightly-0.5.2.dev20241211231633 → liger_kernel_nightly-0.5.2.dev20241212000548}/src/liger_kernel/chunked_loss/functional.py +0 -0
  146. {liger_kernel_nightly-0.5.2.dev20241211231633 → liger_kernel_nightly-0.5.2.dev20241212000548}/src/liger_kernel/chunked_loss/fused_linear_distillation.py +0 -0
  147. {liger_kernel_nightly-0.5.2.dev20241211231633 → liger_kernel_nightly-0.5.2.dev20241212000548}/src/liger_kernel/chunked_loss/fused_linear_preference.py +0 -0
  148. {liger_kernel_nightly-0.5.2.dev20241211231633 → liger_kernel_nightly-0.5.2.dev20241212000548}/src/liger_kernel/chunked_loss/orpo_loss.py +0 -0
  149. {liger_kernel_nightly-0.5.2.dev20241211231633 → liger_kernel_nightly-0.5.2.dev20241212000548}/src/liger_kernel/chunked_loss/simpo_loss.py +0 -0
  150. {liger_kernel_nightly-0.5.2.dev20241211231633 → liger_kernel_nightly-0.5.2.dev20241212000548}/src/liger_kernel/env_report.py +0 -0
  151. {liger_kernel_nightly-0.5.2.dev20241211231633 → liger_kernel_nightly-0.5.2.dev20241212000548}/src/liger_kernel/ops/cross_entropy.py +0 -0
  152. {liger_kernel_nightly-0.5.2.dev20241211231633 → liger_kernel_nightly-0.5.2.dev20241212000548}/src/liger_kernel/ops/experimental/embedding.py +0 -0
  153. {liger_kernel_nightly-0.5.2.dev20241211231633 → liger_kernel_nightly-0.5.2.dev20241212000548}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
  154. {liger_kernel_nightly-0.5.2.dev20241211231633 → liger_kernel_nightly-0.5.2.dev20241212000548}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -0
  155. {liger_kernel_nightly-0.5.2.dev20241211231633 → liger_kernel_nightly-0.5.2.dev20241212000548}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
  156. {liger_kernel_nightly-0.5.2.dev20241211231633 → liger_kernel_nightly-0.5.2.dev20241212000548}/src/liger_kernel/ops/geglu.py +0 -0
  157. {liger_kernel_nightly-0.5.2.dev20241211231633 → liger_kernel_nightly-0.5.2.dev20241212000548}/src/liger_kernel/ops/group_norm.py +0 -0
  158. {liger_kernel_nightly-0.5.2.dev20241211231633 → liger_kernel_nightly-0.5.2.dev20241212000548}/src/liger_kernel/ops/jsd.py +0 -0
  159. {liger_kernel_nightly-0.5.2.dev20241211231633 → liger_kernel_nightly-0.5.2.dev20241212000548}/src/liger_kernel/ops/kl_div.py +0 -0
  160. {liger_kernel_nightly-0.5.2.dev20241211231633 → liger_kernel_nightly-0.5.2.dev20241212000548}/src/liger_kernel/ops/layer_norm.py +0 -0
  161. {liger_kernel_nightly-0.5.2.dev20241211231633 → liger_kernel_nightly-0.5.2.dev20241212000548}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
  162. {liger_kernel_nightly-0.5.2.dev20241211231633 → liger_kernel_nightly-0.5.2.dev20241212000548}/src/liger_kernel/ops/rms_norm.py +0 -0
  163. {liger_kernel_nightly-0.5.2.dev20241211231633 → liger_kernel_nightly-0.5.2.dev20241212000548}/src/liger_kernel/ops/rope.py +0 -0
  164. {liger_kernel_nightly-0.5.2.dev20241211231633 → liger_kernel_nightly-0.5.2.dev20241212000548}/src/liger_kernel/ops/swiglu.py +0 -0
  165. {liger_kernel_nightly-0.5.2.dev20241211231633 → liger_kernel_nightly-0.5.2.dev20241212000548}/src/liger_kernel/ops/utils.py +0 -0
  166. {liger_kernel_nightly-0.5.2.dev20241211231633 → liger_kernel_nightly-0.5.2.dev20241212000548}/src/liger_kernel/transformers/__init__.py +0 -0
  167. {liger_kernel_nightly-0.5.2.dev20241211231633 → liger_kernel_nightly-0.5.2.dev20241212000548}/src/liger_kernel/transformers/auto_model.py +0 -0
  168. {liger_kernel_nightly-0.5.2.dev20241211231633 → liger_kernel_nightly-0.5.2.dev20241212000548}/src/liger_kernel/transformers/cross_entropy.py +0 -0
  169. {liger_kernel_nightly-0.5.2.dev20241211231633 → liger_kernel_nightly-0.5.2.dev20241212000548}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
  170. {liger_kernel_nightly-0.5.2.dev20241211231633 → liger_kernel_nightly-0.5.2.dev20241212000548}/src/liger_kernel/transformers/functional.py +0 -0
  171. {liger_kernel_nightly-0.5.2.dev20241211231633 → liger_kernel_nightly-0.5.2.dev20241212000548}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
  172. {liger_kernel_nightly-0.5.2.dev20241211231633 → liger_kernel_nightly-0.5.2.dev20241212000548}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
  173. {liger_kernel_nightly-0.5.2.dev20241211231633 → liger_kernel_nightly-0.5.2.dev20241212000548}/src/liger_kernel/transformers/geglu.py +0 -0
  174. {liger_kernel_nightly-0.5.2.dev20241211231633 → liger_kernel_nightly-0.5.2.dev20241212000548}/src/liger_kernel/transformers/group_norm.py +0 -0
  175. {liger_kernel_nightly-0.5.2.dev20241211231633 → liger_kernel_nightly-0.5.2.dev20241212000548}/src/liger_kernel/transformers/jsd.py +0 -0
  176. {liger_kernel_nightly-0.5.2.dev20241211231633 → liger_kernel_nightly-0.5.2.dev20241212000548}/src/liger_kernel/transformers/kl_div.py +0 -0
  177. {liger_kernel_nightly-0.5.2.dev20241211231633 → liger_kernel_nightly-0.5.2.dev20241212000548}/src/liger_kernel/transformers/layer_norm.py +0 -0
  178. {liger_kernel_nightly-0.5.2.dev20241211231633 → liger_kernel_nightly-0.5.2.dev20241212000548}/src/liger_kernel/transformers/model/gemma.py +0 -0
  179. {liger_kernel_nightly-0.5.2.dev20241211231633 → liger_kernel_nightly-0.5.2.dev20241212000548}/src/liger_kernel/transformers/model/gemma2.py +0 -0
  180. {liger_kernel_nightly-0.5.2.dev20241211231633 → liger_kernel_nightly-0.5.2.dev20241212000548}/src/liger_kernel/transformers/model/llama.py +0 -0
  181. {liger_kernel_nightly-0.5.2.dev20241211231633 → liger_kernel_nightly-0.5.2.dev20241212000548}/src/liger_kernel/transformers/model/mistral.py +0 -0
  182. {liger_kernel_nightly-0.5.2.dev20241211231633 → liger_kernel_nightly-0.5.2.dev20241212000548}/src/liger_kernel/transformers/model/mixtral.py +0 -0
  183. {liger_kernel_nightly-0.5.2.dev20241211231633 → liger_kernel_nightly-0.5.2.dev20241212000548}/src/liger_kernel/transformers/model/mllama.py +0 -0
  184. {liger_kernel_nightly-0.5.2.dev20241211231633 → liger_kernel_nightly-0.5.2.dev20241212000548}/src/liger_kernel/transformers/model/phi3.py +0 -0
  185. {liger_kernel_nightly-0.5.2.dev20241211231633 → liger_kernel_nightly-0.5.2.dev20241212000548}/src/liger_kernel/transformers/model/qwen2.py +0 -0
  186. {liger_kernel_nightly-0.5.2.dev20241211231633 → liger_kernel_nightly-0.5.2.dev20241212000548}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
  187. {liger_kernel_nightly-0.5.2.dev20241211231633 → liger_kernel_nightly-0.5.2.dev20241212000548}/src/liger_kernel/transformers/monkey_patch.py +0 -0
  188. {liger_kernel_nightly-0.5.2.dev20241211231633 → liger_kernel_nightly-0.5.2.dev20241212000548}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
  189. {liger_kernel_nightly-0.5.2.dev20241211231633 → liger_kernel_nightly-0.5.2.dev20241212000548}/src/liger_kernel/transformers/rms_norm.py +0 -0
  190. {liger_kernel_nightly-0.5.2.dev20241211231633 → liger_kernel_nightly-0.5.2.dev20241212000548}/src/liger_kernel/transformers/rope.py +0 -0
  191. {liger_kernel_nightly-0.5.2.dev20241211231633 → liger_kernel_nightly-0.5.2.dev20241212000548}/src/liger_kernel/transformers/swiglu.py +0 -0
  192. {liger_kernel_nightly-0.5.2.dev20241211231633 → liger_kernel_nightly-0.5.2.dev20241212000548}/src/liger_kernel/transformers/trainer/__init__.py +0 -0
  193. {liger_kernel_nightly-0.5.2.dev20241211231633 → liger_kernel_nightly-0.5.2.dev20241212000548}/src/liger_kernel/transformers/trainer/orpo_trainer.py +0 -0
  194. {liger_kernel_nightly-0.5.2.dev20241211231633 → liger_kernel_nightly-0.5.2.dev20241212000548}/src/liger_kernel/transformers/trainer_integration.py +0 -0
  195. {liger_kernel_nightly-0.5.2.dev20241211231633 → liger_kernel_nightly-0.5.2.dev20241212000548}/src/liger_kernel/triton/__init__.py +0 -0
  196. {liger_kernel_nightly-0.5.2.dev20241211231633 → liger_kernel_nightly-0.5.2.dev20241212000548}/src/liger_kernel/triton/monkey_patch.py +0 -0
  197. {liger_kernel_nightly-0.5.2.dev20241211231633 → liger_kernel_nightly-0.5.2.dev20241212000548}/src/liger_kernel/utils.py +0 -0
  198. {liger_kernel_nightly-0.5.2.dev20241211231633 → liger_kernel_nightly-0.5.2.dev20241212000548}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
  199. {liger_kernel_nightly-0.5.2.dev20241211231633 → liger_kernel_nightly-0.5.2.dev20241212000548}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
@@ -0,0 +1,10 @@
1
+ # .flake8
2
+ [flake8]
3
+ max-line-length = 120
4
+ exclude =
5
+ .git,
6
+ __pycache__,
7
+ benchmark_internal/others,
8
+ .venv
9
+ # E203: https://github.com/psf/black/issues/315
10
+ extend-ignore=E501,B006,E731,A002,E203
@@ -0,0 +1,48 @@
1
+ name: 🐛 Bug Report
2
+ description: Create a report to help us reproduce and fix the bug
3
+
4
+ body:
5
+ - type: markdown
6
+ attributes:
7
+ value: >
8
+ #### Before submitting a bug, please make sure the issue hasn't been already addressed by searching through [the existing and past issues](https://github.com/linkedin/Liger-Kernel/issues).
9
+ - type: textarea
10
+ attributes:
11
+ label: 🐛 Describe the bug
12
+ description: |
13
+ Please provide a clear and concise description of what the bug is.
14
+ placeholder: |
15
+ A clear and concise description of what the bug is.
16
+ validations:
17
+ required: true
18
+
19
+ - type: textarea
20
+ attributes:
21
+ label: Reproduce
22
+ description: |
23
+ If applicable, add a minimal example so that we can reproduce the error by running the code.
24
+ The snippet needs to be as succinct (minimal) as possible, so please take time to trim down any irrelevant code to help us debug efficiently.
25
+ We are going to copy-paste your code and we expect to get the same result as you did: avoid any external data, and include the relevant imports, etc.
26
+ If the code is too long (hopefully, it isn't), feel free to put it in a public gist and link it in the issue: https://gist.github.com.
27
+ Please also paste or describe the results you observe instead of the expected results.
28
+ If you observe an error, please paste the error message including the **full** traceback of the exception.
29
+ validations:
30
+ required: false
31
+
32
+ - type: textarea
33
+ attributes:
34
+ label: Versions
35
+ description: |
36
+ Please provide triton, torch, hardware, and other necessary versions to reproduce the bug.
37
+
38
+ For convenience, you can run the following command to get the versions of important software dependencies:
39
+ ```bash
40
+ python -m liger_kernel.env_report
41
+ ```
42
+ validations:
43
+ required: true
44
+
45
+ - type: markdown
46
+ attributes:
47
+ value: >
48
+ Thanks for contributing 🎉!
@@ -0,0 +1,25 @@
1
+ name: 🚀 Feature request
2
+ description: Submit a proposal/request for a new Liger feature
3
+
4
+ body:
5
+ - type: textarea
6
+ attributes:
7
+ label: 🚀 The feature, motivation and pitch
8
+ description: >
9
+ A clear and concise description of the feature proposal. Please outline the motivation for the proposal. Is your feature request related to a specific problem? e.g., *"I'm working on X and would like Y to be possible"*. If this is related to another GitHub issue, please link here too.
10
+ validations:
11
+ required: true
12
+ - type: textarea
13
+ attributes:
14
+ label: Alternatives
15
+ description: >
16
+ A description of any alternative solutions or features you've considered, if any.
17
+ - type: textarea
18
+ attributes:
19
+ label: Additional context
20
+ description: >
21
+ Add any other context or screenshots about the feature request.
22
+ - type: markdown
23
+ attributes:
24
+ value: >
25
+ Thanks for contributing 🎉!
@@ -0,0 +1,22 @@
1
+ ## Summary
2
+ <!--- This is a required section; please describe the main purpose of this proposed code change. --->
3
+
4
+ <!---
5
+ ## Details
6
+ This is an optional section; is there anything specific that reviewers should be aware of?
7
+ --->
8
+
9
+ ## Testing Done
10
+ <!--- This is a required section; please describe how this change was tested. --->
11
+
12
+ <!--
13
+ Replace BLANK with your device type. For example, A100-80G-PCIe
14
+
15
+ Complete the following tasks before sending your PR, and replace `[ ]` with
16
+ `[x]` to indicate you have done them.
17
+ -->
18
+
19
+ - Hardware Type: <BLANK>
20
+ - [ ] run `make test` to ensure correctness
21
+ - [ ] run `make checkstyle` to ensure code style
22
+ - [ ] run `make test-convergence` to ensure convergence
@@ -0,0 +1,71 @@
1
+ name: AMD GPU
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - main
7
+ paths:
8
+ - "src/**"
9
+ - "test/**"
10
+ pull_request:
11
+ branches:
12
+ - main
13
+ paths:
14
+ - "src/**"
15
+ - "test/**"
16
+ schedule:
17
+ # Runs at 00:00 UTC daily
18
+ - cron: '0 0 * * *'
19
+ workflow_dispatch: # Enables manual trigger
20
+
21
+ concurrency:
22
+ # This causes it to cancel previous in-progress actions on the same PR / branch,
23
+ group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
24
+ cancel-in-progress: true
25
+
26
+ jobs:
27
+ checkstyle:
28
+ runs-on: ubuntu-latest
29
+
30
+ steps:
31
+ - name: Checkout code
32
+ uses: actions/checkout@v3
33
+
34
+ - name: Set up Python
35
+ uses: actions/setup-python@v3
36
+ with:
37
+ python-version: '3.10'
38
+
39
+ - name: Install dependencies
40
+ run: |
41
+ python -m pip install --upgrade pip
42
+ pip install --no-deps .[fmt]
43
+
44
+ - name: Run checkstyle
45
+ run: make checkstyle
46
+
47
+ tests:
48
+ runs-on: linux-mi300-gpu-1
49
+ needs: [checkstyle]
50
+
51
+ steps:
52
+ - name: Checkout code
53
+ uses: actions/checkout@v3
54
+
55
+ - name: Set up Python
56
+ uses: actions/setup-python@v3
57
+ with:
58
+ python-version: '3.10'
59
+
60
+ - name: Setup Dependencies
61
+ run: |
62
+ python -m pip install --upgrade pip
63
+ pip install -e .[dev] --extra-index-url https://download.pytorch.org/whl/nightly/rocm6.2
64
+
65
+ - name: List Python Environments
66
+ run: python -m pip list
67
+
68
+ - name: Run Unit Tests
69
+ run: |
70
+ make test
71
+ make test-convergence
@@ -0,0 +1,96 @@
1
+ name: NVIDIA GPU
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - main
7
+ paths:
8
+ - "src/**"
9
+ - "test/**"
10
+ pull_request:
11
+ branches:
12
+ - main
13
+ paths:
14
+ - "src/**"
15
+ - "test/**"
16
+ schedule:
17
+ # Runs at 00:00 UTC daily
18
+ - cron: '0 0 * * *'
19
+ workflow_dispatch: # Enables manual trigger
20
+
21
+ concurrency:
22
+ group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
23
+ cancel-in-progress: true
24
+
25
+ jobs:
26
+ checkstyle:
27
+ runs-on: ubuntu-latest
28
+
29
+ steps:
30
+ - name: Checkout code
31
+ uses: actions/checkout@v3
32
+
33
+ - name: Set up Python
34
+ uses: actions/setup-python@v3
35
+ with:
36
+ python-version: '3.10'
37
+
38
+ - name: Install dependencies
39
+ run: |
40
+ python -m pip install --upgrade pip
41
+ pip install --no-deps .[fmt]
42
+
43
+ - name: Run checkstyle
44
+ run: make checkstyle
45
+
46
+ tests:
47
+ runs-on: ubuntu-latest
48
+ needs: [checkstyle]
49
+ env:
50
+ MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }}
51
+ MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }}
52
+ REBUILD_IMAGE: ${{ github.event_name == 'schedule' || github.event_name == 'workflow_dispatch' }}
53
+
54
+ steps:
55
+ - name: Checkout code
56
+ uses: actions/checkout@v3
57
+
58
+ - name: Set up Python
59
+ uses: actions/setup-python@v3
60
+ with:
61
+ python-version: '3.10'
62
+
63
+ - name: Install dependencies
64
+ run: |
65
+ python -m pip install --upgrade pip
66
+ pip install modal
67
+
68
+ - name: Run tests
69
+ run: |
70
+ modal run dev.modal.tests
71
+
72
+ tests-bwd:
73
+ runs-on: ubuntu-latest
74
+ needs: [checkstyle]
75
+ env:
76
+ MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }}
77
+ MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }}
78
+ REBUILD_IMAGE: ${{ github.event_name == 'schedule' || github.event_name == 'workflow_dispatch' }}
79
+
80
+ steps:
81
+ - name: Checkout code
82
+ uses: actions/checkout@v3
83
+
84
+ - name: Set up Python
85
+ uses: actions/setup-python@v3
86
+ with:
87
+ python-version: '3.10'
88
+
89
+ - name: Install dependencies
90
+ run: |
91
+ python -m pip install --upgrade pip
92
+ pip install modal
93
+
94
+ - name: Run tests
95
+ run: |
96
+ modal run dev.modal.tests_bwd
@@ -0,0 +1,49 @@
1
+ name: Publish Liger Kernel Nightly
2
+
3
+ # Though it is name "nightly", we will trigger this workflow on push to the main branch for convenience.
4
+
5
+ on:
6
+ push:
7
+ branches:
8
+ - main # Trigger on push to the main branch
9
+
10
+ jobs:
11
+ build:
12
+ runs-on: ubuntu-latest
13
+
14
+ steps:
15
+ - name: Checkout repository
16
+ uses: actions/checkout@v3
17
+
18
+ - name: Set up Python
19
+ uses: actions/setup-python@v3
20
+ with:
21
+ python-version: '3.8'
22
+
23
+ - name: Install dependencies
24
+ run: |
25
+ python -m pip install --upgrade pip
26
+ pip install build twine wheel toml
27
+
28
+ - name: Update package name and version
29
+ run: |
30
+ VERSION=$(python -c "import toml; print(toml.load('pyproject.toml')['project']['version'])")
31
+ DATE=$(date +%Y%m%d%H%M%S)
32
+ NEW_VERSION="$VERSION.dev$DATE"
33
+ sed -i "s/name = \"liger_kernel\"/name = \"liger_kernel_nightly\"/" pyproject.toml
34
+ sed -i "s/version = \"$VERSION\"/version = \"$NEW_VERSION\"/" pyproject.toml
35
+
36
+ - name: Build package
37
+ run: |
38
+ python -m build
39
+
40
+ - name: Publish package to PyPI
41
+ env:
42
+ TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
43
+ TWINE_PASSWORD: ${{ secrets.PYPI_NIGHTLY_PASSWORD }}
44
+ run: |
45
+ twine upload dist/*
46
+
47
+ - name: Create release notes
48
+ run: |
49
+ echo "Nightly build published to PyPI with the name 'liger-kernel-nightly'."
@@ -0,0 +1,38 @@
1
+ name: Publish Liger Kernel on Release
2
+
3
+ on:
4
+ release:
5
+ types: [published]
6
+
7
+ jobs:
8
+ build:
9
+ runs-on: ubuntu-latest
10
+
11
+ steps:
12
+ - name: Checkout repository
13
+ uses: actions/checkout@v3
14
+
15
+ - name: Set up Python
16
+ uses: actions/setup-python@v3
17
+ with:
18
+ python-version: '3.10'
19
+
20
+ - name: Install dependencies
21
+ run: |
22
+ python -m pip install --upgrade pip
23
+ pip install build twine wheel toml
24
+
25
+ - name: Build package
26
+ run: |
27
+ python -m build
28
+
29
+ - name: Publish package to PyPI
30
+ env:
31
+ TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
32
+ TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
33
+ run: |
34
+ twine upload dist/*
35
+
36
+ - name: Create release notes
37
+ run: |
38
+ echo "Release published to PyPI with the name 'liger-kernel'."
@@ -0,0 +1,20 @@
1
+ __pycache__/
2
+ *.egg-info/
3
+ site/
4
+ .cache/
5
+ .venv/
6
+ venv/
7
+ .ipynb_checkpoints/
8
+
9
+ # Misc
10
+ .DS_Store
11
+
12
+ # Build
13
+ build/
14
+ dist/
15
+
16
+ # Lockfiles
17
+ uv.lock
18
+
19
+ # Benchmark images
20
+ benchmark/visualizations
@@ -0,0 +1,2 @@
1
+ [settings]
2
+ profile = black
@@ -0,0 +1,42 @@
1
+ .PHONY: test checkstyle test-convergence all
2
+
3
+
4
+ all: checkstyle test test-convergence
5
+
6
+ # Command to run pytest for correctness tests
7
+ test:
8
+ python -m pytest --disable-warnings test/ --ignore=test/convergence
9
+
10
+ # Command to run flake8 (code style check), isort (import ordering), and black (code formatting)
11
+ # Subsequent commands still run if the previous fails, but return failure at the end
12
+ checkstyle:
13
+ flake8 .; flake8_status=$$?; \
14
+ isort .; isort_status=$$?; \
15
+ black .; black_status=$$?; \
16
+ if [ $$flake8_status -ne 0 ] || [ $$isort_status -ne 0 ] || [ $$black_status -ne 0 ]; then \
17
+ exit 1; \
18
+ fi
19
+
20
+ # Command to run pytest for convergence tests
21
+ # We have to explicitly set HF_DATASETS_OFFLINE=1, or dataset will silently try to send metrics and timeout (80s) https://github.com/huggingface/datasets/blob/37a603679f451826cfafd8aae00738b01dcb9d58/src/datasets/load.py#L286
22
+ test-convergence:
23
+ HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/test_mini_models.py
24
+ HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/test_mini_models_multimodal.py
25
+ HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/test_mini_models_with_logits.py
26
+
27
+ # Command to run all benchmark scripts and update benchmarking data file
28
+ # By default this doesn't overwrite existing data for the same benchmark experiment
29
+ # run with `make run-benchmarks OVERWRITE=1` to overwrite existing benchmark data
30
+ BENCHMARK_DIR = benchmark/scripts
31
+ BENCHMARK_SCRIPTS = $(wildcard $(BENCHMARK_DIR)/benchmark_*.py)
32
+ OVERWRITE ?= 0
33
+
34
+ run-benchmarks:
35
+ @for script in $(BENCHMARK_SCRIPTS); do \
36
+ echo "Running benchmark: $$script"; \
37
+ if [ $(OVERWRITE) -eq 1 ]; then \
38
+ python $$script --overwrite; \
39
+ else \
40
+ python $$script; \
41
+ fi; \
42
+ done
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.2.dev20241211231633
3
+ Version: 0.5.2.dev20241212000548
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -32,10 +32,6 @@ License-File: LICENSE
32
32
  License-File: NOTICE
33
33
  Requires-Dist: torch>=2.1.2
34
34
  Requires-Dist: triton>=2.3.1
35
- Provides-Extra: transformers
36
- Requires-Dist: transformers~=4.0; extra == "transformers"
37
- Provides-Extra: trl
38
- Requires-Dist: trl>=0.11.0; extra == "trl"
39
35
  Provides-Extra: dev
40
36
  Requires-Dist: transformers>=4.44.2; extra == "dev"
41
37
  Requires-Dist: matplotlib>=3.7.2; extra == "dev"
@@ -46,13 +42,11 @@ Requires-Dist: pytest>=7.1.2; extra == "dev"
46
42
  Requires-Dist: pytest-xdist; extra == "dev"
47
43
  Requires-Dist: pytest-rerunfailures; extra == "dev"
48
44
  Requires-Dist: datasets>=2.19.2; extra == "dev"
49
- Requires-Dist: torchvision>=0.16.2; extra == "dev"
50
45
  Requires-Dist: seaborn; extra == "dev"
51
- Provides-Extra: amd
52
- Requires-Dist: torch>=2.6.0.dev; extra == "amd"
53
- Requires-Dist: setuptools-scm>=8; extra == "amd"
54
- Requires-Dist: torchvision>=0.20.0.dev; extra == "amd"
55
- Requires-Dist: triton>=3.0.0; extra == "amd"
46
+ Provides-Extra: fmt
47
+ Requires-Dist: flake8; extra == "fmt"
48
+ Requires-Dist: isort; extra == "fmt"
49
+ Requires-Dist: black; extra == "fmt"
56
50
 
57
51
  <a name="readme-top"></a>
58
52
 
@@ -202,11 +196,13 @@ To install from source:
202
196
  ```bash
203
197
  git clone https://github.com/linkedin/Liger-Kernel.git
204
198
  cd Liger-Kernel
199
+
200
+ # Install Default Dependencies
201
+ # Setup.py will detect whether you are using AMD or NVIDIA
205
202
  pip install -e .
206
- # or if installing on amd platform
207
- pip install -e .[amd] --extra-index-url https://download.pytorch.org/whl/nightly/rocm6.2 # rocm6.2
208
- # or if using transformers
209
- pip install -e .[transformers]
203
+
204
+ # Setup Development Dependencies
205
+ pip install -e ".[dev]"
210
206
  ```
211
207
 
212
208
 
@@ -146,11 +146,13 @@ To install from source:
146
146
  ```bash
147
147
  git clone https://github.com/linkedin/Liger-Kernel.git
148
148
  cd Liger-Kernel
149
+
150
+ # Install Default Dependencies
151
+ # Setup.py will detect whether you are using AMD or NVIDIA
149
152
  pip install -e .
150
- # or if installing on amd platform
151
- pip install -e .[amd] --extra-index-url https://download.pytorch.org/whl/nightly/rocm6.2 # rocm6.2
152
- # or if using transformers
153
- pip install -e .[transformers]
153
+
154
+ # Setup Development Dependencies
155
+ pip install -e ".[dev]"
154
156
  ```
155
157
 
156
158
 
@@ -0,0 +1,169 @@
1
+ import json
2
+ import os
3
+ from argparse import ArgumentParser
4
+ from dataclasses import dataclass
5
+
6
+ import matplotlib.pyplot as plt
7
+ import pandas as pd
8
+ import seaborn as sns
9
+
10
+ DATA_PATH = "data/all_benchmark_data.csv"
11
+ VISUALIZATIONS_PATH = "visualizations/"
12
+
13
+
14
+ @dataclass
15
+ class VisualizationsConfig:
16
+ """
17
+ Configuration for the visualizations script.
18
+
19
+ Args:
20
+ kernel_name (str): Kernel name to benchmark. (Will run `scripts/benchmark_{kernel_name}.py`)
21
+ metric_name (str): Metric name to visualize (speed/memory)
22
+ kernel_operation_mode (str): Kernel operation mode to visualize (forward/backward/full). Defaults to "full"
23
+ display (bool): Display the visualization. Defaults to False
24
+ overwrite (bool): Overwrite existing visualization, if none exist this flag has no effect as ones are always created and saved. Defaults to False
25
+
26
+ """
27
+
28
+ kernel_name: str
29
+ metric_name: str
30
+ kernel_operation_mode: str = "full"
31
+ display: bool = False
32
+ overwrite: bool = False
33
+
34
+
35
+ def parse_args() -> VisualizationsConfig:
36
+ """Parse command line arguments into a configuration object.
37
+
38
+ Returns:
39
+ VisualizationsConfig: Configuration object for the visualizations script.
40
+ """
41
+ parser = ArgumentParser()
42
+ parser.add_argument(
43
+ "--kernel-name", type=str, required=True, help="Kernel name to benchmark"
44
+ )
45
+ parser.add_argument(
46
+ "--metric-name",
47
+ type=str,
48
+ required=True,
49
+ help="Metric name to visualize (speed/memory)",
50
+ )
51
+ parser.add_argument(
52
+ "--kernel-operation-mode",
53
+ type=str,
54
+ required=True,
55
+ help="Kernel operation mode to visualize (forward/backward/full)",
56
+ )
57
+ parser.add_argument(
58
+ "--display", action="store_true", help="Display the visualization"
59
+ )
60
+ parser.add_argument(
61
+ "--overwrite",
62
+ action="store_true",
63
+ help="Overwrite existing visualization, if none exist this flag has no effect as one are always created",
64
+ )
65
+
66
+ args = parser.parse_args()
67
+
68
+ return VisualizationsConfig(**dict(args._get_kwargs()))
69
+
70
+
71
+ def load_data(config: VisualizationsConfig) -> pd.DataFrame:
72
+ """Loads the benchmark data from the CSV file and filters it based on the configuration.
73
+
74
+ Args:
75
+ config (VisualizationsConfig): Configuration object for the visualizations script.
76
+
77
+ Raises:
78
+ ValueError: If no data is found for the given filters.
79
+
80
+ Returns:
81
+ pd.DataFrame: Filtered benchmark dataframe.
82
+ """
83
+ df = pd.read_csv(DATA_PATH)
84
+ df["extra_benchmark_config"] = df["extra_benchmark_config_str"].apply(json.loads)
85
+
86
+ filtered_df = df[
87
+ (df["kernel_name"] == config.kernel_name)
88
+ & (df["metric_name"] == config.metric_name)
89
+ & (df["kernel_operation_mode"] == config.kernel_operation_mode)
90
+ # Use this to filter by extra benchmark configuration property
91
+ # & (data['extra_benchmark_config'].apply(lambda x: x.get('H') == 4096))
92
+ # FIXME: maybe add a way to filter using some configuration, except of hardcoding it
93
+ ]
94
+
95
+ if filtered_df.empty:
96
+ raise ValueError("No data found for the given filters")
97
+
98
+ return filtered_df
99
+
100
+
101
+ def plot_data(df: pd.DataFrame, config: VisualizationsConfig):
102
+ """Plots the benchmark data, saving the result if needed.
103
+
104
+ Args:
105
+ df (pd.DataFrame): Filtered benchmark dataframe.
106
+ config (VisualizationsConfig): Configuration object for the visualizations script.
107
+ """
108
+ xlabel = df["x_label"].iloc[0]
109
+ ylabel = f"{config.metric_name} ({df['metric_unit'].iloc[0]})"
110
+ # Sort by "kernel_provider" to ensure consistent color assignment
111
+ df = df.sort_values(by="kernel_provider")
112
+
113
+ plt.figure(figsize=(10, 6))
114
+ sns.set(style="whitegrid")
115
+ ax = sns.lineplot(
116
+ data=df,
117
+ x="x_value",
118
+ y="y_value_50",
119
+ hue="kernel_provider",
120
+ marker="o",
121
+ palette="tab10",
122
+ errorbar=("ci", None),
123
+ )
124
+
125
+ # Seaborn can't plot pre-computed error bars, so we need to do it manually
126
+ lines = ax.get_lines()
127
+ colors = [line.get_color() for line in lines]
128
+
129
+ for (_, group_data), color in zip(df.groupby("kernel_provider"), colors):
130
+ # for i, row in group_data.iterrows():
131
+ y_error_lower = group_data["y_value_50"] - group_data["y_value_20"]
132
+ y_error_upper = group_data["y_value_80"] - group_data["y_value_50"]
133
+ y_error = [y_error_lower, y_error_upper]
134
+
135
+ plt.errorbar(
136
+ group_data["x_value"],
137
+ group_data["y_value_50"],
138
+ yerr=y_error,
139
+ fmt="o",
140
+ color=color,
141
+ capsize=5,
142
+ )
143
+ plt.legend(title="Kernel Provider")
144
+ plt.xlabel(xlabel)
145
+ plt.ylabel(ylabel)
146
+ plt.tight_layout()
147
+
148
+ out_path = os.path.join(
149
+ VISUALIZATIONS_PATH, f"{config.kernel_name}_{config.metric_name}.png"
150
+ )
151
+
152
+ if config.display:
153
+ plt.show()
154
+ if config.overwrite or not os.path.exists(
155
+ out_path
156
+ ): # Save the plot if it doesn't exist or if we want to overwrite it
157
+ os.makedirs(VISUALIZATIONS_PATH, exist_ok=True)
158
+ plt.savefig(out_path)
159
+ plt.close()
160
+
161
+
162
+ def main():
163
+ config = parse_args()
164
+ df = load_data(config)
165
+ plot_data(df, config)
166
+
167
+
168
+ if __name__ == "__main__":
169
+ main()