warp-lang 0.10.1__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +10 -4
- warp/__init__.pyi +1 -0
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +5 -3
- warp/build_dll.py +29 -9
- warp/builtins.py +868 -507
- warp/codegen.py +1074 -638
- warp/config.py +3 -3
- warp/constants.py +6 -0
- warp/context.py +715 -222
- warp/fabric.py +326 -0
- warp/fem/__init__.py +27 -0
- warp/fem/cache.py +389 -0
- warp/fem/dirichlet.py +181 -0
- warp/fem/domain.py +263 -0
- warp/fem/field/__init__.py +101 -0
- warp/fem/field/field.py +149 -0
- warp/fem/field/nodal_field.py +299 -0
- warp/fem/field/restriction.py +21 -0
- warp/fem/field/test.py +181 -0
- warp/fem/field/trial.py +183 -0
- warp/fem/geometry/__init__.py +19 -0
- warp/fem/geometry/closest_point.py +70 -0
- warp/fem/geometry/deformed_geometry.py +271 -0
- warp/fem/geometry/element.py +744 -0
- warp/fem/geometry/geometry.py +186 -0
- warp/fem/geometry/grid_2d.py +373 -0
- warp/fem/geometry/grid_3d.py +435 -0
- warp/fem/geometry/hexmesh.py +953 -0
- warp/fem/geometry/partition.py +376 -0
- warp/fem/geometry/quadmesh_2d.py +532 -0
- warp/fem/geometry/tetmesh.py +840 -0
- warp/fem/geometry/trimesh_2d.py +577 -0
- warp/fem/integrate.py +1616 -0
- warp/fem/operator.py +191 -0
- warp/fem/polynomial.py +213 -0
- warp/fem/quadrature/__init__.py +2 -0
- warp/fem/quadrature/pic_quadrature.py +245 -0
- warp/fem/quadrature/quadrature.py +294 -0
- warp/fem/space/__init__.py +292 -0
- warp/fem/space/basis_space.py +489 -0
- warp/fem/space/collocated_function_space.py +105 -0
- warp/fem/space/dof_mapper.py +236 -0
- warp/fem/space/function_space.py +145 -0
- warp/fem/space/grid_2d_function_space.py +267 -0
- warp/fem/space/grid_3d_function_space.py +306 -0
- warp/fem/space/hexmesh_function_space.py +352 -0
- warp/fem/space/partition.py +350 -0
- warp/fem/space/quadmesh_2d_function_space.py +369 -0
- warp/fem/space/restriction.py +160 -0
- warp/fem/space/shape/__init__.py +15 -0
- warp/fem/space/shape/cube_shape_function.py +738 -0
- warp/fem/space/shape/shape_function.py +103 -0
- warp/fem/space/shape/square_shape_function.py +611 -0
- warp/fem/space/shape/tet_shape_function.py +567 -0
- warp/fem/space/shape/triangle_shape_function.py +429 -0
- warp/fem/space/tetmesh_function_space.py +292 -0
- warp/fem/space/topology.py +295 -0
- warp/fem/space/trimesh_2d_function_space.py +221 -0
- warp/fem/types.py +77 -0
- warp/fem/utils.py +495 -0
- warp/native/array.h +147 -44
- warp/native/builtin.h +122 -149
- warp/native/bvh.cpp +73 -325
- warp/native/bvh.cu +406 -23
- warp/native/bvh.h +34 -43
- warp/native/clang/clang.cpp +13 -8
- warp/native/crt.h +2 -0
- warp/native/cuda_crt.h +5 -0
- warp/native/cuda_util.cpp +15 -3
- warp/native/cuda_util.h +3 -1
- warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
- warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
- warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
- warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
- warp/native/cutlass/tools/library/scripts/library.py +799 -0
- warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
- warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
- warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
- warp/native/cutlass/tools/library/scripts/rt.py +796 -0
- warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
- warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
- warp/native/cutlass_gemm.cu +5 -3
- warp/native/exports.h +1240 -952
- warp/native/fabric.h +228 -0
- warp/native/hashgrid.cpp +4 -4
- warp/native/hashgrid.h +22 -2
- warp/native/intersect.h +22 -7
- warp/native/intersect_adj.h +8 -8
- warp/native/intersect_tri.h +1 -1
- warp/native/marching.cu +157 -161
- warp/native/mat.h +80 -19
- warp/native/matnn.h +2 -2
- warp/native/mesh.cpp +33 -108
- warp/native/mesh.cu +114 -23
- warp/native/mesh.h +446 -46
- warp/native/noise.h +272 -329
- warp/native/quat.h +51 -8
- warp/native/rand.h +45 -35
- warp/native/range.h +6 -2
- warp/native/reduce.cpp +1 -1
- warp/native/reduce.cu +10 -12
- warp/native/runlength_encode.cu +6 -10
- warp/native/scan.cu +8 -11
- warp/native/sparse.cpp +4 -4
- warp/native/sparse.cu +164 -154
- warp/native/spatial.h +2 -2
- warp/native/temp_buffer.h +14 -30
- warp/native/vec.h +107 -23
- warp/native/volume.h +120 -0
- warp/native/warp.cpp +560 -30
- warp/native/warp.cu +431 -44
- warp/native/warp.h +13 -4
- warp/optim/__init__.py +1 -0
- warp/optim/linear.py +922 -0
- warp/optim/sgd.py +92 -0
- warp/render/render_opengl.py +335 -119
- warp/render/render_usd.py +11 -11
- warp/sim/__init__.py +2 -2
- warp/sim/articulation.py +385 -185
- warp/sim/collide.py +8 -0
- warp/sim/import_mjcf.py +297 -106
- warp/sim/import_urdf.py +389 -210
- warp/sim/import_usd.py +198 -97
- warp/sim/inertia.py +17 -18
- warp/sim/integrator_euler.py +14 -8
- warp/sim/integrator_xpbd.py +158 -16
- warp/sim/model.py +795 -291
- warp/sim/render.py +3 -3
- warp/sim/utils.py +3 -0
- warp/sparse.py +640 -150
- warp/stubs.py +606 -267
- warp/tape.py +61 -10
- warp/tests/__main__.py +3 -6
- warp/tests/assets/curlnoise_golden.npy +0 -0
- warp/tests/assets/pnoise_golden.npy +0 -0
- warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
- warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
- warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
- warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
- warp/tests/aux_test_unresolved_func.py +14 -0
- warp/tests/aux_test_unresolved_symbol.py +14 -0
- warp/tests/disabled_kinematics.py +239 -0
- warp/tests/run_coverage_serial.py +31 -0
- warp/tests/test_adam.py +103 -106
- warp/tests/test_arithmetic.py +128 -74
- warp/tests/test_array.py +212 -97
- warp/tests/test_array_reduce.py +57 -23
- warp/tests/test_atomic.py +64 -28
- warp/tests/test_bool.py +99 -0
- warp/tests/test_builtins_resolution.py +1292 -0
- warp/tests/test_bvh.py +42 -18
- warp/tests/test_closest_point_edge_edge.py +54 -57
- warp/tests/test_codegen.py +208 -130
- warp/tests/test_compile_consts.py +28 -20
- warp/tests/test_conditional.py +108 -24
- warp/tests/test_copy.py +10 -12
- warp/tests/test_ctypes.py +112 -88
- warp/tests/test_dense.py +21 -14
- warp/tests/test_devices.py +98 -0
- warp/tests/test_dlpack.py +75 -75
- warp/tests/test_examples.py +277 -0
- warp/tests/test_fabricarray.py +955 -0
- warp/tests/test_fast_math.py +15 -11
- warp/tests/test_fem.py +1271 -0
- warp/tests/test_fp16.py +53 -19
- warp/tests/test_func.py +187 -86
- warp/tests/test_generics.py +194 -49
- warp/tests/test_grad.py +178 -109
- warp/tests/test_grad_customs.py +176 -0
- warp/tests/test_hash_grid.py +52 -37
- warp/tests/test_import.py +10 -23
- warp/tests/test_indexedarray.py +32 -31
- warp/tests/test_intersect.py +18 -9
- warp/tests/test_large.py +141 -0
- warp/tests/test_launch.py +14 -41
- warp/tests/test_lerp.py +64 -65
- warp/tests/test_linear_solvers.py +154 -0
- warp/tests/test_lvalue.py +493 -0
- warp/tests/test_marching_cubes.py +12 -13
- warp/tests/test_mat.py +517 -2898
- warp/tests/test_mat_lite.py +115 -0
- warp/tests/test_mat_scalar_ops.py +2889 -0
- warp/tests/test_math.py +103 -9
- warp/tests/test_matmul.py +305 -69
- warp/tests/test_matmul_lite.py +410 -0
- warp/tests/test_mesh.py +71 -14
- warp/tests/test_mesh_query_aabb.py +41 -25
- warp/tests/test_mesh_query_point.py +140 -22
- warp/tests/test_mesh_query_ray.py +39 -22
- warp/tests/test_mlp.py +30 -22
- warp/tests/test_model.py +92 -89
- warp/tests/test_modules_lite.py +39 -0
- warp/tests/test_multigpu.py +88 -114
- warp/tests/test_noise.py +12 -11
- warp/tests/test_operators.py +16 -20
- warp/tests/test_options.py +11 -11
- warp/tests/test_pinned.py +17 -18
- warp/tests/test_print.py +32 -11
- warp/tests/test_quat.py +275 -129
- warp/tests/test_rand.py +18 -16
- warp/tests/test_reload.py +38 -34
- warp/tests/test_rounding.py +50 -43
- warp/tests/test_runlength_encode.py +168 -20
- warp/tests/test_smoothstep.py +9 -11
- warp/tests/test_snippet.py +143 -0
- warp/tests/test_sparse.py +261 -63
- warp/tests/test_spatial.py +276 -243
- warp/tests/test_streams.py +110 -85
- warp/tests/test_struct.py +268 -63
- warp/tests/test_tape.py +39 -21
- warp/tests/test_torch.py +118 -89
- warp/tests/test_transient_module.py +12 -13
- warp/tests/test_types.py +614 -0
- warp/tests/test_utils.py +494 -0
- warp/tests/test_vec.py +354 -2050
- warp/tests/test_vec_lite.py +73 -0
- warp/tests/test_vec_scalar_ops.py +2099 -0
- warp/tests/test_volume.py +457 -293
- warp/tests/test_volume_write.py +124 -134
- warp/tests/unittest_serial.py +35 -0
- warp/tests/unittest_suites.py +341 -0
- warp/tests/unittest_utils.py +568 -0
- warp/tests/unused_test_misc.py +71 -0
- warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
- warp/thirdparty/appdirs.py +36 -45
- warp/thirdparty/unittest_parallel.py +549 -0
- warp/torch.py +9 -6
- warp/types.py +1089 -366
- warp/utils.py +93 -387
- warp_lang-0.11.0.dist-info/METADATA +238 -0
- warp_lang-0.11.0.dist-info/RECORD +332 -0
- {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
- warp/tests/test_all.py +0 -219
- warp/tests/test_array_scan.py +0 -60
- warp/tests/test_base.py +0 -208
- warp/tests/test_unresolved_func.py +0 -7
- warp/tests/test_unresolved_symbol.py +0 -7
- warp_lang-0.10.1.dist-info/METADATA +0 -21
- warp_lang-0.10.1.dist-info/RECORD +0 -188
- /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
- /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
- /warp/tests/{test_square.py → aux_test_square.py} +0 -0
- {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/LICENSE.md +0 -0
- {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,619 @@
|
|
|
1
|
+
################################################################################
|
|
2
|
+
#
|
|
3
|
+
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
4
|
+
# SPDX-License-Identifier: BSD-3-Clause
|
|
5
|
+
#
|
|
6
|
+
# Redistribution and use in source and binary forms, with or without
|
|
7
|
+
# modification, are permitted provided that the following conditions are met:
|
|
8
|
+
#
|
|
9
|
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
|
10
|
+
# list of conditions and the following disclaimer.
|
|
11
|
+
#
|
|
12
|
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
13
|
+
# this list of conditions and the following disclaimer in the documentation
|
|
14
|
+
# and/or other materials provided with the distribution.
|
|
15
|
+
#
|
|
16
|
+
# 3. Neither the name of the copyright holder nor the names of its
|
|
17
|
+
# contributors may be used to endorse or promote products derived from
|
|
18
|
+
# this software without specific prior written permission.
|
|
19
|
+
#
|
|
20
|
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
21
|
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
22
|
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
23
|
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
24
|
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
25
|
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
26
|
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
27
|
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
28
|
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
29
|
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
30
|
+
#
|
|
31
|
+
################################################################################
|
|
32
|
+
|
|
33
|
+
from typing import Generic, TypeVar
|
|
34
|
+
from treelib import Tree
|
|
35
|
+
import numpy as np
|
|
36
|
+
|
|
37
|
+
from pycutlass import *
|
|
38
|
+
import pycutlass
|
|
39
|
+
|
|
40
|
+
import ast
|
|
41
|
+
import textwrap
|
|
42
|
+
import inspect
|
|
43
|
+
|
|
44
|
+
################################################################################
|
|
45
|
+
# Type annotation for input arguments
|
|
46
|
+
################################################################################
|
|
47
|
+
|
|
48
|
+
Ttype = TypeVar("Ttype")
|
|
49
|
+
Dtype = TypeVar("Dtype")
|
|
50
|
+
|
|
51
|
+
class NDArray(np.ndarray, Generic[Ttype, Dtype]):
|
|
52
|
+
pass
|
|
53
|
+
|
|
54
|
+
################################################################################
|
|
55
|
+
# Operations
|
|
56
|
+
################################################################################
|
|
57
|
+
|
|
58
|
+
operators = {
|
|
59
|
+
ast.Add: "Add",
|
|
60
|
+
ast.Div: "Div",
|
|
61
|
+
ast.Eq: "Equal",
|
|
62
|
+
ast.Mult: "Mult"
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
################################################################################
|
|
66
|
+
# AST Node abstractions
|
|
67
|
+
################################################################################
|
|
68
|
+
class UnaryNode:
|
|
69
|
+
cnt = 0
|
|
70
|
+
# Concept: this is created by the BinOp Node in python ast
|
|
71
|
+
def __init__(self,
|
|
72
|
+
element_accumulator, element_compute, elements_per_access,
|
|
73
|
+
node, args) -> None:
|
|
74
|
+
if isinstance(node, BinOpNode):
|
|
75
|
+
self.op = node.op
|
|
76
|
+
elif isinstance(node, ast.Call):
|
|
77
|
+
if isinstance(node.func, ast.Name):
|
|
78
|
+
self.op = node.func.id
|
|
79
|
+
elif isinstance(node.func, ast.Attribute):
|
|
80
|
+
self.op = node.func.value.id
|
|
81
|
+
else:
|
|
82
|
+
raise TypeError
|
|
83
|
+
else:
|
|
84
|
+
raise TypeError
|
|
85
|
+
self.tag = "Unary" + self.op + str(UnaryNode.cnt)
|
|
86
|
+
self.id = self.op + str(UnaryNode.cnt)
|
|
87
|
+
self.args = args
|
|
88
|
+
UnaryNode.cnt += 1
|
|
89
|
+
|
|
90
|
+
self.type = "tensor"
|
|
91
|
+
|
|
92
|
+
self.epilogue_op = getattr(pycutlass, self.op)(element_compute)
|
|
93
|
+
|
|
94
|
+
# data types
|
|
95
|
+
self.element_accumulator = element_accumulator
|
|
96
|
+
self.element_compute = element_compute
|
|
97
|
+
self.elements_per_access = elements_per_access
|
|
98
|
+
|
|
99
|
+
def get_epilogue_node(self, visitors):
|
|
100
|
+
self.epilogue_node = UnaryOp(
|
|
101
|
+
self.element_accumulator, self.element_compute,
|
|
102
|
+
self.elements_per_access, *visitors, self.epilogue_op)
|
|
103
|
+
|
|
104
|
+
def get_argument(self, visitor_args, kwargs):
|
|
105
|
+
epilogue_ops = []
|
|
106
|
+
for arg in self.args:
|
|
107
|
+
try:
|
|
108
|
+
epilogue_ops.append(kwargs[arg])
|
|
109
|
+
except:
|
|
110
|
+
epilogue_ops.append(arg) # direct arguments like constant
|
|
111
|
+
self.argument = self.epilogue_node.argument_type(self.epilogue_op.argument_type(*epilogue_ops), *visitor_args)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class BinOpNode:
|
|
115
|
+
cnt = 0
|
|
116
|
+
# Concept: this is created by the BinOp Node in python ast
|
|
117
|
+
def __init__(self,
|
|
118
|
+
element_accumulator, element_compute, elements_per_access,
|
|
119
|
+
node) -> None:
|
|
120
|
+
self.op = operators[type(node.op)]
|
|
121
|
+
self.tag = "Binary" + self.op + str(BinOpNode.cnt)
|
|
122
|
+
self.id = self.op + str(BinOpNode.cnt)
|
|
123
|
+
self.args = None
|
|
124
|
+
BinOpNode.cnt += 1
|
|
125
|
+
|
|
126
|
+
self.type = "tensor"
|
|
127
|
+
|
|
128
|
+
self.epilogue_op = getattr(pycutlass, "Vector"+self.op)(element_compute)
|
|
129
|
+
|
|
130
|
+
# data types
|
|
131
|
+
self.element_accumulator = element_accumulator
|
|
132
|
+
self.element_compute = element_compute
|
|
133
|
+
self.elements_per_access = elements_per_access
|
|
134
|
+
|
|
135
|
+
def get_epilogue_node(self, visitors):
|
|
136
|
+
self.epilogue_node = BinaryOp(
|
|
137
|
+
self.element_accumulator, self.element_compute,
|
|
138
|
+
self.elements_per_access, *visitors, self.epilogue_op)
|
|
139
|
+
|
|
140
|
+
def get_argument(self, visitor_args, kwargs):
|
|
141
|
+
self.argument = self.epilogue_node.argument_type(self.epilogue_op.argument_type(self.args), *visitor_args)
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
class NameNode:
|
|
145
|
+
# Concept: this is created by the Name Node in python ast
|
|
146
|
+
def __init__(self, node) -> None:
|
|
147
|
+
try:
|
|
148
|
+
self.id = node.id
|
|
149
|
+
except:
|
|
150
|
+
self.id = node.targets[0].id
|
|
151
|
+
self.tag = self.id
|
|
152
|
+
|
|
153
|
+
class ScalarInputNode(NameNode):
|
|
154
|
+
# Concept: scalar
|
|
155
|
+
def __init__(self, node) -> None:
|
|
156
|
+
super().__init__(node)
|
|
157
|
+
self.tag = "Scalar:" + self.tag
|
|
158
|
+
self.type = "scalar"
|
|
159
|
+
|
|
160
|
+
class AccumulatorNode(NameNode):
|
|
161
|
+
# Concept: VisitorOpAccumulator
|
|
162
|
+
def __init__(self,
|
|
163
|
+
element_accumulator, elements_per_access, node) -> None:
|
|
164
|
+
super().__init__(node)
|
|
165
|
+
self.tag = "Accum:" + self.tag
|
|
166
|
+
self.type = "tensor"
|
|
167
|
+
|
|
168
|
+
self.element_accumulator = element_accumulator
|
|
169
|
+
self.elements_per_access = elements_per_access
|
|
170
|
+
|
|
171
|
+
def get_epilogue_node(self, visitors):
|
|
172
|
+
self.epilogue_node = AccumulatorOp(
|
|
173
|
+
self.element_accumulator, self.elements_per_access)
|
|
174
|
+
|
|
175
|
+
def get_argument(self, visitor_args, kwargs):
|
|
176
|
+
self.argument = self.epilogue_node.argument_type()
|
|
177
|
+
|
|
178
|
+
class TensorInputNode(NameNode):
|
|
179
|
+
# Concept: VisitorOpTensorInput
|
|
180
|
+
def __init__(self, element_accumulator, node) -> None:
|
|
181
|
+
super().__init__(node)
|
|
182
|
+
self.tag = "TensorInput:" + self.tag
|
|
183
|
+
self.type = "tensor"
|
|
184
|
+
self.element_accumulator = element_accumulator
|
|
185
|
+
|
|
186
|
+
def get_epilogue_node(self, *args):
|
|
187
|
+
self.epilogue_node = TensorInputOp(self.element_accumulator)
|
|
188
|
+
|
|
189
|
+
def get_argument(self, visitor_args, kwargs):
|
|
190
|
+
self.argument = self.epilogue_node.argument_type(
|
|
191
|
+
kwargs[self.id + "_ptr"], kwargs["problem_size"][1],
|
|
192
|
+
kwargs["problem_size"][0] * kwargs["problem_size"][1])
|
|
193
|
+
|
|
194
|
+
class RowBroadcastNode(NameNode):
|
|
195
|
+
# Concept: VisitorOpRowBroadcast
|
|
196
|
+
def __init__(self, element_accumulator, element_fragment, node) -> None:
|
|
197
|
+
super().__init__(node)
|
|
198
|
+
#
|
|
199
|
+
self.tag = "RowBroadcast:" + self.tag
|
|
200
|
+
self.type = "tensor"
|
|
201
|
+
self.element_accumulator = element_accumulator
|
|
202
|
+
self.element_fragment = element_fragment
|
|
203
|
+
|
|
204
|
+
def get_epilogue_node(self, *args):
|
|
205
|
+
self.epilogue_node = RowBroadcastOp(
|
|
206
|
+
self.element_accumulator, self.element_fragment)
|
|
207
|
+
|
|
208
|
+
def get_argument(self, visitor_args, kwargs):
|
|
209
|
+
self.argument = self.epilogue_node.argument_type(kwargs[self.id + "_ptr"], kwargs["problem_size"][1])
|
|
210
|
+
|
|
211
|
+
class ColumnBroadcastNode(NameNode):
|
|
212
|
+
# Concept: VisitorOpColumnBroadcast
|
|
213
|
+
def __init__(self, element_accumulator, element_fragment, node) -> None:
|
|
214
|
+
super().__init__(node)
|
|
215
|
+
self.tag = "ColumnBroadcast:" + self.tag
|
|
216
|
+
self.type = "tensor"
|
|
217
|
+
self.element_accumulator = element_accumulator
|
|
218
|
+
self.element_fragment = element_fragment
|
|
219
|
+
|
|
220
|
+
def get_epilogue_node(self, *args):
|
|
221
|
+
self.epilogue_node = ColumnBroadcastOp(
|
|
222
|
+
self.element_accumulator, self.element_fragment)
|
|
223
|
+
|
|
224
|
+
def get_argument(self, visitor_args, kwargs):
|
|
225
|
+
self.argument = self.epilogue_node.argument_type(kwargs[self.id + "_ptr"], kwargs["problem_size"][0])
|
|
226
|
+
|
|
227
|
+
class TensorOutputNode(NameNode):
|
|
228
|
+
# Concept: VisitorOpTensorOutput
|
|
229
|
+
def __init__(self, element_accumulator, node) -> None:
|
|
230
|
+
super().__init__(node)
|
|
231
|
+
self.tag = "TensorOutput:" + self.tag
|
|
232
|
+
self.type = "tensor"
|
|
233
|
+
self.element_accumulator = element_accumulator
|
|
234
|
+
|
|
235
|
+
def get_epilogue_node(self, visitors):
|
|
236
|
+
self.epilogue_node = TensorOutputOp(self.element_accumulator, *visitors)
|
|
237
|
+
|
|
238
|
+
def get_argument(self, visitor_args, kwargs):
|
|
239
|
+
self.argument = self.epilogue_node.argument_type(kwargs[self.id + "_ptr"], kwargs["problem_size"][1], *visitor_args, kwargs["problem_size"][0] * kwargs["problem_size"][1])
|
|
240
|
+
|
|
241
|
+
class RowReductionNode:
|
|
242
|
+
# Concept: RowReductionOp
|
|
243
|
+
def __init__(self, element_accumulator, element_reduction,
|
|
244
|
+
element_reduction_accumulator, id, factor) -> None:
|
|
245
|
+
#
|
|
246
|
+
self.id = id
|
|
247
|
+
self.tag = "RowReduction:" + self.id
|
|
248
|
+
self.type = "tensor"
|
|
249
|
+
self.element_accumulator = element_accumulator
|
|
250
|
+
self.element_reduction = element_reduction
|
|
251
|
+
self.element_reduction_accumulator = element_reduction_accumulator
|
|
252
|
+
self.factor = factor
|
|
253
|
+
|
|
254
|
+
def get_epilogue_node(self, visitors):
|
|
255
|
+
self.epilogue_node = RowReductionOp(
|
|
256
|
+
self.element_accumulator, self.element_reduction,
|
|
257
|
+
self.element_reduction_accumulator, *visitors)
|
|
258
|
+
|
|
259
|
+
def get_batch_stride(self, problem_size):
|
|
260
|
+
return problem_size[0] * ((problem_size[1] + self.factor - 1) // self.factor)
|
|
261
|
+
|
|
262
|
+
def get_argument(self, visitor_args, kwargs):
|
|
263
|
+
self.argument = self.epilogue_node.argument_type(kwargs[self.id + "_ptr"], *visitor_args, self.get_batch_stride(kwargs["problem_size"]))
|
|
264
|
+
|
|
265
|
+
class ColumnReductionNode:
|
|
266
|
+
# Concept: ColumnReductionOp
|
|
267
|
+
def __init__(self, element_accumulator, element_reduction,
|
|
268
|
+
element_reduction_accumulator, id, factor) -> None:
|
|
269
|
+
#
|
|
270
|
+
self.id = id
|
|
271
|
+
self.tag = "ColumnReduction:" + self.id
|
|
272
|
+
self.type = "tensor"
|
|
273
|
+
self.element_accumulator = element_accumulator
|
|
274
|
+
self.element_reduction = element_reduction
|
|
275
|
+
self.element_reduction_accumulator = element_reduction_accumulator
|
|
276
|
+
self.factor = factor
|
|
277
|
+
|
|
278
|
+
def get_epilogue_node(self, visitors):
|
|
279
|
+
self.epilogue_node = ColumnReductionOp(
|
|
280
|
+
self.element_accumulator, self.element_reduction,
|
|
281
|
+
self.element_reduction_accumulator, *visitors)
|
|
282
|
+
|
|
283
|
+
def get_batch_stride(self, problem_size):
|
|
284
|
+
return problem_size[1] * ((problem_size[0] + self.factor - 1) // self.factor)
|
|
285
|
+
|
|
286
|
+
def get_argument(self, visitor_args, kwargs):
|
|
287
|
+
self.argument = self.epilogue_node.argument_type(kwargs[self.id + '_ptr'], *visitor_args, self.get_batch_stride(kwargs["problem_size"]))
|
|
288
|
+
|
|
289
|
+
################################################################################
|
|
290
|
+
# Epilogue parser function
|
|
291
|
+
################################################################################
|
|
292
|
+
class EpilogueAST(ast.NodeVisitor):
|
|
293
|
+
def __init__(self, epilogue,
|
|
294
|
+
tile_description,
|
|
295
|
+
element_accumulator, elements_per_access,
|
|
296
|
+
element_compute, element_output) -> None:
|
|
297
|
+
#
|
|
298
|
+
|
|
299
|
+
self.tile_description = tile_description
|
|
300
|
+
self.element_accumulator = element_accumulator
|
|
301
|
+
self.elements_per_access = elements_per_access
|
|
302
|
+
self.element_compute = element_compute
|
|
303
|
+
self.element_output = element_output
|
|
304
|
+
self.epilogue = epilogue
|
|
305
|
+
|
|
306
|
+
self.source = textwrap.dedent(inspect.getsource(epilogue.__call__))
|
|
307
|
+
self.ast_tree = ast.parse(self.source)
|
|
308
|
+
self.epilogue_tree = Tree()
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
# print(ast.dump(self.ast_tree, indent=4)) # For Debug purpose
|
|
312
|
+
|
|
313
|
+
# input arguments
|
|
314
|
+
self.input_args = {}
|
|
315
|
+
# return nodes
|
|
316
|
+
self.returns = []
|
|
317
|
+
# reduction source nodes
|
|
318
|
+
self.reduction_source = {}
|
|
319
|
+
|
|
320
|
+
# stack used to keep the parent node id
|
|
321
|
+
self.stack = []
|
|
322
|
+
|
|
323
|
+
# visit the AST
|
|
324
|
+
self.visit(self.ast_tree)
|
|
325
|
+
|
|
326
|
+
# visit the name node
|
|
327
|
+
def visit_Name(self, node):
|
|
328
|
+
# append the return ids into self.returns
|
|
329
|
+
if self.stack[-1] == "return":
|
|
330
|
+
self.returns.append(node.id)
|
|
331
|
+
else:
|
|
332
|
+
# accum is produced from accumulator node
|
|
333
|
+
if node.id == "accum":
|
|
334
|
+
name_node = AccumulatorNode(
|
|
335
|
+
self.element_accumulator, self.elements_per_access, node)
|
|
336
|
+
else:
|
|
337
|
+
# for input nodes
|
|
338
|
+
if node.id in self.input_args.keys():
|
|
339
|
+
type = self.input_args[node.id][0]
|
|
340
|
+
if type == "tensor":
|
|
341
|
+
name_node = TensorInputNode(self.element_accumulator, node)
|
|
342
|
+
elif type == "row":
|
|
343
|
+
name_node = RowBroadcastNode(self.element_accumulator, self.element_compute, node)
|
|
344
|
+
elif type == "column":
|
|
345
|
+
name_node = ColumnBroadcastNode(self.element_accumulator, self.element_compute, node)
|
|
346
|
+
elif type == "scalar":
|
|
347
|
+
name_node = ScalarInputNode(node)
|
|
348
|
+
else:
|
|
349
|
+
raise ValueError(type)
|
|
350
|
+
# for output nodes
|
|
351
|
+
else:
|
|
352
|
+
name_node = TensorOutputNode(self.element_accumulator, node)
|
|
353
|
+
self.epilogue_tree.create_node(name_node.tag, name_node.id, data=name_node, parent=self.stack[-1])
|
|
354
|
+
|
|
355
|
+
def visit_Assign(self, node):
|
|
356
|
+
pre_assign_node = self.epilogue_tree.get_node(node.targets[0].id)
|
|
357
|
+
if pre_assign_node is None:
|
|
358
|
+
# The assign is to a root node
|
|
359
|
+
# skip the reduction nodes
|
|
360
|
+
if isinstance(node.value, ast.Call):
|
|
361
|
+
if isinstance(node.value.func, ast.Name):
|
|
362
|
+
func_type = node.value.func.id
|
|
363
|
+
elif isinstance(node.value.func, ast.Attribute):
|
|
364
|
+
func_type = node.value.func.value.id
|
|
365
|
+
else:
|
|
366
|
+
raise TypeError
|
|
367
|
+
if func_type == 'reduction_op':
|
|
368
|
+
self.reduction_source[node.value.args[0].id] = [node.value.args[1].value, node.value.args[2].value, node.targets[0].id]
|
|
369
|
+
return
|
|
370
|
+
name_node = TensorOutputNode(self.element_accumulator, node)
|
|
371
|
+
self.epilogue_tree.create_node(name_node.tag, name_node.id, data=name_node)
|
|
372
|
+
self.stack.append(name_node.id)
|
|
373
|
+
else:
|
|
374
|
+
if node.targets[0].id in self.returns or node.targets[0].id in self.reduction_source.keys():
|
|
375
|
+
self.stack.append(node.targets[0].id)
|
|
376
|
+
else:
|
|
377
|
+
self.stack.append(pre_assign_node.predecessor(self.epilogue_tree.identifier))
|
|
378
|
+
self.epilogue_tree.remove_node(node.targets[0].id)
|
|
379
|
+
|
|
380
|
+
# get child tag
|
|
381
|
+
self.visit(node.value)
|
|
382
|
+
self.stack.pop()
|
|
383
|
+
|
|
384
|
+
def visit_Call(self, node):
|
|
385
|
+
if isinstance(node.func, ast.Name):
|
|
386
|
+
func_type = node.func.id
|
|
387
|
+
elif isinstance(node.func, ast.Attribute):
|
|
388
|
+
func_type = node.func.value.id
|
|
389
|
+
else:
|
|
390
|
+
raise TypeError
|
|
391
|
+
if func_type == "reduction_op":
|
|
392
|
+
self.visit(node.args[0])
|
|
393
|
+
else:
|
|
394
|
+
arg_list = []
|
|
395
|
+
for idx, arg in enumerate(node.args):
|
|
396
|
+
if idx == 0: continue
|
|
397
|
+
if isinstance(arg, ast.Constant):
|
|
398
|
+
arg_list.append(arg.value)
|
|
399
|
+
elif isinstance(arg, ast.Name):
|
|
400
|
+
arg_list.append(arg.id)
|
|
401
|
+
else:
|
|
402
|
+
raise TypeError
|
|
403
|
+
|
|
404
|
+
unary_node = UnaryNode(self.element_accumulator, self.element_compute, self.elements_per_access, node, arg_list)
|
|
405
|
+
self.epilogue_tree.create_node(unary_node.tag, unary_node.id, parent=self.stack[-1], data=unary_node)
|
|
406
|
+
self.stack.append(unary_node.id)
|
|
407
|
+
self.visit(node.args[0])
|
|
408
|
+
self.stack.pop()
|
|
409
|
+
|
|
410
|
+
def visit_BinOp(self, node):
|
|
411
|
+
binop = BinOpNode(self.element_accumulator, self.element_compute,
|
|
412
|
+
self.elements_per_access, node)
|
|
413
|
+
self.epilogue_tree.create_node(binop.tag, binop.id, data=binop, parent=self.stack[-1])
|
|
414
|
+
self.stack.append(binop.id)
|
|
415
|
+
self.visit(node.left)
|
|
416
|
+
self.visit(node.right)
|
|
417
|
+
self.stack.pop()
|
|
418
|
+
|
|
419
|
+
def visit_Return(self, node):
|
|
420
|
+
self.stack.append("return")
|
|
421
|
+
self.visit(node.value)
|
|
422
|
+
self.stack.pop()
|
|
423
|
+
|
|
424
|
+
# # A function definition
|
|
425
|
+
def visit_FunctionDef(self, node: ast.FunctionDef):
|
|
426
|
+
# visit args
|
|
427
|
+
for arg in node.args.args:
|
|
428
|
+
if arg.arg == "self": continue
|
|
429
|
+
if isinstance(arg.annotation, ast.Constant):
|
|
430
|
+
self.input_args[arg.arg] = [arg.annotation.value, ]
|
|
431
|
+
# visit the assign in the reverse order
|
|
432
|
+
for idx in range(len(node.body)):
|
|
433
|
+
self.visit(node.body[-1-idx])
|
|
434
|
+
|
|
435
|
+
#
|
|
436
|
+
# Tree optimization pass
|
|
437
|
+
#
|
|
438
|
+
|
|
439
|
+
# pass 1: lower Binary to Unary
|
|
440
|
+
def pass_binary_2_unary(self, tree, nid):
|
|
441
|
+
node = tree.get_node(nid)
|
|
442
|
+
if isinstance(node.data, BinOpNode):
|
|
443
|
+
lhs_node = tree.get_node(node.successors(tree.identifier)[0])
|
|
444
|
+
left_type = lhs_node.data.type
|
|
445
|
+
rhs_node = tree.get_node(node.successors(tree.identifier)[1])
|
|
446
|
+
right_type = rhs_node.data.type
|
|
447
|
+
|
|
448
|
+
if left_type == "scalar" and right_type == "tensor":
|
|
449
|
+
node.data = UnaryNode(
|
|
450
|
+
self.element_accumulator, self.element_compute,
|
|
451
|
+
self.elements_per_access,
|
|
452
|
+
node.data, [lhs_node.data.id,])
|
|
453
|
+
node.tag = node.data.tag
|
|
454
|
+
tree.remove_node(lhs_node.data.id)
|
|
455
|
+
self.pass_binary_2_unary(tree, rhs_node.data.id)
|
|
456
|
+
|
|
457
|
+
elif left_type == "tensor" and right_type == "scalar":
|
|
458
|
+
node.data = UnaryNode(
|
|
459
|
+
self.element_accumulator, self.element_compute,
|
|
460
|
+
self.elements_per_access,
|
|
461
|
+
node.data, [rhs_node.id,])
|
|
462
|
+
node.tag = node.data.tag
|
|
463
|
+
tree.remove_node(rhs_node.data.id)
|
|
464
|
+
self.pass_binary_2_unary(tree, lhs_node.data.id)
|
|
465
|
+
|
|
466
|
+
else:
|
|
467
|
+
self.pass_binary_2_unary(tree, lhs_node.data.id)
|
|
468
|
+
self.pass_binary_2_unary(tree, rhs_node.data.id)
|
|
469
|
+
else:
|
|
470
|
+
for child in node.successors(tree.identifier):
|
|
471
|
+
self.pass_binary_2_unary(tree, child)
|
|
472
|
+
|
|
473
|
+
# pass 2: inject reduction nodes
|
|
474
|
+
def pass_inject_reduction(self, tree, nid):
|
|
475
|
+
node = tree.get_node(nid)
|
|
476
|
+
if isinstance(node.data, TensorOutputNode):
|
|
477
|
+
if node.data.id in self.reduction_source.keys():
|
|
478
|
+
direction = self.reduction_source[node.data.id][0]
|
|
479
|
+
target = self.reduction_source[node.data.id][-1]
|
|
480
|
+
if direction == 'row':
|
|
481
|
+
reduction_node = RowReductionNode(
|
|
482
|
+
self.element_accumulator, self.element_output,
|
|
483
|
+
self.element_accumulator, target, self.tile_description.threadblock_shape[1])
|
|
484
|
+
elif direction == "column":
|
|
485
|
+
reduction_node = ColumnReductionNode(
|
|
486
|
+
self.element_accumulator, self.element_output,
|
|
487
|
+
self.element_accumulator, target, self.tile_description.threadblock_shape[0])
|
|
488
|
+
else:
|
|
489
|
+
raise ValueError(direction)
|
|
490
|
+
child_nid = node.successors(tree.identifier)[0]
|
|
491
|
+
# if this output node is injected only for reduction
|
|
492
|
+
if node.data.id not in self.returns:
|
|
493
|
+
# get reduction config from disc
|
|
494
|
+
node.data = reduction_node
|
|
495
|
+
node.tag = reduction_node.tag
|
|
496
|
+
self.pass_inject_reduction(tree, child_nid)
|
|
497
|
+
# if this output node is also a tensor output, inject reduction as its children
|
|
498
|
+
else:
|
|
499
|
+
# get child node
|
|
500
|
+
tree.create_node(reduction_node.tag, reduction_node.id, data=reduction_node, parent=node.data.id)
|
|
501
|
+
tree.move_node(child_nid, reduction_node.id)
|
|
502
|
+
child = tree.get_node(child_nid)
|
|
503
|
+
for grand_child in child.successors(tree.identifier):
|
|
504
|
+
self.pass_inject_reduction(tree, grand_child)
|
|
505
|
+
else:
|
|
506
|
+
for child in node.successors(tree.identifier):
|
|
507
|
+
self.pass_inject_reduction(tree, child)
|
|
508
|
+
else:
|
|
509
|
+
for child in node.successors(tree.identifier):
|
|
510
|
+
self.pass_inject_reduction(tree, child)
|
|
511
|
+
|
|
512
|
+
def pass_inject_epilogue_op(self, tree, nid):
|
|
513
|
+
node = tree.get_node(nid)
|
|
514
|
+
visitors = []
|
|
515
|
+
for child in node.successors(tree.identifier):
|
|
516
|
+
visitors.append(self.pass_inject_epilogue_op(tree, child))
|
|
517
|
+
|
|
518
|
+
node.data.get_epilogue_node(visitors)
|
|
519
|
+
return node.data.epilogue_node
|
|
520
|
+
|
|
521
|
+
def get_arguments(self, tree, nid, kwargs):
|
|
522
|
+
node = tree.get_node(nid)
|
|
523
|
+
visitor_args = []
|
|
524
|
+
for child in node.successors(tree.identifier):
|
|
525
|
+
visitor_args.append(self.get_arguments(tree, child, kwargs))
|
|
526
|
+
|
|
527
|
+
node.data.get_argument(visitor_args, kwargs)
|
|
528
|
+
return node.data.argument
|
|
529
|
+
|
|
530
|
+
class EpilogueVisitTree:
|
|
531
|
+
KernelTemplate = """
|
|
532
|
+
${visitor}
|
|
533
|
+
|
|
534
|
+
using ${operation_name}_EpilogueVisitor = cutlass::epilogue::threadblock::EpilogueVisitorGeneric<${visitor_name}>;
|
|
535
|
+
"""
|
|
536
|
+
def __init__(self, elementwise_functor, tile_description,
|
|
537
|
+
element_accumulator, elements_per_access,
|
|
538
|
+
element_compute, element_output) -> None:
|
|
539
|
+
#
|
|
540
|
+
# data types
|
|
541
|
+
self.tile_description = tile_description
|
|
542
|
+
self.element_accumulator = element_accumulator
|
|
543
|
+
self.elements_per_access = elements_per_access
|
|
544
|
+
self.element_compute = element_compute
|
|
545
|
+
self.element_output = element_output
|
|
546
|
+
# TODO: deprecate this
|
|
547
|
+
self.elementwise_functor = elementwise_functor
|
|
548
|
+
pass
|
|
549
|
+
|
|
550
|
+
def initialize(self):
|
|
551
|
+
function = EpilogueAST(self, self.tile_description,
|
|
552
|
+
self.element_accumulator, self.elements_per_access,
|
|
553
|
+
self.element_compute, self.element_output)
|
|
554
|
+
#
|
|
555
|
+
tree = function.epilogue_tree
|
|
556
|
+
self.tree = tree
|
|
557
|
+
# self.tree.show() # for debug
|
|
558
|
+
function.pass_binary_2_unary(self.tree, self.tree.root)
|
|
559
|
+
# self.tree.show() # for debug
|
|
560
|
+
function.pass_inject_reduction(self.tree, self.tree.root)
|
|
561
|
+
# self.tree.show() # for debug
|
|
562
|
+
function.pass_inject_epilogue_op(self.tree,self.tree.root)
|
|
563
|
+
|
|
564
|
+
visitor = self.tree.get_node(self.tree.root).data.epilogue_node
|
|
565
|
+
self.visitor = visitor
|
|
566
|
+
|
|
567
|
+
class _Argument(ctypes.Structure):
|
|
568
|
+
_fields_ = [
|
|
569
|
+
("visitor_arg", visitor.argument_type)
|
|
570
|
+
]
|
|
571
|
+
def __init__(self, **kwargs) -> None:
|
|
572
|
+
# process input args
|
|
573
|
+
_kwargs = {}
|
|
574
|
+
for input_key in function.input_args.keys():
|
|
575
|
+
if input_key == "accum":
|
|
576
|
+
continue
|
|
577
|
+
if function.input_args[input_key][0] == "scalar":
|
|
578
|
+
# _kwargs[input_key] = kwargs[input_key]
|
|
579
|
+
continue
|
|
580
|
+
# tensor input
|
|
581
|
+
else:
|
|
582
|
+
setattr(self, "buffer_tensor_" + input_key, NumpyFrontend.argument(kwargs[input_key], False))
|
|
583
|
+
setattr(self, input_key + "_ptr", int(getattr(self, "buffer_tensor_" + input_key).ptr))
|
|
584
|
+
_kwargs[input_key+"_ptr"] = getattr(self, input_key + "_ptr")
|
|
585
|
+
# process the return args
|
|
586
|
+
for ret in function.returns:
|
|
587
|
+
setattr(self, "buffer_tensor_" + ret, NumpyFrontend.argument(kwargs[ret], True))
|
|
588
|
+
setattr(self, ret + "_ptr", int(getattr(self, "buffer_tensor_" + ret).ptr))
|
|
589
|
+
_kwargs[ret+"_ptr"] = getattr(self, ret + "_ptr")
|
|
590
|
+
setattr(self, "host_tensor_" + ret, kwargs[ret])
|
|
591
|
+
|
|
592
|
+
_kwargs.update(kwargs)
|
|
593
|
+
function.get_arguments(tree, tree.root, _kwargs)
|
|
594
|
+
self.visitor_arg = tree.get_node(tree.root).data.argument
|
|
595
|
+
|
|
596
|
+
def sync(self, stream_sync=True):
|
|
597
|
+
if stream_sync:
|
|
598
|
+
err, = cudart.cudaDeviceSynchronize()
|
|
599
|
+
if err != cuda.CUresult.CUDA_SUCCESS:
|
|
600
|
+
raise RuntimeError("CUDA Error %s" % str(err))
|
|
601
|
+
|
|
602
|
+
for ret in function.returns:
|
|
603
|
+
err, = cuda.cuMemcpyDtoH(
|
|
604
|
+
getattr(self, "host_tensor_" + ret), cuda.CUdeviceptr(getattr(self, ret + "_ptr")),
|
|
605
|
+
getattr(self, "host_tensor_" + ret).size * getattr(self, "host_tensor_" + ret).itemsize
|
|
606
|
+
)
|
|
607
|
+
if err != cuda.CUresult.CUDA_SUCCESS:
|
|
608
|
+
raise RuntimeError("CUDA Error %s" % str(err))
|
|
609
|
+
pass
|
|
610
|
+
|
|
611
|
+
self.epilogue_type = _Argument
|
|
612
|
+
|
|
613
|
+
def emit(self, operation):
|
|
614
|
+
values = {
|
|
615
|
+
'visitor': self.visitor.emit(operation),
|
|
616
|
+
'operation_name': operation.procedural_name(),
|
|
617
|
+
'visitor_name': self.visitor.instance_name
|
|
618
|
+
}
|
|
619
|
+
return SubstituteTemplate(self.KernelTemplate, values)
|