warp-lang 0.10.1__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (300) hide show
  1. warp/__init__.py +10 -4
  2. warp/__init__.pyi +1 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +5 -3
  6. warp/build_dll.py +29 -9
  7. warp/builtins.py +868 -507
  8. warp/codegen.py +1074 -638
  9. warp/config.py +3 -3
  10. warp/constants.py +6 -0
  11. warp/context.py +715 -222
  12. warp/fabric.py +326 -0
  13. warp/fem/__init__.py +27 -0
  14. warp/fem/cache.py +389 -0
  15. warp/fem/dirichlet.py +181 -0
  16. warp/fem/domain.py +263 -0
  17. warp/fem/field/__init__.py +101 -0
  18. warp/fem/field/field.py +149 -0
  19. warp/fem/field/nodal_field.py +299 -0
  20. warp/fem/field/restriction.py +21 -0
  21. warp/fem/field/test.py +181 -0
  22. warp/fem/field/trial.py +183 -0
  23. warp/fem/geometry/__init__.py +19 -0
  24. warp/fem/geometry/closest_point.py +70 -0
  25. warp/fem/geometry/deformed_geometry.py +271 -0
  26. warp/fem/geometry/element.py +744 -0
  27. warp/fem/geometry/geometry.py +186 -0
  28. warp/fem/geometry/grid_2d.py +373 -0
  29. warp/fem/geometry/grid_3d.py +435 -0
  30. warp/fem/geometry/hexmesh.py +953 -0
  31. warp/fem/geometry/partition.py +376 -0
  32. warp/fem/geometry/quadmesh_2d.py +532 -0
  33. warp/fem/geometry/tetmesh.py +840 -0
  34. warp/fem/geometry/trimesh_2d.py +577 -0
  35. warp/fem/integrate.py +1616 -0
  36. warp/fem/operator.py +191 -0
  37. warp/fem/polynomial.py +213 -0
  38. warp/fem/quadrature/__init__.py +2 -0
  39. warp/fem/quadrature/pic_quadrature.py +245 -0
  40. warp/fem/quadrature/quadrature.py +294 -0
  41. warp/fem/space/__init__.py +292 -0
  42. warp/fem/space/basis_space.py +489 -0
  43. warp/fem/space/collocated_function_space.py +105 -0
  44. warp/fem/space/dof_mapper.py +236 -0
  45. warp/fem/space/function_space.py +145 -0
  46. warp/fem/space/grid_2d_function_space.py +267 -0
  47. warp/fem/space/grid_3d_function_space.py +306 -0
  48. warp/fem/space/hexmesh_function_space.py +352 -0
  49. warp/fem/space/partition.py +350 -0
  50. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  51. warp/fem/space/restriction.py +160 -0
  52. warp/fem/space/shape/__init__.py +15 -0
  53. warp/fem/space/shape/cube_shape_function.py +738 -0
  54. warp/fem/space/shape/shape_function.py +103 -0
  55. warp/fem/space/shape/square_shape_function.py +611 -0
  56. warp/fem/space/shape/tet_shape_function.py +567 -0
  57. warp/fem/space/shape/triangle_shape_function.py +429 -0
  58. warp/fem/space/tetmesh_function_space.py +292 -0
  59. warp/fem/space/topology.py +295 -0
  60. warp/fem/space/trimesh_2d_function_space.py +221 -0
  61. warp/fem/types.py +77 -0
  62. warp/fem/utils.py +495 -0
  63. warp/native/array.h +147 -44
  64. warp/native/builtin.h +122 -149
  65. warp/native/bvh.cpp +73 -325
  66. warp/native/bvh.cu +406 -23
  67. warp/native/bvh.h +34 -43
  68. warp/native/clang/clang.cpp +13 -8
  69. warp/native/crt.h +2 -0
  70. warp/native/cuda_crt.h +5 -0
  71. warp/native/cuda_util.cpp +15 -3
  72. warp/native/cuda_util.h +3 -1
  73. warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
  74. warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
  75. warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
  76. warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
  77. warp/native/cutlass/tools/library/scripts/library.py +799 -0
  78. warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
  79. warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
  80. warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
  81. warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
  82. warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
  83. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
  84. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
  85. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
  86. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
  87. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
  88. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
  89. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
  90. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
  91. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
  92. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
  93. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
  94. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
  95. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
  96. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
  97. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
  98. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
  99. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
  100. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
  101. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
  102. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
  103. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
  104. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
  105. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
  106. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
  107. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
  108. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
  109. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
  110. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
  111. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
  112. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
  113. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
  114. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
  115. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  116. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  117. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
  118. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
  119. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
  120. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  121. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  122. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
  123. warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
  124. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
  125. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
  126. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
  127. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
  128. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
  129. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
  130. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
  131. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
  132. warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
  133. warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
  134. warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
  135. warp/native/cutlass/tools/library/scripts/rt.py +796 -0
  136. warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
  137. warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
  138. warp/native/cutlass_gemm.cu +5 -3
  139. warp/native/exports.h +1240 -952
  140. warp/native/fabric.h +228 -0
  141. warp/native/hashgrid.cpp +4 -4
  142. warp/native/hashgrid.h +22 -2
  143. warp/native/intersect.h +22 -7
  144. warp/native/intersect_adj.h +8 -8
  145. warp/native/intersect_tri.h +1 -1
  146. warp/native/marching.cu +157 -161
  147. warp/native/mat.h +80 -19
  148. warp/native/matnn.h +2 -2
  149. warp/native/mesh.cpp +33 -108
  150. warp/native/mesh.cu +114 -23
  151. warp/native/mesh.h +446 -46
  152. warp/native/noise.h +272 -329
  153. warp/native/quat.h +51 -8
  154. warp/native/rand.h +45 -35
  155. warp/native/range.h +6 -2
  156. warp/native/reduce.cpp +1 -1
  157. warp/native/reduce.cu +10 -12
  158. warp/native/runlength_encode.cu +6 -10
  159. warp/native/scan.cu +8 -11
  160. warp/native/sparse.cpp +4 -4
  161. warp/native/sparse.cu +164 -154
  162. warp/native/spatial.h +2 -2
  163. warp/native/temp_buffer.h +14 -30
  164. warp/native/vec.h +107 -23
  165. warp/native/volume.h +120 -0
  166. warp/native/warp.cpp +560 -30
  167. warp/native/warp.cu +431 -44
  168. warp/native/warp.h +13 -4
  169. warp/optim/__init__.py +1 -0
  170. warp/optim/linear.py +922 -0
  171. warp/optim/sgd.py +92 -0
  172. warp/render/render_opengl.py +335 -119
  173. warp/render/render_usd.py +11 -11
  174. warp/sim/__init__.py +2 -2
  175. warp/sim/articulation.py +385 -185
  176. warp/sim/collide.py +8 -0
  177. warp/sim/import_mjcf.py +297 -106
  178. warp/sim/import_urdf.py +389 -210
  179. warp/sim/import_usd.py +198 -97
  180. warp/sim/inertia.py +17 -18
  181. warp/sim/integrator_euler.py +14 -8
  182. warp/sim/integrator_xpbd.py +158 -16
  183. warp/sim/model.py +795 -291
  184. warp/sim/render.py +3 -3
  185. warp/sim/utils.py +3 -0
  186. warp/sparse.py +640 -150
  187. warp/stubs.py +606 -267
  188. warp/tape.py +61 -10
  189. warp/tests/__main__.py +3 -6
  190. warp/tests/assets/curlnoise_golden.npy +0 -0
  191. warp/tests/assets/pnoise_golden.npy +0 -0
  192. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  193. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  194. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  195. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  196. warp/tests/aux_test_unresolved_func.py +14 -0
  197. warp/tests/aux_test_unresolved_symbol.py +14 -0
  198. warp/tests/disabled_kinematics.py +239 -0
  199. warp/tests/run_coverage_serial.py +31 -0
  200. warp/tests/test_adam.py +103 -106
  201. warp/tests/test_arithmetic.py +128 -74
  202. warp/tests/test_array.py +212 -97
  203. warp/tests/test_array_reduce.py +57 -23
  204. warp/tests/test_atomic.py +64 -28
  205. warp/tests/test_bool.py +99 -0
  206. warp/tests/test_builtins_resolution.py +1292 -0
  207. warp/tests/test_bvh.py +42 -18
  208. warp/tests/test_closest_point_edge_edge.py +54 -57
  209. warp/tests/test_codegen.py +208 -130
  210. warp/tests/test_compile_consts.py +28 -20
  211. warp/tests/test_conditional.py +108 -24
  212. warp/tests/test_copy.py +10 -12
  213. warp/tests/test_ctypes.py +112 -88
  214. warp/tests/test_dense.py +21 -14
  215. warp/tests/test_devices.py +98 -0
  216. warp/tests/test_dlpack.py +75 -75
  217. warp/tests/test_examples.py +277 -0
  218. warp/tests/test_fabricarray.py +955 -0
  219. warp/tests/test_fast_math.py +15 -11
  220. warp/tests/test_fem.py +1271 -0
  221. warp/tests/test_fp16.py +53 -19
  222. warp/tests/test_func.py +187 -86
  223. warp/tests/test_generics.py +194 -49
  224. warp/tests/test_grad.py +178 -109
  225. warp/tests/test_grad_customs.py +176 -0
  226. warp/tests/test_hash_grid.py +52 -37
  227. warp/tests/test_import.py +10 -23
  228. warp/tests/test_indexedarray.py +32 -31
  229. warp/tests/test_intersect.py +18 -9
  230. warp/tests/test_large.py +141 -0
  231. warp/tests/test_launch.py +14 -41
  232. warp/tests/test_lerp.py +64 -65
  233. warp/tests/test_linear_solvers.py +154 -0
  234. warp/tests/test_lvalue.py +493 -0
  235. warp/tests/test_marching_cubes.py +12 -13
  236. warp/tests/test_mat.py +517 -2898
  237. warp/tests/test_mat_lite.py +115 -0
  238. warp/tests/test_mat_scalar_ops.py +2889 -0
  239. warp/tests/test_math.py +103 -9
  240. warp/tests/test_matmul.py +305 -69
  241. warp/tests/test_matmul_lite.py +410 -0
  242. warp/tests/test_mesh.py +71 -14
  243. warp/tests/test_mesh_query_aabb.py +41 -25
  244. warp/tests/test_mesh_query_point.py +140 -22
  245. warp/tests/test_mesh_query_ray.py +39 -22
  246. warp/tests/test_mlp.py +30 -22
  247. warp/tests/test_model.py +92 -89
  248. warp/tests/test_modules_lite.py +39 -0
  249. warp/tests/test_multigpu.py +88 -114
  250. warp/tests/test_noise.py +12 -11
  251. warp/tests/test_operators.py +16 -20
  252. warp/tests/test_options.py +11 -11
  253. warp/tests/test_pinned.py +17 -18
  254. warp/tests/test_print.py +32 -11
  255. warp/tests/test_quat.py +275 -129
  256. warp/tests/test_rand.py +18 -16
  257. warp/tests/test_reload.py +38 -34
  258. warp/tests/test_rounding.py +50 -43
  259. warp/tests/test_runlength_encode.py +168 -20
  260. warp/tests/test_smoothstep.py +9 -11
  261. warp/tests/test_snippet.py +143 -0
  262. warp/tests/test_sparse.py +261 -63
  263. warp/tests/test_spatial.py +276 -243
  264. warp/tests/test_streams.py +110 -85
  265. warp/tests/test_struct.py +268 -63
  266. warp/tests/test_tape.py +39 -21
  267. warp/tests/test_torch.py +118 -89
  268. warp/tests/test_transient_module.py +12 -13
  269. warp/tests/test_types.py +614 -0
  270. warp/tests/test_utils.py +494 -0
  271. warp/tests/test_vec.py +354 -2050
  272. warp/tests/test_vec_lite.py +73 -0
  273. warp/tests/test_vec_scalar_ops.py +2099 -0
  274. warp/tests/test_volume.py +457 -293
  275. warp/tests/test_volume_write.py +124 -134
  276. warp/tests/unittest_serial.py +35 -0
  277. warp/tests/unittest_suites.py +341 -0
  278. warp/tests/unittest_utils.py +568 -0
  279. warp/tests/unused_test_misc.py +71 -0
  280. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  281. warp/thirdparty/appdirs.py +36 -45
  282. warp/thirdparty/unittest_parallel.py +549 -0
  283. warp/torch.py +9 -6
  284. warp/types.py +1089 -366
  285. warp/utils.py +93 -387
  286. warp_lang-0.11.0.dist-info/METADATA +238 -0
  287. warp_lang-0.11.0.dist-info/RECORD +332 -0
  288. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
  289. warp/tests/test_all.py +0 -219
  290. warp/tests/test_array_scan.py +0 -60
  291. warp/tests/test_base.py +0 -208
  292. warp/tests/test_unresolved_func.py +0 -7
  293. warp/tests/test_unresolved_symbol.py +0 -7
  294. warp_lang-0.10.1.dist-info/METADATA +0 -21
  295. warp_lang-0.10.1.dist-info/RECORD +0 -188
  296. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  297. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  298. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  299. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/LICENSE.md +0 -0
  300. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1276 @@
