warp-lang 0.10.1__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +10 -4
- warp/__init__.pyi +1 -0
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +5 -3
- warp/build_dll.py +29 -9
- warp/builtins.py +868 -507
- warp/codegen.py +1074 -638
- warp/config.py +3 -3
- warp/constants.py +6 -0
- warp/context.py +715 -222
- warp/fabric.py +326 -0
- warp/fem/__init__.py +27 -0
- warp/fem/cache.py +389 -0
- warp/fem/dirichlet.py +181 -0
- warp/fem/domain.py +263 -0
- warp/fem/field/__init__.py +101 -0
- warp/fem/field/field.py +149 -0
- warp/fem/field/nodal_field.py +299 -0
- warp/fem/field/restriction.py +21 -0
- warp/fem/field/test.py +181 -0
- warp/fem/field/trial.py +183 -0
- warp/fem/geometry/__init__.py +19 -0
- warp/fem/geometry/closest_point.py +70 -0
- warp/fem/geometry/deformed_geometry.py +271 -0
- warp/fem/geometry/element.py +744 -0
- warp/fem/geometry/geometry.py +186 -0
- warp/fem/geometry/grid_2d.py +373 -0
- warp/fem/geometry/grid_3d.py +435 -0
- warp/fem/geometry/hexmesh.py +953 -0
- warp/fem/geometry/partition.py +376 -0
- warp/fem/geometry/quadmesh_2d.py +532 -0
- warp/fem/geometry/tetmesh.py +840 -0
- warp/fem/geometry/trimesh_2d.py +577 -0
- warp/fem/integrate.py +1616 -0
- warp/fem/operator.py +191 -0
- warp/fem/polynomial.py +213 -0
- warp/fem/quadrature/__init__.py +2 -0
- warp/fem/quadrature/pic_quadrature.py +245 -0
- warp/fem/quadrature/quadrature.py +294 -0
- warp/fem/space/__init__.py +292 -0
- warp/fem/space/basis_space.py +489 -0
- warp/fem/space/collocated_function_space.py +105 -0
- warp/fem/space/dof_mapper.py +236 -0
- warp/fem/space/function_space.py +145 -0
- warp/fem/space/grid_2d_function_space.py +267 -0
- warp/fem/space/grid_3d_function_space.py +306 -0
- warp/fem/space/hexmesh_function_space.py +352 -0
- warp/fem/space/partition.py +350 -0
- warp/fem/space/quadmesh_2d_function_space.py +369 -0
- warp/fem/space/restriction.py +160 -0
- warp/fem/space/shape/__init__.py +15 -0
- warp/fem/space/shape/cube_shape_function.py +738 -0
- warp/fem/space/shape/shape_function.py +103 -0
- warp/fem/space/shape/square_shape_function.py +611 -0
- warp/fem/space/shape/tet_shape_function.py +567 -0
- warp/fem/space/shape/triangle_shape_function.py +429 -0
- warp/fem/space/tetmesh_function_space.py +292 -0
- warp/fem/space/topology.py +295 -0
- warp/fem/space/trimesh_2d_function_space.py +221 -0
- warp/fem/types.py +77 -0
- warp/fem/utils.py +495 -0
- warp/native/array.h +147 -44
- warp/native/builtin.h +122 -149
- warp/native/bvh.cpp +73 -325
- warp/native/bvh.cu +406 -23
- warp/native/bvh.h +34 -43
- warp/native/clang/clang.cpp +13 -8
- warp/native/crt.h +2 -0
- warp/native/cuda_crt.h +5 -0
- warp/native/cuda_util.cpp +15 -3
- warp/native/cuda_util.h +3 -1
- warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
- warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
- warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
- warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
- warp/native/cutlass/tools/library/scripts/library.py +799 -0
- warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
- warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
- warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
- warp/native/cutlass/tools/library/scripts/rt.py +796 -0
- warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
- warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
- warp/native/cutlass_gemm.cu +5 -3
- warp/native/exports.h +1240 -952
- warp/native/fabric.h +228 -0
- warp/native/hashgrid.cpp +4 -4
- warp/native/hashgrid.h +22 -2
- warp/native/intersect.h +22 -7
- warp/native/intersect_adj.h +8 -8
- warp/native/intersect_tri.h +1 -1
- warp/native/marching.cu +157 -161
- warp/native/mat.h +80 -19
- warp/native/matnn.h +2 -2
- warp/native/mesh.cpp +33 -108
- warp/native/mesh.cu +114 -23
- warp/native/mesh.h +446 -46
- warp/native/noise.h +272 -329
- warp/native/quat.h +51 -8
- warp/native/rand.h +45 -35
- warp/native/range.h +6 -2
- warp/native/reduce.cpp +1 -1
- warp/native/reduce.cu +10 -12
- warp/native/runlength_encode.cu +6 -10
- warp/native/scan.cu +8 -11
- warp/native/sparse.cpp +4 -4
- warp/native/sparse.cu +164 -154
- warp/native/spatial.h +2 -2
- warp/native/temp_buffer.h +14 -30
- warp/native/vec.h +107 -23
- warp/native/volume.h +120 -0
- warp/native/warp.cpp +560 -30
- warp/native/warp.cu +431 -44
- warp/native/warp.h +13 -4
- warp/optim/__init__.py +1 -0
- warp/optim/linear.py +922 -0
- warp/optim/sgd.py +92 -0
- warp/render/render_opengl.py +335 -119
- warp/render/render_usd.py +11 -11
- warp/sim/__init__.py +2 -2
- warp/sim/articulation.py +385 -185
- warp/sim/collide.py +8 -0
- warp/sim/import_mjcf.py +297 -106
- warp/sim/import_urdf.py +389 -210
- warp/sim/import_usd.py +198 -97
- warp/sim/inertia.py +17 -18
- warp/sim/integrator_euler.py +14 -8
- warp/sim/integrator_xpbd.py +158 -16
- warp/sim/model.py +795 -291
- warp/sim/render.py +3 -3
- warp/sim/utils.py +3 -0
- warp/sparse.py +640 -150
- warp/stubs.py +606 -267
- warp/tape.py +61 -10
- warp/tests/__main__.py +3 -6
- warp/tests/assets/curlnoise_golden.npy +0 -0
- warp/tests/assets/pnoise_golden.npy +0 -0
- warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
- warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
- warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
- warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
- warp/tests/aux_test_unresolved_func.py +14 -0
- warp/tests/aux_test_unresolved_symbol.py +14 -0
- warp/tests/disabled_kinematics.py +239 -0
- warp/tests/run_coverage_serial.py +31 -0
- warp/tests/test_adam.py +103 -106
- warp/tests/test_arithmetic.py +128 -74
- warp/tests/test_array.py +212 -97
- warp/tests/test_array_reduce.py +57 -23
- warp/tests/test_atomic.py +64 -28
- warp/tests/test_bool.py +99 -0
- warp/tests/test_builtins_resolution.py +1292 -0
- warp/tests/test_bvh.py +42 -18
- warp/tests/test_closest_point_edge_edge.py +54 -57
- warp/tests/test_codegen.py +208 -130
- warp/tests/test_compile_consts.py +28 -20
- warp/tests/test_conditional.py +108 -24
- warp/tests/test_copy.py +10 -12
- warp/tests/test_ctypes.py +112 -88
- warp/tests/test_dense.py +21 -14
- warp/tests/test_devices.py +98 -0
- warp/tests/test_dlpack.py +75 -75
- warp/tests/test_examples.py +277 -0
- warp/tests/test_fabricarray.py +955 -0
- warp/tests/test_fast_math.py +15 -11
- warp/tests/test_fem.py +1271 -0
- warp/tests/test_fp16.py +53 -19
- warp/tests/test_func.py +187 -86
- warp/tests/test_generics.py +194 -49
- warp/tests/test_grad.py +178 -109
- warp/tests/test_grad_customs.py +176 -0
- warp/tests/test_hash_grid.py +52 -37
- warp/tests/test_import.py +10 -23
- warp/tests/test_indexedarray.py +32 -31
- warp/tests/test_intersect.py +18 -9
- warp/tests/test_large.py +141 -0
- warp/tests/test_launch.py +14 -41
- warp/tests/test_lerp.py +64 -65
- warp/tests/test_linear_solvers.py +154 -0
- warp/tests/test_lvalue.py +493 -0
- warp/tests/test_marching_cubes.py +12 -13
- warp/tests/test_mat.py +517 -2898
- warp/tests/test_mat_lite.py +115 -0
- warp/tests/test_mat_scalar_ops.py +2889 -0
- warp/tests/test_math.py +103 -9
- warp/tests/test_matmul.py +305 -69
- warp/tests/test_matmul_lite.py +410 -0
- warp/tests/test_mesh.py +71 -14
- warp/tests/test_mesh_query_aabb.py +41 -25
- warp/tests/test_mesh_query_point.py +140 -22
- warp/tests/test_mesh_query_ray.py +39 -22
- warp/tests/test_mlp.py +30 -22
- warp/tests/test_model.py +92 -89
- warp/tests/test_modules_lite.py +39 -0
- warp/tests/test_multigpu.py +88 -114
- warp/tests/test_noise.py +12 -11
- warp/tests/test_operators.py +16 -20
- warp/tests/test_options.py +11 -11
- warp/tests/test_pinned.py +17 -18
- warp/tests/test_print.py +32 -11
- warp/tests/test_quat.py +275 -129
- warp/tests/test_rand.py +18 -16
- warp/tests/test_reload.py +38 -34
- warp/tests/test_rounding.py +50 -43
- warp/tests/test_runlength_encode.py +168 -20
- warp/tests/test_smoothstep.py +9 -11
- warp/tests/test_snippet.py +143 -0
- warp/tests/test_sparse.py +261 -63
- warp/tests/test_spatial.py +276 -243
- warp/tests/test_streams.py +110 -85
- warp/tests/test_struct.py +268 -63
- warp/tests/test_tape.py +39 -21
- warp/tests/test_torch.py +118 -89
- warp/tests/test_transient_module.py +12 -13
- warp/tests/test_types.py +614 -0
- warp/tests/test_utils.py +494 -0
- warp/tests/test_vec.py +354 -2050
- warp/tests/test_vec_lite.py +73 -0
- warp/tests/test_vec_scalar_ops.py +2099 -0
- warp/tests/test_volume.py +457 -293
- warp/tests/test_volume_write.py +124 -134
- warp/tests/unittest_serial.py +35 -0
- warp/tests/unittest_suites.py +341 -0
- warp/tests/unittest_utils.py +568 -0
- warp/tests/unused_test_misc.py +71 -0
- warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
- warp/thirdparty/appdirs.py +36 -45
- warp/thirdparty/unittest_parallel.py +549 -0
- warp/torch.py +9 -6
- warp/types.py +1089 -366
- warp/utils.py +93 -387
- warp_lang-0.11.0.dist-info/METADATA +238 -0
- warp_lang-0.11.0.dist-info/RECORD +332 -0
- {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
- warp/tests/test_all.py +0 -219
- warp/tests/test_array_scan.py +0 -60
- warp/tests/test_base.py +0 -208
- warp/tests/test_unresolved_func.py +0 -7
- warp/tests/test_unresolved_symbol.py +0 -7
- warp_lang-0.10.1.dist-info/METADATA +0 -21
- warp_lang-0.10.1.dist-info/RECORD +0 -188
- /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
- /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
- /warp/tests/{test_square.py → aux_test_square.py} +0 -0
- {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/LICENSE.md +0 -0
- {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,796 @@
|
|
|
1
|
+
#################################################################################################
|
|
2
|
+
#
|
|
3
|
+
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
4
|
+
# SPDX-License-Identifier: BSD-3-Clause
|
|
5
|
+
#
|
|
6
|
+
# Redistribution and use in source and binary forms, with or without
|
|
7
|
+
# modification, are permitted provided that the following conditions are met:
|
|
8
|
+
#
|
|
9
|
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
|
10
|
+
# list of conditions and the following disclaimer.
|
|
11
|
+
#
|
|
12
|
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
13
|
+
# this list of conditions and the following disclaimer in the documentation
|
|
14
|
+
# and/or other materials provided with the distribution.
|
|
15
|
+
#
|
|
16
|
+
# 3. Neither the name of the copyright holder nor the names of its
|
|
17
|
+
# contributors may be used to endorse or promote products derived from
|
|
18
|
+
# this software without specific prior written permission.
|
|
19
|
+
#
|
|
20
|
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
21
|
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
22
|
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
23
|
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
24
|
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
25
|
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
26
|
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
27
|
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
28
|
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
29
|
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
30
|
+
#
|
|
31
|
+
#################################################################################################
|
|
32
|
+
|
|
33
|
+
# System imports
|
|
34
|
+
import struct
|
|
35
|
+
import io
|
|
36
|
+
import ctypes
|
|
37
|
+
|
|
38
|
+
# CUDA Python import
|
|
39
|
+
from cuda import cuda
|
|
40
|
+
from cuda import nvrtc
|
|
41
|
+
|
|
42
|
+
# CUTLASS imports
|
|
43
|
+
from library import *
|
|
44
|
+
from gemm_operation import EmitGemmUniversalInstance
|
|
45
|
+
|
|
46
|
+
#################################################################################################
|
|
47
|
+
#
|
|
48
|
+
# CUTLASS Py Runtime Components
|
|
49
|
+
#
|
|
50
|
+
#################################################################################################
|
|
51
|
+
|
|
52
|
+
#
|
|
53
|
+
def MaxAlignment(fmt):
|
|
54
|
+
align = 1
|
|
55
|
+
for x in fmt:
|
|
56
|
+
align = max(align, struct.calcsize(x))
|
|
57
|
+
return align
|
|
58
|
+
|
|
59
|
+
#
|
|
60
|
+
def AlignedOffset(offset, align):
|
|
61
|
+
remainder = (offset % align)
|
|
62
|
+
if remainder:
|
|
63
|
+
offset += (align - remainder)
|
|
64
|
+
return offset
|
|
65
|
+
|
|
66
|
+
#
|
|
67
|
+
def PackInteger(host_workspace, offset, value):
|
|
68
|
+
fmt = "i"
|
|
69
|
+
padding = AlignedOffset(offset, 4)
|
|
70
|
+
struct.pack_into(fmt, host_workspace, offset, value)
|
|
71
|
+
return padding + struct.calcsize(fmt)
|
|
72
|
+
|
|
73
|
+
#
|
|
74
|
+
def PackDevicePointer(host_workspace, offset, value):
|
|
75
|
+
fmt = "P"
|
|
76
|
+
offset = AlignedOffset(offset, 8)
|
|
77
|
+
struct.pack_into(fmt, host_workspace, offset, value)
|
|
78
|
+
return offset + struct.calcsize(fmt)
|
|
79
|
+
|
|
80
|
+
#
|
|
81
|
+
def ceil_div(a, b):
|
|
82
|
+
return -(a // -b)
|
|
83
|
+
|
|
84
|
+
#################################################################################################
|
|
85
|
+
|
|
86
|
+
#
|
|
87
|
+
class PitchLinearCoord:
|
|
88
|
+
def __init__(self, contiguous, strided):
|
|
89
|
+
self.contiguous = contiguous
|
|
90
|
+
self.strided = strided
|
|
91
|
+
|
|
92
|
+
#
|
|
93
|
+
class GemmCoord:
|
|
94
|
+
def __init__(self, m = 1, n = 1, k = 1):
|
|
95
|
+
self.m = m
|
|
96
|
+
self.n = n
|
|
97
|
+
self.k = k
|
|
98
|
+
self.fmt = "iii"
|
|
99
|
+
|
|
100
|
+
#
|
|
101
|
+
def ceil_div(self, rhs):
|
|
102
|
+
return GemmCoord(ceil_div(self.m, rhs.m), ceil_div(self.n, rhs.n), ceil_div(self.k, rhs.k))
|
|
103
|
+
|
|
104
|
+
#
|
|
105
|
+
def size(self):
|
|
106
|
+
return struct.calcsize(self.fmt)
|
|
107
|
+
|
|
108
|
+
#
|
|
109
|
+
def alignment(self):
|
|
110
|
+
return MaxAlignment(self.fmt)
|
|
111
|
+
|
|
112
|
+
#
|
|
113
|
+
def pack_into(self, host_workspace, offset):
|
|
114
|
+
|
|
115
|
+
offset = AlignedOffset(offset, 4)
|
|
116
|
+
|
|
117
|
+
struct.pack_into(
|
|
118
|
+
self.fmt,
|
|
119
|
+
host_workspace,
|
|
120
|
+
offset,
|
|
121
|
+
self.m, self.n, self.k)
|
|
122
|
+
|
|
123
|
+
return offset + self.size()
|
|
124
|
+
|
|
125
|
+
#
|
|
126
|
+
class TensorRef:
|
|
127
|
+
def __init__(self, pointer = None, layout = 0):
|
|
128
|
+
self.pointer = pointer
|
|
129
|
+
self.layout = layout
|
|
130
|
+
|
|
131
|
+
def __str__(self):
|
|
132
|
+
return "(%x, %d)" % (self.pointer._ptr, self.layout)
|
|
133
|
+
|
|
134
|
+
#################################################################################################
|
|
135
|
+
|
|
136
|
+
#
|
|
137
|
+
class PredicatedTileAccessIteratorDesc:
|
|
138
|
+
'''
|
|
139
|
+
'''
|
|
140
|
+
|
|
141
|
+
def __init__(
|
|
142
|
+
self,
|
|
143
|
+
element_size_bits,
|
|
144
|
+
advance_rank,
|
|
145
|
+
threadblock_shape,
|
|
146
|
+
threadmap_iterations,
|
|
147
|
+
threadmap_delta):
|
|
148
|
+
|
|
149
|
+
self.element_size_bits = element_size_bits
|
|
150
|
+
self.advance_rank = advance_rank
|
|
151
|
+
self.threadblock_shape = threadblock_shape
|
|
152
|
+
self.threadmap_iterations = threadmap_iterations
|
|
153
|
+
self.threadmap_delta = threadmap_delta
|
|
154
|
+
|
|
155
|
+
#
|
|
156
|
+
class PredicatedTileAccessIteratorParams:
|
|
157
|
+
'''
|
|
158
|
+
'''
|
|
159
|
+
#
|
|
160
|
+
def __init__(self, desc, label):
|
|
161
|
+
self.desc = desc
|
|
162
|
+
self.label = label
|
|
163
|
+
self.fmt = "qqqq"
|
|
164
|
+
#
|
|
165
|
+
def size(self):
|
|
166
|
+
return struct.calcsize(self.fmt)
|
|
167
|
+
|
|
168
|
+
#
|
|
169
|
+
def alignment(self):
|
|
170
|
+
return MaxAlignment(self.fmt)
|
|
171
|
+
|
|
172
|
+
#
|
|
173
|
+
def initialize(self, host_workspace, offset, stride):
|
|
174
|
+
|
|
175
|
+
offset = AlignedOffset(offset, self.alignment())
|
|
176
|
+
|
|
177
|
+
inc_strided = stride * \
|
|
178
|
+
self.desc.threadmap_delta.strided * \
|
|
179
|
+
self.desc.element_size_bits // 8
|
|
180
|
+
|
|
181
|
+
if self.desc.advance_rank:
|
|
182
|
+
inc_advance = self.desc.threadblock_shape.strided * \
|
|
183
|
+
stride * \
|
|
184
|
+
self.desc.element_size_bits // 8
|
|
185
|
+
else:
|
|
186
|
+
inc_advance = self.desc.threadblock_shape.contiguous * \
|
|
187
|
+
self.desc.element_size_bits // 8
|
|
188
|
+
|
|
189
|
+
inc_next = inc_advance - (self.desc.threadmap_iterations.strided - 1) * \
|
|
190
|
+
self.desc.threadmap_delta.strided * \
|
|
191
|
+
stride * \
|
|
192
|
+
self.desc.element_size_bits // 8
|
|
193
|
+
|
|
194
|
+
struct.pack_into(
|
|
195
|
+
self.fmt,
|
|
196
|
+
host_workspace,
|
|
197
|
+
offset,
|
|
198
|
+
stride, inc_strided, inc_next, inc_advance)
|
|
199
|
+
|
|
200
|
+
return offset + self.size()
|
|
201
|
+
#
|
|
202
|
+
|
|
203
|
+
#################################################################################################
|
|
204
|
+
|
|
205
|
+
#
|
|
206
|
+
class EpilogueTileDesc:
|
|
207
|
+
'''
|
|
208
|
+
'''
|
|
209
|
+
def __init__(self, column, row, group, cluster, tile):
|
|
210
|
+
self.column = column
|
|
211
|
+
self.row = row
|
|
212
|
+
self.group = group
|
|
213
|
+
self.cluster = cluster
|
|
214
|
+
self.tile = tile
|
|
215
|
+
|
|
216
|
+
#
|
|
217
|
+
class EpilogueThreadMap:
|
|
218
|
+
'''
|
|
219
|
+
'''
|
|
220
|
+
def __init__(self, threads, elements_per_access, element_size_bits, shape, iterations, delta, count):
|
|
221
|
+
self.threads = threads
|
|
222
|
+
self.elements_per_access = elements_per_access
|
|
223
|
+
self.element_size_bits = element_size_bits
|
|
224
|
+
self.shape = shape
|
|
225
|
+
self.iterations = iterations
|
|
226
|
+
self.delta = delta
|
|
227
|
+
self.count = count
|
|
228
|
+
pass
|
|
229
|
+
|
|
230
|
+
#
|
|
231
|
+
class EpilogueTileIteratorParams:
|
|
232
|
+
'''
|
|
233
|
+
'''
|
|
234
|
+
#
|
|
235
|
+
def __init__(self, desc, label):
|
|
236
|
+
self.desc = desc
|
|
237
|
+
self.label = label
|
|
238
|
+
self.fmt = "qqqqqqqq"
|
|
239
|
+
|
|
240
|
+
#
|
|
241
|
+
def size(self):
|
|
242
|
+
return struct.calcsize(self.fmt)
|
|
243
|
+
|
|
244
|
+
#
|
|
245
|
+
def alignment(self):
|
|
246
|
+
return MaxAlignment(self.fmt)
|
|
247
|
+
|
|
248
|
+
#
|
|
249
|
+
def initialize(self, host_workspace, offset, stride):
|
|
250
|
+
|
|
251
|
+
stride = stride * self.desc.element_size_bits // 8
|
|
252
|
+
|
|
253
|
+
offset = AlignedOffset(offset, self.alignment())
|
|
254
|
+
|
|
255
|
+
increment_row = stride * self.desc.delta.row
|
|
256
|
+
|
|
257
|
+
increment_group = stride * self.desc.delta.group \
|
|
258
|
+
- stride * self.desc.delta.row * (self.desc.iterations.row - 1)
|
|
259
|
+
|
|
260
|
+
increment_cluster = stride * self.desc.delta.cluster \
|
|
261
|
+
- stride * self.desc.delta.group * (self.desc.iterations.group - 1) \
|
|
262
|
+
- stride * self.desc.delta.row * (self.desc.iterations.row - 1)
|
|
263
|
+
|
|
264
|
+
advance_row = stride * self.desc.shape.row
|
|
265
|
+
|
|
266
|
+
advance_group = stride * \
|
|
267
|
+
(self.desc.shape.group - 1) * \
|
|
268
|
+
self.desc.shape.row * \
|
|
269
|
+
self.desc.count.row
|
|
270
|
+
|
|
271
|
+
advance_cluster = stride * \
|
|
272
|
+
self.desc.count.group * \
|
|
273
|
+
self.desc.shape.group * \
|
|
274
|
+
self.desc.count.row * \
|
|
275
|
+
self.desc.shape.row
|
|
276
|
+
|
|
277
|
+
advance_tile = stride * \
|
|
278
|
+
self.desc.shape.group * \
|
|
279
|
+
self.desc.shape.row * \
|
|
280
|
+
self.desc.shape.cluster * \
|
|
281
|
+
self.desc.shape.tile
|
|
282
|
+
|
|
283
|
+
struct.pack_into(
|
|
284
|
+
self.fmt, \
|
|
285
|
+
host_workspace, \
|
|
286
|
+
offset, \
|
|
287
|
+
stride, \
|
|
288
|
+
increment_row, increment_group, increment_cluster, \
|
|
289
|
+
advance_row, advance_group, advance_cluster, advance_tile)
|
|
290
|
+
|
|
291
|
+
return offset + self.size()
|
|
292
|
+
#
|
|
293
|
+
|
|
294
|
+
#################################################################################################
|
|
295
|
+
#
|
|
296
|
+
# Launch configuration
|
|
297
|
+
#
|
|
298
|
+
#################################################################################################
|
|
299
|
+
|
|
300
|
+
class LaunchConfiguration:
|
|
301
|
+
def __init__(self, grid = [1,1,1], block = [1,1,1], smem = 0):
|
|
302
|
+
self.grid = grid
|
|
303
|
+
self.block = block
|
|
304
|
+
self.shared_memory_capacity = smem
|
|
305
|
+
|
|
306
|
+
#################################################################################################
|
|
307
|
+
#
|
|
308
|
+
# Functors
|
|
309
|
+
#
|
|
310
|
+
#################################################################################################
|
|
311
|
+
|
|
312
|
+
#
|
|
313
|
+
class Functor:
|
|
314
|
+
def __init__(self):
|
|
315
|
+
self.decl = ''
|
|
316
|
+
self.definition = ''
|
|
317
|
+
self.fmt = ''
|
|
318
|
+
self.identifier = ''
|
|
319
|
+
|
|
320
|
+
#
|
|
321
|
+
def emit_declaration(self):
|
|
322
|
+
return self.decl
|
|
323
|
+
|
|
324
|
+
#
|
|
325
|
+
def emit_definition(self):
|
|
326
|
+
return self.definition
|
|
327
|
+
|
|
328
|
+
#
|
|
329
|
+
def size(self):
|
|
330
|
+
'''
|
|
331
|
+
Size of the packed Params structure
|
|
332
|
+
'''
|
|
333
|
+
return struct.calcsize(self.fmt)
|
|
334
|
+
|
|
335
|
+
#
|
|
336
|
+
def alignment(self):
|
|
337
|
+
return MaxAlignment(self.fmt)
|
|
338
|
+
|
|
339
|
+
#
|
|
340
|
+
def initialize(self, host_workspace, offset, arguments):
|
|
341
|
+
return offset + self.size()
|
|
342
|
+
|
|
343
|
+
#################################################################################################
|
|
344
|
+
|
|
345
|
+
#
|
|
346
|
+
class LinearCombinationFunctorArguments:
|
|
347
|
+
def __init__(self, alpha = 1.0, beta = 0.0):
|
|
348
|
+
self.alpha = alpha
|
|
349
|
+
self.beta = beta
|
|
350
|
+
self.alpha_ptr = 0
|
|
351
|
+
self.beta_ptr = 0
|
|
352
|
+
|
|
353
|
+
#
|
|
354
|
+
class LinearCombinationFunctor(Functor):
|
|
355
|
+
def __init__(self):
|
|
356
|
+
super().__init__()
|
|
357
|
+
|
|
358
|
+
self.decl = """
|
|
359
|
+
cutlass::epilogue::thread::LinearCombination<
|
|
360
|
+
float,
|
|
361
|
+
1,
|
|
362
|
+
float,
|
|
363
|
+
float
|
|
364
|
+
>"""
|
|
365
|
+
self.identifier = 'linear_combination'
|
|
366
|
+
self.fmt = "ffPP"
|
|
367
|
+
|
|
368
|
+
#
|
|
369
|
+
def size(self):
|
|
370
|
+
'''
|
|
371
|
+
Size of the packed Params structure
|
|
372
|
+
'''
|
|
373
|
+
return struct.calcsize(self.fmt)
|
|
374
|
+
|
|
375
|
+
#
|
|
376
|
+
def alignment(self):
|
|
377
|
+
return MaxAlignment(self.fmt)
|
|
378
|
+
|
|
379
|
+
#
|
|
380
|
+
def initialize(self, host_workspace, offset, arguments):
|
|
381
|
+
|
|
382
|
+
offset = AlignedOffset(offset, self.alignment())
|
|
383
|
+
|
|
384
|
+
struct.pack_into(
|
|
385
|
+
self.fmt,
|
|
386
|
+
host_workspace, offset,
|
|
387
|
+
arguments.alpha, arguments.beta, arguments.alpha_ptr, arguments.beta_ptr)
|
|
388
|
+
|
|
389
|
+
return offset + self.size()
|
|
390
|
+
|
|
391
|
+
#################################################################################################
|
|
392
|
+
#
|
|
393
|
+
# Base class for an executable operation
|
|
394
|
+
#
|
|
395
|
+
#################################################################################################
|
|
396
|
+
|
|
397
|
+
#
|
|
398
|
+
class ExecutableOperation:
|
|
399
|
+
'''
|
|
400
|
+
'''
|
|
401
|
+
def __init__(self, operation):
|
|
402
|
+
self.operation = operation
|
|
403
|
+
self.module = None
|
|
404
|
+
self.kernel = None
|
|
405
|
+
|
|
406
|
+
#
|
|
407
|
+
def name(self):
|
|
408
|
+
return self.operation.procedural_name()
|
|
409
|
+
|
|
410
|
+
#
|
|
411
|
+
def emit(self):
|
|
412
|
+
return ''
|
|
413
|
+
|
|
414
|
+
#
|
|
415
|
+
def can_implement(self, configuration, arguments):
|
|
416
|
+
return False
|
|
417
|
+
|
|
418
|
+
#
|
|
419
|
+
def get_host_workspace_size(self, arguments):
|
|
420
|
+
return 0
|
|
421
|
+
|
|
422
|
+
#
|
|
423
|
+
def get_device_workspace_size(self, arguments):
|
|
424
|
+
return 0
|
|
425
|
+
|
|
426
|
+
#
|
|
427
|
+
def plan(self, arguments):
|
|
428
|
+
return LaunchConfiguration()
|
|
429
|
+
|
|
430
|
+
#
|
|
431
|
+
def initialize(self, host_workspace, device_workspace, launch_config, arguments, stream = cuda.CUstream(0)):
|
|
432
|
+
raise NotImplementedError()
|
|
433
|
+
|
|
434
|
+
#
|
|
435
|
+
def run(self, host_workspace, device_workspace, launch_config, stream = cuda.CUstream(0)):
|
|
436
|
+
|
|
437
|
+
cArg = (ctypes.c_char * len(host_workspace)).from_buffer(host_workspace)
|
|
438
|
+
packed = (ctypes.c_void_p * 1)()
|
|
439
|
+
packed[0] = ctypes.addressof(cArg)
|
|
440
|
+
|
|
441
|
+
err, = cuda.cuLaunchKernel(
|
|
442
|
+
self.kernel,
|
|
443
|
+
launch_config.grid[0], launch_config.grid[1], launch_config.grid[2],
|
|
444
|
+
launch_config.block[0], launch_config.block[1], launch_config.block[2],
|
|
445
|
+
launch_config.shared_memory_capacity,
|
|
446
|
+
stream,
|
|
447
|
+
packed,
|
|
448
|
+
0)
|
|
449
|
+
|
|
450
|
+
return err
|
|
451
|
+
|
|
452
|
+
#################################################################################################
|
|
453
|
+
|
|
454
|
+
|
|
455
|
+
#
|
|
456
|
+
class GemmArguments:
|
|
457
|
+
'''
|
|
458
|
+
'''
|
|
459
|
+
def __init__(self):
|
|
460
|
+
self.problem_size = GemmCoord(0, 0, 0)
|
|
461
|
+
self.A = TensorRef()
|
|
462
|
+
self.B = TensorRef()
|
|
463
|
+
self.C = TensorRef()
|
|
464
|
+
self.D = TensorRef()
|
|
465
|
+
self.output_op = LinearCombinationFunctorArguments()
|
|
466
|
+
|
|
467
|
+
#
|
|
468
|
+
class ThreadblockSwizzle:
|
|
469
|
+
def __init__(self, threadblock_shape, log_threadblock_cohort = 0):
|
|
470
|
+
self.threadblock_shape = threadblock_shape
|
|
471
|
+
self.log_threadblock_cohort = log_threadblock_cohort
|
|
472
|
+
|
|
473
|
+
def grid_tiled_shape(self, problem_size):
|
|
474
|
+
return GemmCoord(
|
|
475
|
+
ceil_div(problem_size.m, self.threadblock_shape.m),
|
|
476
|
+
ceil_div(problem_size.n, self.threadblock_shape.n),
|
|
477
|
+
1)
|
|
478
|
+
|
|
479
|
+
#
|
|
480
|
+
class Gemm(ExecutableOperation):
|
|
481
|
+
'''
|
|
482
|
+
GEMM manages the CUTLASS runtime components
|
|
483
|
+
'''
|
|
484
|
+
#
|
|
485
|
+
def __init__(self, operation):
|
|
486
|
+
super().__init__(operation)
|
|
487
|
+
|
|
488
|
+
self.emitter = EmitGemmUniversalInstance('_type')
|
|
489
|
+
self.threadblock_swizzle = ThreadblockSwizzle(GemmCoord(128, 128, 8))
|
|
490
|
+
|
|
491
|
+
self.threads = 256
|
|
492
|
+
self.shared_memory_capacity = (32 << 10)
|
|
493
|
+
|
|
494
|
+
self.params_A = PredicatedTileAccessIteratorParams(
|
|
495
|
+
PredicatedTileAccessIteratorDesc(
|
|
496
|
+
32,
|
|
497
|
+
1,
|
|
498
|
+
PitchLinearCoord(128, 8),
|
|
499
|
+
PitchLinearCoord(1, 4),
|
|
500
|
+
PitchLinearCoord(1, 2)), 'A')
|
|
501
|
+
|
|
502
|
+
self.params_B = PredicatedTileAccessIteratorParams(
|
|
503
|
+
PredicatedTileAccessIteratorDesc(
|
|
504
|
+
32,
|
|
505
|
+
1,
|
|
506
|
+
PitchLinearCoord(128, 8),
|
|
507
|
+
PitchLinearCoord(1, 4),
|
|
508
|
+
PitchLinearCoord(1, 2)), 'B')
|
|
509
|
+
|
|
510
|
+
self.params_C = EpilogueTileIteratorParams(
|
|
511
|
+
EpilogueThreadMap(
|
|
512
|
+
256,
|
|
513
|
+
1,
|
|
514
|
+
32,
|
|
515
|
+
EpilogueTileDesc(128, 1, 4, 4, 1),
|
|
516
|
+
EpilogueTileDesc(4, 1, 2, 1, 1),
|
|
517
|
+
EpilogueTileDesc(32, 1, 8, 1, 1),
|
|
518
|
+
EpilogueTileDesc(1, 4, 2, 1, 8)), 'C')
|
|
519
|
+
|
|
520
|
+
self.params_D = EpilogueTileIteratorParams(
|
|
521
|
+
EpilogueThreadMap(
|
|
522
|
+
256,
|
|
523
|
+
1,
|
|
524
|
+
32,
|
|
525
|
+
EpilogueTileDesc(128, 1, 4, 4, 1),
|
|
526
|
+
EpilogueTileDesc(4, 1, 2, 1, 1),
|
|
527
|
+
EpilogueTileDesc(32, 1, 8, 1, 1),
|
|
528
|
+
EpilogueTileDesc(1, 4, 2, 1, 8)), 'D')
|
|
529
|
+
|
|
530
|
+
self.output_op = LinearCombinationFunctor()
|
|
531
|
+
|
|
532
|
+
#
|
|
533
|
+
def emit(self):
|
|
534
|
+
return self.emitter.emit(self.operation)
|
|
535
|
+
|
|
536
|
+
#
|
|
537
|
+
def can_implement(self, configuration, arguments):
|
|
538
|
+
pass
|
|
539
|
+
|
|
540
|
+
#
|
|
541
|
+
def get_host_workspace_size(self, arguments):
|
|
542
|
+
return 336
|
|
543
|
+
|
|
544
|
+
#
|
|
545
|
+
def get_device_workspace_size(self, arguments):
|
|
546
|
+
return 0
|
|
547
|
+
|
|
548
|
+
#
|
|
549
|
+
def plan(self, arguments):
|
|
550
|
+
grid = self.threadblock_swizzle.grid_tiled_shape(arguments.problem_size)
|
|
551
|
+
return LaunchConfiguration([grid.m, grid.n, grid.k], [self.threads, 1, 1], self.shared_memory_capacity)
|
|
552
|
+
|
|
553
|
+
#
|
|
554
|
+
def initialize(self, host_workspace, device_workspace, launch_config, arguments, stream = cuda.CUstream(0)):
|
|
555
|
+
|
|
556
|
+
offset = 0
|
|
557
|
+
|
|
558
|
+
# Compute intermediate results
|
|
559
|
+
swizzle_log_tile = 0
|
|
560
|
+
gemm_mode = 0
|
|
561
|
+
batch_count = 1
|
|
562
|
+
gemm_k_size = arguments.problem_size.k
|
|
563
|
+
|
|
564
|
+
# Pack into the host workspace buffer
|
|
565
|
+
offset = arguments.problem_size.pack_into(host_workspace, offset)
|
|
566
|
+
|
|
567
|
+
grid_tiled_shape = self.threadblock_swizzle.grid_tiled_shape(arguments.problem_size)
|
|
568
|
+
offset = grid_tiled_shape.pack_into(host_workspace, offset)
|
|
569
|
+
|
|
570
|
+
offset = PackInteger(host_workspace, offset, swizzle_log_tile)
|
|
571
|
+
|
|
572
|
+
offset = self.params_A.initialize(host_workspace, offset, arguments.A.layout)
|
|
573
|
+
offset = self.params_B.initialize(host_workspace, offset, arguments.B.layout)
|
|
574
|
+
offset = self.params_C.initialize(host_workspace, offset, arguments.C.layout)
|
|
575
|
+
offset = self.params_D.initialize(host_workspace, offset, arguments.D.layout)
|
|
576
|
+
|
|
577
|
+
offset = self.output_op.initialize(host_workspace, offset, arguments.output_op)
|
|
578
|
+
|
|
579
|
+
offset = PackInteger(host_workspace, offset, gemm_mode)
|
|
580
|
+
offset = PackInteger(host_workspace, offset, batch_count)
|
|
581
|
+
offset = PackInteger(host_workspace, offset, gemm_k_size)
|
|
582
|
+
offset = PackDevicePointer(host_workspace, offset, int(arguments.A.pointer))
|
|
583
|
+
offset = PackDevicePointer(host_workspace, offset, int(arguments.B.pointer))
|
|
584
|
+
offset = PackDevicePointer(host_workspace, offset, int(arguments.C.pointer))
|
|
585
|
+
offset = PackDevicePointer(host_workspace, offset, int(arguments.D.pointer))
|
|
586
|
+
|
|
587
|
+
return offset
|
|
588
|
+
|
|
589
|
+
|
|
590
|
+
#################################################################################################
|
|
591
|
+
#
|
|
592
|
+
# Module represents a compilation unit
|
|
593
|
+
#
|
|
594
|
+
#################################################################################################
|
|
595
|
+
|
|
596
|
+
#
|
|
597
|
+
class CompilationOptions:
|
|
598
|
+
'''
|
|
599
|
+
Compilation options.
|
|
600
|
+
'''
|
|
601
|
+
|
|
602
|
+
#
|
|
603
|
+
def __init__(self, architectures = [80], include_paths = []):
|
|
604
|
+
self.includes = []
|
|
605
|
+
self.include_paths = include_paths
|
|
606
|
+
self.flags = ['-std=c++11', '-default-device']
|
|
607
|
+
self.architectures = architectures
|
|
608
|
+
|
|
609
|
+
#
|
|
610
|
+
def get(self):
|
|
611
|
+
options = []
|
|
612
|
+
|
|
613
|
+
for flag in self.flags:
|
|
614
|
+
options.append(bytes(str.encode(flag)))
|
|
615
|
+
|
|
616
|
+
for incl in self.include_paths:
|
|
617
|
+
options.append(bytes(str.encode('--include-path=%s' % incl)))
|
|
618
|
+
|
|
619
|
+
arch_list = "-arch="
|
|
620
|
+
for idx, arch in enumerate(self.architectures):
|
|
621
|
+
if idx:
|
|
622
|
+
arch_list += ","
|
|
623
|
+
arch_list += "sm_%d" % arch
|
|
624
|
+
|
|
625
|
+
options.append(bytes(str.encode(arch_list)))
|
|
626
|
+
|
|
627
|
+
return options
|
|
628
|
+
|
|
629
|
+
IncludeTemplate = r'''#include "${include}"
|
|
630
|
+
'''
|
|
631
|
+
|
|
632
|
+
KernelTemplate = r'''
|
|
633
|
+
extern "C"
|
|
634
|
+
__global__ void
|
|
635
|
+
${operation_name}(${operation_name}${operation_suffix}::Params params) {
|
|
636
|
+
|
|
637
|
+
// Dynamic shared memory base pointer
|
|
638
|
+
extern __shared__ int SharedStorageBase[];
|
|
639
|
+
|
|
640
|
+
// Declare pointer to dynamic shared memory.
|
|
641
|
+
${operation_name}${operation_suffix}::SharedStorage *shared_storage =
|
|
642
|
+
reinterpret_cast<${operation_name}${operation_suffix}::SharedStorage *>(SharedStorageBase);
|
|
643
|
+
|
|
644
|
+
${operation_name}${operation_suffix} op;
|
|
645
|
+
|
|
646
|
+
op(params, *shared_storage);
|
|
647
|
+
}
|
|
648
|
+
|
|
649
|
+
'''
|
|
650
|
+
|
|
651
|
+
#
|
|
652
|
+
class Module:
|
|
653
|
+
def __init__(self, name, operations, compilation_options):
|
|
654
|
+
self.name = name
|
|
655
|
+
self.operations = operations
|
|
656
|
+
self.module = None
|
|
657
|
+
self.log = None
|
|
658
|
+
self.cubin_image = None
|
|
659
|
+
self.source_buffer = ''
|
|
660
|
+
|
|
661
|
+
#
|
|
662
|
+
# Emit source
|
|
663
|
+
#
|
|
664
|
+
self.emit_()
|
|
665
|
+
|
|
666
|
+
#
|
|
667
|
+
# Compile
|
|
668
|
+
#
|
|
669
|
+
self.compile_(compilation_options)
|
|
670
|
+
|
|
671
|
+
#
|
|
672
|
+
# Load module
|
|
673
|
+
#
|
|
674
|
+
self.load_()
|
|
675
|
+
|
|
676
|
+
# Done
|
|
677
|
+
return
|
|
678
|
+
|
|
679
|
+
# Emit a source buffer
|
|
680
|
+
def emit_(self):
|
|
681
|
+
|
|
682
|
+
# 1. Includes
|
|
683
|
+
includes = []
|
|
684
|
+
for operation in self.operations:
|
|
685
|
+
for incl in operation.emitter.includes:
|
|
686
|
+
if incl not in includes:
|
|
687
|
+
includes.append(incl)
|
|
688
|
+
|
|
689
|
+
for incl in includes:
|
|
690
|
+
self.source_buffer += SubstituteTemplate(IncludeTemplate, { 'include': incl} )
|
|
691
|
+
|
|
692
|
+
# 2. Operations
|
|
693
|
+
for operation in self.operations:
|
|
694
|
+
self.source_buffer += operation.emit()
|
|
695
|
+
values = {
|
|
696
|
+
'operation_name': operation.name(),
|
|
697
|
+
'operation_suffix': operation.emitter.operation_suffix
|
|
698
|
+
}
|
|
699
|
+
self.source_buffer += SubstituteTemplate(KernelTemplate, values)
|
|
700
|
+
|
|
701
|
+
# Done
|
|
702
|
+
return
|
|
703
|
+
|
|
704
|
+
# Compile with NVRTC
|
|
705
|
+
def compile_(self, compilation_options):
|
|
706
|
+
|
|
707
|
+
err, program = nvrtc.nvrtcCreateProgram(
|
|
708
|
+
str.encode(self.source_buffer),
|
|
709
|
+
bytes(str.encode(self.name)),
|
|
710
|
+
0, [], [])
|
|
711
|
+
|
|
712
|
+
if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
|
713
|
+
raise RuntimeError('NVRTC Error: {}'.format(err))
|
|
714
|
+
|
|
715
|
+
# Compile program
|
|
716
|
+
options = compilation_options.get()
|
|
717
|
+
|
|
718
|
+
err, = nvrtc.nvrtcCompileProgram(program, len(options), options)
|
|
719
|
+
if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
|
720
|
+
|
|
721
|
+
error_string = 'NVRTC Error: {}\n'.format(err)
|
|
722
|
+
|
|
723
|
+
# Get log from compilation
|
|
724
|
+
err, logSize = nvrtc.nvrtcGetProgramLogSize(program)
|
|
725
|
+
if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
|
726
|
+
raise RuntimeError('NVRTC Error: {}'.format(err))
|
|
727
|
+
|
|
728
|
+
self.log = b' ' * logSize
|
|
729
|
+
err, = nvrtc.nvrtcGetProgramLog(program, self.log)
|
|
730
|
+
if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
|
731
|
+
raise RuntimeError('NVRTC Error: {}'.format(err))
|
|
732
|
+
|
|
733
|
+
raise RuntimeError(error_string + self.log.decode() + self.source_buffer)
|
|
734
|
+
|
|
735
|
+
# Get data from compilation
|
|
736
|
+
err, dataSize = nvrtc.nvrtcGetCUBINSize(program)
|
|
737
|
+
if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
|
738
|
+
raise RuntimeError('NVRTC Error: {}'.format(err))
|
|
739
|
+
|
|
740
|
+
self.cubin_image = b' ' * dataSize
|
|
741
|
+
err, = nvrtc.nvrtcGetCUBIN(program, self.cubin_image)
|
|
742
|
+
if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
|
743
|
+
raise RuntimeError('NVRTC Error: {}'.format(err))
|
|
744
|
+
|
|
745
|
+
return
|
|
746
|
+
|
|
747
|
+
#
|
|
748
|
+
def load_(self):
|
|
749
|
+
|
|
750
|
+
# Load data as module data
|
|
751
|
+
err, self.module = cuda.cuModuleLoadData(self.cubin_image)
|
|
752
|
+
if err != cuda.CUresult.CUDA_SUCCESS:
|
|
753
|
+
raise RuntimeError('Cuda Error: {}'.format(err))
|
|
754
|
+
|
|
755
|
+
# Get functions
|
|
756
|
+
for operation in self.operations:
|
|
757
|
+
err, operation.kernel = cuda.cuModuleGetFunction(
|
|
758
|
+
self.module,
|
|
759
|
+
bytes(str.encode(operation.name())))
|
|
760
|
+
|
|
761
|
+
if err != cuda.CUresult.CUDA_SUCCESS:
|
|
762
|
+
raise RuntimeError('Cuda Error: {}'.format(err))
|
|
763
|
+
|
|
764
|
+
operation.module = self
|
|
765
|
+
|
|
766
|
+
return
|
|
767
|
+
|
|
768
|
+
|
|
769
|
+
#################################################################################################
|
|
770
|
+
#
|
|
771
|
+
# Manifest represents an 'owner' for modules and operations
|
|
772
|
+
#
|
|
773
|
+
#################################################################################################
|
|
774
|
+
|
|
775
|
+
#
|
|
776
|
+
class Manifest:
|
|
777
|
+
|
|
778
|
+
#
|
|
779
|
+
def __init__(self):
|
|
780
|
+
self.operations = {}
|
|
781
|
+
self.modules = []
|
|
782
|
+
pass
|
|
783
|
+
|
|
784
|
+
#
|
|
785
|
+
def append_module(self, module):
|
|
786
|
+
'''
|
|
787
|
+
Appends a module and takes ownership of operations used to construct it.
|
|
788
|
+
'''
|
|
789
|
+
|
|
790
|
+
self.modules.append(module)
|
|
791
|
+
|
|
792
|
+
for operation in module.operations:
|
|
793
|
+
self.operations[operation.name()] = operation
|
|
794
|
+
|
|
795
|
+
|
|
796
|
+
#################################################################################################
|