foreblocks 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- examples/rodrigo.py +351 -0
- flash-attention/benchmarks/benchmark_alibi.py +275 -0
- flash-attention/benchmarks/benchmark_causal.py +225 -0
- flash-attention/benchmarks/benchmark_flash_attention.py +180 -0
- flash-attention/benchmarks/benchmark_gemm.py +47 -0
- flash-attention/csrc/composable_kernel/docs/conf.py +50 -0
- flash-attention/csrc/composable_kernel/example/ck_tile/01_fmha/codegen/__init__.py +0 -0
- flash-attention/csrc/composable_kernel/example/ck_tile/01_fmha/codegen/cmake_config.py +5 -0
- flash-attention/csrc/composable_kernel/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py +128 -0
- flash-attention/csrc/composable_kernel/example/ck_tile/01_fmha/codegen/ops/__init__.py +0 -0
- flash-attention/csrc/composable_kernel/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +902 -0
- flash-attention/csrc/composable_kernel/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +574 -0
- flash-attention/csrc/composable_kernel/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py +359 -0
- flash-attention/csrc/composable_kernel/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +855 -0
- flash-attention/csrc/composable_kernel/example/ck_tile/01_fmha/generate.py +136 -0
- flash-attention/csrc/composable_kernel/example/ck_tile/02_layernorm2d/generate.py +730 -0
- flash-attention/csrc/composable_kernel/example/ck_tile/10_rmsnorm2d/generate.py +715 -0
- flash-attention/csrc/composable_kernel/example/ck_tile/remod.py +21 -0
- flash-attention/csrc/composable_kernel/include/ck_tile/remod.py +93 -0
- flash-attention/csrc/composable_kernel/python/ck4inductor/__init__.py +0 -0
- flash-attention/csrc/composable_kernel/python/ck4inductor/batched_universal_gemm/gen_instances.py +149 -0
- flash-attention/csrc/composable_kernel/python/ck4inductor/batched_universal_gemm/op.py +99 -0
- flash-attention/csrc/composable_kernel/python/ck4inductor/grouped_conv_fwd/gen_instances.py +165 -0
- flash-attention/csrc/composable_kernel/python/ck4inductor/grouped_conv_fwd/op.py +93 -0
- flash-attention/csrc/composable_kernel/python/ck4inductor/universal_gemm/gen_instances.py +572 -0
- flash-attention/csrc/composable_kernel/python/ck4inductor/universal_gemm/op.py +99 -0
- flash-attention/csrc/composable_kernel/python/ck4inductor/util.py +10 -0
- flash-attention/csrc/composable_kernel/python/test/test_gen_instances.py +46 -0
- flash-attention/csrc/composable_kernel/script/convert_miopen_driver_to_profiler.py +413 -0
- flash-attention/csrc/composable_kernel/script/process_perf_data.py +382 -0
- flash-attention/csrc/composable_kernel/tile_engine/ops/gemm/gemm_instance_builder.py +654 -0
- flash-attention/csrc/cutlass/examples/40_cutlass_py/conv2d.py +177 -0
- flash-attention/csrc/cutlass/examples/40_cutlass_py/customizable/conv2d.py +331 -0
- flash-attention/csrc/cutlass/examples/40_cutlass_py/customizable/gemm.py +331 -0
- flash-attention/csrc/cutlass/examples/40_cutlass_py/customizable/gemm_grouped.py +298 -0
- flash-attention/csrc/cutlass/examples/40_cutlass_py/gemm.py +153 -0
- flash-attention/csrc/cutlass/examples/40_cutlass_py/gemm_grouped.py +172 -0
- flash-attention/csrc/cutlass/examples/41_fused_multi_head_attention/fmha_backward_test.py +232 -0
- flash-attention/csrc/cutlass/examples/41_fused_multi_head_attention/piped_subprocess.py +144 -0
- flash-attention/csrc/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_all_code.py +129 -0
- flash-attention/csrc/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_cmake.py +131 -0
- flash-attention/csrc/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_customized_epilogue.py +120 -0
- flash-attention/csrc/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_device.py +469 -0
- flash-attention/csrc/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_ir.py +249 -0
- flash-attention/csrc/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_kernel.py +476 -0
- flash-attention/csrc/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_sample.py +232 -0
- flash-attention/csrc/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_threadblock.py +1013 -0
- flash-attention/csrc/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_turing_and_volta.py +456 -0
- flash-attention/csrc/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_verify.py +92 -0
- flash-attention/csrc/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/helper.py +135 -0
- flash-attention/csrc/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/replace_fix_impl_header.py +67 -0
- flash-attention/csrc/cutlass/python/cutlass/__init__.py +190 -0
- flash-attention/csrc/cutlass/python/cutlass/backend/__init__.py +48 -0
- flash-attention/csrc/cutlass/python/cutlass/backend/arguments.py +133 -0
- flash-attention/csrc/cutlass/python/cutlass/backend/c_types.py +622 -0
- flash-attention/csrc/cutlass/python/cutlass/backend/compiler.py +459 -0
- flash-attention/csrc/cutlass/python/cutlass/backend/conv2d_operation.py +698 -0
- flash-attention/csrc/cutlass/python/cutlass/backend/epilogue.py +541 -0
- flash-attention/csrc/cutlass/python/cutlass/backend/evt/__init__.py +34 -0
- flash-attention/csrc/cutlass/python/cutlass/backend/evt/backend/__init__.py +36 -0
- flash-attention/csrc/cutlass/python/cutlass/backend/evt/backend/emitter_base.py +158 -0
- flash-attention/csrc/cutlass/python/cutlass/backend/evt/backend/sm80_emitter.py +47 -0
- flash-attention/csrc/cutlass/python/cutlass/backend/evt/backend/sm80_nodes.py +258 -0
- flash-attention/csrc/cutlass/python/cutlass/backend/evt/backend/sm90_emitter.py +98 -0
- flash-attention/csrc/cutlass/python/cutlass/backend/evt/backend/sm90_nodes.py +329 -0
- flash-attention/csrc/cutlass/python/cutlass/backend/evt/epilogue.py +167 -0
- flash-attention/csrc/cutlass/python/cutlass/backend/evt/frontend/__init__.py +33 -0
- flash-attention/csrc/cutlass/python/cutlass/backend/evt/frontend/frontend_base.py +262 -0
- flash-attention/csrc/cutlass/python/cutlass/backend/evt/frontend/python_ast.py +187 -0
- flash-attention/csrc/cutlass/python/cutlass/backend/evt/ir/__init__.py +53 -0
- flash-attention/csrc/cutlass/python/cutlass/backend/evt/ir/compute_nodes.py +91 -0
- flash-attention/csrc/cutlass/python/cutlass/backend/evt/ir/dag_ir.py +236 -0
- flash-attention/csrc/cutlass/python/cutlass/backend/evt/ir/layout_algorithm.py +324 -0
- flash-attention/csrc/cutlass/python/cutlass/backend/evt/ir/layout_nodes.py +336 -0
- flash-attention/csrc/cutlass/python/cutlass/backend/evt/ir/load_nodes.py +294 -0
- flash-attention/csrc/cutlass/python/cutlass/backend/evt/ir/node.py +293 -0
- flash-attention/csrc/cutlass/python/cutlass/backend/evt/ir/store_nodes.py +277 -0
- flash-attention/csrc/cutlass/python/cutlass/backend/evt/ir/tensor.py +130 -0
- flash-attention/csrc/cutlass/python/cutlass/backend/evt/passes/__init__.py +42 -0
- flash-attention/csrc/cutlass/python/cutlass/backend/evt/passes/graph_drawer.py +142 -0
- flash-attention/csrc/cutlass/python/cutlass/backend/evt/passes/pass_argument_type.py +116 -0
- flash-attention/csrc/cutlass/python/cutlass/backend/evt/passes/pass_dag_2_tree.py +147 -0
- flash-attention/csrc/cutlass/python/cutlass/backend/evt/passes/pass_fix_element_d.py +64 -0
- flash-attention/csrc/cutlass/python/cutlass/backend/evt/passes/pass_get_impl.py +90 -0
- flash-attention/csrc/cutlass/python/cutlass/backend/evt/passes/pass_layout_elimination.py +217 -0
- flash-attention/csrc/cutlass/python/cutlass/backend/evt/passes/pass_manager.py +164 -0
- flash-attention/csrc/cutlass/python/cutlass/backend/evt/passes/pass_no_op_elimination.py +53 -0
- flash-attention/csrc/cutlass/python/cutlass/backend/evt/passes/pass_preprocess_red.py +97 -0
- flash-attention/csrc/cutlass/python/cutlass/backend/evt/passes/pass_shape_type_propagation.py +59 -0
- flash-attention/csrc/cutlass/python/cutlass/backend/evt/passes/smem_size_calculator.py +204 -0
- flash-attention/csrc/cutlass/python/cutlass/backend/evt/passes/util.py +43 -0
- flash-attention/csrc/cutlass/python/cutlass/backend/frontend.py +107 -0
- flash-attention/csrc/cutlass/python/cutlass/backend/gemm_operation.py +2138 -0
- flash-attention/csrc/cutlass/python/cutlass/backend/library.py +488 -0
- flash-attention/csrc/cutlass/python/cutlass/backend/memory_manager.py +120 -0
- flash-attention/csrc/cutlass/python/cutlass/backend/operation.py +133 -0
- flash-attention/csrc/cutlass/python/cutlass/backend/reduction_operation.py +452 -0
- flash-attention/csrc/cutlass/python/cutlass/backend/type_hint.py +35 -0
- flash-attention/csrc/cutlass/python/cutlass/backend/utils/__init__.py +33 -0
- flash-attention/csrc/cutlass/python/cutlass/backend/utils/device.py +123 -0
- flash-attention/csrc/cutlass/python/cutlass/emit/__init__.py +33 -0
- flash-attention/csrc/cutlass/python/cutlass/emit/common.py +267 -0
- flash-attention/csrc/cutlass/python/cutlass/emit/pytorch.py +936 -0
- flash-attention/csrc/cutlass/python/cutlass/epilogue/__init__.py +55 -0
- flash-attention/csrc/cutlass/python/cutlass/epilogue/epilogue.py +158 -0
- flash-attention/csrc/cutlass/python/cutlass/epilogue/evt_ops.py +92 -0
- flash-attention/csrc/cutlass/python/cutlass/library_defaults.py +580 -0
- flash-attention/csrc/cutlass/python/cutlass/op/__init__.py +36 -0
- flash-attention/csrc/cutlass/python/cutlass/op/conv.py +983 -0
- flash-attention/csrc/cutlass/python/cutlass/op/gemm.py +715 -0
- flash-attention/csrc/cutlass/python/cutlass/op/gemm_grouped.py +264 -0
- flash-attention/csrc/cutlass/python/cutlass/op/op.py +430 -0
- flash-attention/csrc/cutlass/python/cutlass/shape.py +184 -0
- flash-attention/csrc/cutlass/python/cutlass/swizzle.py +65 -0
- flash-attention/csrc/cutlass/python/cutlass/utils/__init__.py +41 -0
- flash-attention/csrc/cutlass/python/cutlass/utils/check.py +269 -0
- flash-attention/csrc/cutlass/python/cutlass/utils/datatypes.py +362 -0
- flash-attention/csrc/cutlass/python/cutlass/utils/profiler.py +185 -0
- flash-attention/csrc/cutlass/python/cutlass_library/__init__.py +63 -0
- flash-attention/csrc/cutlass/python/cutlass_library/conv2d_operation.py +621 -0
- flash-attention/csrc/cutlass/python/cutlass_library/conv3d_operation.py +482 -0
- flash-attention/csrc/cutlass/python/cutlass_library/conv3x_emitter.py +250 -0
- flash-attention/csrc/cutlass/python/cutlass_library/emit_kernel_listing.py +880 -0
- flash-attention/csrc/cutlass/python/cutlass_library/gemm_operation.py +1520 -0
- flash-attention/csrc/cutlass/python/cutlass_library/generator.py +10851 -0
- flash-attention/csrc/cutlass/python/cutlass_library/library.py +1317 -0
- flash-attention/csrc/cutlass/python/cutlass_library/manifest.py +870 -0
- flash-attention/csrc/cutlass/python/cutlass_library/rank_2k_operation.py +438 -0
- flash-attention/csrc/cutlass/python/cutlass_library/rank_k_operation.py +427 -0
- flash-attention/csrc/cutlass/python/cutlass_library/sm90_shapes.py +212 -0
- flash-attention/csrc/cutlass/python/cutlass_library/sm90_utils.py +703 -0
- flash-attention/csrc/cutlass/python/cutlass_library/symm_operation.py +440 -0
- flash-attention/csrc/cutlass/python/cutlass_library/trmm_operation.py +447 -0
- flash-attention/csrc/cutlass/python/docs_src/source/conf.py +132 -0
- flash-attention/csrc/cutlass/python/pycute/__init__.py +36 -0
- flash-attention/csrc/cutlass/python/pycute/int_tuple.py +225 -0
- flash-attention/csrc/cutlass/python/pycute/layout.py +367 -0
- flash-attention/csrc/cutlass/python/pycute/swizzle.py +129 -0
- flash-attention/csrc/cutlass/python/pycute/typing.py +42 -0
- flash-attention/csrc/cutlass/python/setup_cutlass.py +74 -0
- flash-attention/csrc/cutlass/python/setup_library.py +46 -0
- flash-attention/csrc/cutlass/python/setup_pycute.py +46 -0
- flash-attention/csrc/cutlass/test/python/cutlass/conv2d/conv2d_problem_sizes.py +661 -0
- flash-attention/csrc/cutlass/test/python/cutlass/conv2d/conv2d_sm80.py +146 -0
- flash-attention/csrc/cutlass/test/python/cutlass/conv2d/conv2d_test_utils.py +428 -0
- flash-attention/csrc/cutlass/test/python/cutlass/conv2d/run_all_tests.py +44 -0
- flash-attention/csrc/cutlass/test/python/cutlass/emit/pytorch.py +309 -0
- flash-attention/csrc/cutlass/test/python/cutlass/evt/evt_compute_sm80_90.py +122 -0
- flash-attention/csrc/cutlass/test/python/cutlass/evt/evt_layout_sm80_90.py +173 -0
- flash-attention/csrc/cutlass/test/python/cutlass/evt/evt_load_sm80_90.py +142 -0
- flash-attention/csrc/cutlass/test/python/cutlass/evt/evt_mixed_sm80_90.py +274 -0
- flash-attention/csrc/cutlass/test/python/cutlass/evt/evt_store_sm80_90.py +155 -0
- flash-attention/csrc/cutlass/test/python/cutlass/evt/run_all_tests.py +44 -0
- flash-attention/csrc/cutlass/test/python/cutlass/evt/utils/evt_testbed.py +230 -0
- flash-attention/csrc/cutlass/test/python/cutlass/gemm/gemm_batched.py +134 -0
- flash-attention/csrc/cutlass/test/python/cutlass/gemm/gemm_f16_sm80.py +128 -0
- flash-attention/csrc/cutlass/test/python/cutlass/gemm/gemm_f16_sm90.py +146 -0
- flash-attention/csrc/cutlass/test/python/cutlass/gemm/gemm_f32_sm80.py +104 -0
- flash-attention/csrc/cutlass/test/python/cutlass/gemm/gemm_f64_sm80.py +103 -0
- flash-attention/csrc/cutlass/test/python/cutlass/gemm/gemm_f64_sm90.py +71 -0
- flash-attention/csrc/cutlass/test/python/cutlass/gemm/gemm_f8_sm90.py +112 -0
- flash-attention/csrc/cutlass/test/python/cutlass/gemm/gemm_mixed_sm80.py +75 -0
- flash-attention/csrc/cutlass/test/python/cutlass/gemm/gemm_s8_sm80.py +103 -0
- flash-attention/csrc/cutlass/test/python/cutlass/gemm/gemm_s8_sm90.py +98 -0
- flash-attention/csrc/cutlass/test/python/cutlass/gemm/gemm_testbed.py +423 -0
- flash-attention/csrc/cutlass/test/python/cutlass/gemm/run_all_tests.py +44 -0
- flash-attention/csrc/cutlass/test/python/cutlass/gemm/utils.py +260 -0
- flash-attention/csrc/cutlass/test/python/cutlass/installation.py +57 -0
- flash-attention/csrc/cutlass/test/python/cutlass/interface/conv2d_interface.py +284 -0
- flash-attention/csrc/cutlass/test/python/cutlass/interface/evt_interface.py +254 -0
- flash-attention/csrc/cutlass/test/python/cutlass/interface/gemm_interface.py +351 -0
- flash-attention/csrc/cutlass/test/python/cutlass/interface/utils.py +69 -0
- flash-attention/csrc/cutlass/test/python/pycute/run_all_tests.py +75 -0
- flash-attention/csrc/cutlass/test/python/pycute/test_coalesce.py +95 -0
- flash-attention/csrc/cutlass/test/python/pycute/test_complement.py +92 -0
- flash-attention/csrc/cutlass/test/python/pycute/test_composition.py +213 -0
- flash-attention/csrc/cutlass/test/python/pycute/test_int_tuple.py +80 -0
- flash-attention/csrc/cutlass/test/python/pycute/test_left_inverse.py +87 -0
- flash-attention/csrc/cutlass/test/python/pycute/test_right_inverse.py +96 -0
- flash-attention/csrc/cutlass/test/python/pycute/test_typing.py +59 -0
- flash-attention/csrc/cutlass/test/unit/gemm/device/simt_sm50.py +341 -0
- flash-attention/csrc/flash_attn/src/generate_kernels.py +110 -0
- flash-attention/csrc/ft_attention/setup.py +153 -0
- flash-attention/csrc/fused_dense_lib/setup.py +42 -0
- flash-attention/csrc/fused_softmax/setup.py +50 -0
- flash-attention/csrc/layer_norm/setup.py +205 -0
- flash-attention/csrc/rotary/setup.py +126 -0
- flash-attention/csrc/xentropy/setup.py +139 -0
- flash-attention/flash_attn/__init__.py +11 -0
- flash-attention/flash_attn/bert_padding.py +218 -0
- flash-attention/flash_attn/flash_attn_interface.py +1606 -0
- flash-attention/flash_attn/flash_attn_triton.py +1160 -0
- flash-attention/flash_attn/flash_attn_triton_amd/__init__.py +0 -0
- flash-attention/flash_attn/flash_attn_triton_amd/bench.py +1223 -0
- flash-attention/flash_attn/flash_attn_triton_amd/bwd_prefill.py +814 -0
- flash-attention/flash_attn/flash_attn_triton_amd/bwd_prefill_fused.py +3266 -0
- flash-attention/flash_attn/flash_attn_triton_amd/bwd_prefill_onekernel.py +1091 -0
- flash-attention/flash_attn/flash_attn_triton_amd/bwd_prefill_split.py +1354 -0
- flash-attention/flash_attn/flash_attn_triton_amd/bwd_ref.py +478 -0
- flash-attention/flash_attn/flash_attn_triton_amd/fp8.py +716 -0
- flash-attention/flash_attn/flash_attn_triton_amd/fwd_decode.py +814 -0
- flash-attention/flash_attn/flash_attn_triton_amd/fwd_prefill.py +648 -0
- flash-attention/flash_attn/flash_attn_triton_amd/fwd_ref.py +387 -0
- flash-attention/flash_attn/flash_attn_triton_amd/interface_fa.py +798 -0
- flash-attention/flash_attn/flash_attn_triton_amd/test.py +932 -0
- flash-attention/flash_attn/flash_attn_triton_amd/train.py +403 -0
- flash-attention/flash_attn/flash_attn_triton_amd/utils.py +776 -0
- flash-attention/flash_attn/flash_attn_triton_og.py +365 -0
- flash-attention/flash_attn/flash_blocksparse_attention.py +197 -0
- flash-attention/flash_attn/flash_blocksparse_attn_interface.py +200 -0
- flash-attention/flash_attn/fused_softmax.py +201 -0
- flash-attention/flash_attn/layers/__init__.py +0 -0
- flash-attention/flash_attn/layers/patch_embed.py +67 -0
- flash-attention/flash_attn/layers/rotary.py +482 -0
- flash-attention/flash_attn/losses/__init__.py +0 -0
- flash-attention/flash_attn/losses/cross_entropy.py +85 -0
- flash-attention/flash_attn/models/__init__.py +0 -0
- flash-attention/flash_attn/models/baichuan.py +151 -0
- flash-attention/flash_attn/models/bert.py +764 -0
- flash-attention/flash_attn/models/bigcode.py +233 -0
- flash-attention/flash_attn/models/btlm.py +102 -0
- flash-attention/flash_attn/models/falcon.py +143 -0
- flash-attention/flash_attn/models/gpt.py +1080 -0
- flash-attention/flash_attn/models/gpt_neox.py +124 -0
- flash-attention/flash_attn/models/gptj.py +109 -0
- flash-attention/flash_attn/models/llama.py +422 -0
- flash-attention/flash_attn/models/opt.py +116 -0
- flash-attention/flash_attn/models/vit.py +373 -0
- flash-attention/flash_attn/modules/__init__.py +0 -0
- flash-attention/flash_attn/modules/block.py +397 -0
- flash-attention/flash_attn/modules/embedding.py +216 -0
- flash-attention/flash_attn/modules/mha.py +993 -0
- flash-attention/flash_attn/modules/mlp.py +191 -0
- flash-attention/flash_attn/ops/__init__.py +0 -0
- flash-attention/flash_attn/ops/activations.py +135 -0
- flash-attention/flash_attn/ops/fused_dense.py +688 -0
- flash-attention/flash_attn/ops/layer_norm.py +800 -0
- flash-attention/flash_attn/ops/rms_norm.py +174 -0
- flash-attention/flash_attn/ops/triton/__init__.py +1 -0
- flash-attention/flash_attn/ops/triton/cross_entropy.py +330 -0
- flash-attention/flash_attn/ops/triton/k_activations.py +162 -0
- flash-attention/flash_attn/ops/triton/layer_norm.py +1252 -0
- flash-attention/flash_attn/ops/triton/linear.py +594 -0
- flash-attention/flash_attn/ops/triton/mlp.py +149 -0
- flash-attention/flash_attn/ops/triton/rotary.py +185 -0
- flash-attention/flash_attn/utils/__init__.py +0 -0
- flash-attention/flash_attn/utils/benchmark.py +268 -0
- flash-attention/flash_attn/utils/distributed.py +144 -0
- flash-attention/flash_attn/utils/generation.py +740 -0
- flash-attention/flash_attn/utils/library.py +66 -0
- flash-attention/flash_attn/utils/pretrained.py +79 -0
- flash-attention/flash_attn/utils/torch.py +21 -0
- flash-attention/hopper/__init__.py +1 -0
- flash-attention/hopper/benchmark_attn.py +411 -0
- flash-attention/hopper/benchmark_flash_attention_fp8.py +353 -0
- flash-attention/hopper/benchmark_mla_decode.py +129 -0
- flash-attention/hopper/benchmark_split_kv.py +331 -0
- flash-attention/hopper/flash_attn_interface.py +834 -0
- flash-attention/hopper/generate_kernels.py +222 -0
- flash-attention/hopper/padding.py +53 -0
- flash-attention/hopper/setup.py +659 -0
- flash-attention/hopper/test_attn_kvcache.py +486 -0
- flash-attention/hopper/test_flash_attn.py +1130 -0
- flash-attention/hopper/test_kvcache.py +234 -0
- flash-attention/hopper/test_util.py +348 -0
- flash-attention/setup.py +561 -0
- flash-attention/tests/layers/test_rotary.py +134 -0
- flash-attention/tests/losses/test_cross_entropy.py +83 -0
- flash-attention/tests/losses/test_cross_entropy_parallel.py +104 -0
- flash-attention/tests/models/test_baichuan.py +460 -0
- flash-attention/tests/models/test_bert.py +324 -0
- flash-attention/tests/models/test_bigcode.py +204 -0
- flash-attention/tests/models/test_btlm.py +245 -0
- flash-attention/tests/models/test_falcon.py +408 -0
- flash-attention/tests/models/test_gpt.py +478 -0
- flash-attention/tests/models/test_gpt_generation_parallel.py +172 -0
- flash-attention/tests/models/test_gpt_neox.py +104 -0
- flash-attention/tests/models/test_gpt_parallel.py +236 -0
- flash-attention/tests/models/test_gptj.py +184 -0
- flash-attention/tests/models/test_llama.py +633 -0
- flash-attention/tests/models/test_opt.py +237 -0
- flash-attention/tests/models/test_vit.py +48 -0
- flash-attention/tests/modules/test_block_parallel.py +273 -0
- flash-attention/tests/modules/test_embedding_parallel.py +106 -0
- flash-attention/tests/modules/test_mha_parallel.py +160 -0
- flash-attention/tests/modules/test_mlp_parallel.py +143 -0
- flash-attention/tests/ops/test_dropout_layer_norm.py +1189 -0
- flash-attention/tests/ops/test_fused_dense.py +172 -0
- flash-attention/tests/ops/test_fused_dense_parallel.py +237 -0
- flash-attention/tests/ops/triton/test_layer_norm.py +374 -0
- flash-attention/tests/test_flash_attn.py +2525 -0
- flash-attention/tests/test_flash_attn_ck.py +1618 -0
- flash-attention/tests/test_flash_attn_triton_amd.py +2547 -0
- flash-attention/tests/test_rotary.py +321 -0
- flash-attention/tests/test_util.py +274 -0
- flash-attention/training/run.py +68 -0
- flash-attention/training/src/callbacks/__init__.py +0 -0
- flash-attention/training/src/callbacks/causality_monitor.py +61 -0
- flash-attention/training/src/callbacks/ema.py +82 -0
- flash-attention/training/src/callbacks/flop_count.py +43 -0
- flash-attention/training/src/callbacks/gpu_affinity.py +40 -0
- flash-attention/training/src/callbacks/loss_scale_monitor.py +32 -0
- flash-attention/training/src/callbacks/model_checkpoint.py +36 -0
- flash-attention/training/src/callbacks/norm_monitor.py +79 -0
- flash-attention/training/src/callbacks/params_log.py +34 -0
- flash-attention/training/src/callbacks/speed_monitor.py +95 -0
- flash-attention/training/src/callbacks/wandb_callbacks.py +289 -0
- flash-attention/training/src/datamodules/datasets/detokenizer.py +53 -0
- flash-attention/training/src/datamodules/datasets/lm_dataset.py +32 -0
- flash-attention/training/src/datamodules/fault_tolerant_sampler.py +123 -0
- flash-attention/training/src/datamodules/imagenet.py +283 -0
- flash-attention/training/src/datamodules/language_modeling_hf.py +299 -0
- flash-attention/training/src/datamodules/timm_mixup.py +20 -0
- flash-attention/training/src/distributed/ddp_comm_hooks.py +43 -0
- flash-attention/training/src/eval.py +129 -0
- flash-attention/training/src/metrics/accuracy.py +11 -0
- flash-attention/training/src/metrics/num_tokens.py +45 -0
- flash-attention/training/src/metrics/perplexity.py +70 -0
- flash-attention/training/src/models/modules/seq_common.py +342 -0
- flash-attention/training/src/optim/param_grouping.py +114 -0
- flash-attention/training/src/optim/timm_lr_scheduler.py +30 -0
- flash-attention/training/src/tasks/seq.py +192 -0
- flash-attention/training/src/train.py +136 -0
- flash-attention/training/src/utils/checkpoint.py +76 -0
- flash-attention/training/src/utils/ddp_zero1.py +106 -0
- flash-attention/training/src/utils/ddp_zero2.py +146 -0
- flash-attention/training/src/utils/distributed.py +111 -0
- flash-attention/training/src/utils/ema.py +280 -0
- flash-attention/training/src/utils/flops.py +45 -0
- flash-attention/training/src/utils/gpu_affinity.py +142 -0
- flash-attention/training/src/utils/utils.py +146 -0
- flash-attention/training/tests/datamodules/test_language_modeling_hf.py +218 -0
- foreblocks/__init__.py +42 -0
- foreblocks/att.py +299 -0
- foreblocks/aux.py +45 -0
- foreblocks/blocks/__init__.py +64 -0
- foreblocks/blocks/attention.py +287 -0
- foreblocks/blocks/famous.py +529 -0
- foreblocks/blocks/fourier.py +464 -0
- foreblocks/blocks/graph.py +1466 -0
- foreblocks/blocks/mamba.py +119 -0
- foreblocks/blocks/multiscale.py +124 -0
- foreblocks/blocks/nha.py +476 -0
- foreblocks/blocks/ode.py +184 -0
- foreblocks/blocks/simple.py +204 -0
- foreblocks/blocks/wavelets.py +439 -0
- foreblocks/blocks.py +0 -0
- foreblocks/core.py +698 -0
- foreblocks/darts/darts.py +557 -0
- foreblocks/darts/darts_run.py +2386 -0
- foreblocks/enc_dec.py +218 -0
- foreblocks/pipeline.py +389 -0
- foreblocks/pre/__init__.py +0 -0
- foreblocks/pre/ewt.py +104 -0
- foreblocks/pre/filters.py +147 -0
- foreblocks/pre/impute.py +428 -0
- foreblocks/pre/outlier.py +399 -0
- foreblocks/preprocessing.py +978 -0
- foreblocks/tf/embeddings.py +395 -0
- foreblocks/tf/fed.py +322 -0
- foreblocks/tf/transformer.py +690 -0
- foreblocks/tf/transformer_att.py +535 -0
- foreblocks/tf/transformer_aux.py +230 -0
- foreblocks/tf/transformer_moe.py +1137 -0
- foreblocks/third_party/flash_softpick_attn.py +796 -0
- foreblocks/third_party/vsgd.py +212 -0
- foreblocks/utils.py +576 -0
- foreblocks-0.1.0.dist-info/METADATA +484 -0
- foreblocks-0.1.0.dist-info/RECORD +371 -0
- foreblocks-0.1.0.dist-info/WHEEL +5 -0
- foreblocks-0.1.0.dist-info/top_level.txt +3 -0
examples/rodrigo.py
ADDED
|
@@ -0,0 +1,351 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import pickle
|
|
3
|
+
import sys
|
|
4
|
+
|
|
5
|
+
import matplotlib.pyplot as plt
|
|
6
|
+
import numpy as np
|
|
7
|
+
import torch
|
|
8
|
+
import torch.nn as nn
|
|
9
|
+
|
|
10
|
+
# Get the current working directory of the notebook
|
|
11
|
+
notebook_dir = os.getcwd()
|
|
12
|
+
|
|
13
|
+
# Add the parent directory to sys.path
|
|
14
|
+
parent_dir = os.path.abspath(os.path.join(notebook_dir, "."))
|
|
15
|
+
if parent_dir not in sys.path:
|
|
16
|
+
sys.path.append(parent_dir)
|
|
17
|
+
|
|
18
|
+
from foreblocks.att import AttentionLayer
|
|
19
|
+
from foreblocks.blocks import GRU
|
|
20
|
+
from foreblocks.blocks.fourier import FNO1DLayer, FourierFeatures
|
|
21
|
+
from foreblocks.blocks.graph import LatentGraphNetwork
|
|
22
|
+
from foreblocks.tf.embeddings import LearnablePositionalEncoding
|
|
23
|
+
from foreblocks.tf.transformer import TransformerDecoder, TransformerEncoder
|
|
24
|
+
from torch.jit import script
|
|
25
|
+
from torch.utils.data import DataLoader, TensorDataset
|
|
26
|
+
|
|
27
|
+
from foreblocks import ForecastingModel, LSTMDecoder, LSTMEncoder, Trainer
|
|
28
|
+
|
|
29
|
+
# ---
|
|
30
|
+
# jupyter:
|
|
31
|
+
# jupytext:
|
|
32
|
+
# text_representation:
|
|
33
|
+
# extension: .py
|
|
34
|
+
# format_name: hydrogen
|
|
35
|
+
# format_version: '1.3'
|
|
36
|
+
# jupytext_version: 1.17.1
|
|
37
|
+
# kernelspec:
|
|
38
|
+
# display_name: .venv
|
|
39
|
+
# language: python
|
|
40
|
+
# name: python3
|
|
41
|
+
# ---
|
|
42
|
+
|
|
43
|
+
# %%
|
|
44
|
+
|
|
45
|
+
total_epochs = 500
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
# create scheduled_sampling_fn for teacher forcing
|
|
49
|
+
def scheduled_sampling_fn(epoch):
|
|
50
|
+
tf_ratio = max(0.0, 0.8 - (epoch / total_epochs))
|
|
51
|
+
return tf_ratio
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
# Get the current working directory of the notebook
|
|
55
|
+
notebook_dir = os.getcwd()
|
|
56
|
+
|
|
57
|
+
# Add the parent directory to sys.path
|
|
58
|
+
parent_dir = os.path.abspath(os.path.join(notebook_dir, ".."))
|
|
59
|
+
if parent_dir not in sys.path:
|
|
60
|
+
sys.path.append(parent_dir)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
# df = pd.read_csv('df_demmand_without_category_2025_05_13.csv')
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
# %%
|
|
67
|
+
|
|
68
|
+
# # read df_demmand_without_category_2025_05_13.csv
|
|
69
|
+
# import pandas as pd
|
|
70
|
+
# df = pd.read_csv('df_demmand_without_category_2025_05_13.csv')
|
|
71
|
+
# from foreblocks import TimeSeriesPreprocessor
|
|
72
|
+
|
|
73
|
+
# import numpy as np
|
|
74
|
+
# import pandas as pd
|
|
75
|
+
# import matplotlib.pyplot as plt
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
# # Generate synthetic time series data
|
|
79
|
+
# np.random.seed(42)
|
|
80
|
+
# n_samples = 200
|
|
81
|
+
# timestamps = df.date
|
|
82
|
+
|
|
83
|
+
# # convert df to numpy
|
|
84
|
+
# data = df.drop(columns=['date']).values
|
|
85
|
+
# # Create preprocessor with various techniques enabled
|
|
86
|
+
# preprocessor = TimeSeriesPreprocessor(
|
|
87
|
+
# normalize=True,
|
|
88
|
+
# differencing=False,
|
|
89
|
+
# detrend=True,
|
|
90
|
+
# apply_ewt=True,
|
|
91
|
+
# window_size=24,
|
|
92
|
+
# horizon=12,
|
|
93
|
+
# remove_outliers=True,
|
|
94
|
+
# outlier_threshold=2.5,
|
|
95
|
+
# outlier_method="iqr",
|
|
96
|
+
# impute_method="iterative",
|
|
97
|
+
# ewt_bands=5,
|
|
98
|
+
# trend_imf_idx=0,
|
|
99
|
+
# log_transform=False,
|
|
100
|
+
# filter_window=5,
|
|
101
|
+
# filter_polyorder=2,
|
|
102
|
+
# apply_filter=True,
|
|
103
|
+
# self_tune=True,
|
|
104
|
+
# apply_imputation=True,
|
|
105
|
+
# generate_time_features=False,
|
|
106
|
+
# )
|
|
107
|
+
|
|
108
|
+
# # Fit and transform the data
|
|
109
|
+
# X, y, processed_data = preprocessor.fit_transform(data, time_stamps=timestamps)
|
|
110
|
+
|
|
111
|
+
# # Visualize the results
|
|
112
|
+
# plt.figure(figsize=(15, 10))
|
|
113
|
+
|
|
114
|
+
# plt.subplot(3, 1, 1)
|
|
115
|
+
# plt.title('Original Data with Outliers and Missing Values')
|
|
116
|
+
# plt.plot(data)
|
|
117
|
+
|
|
118
|
+
# plt.subplot(3, 1, 2)
|
|
119
|
+
# plt.title('Processed Data')
|
|
120
|
+
# print("Processed data shape:", processed_data.shape)
|
|
121
|
+
# plt.plot(processed_data)
|
|
122
|
+
|
|
123
|
+
# plt.subplot(3, 1, 3)
|
|
124
|
+
# plt.title('EWT Components')
|
|
125
|
+
# ewt_components = preprocessor.get_ewt_components()
|
|
126
|
+
# if ewt_components:
|
|
127
|
+
# for i, imf in enumerate(ewt_components[0].T):
|
|
128
|
+
# plt.plot(imf, label=f'IMF {i}')
|
|
129
|
+
# plt.legend()
|
|
130
|
+
|
|
131
|
+
# plt.tight_layout()
|
|
132
|
+
# plt.show()
|
|
133
|
+
|
|
134
|
+
# print(f"Input sequence shape: {X.shape}")
|
|
135
|
+
# print(f"Target sequence shape: {y.shape}")
|
|
136
|
+
|
|
137
|
+
# %%
|
|
138
|
+
# # load the processed data
|
|
139
|
+
# # save X and y to pickle
|
|
140
|
+
# import pickle
|
|
141
|
+
# with open('X.pkl', 'wb') as f:
|
|
142
|
+
# pickle.dump(X, f)
|
|
143
|
+
# with open('y.pkl', 'wb') as f:
|
|
144
|
+
# pickle.dump(y, f)
|
|
145
|
+
|
|
146
|
+
# %%
|
|
147
|
+
|
|
148
|
+
# load X and y from pickle
|
|
149
|
+
with open("examples/X.pkl", "rb") as f:
|
|
150
|
+
X = pickle.load(f)
|
|
151
|
+
with open("examples/y.pkl", "rb") as f:
|
|
152
|
+
y = pickle.load(f)
|
|
153
|
+
with open("examples/time.pkl", "rb") as f:
|
|
154
|
+
time_feat = pickle.load(f)
|
|
155
|
+
|
|
156
|
+
# %%
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
# Parameters
|
|
160
|
+
input_size = X.shape[2] # Number of features
|
|
161
|
+
hidden_size = 64
|
|
162
|
+
num_layers = 2
|
|
163
|
+
output_size = X.shape[2] # Number of features
|
|
164
|
+
target_len = 12
|
|
165
|
+
seq_len = 24
|
|
166
|
+
total_len = 300 # Total synthetic time series length
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
# fourier_preprocessor = LatentCorrelationGraphLayer(
|
|
170
|
+
# #conv_type='sgconv',
|
|
171
|
+
# input_size=input_size, # Input dimension
|
|
172
|
+
# output_size=hidden_size, # Output dimension (same as hidden_size)
|
|
173
|
+
# )
|
|
174
|
+
preprocessor = LatentGraphNetwork(
|
|
175
|
+
input_size=input_size,
|
|
176
|
+
output_size=input_size,
|
|
177
|
+
hidden_size=input_size,
|
|
178
|
+
strategy="vanilla",
|
|
179
|
+
aggregation="mean",
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
# preprocessor = FourierFeatures(
|
|
183
|
+
# input_size=input_size,
|
|
184
|
+
# output_size=input_size,
|
|
185
|
+
# num_frequencies=8,
|
|
186
|
+
# )
|
|
187
|
+
# 1. Create encoder and decoder
|
|
188
|
+
# encoder = LSTMEncoder(input_size, hidden_size, num_layers)
|
|
189
|
+
# decoder = LSTMDecoder(output_size, hidden_size, output_size, num_layers)
|
|
190
|
+
|
|
191
|
+
model_params = {
|
|
192
|
+
"input_processor_output_size": input_size,
|
|
193
|
+
"hidden_size": 64,
|
|
194
|
+
"nhead": 4,
|
|
195
|
+
"num_encoder_layers": 1,
|
|
196
|
+
"num_decoder_layers": 1,
|
|
197
|
+
"dropout": 0.1,
|
|
198
|
+
"dim_feedforward": 2048,
|
|
199
|
+
"seq_len": 24,
|
|
200
|
+
"target_len": 12,
|
|
201
|
+
"total_len": 1000,
|
|
202
|
+
"input_size": input_size,
|
|
203
|
+
"output_size": output_size,
|
|
204
|
+
}
|
|
205
|
+
|
|
206
|
+
from foreblocks.third_party.flash_softpick_attn import parallel_softpick_attn
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def warmup_softpick(device, d_model=512, n_heads=4, seq_len=16):
|
|
210
|
+
q = torch.randn(
|
|
211
|
+
1, seq_len, n_heads, d_model // n_heads, device=device, dtype=torch.float16
|
|
212
|
+
)
|
|
213
|
+
k = q.clone()
|
|
214
|
+
v = q.clone()
|
|
215
|
+
_ = parallel_softpick_attn(q, k, v, head_first=False)
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
from foreblocks.blocks.nha import NHA
|
|
219
|
+
|
|
220
|
+
embedding_size = 12 # Size of the output embeddings
|
|
221
|
+
# 1. Create the NHA input preprocessor
|
|
222
|
+
nha_preprocessor = NHA(
|
|
223
|
+
input_dim=input_size, # Input dimension
|
|
224
|
+
embedding_dim=embedding_size, # Output embedding dimension
|
|
225
|
+
hidden_dim=12, # Hidden dimension for processing
|
|
226
|
+
num_blocks=2, # Number of hierarchical blocks
|
|
227
|
+
num_levels_per_block=3, # Number of hierarchical levels per block
|
|
228
|
+
kernel_size=3, # Kernel size for convolutions
|
|
229
|
+
attention_heads=4, # Number of attention heads
|
|
230
|
+
dropout=0.1, # Dropout probability
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
from foreblocks.blocks.famous import TimesBlock, TimesBlockPreprocessor
|
|
234
|
+
|
|
235
|
+
times_wrapper = TimesBlockPreprocessor(d_model=input_size)
|
|
236
|
+
# warmup_softpick(device=torch.device("cuda"))
|
|
237
|
+
|
|
238
|
+
pos_encoder = LearnablePositionalEncoding(512)
|
|
239
|
+
pos_decoder = LearnablePositionalEncoding(512)
|
|
240
|
+
|
|
241
|
+
encoder = TransformerEncoder(
|
|
242
|
+
input_size=model_params.get("input_processor_output_size", 1),
|
|
243
|
+
nhead=model_params.get("nhead", 4),
|
|
244
|
+
num_layers=model_params.get("num_encoder_layers", 1),
|
|
245
|
+
dropout=model_params.get("dropout", 0.1),
|
|
246
|
+
dim_feedforward=model_params.get("dim_feedforward", 2048),
|
|
247
|
+
use_moe=True,
|
|
248
|
+
pos_encoder=pos_encoder,
|
|
249
|
+
att_type="autocor",
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
# Create transformer decoder
|
|
253
|
+
decoder = TransformerDecoder(
|
|
254
|
+
input_size=model_params.get("input_processor_output_size", 1),
|
|
255
|
+
output_size=output_size,
|
|
256
|
+
nhead=model_params.get("nhead", 4),
|
|
257
|
+
num_layers=model_params.get("num_decoder_layers", 1),
|
|
258
|
+
dropout=model_params.get("dropout", 0.1),
|
|
259
|
+
dim_feedforward=model_params.get("dim_feedforward", 2048),
|
|
260
|
+
informer_like=True,
|
|
261
|
+
use_moe=True,
|
|
262
|
+
# att_type="prob_sparse",
|
|
263
|
+
pos_encoder=pos_decoder,
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
# from foreblocks.blocks.mamba import MambaDecoder, MambaEncoder
|
|
267
|
+
|
|
268
|
+
# encoder = MambaEncoder(
|
|
269
|
+
# input_size=input_size, hidden_size=hidden_size, num_layers=num_layers
|
|
270
|
+
# )
|
|
271
|
+
# decoder = MambaDecoder(
|
|
272
|
+
# input_size=output_size,
|
|
273
|
+
# hidden_size=hidden_size,
|
|
274
|
+
# num_layers=num_layers,
|
|
275
|
+
# output_size=output_size,
|
|
276
|
+
# )
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
# attention_module = AttentionLayer(
|
|
280
|
+
# method="mha",
|
|
281
|
+
# attention_backend="flash",
|
|
282
|
+
# encoder_hidden_size=hidden_size,
|
|
283
|
+
# decoder_hidden_size=hidden_size,
|
|
284
|
+
# nhead=16,
|
|
285
|
+
# )
|
|
286
|
+
|
|
287
|
+
total_epochs = 500
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
# create scheduled_sampling_fn for teacher forcing
|
|
291
|
+
def scheduled_sampling_fn(epoch):
|
|
292
|
+
# Use a linear decay from 1.0 to 0.0 over the epochs
|
|
293
|
+
tf_ratio = max(0.0, 0.95 - (epoch / total_epochs))
|
|
294
|
+
|
|
295
|
+
return tf_ratio
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
outprocessor = nn.Sequential(
|
|
299
|
+
GRU(input_size=output_size, hidden_size=32, output_size=output_size),
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
print("Using timewrapper")
|
|
303
|
+
outnorm = nn.LayerNorm(output_size)
|
|
304
|
+
model = ForecastingModel(
|
|
305
|
+
encoder=encoder,
|
|
306
|
+
decoder=decoder,
|
|
307
|
+
target_len=target_len,
|
|
308
|
+
forecasting_strategy="seq2seq",
|
|
309
|
+
model_type="informer-like",
|
|
310
|
+
scheduled_sampling_fn=scheduled_sampling_fn,
|
|
311
|
+
output_size=output_size,
|
|
312
|
+
# attention_module=attention_module,
|
|
313
|
+
input_preprocessor=times_wrapper,
|
|
314
|
+
output_block=outprocessor,
|
|
315
|
+
# output_normalization=outnorm,
|
|
316
|
+
input_skip_connection=False,
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
# model = script(model) # Convert to TorchScript for optimization
|
|
320
|
+
trainer = Trainer(
|
|
321
|
+
model,
|
|
322
|
+
optimizer=torch.optim.Adam(model.parameters(), lr=0.001),
|
|
323
|
+
criterion=nn.MSELoss(),
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
train_size = int(0.8 * len(X))
|
|
328
|
+
X_train, Y_train = X[:train_size], y[:train_size]
|
|
329
|
+
X_val, Y_val = X[train_size:], y[train_size:]
|
|
330
|
+
|
|
331
|
+
X_train = torch.tensor(X_train, dtype=torch.float32)
|
|
332
|
+
Y_train = torch.tensor(Y_train, dtype=torch.float32)
|
|
333
|
+
Y_val = torch.tensor(Y_val, dtype=torch.float32)
|
|
334
|
+
X_val = torch.tensor(X_val, dtype=torch.float32)
|
|
335
|
+
time_train = torch.tensor(time_feat[:train_size], dtype=torch.float32)
|
|
336
|
+
|
|
337
|
+
# create DataLoader
|
|
338
|
+
|
|
339
|
+
train_dataset = TensorDataset(X_train, Y_train, time_train)
|
|
340
|
+
print(time_train.shape)
|
|
341
|
+
train_loader = DataLoader(train_dataset, batch_size=1024, shuffle=True)
|
|
342
|
+
data = trainer.train(train_loader, epochs=500)
|
|
343
|
+
metrics = trainer.metrics(X_val, Y_val)
|
|
344
|
+
|
|
345
|
+
|
|
346
|
+
# %%
|
|
347
|
+
X = torch.tensor(X, dtype=torch.float32)
|
|
348
|
+
fig = trainer.plot_prediction(X_val, Y_val, full_series=X, offset=train_size)
|
|
349
|
+
|
|
350
|
+
|
|
351
|
+
# %%
|
|
@@ -0,0 +1,275 @@
|
|
|
1
|
+
# Copyright (c) 2024, Sanghun Cho, Tri Dao.
|
|
2
|
+
|
|
3
|
+
import pickle
|
|
4
|
+
import math
|
|
5
|
+
import torch
|
|
6
|
+
import torch.nn as nn
|
|
7
|
+
import torch.nn.functional as F
|
|
8
|
+
|
|
9
|
+
from einops import rearrange, repeat
|
|
10
|
+
from flash_attn.layers.rotary import apply_rotary_emb
|
|
11
|
+
|
|
12
|
+
from flash_attn.utils.benchmark import benchmark_all, benchmark_forward, benchmark_backward
|
|
13
|
+
from flash_attn.utils.benchmark import benchmark_fwd_bwd, benchmark_combined
|
|
14
|
+
|
|
15
|
+
from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
|
|
16
|
+
|
|
17
|
+
try:
|
|
18
|
+
import xformers.ops as xops
|
|
19
|
+
except ImportError:
|
|
20
|
+
xops = None
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def generate_cos_sin(seqlen, rotary_dim, device, dtype):
|
|
24
|
+
assert rotary_dim % 2 == 0
|
|
25
|
+
angle = torch.rand(seqlen * 2, rotary_dim // 2, device=device) * 2 * math.pi
|
|
26
|
+
cos = torch.cos(angle).to(dtype=dtype)
|
|
27
|
+
sin = torch.sin(angle).to(dtype=dtype)
|
|
28
|
+
return cos, sin
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def flash_rotary(q, k, v, cos, sin, causal=False):
|
|
32
|
+
# corrected by @tridao comments
|
|
33
|
+
q = apply_rotary_emb(
|
|
34
|
+
q, cos, sin, seqlen_offsets=0, interleaved=False, inplace=True
|
|
35
|
+
)
|
|
36
|
+
k = apply_rotary_emb(
|
|
37
|
+
k, cos, sin, seqlen_offsets=0, interleaved=False, inplace=True
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
return flash_attn_func(q, k, v, causal=causal)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def attn_bias_from_alibi_slopes(
|
|
44
|
+
slopes, seqlen_q, seqlen_k, query_padding_mask=None, key_padding_mask=None, causal=False
|
|
45
|
+
):
|
|
46
|
+
batch, nheads = slopes.shape
|
|
47
|
+
device = slopes.device
|
|
48
|
+
slopes = rearrange(slopes, "b h -> b h 1 1")
|
|
49
|
+
if causal:
|
|
50
|
+
return torch.arange(-seqlen_k + 1, 1, device=device, dtype=torch.float32) * slopes
|
|
51
|
+
else:
|
|
52
|
+
row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1")
|
|
53
|
+
col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
|
|
54
|
+
sk = (
|
|
55
|
+
seqlen_k
|
|
56
|
+
if key_padding_mask is None
|
|
57
|
+
else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
|
|
58
|
+
)
|
|
59
|
+
sq = (
|
|
60
|
+
seqlen_q
|
|
61
|
+
if query_padding_mask is None
|
|
62
|
+
else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1")
|
|
63
|
+
)
|
|
64
|
+
relative_pos = torch.abs(row_idx + sk - sq - col_idx)
|
|
65
|
+
return -slopes * relative_pos.to(dtype=slopes.dtype)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def flops(batch, seqlen, headdim, nheads, causal, mode="fwd"):
|
|
69
|
+
assert mode in ["fwd", "bwd", "fwd_bwd"]
|
|
70
|
+
f = 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1)
|
|
71
|
+
return f if mode == "fwd" else (2.5 * f if mode == "bwd" else 3.5 * f)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def efficiency(flop, time):
|
|
75
|
+
return (flop / time / 10**12) if not math.isnan(time) else 0.0
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def attention_pytorch(q, k, v, dropout_p=0.0, causal=True, attn_bias=None):
|
|
79
|
+
"""
|
|
80
|
+
Arguments:
|
|
81
|
+
q, k, v: (batch_size, seqlen, nheads, head_dim)
|
|
82
|
+
dropout_p: float
|
|
83
|
+
attn_bias: (batch_size, nheads, seqlen, seqlen) or (1, nheads, seqlen, seqlen)
|
|
84
|
+
Output:
|
|
85
|
+
output: (batch_size, seqlen, nheads, head_dim)
|
|
86
|
+
"""
|
|
87
|
+
batch_size, seqlen, nheads, d = q.shape
|
|
88
|
+
q = rearrange(q, 'b t h d -> (b h) t d')
|
|
89
|
+
k = rearrange(k, 'b s h d -> (b h) d s')
|
|
90
|
+
softmax_scale = 1.0 / math.sqrt(d)
|
|
91
|
+
# Preallocate attn_weights for `baddbmm`
|
|
92
|
+
if attn_bias is not None:
|
|
93
|
+
scores = rearrange(attn_bias, 'b h t s -> (b h) t s')
|
|
94
|
+
else:
|
|
95
|
+
scores = torch.empty(batch_size * nheads, seqlen, seqlen, dtype=q.dtype, device=q.device)
|
|
96
|
+
scores = rearrange(torch.baddbmm(scores, q, k, beta=1.0, alpha=softmax_scale),
|
|
97
|
+
'(b h) t s -> b h t s', h=nheads)
|
|
98
|
+
if causal:
|
|
99
|
+
# "triu_tril_cuda_template" not implemented for 'BFloat16'
|
|
100
|
+
# So we have to construct the mask in float
|
|
101
|
+
causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
|
|
102
|
+
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
|
|
103
|
+
scores = scores + causal_mask.to(dtype=scores.dtype)
|
|
104
|
+
attention = torch.softmax(scores, dim=-1)
|
|
105
|
+
attention_drop = F.dropout(attention, dropout_p)
|
|
106
|
+
output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
|
|
107
|
+
return output.to(dtype=q.dtype)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def time_fwd_bwd(func, *args, **kwargs):
|
|
111
|
+
time_f, time_b = benchmark_fwd_bwd(func, *args, **kwargs)
|
|
112
|
+
return time_f[1].mean, time_b[1].mean
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
repeats = 30
|
|
116
|
+
device = 'cuda'
|
|
117
|
+
dtype = torch.float16
|
|
118
|
+
|
|
119
|
+
bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 16384)]
|
|
120
|
+
causal_vals = [False, True]
|
|
121
|
+
headdim_vals = [64, 128]
|
|
122
|
+
dim = 2048
|
|
123
|
+
dropout_p = 0.0
|
|
124
|
+
|
|
125
|
+
methods = (["fa2_alibi", "torch"]
|
|
126
|
+
+ (["xformers"] if xops is not None else [])
|
|
127
|
+
+ ["sdpa"]
|
|
128
|
+
+ ["fa2_baseline"]
|
|
129
|
+
+ ["fa2_rotary"])
|
|
130
|
+
|
|
131
|
+
time_f = {}
|
|
132
|
+
time_b = {}
|
|
133
|
+
time_f_b = {}
|
|
134
|
+
speed_f = {}
|
|
135
|
+
speed_b = {}
|
|
136
|
+
speed_f_b = {}
|
|
137
|
+
for causal in causal_vals:
|
|
138
|
+
for headdim in headdim_vals:
|
|
139
|
+
for batch_size, seqlen in bs_seqlen_vals:
|
|
140
|
+
config = (causal, headdim, batch_size, seqlen)
|
|
141
|
+
nheads = dim // headdim
|
|
142
|
+
q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype,
|
|
143
|
+
requires_grad=True) for _ in range(3)]
|
|
144
|
+
# alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3
|
|
145
|
+
alibi_slopes = torch.rand(1, nheads, device=device, dtype=torch.float32) * 0.3
|
|
146
|
+
attn_bias = attn_bias_from_alibi_slopes(alibi_slopes, seqlen, seqlen, causal=causal).to(dtype)
|
|
147
|
+
attn_bias = repeat(attn_bias, "1 ... -> b ...", b=batch_size)
|
|
148
|
+
f, b = time_fwd_bwd(
|
|
149
|
+
flash_attn_func,
|
|
150
|
+
q, k, v,
|
|
151
|
+
dropout_p,
|
|
152
|
+
causal=causal,
|
|
153
|
+
# alibi_slopes=alibi_slopes,
|
|
154
|
+
alibi_slopes=None,
|
|
155
|
+
repeats=repeats,
|
|
156
|
+
verbose=False
|
|
157
|
+
)
|
|
158
|
+
time_f[config, "fa2_baseline"] = f
|
|
159
|
+
time_b[config, "fa2_baseline"] = b
|
|
160
|
+
|
|
161
|
+
q = q.detach().requires_grad_(True)
|
|
162
|
+
k = k.detach().requires_grad_(True)
|
|
163
|
+
v = v.detach().requires_grad_(True)
|
|
164
|
+
f, b = time_fwd_bwd(
|
|
165
|
+
flash_attn_func,
|
|
166
|
+
q, k, v,
|
|
167
|
+
dropout_p,
|
|
168
|
+
causal=causal,
|
|
169
|
+
alibi_slopes=rearrange(alibi_slopes, "1 h -> h"),
|
|
170
|
+
# alibi_slopes=None,
|
|
171
|
+
repeats=repeats,
|
|
172
|
+
verbose=False
|
|
173
|
+
)
|
|
174
|
+
time_f[config, "fa2_alibi"] = f
|
|
175
|
+
time_b[config, "fa2_alibi"] = b
|
|
176
|
+
|
|
177
|
+
try:
|
|
178
|
+
q = q.detach().requires_grad_(True)
|
|
179
|
+
k = k.detach().requires_grad_(True)
|
|
180
|
+
v = v.detach().requires_grad_(True)
|
|
181
|
+
f, b = time_fwd_bwd(
|
|
182
|
+
attention_pytorch,
|
|
183
|
+
q, k, v,
|
|
184
|
+
dropout_p,
|
|
185
|
+
causal=causal,
|
|
186
|
+
attn_bias=attn_bias,
|
|
187
|
+
repeats=repeats,
|
|
188
|
+
verbose=False
|
|
189
|
+
)
|
|
190
|
+
except: # Skip if OOM
|
|
191
|
+
f, b = float('nan'), float('nan')
|
|
192
|
+
time_f[config, "torch"] = f
|
|
193
|
+
time_b[config, "torch"] = b
|
|
194
|
+
|
|
195
|
+
# F.sdpa doesn't currently (torch 2.1) dispatch to flash-attn but just to be safe
|
|
196
|
+
with torch.backends.cuda.sdp_kernel(enable_flash=False):
|
|
197
|
+
q_pt = q.detach().requires_grad_(True).transpose(1, 2)
|
|
198
|
+
k_pt = k.detach().requires_grad_(True).transpose(1, 2)
|
|
199
|
+
v_pt = v.detach().requires_grad_(True).transpose(1, 2)
|
|
200
|
+
f, b = time_fwd_bwd(
|
|
201
|
+
F.scaled_dot_product_attention,
|
|
202
|
+
q_pt, k_pt, v_pt,
|
|
203
|
+
attn_mask=attn_bias,
|
|
204
|
+
dropout_p=dropout_p,
|
|
205
|
+
is_causal=causal,
|
|
206
|
+
repeats=repeats,
|
|
207
|
+
verbose=False
|
|
208
|
+
)
|
|
209
|
+
time_f[config, "sdpa"] = f
|
|
210
|
+
time_b[config, "sdpa"] = b
|
|
211
|
+
|
|
212
|
+
if xops is not None:
|
|
213
|
+
q = q.detach().requires_grad_(True)
|
|
214
|
+
k = k.detach().requires_grad_(True)
|
|
215
|
+
v = v.detach().requires_grad_(True)
|
|
216
|
+
if causal:
|
|
217
|
+
attn_bias_xops = xops.LowerTriangularMask().add_bias(attn_bias.expand(-1, -1, seqlen, -1).to(dtype=q.dtype))
|
|
218
|
+
# NotImplementedError: No operator found for `memory_efficient_attention_backward` with inputs:
|
|
219
|
+
# `flshattB@v2.3.6` is not supported because:
|
|
220
|
+
# attn_bias type is <class 'xformers.ops.fmha.attn_bias.LowerTriangularMaskWithTensorBias'>
|
|
221
|
+
# `cutlassB` is not supported because:
|
|
222
|
+
# attn_bias type is <class 'xformers.ops.fmha.attn_bias.LowerTriangularMaskWithTensorBias'>
|
|
223
|
+
attn_bias_xops = attn_bias_xops.materialize((batch_size, nheads, seqlen, seqlen), dtype=q.dtype, device=device)
|
|
224
|
+
else:
|
|
225
|
+
attn_bias_xops = attn_bias.to(dtype=q.dtype)
|
|
226
|
+
f, b = time_fwd_bwd(
|
|
227
|
+
xops.memory_efficient_attention,
|
|
228
|
+
q, k, v,
|
|
229
|
+
attn_bias_xops,
|
|
230
|
+
dropout_p,
|
|
231
|
+
repeats=repeats,
|
|
232
|
+
verbose=False
|
|
233
|
+
)
|
|
234
|
+
time_f[config, "xformers"] = f
|
|
235
|
+
time_b[config, "xformers"] = b
|
|
236
|
+
|
|
237
|
+
q = q.detach().requires_grad_(True)
|
|
238
|
+
k = k.detach().requires_grad_(True)
|
|
239
|
+
v = v.detach().requires_grad_(True)
|
|
240
|
+
cos, sin = generate_cos_sin(seqlen, headdim, device, dtype)
|
|
241
|
+
f, b = time_fwd_bwd(
|
|
242
|
+
flash_rotary,
|
|
243
|
+
q, k, v,
|
|
244
|
+
cos, sin,
|
|
245
|
+
causal,
|
|
246
|
+
repeats=repeats,
|
|
247
|
+
verbose=False
|
|
248
|
+
)
|
|
249
|
+
time_f[config, "fa2_rotary"] = f
|
|
250
|
+
time_b[config, "fa2_rotary"] = b
|
|
251
|
+
|
|
252
|
+
print(f"### causal={causal}, headdim={headdim}, batch_size={batch_size}, seqlen={seqlen} ###")
|
|
253
|
+
csv_output = ""
|
|
254
|
+
csv_output += f"{causal},{headdim},{batch_size},{seqlen},"
|
|
255
|
+
for method in methods:
|
|
256
|
+
time_f_b[config, method] = time_f[config, method] + time_b[config, method]
|
|
257
|
+
speed_f[config, method] = efficiency(
|
|
258
|
+
flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd"),
|
|
259
|
+
time_f[config, method]
|
|
260
|
+
)
|
|
261
|
+
speed_b[config, method] = efficiency(
|
|
262
|
+
flops(batch_size, seqlen, headdim, nheads, causal, mode="bwd"),
|
|
263
|
+
time_b[config, method]
|
|
264
|
+
)
|
|
265
|
+
speed_f_b[config, method] = efficiency(
|
|
266
|
+
flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd_bwd"),
|
|
267
|
+
time_f_b[config, method]
|
|
268
|
+
)
|
|
269
|
+
print(
|
|
270
|
+
f"{method} fwd: {speed_f[config, method]:.2f} TFLOPs/s, "
|
|
271
|
+
f"bwd: {speed_b[config, method]:.2f} TFLOPs/s, "
|
|
272
|
+
f"fwd + bwd: {speed_f_b[config, method]:.2f} TFLOPs/s"
|
|
273
|
+
)
|
|
274
|
+
csv_output += f"{speed_f[config, method]:.2f},{speed_b[config, method]:.2f},{speed_f_b[config, method]:.2f},"
|
|
275
|
+
print(csv_output)
|