1
+ ################################################################################
2
+ #
3
+ # Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ ################################################################################
32
+
33
+ import enum
34
+ import copy
35
+ import numpy as np
36
+ from typeguard import typechecked
37
+ import cutlass
38
+ from pycutlass import *
39
+ from cuda import cuda
40
+
41
+
42
+ ################################################################################
43
+ #
44
+ # Data structure modeling a GEMM operation
45
+ #
46
+ ################################################################################
47
+
48
+
49
+ def transpose_layout(layout: cutlass.layout):
50
+ if layout == cutlass.ColumnMajor:
51
+ return cutlass.RowMajor
52
+ elif layout == cutlass.RowMajor:
53
+ return cutlass.ColumnMajor
54
+ else:
55
+ raise ValueError("unsupported Layout {}".format(layout))
56
+
57
+
58
+ # @typechecked
59
+ class GemmArguments(ArgumentBase):
60
+ """
61
+ Argument wrapper for GEMM. It encodes problem information and
62
+ user-provide tensors into the kernel's argument
63
+
64
+ :param operation: the GEMM operation to take the argument
65
+ :type operation: :class:`pycutlass.GemmOperationUniversal` |
66
+ :class:`pycutlass.GemmOperationGrouped`
67
+
68
+ :param problem_size: GEMM problem size gemm(M, N, K)
69
+ :type operation: :class:`cutlass.gemm.GemmCoord`
70
+
71
+ :param A: tensor A
72
+ :type A: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
73
+
74
+ :param B: tensor B
75
+ :type B: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
76
+
77
+ :param C: tensor C
78
+ :type C: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
79
+
80
+ :param D: tensor D
81
+ :type D: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
82
+
83
+ :param gemm_mode: GEMM mode
84
+ :type gemm_mode: :class:`cutlass.gemm.Mode`
85
+
86
+ :param output_op: output operator, optional
87
+ :type output_op: :class:`pycutlass.LinearCombinationFunctorArguments`
88
+ """
89
+
90
+ def __init__(
91
+ self, operation: 'GemmOperation', problem_size: 'cutlass.gemm.GemmCoord',
92
+ A: 'Tensor', B: 'Tensor', C: 'Tensor', D: 'Tensor',
93
+ gemm_mode: 'cutlass.gemm.Mode'=cutlass.gemm.Mode.Gemm, **kwargs):
94
+
95
+ self.operation = operation
96
+
97
+ self.layout_A: cutlass.layout = operation.A.layout
98
+ self.layout_B: cutlass.layout = operation.B.layout
99
+ self.layout_C: cutlass.layout = operation.C.layout
100
+
101
+ self.element_A = operation.A.element
102
+ self.element_B = operation.B.element
103
+ self.element_C = operation.C.element
104
+
105
+ if (operation.C.layout in
106
+ [cutlass.RowMajorInterleaved32, cutlass.ColumnMajorInterleaved32]):
107
+ # reorder tensor B for interleaved layout output
108
+ B = self.reorder_tensor_B(B, problem_size)
109
+
110
+ super().__init__(A, B, C, D, **kwargs)
111
+
112
+ if operation.switched:
113
+ self.problem_size = cutlass.gemm.GemmCoord(
114
+ problem_size.n(), problem_size.m(), problem_size.k())
115
+ self.ptr_A, self.ptr_B = self.ptr_B, self.ptr_A
116
+ else:
117
+ self.problem_size = cutlass.gemm.GemmCoord(
118
+ problem_size.m(), problem_size.n(), problem_size.k())
119
+
120
+ # if the number of elements in C = problem_size.n
121
+ # C is treated as the bias
122
+ if hasattr(self, "tensor_c_numel"):
123
+ if (self.tensor_c_numel == self.problem_size.n() and
124
+ self.problem_size.m() != 1): self.bias = True
125
+
126
+ # get the leading dimension
127
+ self.lda = operation.A.layout.packed(self.problem_size.mk()).stride()
128
+ self.ldb = operation.B.layout.packed(self.problem_size.kn()).stride()
129
+ self.ldc = operation.C.layout.packed(self.problem_size.mn()).stride()
130
+ self.ldd = self.ldc
131
+
132
+ # stride 0 trick
133
+ if self.bias:
134
+ self.ldc = 0
135
+
136
+ if 'output_op' in kwargs.keys() and \
137
+ gemm_mode != cutlass.gemm.Mode.GemmSplitKParallel:
138
+ self.output_op = kwargs['output_op']
139
+ else:
140
+ self.output_op = self.operation.epilogue_type(1.0, 0.0)
141
+
142
+ # get number of slices on k dimension
143
+ self.gemm_mode = gemm_mode
144
+ if gemm_mode in [cutlass.gemm.Mode.Gemm, cutlass.gemm.Mode.GemmSplitKParallel]:
145
+ if 'split_k_slices' in kwargs.keys():
146
+ self.batch_count = kwargs['split_k_slices']
147
+ else:
148
+ self.batch_count = 1
149
+ self.split_k_slices = self.batch_count
150
+
151
+ if gemm_mode in [cutlass.gemm.Mode.Batched, cutlass.gemm.Mode.Array]:
152
+ if 'batch' in kwargs.keys():
153
+ self.batch_count = kwargs['batch']
154
+ else:
155
+ self.batch_count = 1
156
+
157
+ self.batched_stride_A = self.problem_size.m() * self.problem_size.k()
158
+ self.batched_stride_B = self.problem_size.n() * self.problem_size.k()
159
+ self.batched_stride_C = self.problem_size.m() * self.problem_size.n()
160
+ self.batched_stride_D = self.problem_size.m() * self.problem_size.n()
161
+ if self.bias:
162
+ self.batched_stride_C = self.problem_size.n()
163
+
164
+ # support GEMM Mode Array
165
+ if gemm_mode == cutlass.gemm.Mode.Array:
166
+ self.ptr_A_array = []
167
+ self.ptr_B_array = []
168
+ self.ptr_C_array = []
169
+ self.ptr_D_array = []
170
+
171
+ ptr_A_addr = int(self.ptr_A)
172
+ ptr_B_addr = int(self.ptr_B)
173
+ ptr_C_addr = int(self.ptr_C)
174
+ ptr_D_addr = int(self.ptr_D)
175
+
176
+ stride_A = self.batched_stride_A * DataTypeSize[self.element_A] // 8
177
+ stride_B = self.batched_stride_B * DataTypeSize[self.element_B] // 8
178
+ stride_C = self.batched_stride_C * DataTypeSize[self.element_C] // 8
179
+ stride_D = self.batched_stride_D * DataTypeSize[self.element_C] // 8
180
+ for _ in range(self.batch_count):
181
+ self.ptr_A_array.append(ptr_A_addr)
182
+ self.ptr_B_array.append(ptr_B_addr)
183
+ self.ptr_C_array.append(ptr_C_addr)
184
+ self.ptr_D_array.append(ptr_D_addr)
185
+
186
+ ptr_A_addr += stride_A
187
+ ptr_B_addr += stride_B
188
+ ptr_C_addr += stride_C
189
+ ptr_D_addr += stride_D
190
+
191
+ self.ptr_A_array_buffer = todevice(self.ptr_A_array, dtype=np.int64)
192
+ self.ptr_B_array_buffer = todevice(self.ptr_B_array, dtype=np.int64)
193
+ self.ptr_C_array_buffer = todevice(self.ptr_C_array, dtype=np.int64)
194
+ self.ptr_D_array_buffer = todevice(self.ptr_D_array, dtype=np.int64)
195
+
196
+ if isinstance(self.operation, GemmOperationUniversal):
197
+ self.initialize()
198
+
199
+ def reorder_tensor_B(self, tensor_B: 'np.ndarray',
200
+ problem_size: 'cutlass.gemm.GemmCoord'):
201
+ """
202
+ Reorder tensor_B for interleaved layout
203
+
204
+ :param tensor_B: input tensor B
205
+ :type tensor_B: numpy.ndarray
206
+ :param problem_size: GEMM problem size
207
+ :type problem_size: :class:`cutlass.gemm.GemmCoord`
208
+
209
+ :return: reordered tensor B
210
+ :rtype: numpy.ndarray
211
+ """
212
+ reordered_tensor_B = np.empty_like(tensor_B)
213
+ tensor_ref_B = self.get_tensor_ref(
214
+ tensor_B, self.element_B, self.layout_B, problem_size, "b"
215
+ )
216
+ reordered_tensor_ref_B = self.get_tensor_ref(
217
+ reordered_tensor_B, self.element_B, self.layout_B, problem_size, "b"
218
+ )
219
+ cutlass.gemm.host.reorder_column(
220
+ tensor_ref_B, reordered_tensor_ref_B, problem_size)
221
+ return reordered_tensor_B
222
+
223
+ def get_tensor_ref(
224
+ self, tensor, dtype, tensor_layout, problem_size, operand):
225
+ if operand == "a":
226
+ tensor_coord = problem_size.mk()
227
+ elif operand == "b":
228
+ tensor_coord = problem_size.kn()
229
+ elif operand in ["c", "d"]:
230
+ tensor_coord = problem_size.mn()
231
+ else:
232
+ raise ValueError("unknown operand: " + operand)
233
+
234
+ layout = tensor_layout.packed(tensor_coord)
235
+
236
+ return TensorRef(tensor, dtype, layout).tensor_ref
237
+
238
+ def get_arguments(self):
239
+ problem_size_ = GemmCoord_(self.problem_size)
240
+ grid_tiled_shape_ = GemmCoord_(
241
+ cutlass.gemm.GemmCoord(
242
+ self.grid_tiled_shape.x, self.grid_tiled_shape.y,
243
+ self.grid_tiled_shape.z
244
+ )
245
+ )
246
+ if self.gemm_mode == cutlass.gemm.Mode.Array:
247
+ arguments = self.operation.argument_type(
248
+ # Arguments from UniversalArgumentsBase
249
+ self.gemm_mode, problem_size_, self.batch_count, 0,
250
+ # Remaining arguments
251
+ self.output_op,
252
+ int(self.ptr_A_array_buffer.ptr),
253
+ int(self.ptr_B_array_buffer.ptr),
254
+ int(self.ptr_C_array_buffer.ptr),
255
+ int(self.ptr_D_array_buffer.ptr),
256
+ 0, 0, 0,
257
+ self.lda, self.ldb, self.ldc, self.ldd,
258
+ self.lda, self.ldb, self.ldc, self.ldd,
259
+ 0, 0, 0
260
+ )
261
+ else:
262
+ arguments = self.operation.argument_type(
263
+ # Arguments from UniversalArgumentsBase
264
+ self.gemm_mode, problem_size_, self.batch_count, self.batched_stride_D,
265
+ # Remaining arguments
266
+ self.output_op,
267
+ int(self.ptr_A), int(self.ptr_B), int(self.ptr_C), int(self.ptr_D),
268
+ self.batched_stride_A, self.batched_stride_B, self.batched_stride_C,
269
+ self.lda, self.ldb, self.ldc, self.ldd,
270
+ self.lda, self.ldb, self.ldc, self.ldd,
271
+ 0, 0, 0
272
+ )
273
+
274
+ self.arguments = arguments, grid_tiled_shape_, self.gemm_k_size
275
+
276
+ def initialize(self):
277
+ # get launch configuration
278
+ launch_config = self.operation.rt_module.plan(self)
279
+
280
+ # get the host and evice workspace
281
+ device_workspace_size = \
282
+ self.operation.rt_module.get_device_workspace_size(self)
283
+
284
+ if device_workspace_size > 0:
285
+ self.workspace_buffer = device_mem_alloc(device_workspace_size)
286
+ workspace_ptr = self.workspace_buffer.ptr
287
+ err, = cuda.cuMemsetD32(
288
+ workspace_ptr, 0, device_workspace_size // 4)
289
+ else:
290
+ workspace_ptr = None
291
+
292
+ device_workspace = 0
293
+ if (workspace_ptr is not None and
294
+ self.gemm_mode == cutlass.gemm.Mode.GemmSplitKParallel):
295
+ # in GEMM splik-K parallel, the D pointer is redirected
296
+ # to the workspace
297
+ self.ptr_D = cuda.CUdeviceptr(workspace_ptr)
298
+ elif (workspace_ptr is not None and
299
+ self.gemm_mode == cutlass.gemm.Mode.Gemm):
300
+ # in GEMM split-K serial
301
+ device_workspace = workspace_ptr
302
+
303
+ self.get_arguments()
304
+
305
+ arguments, grid_tiled_shape, gemm_k_size = self.arguments
306
+ res_arg = self.operation.rt_module.get_args(
307
+ ctypes.byref(arguments), ctypes.c_void_p(int(device_workspace)))
308
+ host_workspace = bytearray(res_arg.contents)
309
+
310
+ device_workspace = None
311
+
312
+ self.host_workspace = host_workspace
313
+ self.device_workspace = device_workspace
314
+ self.launch_config = launch_config
315
+
316
+
317
+ class GemmGroupedArguments:
318
+ """
319
+ Argument wrapper for GEMM Grouped. It encodes problem information and
320
+ user-provide tensors into the kernel's argument
321
+
322
+ :param operation: the GEMM Grouped operation to take the argument
323
+ :type operation: :class:`pycutlass.GemmOperationGrouped`
324
+
325
+ :param problem_size: list of GEMM problem size gemm(M, N, K)
326
+ :type operation: list[:class:`cutlass.gemm.GemmCoord`]
327
+
328
+ :param A: list of tensor A
329
+ :type A: list[cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray]
330
+
331
+ :param B: list of tensor B
332
+ :type B: list[cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray]
333
+
334
+ :param C: list of tensor C
335
+ :type C: list[cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray]
336
+
337
+ :param D: list of tensor D
338
+ :type D: list[cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray]
339
+
340
+ :param output_op: output operator, optional
341
+ :type output_op: :class:`pycutlass.LinearCombinationFunctorArguments`
342
+ """
343
+ def __init__(
344
+ self, operation: 'GemmOperationGrouped',
345
+ problem_sizes: 'list[cutlass.gemm.GemmCoord]',
346
+ A: 'list[Tensor]', B: 'list[Tensor]', C: 'list[torch.Tensor]',
347
+ D: 'list[Tensor]', **kwargs):
348
+
349
+ # get number of problems in the group
350
+ self.problem_count = len(problem_sizes)
351
+
352
+ # check the input arguments
353
+ assert len(A) == self.problem_count
354
+ assert len(B) == self.problem_count
355
+ assert len(C) == self.problem_count
356
+ assert len(D) == self.problem_count
357
+
358
+ problem_size_host = []
359
+ self.ptr_A_host = []
360
+ self.ptr_B_host = []
361
+ self.ptr_C_host = []
362
+ self.ptr_D_host = []
363
+
364
+ lda_host = []
365
+ ldb_host = []
366
+ ldc_host = []
367
+ ldd_host = []
368
+
369
+ self.partitions = 1
370
+
371
+ self.operation = operation
372
+
373
+ # get the threadblock
374
+ threadblock_shape = operation.tile_description.threadblock_shape
375
+ self.threadblock_shape = cutlass.gemm.GemmCoord(
376
+ threadblock_shape[0], threadblock_shape[1], threadblock_shape[2])
377
+ self.threadblock_swizzle = operation.swizzling_functor
378
+
379
+ self.total_tiles = 0
380
+
381
+ self.gemm_arguments = []
382
+
383
+ # process the input arguments
384
+ for idx, problem_size in enumerate(problem_sizes):
385
+ M, N, K = problem_size.m(), problem_size.n(), problem_size.k()
386
+ temp_argument = GemmArguments(
387
+ operation=operation,
388
+ problem_size=cutlass.gemm.GemmCoord(M, N, K),
389
+ A=A[idx], B=B[idx], C=C[idx], D=D[idx],
390
+ )
391
+ self.gemm_arguments.append(temp_argument)
392
+
393
+ problem_size_host.append(
394
+ [temp_argument.problem_size.m(),
395
+ temp_argument.problem_size.n(),
396
+ temp_argument.problem_size.k()]
397
+ )
398
+
399
+ self.ptr_A_host.append(int(temp_argument.ptr_A))
400
+ lda_host.append(temp_argument.lda)
401
+
402
+ self.ptr_B_host.append(int(temp_argument.ptr_B))
403
+ ldb_host.append(temp_argument.ldb)
404
+
405
+ self.ptr_C_host.append(int(temp_argument.ptr_C))
406
+ ldc_host.append(temp_argument.ldc)
407
+
408
+ self.ptr_D_host.append(int(temp_argument.ptr_D))
409
+ ldd_host.append(temp_argument.ldd)
410
+
411
+ # get number of tiles
412
+ grid = self.threadblock_swizzle.get_grid_shape(
413
+ self.threadblock_swizzle.get_tiled_shape(
414
+ temp_argument.problem_size, self.threadblock_shape,
415
+ temp_argument.batch_count)
416
+ )
417
+ self.total_tiles += grid.x * grid.y * grid.z
418
+
419
+ self.problem_size_buffer = todevice(problem_size_host, np.int32)
420
+ self.ptr_A_buffer = todevice(self.ptr_A_host, np.int64)
421
+ self.ptr_B_buffer = todevice(self.ptr_B_host, np.int64)
422
+ self.ptr_C_buffer = todevice(self.ptr_C_host, np.int64)
423
+ self.ptr_D_buffer = todevice(self.ptr_D_host, np.int64)
424
+
425
+ self.lda_buffer = todevice(lda_host, np.int64)
426
+ self.ldb_buffer = todevice(ldb_host, np.int64)
427
+ self.ldc_buffer = todevice(ldc_host, np.int64)
428
+ self.ldd_buffer = todevice(ldd_host, np.int64)
429
+
430
+ if 'output_op' in kwargs.keys():
431
+ self.alpha = kwargs['output_op'].alpha
432
+ self.beta = kwargs['output_op'].beta
433
+ else:
434
+ self.alpha = 1.0
435
+ self.beta = 0.0
436
+
437
+ if 'output_op' in kwargs.keys():
438
+ self.output_op = kwargs['output_op']
439
+ else:
440
+ self.output_op = self.operation.epilogue_type(1.0, 0.0)
441
+
442
+
443
+ # get host problem size
444
+ self.host_problem_size_ptr = np.array(
445
+ problem_size_host, dtype=np.int32).__array_interface__['data'][0]
446
+
447
+ self.arguments = self.get_arguments()
448
+
449
+ self.initialize()
450
+
451
+ def get_arguments(self):
452
+ return self.operation.argument_type(
453
+ self.problem_size_buffer.ptr, self.problem_count, self.total_tiles,
454
+ self.output_op, self.ptr_A_buffer.ptr, self.ptr_B_buffer.ptr,
455
+ self.ptr_C_buffer.ptr, self.ptr_D_buffer.ptr, self.lda_buffer.ptr,
456
+ self.ldb_buffer.ptr, self.ldc_buffer.ptr, self.ldd_buffer.ptr,
457
+ ctypes.c_void_p(int(self.host_problem_size_ptr))
458
+ )
459
+
460
+ def initialize(self):
461
+ # get launch configuration
462
+ launch_config = self.operation.rt_module.plan(self)
463
+
464
+ # get the host and evice workspace
465
+ device_workspace_size = \
466
+ self.operation.rt_module.get_device_workspace_size(self)
467
+
468
+ if device_workspace_size > 0:
469
+ self.workspace_buffer = device_mem_alloc(device_workspace_size)
470
+ workspace_ptr = self.workspace_buffer.ptr
471
+ err, = cuda.cuMemsetD32(
472
+ workspace_ptr, 0, device_workspace_size // 4)
473
+ else:
474
+ workspace_ptr = None
475
+
476
+ if self.operation.precompute_mode == SchedulerMode.Host:
477
+ device_workspace_ptr = self.operation.rt_module.host_precompute(
478
+ self, self.operation.rt_module.get_workspace_size(self))
479
+ else:
480
+ device_workspace_ptr = 0
481
+
482
+ result = self.operation.rt_module.get_args(
483
+ ctypes.byref(self.arguments), self.total_tiles,
484
+ ctypes.c_void_p(int(device_workspace_ptr))
485
+ )
486
+ host_workspace = bytearray(result.contents)
487
+
488
+ device_workspace = None
489
+
490
+ self.host_workspace = host_workspace
491
+ self.device_workspace = device_workspace
492
+ self.launch_config = launch_config
493
+
494
+ def sync(self):
495
+ err, = cudart.cudaDeviceSynchronize()
496
+ if err != cuda.CUresult.CUDA_SUCCESS:
497
+ raise RuntimeError("CUDA Error %s" % str(err))
498
+ for arg in self.gemm_arguments:
499
+ arg.sync(stream_sync=False)
500
+
501
+
502
+ ################################################################################
503
+ # Base class for GEMM runtime module
504
+ ################################################################################
505
+
506
+ class GemmRTbase(ExecutableOperation):
507
+ """
508
+ GemmRT manages the CUTLASS runtime components
509
+ """
510
+
511
+ KernelTemplate = r'''
512
+ extern "C"
513
+ __global__ void
514
+ ${operation_name}(${operation_name}${operation_suffix}::Params params) {
515
+
516
+ // Dynamic shared memory base pointer
517
+ extern __shared__ int SharedStorageBase[];
518
+
519
+ // Declare pointer to dynamic shared memory.
520
+ ${operation_name}${operation_suffix}::SharedStorage *shared_storage =
521
+ reinterpret_cast<${operation_name}${operation_suffix}::SharedStorage *>(SharedStorageBase);
522
+
523
+ ${operation_name}${operation_suffix} op;
524
+
525
+ op(params, *shared_storage);
526
+ }
527
+ '''
528
+
529
+ def __init__(self, operation: 'GemmOperation'):
530
+ super().__init__(operation)
531
+
532
+ self.operation = operation
533
+ threadblock_shape = operation.tile_description.threadblock_shape
534
+ self.threadblock_shape = cutlass.gemm.GemmCoord(
535
+ threadblock_shape[0], threadblock_shape[1], threadblock_shape[2])
536
+ self.threadblock_swizzle = operation.swizzling_functor
537
+
538
+ #: number of threads per threadblock
539
+ self.threads: int = operation.tile_description.num_threads
540
+
541
+ #
542
+ def emit(self):
543
+ return self.emitter.emit(self.operation)
544
+
545
+ #
546
+ def can_implement(self, configuration, arguments):
547
+ raise NotImplementedError()
548
+
549
+ #
550
+ def get_host_workspace_size(self, arguments):
551
+ raise NotImplementedError()
552
+
553
+ #
554
+ def get_device_workspace_size(self, arguments):
555
+ return 0
556
+
557
+ #
558
+ def initialize(self):
559
+ err, = cuda.cuFuncSetAttribute(
560
+ self.kernel,
561
+ attrib=cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
562
+ value=self.shared_memory_capacity)
563
+ if err != cuda.CUresult.CUDA_SUCCESS:
564
+ raise RuntimeError('Cuda Error: {}'.format(err))
565
+
566
+
567
+ ################################################################################
568
+ # Runtime module for GEMM Universal
569
+ ################################################################################
570
+
571
+
572
+ class GemmRTUniversal(GemmRTbase):
573
+ """
574
+ GemmRTUniversal manages the CUTLASS runtime components
575
+ """
576
+ HostTemplate = r'''
577
+ extern "C" {
578
+ // Get the size of params in bytes
579
+ int ${operation_name}_get_param_size(){
580
+ return sizeof(${operation_name}${operation_suffix}::Params);
581
+ }
582
+
583
+ // Get the size of dynamic shared memory in bytes
584
+ int ${operation_name}_shared_memory_size() {
585
+ return int(sizeof(${operation_name}${operation_suffix}::SharedStorage));
586
+ }
587
+
588
+ // Get the params as byte array
589
+ char* ${operation_name}_get_params(${operation_name}_base::Arguments* argument, int* workspace){
590
+ ${operation_name}_base::Params* params;
591
+ params = new ${operation_name}_base::Params(*argument,
592
+ -1, // SM count. Only used for stream-K
593
+ -1 // Occupancy. Only used for stream-K
594
+ );
595
+
596
+ // Semaphore holds the pointer to the workspace in the Params struct
597
+ params->semaphore = workspace;
598
+
599
+ char *bytes = ((char*)(params));
600
+ char *output = new char[sizeof(${operation_name}_base::Params)];
601
+ for (unsigned int i = 0; i < sizeof(${operation_name}_base::Params); i ++)
602
+ output[i] = bytes[i];
603
+
604
+ return output;
605
+ }
606
+ }
607
+ '''
608
+
609
+ def __init__(self, operation: 'GemmOperation'):
610
+ super(GemmRTUniversal, self).__init__(operation)
611
+ self.emitter = EmitGemmUniversalInstance(
612
+ '_type', operation.direct_store, operation.visitor)
613
+
614
+ self.argument_type, self.epilogue_type = get_gemm_arguments(operation.epilogue_functor)
615
+ self.argtype = [
616
+ ctypes.POINTER(self.argument_type),
617
+ ctypes.POINTER(GemmCoord_), ctypes.c_int, ctypes.c_void_p
618
+ ]
619
+
620
+ def plan(self, arguments):
621
+
622
+ grid = self.threadblock_swizzle.get_tiled_shape(
623
+ arguments.problem_size, self.threadblock_shape, arguments.batch_count
624
+ )
625
+
626
+ gemm_k_size = arguments.problem_size.k()
627
+ if (arguments.gemm_mode in
628
+ [cutlass.gemm.Mode.Gemm, cutlass.gemm.Mode.GemmSplitKParallel]):
629
+ #
630
+ alignk = max(max(128 // DataTypeSize[self.operation.A.element],
631
+ 128 // DataTypeSize[self.operation.B.element]), 1)
632
+
633
+ gemm_k_size = (((arguments.problem_size.k() + arguments.batch_count - 1) //
634
+ arguments.batch_count + alignk - 1) // alignk) * alignk
635
+
636
+ if gemm_k_size:
637
+ grid_z = (arguments.problem_size.k() +
638
+ gemm_k_size - 1) // gemm_k_size
639
+ grid = cutlass.gemm.GemmCoord(grid.m(), grid.n(), grid_z)
640
+
641
+ arguments.grid_tiled_shape = cutlass.dim3(grid.m(), grid.n(), grid.k())
642
+ grid = self.threadblock_swizzle.get_grid_shape(grid)
643
+ arguments.gemm_k_size = gemm_k_size
644
+ return LaunchConfiguration(
645
+ [grid.x, grid.y, grid.z],
646
+ [self.threads, 1, 1],
647
+ self.shared_memory_capacity)
648
+
649
+ #
650
+ def get_device_workspace_size(self, arguments: GemmArguments):
651
+ workspace_bytes = 0
652
+ if arguments.gemm_mode == cutlass.gemm.Mode.GemmSplitKParallel:
653
+ workspace_bytes = (DataTypeSize[arguments.operation.C.element]
654
+ * arguments.batched_stride_D * arguments.grid_tiled_shape.z // 8)
655
+ elif (arguments.gemm_mode == cutlass.gemm.Mode.Gemm and
656
+ arguments.split_k_slices > 1):
657
+ #
658
+ workspace_bytes = 4 * arguments.grid_tiled_shape.x * arguments.grid_tiled_shape.y
659
+
660
+ # TODO: get extra workspace size
661
+ # see https://github.com/NVIDIA/cutlass/blob/master/include/cutlass/gemm/device/gemm_universal_base.h
662
+ return workspace_bytes
663
+
664
+
665
+ ###################################################################################################
666
+ # Runtime module for GEMM Grouped
667
+ ###################################################################################################
668
+
669
+
670
+ class GemmRTGrouped(GemmRTbase):
671
+ """
672
+ GemmRTGrouped manages the CUTLASS runtime components
673
+ """
674
+ HostTemplate = r'''
675
+ extern "C" {
676
+
677
+ // precompute scheduling information
678
+ char * ${operation_name}_precompute(${operation_name}_base::Arguments const &args, int tile_count, size_t workspace_bytes) {
679
+ char* host_workspace = new char[workspace_bytes];
680
+ ${operation_name}_base::ProblemVisitor::host_precompute(
681
+ args.host_problem_sizes,
682
+ args.problem_count,
683
+ args.threadblock_count,
684
+ (void*)host_workspace
685
+ );
686
+ return host_workspace;
687
+ }
688
+
689
+ // Get the size of params in bytes
690
+ int ${operation_name}_get_param_size(){
691
+ return sizeof(${operation_name}${operation_suffix}::Params);
692
+ }
693
+
694
+ // Get the size of dynamic shared memory in bytes
695
+ int ${operation_name}_shared_memory_size() {
696
+ return int(sizeof(${operation_name}${operation_suffix}::SharedStorage));
697
+ }
698
+
699
+ // Get the params as byte array
700
+ char* ${operation_name}_get_params(${operation_name}_base::Arguments* argument, int tile_count, void* workspace=nullptr){
701
+ ${operation_name}_base::Params* params;
702
+ params = new ${operation_name}_base::Params(*argument, workspace, tile_count);
703
+
704
+ char *bytes = ((char*)(params));
705
+ char *output = new char[sizeof(${operation_name}_base::Params)];
706
+ for (unsigned int i = 0; i < sizeof(${operation_name}_base::Params); i ++)
707
+ output[i] = bytes[i];
708
+
709
+ return output;
710
+ }
711
+ }
712
+ '''
713
+
714
+ def __init__(self, operation: 'GemmOperation'):
715
+ super(GemmRTGrouped, self).__init__(operation)
716
+ self.extra_funcs = ['precompute']
717
+
718
+ self.emitter = EmitGemmGroupedInstance('_type')
719
+ self.argument_type, self.epilogue_type = get_gemm_grouped_arguments(operation.epilogue_functor)
720
+ self.argtype = [ctypes.POINTER(self.argument_type), ctypes.c_int, ctypes.c_void_p]
721
+
722
+ def host_precompute(self, arguments, workspace_bytes):
723
+ self.precompute.argtype = [
724
+ self.argtype[0], ctypes.c_int, ctypes.c_longlong]
725
+ self.precompute.restype = ctypes.POINTER(
726
+ ctypes.c_byte * workspace_bytes)
727
+
728
+ problem_info = self.precompute(ctypes.byref(
729
+ arguments.arguments), arguments.total_tiles, workspace_bytes)
730
+ problem_info_array = bytearray(problem_info.contents)
731
+
732
+ # copy to device memory
733
+ return rmm.DeviceBuffer.to_device(problem_info_array).ptr
734
+
735
+ def plan(self, arguments):
736
+ return LaunchConfiguration(
737
+ [arguments.total_tiles, 1, 1],
738
+ [self.threads, 1, 1], self.shared_memory_capacity)
739
+
740
+ def get_workspace_size(self, arguments):
741
+ if self.operation.precompute_mode == SchedulerMode.Device:
742
+ return 0
743
+ elif self.operation.precompute_mode == SchedulerMode.Host:
744
+ total_tiles = arguments.total_tiles
745
+ entries_per_block = 1
746
+ return 8 * entries_per_block * total_tiles # three int32_t
747
+
748
+
749
+ ################################################################################
750
+ # Runtime module for GEMM Grouped
751
+ ################################################################################
752
+
753
+ #
754
+ class GemmOperationBase:
755
+ """
756
+ CUTLASS GEMM operation
757
+ """
758
+ #
759
+
760
+ def __init__(
761
+ self, gemm_kind, arch, tile_description: TileDescription,
762
+ A: TensorDescription, B: TensorDescription, C: TensorDescription,
763
+ epilogue_functor,
764
+ swizzling_functor=cutlass.IdentitySwizzle1, **kwargs):
765
+
766
+ #: operation kind
767
+ self.operation_kind: OperationKind = OperationKind.Gemm
768
+ #: compute capability
769
+ self.arch: int = arch
770
+ #: tile description object
771
+ self.tile_description: TileDescription = tile_description
772
+ #: gemm kind
773
+ self.gemm_kind: GemmKind = gemm_kind
774
+
775
+ # use deep copy to avoid overwritting the original TensorDescription
776
+ if C.layout == cutlass.ColumnMajor:
777
+ #: Operand A
778
+ self.A: TensorDescription = copy.deepcopy(B)
779
+ #: Operand B
780
+ self.B: TensorDescription = copy.deepcopy(A)
781
+ #: Operand C
782
+ self.C: TensorDescription = copy.deepcopy(C)
783
+ self.A.layout = transpose_layout(self.A.layout)
784
+ self.B.layout = transpose_layout(self.B.layout)
785
+ self.C.layout = transpose_layout(self.C.layout)
786
+ self.switched = True
787
+ else:
788
+ #: Operand A
789
+ self.A: TensorDescription = copy.deepcopy(A)
790
+ #: Operand B
791
+ self.B: TensorDescription = copy.deepcopy(B)
792
+ #: Operand C
793
+ self.C: TensorDescription = copy.deepcopy(C)
794
+ self.switched = False
795
+
796
+ self.epilogue_functor = epilogue_functor
797
+ self.swizzling_functor = swizzling_functor()
798
+
799
+ if "direct_store" in kwargs:
800
+ self.direct_store = kwargs["direct_store"]
801
+ else:
802
+ self.direct_store = False
803
+
804
+ if "visitor" in kwargs:
805
+ self.visitor = kwargs["visitor"]
806
+ else:
807
+ self.visitor = False
808
+
809
+ def run(self, arguments: GemmArguments) -> cuda.CUresult:
810
+ """
811
+ Configure and launch the cuda kernel with input arguments
812
+ """
813
+ err = self.rt_module.run(
814
+ arguments.host_workspace,
815
+ arguments.device_workspace,
816
+ arguments.launch_config)
817
+
818
+ if err != cuda.CUresult.CUDA_SUCCESS:
819
+ raise RuntimeError('CUDA Error %s' % str(err))
820
+
821
+ return err
822
+
823
+ def free(self):
824
+ if hasattr(self, "workspace_buffer"):
825
+ del self.workspace_buffer
826
+
827
+ #
828
+ def is_complex(self):
829
+ complex_operators = [
830
+ MathOperation.multiply_add_complex,
831
+ MathOperation.multiply_add_complex_gaussian,
832
+ MathOperation.multiply_add_complex_fast_f32
833
+ ]
834
+ return self.tile_description.math_instruction.math_operation in complex_operators
835
+
836
+ #
837
+ def is_planar_complex(self):
838
+ return self.gemm_kind in (GemmKind.PlanarComplex, GemmKind.PlanarComplexArray)
839
+
840
+ #
841
+ def accumulator_type(self):
842
+ accum = self.tile_description.math_instruction.element_accumulator
843
+
844
+ if self.is_complex():
845
+ return get_complex_from_real(accum)
846
+
847
+ return accum
848
+
849
+ #
850
+ def short_math_name(self):
851
+ if self.tile_description.math_instruction.math_operation == MathOperation.multiply_add_complex_gaussian:
852
+ return "g%s" % ShortDataTypeNames[self.accumulator_type()]
853
+ return ShortDataTypeNames[self.accumulator_type()]
854
+
855
+ #
856
+
857
+ def core_name(self):
858
+ ''' The basic operation kind is prefixed with a letter indicating the accumulation type. '''
859
+
860
+ inst_shape = ''
861
+ inst_operation = ''
862
+ intermediate_type = ''
863
+
864
+ math_operations_map = {
865
+ MathOperation.xor_popc: 'xor',
866
+ }
867
+
868
+ if self.tile_description.math_instruction.opcode_class == cutlass.OpClass.TensorOp or \
869
+ self.tile_description.math_instruction.opcode_class == cutlass.OpClass.WmmaTensorOp:
870
+
871
+ math_op = self.tile_description.math_instruction.math_operation
872
+ math_op_string = math_operations_map[math_op] if math_op in math_operations_map.keys(
873
+ ) else ''
874
+
875
+ inst_shape = "%d%d%d" % tuple(
876
+ self.tile_description.math_instruction.instruction_shape)
877
+ inst_shape += math_op_string
878
+
879
+ if self.tile_description.math_instruction.element_a != self.A.element and \
880
+ self.tile_description.math_instruction.element_a != self.tile_description.math_instruction.element_accumulator:
881
+ intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a]
882
+
883
+ return "%s%s%s%s" % (self.short_math_name(), inst_shape, intermediate_type, GemmKindNames[self.gemm_kind])
884
+
885
+ #
886
+ def extended_name(self):
887
+ ''' Append data types if they differ from compute type. '''
888
+ if self.is_complex():
889
+ extended_name = "${core_name}"
890
+ else:
891
+ if self.C.element != self.tile_description.math_instruction.element_accumulator and \
892
+ self.A.element != self.tile_description.math_instruction.element_accumulator:
893
+ extended_name = "${element_c}_${core_name}_${element_a}"
894
+ elif self.C.element == self.tile_description.math_instruction.element_accumulator and \
895
+ self.A.element != self.tile_description.math_instruction.element_accumulator:
896
+ extended_name = "${core_name}_${element_a}"
897
+ else:
898
+ extended_name = "${core_name}"
899
+
900
+ extended_name = SubstituteTemplate(extended_name, {
901
+ 'element_a': DataTypeNames[self.A.element],
902
+ 'element_c': DataTypeNames[self.C.element],
903
+ 'core_name': self.core_name()
904
+ })
905
+
906
+ return extended_name
907
+
908
+ #
909
+ def layout_name(self):
910
+ if self.is_complex() or self.is_planar_complex():
911
+ return "%s%s" % (
912
+ ShortComplexLayoutNames[(
913
+ self.A.layout, self.A.complex_transform)],
914
+ ShortComplexLayoutNames[(
915
+ self.B.layout, self.B.complex_transform)]
916
+ )
917
+ return "%s%s" % (ShortLayoutTypeNames[self.A.layout], ShortLayoutTypeNames[self.B.layout])
918
+
919
+ #
920
+ def procedural_name(self):
921
+ ''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
922
+ threadblock = self.tile_description.procedural_name()
923
+
924
+ opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class]
925
+
926
+ alignment = max([self.A.alignment, self.B.alignment, self.C.alignment])
927
+
928
+ return SubstituteTemplate(
929
+ "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_align${alignment}",
930
+ {
931
+ 'opcode_class': opcode_class_name,
932
+ 'extended_name': self.extended_name(),
933
+ 'threadblock': threadblock,
934
+ 'layout': self.layout_name(),
935
+ 'alignment': "%d" % self.A.alignment,
936
+ }
937
+ )
938
+
939
+ #
940
+ def configuration_name(self):
941
+ ''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
942
+ return self.procedural_name()
943
+
944
+
945
+ class GemmOperationUniversal(GemmOperationBase):
946
+ def __init__(self, arch, tile_description: TileDescription, A: TensorDescription, B, C,
947
+ epilogue_functor, swizzling_functor=cutlass.IdentitySwizzle1, **kwargs):
948
+ super(GemmOperationUniversal, self).__init__(GemmKind.Universal, arch, tile_description,
949
+ A, B, C, epilogue_functor, swizzling_functor, **kwargs)
950
+ self.rt_module = GemmRTUniversal(self)
951
+ self.argument_type = self.rt_module.argument_type
952
+ self.epilogue_type = self.rt_module.epilogue_type
953
+
954
+
955
+ class GemmOperationGrouped(GemmOperationBase):
956
+ def __init__(self, arch, tile_description: TileDescription, A: TensorDescription, B, C,
957
+ epilogue_functor, swizzling_functor=cutlass.IdentitySwizzle1, **kwargs):
958
+ super(GemmOperationGrouped, self).__init__(GemmKind.Grouped, arch, tile_description,
959
+ A, B, C, epilogue_functor, swizzling_functor, **kwargs)
960
+ assert "precompute_mode" in kwargs.keys(
961
+ ), "missing keyword arguement 'precompute_mode'."
962
+ self.precompute_mode = kwargs["precompute_mode"]
963
+ self.rt_module = GemmRTGrouped(self)
964
+ self.argument_type = self.rt_module.argument_type
965
+ self.epilogue_type = self.rt_module.epilogue_type
966
+
967
+ ###################################################################################################
968
+ #
969
+ # Emits single instances of a CUTLASS device-wide operator
970
+ #
971
+ ###################################################################################################
972
+
973
+ #
974
+ class EmitGemmUniversalInstance:
975
+ ''' Responsible for emitting a CUTLASS template definition'''
976
+
977
+ def __init__(self, operation_suffix='', direct_store=False, visitor=False):
978
+ self.operation_suffix = operation_suffix
979
+ self.direct_store = direct_store
980
+ self.visitor = visitor
981
+ self.includes = [
982
+ "cutlass/cutlass.h",
983
+ "cutlass/numeric_types.h",
984
+ "cutlass/arch/arch.h",
985
+ "cutlass/arch/mma.h",
986
+ "cutlass/layout/matrix.h",
987
+ "cutlass/gemm/device/gemm.h",
988
+ "cutlass/gemm/device/gemm_universal_adapter.h",
989
+ "cutlass/gemm/kernel/default_gemm_universal.h",
990
+ ]
991
+ if self.visitor:
992
+ self.includes += [
993
+ "gemm/gemm_universal_with_visitor.h",
994
+ "epilogue/epilogue_visitor_with_layernorm.h",
995
+ "epilogue/epilogue_visitor_generic.h"
996
+ ]
997
+ if self.direct_store:
998
+ self.includes.append(
999
+ "cutlass/epilogue/threadblock/default_epilogue_direct_store.h")
1000
+ self.gemm_template_interleaved = """
1001
+ // Gemm operator ${operation_name}
1002
+ using ${operation_name}_base =
1003
+ typename cutlass::gemm::kernel::DefaultGemmUniversal<
1004
+ ${element_a}, ${layout_a}, ${transform_a}, ${align_a},
1005
+ ${element_b}, ${layout_b}, ${transform_b}, ${align_b},
1006
+ ${element_c}, ${layout_c},
1007
+ ${element_accumulator},
1008
+ ${opcode_class},
1009
+ ${arch},
1010
+ cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
1011
+ cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
1012
+ cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
1013
+ ${epilogue_functor},
1014
+ ${swizzling_functor},
1015
+ ${stages},
1016
+ ${math_operation}
1017
+ >::GemmKernel;
1018
+
1019
+ // Define named type
1020
+ struct ${operation_name}${operation_suffix} :
1021
+ public ${operation_name}_base { };
1022
+ """
1023
+ self.gemm_template_direct_store = """
1024
+ // Gemm operator ${operation_name}
1025
+ using ${operation_name}_default =
1026
+ typename cutlass::gemm::kernel::DefaultGemmUniversal<
1027
+ ${element_a}, ${layout_a}, ${transform_a}, ${align_a},
1028
+ ${element_b}, ${layout_b}, ${transform_b}, ${align_b},
1029
+ ${element_c}, ${layout_c},
1030
+ ${element_accumulator},
1031
+ ${opcode_class},
1032
+ ${arch},
1033
+ cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
1034
+ cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
1035
+ cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
1036
+ ${epilogue_functor},
1037
+ ${swizzling_functor},
1038
+ ${stages},
1039
+ ${math_operation}
1040
+ >::GemmKernel;
1041
+
1042
+ using ${operation_name}_base =
1043
+ cutlass::gemm::kernel::GemmUniversal<
1044
+ ${operation_name}_default::Mma,
1045
+ cutlass::epilogue::threadblock::DefaultEpilogueDirectStore<
1046
+ ${operation_name}_default::Epilogue
1047
+ >::Epilogue,
1048
+ ${operation_name}_default::ThreadblockSwizzle
1049
+ >;
1050
+
1051
+ // Define named type
1052
+ struct ${operation_name}${operation_suffix} :
1053
+ public ${operation_name}_base { };
1054
+ """
1055
+ self.gemm_template_visitor = """
1056
+ // Gemm operator ${operation_name}
1057
+ using ${operation_name}_default =
1058
+ typename cutlass::gemm::kernel::DefaultGemmUniversal<
1059
+ ${element_a}, ${layout_a}, ${transform_a}, ${align_a},
1060
+ ${element_b}, ${layout_b}, ${transform_b}, ${align_b},
1061
+ ${element_c}, ${layout_c},
1062
+ ${element_accumulator},
1063
+ ${opcode_class},
1064
+ ${arch},
1065
+ cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
1066
+ cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
1067
+ cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
1068
+ ${elementwise_epilogue_functor},
1069
+ ${swizzling_functor},
1070
+ ${stages},
1071
+ ${math_operation}
1072
+ >::GemmKernel;
1073
+
1074
+ ${epilogue_visitor}
1075
+
1076
+ using ${operation_name}_Epilogue = typename cutlass::epilogue::threadblock::EpilogueWithVisitorFromExistingEpilogue<
1077
+ ${operation_name}_EpilogueVisitor,
1078
+ typename ${operation_name}_default::Epilogue>::Epilogue;
1079
+
1080
+ using ${operation_name}_base =
1081
+ cutlass::gemm::kernel::GemmUniversalwithEpilogueVisitor<
1082
+ ${operation_name}_default::Mma,
1083
+ ${operation_name}_Epilogue,
1084
+ ${operation_name}_default::ThreadblockSwizzle
1085
+ >;
1086
+
1087
+ // Define named type
1088
+ struct ${operation_name}${operation_suffix} :
1089
+ public ${operation_name}_base { };
1090
+ """
1091
+
1092
+ #
1093
+ def instance_template(self):
1094
+ return """
1095
+ ${compile_guard_start}
1096
+ manifest.append(new ${gemm_kind}<
1097
+ cutlass::gemm::device::GemmUniversalAdapter<${operation_name}>
1098
+ >("${operation_name}"));
1099
+ ${compile_guard_end}
1100
+ """
1101
+
1102
+ #
1103
+ def emit(self, operation):
1104
+
1105
+ threadblock_shape = operation.tile_description.threadblock_shape
1106
+ warp_count = operation.tile_description.warp_count
1107
+
1108
+ warp_shape = [threadblock_shape[idx] // warp_count[idx]
1109
+ for idx in range(3)]
1110
+
1111
+ # transpose_layouts = {
1112
+ # cutlass.layout.ColumnMajorcutlass.layout.ColumnMajor,
1113
+ # cutlass.layout.RowMajorcutlass.layout.RowMajor
1114
+ # }
1115
+
1116
+ # if operation.A.layout in transpose_layouts.keys() and \
1117
+ # operation.B.layout in transpose_layouts.keys() and \
1118
+ # operation.C.layout in transpose_layouts.keys():
1119
+
1120
+ # instance_layout_A = transpose_layouts[operation.A.layout]
1121
+ # instance_layout_B = transpose_layouts[operation.B.layout]
1122
+ # instance_layout_C = transpose_layouts[operation.C.layout]
1123
+
1124
+ # gemm_template = self.gemm_template
1125
+ # else:
1126
+ instance_layout_A, instance_layout_B, instance_layout_C = \
1127
+ (operation.A.layout, operation.B.layout, operation.C.layout)
1128
+ if self.direct_store:
1129
+ gemm_template = self.gemm_template_direct_store
1130
+ elif self.visitor:
1131
+ gemm_template = self.gemm_template_visitor
1132
+ else:
1133
+ gemm_template = self.gemm_template_interleaved
1134
+ #
1135
+
1136
+ values = {
1137
+ 'operation_name': operation.procedural_name(),
1138
+ 'operation_suffix': self.operation_suffix,
1139
+ 'element_a': DataTypeTag[operation.A.element],
1140
+ 'layout_a': LayoutTag[instance_layout_A],
1141
+ 'element_b': DataTypeTag[operation.B.element],
1142
+ 'layout_b': LayoutTag[instance_layout_B],
1143
+ 'element_c': DataTypeTag[operation.C.element],
1144
+ 'layout_c': LayoutTag[instance_layout_C],
1145
+ 'element_accumulator': DataTypeTag[operation.accumulator_type()],
1146
+ 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
1147
+ 'arch': "cutlass::arch::Sm%d" % operation.arch,
1148
+ 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
1149
+ 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
1150
+ 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
1151
+ 'warp_shape_m': str(warp_shape[0]),
1152
+ 'warp_shape_n': str(warp_shape[1]),
1153
+ 'warp_shape_k': str(warp_shape[2]),
1154
+ 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
1155
+ 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
1156
+ 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
1157
+ 'swizzling_functor': operation.swizzling_functor.tag(),
1158
+ 'stages': str(operation.tile_description.stages),
1159
+ 'align_a': str(operation.A.alignment),
1160
+ 'align_b': str(operation.B.alignment),
1161
+ 'transform_a': ComplexTransformTag[operation.A.complex_transform],
1162
+ 'transform_b': ComplexTransformTag[operation.B.complex_transform],
1163
+ 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation]
1164
+ }
1165
+
1166
+ if self.visitor:
1167
+ values['epilogue_visitor'] = operation.epilogue_functor.emit(operation)
1168
+ values['elementwise_epilogue_functor'] = operation.epilogue_functor.elementwise_functor.emit()
1169
+ else:
1170
+ values['epilogue_functor'] = operation.epilogue_functor.emit()
1171
+
1172
+ return SubstituteTemplate(gemm_template, values)
1173
+
1174
+ ###################################################################################################
1175
+
1176
+ #
1177
+
1178
+
1179
+ class EmitGemmGroupedInstance:
1180
+ ''' Responsible for emitting a CUTLASS template definition'''
1181
+
1182
+ def __init__(self, operation_suffix=''):
1183
+ self.operation_suffix = operation_suffix
1184
+ self.includes = [
1185
+ "cutlass/cutlass.h",
1186
+ "cutlass/numeric_types.h",
1187
+ "cutlass/arch/arch.h",
1188
+ "cutlass/arch/mma.h",
1189
+ "cutlass/layout/matrix.h",
1190
+ "cutlass/gemm/kernel/gemm_grouped.h",
1191
+ "cutlass/gemm/kernel/default_gemm_grouped.h"
1192
+ ]
1193
+ self.gemm_template = """
1194
+ // Gemm operator ${operation_name}
1195
+ using ${operation_name}_base =
1196
+ typename cutlass::gemm::kernel::DefaultGemmGrouped<
1197
+ ${element_a}, ${layout_a}, ${transform_a}, ${align_a},
1198
+ ${element_b}, ${layout_b}, ${transform_b}, ${align_b},
1199
+ ${element_c}, ${layout_c},
1200
+ ${element_accumulator},
1201
+ ${opcode_class},
1202
+ ${arch},
1203
+ cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
1204
+ cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
1205
+ cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
1206
+ ${epilogue_functor},
1207
+ ${swizzling_functor},
1208
+ ${stages},
1209
+ ${precompute_mode},
1210
+ ${math_operation}
1211
+ >::GemmKernel;
1212
+
1213
+ // Define named type
1214
+ struct ${operation_name}${operation_suffix} :
1215
+ public ${operation_name}_base { };
1216
+ """
1217
+
1218
+ #
1219
+ def instance_template(self):
1220
+ return """
1221
+ ${compile_guard_start}
1222
+ manifest.append(new ${gemm_kind}<
1223
+ cutlass::gemm::device::GemmGrouped<${operation_name}>
1224
+ >("${operation_name}"));
1225
+ ${compile_guard_end}
1226
+ """
1227
+
1228
+ #
1229
+ def emit(self, operation):
1230
+
1231
+ threadblock_shape = operation.tile_description.threadblock_shape
1232
+ warp_count = operation.tile_description.warp_count
1233
+
1234
+ warp_shape = [threadblock_shape[idx] // warp_count[idx]
1235
+ for idx in range(3)]
1236
+
1237
+ instance_layout_A, instance_layout_B, instance_layout_C = \
1238
+ (operation.A.layout, operation.B.layout, operation.C.layout)
1239
+ #
1240
+
1241
+ # Support built-in epilogue functors or user-defined functions
1242
+ epilogue_functor = operation.epilogue_functor.emit()
1243
+
1244
+ values = {
1245
+ 'operation_name': operation.procedural_name(),
1246
+ 'operation_suffix': self.operation_suffix,
1247
+ 'element_a': DataTypeTag[operation.A.element],
1248
+ 'layout_a': LayoutTag[instance_layout_A],
1249
+ 'element_b': DataTypeTag[operation.B.element],
1250
+ 'layout_b': LayoutTag[instance_layout_B],
1251
+ 'element_c': DataTypeTag[operation.C.element],
1252
+ 'layout_c': LayoutTag[instance_layout_C],
1253
+ 'element_accumulator': DataTypeTag[operation.accumulator_type()],
1254
+ 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
1255
+ 'arch': "cutlass::arch::Sm%d" % operation.arch,
1256
+ 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
1257
+ 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
1258
+ 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
1259
+ 'warp_shape_m': str(warp_shape[0]),
1260
+ 'warp_shape_n': str(warp_shape[1]),
1261
+ 'warp_shape_k': str(warp_shape[2]),
1262
+ 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
1263
+ 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
1264
+ 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
1265
+ 'epilogue_functor': epilogue_functor,
1266
+ 'swizzling_functor': operation.swizzling_functor.tag(),
1267
+ 'stages': str(operation.tile_description.stages),
1268
+ 'align_a': str(operation.A.alignment),
1269
+ 'align_b': str(operation.B.alignment),
1270
+ 'transform_a': ComplexTransformTag[operation.A.complex_transform],
1271
+ 'transform_b': ComplexTransformTag[operation.B.complex_transform],
1272
+ 'precompute_mode': SchedulerModeTag[operation.precompute_mode],
1273
+ 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation]
1274
+ }
1275
+
1276
+ return SubstituteTemplate(self.gemm_template, values)