foreblocks 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (371) hide show
  1. examples/rodrigo.py +351 -0
  2. flash-attention/benchmarks/benchmark_alibi.py +275 -0
  3. flash-attention/benchmarks/benchmark_causal.py +225 -0
  4. flash-attention/benchmarks/benchmark_flash_attention.py +180 -0
  5. flash-attention/benchmarks/benchmark_gemm.py +47 -0
  6. flash-attention/csrc/composable_kernel/docs/conf.py +50 -0
  7. flash-attention/csrc/composable_kernel/example/ck_tile/01_fmha/codegen/__init__.py +0 -0
  8. flash-attention/csrc/composable_kernel/example/ck_tile/01_fmha/codegen/cmake_config.py +5 -0
  9. flash-attention/csrc/composable_kernel/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py +128 -0
  10. flash-attention/csrc/composable_kernel/example/ck_tile/01_fmha/codegen/ops/__init__.py +0 -0
  11. flash-attention/csrc/composable_kernel/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +902 -0
  12. flash-attention/csrc/composable_kernel/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +574 -0
  13. flash-attention/csrc/composable_kernel/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py +359 -0
  14. flash-attention/csrc/composable_kernel/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +855 -0
  15. flash-attention/csrc/composable_kernel/example/ck_tile/01_fmha/generate.py +136 -0
  16. flash-attention/csrc/composable_kernel/example/ck_tile/02_layernorm2d/generate.py +730 -0
  17. flash-attention/csrc/composable_kernel/example/ck_tile/10_rmsnorm2d/generate.py +715 -0
  18. flash-attention/csrc/composable_kernel/example/ck_tile/remod.py +21 -0
  19. flash-attention/csrc/composable_kernel/include/ck_tile/remod.py +93 -0
  20. flash-attention/csrc/composable_kernel/python/ck4inductor/__init__.py +0 -0
  21. flash-attention/csrc/composable_kernel/python/ck4inductor/batched_universal_gemm/gen_instances.py +149 -0
  22. flash-attention/csrc/composable_kernel/python/ck4inductor/batched_universal_gemm/op.py +99 -0
  23. flash-attention/csrc/composable_kernel/python/ck4inductor/grouped_conv_fwd/gen_instances.py +165 -0
  24. flash-attention/csrc/composable_kernel/python/ck4inductor/grouped_conv_fwd/op.py +93 -0
  25. flash-attention/csrc/composable_kernel/python/ck4inductor/universal_gemm/gen_instances.py +572 -0
  26. flash-attention/csrc/composable_kernel/python/ck4inductor/universal_gemm/op.py +99 -0
  27. flash-attention/csrc/composable_kernel/python/ck4inductor/util.py +10 -0
  28. flash-attention/csrc/composable_kernel/python/test/test_gen_instances.py +46 -0
  29. flash-attention/csrc/composable_kernel/script/convert_miopen_driver_to_profiler.py +413 -0
  30. flash-attention/csrc/composable_kernel/script/process_perf_data.py +382 -0
  31. flash-attention/csrc/composable_kernel/tile_engine/ops/gemm/gemm_instance_builder.py +654 -0
  32. flash-attention/csrc/cutlass/examples/40_cutlass_py/conv2d.py +177 -0
  33. flash-attention/csrc/cutlass/examples/40_cutlass_py/customizable/conv2d.py +331 -0
  34. flash-attention/csrc/cutlass/examples/40_cutlass_py/customizable/gemm.py +331 -0
  35. flash-attention/csrc/cutlass/examples/40_cutlass_py/customizable/gemm_grouped.py +298 -0
  36. flash-attention/csrc/cutlass/examples/40_cutlass_py/gemm.py +153 -0
  37. flash-attention/csrc/cutlass/examples/40_cutlass_py/gemm_grouped.py +172 -0
  38. flash-attention/csrc/cutlass/examples/41_fused_multi_head_attention/fmha_backward_test.py +232 -0
  39. flash-attention/csrc/cutlass/examples/41_fused_multi_head_attention/piped_subprocess.py +144 -0
  40. flash-attention/csrc/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_all_code.py +129 -0
  41. flash-attention/csrc/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_cmake.py +131 -0
  42. flash-attention/csrc/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_customized_epilogue.py +120 -0
  43. flash-attention/csrc/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_device.py +469 -0
  44. flash-attention/csrc/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_ir.py +249 -0
  45. flash-attention/csrc/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_kernel.py +476 -0
  46. flash-attention/csrc/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_sample.py +232 -0
  47. flash-attention/csrc/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_threadblock.py +1013 -0
  48. flash-attention/csrc/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_turing_and_volta.py +456 -0
  49. flash-attention/csrc/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_verify.py +92 -0
  50. flash-attention/csrc/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/helper.py +135 -0
  51. flash-attention/csrc/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/replace_fix_impl_header.py +67 -0
  52. flash-attention/csrc/cutlass/python/cutlass/__init__.py +190 -0
  53. flash-attention/csrc/cutlass/python/cutlass/backend/__init__.py +48 -0
  54. flash-attention/csrc/cutlass/python/cutlass/backend/arguments.py +133 -0
  55. flash-attention/csrc/cutlass/python/cutlass/backend/c_types.py +622 -0
  56. flash-attention/csrc/cutlass/python/cutlass/backend/compiler.py +459 -0
  57. flash-attention/csrc/cutlass/python/cutlass/backend/conv2d_operation.py +698 -0
  58. flash-attention/csrc/cutlass/python/cutlass/backend/epilogue.py +541 -0
  59. flash-attention/csrc/cutlass/python/cutlass/backend/evt/__init__.py +34 -0
  60. flash-attention/csrc/cutlass/python/cutlass/backend/evt/backend/__init__.py +36 -0
  61. flash-attention/csrc/cutlass/python/cutlass/backend/evt/backend/emitter_base.py +158 -0
  62. flash-attention/csrc/cutlass/python/cutlass/backend/evt/backend/sm80_emitter.py +47 -0
  63. flash-attention/csrc/cutlass/python/cutlass/backend/evt/backend/sm80_nodes.py +258 -0
  64. flash-attention/csrc/cutlass/python/cutlass/backend/evt/backend/sm90_emitter.py +98 -0
  65. flash-attention/csrc/cutlass/python/cutlass/backend/evt/backend/sm90_nodes.py +329 -0
  66. flash-attention/csrc/cutlass/python/cutlass/backend/evt/epilogue.py +167 -0
  67. flash-attention/csrc/cutlass/python/cutlass/backend/evt/frontend/__init__.py +33 -0
  68. flash-attention/csrc/cutlass/python/cutlass/backend/evt/frontend/frontend_base.py +262 -0
  69. flash-attention/csrc/cutlass/python/cutlass/backend/evt/frontend/python_ast.py +187 -0
  70. flash-attention/csrc/cutlass/python/cutlass/backend/evt/ir/__init__.py +53 -0
  71. flash-attention/csrc/cutlass/python/cutlass/backend/evt/ir/compute_nodes.py +91 -0
  72. flash-attention/csrc/cutlass/python/cutlass/backend/evt/ir/dag_ir.py +236 -0
  73. flash-attention/csrc/cutlass/python/cutlass/backend/evt/ir/layout_algorithm.py +324 -0
  74. flash-attention/csrc/cutlass/python/cutlass/backend/evt/ir/layout_nodes.py +336 -0
  75. flash-attention/csrc/cutlass/python/cutlass/backend/evt/ir/load_nodes.py +294 -0
  76. flash-attention/csrc/cutlass/python/cutlass/backend/evt/ir/node.py +293 -0
  77. flash-attention/csrc/cutlass/python/cutlass/backend/evt/ir/store_nodes.py +277 -0
  78. flash-attention/csrc/cutlass/python/cutlass/backend/evt/ir/tensor.py +130 -0
  79. flash-attention/csrc/cutlass/python/cutlass/backend/evt/passes/__init__.py +42 -0
  80. flash-attention/csrc/cutlass/python/cutlass/backend/evt/passes/graph_drawer.py +142 -0
  81. flash-attention/csrc/cutlass/python/cutlass/backend/evt/passes/pass_argument_type.py +116 -0
  82. flash-attention/csrc/cutlass/python/cutlass/backend/evt/passes/pass_dag_2_tree.py +147 -0
  83. flash-attention/csrc/cutlass/python/cutlass/backend/evt/passes/pass_fix_element_d.py +64 -0
  84. flash-attention/csrc/cutlass/python/cutlass/backend/evt/passes/pass_get_impl.py +90 -0
  85. flash-attention/csrc/cutlass/python/cutlass/backend/evt/passes/pass_layout_elimination.py +217 -0
  86. flash-attention/csrc/cutlass/python/cutlass/backend/evt/passes/pass_manager.py +164 -0
  87. flash-attention/csrc/cutlass/python/cutlass/backend/evt/passes/pass_no_op_elimination.py +53 -0
  88. flash-attention/csrc/cutlass/python/cutlass/backend/evt/passes/pass_preprocess_red.py +97 -0
  89. flash-attention/csrc/cutlass/python/cutlass/backend/evt/passes/pass_shape_type_propagation.py +59 -0
  90. flash-attention/csrc/cutlass/python/cutlass/backend/evt/passes/smem_size_calculator.py +204 -0
  91. flash-attention/csrc/cutlass/python/cutlass/backend/evt/passes/util.py +43 -0
  92. flash-attention/csrc/cutlass/python/cutlass/backend/frontend.py +107 -0
  93. flash-attention/csrc/cutlass/python/cutlass/backend/gemm_operation.py +2138 -0
  94. flash-attention/csrc/cutlass/python/cutlass/backend/library.py +488 -0
  95. flash-attention/csrc/cutlass/python/cutlass/backend/memory_manager.py +120 -0
  96. flash-attention/csrc/cutlass/python/cutlass/backend/operation.py +133 -0
  97. flash-attention/csrc/cutlass/python/cutlass/backend/reduction_operation.py +452 -0
  98. flash-attention/csrc/cutlass/python/cutlass/backend/type_hint.py +35 -0
  99. flash-attention/csrc/cutlass/python/cutlass/backend/utils/__init__.py +33 -0
  100. flash-attention/csrc/cutlass/python/cutlass/backend/utils/device.py +123 -0
  101. flash-attention/csrc/cutlass/python/cutlass/emit/__init__.py +33 -0
  102. flash-attention/csrc/cutlass/python/cutlass/emit/common.py +267 -0
  103. flash-attention/csrc/cutlass/python/cutlass/emit/pytorch.py +936 -0
  104. flash-attention/csrc/cutlass/python/cutlass/epilogue/__init__.py +55 -0
  105. flash-attention/csrc/cutlass/python/cutlass/epilogue/epilogue.py +158 -0
  106. flash-attention/csrc/cutlass/python/cutlass/epilogue/evt_ops.py +92 -0
  107. flash-attention/csrc/cutlass/python/cutlass/library_defaults.py +580 -0
  108. flash-attention/csrc/cutlass/python/cutlass/op/__init__.py +36 -0
  109. flash-attention/csrc/cutlass/python/cutlass/op/conv.py +983 -0
  110. flash-attention/csrc/cutlass/python/cutlass/op/gemm.py +715 -0
  111. flash-attention/csrc/cutlass/python/cutlass/op/gemm_grouped.py +264 -0
  112. flash-attention/csrc/cutlass/python/cutlass/op/op.py +430 -0
  113. flash-attention/csrc/cutlass/python/cutlass/shape.py +184 -0
  114. flash-attention/csrc/cutlass/python/cutlass/swizzle.py +65 -0
  115. flash-attention/csrc/cutlass/python/cutlass/utils/__init__.py +41 -0
  116. flash-attention/csrc/cutlass/python/cutlass/utils/check.py +269 -0
  117. flash-attention/csrc/cutlass/python/cutlass/utils/datatypes.py +362 -0
  118. flash-attention/csrc/cutlass/python/cutlass/utils/profiler.py +185 -0
  119. flash-attention/csrc/cutlass/python/cutlass_library/__init__.py +63 -0
  120. flash-attention/csrc/cutlass/python/cutlass_library/conv2d_operation.py +621 -0
  121. flash-attention/csrc/cutlass/python/cutlass_library/conv3d_operation.py +482 -0
  122. flash-attention/csrc/cutlass/python/cutlass_library/conv3x_emitter.py +250 -0
  123. flash-attention/csrc/cutlass/python/cutlass_library/emit_kernel_listing.py +880 -0
  124. flash-attention/csrc/cutlass/python/cutlass_library/gemm_operation.py +1520 -0
  125. flash-attention/csrc/cutlass/python/cutlass_library/generator.py +10851 -0
  126. flash-attention/csrc/cutlass/python/cutlass_library/library.py +1317 -0
  127. flash-attention/csrc/cutlass/python/cutlass_library/manifest.py +870 -0
  128. flash-attention/csrc/cutlass/python/cutlass_library/rank_2k_operation.py +438 -0
  129. flash-attention/csrc/cutlass/python/cutlass_library/rank_k_operation.py +427 -0
  130. flash-attention/csrc/cutlass/python/cutlass_library/sm90_shapes.py +212 -0
  131. flash-attention/csrc/cutlass/python/cutlass_library/sm90_utils.py +703 -0
  132. flash-attention/csrc/cutlass/python/cutlass_library/symm_operation.py +440 -0
  133. flash-attention/csrc/cutlass/python/cutlass_library/trmm_operation.py +447 -0
  134. flash-attention/csrc/cutlass/python/docs_src/source/conf.py +132 -0
  135. flash-attention/csrc/cutlass/python/pycute/__init__.py +36 -0
  136. flash-attention/csrc/cutlass/python/pycute/int_tuple.py +225 -0
  137. flash-attention/csrc/cutlass/python/pycute/layout.py +367 -0
  138. flash-attention/csrc/cutlass/python/pycute/swizzle.py +129 -0
  139. flash-attention/csrc/cutlass/python/pycute/typing.py +42 -0
  140. flash-attention/csrc/cutlass/python/setup_cutlass.py +74 -0
  141. flash-attention/csrc/cutlass/python/setup_library.py +46 -0
  142. flash-attention/csrc/cutlass/python/setup_pycute.py +46 -0
  143. flash-attention/csrc/cutlass/test/python/cutlass/conv2d/conv2d_problem_sizes.py +661 -0
  144. flash-attention/csrc/cutlass/test/python/cutlass/conv2d/conv2d_sm80.py +146 -0
  145. flash-attention/csrc/cutlass/test/python/cutlass/conv2d/conv2d_test_utils.py +428 -0
  146. flash-attention/csrc/cutlass/test/python/cutlass/conv2d/run_all_tests.py +44 -0
  147. flash-attention/csrc/cutlass/test/python/cutlass/emit/pytorch.py +309 -0
  148. flash-attention/csrc/cutlass/test/python/cutlass/evt/evt_compute_sm80_90.py +122 -0
  149. flash-attention/csrc/cutlass/test/python/cutlass/evt/evt_layout_sm80_90.py +173 -0
  150. flash-attention/csrc/cutlass/test/python/cutlass/evt/evt_load_sm80_90.py +142 -0
  151. flash-attention/csrc/cutlass/test/python/cutlass/evt/evt_mixed_sm80_90.py +274 -0
  152. flash-attention/csrc/cutlass/test/python/cutlass/evt/evt_store_sm80_90.py +155 -0
  153. flash-attention/csrc/cutlass/test/python/cutlass/evt/run_all_tests.py +44 -0
  154. flash-attention/csrc/cutlass/test/python/cutlass/evt/utils/evt_testbed.py +230 -0
  155. flash-attention/csrc/cutlass/test/python/cutlass/gemm/gemm_batched.py +134 -0
  156. flash-attention/csrc/cutlass/test/python/cutlass/gemm/gemm_f16_sm80.py +128 -0
  157. flash-attention/csrc/cutlass/test/python/cutlass/gemm/gemm_f16_sm90.py +146 -0
  158. flash-attention/csrc/cutlass/test/python/cutlass/gemm/gemm_f32_sm80.py +104 -0
  159. flash-attention/csrc/cutlass/test/python/cutlass/gemm/gemm_f64_sm80.py +103 -0
  160. flash-attention/csrc/cutlass/test/python/cutlass/gemm/gemm_f64_sm90.py +71 -0
  161. flash-attention/csrc/cutlass/test/python/cutlass/gemm/gemm_f8_sm90.py +112 -0
  162. flash-attention/csrc/cutlass/test/python/cutlass/gemm/gemm_mixed_sm80.py +75 -0
  163. flash-attention/csrc/cutlass/test/python/cutlass/gemm/gemm_s8_sm80.py +103 -0
  164. flash-attention/csrc/cutlass/test/python/cutlass/gemm/gemm_s8_sm90.py +98 -0
  165. flash-attention/csrc/cutlass/test/python/cutlass/gemm/gemm_testbed.py +423 -0
  166. flash-attention/csrc/cutlass/test/python/cutlass/gemm/run_all_tests.py +44 -0
  167. flash-attention/csrc/cutlass/test/python/cutlass/gemm/utils.py +260 -0
  168. flash-attention/csrc/cutlass/test/python/cutlass/installation.py +57 -0
  169. flash-attention/csrc/cutlass/test/python/cutlass/interface/conv2d_interface.py +284 -0
  170. flash-attention/csrc/cutlass/test/python/cutlass/interface/evt_interface.py +254 -0
  171. flash-attention/csrc/cutlass/test/python/cutlass/interface/gemm_interface.py +351 -0
  172. flash-attention/csrc/cutlass/test/python/cutlass/interface/utils.py +69 -0
  173. flash-attention/csrc/cutlass/test/python/pycute/run_all_tests.py +75 -0
  174. flash-attention/csrc/cutlass/test/python/pycute/test_coalesce.py +95 -0
  175. flash-attention/csrc/cutlass/test/python/pycute/test_complement.py +92 -0
  176. flash-attention/csrc/cutlass/test/python/pycute/test_composition.py +213 -0
  177. flash-attention/csrc/cutlass/test/python/pycute/test_int_tuple.py +80 -0
  178. flash-attention/csrc/cutlass/test/python/pycute/test_left_inverse.py +87 -0
  179. flash-attention/csrc/cutlass/test/python/pycute/test_right_inverse.py +96 -0
  180. flash-attention/csrc/cutlass/test/python/pycute/test_typing.py +59 -0
  181. flash-attention/csrc/cutlass/test/unit/gemm/device/simt_sm50.py +341 -0
  182. flash-attention/csrc/flash_attn/src/generate_kernels.py +110 -0
  183. flash-attention/csrc/ft_attention/setup.py +153 -0
  184. flash-attention/csrc/fused_dense_lib/setup.py +42 -0
  185. flash-attention/csrc/fused_softmax/setup.py +50 -0
  186. flash-attention/csrc/layer_norm/setup.py +205 -0
  187. flash-attention/csrc/rotary/setup.py +126 -0
  188. flash-attention/csrc/xentropy/setup.py +139 -0
  189. flash-attention/flash_attn/__init__.py +11 -0
  190. flash-attention/flash_attn/bert_padding.py +218 -0
  191. flash-attention/flash_attn/flash_attn_interface.py +1606 -0
  192. flash-attention/flash_attn/flash_attn_triton.py +1160 -0
  193. flash-attention/flash_attn/flash_attn_triton_amd/__init__.py +0 -0
  194. flash-attention/flash_attn/flash_attn_triton_amd/bench.py +1223 -0
  195. flash-attention/flash_attn/flash_attn_triton_amd/bwd_prefill.py +814 -0
  196. flash-attention/flash_attn/flash_attn_triton_amd/bwd_prefill_fused.py +3266 -0
  197. flash-attention/flash_attn/flash_attn_triton_amd/bwd_prefill_onekernel.py +1091 -0
  198. flash-attention/flash_attn/flash_attn_triton_amd/bwd_prefill_split.py +1354 -0
  199. flash-attention/flash_attn/flash_attn_triton_amd/bwd_ref.py +478 -0
  200. flash-attention/flash_attn/flash_attn_triton_amd/fp8.py +716 -0
  201. flash-attention/flash_attn/flash_attn_triton_amd/fwd_decode.py +814 -0
  202. flash-attention/flash_attn/flash_attn_triton_amd/fwd_prefill.py +648 -0
  203. flash-attention/flash_attn/flash_attn_triton_amd/fwd_ref.py +387 -0
  204. flash-attention/flash_attn/flash_attn_triton_amd/interface_fa.py +798 -0
  205. flash-attention/flash_attn/flash_attn_triton_amd/test.py +932 -0
  206. flash-attention/flash_attn/flash_attn_triton_amd/train.py +403 -0
  207. flash-attention/flash_attn/flash_attn_triton_amd/utils.py +776 -0
  208. flash-attention/flash_attn/flash_attn_triton_og.py +365 -0
  209. flash-attention/flash_attn/flash_blocksparse_attention.py +197 -0
  210. flash-attention/flash_attn/flash_blocksparse_attn_interface.py +200 -0
  211. flash-attention/flash_attn/fused_softmax.py +201 -0
  212. flash-attention/flash_attn/layers/__init__.py +0 -0
  213. flash-attention/flash_attn/layers/patch_embed.py +67 -0
  214. flash-attention/flash_attn/layers/rotary.py +482 -0
  215. flash-attention/flash_attn/losses/__init__.py +0 -0
  216. flash-attention/flash_attn/losses/cross_entropy.py +85 -0
  217. flash-attention/flash_attn/models/__init__.py +0 -0
  218. flash-attention/flash_attn/models/baichuan.py +151 -0
  219. flash-attention/flash_attn/models/bert.py +764 -0
  220. flash-attention/flash_attn/models/bigcode.py +233 -0
  221. flash-attention/flash_attn/models/btlm.py +102 -0
  222. flash-attention/flash_attn/models/falcon.py +143 -0
  223. flash-attention/flash_attn/models/gpt.py +1080 -0
  224. flash-attention/flash_attn/models/gpt_neox.py +124 -0
  225. flash-attention/flash_attn/models/gptj.py +109 -0
  226. flash-attention/flash_attn/models/llama.py +422 -0
  227. flash-attention/flash_attn/models/opt.py +116 -0
  228. flash-attention/flash_attn/models/vit.py +373 -0
  229. flash-attention/flash_attn/modules/__init__.py +0 -0
  230. flash-attention/flash_attn/modules/block.py +397 -0
  231. flash-attention/flash_attn/modules/embedding.py +216 -0
  232. flash-attention/flash_attn/modules/mha.py +993 -0
  233. flash-attention/flash_attn/modules/mlp.py +191 -0
  234. flash-attention/flash_attn/ops/__init__.py +0 -0
  235. flash-attention/flash_attn/ops/activations.py +135 -0
  236. flash-attention/flash_attn/ops/fused_dense.py +688 -0
  237. flash-attention/flash_attn/ops/layer_norm.py +800 -0
  238. flash-attention/flash_attn/ops/rms_norm.py +174 -0
  239. flash-attention/flash_attn/ops/triton/__init__.py +1 -0
  240. flash-attention/flash_attn/ops/triton/cross_entropy.py +330 -0
  241. flash-attention/flash_attn/ops/triton/k_activations.py +162 -0
  242. flash-attention/flash_attn/ops/triton/layer_norm.py +1252 -0
  243. flash-attention/flash_attn/ops/triton/linear.py +594 -0
  244. flash-attention/flash_attn/ops/triton/mlp.py +149 -0
  245. flash-attention/flash_attn/ops/triton/rotary.py +185 -0
  246. flash-attention/flash_attn/utils/__init__.py +0 -0
  247. flash-attention/flash_attn/utils/benchmark.py +268 -0
  248. flash-attention/flash_attn/utils/distributed.py +144 -0
  249. flash-attention/flash_attn/utils/generation.py +740 -0
  250. flash-attention/flash_attn/utils/library.py +66 -0
  251. flash-attention/flash_attn/utils/pretrained.py +79 -0
  252. flash-attention/flash_attn/utils/torch.py +21 -0
  253. flash-attention/hopper/__init__.py +1 -0
  254. flash-attention/hopper/benchmark_attn.py +411 -0
  255. flash-attention/hopper/benchmark_flash_attention_fp8.py +353 -0
  256. flash-attention/hopper/benchmark_mla_decode.py +129 -0
  257. flash-attention/hopper/benchmark_split_kv.py +331 -0
  258. flash-attention/hopper/flash_attn_interface.py +834 -0
  259. flash-attention/hopper/generate_kernels.py +222 -0
  260. flash-attention/hopper/padding.py +53 -0
  261. flash-attention/hopper/setup.py +659 -0
  262. flash-attention/hopper/test_attn_kvcache.py +486 -0
  263. flash-attention/hopper/test_flash_attn.py +1130 -0
  264. flash-attention/hopper/test_kvcache.py +234 -0
  265. flash-attention/hopper/test_util.py +348 -0
  266. flash-attention/setup.py +561 -0
  267. flash-attention/tests/layers/test_rotary.py +134 -0
  268. flash-attention/tests/losses/test_cross_entropy.py +83 -0
  269. flash-attention/tests/losses/test_cross_entropy_parallel.py +104 -0
  270. flash-attention/tests/models/test_baichuan.py +460 -0
  271. flash-attention/tests/models/test_bert.py +324 -0
  272. flash-attention/tests/models/test_bigcode.py +204 -0
  273. flash-attention/tests/models/test_btlm.py +245 -0
  274. flash-attention/tests/models/test_falcon.py +408 -0
  275. flash-attention/tests/models/test_gpt.py +478 -0
  276. flash-attention/tests/models/test_gpt_generation_parallel.py +172 -0
  277. flash-attention/tests/models/test_gpt_neox.py +104 -0
  278. flash-attention/tests/models/test_gpt_parallel.py +236 -0
  279. flash-attention/tests/models/test_gptj.py +184 -0
  280. flash-attention/tests/models/test_llama.py +633 -0
  281. flash-attention/tests/models/test_opt.py +237 -0
  282. flash-attention/tests/models/test_vit.py +48 -0
  283. flash-attention/tests/modules/test_block_parallel.py +273 -0
  284. flash-attention/tests/modules/test_embedding_parallel.py +106 -0
  285. flash-attention/tests/modules/test_mha_parallel.py +160 -0
  286. flash-attention/tests/modules/test_mlp_parallel.py +143 -0
  287. flash-attention/tests/ops/test_dropout_layer_norm.py +1189 -0
  288. flash-attention/tests/ops/test_fused_dense.py +172 -0
  289. flash-attention/tests/ops/test_fused_dense_parallel.py +237 -0
  290. flash-attention/tests/ops/triton/test_layer_norm.py +374 -0
  291. flash-attention/tests/test_flash_attn.py +2525 -0
  292. flash-attention/tests/test_flash_attn_ck.py +1618 -0
  293. flash-attention/tests/test_flash_attn_triton_amd.py +2547 -0
  294. flash-attention/tests/test_rotary.py +321 -0
  295. flash-attention/tests/test_util.py +274 -0
  296. flash-attention/training/run.py +68 -0
  297. flash-attention/training/src/callbacks/__init__.py +0 -0
  298. flash-attention/training/src/callbacks/causality_monitor.py +61 -0
  299. flash-attention/training/src/callbacks/ema.py +82 -0
  300. flash-attention/training/src/callbacks/flop_count.py +43 -0
  301. flash-attention/training/src/callbacks/gpu_affinity.py +40 -0
  302. flash-attention/training/src/callbacks/loss_scale_monitor.py +32 -0
  303. flash-attention/training/src/callbacks/model_checkpoint.py +36 -0
  304. flash-attention/training/src/callbacks/norm_monitor.py +79 -0
  305. flash-attention/training/src/callbacks/params_log.py +34 -0
  306. flash-attention/training/src/callbacks/speed_monitor.py +95 -0
  307. flash-attention/training/src/callbacks/wandb_callbacks.py +289 -0
  308. flash-attention/training/src/datamodules/datasets/detokenizer.py +53 -0
  309. flash-attention/training/src/datamodules/datasets/lm_dataset.py +32 -0
  310. flash-attention/training/src/datamodules/fault_tolerant_sampler.py +123 -0
  311. flash-attention/training/src/datamodules/imagenet.py +283 -0
  312. flash-attention/training/src/datamodules/language_modeling_hf.py +299 -0
  313. flash-attention/training/src/datamodules/timm_mixup.py +20 -0
  314. flash-attention/training/src/distributed/ddp_comm_hooks.py +43 -0
  315. flash-attention/training/src/eval.py +129 -0
  316. flash-attention/training/src/metrics/accuracy.py +11 -0
  317. flash-attention/training/src/metrics/num_tokens.py +45 -0
  318. flash-attention/training/src/metrics/perplexity.py +70 -0
  319. flash-attention/training/src/models/modules/seq_common.py +342 -0
  320. flash-attention/training/src/optim/param_grouping.py +114 -0
  321. flash-attention/training/src/optim/timm_lr_scheduler.py +30 -0
  322. flash-attention/training/src/tasks/seq.py +192 -0
  323. flash-attention/training/src/train.py +136 -0
  324. flash-attention/training/src/utils/checkpoint.py +76 -0
  325. flash-attention/training/src/utils/ddp_zero1.py +106 -0
  326. flash-attention/training/src/utils/ddp_zero2.py +146 -0
  327. flash-attention/training/src/utils/distributed.py +111 -0
  328. flash-attention/training/src/utils/ema.py +280 -0
  329. flash-attention/training/src/utils/flops.py +45 -0
  330. flash-attention/training/src/utils/gpu_affinity.py +142 -0
  331. flash-attention/training/src/utils/utils.py +146 -0
  332. flash-attention/training/tests/datamodules/test_language_modeling_hf.py +218 -0
  333. foreblocks/__init__.py +42 -0
  334. foreblocks/att.py +299 -0
  335. foreblocks/aux.py +45 -0
  336. foreblocks/blocks/__init__.py +64 -0
  337. foreblocks/blocks/attention.py +287 -0
  338. foreblocks/blocks/famous.py +529 -0
  339. foreblocks/blocks/fourier.py +464 -0
  340. foreblocks/blocks/graph.py +1466 -0
  341. foreblocks/blocks/mamba.py +119 -0
  342. foreblocks/blocks/multiscale.py +124 -0
  343. foreblocks/blocks/nha.py +476 -0
  344. foreblocks/blocks/ode.py +184 -0
  345. foreblocks/blocks/simple.py +204 -0
  346. foreblocks/blocks/wavelets.py +439 -0
  347. foreblocks/blocks.py +0 -0
  348. foreblocks/core.py +698 -0
  349. foreblocks/darts/darts.py +557 -0
  350. foreblocks/darts/darts_run.py +2386 -0
  351. foreblocks/enc_dec.py +218 -0
  352. foreblocks/pipeline.py +389 -0
  353. foreblocks/pre/__init__.py +0 -0
  354. foreblocks/pre/ewt.py +104 -0
  355. foreblocks/pre/filters.py +147 -0
  356. foreblocks/pre/impute.py +428 -0
  357. foreblocks/pre/outlier.py +399 -0
  358. foreblocks/preprocessing.py +978 -0
  359. foreblocks/tf/embeddings.py +395 -0
  360. foreblocks/tf/fed.py +322 -0
  361. foreblocks/tf/transformer.py +690 -0
  362. foreblocks/tf/transformer_att.py +535 -0
  363. foreblocks/tf/transformer_aux.py +230 -0
  364. foreblocks/tf/transformer_moe.py +1137 -0
  365. foreblocks/third_party/flash_softpick_attn.py +796 -0
  366. foreblocks/third_party/vsgd.py +212 -0
  367. foreblocks/utils.py +576 -0
  368. foreblocks-0.1.0.dist-info/METADATA +484 -0
  369. foreblocks-0.1.0.dist-info/RECORD +371 -0
  370. foreblocks-0.1.0.dist-info/WHEEL +5 -0
  371. foreblocks-0.1.0.dist-info/top_level.txt +3 -0
