warp-lang 0.9.0__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (315) hide show
  1. warp/__init__.py +15 -7
  2. warp/__init__.pyi +1 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +22 -443
  6. warp/build_dll.py +384 -0
  7. warp/builtins.py +998 -488
  8. warp/codegen.py +1307 -739
  9. warp/config.py +5 -3
  10. warp/constants.py +6 -0
  11. warp/context.py +1291 -548
  12. warp/dlpack.py +31 -31
  13. warp/fabric.py +326 -0
  14. warp/fem/__init__.py +27 -0
  15. warp/fem/cache.py +389 -0
  16. warp/fem/dirichlet.py +181 -0
  17. warp/fem/domain.py +263 -0
  18. warp/fem/field/__init__.py +101 -0
  19. warp/fem/field/field.py +149 -0
  20. warp/fem/field/nodal_field.py +299 -0
  21. warp/fem/field/restriction.py +21 -0
  22. warp/fem/field/test.py +181 -0
  23. warp/fem/field/trial.py +183 -0
  24. warp/fem/geometry/__init__.py +19 -0
  25. warp/fem/geometry/closest_point.py +70 -0
  26. warp/fem/geometry/deformed_geometry.py +271 -0
  27. warp/fem/geometry/element.py +744 -0
  28. warp/fem/geometry/geometry.py +186 -0
  29. warp/fem/geometry/grid_2d.py +373 -0
  30. warp/fem/geometry/grid_3d.py +435 -0
  31. warp/fem/geometry/hexmesh.py +953 -0
  32. warp/fem/geometry/partition.py +376 -0
  33. warp/fem/geometry/quadmesh_2d.py +532 -0
  34. warp/fem/geometry/tetmesh.py +840 -0
  35. warp/fem/geometry/trimesh_2d.py +577 -0
  36. warp/fem/integrate.py +1616 -0
  37. warp/fem/operator.py +191 -0
  38. warp/fem/polynomial.py +213 -0
  39. warp/fem/quadrature/__init__.py +2 -0
  40. warp/fem/quadrature/pic_quadrature.py +245 -0
  41. warp/fem/quadrature/quadrature.py +294 -0
  42. warp/fem/space/__init__.py +292 -0
  43. warp/fem/space/basis_space.py +489 -0
  44. warp/fem/space/collocated_function_space.py +105 -0
  45. warp/fem/space/dof_mapper.py +236 -0
  46. warp/fem/space/function_space.py +145 -0
  47. warp/fem/space/grid_2d_function_space.py +267 -0
  48. warp/fem/space/grid_3d_function_space.py +306 -0
  49. warp/fem/space/hexmesh_function_space.py +352 -0
  50. warp/fem/space/partition.py +350 -0
  51. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  52. warp/fem/space/restriction.py +160 -0
  53. warp/fem/space/shape/__init__.py +15 -0
  54. warp/fem/space/shape/cube_shape_function.py +738 -0
  55. warp/fem/space/shape/shape_function.py +103 -0
  56. warp/fem/space/shape/square_shape_function.py +611 -0
  57. warp/fem/space/shape/tet_shape_function.py +567 -0
  58. warp/fem/space/shape/triangle_shape_function.py +429 -0
  59. warp/fem/space/tetmesh_function_space.py +292 -0
  60. warp/fem/space/topology.py +295 -0
  61. warp/fem/space/trimesh_2d_function_space.py +221 -0
  62. warp/fem/types.py +77 -0
  63. warp/fem/utils.py +495 -0
  64. warp/native/array.h +164 -55
  65. warp/native/builtin.h +150 -174
  66. warp/native/bvh.cpp +75 -328
  67. warp/native/bvh.cu +406 -23
  68. warp/native/bvh.h +37 -45
  69. warp/native/clang/clang.cpp +136 -24
  70. warp/native/crt.cpp +1 -76
  71. warp/native/crt.h +111 -104
  72. warp/native/cuda_crt.h +1049 -0
  73. warp/native/cuda_util.cpp +15 -3
  74. warp/native/cuda_util.h +3 -1
  75. warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
  76. warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
  77. warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
  78. warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
  79. warp/native/cutlass/tools/library/scripts/library.py +799 -0
  80. warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
  81. warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
  82. warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
  83. warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
  84. warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
  85. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
  86. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
  87. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
  88. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
  89. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
  90. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
  91. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
  92. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
  93. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
  94. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
  95. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
  96. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
  97. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
  98. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
  99. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
  100. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
  101. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
  102. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
  103. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
  104. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
  105. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
  106. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
  107. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
  108. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
  109. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
  110. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
  111. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
  112. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
  113. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
  114. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
  115. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
  116. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
  117. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  118. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  119. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
  120. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
  121. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
  122. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  123. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  124. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
  125. warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
  126. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
  127. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
  128. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
  129. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
  130. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
  131. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
  132. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
  133. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
  134. warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
  135. warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
  136. warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
  137. warp/native/cutlass/tools/library/scripts/rt.py +796 -0
  138. warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
  139. warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
  140. warp/native/cutlass_gemm.cu +5 -3
  141. warp/native/exports.h +1240 -949
  142. warp/native/fabric.h +228 -0
  143. warp/native/hashgrid.cpp +4 -4
  144. warp/native/hashgrid.h +22 -2
  145. warp/native/initializer_array.h +2 -2
  146. warp/native/intersect.h +22 -7
  147. warp/native/intersect_adj.h +8 -8
  148. warp/native/intersect_tri.h +13 -16
  149. warp/native/marching.cu +157 -161
  150. warp/native/mat.h +119 -19
  151. warp/native/matnn.h +2 -2
  152. warp/native/mesh.cpp +108 -83
  153. warp/native/mesh.cu +243 -6
  154. warp/native/mesh.h +1547 -458
  155. warp/native/nanovdb/NanoVDB.h +1 -1
  156. warp/native/noise.h +272 -329
  157. warp/native/quat.h +51 -8
  158. warp/native/rand.h +45 -35
  159. warp/native/range.h +6 -2
  160. warp/native/reduce.cpp +157 -0
  161. warp/native/reduce.cu +348 -0
  162. warp/native/runlength_encode.cpp +62 -0
  163. warp/native/runlength_encode.cu +46 -0
  164. warp/native/scan.cu +11 -13
  165. warp/native/scan.h +1 -0
  166. warp/native/solid_angle.h +442 -0
  167. warp/native/sort.cpp +13 -0
  168. warp/native/sort.cu +9 -1
  169. warp/native/sparse.cpp +338 -0
  170. warp/native/sparse.cu +545 -0
  171. warp/native/spatial.h +2 -2
  172. warp/native/temp_buffer.h +30 -0
  173. warp/native/vec.h +126 -24
  174. warp/native/volume.h +120 -0
  175. warp/native/warp.cpp +658 -53
  176. warp/native/warp.cu +660 -68
  177. warp/native/warp.h +112 -12
  178. warp/optim/__init__.py +1 -0
  179. warp/optim/linear.py +922 -0
  180. warp/optim/sgd.py +92 -0
  181. warp/render/render_opengl.py +392 -152
  182. warp/render/render_usd.py +11 -11
  183. warp/sim/__init__.py +2 -2
  184. warp/sim/articulation.py +385 -185
  185. warp/sim/collide.py +21 -8
  186. warp/sim/import_mjcf.py +297 -106
  187. warp/sim/import_urdf.py +389 -210
  188. warp/sim/import_usd.py +198 -97
  189. warp/sim/inertia.py +17 -18
  190. warp/sim/integrator_euler.py +14 -8
  191. warp/sim/integrator_xpbd.py +161 -19
  192. warp/sim/model.py +795 -291
  193. warp/sim/optimizer.py +2 -6
  194. warp/sim/render.py +65 -3
  195. warp/sim/utils.py +3 -0
  196. warp/sparse.py +1227 -0
  197. warp/stubs.py +665 -223
  198. warp/tape.py +66 -15
  199. warp/tests/__main__.py +3 -6
  200. warp/tests/assets/curlnoise_golden.npy +0 -0
  201. warp/tests/assets/pnoise_golden.npy +0 -0
  202. warp/tests/assets/torus.usda +105 -105
  203. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  204. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  205. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  206. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  207. warp/tests/aux_test_unresolved_func.py +14 -0
  208. warp/tests/aux_test_unresolved_symbol.py +14 -0
  209. warp/tests/disabled_kinematics.py +239 -0
  210. warp/tests/run_coverage_serial.py +31 -0
  211. warp/tests/test_adam.py +103 -106
  212. warp/tests/test_arithmetic.py +128 -74
  213. warp/tests/test_array.py +1497 -211
  214. warp/tests/test_array_reduce.py +150 -0
  215. warp/tests/test_atomic.py +64 -28
  216. warp/tests/test_bool.py +99 -0
  217. warp/tests/test_builtins_resolution.py +1292 -0
  218. warp/tests/test_bvh.py +75 -43
  219. warp/tests/test_closest_point_edge_edge.py +54 -57
  220. warp/tests/test_codegen.py +233 -128
  221. warp/tests/test_compile_consts.py +28 -20
  222. warp/tests/test_conditional.py +108 -24
  223. warp/tests/test_copy.py +10 -12
  224. warp/tests/test_ctypes.py +112 -88
  225. warp/tests/test_dense.py +21 -14
  226. warp/tests/test_devices.py +98 -0
  227. warp/tests/test_dlpack.py +136 -108
  228. warp/tests/test_examples.py +277 -0
  229. warp/tests/test_fabricarray.py +955 -0
  230. warp/tests/test_fast_math.py +15 -11
  231. warp/tests/test_fem.py +1271 -0
  232. warp/tests/test_fp16.py +53 -19
  233. warp/tests/test_func.py +187 -74
  234. warp/tests/test_generics.py +194 -49
  235. warp/tests/test_grad.py +180 -116
  236. warp/tests/test_grad_customs.py +176 -0
  237. warp/tests/test_hash_grid.py +52 -37
  238. warp/tests/test_import.py +10 -23
  239. warp/tests/test_indexedarray.py +577 -24
  240. warp/tests/test_intersect.py +18 -9
  241. warp/tests/test_large.py +141 -0
  242. warp/tests/test_launch.py +251 -15
  243. warp/tests/test_lerp.py +64 -65
  244. warp/tests/test_linear_solvers.py +154 -0
  245. warp/tests/test_lvalue.py +493 -0
  246. warp/tests/test_marching_cubes.py +12 -13
  247. warp/tests/test_mat.py +508 -2778
  248. warp/tests/test_mat_lite.py +115 -0
  249. warp/tests/test_mat_scalar_ops.py +2889 -0
  250. warp/tests/test_math.py +103 -9
  251. warp/tests/test_matmul.py +305 -69
  252. warp/tests/test_matmul_lite.py +410 -0
  253. warp/tests/test_mesh.py +71 -14
  254. warp/tests/test_mesh_query_aabb.py +41 -25
  255. warp/tests/test_mesh_query_point.py +325 -34
  256. warp/tests/test_mesh_query_ray.py +39 -22
  257. warp/tests/test_mlp.py +30 -22
  258. warp/tests/test_model.py +92 -89
  259. warp/tests/test_modules_lite.py +39 -0
  260. warp/tests/test_multigpu.py +88 -114
  261. warp/tests/test_noise.py +12 -11
  262. warp/tests/test_operators.py +16 -20
  263. warp/tests/test_options.py +11 -11
  264. warp/tests/test_pinned.py +17 -18
  265. warp/tests/test_print.py +32 -11
  266. warp/tests/test_quat.py +275 -129
  267. warp/tests/test_rand.py +18 -16
  268. warp/tests/test_reload.py +38 -34
  269. warp/tests/test_rounding.py +50 -43
  270. warp/tests/test_runlength_encode.py +190 -0
  271. warp/tests/test_smoothstep.py +9 -11
  272. warp/tests/test_snippet.py +143 -0
  273. warp/tests/test_sparse.py +460 -0
  274. warp/tests/test_spatial.py +276 -243
  275. warp/tests/test_streams.py +110 -85
  276. warp/tests/test_struct.py +331 -85
  277. warp/tests/test_tape.py +39 -21
  278. warp/tests/test_torch.py +118 -89
  279. warp/tests/test_transient_module.py +12 -13
  280. warp/tests/test_types.py +614 -0
  281. warp/tests/test_utils.py +494 -0
  282. warp/tests/test_vec.py +354 -1987
  283. warp/tests/test_vec_lite.py +73 -0
  284. warp/tests/test_vec_scalar_ops.py +2099 -0
  285. warp/tests/test_volume.py +457 -293
  286. warp/tests/test_volume_write.py +124 -134
  287. warp/tests/unittest_serial.py +35 -0
  288. warp/tests/unittest_suites.py +341 -0
  289. warp/tests/unittest_utils.py +568 -0
  290. warp/tests/unused_test_misc.py +71 -0
  291. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  292. warp/thirdparty/appdirs.py +36 -45
  293. warp/thirdparty/unittest_parallel.py +549 -0
  294. warp/torch.py +72 -30
  295. warp/types.py +1744 -713
  296. warp/utils.py +360 -350
  297. warp_lang-0.11.0.dist-info/LICENSE.md +36 -0
  298. warp_lang-0.11.0.dist-info/METADATA +238 -0
  299. warp_lang-0.11.0.dist-info/RECORD +332 -0
  300. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
  301. warp/bin/warp-clang.exp +0 -0
  302. warp/bin/warp-clang.lib +0 -0
  303. warp/bin/warp.exp +0 -0
  304. warp/bin/warp.lib +0 -0
  305. warp/tests/test_all.py +0 -215
  306. warp/tests/test_array_scan.py +0 -60
  307. warp/tests/test_base.py +0 -208
  308. warp/tests/test_unresolved_func.py +0 -7
  309. warp/tests/test_unresolved_symbol.py +0 -7
  310. warp_lang-0.9.0.dist-info/METADATA +0 -20
  311. warp_lang-0.9.0.dist-info/RECORD +0 -177
  312. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  313. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  314. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  315. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,4625 @@
