warp-lang 0.9.0__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +15 -7
- warp/__init__.pyi +1 -0
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +22 -443
- warp/build_dll.py +384 -0
- warp/builtins.py +998 -488
- warp/codegen.py +1307 -739
- warp/config.py +5 -3
- warp/constants.py +6 -0
- warp/context.py +1291 -548
- warp/dlpack.py +31 -31
- warp/fabric.py +326 -0
- warp/fem/__init__.py +27 -0
- warp/fem/cache.py +389 -0
- warp/fem/dirichlet.py +181 -0
- warp/fem/domain.py +263 -0
- warp/fem/field/__init__.py +101 -0
- warp/fem/field/field.py +149 -0
- warp/fem/field/nodal_field.py +299 -0
- warp/fem/field/restriction.py +21 -0
- warp/fem/field/test.py +181 -0
- warp/fem/field/trial.py +183 -0
- warp/fem/geometry/__init__.py +19 -0
- warp/fem/geometry/closest_point.py +70 -0
- warp/fem/geometry/deformed_geometry.py +271 -0
- warp/fem/geometry/element.py +744 -0
- warp/fem/geometry/geometry.py +186 -0
- warp/fem/geometry/grid_2d.py +373 -0
- warp/fem/geometry/grid_3d.py +435 -0
- warp/fem/geometry/hexmesh.py +953 -0
- warp/fem/geometry/partition.py +376 -0
- warp/fem/geometry/quadmesh_2d.py +532 -0
- warp/fem/geometry/tetmesh.py +840 -0
- warp/fem/geometry/trimesh_2d.py +577 -0
- warp/fem/integrate.py +1616 -0
- warp/fem/operator.py +191 -0
- warp/fem/polynomial.py +213 -0
- warp/fem/quadrature/__init__.py +2 -0
- warp/fem/quadrature/pic_quadrature.py +245 -0
- warp/fem/quadrature/quadrature.py +294 -0
- warp/fem/space/__init__.py +292 -0
- warp/fem/space/basis_space.py +489 -0
- warp/fem/space/collocated_function_space.py +105 -0
- warp/fem/space/dof_mapper.py +236 -0
- warp/fem/space/function_space.py +145 -0
- warp/fem/space/grid_2d_function_space.py +267 -0
- warp/fem/space/grid_3d_function_space.py +306 -0
- warp/fem/space/hexmesh_function_space.py +352 -0
- warp/fem/space/partition.py +350 -0
- warp/fem/space/quadmesh_2d_function_space.py +369 -0
- warp/fem/space/restriction.py +160 -0
- warp/fem/space/shape/__init__.py +15 -0
- warp/fem/space/shape/cube_shape_function.py +738 -0
- warp/fem/space/shape/shape_function.py +103 -0
- warp/fem/space/shape/square_shape_function.py +611 -0
- warp/fem/space/shape/tet_shape_function.py +567 -0
- warp/fem/space/shape/triangle_shape_function.py +429 -0
- warp/fem/space/tetmesh_function_space.py +292 -0
- warp/fem/space/topology.py +295 -0
- warp/fem/space/trimesh_2d_function_space.py +221 -0
- warp/fem/types.py +77 -0
- warp/fem/utils.py +495 -0
- warp/native/array.h +164 -55
- warp/native/builtin.h +150 -174
- warp/native/bvh.cpp +75 -328
- warp/native/bvh.cu +406 -23
- warp/native/bvh.h +37 -45
- warp/native/clang/clang.cpp +136 -24
- warp/native/crt.cpp +1 -76
- warp/native/crt.h +111 -104
- warp/native/cuda_crt.h +1049 -0
- warp/native/cuda_util.cpp +15 -3
- warp/native/cuda_util.h +3 -1
- warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
- warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
- warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
- warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
- warp/native/cutlass/tools/library/scripts/library.py +799 -0
- warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
- warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
- warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
- warp/native/cutlass/tools/library/scripts/rt.py +796 -0
- warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
- warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
- warp/native/cutlass_gemm.cu +5 -3
- warp/native/exports.h +1240 -949
- warp/native/fabric.h +228 -0
- warp/native/hashgrid.cpp +4 -4
- warp/native/hashgrid.h +22 -2
- warp/native/initializer_array.h +2 -2
- warp/native/intersect.h +22 -7
- warp/native/intersect_adj.h +8 -8
- warp/native/intersect_tri.h +13 -16
- warp/native/marching.cu +157 -161
- warp/native/mat.h +119 -19
- warp/native/matnn.h +2 -2
- warp/native/mesh.cpp +108 -83
- warp/native/mesh.cu +243 -6
- warp/native/mesh.h +1547 -458
- warp/native/nanovdb/NanoVDB.h +1 -1
- warp/native/noise.h +272 -329
- warp/native/quat.h +51 -8
- warp/native/rand.h +45 -35
- warp/native/range.h +6 -2
- warp/native/reduce.cpp +157 -0
- warp/native/reduce.cu +348 -0
- warp/native/runlength_encode.cpp +62 -0
- warp/native/runlength_encode.cu +46 -0
- warp/native/scan.cu +11 -13
- warp/native/scan.h +1 -0
- warp/native/solid_angle.h +442 -0
- warp/native/sort.cpp +13 -0
- warp/native/sort.cu +9 -1
- warp/native/sparse.cpp +338 -0
- warp/native/sparse.cu +545 -0
- warp/native/spatial.h +2 -2
- warp/native/temp_buffer.h +30 -0
- warp/native/vec.h +126 -24
- warp/native/volume.h +120 -0
- warp/native/warp.cpp +658 -53
- warp/native/warp.cu +660 -68
- warp/native/warp.h +112 -12
- warp/optim/__init__.py +1 -0
- warp/optim/linear.py +922 -0
- warp/optim/sgd.py +92 -0
- warp/render/render_opengl.py +392 -152
- warp/render/render_usd.py +11 -11
- warp/sim/__init__.py +2 -2
- warp/sim/articulation.py +385 -185
- warp/sim/collide.py +21 -8
- warp/sim/import_mjcf.py +297 -106
- warp/sim/import_urdf.py +389 -210
- warp/sim/import_usd.py +198 -97
- warp/sim/inertia.py +17 -18
- warp/sim/integrator_euler.py +14 -8
- warp/sim/integrator_xpbd.py +161 -19
- warp/sim/model.py +795 -291
- warp/sim/optimizer.py +2 -6
- warp/sim/render.py +65 -3
- warp/sim/utils.py +3 -0
- warp/sparse.py +1227 -0
- warp/stubs.py +665 -223
- warp/tape.py +66 -15
- warp/tests/__main__.py +3 -6
- warp/tests/assets/curlnoise_golden.npy +0 -0
- warp/tests/assets/pnoise_golden.npy +0 -0
- warp/tests/assets/torus.usda +105 -105
- warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
- warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
- warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
- warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
- warp/tests/aux_test_unresolved_func.py +14 -0
- warp/tests/aux_test_unresolved_symbol.py +14 -0
- warp/tests/disabled_kinematics.py +239 -0
- warp/tests/run_coverage_serial.py +31 -0
- warp/tests/test_adam.py +103 -106
- warp/tests/test_arithmetic.py +128 -74
- warp/tests/test_array.py +1497 -211
- warp/tests/test_array_reduce.py +150 -0
- warp/tests/test_atomic.py +64 -28
- warp/tests/test_bool.py +99 -0
- warp/tests/test_builtins_resolution.py +1292 -0
- warp/tests/test_bvh.py +75 -43
- warp/tests/test_closest_point_edge_edge.py +54 -57
- warp/tests/test_codegen.py +233 -128
- warp/tests/test_compile_consts.py +28 -20
- warp/tests/test_conditional.py +108 -24
- warp/tests/test_copy.py +10 -12
- warp/tests/test_ctypes.py +112 -88
- warp/tests/test_dense.py +21 -14
- warp/tests/test_devices.py +98 -0
- warp/tests/test_dlpack.py +136 -108
- warp/tests/test_examples.py +277 -0
- warp/tests/test_fabricarray.py +955 -0
- warp/tests/test_fast_math.py +15 -11
- warp/tests/test_fem.py +1271 -0
- warp/tests/test_fp16.py +53 -19
- warp/tests/test_func.py +187 -74
- warp/tests/test_generics.py +194 -49
- warp/tests/test_grad.py +180 -116
- warp/tests/test_grad_customs.py +176 -0
- warp/tests/test_hash_grid.py +52 -37
- warp/tests/test_import.py +10 -23
- warp/tests/test_indexedarray.py +577 -24
- warp/tests/test_intersect.py +18 -9
- warp/tests/test_large.py +141 -0
- warp/tests/test_launch.py +251 -15
- warp/tests/test_lerp.py +64 -65
- warp/tests/test_linear_solvers.py +154 -0
- warp/tests/test_lvalue.py +493 -0
- warp/tests/test_marching_cubes.py +12 -13
- warp/tests/test_mat.py +508 -2778
- warp/tests/test_mat_lite.py +115 -0
- warp/tests/test_mat_scalar_ops.py +2889 -0
- warp/tests/test_math.py +103 -9
- warp/tests/test_matmul.py +305 -69
- warp/tests/test_matmul_lite.py +410 -0
- warp/tests/test_mesh.py +71 -14
- warp/tests/test_mesh_query_aabb.py +41 -25
- warp/tests/test_mesh_query_point.py +325 -34
- warp/tests/test_mesh_query_ray.py +39 -22
- warp/tests/test_mlp.py +30 -22
- warp/tests/test_model.py +92 -89
- warp/tests/test_modules_lite.py +39 -0
- warp/tests/test_multigpu.py +88 -114
- warp/tests/test_noise.py +12 -11
- warp/tests/test_operators.py +16 -20
- warp/tests/test_options.py +11 -11
- warp/tests/test_pinned.py +17 -18
- warp/tests/test_print.py +32 -11
- warp/tests/test_quat.py +275 -129
- warp/tests/test_rand.py +18 -16
- warp/tests/test_reload.py +38 -34
- warp/tests/test_rounding.py +50 -43
- warp/tests/test_runlength_encode.py +190 -0
- warp/tests/test_smoothstep.py +9 -11
- warp/tests/test_snippet.py +143 -0
- warp/tests/test_sparse.py +460 -0
- warp/tests/test_spatial.py +276 -243
- warp/tests/test_streams.py +110 -85
- warp/tests/test_struct.py +331 -85
- warp/tests/test_tape.py +39 -21
- warp/tests/test_torch.py +118 -89
- warp/tests/test_transient_module.py +12 -13
- warp/tests/test_types.py +614 -0
- warp/tests/test_utils.py +494 -0
- warp/tests/test_vec.py +354 -1987
- warp/tests/test_vec_lite.py +73 -0
- warp/tests/test_vec_scalar_ops.py +2099 -0
- warp/tests/test_volume.py +457 -293
- warp/tests/test_volume_write.py +124 -134
- warp/tests/unittest_serial.py +35 -0
- warp/tests/unittest_suites.py +341 -0
- warp/tests/unittest_utils.py +568 -0
- warp/tests/unused_test_misc.py +71 -0
- warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
- warp/thirdparty/appdirs.py +36 -45
- warp/thirdparty/unittest_parallel.py +549 -0
- warp/torch.py +72 -30
- warp/types.py +1744 -713
- warp/utils.py +360 -350
- warp_lang-0.11.0.dist-info/LICENSE.md +36 -0
- warp_lang-0.11.0.dist-info/METADATA +238 -0
- warp_lang-0.11.0.dist-info/RECORD +332 -0
- {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
- warp/bin/warp-clang.exp +0 -0
- warp/bin/warp-clang.lib +0 -0
- warp/bin/warp.exp +0 -0
- warp/bin/warp.lib +0 -0
- warp/tests/test_all.py +0 -215
- warp/tests/test_array_scan.py +0 -60
- warp/tests/test_base.py +0 -208
- warp/tests/test_unresolved_func.py +0 -7
- warp/tests/test_unresolved_symbol.py +0 -7
- warp_lang-0.9.0.dist-info/METADATA +0 -20
- warp_lang-0.9.0.dist-info/RECORD +0 -177
- /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
- /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
- /warp/tests/{test_square.py → aux_test_square.py} +0 -0
- {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,1026 @@
|
|
|
1
|
+
################################################################################
|
|
2
|
+
#
|
|
3
|
+
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved
|
|
4
|
+
# SPDX-License-Identifier: BSD-3-Clause
|
|
5
|
+
#
|
|
6
|
+
# Redistribution and use in source and binary forms, with or without
|
|
7
|
+
# modification, are permitted provided that the following conditions are met:
|
|
8
|
+
#
|
|
9
|
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
|
10
|
+
# list of conditions and the following disclaimer.
|
|
11
|
+
#
|
|
12
|
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
13
|
+
# this list of conditions and the following disclaimer in the documentation
|
|
14
|
+
# and/or other materials provided with the distribution.
|
|
15
|
+
#
|
|
16
|
+
# 3. Neither the name of the copyright holder nor the names of its
|
|
17
|
+
# contributors may be used to endorse or promote products derived from
|
|
18
|
+
# this software without specific prior written permission.
|
|
19
|
+
#
|
|
20
|
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
21
|
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
22
|
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
23
|
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
24
|
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
25
|
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
26
|
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
27
|
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
28
|
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
29
|
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
30
|
+
#
|
|
31
|
+
################################################################################
|
|
32
|
+
|
|
33
|
+
from ast import Num
|
|
34
|
+
from audioop import mul
|
|
35
|
+
from pipes import Template
|
|
36
|
+
import struct
|
|
37
|
+
from pycutlass.library import DataTypeTag
|
|
38
|
+
from pycutlass import *
|
|
39
|
+
import cutlass
|
|
40
|
+
from scipy.special import erf
|
|
41
|
+
|
|
42
|
+
from pycutlass.c_types import MatrixCoord_
|
|
43
|
+
from pycutlass.frontend import NumpyFrontend
|
|
44
|
+
|
|
45
|
+
from cuda import cuda
|
|
46
|
+
from cuda import cudart
|
|
47
|
+
|
|
48
|
+
dtype2ctype = {
|
|
49
|
+
cutlass.float16: ctypes.c_uint16,
|
|
50
|
+
cutlass.float32: ctypes.c_float,
|
|
51
|
+
cutlass.float64: ctypes.c_double,
|
|
52
|
+
cutlass.int32: ctypes.c_int32
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
#################################################################################################
|
|
57
|
+
#
|
|
58
|
+
# Epilogue Functors
|
|
59
|
+
#
|
|
60
|
+
#################################################################################################
|
|
61
|
+
|
|
62
|
+
class EpilogueFunctorBase:
|
|
63
|
+
"""
|
|
64
|
+
Base class for thread-level epilogue functors
|
|
65
|
+
"""
|
|
66
|
+
def __init__(self) -> None:
|
|
67
|
+
pass
|
|
68
|
+
|
|
69
|
+
def emit(self, tag, template_argument):
|
|
70
|
+
template = """${tag}<${arguments}>"""
|
|
71
|
+
arguments = ""
|
|
72
|
+
for idx, arg in enumerate(template_argument):
|
|
73
|
+
arguments += arg
|
|
74
|
+
if idx < len(template_argument) - 1:
|
|
75
|
+
arguments += ", "
|
|
76
|
+
values = {
|
|
77
|
+
"tag": tag,
|
|
78
|
+
"arguments": arguments
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
return SubstituteTemplate(template, values)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class LinearCombination(EpilogueFunctorBase):
|
|
86
|
+
"""
|
|
87
|
+
Apply a linear combination operator to an array of elements
|
|
88
|
+
D = alpha * accumulator + beta * source
|
|
89
|
+
|
|
90
|
+
:param element_output: data type used to load and store tensors
|
|
91
|
+
|
|
92
|
+
:param epilogue_vector_length: number of elements computed per operation.
|
|
93
|
+
Usually it is 128/sizeof_bits<ElementOutput_>, but we use 64 and 32 sometimes
|
|
94
|
+
when there are not enough data to store
|
|
95
|
+
|
|
96
|
+
:param element_accumulator: Accumulator data type
|
|
97
|
+
|
|
98
|
+
:param element_epilogue: data type used to compute linear combination
|
|
99
|
+
"""
|
|
100
|
+
tag = "cutlass::epilogue::thread::LinearCombination"
|
|
101
|
+
def __init__(
|
|
102
|
+
self, element_output, epilogue_vector_length,
|
|
103
|
+
element_accumulator=None, element_epilogue=None) -> None: # TODO bind ScaleType
|
|
104
|
+
super().__init__()
|
|
105
|
+
|
|
106
|
+
if element_accumulator is None:
|
|
107
|
+
element_accumulator = element_output
|
|
108
|
+
if element_epilogue is None:
|
|
109
|
+
element_epilogue = element_output
|
|
110
|
+
|
|
111
|
+
self.element_output = element_output
|
|
112
|
+
self.element_accumulator = element_accumulator
|
|
113
|
+
self.element_epilogue = element_epilogue
|
|
114
|
+
|
|
115
|
+
self.template_arguments = [
|
|
116
|
+
DataTypeTag[element_output], str(epilogue_vector_length),
|
|
117
|
+
DataTypeTag[element_accumulator], DataTypeTag[element_epilogue]
|
|
118
|
+
]
|
|
119
|
+
|
|
120
|
+
# get epilogue output op type
|
|
121
|
+
c_element_epilogue = dtype2ctype[self.element_epilogue]
|
|
122
|
+
element_epilogue = self.element_epilogue
|
|
123
|
+
|
|
124
|
+
class _EpilogueOutputOpParams(ctypes.Structure):
|
|
125
|
+
_fields_ = [
|
|
126
|
+
("alpha_data", ctypes.c_longlong*2),
|
|
127
|
+
("beta_data", ctypes.c_longlong*2),
|
|
128
|
+
("alpha", c_element_epilogue),
|
|
129
|
+
("beta", c_element_epilogue),
|
|
130
|
+
("alpha_ptr", ctypes.c_void_p),
|
|
131
|
+
("beta_ptr", ctypes.c_void_p),
|
|
132
|
+
]
|
|
133
|
+
def __init__(self, alpha, beta, *args) -> None:
|
|
134
|
+
self.alpha = element_epilogue(alpha).storage
|
|
135
|
+
self.beta = element_epilogue(beta).storage
|
|
136
|
+
self.epilogue_type = _EpilogueOutputOpParams
|
|
137
|
+
|
|
138
|
+
def emit(self):
|
|
139
|
+
return super().emit(self.tag, self.template_arguments)
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
class LinearCombinationClamp(LinearCombination):
|
|
143
|
+
"""
|
|
144
|
+
Applies a linear combination operator to an array of elements then clamps
|
|
145
|
+
the output before converting to the output element type.
|
|
146
|
+
|
|
147
|
+
D = alpha * accumulator + beta * source + uniform
|
|
148
|
+
|
|
149
|
+
:param element_output: data type used to load and store tensors
|
|
150
|
+
|
|
151
|
+
:param epilogue_vector_length: number of elements computed per operation.
|
|
152
|
+
Usually it is 128/sizeof_bits<ElementOutput_>, but we use 64 and 32 sometimes
|
|
153
|
+
when there are not enough data to store
|
|
154
|
+
|
|
155
|
+
:param element_accumulator: Accumulator data type
|
|
156
|
+
|
|
157
|
+
:param element_epilogue: data type used to compute linear combination
|
|
158
|
+
"""
|
|
159
|
+
tag = "cutlass::epilogue::thread::LinearCombinationClamp"
|
|
160
|
+
def __init__(
|
|
161
|
+
self, element_output, epilogue_vector_length,
|
|
162
|
+
element_accumulator=None, element_epilogue=None) -> None:
|
|
163
|
+
# Base constructor
|
|
164
|
+
super().__init__(
|
|
165
|
+
element_output, epilogue_vector_length,
|
|
166
|
+
element_accumulator, element_epilogue)
|
|
167
|
+
|
|
168
|
+
c_element_epilogue = dtype2ctype[self.element_epilogue]
|
|
169
|
+
element_epilogue = self.element_epilogue
|
|
170
|
+
|
|
171
|
+
class _EpilogueOutputOpParams(ctypes.Structure):
|
|
172
|
+
_fields_ = [
|
|
173
|
+
("alpha", c_element_epilogue),
|
|
174
|
+
("beta", c_element_epilogue),
|
|
175
|
+
("alpha_ptr", ctypes.c_void_p),
|
|
176
|
+
("beta_ptr", ctypes.c_void_p),
|
|
177
|
+
]
|
|
178
|
+
def __init__(self, alpha, beta, *args) -> None:
|
|
179
|
+
self.alpha = element_epilogue(alpha).storage
|
|
180
|
+
self.beta = element_epilogue(beta).storage
|
|
181
|
+
self.epilogue_type = _EpilogueOutputOpParams
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
class FastLinearCombinationClamp(EpilogueFunctorBase):
|
|
185
|
+
"""
|
|
186
|
+
Applies a linear combination operator to an array of elements then clamps
|
|
187
|
+
the output before converting to the output element type.
|
|
188
|
+
|
|
189
|
+
D = alpha * accumulator + beta * source
|
|
190
|
+
|
|
191
|
+
Note: The below method only when problem_size_K <= 256 for signed int8 gemm
|
|
192
|
+
or problem_size_K <= 128 for unsigned int8 gemm. The default approach is
|
|
193
|
+
above.
|
|
194
|
+
|
|
195
|
+
:param element_output: data type used to load and store tensors
|
|
196
|
+
|
|
197
|
+
:param epilogue_vector_length: number of elements computed per operation.
|
|
198
|
+
Usually it is 128/sizeof_bits<ElementOutput_>, but we use 64 and 32 sometimes
|
|
199
|
+
when there are not enough data to store
|
|
200
|
+
"""
|
|
201
|
+
tag = "cutlass::epilogue::thread::FastLinearCombinationClamp"
|
|
202
|
+
def __init__(self, element_output, epilogue_vector_length, *args) -> None:
|
|
203
|
+
super().__init__()
|
|
204
|
+
|
|
205
|
+
self.template_arguments = [
|
|
206
|
+
DataTypeTag[element_output], str(epilogue_vector_length)
|
|
207
|
+
]
|
|
208
|
+
|
|
209
|
+
self.element_accumulator = cutlass.int32
|
|
210
|
+
self.element_epilogue = cutlass.float32
|
|
211
|
+
|
|
212
|
+
# get epilogue output op
|
|
213
|
+
c_element_epilogue = dtype2ctype[self.element_epilogue]
|
|
214
|
+
element_epilogue = self.element_epilogue
|
|
215
|
+
|
|
216
|
+
class _EpilogueOutputOpParams(ctypes.Structure):
|
|
217
|
+
_fields_ = [
|
|
218
|
+
("alpha", c_element_epilogue),
|
|
219
|
+
("beta", c_element_epilogue),
|
|
220
|
+
("alpha_ptr", ctypes.c_void_p),
|
|
221
|
+
("beta_ptr", ctypes.c_void_p),
|
|
222
|
+
]
|
|
223
|
+
def __init__(self, alpha, beta, *args) -> None:
|
|
224
|
+
self.alpha = element_epilogue(alpha).storage
|
|
225
|
+
self.beta = element_epilogue(beta).storage
|
|
226
|
+
self.epilogue_type = _EpilogueOutputOpParams
|
|
227
|
+
|
|
228
|
+
def emit(self):
|
|
229
|
+
return super().emit(self.tag, self.template_arguments)
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
class LinearCombinationGeneric(LinearCombination):
|
|
233
|
+
"""
|
|
234
|
+
Applies a linear combination operator followed by an activation function
|
|
235
|
+
to an array of elements.
|
|
236
|
+
|
|
237
|
+
D = activation(alpha * accumulator + beta * source)
|
|
238
|
+
|
|
239
|
+
:param activation_functor: input activation functor
|
|
240
|
+
|
|
241
|
+
:param element_output: data type used to load and store tensors
|
|
242
|
+
|
|
243
|
+
:param epilogue_vector_length: number of elements computed per operation.
|
|
244
|
+
Usually it is 128/sizeof_bits<ElementOutput_>, but we use 64 and 32 sometimes
|
|
245
|
+
when there are not enough data to store
|
|
246
|
+
|
|
247
|
+
:param element_accumulator: Accumulator data type
|
|
248
|
+
|
|
249
|
+
:param element_epilogue: data type used to compute linear combination
|
|
250
|
+
"""
|
|
251
|
+
tag = "cutlass::epilogue::thread::LinearCombinationGeneric"
|
|
252
|
+
def __init__(
|
|
253
|
+
self, activation_functor,
|
|
254
|
+
element_output, epilogue_vector_length,
|
|
255
|
+
element_accumulator=None, element_epilogue=None) -> None:
|
|
256
|
+
super().__init__(
|
|
257
|
+
element_output, epilogue_vector_length,
|
|
258
|
+
element_accumulator, element_epilogue)
|
|
259
|
+
|
|
260
|
+
self.template_arguments = [
|
|
261
|
+
activation_functor.emit(),] + self.template_arguments
|
|
262
|
+
|
|
263
|
+
self.activation_functor = activation_functor
|
|
264
|
+
self.element_epilogue = element_epilogue
|
|
265
|
+
|
|
266
|
+
# get epilogue output op
|
|
267
|
+
self.epilogue_type = self.activation_functor.epilogue_output_op(self.element_epilogue)
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
class ActivationFunctor:
|
|
271
|
+
"""
|
|
272
|
+
Base class for frequently used activation functions
|
|
273
|
+
"""
|
|
274
|
+
def __init__(self, element_compute) -> None:
|
|
275
|
+
pass
|
|
276
|
+
@staticmethod
|
|
277
|
+
def numpy(x: np.ndarray):
|
|
278
|
+
raise NotImplementedError()
|
|
279
|
+
|
|
280
|
+
def emit(self):
|
|
281
|
+
return self.tag
|
|
282
|
+
|
|
283
|
+
@staticmethod
|
|
284
|
+
def epilogue_output_op(element_epilogue):
|
|
285
|
+
c_element_epilogue = dtype2ctype[element_epilogue]
|
|
286
|
+
|
|
287
|
+
class _EpilogueOutputOpParams(ctypes.Structure):
|
|
288
|
+
_fields_ = [
|
|
289
|
+
("alpha", c_element_epilogue),
|
|
290
|
+
("beta", c_element_epilogue),
|
|
291
|
+
("alpha_ptr", ctypes.c_void_p),
|
|
292
|
+
("beta_ptr", ctypes.c_void_p),
|
|
293
|
+
]
|
|
294
|
+
def __init__(self, alpha, beta, *args) -> None:
|
|
295
|
+
self.alpha = element_epilogue(alpha).storage
|
|
296
|
+
self.beta = element_epilogue(beta).storage
|
|
297
|
+
return _EpilogueOutputOpParams
|
|
298
|
+
|
|
299
|
+
# identity operator
|
|
300
|
+
class identity(ActivationFunctor):
|
|
301
|
+
def numpy(x: np.ndarray):
|
|
302
|
+
return x
|
|
303
|
+
|
|
304
|
+
# ReLu operator,
|
|
305
|
+
class relu(ActivationFunctor):
|
|
306
|
+
tag = "cutlass::epilogue::thread::ReLu"
|
|
307
|
+
|
|
308
|
+
def __init__(self, element_compute):
|
|
309
|
+
super().__init__(element_compute)
|
|
310
|
+
class _Arguments(ctypes.Structure):
|
|
311
|
+
_fields_ = [
|
|
312
|
+
("threshold", dtype2ctype[element_compute])
|
|
313
|
+
]
|
|
314
|
+
def __init__(self, threshold=0.) -> None:
|
|
315
|
+
self.threshold = element_compute(threshold).storage
|
|
316
|
+
self.argument_type = _Arguments
|
|
317
|
+
|
|
318
|
+
def emit_visitor(self):
|
|
319
|
+
return "cutlass::ReLUVisitor"
|
|
320
|
+
|
|
321
|
+
@staticmethod
|
|
322
|
+
def numpy(x: np.ndarray):
|
|
323
|
+
return np.maximum(x, 0)
|
|
324
|
+
|
|
325
|
+
# Leaky ReLu operator
|
|
326
|
+
class leaky_relu(ActivationFunctor):
|
|
327
|
+
tag = "cutlass::epilogue::thread::LeakyReLU"
|
|
328
|
+
|
|
329
|
+
def __init__(self, element_compute) -> None:
|
|
330
|
+
super().__init__(element_compute)
|
|
331
|
+
class _Arguments(ctypes.Structure):
|
|
332
|
+
_fields_ = [
|
|
333
|
+
("leaky_alpha", dtype2ctype[element_compute])
|
|
334
|
+
]
|
|
335
|
+
def __init__(self, leaky_alpha) -> None:
|
|
336
|
+
self.leaky_alpha = element_compute(leaky_alpha).storage
|
|
337
|
+
self.argument_type = _Arguments
|
|
338
|
+
|
|
339
|
+
def emit_visitor(self):
|
|
340
|
+
return "cutlass::LeakyReLUVisitor"
|
|
341
|
+
|
|
342
|
+
@staticmethod
|
|
343
|
+
def numpy(x: np.ndarray, leaky_alpha):
|
|
344
|
+
return np.maximum(x, 0) + np.minimum(x, 0) * leaky_alpha
|
|
345
|
+
|
|
346
|
+
def epilogue_output_op(self, element_epilogue):
|
|
347
|
+
c_element_epilogue = dtype2ctype[element_epilogue]
|
|
348
|
+
class _EpilogueOutputOpParams(ctypes.Structure):
|
|
349
|
+
_fields_ = [
|
|
350
|
+
("alpha", c_element_epilogue),
|
|
351
|
+
("beta", c_element_epilogue),
|
|
352
|
+
("alpha_ptr", ctypes.c_void_p),
|
|
353
|
+
("beta_ptr", ctypes.c_void_p),
|
|
354
|
+
("leaky_alpha", c_element_epilogue)
|
|
355
|
+
]
|
|
356
|
+
def __init__(self, alpha, beta, leaky_alpha=0.2, *args) -> None:
|
|
357
|
+
self.alpha = element_epilogue(alpha).storage
|
|
358
|
+
self.beta = element_epilogue(beta).storage
|
|
359
|
+
self.alpha_ptr = 0
|
|
360
|
+
self.beta_ptr = 0
|
|
361
|
+
self.leaky_alpha = element_epilogue(leaky_alpha).storage
|
|
362
|
+
return _EpilogueOutputOpParams
|
|
363
|
+
|
|
364
|
+
# Tanh operator
|
|
365
|
+
class tanh(ActivationFunctor):
|
|
366
|
+
tag = "cutlass::epilogue::thread::Tanh"
|
|
367
|
+
|
|
368
|
+
def __init__(self, element_compute) -> None:
|
|
369
|
+
super().__init__(element_compute)
|
|
370
|
+
class _Arguments(ctypes.Structure):
|
|
371
|
+
_fields_ = [
|
|
372
|
+
("tmp", ctypes.c_int)
|
|
373
|
+
]
|
|
374
|
+
def __init__(self, *args) -> None:
|
|
375
|
+
self.tmp = 0
|
|
376
|
+
self.argument_type = _Arguments
|
|
377
|
+
|
|
378
|
+
def emit_visitor(self):
|
|
379
|
+
return "cutlass::TanhVisitor"
|
|
380
|
+
|
|
381
|
+
@staticmethod
|
|
382
|
+
def numpy(x: np.ndarray):
|
|
383
|
+
return np.tanh(x)
|
|
384
|
+
|
|
385
|
+
def sigmoid_op(x: np.ndarray):
|
|
386
|
+
return 1. / (1. + np.exp(-x))
|
|
387
|
+
|
|
388
|
+
# Sigmoid operator
|
|
389
|
+
class sigmoid(ActivationFunctor):
|
|
390
|
+
tag = "cutlass::epilogue::thread::Sigmoid"
|
|
391
|
+
|
|
392
|
+
@staticmethod
|
|
393
|
+
def numpy(x: np.ndarray):
|
|
394
|
+
return sigmoid_op(x)
|
|
395
|
+
|
|
396
|
+
# SiLu operator
|
|
397
|
+
class silu(ActivationFunctor):
|
|
398
|
+
tag = "cutlass::epilogue::thread::SiLu"
|
|
399
|
+
|
|
400
|
+
@staticmethod
|
|
401
|
+
def numpy(x: np.ndarray):
|
|
402
|
+
return x * sigmoid_op(x)
|
|
403
|
+
|
|
404
|
+
# Hardswish operator
|
|
405
|
+
class hardswish(ActivationFunctor):
|
|
406
|
+
tag = "cutlass::epilogue::thread::HardSwish"
|
|
407
|
+
|
|
408
|
+
@staticmethod
|
|
409
|
+
def numpy(x: np.ndarray):
|
|
410
|
+
relu6 = np.minimum(np.maximum(x + 3., 0), 6.)
|
|
411
|
+
return x * relu6 / 6.
|
|
412
|
+
|
|
413
|
+
# GELU operator
|
|
414
|
+
class gelu(ActivationFunctor):
|
|
415
|
+
tag = "cutlass::epilogue::thread::GELU"
|
|
416
|
+
|
|
417
|
+
@staticmethod
|
|
418
|
+
def numpy(x: np.ndarray):
|
|
419
|
+
return 0.5 * x * (1 + erf(x / np.sqrt(2.)))
|
|
420
|
+
|
|
421
|
+
# reduction operator
|
|
422
|
+
def reduction_op(tensor, direction, math, factor):
|
|
423
|
+
batch, m, n = tensor.shape
|
|
424
|
+
if math == "Add":
|
|
425
|
+
if direction == "row":
|
|
426
|
+
num_cta_n = (n + factor - 1) // factor
|
|
427
|
+
reduction = np.transpose(
|
|
428
|
+
np.sum(tensor.reshape(batch, m, num_cta_n, factor), axis=-1),
|
|
429
|
+
axes=[0, 2, 1]).flatten()
|
|
430
|
+
elif direction == "column":
|
|
431
|
+
num_cta_m = (m + factor - 1) // factor
|
|
432
|
+
reduction = np.sum(
|
|
433
|
+
tensor.reshape(batch, num_cta_m, factor, n), axis=-2).flatten()
|
|
434
|
+
else:
|
|
435
|
+
raise NotImplementedError
|
|
436
|
+
return reduction
|
|
437
|
+
else:
|
|
438
|
+
raise NotImplementedError
|
|
439
|
+
|
|
440
|
+
# # GELU operator implemented using the taylor series approximation
|
|
441
|
+
# class GELU_taylor(ActivationFunctor):
|
|
442
|
+
# tag = "cutlass::epilogue::thread::GELU_taylor"
|
|
443
|
+
|
|
444
|
+
# # Computes backwards pass for GELU operator
|
|
445
|
+
# class dGELU(ActivationFunctor):
|
|
446
|
+
# tag = "cutlass::epilogue::thread::dGELU"
|
|
447
|
+
|
|
448
|
+
################################################################################
|
|
449
|
+
# Epilogue Visitor
|
|
450
|
+
################################################################################
|
|
451
|
+
|
|
452
|
+
|
|
453
|
+
class LayerNorm(EpilogueFunctorBase):
|
|
454
|
+
"""
|
|
455
|
+
Apply a linear combination operator to an array of elements
|
|
456
|
+
D = alpha * accumulator + beta * source
|
|
457
|
+
|
|
458
|
+
:param element_output: data type used to load and store tensors
|
|
459
|
+
|
|
460
|
+
:param epilogue_vector_length: number of elements computed per operation.
|
|
461
|
+
Usually it is 128/sizeof_bits<ElementOutput_>, but we use 64 and 32 sometimes
|
|
462
|
+
when there are not enough data to store
|
|
463
|
+
|
|
464
|
+
:param element_accumulator: Accumulator data type
|
|
465
|
+
|
|
466
|
+
:param element_epilogue: data type used to compute linear combination
|
|
467
|
+
"""
|
|
468
|
+
KernelTemplate = """
|
|
469
|
+
|
|
470
|
+
cutlass::epilogue::threadblock::EpilogueVisitorLayerNorm<
|
|
471
|
+
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
|
472
|
+
${operation_name}_default::kThreadCount,
|
|
473
|
+
${operation_name}_default::Epilogue::OutputTileIterator,
|
|
474
|
+
${operation_name}_default::Epilogue::AccumulatorFragmentIterator::AccumulatorTile,
|
|
475
|
+
${element_compute}, // element_compute
|
|
476
|
+
${element_variance}, // element_variance
|
|
477
|
+
${element_mean}, // element_mean
|
|
478
|
+
${element_layer_norm_compute}, // element_layer_norm_compute
|
|
479
|
+
${epilogue_functor},
|
|
480
|
+
${shifted_k}>;
|
|
481
|
+
"""
|
|
482
|
+
headers = ["gemm/gemm_universal_with_visitor.h",
|
|
483
|
+
"epilogue/epilogue_visitor_with_layernorm.h"]
|
|
484
|
+
def __init__(
|
|
485
|
+
self, elementwise_functor,
|
|
486
|
+
element_variance=None, element_mean=None,
|
|
487
|
+
element_layer_norm_compute=None, shifted_k=True) -> None: # TODO bind ScaleType
|
|
488
|
+
super().__init__()
|
|
489
|
+
|
|
490
|
+
self.elementwise_functor = elementwise_functor
|
|
491
|
+
self.element_compute = elementwise_functor.element_epilogue
|
|
492
|
+
self.element_output = elementwise_functor.element_output
|
|
493
|
+
|
|
494
|
+
if element_variance is None:
|
|
495
|
+
self.element_variance = self.element_output
|
|
496
|
+
if element_mean is None:
|
|
497
|
+
self.element_mean = self.element_output
|
|
498
|
+
if element_layer_norm_compute is None:
|
|
499
|
+
self.element_layer_norm_compute = self.element_compute
|
|
500
|
+
if shifted_k:
|
|
501
|
+
self.shifted_k = "true"
|
|
502
|
+
else:
|
|
503
|
+
self.shifted_k = "false"
|
|
504
|
+
|
|
505
|
+
# get epilogue output op
|
|
506
|
+
elementwise_params_type = self.elementwise_functor.epilogue_type
|
|
507
|
+
|
|
508
|
+
class _EpilogueVisitorParams(ctypes.Structure):
|
|
509
|
+
_fields_ = [
|
|
510
|
+
("element_wise", elementwise_params_type),
|
|
511
|
+
("ptr_Variance", ctypes.c_void_p),
|
|
512
|
+
("ptr_Mean_", ctypes.c_void_p),
|
|
513
|
+
("ptr_Shifted_K_", ctypes.c_void_p),
|
|
514
|
+
("extent", MatrixCoord_)
|
|
515
|
+
]
|
|
516
|
+
def __init__(self, elementwise_params, variance, mean, shift_k, extent) -> None:
|
|
517
|
+
self.element_wise = elementwise_params
|
|
518
|
+
if isinstance(variance, np.ndarray):
|
|
519
|
+
self.buffer_variance = NumpyFrontend.argument(variance, False)
|
|
520
|
+
self.buffer_mean = NumpyFrontend.argument(mean, False)
|
|
521
|
+
self.buffer_shift_k = NumpyFrontend.argument(shift_k, False)
|
|
522
|
+
self.ptr_Variance = int(self.buffer_variance.ptr)
|
|
523
|
+
self.ptr_Mean_ = int(self.buffer_mean.ptr)
|
|
524
|
+
self.ptr_Shifted_K_ = int(self.buffer_shift_k.ptr)
|
|
525
|
+
self.extent = MatrixCoord_(extent[0], extent[1])
|
|
526
|
+
|
|
527
|
+
self.host_variance = variance
|
|
528
|
+
self.host_mean = mean
|
|
529
|
+
self.host_shift_k = shift_k
|
|
530
|
+
|
|
531
|
+
def sync(self, stream_sync=True):
|
|
532
|
+
if stream_sync:
|
|
533
|
+
err, = cudart.cudaDeviceSynchronize()
|
|
534
|
+
if err != cuda.CUresult.CUDA_SUCCESS:
|
|
535
|
+
raise RuntimeError("CUDA Error %s" % str(err))
|
|
536
|
+
|
|
537
|
+
# if hasattr(self, "host_variance"):
|
|
538
|
+
err, = cuda.cuMemcpyDtoH(
|
|
539
|
+
self.host_variance, cuda.CUdeviceptr(self.ptr_Variance),
|
|
540
|
+
self.host_variance.size * self.host_variance.itemsize)
|
|
541
|
+
err, = cuda.cuMemcpyDtoH(
|
|
542
|
+
self.host_mean, cuda.CUdeviceptr(self.ptr_Mean_),
|
|
543
|
+
self.host_mean.size * self.host_mean.itemsize)
|
|
544
|
+
err, = cuda.cuMemcpyDtoH(
|
|
545
|
+
self.host_shift_k, cuda.CUdeviceptr(self.ptr_Shifted_K_),
|
|
546
|
+
self.host_shift_k.size * self.host_shift_k.itemsize)
|
|
547
|
+
if err != cuda.CUresult.CUDA_SUCCESS:
|
|
548
|
+
raise RuntimeError("CUDA Error %s" % str(err))
|
|
549
|
+
|
|
550
|
+
self.epilogue_type = _EpilogueVisitorParams
|
|
551
|
+
|
|
552
|
+
def emit(self, operation):
|
|
553
|
+
values = {
|
|
554
|
+
'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
|
|
555
|
+
'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
|
|
556
|
+
'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
|
|
557
|
+
'operation_name': operation.procedural_name(),
|
|
558
|
+
'element_compute': DataTypeTag[self.element_compute],
|
|
559
|
+
'element_variance': DataTypeTag[self.element_variance],
|
|
560
|
+
'element_mean': DataTypeTag[self.element_mean],
|
|
561
|
+
'element_layer_norm_compute': DataTypeTag[self.element_layer_norm_compute],
|
|
562
|
+
'epilogue_functor': self.elementwise_functor.emit(),
|
|
563
|
+
'shifted_k': self.shifted_k
|
|
564
|
+
}
|
|
565
|
+
return SubstituteTemplate(self.KernelTemplate, values)
|
|
566
|
+
|
|
567
|
+
|
|
568
|
+
|
|
569
|
+
class AccumulatorOp:
|
|
570
|
+
Template = """
|
|
571
|
+
using ${instance_name} = cutlass::epilogue::threadblock::VisitorOpAccumulator<${element_accumulator}, ${elements_per_access}>;
|
|
572
|
+
"""
|
|
573
|
+
counter = 0
|
|
574
|
+
def __init__(self, element_accumulator, elements_per_access) -> None:
|
|
575
|
+
self.element_accumulator = element_accumulator
|
|
576
|
+
self.elements_per_access = elements_per_access
|
|
577
|
+
|
|
578
|
+
self.instance_name = "AccumulatorOp%d" % AccumulatorOp.counter
|
|
579
|
+
AccumulatorOp.counter += 1
|
|
580
|
+
|
|
581
|
+
|
|
582
|
+
class _Arguments(ctypes.Structure):
|
|
583
|
+
_fields_ = [
|
|
584
|
+
("tmp", ctypes.c_int)
|
|
585
|
+
]
|
|
586
|
+
def __init__(self):
|
|
587
|
+
self.tmp = 0
|
|
588
|
+
|
|
589
|
+
self.argument_type = _Arguments
|
|
590
|
+
|
|
591
|
+
def emit(self, *args):
|
|
592
|
+
values = {
|
|
593
|
+
"instance_name": self.instance_name,
|
|
594
|
+
"element_accumulator": DataTypeTag[self.element_accumulator],
|
|
595
|
+
"elements_per_access": str(self.elements_per_access)
|
|
596
|
+
}
|
|
597
|
+
return SubstituteTemplate(self.Template, values)
|
|
598
|
+
|
|
599
|
+
|
|
600
|
+
class LinearCombinationOp:
|
|
601
|
+
Template = """
|
|
602
|
+
${visitor_a}
|
|
603
|
+
|
|
604
|
+
${visitor_b}
|
|
605
|
+
|
|
606
|
+
using ${instance_name} = cutlass::epilogue::threadblock::VisitorOpLinearCombination<
|
|
607
|
+
${element_accumulator}, ${element_compute},
|
|
608
|
+
${elements_per_access}, ${visitor_a_name}, ${visitor_b_name}>;
|
|
609
|
+
"""
|
|
610
|
+
counter = 0
|
|
611
|
+
def __init__(self, element_accumulator, element_compute,
|
|
612
|
+
elements_per_access, visitor_a, visitor_b) -> None:
|
|
613
|
+
#
|
|
614
|
+
self.element_accumulator = element_accumulator
|
|
615
|
+
self.element_compute = element_compute
|
|
616
|
+
self.elements_per_access = elements_per_access
|
|
617
|
+
self.visitor_a = visitor_a
|
|
618
|
+
self.visitor_b = visitor_b
|
|
619
|
+
|
|
620
|
+
self.instance_name = "LinearCombinationOp%d" % LinearCombinationOp.counter
|
|
621
|
+
LinearCombinationOp.counter += 1
|
|
622
|
+
|
|
623
|
+
class _Arguments(ctypes.Structure):
|
|
624
|
+
_fields_ = [
|
|
625
|
+
("alpha", dtype2ctype[self.element_compute]),
|
|
626
|
+
("beta", dtype2ctype[self.element_compute]),
|
|
627
|
+
("visitor_a", self.visitor_a.argument_type),
|
|
628
|
+
("visitor_b", self.visitor_b.argument_type)
|
|
629
|
+
]
|
|
630
|
+
def __init__(self, alpha, beta, visitor_a_arg, visitor_b_arg) -> None:
|
|
631
|
+
self.alpha = element_compute(alpha).storage
|
|
632
|
+
self.beta = element_compute(beta).storage
|
|
633
|
+
self.visitor_a = visitor_a_arg
|
|
634
|
+
self.visitor_b = visitor_b_arg
|
|
635
|
+
|
|
636
|
+
self.argument_type = _Arguments
|
|
637
|
+
|
|
638
|
+
def emit(self, operation):
|
|
639
|
+
values = {
|
|
640
|
+
"instance_name": self.instance_name,
|
|
641
|
+
"element_accumulator": DataTypeTag[self.element_accumulator],
|
|
642
|
+
"element_compute": DataTypeTag[self.element_compute],
|
|
643
|
+
"elements_per_access": str(self.elements_per_access),
|
|
644
|
+
"visitor_a_name": self.visitor_a.instance_name,
|
|
645
|
+
"visitor_b_name": self.visitor_b.instance_name,
|
|
646
|
+
"visitor_a": self.visitor_a.emit(operation),
|
|
647
|
+
"visitor_b": self.visitor_b.emit(operation)
|
|
648
|
+
}
|
|
649
|
+
return SubstituteTemplate(self.Template, values)
|
|
650
|
+
|
|
651
|
+
class VectorAdd:
|
|
652
|
+
def __init__(self, *args) -> None:
|
|
653
|
+
class _Arguments(ctypes.Structure):
|
|
654
|
+
_fields_ = [
|
|
655
|
+
("tmp", ctypes.c_int)
|
|
656
|
+
]
|
|
657
|
+
def __init__(self, *args) -> None:
|
|
658
|
+
self.tmp = 0
|
|
659
|
+
self.argument_type = _Arguments
|
|
660
|
+
|
|
661
|
+
def emit(self):
|
|
662
|
+
return "cutlass::VectorAdd"
|
|
663
|
+
|
|
664
|
+
class VectorMult:
|
|
665
|
+
def __init__(self, *args) -> None:
|
|
666
|
+
class _Arguments(ctypes.Structure):
|
|
667
|
+
_fields_ = [
|
|
668
|
+
("tmp", ctypes.c_int)
|
|
669
|
+
]
|
|
670
|
+
def __init__(self, *args) -> None:
|
|
671
|
+
self.tmp = 0
|
|
672
|
+
self.argument_type = _Arguments
|
|
673
|
+
|
|
674
|
+
def emit(self):
|
|
675
|
+
return "cutlass::VectorMult"
|
|
676
|
+
|
|
677
|
+
|
|
678
|
+
class BinaryOp:
|
|
679
|
+
Template = """
|
|
680
|
+
${visitor_a}
|
|
681
|
+
|
|
682
|
+
${visitor_b}
|
|
683
|
+
|
|
684
|
+
using ${instance_name} = cutlass::epilogue::threadblock::VisitorOpBinary<
|
|
685
|
+
${element_accumulator}, ${element_compute},
|
|
686
|
+
${elements_per_access}, ${visitor_a_name}, ${visitor_b_name}, ${binary_op}>;
|
|
687
|
+
"""
|
|
688
|
+
counter = 0
|
|
689
|
+
def __init__(self, element_accumulator, element_compute,
|
|
690
|
+
elements_per_access, visitor_a, visitor_b, binary_op) -> None:
|
|
691
|
+
#
|
|
692
|
+
self.element_accumulator = element_accumulator
|
|
693
|
+
self.element_compute = element_compute
|
|
694
|
+
self.elements_per_access = elements_per_access
|
|
695
|
+
self.visitor_a = visitor_a
|
|
696
|
+
self.visitor_b = visitor_b
|
|
697
|
+
self.binary_op = binary_op
|
|
698
|
+
|
|
699
|
+
self.instance_name = "BinaryOp%d" % BinaryOp.counter
|
|
700
|
+
BinaryOp.counter += 1
|
|
701
|
+
|
|
702
|
+
class _Arguments(ctypes.Structure):
|
|
703
|
+
_fields_ = [
|
|
704
|
+
("binary_param", binary_op.argument_type),
|
|
705
|
+
("visitor_a", self.visitor_a.argument_type),
|
|
706
|
+
("visitor_b", self.visitor_b.argument_type)
|
|
707
|
+
]
|
|
708
|
+
def __init__(self, binary_param, visitor_a_arg, visitor_b_arg) -> None:
|
|
709
|
+
self.binary_param = binary_param
|
|
710
|
+
self.visitor_a = visitor_a_arg
|
|
711
|
+
self.visitor_b = visitor_b_arg
|
|
712
|
+
|
|
713
|
+
self.argument_type = _Arguments
|
|
714
|
+
def emit(self, operation):
|
|
715
|
+
values = {
|
|
716
|
+
"instance_name": self.instance_name,
|
|
717
|
+
"element_accumulator": DataTypeTag[self.element_accumulator],
|
|
718
|
+
"element_compute": DataTypeTag[self.element_compute],
|
|
719
|
+
"elements_per_access": str(self.elements_per_access),
|
|
720
|
+
"visitor_a_name": self.visitor_a.instance_name,
|
|
721
|
+
"visitor_b_name": self.visitor_b.instance_name,
|
|
722
|
+
"visitor_a": self.visitor_a.emit(operation),
|
|
723
|
+
"visitor_b": self.visitor_b.emit(operation),
|
|
724
|
+
"binary_op": self.binary_op.emit()
|
|
725
|
+
}
|
|
726
|
+
return SubstituteTemplate(self.Template, values)
|
|
727
|
+
|
|
728
|
+
|
|
729
|
+
class Mult:
|
|
730
|
+
def __init__(self, element_compute) -> None:
|
|
731
|
+
class _Arguments(ctypes.Structure):
|
|
732
|
+
_fields_ = [
|
|
733
|
+
("alpha", dtype2ctype[element_compute])
|
|
734
|
+
]
|
|
735
|
+
def __init__(self, alpha) -> None:
|
|
736
|
+
self.alpha = element_compute(alpha).storage
|
|
737
|
+
|
|
738
|
+
self.argument_type = _Arguments
|
|
739
|
+
|
|
740
|
+
def emit_visitor(self):
|
|
741
|
+
return "cutlass::Mult"
|
|
742
|
+
|
|
743
|
+
class UnaryOp:
|
|
744
|
+
Template = """
|
|
745
|
+
${visitor}
|
|
746
|
+
|
|
747
|
+
using ${instance_name} = cutlass::epilogue::threadblock::VisitorOpUnary<
|
|
748
|
+
${element_accumulator}, ${element_compute},
|
|
749
|
+
${elements_per_access}, ${visitor_name}, ${unary_op}>;
|
|
750
|
+
"""
|
|
751
|
+
counter = 0
|
|
752
|
+
def __init__(self, element_accumulator, element_compute,
|
|
753
|
+
elements_per_access, visitor, unary_op) -> None:
|
|
754
|
+
#
|
|
755
|
+
self.element_accumulator = element_accumulator
|
|
756
|
+
self.element_compute = element_compute
|
|
757
|
+
self.elements_per_access = elements_per_access
|
|
758
|
+
self.visitor = visitor
|
|
759
|
+
self.unary_op = unary_op
|
|
760
|
+
|
|
761
|
+
self.instance_name = "UnaryOp%d" % UnaryOp.counter
|
|
762
|
+
UnaryOp.counter += 1
|
|
763
|
+
|
|
764
|
+
class _Arguments(ctypes.Structure):
|
|
765
|
+
_fields_ = [
|
|
766
|
+
("unary_param", unary_op.argument_type),
|
|
767
|
+
("visitor_arg", self.visitor.argument_type)
|
|
768
|
+
]
|
|
769
|
+
def __init__(self, unary_param, visitor_arg) -> None:
|
|
770
|
+
self.unary_param = unary_param
|
|
771
|
+
self.visitor_arg = visitor_arg
|
|
772
|
+
|
|
773
|
+
self.argument_type = _Arguments
|
|
774
|
+
|
|
775
|
+
def emit(self, operation):
|
|
776
|
+
values = {
|
|
777
|
+
"instance_name": self.instance_name,
|
|
778
|
+
"element_accumulator": DataTypeTag[self.element_accumulator],
|
|
779
|
+
"element_compute": DataTypeTag[self.element_compute],
|
|
780
|
+
"elements_per_access": str(self.elements_per_access),
|
|
781
|
+
"visitor_name": self.visitor.instance_name,
|
|
782
|
+
"unary_op": self.unary_op.emit_visitor(),
|
|
783
|
+
"visitor": self.visitor.emit(operation)
|
|
784
|
+
}
|
|
785
|
+
return SubstituteTemplate(self.Template, values)
|
|
786
|
+
|
|
787
|
+
|
|
788
|
+
|
|
789
|
+
class RowBroadcastOp:
|
|
790
|
+
Template = """
|
|
791
|
+
using ${instance_name} = cutlass::epilogue::threadblock::VisitorOpRowBroadcast<
|
|
792
|
+
${element_accumulator}, ${element_fragment}, ${input_tile_iterator}>;
|
|
793
|
+
"""
|
|
794
|
+
counter = 0
|
|
795
|
+
def __init__(self, element_accumulator, element_fragment) -> None:
|
|
796
|
+
self.element_accumulator = element_accumulator
|
|
797
|
+
self.element_fragment = element_fragment
|
|
798
|
+
|
|
799
|
+
self.instance_name = "RowBroadcastOp%d" % RowBroadcastOp.counter
|
|
800
|
+
RowBroadcastOp.counter += 1
|
|
801
|
+
|
|
802
|
+
class _Arguments(ctypes.Structure):
|
|
803
|
+
_fields_ = [
|
|
804
|
+
("broadcast_ptr", ctypes.c_void_p),
|
|
805
|
+
("batch_stride", ctypes.c_longlong)
|
|
806
|
+
]
|
|
807
|
+
def __init__(self, broadcast_ptr, batch_stride=0):
|
|
808
|
+
self.broadcast_ptr = int(broadcast_ptr)
|
|
809
|
+
self.batch_stride = batch_stride
|
|
810
|
+
|
|
811
|
+
self.argument_type = _Arguments
|
|
812
|
+
|
|
813
|
+
def emit(self, operation):
|
|
814
|
+
values = {
|
|
815
|
+
"instance_name": self.instance_name,
|
|
816
|
+
"element_accumulator": DataTypeTag[self.element_accumulator],
|
|
817
|
+
"element_fragment": DataTypeTag[self.element_fragment],
|
|
818
|
+
"input_tile_iterator": operation.procedural_name() + "_default::Epilogue::OutputTileIterator"
|
|
819
|
+
}
|
|
820
|
+
return SubstituteTemplate(self.Template, values)
|
|
821
|
+
|
|
822
|
+
|
|
823
|
+
class ColumnBroadcastOp:
|
|
824
|
+
Template = """
|
|
825
|
+
using ${instance_name} = cutlass::epilogue::threadblock::VisitorOpColumnBroadcast<
|
|
826
|
+
${element_accumulator}, ${element_fragment}, ${input_tile_iterator}>;
|
|
827
|
+
"""
|
|
828
|
+
counter = 0
|
|
829
|
+
def __init__(self, element_accumulator, element_fragment) -> None:
|
|
830
|
+
self.element_accumulator = element_accumulator
|
|
831
|
+
self.element_fragment = element_fragment
|
|
832
|
+
|
|
833
|
+
self.instance_name = "ColumnBroadcastOp%d" % ColumnBroadcastOp.counter
|
|
834
|
+
ColumnBroadcastOp.counter += 1
|
|
835
|
+
|
|
836
|
+
class _Arguments(ctypes.Structure):
|
|
837
|
+
_fields_ = [
|
|
838
|
+
("broadcast_ptr", ctypes.c_void_p),
|
|
839
|
+
("batch_stride", ctypes.c_longlong)
|
|
840
|
+
]
|
|
841
|
+
def __init__(self, broadcast_ptr, batch_stride=0):
|
|
842
|
+
self.broadcast_ptr = int(broadcast_ptr)
|
|
843
|
+
self.batch_stride = batch_stride
|
|
844
|
+
|
|
845
|
+
self.argument_type = _Arguments
|
|
846
|
+
|
|
847
|
+
def emit(self, operation):
|
|
848
|
+
values = {
|
|
849
|
+
"instance_name": self.instance_name,
|
|
850
|
+
"element_accumulator": DataTypeTag[self.element_accumulator],
|
|
851
|
+
"element_fragment": DataTypeTag[self.element_fragment],
|
|
852
|
+
"input_tile_iterator": operation.procedural_name() + "_default::Epilogue::OutputTileIterator"
|
|
853
|
+
}
|
|
854
|
+
return SubstituteTemplate(self.Template, values)
|
|
855
|
+
|
|
856
|
+
|
|
857
|
+
class TensorInputOp:
|
|
858
|
+
Template = """
|
|
859
|
+
using ${instance_name} = cutlass::epilogue::threadblock::VisitorOpTensorInput<
|
|
860
|
+
${element_accumulator}, ${input_tile_iterator}>;
|
|
861
|
+
"""
|
|
862
|
+
counter = 0
|
|
863
|
+
def __init__(self, element_accumulator) -> None:
|
|
864
|
+
self.element_accumulator = element_accumulator
|
|
865
|
+
|
|
866
|
+
self.instance_name = "TensorInputOp%d" % TensorInputOp.counter
|
|
867
|
+
TensorInputOp.counter += 1
|
|
868
|
+
|
|
869
|
+
class _Arguments(ctypes.Structure):
|
|
870
|
+
_fields_ = [
|
|
871
|
+
("input_ptr", ctypes.c_void_p),
|
|
872
|
+
("ldt", ctypes.c_int),
|
|
873
|
+
("batch_stride", ctypes.c_longlong)
|
|
874
|
+
]
|
|
875
|
+
def __init__(self, input_ptr, ldt, batch_stride=0) -> None:
|
|
876
|
+
self.input_ptr = int(input_ptr)
|
|
877
|
+
self.ldt = ldt
|
|
878
|
+
self.batch_stride = batch_stride
|
|
879
|
+
|
|
880
|
+
self.argument_type = _Arguments
|
|
881
|
+
|
|
882
|
+
def emit(self, operation):
|
|
883
|
+
values = {
|
|
884
|
+
"instance_name": self.instance_name,
|
|
885
|
+
"element_accumulator": DataTypeTag[self.element_accumulator],
|
|
886
|
+
"input_tile_iterator": operation.procedural_name() + "_default::Epilogue::OutputTileIterator"
|
|
887
|
+
}
|
|
888
|
+
return SubstituteTemplate(self.Template, values)
|
|
889
|
+
|
|
890
|
+
class TensorOutputOp:
|
|
891
|
+
Template = """
|
|
892
|
+
${visitor}
|
|
893
|
+
|
|
894
|
+
using ${instance_name} = cutlass::epilogue::threadblock::VisitorOpTensorOutput<
|
|
895
|
+
${element_accumulator}, ${output_tile_iterator}, ${visitor_name}>;
|
|
896
|
+
"""
|
|
897
|
+
counter = 0
|
|
898
|
+
def __init__(self, element_accumulator, visitor) -> None:
|
|
899
|
+
self.element_accumulator = element_accumulator
|
|
900
|
+
self.visitor = visitor
|
|
901
|
+
|
|
902
|
+
self.instance_name = "TensorOutputOp%d" % TensorOutputOp.counter
|
|
903
|
+
TensorOutputOp.counter += 1
|
|
904
|
+
|
|
905
|
+
class _Arguments(ctypes.Structure):
|
|
906
|
+
_fields_ = [
|
|
907
|
+
("output_ptr", ctypes.c_void_p),
|
|
908
|
+
("ldt", ctypes.c_int),
|
|
909
|
+
("batch_stride", ctypes.c_longlong),
|
|
910
|
+
("visitor_arg", self.visitor.argument_type)
|
|
911
|
+
]
|
|
912
|
+
def __init__(self, output_ptr, ldt, visitor_arg, batch_stride=0) -> None:
|
|
913
|
+
self.output_ptr = int(output_ptr)
|
|
914
|
+
self.ldt = int(ldt)
|
|
915
|
+
self.visitor_arg = visitor_arg
|
|
916
|
+
self.batch_stride = batch_stride
|
|
917
|
+
|
|
918
|
+
self.argument_type = _Arguments
|
|
919
|
+
|
|
920
|
+
def emit(self, operation):
|
|
921
|
+
values = {
|
|
922
|
+
"instance_name": self.instance_name,
|
|
923
|
+
"element_accumulator": DataTypeTag[self.element_accumulator],
|
|
924
|
+
"output_tile_iterator": operation.procedural_name() + "_default::Epilogue::OutputTileIterator",
|
|
925
|
+
"visitor_name": self.visitor.instance_name,
|
|
926
|
+
"visitor": self.visitor.emit(operation)
|
|
927
|
+
}
|
|
928
|
+
return SubstituteTemplate(self.Template, values)
|
|
929
|
+
|
|
930
|
+
|
|
931
|
+
class ColumnReductionOp:
|
|
932
|
+
Template = """
|
|
933
|
+
${visitor}
|
|
934
|
+
|
|
935
|
+
using ${instance_name} = cutlass::epilogue::threadblock::VisitorOpColumnReduction<
|
|
936
|
+
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
|
937
|
+
${element_accumulator}, ${element_reduction}, ${element_reduction_accumulator},
|
|
938
|
+
${output_tile_iterator}, ${visitor_name}>;
|
|
939
|
+
"""
|
|
940
|
+
counter = 0
|
|
941
|
+
def __init__(self, element_accumulator, element_reduction,
|
|
942
|
+
element_reduction_accumulator, visitor) -> None:
|
|
943
|
+
self.element_accumulator = element_accumulator
|
|
944
|
+
self.element_reduction = element_reduction
|
|
945
|
+
self.element_reduction_accumulator = element_reduction_accumulator
|
|
946
|
+
self.visitor = visitor
|
|
947
|
+
|
|
948
|
+
self.instance_name = "ColumnReductionOp%d" % ColumnReductionOp.counter
|
|
949
|
+
ColumnReductionOp.counter += 1
|
|
950
|
+
|
|
951
|
+
class _Arguments(ctypes.Structure):
|
|
952
|
+
_fields_ = [
|
|
953
|
+
("reduction_ptr", ctypes.c_void_p),
|
|
954
|
+
("batch_stride", ctypes.c_longlong),
|
|
955
|
+
("visitor_arg", self.visitor.argument_type)
|
|
956
|
+
]
|
|
957
|
+
def __init__(self, reduction_ptr, visitor_arg, batch_stride=0) -> None:
|
|
958
|
+
self.reduction_ptr = reduction_ptr
|
|
959
|
+
self.batch_stride = batch_stride
|
|
960
|
+
self.visitor_arg = visitor_arg
|
|
961
|
+
|
|
962
|
+
self.argument_type = _Arguments
|
|
963
|
+
|
|
964
|
+
def emit(self, operation):
|
|
965
|
+
values = {
|
|
966
|
+
"instance_name": self.instance_name,
|
|
967
|
+
'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
|
|
968
|
+
'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
|
|
969
|
+
'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
|
|
970
|
+
"element_accumulator": DataTypeTag[self.element_accumulator],
|
|
971
|
+
"element_reduction": DataTypeTag[self.element_reduction],
|
|
972
|
+
"element_reduction_accumulator": DataTypeTag[self.element_reduction_accumulator],
|
|
973
|
+
"output_tile_iterator": operation.procedural_name() + "_default::Epilogue::OutputTileIterator",
|
|
974
|
+
"visitor_name": self.visitor.instance_name,
|
|
975
|
+
"visitor": self.visitor.emit(operation)
|
|
976
|
+
}
|
|
977
|
+
return SubstituteTemplate(self.Template, values)
|
|
978
|
+
|
|
979
|
+
|
|
980
|
+
class RowReductionOp:
|
|
981
|
+
Template = """
|
|
982
|
+
${visitor}
|
|
983
|
+
|
|
984
|
+
using ${instance_name} = cutlass::epilogue::threadblock::VisitorOpRowReduction<
|
|
985
|
+
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
|
986
|
+
${element_accumulator}, ${element_reduction}, ${element_reduction_accumulator},
|
|
987
|
+
${output_tile_iterator}, ${visitor_name}>;
|
|
988
|
+
"""
|
|
989
|
+
counter = 0
|
|
990
|
+
def __init__(self, element_accumulator, element_reduction,
|
|
991
|
+
element_reduction_accumulator, visitor) -> None:
|
|
992
|
+
self.element_accumulator = element_accumulator
|
|
993
|
+
self.element_reduction = element_reduction
|
|
994
|
+
self.element_reduction_accumulator = element_reduction_accumulator
|
|
995
|
+
self.visitor = visitor
|
|
996
|
+
|
|
997
|
+
self.instance_name = "RowReductionOp%d" % RowReductionOp.counter
|
|
998
|
+
RowReductionOp.counter += 1
|
|
999
|
+
|
|
1000
|
+
class _Arguments(ctypes.Structure):
|
|
1001
|
+
_fields_ = [
|
|
1002
|
+
("reduction_ptr", ctypes.c_void_p),
|
|
1003
|
+
("batch_stride", ctypes.c_longlong),
|
|
1004
|
+
("visitor_arg", self.visitor.argument_type)
|
|
1005
|
+
]
|
|
1006
|
+
def __init__(self, reduction_ptr, visitor_arg, batch_stride=0) -> None:
|
|
1007
|
+
self.reduction_ptr = reduction_ptr
|
|
1008
|
+
self.visitor_arg = visitor_arg
|
|
1009
|
+
self.batch_stride = batch_stride
|
|
1010
|
+
|
|
1011
|
+
self.argument_type = _Arguments
|
|
1012
|
+
|
|
1013
|
+
def emit(self, operation):
|
|
1014
|
+
values = {
|
|
1015
|
+
"instance_name": self.instance_name,
|
|
1016
|
+
'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
|
|
1017
|
+
'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
|
|
1018
|
+
'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
|
|
1019
|
+
"element_accumulator": DataTypeTag[self.element_accumulator],
|
|
1020
|
+
"element_reduction": DataTypeTag[self.element_reduction],
|
|
1021
|
+
"element_reduction_accumulator": DataTypeTag[self.element_reduction_accumulator],
|
|
1022
|
+
"output_tile_iterator": operation.procedural_name() + "_default::Epilogue::OutputTileIterator",
|
|
1023
|
+
"visitor_name": self.visitor.instance_name,
|
|
1024
|
+
"visitor": self.visitor.emit(operation)
|
|
1025
|
+
}
|
|
1026
|
+
return SubstituteTemplate(self.Template, values)
|