warp-lang 1.6.2__py3-none-macosx_10_13_universal2.whl → 1.7.1__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (191) hide show
  1. warp/__init__.py +7 -1
  2. warp/autograd.py +12 -2
  3. warp/bin/libwarp-clang.dylib +0 -0
  4. warp/bin/libwarp.dylib +0 -0
  5. warp/build.py +410 -0
  6. warp/build_dll.py +6 -14
  7. warp/builtins.py +463 -372
  8. warp/codegen.py +196 -124
  9. warp/config.py +42 -6
  10. warp/context.py +496 -271
  11. warp/dlpack.py +8 -6
  12. warp/examples/assets/nonuniform.usd +0 -0
  13. warp/examples/assets/nvidia_logo.png +0 -0
  14. warp/examples/benchmarks/benchmark_cloth.py +1 -1
  15. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  16. warp/examples/core/example_sample_mesh.py +300 -0
  17. warp/examples/distributed/example_jacobi_mpi.py +507 -0
  18. warp/examples/fem/example_apic_fluid.py +1 -1
  19. warp/examples/fem/example_burgers.py +2 -2
  20. warp/examples/fem/example_deformed_geometry.py +1 -1
  21. warp/examples/fem/example_distortion_energy.py +1 -1
  22. warp/examples/fem/example_magnetostatics.py +6 -6
  23. warp/examples/fem/utils.py +9 -3
  24. warp/examples/interop/example_jax_callable.py +116 -0
  25. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  26. warp/examples/interop/example_jax_kernel.py +205 -0
  27. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  28. warp/examples/tile/example_tile_matmul.py +2 -4
  29. warp/fem/__init__.py +11 -1
  30. warp/fem/adaptivity.py +4 -4
  31. warp/fem/field/field.py +11 -1
  32. warp/fem/field/nodal_field.py +56 -88
  33. warp/fem/field/virtual.py +62 -23
  34. warp/fem/geometry/adaptive_nanogrid.py +16 -13
  35. warp/fem/geometry/closest_point.py +1 -1
  36. warp/fem/geometry/deformed_geometry.py +5 -2
  37. warp/fem/geometry/geometry.py +5 -0
  38. warp/fem/geometry/grid_2d.py +12 -12
  39. warp/fem/geometry/grid_3d.py +12 -15
  40. warp/fem/geometry/hexmesh.py +5 -7
  41. warp/fem/geometry/nanogrid.py +9 -11
  42. warp/fem/geometry/quadmesh.py +13 -13
  43. warp/fem/geometry/tetmesh.py +3 -4
  44. warp/fem/geometry/trimesh.py +7 -20
  45. warp/fem/integrate.py +262 -93
  46. warp/fem/linalg.py +5 -5
  47. warp/fem/quadrature/pic_quadrature.py +37 -22
  48. warp/fem/quadrature/quadrature.py +194 -25
  49. warp/fem/space/__init__.py +1 -1
  50. warp/fem/space/basis_function_space.py +4 -2
  51. warp/fem/space/basis_space.py +25 -18
  52. warp/fem/space/hexmesh_function_space.py +2 -2
  53. warp/fem/space/partition.py +6 -2
  54. warp/fem/space/quadmesh_function_space.py +8 -8
  55. warp/fem/space/shape/cube_shape_function.py +23 -23
  56. warp/fem/space/shape/square_shape_function.py +12 -12
  57. warp/fem/space/shape/triangle_shape_function.py +1 -1
  58. warp/fem/space/tetmesh_function_space.py +3 -3
  59. warp/fem/space/trimesh_function_space.py +2 -2
  60. warp/fem/utils.py +12 -6
  61. warp/jax.py +14 -1
  62. warp/jax_experimental/__init__.py +16 -0
  63. warp/{jax_experimental.py → jax_experimental/custom_call.py} +28 -29
  64. warp/jax_experimental/ffi.py +702 -0
  65. warp/jax_experimental/xla_ffi.py +602 -0
  66. warp/math.py +89 -0
  67. warp/native/array.h +13 -0
  68. warp/native/builtin.h +29 -3
  69. warp/native/bvh.cpp +3 -1
  70. warp/native/bvh.cu +42 -14
  71. warp/native/bvh.h +2 -1
  72. warp/native/clang/clang.cpp +30 -3
  73. warp/native/cuda_util.cpp +14 -0
  74. warp/native/cuda_util.h +2 -0
  75. warp/native/exports.h +68 -63
  76. warp/native/intersect.h +26 -26
  77. warp/native/intersect_adj.h +33 -33
  78. warp/native/marching.cu +1 -1
  79. warp/native/mat.h +513 -9
  80. warp/native/mesh.h +10 -10
  81. warp/native/quat.h +99 -11
  82. warp/native/rand.h +6 -0
  83. warp/native/sort.cpp +122 -59
  84. warp/native/sort.cu +152 -15
  85. warp/native/sort.h +8 -1
  86. warp/native/sparse.cpp +43 -22
  87. warp/native/sparse.cu +52 -17
  88. warp/native/svd.h +116 -0
  89. warp/native/tile.h +312 -116
  90. warp/native/tile_reduce.h +46 -3
  91. warp/native/vec.h +68 -7
  92. warp/native/volume.cpp +85 -113
  93. warp/native/volume_builder.cu +25 -10
  94. warp/native/volume_builder.h +6 -0
  95. warp/native/warp.cpp +5 -6
  96. warp/native/warp.cu +100 -11
  97. warp/native/warp.h +19 -10
  98. warp/optim/linear.py +10 -10
  99. warp/render/render_opengl.py +19 -17
  100. warp/render/render_usd.py +93 -3
  101. warp/sim/articulation.py +4 -4
  102. warp/sim/collide.py +32 -19
  103. warp/sim/import_mjcf.py +449 -155
  104. warp/sim/import_urdf.py +32 -12
  105. warp/sim/inertia.py +189 -156
  106. warp/sim/integrator_euler.py +8 -5
  107. warp/sim/integrator_featherstone.py +3 -10
  108. warp/sim/integrator_vbd.py +207 -2
  109. warp/sim/integrator_xpbd.py +8 -5
  110. warp/sim/model.py +71 -25
  111. warp/sim/render.py +4 -0
  112. warp/sim/utils.py +2 -2
  113. warp/sparse.py +642 -555
  114. warp/stubs.py +217 -20
  115. warp/tests/__main__.py +0 -15
  116. warp/tests/assets/torus.usda +1 -1
  117. warp/tests/cuda/__init__.py +0 -0
  118. warp/tests/{test_mempool.py → cuda/test_mempool.py} +39 -0
  119. warp/tests/{test_streams.py → cuda/test_streams.py} +71 -0
  120. warp/tests/geometry/__init__.py +0 -0
  121. warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +66 -63
  122. warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +1 -1
  123. warp/tests/{test_volume.py → geometry/test_volume.py} +41 -6
  124. warp/tests/interop/__init__.py +0 -0
  125. warp/tests/{test_dlpack.py → interop/test_dlpack.py} +28 -5
  126. warp/tests/sim/__init__.py +0 -0
  127. warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +9 -10
  128. warp/tests/{test_collision.py → sim/test_collision.py} +236 -205
  129. warp/tests/sim/test_inertia.py +161 -0
  130. warp/tests/{test_model.py → sim/test_model.py} +40 -0
  131. warp/tests/{flaky_test_sim_grad.py → sim/test_sim_grad.py} +4 -0
  132. warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +2 -1
  133. warp/tests/sim/test_vbd.py +597 -0
  134. warp/tests/sim/test_xpbd.py +399 -0
  135. warp/tests/test_bool.py +1 -1
  136. warp/tests/test_codegen.py +24 -3
  137. warp/tests/test_examples.py +40 -38
  138. warp/tests/test_fem.py +98 -14
  139. warp/tests/test_linear_solvers.py +0 -11
  140. warp/tests/test_mat.py +577 -156
  141. warp/tests/test_mat_scalar_ops.py +4 -4
  142. warp/tests/test_overwrite.py +0 -60
  143. warp/tests/test_quat.py +356 -151
  144. warp/tests/test_rand.py +44 -37
  145. warp/tests/test_sparse.py +47 -6
  146. warp/tests/test_spatial.py +75 -0
  147. warp/tests/test_static.py +1 -1
  148. warp/tests/test_utils.py +84 -4
  149. warp/tests/test_vec.py +336 -178
  150. warp/tests/tile/__init__.py +0 -0
  151. warp/tests/{test_tile.py → tile/test_tile.py} +136 -51
  152. warp/tests/{test_tile_load.py → tile/test_tile_load.py} +98 -1
  153. warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +9 -6
  154. warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +25 -14
  155. warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +60 -1
  156. warp/tests/{test_tile_view.py → tile/test_tile_view.py} +1 -1
  157. warp/tests/unittest_serial.py +1 -0
  158. warp/tests/unittest_suites.py +45 -62
  159. warp/tests/unittest_utils.py +2 -1
  160. warp/thirdparty/unittest_parallel.py +3 -1
  161. warp/types.py +175 -666
  162. warp/utils.py +137 -72
  163. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/METADATA +46 -12
  164. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/RECORD +184 -171
  165. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/WHEEL +1 -1
  166. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info/licenses}/LICENSE.md +0 -26
  167. warp/examples/optim/example_walker.py +0 -317
  168. warp/native/cutlass_gemm.cpp +0 -43
  169. warp/native/cutlass_gemm.cu +0 -382
  170. warp/tests/test_matmul.py +0 -511
  171. warp/tests/test_matmul_lite.py +0 -411
  172. warp/tests/test_vbd.py +0 -386
  173. warp/tests/unused_test_misc.py +0 -77
  174. /warp/tests/{test_async.py → cuda/test_async.py} +0 -0
  175. /warp/tests/{test_ipc.py → cuda/test_ipc.py} +0 -0
  176. /warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +0 -0
  177. /warp/tests/{test_peer.py → cuda/test_peer.py} +0 -0
  178. /warp/tests/{test_pinned.py → cuda/test_pinned.py} +0 -0
  179. /warp/tests/{test_bvh.py → geometry/test_bvh.py} +0 -0
  180. /warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +0 -0
  181. /warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +0 -0
  182. /warp/tests/{test_mesh.py → geometry/test_mesh.py} +0 -0
  183. /warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +0 -0
  184. /warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +0 -0
  185. /warp/tests/{test_jax.py → interop/test_jax.py} +0 -0
  186. /warp/tests/{test_paddle.py → interop/test_paddle.py} +0 -0
  187. /warp/tests/{test_torch.py → interop/test_torch.py} +0 -0
  188. /warp/tests/{test_coloring.py → sim/test_coloring.py} +0 -0
  189. /warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +0 -0
  190. /warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +0 -0
  191. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/top_level.txt +0 -0
