warp-lang 0.9.0__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (315) hide show
  1. warp/__init__.py +15 -7
  2. warp/__init__.pyi +1 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +22 -443
  6. warp/build_dll.py +384 -0
  7. warp/builtins.py +998 -488
  8. warp/codegen.py +1307 -739
  9. warp/config.py +5 -3
  10. warp/constants.py +6 -0
  11. warp/context.py +1291 -548
  12. warp/dlpack.py +31 -31
  13. warp/fabric.py +326 -0
  14. warp/fem/__init__.py +27 -0
  15. warp/fem/cache.py +389 -0
  16. warp/fem/dirichlet.py +181 -0
  17. warp/fem/domain.py +263 -0
  18. warp/fem/field/__init__.py +101 -0
  19. warp/fem/field/field.py +149 -0
  20. warp/fem/field/nodal_field.py +299 -0
  21. warp/fem/field/restriction.py +21 -0
  22. warp/fem/field/test.py +181 -0
  23. warp/fem/field/trial.py +183 -0
  24. warp/fem/geometry/__init__.py +19 -0
  25. warp/fem/geometry/closest_point.py +70 -0
  26. warp/fem/geometry/deformed_geometry.py +271 -0
  27. warp/fem/geometry/element.py +744 -0
  28. warp/fem/geometry/geometry.py +186 -0
  29. warp/fem/geometry/grid_2d.py +373 -0
  30. warp/fem/geometry/grid_3d.py +435 -0
  31. warp/fem/geometry/hexmesh.py +953 -0
  32. warp/fem/geometry/partition.py +376 -0
  33. warp/fem/geometry/quadmesh_2d.py +532 -0
  34. warp/fem/geometry/tetmesh.py +840 -0
  35. warp/fem/geometry/trimesh_2d.py +577 -0
  36. warp/fem/integrate.py +1616 -0
  37. warp/fem/operator.py +191 -0
  38. warp/fem/polynomial.py +213 -0
  39. warp/fem/quadrature/__init__.py +2 -0
  40. warp/fem/quadrature/pic_quadrature.py +245 -0
  41. warp/fem/quadrature/quadrature.py +294 -0
  42. warp/fem/space/__init__.py +292 -0
  43. warp/fem/space/basis_space.py +489 -0
  44. warp/fem/space/collocated_function_space.py +105 -0
  45. warp/fem/space/dof_mapper.py +236 -0
  46. warp/fem/space/function_space.py +145 -0
  47. warp/fem/space/grid_2d_function_space.py +267 -0
  48. warp/fem/space/grid_3d_function_space.py +306 -0
  49. warp/fem/space/hexmesh_function_space.py +352 -0
  50. warp/fem/space/partition.py +350 -0
  51. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  52. warp/fem/space/restriction.py +160 -0
  53. warp/fem/space/shape/__init__.py +15 -0
  54. warp/fem/space/shape/cube_shape_function.py +738 -0
  55. warp/fem/space/shape/shape_function.py +103 -0
  56. warp/fem/space/shape/square_shape_function.py +611 -0
  57. warp/fem/space/shape/tet_shape_function.py +567 -0
  58. warp/fem/space/shape/triangle_shape_function.py +429 -0
  59. warp/fem/space/tetmesh_function_space.py +292 -0
  60. warp/fem/space/topology.py +295 -0
  61. warp/fem/space/trimesh_2d_function_space.py +221 -0
  62. warp/fem/types.py +77 -0
  63. warp/fem/utils.py +495 -0
  64. warp/native/array.h +164 -55
  65. warp/native/builtin.h +150 -174
  66. warp/native/bvh.cpp +75 -328
  67. warp/native/bvh.cu +406 -23
  68. warp/native/bvh.h +37 -45
  69. warp/native/clang/clang.cpp +136 -24
  70. warp/native/crt.cpp +1 -76
  71. warp/native/crt.h +111 -104
  72. warp/native/cuda_crt.h +1049 -0
  73. warp/native/cuda_util.cpp +15 -3
  74. warp/native/cuda_util.h +3 -1
  75. warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
  76. warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
  77. warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
  78. warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
  79. warp/native/cutlass/tools/library/scripts/library.py +799 -0
  80. warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
  81. warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
  82. warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
  83. warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
  84. warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
  85. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
  86. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
  87. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
  88. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
  89. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
  90. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
  91. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
  92. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
  93. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
  94. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
  95. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
  96. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
  97. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
  98. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
  99. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
  100. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
  101. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
  102. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
  103. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
  104. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
  105. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
  106. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
  107. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
  108. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
  109. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
  110. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
  111. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
  112. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
  113. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
  114. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
  115. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
  116. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
  117. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  118. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  119. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
  120. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
  121. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
  122. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  123. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  124. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
  125. warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
  126. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
  127. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
  128. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
  129. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
  130. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
  131. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
  132. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
  133. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
  134. warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
  135. warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
  136. warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
  137. warp/native/cutlass/tools/library/scripts/rt.py +796 -0
  138. warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
  139. warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
  140. warp/native/cutlass_gemm.cu +5 -3
  141. warp/native/exports.h +1240 -949
  142. warp/native/fabric.h +228 -0
  143. warp/native/hashgrid.cpp +4 -4
  144. warp/native/hashgrid.h +22 -2
  145. warp/native/initializer_array.h +2 -2
  146. warp/native/intersect.h +22 -7
  147. warp/native/intersect_adj.h +8 -8
  148. warp/native/intersect_tri.h +13 -16
  149. warp/native/marching.cu +157 -161
  150. warp/native/mat.h +119 -19
  151. warp/native/matnn.h +2 -2
  152. warp/native/mesh.cpp +108 -83
  153. warp/native/mesh.cu +243 -6
  154. warp/native/mesh.h +1547 -458
  155. warp/native/nanovdb/NanoVDB.h +1 -1
  156. warp/native/noise.h +272 -329
  157. warp/native/quat.h +51 -8
  158. warp/native/rand.h +45 -35
  159. warp/native/range.h +6 -2
  160. warp/native/reduce.cpp +157 -0
  161. warp/native/reduce.cu +348 -0
  162. warp/native/runlength_encode.cpp +62 -0
  163. warp/native/runlength_encode.cu +46 -0
  164. warp/native/scan.cu +11 -13
  165. warp/native/scan.h +1 -0
  166. warp/native/solid_angle.h +442 -0
  167. warp/native/sort.cpp +13 -0
  168. warp/native/sort.cu +9 -1
  169. warp/native/sparse.cpp +338 -0
  170. warp/native/sparse.cu +545 -0
  171. warp/native/spatial.h +2 -2
  172. warp/native/temp_buffer.h +30 -0
  173. warp/native/vec.h +126 -24
  174. warp/native/volume.h +120 -0
  175. warp/native/warp.cpp +658 -53
  176. warp/native/warp.cu +660 -68
  177. warp/native/warp.h +112 -12
  178. warp/optim/__init__.py +1 -0
  179. warp/optim/linear.py +922 -0
  180. warp/optim/sgd.py +92 -0
  181. warp/render/render_opengl.py +392 -152
  182. warp/render/render_usd.py +11 -11
  183. warp/sim/__init__.py +2 -2
  184. warp/sim/articulation.py +385 -185
  185. warp/sim/collide.py +21 -8
  186. warp/sim/import_mjcf.py +297 -106
  187. warp/sim/import_urdf.py +389 -210
  188. warp/sim/import_usd.py +198 -97
  189. warp/sim/inertia.py +17 -18
  190. warp/sim/integrator_euler.py +14 -8
  191. warp/sim/integrator_xpbd.py +161 -19
  192. warp/sim/model.py +795 -291
  193. warp/sim/optimizer.py +2 -6
  194. warp/sim/render.py +65 -3
  195. warp/sim/utils.py +3 -0
  196. warp/sparse.py +1227 -0
  197. warp/stubs.py +665 -223
  198. warp/tape.py +66 -15
  199. warp/tests/__main__.py +3 -6
  200. warp/tests/assets/curlnoise_golden.npy +0 -0
  201. warp/tests/assets/pnoise_golden.npy +0 -0
  202. warp/tests/assets/torus.usda +105 -105
  203. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  204. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  205. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  206. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  207. warp/tests/aux_test_unresolved_func.py +14 -0
  208. warp/tests/aux_test_unresolved_symbol.py +14 -0
  209. warp/tests/disabled_kinematics.py +239 -0
  210. warp/tests/run_coverage_serial.py +31 -0
  211. warp/tests/test_adam.py +103 -106
  212. warp/tests/test_arithmetic.py +128 -74
  213. warp/tests/test_array.py +1497 -211
  214. warp/tests/test_array_reduce.py +150 -0
  215. warp/tests/test_atomic.py +64 -28
  216. warp/tests/test_bool.py +99 -0
  217. warp/tests/test_builtins_resolution.py +1292 -0
  218. warp/tests/test_bvh.py +75 -43
  219. warp/tests/test_closest_point_edge_edge.py +54 -57
  220. warp/tests/test_codegen.py +233 -128
  221. warp/tests/test_compile_consts.py +28 -20
  222. warp/tests/test_conditional.py +108 -24
  223. warp/tests/test_copy.py +10 -12
  224. warp/tests/test_ctypes.py +112 -88
  225. warp/tests/test_dense.py +21 -14
  226. warp/tests/test_devices.py +98 -0
  227. warp/tests/test_dlpack.py +136 -108
  228. warp/tests/test_examples.py +277 -0
  229. warp/tests/test_fabricarray.py +955 -0
  230. warp/tests/test_fast_math.py +15 -11
  231. warp/tests/test_fem.py +1271 -0
  232. warp/tests/test_fp16.py +53 -19
  233. warp/tests/test_func.py +187 -74
  234. warp/tests/test_generics.py +194 -49
  235. warp/tests/test_grad.py +180 -116
  236. warp/tests/test_grad_customs.py +176 -0
  237. warp/tests/test_hash_grid.py +52 -37
  238. warp/tests/test_import.py +10 -23
  239. warp/tests/test_indexedarray.py +577 -24
  240. warp/tests/test_intersect.py +18 -9
  241. warp/tests/test_large.py +141 -0
  242. warp/tests/test_launch.py +251 -15
  243. warp/tests/test_lerp.py +64 -65
  244. warp/tests/test_linear_solvers.py +154 -0
  245. warp/tests/test_lvalue.py +493 -0
  246. warp/tests/test_marching_cubes.py +12 -13
  247. warp/tests/test_mat.py +508 -2778
  248. warp/tests/test_mat_lite.py +115 -0
  249. warp/tests/test_mat_scalar_ops.py +2889 -0
  250. warp/tests/test_math.py +103 -9
  251. warp/tests/test_matmul.py +305 -69
  252. warp/tests/test_matmul_lite.py +410 -0
  253. warp/tests/test_mesh.py +71 -14
  254. warp/tests/test_mesh_query_aabb.py +41 -25
  255. warp/tests/test_mesh_query_point.py +325 -34
  256. warp/tests/test_mesh_query_ray.py +39 -22
  257. warp/tests/test_mlp.py +30 -22
  258. warp/tests/test_model.py +92 -89
  259. warp/tests/test_modules_lite.py +39 -0
  260. warp/tests/test_multigpu.py +88 -114
  261. warp/tests/test_noise.py +12 -11
  262. warp/tests/test_operators.py +16 -20
  263. warp/tests/test_options.py +11 -11
  264. warp/tests/test_pinned.py +17 -18
  265. warp/tests/test_print.py +32 -11
  266. warp/tests/test_quat.py +275 -129
  267. warp/tests/test_rand.py +18 -16
  268. warp/tests/test_reload.py +38 -34
  269. warp/tests/test_rounding.py +50 -43
  270. warp/tests/test_runlength_encode.py +190 -0
  271. warp/tests/test_smoothstep.py +9 -11
  272. warp/tests/test_snippet.py +143 -0
  273. warp/tests/test_sparse.py +460 -0
  274. warp/tests/test_spatial.py +276 -243
  275. warp/tests/test_streams.py +110 -85
  276. warp/tests/test_struct.py +331 -85
  277. warp/tests/test_tape.py +39 -21
  278. warp/tests/test_torch.py +118 -89
  279. warp/tests/test_transient_module.py +12 -13
  280. warp/tests/test_types.py +614 -0
  281. warp/tests/test_utils.py +494 -0
  282. warp/tests/test_vec.py +354 -1987
  283. warp/tests/test_vec_lite.py +73 -0
  284. warp/tests/test_vec_scalar_ops.py +2099 -0
  285. warp/tests/test_volume.py +457 -293
  286. warp/tests/test_volume_write.py +124 -134
  287. warp/tests/unittest_serial.py +35 -0
  288. warp/tests/unittest_suites.py +341 -0
  289. warp/tests/unittest_utils.py +568 -0
  290. warp/tests/unused_test_misc.py +71 -0
  291. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  292. warp/thirdparty/appdirs.py +36 -45
  293. warp/thirdparty/unittest_parallel.py +549 -0
  294. warp/torch.py +72 -30
  295. warp/types.py +1744 -713
  296. warp/utils.py +360 -350
  297. warp_lang-0.11.0.dist-info/LICENSE.md +36 -0
  298. warp_lang-0.11.0.dist-info/METADATA +238 -0
  299. warp_lang-0.11.0.dist-info/RECORD +332 -0
  300. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
  301. warp/bin/warp-clang.exp +0 -0
  302. warp/bin/warp-clang.lib +0 -0
  303. warp/bin/warp.exp +0 -0
  304. warp/bin/warp.lib +0 -0
  305. warp/tests/test_all.py +0 -215
  306. warp/tests/test_array_scan.py +0 -60
  307. warp/tests/test_base.py +0 -208
  308. warp/tests/test_unresolved_func.py +0 -7
  309. warp/tests/test_unresolved_symbol.py +0 -7
  310. warp_lang-0.9.0.dist-info/METADATA +0 -20
  311. warp_lang-0.9.0.dist-info/RECORD +0 -177
  312. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  313. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  314. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  315. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,402 @@