examples/rodrigo.py ADDED
@@ -0,0 +1,351 @@
1
+ import os
2
+ import pickle
3
+ import sys
4
+
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ # Get the current working directory of the notebook
11
+ notebook_dir = os.getcwd()
12
+
13
+ # Add the parent directory to sys.path
14
+ parent_dir = os.path.abspath(os.path.join(notebook_dir, "."))
15
+ if parent_dir not in sys.path:
16
+ sys.path.append(parent_dir)
17
+
18
+ from foreblocks.att import AttentionLayer
19
+ from foreblocks.blocks import GRU
20
+ from foreblocks.blocks.fourier import FNO1DLayer, FourierFeatures
21
+ from foreblocks.blocks.graph import LatentGraphNetwork
22
+ from foreblocks.tf.embeddings import LearnablePositionalEncoding
23
+ from foreblocks.tf.transformer import TransformerDecoder, TransformerEncoder
24
+ from torch.jit import script
25
+ from torch.utils.data import DataLoader, TensorDataset
26
+
27
+ from foreblocks import ForecastingModel, LSTMDecoder, LSTMEncoder, Trainer
28
+
29
+ # ---
30
+ # jupyter:
31
+ # jupytext:
32
+ # text_representation:
33
+ # extension: .py
34
+ # format_name: hydrogen
35
+ # format_version: '1.3'
36
+ # jupytext_version: 1.17.1
37
+ # kernelspec:
38
+ # display_name: .venv
39
+ # language: python
40
+ # name: python3
41
+ # ---
42
+
43
+ # %%
44
+
45
+ total_epochs = 500
46
+
47
+
48
+ # create scheduled_sampling_fn for teacher forcing
49
+ def scheduled_sampling_fn(epoch):
50
+ tf_ratio = max(0.0, 0.8 - (epoch / total_epochs))
51
+ return tf_ratio
52
+
53
+
54
+ # Get the current working directory of the notebook
55
+ notebook_dir = os.getcwd()
56
+
57
+ # Add the parent directory to sys.path
58
+ parent_dir = os.path.abspath(os.path.join(notebook_dir, ".."))
59
+ if parent_dir not in sys.path:
60
+ sys.path.append(parent_dir)
61
+
62
+
63
+ # df = pd.read_csv('df_demmand_without_category_2025_05_13.csv')
64
+
65
+
66
+ # %%
67
+
68
+ # # read df_demmand_without_category_2025_05_13.csv
69
+ # import pandas as pd
70
+ # df = pd.read_csv('df_demmand_without_category_2025_05_13.csv')
71
+ # from foreblocks import TimeSeriesPreprocessor
72
+
73
+ # import numpy as np
74
+ # import pandas as pd
75
+ # import matplotlib.pyplot as plt
76
+
77
+
78
+ # # Generate synthetic time series data
79
+ # np.random.seed(42)
80
+ # n_samples = 200
81
+ # timestamps = df.date
82
+
83
+ # # convert df to numpy
84
+ # data = df.drop(columns=['date']).values
85
+ # # Create preprocessor with various techniques enabled
86
+ # preprocessor = TimeSeriesPreprocessor(
87
+ # normalize=True,
88
+ # differencing=False,
89
+ # detrend=True,
90
+ # apply_ewt=True,
91
+ # window_size=24,
92
+ # horizon=12,
93
+ # remove_outliers=True,
94
+ # outlier_threshold=2.5,
95
+ # outlier_method="iqr",
96
+ # impute_method="iterative",
97
+ # ewt_bands=5,
98
+ # trend_imf_idx=0,
99
+ # log_transform=False,
100
+ # filter_window=5,
101
+ # filter_polyorder=2,
102
+ # apply_filter=True,
103
+ # self_tune=True,
104
+ # apply_imputation=True,
105
+ # generate_time_features=False,
106
+ # )
107
+
108
+ # # Fit and transform the data
109
+ # X, y, processed_data = preprocessor.fit_transform(data, time_stamps=timestamps)
110
+
111
+ # # Visualize the results
112
+ # plt.figure(figsize=(15, 10))
113
+
114
+ # plt.subplot(3, 1, 1)
115
+ # plt.title('Original Data with Outliers and Missing Values')
116
+ # plt.plot(data)
117
+
118
+ # plt.subplot(3, 1, 2)
119
+ # plt.title('Processed Data')
120
+ # print("Processed data shape:", processed_data.shape)
121
+ # plt.plot(processed_data)
122
+
123
+ # plt.subplot(3, 1, 3)
124
+ # plt.title('EWT Components')
125
+ # ewt_components = preprocessor.get_ewt_components()
126
+ # if ewt_components:
127
+ # for i, imf in enumerate(ewt_components[0].T):
128
+ # plt.plot(imf, label=f'IMF {i}')
129
+ # plt.legend()
130
+
131
+ # plt.tight_layout()
132
+ # plt.show()
133
+
134
+ # print(f"Input sequence shape: {X.shape}")
135
+ # print(f"Target sequence shape: {y.shape}")
136
+
137
+ # %%
138
+ # # load the processed data
139
+ # # save X and y to pickle
140
+ # import pickle
141
+ # with open('X.pkl', 'wb') as f:
142
+ # pickle.dump(X, f)
143
+ # with open('y.pkl', 'wb') as f:
144
+ # pickle.dump(y, f)
145
+
146
+ # %%
147
+
148
+ # load X and y from pickle
149
+ with open("examples/X.pkl", "rb") as f:
150
+ X = pickle.load(f)
151
+ with open("examples/y.pkl", "rb") as f:
152
+ y = pickle.load(f)
153
+ with open("examples/time.pkl", "rb") as f:
154
+ time_feat = pickle.load(f)
155
+
156
+ # %%
157
+
158
+
159
+ # Parameters
160
+ input_size = X.shape[2] # Number of features
161
+ hidden_size = 64
162
+ num_layers = 2
163
+ output_size = X.shape[2] # Number of features
164
+ target_len = 12
165
+ seq_len = 24
166
+ total_len = 300 # Total synthetic time series length
167
+
168
+
169
+ # fourier_preprocessor = LatentCorrelationGraphLayer(
170
+ # #conv_type='sgconv',
171
+ # input_size=input_size, # Input dimension
172
+ # output_size=hidden_size, # Output dimension (same as hidden_size)
173
+ # )
174
+ preprocessor = LatentGraphNetwork(
175
+ input_size=input_size,
176
+ output_size=input_size,
177
+ hidden_size=input_size,
178
+ strategy="vanilla",
179
+ aggregation="mean",
180
+ )
181
+
182
+ # preprocessor = FourierFeatures(
183
+ # input_size=input_size,
184
+ # output_size=input_size,
185
+ # num_frequencies=8,
186
+ # )
187
+ # 1. Create encoder and decoder
188
+ # encoder = LSTMEncoder(input_size, hidden_size, num_layers)
189
+ # decoder = LSTMDecoder(output_size, hidden_size, output_size, num_layers)
190
+
191
+ model_params = {
192
+ "input_processor_output_size": input_size,
193
+ "hidden_size": 64,
194
+ "nhead": 4,
195
+ "num_encoder_layers": 1,
196
+ "num_decoder_layers": 1,
197
+ "dropout": 0.1,
198
+ "dim_feedforward": 2048,
199
+ "seq_len": 24,
200
+ "target_len": 12,
201
+ "total_len": 1000,
202
+ "input_size": input_size,
203
+ "output_size": output_size,
204
+ }
205
+
206
+ from foreblocks.third_party.flash_softpick_attn import parallel_softpick_attn
207
+
208
+
209
+ def warmup_softpick(device, d_model=512, n_heads=4, seq_len=16):
210
+ q = torch.randn(
211
+ 1, seq_len, n_heads, d_model // n_heads, device=device, dtype=torch.float16
212
+ )
213
+ k = q.clone()
214
+ v = q.clone()
215
+ _ = parallel_softpick_attn(q, k, v, head_first=False)
216
+
217
+
218
+ from foreblocks.blocks.nha import NHA
219
+
220
+ embedding_size = 12 # Size of the output embeddings
221
+ # 1. Create the NHA input preprocessor
222
+ nha_preprocessor = NHA(
223
+ input_dim=input_size, # Input dimension
224
+ embedding_dim=embedding_size, # Output embedding dimension
225
+ hidden_dim=12, # Hidden dimension for processing
226
+ num_blocks=2, # Number of hierarchical blocks
227
+ num_levels_per_block=3, # Number of hierarchical levels per block
228
+ kernel_size=3, # Kernel size for convolutions
229
+ attention_heads=4, # Number of attention heads
230
+ dropout=0.1, # Dropout probability
231
+ )
232
+
233
+ from foreblocks.blocks.famous import TimesBlock, TimesBlockPreprocessor
234
+
235
+ times_wrapper = TimesBlockPreprocessor(d_model=input_size)
236
+ # warmup_softpick(device=torch.device("cuda"))
237
+
238
+ pos_encoder = LearnablePositionalEncoding(512)
239
+ pos_decoder = LearnablePositionalEncoding(512)
240
+
241
+ encoder = TransformerEncoder(
242
+ input_size=model_params.get("input_processor_output_size", 1),
243
+ nhead=model_params.get("nhead", 4),
244
+ num_layers=model_params.get("num_encoder_layers", 1),
245
+ dropout=model_params.get("dropout", 0.1),
246
+ dim_feedforward=model_params.get("dim_feedforward", 2048),
247
+ use_moe=True,
248
+ pos_encoder=pos_encoder,
249
+ att_type="autocor",
250
+ )
251
+
252
+ # Create transformer decoder
253
+ decoder = TransformerDecoder(
254
+ input_size=model_params.get("input_processor_output_size", 1),
255
+ output_size=output_size,
256
+ nhead=model_params.get("nhead", 4),
257
+ num_layers=model_params.get("num_decoder_layers", 1),
258
+ dropout=model_params.get("dropout", 0.1),
259
+ dim_feedforward=model_params.get("dim_feedforward", 2048),
260
+ informer_like=True,
261
+ use_moe=True,
262
+ # att_type="prob_sparse",
263
+ pos_encoder=pos_decoder,
264
+ )
265
+
266
+ # from foreblocks.blocks.mamba import MambaDecoder, MambaEncoder
267
+
268
+ # encoder = MambaEncoder(
269
+ # input_size=input_size, hidden_size=hidden_size, num_layers=num_layers
270
+ # )
271
+ # decoder = MambaDecoder(
272
+ # input_size=output_size,
273
+ # hidden_size=hidden_size,
274
+ # num_layers=num_layers,
275
+ # output_size=output_size,
276
+ # )
277
+
278
+
279
+ # attention_module = AttentionLayer(
280
+ # method="mha",
281
+ # attention_backend="flash",
282
+ # encoder_hidden_size=hidden_size,
283
+ # decoder_hidden_size=hidden_size,
284
+ # nhead=16,
285
+ # )
286
+
287
+ total_epochs = 500
288
+
289
+
290
+ # create scheduled_sampling_fn for teacher forcing
291
+ def scheduled_sampling_fn(epoch):
292
+ # Use a linear decay from 1.0 to 0.0 over the epochs
293
+ tf_ratio = max(0.0, 0.95 - (epoch / total_epochs))
294
+
295
+ return tf_ratio
296
+
297
+
298
+ outprocessor = nn.Sequential(
299
+ GRU(input_size=output_size, hidden_size=32, output_size=output_size),
300
+ )
301
+
302
+ print("Using timewrapper")
303
+ outnorm = nn.LayerNorm(output_size)
304
+ model = ForecastingModel(
305
+ encoder=encoder,
306
+ decoder=decoder,
307
+ target_len=target_len,
308
+ forecasting_strategy="seq2seq",
309
+ model_type="informer-like",
310
+ scheduled_sampling_fn=scheduled_sampling_fn,
311
+ output_size=output_size,
312
+ # attention_module=attention_module,
313
+ input_preprocessor=times_wrapper,
314
+ output_block=outprocessor,
315
+ # output_normalization=outnorm,
316
+ input_skip_connection=False,
317
+ )
318
+
319
+ # model = script(model) # Convert to TorchScript for optimization
320
+ trainer = Trainer(
321
+ model,
322
+ optimizer=torch.optim.Adam(model.parameters(), lr=0.001),
323
+ criterion=nn.MSELoss(),
324
+ )
325
+
326
+
327
+ train_size = int(0.8 * len(X))
328
+ X_train, Y_train = X[:train_size], y[:train_size]
329
+ X_val, Y_val = X[train_size:], y[train_size:]
330
+
331
+ X_train = torch.tensor(X_train, dtype=torch.float32)
332
+ Y_train = torch.tensor(Y_train, dtype=torch.float32)
333
+ Y_val = torch.tensor(Y_val, dtype=torch.float32)
334
+ X_val = torch.tensor(X_val, dtype=torch.float32)
335
+ time_train = torch.tensor(time_feat[:train_size], dtype=torch.float32)
336
+
337
+ # create DataLoader
338
+
339
+ train_dataset = TensorDataset(X_train, Y_train, time_train)
340
+ print(time_train.shape)
341
+ train_loader = DataLoader(train_dataset, batch_size=1024, shuffle=True)
342
+ data = trainer.train(train_loader, epochs=500)
343
+ metrics = trainer.metrics(X_val, Y_val)
344
+
345
+
346
+ # %%
347
+ X = torch.tensor(X, dtype=torch.float32)
348
+ fig = trainer.plot_prediction(X_val, Y_val, full_series=X, offset=train_size)
349
+
350
+
351
+ # %%
@@ -0,0 +1,275 @@
1
+ # Copyright (c) 2024, Sanghun Cho, Tri Dao.
2
+
3
+ import pickle
4
+ import math
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from einops import rearrange, repeat
10
+ from flash_attn.layers.rotary import apply_rotary_emb
11
+
12
+ from flash_attn.utils.benchmark import benchmark_all, benchmark_forward, benchmark_backward
13
+ from flash_attn.utils.benchmark import benchmark_fwd_bwd, benchmark_combined
14
+
15
+ from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
16
+
17
+ try:
18
+ import xformers.ops as xops
19
+ except ImportError:
20
+ xops = None
21
+
22
+
23
+ def generate_cos_sin(seqlen, rotary_dim, device, dtype):
24
+ assert rotary_dim % 2 == 0
25
+ angle = torch.rand(seqlen * 2, rotary_dim // 2, device=device) * 2 * math.pi
26
+ cos = torch.cos(angle).to(dtype=dtype)
27
+ sin = torch.sin(angle).to(dtype=dtype)
28
+ return cos, sin
29
+
30
+
31
+ def flash_rotary(q, k, v, cos, sin, causal=False):
32
+ # corrected by @tridao comments
33
+ q = apply_rotary_emb(
34
+ q, cos, sin, seqlen_offsets=0, interleaved=False, inplace=True
35
+ )
36
+ k = apply_rotary_emb(
37
+ k, cos, sin, seqlen_offsets=0, interleaved=False, inplace=True
38
+ )
39
+
40
+ return flash_attn_func(q, k, v, causal=causal)
41
+
42
+
43
+ def attn_bias_from_alibi_slopes(
44
+ slopes, seqlen_q, seqlen_k, query_padding_mask=None, key_padding_mask=None, causal=False
45
+ ):
46
+ batch, nheads = slopes.shape
47
+ device = slopes.device
48
+ slopes = rearrange(slopes, "b h -> b h 1 1")
49
+ if causal:
50
+ return torch.arange(-seqlen_k + 1, 1, device=device, dtype=torch.float32) * slopes
51
+ else:
52
+ row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1")
53
+ col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
54
+ sk = (
55
+ seqlen_k
56
+ if key_padding_mask is None
57
+ else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
58
+ )
59
+ sq = (
60
+ seqlen_q
61
+ if query_padding_mask is None
62
+ else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1")
63
+ )
64
+ relative_pos = torch.abs(row_idx + sk - sq - col_idx)
65
+ return -slopes * relative_pos.to(dtype=slopes.dtype)
66
+
67
+
68
+ def flops(batch, seqlen, headdim, nheads, causal, mode="fwd"):
69
+ assert mode in ["fwd", "bwd", "fwd_bwd"]
70
+ f = 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1)
71
+ return f if mode == "fwd" else (2.5 * f if mode == "bwd" else 3.5 * f)
72
+
73
+
74
+ def efficiency(flop, time):
75
+ return (flop / time / 10**12) if not math.isnan(time) else 0.0
76
+
77
+
78
+ def attention_pytorch(q, k, v, dropout_p=0.0, causal=True, attn_bias=None):
79
+ """
80
+ Arguments:
81
+ q, k, v: (batch_size, seqlen, nheads, head_dim)
82
+ dropout_p: float
83
+ attn_bias: (batch_size, nheads, seqlen, seqlen) or (1, nheads, seqlen, seqlen)
84
+ Output:
85
+ output: (batch_size, seqlen, nheads, head_dim)
86
+ """
87
+ batch_size, seqlen, nheads, d = q.shape
88
+ q = rearrange(q, 'b t h d -> (b h) t d')
89
+ k = rearrange(k, 'b s h d -> (b h) d s')
90
+ softmax_scale = 1.0 / math.sqrt(d)
91
+ # Preallocate attn_weights for `baddbmm`
92
+ if attn_bias is not None:
93
+ scores = rearrange(attn_bias, 'b h t s -> (b h) t s')
94
+ else:
95
+ scores = torch.empty(batch_size * nheads, seqlen, seqlen, dtype=q.dtype, device=q.device)
96
+ scores = rearrange(torch.baddbmm(scores, q, k, beta=1.0, alpha=softmax_scale),
97
+ '(b h) t s -> b h t s', h=nheads)
98
+ if causal:
99
+ # "triu_tril_cuda_template" not implemented for 'BFloat16'
100
+ # So we have to construct the mask in float
101
+ causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
102
+ # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
103
+ scores = scores + causal_mask.to(dtype=scores.dtype)
104
+ attention = torch.softmax(scores, dim=-1)
105
+ attention_drop = F.dropout(attention, dropout_p)
106
+ output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
107
+ return output.to(dtype=q.dtype)
108
+
109
+
110
+ def time_fwd_bwd(func, *args, **kwargs):
111
+ time_f, time_b = benchmark_fwd_bwd(func, *args, **kwargs)
112
+ return time_f[1].mean, time_b[1].mean
113
+
114
+
115
+ repeats = 30
116
+ device = 'cuda'
117
+ dtype = torch.float16
118
+
119
+ bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 16384)]
120
+ causal_vals = [False, True]
121
+ headdim_vals = [64, 128]
122
+ dim = 2048
123
+ dropout_p = 0.0
124
+
125
+ methods = (["fa2_alibi", "torch"]
126
+ + (["xformers"] if xops is not None else [])
127
+ + ["sdpa"]
128
+ + ["fa2_baseline"]
129
+ + ["fa2_rotary"])
130
+
131
+ time_f = {}
132
+ time_b = {}
133
+ time_f_b = {}
134
+ speed_f = {}
135
+ speed_b = {}
136
+ speed_f_b = {}
137
+ for causal in causal_vals:
138
+ for headdim in headdim_vals:
139
+ for batch_size, seqlen in bs_seqlen_vals:
140
+ config = (causal, headdim, batch_size, seqlen)
141
+ nheads = dim // headdim
142
+ q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype,
143
+ requires_grad=True) for _ in range(3)]
144
+ # alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3
145
+ alibi_slopes = torch.rand(1, nheads, device=device, dtype=torch.float32) * 0.3
146
+ attn_bias = attn_bias_from_alibi_slopes(alibi_slopes, seqlen, seqlen, causal=causal).to(dtype)
147
+ attn_bias = repeat(attn_bias, "1 ... -> b ...", b=batch_size)
148
+ f, b = time_fwd_bwd(
149
+ flash_attn_func,
150
+ q, k, v,
151
+ dropout_p,
152
+ causal=causal,
153
+ # alibi_slopes=alibi_slopes,
154
+ alibi_slopes=None,
155
+ repeats=repeats,
156
+ verbose=False
157
+ )
158
+ time_f[config, "fa2_baseline"] = f
159
+ time_b[config, "fa2_baseline"] = b
160
+
161
+ q = q.detach().requires_grad_(True)
162
+ k = k.detach().requires_grad_(True)
163
+ v = v.detach().requires_grad_(True)
164
+ f, b = time_fwd_bwd(
165
+ flash_attn_func,
166
+ q, k, v,
167
+ dropout_p,
168
+ causal=causal,
169
+ alibi_slopes=rearrange(alibi_slopes, "1 h -> h"),
170
+ # alibi_slopes=None,
171
+ repeats=repeats,
172
+ verbose=False
173
+ )
174
+ time_f[config, "fa2_alibi"] = f
175
+ time_b[config, "fa2_alibi"] = b
176
+
177
+ try:
178
+ q = q.detach().requires_grad_(True)
179
+ k = k.detach().requires_grad_(True)
180
+ v = v.detach().requires_grad_(True)
181
+ f, b = time_fwd_bwd(
182
+ attention_pytorch,
183
+ q, k, v,
184
+ dropout_p,
185
+ causal=causal,
186
+ attn_bias=attn_bias,
187
+ repeats=repeats,
188
+ verbose=False
189
+ )
190
+ except: # Skip if OOM
191
+ f, b = float('nan'), float('nan')
192
+ time_f[config, "torch"] = f
193
+ time_b[config, "torch"] = b
194
+
195
+ # F.sdpa doesn't currently (torch 2.1) dispatch to flash-attn but just to be safe
196
+ with torch.backends.cuda.sdp_kernel(enable_flash=False):
197
+ q_pt = q.detach().requires_grad_(True).transpose(1, 2)
198
+ k_pt = k.detach().requires_grad_(True).transpose(1, 2)
199
+ v_pt = v.detach().requires_grad_(True).transpose(1, 2)
200
+ f, b = time_fwd_bwd(
201
+ F.scaled_dot_product_attention,
202
+ q_pt, k_pt, v_pt,
203
+ attn_mask=attn_bias,
204
+ dropout_p=dropout_p,
205
+ is_causal=causal,
206
+ repeats=repeats,
207
+ verbose=False
208
+ )
209
+ time_f[config, "sdpa"] = f
210
+ time_b[config, "sdpa"] = b
211
+
212
+ if xops is not None:
213
+ q = q.detach().requires_grad_(True)
214
+ k = k.detach().requires_grad_(True)
215
+ v = v.detach().requires_grad_(True)
216
+ if causal:
217
+ attn_bias_xops = xops.LowerTriangularMask().add_bias(attn_bias.expand(-1, -1, seqlen, -1).to(dtype=q.dtype))
218
+ # NotImplementedError: No operator found for `memory_efficient_attention_backward` with inputs:
219
+ # `flshattB@v2.3.6` is not supported because:
220
+ # attn_bias type is <class 'xformers.ops.fmha.attn_bias.LowerTriangularMaskWithTensorBias'>
221
+ # `cutlassB` is not supported because:
222
+ # attn_bias type is <class 'xformers.ops.fmha.attn_bias.LowerTriangularMaskWithTensorBias'>
223
+ attn_bias_xops = attn_bias_xops.materialize((batch_size, nheads, seqlen, seqlen), dtype=q.dtype, device=device)
224
+ else:
225
+ attn_bias_xops = attn_bias.to(dtype=q.dtype)
226
+ f, b = time_fwd_bwd(
227
+ xops.memory_efficient_attention,
228
+ q, k, v,
229
+ attn_bias_xops,
230
+ dropout_p,
231
+ repeats=repeats,
232
+ verbose=False
233
+ )
234
+ time_f[config, "xformers"] = f
235
+ time_b[config, "xformers"] = b
236
+
237
+ q = q.detach().requires_grad_(True)
238
+ k = k.detach().requires_grad_(True)
239
+ v = v.detach().requires_grad_(True)
240
+ cos, sin = generate_cos_sin(seqlen, headdim, device, dtype)
241
+ f, b = time_fwd_bwd(
242
+ flash_rotary,
243
+ q, k, v,
244
+ cos, sin,
245
+ causal,
246
+ repeats=repeats,
247
+ verbose=False
248
+ )
249
+ time_f[config, "fa2_rotary"] = f
250
+ time_b[config, "fa2_rotary"] = b
251
+
252
+ print(f"### causal={causal}, headdim={headdim}, batch_size={batch_size}, seqlen={seqlen} ###")
253
+ csv_output = ""
254
+ csv_output += f"{causal},{headdim},{batch_size},{seqlen},"
255
+ for method in methods:
256
+ time_f_b[config, method] = time_f[config, method] + time_b[config, method]
257
+ speed_f[config, method] = efficiency(
258
+ flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd"),
259
+ time_f[config, method]
260
+ )
261
+ speed_b[config, method] = efficiency(
262
+ flops(batch_size, seqlen, headdim, nheads, causal, mode="bwd"),
263
+ time_b[config, method]
264
+ )
265
+ speed_f_b[config, method] = efficiency(
266
+ flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd_bwd"),
267
+ time_f_b[config, method]
268
+ )
269
+ print(
270
+ f"{method} fwd: {speed_f[config, method]:.2f} TFLOPs/s, "
271
+ f"bwd: {speed_b[config, method]:.2f} TFLOPs/s, "
272
+ f"fwd + bwd: {speed_f_b[config, method]:.2f} TFLOPs/s"
273
+ )
274
+ csv_output += f"{speed_f[config, method]:.2f},{speed_b[config, method]:.2f},{speed_f_b[config, method]:.2f},"
275
+ print(csv_output)