@@ -1,382 +0,0 @@
1
- /*
2
- * SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: Apache-2.0
4
- *
5
- * Licensed under the Apache License, Version 2.0 (the "License");
6
- * you may not use this file except in compliance with the License.
7
- * You may obtain a copy of the License at
8
- *
9
- * http://www.apache.org/licenses/LICENSE-2.0
10
- *
11
- * Unless required by applicable law or agreed to in writing, software
12
- * distributed under the License is distributed on an "AS IS" BASIS,
13
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- * See the License for the specific language governing permissions and
15
- * limitations under the License.
16
- */
17
-
18
- #include "builtin.h"
19
- #include "temp_buffer.h"
20
- #include "cuda_util.h"
21
-
22
- #include "cutlass/cutlass.h"
23
- #include "cutlass/gemm/device/gemm_universal.h"
24
- #include "cutlass/util/device_memory.h"
25
-
26
- #define F16_STR "<f2"
27
- #define F32_STR "<f4"
28
- #define F64_STR "<f8"
29
-
30
- namespace wp {
31
-
32
- template <typename Gemm>
33
- bool run_gemm(int m, int n, int k, int batch_count, const void* a, const void* b, const void* c, void* d, float alpha, float beta) {
34
- //
35
- // Initialize arguments
36
- //
37
- typename Gemm::EpilogueOutputOp::Params epilogue_params(
38
- (typename Gemm::EpilogueOutputOp::ElementCompute)alpha,
39
- (typename Gemm::EpilogueOutputOp::ElementCompute)beta);
40
-
41
- typename Gemm::Arguments arguments{
42
- batch_count == 1 ? cutlass::gemm::GemmUniversalMode::kGemm : cutlass::gemm::GemmUniversalMode::kBatched ,
43
- cutlass::gemm::GemmCoord{m, n, k}, // Problem size
44
- batch_count,
45
- epilogue_params,
46
- a, b, c, d,
47
- int64_t(m * k), int64_t(k * n), int64_t(m * n), int64_t(m * n), // Batch strides
48
- Gemm::LayoutA::packed({m, k}).stride(0), Gemm::LayoutB::packed({k, n}).stride(0), n, n
49
- };
50
-
51
- Gemm gemm;
52
- size_t workspace_size = Gemm::get_workspace_size(arguments);
53
- ScopedTemporary<> workspace(WP_CURRENT_CONTEXT, workspace_size);
54
- cudaStream_t stream = static_cast<cudaStream_t>(cuda_stream_get_current());
55
- cutlass::Status status = gemm.initialize(arguments, workspace.buffer(), stream);
56
-
57
- if (status != cutlass::Status::kSuccess) {
58
- cudaError_t error = cudaGetLastError();
59
- std::cerr << "Error initializing GEMM: " << cudaGetErrorString(error) << "\n";
60
- return false;
61
- }
62
-
63
- //
64
- // Run the GEMM
65
- //
66
-
67
- status = gemm(stream);
68
- if (status != cutlass::Status::kSuccess) {
69
- cudaError_t error = cudaGetLastError();
70
- std::cerr << "Runtime error: " << cudaGetErrorString(error) << "\n";
71
- return false;
72
- }
73
-
74
- return true;
75
- }
76
-
77
- template <
78
- int ComputeCapability,
79
- typename Element_,
80
- typename LayoutA,
81
- typename LayoutB
82
- >
83
- struct DefaultGemmConfig;
84
-
85
- //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
86
-
87
- // Partial specialization for SM80 F64 Tensor Cores
88
- template <typename LayoutA, typename LayoutB>
89
- struct DefaultGemmConfig<80, double, LayoutA, LayoutB> {
90
- using Gemm = cutlass::gemm::device::GemmUniversal<
91
- double, LayoutA, // ElementA and LayoutA
92
- double, LayoutB, // ElementB and LayoutB
93
- double, cutlass::layout::RowMajor, // ElementC and LayoutC
94
- double, // ElementAccumulator
95
- cutlass::arch::OpClassTensorOp, // Operation type
96
- cutlass::arch::Sm80, // Architecture
97
- cutlass::gemm::GemmShape<128, 128, 16>, // ThreadblockShape
98
- cutlass::gemm::GemmShape<32, 64, 16>, // WarpShape
99
- cutlass::gemm::GemmShape<8, 8, 4>, // Instruction Shape
100
- cutlass::epilogue::thread::LinearCombination< // Epilogue
101
- double,
102
- 1,
103
- double,
104
- double>,
105
- cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, // Swizzling
106
- 3 // Stages
107
- >;
108
- };
109
-
110
- // Partial specialization for SM80 F32 Tensor Cores
111
- template <typename LayoutA, typename LayoutB>
112
- struct DefaultGemmConfig<80, float, LayoutA, LayoutB> {
113
- using Gemm = cutlass::gemm::device::GemmUniversal<
114
- float, LayoutA, // ElementA and LayoutA
115
- float, LayoutB, // ElementB and LayoutB
116
- float, cutlass::layout::RowMajor, // ElementC and LayoutC
117
- float, // ElementAccumulator
118
- cutlass::arch::OpClassTensorOp, // Operation type
119
- cutlass::arch::Sm80, // Architecture
120
- cutlass::gemm::GemmShape<256, 128, 16>, // ThreadblockShape
121
- cutlass::gemm::GemmShape<64, 64, 16>, // WarpShape
122
- cutlass::gemm::GemmShape<16, 8, 8>, // Instruction Shape
123
- cutlass::epilogue::thread::LinearCombination< // Epilogue
124
- float,
125
- 128 / cutlass::sizeof_bits<float>::value,
126
- float,
127
- float>,
128
- cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, // Swizzling
129
- 3, // Stages
130
- 4, 4, // AlignmentA and AlignmentB
131
- cutlass::arch::OpMultiplyAddFastF32 // Math mode -- use 3xTF32
132
- >;
133
- };
134
-
135
- // Partial specialization for SM80 F16 Tensor Cores
136
- template <typename LayoutA, typename LayoutB>
137
- struct DefaultGemmConfig<80, cutlass::half_t, LayoutA, LayoutB> {
138
- using Gemm = cutlass::gemm::device::GemmUniversal<
139
- cutlass::half_t, LayoutA, // ElementA and LayoutA
140
- cutlass::half_t, LayoutB, // ElementB and LayoutB
141
- cutlass::half_t, cutlass::layout::RowMajor, // ElementC and LayoutC
142
- cutlass::half_t, // ElementAccumulator
143
- cutlass::arch::OpClassTensorOp, // Operation type
144
- cutlass::arch::Sm80, // Architecture
145
- cutlass::gemm::GemmShape<256, 128, 32>, // ThreadblockShape
146
- cutlass::gemm::GemmShape<64, 64, 32>, // WarpShape
147
- cutlass::gemm::GemmShape<16, 8, 16>, // Instruction Shape
148
- cutlass::epilogue::thread::LinearCombination< // Epilogue
149
- cutlass::half_t,
150
- 128 / cutlass::sizeof_bits<cutlass::half_t>::value,
151
- cutlass::half_t,
152
- cutlass::half_t>,
153
- cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, // Swizzling
154
- 3 // Stages
155
- >;
156
- };
157
-
158
- //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
159
-
160
- // Partial specialization for SM75 F16 Tensor Cores
161
- template <typename LayoutA, typename LayoutB>
162
- struct DefaultGemmConfig<75, cutlass::half_t, LayoutA, LayoutB> {
163
- using Gemm = cutlass::gemm::device::GemmUniversal<
164
- cutlass::half_t, LayoutA, // ElementA and LayoutA
165
- cutlass::half_t, LayoutB, // ElementB and LayoutB
166
- cutlass::half_t, cutlass::layout::RowMajor, // ElementC and LayoutC
167
- cutlass::half_t, // ElementAccumulator
168
- cutlass::arch::OpClassTensorOp, // Operation type
169
- cutlass::arch::Sm75, // Architecture
170
- cutlass::gemm::GemmShape<256, 128, 32>, // ThreadblockShape
171
- cutlass::gemm::GemmShape<64, 64, 32>, // WarpShape
172
- cutlass::gemm::GemmShape<16, 8, 8>, // Instruction Shape
173
- cutlass::epilogue::thread::LinearCombination< // Epilogue
174
- cutlass::half_t,
175
- 128 / cutlass::sizeof_bits<cutlass::half_t>::value,
176
- cutlass::half_t,
177
- cutlass::half_t>,
178
- cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, // Swizzling
179
- 2 // Stages
180
- >;
181
- };
182
-
183
- //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
184
-
185
- // Partial specialization for SM70 F16 Tensor Cores
186
- template <typename LayoutA, typename LayoutB>
187
- struct DefaultGemmConfig<70, cutlass::half_t, LayoutA, LayoutB> {
188
- using Gemm = cutlass::gemm::device::GemmUniversal<
189
- cutlass::half_t, LayoutA, // ElementA and LayoutA
190
- cutlass::half_t, LayoutB, // ElementB and LayoutB
191
- cutlass::half_t, cutlass::layout::RowMajor, // ElementC and LayoutC
192
- cutlass::half_t, // ElementAccumulator
193
- cutlass::arch::OpClassTensorOp, // Operation type
194
- cutlass::arch::Sm70, // Architecture
195
- cutlass::gemm::GemmShape<256, 128, 32>, // ThreadblockShape
196
- cutlass::gemm::GemmShape<64, 64, 32>, // WarpShape
197
- cutlass::gemm::GemmShape<8, 8, 4>, // Instruction Shape
198
- cutlass::epilogue::thread::LinearCombination< // Epilogue
199
- cutlass::half_t,
200
- 128 / cutlass::sizeof_bits<cutlass::half_t>::value,
201
- cutlass::half_t,
202
- cutlass::half_t>,
203
- cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, // Swizzling
204
- 2 // Stages
205
- >;
206
- };
207
-
208
- //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
209
-
210
- // Partial specialization for SM50 SIMT
211
- template <typename Element, typename LayoutA, typename LayoutB>
212
- struct DefaultGemmConfig<50, Element, LayoutA, LayoutB> {
213
- using Gemm = cutlass::gemm::device::GemmUniversal<
214
- Element, LayoutA, // ElementA and LayoutA
215
- Element, LayoutB, // ElementB and LayoutB
216
- Element, cutlass::layout::RowMajor, // ElementC and LayoutC
217
- Element, // ElementAccumulator
218
- cutlass::arch::OpClassSimt, // Operation type
219
- cutlass::arch::Sm50, // Architecture
220
- cutlass::gemm::GemmShape<128, 128, 8>, // ThreadblockShape
221
- cutlass::gemm::GemmShape<32, 64, 8>, // WarpShape
222
- cutlass::gemm::GemmShape<1, 1, 1>, // Instruction Shape
223
- cutlass::epilogue::thread::LinearCombination< // Epilogue
224
- Element,
225
- 1,
226
- Element,
227
- Element>,
228
- cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, // Swizzling
229
- 2 // Stages
230
- >;
231
- };
232
-
233
- //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
234
-
235
- extern "C" {
236
-
237
- WP_API
238
- bool cutlass_gemm(
239
- void* context, int compute_capability,
240
- int m, int n, int k,
241
- const char* datatype_str,
242
- const void* a, const void* b, const void* c, void* d,
243
- float alpha, float beta,
244
- bool row_major_a, bool row_major_b,
245
- bool allow_tf32x3_arith,
246
- int batch_count) {
247
-
248
- std::string datatype(datatype_str);
249
-
250
- ContextGuard guard(context);
251
-
252
- // Specializations for using Tensor Cores and A/B RowMajor/ColumnMajor designations
253
- if (compute_capability == 80) {
254
- if (datatype == F64_STR) {
255
- if (row_major_a && row_major_b) {
256
- using Gemm = DefaultGemmConfig<80, double, cutlass::layout::RowMajor, cutlass::layout::RowMajor>::Gemm;
257
- return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
258
- } else if (!row_major_a && row_major_b) {
259
- using Gemm = DefaultGemmConfig<80, double, cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>::Gemm;
260
- return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
261
- } else if (row_major_a && !row_major_b) {
262
- using Gemm = DefaultGemmConfig<80, double, cutlass::layout::RowMajor, cutlass::layout::ColumnMajor>::Gemm;
263
- return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
264
- } else if (!row_major_a && !row_major_b) {
265
- using Gemm = DefaultGemmConfig<80, double, cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor>::Gemm;
266
- return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
267
- }
268
- } else if (datatype == F32_STR && allow_tf32x3_arith) {
269
- if (row_major_a && row_major_b) {
270
- using Gemm = DefaultGemmConfig<80, float, cutlass::layout::RowMajor, cutlass::layout::RowMajor>::Gemm;
271
- return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
272
- } else if (!row_major_a && row_major_b) {
273
- using Gemm = DefaultGemmConfig<80, float, cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>::Gemm;
274
- return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
275
- } else if (row_major_a && !row_major_b) {
276
- using Gemm = DefaultGemmConfig<80, float, cutlass::layout::RowMajor, cutlass::layout::ColumnMajor>::Gemm;
277
- return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
278
- } else if (!row_major_a && !row_major_b) {
279
- using Gemm = DefaultGemmConfig<80, float, cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor>::Gemm;
280
- return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
281
- }
282
- } else if (datatype == F16_STR) {
283
- if (row_major_a && row_major_b) {
284
- using Gemm = DefaultGemmConfig<80, cutlass::half_t, cutlass::layout::RowMajor, cutlass::layout::RowMajor>::Gemm;
285
- return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
286
- } else if (!row_major_a && row_major_b) {
287
- using Gemm = DefaultGemmConfig<80, cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>::Gemm;
288
- return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
289
- } else if (row_major_a && !row_major_b) {
290
- using Gemm = DefaultGemmConfig<80, cutlass::half_t, cutlass::layout::RowMajor, cutlass::layout::ColumnMajor>::Gemm;
291
- return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
292
- } else if (!row_major_a && !row_major_b) {
293
- using Gemm = DefaultGemmConfig<80, cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor>::Gemm;
294
- return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
295
- }
296
- }
297
- } else if (compute_capability == 75) {
298
- if (datatype == F16_STR) {
299
- if (row_major_a && row_major_b) {
300
- using Gemm = DefaultGemmConfig<75, cutlass::half_t, cutlass::layout::RowMajor, cutlass::layout::RowMajor>::Gemm;
301
- return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
302
- } else if (!row_major_a && row_major_b) {
303
- using Gemm = DefaultGemmConfig<75, cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>::Gemm;
304
- return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
305
- } else if (row_major_a && !row_major_b) {
306
- using Gemm = DefaultGemmConfig<75, cutlass::half_t, cutlass::layout::RowMajor, cutlass::layout::ColumnMajor>::Gemm;
307
- return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
308
- } else if (!row_major_a && !row_major_b) {
309
- using Gemm = DefaultGemmConfig<75, cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor>::Gemm;
310
- return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
311
- }
312
- }
313
- } else if (compute_capability == 70) {
314
- if (datatype == F16_STR) {
315
- if (row_major_a && row_major_b) {
316
- using Gemm = DefaultGemmConfig<70, cutlass::half_t, cutlass::layout::RowMajor, cutlass::layout::RowMajor>::Gemm;
317
- return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
318
- } else if (!row_major_a && row_major_b) {
319
- using Gemm = DefaultGemmConfig<70, cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>::Gemm;
320
- return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
321
- } else if (row_major_a && !row_major_b) {
322
- using Gemm = DefaultGemmConfig<70, cutlass::half_t, cutlass::layout::RowMajor, cutlass::layout::ColumnMajor>::Gemm;
323
- return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
324
- } else if (!row_major_a && !row_major_b) {
325
- using Gemm = DefaultGemmConfig<70, cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor>::Gemm;
326
- return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
327
- }
328
- }
329
- }
330
-
331
- // No Tensor Core capability available. Run a SIMT kernel
332
- if (datatype == F64_STR) {
333
- if (row_major_a && row_major_b) {
334
- using Gemm = DefaultGemmConfig<50, double, cutlass::layout::RowMajor, cutlass::layout::RowMajor>::Gemm;
335
- return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
336
- } else if (!row_major_a && row_major_b) {
337
- using Gemm = DefaultGemmConfig<50, double, cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>::Gemm;
338
- return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
339
- } else if (row_major_a && !row_major_b) {
340
- using Gemm = DefaultGemmConfig<50, double, cutlass::layout::RowMajor, cutlass::layout::ColumnMajor>::Gemm;
341
- return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
342
- } else if (!row_major_a && !row_major_b) {
343
- using Gemm = DefaultGemmConfig<50, double, cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor>::Gemm;
344
- return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
345
- }
346
- } else if (datatype == F32_STR) {
347
- if (row_major_a && row_major_b) {
348
- using Gemm = DefaultGemmConfig<50, float, cutlass::layout::RowMajor, cutlass::layout::RowMajor>::Gemm;
349
- return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
350
- } else if (!row_major_a && row_major_b) {
351
- using Gemm = DefaultGemmConfig<50, float, cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>::Gemm;
352
- return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
353
- } else if (row_major_a && !row_major_b) {
354
- using Gemm = DefaultGemmConfig<50, float, cutlass::layout::RowMajor, cutlass::layout::ColumnMajor>::Gemm;
355
- return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
356
- } else if (!row_major_a && !row_major_b) {
357
- using Gemm = DefaultGemmConfig<50, float, cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor>::Gemm;
358
- return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
359
- }
360
- } else if (datatype == F16_STR) {
361
- if (row_major_a && row_major_b) {
362
- using Gemm = DefaultGemmConfig<50, cutlass::half_t, cutlass::layout::RowMajor, cutlass::layout::RowMajor>::Gemm;
363
- return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
364
- } else if (!row_major_a && row_major_b) {
365
- using Gemm = DefaultGemmConfig<50, cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>::Gemm;
366
- return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
367
- } else if (row_major_a && !row_major_b) {
368
- using Gemm = DefaultGemmConfig<50, cutlass::half_t, cutlass::layout::RowMajor, cutlass::layout::ColumnMajor>::Gemm;
369
- return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
370
- } else if (!row_major_a && !row_major_b) {
371
- using Gemm = DefaultGemmConfig<50, cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor>::Gemm;
372
- return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
373
- }
374
- }
375
-
376
- std::cerr << "Data type " << datatype << " is not currently supported." << std::endl;
377
- return false;
378
- }
379
-
380
- }
381
-
382
- } // namespace wp