warp-lang 1.6.2__py3-none-macosx_10_13_universal2.whl → 1.7.1__py3-none-macosx_10_13_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +7 -1
- warp/autograd.py +12 -2
- warp/bin/libwarp-clang.dylib +0 -0
- warp/bin/libwarp.dylib +0 -0
- warp/build.py +410 -0
- warp/build_dll.py +6 -14
- warp/builtins.py +463 -372
- warp/codegen.py +196 -124
- warp/config.py +42 -6
- warp/context.py +496 -271
- warp/dlpack.py +8 -6
- warp/examples/assets/nonuniform.usd +0 -0
- warp/examples/assets/nvidia_logo.png +0 -0
- warp/examples/benchmarks/benchmark_cloth.py +1 -1
- warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
- warp/examples/core/example_sample_mesh.py +300 -0
- warp/examples/distributed/example_jacobi_mpi.py +507 -0
- warp/examples/fem/example_apic_fluid.py +1 -1
- warp/examples/fem/example_burgers.py +2 -2
- warp/examples/fem/example_deformed_geometry.py +1 -1
- warp/examples/fem/example_distortion_energy.py +1 -1
- warp/examples/fem/example_magnetostatics.py +6 -6
- warp/examples/fem/utils.py +9 -3
- warp/examples/interop/example_jax_callable.py +116 -0
- warp/examples/interop/example_jax_ffi_callback.py +132 -0
- warp/examples/interop/example_jax_kernel.py +205 -0
- warp/examples/optim/example_fluid_checkpoint.py +497 -0
- warp/examples/tile/example_tile_matmul.py +2 -4
- warp/fem/__init__.py +11 -1
- warp/fem/adaptivity.py +4 -4
- warp/fem/field/field.py +11 -1
- warp/fem/field/nodal_field.py +56 -88
- warp/fem/field/virtual.py +62 -23
- warp/fem/geometry/adaptive_nanogrid.py +16 -13
- warp/fem/geometry/closest_point.py +1 -1
- warp/fem/geometry/deformed_geometry.py +5 -2
- warp/fem/geometry/geometry.py +5 -0
- warp/fem/geometry/grid_2d.py +12 -12
- warp/fem/geometry/grid_3d.py +12 -15
- warp/fem/geometry/hexmesh.py +5 -7
- warp/fem/geometry/nanogrid.py +9 -11
- warp/fem/geometry/quadmesh.py +13 -13
- warp/fem/geometry/tetmesh.py +3 -4
- warp/fem/geometry/trimesh.py +7 -20
- warp/fem/integrate.py +262 -93
- warp/fem/linalg.py +5 -5
- warp/fem/quadrature/pic_quadrature.py +37 -22
- warp/fem/quadrature/quadrature.py +194 -25
- warp/fem/space/__init__.py +1 -1
- warp/fem/space/basis_function_space.py +4 -2
- warp/fem/space/basis_space.py +25 -18
- warp/fem/space/hexmesh_function_space.py +2 -2
- warp/fem/space/partition.py +6 -2
- warp/fem/space/quadmesh_function_space.py +8 -8
- warp/fem/space/shape/cube_shape_function.py +23 -23
- warp/fem/space/shape/square_shape_function.py +12 -12
- warp/fem/space/shape/triangle_shape_function.py +1 -1
- warp/fem/space/tetmesh_function_space.py +3 -3
- warp/fem/space/trimesh_function_space.py +2 -2
- warp/fem/utils.py +12 -6
- warp/jax.py +14 -1
- warp/jax_experimental/__init__.py +16 -0
- warp/{jax_experimental.py → jax_experimental/custom_call.py} +28 -29
- warp/jax_experimental/ffi.py +702 -0
- warp/jax_experimental/xla_ffi.py +602 -0
- warp/math.py +89 -0
- warp/native/array.h +13 -0
- warp/native/builtin.h +29 -3
- warp/native/bvh.cpp +3 -1
- warp/native/bvh.cu +42 -14
- warp/native/bvh.h +2 -1
- warp/native/clang/clang.cpp +30 -3
- warp/native/cuda_util.cpp +14 -0
- warp/native/cuda_util.h +2 -0
- warp/native/exports.h +68 -63
- warp/native/intersect.h +26 -26
- warp/native/intersect_adj.h +33 -33
- warp/native/marching.cu +1 -1
- warp/native/mat.h +513 -9
- warp/native/mesh.h +10 -10
- warp/native/quat.h +99 -11
- warp/native/rand.h +6 -0
- warp/native/sort.cpp +122 -59
- warp/native/sort.cu +152 -15
- warp/native/sort.h +8 -1
- warp/native/sparse.cpp +43 -22
- warp/native/sparse.cu +52 -17
- warp/native/svd.h +116 -0
- warp/native/tile.h +312 -116
- warp/native/tile_reduce.h +46 -3
- warp/native/vec.h +68 -7
- warp/native/volume.cpp +85 -113
- warp/native/volume_builder.cu +25 -10
- warp/native/volume_builder.h +6 -0
- warp/native/warp.cpp +5 -6
- warp/native/warp.cu +100 -11
- warp/native/warp.h +19 -10
- warp/optim/linear.py +10 -10
- warp/render/render_opengl.py +19 -17
- warp/render/render_usd.py +93 -3
- warp/sim/articulation.py +4 -4
- warp/sim/collide.py +32 -19
- warp/sim/import_mjcf.py +449 -155
- warp/sim/import_urdf.py +32 -12
- warp/sim/inertia.py +189 -156
- warp/sim/integrator_euler.py +8 -5
- warp/sim/integrator_featherstone.py +3 -10
- warp/sim/integrator_vbd.py +207 -2
- warp/sim/integrator_xpbd.py +8 -5
- warp/sim/model.py +71 -25
- warp/sim/render.py +4 -0
- warp/sim/utils.py +2 -2
- warp/sparse.py +642 -555
- warp/stubs.py +217 -20
- warp/tests/__main__.py +0 -15
- warp/tests/assets/torus.usda +1 -1
- warp/tests/cuda/__init__.py +0 -0
- warp/tests/{test_mempool.py → cuda/test_mempool.py} +39 -0
- warp/tests/{test_streams.py → cuda/test_streams.py} +71 -0
- warp/tests/geometry/__init__.py +0 -0
- warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +66 -63
- warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +1 -1
- warp/tests/{test_volume.py → geometry/test_volume.py} +41 -6
- warp/tests/interop/__init__.py +0 -0
- warp/tests/{test_dlpack.py → interop/test_dlpack.py} +28 -5
- warp/tests/sim/__init__.py +0 -0
- warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +9 -10
- warp/tests/{test_collision.py → sim/test_collision.py} +236 -205
- warp/tests/sim/test_inertia.py +161 -0
- warp/tests/{test_model.py → sim/test_model.py} +40 -0
- warp/tests/{flaky_test_sim_grad.py → sim/test_sim_grad.py} +4 -0
- warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +2 -1
- warp/tests/sim/test_vbd.py +597 -0
- warp/tests/sim/test_xpbd.py +399 -0
- warp/tests/test_bool.py +1 -1
- warp/tests/test_codegen.py +24 -3
- warp/tests/test_examples.py +40 -38
- warp/tests/test_fem.py +98 -14
- warp/tests/test_linear_solvers.py +0 -11
- warp/tests/test_mat.py +577 -156
- warp/tests/test_mat_scalar_ops.py +4 -4
- warp/tests/test_overwrite.py +0 -60
- warp/tests/test_quat.py +356 -151
- warp/tests/test_rand.py +44 -37
- warp/tests/test_sparse.py +47 -6
- warp/tests/test_spatial.py +75 -0
- warp/tests/test_static.py +1 -1
- warp/tests/test_utils.py +84 -4
- warp/tests/test_vec.py +336 -178
- warp/tests/tile/__init__.py +0 -0
- warp/tests/{test_tile.py → tile/test_tile.py} +136 -51
- warp/tests/{test_tile_load.py → tile/test_tile_load.py} +98 -1
- warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +9 -6
- warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +25 -14
- warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +60 -1
- warp/tests/{test_tile_view.py → tile/test_tile_view.py} +1 -1
- warp/tests/unittest_serial.py +1 -0
- warp/tests/unittest_suites.py +45 -62
- warp/tests/unittest_utils.py +2 -1
- warp/thirdparty/unittest_parallel.py +3 -1
- warp/types.py +175 -666
- warp/utils.py +137 -72
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/METADATA +46 -12
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/RECORD +184 -171
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/WHEEL +1 -1
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info/licenses}/LICENSE.md +0 -26
- warp/examples/optim/example_walker.py +0 -317
- warp/native/cutlass_gemm.cpp +0 -43
- warp/native/cutlass_gemm.cu +0 -382
- warp/tests/test_matmul.py +0 -511
- warp/tests/test_matmul_lite.py +0 -411
- warp/tests/test_vbd.py +0 -386
- warp/tests/unused_test_misc.py +0 -77
- /warp/tests/{test_async.py → cuda/test_async.py} +0 -0
- /warp/tests/{test_ipc.py → cuda/test_ipc.py} +0 -0
- /warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +0 -0
- /warp/tests/{test_peer.py → cuda/test_peer.py} +0 -0
- /warp/tests/{test_pinned.py → cuda/test_pinned.py} +0 -0
- /warp/tests/{test_bvh.py → geometry/test_bvh.py} +0 -0
- /warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +0 -0
- /warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +0 -0
- /warp/tests/{test_mesh.py → geometry/test_mesh.py} +0 -0
- /warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +0 -0
- /warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +0 -0
- /warp/tests/{test_jax.py → interop/test_jax.py} +0 -0
- /warp/tests/{test_paddle.py → interop/test_paddle.py} +0 -0
- /warp/tests/{test_torch.py → interop/test_torch.py} +0 -0
- /warp/tests/{test_coloring.py → sim/test_coloring.py} +0 -0
- /warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +0 -0
- /warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +0 -0
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/top_level.txt +0 -0
warp/native/cutlass_gemm.cu
DELETED
|
@@ -1,382 +0,0 @@
|
|
|
1
|
-
/*
|
|
2
|
-
* SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
3
|
-
* SPDX-License-Identifier: Apache-2.0
|
|
4
|
-
*
|
|
5
|
-
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
-
* you may not use this file except in compliance with the License.
|
|
7
|
-
* You may obtain a copy of the License at
|
|
8
|
-
*
|
|
9
|
-
* http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
-
*
|
|
11
|
-
* Unless required by applicable law or agreed to in writing, software
|
|
12
|
-
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
-
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
-
* See the License for the specific language governing permissions and
|
|
15
|
-
* limitations under the License.
|
|
16
|
-
*/
|
|
17
|
-
|
|
18
|
-
#include "builtin.h"
|
|
19
|
-
#include "temp_buffer.h"
|
|
20
|
-
#include "cuda_util.h"
|
|
21
|
-
|
|
22
|
-
#include "cutlass/cutlass.h"
|
|
23
|
-
#include "cutlass/gemm/device/gemm_universal.h"
|
|
24
|
-
#include "cutlass/util/device_memory.h"
|
|
25
|
-
|
|
26
|
-
#define F16_STR "<f2"
|
|
27
|
-
#define F32_STR "<f4"
|
|
28
|
-
#define F64_STR "<f8"
|
|
29
|
-
|
|
30
|
-
namespace wp {
|
|
31
|
-
|
|
32
|
-
template <typename Gemm>
|
|
33
|
-
bool run_gemm(int m, int n, int k, int batch_count, const void* a, const void* b, const void* c, void* d, float alpha, float beta) {
|
|
34
|
-
//
|
|
35
|
-
// Initialize arguments
|
|
36
|
-
//
|
|
37
|
-
typename Gemm::EpilogueOutputOp::Params epilogue_params(
|
|
38
|
-
(typename Gemm::EpilogueOutputOp::ElementCompute)alpha,
|
|
39
|
-
(typename Gemm::EpilogueOutputOp::ElementCompute)beta);
|
|
40
|
-
|
|
41
|
-
typename Gemm::Arguments arguments{
|
|
42
|
-
batch_count == 1 ? cutlass::gemm::GemmUniversalMode::kGemm : cutlass::gemm::GemmUniversalMode::kBatched ,
|
|
43
|
-
cutlass::gemm::GemmCoord{m, n, k}, // Problem size
|
|
44
|
-
batch_count,
|
|
45
|
-
epilogue_params,
|
|
46
|
-
a, b, c, d,
|
|
47
|
-
int64_t(m * k), int64_t(k * n), int64_t(m * n), int64_t(m * n), // Batch strides
|
|
48
|
-
Gemm::LayoutA::packed({m, k}).stride(0), Gemm::LayoutB::packed({k, n}).stride(0), n, n
|
|
49
|
-
};
|
|
50
|
-
|
|
51
|
-
Gemm gemm;
|
|
52
|
-
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
|
53
|
-
ScopedTemporary<> workspace(WP_CURRENT_CONTEXT, workspace_size);
|
|
54
|
-
cudaStream_t stream = static_cast<cudaStream_t>(cuda_stream_get_current());
|
|
55
|
-
cutlass::Status status = gemm.initialize(arguments, workspace.buffer(), stream);
|
|
56
|
-
|
|
57
|
-
if (status != cutlass::Status::kSuccess) {
|
|
58
|
-
cudaError_t error = cudaGetLastError();
|
|
59
|
-
std::cerr << "Error initializing GEMM: " << cudaGetErrorString(error) << "\n";
|
|
60
|
-
return false;
|
|
61
|
-
}
|
|
62
|
-
|
|
63
|
-
//
|
|
64
|
-
// Run the GEMM
|
|
65
|
-
//
|
|
66
|
-
|
|
67
|
-
status = gemm(stream);
|
|
68
|
-
if (status != cutlass::Status::kSuccess) {
|
|
69
|
-
cudaError_t error = cudaGetLastError();
|
|
70
|
-
std::cerr << "Runtime error: " << cudaGetErrorString(error) << "\n";
|
|
71
|
-
return false;
|
|
72
|
-
}
|
|
73
|
-
|
|
74
|
-
return true;
|
|
75
|
-
}
|
|
76
|
-
|
|
77
|
-
template <
|
|
78
|
-
int ComputeCapability,
|
|
79
|
-
typename Element_,
|
|
80
|
-
typename LayoutA,
|
|
81
|
-
typename LayoutB
|
|
82
|
-
>
|
|
83
|
-
struct DefaultGemmConfig;
|
|
84
|
-
|
|
85
|
-
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
86
|
-
|
|
87
|
-
// Partial specialization for SM80 F64 Tensor Cores
|
|
88
|
-
template <typename LayoutA, typename LayoutB>
|
|
89
|
-
struct DefaultGemmConfig<80, double, LayoutA, LayoutB> {
|
|
90
|
-
using Gemm = cutlass::gemm::device::GemmUniversal<
|
|
91
|
-
double, LayoutA, // ElementA and LayoutA
|
|
92
|
-
double, LayoutB, // ElementB and LayoutB
|
|
93
|
-
double, cutlass::layout::RowMajor, // ElementC and LayoutC
|
|
94
|
-
double, // ElementAccumulator
|
|
95
|
-
cutlass::arch::OpClassTensorOp, // Operation type
|
|
96
|
-
cutlass::arch::Sm80, // Architecture
|
|
97
|
-
cutlass::gemm::GemmShape<128, 128, 16>, // ThreadblockShape
|
|
98
|
-
cutlass::gemm::GemmShape<32, 64, 16>, // WarpShape
|
|
99
|
-
cutlass::gemm::GemmShape<8, 8, 4>, // Instruction Shape
|
|
100
|
-
cutlass::epilogue::thread::LinearCombination< // Epilogue
|
|
101
|
-
double,
|
|
102
|
-
1,
|
|
103
|
-
double,
|
|
104
|
-
double>,
|
|
105
|
-
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, // Swizzling
|
|
106
|
-
3 // Stages
|
|
107
|
-
>;
|
|
108
|
-
};
|
|
109
|
-
|
|
110
|
-
// Partial specialization for SM80 F32 Tensor Cores
|
|
111
|
-
template <typename LayoutA, typename LayoutB>
|
|
112
|
-
struct DefaultGemmConfig<80, float, LayoutA, LayoutB> {
|
|
113
|
-
using Gemm = cutlass::gemm::device::GemmUniversal<
|
|
114
|
-
float, LayoutA, // ElementA and LayoutA
|
|
115
|
-
float, LayoutB, // ElementB and LayoutB
|
|
116
|
-
float, cutlass::layout::RowMajor, // ElementC and LayoutC
|
|
117
|
-
float, // ElementAccumulator
|
|
118
|
-
cutlass::arch::OpClassTensorOp, // Operation type
|
|
119
|
-
cutlass::arch::Sm80, // Architecture
|
|
120
|
-
cutlass::gemm::GemmShape<256, 128, 16>, // ThreadblockShape
|
|
121
|
-
cutlass::gemm::GemmShape<64, 64, 16>, // WarpShape
|
|
122
|
-
cutlass::gemm::GemmShape<16, 8, 8>, // Instruction Shape
|
|
123
|
-
cutlass::epilogue::thread::LinearCombination< // Epilogue
|
|
124
|
-
float,
|
|
125
|
-
128 / cutlass::sizeof_bits<float>::value,
|
|
126
|
-
float,
|
|
127
|
-
float>,
|
|
128
|
-
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, // Swizzling
|
|
129
|
-
3, // Stages
|
|
130
|
-
4, 4, // AlignmentA and AlignmentB
|
|
131
|
-
cutlass::arch::OpMultiplyAddFastF32 // Math mode -- use 3xTF32
|
|
132
|
-
>;
|
|
133
|
-
};
|
|
134
|
-
|
|
135
|
-
// Partial specialization for SM80 F16 Tensor Cores
|
|
136
|
-
template <typename LayoutA, typename LayoutB>
|
|
137
|
-
struct DefaultGemmConfig<80, cutlass::half_t, LayoutA, LayoutB> {
|
|
138
|
-
using Gemm = cutlass::gemm::device::GemmUniversal<
|
|
139
|
-
cutlass::half_t, LayoutA, // ElementA and LayoutA
|
|
140
|
-
cutlass::half_t, LayoutB, // ElementB and LayoutB
|
|
141
|
-
cutlass::half_t, cutlass::layout::RowMajor, // ElementC and LayoutC
|
|
142
|
-
cutlass::half_t, // ElementAccumulator
|
|
143
|
-
cutlass::arch::OpClassTensorOp, // Operation type
|
|
144
|
-
cutlass::arch::Sm80, // Architecture
|
|
145
|
-
cutlass::gemm::GemmShape<256, 128, 32>, // ThreadblockShape
|
|
146
|
-
cutlass::gemm::GemmShape<64, 64, 32>, // WarpShape
|
|
147
|
-
cutlass::gemm::GemmShape<16, 8, 16>, // Instruction Shape
|
|
148
|
-
cutlass::epilogue::thread::LinearCombination< // Epilogue
|
|
149
|
-
cutlass::half_t,
|
|
150
|
-
128 / cutlass::sizeof_bits<cutlass::half_t>::value,
|
|
151
|
-
cutlass::half_t,
|
|
152
|
-
cutlass::half_t>,
|
|
153
|
-
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, // Swizzling
|
|
154
|
-
3 // Stages
|
|
155
|
-
>;
|
|
156
|
-
};
|
|
157
|
-
|
|
158
|
-
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
159
|
-
|
|
160
|
-
// Partial specialization for SM75 F16 Tensor Cores
|
|
161
|
-
template <typename LayoutA, typename LayoutB>
|
|
162
|
-
struct DefaultGemmConfig<75, cutlass::half_t, LayoutA, LayoutB> {
|
|
163
|
-
using Gemm = cutlass::gemm::device::GemmUniversal<
|
|
164
|
-
cutlass::half_t, LayoutA, // ElementA and LayoutA
|
|
165
|
-
cutlass::half_t, LayoutB, // ElementB and LayoutB
|
|
166
|
-
cutlass::half_t, cutlass::layout::RowMajor, // ElementC and LayoutC
|
|
167
|
-
cutlass::half_t, // ElementAccumulator
|
|
168
|
-
cutlass::arch::OpClassTensorOp, // Operation type
|
|
169
|
-
cutlass::arch::Sm75, // Architecture
|
|
170
|
-
cutlass::gemm::GemmShape<256, 128, 32>, // ThreadblockShape
|
|
171
|
-
cutlass::gemm::GemmShape<64, 64, 32>, // WarpShape
|
|
172
|
-
cutlass::gemm::GemmShape<16, 8, 8>, // Instruction Shape
|
|
173
|
-
cutlass::epilogue::thread::LinearCombination< // Epilogue
|
|
174
|
-
cutlass::half_t,
|
|
175
|
-
128 / cutlass::sizeof_bits<cutlass::half_t>::value,
|
|
176
|
-
cutlass::half_t,
|
|
177
|
-
cutlass::half_t>,
|
|
178
|
-
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, // Swizzling
|
|
179
|
-
2 // Stages
|
|
180
|
-
>;
|
|
181
|
-
};
|
|
182
|
-
|
|
183
|
-
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
184
|
-
|
|
185
|
-
// Partial specialization for SM70 F16 Tensor Cores
|
|
186
|
-
template <typename LayoutA, typename LayoutB>
|
|
187
|
-
struct DefaultGemmConfig<70, cutlass::half_t, LayoutA, LayoutB> {
|
|
188
|
-
using Gemm = cutlass::gemm::device::GemmUniversal<
|
|
189
|
-
cutlass::half_t, LayoutA, // ElementA and LayoutA
|
|
190
|
-
cutlass::half_t, LayoutB, // ElementB and LayoutB
|
|
191
|
-
cutlass::half_t, cutlass::layout::RowMajor, // ElementC and LayoutC
|
|
192
|
-
cutlass::half_t, // ElementAccumulator
|
|
193
|
-
cutlass::arch::OpClassTensorOp, // Operation type
|
|
194
|
-
cutlass::arch::Sm70, // Architecture
|
|
195
|
-
cutlass::gemm::GemmShape<256, 128, 32>, // ThreadblockShape
|
|
196
|
-
cutlass::gemm::GemmShape<64, 64, 32>, // WarpShape
|
|
197
|
-
cutlass::gemm::GemmShape<8, 8, 4>, // Instruction Shape
|
|
198
|
-
cutlass::epilogue::thread::LinearCombination< // Epilogue
|
|
199
|
-
cutlass::half_t,
|
|
200
|
-
128 / cutlass::sizeof_bits<cutlass::half_t>::value,
|
|
201
|
-
cutlass::half_t,
|
|
202
|
-
cutlass::half_t>,
|
|
203
|
-
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, // Swizzling
|
|
204
|
-
2 // Stages
|
|
205
|
-
>;
|
|
206
|
-
};
|
|
207
|
-
|
|
208
|
-
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
209
|
-
|
|
210
|
-
// Partial specialization for SM50 SIMT
|
|
211
|
-
template <typename Element, typename LayoutA, typename LayoutB>
|
|
212
|
-
struct DefaultGemmConfig<50, Element, LayoutA, LayoutB> {
|
|
213
|
-
using Gemm = cutlass::gemm::device::GemmUniversal<
|
|
214
|
-
Element, LayoutA, // ElementA and LayoutA
|
|
215
|
-
Element, LayoutB, // ElementB and LayoutB
|
|
216
|
-
Element, cutlass::layout::RowMajor, // ElementC and LayoutC
|
|
217
|
-
Element, // ElementAccumulator
|
|
218
|
-
cutlass::arch::OpClassSimt, // Operation type
|
|
219
|
-
cutlass::arch::Sm50, // Architecture
|
|
220
|
-
cutlass::gemm::GemmShape<128, 128, 8>, // ThreadblockShape
|
|
221
|
-
cutlass::gemm::GemmShape<32, 64, 8>, // WarpShape
|
|
222
|
-
cutlass::gemm::GemmShape<1, 1, 1>, // Instruction Shape
|
|
223
|
-
cutlass::epilogue::thread::LinearCombination< // Epilogue
|
|
224
|
-
Element,
|
|
225
|
-
1,
|
|
226
|
-
Element,
|
|
227
|
-
Element>,
|
|
228
|
-
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, // Swizzling
|
|
229
|
-
2 // Stages
|
|
230
|
-
>;
|
|
231
|
-
};
|
|
232
|
-
|
|
233
|
-
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
234
|
-
|
|
235
|
-
extern "C" {
|
|
236
|
-
|
|
237
|
-
WP_API
|
|
238
|
-
bool cutlass_gemm(
|
|
239
|
-
void* context, int compute_capability,
|
|
240
|
-
int m, int n, int k,
|
|
241
|
-
const char* datatype_str,
|
|
242
|
-
const void* a, const void* b, const void* c, void* d,
|
|
243
|
-
float alpha, float beta,
|
|
244
|
-
bool row_major_a, bool row_major_b,
|
|
245
|
-
bool allow_tf32x3_arith,
|
|
246
|
-
int batch_count) {
|
|
247
|
-
|
|
248
|
-
std::string datatype(datatype_str);
|
|
249
|
-
|
|
250
|
-
ContextGuard guard(context);
|
|
251
|
-
|
|
252
|
-
// Specializations for using Tensor Cores and A/B RowMajor/ColumnMajor designations
|
|
253
|
-
if (compute_capability == 80) {
|
|
254
|
-
if (datatype == F64_STR) {
|
|
255
|
-
if (row_major_a && row_major_b) {
|
|
256
|
-
using Gemm = DefaultGemmConfig<80, double, cutlass::layout::RowMajor, cutlass::layout::RowMajor>::Gemm;
|
|
257
|
-
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
258
|
-
} else if (!row_major_a && row_major_b) {
|
|
259
|
-
using Gemm = DefaultGemmConfig<80, double, cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>::Gemm;
|
|
260
|
-
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
261
|
-
} else if (row_major_a && !row_major_b) {
|
|
262
|
-
using Gemm = DefaultGemmConfig<80, double, cutlass::layout::RowMajor, cutlass::layout::ColumnMajor>::Gemm;
|
|
263
|
-
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
264
|
-
} else if (!row_major_a && !row_major_b) {
|
|
265
|
-
using Gemm = DefaultGemmConfig<80, double, cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor>::Gemm;
|
|
266
|
-
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
267
|
-
}
|
|
268
|
-
} else if (datatype == F32_STR && allow_tf32x3_arith) {
|
|
269
|
-
if (row_major_a && row_major_b) {
|
|
270
|
-
using Gemm = DefaultGemmConfig<80, float, cutlass::layout::RowMajor, cutlass::layout::RowMajor>::Gemm;
|
|
271
|
-
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
272
|
-
} else if (!row_major_a && row_major_b) {
|
|
273
|
-
using Gemm = DefaultGemmConfig<80, float, cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>::Gemm;
|
|
274
|
-
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
275
|
-
} else if (row_major_a && !row_major_b) {
|
|
276
|
-
using Gemm = DefaultGemmConfig<80, float, cutlass::layout::RowMajor, cutlass::layout::ColumnMajor>::Gemm;
|
|
277
|
-
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
278
|
-
} else if (!row_major_a && !row_major_b) {
|
|
279
|
-
using Gemm = DefaultGemmConfig<80, float, cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor>::Gemm;
|
|
280
|
-
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
281
|
-
}
|
|
282
|
-
} else if (datatype == F16_STR) {
|
|
283
|
-
if (row_major_a && row_major_b) {
|
|
284
|
-
using Gemm = DefaultGemmConfig<80, cutlass::half_t, cutlass::layout::RowMajor, cutlass::layout::RowMajor>::Gemm;
|
|
285
|
-
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
286
|
-
} else if (!row_major_a && row_major_b) {
|
|
287
|
-
using Gemm = DefaultGemmConfig<80, cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>::Gemm;
|
|
288
|
-
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
289
|
-
} else if (row_major_a && !row_major_b) {
|
|
290
|
-
using Gemm = DefaultGemmConfig<80, cutlass::half_t, cutlass::layout::RowMajor, cutlass::layout::ColumnMajor>::Gemm;
|
|
291
|
-
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
292
|
-
} else if (!row_major_a && !row_major_b) {
|
|
293
|
-
using Gemm = DefaultGemmConfig<80, cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor>::Gemm;
|
|
294
|
-
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
295
|
-
}
|
|
296
|
-
}
|
|
297
|
-
} else if (compute_capability == 75) {
|
|
298
|
-
if (datatype == F16_STR) {
|
|
299
|
-
if (row_major_a && row_major_b) {
|
|
300
|
-
using Gemm = DefaultGemmConfig<75, cutlass::half_t, cutlass::layout::RowMajor, cutlass::layout::RowMajor>::Gemm;
|
|
301
|
-
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
302
|
-
} else if (!row_major_a && row_major_b) {
|
|
303
|
-
using Gemm = DefaultGemmConfig<75, cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>::Gemm;
|
|
304
|
-
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
305
|
-
} else if (row_major_a && !row_major_b) {
|
|
306
|
-
using Gemm = DefaultGemmConfig<75, cutlass::half_t, cutlass::layout::RowMajor, cutlass::layout::ColumnMajor>::Gemm;
|
|
307
|
-
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
308
|
-
} else if (!row_major_a && !row_major_b) {
|
|
309
|
-
using Gemm = DefaultGemmConfig<75, cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor>::Gemm;
|
|
310
|
-
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
311
|
-
}
|
|
312
|
-
}
|
|
313
|
-
} else if (compute_capability == 70) {
|
|
314
|
-
if (datatype == F16_STR) {
|
|
315
|
-
if (row_major_a && row_major_b) {
|
|
316
|
-
using Gemm = DefaultGemmConfig<70, cutlass::half_t, cutlass::layout::RowMajor, cutlass::layout::RowMajor>::Gemm;
|
|
317
|
-
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
318
|
-
} else if (!row_major_a && row_major_b) {
|
|
319
|
-
using Gemm = DefaultGemmConfig<70, cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>::Gemm;
|
|
320
|
-
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
321
|
-
} else if (row_major_a && !row_major_b) {
|
|
322
|
-
using Gemm = DefaultGemmConfig<70, cutlass::half_t, cutlass::layout::RowMajor, cutlass::layout::ColumnMajor>::Gemm;
|
|
323
|
-
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
324
|
-
} else if (!row_major_a && !row_major_b) {
|
|
325
|
-
using Gemm = DefaultGemmConfig<70, cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor>::Gemm;
|
|
326
|
-
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
327
|
-
}
|
|
328
|
-
}
|
|
329
|
-
}
|
|
330
|
-
|
|
331
|
-
// No Tensor Core capability available. Run a SIMT kernel
|
|
332
|
-
if (datatype == F64_STR) {
|
|
333
|
-
if (row_major_a && row_major_b) {
|
|
334
|
-
using Gemm = DefaultGemmConfig<50, double, cutlass::layout::RowMajor, cutlass::layout::RowMajor>::Gemm;
|
|
335
|
-
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
336
|
-
} else if (!row_major_a && row_major_b) {
|
|
337
|
-
using Gemm = DefaultGemmConfig<50, double, cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>::Gemm;
|
|
338
|
-
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
339
|
-
} else if (row_major_a && !row_major_b) {
|
|
340
|
-
using Gemm = DefaultGemmConfig<50, double, cutlass::layout::RowMajor, cutlass::layout::ColumnMajor>::Gemm;
|
|
341
|
-
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
342
|
-
} else if (!row_major_a && !row_major_b) {
|
|
343
|
-
using Gemm = DefaultGemmConfig<50, double, cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor>::Gemm;
|
|
344
|
-
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
345
|
-
}
|
|
346
|
-
} else if (datatype == F32_STR) {
|
|
347
|
-
if (row_major_a && row_major_b) {
|
|
348
|
-
using Gemm = DefaultGemmConfig<50, float, cutlass::layout::RowMajor, cutlass::layout::RowMajor>::Gemm;
|
|
349
|
-
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
350
|
-
} else if (!row_major_a && row_major_b) {
|
|
351
|
-
using Gemm = DefaultGemmConfig<50, float, cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>::Gemm;
|
|
352
|
-
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
353
|
-
} else if (row_major_a && !row_major_b) {
|
|
354
|
-
using Gemm = DefaultGemmConfig<50, float, cutlass::layout::RowMajor, cutlass::layout::ColumnMajor>::Gemm;
|
|
355
|
-
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
356
|
-
} else if (!row_major_a && !row_major_b) {
|
|
357
|
-
using Gemm = DefaultGemmConfig<50, float, cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor>::Gemm;
|
|
358
|
-
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
359
|
-
}
|
|
360
|
-
} else if (datatype == F16_STR) {
|
|
361
|
-
if (row_major_a && row_major_b) {
|
|
362
|
-
using Gemm = DefaultGemmConfig<50, cutlass::half_t, cutlass::layout::RowMajor, cutlass::layout::RowMajor>::Gemm;
|
|
363
|
-
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
364
|
-
} else if (!row_major_a && row_major_b) {
|
|
365
|
-
using Gemm = DefaultGemmConfig<50, cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>::Gemm;
|
|
366
|
-
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
367
|
-
} else if (row_major_a && !row_major_b) {
|
|
368
|
-
using Gemm = DefaultGemmConfig<50, cutlass::half_t, cutlass::layout::RowMajor, cutlass::layout::ColumnMajor>::Gemm;
|
|
369
|
-
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
370
|
-
} else if (!row_major_a && !row_major_b) {
|
|
371
|
-
using Gemm = DefaultGemmConfig<50, cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor>::Gemm;
|
|
372
|
-
return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
|
|
373
|
-
}
|
|
374
|
-
}
|
|
375
|
-
|
|
376
|
-
std::cerr << "Data type " << datatype << " is not currently supported." << std::endl;
|
|
377
|
-
return false;
|
|
378
|
-
}
|
|
379
|
-
|
|
380
|
-
}
|
|
381
|
-
|
|
382
|
-
} // namespace wp
|