1
+ #
2
+ # \file generator.py
3
+ #
4
+ # \brief Generates the CUTLASS Library's instances
5
+ #
6
+
7
+ import enum
8
+ import os.path
9
+ import shutil
10
+ import argparse
11
+
12
+ from library import *
13
+ from manifest import *
14
+ from itertools import product
15
+
16
+ ###################################################################################################
17
+
18
+ #
19
+ def CudaToolkitVersionSatisfies(semantic_ver_string, major, minor, patch = 0):
20
+
21
+ # by default, use the latest CUDA Toolkit version
22
+ cuda_version = [11, 0, 132]
23
+
24
+ # Update cuda_version based on parsed string
25
+ if semantic_ver_string != '':
26
+ for i, x in enumerate([int(x) for x in semantic_ver_string.split('.')]):
27
+ if i < len(cuda_version):
28
+ cuda_version[i] = x
29
+ else:
30
+ cuda_version.append(x)
31
+ return cuda_version >= [major, minor, patch]
32
+
33
+
34
+ ###################################################################################################
35
+ ###################################################################################################
36
+
37
+ #
38
+ def EpilogueAlignment(max_alignment, tile, epilogue_steps = 8):
39
+ ''' Helper to compute the maximum alignment of the epilogue '''
40
+
41
+ def product(X, identity = 1):
42
+ result = identity
43
+ for item in X:
44
+ result *= item
45
+ return result
46
+
47
+ elements_per_thread = product(tile.threadblock_shape[:-1]) // product(tile.warp_count) // 32 // epilogue_steps
48
+ return min(max_alignment, elements_per_thread)
49
+
50
+ #
51
+ def CreateGemmOperator(manifest, layouts, tile_descriptions, data_type, \
52
+ alignment_constraints, complex_transforms = None, epilogue_functor = EpilogueFunctor.LinearCombination, \
53
+ swizzling_functor = SwizzlingFunctor.Identity8):
54
+ # Use StreamK decomposition for basic GEMMs
55
+ # swizzling_functor = SwizzlingFunctor.StreamK):
56
+
57
+ if complex_transforms is None:
58
+ complex_transforms = [(ComplexTransform.none, ComplexTransform.none),]
59
+
60
+ element_a, element_b, element_c, element_epilogue = data_type
61
+
62
+ operations = []
63
+
64
+ # by default, only generate the largest tile and largest alignment
65
+ if manifest.kernel_filter == '':
66
+ tile_descriptions = [tile_descriptions[0],]
67
+ alignment_constraints = [alignment_constraints[0],]
68
+
69
+ for layout in layouts:
70
+ for tile_description in tile_descriptions:
71
+ for alignment in alignment_constraints:
72
+ for complex_transform in complex_transforms:
73
+
74
+ alignment_c = min(8, alignment)
75
+
76
+ A = TensorDescription(element_a, layout[0], alignment, complex_transform[0])
77
+ B = TensorDescription(element_b, layout[1], alignment, complex_transform[1])
78
+ C = TensorDescription(element_c, layout[2], alignment_c)
79
+
80
+ new_operation = GemmOperation(GemmKind.Universal, tile_description.minimum_compute_capability, \
81
+ tile_description, A, B, C, element_epilogue, epilogue_functor, swizzling_functor)
82
+
83
+ manifest.append(new_operation)
84
+ operations.append(new_operation)
85
+
86
+ return operations
87
+
88
+ #
89
+ def CreateSparseGemmOperator(manifest, layouts, tile_descriptions, data_type, \
90
+ alignment_constraints, complex_transforms = None, epilogue_functor = EpilogueFunctor.LinearCombination, \
91
+ swizzling_functor = SwizzlingFunctor.Identity8):
92
+
93
+ if complex_transforms is None:
94
+ complex_transforms = [(ComplexTransform.none, ComplexTransform.none),]
95
+
96
+ element_a, element_b, element_c, element_epilogue = data_type
97
+
98
+ gemm_kinds = [GemmKind.Sparse]
99
+
100
+ operations = []
101
+
102
+ # by default, only generate the largest tile and largest alignment
103
+ if manifest.kernel_filter == '':
104
+ tile_descriptions = [tile_descriptions[0],]
105
+ alignment_constraints = [alignment_constraints[0],]
106
+
107
+ for layout in layouts:
108
+ for tile_description in tile_descriptions:
109
+ for alignment in alignment_constraints:
110
+ for complex_transform in complex_transforms:
111
+
112
+ alignment_c = min(8, alignment)
113
+
114
+ A = TensorDescription(element_a, layout[0], alignment, complex_transform[0])
115
+ B = TensorDescription(element_b, layout[1], alignment, complex_transform[1])
116
+ C = TensorDescription(element_c, layout[2], alignment_c)
117
+
118
+ new_operation = GemmOperation(GemmKind.Sparse, tile_description.minimum_compute_capability, \
119
+ tile_description, A, B, C, element_epilogue, epilogue_functor, swizzling_functor)
120
+
121
+ manifest.append(new_operation)
122
+ operations.append(new_operation)
123
+
124
+ return operations
125
+
126
+ #
127
+ def CreateGemmPlanarComplexOperator(manifest, layouts, tile_descriptions, data_type, \
128
+ alignment_constraints, complex_transforms):
129
+
130
+ if complex_transforms is None:
131
+ complex_transforms = [(ComplexTransform.none, ComplexTransform.none),]
132
+
133
+ element_a, element_b, element_c, element_epilogue = data_type
134
+
135
+ gemm_kinds = [GemmKind.PlanarComplex, GemmKind.PlanarComplexArray]
136
+
137
+ # by default, only generate the largest tile and largest alignment
138
+ if manifest.kernel_filter == '':
139
+ tile_descriptions = [tile_descriptions[0],]
140
+ alignment_constraints = [alignment_constraints[0],]
141
+
142
+ for gemm_kind in gemm_kinds:
143
+ for layout in layouts:
144
+ for tile_description in tile_descriptions:
145
+ for alignment in alignment_constraints:
146
+ for complex_transform in complex_transforms:
147
+
148
+ alignment_c = min(8, alignment)
149
+
150
+ A = TensorDescription(element_a, layout[0], alignment, complex_transform[0])
151
+ B = TensorDescription(element_b, layout[1], alignment, complex_transform[1])
152
+ C = TensorDescription(element_c, layout[2], alignment_c)
153
+
154
+ manifest.append(GemmOperation(gemm_kind, \
155
+ tile_description.minimum_compute_capability, \
156
+ tile_description, A, B, C, element_epilogue))
157
+ return
158
+
159
+ #
160
+ def CreateGemmGroupedOperator(manifest, layouts, tile_descriptions, data_type, \
161
+ alignment_constraints, complex_transforms = None, epilogue_functor = EpilogueFunctor.LinearCombination, \
162
+ swizzling_functor = SwizzlingFunctor.Identity8):
163
+
164
+ if complex_transforms is None:
165
+ complex_transforms = [(ComplexTransform.none, ComplexTransform.none),]
166
+
167
+ element_a, element_b, element_c, element_epilogue = data_type
168
+
169
+ operations = []
170
+
171
+ # by default, only generate the largest tile and largest alignment
172
+ if manifest.kernel_filter == '':
173
+ tile_descriptions = [tile_descriptions[0],]
174
+ alignment_constraints = [alignment_constraints[0],]
175
+
176
+ for layout in layouts:
177
+ for tile_description in tile_descriptions:
178
+ for alignment in alignment_constraints:
179
+ for complex_transform in complex_transforms:
180
+
181
+ alignment_c = min(8, alignment)
182
+
183
+ A = TensorDescription(element_a, layout[0], alignment, complex_transform[0])
184
+ B = TensorDescription(element_b, layout[1], alignment, complex_transform[1])
185
+ C = TensorDescription(element_c, layout[2], alignment_c)
186
+
187
+ new_operation = GroupedGemmOperation(GemmKind.Grouped, tile_description.minimum_compute_capability, \
188
+ tile_description, A, B, C, element_epilogue, epilogue_functor, swizzling_functor)
189
+
190
+ manifest.append(new_operation)
191
+ operations.append(new_operation)
192
+
193
+ return operations
194
+
195
+ #
196
+ def CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, data_type, \
197
+ alignment_constraints, blas_mode, epilogue_functor = EpilogueFunctor.LinearCombination, \
198
+ swizzling_functor = SwizzlingFunctor.Identity8):
199
+
200
+ element_a, element_c, element_epilogue = data_type
201
+
202
+ operations = []
203
+
204
+ # by default, only generate the largest tile and largest alignment
205
+ if manifest.kernel_filter == '':
206
+ tile_descriptions = [tile_descriptions[0],]
207
+ alignment_constraints = [alignment_constraints[0],]
208
+
209
+ for layout in layouts:
210
+ for fill_mode in fill_modes:
211
+ for tile_description in tile_descriptions:
212
+ for alignment in alignment_constraints:
213
+
214
+ # SERK supported layouts (RowMajor, ColumnMajor) with no conjugation
215
+ complex_transform = ComplexTransform.none
216
+
217
+ # HERK supported layouts (RowMajor + conj, ColumnMajor)
218
+ if blas_mode == BlasMode.hermitian and layout[0] == LayoutType.RowMajor:
219
+ complex_transform = ComplexTransform.conj
220
+
221
+ alignment_c = 1 # Alignment only applies to A in SYRK
222
+
223
+ A = TensorDescription(element_a, layout[0], alignment, complex_transform)
224
+ C = SymmetricTensorDescription(element_c, layout[1], fill_mode, alignment_c)
225
+
226
+ # Rank-K update
227
+ new_operation = RankKOperation(RankKKind.Universal, tile_description.minimum_compute_capability, \
228
+ tile_description, A, C, element_epilogue, epilogue_functor, swizzling_functor, blas_mode)
229
+
230
+ manifest.append(new_operation)
231
+ operations.append(new_operation)
232
+
233
+ # Rank-2K update
234
+ new_operation = Rank2KOperation(RankKKind.Universal, tile_description.minimum_compute_capability, \
235
+ tile_description, A, C, element_epilogue, epilogue_functor, swizzling_functor, blas_mode)
236
+
237
+ manifest.append(new_operation)
238
+ operations.append(new_operation)
239
+
240
+ return operations
241
+
242
+ #
243
+ def CreateTrmmOperator(manifest, layouts, side_modes, fill_modes, diag_types, tile_descriptions, data_type, \
244
+ alignment_constraints, complex_transforms = None, epilogue_functor = EpilogueFunctor.LinearCombination, \
245
+ swizzling_functor = SwizzlingFunctor.Identity8):
246
+
247
+ if complex_transforms is None:
248
+ complex_transforms = [(ComplexTransform.none),]
249
+
250
+ element_a, element_b, element_c, element_epilogue = data_type
251
+
252
+ operations = []
253
+
254
+ # by default, only generate the largest tile and largest alignment
255
+ if manifest.kernel_filter == '':
256
+ tile_descriptions = [tile_descriptions[0],]
257
+ alignment_constraints = [alignment_constraints[0],]
258
+
259
+ for layout in layouts:
260
+ for side_mode in side_modes:
261
+ for fill_mode in fill_modes:
262
+ for diag_type in diag_types:
263
+ for tile_description in tile_descriptions:
264
+ for alignment in alignment_constraints:
265
+ for complex_transform in complex_transforms:
266
+
267
+ alignment_c = min(8, alignment)
268
+
269
+ A = TriangularTensorDescription(element_a, layout[0], side_mode, fill_mode, diag_type,
270
+ alignment, complex_transform)
271
+ B = TensorDescription(element_b, layout[1], alignment)
272
+ C = TensorDescription(element_c, layout[2], alignment_c)
273
+
274
+ new_operation = TrmmOperation(TrmmKind.Universal, tile_description.minimum_compute_capability, \
275
+ tile_description, A, B, C, element_epilogue, epilogue_functor, swizzling_functor)
276
+
277
+ manifest.append(new_operation)
278
+ operations.append(new_operation)
279
+
280
+ return operations
281
+
282
+ #
283
+ def CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, data_type, \
284
+ alignment_constraints, blas_mode, epilogue_functor = EpilogueFunctor.LinearCombination, \
285
+ swizzling_functor = SwizzlingFunctor.Identity8):
286
+
287
+ element_a, element_b, element_c, element_epilogue = data_type
288
+
289
+ operations = []
290
+
291
+ # by default, only generate the largest tile and largest alignment
292
+ if manifest.kernel_filter == '':
293
+ tile_descriptions = [tile_descriptions[0],]
294
+ alignment_constraints = [alignment_constraints[0],]
295
+
296
+ for layout in layouts:
297
+ for side_mode in side_modes:
298
+ for fill_mode in fill_modes:
299
+ for tile_description in tile_descriptions:
300
+ for alignment in alignment_constraints:
301
+
302
+ # SYMM supported layouts (RowMajor, ColumnMajor) with no conjugation
303
+ complex_transform = ComplexTransform.none
304
+
305
+ alignment_a = 1 # No vectorized access for the triangular matrix
306
+ alignment_c = min(8, alignment)
307
+
308
+ A = SymmetricTensorDescription(element_a, layout[0], fill_mode, alignment_a, complex_transform, side_mode)
309
+ # tensor A and B have same data type and layout
310
+ B = TensorDescription(element_b, layout[0], alignment)
311
+ C = TensorDescription(element_c, layout[1], alignment_c)
312
+
313
+ # SYMM/HEMM update
314
+ new_operation = SymmOperation(SymmKind.Universal, tile_description.minimum_compute_capability, \
315
+ tile_description, A, B, C, element_epilogue, epilogue_functor, swizzling_functor, blas_mode)
316
+
317
+ manifest.append(new_operation)
318
+ operations.append(new_operation)
319
+
320
+ # SYMM/HEMM update
321
+ new_operation = SymmOperation(SymmKind.Universal, tile_description.minimum_compute_capability, \
322
+ tile_description, A, B, C, element_epilogue, epilogue_functor, swizzling_functor, blas_mode)
323
+
324
+ manifest.append(new_operation)
325
+ operations.append(new_operation)
326
+
327
+ return operations
328
+
329
+ ###########################################################################################################
330
+ # ConvolutionOperator support variations
331
+ # ____________________________________________________________________
332
+ # ConvolutionalOperator | Analytic | Optimized
333
+ # ____________________________________________________________________
334
+ # | Fprop | (strided) | (strided)
335
+ # | Dgrad | (strided, unity*) | (strided, unity)
336
+ # | Wgrad | (strided) | (strided)
337
+ # ____________________________________________________________________
338
+ #
339
+ # Note : Operator marked (*) are supported but not generated to keep the instantiated kernel count low
340
+ ###########################################################################################################
341
+ # Convolution for 2D operations
342
+ def CreateConv2dOperator(manifest, layout, tile_descriptions, data_type, alignment_constraints, \
343
+ conv_kinds = [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad], \
344
+ epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity4):
345
+
346
+ element_a, element_b, element_c, element_epilogue = data_type
347
+
348
+ # one exceptional case
349
+
350
+ # iterator algorithm (analytic and optimized)
351
+ #iterator_algorithms = [IteratorAlgorithm.Analytic, IteratorAlgorithm.Optimized]
352
+ iterator_algorithms = [IteratorAlgorithm.Optimized]
353
+
354
+ # by default, only generate the largest tile size, largest alignment, and optimized iterator
355
+ if manifest.kernel_filter == '':
356
+ tile_descriptions = [tile_descriptions[0],]
357
+ alignment_constraints = [alignment_constraints[0],]
358
+ iterator_algorithms = [IteratorAlgorithm.Optimized]
359
+
360
+ operations = []
361
+
362
+ for tile in tile_descriptions:
363
+ for alignment in alignment_constraints:
364
+
365
+ alignment_c = min(8, alignment)
366
+
367
+ A = TensorDescription(element_a, layout[0], alignment)
368
+ B = TensorDescription(element_b, layout[1], alignment)
369
+ C = TensorDescription(element_c, layout[2], alignment_c)
370
+
371
+ swizzling_functor_ = swizzling_functor
372
+
373
+ #
374
+ # Conv2d Fprop
375
+ #
376
+ if ConvKind.Fprop in conv_kinds:
377
+
378
+ # Strided support for Analytic and Optimized Fprop
379
+ for iterator_algorithm in iterator_algorithms:
380
+ new_operations = [
381
+ # None grouped kernel
382
+ Conv2dOperation(ConvKind.Fprop, iterator_algorithm, tile.minimum_compute_capability, tile,\
383
+ A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_),
384
+ ]
385
+
386
+ # Instance group conv kernel
387
+ if tile.math_instruction.opcode_class == OpcodeClass.TensorOp and A.layout == LayoutType.TensorNHWC:
388
+ # SingleGroup kernel
389
+ new_operations.append(Conv2dOperation(ConvKind.Fprop, iterator_algorithm, tile.minimum_compute_capability, tile,\
390
+ A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_, group_mode=GroupMode.SingleGroup))
391
+
392
+ # Analytic iterator supports MultipleGroup mode
393
+ if iterator_algorithm == IteratorAlgorithm.Analytic:
394
+ new_operations.append(Conv2dOperation(ConvKind.Fprop, iterator_algorithm, tile.minimum_compute_capability, tile,\
395
+ A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_, group_mode=GroupMode.MultipleGroup))
396
+
397
+ for new_operation in new_operations:
398
+ manifest.append(new_operation)
399
+ operations.append(new_operation)
400
+
401
+ #
402
+ # Conv2d Dgrad
403
+ #
404
+ if ConvKind.Dgrad in conv_kinds:
405
+
406
+ # Unity stride for Analytic and Optimized Dgrad
407
+ for iterator_algorithm in iterator_algorithms:
408
+ new_operation = Conv2dOperation(ConvKind.Dgrad, iterator_algorithm, tile.minimum_compute_capability, tile,\
409
+ A, B, C, element_epilogue, StrideSupport.Unity, epilogue_functor, swizzling_functor_)
410
+
411
+ manifest.append(new_operation)
412
+ operations.append(new_operation)
413
+
414
+ # Strided support for Analytic Dgrad
415
+ # strided dgrad uses a special threadblock swizzle
416
+ # note that SwizzlingFunctor.StridedDgradHorizontal might be
417
+ # better for problem sizes with large activation channel count
418
+ swizzling_functor_strided_dgrad_ = SwizzlingFunctor.StridedDgradIdentity1
419
+
420
+ if IteratorAlgorithm.Analytic in iterator_algorithms:
421
+ new_operation = Conv2dOperation(ConvKind.Dgrad, IteratorAlgorithm.Analytic, tile.minimum_compute_capability, tile,\
422
+ A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_strided_dgrad_)
423
+
424
+ manifest.append(new_operation)
425
+ operations.append(new_operation)
426
+
427
+ # Strided support for Optimized Dgrad
428
+ if IteratorAlgorithm.Optimized in iterator_algorithms:
429
+ new_operation = Conv2dOperation(ConvKind.Dgrad, IteratorAlgorithm.Optimized, tile.minimum_compute_capability, tile,\
430
+ A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_strided_dgrad_)
431
+
432
+ manifest.append(new_operation)
433
+ operations.append(new_operation)
434
+
435
+ #
436
+ # Conv2d Wgrad
437
+ #
438
+ if ConvKind.Wgrad in conv_kinds:
439
+
440
+ # Strided support for Analytic and Optimized Wgrad
441
+ for iterator_algorithm in iterator_algorithms:
442
+ new_operation = Conv2dOperation(ConvKind.Wgrad, iterator_algorithm, tile.minimum_compute_capability, tile,\
443
+ A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_)
444
+
445
+ manifest.append(new_operation)
446
+ operations.append(new_operation)
447
+
448
+ return operations
449
+
450
+ # Convolution for 2D operations specialized for few channels
451
+ def CreateConv2dFixedChannelsOperator(manifest, layout, tile_descriptions, data_type, channel_counts, \
452
+ conv_kinds = [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad], \
453
+ epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity4):
454
+
455
+ element_a, element_b, element_c, element_epilogue = data_type
456
+
457
+ # one exceptional case
458
+
459
+ # iterator algorithm (analytic and optimized)
460
+ iterator_algorithms = [IteratorAlgorithm.FixedChannels,]
461
+
462
+ # by default, only generate the largest tile size, largest alignment, and optimized iterator
463
+ if manifest.kernel_filter == '':
464
+ tile_descriptions = [tile_descriptions[0],]
465
+ channel_counts = [channel_counts[0],]
466
+
467
+ operations = []
468
+
469
+
470
+
471
+ for tile in tile_descriptions:
472
+ for channel_count in channel_counts:
473
+
474
+ alignment_c = EpilogueAlignment(channel_count, tile)
475
+
476
+ A = TensorDescription(element_a, layout[0], channel_count)
477
+ B = TensorDescription(element_b, layout[1], channel_count)
478
+ C = TensorDescription(element_c, layout[2], alignment_c)
479
+
480
+ swizzling_functor_ = swizzling_functor
481
+
482
+ #
483
+ # Conv2d Fprop
484
+ #
485
+ if ConvKind.Fprop in conv_kinds:
486
+
487
+ # Strided support for Analytic and Optimized Fprop
488
+ for iterator_algorithm in iterator_algorithms:
489
+ new_operation = Conv2dOperation(ConvKind.Fprop, iterator_algorithm, tile.minimum_compute_capability, tile,\
490
+ A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_)
491
+
492
+ manifest.append(new_operation)
493
+ operations.append(new_operation)
494
+
495
+
496
+ # Convolution for 2D operations specialized for few channels
497
+ def CreateConv2dFewChannelsOperator(manifest, layout, tile_descriptions, data_type, channel_counts, \
498
+ conv_kinds = [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad], \
499
+ epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity4):
500
+
501
+ element_a, element_b, element_c, element_epilogue = data_type
502
+
503
+ # one exceptional case
504
+
505
+ # iterator algorithm (analytic and optimized)
506
+ iterator_algorithms = [IteratorAlgorithm.FewChannels,]
507
+
508
+ # by default, only generate the largest tile size, largest alignment, and optimized iterator
509
+ if manifest.kernel_filter == '':
510
+ tile_descriptions = [tile_descriptions[0],]
511
+ channel_counts = [channel_counts[0],]
512
+
513
+ operations = []
514
+
515
+ for tile in tile_descriptions:
516
+ for channel_count in channel_counts:
517
+
518
+ alignment_c = EpilogueAlignment(channel_count, tile)
519
+
520
+ A = TensorDescription(element_a, layout[0], channel_count)
521
+ B = TensorDescription(element_b, layout[1], channel_count)
522
+ C = TensorDescription(element_c, layout[2], alignment_c)
523
+
524
+ swizzling_functor_ = swizzling_functor
525
+
526
+ #
527
+ # Conv2d Fprop
528
+ #
529
+ if ConvKind.Fprop in conv_kinds:
530
+
531
+ # Strided support for Analytic and Optimized Fprop
532
+ for iterator_algorithm in iterator_algorithms:
533
+ new_operation = Conv2dOperation(ConvKind.Fprop, iterator_algorithm, tile.minimum_compute_capability, tile,\
534
+ A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_)
535
+
536
+ manifest.append(new_operation)
537
+ operations.append(new_operation)
538
+
539
+ # Convolution for 3D operations
540
+ def CreateConv3dOperator(manifest, layout, tile_descriptions, data_type, alignment, \
541
+ conv_kinds = [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad], epilogue_functor = EpilogueFunctor.LinearCombination):
542
+
543
+ element_a, element_b, element_c, element_epilogue = data_type
544
+
545
+ # one exceptional case
546
+ alignment_c = min(8, alignment)
547
+
548
+ # iterator algorithm (analytic and optimized)
549
+ # iterator_algorithms = [IteratorAlgorithm.Analytic, IteratorAlgorithm.Optimized]
550
+ iterator_algorithms = [IteratorAlgorithm.Optimized]
551
+
552
+ # by default, only generate the largest tile size and optimized iterators
553
+ if manifest.kernel_filter == '':
554
+ tile_descriptions = [tile_descriptions[0],]
555
+ iterator_algorithms = [IteratorAlgorithm.Optimized]
556
+
557
+ operations = []
558
+
559
+ # All tile sizes for Conv3dFprop and Conv3dWgrad
560
+ for tile in tile_descriptions:
561
+ A = TensorDescription(element_a, layout, alignment)
562
+ B = TensorDescription(element_b, layout, alignment)
563
+ C = TensorDescription(element_c, layout, alignment_c)
564
+
565
+ #
566
+ # Conv3d Fprop
567
+ #
568
+ if ConvKind.Fprop in conv_kinds:
569
+ # Strided support for Analytic and Optimized Fprop
570
+ for iterator_algorithm in iterator_algorithms:
571
+ new_operation = Conv3dOperation(ConvKind.Fprop, iterator_algorithm, tile.minimum_compute_capability, tile,\
572
+ A, B, C, element_epilogue, StrideSupport.Strided)
573
+ manifest.append(new_operation)
574
+ operations.append(new_operation)
575
+ #
576
+ # Conv3d Wgrad
577
+ #
578
+ if ConvKind.Wgrad in conv_kinds:
579
+
580
+ # Strided support for Analytic and Optimized Wgrad
581
+ for iterator_algorithm in iterator_algorithms:
582
+ new_operation = Conv3dOperation(ConvKind.Wgrad, iterator_algorithm, tile.minimum_compute_capability, tile,\
583
+ A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor)
584
+ manifest.append(new_operation)
585
+ operations.append(new_operation)
586
+
587
+ # All tile sizes for Conv3dDgrad
588
+ for tile in tile_descriptions:
589
+
590
+ A = TensorDescription(element_a, layout, alignment)
591
+ B = TensorDescription(element_b, layout, alignment)
592
+ C = TensorDescription(element_c, layout, alignment_c)
593
+
594
+ #
595
+ # Conv3d Dgrad
596
+ #
597
+ if ConvKind.Dgrad in conv_kinds:
598
+ # Unity stride for Optimized Dgrad
599
+ new_operation = Conv3dOperation(ConvKind.Dgrad, IteratorAlgorithm.Optimized, tile.minimum_compute_capability, tile,\
600
+ A, B, C, element_epilogue, StrideSupport.Unity, epilogue_functor)
601
+
602
+ manifest.append(new_operation)
603
+ operations.append(new_operation)
604
+
605
+ # Strided support for Analytic Dgrad
606
+ # Conv3dDgrad has a naive strided support which does not cut down redundant MMAs
607
+ new_operation = Conv3dOperation(ConvKind.Dgrad, IteratorAlgorithm.Analytic, tile.minimum_compute_capability, tile,\
608
+ A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor)
609
+
610
+ manifest.append(new_operation)
611
+ operations.append(new_operation)
612
+
613
+ return operations
614
+
615
+ # Convolution for Depthwise 2d conv
616
+ def CreateDepthwiseConv2dOperator(manifest, layout, tile_descriptions, data_type, alignment_constraints, \
617
+ conv_kinds = [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad], \
618
+ epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity4):
619
+
620
+ element_a, element_b, element_c, element_epilogue = data_type
621
+
622
+ # iterator algorithm (FixedStrideDilation, Optimized)
623
+ iterator_algorithms = [IteratorAlgorithm.FixedStrideDilation, IteratorAlgorithm.Optimized]
624
+
625
+ # by default, only generate the largest tile size, largest alignment, and optimized iterator
626
+ if manifest.kernel_filter == '':
627
+ tile_descriptions = [tile_descriptions[0],]
628
+ alignment_constraints = [alignment_constraints[0],]
629
+
630
+ operations = []
631
+
632
+ for tile in tile_descriptions:
633
+ for alignment in alignment_constraints:
634
+
635
+ alignment_c = min(8, alignment)
636
+
637
+ A = TensorDescription(element_a, layout[0], alignment)
638
+ B = TensorDescription(element_b, layout[1], alignment)
639
+ C = TensorDescription(element_c, layout[2], alignment_c)
640
+
641
+ swizzling_functor_ = swizzling_functor
642
+
643
+ if ConvKind.Fprop in conv_kinds:
644
+
645
+ # Strided support for Optimized and FixedStridedDilation Depthwise Conv
646
+ for iterator_algorithm in iterator_algorithms:
647
+ stride_support = StrideSupport.Strided
648
+ if iterator_algorithm == IteratorAlgorithm.FixedStrideDilation:
649
+ if tile.stride == [-1, -1] or tile.dilation == [-1,-1]:
650
+ continue
651
+ stride_support = StrideSupport.Fixed
652
+
653
+ if iterator_algorithm == IteratorAlgorithm.Optimized:
654
+ if tile.stride != [-1, -1] or tile.dilation != [-1,-1]:
655
+ continue
656
+ new_operation = Conv2dOperation(ConvKind.Fprop,
657
+ iterator_algorithm,
658
+ tile.minimum_compute_capability,
659
+ tile,
660
+ A, B, C,
661
+ element_epilogue,
662
+ stride_support,
663
+ epilogue_functor,
664
+ swizzling_functor_,
665
+ group_mode=GroupMode.Depthwise)
666
+
667
+ manifest.append(new_operation)
668
+ operations.append(new_operation)
669
+
670
+ return operations
671
+
672
+ ###################################################################################################
673
+ ###################################################################################################
674
+
675
+ #
676
+ def GenerateSM50_Simt(manifest, cuda_version):
677
+ layouts = [
678
+ (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
679
+ (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
680
+ (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
681
+ (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
682
+ ]
683
+
684
+ math_instructions = [
685
+ MathInstruction( \
686
+ [1, 1, 1], \
687
+ DataType.f32, DataType.f32, DataType.f32, \
688
+ OpcodeClass.Simt, \
689
+ MathOperation.multiply_add),
690
+ MathInstruction( \
691
+ [1, 1, 1], \
692
+ DataType.f64, DataType.f64, DataType.f64, \
693
+ OpcodeClass.Simt, \
694
+ MathOperation.multiply_add),
695
+ ]
696
+
697
+ min_cc = 50
698
+ max_cc = 1024
699
+
700
+ alignment_constraints = [1,]
701
+
702
+ for math_inst in math_instructions:
703
+ tile_descriptions = [
704
+ TileDescription([128, 128, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc),
705
+ TileDescription([128, 64, 8], 2, [2, 2, 1], math_inst, min_cc, max_cc),
706
+ TileDescription([ 64, 128, 8], 2, [2, 2, 1], math_inst, min_cc, max_cc),
707
+ TileDescription([ 64, 64, 8], 2, [2, 1, 1], math_inst, min_cc, max_cc),
708
+ TileDescription([128, 32, 8], 2, [2, 1, 1], math_inst, min_cc, max_cc),
709
+ TileDescription([ 32, 128, 8], 2, [1, 2, 1], math_inst, min_cc, max_cc),
710
+ ]
711
+
712
+ data_type = [
713
+ math_inst.element_a,
714
+ math_inst.element_b,
715
+ math_inst.element_accumulator,
716
+ math_inst.element_accumulator,
717
+ ]
718
+
719
+ CreateGemmOperator(manifest, layouts, tile_descriptions, \
720
+ data_type, alignment_constraints)
721
+
722
+ if math_inst.element_a == DataType.f32:
723
+ conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
724
+ CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints)
725
+ #
726
+
727
+ #
728
+ def GenerateSM50_Simt_complex(manifest, cuda_version):
729
+ layouts = [
730
+ (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
731
+ (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
732
+ (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
733
+ (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
734
+ ]
735
+
736
+ math_instructions = [
737
+ MathInstruction( \
738
+ [1, 1, 1], \
739
+ DataType.f32, DataType.f32, DataType.f32, \
740
+ OpcodeClass.Simt, \
741
+ MathOperation.multiply_add_complex),
742
+ ]
743
+
744
+ min_cc = 50
745
+ max_cc = 1024
746
+
747
+ alignment_constraints = [1,]
748
+
749
+ for math_inst in math_instructions:
750
+ tile_descriptions = [
751
+ TileDescription([128, 64, 8], 2, [2, 2, 1], math_inst, min_cc, max_cc),
752
+ TileDescription([ 64, 128, 8], 2, [2, 2, 1], math_inst, min_cc, max_cc),
753
+ TileDescription([ 64, 64, 8], 2, [2, 1, 1], math_inst, min_cc, max_cc),
754
+ TileDescription([128, 32, 8], 2, [2, 1, 1], math_inst, min_cc, max_cc),
755
+ TileDescription([ 32, 128, 8], 2, [1, 2, 1], math_inst, min_cc, max_cc),
756
+ TileDescription([128, 128, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc),
757
+ ]
758
+
759
+ data_type = [
760
+ DataType.cf32,
761
+ DataType.cf32,
762
+ DataType.cf32,
763
+ DataType.cf32,
764
+ ]
765
+
766
+
767
+ CreateGemmOperator(manifest, layouts, tile_descriptions, \
768
+ data_type, alignment_constraints)
769
+
770
+ conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
771
+ CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints)
772
+ #
773
+
774
+ #
775
+ def GenerateSM50(manifest, cuda_version):
776
+ GenerateSM50_Simt(manifest, cuda_version)
777
+ GenerateSM50_Simt_complex(manifest, cuda_version)
778
+
779
+ ###################################################################################################
780
+ ###################################################################################################
781
+
782
+ #
783
+ def GenerateSM60_Simt(manifest, cuda_version):
784
+ layouts = [
785
+ (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
786
+ (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
787
+ (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
788
+ (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
789
+ ]
790
+
791
+ math_instructions = [
792
+ MathInstruction( \
793
+ [1, 1, 1], \
794
+ DataType.f16, DataType.f16, DataType.f16, \
795
+ OpcodeClass.Simt, \
796
+ MathOperation.multiply_add),
797
+ ]
798
+
799
+ min_cc = 60
800
+ max_cc = 1024
801
+
802
+ alignment_constraints = [1,]
803
+
804
+ for math_inst in math_instructions:
805
+ tile_descriptions = [
806
+ TileDescription([256, 128, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc),
807
+ TileDescription([128, 256, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc),
808
+ TileDescription([128, 128, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc),
809
+ TileDescription([128, 64, 8], 2, [2, 2, 1], math_inst, min_cc, max_cc),
810
+ TileDescription([ 64, 128, 8], 2, [2, 2, 1], math_inst, min_cc, max_cc),
811
+ TileDescription([ 64, 64, 8], 2, [2, 1, 1], math_inst, min_cc, max_cc),
812
+ TileDescription([128, 32, 8], 2, [2, 1, 1], math_inst, min_cc, max_cc),
813
+ TileDescription([ 32, 128, 8], 2, [1, 2, 1], math_inst, min_cc, max_cc),
814
+ ]
815
+
816
+ data_type = [
817
+ math_inst.element_a,
818
+ math_inst.element_b,
819
+ math_inst.element_accumulator,
820
+ math_inst.element_accumulator,
821
+ ]
822
+
823
+ CreateGemmOperator(manifest, layouts, tile_descriptions, \
824
+ data_type, alignment_constraints)
825
+ #
826
+ def GenerateSM60_Simt_DepthwiseConv2d(manifest, cuda_version):
827
+
828
+ math_instructions = [
829
+ MathInstruction( \
830
+ [1, 1, 1], \
831
+ DataType.f16, DataType.f16, DataType.f16, \
832
+ OpcodeClass.Simt, \
833
+ MathOperation.multiply_add),
834
+ ]
835
+
836
+ min_cc = 60
837
+ max_cc = 1024
838
+
839
+ alignment_constraints = [8,]
840
+
841
+ filter_3x3 = [3, 3]
842
+ filter_5x5 = [5, 5]
843
+
844
+ # [stride_h, stride_w]
845
+ # [-1, -1] means all stride size.
846
+ strides = [[-1,-1], [1, 1], [2, 2]]
847
+ # [dilation_h, dilation_w]
848
+ # [-1, -1] means all dilation size.
849
+ dilations = [[-1,-1], [1, 1], [2, 2]]
850
+
851
+ #groups per thread block
852
+ g16 = 16
853
+ g32 = 32
854
+ g64 = 64
855
+
856
+ #output shape per thread block
857
+ npq_1x4x4 = [1, 4, 4]
858
+ npq_1x8x8 = [1, 8, 8]
859
+ npq_1x10x10 = [1, 10, 10]
860
+
861
+ tile_descriptions = []
862
+ for math_inst in math_instructions:
863
+ for stride, dilation in product(strides, dilations):
864
+ tile_descriptions.extend([
865
+ # filter3x3 ThreadBlock_output, filter, stage, warp
866
+ Direct2dConvFixedStrideDilationTileDescription(npq_1x8x8+[g32], filter_3x3, 3, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc),
867
+ Direct2dConvFixedStrideDilationTileDescription(npq_1x8x8+[g64], filter_3x3, 3, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc),
868
+ Direct2dConvFixedStrideDilationTileDescription(npq_1x8x8+[g16], filter_3x3, 3, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc),
869
+
870
+ Direct2dConvFixedStrideDilationTileDescription(npq_1x10x10+[g64], filter_3x3, 2, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc),
871
+
872
+ Direct2dConvFixedStrideDilationTileDescription(npq_1x4x4+[g32], filter_3x3, 4, stride, dilation, [4, 1, 1], math_inst, min_cc, max_cc),
873
+ Direct2dConvFixedStrideDilationTileDescription(npq_1x4x4+[g64], filter_3x3, 4, stride, dilation,[4, 1, 1], math_inst, min_cc, max_cc),
874
+ Direct2dConvFixedStrideDilationTileDescription(npq_1x4x4+[g16], filter_3x3, 4, stride, dilation, [4, 1, 1], math_inst, min_cc, max_cc),
875
+
876
+ # filter5x5 ThreadBlock_output, filter, stage, warp
877
+ Direct2dConvFixedStrideDilationTileDescription(npq_1x8x8+[g32], filter_5x5, 3, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc),
878
+ Direct2dConvFixedStrideDilationTileDescription(npq_1x8x8+[g64], filter_5x5, 3, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc),
879
+ Direct2dConvFixedStrideDilationTileDescription(npq_1x8x8+[g16], filter_5x5, 3, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc),
880
+
881
+ Direct2dConvFixedStrideDilationTileDescription(npq_1x10x10+[g64], filter_5x5, 2, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc),
882
+
883
+ Direct2dConvFixedStrideDilationTileDescription(npq_1x4x4+[g32], filter_5x5, 4, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc),
884
+ Direct2dConvFixedStrideDilationTileDescription(npq_1x4x4+[g64], filter_5x5, 4, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc),
885
+ Direct2dConvFixedStrideDilationTileDescription(npq_1x4x4+[g16], filter_5x5, 4, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc)
886
+ ])
887
+
888
+ data_type = [
889
+ math_inst.element_a,
890
+ math_inst.element_b,
891
+ math_inst.element_accumulator,
892
+ math_inst.element_accumulator,
893
+ ]
894
+
895
+ conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
896
+ CreateDepthwiseConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints)
897
+ #
898
+
899
+ #
900
+ def GenerateSM60(manifest, cuda_version):
901
+ GenerateSM60_Simt(manifest, cuda_version)
902
+ GenerateSM60_Simt_DepthwiseConv2d(manifest, cuda_version)
903
+
904
+ ###################################################################################################
905
+ ###################################################################################################
906
+
907
+ #
908
+ def GenerateSM61_Simt(manifest, cuda_version):
909
+ layouts = [
910
+ (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
911
+ (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
912
+ (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
913
+ (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
914
+ ]
915
+
916
+ math_instructions = [
917
+ MathInstruction( \
918
+ [1, 1, 4], \
919
+ DataType.s8, DataType.s8, DataType.s32, \
920
+ OpcodeClass.Simt, \
921
+ MathOperation.multiply_add),
922
+ ]
923
+
924
+ min_cc = 61
925
+ max_cc = 1024
926
+
927
+ alignment_constraints = [1,]
928
+
929
+ for math_inst in math_instructions:
930
+ tile_descriptions = [
931
+ TileDescription([128, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc),
932
+ TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
933
+ TileDescription([ 64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
934
+ TileDescription([ 64, 64, 32], 2, [2, 1, 1], math_inst, min_cc, max_cc),
935
+ TileDescription([128, 32, 32], 2, [2, 1, 1], math_inst, min_cc, max_cc),
936
+ TileDescription([ 32, 128, 32], 2, [1, 2, 1], math_inst, min_cc, max_cc),
937
+ ]
938
+
939
+ data_type = [
940
+ math_inst.element_a,
941
+ math_inst.element_b,
942
+ math_inst.element_accumulator,
943
+ math_inst.element_accumulator,
944
+ ]
945
+ data_type_mixed = [
946
+ math_inst.element_a,
947
+ math_inst.element_b,
948
+ math_inst.element_a,
949
+ math_inst.element_accumulator,
950
+ ]
951
+
952
+ CreateGemmOperator(manifest, layouts, tile_descriptions, \
953
+ data_type, alignment_constraints)
954
+
955
+ CreateGemmOperator(manifest, layouts, tile_descriptions, \
956
+ data_type_mixed, alignment_constraints)
957
+ #
958
+
959
+ #
960
+ def GenerateSM61(manifest, cuda_version):
961
+ GenerateSM61_Simt(manifest, cuda_version)
962
+
963
+ ###################################################################################################
964
+ ###################################################################################################
965
+
966
+ #
967
+ def GenerateSM70_TensorOp_884(manifest, cuda_version):
968
+
969
+ if not CudaToolkitVersionSatisfies(cuda_version, 10, 1):
970
+ return
971
+
972
+ layouts = [
973
+ (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
974
+ (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
975
+ (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
976
+ (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
977
+ ]
978
+
979
+ math_instructions = [
980
+ MathInstruction( \
981
+ [8, 8, 4], \
982
+ DataType.f16, DataType.f16, DataType.f32, \
983
+ OpcodeClass.TensorOp, \
984
+ MathOperation.multiply_add),
985
+ MathInstruction( \
986
+ [8, 8, 4], \
987
+ DataType.f16, DataType.f16, DataType.f16, \
988
+ OpcodeClass.TensorOp, \
989
+ MathOperation.multiply_add),
990
+ ]
991
+
992
+ min_cc = 70
993
+ max_cc = 75
994
+
995
+ alignment_constraints = [8, 4, 2, 1]
996
+
997
+ for math_inst in math_instructions:
998
+ tile_descriptions = [
999
+ TileDescription([256, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc),
1000
+ TileDescription([128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc),
1001
+ TileDescription([128, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
1002
+ TileDescription([256, 64, 32], 2, [4, 1, 1], math_inst, min_cc, max_cc),
1003
+ TileDescription([ 64, 256, 32], 2, [1, 4, 1], math_inst, min_cc, max_cc),
1004
+ TileDescription([ 64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
1005
+ TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
1006
+ TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
1007
+ ]
1008
+
1009
+ data_type = [
1010
+ math_inst.element_a,
1011
+ math_inst.element_b,
1012
+ math_inst.element_accumulator,
1013
+ math_inst.element_accumulator,
1014
+ ]
1015
+
1016
+ CreateGemmOperator(manifest, layouts, tile_descriptions, \
1017
+ data_type, alignment_constraints)
1018
+
1019
+ conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
1020
+ CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints)
1021
+
1022
+ # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation)
1023
+ if math_inst.element_a != math_inst.element_accumulator:
1024
+
1025
+ data_type_mixed = [
1026
+ math_inst.element_a,
1027
+ math_inst.element_b,
1028
+ math_inst.element_a,
1029
+ math_inst.element_accumulator,
1030
+ ]
1031
+
1032
+ CreateGemmOperator(manifest, layouts, tile_descriptions, \
1033
+ data_type_mixed, alignment_constraints)
1034
+
1035
+ CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, alignment_constraints)
1036
+
1037
+ #
1038
+ def GenerateSM70_PlanarComplexTensorOp_884(manifest, cuda_version):
1039
+
1040
+ if not CudaToolkitVersionSatisfies(cuda_version, 10, 1):
1041
+ return
1042
+
1043
+ layouts = [
1044
+ (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
1045
+ (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
1046
+ (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
1047
+ (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
1048
+ ]
1049
+
1050
+ complex_transforms = [
1051
+ (ComplexTransform.none, ComplexTransform.none),
1052
+ (ComplexTransform.conj, ComplexTransform.none),
1053
+ (ComplexTransform.none, ComplexTransform.conj),
1054
+ (ComplexTransform.conj, ComplexTransform.conj)
1055
+ ]
1056
+
1057
+ math_instructions = [
1058
+ MathInstruction( \
1059
+ [8, 8, 4], \
1060
+ DataType.f16, DataType.f16, DataType.f32, \
1061
+ OpcodeClass.TensorOp, \
1062
+ MathOperation.multiply_add),
1063
+ MathInstruction( \
1064
+ [8, 8, 4], \
1065
+ DataType.f16, DataType.f16, DataType.f16, \
1066
+ OpcodeClass.TensorOp, \
1067
+ MathOperation.multiply_add),
1068
+ ]
1069
+
1070
+ min_cc = 70
1071
+ max_cc = 75
1072
+
1073
+ alignment_constraints = [8, 2, 1]
1074
+
1075
+ for math_inst in math_instructions:
1076
+ tile_descriptions = [
1077
+ TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
1078
+ ]
1079
+
1080
+ data_type = [
1081
+ math_inst.element_a,
1082
+ math_inst.element_b,
1083
+ math_inst.element_accumulator,
1084
+ math_inst.element_accumulator,
1085
+ ]
1086
+
1087
+ CreateGemmPlanarComplexOperator(manifest, layouts, tile_descriptions, \
1088
+ data_type, alignment_constraints, complex_transforms)
1089
+
1090
+ # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation)
1091
+ if math_inst.element_a != math_inst.element_accumulator:
1092
+
1093
+ data_type_mixed = [
1094
+ math_inst.element_a,
1095
+ math_inst.element_b,
1096
+ math_inst.element_a,
1097
+ math_inst.element_accumulator,
1098
+ ]
1099
+
1100
+ CreateGemmPlanarComplexOperator(manifest, layouts, tile_descriptions, \
1101
+ data_type_mixed, alignment_constraints, complex_transforms)
1102
+
1103
+
1104
+ #
1105
+ def GenerateSM70_WmmaTensorOp_161616(manifest, cuda_version):
1106
+
1107
+ layouts = [
1108
+ (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
1109
+ (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
1110
+ (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
1111
+ (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
1112
+ ]
1113
+
1114
+ math_instructions = [
1115
+ MathInstruction( \
1116
+ [16, 16, 16], \
1117
+ DataType.f16, DataType.f16, DataType.f32, \
1118
+ OpcodeClass.WmmaTensorOp, \
1119
+ MathOperation.multiply_add),
1120
+ MathInstruction( \
1121
+ [16, 16, 16], \
1122
+ DataType.f16, DataType.f16, DataType.f16, \
1123
+ OpcodeClass.WmmaTensorOp, \
1124
+ MathOperation.multiply_add),
1125
+ ]
1126
+
1127
+ min_cc = 70
1128
+ max_cc = 1024
1129
+
1130
+ alignment_constraints = [8,]
1131
+
1132
+ for math_inst in math_instructions:
1133
+ tile_descriptions = [
1134
+ TileDescription([128, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc),
1135
+ TileDescription([ 64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
1136
+ TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
1137
+ TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
1138
+ ]
1139
+
1140
+ data_type = [
1141
+ math_inst.element_a,
1142
+ math_inst.element_b,
1143
+ math_inst.element_accumulator,
1144
+ math_inst.element_accumulator,
1145
+ ]
1146
+
1147
+ CreateGemmOperator(manifest, layouts, tile_descriptions, \
1148
+ data_type, alignment_constraints)
1149
+
1150
+ # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation)
1151
+ if math_inst.element_a != math_inst.element_accumulator:
1152
+
1153
+ data_type_mixed = [
1154
+ math_inst.element_a,
1155
+ math_inst.element_b,
1156
+ math_inst.element_a,
1157
+ math_inst.element_accumulator,
1158
+ ]
1159
+
1160
+ CreateGemmOperator(manifest, layouts, tile_descriptions, \
1161
+ data_type_mixed, alignment_constraints)
1162
+
1163
+ #
1164
+ ##################################################################################################
1165
+ #
1166
+
1167
+ def GenerateSM70(manifest, cuda_version):
1168
+ GenerateSM70_TensorOp_884(manifest, cuda_version)
1169
+ GenerateSM70_PlanarComplexTensorOp_884(manifest, cuda_version)
1170
+
1171
+ # To limit build size, WMMA GEMMs are disabled for now.
1172
+ #
1173
+ #GenerateSM70_WmmaTensorOp_161616(manifest, cuda_version)
1174
+
1175
+ ###################################################################################################
1176
+ ###################################################################################################
1177
+
1178
+ #
1179
+ def GenerateSM75_TensorOp_1688_FewChannels(manifest, cuda_version, math_inst):
1180
+
1181
+ min_cc = 75
1182
+ max_cc = 1024
1183
+
1184
+ tile_descriptions = [
1185
+ TileDescription([128, 64, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc),
1186
+ TileDescription([256, 64, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc),
1187
+ TileDescription([128, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc),
1188
+ TileDescription([ 64, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc),
1189
+ TileDescription([ 64, 128, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc),
1190
+ TileDescription([ 64, 64, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc),
1191
+ TileDescription([ 64, 128, 64], 2, [2, 2, 2], math_inst, min_cc, max_cc),
1192
+ ]
1193
+
1194
+ data_type = [
1195
+ math_inst.element_a,
1196
+ math_inst.element_b,
1197
+ math_inst.element_accumulator,
1198
+ math_inst.element_accumulator,
1199
+ ]
1200
+
1201
+ conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
1202
+
1203
+ CreateConv2dFixedChannelsOperator(manifest, conv_layout, tile_descriptions, data_type, [4, 8])
1204
+ CreateConv2dFewChannelsOperator(manifest, conv_layout, tile_descriptions, data_type, [1, 2, 4])
1205
+
1206
+ # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation)
1207
+ if math_inst.element_a != math_inst.element_accumulator:
1208
+
1209
+ data_type_mixed = [
1210
+ math_inst.element_a,
1211
+ math_inst.element_b,
1212
+ math_inst.element_a,
1213
+ math_inst.element_accumulator,
1214
+ ]
1215
+
1216
+ CreateConv2dFixedChannelsOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, [4, 8])
1217
+ CreateConv2dFewChannelsOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, [1, 2, 4])
1218
+
1219
+ #
1220
+ def GenerateSM75_TensorOp_1688(manifest, cuda_version):
1221
+
1222
+ if not CudaToolkitVersionSatisfies(cuda_version, 10, 2):
1223
+ return
1224
+
1225
+ layouts = [
1226
+ (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
1227
+ (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
1228
+ (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
1229
+ (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
1230
+ ]
1231
+
1232
+ math_instructions = [
1233
+ MathInstruction( \
1234
+ [16, 8, 8], \
1235
+ DataType.f16, DataType.f16, DataType.f32, \
1236
+ OpcodeClass.TensorOp, \
1237
+ MathOperation.multiply_add),
1238
+ MathInstruction( \
1239
+ [16, 8, 8], \
1240
+ DataType.f16, DataType.f16, DataType.f16, \
1241
+ OpcodeClass.TensorOp, \
1242
+ MathOperation.multiply_add),
1243
+ ]
1244
+
1245
+ min_cc = 75
1246
+ max_cc = 1024
1247
+
1248
+ alignment_constraints = [8, 4, 2, 1]
1249
+
1250
+ for math_inst in math_instructions:
1251
+ tile_descriptions = [
1252
+ TileDescription([256, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc),
1253
+ TileDescription([128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc),
1254
+ TileDescription([128, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
1255
+ TileDescription([ 64, 256, 32], 2, [1, 4, 1], math_inst, min_cc, max_cc),
1256
+ TileDescription([256, 64, 32], 2, [4, 1, 1], math_inst, min_cc, max_cc),
1257
+ TileDescription([ 64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
1258
+ TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
1259
+ TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
1260
+ TileDescription([ 64, 128, 64], 2, [1, 2, 2], math_inst, min_cc, max_cc),
1261
+ ]
1262
+
1263
+ data_type = [
1264
+ math_inst.element_a,
1265
+ math_inst.element_b,
1266
+ math_inst.element_accumulator,
1267
+ math_inst.element_accumulator,
1268
+ ]
1269
+
1270
+ CreateGemmOperator(manifest, layouts, tile_descriptions, \
1271
+ data_type, alignment_constraints)
1272
+
1273
+ conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
1274
+
1275
+ CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints)
1276
+
1277
+ # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation)
1278
+ if math_inst.element_a != math_inst.element_accumulator:
1279
+
1280
+ data_type_mixed = [
1281
+ math_inst.element_a,
1282
+ math_inst.element_b,
1283
+ math_inst.element_a,
1284
+ math_inst.element_accumulator,
1285
+ ]
1286
+
1287
+ CreateGemmOperator(manifest, layouts, tile_descriptions, \
1288
+ data_type_mixed, alignment_constraints)
1289
+
1290
+ CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, alignment_constraints)
1291
+
1292
+ # Separate generator for 'few channels' specializations
1293
+ GenerateSM75_TensorOp_1688_FewChannels(manifest, cuda_version, math_inst)
1294
+
1295
+ #
1296
+
1297
+ #
1298
+ def GenerateSM75_PlanarComplexTensorOp_1688(manifest, cuda_version):
1299
+
1300
+ if not CudaToolkitVersionSatisfies(cuda_version, 10, 2):
1301
+ return
1302
+
1303
+ layouts = [
1304
+ (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
1305
+ (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
1306
+ (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
1307
+ (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
1308
+ ]
1309
+
1310
+ complex_transforms = [
1311
+ (ComplexTransform.none, ComplexTransform.none),
1312
+ (ComplexTransform.conj, ComplexTransform.none),
1313
+ (ComplexTransform.none, ComplexTransform.conj),
1314
+ (ComplexTransform.conj, ComplexTransform.conj)
1315
+ ]
1316
+
1317
+ math_instructions = [
1318
+ MathInstruction( \
1319
+ [16, 8, 8], \
1320
+ DataType.f16, DataType.f16, DataType.f32, \
1321
+ OpcodeClass.TensorOp, \
1322
+ MathOperation.multiply_add),
1323
+ MathInstruction( \
1324
+ [16, 8, 8], \
1325
+ DataType.f16, DataType.f16, DataType.f16, \
1326
+ OpcodeClass.TensorOp, \
1327
+ MathOperation.multiply_add),
1328
+ ]
1329
+
1330
+ min_cc = 75
1331
+ max_cc = 1024
1332
+
1333
+ alignment_constraints = [8, 2, 1]
1334
+
1335
+ for math_inst in math_instructions:
1336
+ tile_descriptions = [
1337
+ TileDescription([ 64, 128, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc),
1338
+ TileDescription([128, 64, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc),
1339
+ TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
1340
+ ]
1341
+
1342
+ data_type = [
1343
+ math_inst.element_a,
1344
+ math_inst.element_b,
1345
+ math_inst.element_accumulator,
1346
+ math_inst.element_accumulator,
1347
+ ]
1348
+
1349
+ CreateGemmPlanarComplexOperator(manifest, layouts, tile_descriptions, \
1350
+ data_type, alignment_constraints, complex_transforms)
1351
+
1352
+ # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation)
1353
+ if math_inst.element_a != math_inst.element_accumulator:
1354
+
1355
+ data_type_mixed = [
1356
+ math_inst.element_a,
1357
+ math_inst.element_b,
1358
+ math_inst.element_a,
1359
+ math_inst.element_accumulator,
1360
+ ]
1361
+
1362
+ CreateGemmPlanarComplexOperator(manifest, layouts, tile_descriptions, \
1363
+ data_type_mixed, alignment_constraints, complex_transforms)
1364
+
1365
+ #
1366
+ def GenerateSM75_TensorOp_8816_TN(manifest, cuda_version):
1367
+
1368
+ if not CudaToolkitVersionSatisfies(cuda_version, 10, 2):
1369
+ return
1370
+
1371
+ layouts = [
1372
+ (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
1373
+ ]
1374
+
1375
+ math_instructions = [
1376
+ MathInstruction( \
1377
+ [8, 8, 16], \
1378
+ DataType.s8, DataType.s8, DataType.s32, \
1379
+ OpcodeClass.TensorOp, \
1380
+ MathOperation.multiply_add_saturate),
1381
+ MathInstruction( \
1382
+ [8, 8, 16], \
1383
+ DataType.u8, DataType.u8, DataType.s32, \
1384
+ OpcodeClass.TensorOp, \
1385
+ MathOperation.multiply_add_saturate),
1386
+ ]
1387
+
1388
+ min_cc = 75
1389
+ max_cc = 1024
1390
+
1391
+ alignment_constraints = [16,]
1392
+
1393
+ for math_inst in math_instructions:
1394
+ tile_descriptions = [
1395
+ TileDescription([256, 128, 64], 2, [4, 2, 1], math_inst, min_cc, max_cc),
1396
+ TileDescription([128, 256, 64], 2, [2, 4, 1], math_inst, min_cc, max_cc),
1397
+ TileDescription([128, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc),
1398
+ TileDescription([ 64, 256, 64], 2, [1, 4, 1], math_inst, min_cc, max_cc),
1399
+ TileDescription([256, 64, 64], 2, [4, 1, 1], math_inst, min_cc, max_cc),
1400
+ TileDescription([ 64, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc),
1401
+ TileDescription([128, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc),
1402
+ TileDescription([ 64, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc),
1403
+ ]
1404
+
1405
+ data_type = [
1406
+ math_inst.element_a,
1407
+ math_inst.element_b,
1408
+ math_inst.element_accumulator,
1409
+ DataType.s32,
1410
+ ]
1411
+
1412
+ CreateGemmOperator(manifest, layouts, tile_descriptions, \
1413
+ data_type, alignment_constraints, None, EpilogueFunctor.LinearCombination)
1414
+
1415
+ conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
1416
+ CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
1417
+ data_type, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombination)
1418
+
1419
+ # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation)
1420
+ if math_inst.element_a != math_inst.element_accumulator:
1421
+
1422
+ data_type_mixed = [
1423
+ math_inst.element_a,
1424
+ math_inst.element_b,
1425
+ math_inst.element_a,
1426
+ DataType.f32,
1427
+ ]
1428
+
1429
+ operations = []
1430
+
1431
+ operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \
1432
+ data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp)
1433
+
1434
+ operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
1435
+ data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
1436
+
1437
+ for op in operations:
1438
+ if op.tile_description.threadblock_shape[1] >= 128:
1439
+ op.C.alignment = 16
1440
+ else:
1441
+ op.C.alignment = 8
1442
+
1443
+ #
1444
+
1445
+ #
1446
+ def GenerateSM75_TensorOp_8816_Interleaved(manifest, cuda_version):
1447
+
1448
+ if not CudaToolkitVersionSatisfies(cuda_version, 10, 2):
1449
+ return
1450
+
1451
+ layouts = [
1452
+ (LayoutType.ColumnMajorInterleaved32, LayoutType.RowMajorInterleaved32, LayoutType.ColumnMajorInterleaved32),
1453
+ ]
1454
+
1455
+ math_instructions = [
1456
+ MathInstruction( \
1457
+ [8, 8, 16], \
1458
+ DataType.s8, DataType.s8, DataType.s32, \
1459
+ OpcodeClass.TensorOp, \
1460
+ MathOperation.multiply_add_saturate),
1461
+ MathInstruction( \
1462
+ [8, 8, 16], \
1463
+ DataType.u8, DataType.u8, DataType.s32, \
1464
+ OpcodeClass.TensorOp, \
1465
+ MathOperation.multiply_add_saturate),
1466
+ ]
1467
+
1468
+ min_cc = 75
1469
+ max_cc = 1024
1470
+
1471
+ alignment_constraints = [16,]
1472
+
1473
+ for math_inst in math_instructions:
1474
+ tile_descriptions = [
1475
+ TileDescription([256, 128, 64], 2, [4, 2, 1], math_inst, min_cc, max_cc),
1476
+ TileDescription([128, 256, 64], 2, [2, 4, 1], math_inst, min_cc, max_cc),
1477
+ TileDescription([128, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc),
1478
+ TileDescription([256, 64, 64], 2, [4, 1, 1], math_inst, min_cc, max_cc),
1479
+ TileDescription([ 64, 256, 64], 2, [1, 4, 1], math_inst, min_cc, max_cc),
1480
+ TileDescription([ 64, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc),
1481
+ TileDescription([128, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc),
1482
+ TileDescription([ 64, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc),
1483
+ ]
1484
+
1485
+ data_type_mixed = [
1486
+ math_inst.element_a,
1487
+ math_inst.element_b,
1488
+ math_inst.element_a,
1489
+ DataType.f32,
1490
+ ]
1491
+
1492
+ operations = CreateGemmOperator(manifest, layouts, tile_descriptions, \
1493
+ data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp)
1494
+
1495
+ conv_layout = (LayoutType.TensorNC32HW32, LayoutType.TensorC32RSK32, LayoutType.TensorNC32HW32)
1496
+
1497
+ operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
1498
+ data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
1499
+
1500
+ for op in operations:
1501
+ op.C.alignment = 8
1502
+ #
1503
+
1504
+ #
1505
+ def GenerateSM75_TensorOp_8832_TN(manifest, cuda_version):
1506
+
1507
+ if not CudaToolkitVersionSatisfies(cuda_version, 10, 2):
1508
+ return
1509
+
1510
+ layouts = [
1511
+ (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
1512
+ ]
1513
+
1514
+ math_instructions = [
1515
+ MathInstruction( \
1516
+ [8, 8, 32], \
1517
+ DataType.s4, DataType.s4, DataType.s32, \
1518
+ OpcodeClass.TensorOp, \
1519
+ MathOperation.multiply_add_saturate),
1520
+ MathInstruction( \
1521
+ [8, 8, 32], \
1522
+ DataType.u4, DataType.u4, DataType.s32, \
1523
+ OpcodeClass.TensorOp, \
1524
+ MathOperation.multiply_add_saturate),
1525
+ ]
1526
+
1527
+ min_cc = 75
1528
+ max_cc = 1024
1529
+ alignment_constraints = [32,]
1530
+
1531
+ for math_inst in math_instructions:
1532
+ tile_descriptions = [
1533
+ TileDescription([256, 128, 128], 2, [4, 2, 1], math_inst, min_cc, max_cc),
1534
+ TileDescription([128, 256, 128], 2, [2, 4, 1], math_inst, min_cc, max_cc),
1535
+ TileDescription([128, 128, 128], 2, [2, 2, 1], math_inst, min_cc, max_cc),
1536
+ TileDescription([256, 64, 128], 2, [4, 1, 1], math_inst, min_cc, max_cc),
1537
+ TileDescription([ 64, 256, 128], 2, [1, 4, 1], math_inst, min_cc, max_cc),
1538
+ TileDescription([ 64, 128, 128], 2, [2, 2, 1], math_inst, min_cc, max_cc),
1539
+ TileDescription([128, 64, 128], 2, [2, 2, 1], math_inst, min_cc, max_cc),
1540
+ TileDescription([ 64, 64, 128], 2, [2, 2, 1], math_inst, min_cc, max_cc),
1541
+ ]
1542
+
1543
+ data_type = [
1544
+ math_inst.element_a,
1545
+ math_inst.element_b,
1546
+ math_inst.element_accumulator,
1547
+ DataType.s32,
1548
+ ]
1549
+
1550
+ CreateGemmOperator(manifest, layouts, tile_descriptions, \
1551
+ data_type, alignment_constraints, None, EpilogueFunctor.LinearCombination)
1552
+
1553
+ conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
1554
+ CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
1555
+ data_type, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombination)
1556
+
1557
+ # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation)
1558
+ if math_inst.element_a != math_inst.element_accumulator:
1559
+
1560
+ data_type_mixed = [
1561
+ math_inst.element_a,
1562
+ math_inst.element_b,
1563
+ math_inst.element_a,
1564
+ DataType.f32,
1565
+ ]
1566
+
1567
+ operations = []
1568
+
1569
+ operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \
1570
+ data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp)
1571
+
1572
+ operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
1573
+ data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
1574
+
1575
+ for op in operations:
1576
+ if op.tile_description.threadblock_shape[1] >= 128:
1577
+ op.C.alignment = 16
1578
+ elif op.tile_description.threadblock_shape[1] == 64:
1579
+ op.C.alignment = 8
1580
+ else:
1581
+ op.C.alignment = 8
1582
+
1583
+ #
1584
+
1585
+ #
1586
+ def GenerateSM75_TensorOp_8832_Interleaved(manifest, cuda_version):
1587
+
1588
+ if not CudaToolkitVersionSatisfies(cuda_version, 10, 2):
1589
+ return
1590
+
1591
+ layouts = [
1592
+ (LayoutType.ColumnMajorInterleaved64, LayoutType.RowMajorInterleaved64, LayoutType.ColumnMajorInterleaved64),
1593
+ ]
1594
+
1595
+ math_instructions = [
1596
+ MathInstruction( \
1597
+ [8, 8, 32], \
1598
+ DataType.s4, DataType.s4, DataType.s32, \
1599
+ OpcodeClass.TensorOp, \
1600
+ MathOperation.multiply_add_saturate),
1601
+ MathInstruction( \
1602
+ [8, 8, 32], \
1603
+ DataType.u4, DataType.u4, DataType.s32, \
1604
+ OpcodeClass.TensorOp, \
1605
+ MathOperation.multiply_add_saturate),
1606
+ ]
1607
+
1608
+ min_cc = 75
1609
+ max_cc = 1024
1610
+ alignment_constraints = [32,]
1611
+
1612
+ for math_inst in math_instructions:
1613
+ tile_descriptions = [
1614
+ TileDescription([256, 128, 128], 2, [4, 2, 1], math_inst, min_cc, max_cc),
1615
+ TileDescription([128, 256, 128], 2, [2, 4, 1], math_inst, min_cc, max_cc),
1616
+ TileDescription([128, 128, 128], 2, [2, 2, 1], math_inst, min_cc, max_cc),
1617
+ TileDescription([256, 64, 128], 2, [4, 1, 1], math_inst, min_cc, max_cc),
1618
+ TileDescription([ 64, 256, 128], 2, [1, 4, 1], math_inst, min_cc, max_cc),
1619
+ TileDescription([ 64, 128, 128], 2, [2, 2, 1], math_inst, min_cc, max_cc),
1620
+ ]
1621
+
1622
+ # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation)
1623
+ if math_inst.element_a != math_inst.element_accumulator:
1624
+
1625
+ data_type_mixed = [
1626
+ math_inst.element_a,
1627
+ math_inst.element_b,
1628
+ math_inst.element_a,
1629
+ DataType.f32,
1630
+ ]
1631
+
1632
+ operations = CreateGemmOperator(manifest, layouts, tile_descriptions, \
1633
+ data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp)
1634
+
1635
+ conv_layout = (LayoutType.TensorNC64HW64, LayoutType.TensorC64RSK64, LayoutType.TensorNC64HW64)
1636
+
1637
+ operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
1638
+ data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
1639
+
1640
+ for op in operations:
1641
+ op.C.alignment = 16
1642
+ #
1643
+
1644
+ #
1645
+ def GenerateSM75_TensorOp_88128(manifest, cuda_version):
1646
+
1647
+ if not CudaToolkitVersionSatisfies(cuda_version, 11, 0):
1648
+ return
1649
+
1650
+ layouts = [
1651
+ (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
1652
+ ]
1653
+
1654
+ math_instructions = [
1655
+ MathInstruction( \
1656
+ [8, 8, 128], \
1657
+ DataType.b1, DataType.b1, DataType.s32, \
1658
+ OpcodeClass.TensorOp, \
1659
+ MathOperation.xor_popc),
1660
+ ]
1661
+
1662
+ min_cc = 75
1663
+ max_cc = 1024
1664
+ alignment_constraints = [128,]
1665
+
1666
+ for math_inst in math_instructions:
1667
+ tile_descriptions = [
1668
+ TileDescription([256, 128, 512], 2, [4, 2, 1], math_inst, min_cc, max_cc),
1669
+ TileDescription([128, 256, 512], 2, [2, 4, 1], math_inst, min_cc, max_cc),
1670
+ TileDescription([128, 128, 512], 2, [2, 2, 1], math_inst, min_cc, max_cc),
1671
+ TileDescription([ 64, 256, 512], 2, [1, 4, 1], math_inst, min_cc, max_cc),
1672
+ TileDescription([256, 64, 512], 2, [4, 1, 1], math_inst, min_cc, max_cc),
1673
+ TileDescription([ 64, 128, 512], 2, [2, 2, 1], math_inst, min_cc, max_cc),
1674
+ TileDescription([128, 64, 512], 2, [2, 2, 1], math_inst, min_cc, max_cc),
1675
+ TileDescription([ 64, 64, 512], 2, [2, 2, 1], math_inst, min_cc, max_cc),
1676
+ ]
1677
+
1678
+ data_type = [DataType.b1, DataType.b1, DataType.s32, DataType.s32]
1679
+
1680
+ CreateGemmOperator(manifest, layouts, tile_descriptions, \
1681
+ data_type, alignment_constraints)
1682
+
1683
+ #
1684
+
1685
+ #
1686
+ def GenerateSM75_WmmaTensorOp_161616(manifest, cuda_version):
1687
+
1688
+ if not CudaToolkitVersionSatisfies(cuda_version, 10, 0):
1689
+ return
1690
+
1691
+ layouts = [
1692
+ (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
1693
+ (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
1694
+ (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
1695
+ (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
1696
+ ]
1697
+
1698
+ math_instructions = [
1699
+ MathInstruction( \
1700
+ [16, 16, 16], \
1701
+ DataType.s8, DataType.s8, DataType.s32, \
1702
+ OpcodeClass.WmmaTensorOp, \
1703
+ MathOperation.multiply_add),
1704
+ ]
1705
+
1706
+ min_cc = 75
1707
+ max_cc = 1024
1708
+
1709
+ alignment_constraints = [16,]
1710
+
1711
+ for math_inst in math_instructions:
1712
+ tile_descriptions = [
1713
+ TileDescription([128, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc),
1714
+ TileDescription([ 64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
1715
+ TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
1716
+ TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
1717
+ ]
1718
+
1719
+ data_type = [
1720
+ math_inst.element_a,
1721
+ math_inst.element_b,
1722
+ math_inst.element_accumulator,
1723
+ DataType.f32,
1724
+ ]
1725
+
1726
+ CreateGemmOperator(manifest, layouts, tile_descriptions, \
1727
+ data_type, alignment_constraints)
1728
+
1729
+ # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation)
1730
+ if math_inst.element_a != math_inst.element_accumulator:
1731
+
1732
+ data_type_mixed = [
1733
+ math_inst.element_a,
1734
+ math_inst.element_b,
1735
+ math_inst.element_a,
1736
+ DataType.f32,
1737
+ ]
1738
+
1739
+ CreateGemmOperator(manifest, layouts, tile_descriptions, \
1740
+ data_type_mixed, alignment_constraints)
1741
+ #
1742
+
1743
+ #
1744
+ def GenerateSM75_Simt_complex(manifest, cuda_version):
1745
+ math_instructions = [
1746
+ MathInstruction( \
1747
+ [1, 1, 1], \
1748
+ DataType.f32, DataType.f32, DataType.f32, \
1749
+ OpcodeClass.Simt, \
1750
+ MathOperation.multiply_add_complex),
1751
+ ]
1752
+
1753
+ min_cc = 75
1754
+ max_cc = 1024
1755
+
1756
+ alignment_constraints = [1,]
1757
+
1758
+ for math_inst in math_instructions:
1759
+ tile_descriptions = [
1760
+ TileDescription([128, 128, 8], 5, [4, 2, 1], math_inst, min_cc, max_cc)
1761
+ ]
1762
+ data_type = [
1763
+ DataType.cf32,
1764
+ DataType.cf32,
1765
+ DataType.cf32,
1766
+ DataType.cf32
1767
+ ]
1768
+
1769
+ complex_transforms = [
1770
+ (ComplexTransform.none, ComplexTransform.none),
1771
+ (ComplexTransform.conj, ComplexTransform.none),
1772
+ (ComplexTransform.none, ComplexTransform.conj),
1773
+ (ComplexTransform.conj, ComplexTransform.conj)
1774
+ ]
1775
+
1776
+ conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
1777
+ CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints)
1778
+ #
1779
+
1780
+ def GenerateSM75(manifest, cuda_version):
1781
+ GenerateSM75_TensorOp_1688(manifest, cuda_version)
1782
+ GenerateSM75_PlanarComplexTensorOp_1688(manifest, cuda_version)
1783
+ GenerateSM75_TensorOp_8816_TN(manifest, cuda_version)
1784
+ GenerateSM75_TensorOp_8816_Interleaved(manifest, cuda_version)
1785
+ GenerateSM75_TensorOp_8832_TN(manifest, cuda_version)
1786
+ GenerateSM75_TensorOp_8832_Interleaved(manifest, cuda_version)
1787
+ GenerateSM75_TensorOp_88128(manifest, cuda_version)
1788
+ #GenerateSM75_WmmaTensorOp_161616(manifest, cuda_version)
1789
+ GenerateSM75_Simt_complex(manifest, cuda_version)
1790
+
1791
+
1792
+ ###################################################################################################
1793
+ ###################################################################################################
1794
+
1795
+ #
1796
+ def GenerateSM80_TensorOp_16816(manifest, cuda_version):
1797
+
1798
+ if not CudaToolkitVersionSatisfies(cuda_version, 11, 0):
1799
+ return
1800
+
1801
+ layouts = [
1802
+ (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
1803
+ (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
1804
+ (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
1805
+ (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
1806
+ ]
1807
+
1808
+ math_instructions = [
1809
+ MathInstruction( \
1810
+ [16, 8, 16], \
1811
+ DataType.f16, DataType.f16, DataType.f32, \
1812
+ OpcodeClass.TensorOp, \
1813
+ MathOperation.multiply_add),
1814
+ MathInstruction( \
1815
+ [16, 8, 16], \
1816
+ DataType.f16, DataType.f16, DataType.f16, \
1817
+ OpcodeClass.TensorOp, \
1818
+ MathOperation.multiply_add),
1819
+ MathInstruction( \
1820
+ [16, 8, 16], \
1821
+ DataType.bf16, DataType.bf16, DataType.f32, \
1822
+ OpcodeClass.TensorOp, \
1823
+ MathOperation.multiply_add),
1824
+ ]
1825
+
1826
+ min_cc = 80
1827
+ max_cc = 1024
1828
+
1829
+ alignment_constraints = [8, 4, 2]
1830
+
1831
+ for math_inst in math_instructions:
1832
+ tile_descriptions = [
1833
+ TileDescription([256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc),
1834
+ TileDescription([128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc),
1835
+ TileDescription([256, 64, 32], 3, [4, 1, 1], math_inst, min_cc, max_cc),
1836
+ TileDescription([256, 64, 32], 4, [4, 1, 1], math_inst, min_cc, max_cc),
1837
+ TileDescription([ 64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc),
1838
+ TileDescription([128, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc),
1839
+ TileDescription([128, 128, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc),
1840
+ TileDescription([128, 128, 32], 5, [2, 2, 1], math_inst, min_cc, max_cc),
1841
+ TileDescription([128, 64, 32], 6, [2, 2, 1], math_inst, min_cc, max_cc),
1842
+ TileDescription([ 64, 128, 32], 6, [2, 2, 1], math_inst, min_cc, max_cc),
1843
+ TileDescription([ 64, 64, 32], 10, [2, 2, 1], math_inst, min_cc, max_cc),
1844
+ TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc),
1845
+ TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc),
1846
+ TileDescription([256, 64, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc),
1847
+ TileDescription([ 64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc),
1848
+ TileDescription([128, 128, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc),
1849
+ TileDescription([256, 64, 64], 3, [4, 1, 1], math_inst, min_cc, max_cc),
1850
+ TileDescription([ 64, 256, 64], 3, [1, 4, 1], math_inst, min_cc, max_cc),
1851
+ TileDescription([128, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc),
1852
+ TileDescription([128, 64, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc),
1853
+ TileDescription([ 64, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc),
1854
+ TileDescription([ 64, 64, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc),
1855
+ ]
1856
+
1857
+ data_type = [
1858
+ math_inst.element_a,
1859
+ math_inst.element_b,
1860
+ math_inst.element_accumulator,
1861
+ math_inst.element_accumulator,
1862
+ ]
1863
+
1864
+ CreateGemmOperator(manifest, layouts, tile_descriptions, \
1865
+ data_type, alignment_constraints)
1866
+
1867
+ CreateGemmGroupedOperator(manifest, layouts, tile_descriptions, data_type, alignment_constraints)
1868
+
1869
+ conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
1870
+ CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints)
1871
+ CreateConv2dFixedChannelsOperator(manifest, conv_layout, tile_descriptions, data_type, [4, 8])
1872
+ CreateConv3dOperator(manifest, LayoutType.TensorNDHWC, tile_descriptions, data_type, 8)
1873
+
1874
+ # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation)
1875
+ if math_inst.element_a != math_inst.element_accumulator:
1876
+
1877
+ data_type_mixed = [
1878
+ math_inst.element_a,
1879
+ math_inst.element_b,
1880
+ math_inst.element_a,
1881
+ math_inst.element_accumulator,
1882
+ ]
1883
+
1884
+ CreateGemmOperator(manifest, layouts, tile_descriptions, \
1885
+ data_type_mixed, alignment_constraints)
1886
+
1887
+ CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, alignment_constraints)
1888
+ CreateConv2dFixedChannelsOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, [4, 8])
1889
+ CreateConv3dOperator(manifest, LayoutType.TensorNDHWC, tile_descriptions, data_type_mixed, 8)
1890
+ #
1891
+
1892
+ #
1893
+ def GenerateSM80_SparseTensorOp_16832(manifest, cuda_version):
1894
+
1895
+ if not CudaToolkitVersionSatisfies(cuda_version, 11, 1):
1896
+ return
1897
+
1898
+ layouts = [
1899
+ (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.RowMajor),
1900
+ (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.RowMajor),
1901
+ (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor),
1902
+ (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor),
1903
+ ]
1904
+
1905
+ math_instructions = [
1906
+ MathInstruction( \
1907
+ [16, 8, 32], \
1908
+ DataType.f16, DataType.f16, DataType.f32, \
1909
+ OpcodeClass.TensorOp, \
1910
+ MathOperation.multiply_add),
1911
+ MathInstruction( \
1912
+ [16, 8, 32], \
1913
+ DataType.f16, DataType.f16, DataType.f16, \
1914
+ OpcodeClass.TensorOp, \
1915
+ MathOperation.multiply_add),
1916
+ MathInstruction( \
1917
+ [16, 8, 32], \
1918
+ DataType.bf16, DataType.bf16, DataType.f32, \
1919
+ OpcodeClass.TensorOp, \
1920
+ MathOperation.multiply_add),
1921
+ ]
1922
+
1923
+ min_cc = 80
1924
+ max_cc = 1024
1925
+
1926
+ alignment_constraints = [8]
1927
+
1928
+ for math_inst in math_instructions:
1929
+ tile_descriptions = [
1930
+ TileDescription([ 64, 128, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc),
1931
+ TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc),
1932
+ TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc),
1933
+ TileDescription([128, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc),
1934
+ TileDescription([256, 64, 64], 3, [4, 1, 1], math_inst, min_cc, max_cc),
1935
+ TileDescription([ 64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc),
1936
+ TileDescription([128, 64, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc),
1937
+ TileDescription([ 64, 64, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc),
1938
+ TileDescription([128, 128, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc),
1939
+ TileDescription([256, 64, 128], 3, [4, 1, 1], math_inst, min_cc, max_cc),
1940
+ TileDescription([128, 64, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc),
1941
+ TileDescription([ 64, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc),
1942
+ TileDescription([ 64, 64, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc),
1943
+ ]
1944
+
1945
+ data_type = [
1946
+ math_inst.element_a,
1947
+ math_inst.element_b,
1948
+ math_inst.element_accumulator,
1949
+ math_inst.element_accumulator,
1950
+ ]
1951
+
1952
+ CreateSparseGemmOperator(manifest, layouts, tile_descriptions, \
1953
+ data_type, alignment_constraints)
1954
+
1955
+ # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation)
1956
+ if math_inst.element_a != math_inst.element_accumulator:
1957
+
1958
+ data_type_mixed = [
1959
+ math_inst.element_a,
1960
+ math_inst.element_b,
1961
+ math_inst.element_a,
1962
+ math_inst.element_accumulator,
1963
+ ]
1964
+
1965
+ CreateSparseGemmOperator(manifest, layouts, tile_descriptions, \
1966
+ data_type_mixed, alignment_constraints)
1967
+
1968
+ #
1969
+
1970
+ #
1971
+ def GenerateSM80_PlanarComplexTensorOp_16816(manifest, cuda_version):
1972
+
1973
+ if not CudaToolkitVersionSatisfies(cuda_version, 11, 0):
1974
+ return
1975
+
1976
+ layouts = [
1977
+ (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
1978
+ (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
1979
+ (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
1980
+ (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
1981
+ ]
1982
+
1983
+ complex_transforms = [
1984
+ (ComplexTransform.none, ComplexTransform.none),
1985
+ (ComplexTransform.conj, ComplexTransform.none),
1986
+ (ComplexTransform.none, ComplexTransform.conj),
1987
+ (ComplexTransform.conj, ComplexTransform.conj)
1988
+ ]
1989
+
1990
+ math_instructions = [
1991
+ MathInstruction( \
1992
+ [16, 8, 16], \
1993
+ DataType.f16, DataType.f16, DataType.f32, \
1994
+ OpcodeClass.TensorOp, \
1995
+ MathOperation.multiply_add),
1996
+ MathInstruction( \
1997
+ [16, 8, 16], \
1998
+ DataType.bf16, DataType.bf16, DataType.f32, \
1999
+ OpcodeClass.TensorOp, \
2000
+ MathOperation.multiply_add),
2001
+ MathInstruction( \
2002
+ [16, 8, 16], \
2003
+ DataType.f16, DataType.f16, DataType.f16, \
2004
+ OpcodeClass.TensorOp, \
2005
+ MathOperation.multiply_add),
2006
+ ]
2007
+
2008
+ min_cc = 80
2009
+ max_cc = 1024
2010
+
2011
+ alignment_constraints = [8, ]
2012
+
2013
+ for math_inst in math_instructions:
2014
+ tile_descriptions = [
2015
+ TileDescription([ 64, 128, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc),
2016
+ TileDescription([128, 64, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc),
2017
+ TileDescription([ 64, 64, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc),
2018
+ ]
2019
+
2020
+ data_type = [
2021
+ math_inst.element_a,
2022
+ math_inst.element_b,
2023
+ math_inst.element_accumulator,
2024
+ math_inst.element_accumulator,
2025
+ ]
2026
+
2027
+ CreateGemmPlanarComplexOperator(manifest, layouts, tile_descriptions, \
2028
+ data_type, alignment_constraints, complex_transforms)
2029
+
2030
+ # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation)
2031
+ if math_inst.element_a != math_inst.element_accumulator:
2032
+
2033
+ data_type_mixed = [
2034
+ math_inst.element_a,
2035
+ math_inst.element_b,
2036
+ math_inst.element_a,
2037
+ math_inst.element_accumulator,
2038
+ ]
2039
+
2040
+ CreateGemmPlanarComplexOperator(manifest, layouts, tile_descriptions, \
2041
+ data_type_mixed, alignment_constraints, complex_transforms)
2042
+
2043
+ #
2044
+ def GenerateSM80_TensorOp_16832_TN(manifest, cuda_version):
2045
+
2046
+ if not CudaToolkitVersionSatisfies(cuda_version, 11, 0):
2047
+ return
2048
+
2049
+ layouts = [
2050
+ (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
2051
+ ]
2052
+
2053
+ math_instructions = [
2054
+ MathInstruction( \
2055
+ [16, 8, 32], \
2056
+ DataType.s8, DataType.s8, DataType.s32, \
2057
+ OpcodeClass.TensorOp, \
2058
+ MathOperation.multiply_add_saturate),
2059
+ MathInstruction( \
2060
+ [16, 8, 32], \
2061
+ DataType.u8, DataType.u8, DataType.s32, \
2062
+ OpcodeClass.TensorOp, \
2063
+ MathOperation.multiply_add_saturate),
2064
+ ]
2065
+
2066
+ min_cc = 80
2067
+ max_cc = 1024
2068
+ smem_usage = 164
2069
+
2070
+ alignment_constraints = [16,]
2071
+
2072
+ for math_inst in math_instructions:
2073
+ tile_descriptions = [
2074
+ TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc),
2075
+ TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc),
2076
+ TileDescription([256, 64, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc),
2077
+ TileDescription([ 64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc),
2078
+ TileDescription([128, 128, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc),
2079
+ TileDescription([128, 64, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc),
2080
+ TileDescription([ 64, 128, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc),
2081
+ TileDescription([ 64, 64, 64], 10, [2, 2, 1], math_inst, min_cc, max_cc),
2082
+ TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc),
2083
+ TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc),
2084
+ TileDescription([256, 64, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc),
2085
+ TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc),
2086
+ TileDescription([128, 128, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc),
2087
+ TileDescription([128, 64, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc),
2088
+ TileDescription([ 64, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc),
2089
+ TileDescription([ 64, 64, 128], 5, [2, 2, 1], math_inst, min_cc, max_cc),
2090
+ ]
2091
+
2092
+ data_type = [math_inst.element_a, math_inst.element_b, math_inst.element_accumulator, DataType.s32]
2093
+ data_type_mixed = [math_inst.element_a, math_inst.element_b, math_inst.element_a, DataType.f32]
2094
+
2095
+ CreateGemmOperator(manifest, layouts, tile_descriptions, \
2096
+ data_type, alignment_constraints, None, EpilogueFunctor.LinearCombination)
2097
+
2098
+ operations = []
2099
+
2100
+ operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \
2101
+ data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp)
2102
+
2103
+ conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
2104
+ CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
2105
+ data_type, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombination)
2106
+
2107
+ operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
2108
+ data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
2109
+
2110
+ for op in operations:
2111
+ if op.tile_description.threadblock_shape[1] >= 128:
2112
+ op.C.alignment = 16
2113
+ else:
2114
+ op.C.alignment = 8
2115
+
2116
+ #
2117
+
2118
+ #
2119
+ def GenerateSM80_SparseTensorOp_16864_TN(manifest, cuda_version):
2120
+
2121
+ if not CudaToolkitVersionSatisfies(cuda_version, 11, 1):
2122
+ return
2123
+
2124
+ layouts = [
2125
+ (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor),
2126
+ ]
2127
+
2128
+ math_inst = \
2129
+ MathInstruction( \
2130
+ [16, 8, 64], \
2131
+ DataType.s8, DataType.s8, DataType.s32, \
2132
+ OpcodeClass.TensorOp, \
2133
+ MathOperation.multiply_add_saturate)
2134
+
2135
+ min_cc = 80
2136
+ max_cc = 1024
2137
+
2138
+ alignment_constraints = [16,]
2139
+
2140
+ tile_descriptions = [
2141
+ TileDescription([128, 64, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc),
2142
+ TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc),
2143
+ TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc),
2144
+ TileDescription([128, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc),
2145
+ TileDescription([256, 64, 128], 3, [4, 1, 1], math_inst, min_cc, max_cc),
2146
+ TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc),
2147
+ TileDescription([ 64, 128, 128], 6, [2, 2, 1], math_inst, min_cc, max_cc),
2148
+ TileDescription([ 64, 64, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc),
2149
+ TileDescription([128, 128, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc),
2150
+ TileDescription([128, 64, 256], 4, [2, 2, 1], math_inst, min_cc, max_cc),
2151
+ TileDescription([ 64, 128, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc),
2152
+ TileDescription([ 64, 64, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc),
2153
+ ]
2154
+
2155
+ data_type = [DataType.s8, DataType.s8, DataType.s32, DataType.s32]
2156
+ data_type_mixed = [DataType.s8, DataType.s8, DataType.s8, DataType.f32]
2157
+
2158
+ CreateSparseGemmOperator(manifest, layouts, tile_descriptions, \
2159
+ data_type, alignment_constraints, None, EpilogueFunctor.LinearCombination)
2160
+
2161
+ operations = []
2162
+
2163
+ operations += CreateSparseGemmOperator(manifest, layouts, tile_descriptions, \
2164
+ data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp)
2165
+
2166
+ for op in operations:
2167
+ if op.tile_description.threadblock_shape[1] >= 128:
2168
+ op.C.alignment = 16
2169
+ else:
2170
+ op.C.alignment = 8
2171
+ #
2172
+
2173
+ #
2174
+ def GenerateSM80_TensorOp_16832_Interleaved(manifest, cuda_version):
2175
+
2176
+ if not CudaToolkitVersionSatisfies(cuda_version, 11, 0):
2177
+ return
2178
+
2179
+ layouts = [
2180
+ (LayoutType.ColumnMajorInterleaved32, LayoutType.RowMajorInterleaved32, LayoutType.ColumnMajorInterleaved32),
2181
+ ]
2182
+
2183
+ math_instructions = [
2184
+ MathInstruction( \
2185
+ [16, 8, 32], \
2186
+ DataType.s8, DataType.s8, DataType.s32, \
2187
+ OpcodeClass.TensorOp, \
2188
+ MathOperation.multiply_add_saturate),
2189
+ MathInstruction( \
2190
+ [16, 8, 32], \
2191
+ DataType.u8, DataType.u8, DataType.s32, \
2192
+ OpcodeClass.TensorOp, \
2193
+ MathOperation.multiply_add_saturate),
2194
+ ]
2195
+
2196
+ min_cc = 80
2197
+ max_cc = 1024
2198
+
2199
+ alignment_constraints = [16,]
2200
+
2201
+ for math_inst in math_instructions:
2202
+ tile_descriptions = [
2203
+ TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc),
2204
+ TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc),
2205
+ TileDescription([256, 64, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc),
2206
+ TileDescription([ 64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc),
2207
+ TileDescription([128, 128, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc),
2208
+ TileDescription([128, 64, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc),
2209
+ TileDescription([ 64, 128, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc),
2210
+ TileDescription([ 64, 64, 64], 10, [2, 2, 1], math_inst, min_cc, max_cc),
2211
+ ]
2212
+
2213
+ data_type_mixed = [math_inst.element_a, math_inst.element_b, math_inst.element_a, DataType.f32]
2214
+
2215
+ operations = CreateGemmOperator(manifest, layouts, tile_descriptions, \
2216
+ data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp)
2217
+
2218
+ conv_layout = (LayoutType.TensorNC32HW32, LayoutType.TensorC32RSK32, LayoutType.TensorNC32HW32)
2219
+
2220
+ operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
2221
+ data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
2222
+
2223
+ for op in operations:
2224
+ op.C.alignment = 8
2225
+ #
2226
+
2227
+ #
2228
+ def GenerateSM80_TensorOp_16864_TN(manifest, cuda_version):
2229
+
2230
+ if not CudaToolkitVersionSatisfies(cuda_version, 11, 0):
2231
+ return
2232
+
2233
+ layouts = [
2234
+ (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
2235
+ ]
2236
+
2237
+ math_instructions = [
2238
+ MathInstruction( \
2239
+ [16, 8, 64], \
2240
+ DataType.s4, DataType.s4, DataType.s32, \
2241
+ OpcodeClass.TensorOp, \
2242
+ MathOperation.multiply_add_saturate),
2243
+ MathInstruction( \
2244
+ [16, 8, 64], \
2245
+ DataType.u4, DataType.u4, DataType.s32, \
2246
+ OpcodeClass.TensorOp, \
2247
+ MathOperation.multiply_add_saturate),
2248
+ ]
2249
+
2250
+ min_cc = 80
2251
+ max_cc = 1024
2252
+ alignment_constraints = [32,]
2253
+
2254
+ for math_inst in math_instructions:
2255
+ tile_descriptions = [
2256
+ TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc),
2257
+ TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc),
2258
+ TileDescription([256, 64, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc),
2259
+ TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc),
2260
+ TileDescription([128, 128, 128], 5, [2, 2, 1], math_inst, min_cc, max_cc),
2261
+ TileDescription([128, 64, 128], 6, [2, 2, 1], math_inst, min_cc, max_cc),
2262
+ TileDescription([ 64, 128, 128], 6, [2, 2, 1], math_inst, min_cc, max_cc),
2263
+ TileDescription([ 64, 64, 128], 10, [2, 2, 1], math_inst, min_cc, max_cc),
2264
+ TileDescription([256, 128, 256], 3, [4, 2, 1], math_inst, min_cc, max_cc),
2265
+ TileDescription([128, 256, 256], 3, [2, 4, 1], math_inst, min_cc, max_cc),
2266
+ TileDescription([256, 64, 256], 4, [4, 1, 1], math_inst, min_cc, max_cc),
2267
+ TileDescription([ 64, 256, 256], 4, [1, 4, 1], math_inst, min_cc, max_cc),
2268
+ TileDescription([128, 128, 256], 4, [2, 2, 1], math_inst, min_cc, max_cc),
2269
+ TileDescription([128, 128, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc),
2270
+ TileDescription([128, 64, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc),
2271
+ TileDescription([ 64, 128, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc),
2272
+ TileDescription([ 64, 64, 256], 5, [2, 2, 1], math_inst, min_cc, max_cc),
2273
+ ]
2274
+
2275
+ data_type = [math_inst.element_a, math_inst.element_b, math_inst.element_accumulator, DataType.s32]
2276
+ data_type_mixed = [math_inst.element_a, math_inst.element_b, math_inst.element_a, DataType.f32]
2277
+
2278
+ CreateGemmOperator(manifest, layouts, tile_descriptions, \
2279
+ data_type, alignment_constraints, None, EpilogueFunctor.LinearCombination)
2280
+
2281
+ operations = []
2282
+
2283
+ operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \
2284
+ data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp)
2285
+
2286
+ conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
2287
+ CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
2288
+ data_type, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombination)
2289
+
2290
+ operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
2291
+ data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
2292
+
2293
+ for op in operations:
2294
+ if op.tile_description.threadblock_shape[1] >= 128:
2295
+ op.C.alignment = 16
2296
+ elif op.tile_description.threadblock_shape[1] == 64:
2297
+ op.C.alignment = 8
2298
+ else:
2299
+ op.C.alignment = 8
2300
+ #
2301
+
2302
+ #
2303
+ def GenerateSM80_SparseTensorOp_168128_TN(manifest, cuda_version):
2304
+
2305
+ if not CudaToolkitVersionSatisfies(cuda_version, 11, 1):
2306
+ return
2307
+
2308
+ layouts = [
2309
+ (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor),
2310
+ ]
2311
+
2312
+ math_inst = \
2313
+ MathInstruction( \
2314
+ [16, 8, 128], \
2315
+ DataType.s4, DataType.s4, DataType.s32, \
2316
+ OpcodeClass.TensorOp, \
2317
+ MathOperation.multiply_add_saturate)
2318
+
2319
+ min_cc = 80
2320
+ max_cc = 1024
2321
+ alignment_constraints = [32,]
2322
+
2323
+ tile_descriptions = [
2324
+ TileDescription([ 64, 64, 256], 4, [2, 2, 1], math_inst, min_cc, max_cc),
2325
+ TileDescription([256, 64, 256], 3, [4, 1, 1], math_inst, min_cc, max_cc),
2326
+ TileDescription([256, 128, 256], 3, [4, 2, 1], math_inst, min_cc, max_cc),
2327
+ TileDescription([128, 256, 256], 3, [2, 4, 1], math_inst, min_cc, max_cc),
2328
+ TileDescription([128, 128, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc),
2329
+ TileDescription([ 64, 256, 256], 4, [1, 4, 1], math_inst, min_cc, max_cc),
2330
+ TileDescription([128, 64, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc),
2331
+ TileDescription([ 64, 128, 256], 6, [2, 2, 1], math_inst, min_cc, max_cc),
2332
+ TileDescription([128, 128, 512], 3, [2, 4, 1], math_inst, min_cc, max_cc),
2333
+ TileDescription([128, 64, 512], 4, [2, 2, 1], math_inst, min_cc, max_cc),
2334
+ TileDescription([ 64, 128, 512], 3, [2, 2, 1], math_inst, min_cc, max_cc),
2335
+ TileDescription([ 64, 64, 512], 3, [2, 2, 1], math_inst, min_cc, max_cc),
2336
+ ]
2337
+
2338
+ data_type = [DataType.s4, DataType.s4, DataType.s32, DataType.s32]
2339
+ data_type_mixed = [DataType.s4, DataType.s4, DataType.s4, DataType.f32]
2340
+
2341
+ CreateSparseGemmOperator(manifest, layouts, tile_descriptions, \
2342
+ data_type, alignment_constraints, None, EpilogueFunctor.LinearCombination)
2343
+
2344
+ operations = []
2345
+
2346
+ operations += CreateSparseGemmOperator(manifest, layouts, tile_descriptions, \
2347
+ data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp)
2348
+
2349
+ for op in operations:
2350
+ if op.tile_description.threadblock_shape[1] > 128:
2351
+ op.C.alignment = 16
2352
+ else:
2353
+ op.C.alignment = 8
2354
+ #
2355
+
2356
+ #
2357
+ def GenerateSM80_TensorOp_16864_Interleaved(manifest, cuda_version):
2358
+
2359
+ if not CudaToolkitVersionSatisfies(cuda_version, 11, 0):
2360
+ return
2361
+
2362
+ layouts = [
2363
+ (LayoutType.ColumnMajorInterleaved64, LayoutType.RowMajorInterleaved64, LayoutType.ColumnMajorInterleaved64),
2364
+ ]
2365
+
2366
+ math_instructions = [
2367
+ MathInstruction( \
2368
+ [16, 8, 64], \
2369
+ DataType.s4, DataType.s4, DataType.s32, \
2370
+ OpcodeClass.TensorOp, \
2371
+ MathOperation.multiply_add_saturate),
2372
+ MathInstruction( \
2373
+ [16, 8, 64], \
2374
+ DataType.u4, DataType.u4, DataType.s32, \
2375
+ OpcodeClass.TensorOp, \
2376
+ MathOperation.multiply_add_saturate),
2377
+ ]
2378
+
2379
+ min_cc = 80
2380
+ max_cc = 1024
2381
+ alignment_constraints = [32,]
2382
+
2383
+ for math_inst in math_instructions:
2384
+ tile_descriptions = [
2385
+ TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc),
2386
+ TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc),
2387
+ TileDescription([256, 64, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc),
2388
+ TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc),
2389
+ TileDescription([128, 128, 128], 5, [2, 2, 1], math_inst, min_cc, max_cc),
2390
+ TileDescription([ 64, 128, 128], 6, [2, 2, 1], math_inst, min_cc, max_cc),
2391
+ ]
2392
+
2393
+ data_type_mixed = [math_inst.element_a, math_inst.element_b, math_inst.element_a, DataType.f32]
2394
+
2395
+ operations = []
2396
+
2397
+ operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \
2398
+ data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp)
2399
+
2400
+ conv_layout = (LayoutType.TensorNC64HW64, LayoutType.TensorC64RSK64, LayoutType.TensorNC64HW64)
2401
+
2402
+ operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
2403
+ data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
2404
+
2405
+ for op in operations:
2406
+ op.C.alignment = 16
2407
+ #
2408
+
2409
+ #
2410
+ def GenerateSM80_TensorOp_168256(manifest, cuda_version):
2411
+
2412
+ if not CudaToolkitVersionSatisfies(cuda_version, 11, 0):
2413
+ return
2414
+
2415
+ layouts = [
2416
+ (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
2417
+ ]
2418
+
2419
+ math_instructions = [
2420
+ MathInstruction( \
2421
+ [16, 8, 256], \
2422
+ DataType.b1, DataType.b1, DataType.s32, \
2423
+ OpcodeClass.TensorOp, \
2424
+ MathOperation.xor_popc),
2425
+ ]
2426
+
2427
+ min_cc = 80
2428
+ max_cc = {
2429
+ MathOperation.xor_popc: 1024
2430
+ }
2431
+
2432
+ alignment_constraints = [128,]
2433
+
2434
+ for math_inst in math_instructions:
2435
+ tile_descriptions = [
2436
+ TileDescription([256, 128, 512], 3, [4, 2, 1], math_inst, min_cc, max_cc[math_inst.math_operation]),
2437
+ TileDescription([128, 256, 512], 3, [2, 4, 1], math_inst, min_cc, max_cc[math_inst.math_operation]),
2438
+ TileDescription([256, 64, 512], 4, [4, 1, 1], math_inst, min_cc, max_cc[math_inst.math_operation]),
2439
+ TileDescription([ 64, 256, 512], 4, [1, 4, 1], math_inst, min_cc, max_cc[math_inst.math_operation]),
2440
+ TileDescription([128, 128, 512], 5, [2, 2, 1], math_inst, min_cc, max_cc[math_inst.math_operation]),
2441
+ TileDescription([128, 64, 512], 6, [2, 2, 1], math_inst, min_cc, max_cc[math_inst.math_operation]),
2442
+ TileDescription([ 64, 128, 512], 6, [2, 2, 1], math_inst, min_cc, max_cc[math_inst.math_operation]),
2443
+ TileDescription([ 64, 64, 512], 10, [2, 2, 1], math_inst, min_cc, max_cc[math_inst.math_operation]),
2444
+ TileDescription([256, 128, 1024], 3, [4, 2, 1], math_inst, min_cc, max_cc[math_inst.math_operation]),
2445
+ TileDescription([128, 256, 1024], 3, [2, 4, 1], math_inst, min_cc, max_cc[math_inst.math_operation]),
2446
+ TileDescription([256, 64, 1024], 4, [4, 1, 1], math_inst, min_cc, max_cc[math_inst.math_operation]),
2447
+ TileDescription([ 64, 256, 1024], 4, [1, 4, 1], math_inst, min_cc, max_cc[math_inst.math_operation]),
2448
+ TileDescription([128, 128, 1024], 4, [2, 2, 1], math_inst, min_cc, max_cc[math_inst.math_operation]),
2449
+ TileDescription([128, 64, 1024], 3, [2, 2, 1], math_inst, min_cc, max_cc[math_inst.math_operation]),
2450
+ TileDescription([ 64, 128, 1024], 3, [2, 2, 1], math_inst, min_cc, max_cc[math_inst.math_operation]),
2451
+ TileDescription([ 64, 64, 1024], 5, [2, 2, 1], math_inst, min_cc, max_cc[math_inst.math_operation]),
2452
+ ]
2453
+
2454
+ data_type = [DataType.b1, DataType.b1, DataType.s32, DataType.s32]
2455
+
2456
+ CreateGemmOperator(manifest, layouts, tile_descriptions, \
2457
+ data_type, alignment_constraints)
2458
+
2459
+ #
2460
+
2461
+ #
2462
+ def GenerateSM80_TensorOp_1688(manifest, cuda_version):
2463
+
2464
+ if not CudaToolkitVersionSatisfies(cuda_version, 11, 0):
2465
+ return
2466
+
2467
+ layouts = [
2468
+ (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
2469
+ (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
2470
+ (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
2471
+ (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
2472
+ ]
2473
+
2474
+ math_instructions = [
2475
+ MathInstruction( \
2476
+ [16, 8, 8], \
2477
+ DataType.tf32, DataType.tf32, DataType.f32, \
2478
+ OpcodeClass.TensorOp, \
2479
+ MathOperation.multiply_add)
2480
+ ]
2481
+
2482
+ min_cc = 80
2483
+ max_cc = 1024
2484
+
2485
+ alignment_constraints = [4, 2, 1]
2486
+
2487
+ for math_inst in math_instructions:
2488
+ tile_descriptions = [
2489
+ TileDescription([256, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc),
2490
+ TileDescription([128, 256, 16], 3, [2, 4, 1], math_inst, min_cc, max_cc),
2491
+ TileDescription([256, 64, 16], 4, [4, 1, 1], math_inst, min_cc, max_cc),
2492
+ TileDescription([ 64, 256, 16], 4, [1, 4, 1], math_inst, min_cc, max_cc),
2493
+ TileDescription([128, 128, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc),
2494
+ TileDescription([128, 128, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc),
2495
+ TileDescription([128, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc),
2496
+ TileDescription([128, 64, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc),
2497
+ TileDescription([ 64, 128, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc),
2498
+ TileDescription([ 64, 64, 16], 10, [2, 2, 1], math_inst, min_cc, max_cc),
2499
+ TileDescription([256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc),
2500
+ TileDescription([128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc),
2501
+ TileDescription([256, 64, 32], 4, [4, 1, 1], math_inst, min_cc, max_cc),
2502
+ TileDescription([ 64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc),
2503
+ TileDescription([128, 128, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc),
2504
+ TileDescription([128, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc),
2505
+ TileDescription([128, 64, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc),
2506
+ TileDescription([64, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc),
2507
+ TileDescription([ 64, 64, 32], 5, [2, 2, 1], math_inst, min_cc, max_cc),
2508
+ ]
2509
+
2510
+ data_type = [
2511
+ math_inst.element_a,
2512
+ math_inst.element_b,
2513
+ math_inst.element_accumulator,
2514
+ math_inst.element_accumulator,
2515
+ ]
2516
+
2517
+ data_type_mixed = [
2518
+ math_inst.element_a,
2519
+ math_inst.element_b,
2520
+ math_inst.element_a,
2521
+ math_inst.element_accumulator,
2522
+ ]
2523
+
2524
+ CreateGemmOperator(manifest, layouts, tile_descriptions, \
2525
+ data_type, alignment_constraints)
2526
+
2527
+ CreateGemmOperator(manifest, layouts, tile_descriptions, \
2528
+ data_type_mixed, alignment_constraints)
2529
+
2530
+ conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
2531
+
2532
+ CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints)
2533
+ CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, alignment_constraints)
2534
+ #
2535
+
2536
+ #
2537
+ def GenerateSM80_TensorOp_1688_fast_math(manifest, cuda_version):
2538
+
2539
+ if not CudaToolkitVersionSatisfies(cuda_version, 11, 0):
2540
+ return
2541
+
2542
+ layouts = [
2543
+ (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
2544
+ (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
2545
+ (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
2546
+ (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
2547
+ ]
2548
+
2549
+ math_instructions = [
2550
+ MathInstruction( \
2551
+ [16, 8, 8], \
2552
+ DataType.tf32, DataType.tf32, DataType.f32, \
2553
+ OpcodeClass.TensorOp, \
2554
+ MathOperation.multiply_add),
2555
+ MathInstruction( \
2556
+ [16, 8, 8], \
2557
+ DataType.f16, DataType.f16, DataType.f32, \
2558
+ OpcodeClass.TensorOp, \
2559
+ MathOperation.multiply_add_fast_f16),
2560
+ MathInstruction( \
2561
+ [16, 8, 8], \
2562
+ DataType.bf16, DataType.bf16, DataType.f32, \
2563
+ OpcodeClass.TensorOp, \
2564
+ MathOperation.multiply_add_fast_bf16),
2565
+ ]
2566
+
2567
+ min_cc = 80
2568
+ max_cc = 1024
2569
+
2570
+ alignment_constraints = [4, 2, 1]
2571
+
2572
+ for math_inst in math_instructions:
2573
+ tile_descriptions = [
2574
+ TileDescription([256, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc),
2575
+ TileDescription([128, 256, 16], 3, [2, 4, 1], math_inst, min_cc, max_cc),
2576
+ TileDescription([256, 64, 16], 4, [4, 1, 1], math_inst, min_cc, max_cc),
2577
+ TileDescription([ 64, 256, 16], 4, [1, 4, 1], math_inst, min_cc, max_cc),
2578
+ TileDescription([128, 128, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc),
2579
+ TileDescription([128, 128, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc),
2580
+ TileDescription([128, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc),
2581
+ TileDescription([128, 64, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc),
2582
+ TileDescription([ 64, 128, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc),
2583
+ TileDescription([ 64, 64, 16], 10, [2, 2, 1], math_inst, min_cc, max_cc),
2584
+ TileDescription([256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc),
2585
+ TileDescription([128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc),
2586
+ TileDescription([256, 64, 32], 4, [4, 1, 1], math_inst, min_cc, max_cc),
2587
+ TileDescription([ 64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc),
2588
+ TileDescription([128, 128, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc),
2589
+ TileDescription([128, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc),
2590
+ TileDescription([128, 64, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc),
2591
+ TileDescription([ 64, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc),
2592
+ TileDescription([ 64, 64, 32], 5, [2, 2, 1], math_inst, min_cc, max_cc),
2593
+ ]
2594
+
2595
+ data_type = [DataType.f32, DataType.f32, DataType.f32, DataType.f32]
2596
+
2597
+ CreateGemmOperator(manifest, layouts, tile_descriptions, \
2598
+ data_type, alignment_constraints)
2599
+
2600
+ conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
2601
+ CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints)
2602
+ #
2603
+
2604
+ #
2605
+ def GenerateSM80_TensorOp_1688_fast_fp32_math(manifest, cuda_version):
2606
+
2607
+ if not CudaToolkitVersionSatisfies(cuda_version, 11, 0):
2608
+ return
2609
+
2610
+ layouts = [
2611
+ (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
2612
+ (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
2613
+ (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
2614
+ (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
2615
+ ]
2616
+
2617
+ math_instructions = [
2618
+ MathInstruction( \
2619
+ [16, 8, 8], \
2620
+ DataType.f32, DataType.f32, DataType.f32, \
2621
+ OpcodeClass.TensorOp, \
2622
+ MathOperation.multiply_add_fast_f32),
2623
+ ]
2624
+
2625
+ min_cc = 80
2626
+ max_cc = 1024
2627
+
2628
+ alignment_constraints = [4, 2, 1]
2629
+
2630
+ for math_inst in math_instructions:
2631
+ tile_descriptions = [
2632
+ TileDescription([128, 128, 16], 4, [4, 2, 1], math_inst, min_cc, max_cc),
2633
+ TileDescription([128, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc),
2634
+ TileDescription([256, 64, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc),
2635
+ TileDescription([ 64, 256, 16], 3, [2, 4, 1], math_inst, min_cc, max_cc),
2636
+ TileDescription([128, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc),
2637
+ TileDescription([ 64, 128, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc),
2638
+ TileDescription([ 64, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc),
2639
+ TileDescription([128, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc),
2640
+ TileDescription([256, 64, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc),
2641
+ TileDescription([ 64, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc),
2642
+ TileDescription([128, 64, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc),
2643
+ TileDescription([ 64, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc),
2644
+ TileDescription([ 64, 64, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc),
2645
+ ]
2646
+
2647
+ data_type = [DataType.f32, DataType.f32, DataType.f32, DataType.f32]
2648
+
2649
+ CreateGemmOperator(manifest, layouts, tile_descriptions, \
2650
+ data_type, alignment_constraints)
2651
+
2652
+ conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
2653
+ CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints)
2654
+ #
2655
+
2656
+ def GenerateSM80_TensorOp_1688_fast_fp32_math_complex(manifest, cuda_version):
2657
+
2658
+ if not CudaToolkitVersionSatisfies(cuda_version, 11, 0):
2659
+ return
2660
+
2661
+ layouts = [
2662
+ (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
2663
+ (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
2664
+ (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
2665
+ (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
2666
+ ]
2667
+
2668
+ math_inst = MathInstruction( \
2669
+ [16, 8, 8], \
2670
+ DataType.f32, DataType.f32, DataType.f32, \
2671
+ OpcodeClass.TensorOp, \
2672
+ MathOperation.multiply_add_complex_fast_f32)
2673
+
2674
+ min_cc = 80
2675
+ max_cc = 1024
2676
+
2677
+ tile_descriptions = [
2678
+ TileDescription([128, 64, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc),
2679
+ TileDescription([64, 128, 16], 3, [2, 4, 1], math_inst, min_cc, max_cc),
2680
+ TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc),
2681
+ TileDescription([64, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc),
2682
+ TileDescription([32, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc),
2683
+ TileDescription([32, 32, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc),
2684
+ ]
2685
+
2686
+ data_type = [
2687
+ DataType.cf32, DataType.cf32, DataType.cf32, DataType.cf32
2688
+ ]
2689
+
2690
+ alignment_constraints = [1,]
2691
+
2692
+ complex_transforms = [
2693
+ (ComplexTransform.none, ComplexTransform.none),
2694
+ (ComplexTransform.conj, ComplexTransform.none),
2695
+ (ComplexTransform.none, ComplexTransform.conj),
2696
+ (ComplexTransform.conj, ComplexTransform.conj)
2697
+ ]
2698
+
2699
+ CreateGemmOperator(manifest, layouts, tile_descriptions, \
2700
+ data_type, alignment_constraints, complex_transforms)
2701
+
2702
+
2703
+ #
2704
+ def GenerateSM80_SparseTensorOp_16816_fast_math(manifest, cuda_version):
2705
+
2706
+ if not CudaToolkitVersionSatisfies(cuda_version, 11, 1):
2707
+ return
2708
+
2709
+ layouts = [
2710
+ (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.RowMajor),
2711
+ (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.RowMajor),
2712
+ (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor),
2713
+ (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor),
2714
+ ]
2715
+
2716
+ math_instructions = [
2717
+ MathInstruction( \
2718
+ [16, 8, 16], \
2719
+ DataType.tf32, DataType.tf32, DataType.f32, \
2720
+ OpcodeClass.TensorOp, \
2721
+ MathOperation.multiply_add),
2722
+ ]
2723
+
2724
+ min_cc = 80
2725
+ max_cc = 1024
2726
+
2727
+ alignment_constraints = [4]
2728
+
2729
+ for math_inst in math_instructions:
2730
+ tile_descriptions = [
2731
+ TileDescription([128, 64, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc),
2732
+ TileDescription([128, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc),
2733
+ TileDescription([256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc),
2734
+ TileDescription([128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc),
2735
+ TileDescription([256, 64, 32], 3, [4, 1, 1], math_inst, min_cc, max_cc),
2736
+ TileDescription([ 64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc),
2737
+ TileDescription([ 64, 128, 32], 6, [2, 2, 1], math_inst, min_cc, max_cc),
2738
+ TileDescription([ 64, 64, 32], 6, [2, 2, 1], math_inst, min_cc, max_cc),
2739
+ TileDescription([128, 128, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc),
2740
+ TileDescription([256, 64, 64], 3, [4, 1, 1], math_inst, min_cc, max_cc),
2741
+ TileDescription([128, 64, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc),
2742
+ TileDescription([ 64, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc),
2743
+ TileDescription([ 64, 64, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc),
2744
+ ]
2745
+
2746
+ data_type = [DataType.f32, DataType.f32, DataType.f32, DataType.f32]
2747
+
2748
+ CreateSparseGemmOperator(manifest, layouts, tile_descriptions, \
2749
+ data_type, alignment_constraints)
2750
+ #
2751
+
2752
+ #
2753
+ def GenerateSM80_TensorOp_1688_complex(manifest, cuda_version):
2754
+
2755
+ if not CudaToolkitVersionSatisfies(cuda_version, 11, 0):
2756
+ return
2757
+
2758
+ layouts = [
2759
+ (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
2760
+ (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
2761
+ (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
2762
+ (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
2763
+ ]
2764
+
2765
+ math_inst = MathInstruction( \
2766
+ [16, 8, 8], \
2767
+ DataType.tf32, DataType.tf32, DataType.f32, \
2768
+ OpcodeClass.TensorOp, \
2769
+ MathOperation.multiply_add_complex)
2770
+
2771
+ min_cc = 80
2772
+ max_cc = 1024
2773
+
2774
+ tile_descriptions = [
2775
+ TileDescription([128, 128, 16], 4, [2, 4, 1], math_inst, min_cc, max_cc),
2776
+ TileDescription([128, 64, 16], 4, [4, 2, 1], math_inst, min_cc, max_cc),
2777
+ TileDescription([64, 128, 16], 4, [2, 4, 1], math_inst, min_cc, max_cc),
2778
+ TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc),
2779
+ TileDescription([64, 32, 16], 4, [2, 1, 1], math_inst, min_cc, max_cc),
2780
+ TileDescription([32, 64, 16], 4, [1, 2, 1], math_inst, min_cc, max_cc),
2781
+ TileDescription([32, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc),
2782
+ ]
2783
+
2784
+ data_type = [
2785
+ DataType.cf32, DataType.cf32, DataType.cf32, DataType.cf32
2786
+ ]
2787
+
2788
+ alignment_constraints = [1,]
2789
+
2790
+ complex_transforms = [
2791
+ (ComplexTransform.none, ComplexTransform.none),
2792
+ (ComplexTransform.conj, ComplexTransform.none),
2793
+ (ComplexTransform.none, ComplexTransform.conj),
2794
+ (ComplexTransform.conj, ComplexTransform.conj)
2795
+ ]
2796
+
2797
+ CreateGemmOperator(manifest, layouts, tile_descriptions, \
2798
+ data_type, alignment_constraints, complex_transforms)
2799
+ #
2800
+
2801
+ #
2802
+ def GenerateSM80_TensorOp_1688_rank_k(manifest, cuda_version):
2803
+
2804
+ if not CudaToolkitVersionSatisfies(cuda_version, 11, 0):
2805
+ return
2806
+
2807
+ layouts = [
2808
+ (LayoutType.ColumnMajor, LayoutType.ColumnMajor),
2809
+ (LayoutType.RowMajor, LayoutType.ColumnMajor),
2810
+ ]
2811
+
2812
+ fill_modes = [
2813
+ FillMode.Lower, FillMode.Upper,
2814
+ ]
2815
+
2816
+ math_instructions = [
2817
+ MathInstruction( \
2818
+ [16, 8, 8], \
2819
+ DataType.tf32, DataType.tf32, DataType.f32, \
2820
+ OpcodeClass.TensorOp, \
2821
+ MathOperation.multiply_add),
2822
+ MathInstruction( \
2823
+ [16, 8, 8], \
2824
+ DataType.f32, DataType.f32, DataType.f32, \
2825
+ OpcodeClass.TensorOp, \
2826
+ MathOperation.multiply_add_fast_f32),
2827
+ ]
2828
+
2829
+ min_cc = 80
2830
+ max_cc = 1024
2831
+
2832
+ alignment_constraints = [1, 2, 4] # Alignment only applies to A in SYRK
2833
+
2834
+ for math_inst in math_instructions:
2835
+ tile_descriptions = [
2836
+ TileDescription([256, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc),
2837
+ TileDescription([128, 256, 16], 3, [2, 4, 1], math_inst, min_cc, max_cc),
2838
+ #TileDescription([256, 64, 16], 4, [4, 1, 1], math_inst, min_cc, max_cc),
2839
+ #TileDescription([ 64, 256, 16], 4, [1, 4, 1], math_inst, min_cc, max_cc),
2840
+ TileDescription([128, 128, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc),
2841
+ #TileDescription([128, 64, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc),
2842
+ #TileDescription([ 64, 128, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc),
2843
+ #TileDescription([ 64, 64, 16], 10, [2, 2, 1], math_inst, min_cc, max_cc),
2844
+ TileDescription([256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc),
2845
+ TileDescription([128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc),
2846
+ #TileDescription([256, 64, 32], 4, [4, 1, 1], math_inst, min_cc, max_cc),
2847
+ #TileDescription([ 64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc),
2848
+ TileDescription([128, 128, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc),
2849
+ #TileDescription([128, 64, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc),
2850
+ #TileDescription([ 64, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc),
2851
+ #TileDescription([ 64, 64, 32], 5, [2, 2, 1], math_inst, min_cc, max_cc),
2852
+ ]
2853
+
2854
+ data_type = [DataType.f32, DataType.f32, DataType.f32]
2855
+
2856
+ CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \
2857
+ data_type, alignment_constraints, BlasMode.symmetric)
2858
+ #
2859
+
2860
+ #
2861
+ def GenerateSM80_TensorOp_1688_rank_k_complex(manifest, cuda_version):
2862
+
2863
+ if not CudaToolkitVersionSatisfies(cuda_version, 11, 0):
2864
+ return
2865
+
2866
+ layouts = [
2867
+ (LayoutType.ColumnMajor, LayoutType.ColumnMajor),
2868
+ (LayoutType.RowMajor, LayoutType.ColumnMajor),
2869
+ ]
2870
+
2871
+ fill_modes = [
2872
+ FillMode.Lower, FillMode.Upper,
2873
+ ]
2874
+
2875
+ math_instructions = [
2876
+ MathInstruction( \
2877
+ [16, 8, 8], \
2878
+ DataType.tf32, DataType.tf32, DataType.f32, \
2879
+ OpcodeClass.TensorOp, \
2880
+ MathOperation.multiply_add_complex),
2881
+ MathInstruction( \
2882
+ [16, 8, 8], \
2883
+ DataType.f32, DataType.f32, DataType.f32, \
2884
+ OpcodeClass.TensorOp, \
2885
+ MathOperation.multiply_add_complex_fast_f32),
2886
+ ]
2887
+
2888
+ min_cc = 80
2889
+ max_cc = 1024
2890
+
2891
+ for math_inst in math_instructions:
2892
+ tile_descriptions = [
2893
+ TileDescription([128, 64, 16], 4, [4, 2, 1], math_inst, min_cc, max_cc),
2894
+ TileDescription([64, 128, 16], 4, [2, 4, 1], math_inst, min_cc, max_cc),
2895
+ TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc),
2896
+ #TileDescription([64, 32, 16], 4, [2, 1, 1], math_inst, min_cc, max_cc),
2897
+ #TileDescription([32, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc),
2898
+ ]
2899
+
2900
+ data_type = [
2901
+ DataType.cf32, DataType.cf32, DataType.cf32
2902
+ ]
2903
+
2904
+ alignment_constraints = [1,]
2905
+
2906
+ # SYRK
2907
+ CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \
2908
+ data_type, alignment_constraints, BlasMode.symmetric)
2909
+
2910
+ # HERK
2911
+ CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \
2912
+ data_type, alignment_constraints, BlasMode.hermitian)
2913
+ #
2914
+
2915
+ #
2916
+ def GenerateSM80_TensorOp_1688_trmm(manifest, cuda_version):
2917
+
2918
+ if not CudaToolkitVersionSatisfies(cuda_version, 11, 0):
2919
+ return
2920
+
2921
+ layouts = [
2922
+ (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
2923
+ (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
2924
+ ]
2925
+
2926
+ side_modes = [
2927
+ SideMode.Left, SideMode.Right,
2928
+ ]
2929
+
2930
+ fill_modes = [
2931
+ FillMode.Lower, FillMode.Upper,
2932
+ ]
2933
+
2934
+ diag_types = [
2935
+ DiagType.NonUnit, DiagType.Unit,
2936
+ ]
2937
+
2938
+ math_instructions = [
2939
+ MathInstruction( \
2940
+ [16, 8, 8], \
2941
+ DataType.tf32, DataType.tf32, DataType.f32, \
2942
+ OpcodeClass.TensorOp, \
2943
+ MathOperation.multiply_add),
2944
+ MathInstruction( \
2945
+ [16, 8, 8], \
2946
+ DataType.f32, DataType.f32, DataType.f32, \
2947
+ OpcodeClass.TensorOp, \
2948
+ MathOperation.multiply_add_fast_f32),
2949
+ ]
2950
+
2951
+ min_cc = 80
2952
+ max_cc = 1024
2953
+
2954
+ alignment_constraints = [1, 2, 4]
2955
+
2956
+ for math_inst in math_instructions:
2957
+ tile_descriptions = [
2958
+ TileDescription([256, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc),
2959
+ TileDescription([128, 256, 16], 3, [2, 4, 1], math_inst, min_cc, max_cc),
2960
+ TileDescription([256, 64, 16], 4, [4, 1, 1], math_inst, min_cc, max_cc),
2961
+ TileDescription([ 64, 256, 16], 4, [1, 4, 1], math_inst, min_cc, max_cc),
2962
+ TileDescription([128, 128, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc),
2963
+ TileDescription([128, 64, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc),
2964
+ #TileDescription([ 64, 128, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc),
2965
+ TileDescription([ 64, 64, 16], 10, [2, 2, 1], math_inst, min_cc, max_cc),
2966
+ TileDescription([256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc),
2967
+ TileDescription([128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc),
2968
+ #TileDescription([256, 64, 32], 4, [4, 1, 1], math_inst, min_cc, max_cc),
2969
+ #TileDescription([ 64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc),
2970
+ TileDescription([128, 128, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc),
2971
+ #TileDescription([128, 64, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc),
2972
+ #TileDescription([ 64, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc),
2973
+ #TileDescription([ 64, 64, 32], 5, [2, 2, 1], math_inst, min_cc, max_cc),
2974
+ ]
2975
+
2976
+ data_type = [DataType.f32, DataType.f32, DataType.f32, DataType.f32]
2977
+
2978
+ CreateTrmmOperator(manifest, layouts, side_modes, fill_modes, diag_types, tile_descriptions, \
2979
+ data_type, alignment_constraints)
2980
+ #
2981
+
2982
+ #
2983
+ def GenerateSM80_TensorOp_1688_trmm_complex(manifest, cuda_version):
2984
+
2985
+ if not CudaToolkitVersionSatisfies(cuda_version, 11, 0):
2986
+ return
2987
+
2988
+ layouts = [
2989
+ (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
2990
+ (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
2991
+ ]
2992
+
2993
+ side_modes = [
2994
+ SideMode.Left, SideMode.Right,
2995
+ ]
2996
+
2997
+ fill_modes = [
2998
+ FillMode.Lower, FillMode.Upper,
2999
+ ]
3000
+
3001
+ diag_types = [
3002
+ DiagType.NonUnit, DiagType.Unit,
3003
+ ]
3004
+
3005
+ math_instructions = [
3006
+ MathInstruction( \
3007
+ [16, 8, 8], \
3008
+ DataType.tf32, DataType.tf32, DataType.f32, \
3009
+ OpcodeClass.TensorOp, \
3010
+ MathOperation.multiply_add_complex),
3011
+ MathInstruction( \
3012
+ [16, 8, 8], \
3013
+ DataType.f32, DataType.f32, DataType.f32, \
3014
+ OpcodeClass.TensorOp, \
3015
+ MathOperation.multiply_add_complex_fast_f32),
3016
+ ]
3017
+
3018
+ min_cc = 80
3019
+ max_cc = 1024
3020
+
3021
+ for math_inst in math_instructions:
3022
+ tile_descriptions = [
3023
+ TileDescription([128, 64, 16], 4, [4, 2, 1], math_inst, min_cc, max_cc),
3024
+ TileDescription([64, 128, 16], 4, [2, 4, 1], math_inst, min_cc, max_cc),
3025
+ TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc),
3026
+ TileDescription([64, 32, 16], 4, [2, 1, 1], math_inst, min_cc, max_cc),
3027
+ TileDescription([32, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc),
3028
+ ]
3029
+
3030
+ data_type = [
3031
+ DataType.cf32, DataType.cf32, DataType.cf32, DataType.cf32
3032
+ ]
3033
+
3034
+ alignment_constraints = [1,]
3035
+
3036
+ complex_transforms = [
3037
+ ComplexTransform.none, ComplexTransform.conj,
3038
+ ]
3039
+
3040
+ CreateTrmmOperator(manifest, layouts, side_modes, fill_modes, diag_types, tile_descriptions, \
3041
+ data_type, alignment_constraints, complex_transforms)
3042
+ #
3043
+
3044
+ #
3045
+ def GenerateSM80_TensorOp_1688_symm(manifest, cuda_version):
3046
+
3047
+ if not CudaToolkitVersionSatisfies(cuda_version, 11, 0):
3048
+ return
3049
+
3050
+ # A and B have same layouts
3051
+ layouts = [
3052
+ (LayoutType.ColumnMajor, LayoutType.ColumnMajor),
3053
+ ]
3054
+
3055
+ side_modes = [
3056
+ SideMode.Left, SideMode.Right,
3057
+ ]
3058
+
3059
+ fill_modes = [
3060
+ FillMode.Lower, FillMode.Upper,
3061
+ ]
3062
+
3063
+ math_instructions = [
3064
+ MathInstruction( \
3065
+ [16, 8, 8], \
3066
+ DataType.tf32, DataType.tf32, DataType.f32, \
3067
+ OpcodeClass.TensorOp, \
3068
+ MathOperation.multiply_add),
3069
+ MathInstruction( \
3070
+ [16, 8, 8], \
3071
+ DataType.f32, DataType.f32, DataType.f32, \
3072
+ OpcodeClass.TensorOp, \
3073
+ MathOperation.multiply_add_fast_f32),
3074
+ ]
3075
+
3076
+ min_cc = 80
3077
+ max_cc = 1024
3078
+
3079
+ alignment_constraints = [
3080
+ 1, 2, 4
3081
+ ]
3082
+
3083
+ for math_inst in math_instructions:
3084
+ tile_descriptions = [
3085
+ TileDescription([256, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc),
3086
+ TileDescription([128, 256, 16], 3, [2, 4, 1], math_inst, min_cc, max_cc),
3087
+ #TileDescription([256, 64, 16], 4, [4, 1, 1], math_inst, min_cc, max_cc),
3088
+ #TileDescription([ 64, 256, 16], 4, [1, 4, 1], math_inst, min_cc, max_cc),
3089
+ TileDescription([128, 128, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc),
3090
+ #TileDescription([128, 64, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc),
3091
+ #TileDescription([ 64, 128, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc),
3092
+ #TileDescription([ 64, 64, 16], 10, [2, 2, 1], math_inst, min_cc, max_cc),
3093
+ TileDescription([256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc),
3094
+ TileDescription([128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc),
3095
+ #TileDescription([256, 64, 32], 4, [4, 1, 1], math_inst, min_cc, max_cc),
3096
+ #TileDescription([ 64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc),
3097
+ TileDescription([128, 128, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc),
3098
+ #TileDescription([128, 64, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc),
3099
+ #TileDescription([ 64, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc),
3100
+ #TileDescription([ 64, 64, 32], 5, [2, 2, 1], math_inst, min_cc, max_cc),
3101
+ ]
3102
+
3103
+ data_type = [DataType.f32, DataType.f32, DataType.f32, DataType.f32]
3104
+
3105
+ CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \
3106
+ data_type, alignment_constraints, BlasMode.symmetric)
3107
+ #
3108
+
3109
+ #
3110
+ def GenerateSM80_TensorOp_1688_symm_complex(manifest, cuda_version):
3111
+
3112
+ if not CudaToolkitVersionSatisfies(cuda_version, 11, 0):
3113
+ return
3114
+
3115
+ layouts = [
3116
+ (LayoutType.ColumnMajor, LayoutType.ColumnMajor),
3117
+ ]
3118
+
3119
+ side_modes = [
3120
+ SideMode.Left, SideMode.Right,
3121
+ ]
3122
+
3123
+ fill_modes = [
3124
+ FillMode.Lower, FillMode.Upper,
3125
+ ]
3126
+
3127
+ math_instructions = [
3128
+ MathInstruction( \
3129
+ [16, 8, 8], \
3130
+ DataType.tf32, DataType.tf32, DataType.f32, \
3131
+ OpcodeClass.TensorOp, \
3132
+ MathOperation.multiply_add_complex),
3133
+ MathInstruction( \
3134
+ [16, 8, 8], \
3135
+ DataType.f32, DataType.f32, DataType.f32, \
3136
+ OpcodeClass.TensorOp, \
3137
+ MathOperation.multiply_add_complex_fast_f32),
3138
+ ]
3139
+
3140
+ min_cc = 80
3141
+ max_cc = 1024
3142
+
3143
+ for math_inst in math_instructions:
3144
+ tile_descriptions = [
3145
+ TileDescription([128, 64, 16], 4, [4, 2, 1], math_inst, min_cc, max_cc),
3146
+ TileDescription([64, 128, 16], 4, [2, 4, 1], math_inst, min_cc, max_cc),
3147
+ TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc),
3148
+ #TileDescription([64, 32, 16], 4, [2, 1, 1], math_inst, min_cc, max_cc),
3149
+ #TileDescription([32, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc),
3150
+ ]
3151
+
3152
+ data_type = [
3153
+ DataType.cf32, DataType.cf32, DataType.cf32, DataType.cf32
3154
+ ]
3155
+
3156
+ alignment_constraints = [1,]
3157
+
3158
+ # SYMM
3159
+ CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \
3160
+ data_type, alignment_constraints, BlasMode.symmetric)
3161
+
3162
+ # HEMM
3163
+ CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \
3164
+ data_type, alignment_constraints, BlasMode.hermitian)
3165
+ #
3166
+
3167
+ #
3168
+ def GenerateSM80_TensorOp_884(manifest, cuda_version):
3169
+
3170
+ if not CudaToolkitVersionSatisfies(cuda_version, 11, 0):
3171
+ return
3172
+
3173
+ layouts = [
3174
+ (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
3175
+ (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
3176
+ (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
3177
+ (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
3178
+ ]
3179
+
3180
+ math_inst = \
3181
+ MathInstruction( \
3182
+ [8, 8, 4], \
3183
+ DataType.f64, DataType.f64, DataType.f64, \
3184
+ OpcodeClass.TensorOp, \
3185
+ MathOperation.multiply_add)
3186
+
3187
+ min_cc = 80
3188
+ max_cc = 1024
3189
+
3190
+ alignment_constraints = [1,]
3191
+
3192
+ tile_descriptions = [
3193
+ TileDescription([128, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc),
3194
+ TileDescription([256, 64, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc),
3195
+ TileDescription([64, 256, 16], 3, [2, 4, 1], math_inst, min_cc, max_cc),
3196
+ TileDescription([256, 32, 16], 3, [4, 1, 1], math_inst, min_cc, max_cc),
3197
+ TileDescription([32, 256, 16], 3, [1, 4, 1], math_inst, min_cc, max_cc),
3198
+ TileDescription([128, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc),
3199
+ TileDescription([64, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc),
3200
+ TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc),
3201
+ TileDescription([64, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc),
3202
+ TileDescription([32, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc),
3203
+ TileDescription([32, 32, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc),
3204
+ TileDescription([16, 32, 16], 5, [1, 2, 1], math_inst, min_cc, max_cc),
3205
+ TileDescription([32, 16, 16], 5, [2, 1, 1], math_inst, min_cc, max_cc),
3206
+ ]
3207
+
3208
+ data_type = [DataType.f64, DataType.f64, DataType.f64, DataType.f64]
3209
+
3210
+ CreateGemmOperator(manifest, layouts, tile_descriptions, \
3211
+ data_type, alignment_constraints)
3212
+ #
3213
+
3214
+ #
3215
+ def GenerateSM80_TensorOp_884_complex(manifest, cuda_version):
3216
+
3217
+ if not CudaToolkitVersionSatisfies(cuda_version, 11, 0):
3218
+ return
3219
+
3220
+ layouts = [
3221
+ (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
3222
+ (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
3223
+ (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
3224
+ (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
3225
+ ]
3226
+
3227
+ math_inst = \
3228
+ MathInstruction( \
3229
+ [8, 8, 4], \
3230
+ DataType.f64, DataType.f64, DataType.f64, \
3231
+ OpcodeClass.TensorOp, \
3232
+ MathOperation.multiply_add_complex)
3233
+
3234
+ min_cc = 80
3235
+ max_cc = 1024
3236
+
3237
+ alignment_constraints = [1,]
3238
+
3239
+ tile_descriptions = [
3240
+ TileDescription([128, 64, 8 ], 3, [4, 2, 1], math_inst, min_cc, max_cc),
3241
+ TileDescription([64, 128, 8 ], 3, [2, 4, 1], math_inst, min_cc, max_cc),
3242
+ TileDescription([64, 64, 8 ], 3, [2, 2, 1], math_inst, min_cc, max_cc),
3243
+ TileDescription([64, 32, 8 ], 4, [2, 2, 1], math_inst, min_cc, max_cc),
3244
+ TileDescription([32, 64, 8 ], 4, [2, 2, 1], math_inst, min_cc, max_cc),
3245
+ TileDescription([32, 32, 8 ], 4, [2, 2, 1], math_inst, min_cc, max_cc),
3246
+ TileDescription([16, 32, 8 ], 4, [1, 2, 1], math_inst, min_cc, max_cc),
3247
+ TileDescription([32, 16, 8 ], 4, [2, 1, 1], math_inst, min_cc, max_cc),
3248
+ TileDescription([128, 64, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc),
3249
+ TileDescription([64, 128, 16], 3, [2, 4, 1], math_inst, min_cc, max_cc),
3250
+ TileDescription([64, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc),
3251
+ TileDescription([64, 32, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc),
3252
+ TileDescription([32, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc),
3253
+ TileDescription([32, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc),
3254
+ TileDescription([16, 32, 16], 4, [1, 2, 1], math_inst, min_cc, max_cc),
3255
+ TileDescription([32, 16, 16], 3, [2, 1, 1], math_inst, min_cc, max_cc),
3256
+ ]
3257
+
3258
+ data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64]
3259
+
3260
+ complex_transforms = [
3261
+ (ComplexTransform.none, ComplexTransform.none),
3262
+ (ComplexTransform.conj, ComplexTransform.none),
3263
+ (ComplexTransform.none, ComplexTransform.conj),
3264
+ (ComplexTransform.conj, ComplexTransform.conj)
3265
+ ]
3266
+
3267
+ CreateGemmOperator(manifest, layouts, tile_descriptions, \
3268
+ data_type, alignment_constraints, complex_transforms)
3269
+
3270
+ #
3271
+ def GenerateSM80_TensorOp_884_complex_gaussian(manifest, cuda_version):
3272
+
3273
+ if not CudaToolkitVersionSatisfies(cuda_version, 11, 0):
3274
+ return
3275
+
3276
+ layouts = [
3277
+ (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
3278
+ (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
3279
+ (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
3280
+ (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
3281
+ ]
3282
+
3283
+ math_inst = \
3284
+ MathInstruction( \
3285
+ [8, 8, 4], \
3286
+ DataType.f64, DataType.f64, DataType.f64, \
3287
+ OpcodeClass.TensorOp, \
3288
+ MathOperation.multiply_add_complex_gaussian)
3289
+
3290
+ min_cc = 80
3291
+ max_cc = 1024
3292
+
3293
+ alignment_constraints = [1,]
3294
+
3295
+ tile_descriptions = [
3296
+ TileDescription([64, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc),
3297
+ TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
3298
+ TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
3299
+ TileDescription([32, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
3300
+ TileDescription([16, 32, 8], 4, [1, 2, 1], math_inst, min_cc, max_cc),
3301
+ TileDescription([32, 16, 8], 4, [2, 1, 1], math_inst, min_cc, max_cc),
3302
+ ]
3303
+
3304
+ data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64]
3305
+
3306
+ complex_transforms = [
3307
+ (ComplexTransform.none, ComplexTransform.none),
3308
+ (ComplexTransform.conj, ComplexTransform.none),
3309
+ (ComplexTransform.none, ComplexTransform.conj),
3310
+ (ComplexTransform.conj, ComplexTransform.conj)
3311
+ ]
3312
+
3313
+ CreateGemmOperator(manifest, layouts, tile_descriptions, \
3314
+ data_type, alignment_constraints, complex_transforms)
3315
+ #
3316
+
3317
+ #
3318
+ def GenerateSM80_TensorOp_884_rank_k(manifest, cuda_version):
3319
+
3320
+ if not CudaToolkitVersionSatisfies(cuda_version, 11, 0):
3321
+ return
3322
+
3323
+ layouts = [
3324
+ (LayoutType.ColumnMajor, LayoutType.ColumnMajor),
3325
+ (LayoutType.RowMajor, LayoutType.ColumnMajor),
3326
+ ]
3327
+
3328
+ fill_modes = [
3329
+ FillMode.Lower, FillMode.Upper,
3330
+ ]
3331
+
3332
+ math_inst = \
3333
+ MathInstruction( \
3334
+ [8, 8, 4], \
3335
+ DataType.f64, DataType.f64, DataType.f64, \
3336
+ OpcodeClass.TensorOp, \
3337
+ MathOperation.multiply_add)
3338
+
3339
+ min_cc = 80
3340
+ max_cc = 1024
3341
+
3342
+ alignment_constraints = [1,]
3343
+
3344
+ tile_descriptions = [
3345
+ TileDescription([128, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc),
3346
+ TileDescription([64, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc),
3347
+ TileDescription([128, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc),
3348
+ TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc),
3349
+ TileDescription([64, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc),
3350
+ TileDescription([32, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc),
3351
+ TileDescription([32, 32, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc),
3352
+ TileDescription([16, 32, 16], 5, [1, 2, 1], math_inst, min_cc, max_cc),
3353
+ TileDescription([32, 16, 16], 5, [2, 1, 1], math_inst, min_cc, max_cc),
3354
+ ]
3355
+
3356
+ data_type = [DataType.f64, DataType.f64, DataType.f64]
3357
+
3358
+ CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \
3359
+ data_type, alignment_constraints, BlasMode.symmetric)
3360
+ #
3361
+
3362
+ #
3363
+ def GenerateSM80_TensorOp_884_rank_k_complex(manifest, cuda_version):
3364
+
3365
+ if not CudaToolkitVersionSatisfies(cuda_version, 11, 0):
3366
+ return
3367
+
3368
+ layouts = [
3369
+ (LayoutType.ColumnMajor, LayoutType.ColumnMajor),
3370
+ (LayoutType.RowMajor, LayoutType.ColumnMajor),
3371
+ ]
3372
+
3373
+ fill_modes = [
3374
+ FillMode.Lower, FillMode.Upper,
3375
+ ]
3376
+
3377
+ math_inst = \
3378
+ MathInstruction( \
3379
+ [8, 8, 4], \
3380
+ DataType.f64, DataType.f64, DataType.f64, \
3381
+ OpcodeClass.TensorOp, \
3382
+ MathOperation.multiply_add_complex)
3383
+
3384
+ min_cc = 80
3385
+ max_cc = 1024
3386
+
3387
+ alignment_constraints = [1,]
3388
+
3389
+ tile_descriptions = [
3390
+ TileDescription([128, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc),
3391
+ TileDescription([64, 128, 8], 3, [2, 4, 1], math_inst, min_cc, max_cc),
3392
+ TileDescription([64, 64, 8], 3, [2, 2, 1], math_inst, min_cc, max_cc),
3393
+ #TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
3394
+ #TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
3395
+ #TileDescription([32, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
3396
+ #TileDescription([16, 32, 8], 4, [1, 2, 1], math_inst, min_cc, max_cc),
3397
+ #TileDescription([32, 16, 8], 4, [2, 1, 1], math_inst, min_cc, max_cc),
3398
+ ]
3399
+
3400
+ data_type = [DataType.cf64, DataType.cf64, DataType.cf64]
3401
+
3402
+ # SYRK computation
3403
+ CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \
3404
+ data_type, alignment_constraints, BlasMode.symmetric)
3405
+
3406
+ # HERK computation
3407
+ CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \
3408
+ data_type, alignment_constraints, BlasMode.hermitian)
3409
+
3410
+ #
3411
+
3412
+ #
3413
+ def GenerateSM80_TensorOp_884_rank_k_complex_gaussian(manifest, cuda_version):
3414
+
3415
+ if not CudaToolkitVersionSatisfies(cuda_version, 11, 0):
3416
+ return
3417
+
3418
+ layouts = [
3419
+ (LayoutType.ColumnMajor, LayoutType.ColumnMajor),
3420
+ (LayoutType.RowMajor, LayoutType.ColumnMajor),
3421
+ ]
3422
+
3423
+ fill_modes = [
3424
+ FillMode.Lower, FillMode.Upper,
3425
+ ]
3426
+
3427
+ math_inst = \
3428
+ MathInstruction( \
3429
+ [8, 8, 4], \
3430
+ DataType.f64, DataType.f64, DataType.f64, \
3431
+ OpcodeClass.TensorOp, \
3432
+ MathOperation.multiply_add_complex_gaussian)
3433
+
3434
+ min_cc = 80
3435
+ max_cc = 1024
3436
+
3437
+ alignment_constraints = [1,]
3438
+
3439
+ tile_descriptions = [
3440
+ TileDescription([64, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc),
3441
+ TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
3442
+ TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
3443
+ #TileDescription([32, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
3444
+ #TileDescription([16, 32, 8], 4, [1, 2, 1], math_inst, min_cc, max_cc),
3445
+ #TileDescription([32, 16, 8], 4, [2, 1, 1], math_inst, min_cc, max_cc),
3446
+ ]
3447
+
3448
+ data_type = [DataType.cf64, DataType.cf64, DataType.cf64]
3449
+
3450
+ complex_transforms = [ComplexTransform.none,]
3451
+
3452
+ # SYRK computation
3453
+ CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \
3454
+ data_type, alignment_constraints, BlasMode.symmetric)
3455
+
3456
+ # HERK computation
3457
+ CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \
3458
+ data_type, alignment_constraints, BlasMode.hermitian)
3459
+ #
3460
+
3461
+ #
3462
+ def GenerateSM80_TensorOp_884_trmm(manifest, cuda_version):
3463
+
3464
+ if not CudaToolkitVersionSatisfies(cuda_version, 11, 0):
3465
+ return
3466
+
3467
+ layouts = [
3468
+ (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
3469
+ (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
3470
+ ]
3471
+
3472
+ side_modes = [
3473
+ SideMode.Left, SideMode.Right,
3474
+ ]
3475
+
3476
+ fill_modes = [
3477
+ FillMode.Lower, FillMode.Upper,
3478
+ ]
3479
+
3480
+ diag_types = [
3481
+ DiagType.NonUnit, DiagType.Unit,
3482
+ ]
3483
+
3484
+ math_inst = \
3485
+ MathInstruction( \
3486
+ [8, 8, 4], \
3487
+ DataType.f64, DataType.f64, DataType.f64, \
3488
+ OpcodeClass.TensorOp, \
3489
+ MathOperation.multiply_add)
3490
+
3491
+ min_cc = 80
3492
+ max_cc = 1024
3493
+
3494
+ alignment_constraints = [1,]
3495
+
3496
+ tile_descriptions = [
3497
+ TileDescription([128, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc),
3498
+ TileDescription([64, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc),
3499
+ TileDescription([128, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc),
3500
+ TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc),
3501
+ ]
3502
+
3503
+ data_type = [DataType.f64, DataType.f64, DataType.f64, DataType.f64]
3504
+
3505
+ CreateTrmmOperator(manifest, layouts, side_modes, fill_modes, diag_types, tile_descriptions, \
3506
+ data_type, alignment_constraints)
3507
+ #
3508
+
3509
+ #
3510
+ def GenerateSM80_TensorOp_884_trmm_complex(manifest, cuda_version):
3511
+
3512
+ if not CudaToolkitVersionSatisfies(cuda_version, 11, 0):
3513
+ return
3514
+
3515
+ layouts = [
3516
+ (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
3517
+ (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
3518
+ ]
3519
+
3520
+ side_modes = [
3521
+ SideMode.Left, SideMode.Right,
3522
+ ]
3523
+
3524
+ fill_modes = [
3525
+ FillMode.Lower, FillMode.Upper,
3526
+ ]
3527
+
3528
+ diag_types = [
3529
+ DiagType.NonUnit, DiagType.Unit,
3530
+ ]
3531
+
3532
+ math_inst = \
3533
+ MathInstruction( \
3534
+ [8, 8, 4], \
3535
+ DataType.f64, DataType.f64, DataType.f64, \
3536
+ OpcodeClass.TensorOp, \
3537
+ MathOperation.multiply_add_complex)
3538
+
3539
+ min_cc = 80
3540
+ max_cc = 1024
3541
+
3542
+ alignment_constraints = [1,]
3543
+
3544
+ tile_descriptions = [
3545
+ TileDescription([128, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc),
3546
+ TileDescription([64, 128, 8], 3, [2, 4, 1], math_inst, min_cc, max_cc),
3547
+ TileDescription([64, 64, 8], 3, [2, 2, 1], math_inst, min_cc, max_cc),
3548
+ TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
3549
+ TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
3550
+ ]
3551
+
3552
+ data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64]
3553
+
3554
+ complex_transforms = [
3555
+ ComplexTransform.none, ComplexTransform.conj,
3556
+ ]
3557
+
3558
+ CreateTrmmOperator(manifest, layouts, side_modes, fill_modes, diag_types, tile_descriptions, \
3559
+ data_type, alignment_constraints, complex_transforms)
3560
+ #
3561
+
3562
+
3563
+ #
3564
+ def GenerateSM80_TensorOp_884_trmm_complex_gaussian(manifest, cuda_version):
3565
+
3566
+ if not CudaToolkitVersionSatisfies(cuda_version, 11, 0):
3567
+ return
3568
+
3569
+ layouts = [
3570
+ (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
3571
+ (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
3572
+ ]
3573
+
3574
+ side_modes = [
3575
+ SideMode.Left, SideMode.Right,
3576
+ ]
3577
+
3578
+ fill_modes = [
3579
+ FillMode.Lower, FillMode.Upper,
3580
+ ]
3581
+
3582
+ diag_types = [
3583
+ DiagType.NonUnit, DiagType.Unit,
3584
+ ]
3585
+
3586
+ math_inst = \
3587
+ MathInstruction( \
3588
+ [8, 8, 4], \
3589
+ DataType.f64, DataType.f64, DataType.f64, \
3590
+ OpcodeClass.TensorOp, \
3591
+ MathOperation.multiply_add_complex_gaussian)
3592
+
3593
+ min_cc = 80
3594
+ max_cc = 1024
3595
+
3596
+ alignment_constraints = [1,]
3597
+
3598
+ tile_descriptions = [
3599
+ TileDescription([64, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc),
3600
+ TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
3601
+ TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
3602
+ ]
3603
+
3604
+ data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64]
3605
+
3606
+ complex_transforms = [
3607
+ ComplexTransform.none, ComplexTransform.conj,
3608
+ ]
3609
+
3610
+ CreateTrmmOperator(manifest, layouts, side_modes, fill_modes, diag_types, tile_descriptions, \
3611
+ data_type, alignment_constraints, complex_transforms)
3612
+ #
3613
+
3614
+ #
3615
+ def GenerateSM80_TensorOp_884_symm(manifest, cuda_version):
3616
+
3617
+ if not CudaToolkitVersionSatisfies(cuda_version, 11, 0):
3618
+ return
3619
+
3620
+ layouts = [
3621
+ (LayoutType.ColumnMajor, LayoutType.ColumnMajor),
3622
+ ]
3623
+
3624
+ side_modes = [
3625
+ SideMode.Left, SideMode.Right,
3626
+ ]
3627
+
3628
+ fill_modes = [
3629
+ FillMode.Lower, FillMode.Upper,
3630
+ ]
3631
+
3632
+ math_inst = \
3633
+ MathInstruction( \
3634
+ [8, 8, 4], \
3635
+ DataType.f64, DataType.f64, DataType.f64, \
3636
+ OpcodeClass.TensorOp, \
3637
+ MathOperation.multiply_add)
3638
+
3639
+ min_cc = 80
3640
+ max_cc = 1024
3641
+
3642
+ alignment_constraints = [1,]
3643
+
3644
+ tile_descriptions = [
3645
+ TileDescription([128, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc),
3646
+ TileDescription([64, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc),
3647
+ TileDescription([128, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc),
3648
+ TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc),
3649
+ TileDescription([64, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc),
3650
+ TileDescription([32, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc),
3651
+ TileDescription([32, 32, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc),
3652
+ TileDescription([16, 32, 16], 5, [1, 2, 1], math_inst, min_cc, max_cc),
3653
+ TileDescription([32, 16, 16], 5, [2, 1, 1], math_inst, min_cc, max_cc),
3654
+ ]
3655
+
3656
+ data_type = [DataType.f64, DataType.f64, DataType.f64, DataType.f64]
3657
+
3658
+ CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \
3659
+ data_type, alignment_constraints, BlasMode.symmetric)
3660
+ #
3661
+
3662
+ #
3663
+ def GenerateSM80_TensorOp_884_symm_complex(manifest, cuda_version):
3664
+
3665
+ if not CudaToolkitVersionSatisfies(cuda_version, 11, 0):
3666
+ return
3667
+
3668
+ layouts = [
3669
+ (LayoutType.ColumnMajor, LayoutType.ColumnMajor),
3670
+ ]
3671
+
3672
+ side_modes = [
3673
+ SideMode.Left, SideMode.Right,
3674
+ ]
3675
+
3676
+ fill_modes = [
3677
+ FillMode.Lower, FillMode.Upper,
3678
+ ]
3679
+
3680
+ math_inst = \
3681
+ MathInstruction( \
3682
+ [8, 8, 4], \
3683
+ DataType.f64, DataType.f64, DataType.f64, \
3684
+ OpcodeClass.TensorOp, \
3685
+ MathOperation.multiply_add_complex)
3686
+
3687
+ min_cc = 80
3688
+ max_cc = 1024
3689
+
3690
+ alignment_constraints = [1,]
3691
+
3692
+ tile_descriptions = [
3693
+ TileDescription([128, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc),
3694
+ TileDescription([64, 128, 8], 3, [2, 4, 1], math_inst, min_cc, max_cc),
3695
+ TileDescription([64, 64, 8], 3, [2, 2, 1], math_inst, min_cc, max_cc),
3696
+ #TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
3697
+ #TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
3698
+ #TileDescription([32, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
3699
+ #TileDescription([16, 32, 8], 4, [1, 2, 1], math_inst, min_cc, max_cc),
3700
+ #TileDescription([32, 16, 8], 4, [2, 1, 1], math_inst, min_cc, max_cc),
3701
+ ]
3702
+
3703
+ data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64]
3704
+
3705
+ # SYMM computation
3706
+ CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \
3707
+ data_type, alignment_constraints, BlasMode.symmetric)
3708
+
3709
+ # HEMM computation
3710
+ CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \
3711
+ data_type, alignment_constraints, BlasMode.hermitian)
3712
+ #
3713
+
3714
+ #
3715
+ def GenerateSM80_TensorOp_884_symm_complex_gaussian(manifest, cuda_version):
3716
+
3717
+ if not CudaToolkitVersionSatisfies(cuda_version, 11, 0):
3718
+ return
3719
+
3720
+ layouts = [
3721
+ (LayoutType.ColumnMajor, LayoutType.ColumnMajor),
3722
+ ]
3723
+
3724
+ side_modes = [
3725
+ SideMode.Left, SideMode.Right,
3726
+ ]
3727
+
3728
+ fill_modes = [
3729
+ FillMode.Lower, FillMode.Upper,
3730
+ ]
3731
+
3732
+ math_inst = \
3733
+ MathInstruction( \
3734
+ [8, 8, 4], \
3735
+ DataType.f64, DataType.f64, DataType.f64, \
3736
+ OpcodeClass.TensorOp, \
3737
+ MathOperation.multiply_add_complex_gaussian)
3738
+
3739
+ min_cc = 80
3740
+ max_cc = 1024
3741
+
3742
+ alignment_constraints = [1,]
3743
+
3744
+ tile_descriptions = [
3745
+ TileDescription([64, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc),
3746
+ TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
3747
+ TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
3748
+ #TileDescription([32, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
3749
+ #TileDescription([16, 32, 8], 4, [1, 2, 1], math_inst, min_cc, max_cc),
3750
+ #TileDescription([32, 16, 8], 4, [2, 1, 1], math_inst, min_cc, max_cc),
3751
+ ]
3752
+
3753
+ data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64]
3754
+
3755
+ complex_transforms = [ComplexTransform.none,]
3756
+
3757
+ # SYMM computation
3758
+ CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \
3759
+ data_type, alignment_constraints, BlasMode.symmetric)
3760
+
3761
+ # HEMM computation
3762
+ CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \
3763
+ data_type, alignment_constraints, BlasMode.hermitian)
3764
+ #
3765
+
3766
+ ###################################################################################################
3767
+
3768
+ #
3769
+ def GenerateSM80_Simt_f32(manifest, cuda_version):
3770
+ layouts = [
3771
+ (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
3772
+ (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
3773
+ (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
3774
+ (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
3775
+ ]
3776
+
3777
+ math_instructions = [
3778
+ MathInstruction( \
3779
+ [1, 1, 1], \
3780
+ DataType.f32, DataType.f32, DataType.f32, \
3781
+ OpcodeClass.Simt, \
3782
+ MathOperation.multiply_add),
3783
+ ]
3784
+
3785
+ min_cc = 80
3786
+ max_cc = 1024
3787
+
3788
+ alignment_constraints = [1,]
3789
+
3790
+ for math_inst in math_instructions:
3791
+ tile_descriptions = [
3792
+ TileDescription([256, 128, 8], 5, [4, 2, 1], math_inst, min_cc, max_cc),
3793
+ TileDescription([128, 256, 8], 5, [2, 4, 1], math_inst, min_cc, max_cc),
3794
+ TileDescription([128, 128, 8], 5, [4, 2, 1], math_inst, min_cc, max_cc),
3795
+ TileDescription([256, 128, 8], 4, [4, 2, 1], math_inst, min_cc, max_cc),
3796
+ TileDescription([128, 256, 8], 4, [2, 4, 1], math_inst, min_cc, max_cc),
3797
+ TileDescription([128, 128, 8], 4, [4, 2, 1], math_inst, min_cc, max_cc),
3798
+ TileDescription([128, 64, 8], 5, [2, 2, 1], math_inst, min_cc, max_cc),
3799
+ TileDescription([ 64, 128, 8], 5, [2, 2, 1], math_inst, min_cc, max_cc),
3800
+ TileDescription([ 64, 64, 8], 5, [2, 1, 1], math_inst, min_cc, max_cc),
3801
+ TileDescription([128, 32, 8], 5, [2, 1, 1], math_inst, min_cc, max_cc),
3802
+ TileDescription([ 32, 128, 8], 5, [1, 2, 1], math_inst, min_cc, max_cc),
3803
+ ]
3804
+
3805
+ data_type = [
3806
+ math_inst.element_a,
3807
+ math_inst.element_b,
3808
+ math_inst.element_accumulator,
3809
+ math_inst.element_accumulator,
3810
+ ]
3811
+
3812
+ CreateGemmOperator(manifest, layouts, tile_descriptions, \
3813
+ data_type, alignment_constraints)
3814
+
3815
+ conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
3816
+ CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints)
3817
+ #
3818
+
3819
+
3820
+ #
3821
+ def GenerateSM80_Simt_f64(manifest, cuda_version):
3822
+ layouts = [
3823
+ (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
3824
+ (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
3825
+ (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
3826
+ (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
3827
+ ]
3828
+
3829
+ math_instructions = [
3830
+ MathInstruction( \
3831
+ [1, 1, 1], \
3832
+ DataType.f64, DataType.f64, DataType.f64, \
3833
+ OpcodeClass.Simt, \
3834
+ MathOperation.multiply_add),
3835
+ ]
3836
+
3837
+ min_cc = 80
3838
+ max_cc = 1024
3839
+
3840
+ alignment_constraints = [1,]
3841
+
3842
+ for math_inst in math_instructions:
3843
+ tile_descriptions = [
3844
+ TileDescription([128, 128, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc),
3845
+ TileDescription([128, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
3846
+ TileDescription([ 64, 128, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
3847
+ TileDescription([ 64, 64, 8], 5, [2, 1, 1], math_inst, min_cc, max_cc),
3848
+ TileDescription([128, 32, 8], 5, [2, 1, 1], math_inst, min_cc, max_cc),
3849
+ TileDescription([ 32, 128, 8], 5, [1, 2, 1], math_inst, min_cc, max_cc),
3850
+ ]
3851
+
3852
+ data_type = [
3853
+ math_inst.element_a,
3854
+ math_inst.element_b,
3855
+ math_inst.element_accumulator,
3856
+ math_inst.element_accumulator,
3857
+ ]
3858
+
3859
+ CreateGemmOperator(manifest, layouts, tile_descriptions, \
3860
+ data_type, alignment_constraints)
3861
+ #
3862
+
3863
+
3864
+ ##################################################################################################
3865
+ #
3866
+ def GenerateSM80_Simt_complex(manifest, cuda_version):
3867
+ math_instructions = [
3868
+ MathInstruction( \
3869
+ [1, 1, 1], \
3870
+ DataType.f32, DataType.f32, DataType.f32, \
3871
+ OpcodeClass.Simt, \
3872
+ MathOperation.multiply_add_complex),
3873
+ ]
3874
+
3875
+ min_cc = 80
3876
+ max_cc = 1024
3877
+
3878
+ alignment_constraints = [1,]
3879
+
3880
+ data_type = [
3881
+ DataType.cf32,
3882
+ DataType.cf32,
3883
+ DataType.cf32,
3884
+ DataType.cf32
3885
+ ]
3886
+
3887
+ layouts = [
3888
+ (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
3889
+ (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
3890
+ (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
3891
+ (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
3892
+ ]
3893
+
3894
+ complex_transforms = [
3895
+ (ComplexTransform.none, ComplexTransform.none),
3896
+ (ComplexTransform.conj, ComplexTransform.none),
3897
+ (ComplexTransform.none, ComplexTransform.conj),
3898
+ (ComplexTransform.conj, ComplexTransform.conj)
3899
+ ]
3900
+
3901
+ for math_inst in math_instructions:
3902
+
3903
+ tile_descriptions = [
3904
+ TileDescription([128, 128, 8], 5, [4, 2, 1], math_inst, min_cc, max_cc),
3905
+ TileDescription([128, 128, 8], 4, [4, 2, 1], math_inst, min_cc, max_cc),
3906
+ TileDescription([64, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc),
3907
+ TileDescription([ 64, 128, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc),
3908
+ TileDescription([128, 64, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc),
3909
+ TileDescription([64, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc),
3910
+ TileDescription([32, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc),
3911
+ TileDescription([32, 32, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc),
3912
+ ]
3913
+
3914
+ CreateGemmOperator(manifest, layouts, tile_descriptions, data_type, alignment_constraints, complex_transforms)
3915
+
3916
+ conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
3917
+ CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints)
3918
+ #
3919
+
3920
+ ###################################################################################################
3921
+
3922
+ #
3923
+ def GenerateSM80(manifest, cuda_version):
3924
+ GenerateSM80_TensorOp_16816(manifest, cuda_version)
3925
+ GenerateSM80_SparseTensorOp_16832(manifest, cuda_version)
3926
+ GenerateSM80_PlanarComplexTensorOp_16816(manifest, cuda_version)
3927
+ GenerateSM80_TensorOp_1688(manifest, cuda_version)
3928
+ GenerateSM80_TensorOp_1688_fast_math(manifest, cuda_version)
3929
+ GenerateSM80_SparseTensorOp_16816_fast_math(manifest, cuda_version)
3930
+ GenerateSM80_TensorOp_1688_complex(manifest, cuda_version)
3931
+ # 3xTF32
3932
+ GenerateSM80_TensorOp_1688_fast_fp32_math(manifest, cuda_version)
3933
+ GenerateSM80_TensorOp_1688_fast_fp32_math_complex(manifest, cuda_version)
3934
+ GenerateSM80_TensorOp_1688_rank_k(manifest, cuda_version)
3935
+ GenerateSM80_TensorOp_1688_rank_k_complex(manifest, cuda_version)
3936
+ GenerateSM80_TensorOp_1688_trmm(manifest, cuda_version)
3937
+ GenerateSM80_TensorOp_1688_trmm_complex(manifest, cuda_version)
3938
+ GenerateSM80_TensorOp_1688_symm(manifest, cuda_version)
3939
+ GenerateSM80_TensorOp_1688_symm_complex(manifest, cuda_version)
3940
+ GenerateSM80_TensorOp_884(manifest, cuda_version)
3941
+ GenerateSM80_TensorOp_884_complex(manifest, cuda_version)
3942
+ GenerateSM80_TensorOp_884_complex_gaussian(manifest, cuda_version)
3943
+ GenerateSM80_TensorOp_884_rank_k(manifest, cuda_version)
3944
+ GenerateSM80_TensorOp_884_rank_k_complex(manifest, cuda_version)
3945
+ GenerateSM80_TensorOp_884_rank_k_complex_gaussian(manifest, cuda_version)
3946
+ GenerateSM80_TensorOp_884_trmm(manifest, cuda_version)
3947
+ GenerateSM80_TensorOp_884_trmm_complex(manifest, cuda_version)
3948
+ GenerateSM80_TensorOp_884_trmm_complex_gaussian(manifest, cuda_version)
3949
+ GenerateSM80_TensorOp_884_symm(manifest, cuda_version)
3950
+ GenerateSM80_TensorOp_884_symm_complex(manifest, cuda_version)
3951
+ GenerateSM80_TensorOp_884_symm_complex_gaussian(manifest, cuda_version)
3952
+ GenerateSM80_TensorOp_16832_TN(manifest, cuda_version)
3953
+ GenerateSM80_SparseTensorOp_16864_TN(manifest, cuda_version)
3954
+ GenerateSM80_TensorOp_16832_Interleaved(manifest, cuda_version)
3955
+ GenerateSM80_TensorOp_16864_TN(manifest, cuda_version)
3956
+ GenerateSM80_SparseTensorOp_168128_TN(manifest, cuda_version)
3957
+ GenerateSM80_TensorOp_16864_Interleaved(manifest, cuda_version)
3958
+ GenerateSM80_TensorOp_168256(manifest, cuda_version)
3959
+ GenerateSM80_Simt_f32(manifest, cuda_version)
3960
+ GenerateSM80_Simt_f64(manifest, cuda_version)
3961
+ GenerateSM80_Simt_complex(manifest, cuda_version)
3962
+
3963
+ ###################################################################################################
3964
+
3965
+ #
3966
+ def GenerateSM90_TensorOp_1684(manifest, cuda_version):
3967
+
3968
+ if not CudaToolkitVersionSatisfies(cuda_version, 11, 8):
3969
+ return
3970
+
3971
+ layouts = [
3972
+ (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
3973
+ (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
3974
+ (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
3975
+ (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
3976
+ ]
3977
+
3978
+ math_inst = \
3979
+ MathInstruction( \
3980
+ [16, 8, 4], \
3981
+ DataType.f64, DataType.f64, DataType.f64, \
3982
+ OpcodeClass.TensorOp, \
3983
+ MathOperation.multiply_add)
3984
+
3985
+ min_cc = 90
3986
+ max_cc = 1024
3987
+
3988
+ alignment_constraints = [1,]
3989
+
3990
+ tile_descriptions = [
3991
+ TileDescription([128, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc),
3992
+ TileDescription([256, 64, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc),
3993
+ TileDescription([64, 256, 16], 3, [2, 4, 1], math_inst, min_cc, max_cc),
3994
+ TileDescription([256, 32, 16], 3, [4, 1, 1], math_inst, min_cc, max_cc),
3995
+ TileDescription([32, 256, 16], 3, [1, 4, 1], math_inst, min_cc, max_cc),
3996
+ TileDescription([128, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc),
3997
+ TileDescription([64, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc),
3998
+ TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc),
3999
+ TileDescription([64, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc),
4000
+ TileDescription([32, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc),
4001
+ TileDescription([32, 32, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc),
4002
+ TileDescription([16, 32, 16], 5, [1, 2, 1], math_inst, min_cc, max_cc),
4003
+ TileDescription([32, 16, 16], 5, [2, 1, 1], math_inst, min_cc, max_cc),
4004
+ ]
4005
+
4006
+ data_type = [DataType.f64, DataType.f64, DataType.f64, DataType.f64]
4007
+
4008
+ CreateGemmOperator(manifest, layouts, tile_descriptions, \
4009
+ data_type, alignment_constraints)
4010
+
4011
+ #
4012
+
4013
+ #
4014
+ def GenerateSM90_TensorOp_1684_complex(manifest, cuda_version):
4015
+
4016
+ if not CudaToolkitVersionSatisfies(cuda_version, 11, 8):
4017
+ return
4018
+
4019
+ layouts = [
4020
+ (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
4021
+ (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
4022
+ (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
4023
+ (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
4024
+ ]
4025
+
4026
+ math_inst = \
4027
+ MathInstruction( \
4028
+ [16, 8, 4], \
4029
+ DataType.f64, DataType.f64, DataType.f64, \
4030
+ OpcodeClass.TensorOp, \
4031
+ MathOperation.multiply_add_complex)
4032
+
4033
+ min_cc = 90
4034
+ max_cc = 1024
4035
+
4036
+ alignment_constraints = [1,]
4037
+
4038
+ tile_descriptions = [
4039
+ TileDescription([128, 64, 8 ], 3, [4, 2, 1], math_inst, min_cc, max_cc),
4040
+ TileDescription([64, 128, 8 ], 3, [2, 4, 1], math_inst, min_cc, max_cc),
4041
+ TileDescription([64, 64, 8 ], 3, [2, 2, 1], math_inst, min_cc, max_cc),
4042
+ TileDescription([64, 32, 8 ], 4, [2, 2, 1], math_inst, min_cc, max_cc),
4043
+ TileDescription([32, 64, 8 ], 4, [2, 2, 1], math_inst, min_cc, max_cc),
4044
+ TileDescription([32, 32, 8 ], 4, [2, 2, 1], math_inst, min_cc, max_cc),
4045
+ TileDescription([16, 32, 8 ], 4, [1, 2, 1], math_inst, min_cc, max_cc),
4046
+ TileDescription([32, 16, 8 ], 4, [2, 1, 1], math_inst, min_cc, max_cc),
4047
+ TileDescription([128, 64, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc),
4048
+ TileDescription([64, 128, 16], 3, [2, 4, 1], math_inst, min_cc, max_cc),
4049
+ TileDescription([64, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc),
4050
+ TileDescription([64, 32, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc),
4051
+ TileDescription([32, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc),
4052
+ TileDescription([32, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc),
4053
+ TileDescription([16, 32, 16], 4, [1, 2, 1], math_inst, min_cc, max_cc),
4054
+ TileDescription([32, 16, 16], 3, [2, 1, 1], math_inst, min_cc, max_cc),
4055
+ ]
4056
+
4057
+ data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64]
4058
+
4059
+ complex_transforms = [
4060
+ (ComplexTransform.none, ComplexTransform.none),
4061
+ (ComplexTransform.conj, ComplexTransform.none),
4062
+ (ComplexTransform.none, ComplexTransform.conj),
4063
+ (ComplexTransform.conj, ComplexTransform.conj)
4064
+ ]
4065
+
4066
+ CreateGemmOperator(manifest, layouts, tile_descriptions, \
4067
+ data_type, alignment_constraints, complex_transforms)
4068
+ #
4069
+
4070
+ #
4071
+ def GenerateSM90_TensorOp_1684_complex_gaussian(manifest, cuda_version):
4072
+
4073
+ if not CudaToolkitVersionSatisfies(cuda_version, 11, 8):
4074
+ return
4075
+
4076
+ layouts = [
4077
+ (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
4078
+ (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
4079
+ (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
4080
+ (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
4081
+ ]
4082
+
4083
+ math_inst = \
4084
+ MathInstruction( \
4085
+ [16, 8, 4], \
4086
+ DataType.f64, DataType.f64, DataType.f64, \
4087
+ OpcodeClass.TensorOp, \
4088
+ MathOperation.multiply_add_complex_gaussian)
4089
+
4090
+ min_cc = 90
4091
+ max_cc = 1024
4092
+
4093
+ alignment_constraints = [1,]
4094
+
4095
+ tile_descriptions = [
4096
+ TileDescription([64, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc),
4097
+ TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
4098
+ TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
4099
+ TileDescription([32, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
4100
+ TileDescription([16, 32, 8], 4, [1, 2, 1], math_inst, min_cc, max_cc),
4101
+ TileDescription([32, 16, 8], 4, [2, 1, 1], math_inst, min_cc, max_cc),
4102
+ ]
4103
+
4104
+ data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64]
4105
+
4106
+ complex_transforms = [
4107
+ (ComplexTransform.none, ComplexTransform.none),
4108
+ (ComplexTransform.conj, ComplexTransform.none),
4109
+ (ComplexTransform.none, ComplexTransform.conj),
4110
+ (ComplexTransform.conj, ComplexTransform.conj)
4111
+ ]
4112
+
4113
+ CreateGemmOperator(manifest, layouts, tile_descriptions, \
4114
+ data_type, alignment_constraints, complex_transforms)
4115
+ #
4116
+
4117
+ #
4118
+ def GenerateSM90_TensorOp_1684_rank_k(manifest, cuda_version):
4119
+
4120
+ if not CudaToolkitVersionSatisfies(cuda_version, 11, 8):
4121
+ return
4122
+
4123
+ layouts = [
4124
+ (LayoutType.ColumnMajor, LayoutType.ColumnMajor),
4125
+ (LayoutType.RowMajor, LayoutType.ColumnMajor),
4126
+ ]
4127
+
4128
+ fill_modes = [
4129
+ FillMode.Lower, FillMode.Upper,
4130
+ ]
4131
+
4132
+ math_inst = \
4133
+ MathInstruction( \
4134
+ [16, 8, 4], \
4135
+ DataType.f64, DataType.f64, DataType.f64, \
4136
+ OpcodeClass.TensorOp, \
4137
+ MathOperation.multiply_add)
4138
+
4139
+ min_cc = 90
4140
+ max_cc = 1024
4141
+
4142
+ alignment_constraints = [1,]
4143
+
4144
+ tile_descriptions = [
4145
+ TileDescription([128, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc),
4146
+ TileDescription([64, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc),
4147
+ TileDescription([128, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc),
4148
+ TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc),
4149
+ TileDescription([64, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc),
4150
+ TileDescription([32, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc),
4151
+ TileDescription([32, 32, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc),
4152
+ TileDescription([16, 32, 16], 5, [1, 2, 1], math_inst, min_cc, max_cc),
4153
+ TileDescription([32, 16, 16], 5, [2, 1, 1], math_inst, min_cc, max_cc),
4154
+ ]
4155
+
4156
+ data_type = [DataType.f64, DataType.f64, DataType.f64]
4157
+
4158
+ CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \
4159
+ data_type, alignment_constraints, BlasMode.symmetric)
4160
+ #
4161
+
4162
+ #
4163
+ def GenerateSM90_TensorOp_1684_rank_k_complex(manifest, cuda_version):
4164
+
4165
+ if not CudaToolkitVersionSatisfies(cuda_version, 11, 8):
4166
+ return
4167
+
4168
+ layouts = [
4169
+ (LayoutType.ColumnMajor, LayoutType.ColumnMajor),
4170
+ (LayoutType.RowMajor, LayoutType.ColumnMajor),
4171
+ ]
4172
+
4173
+ fill_modes = [
4174
+ FillMode.Lower, FillMode.Upper,
4175
+ ]
4176
+
4177
+ math_inst = \
4178
+ MathInstruction( \
4179
+ [16, 8, 4], \
4180
+ DataType.f64, DataType.f64, DataType.f64, \
4181
+ OpcodeClass.TensorOp, \
4182
+ MathOperation.multiply_add_complex)
4183
+
4184
+ min_cc = 90
4185
+ max_cc = 1024
4186
+
4187
+ alignment_constraints = [1,]
4188
+
4189
+ tile_descriptions = [
4190
+ TileDescription([128, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc),
4191
+ TileDescription([64, 128, 8], 3, [2, 4, 1], math_inst, min_cc, max_cc),
4192
+ TileDescription([64, 64, 8], 3, [2, 2, 1], math_inst, min_cc, max_cc),
4193
+ #TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
4194
+ #TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
4195
+ #TileDescription([32, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
4196
+ #TileDescription([16, 32, 8], 4, [1, 2, 1], math_inst, min_cc, max_cc),
4197
+ #TileDescription([32, 16, 8], 4, [2, 1, 1], math_inst, min_cc, max_cc),
4198
+ ]
4199
+
4200
+ data_type = [DataType.cf64, DataType.cf64, DataType.cf64]
4201
+
4202
+ # SYRK computation
4203
+ CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \
4204
+ data_type, alignment_constraints, BlasMode.symmetric)
4205
+
4206
+ # HERK computation
4207
+ CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \
4208
+ data_type, alignment_constraints, BlasMode.hermitian)
4209
+
4210
+ #
4211
+
4212
+ #
4213
+ def GenerateSM90_TensorOp_1684_rank_k_complex_gaussian(manifest, cuda_version):
4214
+
4215
+ if not CudaToolkitVersionSatisfies(cuda_version, 11, 8):
4216
+ return
4217
+
4218
+ layouts = [
4219
+ (LayoutType.ColumnMajor, LayoutType.ColumnMajor),
4220
+ (LayoutType.RowMajor, LayoutType.ColumnMajor),
4221
+ ]
4222
+
4223
+ fill_modes = [
4224
+ FillMode.Lower, FillMode.Upper,
4225
+ ]
4226
+
4227
+ math_inst = \
4228
+ MathInstruction( \
4229
+ [16, 8, 4], \
4230
+ DataType.f64, DataType.f64, DataType.f64, \
4231
+ OpcodeClass.TensorOp, \
4232
+ MathOperation.multiply_add_complex_gaussian)
4233
+
4234
+ min_cc = 90
4235
+ max_cc = 1024
4236
+
4237
+ alignment_constraints = [1,]
4238
+
4239
+ tile_descriptions = [
4240
+ TileDescription([64, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc),
4241
+ TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
4242
+ TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
4243
+ #TileDescription([32, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
4244
+ #TileDescription([16, 32, 8], 4, [1, 2, 1], math_inst, min_cc, max_cc),
4245
+ #TileDescription([32, 16, 8], 4, [2, 1, 1], math_inst, min_cc, max_cc),
4246
+ ]
4247
+
4248
+ data_type = [DataType.cf64, DataType.cf64, DataType.cf64]
4249
+
4250
+ complex_transforms = [ComplexTransform.none,]
4251
+
4252
+ # SYRK computation
4253
+ CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \
4254
+ data_type, alignment_constraints, BlasMode.symmetric)
4255
+
4256
+ # HERK computation
4257
+ CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \
4258
+ data_type, alignment_constraints, BlasMode.hermitian)
4259
+ #
4260
+
4261
+ #
4262
+ def GenerateSM90_TensorOp_1684_trmm(manifest, cuda_version):
4263
+
4264
+ if not CudaToolkitVersionSatisfies(cuda_version, 11, 8):
4265
+ return
4266
+
4267
+ layouts = [
4268
+ (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
4269
+ (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
4270
+ ]
4271
+
4272
+ side_modes = [
4273
+ SideMode.Left, SideMode.Right,
4274
+ ]
4275
+
4276
+ fill_modes = [
4277
+ FillMode.Lower, FillMode.Upper,
4278
+ ]
4279
+
4280
+ diag_types = [
4281
+ DiagType.NonUnit, DiagType.Unit,
4282
+ ]
4283
+
4284
+ math_inst = \
4285
+ MathInstruction( \
4286
+ [16, 8, 4], \
4287
+ DataType.f64, DataType.f64, DataType.f64, \
4288
+ OpcodeClass.TensorOp, \
4289
+ MathOperation.multiply_add)
4290
+
4291
+ min_cc = 90
4292
+ max_cc = 1024
4293
+
4294
+ alignment_constraints = [1,]
4295
+
4296
+ tile_descriptions = [
4297
+ TileDescription([128, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc),
4298
+ TileDescription([64, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc),
4299
+ TileDescription([128, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc),
4300
+ TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc),
4301
+ ]
4302
+
4303
+ data_type = [DataType.f64, DataType.f64, DataType.f64, DataType.f64]
4304
+
4305
+ CreateTrmmOperator(manifest, layouts, side_modes, fill_modes, diag_types, tile_descriptions, \
4306
+ data_type, alignment_constraints)
4307
+ #
4308
+
4309
+ #
4310
+ def GenerateSM90_TensorOp_1684_trmm_complex(manifest, cuda_version):
4311
+
4312
+ if not CudaToolkitVersionSatisfies(cuda_version, 11, 8):
4313
+ return
4314
+
4315
+ layouts = [
4316
+ (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
4317
+ (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
4318
+ ]
4319
+
4320
+ side_modes = [
4321
+ SideMode.Left, SideMode.Right,
4322
+ ]
4323
+
4324
+ fill_modes = [
4325
+ FillMode.Lower, FillMode.Upper,
4326
+ ]
4327
+
4328
+ diag_types = [
4329
+ DiagType.NonUnit, DiagType.Unit,
4330
+ ]
4331
+
4332
+ math_inst = \
4333
+ MathInstruction( \
4334
+ [16, 8, 4], \
4335
+ DataType.f64, DataType.f64, DataType.f64, \
4336
+ OpcodeClass.TensorOp, \
4337
+ MathOperation.multiply_add_complex)
4338
+
4339
+ min_cc = 90
4340
+ max_cc = 1024
4341
+
4342
+ alignment_constraints = [1,]
4343
+
4344
+ tile_descriptions = [
4345
+ TileDescription([128, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc),
4346
+ TileDescription([64, 128, 8], 3, [2, 4, 1], math_inst, min_cc, max_cc),
4347
+ TileDescription([64, 64, 8], 3, [2, 2, 1], math_inst, min_cc, max_cc),
4348
+ TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
4349
+ TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
4350
+ ]
4351
+
4352
+ data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64]
4353
+
4354
+ complex_transforms = [
4355
+ ComplexTransform.none, ComplexTransform.conj,
4356
+ ]
4357
+
4358
+ CreateTrmmOperator(manifest, layouts, side_modes, fill_modes, diag_types, tile_descriptions, \
4359
+ data_type, alignment_constraints, complex_transforms)
4360
+ #
4361
+
4362
+
4363
+ #
4364
+ def GenerateSM90_TensorOp_1684_trmm_complex_gaussian(manifest, cuda_version):
4365
+
4366
+ if not CudaToolkitVersionSatisfies(cuda_version, 11, 8):
4367
+ return
4368
+
4369
+ layouts = [
4370
+ (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
4371
+ (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
4372
+ ]
4373
+
4374
+ side_modes = [
4375
+ SideMode.Left, SideMode.Right,
4376
+ ]
4377
+
4378
+ fill_modes = [
4379
+ FillMode.Lower, FillMode.Upper,
4380
+ ]
4381
+
4382
+ diag_types = [
4383
+ DiagType.NonUnit, DiagType.Unit,
4384
+ ]
4385
+
4386
+ math_inst = \
4387
+ MathInstruction( \
4388
+ [16, 8, 4], \
4389
+ DataType.f64, DataType.f64, DataType.f64, \
4390
+ OpcodeClass.TensorOp, \
4391
+ MathOperation.multiply_add_complex_gaussian)
4392
+
4393
+ min_cc = 90
4394
+ max_cc = 1024
4395
+
4396
+ alignment_constraints = [1,]
4397
+
4398
+ tile_descriptions = [
4399
+ TileDescription([64, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc),
4400
+ TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
4401
+ TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
4402
+ ]
4403
+
4404
+ data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64]
4405
+
4406
+ complex_transforms = [
4407
+ ComplexTransform.none, ComplexTransform.conj,
4408
+ ]
4409
+
4410
+ CreateTrmmOperator(manifest, layouts, side_modes, fill_modes, diag_types, tile_descriptions, \
4411
+ data_type, alignment_constraints, complex_transforms)
4412
+ #
4413
+
4414
+ #
4415
+ def GenerateSM90_TensorOp_1684_symm(manifest, cuda_version):
4416
+
4417
+ if not CudaToolkitVersionSatisfies(cuda_version, 11, 8):
4418
+ return
4419
+
4420
+ layouts = [
4421
+ (LayoutType.ColumnMajor, LayoutType.ColumnMajor),
4422
+ ]
4423
+
4424
+ side_modes = [
4425
+ SideMode.Left, SideMode.Right,
4426
+ ]
4427
+
4428
+ fill_modes = [
4429
+ FillMode.Lower, FillMode.Upper,
4430
+ ]
4431
+
4432
+ math_inst = \
4433
+ MathInstruction( \
4434
+ [16, 8, 4], \
4435
+ DataType.f64, DataType.f64, DataType.f64, \
4436
+ OpcodeClass.TensorOp, \
4437
+ MathOperation.multiply_add)
4438
+
4439
+ min_cc = 90
4440
+ max_cc = 1024
4441
+
4442
+ alignment_constraints = [1,]
4443
+
4444
+ tile_descriptions = [
4445
+ TileDescription([128, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc),
4446
+ TileDescription([64, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc),
4447
+ TileDescription([128, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc),
4448
+ TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc),
4449
+ TileDescription([64, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc),
4450
+ TileDescription([32, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc),
4451
+ TileDescription([32, 32, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc),
4452
+ TileDescription([16, 32, 16], 5, [1, 2, 1], math_inst, min_cc, max_cc),
4453
+ TileDescription([32, 16, 16], 5, [2, 1, 1], math_inst, min_cc, max_cc),
4454
+ ]
4455
+
4456
+ data_type = [DataType.f64, DataType.f64, DataType.f64, DataType.f64]
4457
+
4458
+ CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \
4459
+ data_type, alignment_constraints, BlasMode.symmetric)
4460
+ #
4461
+
4462
+ #
4463
+ def GenerateSM90_TensorOp_1684_symm_complex(manifest, cuda_version):
4464
+
4465
+ if not CudaToolkitVersionSatisfies(cuda_version, 11, 8):
4466
+ return
4467
+
4468
+ layouts = [
4469
+ (LayoutType.ColumnMajor, LayoutType.ColumnMajor),
4470
+ ]
4471
+
4472
+ side_modes = [
4473
+ SideMode.Left, SideMode.Right,
4474
+ ]
4475
+
4476
+ fill_modes = [
4477
+ FillMode.Lower, FillMode.Upper,
4478
+ ]
4479
+
4480
+ math_inst = \
4481
+ MathInstruction( \
4482
+ [16, 8, 4], \
4483
+ DataType.f64, DataType.f64, DataType.f64, \
4484
+ OpcodeClass.TensorOp, \
4485
+ MathOperation.multiply_add_complex)
4486
+
4487
+ min_cc = 90
4488
+ max_cc = 1024
4489
+
4490
+ alignment_constraints = [1,]
4491
+
4492
+ tile_descriptions = [
4493
+ TileDescription([128, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc),
4494
+ TileDescription([64, 128, 8], 3, [2, 4, 1], math_inst, min_cc, max_cc),
4495
+ TileDescription([64, 64, 8], 3, [2, 2, 1], math_inst, min_cc, max_cc),
4496
+ #TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
4497
+ #TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
4498
+ #TileDescription([32, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
4499
+ #TileDescription([16, 32, 8], 4, [1, 2, 1], math_inst, min_cc, max_cc),
4500
+ #TileDescription([32, 16, 8], 4, [2, 1, 1], math_inst, min_cc, max_cc),
4501
+ ]
4502
+
4503
+ data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64]
4504
+
4505
+ # SYMM computation
4506
+ CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \
4507
+ data_type, alignment_constraints, BlasMode.symmetric)
4508
+
4509
+ # HEMM computation
4510
+ CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \
4511
+ data_type, alignment_constraints, BlasMode.hermitian)
4512
+ #
4513
+
4514
+ #
4515
+ def GenerateSM90_TensorOp_1684_symm_complex_gaussian(manifest, cuda_version):
4516
+
4517
+ if not CudaToolkitVersionSatisfies(cuda_version, 11, 8):
4518
+ return
4519
+
4520
+ layouts = [
4521
+ (LayoutType.ColumnMajor, LayoutType.ColumnMajor),
4522
+ ]
4523
+
4524
+ side_modes = [
4525
+ SideMode.Left, SideMode.Right,
4526
+ ]
4527
+
4528
+ fill_modes = [
4529
+ FillMode.Lower, FillMode.Upper,
4530
+ ]
4531
+
4532
+ math_inst = \
4533
+ MathInstruction( \
4534
+ [16, 8, 4], \
4535
+ DataType.f64, DataType.f64, DataType.f64, \
4536
+ OpcodeClass.TensorOp, \
4537
+ MathOperation.multiply_add_complex_gaussian)
4538
+
4539
+ min_cc = 90
4540
+ max_cc = 1024
4541
+
4542
+ alignment_constraints = [1,]
4543
+
4544
+ tile_descriptions = [
4545
+ TileDescription([64, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc),
4546
+ TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
4547
+ TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
4548
+ #TileDescription([32, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
4549
+ #TileDescription([16, 32, 8], 4, [1, 2, 1], math_inst, min_cc, max_cc),
4550
+ #TileDescription([32, 16, 8], 4, [2, 1, 1], math_inst, min_cc, max_cc),
4551
+ ]
4552
+
4553
+ data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64]
4554
+
4555
+ complex_transforms = [ComplexTransform.none,]
4556
+
4557
+ # SYMM computation
4558
+ CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \
4559
+ data_type, alignment_constraints, BlasMode.symmetric)
4560
+
4561
+ # HEMM computation
4562
+ CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \
4563
+ data_type, alignment_constraints, BlasMode.hermitian)
4564
+ #
4565
+
4566
+ ###################################################################################################
4567
+
4568
+ #
4569
+ def GenerateSM90(manifest, cuda_version):
4570
+
4571
+ GenerateSM90_TensorOp_1684(manifest, cuda_version)
4572
+ GenerateSM90_TensorOp_1684_complex(manifest, cuda_version)
4573
+ GenerateSM90_TensorOp_1684_complex_gaussian(manifest, cuda_version)
4574
+
4575
+ GenerateSM90_TensorOp_1684_rank_k(manifest, cuda_version)
4576
+ GenerateSM90_TensorOp_1684_rank_k_complex(manifest, cuda_version)
4577
+ GenerateSM90_TensorOp_1684_rank_k_complex_gaussian(manifest, cuda_version)
4578
+ GenerateSM90_TensorOp_1684_trmm(manifest, cuda_version)
4579
+ GenerateSM90_TensorOp_1684_trmm_complex(manifest, cuda_version)
4580
+ GenerateSM90_TensorOp_1684_trmm_complex_gaussian(manifest, cuda_version)
4581
+ GenerateSM90_TensorOp_1684_symm(manifest, cuda_version)
4582
+ GenerateSM90_TensorOp_1684_symm_complex(manifest, cuda_version)
4583
+ GenerateSM90_TensorOp_1684_symm_complex_gaussian(manifest, cuda_version)
4584
+
4585
+ ###################################################################################################
4586
+
4587
+ if __name__ == "__main__":
4588
+
4589
+ parser = argparse.ArgumentParser(description="Generates device kernel registration code for CUTLASS Kernels")
4590
+ parser.add_argument("--operations", default="all", help="Specifies the operation to generate (gemm, all)")
4591
+ parser.add_argument("--build-dir", default=".", required=False, help="CUTLASS top-level build directory")
4592
+ parser.add_argument("--curr-build-dir", default=".", help="CUTLASS current build directory. cmake files will be emitted in this directory")
4593
+ parser.add_argument("--generator-target", default='library', help="Target of CUTLASS Library Generator.")
4594
+ parser.add_argument("--architectures", default='53;60;61;70;75;80', help="Target compute architectures")
4595
+ parser.add_argument("--kernels", default='', help='Comma delimited list to filter kernels by name.')
4596
+ parser.add_argument("--ignore-kernels", default='', help='Comma delimited list of kernels to exclude from build.')
4597
+ parser.add_argument("--filter-by-cc", default='True', type=str, help='If enabled, kernels whose comupte capability range is not satisfied by the build target are excluded.')
4598
+ parser.add_argument("--cuda-version", default="11.0.0", help="Semantic version string of CUDA Toolkit")
4599
+ parser.add_argument('--kernel-filter-file', type=str, default=None, required=False, help='Full path of filter file')
4600
+ parser.add_argument('--selected-kernel-list', type=str, default=None, required=False,
4601
+ help='Specify the output log file containing all enabled kernels in this build')
4602
+ parser.add_argument("--interface-dir", default=None, required=False, help="Interface header to kernels")
4603
+
4604
+ args = parser.parse_args()
4605
+
4606
+ manifest = Manifest(args)
4607
+
4608
+ GenerateSM50(manifest, args.cuda_version)
4609
+ GenerateSM60(manifest, args.cuda_version)
4610
+ GenerateSM61(manifest, args.cuda_version)
4611
+ GenerateSM70(manifest, args.cuda_version)
4612
+ GenerateSM75(manifest, args.cuda_version)
4613
+ GenerateSM80(manifest, args.cuda_version)
4614
+ GenerateSM90(manifest, args.cuda_version)
4615
+
4616
+ if 'library' in args.generator_target.split(','):
4617
+ manifest.emit(GeneratorTarget.Library)
4618
+
4619
+ if args.selected_kernel_list is not None:
4620
+ if len(manifest.selected_kernels) > 0:
4621
+ with open(args.selected_kernel_list, 'w') as file_writer:
4622
+ for line in manifest.selected_kernels:
4623
+ file_writer.write("%s\n" % line)
4624
+ #
4625
+ ###################################################################################################