warp-lang 0.10.1__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (300) hide show
  1. warp/__init__.py +10 -4
  2. warp/__init__.pyi +1 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +5 -3
  6. warp/build_dll.py +29 -9
  7. warp/builtins.py +868 -507
  8. warp/codegen.py +1074 -638
  9. warp/config.py +3 -3
  10. warp/constants.py +6 -0
  11. warp/context.py +715 -222
  12. warp/fabric.py +326 -0
  13. warp/fem/__init__.py +27 -0
  14. warp/fem/cache.py +389 -0
  15. warp/fem/dirichlet.py +181 -0
  16. warp/fem/domain.py +263 -0
  17. warp/fem/field/__init__.py +101 -0
  18. warp/fem/field/field.py +149 -0
  19. warp/fem/field/nodal_field.py +299 -0
  20. warp/fem/field/restriction.py +21 -0
  21. warp/fem/field/test.py +181 -0
  22. warp/fem/field/trial.py +183 -0
  23. warp/fem/geometry/__init__.py +19 -0
  24. warp/fem/geometry/closest_point.py +70 -0
  25. warp/fem/geometry/deformed_geometry.py +271 -0
  26. warp/fem/geometry/element.py +744 -0
  27. warp/fem/geometry/geometry.py +186 -0
  28. warp/fem/geometry/grid_2d.py +373 -0
  29. warp/fem/geometry/grid_3d.py +435 -0
  30. warp/fem/geometry/hexmesh.py +953 -0
  31. warp/fem/geometry/partition.py +376 -0
  32. warp/fem/geometry/quadmesh_2d.py +532 -0
  33. warp/fem/geometry/tetmesh.py +840 -0
  34. warp/fem/geometry/trimesh_2d.py +577 -0
  35. warp/fem/integrate.py +1616 -0
  36. warp/fem/operator.py +191 -0
  37. warp/fem/polynomial.py +213 -0
  38. warp/fem/quadrature/__init__.py +2 -0
  39. warp/fem/quadrature/pic_quadrature.py +245 -0
  40. warp/fem/quadrature/quadrature.py +294 -0
  41. warp/fem/space/__init__.py +292 -0
  42. warp/fem/space/basis_space.py +489 -0
  43. warp/fem/space/collocated_function_space.py +105 -0
  44. warp/fem/space/dof_mapper.py +236 -0
  45. warp/fem/space/function_space.py +145 -0
  46. warp/fem/space/grid_2d_function_space.py +267 -0
  47. warp/fem/space/grid_3d_function_space.py +306 -0
  48. warp/fem/space/hexmesh_function_space.py +352 -0
  49. warp/fem/space/partition.py +350 -0
  50. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  51. warp/fem/space/restriction.py +160 -0
  52. warp/fem/space/shape/__init__.py +15 -0
  53. warp/fem/space/shape/cube_shape_function.py +738 -0
  54. warp/fem/space/shape/shape_function.py +103 -0
  55. warp/fem/space/shape/square_shape_function.py +611 -0
  56. warp/fem/space/shape/tet_shape_function.py +567 -0
  57. warp/fem/space/shape/triangle_shape_function.py +429 -0
  58. warp/fem/space/tetmesh_function_space.py +292 -0
  59. warp/fem/space/topology.py +295 -0
  60. warp/fem/space/trimesh_2d_function_space.py +221 -0
  61. warp/fem/types.py +77 -0
  62. warp/fem/utils.py +495 -0
  63. warp/native/array.h +147 -44
  64. warp/native/builtin.h +122 -149
  65. warp/native/bvh.cpp +73 -325
  66. warp/native/bvh.cu +406 -23
  67. warp/native/bvh.h +34 -43
  68. warp/native/clang/clang.cpp +13 -8
  69. warp/native/crt.h +2 -0
  70. warp/native/cuda_crt.h +5 -0
  71. warp/native/cuda_util.cpp +15 -3
  72. warp/native/cuda_util.h +3 -1
  73. warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
  74. warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
  75. warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
  76. warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
  77. warp/native/cutlass/tools/library/scripts/library.py +799 -0
  78. warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
  79. warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
  80. warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
  81. warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
  82. warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
  83. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
  84. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
  85. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
  86. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
  87. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
  88. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
  89. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
  90. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
  91. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
  92. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
  93. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
  94. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
  95. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
  96. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
  97. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
  98. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
  99. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
  100. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
  101. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
  102. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
  103. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
  104. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
  105. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
  106. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
  107. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
  108. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
  109. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
  110. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
  111. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
  112. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
  113. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
  114. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
  115. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  116. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  117. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
  118. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
  119. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
  120. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  121. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  122. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
  123. warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
  124. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
  125. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
  126. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
  127. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
  128. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
  129. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
  130. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
  131. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
  132. warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
  133. warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
  134. warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
  135. warp/native/cutlass/tools/library/scripts/rt.py +796 -0
  136. warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
  137. warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
  138. warp/native/cutlass_gemm.cu +5 -3
  139. warp/native/exports.h +1240 -952
  140. warp/native/fabric.h +228 -0
  141. warp/native/hashgrid.cpp +4 -4
  142. warp/native/hashgrid.h +22 -2
  143. warp/native/intersect.h +22 -7
  144. warp/native/intersect_adj.h +8 -8
  145. warp/native/intersect_tri.h +1 -1
  146. warp/native/marching.cu +157 -161
  147. warp/native/mat.h +80 -19
  148. warp/native/matnn.h +2 -2
  149. warp/native/mesh.cpp +33 -108
  150. warp/native/mesh.cu +114 -23
  151. warp/native/mesh.h +446 -46
  152. warp/native/noise.h +272 -329
  153. warp/native/quat.h +51 -8
  154. warp/native/rand.h +45 -35
  155. warp/native/range.h +6 -2
  156. warp/native/reduce.cpp +1 -1
  157. warp/native/reduce.cu +10 -12
  158. warp/native/runlength_encode.cu +6 -10
  159. warp/native/scan.cu +8 -11
  160. warp/native/sparse.cpp +4 -4
  161. warp/native/sparse.cu +164 -154
  162. warp/native/spatial.h +2 -2
  163. warp/native/temp_buffer.h +14 -30
  164. warp/native/vec.h +107 -23
  165. warp/native/volume.h +120 -0
  166. warp/native/warp.cpp +560 -30
  167. warp/native/warp.cu +431 -44
  168. warp/native/warp.h +13 -4
  169. warp/optim/__init__.py +1 -0
  170. warp/optim/linear.py +922 -0
  171. warp/optim/sgd.py +92 -0
  172. warp/render/render_opengl.py +335 -119
  173. warp/render/render_usd.py +11 -11
  174. warp/sim/__init__.py +2 -2
  175. warp/sim/articulation.py +385 -185
  176. warp/sim/collide.py +8 -0
  177. warp/sim/import_mjcf.py +297 -106
  178. warp/sim/import_urdf.py +389 -210
  179. warp/sim/import_usd.py +198 -97
  180. warp/sim/inertia.py +17 -18
  181. warp/sim/integrator_euler.py +14 -8
  182. warp/sim/integrator_xpbd.py +158 -16
  183. warp/sim/model.py +795 -291
  184. warp/sim/render.py +3 -3
  185. warp/sim/utils.py +3 -0
  186. warp/sparse.py +640 -150
  187. warp/stubs.py +606 -267
  188. warp/tape.py +61 -10
  189. warp/tests/__main__.py +3 -6
  190. warp/tests/assets/curlnoise_golden.npy +0 -0
  191. warp/tests/assets/pnoise_golden.npy +0 -0
  192. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  193. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  194. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  195. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  196. warp/tests/aux_test_unresolved_func.py +14 -0
  197. warp/tests/aux_test_unresolved_symbol.py +14 -0
  198. warp/tests/disabled_kinematics.py +239 -0
  199. warp/tests/run_coverage_serial.py +31 -0
  200. warp/tests/test_adam.py +103 -106
  201. warp/tests/test_arithmetic.py +128 -74
  202. warp/tests/test_array.py +212 -97
  203. warp/tests/test_array_reduce.py +57 -23
  204. warp/tests/test_atomic.py +64 -28
  205. warp/tests/test_bool.py +99 -0
  206. warp/tests/test_builtins_resolution.py +1292 -0
  207. warp/tests/test_bvh.py +42 -18
  208. warp/tests/test_closest_point_edge_edge.py +54 -57
  209. warp/tests/test_codegen.py +208 -130
  210. warp/tests/test_compile_consts.py +28 -20
  211. warp/tests/test_conditional.py +108 -24
  212. warp/tests/test_copy.py +10 -12
  213. warp/tests/test_ctypes.py +112 -88
  214. warp/tests/test_dense.py +21 -14
  215. warp/tests/test_devices.py +98 -0
  216. warp/tests/test_dlpack.py +75 -75
  217. warp/tests/test_examples.py +277 -0
  218. warp/tests/test_fabricarray.py +955 -0
  219. warp/tests/test_fast_math.py +15 -11
  220. warp/tests/test_fem.py +1271 -0
  221. warp/tests/test_fp16.py +53 -19
  222. warp/tests/test_func.py +187 -86
  223. warp/tests/test_generics.py +194 -49
  224. warp/tests/test_grad.py +178 -109
  225. warp/tests/test_grad_customs.py +176 -0
  226. warp/tests/test_hash_grid.py +52 -37
  227. warp/tests/test_import.py +10 -23
  228. warp/tests/test_indexedarray.py +32 -31
  229. warp/tests/test_intersect.py +18 -9
  230. warp/tests/test_large.py +141 -0
  231. warp/tests/test_launch.py +14 -41
  232. warp/tests/test_lerp.py +64 -65
  233. warp/tests/test_linear_solvers.py +154 -0
  234. warp/tests/test_lvalue.py +493 -0
  235. warp/tests/test_marching_cubes.py +12 -13
  236. warp/tests/test_mat.py +517 -2898
  237. warp/tests/test_mat_lite.py +115 -0
  238. warp/tests/test_mat_scalar_ops.py +2889 -0
  239. warp/tests/test_math.py +103 -9
  240. warp/tests/test_matmul.py +305 -69
  241. warp/tests/test_matmul_lite.py +410 -0
  242. warp/tests/test_mesh.py +71 -14
  243. warp/tests/test_mesh_query_aabb.py +41 -25
  244. warp/tests/test_mesh_query_point.py +140 -22
  245. warp/tests/test_mesh_query_ray.py +39 -22
  246. warp/tests/test_mlp.py +30 -22
  247. warp/tests/test_model.py +92 -89
  248. warp/tests/test_modules_lite.py +39 -0
  249. warp/tests/test_multigpu.py +88 -114
  250. warp/tests/test_noise.py +12 -11
  251. warp/tests/test_operators.py +16 -20
  252. warp/tests/test_options.py +11 -11
  253. warp/tests/test_pinned.py +17 -18
  254. warp/tests/test_print.py +32 -11
  255. warp/tests/test_quat.py +275 -129
  256. warp/tests/test_rand.py +18 -16
  257. warp/tests/test_reload.py +38 -34
  258. warp/tests/test_rounding.py +50 -43
  259. warp/tests/test_runlength_encode.py +168 -20
  260. warp/tests/test_smoothstep.py +9 -11
  261. warp/tests/test_snippet.py +143 -0
  262. warp/tests/test_sparse.py +261 -63
  263. warp/tests/test_spatial.py +276 -243
  264. warp/tests/test_streams.py +110 -85
  265. warp/tests/test_struct.py +268 -63
  266. warp/tests/test_tape.py +39 -21
  267. warp/tests/test_torch.py +118 -89
  268. warp/tests/test_transient_module.py +12 -13
  269. warp/tests/test_types.py +614 -0
  270. warp/tests/test_utils.py +494 -0
  271. warp/tests/test_vec.py +354 -2050
  272. warp/tests/test_vec_lite.py +73 -0
  273. warp/tests/test_vec_scalar_ops.py +2099 -0
  274. warp/tests/test_volume.py +457 -293
  275. warp/tests/test_volume_write.py +124 -134
  276. warp/tests/unittest_serial.py +35 -0
  277. warp/tests/unittest_suites.py +341 -0
  278. warp/tests/unittest_utils.py +568 -0
  279. warp/tests/unused_test_misc.py +71 -0
  280. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  281. warp/thirdparty/appdirs.py +36 -45
  282. warp/thirdparty/unittest_parallel.py +549 -0
  283. warp/torch.py +9 -6
  284. warp/types.py +1089 -366
  285. warp/utils.py +93 -387
  286. warp_lang-0.11.0.dist-info/METADATA +238 -0
  287. warp_lang-0.11.0.dist-info/RECORD +332 -0
  288. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
  289. warp/tests/test_all.py +0 -219
  290. warp/tests/test_array_scan.py +0 -60
  291. warp/tests/test_base.py +0 -208
  292. warp/tests/test_unresolved_func.py +0 -7
  293. warp/tests/test_unresolved_symbol.py +0 -7
  294. warp_lang-0.10.1.dist-info/METADATA +0 -21
  295. warp_lang-0.10.1.dist-info/RECORD +0 -188
  296. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  297. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  298. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  299. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/LICENSE.md +0 -0
  300. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,631 @@