1
+ #
2
+ # \file generator.py
3
+ #
4
+ # \brief Generates the CUTLASS Library's instances
5
+ #
6
+
7
+ import enum
8
+ import os.path
9
+ import shutil
10
+
11
+ from library import *
12
+ from gemm_operation import *
13
+ from rank_k_operation import *
14
+ from rank_2k_operation import *
15
+ from trmm_operation import *
16
+ from symm_operation import *
17
+ from conv2d_operation import *
18
+ from conv3d_operation import *
19
+
20
+ ###################################################################################################
21
+
22
+ class EmitOperationKindLibrary:
23
+ def __init__(self, generated_path, kind, args):
24
+ self.generated_path = generated_path
25
+ self.kind = kind
26
+ self.args = args
27
+ self.emitters = {
28
+ OperationKind.Gemm: EmitGemmConfigurationLibrary
29
+ , OperationKind.Conv2d: EmitConv2dConfigurationLibrary
30
+ , OperationKind.Conv3d: EmitConv3dConfigurationLibrary
31
+ , OperationKind.RankK: EmitRankKConfigurationLibrary
32
+ , OperationKind.Rank2K: EmitRank2KConfigurationLibrary
33
+ , OperationKind.Trmm: EmitTrmmConfigurationLibrary
34
+ , OperationKind.Symm: EmitSymmConfigurationLibrary
35
+ }
36
+
37
+ self.configurations = [];
38
+
39
+ self.header_template ="""
40
+ /*
41
+ Generated by manifest.py - Do not edit.
42
+ */
43
+
44
+ #include "cutlass/cutlass.h"
45
+ #include "cutlass/library/library.h"
46
+ #include "cutlass/library/manifest.h"
47
+
48
+ namespace cutlass {
49
+ namespace library {
50
+
51
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
52
+
53
+ """
54
+ self.entry_template = """
55
+
56
+ //
57
+ // Entry point to construct operations
58
+ //
59
+ void initialize_all_${operation_name}_operations(Manifest &manifest) {
60
+ """
61
+ self.configuration_prototype_template = "void initialize_${configuration_name}(Manifest &manifest);\n"
62
+ self.configuration_template =" initialize_${configuration_name}(manifest);\n"
63
+
64
+ self.epilogue_template ="""
65
+
66
+ }
67
+
68
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
69
+
70
+ } // namespace library
71
+ } // namespace cutlass
72
+
73
+ """
74
+
75
+ #
76
+ def __enter__(self):
77
+ self.operation_path = os.path.join(self.generated_path, OperationKindNames[self.kind])
78
+ os.mkdir(self.operation_path)
79
+
80
+ self.top_level_path = os.path.join(self.operation_path, "all_%s_operations.cu" % OperationKindNames[self.kind])
81
+
82
+ self.top_level_file = open(self.top_level_path, "w")
83
+ self.top_level_file.write(self.header_template)
84
+
85
+ self.source_files = [self.top_level_path,]
86
+
87
+ return self
88
+
89
+ #
90
+ def emit(self, configuration_name, operations):
91
+
92
+ with self.emitters[self.kind](self.operation_path, configuration_name) as configuration_emitter:
93
+ for operation in operations:
94
+ configuration_emitter.emit(operation)
95
+
96
+ self.source_files.append(configuration_emitter.configuration_path)
97
+
98
+ self.configurations.append(configuration_name)
99
+ self.top_level_file.write(SubstituteTemplate(self.configuration_prototype_template, {'configuration_name': configuration_name} ))
100
+
101
+ #
102
+ def __exit__(self, exception_type, exception_value, traceback):
103
+ self.top_level_file.write(SubstituteTemplate(self.entry_template, {'operation_name': OperationKindNames[self.kind]}))
104
+
105
+ for configuration_name in self.configurations:
106
+ self.top_level_file.write(SubstituteTemplate(self.configuration_template, {'configuration_name': configuration_name}))
107
+
108
+ self.top_level_file.write(self.epilogue_template)
109
+ self.top_level_file.close()
110
+
111
+ class EmitInterfaceLibrary:
112
+ def __init__(self, generated_path, operation_count, args):
113
+ self.generated_path = generated_path
114
+ self.args = args
115
+
116
+
117
+ self.prototypes = []
118
+ self.fn_calls = []
119
+ self.operation_count = str(operation_count)
120
+
121
+ self.top_level_hdr_template = '''
122
+ /*
123
+ Generated by manifest.py - Do not edit.
124
+ */
125
+ '''
126
+ self.top_level_prologue = '''
127
+
128
+ #include "cutlass/library/library.h"
129
+ #include "cutlass/library/manifest.h"
130
+
131
+ namespace cutlass {
132
+ \tnamespace library {
133
+
134
+ ${prototypes}
135
+
136
+ \t\tvoid initialize_all(Manifest &manifest) {
137
+ \t\t\tmanifest.reserve(${operation_count});\n\n
138
+ ${fn_calls}
139
+ \t\t\t}
140
+
141
+ \t} // namespace library
142
+ } // namespace cutlass
143
+
144
+ '''
145
+
146
+ #
147
+ def __enter__(self):
148
+ self.top_level_path = os.path.join(self.generated_path, 'initialize_all.cpp')
149
+
150
+ self.top_level_file = open(self.top_level_path, "w")
151
+ self.top_level_file.write(self.top_level_hdr_template)
152
+
153
+ self.source_files = [self.top_level_path,]
154
+
155
+ return self
156
+
157
+ #
158
+ def emit(self, operation_name):
159
+ self.prototypes.append(SubstituteTemplate(
160
+ "\t\tvoid initialize_all_${operation_kind}_operations(Manifest &manifest);",
161
+ {'operation_kind': operation_name}))
162
+ self.fn_calls.append(SubstituteTemplate(
163
+ "\t\t\tinitialize_all_${operation_kind}_operations(manifest);",
164
+ {'operation_kind': operation_name}))
165
+
166
+
167
+
168
+ #
169
+ def __exit__(self, exception_type, exception_value, traceback):
170
+ self.top_level_file.write(SubstituteTemplate(self.top_level_prologue, {'prototypes':"\n".join(self.prototypes),
171
+ 'fn_calls':"\n".join(self.fn_calls),
172
+ 'operation_count': self.operation_count}))
173
+ self.top_level_file.close()
174
+
175
+ ###################################################################################################
176
+ ###################################################################################################
177
+
178
+ class Options:
179
+ def __init__(self):
180
+ pass
181
+
182
+ ###################################################################################################
183
+
184
+ #
185
+ class Manifest:
186
+
187
+ #
188
+ def __init__(self, args = None):
189
+ self.operations = {}
190
+ self.args = args
191
+ self.operation_count = 0
192
+ self.operations_by_name = {}
193
+
194
+ self.kernel_filter = ''
195
+ self.kernel_filter_list = []
196
+ self.kernel_names = []
197
+ self.operations_enabled = []
198
+ self.selected_kernels = []
199
+ self.ignore_kernel_names = []
200
+ self.compute_capabilities = [50,]
201
+ self.curr_build_dir = '.'
202
+ self.filter_by_cc = True
203
+
204
+ if self.args:
205
+ self.kernel_filter = self.args.kernels
206
+ self.curr_build_dir = args.curr_build_dir
207
+ architectures = args.architectures.split(';') if len(args.architectures) else ['50',]
208
+ self.compute_capabilities = [int(x) for x in architectures]
209
+
210
+ if args.filter_by_cc in ['false', 'False', '0']:
211
+ self.filter_by_cc = False
212
+
213
+ if args.operations == 'all':
214
+ self.operations_enabled = []
215
+ else:
216
+ operations_list = [
217
+ OperationKind.Gemm
218
+ , OperationKind.Conv2d
219
+ , OperationKind.Conv3d
220
+ , OperationKind.RankK
221
+ , OperationKind.Trmm
222
+ , OperationKind.Symm
223
+ ]
224
+ self.operations_enabled = [x for x in operations_list if OperationKindNames[x] in args.operations.split(',')]
225
+
226
+ if args.kernels == 'all':
227
+ self.kernel_names = []
228
+ else:
229
+ self.kernel_names = [x for x in args.kernels.split(',') if x != '']
230
+
231
+ self.ignore_kernel_names = [x for x in args.ignore_kernels.split(',') if x != '']
232
+
233
+ if args.kernel_filter_file is None:
234
+ self.kernel_filter_list = []
235
+ else:
236
+ self.kernel_filter_list = self.get_kernel_filters(args.kernel_filter_file)
237
+
238
+ #
239
+ def get_kernel_filters (self, kernelListFile):
240
+ if os.path.isfile(kernelListFile):
241
+ with open(kernelListFile, 'r') as fileReader:
242
+ lines = [line.rstrip() for line in fileReader if not line.startswith("#")]
243
+
244
+ lines = [re.compile(line) for line in lines if line]
245
+ return lines
246
+ else:
247
+ return []
248
+
249
+ #
250
+ def filter_out_kernels(self, kernel_name, kernel_filter_list):
251
+
252
+ for kernel_filter_re in kernel_filter_list:
253
+ if kernel_filter_re.search(kernel_name) is not None:
254
+ return True
255
+
256
+ return False
257
+
258
+
259
+ #
260
+ def _filter_string_matches(self, filter_string, haystack):
261
+ ''' Returns true if all substrings appear in the haystack in order'''
262
+ substrings = filter_string.split('*')
263
+ for sub in substrings:
264
+ idx = haystack.find(sub)
265
+ if idx < 0:
266
+ return False
267
+ haystack = haystack[idx + len(sub):]
268
+ return True
269
+
270
+ #
271
+ def filter(self, operation):
272
+ ''' Filtering operations based on various criteria'''
273
+
274
+ # filter based on compute capability
275
+ enabled = not (self.filter_by_cc)
276
+
277
+ for cc in self.compute_capabilities:
278
+ if cc >= operation.tile_description.minimum_compute_capability and \
279
+ cc <= operation.tile_description.maximum_compute_capability and \
280
+ (cc not in SharedMemPerCC or SharedMemPerCC[cc] >= CalculateSmemUsage(operation)):
281
+
282
+ enabled = True
283
+ break
284
+
285
+ if not enabled:
286
+ return False
287
+
288
+ if len(self.operations_enabled) and not operation.operation_kind in self.operations_enabled:
289
+ return False
290
+
291
+ # eliminate duplicates
292
+ if operation.procedural_name() in self.operations_by_name.keys():
293
+ return False
294
+
295
+ # Filter based on list of valid substrings
296
+ if len(self.kernel_names):
297
+ name = operation.procedural_name()
298
+ enabled = False
299
+
300
+ # compare against the include list
301
+ for name_substr in self.kernel_names:
302
+ if self._filter_string_matches(name_substr, name):
303
+ enabled = True
304
+ break
305
+
306
+ # compare against the exclude list
307
+ for name_substr in self.ignore_kernel_names:
308
+ if self._filter_string_matches(name_substr, name):
309
+ enabled = False
310
+ break
311
+
312
+ if len(self.kernel_filter_list) > 0:
313
+ enabled = False
314
+ if self.filter_out_kernels(operation.procedural_name(), self.kernel_filter_list):
315
+ enabled = True
316
+
317
+ # todo: filter based on compute data type
318
+ return enabled
319
+ #
320
+
321
+ #
322
+ def append(self, operation):
323
+ '''
324
+ Inserts the operation.
325
+
326
+ operation_kind -> configuration_name -> []
327
+ '''
328
+
329
+ if self.filter(operation):
330
+
331
+ self.selected_kernels.append(operation.procedural_name())
332
+
333
+ self.operations_by_name[operation.procedural_name()] = operation
334
+
335
+ # add the configuration
336
+ configuration_name = operation.configuration_name()
337
+
338
+ if operation.operation_kind not in self.operations.keys():
339
+ self.operations[operation.operation_kind] = {}
340
+
341
+ if configuration_name not in self.operations[operation.operation_kind].keys():
342
+ self.operations[operation.operation_kind][configuration_name] = []
343
+
344
+ self.operations[operation.operation_kind][configuration_name].append(operation)
345
+ self.operation_count += 1
346
+ #
347
+
348
+ #
349
+ def emit(self, target = GeneratorTarget.Library):
350
+
351
+ operation_emitters = {
352
+ GeneratorTarget.Library: EmitOperationKindLibrary
353
+ }
354
+ interface_emitters = {
355
+ GeneratorTarget.Library: EmitInterfaceLibrary
356
+ }
357
+
358
+ generated_path = os.path.join(self.curr_build_dir, 'generated')
359
+
360
+ # create generated/
361
+ if os.path.exists(generated_path):
362
+ shutil.rmtree(generated_path)
363
+
364
+ os.mkdir(generated_path)
365
+
366
+ source_files = []
367
+
368
+ with interface_emitters[target](generated_path, self.operation_count, self.args) as iface_emitter:
369
+ for operation_kind, configurations in self.operations.items():
370
+ iface_emitter.emit(OperationKindNames[operation_kind])
371
+
372
+ source_files += iface_emitter.source_files
373
+
374
+
375
+ # for each operation kind, emit initializer for all configurations
376
+ for operation_kind, configurations in self.operations.items():
377
+ with operation_emitters[target](generated_path, operation_kind, self.args) as operation_kind_emitter:
378
+ for configuration_name, operations in configurations.items():
379
+ operation_kind_emitter.emit(configuration_name, operations)
380
+
381
+ source_files += operation_kind_emitter.source_files
382
+
383
+ # write the manifest.cmake file containing paths from all targets
384
+ manifest_path = os.path.join(generated_path, "manifest.cmake")
385
+ with open(manifest_path, "w") as manifest_file:
386
+
387
+ target_name = 'cutlass_library_objs'
388
+
389
+ target_text = SubstituteTemplate("""cutlass_target_sources(
390
+ ${target_name}
391
+ BATCH_SOURCES ON
392
+ PRIVATE
393
+ """, { 'target_name': target_name})
394
+
395
+ manifest_file.write(target_text)
396
+
397
+ for source_file in source_files:
398
+ manifest_file.write(" %s\n" % str(source_file.replace('\\', '/')))
399
+ manifest_file.write(")")
400
+ #
401
+
402
+ ###################################################################################################
@@ -0,0 +1,96 @@
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ # Configuration file for the Sphinx documentation builder.
34
+ #
35
+ # This file only contains a selection of the most common options. For a full
36
+ # list see the documentation:
37
+ # https://www.sphinx-doc.org/en/master/usage/configuration.html
38
+
39
+ # -- Path setup --------------------------------------------------------------
40
+
41
+ # If extensions (or modules to document with autodoc) are in another directory,
42
+ # add these directories to sys.path here. If the directory is relative to the
43
+ # documentation root, use os.path.abspath to make it absolute, like shown here.
44
+ #
45
+ # import os
46
+ # import sys
47
+ # sys.path.insert(0, os.path.abspath('.'))
48
+
49
+
50
+ # -- Project information -----------------------------------------------------
51
+
52
+ project = 'PyCutlass'
53
+ copyright = '2022, Zhaodong Chen; Andrew Kerr; Haicheng Wu; Szymon Migacz; Graham Markall'
54
+ author = 'Zhaodong Chen; Andrew Kerr; Haicheng Wu; Szymon Migacz; Graham Markall'
55
+
56
+
57
+ # -- General configuration ---------------------------------------------------
58
+
59
+ # Add any Sphinx extension module names here, as strings. They can be
60
+ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
61
+ # ones.
62
+ extensions = [
63
+ 'sphinx.ext.duration',
64
+ 'sphinx.ext.doctest',
65
+ 'sphinx.ext.autodoc',
66
+ 'sphinx.ext.intersphinx',
67
+ 'enum_tools.autoenum',
68
+ 'sphinx.ext.autosummary',
69
+ 'm2r2'
70
+ ]
71
+
72
+ source_suffix = [".rst", ".md"]
73
+
74
+ autosummary_generate = True
75
+ autosummary_imported_members = True
76
+
77
+ # Add any paths that contain templates here, relative to this directory.
78
+ templates_path = ['_templates']
79
+
80
+ # List of patterns, relative to source directory, that match files and
81
+ # directories to ignore when looking for source files.
82
+ # This pattern also affects html_static_path and html_extra_path.
83
+ exclude_patterns = []
84
+
85
+
86
+ # -- Options for HTML output -------------------------------------------------
87
+
88
+ # The theme to use for HTML and HTML Help pages. See the documentation for
89
+ # a list of builtin themes.
90
+ #
91
+ html_theme = 'bizstyle'
92
+
93
+ # Add any paths that contain custom static files (such as style sheets) here,
94
+ # relative to this directory. They are copied after the builtin static files,
95
+ # so a file named "default.css" will overwrite the builtin "default.css".
96
+ # html_static_path = ['_static']
@@ -0,0 +1,106 @@
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ from pycutlass import *
34
+ import pycutlass
35
+ from pycutlass.epilogue import LinearCombination
36
+ from pycutlass.test.conv2d_testbed import Conv2dLauncher
37
+
38
+
39
+ if __name__ == "__main__":
40
+ pycutlass.get_memory_pool(2**33, 2**33)
41
+ pycutlass.compiler.nvcc()
42
+
43
+ math_inst = MathInstruction(
44
+ instruction_shape=[16, 8, 16],
45
+ element_a=cutlass.float16, element_b=cutlass.float16,
46
+ element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp,
47
+ math_operation=MathOperation.multiply_add
48
+ )
49
+
50
+ A = TensorDescription(
51
+ element=math_inst.element_a,
52
+ layout=cutlass.TensorNHWC,
53
+ alignment=8)
54
+ B = TensorDescription(
55
+ element=math_inst.element_b,
56
+ layout=cutlass.TensorNHWC,
57
+ alignment=8)
58
+ C = TensorDescription(
59
+ element=cutlass.float32,
60
+ layout=cutlass.TensorNHWC,
61
+ alignment=8)
62
+
63
+ tile_description = TileDescription(
64
+ threadblock_shape=[128, 128, 64], stages=4,
65
+ warp_count=[2, 2, 1],
66
+ math_instruction=math_inst
67
+ )
68
+
69
+ epilogue_functor = LinearCombination(cutlass.float32, 4, cutlass.float32, cutlass.float32)
70
+
71
+ operation = Conv2dOperation(
72
+ conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
73
+ arch=80, tile_description=tile_description, A=A, B=B, C=C,
74
+ element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided,
75
+ epilogue_functor=epilogue_functor,
76
+ swizzling_functor=cutlass.IdentitySwizzle1
77
+ )
78
+
79
+ profiler = Conv2dLauncher(operation, verification=False, profiling=True)
80
+
81
+ python_runtime = profiler.run(
82
+ problem_size = cutlass.conv.Conv2dProblemSize(
83
+ cutlass.Tensor4DCoord(32, 224, 224, 128),
84
+ cutlass.Tensor4DCoord(128, 3, 3, 128),
85
+ cutlass.Tensor4DCoord(1, 1, 1, 1),
86
+ cutlass.MatrixCoord(1, 1),
87
+ cutlass.MatrixCoord(1, 1),
88
+ cutlass.conv.Mode.cross_correlation,
89
+ 1, 1
90
+ ), split_k_mode=cutlass.conv.SplitKMode.Serial
91
+ )
92
+
93
+
94
+ cpp_runtime = profiler.run_cutlass_profiler(
95
+ problem_size = cutlass.conv.Conv2dProblemSize(
96
+ cutlass.Tensor4DCoord(32, 224, 224, 128),
97
+ cutlass.Tensor4DCoord(128, 3, 3, 128),
98
+ cutlass.Tensor4DCoord(1, 1, 1, 1),
99
+ cutlass.MatrixCoord(1, 1),
100
+ cutlass.MatrixCoord(1, 1),
101
+ cutlass.conv.Mode.cross_correlation,
102
+ 1, 1
103
+ ), split_k_mode=cutlass.conv.SplitKMode.Serial
104
+ )
105
+
106
+ print(cpp_runtime / python_runtime)
@@ -0,0 +1,91 @@
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ import pycutlass
34
+ from pycutlass import *
35
+ from pycutlass.test import *
36
+ from pycutlass.test.gemm_testbed import GemmUniversalLauncher
37
+
38
+ if __name__ == '__main__':
39
+ pycutlass.get_memory_pool(2**32, 2**32)
40
+ pycutlass.compiler.nvcc()
41
+
42
+ math_inst = MathInstruction(
43
+ instruction_shape=[16, 8, 16],
44
+ element_a=cutlass.float16, element_b=cutlass.float16,
45
+ element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp,
46
+ math_operation=MathOperation.multiply_add
47
+ )
48
+
49
+ tile_description = TileDescription(
50
+ threadblock_shape=[256, 128, 32],
51
+ stages=3, warp_count=[4, 2, 1],
52
+ math_instruction=math_inst
53
+ )
54
+
55
+ A = TensorDescription(
56
+ element=cutlass.float16, layout=cutlass.RowMajor,
57
+ alignment=4
58
+ )
59
+ B = TensorDescription(
60
+ element=cutlass.float16, layout=cutlass.RowMajor,
61
+ alignment=4
62
+ )
63
+ C = TensorDescription(
64
+ element=cutlass.float32, layout=cutlass.ColumnMajor,
65
+ alignment=4
66
+ )
67
+
68
+ element_epilogue = cutlass.float32
69
+
70
+ epilogue_functor = LinearCombination(cutlass.float32, 4, cutlass.float32, cutlass.float32)
71
+
72
+ swizzling_functor = cutlass.IdentitySwizzle1
73
+
74
+ operation = GemmOperationUniversal(
75
+ arch=80, tile_description=tile_description,
76
+ A=A, B=B, C=C, element_epilogue=element_epilogue,
77
+ epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
78
+ )
79
+
80
+ profiler = GemmUniversalLauncher(operation, verification=False, profiling=True)
81
+ python_runtime = profiler.run(
82
+ mode=cutlass.gemm.Mode.Gemm,
83
+ problem_size=cutlass.gemm.GemmCoord(4096, 4096, 4096)
84
+ )
85
+
86
+ cpp_runtime = profiler.run_cutlass_profiler(
87
+ mode=cutlass.gemm.Mode.Gemm,
88
+ problem_size=cutlass.gemm.GemmCoord(4096, 4096, 4096),
89
+ )
90
+
91
+ print(cpp_runtime / python_runtime)