warp-lang 1.0.2__py3-none-manylinux2014_x86_64.whl → 1.1.0__py3-none-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +108 -97
- warp/__init__.pyi +1 -1
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/build.py +115 -113
- warp/build_dll.py +383 -375
- warp/builtins.py +3425 -3354
- warp/codegen.py +2878 -2792
- warp/config.py +40 -36
- warp/constants.py +45 -45
- warp/context.py +5194 -5102
- warp/dlpack.py +442 -442
- warp/examples/__init__.py +16 -16
- warp/examples/assets/bear.usd +0 -0
- warp/examples/assets/bunny.usd +0 -0
- warp/examples/assets/cartpole.urdf +110 -110
- warp/examples/assets/crazyflie.usd +0 -0
- warp/examples/assets/cube.usd +0 -0
- warp/examples/assets/nv_ant.xml +92 -92
- warp/examples/assets/nv_humanoid.xml +183 -183
- warp/examples/assets/quadruped.urdf +267 -267
- warp/examples/assets/rocks.nvdb +0 -0
- warp/examples/assets/rocks.usd +0 -0
- warp/examples/assets/sphere.usd +0 -0
- warp/examples/benchmarks/benchmark_api.py +383 -383
- warp/examples/benchmarks/benchmark_cloth.py +278 -277
- warp/examples/benchmarks/benchmark_cloth_cupy.py +88 -88
- warp/examples/benchmarks/benchmark_cloth_jax.py +97 -100
- warp/examples/benchmarks/benchmark_cloth_numba.py +146 -142
- warp/examples/benchmarks/benchmark_cloth_numpy.py +77 -77
- warp/examples/benchmarks/benchmark_cloth_pytorch.py +86 -86
- warp/examples/benchmarks/benchmark_cloth_taichi.py +112 -112
- warp/examples/benchmarks/benchmark_cloth_warp.py +146 -146
- warp/examples/benchmarks/benchmark_launches.py +295 -295
- warp/examples/browse.py +29 -29
- warp/examples/core/example_dem.py +234 -219
- warp/examples/core/example_fluid.py +293 -267
- warp/examples/core/example_graph_capture.py +144 -126
- warp/examples/core/example_marching_cubes.py +188 -174
- warp/examples/core/example_mesh.py +174 -155
- warp/examples/core/example_mesh_intersect.py +205 -193
- warp/examples/core/example_nvdb.py +176 -170
- warp/examples/core/example_raycast.py +105 -90
- warp/examples/core/example_raymarch.py +199 -178
- warp/examples/core/example_render_opengl.py +185 -141
- warp/examples/core/example_sph.py +405 -387
- warp/examples/core/example_torch.py +222 -181
- warp/examples/core/example_wave.py +263 -248
- warp/examples/fem/bsr_utils.py +378 -380
- warp/examples/fem/example_apic_fluid.py +407 -389
- warp/examples/fem/example_convection_diffusion.py +182 -168
- warp/examples/fem/example_convection_diffusion_dg.py +219 -209
- warp/examples/fem/example_convection_diffusion_dg0.py +204 -194
- warp/examples/fem/example_deformed_geometry.py +177 -159
- warp/examples/fem/example_diffusion.py +201 -173
- warp/examples/fem/example_diffusion_3d.py +177 -152
- warp/examples/fem/example_diffusion_mgpu.py +221 -214
- warp/examples/fem/example_mixed_elasticity.py +244 -222
- warp/examples/fem/example_navier_stokes.py +259 -243
- warp/examples/fem/example_stokes.py +220 -192
- warp/examples/fem/example_stokes_transfer.py +265 -249
- warp/examples/fem/mesh_utils.py +133 -109
- warp/examples/fem/plot_utils.py +292 -287
- warp/examples/optim/example_bounce.py +260 -246
- warp/examples/optim/example_cloth_throw.py +222 -209
- warp/examples/optim/example_diffray.py +566 -536
- warp/examples/optim/example_drone.py +864 -835
- warp/examples/optim/example_inverse_kinematics.py +176 -168
- warp/examples/optim/example_inverse_kinematics_torch.py +185 -169
- warp/examples/optim/example_spring_cage.py +239 -231
- warp/examples/optim/example_trajectory.py +223 -199
- warp/examples/optim/example_walker.py +306 -293
- warp/examples/sim/example_cartpole.py +139 -129
- warp/examples/sim/example_cloth.py +196 -186
- warp/examples/sim/example_granular.py +124 -111
- warp/examples/sim/example_granular_collision_sdf.py +197 -186
- warp/examples/sim/example_jacobian_ik.py +236 -214
- warp/examples/sim/example_particle_chain.py +118 -105
- warp/examples/sim/example_quadruped.py +193 -180
- warp/examples/sim/example_rigid_chain.py +197 -187
- warp/examples/sim/example_rigid_contact.py +189 -177
- warp/examples/sim/example_rigid_force.py +127 -125
- warp/examples/sim/example_rigid_gyroscopic.py +109 -95
- warp/examples/sim/example_rigid_soft_contact.py +134 -122
- warp/examples/sim/example_soft_body.py +190 -177
- warp/fabric.py +337 -335
- warp/fem/__init__.py +60 -27
- warp/fem/cache.py +401 -388
- warp/fem/dirichlet.py +178 -179
- warp/fem/domain.py +262 -263
- warp/fem/field/__init__.py +100 -101
- warp/fem/field/field.py +148 -149
- warp/fem/field/nodal_field.py +298 -299
- warp/fem/field/restriction.py +22 -21
- warp/fem/field/test.py +180 -181
- warp/fem/field/trial.py +183 -183
- warp/fem/geometry/__init__.py +15 -19
- warp/fem/geometry/closest_point.py +69 -70
- warp/fem/geometry/deformed_geometry.py +270 -271
- warp/fem/geometry/element.py +744 -744
- warp/fem/geometry/geometry.py +184 -186
- warp/fem/geometry/grid_2d.py +380 -373
- warp/fem/geometry/grid_3d.py +441 -435
- warp/fem/geometry/hexmesh.py +953 -953
- warp/fem/geometry/partition.py +374 -376
- warp/fem/geometry/quadmesh_2d.py +532 -532
- warp/fem/geometry/tetmesh.py +840 -840
- warp/fem/geometry/trimesh_2d.py +577 -577
- warp/fem/integrate.py +1630 -1615
- warp/fem/operator.py +190 -191
- warp/fem/polynomial.py +214 -213
- warp/fem/quadrature/__init__.py +2 -2
- warp/fem/quadrature/pic_quadrature.py +243 -245
- warp/fem/quadrature/quadrature.py +295 -294
- warp/fem/space/__init__.py +294 -292
- warp/fem/space/basis_space.py +488 -489
- warp/fem/space/collocated_function_space.py +100 -105
- warp/fem/space/dof_mapper.py +236 -236
- warp/fem/space/function_space.py +148 -145
- warp/fem/space/grid_2d_function_space.py +267 -267
- warp/fem/space/grid_3d_function_space.py +305 -306
- warp/fem/space/hexmesh_function_space.py +350 -352
- warp/fem/space/partition.py +350 -350
- warp/fem/space/quadmesh_2d_function_space.py +368 -369
- warp/fem/space/restriction.py +158 -160
- warp/fem/space/shape/__init__.py +13 -15
- warp/fem/space/shape/cube_shape_function.py +738 -738
- warp/fem/space/shape/shape_function.py +102 -103
- warp/fem/space/shape/square_shape_function.py +611 -611
- warp/fem/space/shape/tet_shape_function.py +565 -567
- warp/fem/space/shape/triangle_shape_function.py +429 -429
- warp/fem/space/tetmesh_function_space.py +294 -292
- warp/fem/space/topology.py +297 -295
- warp/fem/space/trimesh_2d_function_space.py +223 -221
- warp/fem/types.py +77 -77
- warp/fem/utils.py +495 -495
- warp/jax.py +166 -141
- warp/jax_experimental.py +341 -339
- warp/native/array.h +1072 -1025
- warp/native/builtin.h +1560 -1560
- warp/native/bvh.cpp +398 -398
- warp/native/bvh.cu +525 -525
- warp/native/bvh.h +429 -429
- warp/native/clang/clang.cpp +495 -464
- warp/native/crt.cpp +31 -31
- warp/native/crt.h +334 -334
- warp/native/cuda_crt.h +1049 -1049
- warp/native/cuda_util.cpp +549 -540
- warp/native/cuda_util.h +288 -203
- warp/native/cutlass_gemm.cpp +34 -34
- warp/native/cutlass_gemm.cu +372 -372
- warp/native/error.cpp +66 -66
- warp/native/error.h +27 -27
- warp/native/fabric.h +228 -228
- warp/native/hashgrid.cpp +301 -278
- warp/native/hashgrid.cu +78 -77
- warp/native/hashgrid.h +227 -227
- warp/native/initializer_array.h +32 -32
- warp/native/intersect.h +1204 -1204
- warp/native/intersect_adj.h +365 -365
- warp/native/intersect_tri.h +322 -322
- warp/native/marching.cpp +2 -2
- warp/native/marching.cu +497 -497
- warp/native/marching.h +2 -2
- warp/native/mat.h +1498 -1498
- warp/native/matnn.h +333 -333
- warp/native/mesh.cpp +203 -203
- warp/native/mesh.cu +293 -293
- warp/native/mesh.h +1887 -1887
- warp/native/nanovdb/NanoVDB.h +4782 -4782
- warp/native/nanovdb/PNanoVDB.h +2553 -2553
- warp/native/nanovdb/PNanoVDBWrite.h +294 -294
- warp/native/noise.h +850 -850
- warp/native/quat.h +1084 -1084
- warp/native/rand.h +299 -299
- warp/native/range.h +108 -108
- warp/native/reduce.cpp +156 -156
- warp/native/reduce.cu +348 -348
- warp/native/runlength_encode.cpp +61 -61
- warp/native/runlength_encode.cu +46 -46
- warp/native/scan.cpp +30 -30
- warp/native/scan.cu +36 -36
- warp/native/scan.h +7 -7
- warp/native/solid_angle.h +442 -442
- warp/native/sort.cpp +94 -94
- warp/native/sort.cu +97 -97
- warp/native/sort.h +14 -14
- warp/native/sparse.cpp +337 -337
- warp/native/sparse.cu +544 -544
- warp/native/spatial.h +630 -630
- warp/native/svd.h +562 -562
- warp/native/temp_buffer.h +30 -30
- warp/native/vec.h +1132 -1132
- warp/native/volume.cpp +297 -297
- warp/native/volume.cu +32 -32
- warp/native/volume.h +538 -538
- warp/native/volume_builder.cu +425 -425
- warp/native/volume_builder.h +19 -19
- warp/native/warp.cpp +1057 -1052
- warp/native/warp.cu +2943 -2828
- warp/native/warp.h +313 -305
- warp/optim/__init__.py +9 -9
- warp/optim/adam.py +120 -120
- warp/optim/linear.py +1104 -939
- warp/optim/sgd.py +104 -92
- warp/render/__init__.py +10 -10
- warp/render/render_opengl.py +3217 -3204
- warp/render/render_usd.py +768 -749
- warp/render/utils.py +152 -150
- warp/sim/__init__.py +52 -59
- warp/sim/articulation.py +685 -685
- warp/sim/collide.py +1594 -1590
- warp/sim/import_mjcf.py +489 -481
- warp/sim/import_snu.py +220 -221
- warp/sim/import_urdf.py +536 -516
- warp/sim/import_usd.py +887 -881
- warp/sim/inertia.py +316 -317
- warp/sim/integrator.py +234 -233
- warp/sim/integrator_euler.py +1956 -1956
- warp/sim/integrator_featherstone.py +1910 -1991
- warp/sim/integrator_xpbd.py +3294 -3312
- warp/sim/model.py +4473 -4314
- warp/sim/particles.py +113 -112
- warp/sim/render.py +417 -403
- warp/sim/utils.py +413 -410
- warp/sparse.py +1227 -1227
- warp/stubs.py +2109 -2469
- warp/tape.py +1162 -225
- warp/tests/__init__.py +1 -1
- warp/tests/__main__.py +4 -4
- warp/tests/assets/torus.usda +105 -105
- warp/tests/aux_test_class_kernel.py +26 -26
- warp/tests/aux_test_compile_consts_dummy.py +10 -10
- warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -21
- warp/tests/aux_test_dependent.py +22 -22
- warp/tests/aux_test_grad_customs.py +23 -23
- warp/tests/aux_test_reference.py +11 -11
- warp/tests/aux_test_reference_reference.py +10 -10
- warp/tests/aux_test_square.py +17 -17
- warp/tests/aux_test_unresolved_func.py +14 -14
- warp/tests/aux_test_unresolved_symbol.py +14 -14
- warp/tests/disabled_kinematics.py +239 -239
- warp/tests/run_coverage_serial.py +31 -31
- warp/tests/test_adam.py +157 -157
- warp/tests/test_arithmetic.py +1124 -1124
- warp/tests/test_array.py +2417 -2326
- warp/tests/test_array_reduce.py +150 -150
- warp/tests/test_async.py +668 -656
- warp/tests/test_atomic.py +141 -141
- warp/tests/test_bool.py +204 -149
- warp/tests/test_builtins_resolution.py +1292 -1292
- warp/tests/test_bvh.py +164 -171
- warp/tests/test_closest_point_edge_edge.py +228 -228
- warp/tests/test_codegen.py +566 -553
- warp/tests/test_compile_consts.py +97 -101
- warp/tests/test_conditional.py +246 -246
- warp/tests/test_copy.py +232 -215
- warp/tests/test_ctypes.py +632 -632
- warp/tests/test_dense.py +67 -67
- warp/tests/test_devices.py +91 -98
- warp/tests/test_dlpack.py +530 -529
- warp/tests/test_examples.py +400 -378
- warp/tests/test_fabricarray.py +955 -955
- warp/tests/test_fast_math.py +62 -54
- warp/tests/test_fem.py +1277 -1278
- warp/tests/test_fp16.py +130 -130
- warp/tests/test_func.py +338 -337
- warp/tests/test_generics.py +571 -571
- warp/tests/test_grad.py +746 -640
- warp/tests/test_grad_customs.py +333 -336
- warp/tests/test_hash_grid.py +210 -164
- warp/tests/test_import.py +39 -39
- warp/tests/test_indexedarray.py +1134 -1134
- warp/tests/test_intersect.py +67 -67
- warp/tests/test_jax.py +307 -307
- warp/tests/test_large.py +167 -164
- warp/tests/test_launch.py +354 -354
- warp/tests/test_lerp.py +261 -261
- warp/tests/test_linear_solvers.py +191 -171
- warp/tests/test_lvalue.py +421 -493
- warp/tests/test_marching_cubes.py +65 -65
- warp/tests/test_mat.py +1801 -1827
- warp/tests/test_mat_lite.py +115 -115
- warp/tests/test_mat_scalar_ops.py +2907 -2889
- warp/tests/test_math.py +126 -193
- warp/tests/test_matmul.py +500 -499
- warp/tests/test_matmul_lite.py +410 -410
- warp/tests/test_mempool.py +188 -190
- warp/tests/test_mesh.py +284 -324
- warp/tests/test_mesh_query_aabb.py +228 -241
- warp/tests/test_mesh_query_point.py +692 -702
- warp/tests/test_mesh_query_ray.py +292 -303
- warp/tests/test_mlp.py +276 -276
- warp/tests/test_model.py +110 -110
- warp/tests/test_modules_lite.py +39 -39
- warp/tests/test_multigpu.py +163 -163
- warp/tests/test_noise.py +248 -248
- warp/tests/test_operators.py +250 -250
- warp/tests/test_options.py +123 -125
- warp/tests/test_peer.py +133 -137
- warp/tests/test_pinned.py +78 -78
- warp/tests/test_print.py +54 -54
- warp/tests/test_quat.py +2086 -2086
- warp/tests/test_rand.py +288 -288
- warp/tests/test_reload.py +217 -217
- warp/tests/test_rounding.py +179 -179
- warp/tests/test_runlength_encode.py +190 -190
- warp/tests/test_sim_grad.py +243 -0
- warp/tests/test_sim_kinematics.py +91 -97
- warp/tests/test_smoothstep.py +168 -168
- warp/tests/test_snippet.py +305 -266
- warp/tests/test_sparse.py +468 -460
- warp/tests/test_spatial.py +2148 -2148
- warp/tests/test_streams.py +486 -473
- warp/tests/test_struct.py +710 -675
- warp/tests/test_tape.py +173 -148
- warp/tests/test_torch.py +743 -743
- warp/tests/test_transient_module.py +87 -87
- warp/tests/test_types.py +556 -659
- warp/tests/test_utils.py +490 -499
- warp/tests/test_vec.py +1264 -1268
- warp/tests/test_vec_lite.py +73 -73
- warp/tests/test_vec_scalar_ops.py +2099 -2099
- warp/tests/test_verify_fp.py +94 -94
- warp/tests/test_volume.py +737 -736
- warp/tests/test_volume_write.py +255 -265
- warp/tests/unittest_serial.py +37 -37
- warp/tests/unittest_suites.py +363 -359
- warp/tests/unittest_utils.py +603 -578
- warp/tests/unused_test_misc.py +71 -71
- warp/tests/walkthrough_debug.py +85 -85
- warp/thirdparty/appdirs.py +598 -598
- warp/thirdparty/dlpack.py +143 -143
- warp/thirdparty/unittest_parallel.py +566 -561
- warp/torch.py +321 -295
- warp/types.py +4504 -4450
- warp/utils.py +1008 -821
- {warp_lang-1.0.2.dist-info → warp_lang-1.1.0.dist-info}/LICENSE.md +126 -126
- {warp_lang-1.0.2.dist-info → warp_lang-1.1.0.dist-info}/METADATA +338 -400
- warp_lang-1.1.0.dist-info/RECORD +352 -0
- warp/examples/assets/cube.usda +0 -42
- warp/examples/assets/sphere.usda +0 -56
- warp/examples/assets/torus.usda +0 -105
- warp_lang-1.0.2.dist-info/RECORD +0 -352
- {warp_lang-1.0.2.dist-info → warp_lang-1.1.0.dist-info}/WHEEL +0 -0
- {warp_lang-1.0.2.dist-info → warp_lang-1.1.0.dist-info}/top_level.txt +0 -0
warp/native/cutlass_gemm.cu
CHANGED
|
@@ -1,373 +1,373 @@
|
|
|
1
|
-
/** Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
|
|
2
|
-
* NVIDIA CORPORATION and its licensors retain all intellectual property
|
|
3
|
-
* and proprietary rights in and to this software, related documentation
|
|
4
|
-
* and any modifications thereto. Any use, reproduction, disclosure or
|
|
5
|
-
* distribution of this software and related documentation without an express
|
|
6
|
-
* license agreement from NVIDIA CORPORATION is strictly prohibited.
|
|
7
|
-
*/
|
|
8
|
-
|
|
9
|
-
#include "builtin.h"
|
|
10
|
-
#include "temp_buffer.h"
|
|
11
|
-
#include "cuda_util.h"
|
|
12
|
-
|
|
13
|
-
#include "cutlass/cutlass.h"
|
|
14
|
-
#include "cutlass/gemm/device/gemm_universal.h"
|
|
15
|
-
#include "cutlass/util/device_memory.h"
|
|
16
|
-
|
|
17
|
-
#define F16_STR "<f2"
|
|
18
|
-
#define F32_STR "<f4"
|
|
19
|
-
#define F64_STR "<f8"
|
|
20
|
-
|
|
21
|
-
namespace wp {
|
|
22
|
-
|
|
23
|
-
template <typename Gemm>
|
|
24
|
-
bool run_gemm(int m, int n, int k, int batch_count, const void* a, const void* b, const void* c, void* d, float alpha, float beta) {
|
|
25
|
-
//
|
|
26
|
-
// Initialize arguments
|
|
27
|
-
//
|
|
28
|
-
typename Gemm::EpilogueOutputOp::Params epilogue_params(
|
|
29
|
-
(typename Gemm::EpilogueOutputOp::ElementCompute)alpha,
|
|
30
|
-
(typename Gemm::EpilogueOutputOp::ElementCompute)beta);
|
|
31
|
-
|
|
32
|
-
typename Gemm::Arguments arguments{
|
|
33
|
-
batch_count == 1 ? cutlass::gemm::GemmUniversalMode::kGemm : cutlass::gemm::GemmUniversalMode::kBatched ,
|
|
34
|
-
cutlass::gemm::GemmCoord{m, n, k}, // Problem size
|
|
35
|
-
batch_count,
|
|
36
|
-
epilogue_params,
|
|
37
|
-
a, b, c, d,
|
|
38
|
-
int64_t(m * k), int64_t(k * n), int64_t(m * n), int64_t(m * n), // Batch strides
|
|
39
|
-
Gemm::LayoutA::packed({m, k}).stride(0), Gemm::LayoutB::packed({k, n}).stride(0), n, n
|
|
40
|
-
};
|
|
41
|
-
|
|
42
|
-
Gemm gemm;
|
|
43
|
-
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
|
44
|
-
ScopedTemporary<> workspace(WP_CURRENT_CONTEXT, workspace_size);
|
|
45
|
-
cudaStream_t stream = static_cast<cudaStream_t>(cuda_stream_get_current());
|
|
46
|
-
cutlass::Status status = gemm.initialize(arguments, workspace.buffer(), stream);
|
|
47
|
-
|
|
48
|
-
if (status != cutlass::Status::kSuccess) {
|
|
49
|
-
cudaError_t error = cudaGetLastError();
|
|
50
|
-
std::cerr << "Error initializing GEMM: " << cudaGetErrorString(error) << "\n";
|
|
51
|
-
return false;
|
|
52
|
-
}
|
|
53
|
-
|
|
54
|
-
//
|
|
55
|
-
// Run the GEMM
|
|
56
|
-
//
|
|
57
|
-
|
|
58
|
-
status = gemm(stream);
|
|
59
|
-
if (status != cutlass::Status::kSuccess) {
|
|
60
|
-
cudaError_t error = cudaGetLastError();
|
|
61
|
-
std::cerr << "Runtime error: " << cudaGetErrorString(error) << "\n";
|
|
62
|
-
return false;
|
|
63
|
-
}
|
|
64
|
-
|
|
65
|
-
return true;
|
|
66
|
-
}
|
|
67
|
-
|
|
68
|
-
template <
|
|
69
|
-
int ComputeCapability,
|
|
70
|
-
typename Element_,
|
|
71
|
-
typename LayoutA,
|
|
72
|
-
typename LayoutB
|
|
73
|
-
>
|
|
74
|
-
struct DefaultGemmConfig;
|
|
75
|
-
|
|
76
|
-
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
77
|
-
|
|
78
|
-
// Partial specialization for SM80 F64 Tensor Cores
|
|
79
|
-
template <typename LayoutA, typename LayoutB>
|
|
80
|
-
struct DefaultGemmConfig<80, double, LayoutA, LayoutB> {
|
|
81
|
-
using Gemm = cutlass::gemm::device::GemmUniversal<
|
|
82
|
-
double, LayoutA, // ElementA and LayoutA
|
|
83
|
-
double, LayoutB, // ElementB and LayoutB
|
|
84
|
-
double, cutlass::layout::RowMajor, // ElementC and LayoutC
|
|
85
|
-
double, // ElementAccumulator
|
|
86
|
-
cutlass::arch::OpClassTensorOp, // Operation type
|
|
87
|
-
cutlass::arch::Sm80, // Architecture
|
|
88
|
-
cutlass::gemm::GemmShape<128, 128, 16>, // ThreadblockShape
|
|
89
|
-
cutlass::gemm::GemmShape<32, 64, 16>, // WarpShape
|
|
90
|
-
cutlass::gemm::GemmShape<8, 8, 4>, // Instruction Shape
|
|
91
|
-
cutlass::epilogue::thread::LinearCombination< // Epilogue
|
|
92
|
-
double,
|
|
93
|
-
1,
|
|
94
|
-
double,
|
|
95
|
-
double>,
|
|
96
|
-
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, // Swizzling
|
|
97
|
-
3 // Stages
|
|
98
|
-
>;
|
|
99
|
-
};
|
|
100
|
-
|
|
101
|
-
// Partial specialization for SM80 F32 Tensor Cores
|
|
102
|
-
template <typename LayoutA, typename LayoutB>
|
|
103
|
-
struct DefaultGemmConfig<80, float, LayoutA, LayoutB> {
|
|
104
|
-
using Gemm = cutlass::gemm::device::GemmUniversal<
|
|
105
|
-
float, LayoutA, // ElementA and LayoutA
|
|
106
|
-
float, LayoutB, // ElementB and LayoutB
|
|
107
|
-
float, cutlass::layout::RowMajor, // ElementC and LayoutC
|
|
108
|
-
float, // ElementAccumulator
|
|
109
|
-
cutlass::arch::OpClassTensorOp, // Operation type
|
|
110
|
-
cutlass::arch::Sm80, // Architecture
|
|
111
|
-
cutlass::gemm::GemmShape<256, 128, 16>, // ThreadblockShape
|
|
112
|
-
cutlass::gemm::GemmShape<64, 64, 16>, // WarpShape
|
|
113
|
-
cutlass::gemm::GemmShape<16, 8, 8>, // Instruction Shape
|
|
114
|
-
cutlass::epilogue::thread::LinearCombination< // Epilogue
|
|
115
|
-
float,
|
|
116
|
-
128 / cutlass::sizeof_bits<float>::value,
|
|
117
|
-
float,
|
|
118
|
-
float>,
|
|
119
|
-
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, // Swizzling
|
|
120
|
-
3, // Stages
|
|
121
|
-
4, 4, // AlignmentA and AlignmentB
|
|
122
|
-
cutlass::arch::OpMultiplyAddFastF32 // Math mode -- use 3xTF32
|
|
123
|
-
>;
|
|
124
|
-
};
|
|
125
|
-
|
|
126
|
-
// Partial specialization for SM80 F16 Tensor Cores
|
|
127
|
-
template <typename LayoutA, typename LayoutB>
|
|
128
|
-
struct DefaultGemmConfig<80, cutlass::half_t, LayoutA, LayoutB> {
|
|
129
|
-
using Gemm = cutlass::gemm::device::GemmUniversal<
|
|
130
|
-
cutlass::half_t, LayoutA, // ElementA and LayoutA
|
|
131
|
-
cutlass::half_t, LayoutB, // ElementB and LayoutB
|
|
132
|
-
cutlass::half_t, cutlass::layout::RowMajor, // ElementC and LayoutC
|
|
133
|
-
cutlass::half_t, // ElementAccumulator
|
|
134
|
-
cutlass::arch::OpClassTensorOp, // Operation type
|
|
135
|
-
cutlass::arch::Sm80, // Architecture
|
|
136
|
-
cutlass::gemm::GemmShape<256, 128, 32>, // ThreadblockShape
|
|
137
|
-
cutlass::gemm::GemmShape<64, 64, 32>, // WarpShape
|
|
138
|
-
cutlass::gemm::GemmShape<16, 8, 16>, // Instruction Shape
|
|
139
|
-
cutlass::epilogue::thread::LinearCombination< // Epilogue
|
|
140
|
-
cutlass::half_t,
|
|
141
|
-
128 / cutlass::sizeof_bits<cutlass::half_t>::value,
|
|
142
|
-
cutlass::half_t,
|
|
143
|
-
cutlass::half_t>,
|
|
144
|
-
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, // Swizzling
|
|
145
|
-
3 // Stages
|
|
146
|
-
>;
|
|
147
|
-
};
|
|
148
|
-
|
|
149
|
-
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
150
|
-
|
|
151
|
-
// Partial specialization for SM75 F16 Tensor Cores
|
|
152
|
-
template <typename LayoutA, typename LayoutB>
|
|
153
|
-
struct DefaultGemmConfig<75, cutlass::half_t, LayoutA, LayoutB> {
|
|
154
|
-
using Gemm = cutlass::gemm::device::GemmUniversal<
|
|
155
|
-
cutlass::half_t, LayoutA, // ElementA and LayoutA
|
|
156
|
-
cutlass::half_t, LayoutB, // ElementB and LayoutB
|
|
157
|
-
cutlass::half_t, cutlass::layout::RowMajor, // ElementC and LayoutC
|
|
158
|
-
cutlass::half_t, // ElementAccumulator
|
|
159
|
-
cutlass::arch::OpClassTensorOp, // Operation type
|
|
160
|
-
cutlass::arch::Sm75, // Architecture
|
|
161
|
-
cutlass::gemm::GemmShape<256, 128, 32>, // ThreadblockShape
|
|
162
|
-
cutlass::gemm::GemmShape<64, 64, 32>, // WarpShape
|
|
163
|
-
cutlass::gemm::GemmShape<16, 8, 8>, // Instruction Shape
|
|
164
|
-
cutlass::epilogue::thread::LinearCombination< // Epilogue
|
|
165
|
-
cutlass::half_t,
|
|
166
|
-
128 / cutlass::sizeof_bits<cutlass::half_t>::value,
|
|
167
|
-
cutlass::half_t,
|
|
168
|
-
cutlass::half_t>,
|
|
169
|
-
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, // Swizzling
|
|
170
|
-
2 // Stages
|
|
171
|
-
>;
|
|
172
|
-
};
|
|
173
|
-
|
|
174
|
-
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
175
|
-
|
|
176
|
-
// Partial specialization for SM70 F16 Tensor Cores
|
|
177
|
-
template <typename LayoutA, typename LayoutB>
|
|
178
|
-
struct DefaultGemmConfig<70, cutlass::half_t, LayoutA, LayoutB> {
|
|
179
|
-
using Gemm = cutlass::gemm::device::GemmUniversal<
|
|
180
|
-
cutlass::half_t, LayoutA, // ElementA and LayoutA
|
|
181
|
-
cutlass::half_t, LayoutB, // ElementB and LayoutB
|
|
182
|
-
cutlass::half_t, cutlass::layout::RowMajor, // ElementC and LayoutC
|
|
183
|
-
cutlass::half_t, // ElementAccumulator
|
|
184
|
-
cutlass::arch::OpClassTensorOp, // Operation type
|
|
185
|
-
cutlass::arch::Sm70, // Architecture
|
|
186
|
-
cutlass::gemm::GemmShape<256, 128, 32>, // ThreadblockShape
|
|
187
|
-
cutlass::gemm::GemmShape<64, 64, 32>, // WarpShape
|
|
188
|
-
cutlass::gemm::GemmShape<8, 8, 4>, // Instruction Shape
|
|
189
|
-
cutlass::epilogue::thread::LinearCombination< // Epilogue
|
|
190
|
-
cutlass::half_t,
|
|
191
|
-
128 / cutlass::sizeof_bits<cutlass::half_t>::value,
|
|
192
|
-
cutlass::half_t,
|
|
193
|
-
cutlass::half_t>,
|
|
194
|
-
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, // Swizzling
|
|
195
|
-
2 // Stages
|
|
196
|
-
>;
|
|
197
|
-
};
|
|
198
|
-
|
|
199
|
-
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
200
|
-
|
|
201
|
-
// Partial specialization for SM50 SIMT
|
|
202
|
-
template <typename Element, typename LayoutA, typename LayoutB>
|
|
203
|
-
struct DefaultGemmConfig<50, Element, LayoutA, LayoutB> {
|
|
204
|
-
using Gemm = cutlass::gemm::device::GemmUniversal<
|
|
205
|
-
Element, LayoutA, // ElementA and LayoutA
|
|
206
|
-
Element, LayoutB, // ElementB and LayoutB
|
|
207
|
-
Element, cutlass::layout::RowMajor, // ElementC and LayoutC
|
|
208
|
-
Element, // ElementAccumulator
|
|
209
|
-
cutlass::arch::OpClassSimt, // Operation type
|
|
210
|
-
cutlass::arch::Sm50, // Architecture
|
|
211
|
-
cutlass::gemm::GemmShape<128, 128, 8>, // ThreadblockShape
|
|
212
|
-
cutlass::gemm::GemmShape<32, 64, 8>, // WarpShape
|
|
213
|
-
cutlass::gemm::GemmShape<1, 1, 1>, // Instruction Shape
|
|
214
|
-
cutlass::epilogue::thread::LinearCombination< // Epilogue
|
|
215
|
-
Element,
|
|
216
|
-
1,
|
|
217
|
-
Element,
|
|
218
|
-
Element>,
|
|
219
|
-
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, // Swizzling
|
|
220
|
-
2 // Stages
|
|
221
|
-
>;
|
|
222
|
-
};
|
|
223
|
-
|
|
224
|
-
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
225
|
-
|
|
226
|
-
extern "C" {
|
|
227
|
-
|
|
228
|
-
WP_API
|
|
229
|
-
bool cutlass_gemm(
|
|
230
|
-
void* context, int compute_capability,
|
|
231
|
-
int m, int n, int k,
|
|
232
|
-
const char* datatype_str,
|
|
233
|
-
const void* a, const void* b, const void* c, void* d,
|
|
234
|
-
float alpha, float beta,
|
|
235
|
-
bool row_major_a, bool row_major_b,
|
|
236
|
-
bool allow_tf32x3_arith,
|
|
237
|
-
int batch_count) {
|
|
238
|
-
|
|
239
|
-
std::string datatype(datatype_str);
|
|
240
|
-
|
|
241
|
-
ContextGuard guard(context);
|
|
242
|
-
|
|
243
|
-
// Specializations for using Tensor Cores and A/B RowMajor/ColumnMajor designations
|
|
244
|
-
if (compute_capability == 80) {
|
|
245
|
-
if (datatype == F64_STR) {
|
|
246
|
-
if (row_major_a && row_major_b) {
|
|
247
|
-
using Gemm = DefaultGemmConfig<80, double, cutlass::layout::RowMajor, cutlass::layout::RowMajor>::Gemm;
|
|
248
|
-
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
249
|
-
} else if (!row_major_a && row_major_b) {
|
|
250
|
-
using Gemm = DefaultGemmConfig<80, double, cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>::Gemm;
|
|
251
|
-
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
252
|
-
} else if (row_major_a && !row_major_b) {
|
|
253
|
-
using Gemm = DefaultGemmConfig<80, double, cutlass::layout::RowMajor, cutlass::layout::ColumnMajor>::Gemm;
|
|
254
|
-
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
255
|
-
} else if (!row_major_a && !row_major_b) {
|
|
256
|
-
using Gemm = DefaultGemmConfig<80, double, cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor>::Gemm;
|
|
257
|
-
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
258
|
-
}
|
|
259
|
-
} else if (datatype == F32_STR && allow_tf32x3_arith) {
|
|
260
|
-
if (row_major_a && row_major_b) {
|
|
261
|
-
using Gemm = DefaultGemmConfig<80, float, cutlass::layout::RowMajor, cutlass::layout::RowMajor>::Gemm;
|
|
262
|
-
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
263
|
-
} else if (!row_major_a && row_major_b) {
|
|
264
|
-
using Gemm = DefaultGemmConfig<80, float, cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>::Gemm;
|
|
265
|
-
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
266
|
-
} else if (row_major_a && !row_major_b) {
|
|
267
|
-
using Gemm = DefaultGemmConfig<80, float, cutlass::layout::RowMajor, cutlass::layout::ColumnMajor>::Gemm;
|
|
268
|
-
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
269
|
-
} else if (!row_major_a && !row_major_b) {
|
|
270
|
-
using Gemm = DefaultGemmConfig<80, float, cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor>::Gemm;
|
|
271
|
-
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
272
|
-
}
|
|
273
|
-
} else if (datatype == F16_STR) {
|
|
274
|
-
if (row_major_a && row_major_b) {
|
|
275
|
-
using Gemm = DefaultGemmConfig<80, cutlass::half_t, cutlass::layout::RowMajor, cutlass::layout::RowMajor>::Gemm;
|
|
276
|
-
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
277
|
-
} else if (!row_major_a && row_major_b) {
|
|
278
|
-
using Gemm = DefaultGemmConfig<80, cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>::Gemm;
|
|
279
|
-
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
280
|
-
} else if (row_major_a && !row_major_b) {
|
|
281
|
-
using Gemm = DefaultGemmConfig<80, cutlass::half_t, cutlass::layout::RowMajor, cutlass::layout::ColumnMajor>::Gemm;
|
|
282
|
-
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
283
|
-
} else if (!row_major_a && !row_major_b) {
|
|
284
|
-
using Gemm = DefaultGemmConfig<80, cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor>::Gemm;
|
|
285
|
-
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
286
|
-
}
|
|
287
|
-
}
|
|
288
|
-
} else if (compute_capability == 75) {
|
|
289
|
-
if (datatype == F16_STR) {
|
|
290
|
-
if (row_major_a && row_major_b) {
|
|
291
|
-
using Gemm = DefaultGemmConfig<75, cutlass::half_t, cutlass::layout::RowMajor, cutlass::layout::RowMajor>::Gemm;
|
|
292
|
-
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
293
|
-
} else if (!row_major_a && row_major_b) {
|
|
294
|
-
using Gemm = DefaultGemmConfig<75, cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>::Gemm;
|
|
295
|
-
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
296
|
-
} else if (row_major_a && !row_major_b) {
|
|
297
|
-
using Gemm = DefaultGemmConfig<75, cutlass::half_t, cutlass::layout::RowMajor, cutlass::layout::ColumnMajor>::Gemm;
|
|
298
|
-
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
299
|
-
} else if (!row_major_a && !row_major_b) {
|
|
300
|
-
using Gemm = DefaultGemmConfig<75, cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor>::Gemm;
|
|
301
|
-
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
302
|
-
}
|
|
303
|
-
}
|
|
304
|
-
} else if (compute_capability == 70) {
|
|
305
|
-
if (datatype == F16_STR) {
|
|
306
|
-
if (row_major_a && row_major_b) {
|
|
307
|
-
using Gemm = DefaultGemmConfig<70, cutlass::half_t, cutlass::layout::RowMajor, cutlass::layout::RowMajor>::Gemm;
|
|
308
|
-
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
309
|
-
} else if (!row_major_a && row_major_b) {
|
|
310
|
-
using Gemm = DefaultGemmConfig<70, cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>::Gemm;
|
|
311
|
-
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
312
|
-
} else if (row_major_a && !row_major_b) {
|
|
313
|
-
using Gemm = DefaultGemmConfig<70, cutlass::half_t, cutlass::layout::RowMajor, cutlass::layout::ColumnMajor>::Gemm;
|
|
314
|
-
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
315
|
-
} else if (!row_major_a && !row_major_b) {
|
|
316
|
-
using Gemm = DefaultGemmConfig<70, cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor>::Gemm;
|
|
317
|
-
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
318
|
-
}
|
|
319
|
-
}
|
|
320
|
-
}
|
|
321
|
-
|
|
322
|
-
// No Tensor Core capability available. Run a SIMT kernel
|
|
323
|
-
if (datatype == F64_STR) {
|
|
324
|
-
if (row_major_a && row_major_b) {
|
|
325
|
-
using Gemm = DefaultGemmConfig<50, double, cutlass::layout::RowMajor, cutlass::layout::RowMajor>::Gemm;
|
|
326
|
-
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
327
|
-
} else if (!row_major_a && row_major_b) {
|
|
328
|
-
using Gemm = DefaultGemmConfig<50, double, cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>::Gemm;
|
|
329
|
-
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
330
|
-
} else if (row_major_a && !row_major_b) {
|
|
331
|
-
using Gemm = DefaultGemmConfig<50, double, cutlass::layout::RowMajor, cutlass::layout::ColumnMajor>::Gemm;
|
|
332
|
-
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
333
|
-
} else if (!row_major_a && !row_major_b) {
|
|
334
|
-
using Gemm = DefaultGemmConfig<50, double, cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor>::Gemm;
|
|
335
|
-
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
336
|
-
}
|
|
337
|
-
} else if (datatype == F32_STR) {
|
|
338
|
-
if (row_major_a && row_major_b) {
|
|
339
|
-
using Gemm = DefaultGemmConfig<50, float, cutlass::layout::RowMajor, cutlass::layout::RowMajor>::Gemm;
|
|
340
|
-
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
341
|
-
} else if (!row_major_a && row_major_b) {
|
|
342
|
-
using Gemm = DefaultGemmConfig<50, float, cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>::Gemm;
|
|
343
|
-
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
344
|
-
} else if (row_major_a && !row_major_b) {
|
|
345
|
-
using Gemm = DefaultGemmConfig<50, float, cutlass::layout::RowMajor, cutlass::layout::ColumnMajor>::Gemm;
|
|
346
|
-
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
347
|
-
} else if (!row_major_a && !row_major_b) {
|
|
348
|
-
using Gemm = DefaultGemmConfig<50, float, cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor>::Gemm;
|
|
349
|
-
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
350
|
-
}
|
|
351
|
-
} else if (datatype == F16_STR) {
|
|
352
|
-
if (row_major_a && row_major_b) {
|
|
353
|
-
using Gemm = DefaultGemmConfig<50, cutlass::half_t, cutlass::layout::RowMajor, cutlass::layout::RowMajor>::Gemm;
|
|
354
|
-
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
355
|
-
} else if (!row_major_a && row_major_b) {
|
|
356
|
-
using Gemm = DefaultGemmConfig<50, cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>::Gemm;
|
|
357
|
-
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
358
|
-
} else if (row_major_a && !row_major_b) {
|
|
359
|
-
using Gemm = DefaultGemmConfig<50, cutlass::half_t, cutlass::layout::RowMajor, cutlass::layout::ColumnMajor>::Gemm;
|
|
360
|
-
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
361
|
-
} else if (!row_major_a && !row_major_b) {
|
|
362
|
-
using Gemm = DefaultGemmConfig<50, cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor>::Gemm;
|
|
363
|
-
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
364
|
-
}
|
|
365
|
-
}
|
|
366
|
-
|
|
367
|
-
std::cerr << "Data type " << datatype << " is not currently supported." << std::endl;
|
|
368
|
-
return false;
|
|
369
|
-
}
|
|
370
|
-
|
|
371
|
-
}
|
|
372
|
-
|
|
1
|
+
/** Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
|
|
2
|
+
* NVIDIA CORPORATION and its licensors retain all intellectual property
|
|
3
|
+
* and proprietary rights in and to this software, related documentation
|
|
4
|
+
* and any modifications thereto. Any use, reproduction, disclosure or
|
|
5
|
+
* distribution of this software and related documentation without an express
|
|
6
|
+
* license agreement from NVIDIA CORPORATION is strictly prohibited.
|
|
7
|
+
*/
|
|
8
|
+
|
|
9
|
+
#include "builtin.h"
|
|
10
|
+
#include "temp_buffer.h"
|
|
11
|
+
#include "cuda_util.h"
|
|
12
|
+
|
|
13
|
+
#include "cutlass/cutlass.h"
|
|
14
|
+
#include "cutlass/gemm/device/gemm_universal.h"
|
|
15
|
+
#include "cutlass/util/device_memory.h"
|
|
16
|
+
|
|
17
|
+
#define F16_STR "<f2"
|
|
18
|
+
#define F32_STR "<f4"
|
|
19
|
+
#define F64_STR "<f8"
|
|
20
|
+
|
|
21
|
+
namespace wp {
|
|
22
|
+
|
|
23
|
+
template <typename Gemm>
|
|
24
|
+
bool run_gemm(int m, int n, int k, int batch_count, const void* a, const void* b, const void* c, void* d, float alpha, float beta) {
|
|
25
|
+
//
|
|
26
|
+
// Initialize arguments
|
|
27
|
+
//
|
|
28
|
+
typename Gemm::EpilogueOutputOp::Params epilogue_params(
|
|
29
|
+
(typename Gemm::EpilogueOutputOp::ElementCompute)alpha,
|
|
30
|
+
(typename Gemm::EpilogueOutputOp::ElementCompute)beta);
|
|
31
|
+
|
|
32
|
+
typename Gemm::Arguments arguments{
|
|
33
|
+
batch_count == 1 ? cutlass::gemm::GemmUniversalMode::kGemm : cutlass::gemm::GemmUniversalMode::kBatched ,
|
|
34
|
+
cutlass::gemm::GemmCoord{m, n, k}, // Problem size
|
|
35
|
+
batch_count,
|
|
36
|
+
epilogue_params,
|
|
37
|
+
a, b, c, d,
|
|
38
|
+
int64_t(m * k), int64_t(k * n), int64_t(m * n), int64_t(m * n), // Batch strides
|
|
39
|
+
Gemm::LayoutA::packed({m, k}).stride(0), Gemm::LayoutB::packed({k, n}).stride(0), n, n
|
|
40
|
+
};
|
|
41
|
+
|
|
42
|
+
Gemm gemm;
|
|
43
|
+
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
|
44
|
+
ScopedTemporary<> workspace(WP_CURRENT_CONTEXT, workspace_size);
|
|
45
|
+
cudaStream_t stream = static_cast<cudaStream_t>(cuda_stream_get_current());
|
|
46
|
+
cutlass::Status status = gemm.initialize(arguments, workspace.buffer(), stream);
|
|
47
|
+
|
|
48
|
+
if (status != cutlass::Status::kSuccess) {
|
|
49
|
+
cudaError_t error = cudaGetLastError();
|
|
50
|
+
std::cerr << "Error initializing GEMM: " << cudaGetErrorString(error) << "\n";
|
|
51
|
+
return false;
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
//
|
|
55
|
+
// Run the GEMM
|
|
56
|
+
//
|
|
57
|
+
|
|
58
|
+
status = gemm(stream);
|
|
59
|
+
if (status != cutlass::Status::kSuccess) {
|
|
60
|
+
cudaError_t error = cudaGetLastError();
|
|
61
|
+
std::cerr << "Runtime error: " << cudaGetErrorString(error) << "\n";
|
|
62
|
+
return false;
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
return true;
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
template <
|
|
69
|
+
int ComputeCapability,
|
|
70
|
+
typename Element_,
|
|
71
|
+
typename LayoutA,
|
|
72
|
+
typename LayoutB
|
|
73
|
+
>
|
|
74
|
+
struct DefaultGemmConfig;
|
|
75
|
+
|
|
76
|
+
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
77
|
+
|
|
78
|
+
// Partial specialization for SM80 F64 Tensor Cores
|
|
79
|
+
template <typename LayoutA, typename LayoutB>
|
|
80
|
+
struct DefaultGemmConfig<80, double, LayoutA, LayoutB> {
|
|
81
|
+
using Gemm = cutlass::gemm::device::GemmUniversal<
|
|
82
|
+
double, LayoutA, // ElementA and LayoutA
|
|
83
|
+
double, LayoutB, // ElementB and LayoutB
|
|
84
|
+
double, cutlass::layout::RowMajor, // ElementC and LayoutC
|
|
85
|
+
double, // ElementAccumulator
|
|
86
|
+
cutlass::arch::OpClassTensorOp, // Operation type
|
|
87
|
+
cutlass::arch::Sm80, // Architecture
|
|
88
|
+
cutlass::gemm::GemmShape<128, 128, 16>, // ThreadblockShape
|
|
89
|
+
cutlass::gemm::GemmShape<32, 64, 16>, // WarpShape
|
|
90
|
+
cutlass::gemm::GemmShape<8, 8, 4>, // Instruction Shape
|
|
91
|
+
cutlass::epilogue::thread::LinearCombination< // Epilogue
|
|
92
|
+
double,
|
|
93
|
+
1,
|
|
94
|
+
double,
|
|
95
|
+
double>,
|
|
96
|
+
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, // Swizzling
|
|
97
|
+
3 // Stages
|
|
98
|
+
>;
|
|
99
|
+
};
|
|
100
|
+
|
|
101
|
+
// Partial specialization for SM80 F32 Tensor Cores
|
|
102
|
+
template <typename LayoutA, typename LayoutB>
|
|
103
|
+
struct DefaultGemmConfig<80, float, LayoutA, LayoutB> {
|
|
104
|
+
using Gemm = cutlass::gemm::device::GemmUniversal<
|
|
105
|
+
float, LayoutA, // ElementA and LayoutA
|
|
106
|
+
float, LayoutB, // ElementB and LayoutB
|
|
107
|
+
float, cutlass::layout::RowMajor, // ElementC and LayoutC
|
|
108
|
+
float, // ElementAccumulator
|
|
109
|
+
cutlass::arch::OpClassTensorOp, // Operation type
|
|
110
|
+
cutlass::arch::Sm80, // Architecture
|
|
111
|
+
cutlass::gemm::GemmShape<256, 128, 16>, // ThreadblockShape
|
|
112
|
+
cutlass::gemm::GemmShape<64, 64, 16>, // WarpShape
|
|
113
|
+
cutlass::gemm::GemmShape<16, 8, 8>, // Instruction Shape
|
|
114
|
+
cutlass::epilogue::thread::LinearCombination< // Epilogue
|
|
115
|
+
float,
|
|
116
|
+
128 / cutlass::sizeof_bits<float>::value,
|
|
117
|
+
float,
|
|
118
|
+
float>,
|
|
119
|
+
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, // Swizzling
|
|
120
|
+
3, // Stages
|
|
121
|
+
4, 4, // AlignmentA and AlignmentB
|
|
122
|
+
cutlass::arch::OpMultiplyAddFastF32 // Math mode -- use 3xTF32
|
|
123
|
+
>;
|
|
124
|
+
};
|
|
125
|
+
|
|
126
|
+
// Partial specialization for SM80 F16 Tensor Cores
|
|
127
|
+
template <typename LayoutA, typename LayoutB>
|
|
128
|
+
struct DefaultGemmConfig<80, cutlass::half_t, LayoutA, LayoutB> {
|
|
129
|
+
using Gemm = cutlass::gemm::device::GemmUniversal<
|
|
130
|
+
cutlass::half_t, LayoutA, // ElementA and LayoutA
|
|
131
|
+
cutlass::half_t, LayoutB, // ElementB and LayoutB
|
|
132
|
+
cutlass::half_t, cutlass::layout::RowMajor, // ElementC and LayoutC
|
|
133
|
+
cutlass::half_t, // ElementAccumulator
|
|
134
|
+
cutlass::arch::OpClassTensorOp, // Operation type
|
|
135
|
+
cutlass::arch::Sm80, // Architecture
|
|
136
|
+
cutlass::gemm::GemmShape<256, 128, 32>, // ThreadblockShape
|
|
137
|
+
cutlass::gemm::GemmShape<64, 64, 32>, // WarpShape
|
|
138
|
+
cutlass::gemm::GemmShape<16, 8, 16>, // Instruction Shape
|
|
139
|
+
cutlass::epilogue::thread::LinearCombination< // Epilogue
|
|
140
|
+
cutlass::half_t,
|
|
141
|
+
128 / cutlass::sizeof_bits<cutlass::half_t>::value,
|
|
142
|
+
cutlass::half_t,
|
|
143
|
+
cutlass::half_t>,
|
|
144
|
+
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, // Swizzling
|
|
145
|
+
3 // Stages
|
|
146
|
+
>;
|
|
147
|
+
};
|
|
148
|
+
|
|
149
|
+
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
150
|
+
|
|
151
|
+
// Partial specialization for SM75 F16 Tensor Cores
|
|
152
|
+
template <typename LayoutA, typename LayoutB>
|
|
153
|
+
struct DefaultGemmConfig<75, cutlass::half_t, LayoutA, LayoutB> {
|
|
154
|
+
using Gemm = cutlass::gemm::device::GemmUniversal<
|
|
155
|
+
cutlass::half_t, LayoutA, // ElementA and LayoutA
|
|
156
|
+
cutlass::half_t, LayoutB, // ElementB and LayoutB
|
|
157
|
+
cutlass::half_t, cutlass::layout::RowMajor, // ElementC and LayoutC
|
|
158
|
+
cutlass::half_t, // ElementAccumulator
|
|
159
|
+
cutlass::arch::OpClassTensorOp, // Operation type
|
|
160
|
+
cutlass::arch::Sm75, // Architecture
|
|
161
|
+
cutlass::gemm::GemmShape<256, 128, 32>, // ThreadblockShape
|
|
162
|
+
cutlass::gemm::GemmShape<64, 64, 32>, // WarpShape
|
|
163
|
+
cutlass::gemm::GemmShape<16, 8, 8>, // Instruction Shape
|
|
164
|
+
cutlass::epilogue::thread::LinearCombination< // Epilogue
|
|
165
|
+
cutlass::half_t,
|
|
166
|
+
128 / cutlass::sizeof_bits<cutlass::half_t>::value,
|
|
167
|
+
cutlass::half_t,
|
|
168
|
+
cutlass::half_t>,
|
|
169
|
+
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, // Swizzling
|
|
170
|
+
2 // Stages
|
|
171
|
+
>;
|
|
172
|
+
};
|
|
173
|
+
|
|
174
|
+
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
175
|
+
|
|
176
|
+
// Partial specialization for SM70 F16 Tensor Cores
|
|
177
|
+
template <typename LayoutA, typename LayoutB>
|
|
178
|
+
struct DefaultGemmConfig<70, cutlass::half_t, LayoutA, LayoutB> {
|
|
179
|
+
using Gemm = cutlass::gemm::device::GemmUniversal<
|
|
180
|
+
cutlass::half_t, LayoutA, // ElementA and LayoutA
|
|
181
|
+
cutlass::half_t, LayoutB, // ElementB and LayoutB
|
|
182
|
+
cutlass::half_t, cutlass::layout::RowMajor, // ElementC and LayoutC
|
|
183
|
+
cutlass::half_t, // ElementAccumulator
|
|
184
|
+
cutlass::arch::OpClassTensorOp, // Operation type
|
|
185
|
+
cutlass::arch::Sm70, // Architecture
|
|
186
|
+
cutlass::gemm::GemmShape<256, 128, 32>, // ThreadblockShape
|
|
187
|
+
cutlass::gemm::GemmShape<64, 64, 32>, // WarpShape
|
|
188
|
+
cutlass::gemm::GemmShape<8, 8, 4>, // Instruction Shape
|
|
189
|
+
cutlass::epilogue::thread::LinearCombination< // Epilogue
|
|
190
|
+
cutlass::half_t,
|
|
191
|
+
128 / cutlass::sizeof_bits<cutlass::half_t>::value,
|
|
192
|
+
cutlass::half_t,
|
|
193
|
+
cutlass::half_t>,
|
|
194
|
+
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, // Swizzling
|
|
195
|
+
2 // Stages
|
|
196
|
+
>;
|
|
197
|
+
};
|
|
198
|
+
|
|
199
|
+
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
200
|
+
|
|
201
|
+
// Partial specialization for SM50 SIMT
|
|
202
|
+
template <typename Element, typename LayoutA, typename LayoutB>
|
|
203
|
+
struct DefaultGemmConfig<50, Element, LayoutA, LayoutB> {
|
|
204
|
+
using Gemm = cutlass::gemm::device::GemmUniversal<
|
|
205
|
+
Element, LayoutA, // ElementA and LayoutA
|
|
206
|
+
Element, LayoutB, // ElementB and LayoutB
|
|
207
|
+
Element, cutlass::layout::RowMajor, // ElementC and LayoutC
|
|
208
|
+
Element, // ElementAccumulator
|
|
209
|
+
cutlass::arch::OpClassSimt, // Operation type
|
|
210
|
+
cutlass::arch::Sm50, // Architecture
|
|
211
|
+
cutlass::gemm::GemmShape<128, 128, 8>, // ThreadblockShape
|
|
212
|
+
cutlass::gemm::GemmShape<32, 64, 8>, // WarpShape
|
|
213
|
+
cutlass::gemm::GemmShape<1, 1, 1>, // Instruction Shape
|
|
214
|
+
cutlass::epilogue::thread::LinearCombination< // Epilogue
|
|
215
|
+
Element,
|
|
216
|
+
1,
|
|
217
|
+
Element,
|
|
218
|
+
Element>,
|
|
219
|
+
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, // Swizzling
|
|
220
|
+
2 // Stages
|
|
221
|
+
>;
|
|
222
|
+
};
|
|
223
|
+
|
|
224
|
+
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
225
|
+
|
|
226
|
+
extern "C" {
|
|
227
|
+
|
|
228
|
+
WP_API
|
|
229
|
+
bool cutlass_gemm(
|
|
230
|
+
void* context, int compute_capability,
|
|
231
|
+
int m, int n, int k,
|
|
232
|
+
const char* datatype_str,
|
|
233
|
+
const void* a, const void* b, const void* c, void* d,
|
|
234
|
+
float alpha, float beta,
|
|
235
|
+
bool row_major_a, bool row_major_b,
|
|
236
|
+
bool allow_tf32x3_arith,
|
|
237
|
+
int batch_count) {
|
|
238
|
+
|
|
239
|
+
std::string datatype(datatype_str);
|
|
240
|
+
|
|
241
|
+
ContextGuard guard(context);
|
|
242
|
+
|
|
243
|
+
// Specializations for using Tensor Cores and A/B RowMajor/ColumnMajor designations
|
|
244
|
+
if (compute_capability == 80) {
|
|
245
|
+
if (datatype == F64_STR) {
|
|
246
|
+
if (row_major_a && row_major_b) {
|
|
247
|
+
using Gemm = DefaultGemmConfig<80, double, cutlass::layout::RowMajor, cutlass::layout::RowMajor>::Gemm;
|
|
248
|
+
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
249
|
+
} else if (!row_major_a && row_major_b) {
|
|
250
|
+
using Gemm = DefaultGemmConfig<80, double, cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>::Gemm;
|
|
251
|
+
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
252
|
+
} else if (row_major_a && !row_major_b) {
|
|
253
|
+
using Gemm = DefaultGemmConfig<80, double, cutlass::layout::RowMajor, cutlass::layout::ColumnMajor>::Gemm;
|
|
254
|
+
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
255
|
+
} else if (!row_major_a && !row_major_b) {
|
|
256
|
+
using Gemm = DefaultGemmConfig<80, double, cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor>::Gemm;
|
|
257
|
+
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
258
|
+
}
|
|
259
|
+
} else if (datatype == F32_STR && allow_tf32x3_arith) {
|
|
260
|
+
if (row_major_a && row_major_b) {
|
|
261
|
+
using Gemm = DefaultGemmConfig<80, float, cutlass::layout::RowMajor, cutlass::layout::RowMajor>::Gemm;
|
|
262
|
+
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
263
|
+
} else if (!row_major_a && row_major_b) {
|
|
264
|
+
using Gemm = DefaultGemmConfig<80, float, cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>::Gemm;
|
|
265
|
+
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
266
|
+
} else if (row_major_a && !row_major_b) {
|
|
267
|
+
using Gemm = DefaultGemmConfig<80, float, cutlass::layout::RowMajor, cutlass::layout::ColumnMajor>::Gemm;
|
|
268
|
+
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
269
|
+
} else if (!row_major_a && !row_major_b) {
|
|
270
|
+
using Gemm = DefaultGemmConfig<80, float, cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor>::Gemm;
|
|
271
|
+
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
272
|
+
}
|
|
273
|
+
} else if (datatype == F16_STR) {
|
|
274
|
+
if (row_major_a && row_major_b) {
|
|
275
|
+
using Gemm = DefaultGemmConfig<80, cutlass::half_t, cutlass::layout::RowMajor, cutlass::layout::RowMajor>::Gemm;
|
|
276
|
+
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
277
|
+
} else if (!row_major_a && row_major_b) {
|
|
278
|
+
using Gemm = DefaultGemmConfig<80, cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>::Gemm;
|
|
279
|
+
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
280
|
+
} else if (row_major_a && !row_major_b) {
|
|
281
|
+
using Gemm = DefaultGemmConfig<80, cutlass::half_t, cutlass::layout::RowMajor, cutlass::layout::ColumnMajor>::Gemm;
|
|
282
|
+
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
283
|
+
} else if (!row_major_a && !row_major_b) {
|
|
284
|
+
using Gemm = DefaultGemmConfig<80, cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor>::Gemm;
|
|
285
|
+
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
286
|
+
}
|
|
287
|
+
}
|
|
288
|
+
} else if (compute_capability == 75) {
|
|
289
|
+
if (datatype == F16_STR) {
|
|
290
|
+
if (row_major_a && row_major_b) {
|
|
291
|
+
using Gemm = DefaultGemmConfig<75, cutlass::half_t, cutlass::layout::RowMajor, cutlass::layout::RowMajor>::Gemm;
|
|
292
|
+
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
293
|
+
} else if (!row_major_a && row_major_b) {
|
|
294
|
+
using Gemm = DefaultGemmConfig<75, cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>::Gemm;
|
|
295
|
+
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
296
|
+
} else if (row_major_a && !row_major_b) {
|
|
297
|
+
using Gemm = DefaultGemmConfig<75, cutlass::half_t, cutlass::layout::RowMajor, cutlass::layout::ColumnMajor>::Gemm;
|
|
298
|
+
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
299
|
+
} else if (!row_major_a && !row_major_b) {
|
|
300
|
+
using Gemm = DefaultGemmConfig<75, cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor>::Gemm;
|
|
301
|
+
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
302
|
+
}
|
|
303
|
+
}
|
|
304
|
+
} else if (compute_capability == 70) {
|
|
305
|
+
if (datatype == F16_STR) {
|
|
306
|
+
if (row_major_a && row_major_b) {
|
|
307
|
+
using Gemm = DefaultGemmConfig<70, cutlass::half_t, cutlass::layout::RowMajor, cutlass::layout::RowMajor>::Gemm;
|
|
308
|
+
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
309
|
+
} else if (!row_major_a && row_major_b) {
|
|
310
|
+
using Gemm = DefaultGemmConfig<70, cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>::Gemm;
|
|
311
|
+
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
312
|
+
} else if (row_major_a && !row_major_b) {
|
|
313
|
+
using Gemm = DefaultGemmConfig<70, cutlass::half_t, cutlass::layout::RowMajor, cutlass::layout::ColumnMajor>::Gemm;
|
|
314
|
+
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
315
|
+
} else if (!row_major_a && !row_major_b) {
|
|
316
|
+
using Gemm = DefaultGemmConfig<70, cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor>::Gemm;
|
|
317
|
+
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
318
|
+
}
|
|
319
|
+
}
|
|
320
|
+
}
|
|
321
|
+
|
|
322
|
+
// No Tensor Core capability available. Run a SIMT kernel
|
|
323
|
+
if (datatype == F64_STR) {
|
|
324
|
+
if (row_major_a && row_major_b) {
|
|
325
|
+
using Gemm = DefaultGemmConfig<50, double, cutlass::layout::RowMajor, cutlass::layout::RowMajor>::Gemm;
|
|
326
|
+
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
327
|
+
} else if (!row_major_a && row_major_b) {
|
|
328
|
+
using Gemm = DefaultGemmConfig<50, double, cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>::Gemm;
|
|
329
|
+
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
330
|
+
} else if (row_major_a && !row_major_b) {
|
|
331
|
+
using Gemm = DefaultGemmConfig<50, double, cutlass::layout::RowMajor, cutlass::layout::ColumnMajor>::Gemm;
|
|
332
|
+
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
333
|
+
} else if (!row_major_a && !row_major_b) {
|
|
334
|
+
using Gemm = DefaultGemmConfig<50, double, cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor>::Gemm;
|
|
335
|
+
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
336
|
+
}
|
|
337
|
+
} else if (datatype == F32_STR) {
|
|
338
|
+
if (row_major_a && row_major_b) {
|
|
339
|
+
using Gemm = DefaultGemmConfig<50, float, cutlass::layout::RowMajor, cutlass::layout::RowMajor>::Gemm;
|
|
340
|
+
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
341
|
+
} else if (!row_major_a && row_major_b) {
|
|
342
|
+
using Gemm = DefaultGemmConfig<50, float, cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>::Gemm;
|
|
343
|
+
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
344
|
+
} else if (row_major_a && !row_major_b) {
|
|
345
|
+
using Gemm = DefaultGemmConfig<50, float, cutlass::layout::RowMajor, cutlass::layout::ColumnMajor>::Gemm;
|
|
346
|
+
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
347
|
+
} else if (!row_major_a && !row_major_b) {
|
|
348
|
+
using Gemm = DefaultGemmConfig<50, float, cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor>::Gemm;
|
|
349
|
+
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
350
|
+
}
|
|
351
|
+
} else if (datatype == F16_STR) {
|
|
352
|
+
if (row_major_a && row_major_b) {
|
|
353
|
+
using Gemm = DefaultGemmConfig<50, cutlass::half_t, cutlass::layout::RowMajor, cutlass::layout::RowMajor>::Gemm;
|
|
354
|
+
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
355
|
+
} else if (!row_major_a && row_major_b) {
|
|
356
|
+
using Gemm = DefaultGemmConfig<50, cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>::Gemm;
|
|
357
|
+
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
358
|
+
} else if (row_major_a && !row_major_b) {
|
|
359
|
+
using Gemm = DefaultGemmConfig<50, cutlass::half_t, cutlass::layout::RowMajor, cutlass::layout::ColumnMajor>::Gemm;
|
|
360
|
+
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
361
|
+
} else if (!row_major_a && !row_major_b) {
|
|
362
|
+
using Gemm = DefaultGemmConfig<50, cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor>::Gemm;
|
|
363
|
+
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
364
|
+
}
|
|
365
|
+
}
|
|
366
|
+
|
|
367
|
+
std::cerr << "Data type " << datatype << " is not currently supported." << std::endl;
|
|
368
|
+
return false;
|
|
369
|
+
}
|
|
370
|
+
|
|
371
|
+
}
|
|
372
|
+
|
|
373
373
|
} // namespace wp
|