1
+ ################################################################################
2
+ #
3
+ # Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ ################################################################################
32
+ from typeguard import typechecked
33
+ from cuda import cuda
34
+ from typing import Union
35
+ import numpy as np
36
+
37
+ from typeguard import typechecked
38
+
39
+ from pycutlass import *
40
+
41
+
42
+ # @typechecked
43
+ class Conv2dArguments(ArgumentBase):
44
+ """
45
+ Argument wrapper for Conv2d. It encodes problem information and
46
+ user-provide tensors into the kernel's argument.
47
+
48
+ :param operation: the Conv2d operation to take the argument
49
+ :type operation: :class:`pycutlass.Conv2dOperation`
50
+
51
+ :param problem_size: the Conv2d problem size
52
+ :type problem_size: :class:`cutlass.conv.Conv2dProblemSize`
53
+
54
+ :param A: tensor A
55
+ :type A: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
56
+
57
+ :param B: tensor B
58
+ :type B: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
59
+
60
+ :param C: tensor C
61
+ :type C: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
62
+
63
+ :param D: tensor D
64
+ :type D: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
65
+
66
+ :param split_k_mode: conv2d split K mode, defaults to
67
+ cutlass.conv.SplitKMode.Serial
68
+ :type split_k_mode: cutlass.conv.SplitKMode, optional
69
+
70
+ :param output_op: output operator, optional
71
+ :type output_op: :class:`pycutlass.LinearCombinationFunctorArguments`
72
+
73
+ """
74
+
75
+ def __init__(self, operation: 'Conv2dOperation',
76
+ problem_size: 'cutlass.conv.Conv2dProblemSize',
77
+ A: 'Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor]',
78
+ B: 'Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor]',
79
+ C: 'Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor]',
80
+ D: 'Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor]',
81
+ split_k_mode: 'cutlass.conv.SplitKMode'
82
+ = cutlass.conv.SplitKMode.Serial, **kwargs) -> None:
83
+
84
+ self.operation = operation
85
+ #: convolution kind
86
+ self.conv_kind: cutlass.conv.Operator = operation.conv_kind
87
+ self.layout_A: cutlass.layout = operation.A.layout
88
+ self.layout_B: cutlass.layout = operation.B.layout
89
+ self.layout_C: cutlass.layout = operation.C.layout
90
+
91
+ self.element_A = operation.A.element
92
+ self.element_B = operation.B.element
93
+ self.element_C = operation.C.element
94
+
95
+ if self.layout_C == cutlass.TensorNC32HW32:
96
+ B = self.reorder_tensor_B(B, problem_size)
97
+
98
+ super().__init__(A, B, C, D, **kwargs)
99
+ # preprocessing output ops
100
+
101
+ if 'output_op' in kwargs.keys() and \
102
+ split_k_mode != cutlass.conv.SplitKMode.Parallel:
103
+ self.output_op = kwargs['output_op']
104
+ else:
105
+ self.output_op = self.operation.epilogue_type(1.0, 0.0)
106
+
107
+ if "split_k_slices" in kwargs.keys():
108
+ self.split_k_mode = split_k_mode
109
+ self.split_k_slices = kwargs["split_k_slices"]
110
+ else:
111
+ self.split_k_mode = cutlass.conv.SplitKMode.Serial
112
+ self.split_k_slices = 1
113
+
114
+ #: problem_size
115
+ self.problem_size: cutlass.conv.Conv2dProblemSize = problem_size
116
+ self.problem_size.split_k_slices = self.split_k_slices
117
+
118
+ if hasattr(self, "tensor_c_numel"):
119
+ c_coord = cutlass.conv.implicit_gemm_tensor_c_extent(
120
+ self.conv_kind, problem_size)
121
+ if (self.tensor_c_numel == c_coord.at(3) and
122
+ self.tensor_c_numel < c_coord.size()):
123
+ self.bias = True
124
+
125
+ #
126
+ # initialize the argument
127
+ #
128
+ self.initialize()
129
+
130
+ # @typechecked
131
+ def reorder_tensor_B(self, tensor_B: 'np.ndarray',
132
+ problem_size: 'cutlass.conv.Conv2dProblemSize'):
133
+ """
134
+ Reorder tensor_B for interleaved layout
135
+
136
+ :param tensor_B: input tensor B
137
+ :type tensor_B: numpy.ndarray
138
+ :param problem_size: Conv2d problem size
139
+ :type problem_size: :class:`cutlass.conv.Conv2dProblemSize`
140
+
141
+ :return: reordered tensor B
142
+ :rtype: numpy.ndarray
143
+ """
144
+ reordered_tensor_B = np.empty_like(tensor_B)
145
+ tensor_ref_B = self.get_tensor_ref(
146
+ tensor_B, self.element_B, self.layout_B, problem_size, "b")
147
+ reordered_tensor_ref_B = self.get_tensor_ref(
148
+ reordered_tensor_B, self.element_B,
149
+ self.layout_B, problem_size, "b")
150
+ cutlass.conv.host.reorder_convK(
151
+ reordered_tensor_ref_B, tensor_ref_B, self.conv_kind, problem_size)
152
+
153
+ return reordered_tensor_B
154
+
155
+ def get_tensor_ref(
156
+ self, tensor, dtype, tensor_layout, problem_size, operand):
157
+ if operand == "a":
158
+ tensor_coord = cutlass.conv.implicit_gemm_tensor_a_extent(
159
+ self.conv_kind, problem_size)
160
+ elif operand == "b":
161
+ tensor_coord = cutlass.conv.implicit_gemm_tensor_b_extent(
162
+ self.conv_kind, problem_size)
163
+ elif operand in ["c", "d"]:
164
+ tensor_coord = cutlass.conv.implicit_gemm_tensor_c_extent(
165
+ self.conv_kind, problem_size)
166
+ else:
167
+ raise ValueError("unknown operand: " + operand)
168
+ # Zero stride trick
169
+ if operand == "c" and self.bias:
170
+ tensor_coord = cutlass.Tensor4DCoord(0, 0, 0, 0)
171
+
172
+ layout = tensor_layout.packed(tensor_coord)
173
+
174
+ return TensorRef(tensor, dtype, layout).tensor_ref
175
+
176
+ def get_arguments(self, semaphore):
177
+ ref_A = TensorRef_(self.get_tensor_ref(
178
+ self.ptr_A, self.element_A, self.layout_A, self.problem_size, "a"))
179
+ ref_B = TensorRef_(self.get_tensor_ref(
180
+ self.ptr_B, self.element_B, self.layout_B, self.problem_size, "b"))
181
+ ref_C = TensorRef_(self.get_tensor_ref(
182
+ self.ptr_C, self.element_C, self.layout_C, self.problem_size, "c"))
183
+ ref_D = TensorRef_(self.get_tensor_ref(
184
+ self.ptr_D, self.element_C, self.layout_C, self.problem_size, "d"))
185
+
186
+ self.c_arguments = self.operation.argument_type(
187
+ Conv2DProblemSize(self.problem_size),
188
+ ref_A, ref_B, ref_C, ref_D, self.output_op, self.split_k_mode
189
+ )
190
+
191
+ self.semaphore = semaphore
192
+
193
+ def initialize(self):
194
+ """
195
+ Initialize the kernel arguments handling following stuffs
196
+ 1. get kernel launch configuration including grid, cta size,
197
+ and dynamic shared memory capacity
198
+ 2. allocate and initialize device workspace
199
+ 3. get kernel params as bytearray for NVRTC input
200
+ """
201
+ # get launch configuration
202
+ self.launch_config = self.operation.rt_module.plan(self)
203
+
204
+ # allocate and initialize device workspace
205
+ device_workspace_size = \
206
+ self.operation.rt_module.get_device_workspace_size(self)
207
+
208
+ if device_workspace_size > 0:
209
+ self.workspace_buffer = device_mem_alloc(device_workspace_size)
210
+ workspace_ptr = self.workspace_buffer.ptr
211
+ err, = cuda.cuMemsetD32(
212
+ workspace_ptr, 0, device_workspace_size // 4)
213
+ else:
214
+ workspace_ptr = None
215
+
216
+ # get kernel params as bytearray
217
+ semaphore = 0
218
+ if workspace_ptr is not None and \
219
+ self.split_k_mode == cutlass.conv.SplitKMode.Parallel:
220
+ self.ptr_D = workspace_ptr
221
+ elif workspace_ptr is not None and \
222
+ self.split_k_mode == cutlass.conv.SplitKMode.Serial:
223
+ semaphore = workspace_ptr
224
+
225
+ self.get_arguments(semaphore)
226
+
227
+ params_ = self.operation.rt_module.get_args(ctypes.byref(
228
+ self.c_arguments), ctypes.c_void_p(int(self.semaphore)))
229
+ self.host_workspace = bytearray(params_.contents)
230
+ self.device_workspace = None
231
+
232
+ def sync(self):
233
+ """
234
+ Synchronize the arguments. If the input tensor is in host,
235
+ copy it from device to host.
236
+ """
237
+ return super().sync()
238
+
239
+
240
+ # @typechecked
241
+ class Conv2dRT(ExecutableOperation):
242
+ """
243
+ Conv2dRT manages the CUTLASS runtime components
244
+ """
245
+ KernelTemplate = r'''
246
+ extern "C"
247
+ __global__ void
248
+ ${operation_name}(${operation_name}${operation_suffix}::Params params) {
249
+
250
+ // Dynamic shared memory base pointer
251
+ extern __shared__ int SharedStorageBase[];
252
+
253
+ // Declare pointer to dynamic shared memory.
254
+ ${operation_name}${operation_suffix}::SharedStorage *shared_storage =
255
+ reinterpret_cast<${operation_name}${operation_suffix}::SharedStorage *>(SharedStorageBase);
256
+
257
+ ${operation_name}${operation_suffix} op;
258
+
259
+ op(params, *shared_storage);
260
+ }
261
+ '''
262
+
263
+ HostTemplate = r'''
264
+ extern "C" {
265
+ // Get the size of params in bytes
266
+ int ${operation_name}_get_param_size(){
267
+ return sizeof(${operation_name}${operation_suffix}::Params);
268
+ }
269
+
270
+ // Get the size of dynamic shared memory in bytes
271
+ int ${operation_name}_shared_memory_size() {
272
+ return int(sizeof(${operation_name}${operation_suffix}::SharedStorage));
273
+ }
274
+
275
+ // Get the params as byte array
276
+ char* ${operation_name}_get_params(${operation_name}${operation_suffix}::Arguments* arguments, int *semaphore=nullptr){
277
+ typename ${operation_name}${operation_suffix}::Params* params;
278
+ params = new ${operation_name}${operation_suffix}::Params(*arguments, semaphore);
279
+
280
+ char *bytes = ((char*)(params));
281
+ char *output = new char[sizeof(${operation_name}${operation_suffix}::Params)];
282
+ for (unsigned int i = 0; i < sizeof(${operation_name}${operation_suffix}::Params); i ++)
283
+ output[i] = bytes[i];
284
+
285
+ return output;
286
+ }
287
+ }
288
+
289
+ '''
290
+
291
+ def __init__(self, operation: 'Conv2dOperation'):
292
+ super().__init__(operation)
293
+ self.argument_type, self.epilogue_type = get_conv2d_arguments(operation.epilogue_functor)
294
+ self.argtype = [ctypes.POINTER(self.argument_type), ctypes.c_void_p]
295
+ self.conv_kind = operation.conv_kind
296
+
297
+ self.operation: Conv2dOperation = operation
298
+
299
+ self.emitter = EmitConv2dInstance('_type')
300
+
301
+ self.threads: int = operation.tile_description.num_threads
302
+
303
+ self.swizzle_functor = operation.swizzling_functor
304
+
305
+ def emit(self):
306
+ return self.emitter.emit(self.operation)
307
+
308
+ # @typechecked
309
+ def get_device_workspace_size(self, arguments: Conv2dArguments):
310
+ workspace_bytes = 0
311
+
312
+ launch_config = arguments.launch_config
313
+
314
+ self.conv_kind = self.operation.conv_kind
315
+
316
+ if arguments.split_k_mode == cutlass.conv.SplitKMode.Parallel:
317
+ problem_size = arguments.problem_size
318
+ workspace_bytes = DataTypeSize[self.operation.C.element] \
319
+ * launch_config.grid[2] * cutlass.conv.implicit_gemm_tensor_c_size(
320
+ self.conv_kind, problem_size
321
+ ) // 8
322
+ elif arguments.split_k_mode == cutlass.conv.SplitKMode.Serial and \
323
+ arguments.split_k_slices > 1:
324
+ workspace_bytes = launch_config.grid[0] * launch_config.grid[1] * 4
325
+
326
+ return workspace_bytes
327
+
328
+ # @typechecked
329
+ def plan(self, arguments: Conv2dArguments):
330
+ tile_size = cutlass.gemm.GemmCoord(
331
+ self.operation.tile_description.threadblock_shape[0],
332
+ self.operation.tile_description.threadblock_shape[1],
333
+ self.operation.tile_description.threadblock_shape[2]
334
+ )
335
+
336
+ grid = self.swizzle_functor.get_grid_shape(
337
+ self.swizzle_functor.get_tiled_shape(
338
+ self.conv_kind, arguments.problem_size,
339
+ tile_size, arguments.split_k_slices
340
+ )
341
+ )
342
+ return LaunchConfiguration(
343
+ [grid.x, grid.y, grid.z], [self.threads, 1, 1],
344
+ self.shared_memory_capacity)
345
+
346
+ def initialize(self):
347
+ err, = cuda.cuFuncSetAttribute(
348
+ self.kernel,
349
+ attrib=cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
350
+ value=self.shared_memory_capacity)
351
+ if err != cuda.CUresult.CUDA_SUCCESS:
352
+ raise RuntimeError('Cuda Error: {}'.format(err))
353
+
354
+ #
355
+
356
+
357
+ class Conv2dOperation:
358
+ """
359
+ CUTLASS Conv2d operation description.
360
+
361
+ :param conv_kind: convolution operator
362
+ :type conv_kind: :class:`cutlass.conv.Operator`
363
+
364
+ :param iterator_algorithm: Selects among several implementation
365
+ variants trading off performance with simplicity
366
+ :type iterator_algorithm: :class:`cutlass.conv.IteratorAlgorithm`
367
+
368
+ :param arch: GPU compute capability (sm_xx)
369
+ :type arch: int
370
+
371
+ :param tile_description: tile description
372
+ :type tile_description: :class:`pycutlass.TileDescription`
373
+
374
+ :param A: tensor A description
375
+ :type A: :class:`pycutlass.TensorDescription`
376
+
377
+ :param B: tensor B description
378
+ :type B: :class:`pycutlass.TensorDescription`
379
+
380
+ :param C: tensor C description
381
+ :type C: :class:`pycutlass.TensorDescription`
382
+
383
+ :param D: tensor D description
384
+ :type D: :class:`pycutlass.TensorDescription`
385
+
386
+ :param element_epilogue: element type for computation in epilogue \
387
+ :type element_epilogue: cutlass.int8 | cutlass.int32 | cutlass.float16 | \
388
+ cutlass.bfloat16 | cutlass.float32 | cutlass.float64
389
+
390
+ :param stride_support: distinguish among partial specializations that \
391
+ accelerate certain problems where convolution stride is unit \
392
+ :type stride_support: :class:`cutlass.conv.StrideSupport`
393
+
394
+ :param epilogue_functor: convolution epilogue functor
395
+ :type epilogue_functor: :class:`EpilogueFunctor`
396
+
397
+ :param swizzling_functor: threadblock swizzling functor
398
+ """
399
+ #
400
+
401
+ def __init__(self,
402
+ conv_kind: cutlass.conv.Operator,
403
+ iterator_algorithm: cutlass.conv.IteratorAlgorithm,
404
+ arch: int, tile_description: TileDescription,
405
+ A: TensorDescription, B: TensorDescription, C: TensorDescription,
406
+ stride_support, epilogue_functor,
407
+ swizzling_functor=cutlass.IdentitySwizzle1):
408
+
409
+ self.operation_kind: OperationKind = OperationKind.Conv2d
410
+ self.arch: int = arch
411
+ self.tile_description: TileDescription = tile_description
412
+ self.conv_kind = conv_kind
413
+ self.A: TensorDescription = A
414
+ self.B: TensorDescription = B
415
+ self.C: TensorDescription = C
416
+ self.epilogue_functor = epilogue_functor
417
+ self.iterator_algorithm = iterator_algorithm
418
+ self.stride_support = stride_support
419
+ self.swizzling_functor = swizzling_functor()
420
+
421
+ self.rt_module: Conv2dRT = Conv2dRT(self)
422
+ self.argument_type = self.rt_module.argument_type
423
+ self.epilogue_type = self.rt_module.epilogue_type
424
+
425
+ def run(self, arguments: Conv2dArguments) -> cuda.CUresult:
426
+ """
427
+ Launch the cuda kernel with input arguments
428
+
429
+ :param arguments: conv2d arguments
430
+ :type arguments: :class:`pycutlass.Conv2dArguments`
431
+ """
432
+
433
+ # launch the kernel
434
+ err = self.rt_module.run(
435
+ arguments.host_workspace,
436
+ arguments.device_workspace,
437
+ arguments.launch_config)
438
+
439
+ if err != cuda.CUresult.CUDA_SUCCESS:
440
+ raise RuntimeError('CUDA Error %s' % str(err))
441
+
442
+ return err
443
+
444
+ #
445
+ # Get function name
446
+ #
447
+
448
+ def procedural_name(self):
449
+ ''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
450
+ return self.configuration_name()
451
+ #
452
+
453
+ def configuration_name(self):
454
+ ''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
455
+
456
+ opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class]
457
+
458
+ threadblock = "%dx%d_%dx%d" % (
459
+ self.tile_description.threadblock_shape[0],
460
+ self.tile_description.threadblock_shape[1],
461
+ self.tile_description.threadblock_shape[2],
462
+ self.tile_description.stages
463
+ )
464
+
465
+ if self.stride_support == StrideSupport.Unity:
466
+ configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_unity_stride_align${alignment}"
467
+ else:
468
+ configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_align${alignment}"
469
+
470
+ return SubstituteTemplate(
471
+ configuration_name,
472
+ {
473
+ 'opcode_class': opcode_class_name,
474
+ 'extended_name': self.extended_name(),
475
+ 'threadblock': threadblock,
476
+ 'layout': self.layout_name(),
477
+ 'alignment': "%d" % self.A.alignment,
478
+ }
479
+ )
480
+
481
+ #
482
+ def extended_name(self):
483
+ ''' Append data types if they differ from compute type. '''
484
+ if self.C.element != self.tile_description.math_instruction.element_accumulator and \
485
+ self.A.element != self.tile_description.math_instruction.element_accumulator:
486
+ extended_name = "${element_c}_${core_name}_${element_a}"
487
+ elif self.C.element == self.tile_description.math_instruction.element_accumulator and \
488
+ self.A.element != self.tile_description.math_instruction.element_accumulator:
489
+ extended_name = "${core_name}_${element_a}"
490
+ else:
491
+ extended_name = "${core_name}"
492
+
493
+ extended_name = SubstituteTemplate(extended_name, {
494
+ 'element_a': DataTypeNames[self.A.element],
495
+ 'element_c': DataTypeNames[self.C.element],
496
+ 'core_name': self.core_name()
497
+ })
498
+
499
+ return extended_name
500
+
501
+ #
502
+ def layout_name(self):
503
+ return "%s" % (ShortLayoutTypeNames[self.A.layout])
504
+
505
+ #
506
+ def core_name(self):
507
+ ''' The basic operation kind is prefixed with a letter indicating the accumulation type. '''
508
+
509
+ intermediate_type = ''
510
+
511
+ if self.tile_description.math_instruction.opcode_class == cutlass.OpClass.TensorOp:
512
+ inst_shape = "%d%d%d" % tuple(
513
+ self.tile_description.math_instruction.instruction_shape)
514
+ if self.tile_description.math_instruction.element_a != self.A.element and \
515
+ self.tile_description.math_instruction.element_a != self.accumulator_type():
516
+ intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a]
517
+ else:
518
+ inst_shape = ''
519
+
520
+ return "%s%s%s%s_%s" % (ShortDataTypeNames[self.accumulator_type()],
521
+ inst_shape, intermediate_type, ConvKindNames[self.conv_kind], IteratorAlgorithmNames[self.iterator_algorithm])
522
+
523
+ #
524
+ def is_complex(self):
525
+ complex_operators = [
526
+ MathOperation.multiply_add_complex,
527
+ MathOperation.multiply_add_complex_gaussian
528
+ ]
529
+ return self.tile_description.math_instruction.math_operation in complex_operators
530
+
531
+ #
532
+ def accumulator_type(self):
533
+ accum = self.tile_description.math_instruction.element_accumulator
534
+
535
+ if self.is_complex():
536
+ return get_complex_from_real(accum)
537
+
538
+ return accum
539
+
540
+
541
+ ###################################################################################################
542
+ #
543
+ # Emits single instances of a CUTLASS device-wide operator
544
+ #
545
+ ###################################################################################################
546
+
547
+ class EmitConv2dInstance:
548
+ def __init__(self, operation_suffix=''):
549
+ self.operation_suffix = operation_suffix
550
+ self.includes = [
551
+ "cutlass/cutlass.h",
552
+ "cutlass/conv/kernel/default_conv2d_fprop.h",
553
+ "cutlass/conv/kernel/default_conv2d_dgrad.h",
554
+ "cutlass/conv/kernel/default_conv2d_wgrad.h"
555
+ ]
556
+ self.template = """
557
+ // Conv2d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}"
558
+ using ${operation_name}_base =
559
+ typename cutlass::conv::kernel::DefaultConv2d${conv_kind_name}<
560
+ ${element_a},
561
+ ${layout_a},
562
+ ${element_b},
563
+ ${layout_b},
564
+ ${element_c},
565
+ ${layout_c},
566
+ ${element_accumulator},
567
+ ${opcode_class},
568
+ ${arch},
569
+ cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
570
+ cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k} >,
571
+ cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
572
+ ${epilogue_functor},
573
+ ${swizzling_functor}, // cutlass::gemm::threadblock::GemmSplitKIdentityThreadblockSwizzle<>,
574
+ ${stages},
575
+ ${math_operator},
576
+ ${iterator_algorithm},
577
+ ${stride_support},
578
+ ${align_a},
579
+ ${align_b}
580
+ >::Kernel;
581
+
582
+ struct ${operation_name}${operation_suffix}:
583
+ public ${operation_name}_base { };
584
+
585
+ """
586
+
587
+ def emit(self, operation):
588
+
589
+ warp_shape = [int(operation.tile_description.threadblock_shape[idx] /
590
+ operation.tile_description.warp_count[idx]) for idx in range(3)]
591
+
592
+ epilogue_vector_length = int(min(
593
+ operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element])
594
+
595
+ values = {
596
+ 'operation_name': operation.procedural_name(),
597
+ 'operation_suffix': self.operation_suffix,
598
+ 'conv_kind': ConvKindTag[operation.conv_kind],
599
+ 'conv_kind_name': ConvKindNames[operation.conv_kind].capitalize(),
600
+ 'element_a': DataTypeTag[operation.A.element],
601
+ 'layout_a': LayoutTag[operation.A.layout],
602
+ 'element_b': DataTypeTag[operation.B.element],
603
+ 'layout_b': LayoutTag[operation.B.layout],
604
+ 'element_c': DataTypeTag[operation.C.element],
605
+ 'layout_c': LayoutTag[operation.C.layout],
606
+ 'element_accumulator': DataTypeTag[operation.accumulator_type()],
607
+ 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
608
+ 'arch': "cutlass::arch::Sm%d" % operation.arch,
609
+ 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
610
+ 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
611
+ 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
612
+ 'warp_shape_m': str(warp_shape[0]),
613
+ 'warp_shape_n': str(warp_shape[1]),
614
+ 'warp_shape_k': str(warp_shape[2]),
615
+ 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
616
+ 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
617
+ 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
618
+ 'epilogue_vector_length': str(epilogue_vector_length),
619
+ 'epilogue_functor': operation.epilogue_functor.emit(),
620
+ 'swizzling_functor': operation.swizzling_functor.tag(),
621
+ 'stages': str(operation.tile_description.stages),
622
+ 'iterator_algorithm': IteratorAlgorithmTag[operation.iterator_algorithm],
623
+ 'iterator_algorithm_name': IteratorAlgorithmNames[operation.iterator_algorithm].capitalize(),
624
+ 'stride_support': StrideSupportTag[operation.stride_support],
625
+ 'math_operator': 'cutlass::arch::OpMultiplyAddComplex' if operation.is_complex() else
626
+ MathOperationTag[operation.tile_description.math_instruction.math_operation],
627
+ 'align_a': str(operation.A.alignment),
628
+ 'align_b': str(operation.B.alignment),
629
+ }
630
+
631
+ return SubstituteTemplate(self.template, values)