warp-lang 0.10.1__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +10 -4
- warp/__init__.pyi +1 -0
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +5 -3
- warp/build_dll.py +29 -9
- warp/builtins.py +868 -507
- warp/codegen.py +1074 -638
- warp/config.py +3 -3
- warp/constants.py +6 -0
- warp/context.py +715 -222
- warp/fabric.py +326 -0
- warp/fem/__init__.py +27 -0
- warp/fem/cache.py +389 -0
- warp/fem/dirichlet.py +181 -0
- warp/fem/domain.py +263 -0
- warp/fem/field/__init__.py +101 -0
- warp/fem/field/field.py +149 -0
- warp/fem/field/nodal_field.py +299 -0
- warp/fem/field/restriction.py +21 -0
- warp/fem/field/test.py +181 -0
- warp/fem/field/trial.py +183 -0
- warp/fem/geometry/__init__.py +19 -0
- warp/fem/geometry/closest_point.py +70 -0
- warp/fem/geometry/deformed_geometry.py +271 -0
- warp/fem/geometry/element.py +744 -0
- warp/fem/geometry/geometry.py +186 -0
- warp/fem/geometry/grid_2d.py +373 -0
- warp/fem/geometry/grid_3d.py +435 -0
- warp/fem/geometry/hexmesh.py +953 -0
- warp/fem/geometry/partition.py +376 -0
- warp/fem/geometry/quadmesh_2d.py +532 -0
- warp/fem/geometry/tetmesh.py +840 -0
- warp/fem/geometry/trimesh_2d.py +577 -0
- warp/fem/integrate.py +1616 -0
- warp/fem/operator.py +191 -0
- warp/fem/polynomial.py +213 -0
- warp/fem/quadrature/__init__.py +2 -0
- warp/fem/quadrature/pic_quadrature.py +245 -0
- warp/fem/quadrature/quadrature.py +294 -0
- warp/fem/space/__init__.py +292 -0
- warp/fem/space/basis_space.py +489 -0
- warp/fem/space/collocated_function_space.py +105 -0
- warp/fem/space/dof_mapper.py +236 -0
- warp/fem/space/function_space.py +145 -0
- warp/fem/space/grid_2d_function_space.py +267 -0
- warp/fem/space/grid_3d_function_space.py +306 -0
- warp/fem/space/hexmesh_function_space.py +352 -0
- warp/fem/space/partition.py +350 -0
- warp/fem/space/quadmesh_2d_function_space.py +369 -0
- warp/fem/space/restriction.py +160 -0
- warp/fem/space/shape/__init__.py +15 -0
- warp/fem/space/shape/cube_shape_function.py +738 -0
- warp/fem/space/shape/shape_function.py +103 -0
- warp/fem/space/shape/square_shape_function.py +611 -0
- warp/fem/space/shape/tet_shape_function.py +567 -0
- warp/fem/space/shape/triangle_shape_function.py +429 -0
- warp/fem/space/tetmesh_function_space.py +292 -0
- warp/fem/space/topology.py +295 -0
- warp/fem/space/trimesh_2d_function_space.py +221 -0
- warp/fem/types.py +77 -0
- warp/fem/utils.py +495 -0
- warp/native/array.h +147 -44
- warp/native/builtin.h +122 -149
- warp/native/bvh.cpp +73 -325
- warp/native/bvh.cu +406 -23
- warp/native/bvh.h +34 -43
- warp/native/clang/clang.cpp +13 -8
- warp/native/crt.h +2 -0
- warp/native/cuda_crt.h +5 -0
- warp/native/cuda_util.cpp +15 -3
- warp/native/cuda_util.h +3 -1
- warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
- warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
- warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
- warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
- warp/native/cutlass/tools/library/scripts/library.py +799 -0
- warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
- warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
- warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
- warp/native/cutlass/tools/library/scripts/rt.py +796 -0
- warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
- warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
- warp/native/cutlass_gemm.cu +5 -3
- warp/native/exports.h +1240 -952
- warp/native/fabric.h +228 -0
- warp/native/hashgrid.cpp +4 -4
- warp/native/hashgrid.h +22 -2
- warp/native/intersect.h +22 -7
- warp/native/intersect_adj.h +8 -8
- warp/native/intersect_tri.h +1 -1
- warp/native/marching.cu +157 -161
- warp/native/mat.h +80 -19
- warp/native/matnn.h +2 -2
- warp/native/mesh.cpp +33 -108
- warp/native/mesh.cu +114 -23
- warp/native/mesh.h +446 -46
- warp/native/noise.h +272 -329
- warp/native/quat.h +51 -8
- warp/native/rand.h +45 -35
- warp/native/range.h +6 -2
- warp/native/reduce.cpp +1 -1
- warp/native/reduce.cu +10 -12
- warp/native/runlength_encode.cu +6 -10
- warp/native/scan.cu +8 -11
- warp/native/sparse.cpp +4 -4
- warp/native/sparse.cu +164 -154
- warp/native/spatial.h +2 -2
- warp/native/temp_buffer.h +14 -30
- warp/native/vec.h +107 -23
- warp/native/volume.h +120 -0
- warp/native/warp.cpp +560 -30
- warp/native/warp.cu +431 -44
- warp/native/warp.h +13 -4
- warp/optim/__init__.py +1 -0
- warp/optim/linear.py +922 -0
- warp/optim/sgd.py +92 -0
- warp/render/render_opengl.py +335 -119
- warp/render/render_usd.py +11 -11
- warp/sim/__init__.py +2 -2
- warp/sim/articulation.py +385 -185
- warp/sim/collide.py +8 -0
- warp/sim/import_mjcf.py +297 -106
- warp/sim/import_urdf.py +389 -210
- warp/sim/import_usd.py +198 -97
- warp/sim/inertia.py +17 -18
- warp/sim/integrator_euler.py +14 -8
- warp/sim/integrator_xpbd.py +158 -16
- warp/sim/model.py +795 -291
- warp/sim/render.py +3 -3
- warp/sim/utils.py +3 -0
- warp/sparse.py +640 -150
- warp/stubs.py +606 -267
- warp/tape.py +61 -10
- warp/tests/__main__.py +3 -6
- warp/tests/assets/curlnoise_golden.npy +0 -0
- warp/tests/assets/pnoise_golden.npy +0 -0
- warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
- warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
- warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
- warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
- warp/tests/aux_test_unresolved_func.py +14 -0
- warp/tests/aux_test_unresolved_symbol.py +14 -0
- warp/tests/disabled_kinematics.py +239 -0
- warp/tests/run_coverage_serial.py +31 -0
- warp/tests/test_adam.py +103 -106
- warp/tests/test_arithmetic.py +128 -74
- warp/tests/test_array.py +212 -97
- warp/tests/test_array_reduce.py +57 -23
- warp/tests/test_atomic.py +64 -28
- warp/tests/test_bool.py +99 -0
- warp/tests/test_builtins_resolution.py +1292 -0
- warp/tests/test_bvh.py +42 -18
- warp/tests/test_closest_point_edge_edge.py +54 -57
- warp/tests/test_codegen.py +208 -130
- warp/tests/test_compile_consts.py +28 -20
- warp/tests/test_conditional.py +108 -24
- warp/tests/test_copy.py +10 -12
- warp/tests/test_ctypes.py +112 -88
- warp/tests/test_dense.py +21 -14
- warp/tests/test_devices.py +98 -0
- warp/tests/test_dlpack.py +75 -75
- warp/tests/test_examples.py +277 -0
- warp/tests/test_fabricarray.py +955 -0
- warp/tests/test_fast_math.py +15 -11
- warp/tests/test_fem.py +1271 -0
- warp/tests/test_fp16.py +53 -19
- warp/tests/test_func.py +187 -86
- warp/tests/test_generics.py +194 -49
- warp/tests/test_grad.py +178 -109
- warp/tests/test_grad_customs.py +176 -0
- warp/tests/test_hash_grid.py +52 -37
- warp/tests/test_import.py +10 -23
- warp/tests/test_indexedarray.py +32 -31
- warp/tests/test_intersect.py +18 -9
- warp/tests/test_large.py +141 -0
- warp/tests/test_launch.py +14 -41
- warp/tests/test_lerp.py +64 -65
- warp/tests/test_linear_solvers.py +154 -0
- warp/tests/test_lvalue.py +493 -0
- warp/tests/test_marching_cubes.py +12 -13
- warp/tests/test_mat.py +517 -2898
- warp/tests/test_mat_lite.py +115 -0
- warp/tests/test_mat_scalar_ops.py +2889 -0
- warp/tests/test_math.py +103 -9
- warp/tests/test_matmul.py +305 -69
- warp/tests/test_matmul_lite.py +410 -0
- warp/tests/test_mesh.py +71 -14
- warp/tests/test_mesh_query_aabb.py +41 -25
- warp/tests/test_mesh_query_point.py +140 -22
- warp/tests/test_mesh_query_ray.py +39 -22
- warp/tests/test_mlp.py +30 -22
- warp/tests/test_model.py +92 -89
- warp/tests/test_modules_lite.py +39 -0
- warp/tests/test_multigpu.py +88 -114
- warp/tests/test_noise.py +12 -11
- warp/tests/test_operators.py +16 -20
- warp/tests/test_options.py +11 -11
- warp/tests/test_pinned.py +17 -18
- warp/tests/test_print.py +32 -11
- warp/tests/test_quat.py +275 -129
- warp/tests/test_rand.py +18 -16
- warp/tests/test_reload.py +38 -34
- warp/tests/test_rounding.py +50 -43
- warp/tests/test_runlength_encode.py +168 -20
- warp/tests/test_smoothstep.py +9 -11
- warp/tests/test_snippet.py +143 -0
- warp/tests/test_sparse.py +261 -63
- warp/tests/test_spatial.py +276 -243
- warp/tests/test_streams.py +110 -85
- warp/tests/test_struct.py +268 -63
- warp/tests/test_tape.py +39 -21
- warp/tests/test_torch.py +118 -89
- warp/tests/test_transient_module.py +12 -13
- warp/tests/test_types.py +614 -0
- warp/tests/test_utils.py +494 -0
- warp/tests/test_vec.py +354 -2050
- warp/tests/test_vec_lite.py +73 -0
- warp/tests/test_vec_scalar_ops.py +2099 -0
- warp/tests/test_volume.py +457 -293
- warp/tests/test_volume_write.py +124 -134
- warp/tests/unittest_serial.py +35 -0
- warp/tests/unittest_suites.py +341 -0
- warp/tests/unittest_utils.py +568 -0
- warp/tests/unused_test_misc.py +71 -0
- warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
- warp/thirdparty/appdirs.py +36 -45
- warp/thirdparty/unittest_parallel.py +549 -0
- warp/torch.py +9 -6
- warp/types.py +1089 -366
- warp/utils.py +93 -387
- warp_lang-0.11.0.dist-info/METADATA +238 -0
- warp_lang-0.11.0.dist-info/RECORD +332 -0
- {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
- warp/tests/test_all.py +0 -219
- warp/tests/test_array_scan.py +0 -60
- warp/tests/test_base.py +0 -208
- warp/tests/test_unresolved_func.py +0 -7
- warp/tests/test_unresolved_symbol.py +0 -7
- warp_lang-0.10.1.dist-info/METADATA +0 -21
- warp_lang-0.10.1.dist-info/RECORD +0 -188
- /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
- /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
- /warp/tests/{test_square.py → aux_test_square.py} +0 -0
- {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/LICENSE.md +0 -0
- {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,557 @@
|
|
|
1
|
+
#################################################################################################
|
|
2
|
+
#
|
|
3
|
+
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
4
|
+
# SPDX-License-Identifier: BSD-3-Clause
|
|
5
|
+
#
|
|
6
|
+
# Redistribution and use in source and binary forms, with or without
|
|
7
|
+
# modification, are permitted provided that the following conditions are met:
|
|
8
|
+
#
|
|
9
|
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
|
10
|
+
# list of conditions and the following disclaimer.
|
|
11
|
+
#
|
|
12
|
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
13
|
+
# this list of conditions and the following disclaimer in the documentation
|
|
14
|
+
# and/or other materials provided with the distribution.
|
|
15
|
+
#
|
|
16
|
+
# 3. Neither the name of the copyright holder nor the names of its
|
|
17
|
+
# contributors may be used to endorse or promote products derived from
|
|
18
|
+
# this software without specific prior written permission.
|
|
19
|
+
#
|
|
20
|
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
21
|
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
22
|
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
23
|
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
24
|
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
25
|
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
26
|
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
27
|
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
28
|
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
29
|
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
30
|
+
#
|
|
31
|
+
#################################################################################################
|
|
32
|
+
|
|
33
|
+
from time import sleep
|
|
34
|
+
import pycutlass
|
|
35
|
+
from pycutlass import *
|
|
36
|
+
import cutlass
|
|
37
|
+
from cuda import cudart
|
|
38
|
+
from cuda import cuda
|
|
39
|
+
from bfloat16 import bfloat16
|
|
40
|
+
from .profiler import GpuTimer
|
|
41
|
+
import subprocess
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def transpose(layout):
|
|
45
|
+
if layout == cutlass.RowMajor:
|
|
46
|
+
return cutlass.ColumnMajor
|
|
47
|
+
elif layout == cutlass.ColumnMajor:
|
|
48
|
+
return cutlass.RowMajor
|
|
49
|
+
elif layout == cutlass.ColumnMajorInterleaved32:
|
|
50
|
+
return cutlass.RowMajorInterleaved32
|
|
51
|
+
elif layout == cutlass.RowMajorInterleaved32:
|
|
52
|
+
return cutlass.ColumnMajorInterleaved32
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def getTensorRef(tensor: np.ndarray, problem_size: cutlass.gemm.GemmCoord, operand: str, layout: cutlass.layout):
|
|
56
|
+
ptr = tensor.__array_interface__['data'][0]
|
|
57
|
+
if operand == "a":
|
|
58
|
+
tensor_coord = problem_size.mk()
|
|
59
|
+
elif operand == "b":
|
|
60
|
+
tensor_coord = problem_size.kn()
|
|
61
|
+
elif operand in ["c", "d"]:
|
|
62
|
+
tensor_coord = problem_size.mn()
|
|
63
|
+
else:
|
|
64
|
+
raise ValueError("unknonw operand: " + operand)
|
|
65
|
+
|
|
66
|
+
if layout == cutlass.RowMajor:
|
|
67
|
+
layout = cutlass.RowMajor.packed(tensor_coord)
|
|
68
|
+
layout_tag = "RowMajor"
|
|
69
|
+
elif layout == cutlass.ColumnMajor:
|
|
70
|
+
layout = cutlass.ColumnMajor.packed(tensor_coord)
|
|
71
|
+
layout_tag = "ColumnMajor"
|
|
72
|
+
elif layout == cutlass.ColumnMajorInterleaved32:
|
|
73
|
+
layout = cutlass.ColumnMajorInterleaved32.packed(tensor_coord)
|
|
74
|
+
layout_tag = "ColumnMajorInterleaved32"
|
|
75
|
+
elif layout == cutlass.RowMajorInterleaved32:
|
|
76
|
+
layout = cutlass.RowMajorInterleaved32.packed(tensor_coord)
|
|
77
|
+
layout_tag = "RowMajorInterleaved32"
|
|
78
|
+
else:
|
|
79
|
+
raise ValueError("unsupported layout")
|
|
80
|
+
if tensor.dtype == np.float32:
|
|
81
|
+
ref_name = "TensorRefF32" + layout_tag
|
|
82
|
+
elif tensor.dtype == np.float64:
|
|
83
|
+
ref_name = "TensorRefF64" + layout_tag
|
|
84
|
+
elif tensor.dtype == np.float16:
|
|
85
|
+
ref_name = "TensorRefF16" + layout_tag
|
|
86
|
+
elif tensor.dtype == bfloat16:
|
|
87
|
+
ref_name = "TensorRefBF16" + layout_tag
|
|
88
|
+
elif tensor.dtype == np.int8:
|
|
89
|
+
ref_name = "TensorRefS8" + layout_tag
|
|
90
|
+
elif tensor.dtype == np.int32:
|
|
91
|
+
ref_name = "TensorRefS32" + layout_tag
|
|
92
|
+
else:
|
|
93
|
+
raise ValueError("unsupported datatype %s" %
|
|
94
|
+
ShortDataTypeNames[tensor.dtype])
|
|
95
|
+
|
|
96
|
+
return getattr(cutlass, ref_name)(ptr, layout)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def getTensorView(tensor: np.ndarray, problem_size: cutlass.gemm.GemmCoord, operand: str, layout: str):
|
|
100
|
+
tensor_ref = getTensorRef(tensor, problem_size, operand, layout)
|
|
101
|
+
|
|
102
|
+
if operand == "a":
|
|
103
|
+
tensor_coord = problem_size.mk()
|
|
104
|
+
elif operand == "b":
|
|
105
|
+
tensor_coord = problem_size.kn()
|
|
106
|
+
elif operand in ["c", "d"]:
|
|
107
|
+
tensor_coord = problem_size.mn()
|
|
108
|
+
else:
|
|
109
|
+
raise ValueError("unknonw operand: " + operand)
|
|
110
|
+
|
|
111
|
+
if layout == cutlass.RowMajor:
|
|
112
|
+
layout_tag = "RowMajor"
|
|
113
|
+
elif layout == cutlass.ColumnMajor:
|
|
114
|
+
layout_tag = "ColumnMajor"
|
|
115
|
+
elif layout == cutlass.ColumnMajorInterleaved32:
|
|
116
|
+
layout_tag = "ColumnMajorInterleaved32"
|
|
117
|
+
elif layout == cutlass.RowMajorInterleaved32:
|
|
118
|
+
layout_tag = "RowMajorInterleaved32"
|
|
119
|
+
else:
|
|
120
|
+
raise ValueError("unsupported layout")
|
|
121
|
+
if tensor.dtype == np.float32:
|
|
122
|
+
ref_name = "TensorViewF32" + layout_tag
|
|
123
|
+
elif tensor.dtype == np.float64:
|
|
124
|
+
ref_name = "TensorViewF64" + layout_tag
|
|
125
|
+
elif tensor.dtype == np.float16:
|
|
126
|
+
ref_name = "TensorViewF16" + layout_tag
|
|
127
|
+
elif tensor.dtype == bfloat16:
|
|
128
|
+
ref_name = "TensorViewBF16" + layout_tag
|
|
129
|
+
elif tensor.dtype == np.int32:
|
|
130
|
+
ref_name = "TensorViewS32" + layout_tag
|
|
131
|
+
elif tensor.dtype == np.int8:
|
|
132
|
+
ref_name = "TensorViewS8" + layout_tag
|
|
133
|
+
else:
|
|
134
|
+
raise ValueError("unsupported datatype")
|
|
135
|
+
|
|
136
|
+
return getattr(cutlass, ref_name)(tensor_ref, tensor_coord)
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
class GemmUniversalLauncher:
|
|
140
|
+
def __init__(self, operation: 'GemmOperationUniversal', seed: int = 2080, interleaved=False,
|
|
141
|
+
verification=True, profiling=False, warmup_iterations=500, iterations=500, **kwargs) -> None:
|
|
142
|
+
# create the reduction kernel
|
|
143
|
+
self.reduction_operation: ReductionOperation = ReductionOperation(
|
|
144
|
+
shape=cutlass.MatrixCoord(4, 32 * operation.C.alignment),
|
|
145
|
+
C=operation.C, element_accumulator=operation.tile_description.math_instruction.element_accumulator,
|
|
146
|
+
element_compute=operation.epilogue_functor.element_epilogue, epilogue_functor=operation.epilogue_functor,
|
|
147
|
+
count=operation.C.alignment
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
self.math_operation = operation.tile_description.math_instruction.math_operation
|
|
151
|
+
|
|
152
|
+
#: verify the output result
|
|
153
|
+
self.verification = verification
|
|
154
|
+
#: profile the kernel's runtime
|
|
155
|
+
self.profiling = profiling
|
|
156
|
+
|
|
157
|
+
self.timer = GpuTimer()
|
|
158
|
+
|
|
159
|
+
self.warmup_iterations = warmup_iterations
|
|
160
|
+
self.iterations = iterations
|
|
161
|
+
|
|
162
|
+
if "sleep" in kwargs.keys():
|
|
163
|
+
self.sleep_time = kwargs["sleep"]
|
|
164
|
+
else:
|
|
165
|
+
self.sleep_time = 0
|
|
166
|
+
|
|
167
|
+
#
|
|
168
|
+
# Compile the operator
|
|
169
|
+
#
|
|
170
|
+
|
|
171
|
+
pycutlass.compiler.add_module([operation, self.reduction_operation])
|
|
172
|
+
|
|
173
|
+
self.operation = operation
|
|
174
|
+
|
|
175
|
+
self.dtype_A = GemmUniversalLauncher.numpy_type(operation.A.element)
|
|
176
|
+
self.dtype_B = GemmUniversalLauncher.numpy_type(operation.B.element)
|
|
177
|
+
self.dtype_C = GemmUniversalLauncher.numpy_type(operation.C.element)
|
|
178
|
+
self.dtype_D = GemmUniversalLauncher.numpy_type(operation.C.element)
|
|
179
|
+
|
|
180
|
+
accumulator_size = DataTypeSize[operation.tile_description.math_instruction.element_accumulator]
|
|
181
|
+
element_size = DataTypeSize[operation.A.element]
|
|
182
|
+
|
|
183
|
+
if element_size == 1:
|
|
184
|
+
self.scope_max = 1
|
|
185
|
+
self.scope_min = 0
|
|
186
|
+
elif element_size <= 8:
|
|
187
|
+
self.scope_max = 1
|
|
188
|
+
self.scope_min = -1
|
|
189
|
+
elif element_size == 16:
|
|
190
|
+
self.scope_max = 4
|
|
191
|
+
self.scope_min = -4
|
|
192
|
+
else:
|
|
193
|
+
self.scope_max = 8
|
|
194
|
+
self.scope_min = -8
|
|
195
|
+
|
|
196
|
+
#: seed
|
|
197
|
+
self.seed: int = seed
|
|
198
|
+
|
|
199
|
+
#: whether the layout is interleaved
|
|
200
|
+
self.interleaved = interleaved
|
|
201
|
+
|
|
202
|
+
#: compute type
|
|
203
|
+
self.compute_type = operation.epilogue_functor.element_epilogue
|
|
204
|
+
self.accumulator_type = operation.tile_description.math_instruction.element_accumulator
|
|
205
|
+
|
|
206
|
+
def print_problem_size(self, p, mode, batch_count):
|
|
207
|
+
if mode == cutlass.gemm.Mode.Gemm:
|
|
208
|
+
mode = "Gemm"
|
|
209
|
+
elif mode == cutlass.gemm.Mode.GemmSplitKParallel:
|
|
210
|
+
mode = "GemmSplitKParalel"
|
|
211
|
+
problem_size = "problem: %d, %d, %d\n batch_count: %d\n mode: %s" % (
|
|
212
|
+
p.m(), p.n(), p.k(), batch_count, mode)
|
|
213
|
+
print(problem_size)
|
|
214
|
+
|
|
215
|
+
@staticmethod
|
|
216
|
+
def numpy_type(type):
|
|
217
|
+
if type == cutlass.float64:
|
|
218
|
+
return np.float64
|
|
219
|
+
elif type == cutlass.float32:
|
|
220
|
+
return np.float32
|
|
221
|
+
elif type == cutlass.float16:
|
|
222
|
+
return np.float16
|
|
223
|
+
elif type == cutlass.bfloat16:
|
|
224
|
+
return bfloat16
|
|
225
|
+
elif type == cutlass.int32:
|
|
226
|
+
return np.int32
|
|
227
|
+
elif type == cutlass.int8:
|
|
228
|
+
return np.int8
|
|
229
|
+
else:
|
|
230
|
+
raise ValueError("unsupported type: %s" % ShortDataTypeNames[type])
|
|
231
|
+
|
|
232
|
+
def uniform_init(self, size, dtype):
|
|
233
|
+
if dtype in [np.float32, np.float16, bfloat16, np.float64]:
|
|
234
|
+
return np.ceil(
|
|
235
|
+
np.random.uniform(
|
|
236
|
+
low=self.scope_min - 0.5, high=self.scope_max - 0.5,
|
|
237
|
+
size=size).astype(dtype)
|
|
238
|
+
)
|
|
239
|
+
else:
|
|
240
|
+
return np.random.uniform(
|
|
241
|
+
low=self.scope_min - 1, high=self.scope_max + 1,
|
|
242
|
+
size=size).astype(dtype)
|
|
243
|
+
|
|
244
|
+
def reorder_tensor_B(self, tensor_B, problem_size):
|
|
245
|
+
reordered_tensor_B = np.empty_like(tensor_B)
|
|
246
|
+
tensor_ref_B = getTensorRef(
|
|
247
|
+
tensor_B, problem_size, "b", self.operation.B.layout)
|
|
248
|
+
reordered_tensor_ref_B = getTensorRef(
|
|
249
|
+
reordered_tensor_B, problem_size, "b", self.operation.B.layout)
|
|
250
|
+
cutlass.gemm.host.reorder_column(
|
|
251
|
+
tensor_ref_B, reordered_tensor_ref_B, problem_size)
|
|
252
|
+
return reordered_tensor_B
|
|
253
|
+
|
|
254
|
+
def host_reference(self, problem_size, tensor_A, tensor_B, tensor_C, alpha, beta):
|
|
255
|
+
# TODO
|
|
256
|
+
tensor_D_ref = np.ones_like(tensor_C)
|
|
257
|
+
alpha = self.numpy_type(self.compute_type)(alpha)
|
|
258
|
+
beta = self.numpy_type(self.compute_type)(beta)
|
|
259
|
+
init_acc = 0
|
|
260
|
+
|
|
261
|
+
alpha = self.compute_type(alpha).value()
|
|
262
|
+
beta = self.compute_type(beta).value()
|
|
263
|
+
init_acc = self.accumulator_type(init_acc).value()
|
|
264
|
+
|
|
265
|
+
if self.operation.switched:
|
|
266
|
+
tensor_ref_A = getTensorRef(
|
|
267
|
+
tensor_A, problem_size, "a", transpose(self.operation.B.layout))
|
|
268
|
+
tensor_ref_B = getTensorRef(
|
|
269
|
+
tensor_B, problem_size, "b", transpose(self.operation.A.layout))
|
|
270
|
+
tensor_ref_C = getTensorRef(
|
|
271
|
+
tensor_C, problem_size, "c", transpose(self.operation.C.layout))
|
|
272
|
+
tensor_ref_D_ref = getTensorRef(
|
|
273
|
+
tensor_D_ref, problem_size, "d", transpose(self.operation.C.layout))
|
|
274
|
+
else:
|
|
275
|
+
tensor_ref_A = getTensorRef(
|
|
276
|
+
tensor_A, problem_size, "a", self.operation.A.layout)
|
|
277
|
+
tensor_ref_B = getTensorRef(
|
|
278
|
+
tensor_B, problem_size, "b", self.operation.B.layout)
|
|
279
|
+
tensor_ref_C = getTensorRef(
|
|
280
|
+
tensor_C, problem_size, "c", self.operation.C.layout)
|
|
281
|
+
tensor_ref_D_ref = getTensorRef(
|
|
282
|
+
tensor_D_ref, problem_size, "d", self.operation.C.layout)
|
|
283
|
+
|
|
284
|
+
if self.math_operation in [MathOperation.multiply_add_saturate]:
|
|
285
|
+
cutlass.test.gemm.host.gemm_saturate(
|
|
286
|
+
problem_size, alpha, tensor_ref_A, tensor_ref_B, beta, tensor_ref_C, tensor_ref_D_ref, init_acc)
|
|
287
|
+
else:
|
|
288
|
+
cutlass.test.gemm.host.gemm(problem_size, alpha, tensor_ref_A,
|
|
289
|
+
tensor_ref_B, beta, tensor_ref_C, tensor_ref_D_ref, init_acc)
|
|
290
|
+
|
|
291
|
+
return tensor_D_ref
|
|
292
|
+
|
|
293
|
+
def equal(self, tensor_D, tensor_D_ref, problem_size):
|
|
294
|
+
|
|
295
|
+
tensor_view_D = getTensorView(
|
|
296
|
+
tensor_D, problem_size, "d", self.operation.C.layout)
|
|
297
|
+
tensor_view_D_ref = getTensorView(
|
|
298
|
+
tensor_D_ref, problem_size, "d", self.operation.C.layout)
|
|
299
|
+
|
|
300
|
+
return cutlass.test.gemm.host.equals(tensor_view_D, tensor_view_D_ref)
|
|
301
|
+
|
|
302
|
+
def bytes(self, problem_size, batch_count=1, alpha=1.0, beta=0.0):
|
|
303
|
+
m = problem_size.m()
|
|
304
|
+
n = problem_size.n()
|
|
305
|
+
k = problem_size.k()
|
|
306
|
+
|
|
307
|
+
bytes = \
|
|
308
|
+
(DataTypeSize[self.operation.A.element] * m // 8) * k + \
|
|
309
|
+
(DataTypeSize[self.operation.B.element] * n // 8) * k + \
|
|
310
|
+
(DataTypeSize[self.operation.C.element] * m // 8) * n
|
|
311
|
+
|
|
312
|
+
if beta != 0:
|
|
313
|
+
bytes += (DataTypeSize[self.operation.C.element] * m // 8) * n
|
|
314
|
+
|
|
315
|
+
bytes *= batch_count
|
|
316
|
+
|
|
317
|
+
return bytes
|
|
318
|
+
|
|
319
|
+
def flops(self, problem_size, batch_count=1):
|
|
320
|
+
m = problem_size.m()
|
|
321
|
+
n = problem_size.n()
|
|
322
|
+
k = problem_size.k()
|
|
323
|
+
|
|
324
|
+
flops_ = (m * n * k + m * n) * 2 * batch_count
|
|
325
|
+
|
|
326
|
+
# TODO: complex
|
|
327
|
+
return flops_
|
|
328
|
+
|
|
329
|
+
def run_cutlass_profiler(self, mode, problem_size, batch_count=1, alpha=1.0, beta=0.0):
|
|
330
|
+
|
|
331
|
+
cutlass_path = os.getenv('CUTLASS_PATH')
|
|
332
|
+
assert cutlass_path is not None, "Environment variable 'CUTLASS_PATH' is not defined."
|
|
333
|
+
|
|
334
|
+
values = {
|
|
335
|
+
"profiler_path": cutlass_path + "/build/tools/profiler/cutlass_profiler",
|
|
336
|
+
"kernel_name": self.operation.procedural_name(),
|
|
337
|
+
"verification_providers": "device",
|
|
338
|
+
"provider": "cutlass",
|
|
339
|
+
"m": str(problem_size.m()),
|
|
340
|
+
"n": str(problem_size.n()),
|
|
341
|
+
"k": str(problem_size.k()),
|
|
342
|
+
'split_k_slices': str(batch_count),
|
|
343
|
+
'alpha': str(alpha),
|
|
344
|
+
'beta': str(beta),
|
|
345
|
+
'warmup': str(self.warmup_iterations),
|
|
346
|
+
'profile': str(self.iterations)
|
|
347
|
+
}
|
|
348
|
+
|
|
349
|
+
cmd_template = \
|
|
350
|
+
"${profiler_path} --kernels=${kernel_name} --verification-providers=${verification_providers}" \
|
|
351
|
+
" --providers=${provider} --m=${m} --n=${n} --k=${k}"
|
|
352
|
+
|
|
353
|
+
cmd = SubstituteTemplate(cmd_template, values)
|
|
354
|
+
result = subprocess.getoutput(cmd)
|
|
355
|
+
|
|
356
|
+
m = re.search(r"Runtime:\s+(?P<runtime>\d+.\d+)", result)
|
|
357
|
+
runtime = float(m.group('runtime'))
|
|
358
|
+
|
|
359
|
+
m = re.search(r"Bytes:\s+(?P<bytes>\d+)", result)
|
|
360
|
+
bytes = int(m.group('bytes'))
|
|
361
|
+
|
|
362
|
+
m = re.search(r"FLOPs:\s+(?P<flops>\d+)", result)
|
|
363
|
+
flops = int(m.group('flops'))
|
|
364
|
+
|
|
365
|
+
# check if the problem size matches
|
|
366
|
+
assert bytes == self.bytes(problem_size, alpha, beta)
|
|
367
|
+
assert flops == self.flops(problem_size)
|
|
368
|
+
|
|
369
|
+
return runtime
|
|
370
|
+
|
|
371
|
+
def run(self, mode, problem_size, batch_count=1, alpha=1.0, beta=0.0):
|
|
372
|
+
|
|
373
|
+
assert get_allocated_size(
|
|
374
|
+
) == 0, "%d byte of pool memory is not released in previous run" % get_allocated_size()
|
|
375
|
+
|
|
376
|
+
np.random.seed(self.seed)
|
|
377
|
+
|
|
378
|
+
tensor_A = self.uniform_init(
|
|
379
|
+
size=(problem_size.m() * problem_size.k(),), dtype=self.dtype_A)
|
|
380
|
+
tensor_B = self.uniform_init(
|
|
381
|
+
size=(problem_size.n() * problem_size.k(),), dtype=self.dtype_B)
|
|
382
|
+
tensor_C = self.uniform_init(
|
|
383
|
+
size=(problem_size.m() * problem_size.n(),), dtype=self.dtype_C)
|
|
384
|
+
tensor_D = np.zeros(
|
|
385
|
+
shape=(problem_size.m() * problem_size.n(),), dtype=self.dtype_D)
|
|
386
|
+
|
|
387
|
+
#
|
|
388
|
+
# Launch kernel
|
|
389
|
+
#
|
|
390
|
+
|
|
391
|
+
arguments = GemmArguments(
|
|
392
|
+
operation=self.operation, problem_size=problem_size,
|
|
393
|
+
A=tensor_A, B=tensor_B, C=tensor_C, D=tensor_D,
|
|
394
|
+
output_op=self.operation.epilogue_type(alpha, beta),
|
|
395
|
+
gemm_mode=mode, split_k_slices=batch_count
|
|
396
|
+
)
|
|
397
|
+
|
|
398
|
+
if mode == cutlass.gemm.Mode.GemmSplitKParallel:
|
|
399
|
+
reduction_arguments = ReductionArguments(
|
|
400
|
+
self.reduction_operation, problem_size=[
|
|
401
|
+
problem_size.m(), problem_size.n()],
|
|
402
|
+
partitions=batch_count,
|
|
403
|
+
workspace=arguments.ptr_D,
|
|
404
|
+
destination=tensor_D,
|
|
405
|
+
source=tensor_C,
|
|
406
|
+
output_op=self.reduction_operation.epilogue_type(alpha, beta)
|
|
407
|
+
)
|
|
408
|
+
|
|
409
|
+
self.operation.run(arguments)
|
|
410
|
+
|
|
411
|
+
if mode == cutlass.gemm.Mode.GemmSplitKParallel:
|
|
412
|
+
self.reduction_operation.run(reduction_arguments)
|
|
413
|
+
|
|
414
|
+
passed = True
|
|
415
|
+
|
|
416
|
+
if self.verification:
|
|
417
|
+
if mode == cutlass.gemm.Mode.GemmSplitKParallel:
|
|
418
|
+
reduction_arguments.sync()
|
|
419
|
+
else:
|
|
420
|
+
arguments.sync()
|
|
421
|
+
tensor_D_ref = self.host_reference(
|
|
422
|
+
problem_size, tensor_A, tensor_B, tensor_C, alpha, beta)
|
|
423
|
+
passed = self.equal(tensor_D, tensor_D_ref, problem_size)
|
|
424
|
+
|
|
425
|
+
try:
|
|
426
|
+
assert passed
|
|
427
|
+
except AssertionError:
|
|
428
|
+
self.print_problem_size(problem_size, mode, batch_count)
|
|
429
|
+
|
|
430
|
+
if self.profiling:
|
|
431
|
+
sleep(self.sleep_time)
|
|
432
|
+
for _ in range(self.warmup_iterations):
|
|
433
|
+
self.operation.run(arguments)
|
|
434
|
+
if mode == cutlass.gemm.Mode.GemmSplitKParallel:
|
|
435
|
+
self.reduction_operation.run(reduction_arguments)
|
|
436
|
+
|
|
437
|
+
self.timer.start()
|
|
438
|
+
for _ in range(self.iterations):
|
|
439
|
+
self.operation.run(arguments)
|
|
440
|
+
if mode == cutlass.gemm.Mode.GemmSplitKParallel:
|
|
441
|
+
self.reduction_operation.run(reduction_arguments)
|
|
442
|
+
self.timer.stop_and_wait()
|
|
443
|
+
|
|
444
|
+
runtime = self.timer.duration(self.iterations)
|
|
445
|
+
|
|
446
|
+
# free memory and clear buffers
|
|
447
|
+
del arguments
|
|
448
|
+
if mode == cutlass.gemm.Mode.GemmSplitKParallel:
|
|
449
|
+
del reduction_arguments
|
|
450
|
+
|
|
451
|
+
assert get_allocated_size(
|
|
452
|
+
) == 0, "%d byte of pool memory is not released after current run" % get_allocated_size()
|
|
453
|
+
|
|
454
|
+
if self.profiling:
|
|
455
|
+
return runtime
|
|
456
|
+
return passed
|
|
457
|
+
|
|
458
|
+
|
|
459
|
+
def test_all_gemm(operation: 'GemmOperationUniversal', testcase="universal"):
|
|
460
|
+
|
|
461
|
+
passed = True
|
|
462
|
+
|
|
463
|
+
minimum_operand_element_size = min(
|
|
464
|
+
DataTypeSize[operation.A.element], DataTypeSize[operation.B.element])
|
|
465
|
+
opcode_class = operation.tile_description.math_instruction.opcode_class
|
|
466
|
+
|
|
467
|
+
if opcode_class == cutlass.OpClass.Simt:
|
|
468
|
+
alignment = 1
|
|
469
|
+
else:
|
|
470
|
+
alignment = 128 // minimum_operand_element_size
|
|
471
|
+
|
|
472
|
+
# int8_t gemm alignment constrainst
|
|
473
|
+
if opcode_class == cutlass.OpClass.Simt and operation.A.element == cutlass.int8 and operation.A.layout == cutlass.ColumnMajor:
|
|
474
|
+
alignment_m = 4
|
|
475
|
+
else:
|
|
476
|
+
alignment_m = alignment
|
|
477
|
+
|
|
478
|
+
if opcode_class == cutlass.OpClass.Simt and operation.B.element == cutlass.int8 and operation.A.layout == cutlass.RowMajor:
|
|
479
|
+
alignment_n = 4
|
|
480
|
+
else:
|
|
481
|
+
alignment_n = alignment
|
|
482
|
+
|
|
483
|
+
if opcode_class == cutlass.OpClass.Simt and operation.A.element == cutlass.int8 \
|
|
484
|
+
and operation.B.element == cutlass.int8 \
|
|
485
|
+
and (operation.A.layout == cutlass.RowMajor or operation.B.layout == cutlass.ColumnMajor):
|
|
486
|
+
|
|
487
|
+
alignment_k = 4
|
|
488
|
+
else:
|
|
489
|
+
alignment_k = alignment
|
|
490
|
+
|
|
491
|
+
threadblock_k = operation.tile_description.threadblock_shape[2]
|
|
492
|
+
|
|
493
|
+
if testcase == "interleaved":
|
|
494
|
+
if operation.A.layout in [cutlass.ColumnMajorInterleaved32, cutlass.RowMajorInterleaved32]:
|
|
495
|
+
interleavedk = 32
|
|
496
|
+
else:
|
|
497
|
+
raise ValueError("unknonw layout")
|
|
498
|
+
|
|
499
|
+
if testcase == "interleaved":
|
|
500
|
+
modes = [cutlass.gemm.Mode.Gemm, ]
|
|
501
|
+
problem_size_m = [interleavedk, 512+interleavedk]
|
|
502
|
+
problem_size_n = [interleavedk, 512+interleavedk]
|
|
503
|
+
problem_size_k = [interleavedk, threadblock_k *
|
|
504
|
+
operation.tile_description.stages + interleavedk]
|
|
505
|
+
problem_alpha = [1.0]
|
|
506
|
+
problem_beta = [0.0]
|
|
507
|
+
batch_counts = [1, ]
|
|
508
|
+
elif testcase == "multistage":
|
|
509
|
+
modes = [cutlass.gemm.Mode.Gemm, ]
|
|
510
|
+
problem_size_m = [16, 528]
|
|
511
|
+
problem_size_n = [16, 528]
|
|
512
|
+
problem_size_k = [threadblock_k, threadblock_k * operation.tile_description.stages +
|
|
513
|
+
operation.tile_description.math_instruction.instruction_shape[2]]
|
|
514
|
+
problem_alpha = [1.0]
|
|
515
|
+
problem_beta = [0.0]
|
|
516
|
+
batch_counts = [1, ]
|
|
517
|
+
else: # universal
|
|
518
|
+
modes = [cutlass.gemm.Mode.Gemm, cutlass.gemm.Mode.GemmSplitKParallel]
|
|
519
|
+
problem_size_m = [alignment_m, 512 - 3 * alignment_m]
|
|
520
|
+
problem_size_n = [alignment_n, 512 - 2 * alignment_n]
|
|
521
|
+
problem_size_k = [
|
|
522
|
+
alignment_k,
|
|
523
|
+
threadblock_k * operation.tile_description.stages - alignment_k,
|
|
524
|
+
threadblock_k * operation.tile_description.stages * 3 - alignment_k]
|
|
525
|
+
batch_counts = [1, 2, 3, 5, 7]
|
|
526
|
+
problem_alpha = [1.0]
|
|
527
|
+
problem_beta = [2.0]
|
|
528
|
+
|
|
529
|
+
testbed = GemmUniversalLauncher(
|
|
530
|
+
operation, interleaved=(testcase == "interleaved"))
|
|
531
|
+
|
|
532
|
+
for mode in modes:
|
|
533
|
+
for m in problem_size_m:
|
|
534
|
+
for n in problem_size_n:
|
|
535
|
+
for k in problem_size_k:
|
|
536
|
+
for batch_count in batch_counts:
|
|
537
|
+
for alpha in problem_alpha:
|
|
538
|
+
for beta in problem_beta:
|
|
539
|
+
# skip very small K problems
|
|
540
|
+
if testcase == "universal":
|
|
541
|
+
if (k // batch_count < 2 * threadblock_k):
|
|
542
|
+
continue
|
|
543
|
+
|
|
544
|
+
problem_size = cutlass.gemm.GemmCoord(m, n, k)
|
|
545
|
+
|
|
546
|
+
passed = testbed.run(
|
|
547
|
+
mode, problem_size, batch_count, alpha, beta)
|
|
548
|
+
|
|
549
|
+
err, = cudart.cudaDeviceSynchronize()
|
|
550
|
+
if err != cuda.CUresult.CUDA_SUCCESS:
|
|
551
|
+
raise RuntimeError(
|
|
552
|
+
"CUDA Error %s" % str(err))
|
|
553
|
+
|
|
554
|
+
if not passed:
|
|
555
|
+
return False
|
|
556
|
+
|
|
557
|
+
return passed
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
#################################################################################################
|
|
2
|
+
#
|
|
3
|
+
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
4
|
+
# SPDX-License-Identifier: BSD-3-Clause
|
|
5
|
+
#
|
|
6
|
+
# Redistribution and use in source and binary forms, with or without
|
|
7
|
+
# modification, are permitted provided that the following conditions are met:
|
|
8
|
+
#
|
|
9
|
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
|
10
|
+
# list of conditions and the following disclaimer.
|
|
11
|
+
#
|
|
12
|
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
13
|
+
# this list of conditions and the following disclaimer in the documentation
|
|
14
|
+
# and/or other materials provided with the distribution.
|
|
15
|
+
#
|
|
16
|
+
# 3. Neither the name of the copyright holder nor the names of its
|
|
17
|
+
# contributors may be used to endorse or promote products derived from
|
|
18
|
+
# this software without specific prior written permission.
|
|
19
|
+
#
|
|
20
|
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
21
|
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
22
|
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
23
|
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
24
|
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
25
|
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
26
|
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
27
|
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
28
|
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
29
|
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
30
|
+
#
|
|
31
|
+
#################################################################################################
|
|
32
|
+
|
|
33
|
+
from cuda import cuda
|
|
34
|
+
from cuda import cudart
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class GpuTimer:
|
|
38
|
+
def __init__(self) -> None:
|
|
39
|
+
self.events = [
|
|
40
|
+
cuda.cuEventCreate(cuda.CUevent_flags.CU_EVENT_DEFAULT)[1],
|
|
41
|
+
cuda.cuEventCreate(cuda.CUevent_flags.CU_EVENT_DEFAULT)[1]
|
|
42
|
+
]
|
|
43
|
+
|
|
44
|
+
def start(self, stream=cuda.CUstream(0)):
|
|
45
|
+
err, = cuda.cuEventRecord(self.events[0], stream)
|
|
46
|
+
if err != cuda.CUresult.CUDA_SUCCESS:
|
|
47
|
+
raise RuntimeError("CUDA Error %s" % str(err))
|
|
48
|
+
|
|
49
|
+
def stop(self, stream=cuda.CUstream(0)):
|
|
50
|
+
err, = cuda.cuEventRecord(self.events[1], stream)
|
|
51
|
+
if err != cuda.CUresult.CUDA_SUCCESS:
|
|
52
|
+
raise RuntimeError("CUDA Error %s" % str(err))
|
|
53
|
+
pass
|
|
54
|
+
|
|
55
|
+
def stop_and_wait(self, stream=cuda.CUstream(0)):
|
|
56
|
+
self.stop(stream)
|
|
57
|
+
if stream:
|
|
58
|
+
err, = cuda.cuStreamSynchronize(stream)
|
|
59
|
+
if err != cuda.CUresult.CUDA_SUCCESS:
|
|
60
|
+
raise RuntimeError("CUDA Error %s" % str(err))
|
|
61
|
+
else:
|
|
62
|
+
err, = cudart.cudaDeviceSynchronize()
|
|
63
|
+
if err != cuda.CUresult.CUDA_SUCCESS:
|
|
64
|
+
raise RuntimeError("CUDA Error %s" % str(err))
|
|
65
|
+
|
|
66
|
+
def duration(self, iterations=1):
|
|
67
|
+
err, duration = cuda.cuEventElapsedTime(self.events[0], self.events[1])
|
|
68
|
+
if err != cuda.CUresult.CUDA_SUCCESS:
|
|
69
|
+
raise RuntimeError("CUDA Error %s" % str(err))
|
|
70
|
+
return duration / float(iterations)
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
################################################################################
|
|
2
|
+
#
|
|
3
|
+
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
4
|
+
# SPDX-License-Identifier: BSD-3-Clause
|
|
5
|
+
#
|
|
6
|
+
# Redistribution and use in source and binary forms, with or without
|
|
7
|
+
# modification, are permitted provided that the following conditions are met:
|
|
8
|
+
#
|
|
9
|
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
|
10
|
+
# list of conditions and the following disclaimer.
|
|
11
|
+
#
|
|
12
|
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
13
|
+
# this list of conditions and the following disclaimer in the documentation
|
|
14
|
+
# and/or other materials provided with the distribution.
|
|
15
|
+
#
|
|
16
|
+
# 3. Neither the name of the copyright holder nor the names of its
|
|
17
|
+
# contributors may be used to endorse or promote products derived from
|
|
18
|
+
# this software without specific prior written permission.
|
|
19
|
+
#
|
|
20
|
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
21
|
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
22
|
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
23
|
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
24
|
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
25
|
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
26
|
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
27
|
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
28
|
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
29
|
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
30
|
+
#
|
|
31
|
+
################################################################################
|
|
32
|
+
|
|
33
|
+
from typing import Union
|
|
34
|
+
from typeguard import typechecked
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
GemmOperation = 'Union[GemmOperationUniversal, GemmOperationGrouped]'
|
|
38
|
+
|
|
39
|
+
Tensor = 'Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]'
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from pycutlass.utils.reference_model import *
|