warp-lang 0.9.0__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (315) hide show
  1. warp/__init__.py +15 -7
  2. warp/__init__.pyi +1 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +22 -443
  6. warp/build_dll.py +384 -0
  7. warp/builtins.py +998 -488
  8. warp/codegen.py +1307 -739
  9. warp/config.py +5 -3
  10. warp/constants.py +6 -0
  11. warp/context.py +1291 -548
  12. warp/dlpack.py +31 -31
  13. warp/fabric.py +326 -0
  14. warp/fem/__init__.py +27 -0
  15. warp/fem/cache.py +389 -0
  16. warp/fem/dirichlet.py +181 -0
  17. warp/fem/domain.py +263 -0
  18. warp/fem/field/__init__.py +101 -0
  19. warp/fem/field/field.py +149 -0
  20. warp/fem/field/nodal_field.py +299 -0
  21. warp/fem/field/restriction.py +21 -0
  22. warp/fem/field/test.py +181 -0
  23. warp/fem/field/trial.py +183 -0
  24. warp/fem/geometry/__init__.py +19 -0
  25. warp/fem/geometry/closest_point.py +70 -0
  26. warp/fem/geometry/deformed_geometry.py +271 -0
  27. warp/fem/geometry/element.py +744 -0
  28. warp/fem/geometry/geometry.py +186 -0
  29. warp/fem/geometry/grid_2d.py +373 -0
  30. warp/fem/geometry/grid_3d.py +435 -0
  31. warp/fem/geometry/hexmesh.py +953 -0
  32. warp/fem/geometry/partition.py +376 -0
  33. warp/fem/geometry/quadmesh_2d.py +532 -0
  34. warp/fem/geometry/tetmesh.py +840 -0
  35. warp/fem/geometry/trimesh_2d.py +577 -0
  36. warp/fem/integrate.py +1616 -0
  37. warp/fem/operator.py +191 -0
  38. warp/fem/polynomial.py +213 -0
  39. warp/fem/quadrature/__init__.py +2 -0
  40. warp/fem/quadrature/pic_quadrature.py +245 -0
  41. warp/fem/quadrature/quadrature.py +294 -0
  42. warp/fem/space/__init__.py +292 -0
  43. warp/fem/space/basis_space.py +489 -0
  44. warp/fem/space/collocated_function_space.py +105 -0
  45. warp/fem/space/dof_mapper.py +236 -0
  46. warp/fem/space/function_space.py +145 -0
  47. warp/fem/space/grid_2d_function_space.py +267 -0
  48. warp/fem/space/grid_3d_function_space.py +306 -0
  49. warp/fem/space/hexmesh_function_space.py +352 -0
  50. warp/fem/space/partition.py +350 -0
  51. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  52. warp/fem/space/restriction.py +160 -0
  53. warp/fem/space/shape/__init__.py +15 -0
  54. warp/fem/space/shape/cube_shape_function.py +738 -0
  55. warp/fem/space/shape/shape_function.py +103 -0
  56. warp/fem/space/shape/square_shape_function.py +611 -0
  57. warp/fem/space/shape/tet_shape_function.py +567 -0
  58. warp/fem/space/shape/triangle_shape_function.py +429 -0
  59. warp/fem/space/tetmesh_function_space.py +292 -0
  60. warp/fem/space/topology.py +295 -0
  61. warp/fem/space/trimesh_2d_function_space.py +221 -0
  62. warp/fem/types.py +77 -0
  63. warp/fem/utils.py +495 -0
  64. warp/native/array.h +164 -55
  65. warp/native/builtin.h +150 -174
  66. warp/native/bvh.cpp +75 -328
  67. warp/native/bvh.cu +406 -23
  68. warp/native/bvh.h +37 -45
  69. warp/native/clang/clang.cpp +136 -24
  70. warp/native/crt.cpp +1 -76
  71. warp/native/crt.h +111 -104
  72. warp/native/cuda_crt.h +1049 -0
  73. warp/native/cuda_util.cpp +15 -3
  74. warp/native/cuda_util.h +3 -1
  75. warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
  76. warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
  77. warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
  78. warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
  79. warp/native/cutlass/tools/library/scripts/library.py +799 -0
  80. warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
  81. warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
  82. warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
  83. warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
  84. warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
  85. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
  86. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
  87. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
  88. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
  89. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
  90. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
  91. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
  92. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
  93. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
  94. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
  95. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
  96. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
  97. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
  98. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
  99. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
  100. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
  101. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
  102. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
  103. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
  104. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
  105. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
  106. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
  107. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
  108. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
  109. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
  110. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
  111. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
  112. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
  113. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
  114. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
  115. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
  116. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
  117. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  118. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  119. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
  120. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
  121. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
  122. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  123. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  124. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
  125. warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
  126. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
  127. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
  128. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
  129. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
  130. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
  131. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
  132. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
  133. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
  134. warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
  135. warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
  136. warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
  137. warp/native/cutlass/tools/library/scripts/rt.py +796 -0
  138. warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
  139. warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
  140. warp/native/cutlass_gemm.cu +5 -3
  141. warp/native/exports.h +1240 -949
  142. warp/native/fabric.h +228 -0
  143. warp/native/hashgrid.cpp +4 -4
  144. warp/native/hashgrid.h +22 -2
  145. warp/native/initializer_array.h +2 -2
  146. warp/native/intersect.h +22 -7
  147. warp/native/intersect_adj.h +8 -8
  148. warp/native/intersect_tri.h +13 -16
  149. warp/native/marching.cu +157 -161
  150. warp/native/mat.h +119 -19
  151. warp/native/matnn.h +2 -2
  152. warp/native/mesh.cpp +108 -83
  153. warp/native/mesh.cu +243 -6
  154. warp/native/mesh.h +1547 -458
  155. warp/native/nanovdb/NanoVDB.h +1 -1
  156. warp/native/noise.h +272 -329
  157. warp/native/quat.h +51 -8
  158. warp/native/rand.h +45 -35
  159. warp/native/range.h +6 -2
  160. warp/native/reduce.cpp +157 -0
  161. warp/native/reduce.cu +348 -0
  162. warp/native/runlength_encode.cpp +62 -0
  163. warp/native/runlength_encode.cu +46 -0
  164. warp/native/scan.cu +11 -13
  165. warp/native/scan.h +1 -0
  166. warp/native/solid_angle.h +442 -0
  167. warp/native/sort.cpp +13 -0
  168. warp/native/sort.cu +9 -1
  169. warp/native/sparse.cpp +338 -0
  170. warp/native/sparse.cu +545 -0
  171. warp/native/spatial.h +2 -2
  172. warp/native/temp_buffer.h +30 -0
  173. warp/native/vec.h +126 -24
  174. warp/native/volume.h +120 -0
  175. warp/native/warp.cpp +658 -53
  176. warp/native/warp.cu +660 -68
  177. warp/native/warp.h +112 -12
  178. warp/optim/__init__.py +1 -0
  179. warp/optim/linear.py +922 -0
  180. warp/optim/sgd.py +92 -0
  181. warp/render/render_opengl.py +392 -152
  182. warp/render/render_usd.py +11 -11
  183. warp/sim/__init__.py +2 -2
  184. warp/sim/articulation.py +385 -185
  185. warp/sim/collide.py +21 -8
  186. warp/sim/import_mjcf.py +297 -106
  187. warp/sim/import_urdf.py +389 -210
  188. warp/sim/import_usd.py +198 -97
  189. warp/sim/inertia.py +17 -18
  190. warp/sim/integrator_euler.py +14 -8
  191. warp/sim/integrator_xpbd.py +161 -19
  192. warp/sim/model.py +795 -291
  193. warp/sim/optimizer.py +2 -6
  194. warp/sim/render.py +65 -3
  195. warp/sim/utils.py +3 -0
  196. warp/sparse.py +1227 -0
  197. warp/stubs.py +665 -223
  198. warp/tape.py +66 -15
  199. warp/tests/__main__.py +3 -6
  200. warp/tests/assets/curlnoise_golden.npy +0 -0
  201. warp/tests/assets/pnoise_golden.npy +0 -0
  202. warp/tests/assets/torus.usda +105 -105
  203. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  204. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  205. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  206. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  207. warp/tests/aux_test_unresolved_func.py +14 -0
  208. warp/tests/aux_test_unresolved_symbol.py +14 -0
  209. warp/tests/disabled_kinematics.py +239 -0
  210. warp/tests/run_coverage_serial.py +31 -0
  211. warp/tests/test_adam.py +103 -106
  212. warp/tests/test_arithmetic.py +128 -74
  213. warp/tests/test_array.py +1497 -211
  214. warp/tests/test_array_reduce.py +150 -0
  215. warp/tests/test_atomic.py +64 -28
  216. warp/tests/test_bool.py +99 -0
  217. warp/tests/test_builtins_resolution.py +1292 -0
  218. warp/tests/test_bvh.py +75 -43
  219. warp/tests/test_closest_point_edge_edge.py +54 -57
  220. warp/tests/test_codegen.py +233 -128
  221. warp/tests/test_compile_consts.py +28 -20
  222. warp/tests/test_conditional.py +108 -24
  223. warp/tests/test_copy.py +10 -12
  224. warp/tests/test_ctypes.py +112 -88
  225. warp/tests/test_dense.py +21 -14
  226. warp/tests/test_devices.py +98 -0
  227. warp/tests/test_dlpack.py +136 -108
  228. warp/tests/test_examples.py +277 -0
  229. warp/tests/test_fabricarray.py +955 -0
  230. warp/tests/test_fast_math.py +15 -11
  231. warp/tests/test_fem.py +1271 -0
  232. warp/tests/test_fp16.py +53 -19
  233. warp/tests/test_func.py +187 -74
  234. warp/tests/test_generics.py +194 -49
  235. warp/tests/test_grad.py +180 -116
  236. warp/tests/test_grad_customs.py +176 -0
  237. warp/tests/test_hash_grid.py +52 -37
  238. warp/tests/test_import.py +10 -23
  239. warp/tests/test_indexedarray.py +577 -24
  240. warp/tests/test_intersect.py +18 -9
  241. warp/tests/test_large.py +141 -0
  242. warp/tests/test_launch.py +251 -15
  243. warp/tests/test_lerp.py +64 -65
  244. warp/tests/test_linear_solvers.py +154 -0
  245. warp/tests/test_lvalue.py +493 -0
  246. warp/tests/test_marching_cubes.py +12 -13
  247. warp/tests/test_mat.py +508 -2778
  248. warp/tests/test_mat_lite.py +115 -0
  249. warp/tests/test_mat_scalar_ops.py +2889 -0
  250. warp/tests/test_math.py +103 -9
  251. warp/tests/test_matmul.py +305 -69
  252. warp/tests/test_matmul_lite.py +410 -0
  253. warp/tests/test_mesh.py +71 -14
  254. warp/tests/test_mesh_query_aabb.py +41 -25
  255. warp/tests/test_mesh_query_point.py +325 -34
  256. warp/tests/test_mesh_query_ray.py +39 -22
  257. warp/tests/test_mlp.py +30 -22
  258. warp/tests/test_model.py +92 -89
  259. warp/tests/test_modules_lite.py +39 -0
  260. warp/tests/test_multigpu.py +88 -114
  261. warp/tests/test_noise.py +12 -11
  262. warp/tests/test_operators.py +16 -20
  263. warp/tests/test_options.py +11 -11
  264. warp/tests/test_pinned.py +17 -18
  265. warp/tests/test_print.py +32 -11
  266. warp/tests/test_quat.py +275 -129
  267. warp/tests/test_rand.py +18 -16
  268. warp/tests/test_reload.py +38 -34
  269. warp/tests/test_rounding.py +50 -43
  270. warp/tests/test_runlength_encode.py +190 -0
  271. warp/tests/test_smoothstep.py +9 -11
  272. warp/tests/test_snippet.py +143 -0
  273. warp/tests/test_sparse.py +460 -0
  274. warp/tests/test_spatial.py +276 -243
  275. warp/tests/test_streams.py +110 -85
  276. warp/tests/test_struct.py +331 -85
  277. warp/tests/test_tape.py +39 -21
  278. warp/tests/test_torch.py +118 -89
  279. warp/tests/test_transient_module.py +12 -13
  280. warp/tests/test_types.py +614 -0
  281. warp/tests/test_utils.py +494 -0
  282. warp/tests/test_vec.py +354 -1987
  283. warp/tests/test_vec_lite.py +73 -0
  284. warp/tests/test_vec_scalar_ops.py +2099 -0
  285. warp/tests/test_volume.py +457 -293
  286. warp/tests/test_volume_write.py +124 -134
  287. warp/tests/unittest_serial.py +35 -0
  288. warp/tests/unittest_suites.py +341 -0
  289. warp/tests/unittest_utils.py +568 -0
  290. warp/tests/unused_test_misc.py +71 -0
  291. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  292. warp/thirdparty/appdirs.py +36 -45
  293. warp/thirdparty/unittest_parallel.py +549 -0
  294. warp/torch.py +72 -30
  295. warp/types.py +1744 -713
  296. warp/utils.py +360 -350
  297. warp_lang-0.11.0.dist-info/LICENSE.md +36 -0
  298. warp_lang-0.11.0.dist-info/METADATA +238 -0
  299. warp_lang-0.11.0.dist-info/RECORD +332 -0
  300. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
  301. warp/bin/warp-clang.exp +0 -0
  302. warp/bin/warp-clang.lib +0 -0
  303. warp/bin/warp.exp +0 -0
  304. warp/bin/warp.lib +0 -0
  305. warp/tests/test_all.py +0 -215
  306. warp/tests/test_array_scan.py +0 -60
  307. warp/tests/test_base.py +0 -208
  308. warp/tests/test_unresolved_func.py +0 -7
  309. warp/tests/test_unresolved_symbol.py +0 -7
  310. warp_lang-0.9.0.dist-info/METADATA +0 -20
  311. warp_lang-0.9.0.dist-info/RECORD +0 -177
  312. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  313. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  314. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  315. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,646 @@
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ import pycutlass
34
+ from pycutlass import *
35
+ from pycutlass.test import *
36
+ from time import sleep
37
+ from bfloat16 import bfloat16
38
+ import subprocess
39
+ from typeguard import typechecked
40
+ import re
41
+
42
+
43
+
44
+ def getTensorRef(tensor, tensor_layout, conv_kind, problem_size, operand):
45
+ ptr = tensor.__array_interface__['data'][0]
46
+ if operand == "a":
47
+ tensor_coord = cutlass.conv.implicit_gemm_tensor_a_extent(conv_kind, problem_size)
48
+ elif operand == "b":
49
+ tensor_coord = cutlass.conv.implicit_gemm_tensor_b_extent(conv_kind, problem_size)
50
+ elif operand in ["c", "d"]:
51
+ tensor_coord = cutlass.conv.implicit_gemm_tensor_c_extent(conv_kind, problem_size)
52
+ else:
53
+ raise ValueError("unknown operand: " + operand)
54
+
55
+ layout = tensor_layout.packed(tensor_coord)
56
+
57
+ if tensor.dtype == np.float64:
58
+ return cutlass.TensorRefF64NHWC(ptr, layout)
59
+ elif tensor.dtype == np.float32:
60
+ return cutlass.TensorRefF32NHWC(ptr, layout)
61
+ elif tensor.dtype == np.float16:
62
+ return cutlass.TensorRefF16NHWC(ptr, layout)
63
+ if tensor.dtype == bfloat16:
64
+ return cutlass.TensorRefBF16NHWC(ptr, layout)
65
+ elif tensor.dtype == np.int32:
66
+ return cutlass.TensorRefS32NHWC(ptr, layout)
67
+ elif tensor.dtype == np.int8:
68
+ if tensor_layout == cutlass.TensorNC32HW32:
69
+ return cutlass.TensorRefS8NC32HW32(ptr, layout)
70
+ elif tensor_layout == cutlass.TensorC32RSK32:
71
+ return cutlass.TensorRefS8C32RSK32(ptr, layout)
72
+ else:
73
+ return cutlass.TensorRefS8NHWC(ptr, layout)
74
+ else:
75
+ raise ValueError("unsupported data type")
76
+
77
+ def getTensorView(tensor, tensor_layout, conv_kind, problem_size, operand):
78
+ tensor_ref = getTensorRef(tensor, tensor_layout, conv_kind, problem_size, operand)
79
+
80
+ if operand == "a":
81
+ tensor_coord = cutlass.conv.implicit_gemm_tensor_a_extent(conv_kind, problem_size)
82
+ elif operand == "b":
83
+ tensor_coord = cutlass.conv.implicit_gemm_tensor_b_extent(conv_kind, problem_size)
84
+ elif operand in ["c", "d"]:
85
+ tensor_coord = cutlass.conv.implicit_gemm_tensor_c_extent(conv_kind, problem_size)
86
+ else:
87
+ raise ValueError("unknown operand: " + operand)
88
+
89
+ if tensor.dtype == np.float64:
90
+ return cutlass.TensorViewF64NHWC(tensor_ref, tensor_coord)
91
+ elif tensor.dtype == np.float32:
92
+ return cutlass.TensorViewF32NHWC(tensor_ref, tensor_coord)
93
+ elif tensor.dtype == np.float16:
94
+ return cutlass.TensorViewF16NHWC(tensor_ref, tensor_coord)
95
+ elif tensor.dtype == bfloat16:
96
+ return cutlass.TensorViewBF16NHWC(tensor_ref, tensor_coord)
97
+ elif tensor.dtype == np.int32:
98
+ return cutlass.TensorViewS32NHWC(tensor_ref, tensor_coord)
99
+ elif tensor.dtype == np.int8:
100
+ if tensor_layout == cutlass.TensorNC32HW32:
101
+ return cutlass.TensorViewS8NC32HW32(tensor_ref, tensor_coord)
102
+ elif tensor_layout == cutlass.TensorC32RSK32:
103
+ return cutlass.TensorViewS8C32RSK32(tensor_ref, tensor_coord)
104
+ else:
105
+ return cutlass.TensorViewS8NHWC(tensor_ref, tensor_coord)
106
+
107
+ else:
108
+ raise ValueError("unsupported data type")
109
+
110
+
111
+
112
+ # @typechecked
113
+ class Conv2dLauncher:
114
+ """
115
+ Launcher that runs the operation on given problem size
116
+ """
117
+ def __init__(self, operation: 'Conv2dOperation', seed: int=2080, interleaved=False,
118
+ verification=True, profiling=False, warmup_iterations=500, iterations=500, **kwargs) -> None:
119
+
120
+ self.enable_cached_results = True
121
+ self.interleaved = interleaved
122
+
123
+ # create the reduction kernel
124
+ self.reduction_operation = ReductionOperation(
125
+ shape=cutlass.MatrixCoord(4, 32 * operation.C.alignment),
126
+ C=operation.C, element_accumulator=operation.tile_description.math_instruction.element_accumulator,
127
+ element_compute=operation.epilogue_functor.element_epilogue, epilogue_functor=operation.epilogue_functor,
128
+ count=operation.C.alignment
129
+ )
130
+
131
+ #: verify the output result
132
+ self.verification = verification
133
+ #: profile the kernel's runtime
134
+ self.profiling = profiling
135
+
136
+ self.timer = GpuTimer()
137
+
138
+ self.warmup_iterations = warmup_iterations
139
+ self.iterations = iterations
140
+
141
+ if "sleep" in kwargs.keys():
142
+ self.sleep_time = kwargs["sleep"]
143
+ else:
144
+ self.sleep_time = 0
145
+
146
+ #
147
+ # Compile the operator
148
+ #
149
+
150
+ pycutlass.compiler.add_module([operation, self.reduction_operation])
151
+
152
+ self.operation = operation
153
+
154
+ self.dtype_A = Conv2dLauncher.numpy_type(operation.A.element)
155
+ self.layout_A = operation.A.layout
156
+ self.dtype_B = Conv2dLauncher.numpy_type(operation.B.element)
157
+ self.layout_B = operation.B.layout
158
+ self.dtype_C = Conv2dLauncher.numpy_type(operation.C.element)
159
+ self.layout_C = operation.C.layout
160
+ self.dtype_D = Conv2dLauncher.numpy_type(operation.C.element)
161
+ self.layout_D = operation.C.layout
162
+
163
+ accumulator_size = DataTypeSize[operation.tile_description.math_instruction.element_accumulator]
164
+ element_size = DataTypeSize[operation.A.element]
165
+
166
+ if element_size <= 8:
167
+ self.scope = 1
168
+ elif element_size == 16:
169
+ if accumulator_size <= 16:
170
+ self.scope = 2
171
+ else:
172
+ self.scope = 4
173
+ else:
174
+ self.scope = 7
175
+
176
+ # Seed
177
+ self.seed = seed
178
+
179
+ self.conv_kind = operation.conv_kind
180
+
181
+
182
+ #
183
+ # Get the host reference function
184
+ #
185
+
186
+ self.element_compute = operation.epilogue_functor.element_epilogue
187
+
188
+ self.host_conv2d = cutlass.test.conv.host.conv2d
189
+
190
+ self.timer = GpuTimer()
191
+
192
+ @staticmethod
193
+ def numpy_type(type):
194
+ if type == cutlass.float64:
195
+ return np.float64
196
+ elif type == cutlass.float32:
197
+ return np.float32
198
+ elif type == cutlass.float16:
199
+ return np.float16
200
+ elif type == cutlass.bfloat16:
201
+ return bfloat16
202
+ elif type == cutlass.int32:
203
+ return np.int32
204
+ elif type == cutlass.int8:
205
+ return np.int8
206
+ else:
207
+ raise ValueError("unsupported type: %s" % ShortDataTypeNames[type])
208
+
209
+ def print_problem_size(self, p, split_k_mode=1):
210
+ print("nhwc_%dx%dx%dx%d_krsc_%dx%dx%dx%d_padding_%dx%d_stride_%dx%d_dilation_%dx%d_splitkslices_%d_splitkmode_%d"
211
+ % (p.N, p.H, p.W, p.C, p.K, p.R, p.S, p.C, p.pad_h,
212
+ p.pad_w, p.stride_h, p.stride_w, p.dilation_h, p.dilation_w, p.split_k_slices, split_k_mode))
213
+
214
+ def uniform_init(self, size, dtype):
215
+ if dtype in [np.float32, np.float16, bfloat16, np.float64]:
216
+ return np.ceil(
217
+ np.random.uniform(
218
+ low=-self.scope - 0.5, high=self.scope - 0.5,
219
+ size=size).astype(dtype)
220
+ )
221
+ else:
222
+ return np.random.uniform(
223
+ low=-self.scope - 1, high=self.scope + 1,
224
+ size=size).astype(dtype)
225
+
226
+ def eq_gemm_size(self, problem_size):
227
+ n = problem_size.N
228
+ p = problem_size.P
229
+ q = problem_size.Q
230
+ k = problem_size.K
231
+ r = problem_size.R
232
+ s = problem_size.S
233
+ c = problem_size.C
234
+ h = problem_size.H
235
+ w = problem_size.W
236
+ if self.conv_kind == cutlass.conv.Operator.fprop:
237
+ return cutlass.gemm.GemmCoord(n * p * q, k, r * s * c)
238
+ elif self.conv_kind == cutlass.conv.Operator.dgrad:
239
+ return cutlass.gemm.GemmCoord(n * h * w, c, k * r * s)
240
+ else:
241
+ return cutlass.gemm.GemmCoord(k, r * s * c, n * p * q)
242
+
243
+ def bytes(self, problem_size, alpha, beta):
244
+ mnk = self.eq_gemm_size(problem_size)
245
+
246
+ bytes_ = \
247
+ (DataTypeSize[self.operation.A.element] * mnk.m() // 8) * mnk.k() + \
248
+ (DataTypeSize[self.operation.B.element] * mnk.n() // 8) * mnk.k() + \
249
+ (DataTypeSize[self.operation.C.element] * mnk.m() // 8) * mnk.n()
250
+
251
+ if beta != 0:
252
+ bytes_ += (DataTypeSize[self.operation.C.element] * mnk.m() // 8) * mnk.n()
253
+
254
+ return bytes_
255
+
256
+ def flops(self, problem_size):
257
+ mnk = self.eq_gemm_size(problem_size)
258
+
259
+ flops_mainloop_ = mnk.m() * mnk.n() * mnk.k() * 2
260
+ flops_epilogue_ = mnk.m() * mnk.n() * 2
261
+
262
+ # Adjust mainloop flop for dgrad stride
263
+ if self.conv_kind == cutlass.conv.Operator.dgrad:
264
+ flops_mainloop_ = flops_mainloop_ // (problem_size.stride_h * problem_size.stride_w)
265
+
266
+ flops_total_ = flops_mainloop_ + flops_epilogue_
267
+
268
+ # TODO complex-value support
269
+ # switch (operation_desc.tile_description.math_instruction.math_operation) {
270
+ # case library::MathOperationID::kMultiplyAddComplex:
271
+ # flops_total_ *=4;
272
+ # break;
273
+
274
+ # default: break;
275
+ # }
276
+
277
+ return flops_total_
278
+
279
+
280
+
281
+ def host_reference(self, problem_size, tensor_A, tensor_B, tensor_C, alpha, beta):
282
+ if self.element_compute == cutlass.float16:
283
+ alpha = cutlass.float16(alpha)
284
+ beta = cutlass.float16(beta)
285
+ elif self.element_compute == cutlass.int32:
286
+ alpha = int(alpha)
287
+ beta = int(beta)
288
+ else:
289
+ alpha = alpha
290
+ beta = beta
291
+
292
+ # if cached result is loaded
293
+ cached_result_loaded = False
294
+
295
+ if self.enable_cached_results:
296
+ # get problem key
297
+ cached_test_key = cutlass.test.conv.host.CreateCachedConv2dTestKey(
298
+ self.conv_kind, problem_size, alpha, beta,
299
+ getTensorView(tensor_A, self.layout_A, self.conv_kind, problem_size, "a"),
300
+ getTensorView(tensor_B, self.layout_B, self.conv_kind, problem_size, "b"),
301
+ getTensorView(tensor_C, self.layout_C, self.conv_kind, problem_size, "c"),
302
+ )
303
+
304
+ cached_test_result = cutlass.test.conv.host.CachedTestResult()
305
+
306
+ conv2d_result_cache_name = "cached_results_SM%d_%d.txt" % (self.operation.arch, self.seed)
307
+
308
+ cached_results = cutlass.test.conv.host.CachedTestResultListing(conv2d_result_cache_name)
309
+ # CachedTestResultListing cached_results(conv2d_result_cache_name);
310
+ cached = cached_results.find(cached_test_key)
311
+ cached_result_loaded = cached[0]
312
+ if cached_result_loaded :
313
+ cached_test_result = cached[1]
314
+
315
+ if not cached_result_loaded:
316
+ # compute the conv2d on host
317
+ tensor_D_ref = np.ones_like(tensor_C)
318
+ tensor_ref_A = getTensorRef(tensor_A, self.layout_A, self.conv_kind, problem_size, "a")
319
+ tensor_ref_B = getTensorRef(tensor_B, self.layout_B, self.conv_kind, problem_size, "b")
320
+ tensor_ref_C = getTensorRef(tensor_C, self.layout_C, self.conv_kind, problem_size, "c")
321
+ tensor_ref_D_ref = getTensorRef(tensor_D_ref, self.layout_D, self.conv_kind, problem_size, "d")
322
+
323
+ self.host_conv2d(
324
+ self.conv_kind, problem_size,
325
+ tensor_ref_A, tensor_ref_B, tensor_ref_C, tensor_ref_D_ref,
326
+ alpha, beta
327
+ )
328
+
329
+ tensor_view_D_ref = getTensorView(tensor_D_ref, self.layout_D, self.conv_kind, problem_size, "d")
330
+
331
+ if self.enable_cached_results:
332
+ cached_test_result.D = cutlass.test.conv.host.TensorHash(tensor_view_D_ref)
333
+ cached_results = cutlass.test.conv.host.CachedTestResultListing(conv2d_result_cache_name)
334
+ cached_results.append(cached_test_key, cached_test_result)
335
+ cached_results.write(conv2d_result_cache_name)
336
+ else:
337
+ return tensor_D_ref
338
+
339
+ return cached_test_result.D
340
+
341
+ def equal(self, tensor_D, tensor_D_ref, problem_size):
342
+ if self.enable_cached_results:
343
+ tensor_view_D = getTensorView(tensor_D, self.layout_D, self.conv_kind, problem_size, "d")
344
+ tensor_D_hash = cutlass.test.conv.host.TensorHash(tensor_view_D)
345
+
346
+ return tensor_D_hash == tensor_D_ref
347
+ else:
348
+ tensor_view_D = getTensorView(tensor_D, self.layout_D, self.conv_kind, problem_size, "d")
349
+ tensor_view_D_ref = getTensorView(tensor_D_ref, self.layout_D, self.conv_kind, problem_size, "d")
350
+ return cutlass.test.conv.host.equals(tensor_view_D, tensor_view_D_ref)
351
+
352
+ def run_cutlass_profiler(self, problem_size, split_k_mode=cutlass.conv.SplitKMode.Serial, alpha=1.0, beta=0.0):
353
+
354
+ if split_k_mode == cutlass.conv.SplitKMode.Serial:
355
+ split_k_mode_ = "serial"
356
+ else:
357
+ split_k_mode_ = "parallel"
358
+
359
+ cutlass_path = os.getenv('CUTLASS_PATH')
360
+ assert cutlass_path is not None, "Environment variable 'CUTLASS_PATH' is not defined."
361
+
362
+ values = {
363
+ "profiler_path": cutlass_path + "/build/tools/profiler/cutlass_profiler",
364
+ "kernel_name": self.operation.procedural_name(),
365
+ "verification_providers": "device",
366
+ "provider": "cutlass",
367
+ 'n': str(problem_size.N),
368
+ 'h': str(problem_size.H),
369
+ 'w': str(problem_size.W),
370
+ 'c': str(problem_size.C),
371
+ 'k': str(problem_size.K),
372
+ 'r': str(problem_size.R),
373
+ 's': str(problem_size.S),
374
+ 'p': str(problem_size.P),
375
+ 'q': str(problem_size.Q),
376
+ 'pad_h': str(problem_size.pad_h),
377
+ 'pad_w': str(problem_size.pad_w),
378
+ 'stride_h': str(problem_size.stride_h),
379
+ 'stride_w': str(problem_size.stride_w),
380
+ 'dilation_h': str(problem_size.dilation_h),
381
+ 'dilation_w': str(problem_size.dilation_w),
382
+ 'split_k_slices': str(problem_size.split_k_slices),
383
+ 'split_k_mode': split_k_mode_,
384
+ 'alpha': str(alpha),
385
+ 'beta': str(beta),
386
+ 'warmup': str(self.warmup_iterations),
387
+ 'profile': str(self.iterations)
388
+ }
389
+
390
+ cmd_template = \
391
+ "${profiler_path} --kernels=${kernel_name} --verification-providers=${verification_providers}" \
392
+ " --providers=${provider} --n=${n} --h=${h} --w=${w} --c=${c} --k=${k} --r=${r} --s=${s} --p=${p}" \
393
+ " --q=${q} --pad_h=${pad_h} --pad_w=${pad_w} --stride_h={stride_h} --stride_w=${stride_w}" \
394
+ " --dilation_h=${dilation_h} --dilation_w=${dilation_w} --warmup-iterations=${warmup} --profiling-iterations=${profile}" \
395
+ " --split_k_slices=${split_k_slices} --alpha=${alpha} --beta=${beta} --split_k_mode=${split_k_mode}"
396
+
397
+ cmd = SubstituteTemplate(cmd_template, values)
398
+ result = subprocess.getoutput(cmd)
399
+
400
+ m = re.search(r"Runtime:\s+(?P<runtime>\d+.\d+)", result)
401
+ runtime = float(m.group('runtime'))
402
+
403
+ m = re.search(r"Bytes:\s+(?P<bytes>\d+)", result)
404
+ bytes = int(m.group('bytes'))
405
+
406
+ m = re.search(r"FLOPs:\s+(?P<flops>\d+)", result)
407
+ flops = int(m.group('flops'))
408
+
409
+ # check if the problem size matches
410
+ assert bytes == self.bytes(problem_size, alpha, beta)
411
+ assert flops == self.flops(problem_size)
412
+
413
+ return runtime
414
+
415
+
416
+
417
+ def run(self, problem_size, split_k_mode=cutlass.conv.SplitKMode.Serial,
418
+ alpha=1.0, beta=0.0):
419
+
420
+ assert get_allocated_size() == 0, "%d byte of pool memory is not released in previous run" % get_allocated_size()
421
+
422
+ #
423
+ # Initialize input and output tensors
424
+ #
425
+ tensor_A_size = cutlass.conv.implicit_gemm_tensor_a_size(self.conv_kind, problem_size)
426
+ tensor_B_size = cutlass.conv.implicit_gemm_tensor_b_size(self.conv_kind, problem_size)
427
+ tensor_C_size = cutlass.conv.implicit_gemm_tensor_c_size(self.conv_kind, problem_size)
428
+
429
+ np.random.seed(self.seed)
430
+
431
+ tensor_A = self.uniform_init(size=(tensor_A_size,), dtype=self.dtype_A)
432
+ tensor_B = self.uniform_init(size=(tensor_B_size,), dtype=self.dtype_B)
433
+ tensor_C = self.uniform_init(size=(tensor_C_size,), dtype=self.dtype_C)
434
+ tensor_D = np.zeros(shape=(tensor_C_size,), dtype=self.dtype_D)
435
+
436
+
437
+ #
438
+ # Launch kernel
439
+ #
440
+
441
+ arguments = Conv2dArguments(
442
+ operation=self.operation, problem_size=problem_size, A=tensor_A,
443
+ B=tensor_B, C=tensor_C, D=tensor_D,
444
+ output_op = self.operation.epilogue_type(alpha, beta),
445
+ split_k_slices=problem_size.split_k_slices,
446
+ split_k_mode=split_k_mode
447
+ )
448
+
449
+ if split_k_mode == cutlass.conv.SplitKMode.Parallel:
450
+ implicit_gemm_size = cutlass.conv.implicit_gemm_problem_size(self.operation.conv_kind, arguments.problem_size)
451
+ reduction_arguments = ReductionArguments(
452
+ self.reduction_operation,
453
+ problem_size=[implicit_gemm_size.m(), implicit_gemm_size.n()], partitions=problem_size.split_k_slices,
454
+ workspace=arguments.ptr_D,
455
+ destination=tensor_D,
456
+ source=tensor_C,
457
+ output_op = self.reduction_operation.epilogue_type(alpha, beta)
458
+ )
459
+
460
+ self.operation.run(arguments)
461
+ if split_k_mode == cutlass.conv.SplitKMode.Parallel:
462
+ self.reduction_operation.run(reduction_arguments)
463
+
464
+ passed = True
465
+ if self.verification:
466
+ if split_k_mode == cutlass.conv.SplitKMode.Parallel:
467
+ reduction_arguments.sync()
468
+ else:
469
+ arguments.sync()
470
+
471
+ tensor_D_ref = self.host_reference(problem_size, tensor_A, tensor_B, tensor_C, alpha, beta)
472
+
473
+ passed = self.equal(tensor_D, tensor_D_ref, problem_size)
474
+
475
+ try:
476
+ assert passed
477
+ except AssertionError:
478
+ self.print_problem_size(problem_size, split_k_mode)
479
+
480
+ if self.profiling:
481
+ sleep(self.sleep_time)
482
+ for _ in range(self.warmup_iterations):
483
+ self.operation.run(arguments)
484
+ if split_k_mode == cutlass.conv.SplitKMode.Parallel:
485
+ self.reduction_operation.run(reduction_arguments)
486
+
487
+ self.timer.start()
488
+ for _ in range(self.warmup_iterations):
489
+ self.operation.run(arguments)
490
+ if split_k_mode == cutlass.conv.SplitKMode.Parallel:
491
+ self.reduction_operation.run(reduction_arguments)
492
+ self.timer.stop_and_wait()
493
+ runtime = self.timer.duration(self.iterations)
494
+
495
+ # free memory
496
+ del arguments
497
+ if split_k_mode == cutlass.conv.SplitKMode.Parallel:
498
+ del reduction_arguments
499
+
500
+ assert get_allocated_size() == 0, "%d byte of pool memory is not released after current run" % get_allocated_size()
501
+ if self.profiling:
502
+ return runtime
503
+ return passed
504
+
505
+
506
+
507
+ ########################################################################################################
508
+ # TestAllConv: Runs cutlass::conv::device::ImplicitGemmConvolution operator and compares it with reference
509
+ # TestAllConv runs conv operator on default conv problem sizes from test::conv::device::TestbedConv2dProblemSizes
510
+ # Additionaly, each conv2d test can provide conv problem sizes (conv_test_sizes) and blacklist of sizes
511
+ # (conv_blacklist_sizes)
512
+ ############################################################################################################
513
+
514
+ def test_all_conv2d(operation: Conv2dOperation, conv_test_sizes = [], interleaved=False): # TODO: conv_test_sizes and conv_blacklist_sizes
515
+ passed = True
516
+
517
+ #
518
+ # Testbed object
519
+ #
520
+
521
+ testbed = Conv2dLauncher(operation, interleaved=interleaved)
522
+
523
+ #
524
+ # Get conv problem sizes to run conv operator
525
+ #
526
+
527
+ conv_problems = cutlass.test.conv.TestbedConv2dProblemSizes(64)
528
+
529
+ # Vector of conv2d problem sizes to avoid duplicate runs
530
+ conv_tested_sizes = []
531
+
532
+ # TODO: include resnet 50 sizes, user sepecified sizes, and rigorous sizes
533
+
534
+ # Flatten 2D problem_vectors into a 1D problem sizes
535
+ problem_sizes = conv_problems.conv2d_default_sizes
536
+
537
+ problem_sizes = [conv_problem for conv_problem in problem_sizes] + conv_test_sizes
538
+
539
+ # Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slices=1, alpha=1.0, beta=0.0)
540
+ for conv_problem in problem_sizes:
541
+
542
+ # TODO: skip blacklist problem sizes
543
+ if conv_problem in conv_tested_sizes:
544
+ continue
545
+
546
+ # skip channel dimension % 32 != 0 for interleaved case
547
+ if interleaved:
548
+ if conv_problem.K % 32 != 0 or conv_problem.C % 32 != 0:
549
+ continue
550
+
551
+ #
552
+ # Procedurally disable certain cases
553
+ #
554
+
555
+ # CUTLASS DGRAD's *unity* stride specialization only support stride {1, 1}
556
+ if operation.conv_kind == cutlass.conv.Operator.dgrad and operation.stride_support == StrideSupport.Unity:
557
+ if not ((conv_problem.stride_h == 1) and (conv_problem.stride_w == 1)):
558
+ continue
559
+
560
+ if not interleaved:
561
+ # Fixed channels algorithm requires channel count to match access size
562
+ if operation.iterator_algorithm == cutlass.conv.IteratorAlgorithm.fixed_channels:
563
+ if conv_problem.C != operation.A.alignment:
564
+ continue
565
+
566
+ # Few channels algorithm requires channel count to match access size
567
+ if operation.iterator_algorithm == cutlass.conv.IteratorAlgorithm.few_channels:
568
+ if conv_problem.C % operation.A.alignment:
569
+ continue
570
+
571
+ # CUTLASS DGRAD's *strided* stride specialization supports all stride {stride_h, stride_w}
572
+ # Although strided dgrad works for all stride combinations, we are only going
573
+ # to run strided dgrad for non-unity strides
574
+
575
+ if operation.conv_kind == cutlass.conv.Operator.dgrad and operation.stride_support == StrideSupport.Strided:
576
+ if (conv_problem.stride_h == 1) and (conv_problem.stride_w == 1):
577
+ continue
578
+
579
+ #
580
+ # Test
581
+ #
582
+
583
+ # push back tested problem size to avoid re-running duplicates
584
+ conv_tested_sizes.append(conv_problem)
585
+
586
+ passed = testbed.run(conv_problem)
587
+
588
+ # if not passed: return False
589
+
590
+ # TODO: If CUTLASS_UNIT_TEST_PROBLEM_COUNT is set reduce the the number of tested problem counts
591
+
592
+ if interleaved:
593
+ return True
594
+ #
595
+ # filter the cases for split K
596
+ #
597
+
598
+ # Small-channels convolution can't run here.
599
+ if operation.iterator_algorithm in [cutlass.conv.IteratorAlgorithm.fixed_channels, cutlass.conv.IteratorAlgorithm.few_channels]:
600
+ return True
601
+
602
+ # CUTLASS DGRAD's *stride* specialization does not support split-k mode
603
+ if operation.conv_kind == cutlass.conv.Operator.dgrad and operation.stride_support == StrideSupport.Strided:
604
+ conv_problem = cutlass.conv.Conv2dProblemSize(
605
+ cutlass.Tensor4DCoord(1, 56, 56, 8),
606
+ cutlass.Tensor4DCoord(8, 1, 1, 8),
607
+ cutlass.Tensor4DCoord(0, 0, 0, 0),
608
+ cutlass.MatrixCoord(2, 2),
609
+ cutlass.MatrixCoord(1, 1),
610
+ cutlass.conv.Mode.cross_correlation,
611
+ 1, 1
612
+ )
613
+ passed = testbed.run(conv_problem)
614
+
615
+ return passed
616
+
617
+ # Sweep split-k-slice using serial and prallel reduction with non-unity alpha and non-zero beta for
618
+ # a single conv2d problem size. Convolution unit tests take a long time to run so only sweep parameters
619
+ # which are abolutely neccessary to catch functional bugs. The below code does provide option to sweep
620
+ # alpha and beta for local testing, but only runs one value for alpha and beta.
621
+
622
+ conv2d_split_k_test_size = cutlass.conv.Conv2dProblemSize(
623
+ cutlass.Tensor4DCoord(1, 17, 11, 288),
624
+ cutlass.Tensor4DCoord(160, 3, 3, 288),
625
+ cutlass.Tensor4DCoord(1, 1, 1, 1),
626
+ cutlass.MatrixCoord(1, 1),
627
+ cutlass.MatrixCoord(1, 1),
628
+ cutlass.conv.Mode.cross_correlation,
629
+ 1, 1
630
+ )
631
+
632
+ split_k_modes = [cutlass.conv.SplitKMode.Parallel, cutlass.conv.SplitKMode.Serial]
633
+
634
+ split_k_slices = [1, 2, 3, 4, 201]
635
+ problem_alpha = [2.0,]
636
+ problem_beta = [2.0,]
637
+
638
+ for split_k_mode in split_k_modes:
639
+ for split_k_slice in split_k_slices:
640
+ for alpha in problem_alpha:
641
+ for beta in problem_beta:
642
+ passed = testbed.run(conv2d_split_k_test_size.reset_split_k_slices(split_k_slice),
643
+ split_k_mode,
644
+ alpha, beta)
645
+
646
+ return passed