warp-lang 0.9.0__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +15 -7
- warp/__init__.pyi +1 -0
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +22 -443
- warp/build_dll.py +384 -0
- warp/builtins.py +998 -488
- warp/codegen.py +1307 -739
- warp/config.py +5 -3
- warp/constants.py +6 -0
- warp/context.py +1291 -548
- warp/dlpack.py +31 -31
- warp/fabric.py +326 -0
- warp/fem/__init__.py +27 -0
- warp/fem/cache.py +389 -0
- warp/fem/dirichlet.py +181 -0
- warp/fem/domain.py +263 -0
- warp/fem/field/__init__.py +101 -0
- warp/fem/field/field.py +149 -0
- warp/fem/field/nodal_field.py +299 -0
- warp/fem/field/restriction.py +21 -0
- warp/fem/field/test.py +181 -0
- warp/fem/field/trial.py +183 -0
- warp/fem/geometry/__init__.py +19 -0
- warp/fem/geometry/closest_point.py +70 -0
- warp/fem/geometry/deformed_geometry.py +271 -0
- warp/fem/geometry/element.py +744 -0
- warp/fem/geometry/geometry.py +186 -0
- warp/fem/geometry/grid_2d.py +373 -0
- warp/fem/geometry/grid_3d.py +435 -0
- warp/fem/geometry/hexmesh.py +953 -0
- warp/fem/geometry/partition.py +376 -0
- warp/fem/geometry/quadmesh_2d.py +532 -0
- warp/fem/geometry/tetmesh.py +840 -0
- warp/fem/geometry/trimesh_2d.py +577 -0
- warp/fem/integrate.py +1616 -0
- warp/fem/operator.py +191 -0
- warp/fem/polynomial.py +213 -0
- warp/fem/quadrature/__init__.py +2 -0
- warp/fem/quadrature/pic_quadrature.py +245 -0
- warp/fem/quadrature/quadrature.py +294 -0
- warp/fem/space/__init__.py +292 -0
- warp/fem/space/basis_space.py +489 -0
- warp/fem/space/collocated_function_space.py +105 -0
- warp/fem/space/dof_mapper.py +236 -0
- warp/fem/space/function_space.py +145 -0
- warp/fem/space/grid_2d_function_space.py +267 -0
- warp/fem/space/grid_3d_function_space.py +306 -0
- warp/fem/space/hexmesh_function_space.py +352 -0
- warp/fem/space/partition.py +350 -0
- warp/fem/space/quadmesh_2d_function_space.py +369 -0
- warp/fem/space/restriction.py +160 -0
- warp/fem/space/shape/__init__.py +15 -0
- warp/fem/space/shape/cube_shape_function.py +738 -0
- warp/fem/space/shape/shape_function.py +103 -0
- warp/fem/space/shape/square_shape_function.py +611 -0
- warp/fem/space/shape/tet_shape_function.py +567 -0
- warp/fem/space/shape/triangle_shape_function.py +429 -0
- warp/fem/space/tetmesh_function_space.py +292 -0
- warp/fem/space/topology.py +295 -0
- warp/fem/space/trimesh_2d_function_space.py +221 -0
- warp/fem/types.py +77 -0
- warp/fem/utils.py +495 -0
- warp/native/array.h +164 -55
- warp/native/builtin.h +150 -174
- warp/native/bvh.cpp +75 -328
- warp/native/bvh.cu +406 -23
- warp/native/bvh.h +37 -45
- warp/native/clang/clang.cpp +136 -24
- warp/native/crt.cpp +1 -76
- warp/native/crt.h +111 -104
- warp/native/cuda_crt.h +1049 -0
- warp/native/cuda_util.cpp +15 -3
- warp/native/cuda_util.h +3 -1
- warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
- warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
- warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
- warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
- warp/native/cutlass/tools/library/scripts/library.py +799 -0
- warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
- warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
- warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
- warp/native/cutlass/tools/library/scripts/rt.py +796 -0
- warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
- warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
- warp/native/cutlass_gemm.cu +5 -3
- warp/native/exports.h +1240 -949
- warp/native/fabric.h +228 -0
- warp/native/hashgrid.cpp +4 -4
- warp/native/hashgrid.h +22 -2
- warp/native/initializer_array.h +2 -2
- warp/native/intersect.h +22 -7
- warp/native/intersect_adj.h +8 -8
- warp/native/intersect_tri.h +13 -16
- warp/native/marching.cu +157 -161
- warp/native/mat.h +119 -19
- warp/native/matnn.h +2 -2
- warp/native/mesh.cpp +108 -83
- warp/native/mesh.cu +243 -6
- warp/native/mesh.h +1547 -458
- warp/native/nanovdb/NanoVDB.h +1 -1
- warp/native/noise.h +272 -329
- warp/native/quat.h +51 -8
- warp/native/rand.h +45 -35
- warp/native/range.h +6 -2
- warp/native/reduce.cpp +157 -0
- warp/native/reduce.cu +348 -0
- warp/native/runlength_encode.cpp +62 -0
- warp/native/runlength_encode.cu +46 -0
- warp/native/scan.cu +11 -13
- warp/native/scan.h +1 -0
- warp/native/solid_angle.h +442 -0
- warp/native/sort.cpp +13 -0
- warp/native/sort.cu +9 -1
- warp/native/sparse.cpp +338 -0
- warp/native/sparse.cu +545 -0
- warp/native/spatial.h +2 -2
- warp/native/temp_buffer.h +30 -0
- warp/native/vec.h +126 -24
- warp/native/volume.h +120 -0
- warp/native/warp.cpp +658 -53
- warp/native/warp.cu +660 -68
- warp/native/warp.h +112 -12
- warp/optim/__init__.py +1 -0
- warp/optim/linear.py +922 -0
- warp/optim/sgd.py +92 -0
- warp/render/render_opengl.py +392 -152
- warp/render/render_usd.py +11 -11
- warp/sim/__init__.py +2 -2
- warp/sim/articulation.py +385 -185
- warp/sim/collide.py +21 -8
- warp/sim/import_mjcf.py +297 -106
- warp/sim/import_urdf.py +389 -210
- warp/sim/import_usd.py +198 -97
- warp/sim/inertia.py +17 -18
- warp/sim/integrator_euler.py +14 -8
- warp/sim/integrator_xpbd.py +161 -19
- warp/sim/model.py +795 -291
- warp/sim/optimizer.py +2 -6
- warp/sim/render.py +65 -3
- warp/sim/utils.py +3 -0
- warp/sparse.py +1227 -0
- warp/stubs.py +665 -223
- warp/tape.py +66 -15
- warp/tests/__main__.py +3 -6
- warp/tests/assets/curlnoise_golden.npy +0 -0
- warp/tests/assets/pnoise_golden.npy +0 -0
- warp/tests/assets/torus.usda +105 -105
- warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
- warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
- warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
- warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
- warp/tests/aux_test_unresolved_func.py +14 -0
- warp/tests/aux_test_unresolved_symbol.py +14 -0
- warp/tests/disabled_kinematics.py +239 -0
- warp/tests/run_coverage_serial.py +31 -0
- warp/tests/test_adam.py +103 -106
- warp/tests/test_arithmetic.py +128 -74
- warp/tests/test_array.py +1497 -211
- warp/tests/test_array_reduce.py +150 -0
- warp/tests/test_atomic.py +64 -28
- warp/tests/test_bool.py +99 -0
- warp/tests/test_builtins_resolution.py +1292 -0
- warp/tests/test_bvh.py +75 -43
- warp/tests/test_closest_point_edge_edge.py +54 -57
- warp/tests/test_codegen.py +233 -128
- warp/tests/test_compile_consts.py +28 -20
- warp/tests/test_conditional.py +108 -24
- warp/tests/test_copy.py +10 -12
- warp/tests/test_ctypes.py +112 -88
- warp/tests/test_dense.py +21 -14
- warp/tests/test_devices.py +98 -0
- warp/tests/test_dlpack.py +136 -108
- warp/tests/test_examples.py +277 -0
- warp/tests/test_fabricarray.py +955 -0
- warp/tests/test_fast_math.py +15 -11
- warp/tests/test_fem.py +1271 -0
- warp/tests/test_fp16.py +53 -19
- warp/tests/test_func.py +187 -74
- warp/tests/test_generics.py +194 -49
- warp/tests/test_grad.py +180 -116
- warp/tests/test_grad_customs.py +176 -0
- warp/tests/test_hash_grid.py +52 -37
- warp/tests/test_import.py +10 -23
- warp/tests/test_indexedarray.py +577 -24
- warp/tests/test_intersect.py +18 -9
- warp/tests/test_large.py +141 -0
- warp/tests/test_launch.py +251 -15
- warp/tests/test_lerp.py +64 -65
- warp/tests/test_linear_solvers.py +154 -0
- warp/tests/test_lvalue.py +493 -0
- warp/tests/test_marching_cubes.py +12 -13
- warp/tests/test_mat.py +508 -2778
- warp/tests/test_mat_lite.py +115 -0
- warp/tests/test_mat_scalar_ops.py +2889 -0
- warp/tests/test_math.py +103 -9
- warp/tests/test_matmul.py +305 -69
- warp/tests/test_matmul_lite.py +410 -0
- warp/tests/test_mesh.py +71 -14
- warp/tests/test_mesh_query_aabb.py +41 -25
- warp/tests/test_mesh_query_point.py +325 -34
- warp/tests/test_mesh_query_ray.py +39 -22
- warp/tests/test_mlp.py +30 -22
- warp/tests/test_model.py +92 -89
- warp/tests/test_modules_lite.py +39 -0
- warp/tests/test_multigpu.py +88 -114
- warp/tests/test_noise.py +12 -11
- warp/tests/test_operators.py +16 -20
- warp/tests/test_options.py +11 -11
- warp/tests/test_pinned.py +17 -18
- warp/tests/test_print.py +32 -11
- warp/tests/test_quat.py +275 -129
- warp/tests/test_rand.py +18 -16
- warp/tests/test_reload.py +38 -34
- warp/tests/test_rounding.py +50 -43
- warp/tests/test_runlength_encode.py +190 -0
- warp/tests/test_smoothstep.py +9 -11
- warp/tests/test_snippet.py +143 -0
- warp/tests/test_sparse.py +460 -0
- warp/tests/test_spatial.py +276 -243
- warp/tests/test_streams.py +110 -85
- warp/tests/test_struct.py +331 -85
- warp/tests/test_tape.py +39 -21
- warp/tests/test_torch.py +118 -89
- warp/tests/test_transient_module.py +12 -13
- warp/tests/test_types.py +614 -0
- warp/tests/test_utils.py +494 -0
- warp/tests/test_vec.py +354 -1987
- warp/tests/test_vec_lite.py +73 -0
- warp/tests/test_vec_scalar_ops.py +2099 -0
- warp/tests/test_volume.py +457 -293
- warp/tests/test_volume_write.py +124 -134
- warp/tests/unittest_serial.py +35 -0
- warp/tests/unittest_suites.py +341 -0
- warp/tests/unittest_utils.py +568 -0
- warp/tests/unused_test_misc.py +71 -0
- warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
- warp/thirdparty/appdirs.py +36 -45
- warp/thirdparty/unittest_parallel.py +549 -0
- warp/torch.py +72 -30
- warp/types.py +1744 -713
- warp/utils.py +360 -350
- warp_lang-0.11.0.dist-info/LICENSE.md +36 -0
- warp_lang-0.11.0.dist-info/METADATA +238 -0
- warp_lang-0.11.0.dist-info/RECORD +332 -0
- {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
- warp/bin/warp-clang.exp +0 -0
- warp/bin/warp-clang.lib +0 -0
- warp/bin/warp.exp +0 -0
- warp/bin/warp.lib +0 -0
- warp/tests/test_all.py +0 -215
- warp/tests/test_array_scan.py +0 -60
- warp/tests/test_base.py +0 -208
- warp/tests/test_unresolved_func.py +0 -7
- warp/tests/test_unresolved_symbol.py +0 -7
- warp_lang-0.9.0.dist-info/METADATA +0 -20
- warp_lang-0.9.0.dist-info/RECORD +0 -177
- /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
- /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
- /warp/tests/{test_square.py → aux_test_square.py} +0 -0
- {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,631 @@
|
|
|
1
|
+
################################################################################
|
|
2
|
+
#
|
|
3
|
+
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved
|
|
4
|
+
# SPDX-License-Identifier: BSD-3-Clause
|
|
5
|
+
#
|
|
6
|
+
# Redistribution and use in source and binary forms, with or without
|
|
7
|
+
# modification, are permitted provided that the following conditions are met:
|
|
8
|
+
#
|
|
9
|
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
|
10
|
+
# list of conditions and the following disclaimer.
|
|
11
|
+
#
|
|
12
|
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
13
|
+
# this list of conditions and the following disclaimer in the documentation
|
|
14
|
+
# and/or other materials provided with the distribution.
|
|
15
|
+
#
|
|
16
|
+
# 3. Neither the name of the copyright holder nor the names of its
|
|
17
|
+
# contributors may be used to endorse or promote products derived from
|
|
18
|
+
# this software without specific prior written permission.
|
|
19
|
+
#
|
|
20
|
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
21
|
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
22
|
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
23
|
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
24
|
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
25
|
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
26
|
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
27
|
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
28
|
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
29
|
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
30
|
+
#
|
|
31
|
+
################################################################################
|
|
32
|
+
from typeguard import typechecked
|
|
33
|
+
from cuda import cuda
|
|
34
|
+
from typing import Union
|
|
35
|
+
import numpy as np
|
|
36
|
+
|
|
37
|
+
from typeguard import typechecked
|
|
38
|
+
|
|
39
|
+
from pycutlass import *
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
# @typechecked
|
|
43
|
+
class Conv2dArguments(ArgumentBase):
|
|
44
|
+
"""
|
|
45
|
+
Argument wrapper for Conv2d. It encodes problem information and
|
|
46
|
+
user-provide tensors into the kernel's argument.
|
|
47
|
+
|
|
48
|
+
:param operation: the Conv2d operation to take the argument
|
|
49
|
+
:type operation: :class:`pycutlass.Conv2dOperation`
|
|
50
|
+
|
|
51
|
+
:param problem_size: the Conv2d problem size
|
|
52
|
+
:type problem_size: :class:`cutlass.conv.Conv2dProblemSize`
|
|
53
|
+
|
|
54
|
+
:param A: tensor A
|
|
55
|
+
:type A: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
|
|
56
|
+
|
|
57
|
+
:param B: tensor B
|
|
58
|
+
:type B: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
|
|
59
|
+
|
|
60
|
+
:param C: tensor C
|
|
61
|
+
:type C: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
|
|
62
|
+
|
|
63
|
+
:param D: tensor D
|
|
64
|
+
:type D: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
|
|
65
|
+
|
|
66
|
+
:param split_k_mode: conv2d split K mode, defaults to
|
|
67
|
+
cutlass.conv.SplitKMode.Serial
|
|
68
|
+
:type split_k_mode: cutlass.conv.SplitKMode, optional
|
|
69
|
+
|
|
70
|
+
:param output_op: output operator, optional
|
|
71
|
+
:type output_op: :class:`pycutlass.LinearCombinationFunctorArguments`
|
|
72
|
+
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
def __init__(self, operation: 'Conv2dOperation',
|
|
76
|
+
problem_size: 'cutlass.conv.Conv2dProblemSize',
|
|
77
|
+
A: 'Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor]',
|
|
78
|
+
B: 'Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor]',
|
|
79
|
+
C: 'Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor]',
|
|
80
|
+
D: 'Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor]',
|
|
81
|
+
split_k_mode: 'cutlass.conv.SplitKMode'
|
|
82
|
+
= cutlass.conv.SplitKMode.Serial, **kwargs) -> None:
|
|
83
|
+
|
|
84
|
+
self.operation = operation
|
|
85
|
+
#: convolution kind
|
|
86
|
+
self.conv_kind: cutlass.conv.Operator = operation.conv_kind
|
|
87
|
+
self.layout_A: cutlass.layout = operation.A.layout
|
|
88
|
+
self.layout_B: cutlass.layout = operation.B.layout
|
|
89
|
+
self.layout_C: cutlass.layout = operation.C.layout
|
|
90
|
+
|
|
91
|
+
self.element_A = operation.A.element
|
|
92
|
+
self.element_B = operation.B.element
|
|
93
|
+
self.element_C = operation.C.element
|
|
94
|
+
|
|
95
|
+
if self.layout_C == cutlass.TensorNC32HW32:
|
|
96
|
+
B = self.reorder_tensor_B(B, problem_size)
|
|
97
|
+
|
|
98
|
+
super().__init__(A, B, C, D, **kwargs)
|
|
99
|
+
# preprocessing output ops
|
|
100
|
+
|
|
101
|
+
if 'output_op' in kwargs.keys() and \
|
|
102
|
+
split_k_mode != cutlass.conv.SplitKMode.Parallel:
|
|
103
|
+
self.output_op = kwargs['output_op']
|
|
104
|
+
else:
|
|
105
|
+
self.output_op = self.operation.epilogue_type(1.0, 0.0)
|
|
106
|
+
|
|
107
|
+
if "split_k_slices" in kwargs.keys():
|
|
108
|
+
self.split_k_mode = split_k_mode
|
|
109
|
+
self.split_k_slices = kwargs["split_k_slices"]
|
|
110
|
+
else:
|
|
111
|
+
self.split_k_mode = cutlass.conv.SplitKMode.Serial
|
|
112
|
+
self.split_k_slices = 1
|
|
113
|
+
|
|
114
|
+
#: problem_size
|
|
115
|
+
self.problem_size: cutlass.conv.Conv2dProblemSize = problem_size
|
|
116
|
+
self.problem_size.split_k_slices = self.split_k_slices
|
|
117
|
+
|
|
118
|
+
if hasattr(self, "tensor_c_numel"):
|
|
119
|
+
c_coord = cutlass.conv.implicit_gemm_tensor_c_extent(
|
|
120
|
+
self.conv_kind, problem_size)
|
|
121
|
+
if (self.tensor_c_numel == c_coord.at(3) and
|
|
122
|
+
self.tensor_c_numel < c_coord.size()):
|
|
123
|
+
self.bias = True
|
|
124
|
+
|
|
125
|
+
#
|
|
126
|
+
# initialize the argument
|
|
127
|
+
#
|
|
128
|
+
self.initialize()
|
|
129
|
+
|
|
130
|
+
# @typechecked
|
|
131
|
+
def reorder_tensor_B(self, tensor_B: 'np.ndarray',
|
|
132
|
+
problem_size: 'cutlass.conv.Conv2dProblemSize'):
|
|
133
|
+
"""
|
|
134
|
+
Reorder tensor_B for interleaved layout
|
|
135
|
+
|
|
136
|
+
:param tensor_B: input tensor B
|
|
137
|
+
:type tensor_B: numpy.ndarray
|
|
138
|
+
:param problem_size: Conv2d problem size
|
|
139
|
+
:type problem_size: :class:`cutlass.conv.Conv2dProblemSize`
|
|
140
|
+
|
|
141
|
+
:return: reordered tensor B
|
|
142
|
+
:rtype: numpy.ndarray
|
|
143
|
+
"""
|
|
144
|
+
reordered_tensor_B = np.empty_like(tensor_B)
|
|
145
|
+
tensor_ref_B = self.get_tensor_ref(
|
|
146
|
+
tensor_B, self.element_B, self.layout_B, problem_size, "b")
|
|
147
|
+
reordered_tensor_ref_B = self.get_tensor_ref(
|
|
148
|
+
reordered_tensor_B, self.element_B,
|
|
149
|
+
self.layout_B, problem_size, "b")
|
|
150
|
+
cutlass.conv.host.reorder_convK(
|
|
151
|
+
reordered_tensor_ref_B, tensor_ref_B, self.conv_kind, problem_size)
|
|
152
|
+
|
|
153
|
+
return reordered_tensor_B
|
|
154
|
+
|
|
155
|
+
def get_tensor_ref(
|
|
156
|
+
self, tensor, dtype, tensor_layout, problem_size, operand):
|
|
157
|
+
if operand == "a":
|
|
158
|
+
tensor_coord = cutlass.conv.implicit_gemm_tensor_a_extent(
|
|
159
|
+
self.conv_kind, problem_size)
|
|
160
|
+
elif operand == "b":
|
|
161
|
+
tensor_coord = cutlass.conv.implicit_gemm_tensor_b_extent(
|
|
162
|
+
self.conv_kind, problem_size)
|
|
163
|
+
elif operand in ["c", "d"]:
|
|
164
|
+
tensor_coord = cutlass.conv.implicit_gemm_tensor_c_extent(
|
|
165
|
+
self.conv_kind, problem_size)
|
|
166
|
+
else:
|
|
167
|
+
raise ValueError("unknown operand: " + operand)
|
|
168
|
+
# Zero stride trick
|
|
169
|
+
if operand == "c" and self.bias:
|
|
170
|
+
tensor_coord = cutlass.Tensor4DCoord(0, 0, 0, 0)
|
|
171
|
+
|
|
172
|
+
layout = tensor_layout.packed(tensor_coord)
|
|
173
|
+
|
|
174
|
+
return TensorRef(tensor, dtype, layout).tensor_ref
|
|
175
|
+
|
|
176
|
+
def get_arguments(self, semaphore):
|
|
177
|
+
ref_A = TensorRef_(self.get_tensor_ref(
|
|
178
|
+
self.ptr_A, self.element_A, self.layout_A, self.problem_size, "a"))
|
|
179
|
+
ref_B = TensorRef_(self.get_tensor_ref(
|
|
180
|
+
self.ptr_B, self.element_B, self.layout_B, self.problem_size, "b"))
|
|
181
|
+
ref_C = TensorRef_(self.get_tensor_ref(
|
|
182
|
+
self.ptr_C, self.element_C, self.layout_C, self.problem_size, "c"))
|
|
183
|
+
ref_D = TensorRef_(self.get_tensor_ref(
|
|
184
|
+
self.ptr_D, self.element_C, self.layout_C, self.problem_size, "d"))
|
|
185
|
+
|
|
186
|
+
self.c_arguments = self.operation.argument_type(
|
|
187
|
+
Conv2DProblemSize(self.problem_size),
|
|
188
|
+
ref_A, ref_B, ref_C, ref_D, self.output_op, self.split_k_mode
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
self.semaphore = semaphore
|
|
192
|
+
|
|
193
|
+
def initialize(self):
|
|
194
|
+
"""
|
|
195
|
+
Initialize the kernel arguments handling following stuffs
|
|
196
|
+
1. get kernel launch configuration including grid, cta size,
|
|
197
|
+
and dynamic shared memory capacity
|
|
198
|
+
2. allocate and initialize device workspace
|
|
199
|
+
3. get kernel params as bytearray for NVRTC input
|
|
200
|
+
"""
|
|
201
|
+
# get launch configuration
|
|
202
|
+
self.launch_config = self.operation.rt_module.plan(self)
|
|
203
|
+
|
|
204
|
+
# allocate and initialize device workspace
|
|
205
|
+
device_workspace_size = \
|
|
206
|
+
self.operation.rt_module.get_device_workspace_size(self)
|
|
207
|
+
|
|
208
|
+
if device_workspace_size > 0:
|
|
209
|
+
self.workspace_buffer = device_mem_alloc(device_workspace_size)
|
|
210
|
+
workspace_ptr = self.workspace_buffer.ptr
|
|
211
|
+
err, = cuda.cuMemsetD32(
|
|
212
|
+
workspace_ptr, 0, device_workspace_size // 4)
|
|
213
|
+
else:
|
|
214
|
+
workspace_ptr = None
|
|
215
|
+
|
|
216
|
+
# get kernel params as bytearray
|
|
217
|
+
semaphore = 0
|
|
218
|
+
if workspace_ptr is not None and \
|
|
219
|
+
self.split_k_mode == cutlass.conv.SplitKMode.Parallel:
|
|
220
|
+
self.ptr_D = workspace_ptr
|
|
221
|
+
elif workspace_ptr is not None and \
|
|
222
|
+
self.split_k_mode == cutlass.conv.SplitKMode.Serial:
|
|
223
|
+
semaphore = workspace_ptr
|
|
224
|
+
|
|
225
|
+
self.get_arguments(semaphore)
|
|
226
|
+
|
|
227
|
+
params_ = self.operation.rt_module.get_args(ctypes.byref(
|
|
228
|
+
self.c_arguments), ctypes.c_void_p(int(self.semaphore)))
|
|
229
|
+
self.host_workspace = bytearray(params_.contents)
|
|
230
|
+
self.device_workspace = None
|
|
231
|
+
|
|
232
|
+
def sync(self):
|
|
233
|
+
"""
|
|
234
|
+
Synchronize the arguments. If the input tensor is in host,
|
|
235
|
+
copy it from device to host.
|
|
236
|
+
"""
|
|
237
|
+
return super().sync()
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
# @typechecked
|
|
241
|
+
class Conv2dRT(ExecutableOperation):
|
|
242
|
+
"""
|
|
243
|
+
Conv2dRT manages the CUTLASS runtime components
|
|
244
|
+
"""
|
|
245
|
+
KernelTemplate = r'''
|
|
246
|
+
extern "C"
|
|
247
|
+
__global__ void
|
|
248
|
+
${operation_name}(${operation_name}${operation_suffix}::Params params) {
|
|
249
|
+
|
|
250
|
+
// Dynamic shared memory base pointer
|
|
251
|
+
extern __shared__ int SharedStorageBase[];
|
|
252
|
+
|
|
253
|
+
// Declare pointer to dynamic shared memory.
|
|
254
|
+
${operation_name}${operation_suffix}::SharedStorage *shared_storage =
|
|
255
|
+
reinterpret_cast<${operation_name}${operation_suffix}::SharedStorage *>(SharedStorageBase);
|
|
256
|
+
|
|
257
|
+
${operation_name}${operation_suffix} op;
|
|
258
|
+
|
|
259
|
+
op(params, *shared_storage);
|
|
260
|
+
}
|
|
261
|
+
'''
|
|
262
|
+
|
|
263
|
+
HostTemplate = r'''
|
|
264
|
+
extern "C" {
|
|
265
|
+
// Get the size of params in bytes
|
|
266
|
+
int ${operation_name}_get_param_size(){
|
|
267
|
+
return sizeof(${operation_name}${operation_suffix}::Params);
|
|
268
|
+
}
|
|
269
|
+
|
|
270
|
+
// Get the size of dynamic shared memory in bytes
|
|
271
|
+
int ${operation_name}_shared_memory_size() {
|
|
272
|
+
return int(sizeof(${operation_name}${operation_suffix}::SharedStorage));
|
|
273
|
+
}
|
|
274
|
+
|
|
275
|
+
// Get the params as byte array
|
|
276
|
+
char* ${operation_name}_get_params(${operation_name}${operation_suffix}::Arguments* arguments, int *semaphore=nullptr){
|
|
277
|
+
typename ${operation_name}${operation_suffix}::Params* params;
|
|
278
|
+
params = new ${operation_name}${operation_suffix}::Params(*arguments, semaphore);
|
|
279
|
+
|
|
280
|
+
char *bytes = ((char*)(params));
|
|
281
|
+
char *output = new char[sizeof(${operation_name}${operation_suffix}::Params)];
|
|
282
|
+
for (unsigned int i = 0; i < sizeof(${operation_name}${operation_suffix}::Params); i ++)
|
|
283
|
+
output[i] = bytes[i];
|
|
284
|
+
|
|
285
|
+
return output;
|
|
286
|
+
}
|
|
287
|
+
}
|
|
288
|
+
|
|
289
|
+
'''
|
|
290
|
+
|
|
291
|
+
def __init__(self, operation: 'Conv2dOperation'):
|
|
292
|
+
super().__init__(operation)
|
|
293
|
+
self.argument_type, self.epilogue_type = get_conv2d_arguments(operation.epilogue_functor)
|
|
294
|
+
self.argtype = [ctypes.POINTER(self.argument_type), ctypes.c_void_p]
|
|
295
|
+
self.conv_kind = operation.conv_kind
|
|
296
|
+
|
|
297
|
+
self.operation: Conv2dOperation = operation
|
|
298
|
+
|
|
299
|
+
self.emitter = EmitConv2dInstance('_type')
|
|
300
|
+
|
|
301
|
+
self.threads: int = operation.tile_description.num_threads
|
|
302
|
+
|
|
303
|
+
self.swizzle_functor = operation.swizzling_functor
|
|
304
|
+
|
|
305
|
+
def emit(self):
|
|
306
|
+
return self.emitter.emit(self.operation)
|
|
307
|
+
|
|
308
|
+
# @typechecked
|
|
309
|
+
def get_device_workspace_size(self, arguments: Conv2dArguments):
|
|
310
|
+
workspace_bytes = 0
|
|
311
|
+
|
|
312
|
+
launch_config = arguments.launch_config
|
|
313
|
+
|
|
314
|
+
self.conv_kind = self.operation.conv_kind
|
|
315
|
+
|
|
316
|
+
if arguments.split_k_mode == cutlass.conv.SplitKMode.Parallel:
|
|
317
|
+
problem_size = arguments.problem_size
|
|
318
|
+
workspace_bytes = DataTypeSize[self.operation.C.element] \
|
|
319
|
+
* launch_config.grid[2] * cutlass.conv.implicit_gemm_tensor_c_size(
|
|
320
|
+
self.conv_kind, problem_size
|
|
321
|
+
) // 8
|
|
322
|
+
elif arguments.split_k_mode == cutlass.conv.SplitKMode.Serial and \
|
|
323
|
+
arguments.split_k_slices > 1:
|
|
324
|
+
workspace_bytes = launch_config.grid[0] * launch_config.grid[1] * 4
|
|
325
|
+
|
|
326
|
+
return workspace_bytes
|
|
327
|
+
|
|
328
|
+
# @typechecked
|
|
329
|
+
def plan(self, arguments: Conv2dArguments):
|
|
330
|
+
tile_size = cutlass.gemm.GemmCoord(
|
|
331
|
+
self.operation.tile_description.threadblock_shape[0],
|
|
332
|
+
self.operation.tile_description.threadblock_shape[1],
|
|
333
|
+
self.operation.tile_description.threadblock_shape[2]
|
|
334
|
+
)
|
|
335
|
+
|
|
336
|
+
grid = self.swizzle_functor.get_grid_shape(
|
|
337
|
+
self.swizzle_functor.get_tiled_shape(
|
|
338
|
+
self.conv_kind, arguments.problem_size,
|
|
339
|
+
tile_size, arguments.split_k_slices
|
|
340
|
+
)
|
|
341
|
+
)
|
|
342
|
+
return LaunchConfiguration(
|
|
343
|
+
[grid.x, grid.y, grid.z], [self.threads, 1, 1],
|
|
344
|
+
self.shared_memory_capacity)
|
|
345
|
+
|
|
346
|
+
def initialize(self):
|
|
347
|
+
err, = cuda.cuFuncSetAttribute(
|
|
348
|
+
self.kernel,
|
|
349
|
+
attrib=cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
|
|
350
|
+
value=self.shared_memory_capacity)
|
|
351
|
+
if err != cuda.CUresult.CUDA_SUCCESS:
|
|
352
|
+
raise RuntimeError('Cuda Error: {}'.format(err))
|
|
353
|
+
|
|
354
|
+
#
|
|
355
|
+
|
|
356
|
+
|
|
357
|
+
class Conv2dOperation:
|
|
358
|
+
"""
|
|
359
|
+
CUTLASS Conv2d operation description.
|
|
360
|
+
|
|
361
|
+
:param conv_kind: convolution operator
|
|
362
|
+
:type conv_kind: :class:`cutlass.conv.Operator`
|
|
363
|
+
|
|
364
|
+
:param iterator_algorithm: Selects among several implementation
|
|
365
|
+
variants trading off performance with simplicity
|
|
366
|
+
:type iterator_algorithm: :class:`cutlass.conv.IteratorAlgorithm`
|
|
367
|
+
|
|
368
|
+
:param arch: GPU compute capability (sm_xx)
|
|
369
|
+
:type arch: int
|
|
370
|
+
|
|
371
|
+
:param tile_description: tile description
|
|
372
|
+
:type tile_description: :class:`pycutlass.TileDescription`
|
|
373
|
+
|
|
374
|
+
:param A: tensor A description
|
|
375
|
+
:type A: :class:`pycutlass.TensorDescription`
|
|
376
|
+
|
|
377
|
+
:param B: tensor B description
|
|
378
|
+
:type B: :class:`pycutlass.TensorDescription`
|
|
379
|
+
|
|
380
|
+
:param C: tensor C description
|
|
381
|
+
:type C: :class:`pycutlass.TensorDescription`
|
|
382
|
+
|
|
383
|
+
:param D: tensor D description
|
|
384
|
+
:type D: :class:`pycutlass.TensorDescription`
|
|
385
|
+
|
|
386
|
+
:param element_epilogue: element type for computation in epilogue \
|
|
387
|
+
:type element_epilogue: cutlass.int8 | cutlass.int32 | cutlass.float16 | \
|
|
388
|
+
cutlass.bfloat16 | cutlass.float32 | cutlass.float64
|
|
389
|
+
|
|
390
|
+
:param stride_support: distinguish among partial specializations that \
|
|
391
|
+
accelerate certain problems where convolution stride is unit \
|
|
392
|
+
:type stride_support: :class:`cutlass.conv.StrideSupport`
|
|
393
|
+
|
|
394
|
+
:param epilogue_functor: convolution epilogue functor
|
|
395
|
+
:type epilogue_functor: :class:`EpilogueFunctor`
|
|
396
|
+
|
|
397
|
+
:param swizzling_functor: threadblock swizzling functor
|
|
398
|
+
"""
|
|
399
|
+
#
|
|
400
|
+
|
|
401
|
+
def __init__(self,
|
|
402
|
+
conv_kind: cutlass.conv.Operator,
|
|
403
|
+
iterator_algorithm: cutlass.conv.IteratorAlgorithm,
|
|
404
|
+
arch: int, tile_description: TileDescription,
|
|
405
|
+
A: TensorDescription, B: TensorDescription, C: TensorDescription,
|
|
406
|
+
stride_support, epilogue_functor,
|
|
407
|
+
swizzling_functor=cutlass.IdentitySwizzle1):
|
|
408
|
+
|
|
409
|
+
self.operation_kind: OperationKind = OperationKind.Conv2d
|
|
410
|
+
self.arch: int = arch
|
|
411
|
+
self.tile_description: TileDescription = tile_description
|
|
412
|
+
self.conv_kind = conv_kind
|
|
413
|
+
self.A: TensorDescription = A
|
|
414
|
+
self.B: TensorDescription = B
|
|
415
|
+
self.C: TensorDescription = C
|
|
416
|
+
self.epilogue_functor = epilogue_functor
|
|
417
|
+
self.iterator_algorithm = iterator_algorithm
|
|
418
|
+
self.stride_support = stride_support
|
|
419
|
+
self.swizzling_functor = swizzling_functor()
|
|
420
|
+
|
|
421
|
+
self.rt_module: Conv2dRT = Conv2dRT(self)
|
|
422
|
+
self.argument_type = self.rt_module.argument_type
|
|
423
|
+
self.epilogue_type = self.rt_module.epilogue_type
|
|
424
|
+
|
|
425
|
+
def run(self, arguments: Conv2dArguments) -> cuda.CUresult:
|
|
426
|
+
"""
|
|
427
|
+
Launch the cuda kernel with input arguments
|
|
428
|
+
|
|
429
|
+
:param arguments: conv2d arguments
|
|
430
|
+
:type arguments: :class:`pycutlass.Conv2dArguments`
|
|
431
|
+
"""
|
|
432
|
+
|
|
433
|
+
# launch the kernel
|
|
434
|
+
err = self.rt_module.run(
|
|
435
|
+
arguments.host_workspace,
|
|
436
|
+
arguments.device_workspace,
|
|
437
|
+
arguments.launch_config)
|
|
438
|
+
|
|
439
|
+
if err != cuda.CUresult.CUDA_SUCCESS:
|
|
440
|
+
raise RuntimeError('CUDA Error %s' % str(err))
|
|
441
|
+
|
|
442
|
+
return err
|
|
443
|
+
|
|
444
|
+
#
|
|
445
|
+
# Get function name
|
|
446
|
+
#
|
|
447
|
+
|
|
448
|
+
def procedural_name(self):
|
|
449
|
+
''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
|
|
450
|
+
return self.configuration_name()
|
|
451
|
+
#
|
|
452
|
+
|
|
453
|
+
def configuration_name(self):
|
|
454
|
+
''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
|
|
455
|
+
|
|
456
|
+
opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class]
|
|
457
|
+
|
|
458
|
+
threadblock = "%dx%d_%dx%d" % (
|
|
459
|
+
self.tile_description.threadblock_shape[0],
|
|
460
|
+
self.tile_description.threadblock_shape[1],
|
|
461
|
+
self.tile_description.threadblock_shape[2],
|
|
462
|
+
self.tile_description.stages
|
|
463
|
+
)
|
|
464
|
+
|
|
465
|
+
if self.stride_support == StrideSupport.Unity:
|
|
466
|
+
configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_unity_stride_align${alignment}"
|
|
467
|
+
else:
|
|
468
|
+
configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_align${alignment}"
|
|
469
|
+
|
|
470
|
+
return SubstituteTemplate(
|
|
471
|
+
configuration_name,
|
|
472
|
+
{
|
|
473
|
+
'opcode_class': opcode_class_name,
|
|
474
|
+
'extended_name': self.extended_name(),
|
|
475
|
+
'threadblock': threadblock,
|
|
476
|
+
'layout': self.layout_name(),
|
|
477
|
+
'alignment': "%d" % self.A.alignment,
|
|
478
|
+
}
|
|
479
|
+
)
|
|
480
|
+
|
|
481
|
+
#
|
|
482
|
+
def extended_name(self):
|
|
483
|
+
''' Append data types if they differ from compute type. '''
|
|
484
|
+
if self.C.element != self.tile_description.math_instruction.element_accumulator and \
|
|
485
|
+
self.A.element != self.tile_description.math_instruction.element_accumulator:
|
|
486
|
+
extended_name = "${element_c}_${core_name}_${element_a}"
|
|
487
|
+
elif self.C.element == self.tile_description.math_instruction.element_accumulator and \
|
|
488
|
+
self.A.element != self.tile_description.math_instruction.element_accumulator:
|
|
489
|
+
extended_name = "${core_name}_${element_a}"
|
|
490
|
+
else:
|
|
491
|
+
extended_name = "${core_name}"
|
|
492
|
+
|
|
493
|
+
extended_name = SubstituteTemplate(extended_name, {
|
|
494
|
+
'element_a': DataTypeNames[self.A.element],
|
|
495
|
+
'element_c': DataTypeNames[self.C.element],
|
|
496
|
+
'core_name': self.core_name()
|
|
497
|
+
})
|
|
498
|
+
|
|
499
|
+
return extended_name
|
|
500
|
+
|
|
501
|
+
#
|
|
502
|
+
def layout_name(self):
|
|
503
|
+
return "%s" % (ShortLayoutTypeNames[self.A.layout])
|
|
504
|
+
|
|
505
|
+
#
|
|
506
|
+
def core_name(self):
|
|
507
|
+
''' The basic operation kind is prefixed with a letter indicating the accumulation type. '''
|
|
508
|
+
|
|
509
|
+
intermediate_type = ''
|
|
510
|
+
|
|
511
|
+
if self.tile_description.math_instruction.opcode_class == cutlass.OpClass.TensorOp:
|
|
512
|
+
inst_shape = "%d%d%d" % tuple(
|
|
513
|
+
self.tile_description.math_instruction.instruction_shape)
|
|
514
|
+
if self.tile_description.math_instruction.element_a != self.A.element and \
|
|
515
|
+
self.tile_description.math_instruction.element_a != self.accumulator_type():
|
|
516
|
+
intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a]
|
|
517
|
+
else:
|
|
518
|
+
inst_shape = ''
|
|
519
|
+
|
|
520
|
+
return "%s%s%s%s_%s" % (ShortDataTypeNames[self.accumulator_type()],
|
|
521
|
+
inst_shape, intermediate_type, ConvKindNames[self.conv_kind], IteratorAlgorithmNames[self.iterator_algorithm])
|
|
522
|
+
|
|
523
|
+
#
|
|
524
|
+
def is_complex(self):
|
|
525
|
+
complex_operators = [
|
|
526
|
+
MathOperation.multiply_add_complex,
|
|
527
|
+
MathOperation.multiply_add_complex_gaussian
|
|
528
|
+
]
|
|
529
|
+
return self.tile_description.math_instruction.math_operation in complex_operators
|
|
530
|
+
|
|
531
|
+
#
|
|
532
|
+
def accumulator_type(self):
|
|
533
|
+
accum = self.tile_description.math_instruction.element_accumulator
|
|
534
|
+
|
|
535
|
+
if self.is_complex():
|
|
536
|
+
return get_complex_from_real(accum)
|
|
537
|
+
|
|
538
|
+
return accum
|
|
539
|
+
|
|
540
|
+
|
|
541
|
+
###################################################################################################
|
|
542
|
+
#
|
|
543
|
+
# Emits single instances of a CUTLASS device-wide operator
|
|
544
|
+
#
|
|
545
|
+
###################################################################################################
|
|
546
|
+
|
|
547
|
+
class EmitConv2dInstance:
|
|
548
|
+
def __init__(self, operation_suffix=''):
|
|
549
|
+
self.operation_suffix = operation_suffix
|
|
550
|
+
self.includes = [
|
|
551
|
+
"cutlass/cutlass.h",
|
|
552
|
+
"cutlass/conv/kernel/default_conv2d_fprop.h",
|
|
553
|
+
"cutlass/conv/kernel/default_conv2d_dgrad.h",
|
|
554
|
+
"cutlass/conv/kernel/default_conv2d_wgrad.h"
|
|
555
|
+
]
|
|
556
|
+
self.template = """
|
|
557
|
+
// Conv2d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}"
|
|
558
|
+
using ${operation_name}_base =
|
|
559
|
+
typename cutlass::conv::kernel::DefaultConv2d${conv_kind_name}<
|
|
560
|
+
${element_a},
|
|
561
|
+
${layout_a},
|
|
562
|
+
${element_b},
|
|
563
|
+
${layout_b},
|
|
564
|
+
${element_c},
|
|
565
|
+
${layout_c},
|
|
566
|
+
${element_accumulator},
|
|
567
|
+
${opcode_class},
|
|
568
|
+
${arch},
|
|
569
|
+
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
|
570
|
+
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k} >,
|
|
571
|
+
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
|
572
|
+
${epilogue_functor},
|
|
573
|
+
${swizzling_functor}, // cutlass::gemm::threadblock::GemmSplitKIdentityThreadblockSwizzle<>,
|
|
574
|
+
${stages},
|
|
575
|
+
${math_operator},
|
|
576
|
+
${iterator_algorithm},
|
|
577
|
+
${stride_support},
|
|
578
|
+
${align_a},
|
|
579
|
+
${align_b}
|
|
580
|
+
>::Kernel;
|
|
581
|
+
|
|
582
|
+
struct ${operation_name}${operation_suffix}:
|
|
583
|
+
public ${operation_name}_base { };
|
|
584
|
+
|
|
585
|
+
"""
|
|
586
|
+
|
|
587
|
+
def emit(self, operation):
|
|
588
|
+
|
|
589
|
+
warp_shape = [int(operation.tile_description.threadblock_shape[idx] /
|
|
590
|
+
operation.tile_description.warp_count[idx]) for idx in range(3)]
|
|
591
|
+
|
|
592
|
+
epilogue_vector_length = int(min(
|
|
593
|
+
operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element])
|
|
594
|
+
|
|
595
|
+
values = {
|
|
596
|
+
'operation_name': operation.procedural_name(),
|
|
597
|
+
'operation_suffix': self.operation_suffix,
|
|
598
|
+
'conv_kind': ConvKindTag[operation.conv_kind],
|
|
599
|
+
'conv_kind_name': ConvKindNames[operation.conv_kind].capitalize(),
|
|
600
|
+
'element_a': DataTypeTag[operation.A.element],
|
|
601
|
+
'layout_a': LayoutTag[operation.A.layout],
|
|
602
|
+
'element_b': DataTypeTag[operation.B.element],
|
|
603
|
+
'layout_b': LayoutTag[operation.B.layout],
|
|
604
|
+
'element_c': DataTypeTag[operation.C.element],
|
|
605
|
+
'layout_c': LayoutTag[operation.C.layout],
|
|
606
|
+
'element_accumulator': DataTypeTag[operation.accumulator_type()],
|
|
607
|
+
'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
|
|
608
|
+
'arch': "cutlass::arch::Sm%d" % operation.arch,
|
|
609
|
+
'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
|
|
610
|
+
'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
|
|
611
|
+
'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
|
|
612
|
+
'warp_shape_m': str(warp_shape[0]),
|
|
613
|
+
'warp_shape_n': str(warp_shape[1]),
|
|
614
|
+
'warp_shape_k': str(warp_shape[2]),
|
|
615
|
+
'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
|
|
616
|
+
'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
|
|
617
|
+
'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
|
|
618
|
+
'epilogue_vector_length': str(epilogue_vector_length),
|
|
619
|
+
'epilogue_functor': operation.epilogue_functor.emit(),
|
|
620
|
+
'swizzling_functor': operation.swizzling_functor.tag(),
|
|
621
|
+
'stages': str(operation.tile_description.stages),
|
|
622
|
+
'iterator_algorithm': IteratorAlgorithmTag[operation.iterator_algorithm],
|
|
623
|
+
'iterator_algorithm_name': IteratorAlgorithmNames[operation.iterator_algorithm].capitalize(),
|
|
624
|
+
'stride_support': StrideSupportTag[operation.stride_support],
|
|
625
|
+
'math_operator': 'cutlass::arch::OpMultiplyAddComplex' if operation.is_complex() else
|
|
626
|
+
MathOperationTag[operation.tile_description.math_instruction.math_operation],
|
|
627
|
+
'align_a': str(operation.A.alignment),
|
|
628
|
+
'align_b': str(operation.B.alignment),
|
|
629
|
+
}
|
|
630
|
+
|
|
631
|
+
return SubstituteTemplate(self.template, values)
|