warp-lang 0.10.1__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +10 -4
- warp/__init__.pyi +1 -0
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +5 -3
- warp/build_dll.py +29 -9
- warp/builtins.py +868 -507
- warp/codegen.py +1074 -638
- warp/config.py +3 -3
- warp/constants.py +6 -0
- warp/context.py +715 -222
- warp/fabric.py +326 -0
- warp/fem/__init__.py +27 -0
- warp/fem/cache.py +389 -0
- warp/fem/dirichlet.py +181 -0
- warp/fem/domain.py +263 -0
- warp/fem/field/__init__.py +101 -0
- warp/fem/field/field.py +149 -0
- warp/fem/field/nodal_field.py +299 -0
- warp/fem/field/restriction.py +21 -0
- warp/fem/field/test.py +181 -0
- warp/fem/field/trial.py +183 -0
- warp/fem/geometry/__init__.py +19 -0
- warp/fem/geometry/closest_point.py +70 -0
- warp/fem/geometry/deformed_geometry.py +271 -0
- warp/fem/geometry/element.py +744 -0
- warp/fem/geometry/geometry.py +186 -0
- warp/fem/geometry/grid_2d.py +373 -0
- warp/fem/geometry/grid_3d.py +435 -0
- warp/fem/geometry/hexmesh.py +953 -0
- warp/fem/geometry/partition.py +376 -0
- warp/fem/geometry/quadmesh_2d.py +532 -0
- warp/fem/geometry/tetmesh.py +840 -0
- warp/fem/geometry/trimesh_2d.py +577 -0
- warp/fem/integrate.py +1616 -0
- warp/fem/operator.py +191 -0
- warp/fem/polynomial.py +213 -0
- warp/fem/quadrature/__init__.py +2 -0
- warp/fem/quadrature/pic_quadrature.py +245 -0
- warp/fem/quadrature/quadrature.py +294 -0
- warp/fem/space/__init__.py +292 -0
- warp/fem/space/basis_space.py +489 -0
- warp/fem/space/collocated_function_space.py +105 -0
- warp/fem/space/dof_mapper.py +236 -0
- warp/fem/space/function_space.py +145 -0
- warp/fem/space/grid_2d_function_space.py +267 -0
- warp/fem/space/grid_3d_function_space.py +306 -0
- warp/fem/space/hexmesh_function_space.py +352 -0
- warp/fem/space/partition.py +350 -0
- warp/fem/space/quadmesh_2d_function_space.py +369 -0
- warp/fem/space/restriction.py +160 -0
- warp/fem/space/shape/__init__.py +15 -0
- warp/fem/space/shape/cube_shape_function.py +738 -0
- warp/fem/space/shape/shape_function.py +103 -0
- warp/fem/space/shape/square_shape_function.py +611 -0
- warp/fem/space/shape/tet_shape_function.py +567 -0
- warp/fem/space/shape/triangle_shape_function.py +429 -0
- warp/fem/space/tetmesh_function_space.py +292 -0
- warp/fem/space/topology.py +295 -0
- warp/fem/space/trimesh_2d_function_space.py +221 -0
- warp/fem/types.py +77 -0
- warp/fem/utils.py +495 -0
- warp/native/array.h +147 -44
- warp/native/builtin.h +122 -149
- warp/native/bvh.cpp +73 -325
- warp/native/bvh.cu +406 -23
- warp/native/bvh.h +34 -43
- warp/native/clang/clang.cpp +13 -8
- warp/native/crt.h +2 -0
- warp/native/cuda_crt.h +5 -0
- warp/native/cuda_util.cpp +15 -3
- warp/native/cuda_util.h +3 -1
- warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
- warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
- warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
- warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
- warp/native/cutlass/tools/library/scripts/library.py +799 -0
- warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
- warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
- warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
- warp/native/cutlass/tools/library/scripts/rt.py +796 -0
- warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
- warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
- warp/native/cutlass_gemm.cu +5 -3
- warp/native/exports.h +1240 -952
- warp/native/fabric.h +228 -0
- warp/native/hashgrid.cpp +4 -4
- warp/native/hashgrid.h +22 -2
- warp/native/intersect.h +22 -7
- warp/native/intersect_adj.h +8 -8
- warp/native/intersect_tri.h +1 -1
- warp/native/marching.cu +157 -161
- warp/native/mat.h +80 -19
- warp/native/matnn.h +2 -2
- warp/native/mesh.cpp +33 -108
- warp/native/mesh.cu +114 -23
- warp/native/mesh.h +446 -46
- warp/native/noise.h +272 -329
- warp/native/quat.h +51 -8
- warp/native/rand.h +45 -35
- warp/native/range.h +6 -2
- warp/native/reduce.cpp +1 -1
- warp/native/reduce.cu +10 -12
- warp/native/runlength_encode.cu +6 -10
- warp/native/scan.cu +8 -11
- warp/native/sparse.cpp +4 -4
- warp/native/sparse.cu +164 -154
- warp/native/spatial.h +2 -2
- warp/native/temp_buffer.h +14 -30
- warp/native/vec.h +107 -23
- warp/native/volume.h +120 -0
- warp/native/warp.cpp +560 -30
- warp/native/warp.cu +431 -44
- warp/native/warp.h +13 -4
- warp/optim/__init__.py +1 -0
- warp/optim/linear.py +922 -0
- warp/optim/sgd.py +92 -0
- warp/render/render_opengl.py +335 -119
- warp/render/render_usd.py +11 -11
- warp/sim/__init__.py +2 -2
- warp/sim/articulation.py +385 -185
- warp/sim/collide.py +8 -0
- warp/sim/import_mjcf.py +297 -106
- warp/sim/import_urdf.py +389 -210
- warp/sim/import_usd.py +198 -97
- warp/sim/inertia.py +17 -18
- warp/sim/integrator_euler.py +14 -8
- warp/sim/integrator_xpbd.py +158 -16
- warp/sim/model.py +795 -291
- warp/sim/render.py +3 -3
- warp/sim/utils.py +3 -0
- warp/sparse.py +640 -150
- warp/stubs.py +606 -267
- warp/tape.py +61 -10
- warp/tests/__main__.py +3 -6
- warp/tests/assets/curlnoise_golden.npy +0 -0
- warp/tests/assets/pnoise_golden.npy +0 -0
- warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
- warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
- warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
- warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
- warp/tests/aux_test_unresolved_func.py +14 -0
- warp/tests/aux_test_unresolved_symbol.py +14 -0
- warp/tests/disabled_kinematics.py +239 -0
- warp/tests/run_coverage_serial.py +31 -0
- warp/tests/test_adam.py +103 -106
- warp/tests/test_arithmetic.py +128 -74
- warp/tests/test_array.py +212 -97
- warp/tests/test_array_reduce.py +57 -23
- warp/tests/test_atomic.py +64 -28
- warp/tests/test_bool.py +99 -0
- warp/tests/test_builtins_resolution.py +1292 -0
- warp/tests/test_bvh.py +42 -18
- warp/tests/test_closest_point_edge_edge.py +54 -57
- warp/tests/test_codegen.py +208 -130
- warp/tests/test_compile_consts.py +28 -20
- warp/tests/test_conditional.py +108 -24
- warp/tests/test_copy.py +10 -12
- warp/tests/test_ctypes.py +112 -88
- warp/tests/test_dense.py +21 -14
- warp/tests/test_devices.py +98 -0
- warp/tests/test_dlpack.py +75 -75
- warp/tests/test_examples.py +277 -0
- warp/tests/test_fabricarray.py +955 -0
- warp/tests/test_fast_math.py +15 -11
- warp/tests/test_fem.py +1271 -0
- warp/tests/test_fp16.py +53 -19
- warp/tests/test_func.py +187 -86
- warp/tests/test_generics.py +194 -49
- warp/tests/test_grad.py +178 -109
- warp/tests/test_grad_customs.py +176 -0
- warp/tests/test_hash_grid.py +52 -37
- warp/tests/test_import.py +10 -23
- warp/tests/test_indexedarray.py +32 -31
- warp/tests/test_intersect.py +18 -9
- warp/tests/test_large.py +141 -0
- warp/tests/test_launch.py +14 -41
- warp/tests/test_lerp.py +64 -65
- warp/tests/test_linear_solvers.py +154 -0
- warp/tests/test_lvalue.py +493 -0
- warp/tests/test_marching_cubes.py +12 -13
- warp/tests/test_mat.py +517 -2898
- warp/tests/test_mat_lite.py +115 -0
- warp/tests/test_mat_scalar_ops.py +2889 -0
- warp/tests/test_math.py +103 -9
- warp/tests/test_matmul.py +305 -69
- warp/tests/test_matmul_lite.py +410 -0
- warp/tests/test_mesh.py +71 -14
- warp/tests/test_mesh_query_aabb.py +41 -25
- warp/tests/test_mesh_query_point.py +140 -22
- warp/tests/test_mesh_query_ray.py +39 -22
- warp/tests/test_mlp.py +30 -22
- warp/tests/test_model.py +92 -89
- warp/tests/test_modules_lite.py +39 -0
- warp/tests/test_multigpu.py +88 -114
- warp/tests/test_noise.py +12 -11
- warp/tests/test_operators.py +16 -20
- warp/tests/test_options.py +11 -11
- warp/tests/test_pinned.py +17 -18
- warp/tests/test_print.py +32 -11
- warp/tests/test_quat.py +275 -129
- warp/tests/test_rand.py +18 -16
- warp/tests/test_reload.py +38 -34
- warp/tests/test_rounding.py +50 -43
- warp/tests/test_runlength_encode.py +168 -20
- warp/tests/test_smoothstep.py +9 -11
- warp/tests/test_snippet.py +143 -0
- warp/tests/test_sparse.py +261 -63
- warp/tests/test_spatial.py +276 -243
- warp/tests/test_streams.py +110 -85
- warp/tests/test_struct.py +268 -63
- warp/tests/test_tape.py +39 -21
- warp/tests/test_torch.py +118 -89
- warp/tests/test_transient_module.py +12 -13
- warp/tests/test_types.py +614 -0
- warp/tests/test_utils.py +494 -0
- warp/tests/test_vec.py +354 -2050
- warp/tests/test_vec_lite.py +73 -0
- warp/tests/test_vec_scalar_ops.py +2099 -0
- warp/tests/test_volume.py +457 -293
- warp/tests/test_volume_write.py +124 -134
- warp/tests/unittest_serial.py +35 -0
- warp/tests/unittest_suites.py +341 -0
- warp/tests/unittest_utils.py +568 -0
- warp/tests/unused_test_misc.py +71 -0
- warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
- warp/thirdparty/appdirs.py +36 -45
- warp/thirdparty/unittest_parallel.py +549 -0
- warp/torch.py +9 -6
- warp/types.py +1089 -366
- warp/utils.py +93 -387
- warp_lang-0.11.0.dist-info/METADATA +238 -0
- warp_lang-0.11.0.dist-info/RECORD +332 -0
- {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
- warp/tests/test_all.py +0 -219
- warp/tests/test_array_scan.py +0 -60
- warp/tests/test_base.py +0 -208
- warp/tests/test_unresolved_func.py +0 -7
- warp/tests/test_unresolved_symbol.py +0 -7
- warp_lang-0.10.1.dist-info/METADATA +0 -21
- warp_lang-0.10.1.dist-info/RECORD +0 -188
- /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
- /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
- /warp/tests/{test_square.py → aux_test_square.py} +0 -0
- {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/LICENSE.md +0 -0
- {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
import distutils.cmd
|
|
2
|
+
from setuptools import setup
|
|
3
|
+
import setuptools.command.build_py
|
|
4
|
+
import os
|
|
5
|
+
|
|
6
|
+
# build rmm dependency
|
|
7
|
+
class BuildRMM(distutils.cmd.Command):
|
|
8
|
+
user_options = []
|
|
9
|
+
def initialize_options(self):
|
|
10
|
+
pass
|
|
11
|
+
def finalize_options(self):
|
|
12
|
+
pass
|
|
13
|
+
def run(self):
|
|
14
|
+
try:
|
|
15
|
+
import rmm
|
|
16
|
+
except ImportError:
|
|
17
|
+
print("installing rmm")
|
|
18
|
+
os.system("git clone -b branch-22.08 --recurse-submodules https://github.com/rapidsai/rmm.git")
|
|
19
|
+
os.chdir("./rmm")
|
|
20
|
+
os.system("./build.sh librmm rmm")
|
|
21
|
+
os.chdir("./python")
|
|
22
|
+
os.system("python setup.py build_ext --inplace")
|
|
23
|
+
os.system("python setup.py install")
|
|
24
|
+
|
|
25
|
+
cutlass_path = os.getenv('CUTLASS_PATH')
|
|
26
|
+
assert cutlass_path is not None, "Environment variable 'CUTLASS_PATH' is not defined."
|
|
27
|
+
cuda_install_path = os.getenv('CUDA_INSTALL_PATH')
|
|
28
|
+
assert cuda_install_path is not None, "Environment variable 'CUDA_INSTALL_PATH' is not defined."
|
|
29
|
+
|
|
30
|
+
ext_modules = []
|
|
31
|
+
|
|
32
|
+
try:
|
|
33
|
+
from pybind11.setup_helpers import Pybind11Extension, build_ext
|
|
34
|
+
include_dirs = [
|
|
35
|
+
cutlass_path + "/include",
|
|
36
|
+
cuda_install_path + "/include",
|
|
37
|
+
cutlass_path + "/tools/util/include",
|
|
38
|
+
cutlass_path + "/test",
|
|
39
|
+
cutlass_path + "/tools/library/scripts/pycutlass/googletest/googletest/include"
|
|
40
|
+
]
|
|
41
|
+
|
|
42
|
+
ext_modules = [
|
|
43
|
+
Pybind11Extension("cutlass",
|
|
44
|
+
["src/cpp/cutlass.cpp"],
|
|
45
|
+
include_dirs=include_dirs,
|
|
46
|
+
extra_compile_args=["-fpermissive", "-w"])
|
|
47
|
+
]
|
|
48
|
+
except ImportError:
|
|
49
|
+
pass
|
|
50
|
+
|
|
51
|
+
setup(
|
|
52
|
+
name="PyCutlass",
|
|
53
|
+
version="0.0.1",
|
|
54
|
+
author="Zhaodong Chen; Andrew Kerr; Haicheng Wu; Szymon Migacz; Graham Markall",
|
|
55
|
+
author_email="zhaodongc@nvidia.com",
|
|
56
|
+
description="Python interface for CUTLASS",
|
|
57
|
+
classifiers=[
|
|
58
|
+
"Programming Language :: Python :: 3",
|
|
59
|
+
"License :: OSI Approved :: MIT License",
|
|
60
|
+
"Operating System :: OS Independent",
|
|
61
|
+
],
|
|
62
|
+
package_dir={"": "src"},
|
|
63
|
+
packages=['pycutlass', 'pycutlass.utils', 'pycutlass.test'],
|
|
64
|
+
setup_requires=["pybind11", "numpy<1.23"],
|
|
65
|
+
install_requires=[
|
|
66
|
+
"numpy<1.23",
|
|
67
|
+
'pybind11',
|
|
68
|
+
'cuda-python<11.7.0',
|
|
69
|
+
'typeguard',
|
|
70
|
+
'bfloat16',
|
|
71
|
+
'typing',
|
|
72
|
+
'scikit-build',
|
|
73
|
+
'treelib'
|
|
74
|
+
],
|
|
75
|
+
cmdclass={
|
|
76
|
+
'rmm': BuildRMM
|
|
77
|
+
},
|
|
78
|
+
ext_modules=ext_modules,
|
|
79
|
+
python_requires=">=3.6",
|
|
80
|
+
)
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
import re
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def SubstituteTemplate(template, values):
|
|
5
|
+
text = template
|
|
6
|
+
changed = True
|
|
7
|
+
while changed:
|
|
8
|
+
changed = False
|
|
9
|
+
for key, value in values.items():
|
|
10
|
+
regex = "\\$\\{%s\\}" % key
|
|
11
|
+
newtext = re.sub(regex, value, text)
|
|
12
|
+
if newtext != text:
|
|
13
|
+
changed = True
|
|
14
|
+
text = newtext
|
|
15
|
+
return text
|
|
16
|
+
|
|
17
|
+
from pycutlass.type_hint import *
|
|
18
|
+
from pycutlass.tensor_ref import *
|
|
19
|
+
from pycutlass.operation import *
|
|
20
|
+
from pycutlass.epilogue import *
|
|
21
|
+
from pycutlass.parser import *
|
|
22
|
+
from pycutlass.compiler import ArtifactManager
|
|
23
|
+
from pycutlass.memory_manager import *
|
|
24
|
+
from pycutlass.arguments import *
|
|
25
|
+
from pycutlass.library import *
|
|
26
|
+
from pycutlass.c_types import *
|
|
27
|
+
from pycutlass.gemm_operation import *
|
|
28
|
+
from pycutlass.conv2d_operation import *
|
|
29
|
+
from pycutlass.compiler import *
|
|
30
|
+
from pycutlass.utils import *
|
|
31
|
+
from pycutlass.frontend import *
|
|
32
|
+
from pycutlass.reduction_operation import *
|
|
33
|
+
from pycutlass.compiler import *
|
|
34
|
+
|
|
35
|
+
# module-wide variables
|
|
36
|
+
|
|
37
|
+
import sys
|
|
38
|
+
this = sys.modules[__name__]
|
|
39
|
+
|
|
40
|
+
# artifact manager
|
|
41
|
+
this.compiler = ArtifactManager()
|
|
42
|
+
|
|
43
|
+
def get_memory_pool(init_pool_size=0, max_pool_size=2**34):
|
|
44
|
+
this.memory_pool = PoolMemoryManager(
|
|
45
|
+
init_pool_size=init_pool_size,
|
|
46
|
+
max_pool_size=max_pool_size
|
|
47
|
+
)
|
|
48
|
+
return this.memory_pool
|
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
#################################################################################################
|
|
2
|
+
#
|
|
3
|
+
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
4
|
+
# SPDX-License-Identifier: BSD-3-Clause
|
|
5
|
+
#
|
|
6
|
+
# Redistribution and use in source and binary forms, with or without
|
|
7
|
+
# modification, are permitted provided that the following conditions are met:
|
|
8
|
+
#
|
|
9
|
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
|
10
|
+
# list of conditions and the following disclaimer.
|
|
11
|
+
#
|
|
12
|
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
13
|
+
# this list of conditions and the following disclaimer in the documentation
|
|
14
|
+
# and/or other materials provided with the distribution.
|
|
15
|
+
#
|
|
16
|
+
# 3. Neither the name of the copyright holder nor the names of its
|
|
17
|
+
# contributors may be used to endorse or promote products derived from
|
|
18
|
+
# this software without specific prior written permission.
|
|
19
|
+
#
|
|
20
|
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
21
|
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
22
|
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
23
|
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
24
|
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
25
|
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
26
|
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
27
|
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
28
|
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
29
|
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
30
|
+
#
|
|
31
|
+
#################################################################################################
|
|
32
|
+
from .frontend import CupyFrontend
|
|
33
|
+
from typeguard import typechecked
|
|
34
|
+
from pycutlass.frontend import *
|
|
35
|
+
from typing import Union
|
|
36
|
+
import numpy as np
|
|
37
|
+
from cuda import cuda
|
|
38
|
+
try:
|
|
39
|
+
import torch
|
|
40
|
+
torch_available = True
|
|
41
|
+
except ImportError:
|
|
42
|
+
torch_available = False
|
|
43
|
+
from cuda import cudart
|
|
44
|
+
try:
|
|
45
|
+
import cupy as cp
|
|
46
|
+
cupy_available = True
|
|
47
|
+
except ImportError:
|
|
48
|
+
cupy_available = False
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
# @typechecked
|
|
52
|
+
class ArgumentBase:
|
|
53
|
+
"""
|
|
54
|
+
Base class for operation arguments
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
def __init__(self,
|
|
58
|
+
A: 'Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]',
|
|
59
|
+
B: 'Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]',
|
|
60
|
+
C: 'Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]',
|
|
61
|
+
D: 'Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]',
|
|
62
|
+
**kwargs) -> None:
|
|
63
|
+
|
|
64
|
+
# tensor_C can be interpreted as the bias with bias=True in keyword args
|
|
65
|
+
if "bias" in kwargs.keys():
|
|
66
|
+
self.bias = kwargs["bias"]
|
|
67
|
+
else:
|
|
68
|
+
# by default, tensor_C is not bias
|
|
69
|
+
self.bias = False
|
|
70
|
+
|
|
71
|
+
# preprocessing input tensors
|
|
72
|
+
if isinstance(A, np.ndarray):
|
|
73
|
+
self.host_D = D
|
|
74
|
+
self.buffer_A = NumpyFrontend.argument(A, False)
|
|
75
|
+
self.buffer_B = NumpyFrontend.argument(B, False)
|
|
76
|
+
self.buffer_C = NumpyFrontend.argument(C, False)
|
|
77
|
+
self.buffer_D = NumpyFrontend.argument(D, True)
|
|
78
|
+
self.ptr_A = self.buffer_A.ptr
|
|
79
|
+
self.ptr_B = self.buffer_B.ptr
|
|
80
|
+
self.ptr_C = self.buffer_C.ptr
|
|
81
|
+
self.ptr_D = self.buffer_D.ptr
|
|
82
|
+
# number of elements in C
|
|
83
|
+
self.tensor_c_numel = C.size
|
|
84
|
+
elif torch_available and isinstance(A, torch.Tensor):
|
|
85
|
+
self.ptr_A = TorchFrontend.argument(A)
|
|
86
|
+
self.ptr_B = TorchFrontend.argument(B)
|
|
87
|
+
self.ptr_C = TorchFrontend.argument(C)
|
|
88
|
+
self.ptr_D = TorchFrontend.argument(D)
|
|
89
|
+
# number of elements in C
|
|
90
|
+
self.tensor_c_numel = C.numel()
|
|
91
|
+
elif isinstance(A, cuda.CUdeviceptr):
|
|
92
|
+
self.ptr_A = A
|
|
93
|
+
self.ptr_B = B
|
|
94
|
+
self.ptr_C = C
|
|
95
|
+
self.ptr_D = D
|
|
96
|
+
|
|
97
|
+
elif cupy_available and isinstance(A, cp.ndarray):
|
|
98
|
+
self.ptr_A = CupyFrontend.argument(A)
|
|
99
|
+
self.ptr_B = CupyFrontend.argument(B)
|
|
100
|
+
self.ptr_C = CupyFrontend.argument(C)
|
|
101
|
+
self.ptr_D = CupyFrontend.argument(D)
|
|
102
|
+
# number of elements in C
|
|
103
|
+
self.tensor_c_numel = C.size
|
|
104
|
+
else:
|
|
105
|
+
raise TypeError(
|
|
106
|
+
"Unsupported Frontend. Only support numpy and torch")
|
|
107
|
+
|
|
108
|
+
def sync(self, stream_sync=True):
|
|
109
|
+
if stream_sync:
|
|
110
|
+
err, = cudart.cudaDeviceSynchronize()
|
|
111
|
+
if err != cuda.CUresult.CUDA_SUCCESS:
|
|
112
|
+
raise RuntimeError("CUDA Error %s" % str(err))
|
|
113
|
+
|
|
114
|
+
if hasattr(self, "host_D"):
|
|
115
|
+
err, = cuda.cuMemcpyDtoH(
|
|
116
|
+
self.host_D, self.ptr_D, self.host_D.size * self.host_D.itemsize)
|
|
117
|
+
if err != cuda.CUresult.CUDA_SUCCESS:
|
|
118
|
+
raise RuntimeError("CUDA Error %s" % str(err))
|
|
@@ -0,0 +1,241 @@
|
|
|
1
|
+
#################################################################################################
|
|
2
|
+
#
|
|
3
|
+
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
4
|
+
# SPDX-License-Identifier: BSD-3-Clause
|
|
5
|
+
#
|
|
6
|
+
# Redistribution and use in source and binary forms, with or without
|
|
7
|
+
# modification, are permitted provided that the following conditions are met:
|
|
8
|
+
#
|
|
9
|
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
|
10
|
+
# list of conditions and the following disclaimer.
|
|
11
|
+
#
|
|
12
|
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
13
|
+
# this list of conditions and the following disclaimer in the documentation
|
|
14
|
+
# and/or other materials provided with the distribution.
|
|
15
|
+
#
|
|
16
|
+
# 3. Neither the name of the copyright holder nor the names of its
|
|
17
|
+
# contributors may be used to endorse or promote products derived from
|
|
18
|
+
# this software without specific prior written permission.
|
|
19
|
+
#
|
|
20
|
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
21
|
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
22
|
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
23
|
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
24
|
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
25
|
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
26
|
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
27
|
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
28
|
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
29
|
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
30
|
+
#
|
|
31
|
+
#################################################################################################
|
|
32
|
+
|
|
33
|
+
import ctypes
|
|
34
|
+
from pycutlass.library import *
|
|
35
|
+
|
|
36
|
+
# 12B
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class GemmCoord_(ctypes.Structure):
|
|
40
|
+
_fields_ = [
|
|
41
|
+
("m", ctypes.c_int),
|
|
42
|
+
("n", ctypes.c_int),
|
|
43
|
+
("k", ctypes.c_int)
|
|
44
|
+
]
|
|
45
|
+
|
|
46
|
+
def __init__(self, gemm_coord) -> None:
|
|
47
|
+
for field_name, _ in self._fields_:
|
|
48
|
+
setattr(self, field_name, getattr(gemm_coord, field_name)())
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class MatrixCoord_(ctypes.Structure):
|
|
52
|
+
_fields_ = [
|
|
53
|
+
("row", ctypes.c_int),
|
|
54
|
+
("column", ctypes.c_int)
|
|
55
|
+
]
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
dtype2ctype = {
|
|
59
|
+
cutlass.float16: ctypes.c_uint16,
|
|
60
|
+
cutlass.float32: ctypes.c_float,
|
|
61
|
+
cutlass.float64: ctypes.c_double,
|
|
62
|
+
cutlass.int32: ctypes.c_int32
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def get_gemm_arguments(epilogue_functor):
|
|
67
|
+
|
|
68
|
+
_EpilogueOutputOpParams = epilogue_functor.epilogue_type
|
|
69
|
+
|
|
70
|
+
class _GemmArguments(ctypes.Structure):
|
|
71
|
+
_fields_ = [
|
|
72
|
+
# Arguments from UniversalArgumentsBase
|
|
73
|
+
("mode", ctypes.c_int),
|
|
74
|
+
("problem_size", GemmCoord_),
|
|
75
|
+
("batch_count", ctypes.c_int),
|
|
76
|
+
("batch_stride_D", ctypes.c_longlong),
|
|
77
|
+
# Remaining arguments
|
|
78
|
+
("epilogue", _EpilogueOutputOpParams),
|
|
79
|
+
("ptr_A", ctypes.c_void_p),
|
|
80
|
+
("ptr_B", ctypes.c_void_p),
|
|
81
|
+
("ptr_C", ctypes.c_void_p),
|
|
82
|
+
("ptr_D", ctypes.c_void_p),
|
|
83
|
+
("batch_stride_A", ctypes.c_longlong),
|
|
84
|
+
("batch_stride_B", ctypes.c_longlong),
|
|
85
|
+
("batch_stride_C", ctypes.c_longlong),
|
|
86
|
+
("stride_a", ctypes.c_longlong),
|
|
87
|
+
("stride_b", ctypes.c_longlong),
|
|
88
|
+
("stride_c", ctypes.c_longlong),
|
|
89
|
+
("stride_d", ctypes.c_longlong),
|
|
90
|
+
("lda", ctypes.c_longlong),
|
|
91
|
+
("ldb", ctypes.c_longlong),
|
|
92
|
+
("ldc", ctypes.c_longlong),
|
|
93
|
+
("ldd", ctypes.c_longlong),
|
|
94
|
+
("ptr_gather_A_indices", ctypes.c_void_p),
|
|
95
|
+
("ptr_gether_B_indices", ctypes.c_void_p),
|
|
96
|
+
("ptr_scatter_D_indices", ctypes.c_void_p)
|
|
97
|
+
]
|
|
98
|
+
|
|
99
|
+
return _GemmArguments, _EpilogueOutputOpParams
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
###########################################################################################
|
|
103
|
+
# GEMM Grouped
|
|
104
|
+
###########################################################################################
|
|
105
|
+
|
|
106
|
+
# include/cutlass/gemm/kernel/gemm_grouped.h
|
|
107
|
+
|
|
108
|
+
def get_gemm_grouped_arguments(epilogue_functor):
|
|
109
|
+
_EpilogueOutputOpParams = epilogue_functor.epilogue_type
|
|
110
|
+
|
|
111
|
+
class _GEMMGroupedArguments(ctypes.Structure):
|
|
112
|
+
_fields_ = [
|
|
113
|
+
("problem_sizes", ctypes.c_void_p),
|
|
114
|
+
("problem_count", ctypes.c_int),
|
|
115
|
+
("threadblock_count", ctypes.c_int),
|
|
116
|
+
("output_op", _EpilogueOutputOpParams),
|
|
117
|
+
("ptr_A", ctypes.c_void_p),
|
|
118
|
+
("ptr_B", ctypes.c_void_p),
|
|
119
|
+
("ptr_C", ctypes.c_void_p),
|
|
120
|
+
("ptr_D", ctypes.c_void_p),
|
|
121
|
+
("lda", ctypes.c_void_p),
|
|
122
|
+
("ldb", ctypes.c_void_p),
|
|
123
|
+
("ldc", ctypes.c_void_p),
|
|
124
|
+
("ldd", ctypes.c_void_p),
|
|
125
|
+
("host_problem_sizes", ctypes.c_void_p)
|
|
126
|
+
]
|
|
127
|
+
|
|
128
|
+
return _GEMMGroupedArguments, _EpilogueOutputOpParams
|
|
129
|
+
|
|
130
|
+
############################################################################################
|
|
131
|
+
# Convolution2D
|
|
132
|
+
############################################################################################
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
# We use the arguments as the interface
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
# include/cutlass/conv/conv2d_problem_size.h
|
|
139
|
+
# 64B
|
|
140
|
+
class Conv2DProblemSize(ctypes.Structure):
|
|
141
|
+
_fields_ = [
|
|
142
|
+
("N", ctypes.c_int),
|
|
143
|
+
("H", ctypes.c_int),
|
|
144
|
+
("W", ctypes.c_int),
|
|
145
|
+
("C", ctypes.c_int),
|
|
146
|
+
("P", ctypes.c_int),
|
|
147
|
+
("Q", ctypes.c_int),
|
|
148
|
+
("K", ctypes.c_int),
|
|
149
|
+
("R", ctypes.c_int),
|
|
150
|
+
("S", ctypes.c_int),
|
|
151
|
+
("pad_h", ctypes.c_int),
|
|
152
|
+
("pad_w", ctypes.c_int),
|
|
153
|
+
("stride_h", ctypes.c_int),
|
|
154
|
+
("stride_w", ctypes.c_int),
|
|
155
|
+
("dilation_h", ctypes.c_int),
|
|
156
|
+
("dilation_w", ctypes.c_int),
|
|
157
|
+
("mode", ctypes.c_int), # kCrossCorrelation: 0, kConvolution: 1
|
|
158
|
+
("split_k_slices", ctypes.c_int),
|
|
159
|
+
("groups", ctypes.c_int)
|
|
160
|
+
]
|
|
161
|
+
|
|
162
|
+
def __init__(self, problem_size) -> None:
|
|
163
|
+
for field_name, _ in self._fields_:
|
|
164
|
+
setattr(self, field_name, getattr(problem_size, field_name))
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
# include/cutlass/layout/tensor.h
|
|
168
|
+
# 12B
|
|
169
|
+
class Layout4D(ctypes.Structure):
|
|
170
|
+
_fields_ = [
|
|
171
|
+
("stride", ctypes.c_int * 3)
|
|
172
|
+
]
|
|
173
|
+
|
|
174
|
+
def __init__(self, tensor_ref):
|
|
175
|
+
stride = tensor_ref.stride()
|
|
176
|
+
setattr(self, "stride", (stride.at(0), stride.at(1), stride.at(2)))
|
|
177
|
+
|
|
178
|
+
# TODO: Tensor 5-D takes ("stride", ctypes.c_int * 4)
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
# include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_optimized.h
|
|
182
|
+
# TensorRef is basically cutlass::TensorRef<Element, Layout>;
|
|
183
|
+
# include/cutlass/tensor_ref.h
|
|
184
|
+
# 24B
|
|
185
|
+
class TensorRef_(ctypes.Structure):
|
|
186
|
+
_fields_ = [
|
|
187
|
+
("ptr", ctypes.c_void_p),
|
|
188
|
+
("layout", Layout4D)
|
|
189
|
+
]
|
|
190
|
+
|
|
191
|
+
def __init__(self, tensor_ref):
|
|
192
|
+
setattr(self, "ptr", tensor_ref.data())
|
|
193
|
+
setattr(self, "layout", Layout4D(tensor_ref.layout()))
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
class TensorRef2D_(ctypes.Structure):
|
|
197
|
+
_fields_ = [
|
|
198
|
+
("ptr", ctypes.c_void_p),
|
|
199
|
+
("stride", ctypes.c_int)
|
|
200
|
+
]
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
# include/cutlass/conv/kernel/implicit_gemm_convolution.h
|
|
204
|
+
# split_k_mode: kNone: 0, kSerial: 1, kParallel: 2, kParallelSerial: 3, kInvalid: 4
|
|
205
|
+
|
|
206
|
+
def get_conv2d_arguments(epilogue_functor):
|
|
207
|
+
_EpilogueOutputOpParams = epilogue_functor.epilogue_type
|
|
208
|
+
|
|
209
|
+
class _Conv2dArguments(ctypes.Structure):
|
|
210
|
+
_fields_ = [
|
|
211
|
+
("problem_size", Conv2DProblemSize), # 0
|
|
212
|
+
("ref_A", TensorRef_), # 72
|
|
213
|
+
("ref_B", TensorRef_), # 96
|
|
214
|
+
("ref_C", TensorRef_), # 120
|
|
215
|
+
("ref_D", TensorRef_), # 144
|
|
216
|
+
("output_op", _EpilogueOutputOpParams), # 168
|
|
217
|
+
("split_k_mode", ctypes.c_int) # 192
|
|
218
|
+
]
|
|
219
|
+
|
|
220
|
+
return _Conv2dArguments, _EpilogueOutputOpParams
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
############################################################################################
|
|
224
|
+
# Reduction
|
|
225
|
+
############################################################################################
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
def get_reduction_params(epilogue_functor):
|
|
229
|
+
_EpilogueOutputParams = epilogue_functor.epilogue_type
|
|
230
|
+
|
|
231
|
+
class _ReductionParams(ctypes.Structure):
|
|
232
|
+
_fields_ = [
|
|
233
|
+
("problem_size", MatrixCoord_),
|
|
234
|
+
("partitions", ctypes.c_int),
|
|
235
|
+
("partition_stride", ctypes.c_longlong),
|
|
236
|
+
("workspace", TensorRef2D_),
|
|
237
|
+
("destination", TensorRef2D_),
|
|
238
|
+
("source", TensorRef2D_),
|
|
239
|
+
("output_op", _EpilogueOutputParams)
|
|
240
|
+
]
|
|
241
|
+
return _ReductionParams, _EpilogueOutputParams
|