warp-lang 0.9.0__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (315) hide show
  1. warp/__init__.py +15 -7
  2. warp/__init__.pyi +1 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +22 -443
  6. warp/build_dll.py +384 -0
  7. warp/builtins.py +998 -488
  8. warp/codegen.py +1307 -739
  9. warp/config.py +5 -3
  10. warp/constants.py +6 -0
  11. warp/context.py +1291 -548
  12. warp/dlpack.py +31 -31
  13. warp/fabric.py +326 -0
  14. warp/fem/__init__.py +27 -0
  15. warp/fem/cache.py +389 -0
  16. warp/fem/dirichlet.py +181 -0
  17. warp/fem/domain.py +263 -0
  18. warp/fem/field/__init__.py +101 -0
  19. warp/fem/field/field.py +149 -0
  20. warp/fem/field/nodal_field.py +299 -0
  21. warp/fem/field/restriction.py +21 -0
  22. warp/fem/field/test.py +181 -0
  23. warp/fem/field/trial.py +183 -0
  24. warp/fem/geometry/__init__.py +19 -0
  25. warp/fem/geometry/closest_point.py +70 -0
  26. warp/fem/geometry/deformed_geometry.py +271 -0
  27. warp/fem/geometry/element.py +744 -0
  28. warp/fem/geometry/geometry.py +186 -0
  29. warp/fem/geometry/grid_2d.py +373 -0
  30. warp/fem/geometry/grid_3d.py +435 -0
  31. warp/fem/geometry/hexmesh.py +953 -0
  32. warp/fem/geometry/partition.py +376 -0
  33. warp/fem/geometry/quadmesh_2d.py +532 -0
  34. warp/fem/geometry/tetmesh.py +840 -0
  35. warp/fem/geometry/trimesh_2d.py +577 -0
  36. warp/fem/integrate.py +1616 -0
  37. warp/fem/operator.py +191 -0
  38. warp/fem/polynomial.py +213 -0
  39. warp/fem/quadrature/__init__.py +2 -0
  40. warp/fem/quadrature/pic_quadrature.py +245 -0
  41. warp/fem/quadrature/quadrature.py +294 -0
  42. warp/fem/space/__init__.py +292 -0
  43. warp/fem/space/basis_space.py +489 -0
  44. warp/fem/space/collocated_function_space.py +105 -0
  45. warp/fem/space/dof_mapper.py +236 -0
  46. warp/fem/space/function_space.py +145 -0
  47. warp/fem/space/grid_2d_function_space.py +267 -0
  48. warp/fem/space/grid_3d_function_space.py +306 -0
  49. warp/fem/space/hexmesh_function_space.py +352 -0
  50. warp/fem/space/partition.py +350 -0
  51. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  52. warp/fem/space/restriction.py +160 -0
  53. warp/fem/space/shape/__init__.py +15 -0
  54. warp/fem/space/shape/cube_shape_function.py +738 -0
  55. warp/fem/space/shape/shape_function.py +103 -0
  56. warp/fem/space/shape/square_shape_function.py +611 -0
  57. warp/fem/space/shape/tet_shape_function.py +567 -0
  58. warp/fem/space/shape/triangle_shape_function.py +429 -0
  59. warp/fem/space/tetmesh_function_space.py +292 -0
  60. warp/fem/space/topology.py +295 -0
  61. warp/fem/space/trimesh_2d_function_space.py +221 -0
  62. warp/fem/types.py +77 -0
  63. warp/fem/utils.py +495 -0
  64. warp/native/array.h +164 -55
  65. warp/native/builtin.h +150 -174
  66. warp/native/bvh.cpp +75 -328
  67. warp/native/bvh.cu +406 -23
  68. warp/native/bvh.h +37 -45
  69. warp/native/clang/clang.cpp +136 -24
  70. warp/native/crt.cpp +1 -76
  71. warp/native/crt.h +111 -104
  72. warp/native/cuda_crt.h +1049 -0
  73. warp/native/cuda_util.cpp +15 -3
  74. warp/native/cuda_util.h +3 -1
  75. warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
  76. warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
  77. warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
  78. warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
  79. warp/native/cutlass/tools/library/scripts/library.py +799 -0
  80. warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
  81. warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
  82. warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
  83. warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
  84. warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
  85. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
  86. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
  87. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
  88. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
  89. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
  90. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
  91. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
  92. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
  93. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
  94. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
  95. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
  96. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
  97. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
  98. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
  99. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
  100. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
  101. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
  102. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
  103. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
  104. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
  105. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
  106. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
  107. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
  108. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
  109. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
  110. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
  111. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
  112. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
  113. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
  114. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
  115. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
  116. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
  117. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  118. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  119. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
  120. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
  121. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
  122. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  123. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  124. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
  125. warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
  126. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
  127. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
  128. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
  129. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
  130. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
  131. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
  132. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
  133. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
  134. warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
  135. warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
  136. warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
  137. warp/native/cutlass/tools/library/scripts/rt.py +796 -0
  138. warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
  139. warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
  140. warp/native/cutlass_gemm.cu +5 -3
  141. warp/native/exports.h +1240 -949
  142. warp/native/fabric.h +228 -0
  143. warp/native/hashgrid.cpp +4 -4
  144. warp/native/hashgrid.h +22 -2
  145. warp/native/initializer_array.h +2 -2
  146. warp/native/intersect.h +22 -7
  147. warp/native/intersect_adj.h +8 -8
  148. warp/native/intersect_tri.h +13 -16
  149. warp/native/marching.cu +157 -161
  150. warp/native/mat.h +119 -19
  151. warp/native/matnn.h +2 -2
  152. warp/native/mesh.cpp +108 -83
  153. warp/native/mesh.cu +243 -6
  154. warp/native/mesh.h +1547 -458
  155. warp/native/nanovdb/NanoVDB.h +1 -1
  156. warp/native/noise.h +272 -329
  157. warp/native/quat.h +51 -8
  158. warp/native/rand.h +45 -35
  159. warp/native/range.h +6 -2
  160. warp/native/reduce.cpp +157 -0
  161. warp/native/reduce.cu +348 -0
  162. warp/native/runlength_encode.cpp +62 -0
  163. warp/native/runlength_encode.cu +46 -0
  164. warp/native/scan.cu +11 -13
  165. warp/native/scan.h +1 -0
  166. warp/native/solid_angle.h +442 -0
  167. warp/native/sort.cpp +13 -0
  168. warp/native/sort.cu +9 -1
  169. warp/native/sparse.cpp +338 -0
  170. warp/native/sparse.cu +545 -0
  171. warp/native/spatial.h +2 -2
  172. warp/native/temp_buffer.h +30 -0
  173. warp/native/vec.h +126 -24
  174. warp/native/volume.h +120 -0
  175. warp/native/warp.cpp +658 -53
  176. warp/native/warp.cu +660 -68
  177. warp/native/warp.h +112 -12
  178. warp/optim/__init__.py +1 -0
  179. warp/optim/linear.py +922 -0
  180. warp/optim/sgd.py +92 -0
  181. warp/render/render_opengl.py +392 -152
  182. warp/render/render_usd.py +11 -11
  183. warp/sim/__init__.py +2 -2
  184. warp/sim/articulation.py +385 -185
  185. warp/sim/collide.py +21 -8
  186. warp/sim/import_mjcf.py +297 -106
  187. warp/sim/import_urdf.py +389 -210
  188. warp/sim/import_usd.py +198 -97
  189. warp/sim/inertia.py +17 -18
  190. warp/sim/integrator_euler.py +14 -8
  191. warp/sim/integrator_xpbd.py +161 -19
  192. warp/sim/model.py +795 -291
  193. warp/sim/optimizer.py +2 -6
  194. warp/sim/render.py +65 -3
  195. warp/sim/utils.py +3 -0
  196. warp/sparse.py +1227 -0
  197. warp/stubs.py +665 -223
  198. warp/tape.py +66 -15
  199. warp/tests/__main__.py +3 -6
  200. warp/tests/assets/curlnoise_golden.npy +0 -0
  201. warp/tests/assets/pnoise_golden.npy +0 -0
  202. warp/tests/assets/torus.usda +105 -105
  203. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  204. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  205. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  206. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  207. warp/tests/aux_test_unresolved_func.py +14 -0
  208. warp/tests/aux_test_unresolved_symbol.py +14 -0
  209. warp/tests/disabled_kinematics.py +239 -0
  210. warp/tests/run_coverage_serial.py +31 -0
  211. warp/tests/test_adam.py +103 -106
  212. warp/tests/test_arithmetic.py +128 -74
  213. warp/tests/test_array.py +1497 -211
  214. warp/tests/test_array_reduce.py +150 -0
  215. warp/tests/test_atomic.py +64 -28
  216. warp/tests/test_bool.py +99 -0
  217. warp/tests/test_builtins_resolution.py +1292 -0
  218. warp/tests/test_bvh.py +75 -43
  219. warp/tests/test_closest_point_edge_edge.py +54 -57
  220. warp/tests/test_codegen.py +233 -128
  221. warp/tests/test_compile_consts.py +28 -20
  222. warp/tests/test_conditional.py +108 -24
  223. warp/tests/test_copy.py +10 -12
  224. warp/tests/test_ctypes.py +112 -88
  225. warp/tests/test_dense.py +21 -14
  226. warp/tests/test_devices.py +98 -0
  227. warp/tests/test_dlpack.py +136 -108
  228. warp/tests/test_examples.py +277 -0
  229. warp/tests/test_fabricarray.py +955 -0
  230. warp/tests/test_fast_math.py +15 -11
  231. warp/tests/test_fem.py +1271 -0
  232. warp/tests/test_fp16.py +53 -19
  233. warp/tests/test_func.py +187 -74
  234. warp/tests/test_generics.py +194 -49
  235. warp/tests/test_grad.py +180 -116
  236. warp/tests/test_grad_customs.py +176 -0
  237. warp/tests/test_hash_grid.py +52 -37
  238. warp/tests/test_import.py +10 -23
  239. warp/tests/test_indexedarray.py +577 -24
  240. warp/tests/test_intersect.py +18 -9
  241. warp/tests/test_large.py +141 -0
  242. warp/tests/test_launch.py +251 -15
  243. warp/tests/test_lerp.py +64 -65
  244. warp/tests/test_linear_solvers.py +154 -0
  245. warp/tests/test_lvalue.py +493 -0
  246. warp/tests/test_marching_cubes.py +12 -13
  247. warp/tests/test_mat.py +508 -2778
  248. warp/tests/test_mat_lite.py +115 -0
  249. warp/tests/test_mat_scalar_ops.py +2889 -0
  250. warp/tests/test_math.py +103 -9
  251. warp/tests/test_matmul.py +305 -69
  252. warp/tests/test_matmul_lite.py +410 -0
  253. warp/tests/test_mesh.py +71 -14
  254. warp/tests/test_mesh_query_aabb.py +41 -25
  255. warp/tests/test_mesh_query_point.py +325 -34
  256. warp/tests/test_mesh_query_ray.py +39 -22
  257. warp/tests/test_mlp.py +30 -22
  258. warp/tests/test_model.py +92 -89
  259. warp/tests/test_modules_lite.py +39 -0
  260. warp/tests/test_multigpu.py +88 -114
  261. warp/tests/test_noise.py +12 -11
  262. warp/tests/test_operators.py +16 -20
  263. warp/tests/test_options.py +11 -11
  264. warp/tests/test_pinned.py +17 -18
  265. warp/tests/test_print.py +32 -11
  266. warp/tests/test_quat.py +275 -129
  267. warp/tests/test_rand.py +18 -16
  268. warp/tests/test_reload.py +38 -34
  269. warp/tests/test_rounding.py +50 -43
  270. warp/tests/test_runlength_encode.py +190 -0
  271. warp/tests/test_smoothstep.py +9 -11
  272. warp/tests/test_snippet.py +143 -0
  273. warp/tests/test_sparse.py +460 -0
  274. warp/tests/test_spatial.py +276 -243
  275. warp/tests/test_streams.py +110 -85
  276. warp/tests/test_struct.py +331 -85
  277. warp/tests/test_tape.py +39 -21
  278. warp/tests/test_torch.py +118 -89
  279. warp/tests/test_transient_module.py +12 -13
  280. warp/tests/test_types.py +614 -0
  281. warp/tests/test_utils.py +494 -0
  282. warp/tests/test_vec.py +354 -1987
  283. warp/tests/test_vec_lite.py +73 -0
  284. warp/tests/test_vec_scalar_ops.py +2099 -0
  285. warp/tests/test_volume.py +457 -293
  286. warp/tests/test_volume_write.py +124 -134
  287. warp/tests/unittest_serial.py +35 -0
  288. warp/tests/unittest_suites.py +341 -0
  289. warp/tests/unittest_utils.py +568 -0
  290. warp/tests/unused_test_misc.py +71 -0
  291. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  292. warp/thirdparty/appdirs.py +36 -45
  293. warp/thirdparty/unittest_parallel.py +549 -0
  294. warp/torch.py +72 -30
  295. warp/types.py +1744 -713
  296. warp/utils.py +360 -350
  297. warp_lang-0.11.0.dist-info/LICENSE.md +36 -0
  298. warp_lang-0.11.0.dist-info/METADATA +238 -0
  299. warp_lang-0.11.0.dist-info/RECORD +332 -0
  300. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
  301. warp/bin/warp-clang.exp +0 -0
  302. warp/bin/warp-clang.lib +0 -0
  303. warp/bin/warp.exp +0 -0
  304. warp/bin/warp.lib +0 -0
  305. warp/tests/test_all.py +0 -215
  306. warp/tests/test_array_scan.py +0 -60
  307. warp/tests/test_base.py +0 -208
  308. warp/tests/test_unresolved_func.py +0 -7
  309. warp/tests/test_unresolved_symbol.py +0 -7
  310. warp_lang-0.9.0.dist-info/METADATA +0 -20
  311. warp_lang-0.9.0.dist-info/RECORD +0 -177
  312. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  313. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  314. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  315. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,557 @@
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ from time import sleep
34
+ import pycutlass
35
+ from pycutlass import *
36
+ import cutlass
37
+ from cuda import cudart
38
+ from cuda import cuda
39
+ from bfloat16 import bfloat16
40
+ from .profiler import GpuTimer
41
+ import subprocess
42
+
43
+
44
+ def transpose(layout):
45
+ if layout == cutlass.RowMajor:
46
+ return cutlass.ColumnMajor
47
+ elif layout == cutlass.ColumnMajor:
48
+ return cutlass.RowMajor
49
+ elif layout == cutlass.ColumnMajorInterleaved32:
50
+ return cutlass.RowMajorInterleaved32
51
+ elif layout == cutlass.RowMajorInterleaved32:
52
+ return cutlass.ColumnMajorInterleaved32
53
+
54
+
55
+ def getTensorRef(tensor: np.ndarray, problem_size: cutlass.gemm.GemmCoord, operand: str, layout: cutlass.layout):
56
+ ptr = tensor.__array_interface__['data'][0]
57
+ if operand == "a":
58
+ tensor_coord = problem_size.mk()
59
+ elif operand == "b":
60
+ tensor_coord = problem_size.kn()
61
+ elif operand in ["c", "d"]:
62
+ tensor_coord = problem_size.mn()
63
+ else:
64
+ raise ValueError("unknonw operand: " + operand)
65
+
66
+ if layout == cutlass.RowMajor:
67
+ layout = cutlass.RowMajor.packed(tensor_coord)
68
+ layout_tag = "RowMajor"
69
+ elif layout == cutlass.ColumnMajor:
70
+ layout = cutlass.ColumnMajor.packed(tensor_coord)
71
+ layout_tag = "ColumnMajor"
72
+ elif layout == cutlass.ColumnMajorInterleaved32:
73
+ layout = cutlass.ColumnMajorInterleaved32.packed(tensor_coord)
74
+ layout_tag = "ColumnMajorInterleaved32"
75
+ elif layout == cutlass.RowMajorInterleaved32:
76
+ layout = cutlass.RowMajorInterleaved32.packed(tensor_coord)
77
+ layout_tag = "RowMajorInterleaved32"
78
+ else:
79
+ raise ValueError("unsupported layout")
80
+ if tensor.dtype == np.float32:
81
+ ref_name = "TensorRefF32" + layout_tag
82
+ elif tensor.dtype == np.float64:
83
+ ref_name = "TensorRefF64" + layout_tag
84
+ elif tensor.dtype == np.float16:
85
+ ref_name = "TensorRefF16" + layout_tag
86
+ elif tensor.dtype == bfloat16:
87
+ ref_name = "TensorRefBF16" + layout_tag
88
+ elif tensor.dtype == np.int8:
89
+ ref_name = "TensorRefS8" + layout_tag
90
+ elif tensor.dtype == np.int32:
91
+ ref_name = "TensorRefS32" + layout_tag
92
+ else:
93
+ raise ValueError("unsupported datatype %s" %
94
+ ShortDataTypeNames[tensor.dtype])
95
+
96
+ return getattr(cutlass, ref_name)(ptr, layout)
97
+
98
+
99
+ def getTensorView(tensor: np.ndarray, problem_size: cutlass.gemm.GemmCoord, operand: str, layout: str):
100
+ tensor_ref = getTensorRef(tensor, problem_size, operand, layout)
101
+
102
+ if operand == "a":
103
+ tensor_coord = problem_size.mk()
104
+ elif operand == "b":
105
+ tensor_coord = problem_size.kn()
106
+ elif operand in ["c", "d"]:
107
+ tensor_coord = problem_size.mn()
108
+ else:
109
+ raise ValueError("unknonw operand: " + operand)
110
+
111
+ if layout == cutlass.RowMajor:
112
+ layout_tag = "RowMajor"
113
+ elif layout == cutlass.ColumnMajor:
114
+ layout_tag = "ColumnMajor"
115
+ elif layout == cutlass.ColumnMajorInterleaved32:
116
+ layout_tag = "ColumnMajorInterleaved32"
117
+ elif layout == cutlass.RowMajorInterleaved32:
118
+ layout_tag = "RowMajorInterleaved32"
119
+ else:
120
+ raise ValueError("unsupported layout")
121
+ if tensor.dtype == np.float32:
122
+ ref_name = "TensorViewF32" + layout_tag
123
+ elif tensor.dtype == np.float64:
124
+ ref_name = "TensorViewF64" + layout_tag
125
+ elif tensor.dtype == np.float16:
126
+ ref_name = "TensorViewF16" + layout_tag
127
+ elif tensor.dtype == bfloat16:
128
+ ref_name = "TensorViewBF16" + layout_tag
129
+ elif tensor.dtype == np.int32:
130
+ ref_name = "TensorViewS32" + layout_tag
131
+ elif tensor.dtype == np.int8:
132
+ ref_name = "TensorViewS8" + layout_tag
133
+ else:
134
+ raise ValueError("unsupported datatype")
135
+
136
+ return getattr(cutlass, ref_name)(tensor_ref, tensor_coord)
137
+
138
+
139
+ class GemmUniversalLauncher:
140
+ def __init__(self, operation: 'GemmOperationUniversal', seed: int = 2080, interleaved=False,
141
+ verification=True, profiling=False, warmup_iterations=500, iterations=500, **kwargs) -> None:
142
+ # create the reduction kernel
143
+ self.reduction_operation: ReductionOperation = ReductionOperation(
144
+ shape=cutlass.MatrixCoord(4, 32 * operation.C.alignment),
145
+ C=operation.C, element_accumulator=operation.tile_description.math_instruction.element_accumulator,
146
+ element_compute=operation.epilogue_functor.element_epilogue, epilogue_functor=operation.epilogue_functor,
147
+ count=operation.C.alignment
148
+ )
149
+
150
+ self.math_operation = operation.tile_description.math_instruction.math_operation
151
+
152
+ #: verify the output result
153
+ self.verification = verification
154
+ #: profile the kernel's runtime
155
+ self.profiling = profiling
156
+
157
+ self.timer = GpuTimer()
158
+
159
+ self.warmup_iterations = warmup_iterations
160
+ self.iterations = iterations
161
+
162
+ if "sleep" in kwargs.keys():
163
+ self.sleep_time = kwargs["sleep"]
164
+ else:
165
+ self.sleep_time = 0
166
+
167
+ #
168
+ # Compile the operator
169
+ #
170
+
171
+ pycutlass.compiler.add_module([operation, self.reduction_operation])
172
+
173
+ self.operation = operation
174
+
175
+ self.dtype_A = GemmUniversalLauncher.numpy_type(operation.A.element)
176
+ self.dtype_B = GemmUniversalLauncher.numpy_type(operation.B.element)
177
+ self.dtype_C = GemmUniversalLauncher.numpy_type(operation.C.element)
178
+ self.dtype_D = GemmUniversalLauncher.numpy_type(operation.C.element)
179
+
180
+ accumulator_size = DataTypeSize[operation.tile_description.math_instruction.element_accumulator]
181
+ element_size = DataTypeSize[operation.A.element]
182
+
183
+ if element_size == 1:
184
+ self.scope_max = 1
185
+ self.scope_min = 0
186
+ elif element_size <= 8:
187
+ self.scope_max = 1
188
+ self.scope_min = -1
189
+ elif element_size == 16:
190
+ self.scope_max = 4
191
+ self.scope_min = -4
192
+ else:
193
+ self.scope_max = 8
194
+ self.scope_min = -8
195
+
196
+ #: seed
197
+ self.seed: int = seed
198
+
199
+ #: whether the layout is interleaved
200
+ self.interleaved = interleaved
201
+
202
+ #: compute type
203
+ self.compute_type = operation.epilogue_functor.element_epilogue
204
+ self.accumulator_type = operation.tile_description.math_instruction.element_accumulator
205
+
206
+ def print_problem_size(self, p, mode, batch_count):
207
+ if mode == cutlass.gemm.Mode.Gemm:
208
+ mode = "Gemm"
209
+ elif mode == cutlass.gemm.Mode.GemmSplitKParallel:
210
+ mode = "GemmSplitKParalel"
211
+ problem_size = "problem: %d, %d, %d\n batch_count: %d\n mode: %s" % (
212
+ p.m(), p.n(), p.k(), batch_count, mode)
213
+ print(problem_size)
214
+
215
+ @staticmethod
216
+ def numpy_type(type):
217
+ if type == cutlass.float64:
218
+ return np.float64
219
+ elif type == cutlass.float32:
220
+ return np.float32
221
+ elif type == cutlass.float16:
222
+ return np.float16
223
+ elif type == cutlass.bfloat16:
224
+ return bfloat16
225
+ elif type == cutlass.int32:
226
+ return np.int32
227
+ elif type == cutlass.int8:
228
+ return np.int8
229
+ else:
230
+ raise ValueError("unsupported type: %s" % ShortDataTypeNames[type])
231
+
232
+ def uniform_init(self, size, dtype):
233
+ if dtype in [np.float32, np.float16, bfloat16, np.float64]:
234
+ return np.ceil(
235
+ np.random.uniform(
236
+ low=self.scope_min - 0.5, high=self.scope_max - 0.5,
237
+ size=size).astype(dtype)
238
+ )
239
+ else:
240
+ return np.random.uniform(
241
+ low=self.scope_min - 1, high=self.scope_max + 1,
242
+ size=size).astype(dtype)
243
+
244
+ def reorder_tensor_B(self, tensor_B, problem_size):
245
+ reordered_tensor_B = np.empty_like(tensor_B)
246
+ tensor_ref_B = getTensorRef(
247
+ tensor_B, problem_size, "b", self.operation.B.layout)
248
+ reordered_tensor_ref_B = getTensorRef(
249
+ reordered_tensor_B, problem_size, "b", self.operation.B.layout)
250
+ cutlass.gemm.host.reorder_column(
251
+ tensor_ref_B, reordered_tensor_ref_B, problem_size)
252
+ return reordered_tensor_B
253
+
254
+ def host_reference(self, problem_size, tensor_A, tensor_B, tensor_C, alpha, beta):
255
+ # TODO
256
+ tensor_D_ref = np.ones_like(tensor_C)
257
+ alpha = self.numpy_type(self.compute_type)(alpha)
258
+ beta = self.numpy_type(self.compute_type)(beta)
259
+ init_acc = 0
260
+
261
+ alpha = self.compute_type(alpha).value()
262
+ beta = self.compute_type(beta).value()
263
+ init_acc = self.accumulator_type(init_acc).value()
264
+
265
+ if self.operation.switched:
266
+ tensor_ref_A = getTensorRef(
267
+ tensor_A, problem_size, "a", transpose(self.operation.B.layout))
268
+ tensor_ref_B = getTensorRef(
269
+ tensor_B, problem_size, "b", transpose(self.operation.A.layout))
270
+ tensor_ref_C = getTensorRef(
271
+ tensor_C, problem_size, "c", transpose(self.operation.C.layout))
272
+ tensor_ref_D_ref = getTensorRef(
273
+ tensor_D_ref, problem_size, "d", transpose(self.operation.C.layout))
274
+ else:
275
+ tensor_ref_A = getTensorRef(
276
+ tensor_A, problem_size, "a", self.operation.A.layout)
277
+ tensor_ref_B = getTensorRef(
278
+ tensor_B, problem_size, "b", self.operation.B.layout)
279
+ tensor_ref_C = getTensorRef(
280
+ tensor_C, problem_size, "c", self.operation.C.layout)
281
+ tensor_ref_D_ref = getTensorRef(
282
+ tensor_D_ref, problem_size, "d", self.operation.C.layout)
283
+
284
+ if self.math_operation in [MathOperation.multiply_add_saturate]:
285
+ cutlass.test.gemm.host.gemm_saturate(
286
+ problem_size, alpha, tensor_ref_A, tensor_ref_B, beta, tensor_ref_C, tensor_ref_D_ref, init_acc)
287
+ else:
288
+ cutlass.test.gemm.host.gemm(problem_size, alpha, tensor_ref_A,
289
+ tensor_ref_B, beta, tensor_ref_C, tensor_ref_D_ref, init_acc)
290
+
291
+ return tensor_D_ref
292
+
293
+ def equal(self, tensor_D, tensor_D_ref, problem_size):
294
+
295
+ tensor_view_D = getTensorView(
296
+ tensor_D, problem_size, "d", self.operation.C.layout)
297
+ tensor_view_D_ref = getTensorView(
298
+ tensor_D_ref, problem_size, "d", self.operation.C.layout)
299
+
300
+ return cutlass.test.gemm.host.equals(tensor_view_D, tensor_view_D_ref)
301
+
302
+ def bytes(self, problem_size, batch_count=1, alpha=1.0, beta=0.0):
303
+ m = problem_size.m()
304
+ n = problem_size.n()
305
+ k = problem_size.k()
306
+
307
+ bytes = \
308
+ (DataTypeSize[self.operation.A.element] * m // 8) * k + \
309
+ (DataTypeSize[self.operation.B.element] * n // 8) * k + \
310
+ (DataTypeSize[self.operation.C.element] * m // 8) * n
311
+
312
+ if beta != 0:
313
+ bytes += (DataTypeSize[self.operation.C.element] * m // 8) * n
314
+
315
+ bytes *= batch_count
316
+
317
+ return bytes
318
+
319
+ def flops(self, problem_size, batch_count=1):
320
+ m = problem_size.m()
321
+ n = problem_size.n()
322
+ k = problem_size.k()
323
+
324
+ flops_ = (m * n * k + m * n) * 2 * batch_count
325
+
326
+ # TODO: complex
327
+ return flops_
328
+
329
+ def run_cutlass_profiler(self, mode, problem_size, batch_count=1, alpha=1.0, beta=0.0):
330
+
331
+ cutlass_path = os.getenv('CUTLASS_PATH')
332
+ assert cutlass_path is not None, "Environment variable 'CUTLASS_PATH' is not defined."
333
+
334
+ values = {
335
+ "profiler_path": cutlass_path + "/build/tools/profiler/cutlass_profiler",
336
+ "kernel_name": self.operation.procedural_name(),
337
+ "verification_providers": "device",
338
+ "provider": "cutlass",
339
+ "m": str(problem_size.m()),
340
+ "n": str(problem_size.n()),
341
+ "k": str(problem_size.k()),
342
+ 'split_k_slices': str(batch_count),
343
+ 'alpha': str(alpha),
344
+ 'beta': str(beta),
345
+ 'warmup': str(self.warmup_iterations),
346
+ 'profile': str(self.iterations)
347
+ }
348
+
349
+ cmd_template = \
350
+ "${profiler_path} --kernels=${kernel_name} --verification-providers=${verification_providers}" \
351
+ " --providers=${provider} --m=${m} --n=${n} --k=${k}"
352
+
353
+ cmd = SubstituteTemplate(cmd_template, values)
354
+ result = subprocess.getoutput(cmd)
355
+
356
+ m = re.search(r"Runtime:\s+(?P<runtime>\d+.\d+)", result)
357
+ runtime = float(m.group('runtime'))
358
+
359
+ m = re.search(r"Bytes:\s+(?P<bytes>\d+)", result)
360
+ bytes = int(m.group('bytes'))
361
+
362
+ m = re.search(r"FLOPs:\s+(?P<flops>\d+)", result)
363
+ flops = int(m.group('flops'))
364
+
365
+ # check if the problem size matches
366
+ assert bytes == self.bytes(problem_size, alpha, beta)
367
+ assert flops == self.flops(problem_size)
368
+
369
+ return runtime
370
+
371
+ def run(self, mode, problem_size, batch_count=1, alpha=1.0, beta=0.0):
372
+
373
+ assert get_allocated_size(
374
+ ) == 0, "%d byte of pool memory is not released in previous run" % get_allocated_size()
375
+
376
+ np.random.seed(self.seed)
377
+
378
+ tensor_A = self.uniform_init(
379
+ size=(problem_size.m() * problem_size.k(),), dtype=self.dtype_A)
380
+ tensor_B = self.uniform_init(
381
+ size=(problem_size.n() * problem_size.k(),), dtype=self.dtype_B)
382
+ tensor_C = self.uniform_init(
383
+ size=(problem_size.m() * problem_size.n(),), dtype=self.dtype_C)
384
+ tensor_D = np.zeros(
385
+ shape=(problem_size.m() * problem_size.n(),), dtype=self.dtype_D)
386
+
387
+ #
388
+ # Launch kernel
389
+ #
390
+
391
+ arguments = GemmArguments(
392
+ operation=self.operation, problem_size=problem_size,
393
+ A=tensor_A, B=tensor_B, C=tensor_C, D=tensor_D,
394
+ output_op=self.operation.epilogue_type(alpha, beta),
395
+ gemm_mode=mode, split_k_slices=batch_count
396
+ )
397
+
398
+ if mode == cutlass.gemm.Mode.GemmSplitKParallel:
399
+ reduction_arguments = ReductionArguments(
400
+ self.reduction_operation, problem_size=[
401
+ problem_size.m(), problem_size.n()],
402
+ partitions=batch_count,
403
+ workspace=arguments.ptr_D,
404
+ destination=tensor_D,
405
+ source=tensor_C,
406
+ output_op=self.reduction_operation.epilogue_type(alpha, beta)
407
+ )
408
+
409
+ self.operation.run(arguments)
410
+
411
+ if mode == cutlass.gemm.Mode.GemmSplitKParallel:
412
+ self.reduction_operation.run(reduction_arguments)
413
+
414
+ passed = True
415
+
416
+ if self.verification:
417
+ if mode == cutlass.gemm.Mode.GemmSplitKParallel:
418
+ reduction_arguments.sync()
419
+ else:
420
+ arguments.sync()
421
+ tensor_D_ref = self.host_reference(
422
+ problem_size, tensor_A, tensor_B, tensor_C, alpha, beta)
423
+ passed = self.equal(tensor_D, tensor_D_ref, problem_size)
424
+
425
+ try:
426
+ assert passed
427
+ except AssertionError:
428
+ self.print_problem_size(problem_size, mode, batch_count)
429
+
430
+ if self.profiling:
431
+ sleep(self.sleep_time)
432
+ for _ in range(self.warmup_iterations):
433
+ self.operation.run(arguments)
434
+ if mode == cutlass.gemm.Mode.GemmSplitKParallel:
435
+ self.reduction_operation.run(reduction_arguments)
436
+
437
+ self.timer.start()
438
+ for _ in range(self.iterations):
439
+ self.operation.run(arguments)
440
+ if mode == cutlass.gemm.Mode.GemmSplitKParallel:
441
+ self.reduction_operation.run(reduction_arguments)
442
+ self.timer.stop_and_wait()
443
+
444
+ runtime = self.timer.duration(self.iterations)
445
+
446
+ # free memory and clear buffers
447
+ del arguments
448
+ if mode == cutlass.gemm.Mode.GemmSplitKParallel:
449
+ del reduction_arguments
450
+
451
+ assert get_allocated_size(
452
+ ) == 0, "%d byte of pool memory is not released after current run" % get_allocated_size()
453
+
454
+ if self.profiling:
455
+ return runtime
456
+ return passed
457
+
458
+
459
+ def test_all_gemm(operation: 'GemmOperationUniversal', testcase="universal"):
460
+
461
+ passed = True
462
+
463
+ minimum_operand_element_size = min(
464
+ DataTypeSize[operation.A.element], DataTypeSize[operation.B.element])
465
+ opcode_class = operation.tile_description.math_instruction.opcode_class
466
+
467
+ if opcode_class == cutlass.OpClass.Simt:
468
+ alignment = 1
469
+ else:
470
+ alignment = 128 // minimum_operand_element_size
471
+
472
+ # int8_t gemm alignment constrainst
473
+ if opcode_class == cutlass.OpClass.Simt and operation.A.element == cutlass.int8 and operation.A.layout == cutlass.ColumnMajor:
474
+ alignment_m = 4
475
+ else:
476
+ alignment_m = alignment
477
+
478
+ if opcode_class == cutlass.OpClass.Simt and operation.B.element == cutlass.int8 and operation.A.layout == cutlass.RowMajor:
479
+ alignment_n = 4
480
+ else:
481
+ alignment_n = alignment
482
+
483
+ if opcode_class == cutlass.OpClass.Simt and operation.A.element == cutlass.int8 \
484
+ and operation.B.element == cutlass.int8 \
485
+ and (operation.A.layout == cutlass.RowMajor or operation.B.layout == cutlass.ColumnMajor):
486
+
487
+ alignment_k = 4
488
+ else:
489
+ alignment_k = alignment
490
+
491
+ threadblock_k = operation.tile_description.threadblock_shape[2]
492
+
493
+ if testcase == "interleaved":
494
+ if operation.A.layout in [cutlass.ColumnMajorInterleaved32, cutlass.RowMajorInterleaved32]:
495
+ interleavedk = 32
496
+ else:
497
+ raise ValueError("unknonw layout")
498
+
499
+ if testcase == "interleaved":
500
+ modes = [cutlass.gemm.Mode.Gemm, ]
501
+ problem_size_m = [interleavedk, 512+interleavedk]
502
+ problem_size_n = [interleavedk, 512+interleavedk]
503
+ problem_size_k = [interleavedk, threadblock_k *
504
+ operation.tile_description.stages + interleavedk]
505
+ problem_alpha = [1.0]
506
+ problem_beta = [0.0]
507
+ batch_counts = [1, ]
508
+ elif testcase == "multistage":
509
+ modes = [cutlass.gemm.Mode.Gemm, ]
510
+ problem_size_m = [16, 528]
511
+ problem_size_n = [16, 528]
512
+ problem_size_k = [threadblock_k, threadblock_k * operation.tile_description.stages +
513
+ operation.tile_description.math_instruction.instruction_shape[2]]
514
+ problem_alpha = [1.0]
515
+ problem_beta = [0.0]
516
+ batch_counts = [1, ]
517
+ else: # universal
518
+ modes = [cutlass.gemm.Mode.Gemm, cutlass.gemm.Mode.GemmSplitKParallel]
519
+ problem_size_m = [alignment_m, 512 - 3 * alignment_m]
520
+ problem_size_n = [alignment_n, 512 - 2 * alignment_n]
521
+ problem_size_k = [
522
+ alignment_k,
523
+ threadblock_k * operation.tile_description.stages - alignment_k,
524
+ threadblock_k * operation.tile_description.stages * 3 - alignment_k]
525
+ batch_counts = [1, 2, 3, 5, 7]
526
+ problem_alpha = [1.0]
527
+ problem_beta = [2.0]
528
+
529
+ testbed = GemmUniversalLauncher(
530
+ operation, interleaved=(testcase == "interleaved"))
531
+
532
+ for mode in modes:
533
+ for m in problem_size_m:
534
+ for n in problem_size_n:
535
+ for k in problem_size_k:
536
+ for batch_count in batch_counts:
537
+ for alpha in problem_alpha:
538
+ for beta in problem_beta:
539
+ # skip very small K problems
540
+ if testcase == "universal":
541
+ if (k // batch_count < 2 * threadblock_k):
542
+ continue
543
+
544
+ problem_size = cutlass.gemm.GemmCoord(m, n, k)
545
+
546
+ passed = testbed.run(
547
+ mode, problem_size, batch_count, alpha, beta)
548
+
549
+ err, = cudart.cudaDeviceSynchronize()
550
+ if err != cuda.CUresult.CUDA_SUCCESS:
551
+ raise RuntimeError(
552
+ "CUDA Error %s" % str(err))
553
+
554
+ if not passed:
555
+ return False
556
+
557
+ return passed
@@ -0,0 +1,70 @@
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ from cuda import cuda
34
+ from cuda import cudart
35
+
36
+
37
+ class GpuTimer:
38
+ def __init__(self) -> None:
39
+ self.events = [
40
+ cuda.cuEventCreate(cuda.CUevent_flags.CU_EVENT_DEFAULT)[1],
41
+ cuda.cuEventCreate(cuda.CUevent_flags.CU_EVENT_DEFAULT)[1]
42
+ ]
43
+
44
+ def start(self, stream=cuda.CUstream(0)):
45
+ err, = cuda.cuEventRecord(self.events[0], stream)
46
+ if err != cuda.CUresult.CUDA_SUCCESS:
47
+ raise RuntimeError("CUDA Error %s" % str(err))
48
+
49
+ def stop(self, stream=cuda.CUstream(0)):
50
+ err, = cuda.cuEventRecord(self.events[1], stream)
51
+ if err != cuda.CUresult.CUDA_SUCCESS:
52
+ raise RuntimeError("CUDA Error %s" % str(err))
53
+ pass
54
+
55
+ def stop_and_wait(self, stream=cuda.CUstream(0)):
56
+ self.stop(stream)
57
+ if stream:
58
+ err, = cuda.cuStreamSynchronize(stream)
59
+ if err != cuda.CUresult.CUDA_SUCCESS:
60
+ raise RuntimeError("CUDA Error %s" % str(err))
61
+ else:
62
+ err, = cudart.cudaDeviceSynchronize()
63
+ if err != cuda.CUresult.CUDA_SUCCESS:
64
+ raise RuntimeError("CUDA Error %s" % str(err))
65
+
66
+ def duration(self, iterations=1):
67
+ err, duration = cuda.cuEventElapsedTime(self.events[0], self.events[1])
68
+ if err != cuda.CUresult.CUDA_SUCCESS:
69
+ raise RuntimeError("CUDA Error %s" % str(err))
70
+ return duration / float(iterations)
@@ -0,0 +1,39 @@
1
+ ################################################################################
2
+ #
3
+ # Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ ################################################################################
32
+
33
+ from typing import Union
34
+ from typeguard import typechecked
35
+
36
+
37
+ GemmOperation = 'Union[GemmOperationUniversal, GemmOperationGrouped]'
38
+
39
+ Tensor = 'Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]'
@@ -0,0 +1 @@
1
+ from pycutlass.utils.reference_model import *