warp-lang 0.9.0__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (315) hide show
  1. warp/__init__.py +15 -7
  2. warp/__init__.pyi +1 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +22 -443
  6. warp/build_dll.py +384 -0
  7. warp/builtins.py +998 -488
  8. warp/codegen.py +1307 -739
  9. warp/config.py +5 -3
  10. warp/constants.py +6 -0
  11. warp/context.py +1291 -548
  12. warp/dlpack.py +31 -31
  13. warp/fabric.py +326 -0
  14. warp/fem/__init__.py +27 -0
  15. warp/fem/cache.py +389 -0
  16. warp/fem/dirichlet.py +181 -0
  17. warp/fem/domain.py +263 -0
  18. warp/fem/field/__init__.py +101 -0
  19. warp/fem/field/field.py +149 -0
  20. warp/fem/field/nodal_field.py +299 -0
  21. warp/fem/field/restriction.py +21 -0
  22. warp/fem/field/test.py +181 -0
  23. warp/fem/field/trial.py +183 -0
  24. warp/fem/geometry/__init__.py +19 -0
  25. warp/fem/geometry/closest_point.py +70 -0
  26. warp/fem/geometry/deformed_geometry.py +271 -0
  27. warp/fem/geometry/element.py +744 -0
  28. warp/fem/geometry/geometry.py +186 -0
  29. warp/fem/geometry/grid_2d.py +373 -0
  30. warp/fem/geometry/grid_3d.py +435 -0
  31. warp/fem/geometry/hexmesh.py +953 -0
  32. warp/fem/geometry/partition.py +376 -0
  33. warp/fem/geometry/quadmesh_2d.py +532 -0
  34. warp/fem/geometry/tetmesh.py +840 -0
  35. warp/fem/geometry/trimesh_2d.py +577 -0
  36. warp/fem/integrate.py +1616 -0
  37. warp/fem/operator.py +191 -0
  38. warp/fem/polynomial.py +213 -0
  39. warp/fem/quadrature/__init__.py +2 -0
  40. warp/fem/quadrature/pic_quadrature.py +245 -0
  41. warp/fem/quadrature/quadrature.py +294 -0
  42. warp/fem/space/__init__.py +292 -0
  43. warp/fem/space/basis_space.py +489 -0
  44. warp/fem/space/collocated_function_space.py +105 -0
  45. warp/fem/space/dof_mapper.py +236 -0
  46. warp/fem/space/function_space.py +145 -0
  47. warp/fem/space/grid_2d_function_space.py +267 -0
  48. warp/fem/space/grid_3d_function_space.py +306 -0
  49. warp/fem/space/hexmesh_function_space.py +352 -0
  50. warp/fem/space/partition.py +350 -0
  51. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  52. warp/fem/space/restriction.py +160 -0
  53. warp/fem/space/shape/__init__.py +15 -0
  54. warp/fem/space/shape/cube_shape_function.py +738 -0
  55. warp/fem/space/shape/shape_function.py +103 -0
  56. warp/fem/space/shape/square_shape_function.py +611 -0
  57. warp/fem/space/shape/tet_shape_function.py +567 -0
  58. warp/fem/space/shape/triangle_shape_function.py +429 -0
  59. warp/fem/space/tetmesh_function_space.py +292 -0
  60. warp/fem/space/topology.py +295 -0
  61. warp/fem/space/trimesh_2d_function_space.py +221 -0
  62. warp/fem/types.py +77 -0
  63. warp/fem/utils.py +495 -0
  64. warp/native/array.h +164 -55
  65. warp/native/builtin.h +150 -174
  66. warp/native/bvh.cpp +75 -328
  67. warp/native/bvh.cu +406 -23
  68. warp/native/bvh.h +37 -45
  69. warp/native/clang/clang.cpp +136 -24
  70. warp/native/crt.cpp +1 -76
  71. warp/native/crt.h +111 -104
  72. warp/native/cuda_crt.h +1049 -0
  73. warp/native/cuda_util.cpp +15 -3
  74. warp/native/cuda_util.h +3 -1
  75. warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
  76. warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
  77. warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
  78. warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
  79. warp/native/cutlass/tools/library/scripts/library.py +799 -0
  80. warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
  81. warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
  82. warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
  83. warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
  84. warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
  85. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
  86. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
  87. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
  88. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
  89. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
  90. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
  91. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
  92. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
  93. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
  94. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
  95. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
  96. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
  97. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
  98. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
  99. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
  100. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
  101. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
  102. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
  103. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
  104. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
  105. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
  106. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
  107. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
  108. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
  109. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
  110. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
  111. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
  112. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
  113. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
  114. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
  115. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
  116. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
  117. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  118. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  119. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
  120. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
  121. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
  122. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  123. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  124. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
  125. warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
  126. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
  127. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
  128. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
  129. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
  130. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
  131. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
  132. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
  133. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
  134. warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
  135. warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
  136. warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
  137. warp/native/cutlass/tools/library/scripts/rt.py +796 -0
  138. warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
  139. warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
  140. warp/native/cutlass_gemm.cu +5 -3
  141. warp/native/exports.h +1240 -949
  142. warp/native/fabric.h +228 -0
  143. warp/native/hashgrid.cpp +4 -4
  144. warp/native/hashgrid.h +22 -2
  145. warp/native/initializer_array.h +2 -2
  146. warp/native/intersect.h +22 -7
  147. warp/native/intersect_adj.h +8 -8
  148. warp/native/intersect_tri.h +13 -16
  149. warp/native/marching.cu +157 -161
  150. warp/native/mat.h +119 -19
  151. warp/native/matnn.h +2 -2
  152. warp/native/mesh.cpp +108 -83
  153. warp/native/mesh.cu +243 -6
  154. warp/native/mesh.h +1547 -458
  155. warp/native/nanovdb/NanoVDB.h +1 -1
  156. warp/native/noise.h +272 -329
  157. warp/native/quat.h +51 -8
  158. warp/native/rand.h +45 -35
  159. warp/native/range.h +6 -2
  160. warp/native/reduce.cpp +157 -0
  161. warp/native/reduce.cu +348 -0
  162. warp/native/runlength_encode.cpp +62 -0
  163. warp/native/runlength_encode.cu +46 -0
  164. warp/native/scan.cu +11 -13
  165. warp/native/scan.h +1 -0
  166. warp/native/solid_angle.h +442 -0
  167. warp/native/sort.cpp +13 -0
  168. warp/native/sort.cu +9 -1
  169. warp/native/sparse.cpp +338 -0
  170. warp/native/sparse.cu +545 -0
  171. warp/native/spatial.h +2 -2
  172. warp/native/temp_buffer.h +30 -0
  173. warp/native/vec.h +126 -24
  174. warp/native/volume.h +120 -0
  175. warp/native/warp.cpp +658 -53
  176. warp/native/warp.cu +660 -68
  177. warp/native/warp.h +112 -12
  178. warp/optim/__init__.py +1 -0
  179. warp/optim/linear.py +922 -0
  180. warp/optim/sgd.py +92 -0
  181. warp/render/render_opengl.py +392 -152
  182. warp/render/render_usd.py +11 -11
  183. warp/sim/__init__.py +2 -2
  184. warp/sim/articulation.py +385 -185
  185. warp/sim/collide.py +21 -8
  186. warp/sim/import_mjcf.py +297 -106
  187. warp/sim/import_urdf.py +389 -210
  188. warp/sim/import_usd.py +198 -97
  189. warp/sim/inertia.py +17 -18
  190. warp/sim/integrator_euler.py +14 -8
  191. warp/sim/integrator_xpbd.py +161 -19
  192. warp/sim/model.py +795 -291
  193. warp/sim/optimizer.py +2 -6
  194. warp/sim/render.py +65 -3
  195. warp/sim/utils.py +3 -0
  196. warp/sparse.py +1227 -0
  197. warp/stubs.py +665 -223
  198. warp/tape.py +66 -15
  199. warp/tests/__main__.py +3 -6
  200. warp/tests/assets/curlnoise_golden.npy +0 -0
  201. warp/tests/assets/pnoise_golden.npy +0 -0
  202. warp/tests/assets/torus.usda +105 -105
  203. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  204. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  205. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  206. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  207. warp/tests/aux_test_unresolved_func.py +14 -0
  208. warp/tests/aux_test_unresolved_symbol.py +14 -0
  209. warp/tests/disabled_kinematics.py +239 -0
  210. warp/tests/run_coverage_serial.py +31 -0
  211. warp/tests/test_adam.py +103 -106
  212. warp/tests/test_arithmetic.py +128 -74
  213. warp/tests/test_array.py +1497 -211
  214. warp/tests/test_array_reduce.py +150 -0
  215. warp/tests/test_atomic.py +64 -28
  216. warp/tests/test_bool.py +99 -0
  217. warp/tests/test_builtins_resolution.py +1292 -0
  218. warp/tests/test_bvh.py +75 -43
  219. warp/tests/test_closest_point_edge_edge.py +54 -57
  220. warp/tests/test_codegen.py +233 -128
  221. warp/tests/test_compile_consts.py +28 -20
  222. warp/tests/test_conditional.py +108 -24
  223. warp/tests/test_copy.py +10 -12
  224. warp/tests/test_ctypes.py +112 -88
  225. warp/tests/test_dense.py +21 -14
  226. warp/tests/test_devices.py +98 -0
  227. warp/tests/test_dlpack.py +136 -108
  228. warp/tests/test_examples.py +277 -0
  229. warp/tests/test_fabricarray.py +955 -0
  230. warp/tests/test_fast_math.py +15 -11
  231. warp/tests/test_fem.py +1271 -0
  232. warp/tests/test_fp16.py +53 -19
  233. warp/tests/test_func.py +187 -74
  234. warp/tests/test_generics.py +194 -49
  235. warp/tests/test_grad.py +180 -116
  236. warp/tests/test_grad_customs.py +176 -0
  237. warp/tests/test_hash_grid.py +52 -37
  238. warp/tests/test_import.py +10 -23
  239. warp/tests/test_indexedarray.py +577 -24
  240. warp/tests/test_intersect.py +18 -9
  241. warp/tests/test_large.py +141 -0
  242. warp/tests/test_launch.py +251 -15
  243. warp/tests/test_lerp.py +64 -65
  244. warp/tests/test_linear_solvers.py +154 -0
  245. warp/tests/test_lvalue.py +493 -0
  246. warp/tests/test_marching_cubes.py +12 -13
  247. warp/tests/test_mat.py +508 -2778
  248. warp/tests/test_mat_lite.py +115 -0
  249. warp/tests/test_mat_scalar_ops.py +2889 -0
  250. warp/tests/test_math.py +103 -9
  251. warp/tests/test_matmul.py +305 -69
  252. warp/tests/test_matmul_lite.py +410 -0
  253. warp/tests/test_mesh.py +71 -14
  254. warp/tests/test_mesh_query_aabb.py +41 -25
  255. warp/tests/test_mesh_query_point.py +325 -34
  256. warp/tests/test_mesh_query_ray.py +39 -22
  257. warp/tests/test_mlp.py +30 -22
  258. warp/tests/test_model.py +92 -89
  259. warp/tests/test_modules_lite.py +39 -0
  260. warp/tests/test_multigpu.py +88 -114
  261. warp/tests/test_noise.py +12 -11
  262. warp/tests/test_operators.py +16 -20
  263. warp/tests/test_options.py +11 -11
  264. warp/tests/test_pinned.py +17 -18
  265. warp/tests/test_print.py +32 -11
  266. warp/tests/test_quat.py +275 -129
  267. warp/tests/test_rand.py +18 -16
  268. warp/tests/test_reload.py +38 -34
  269. warp/tests/test_rounding.py +50 -43
  270. warp/tests/test_runlength_encode.py +190 -0
  271. warp/tests/test_smoothstep.py +9 -11
  272. warp/tests/test_snippet.py +143 -0
  273. warp/tests/test_sparse.py +460 -0
  274. warp/tests/test_spatial.py +276 -243
  275. warp/tests/test_streams.py +110 -85
  276. warp/tests/test_struct.py +331 -85
  277. warp/tests/test_tape.py +39 -21
  278. warp/tests/test_torch.py +118 -89
  279. warp/tests/test_transient_module.py +12 -13
  280. warp/tests/test_types.py +614 -0
  281. warp/tests/test_utils.py +494 -0
  282. warp/tests/test_vec.py +354 -1987
  283. warp/tests/test_vec_lite.py +73 -0
  284. warp/tests/test_vec_scalar_ops.py +2099 -0
  285. warp/tests/test_volume.py +457 -293
  286. warp/tests/test_volume_write.py +124 -134
  287. warp/tests/unittest_serial.py +35 -0
  288. warp/tests/unittest_suites.py +341 -0
  289. warp/tests/unittest_utils.py +568 -0
  290. warp/tests/unused_test_misc.py +71 -0
  291. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  292. warp/thirdparty/appdirs.py +36 -45
  293. warp/thirdparty/unittest_parallel.py +549 -0
  294. warp/torch.py +72 -30
  295. warp/types.py +1744 -713
  296. warp/utils.py +360 -350
  297. warp_lang-0.11.0.dist-info/LICENSE.md +36 -0
  298. warp_lang-0.11.0.dist-info/METADATA +238 -0
  299. warp_lang-0.11.0.dist-info/RECORD +332 -0
  300. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
  301. warp/bin/warp-clang.exp +0 -0
  302. warp/bin/warp-clang.lib +0 -0
  303. warp/bin/warp.exp +0 -0
  304. warp/bin/warp.lib +0 -0
  305. warp/tests/test_all.py +0 -215
  306. warp/tests/test_array_scan.py +0 -60
  307. warp/tests/test_base.py +0 -208
  308. warp/tests/test_unresolved_func.py +0 -7
  309. warp/tests/test_unresolved_symbol.py +0 -7
  310. warp_lang-0.9.0.dist-info/METADATA +0 -20
  311. warp_lang-0.9.0.dist-info/RECORD +0 -177
  312. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  313. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  314. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  315. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,453 @@
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ ## Test case generator for SM80
34
+
35
+ import pycutlass
36
+ from pycutlass import *
37
+ from pycutlass.test import *
38
+ from pycutlass.utils.device import device_cc
39
+ import unittest
40
+
41
+ #
42
+ # Create GEMM operation
43
+ #
44
+
45
+ @unittest.skipIf(device_cc() < 80, "Device compute capability is insufficient for SM80 tests.")
46
+ def TestGemmOperator(gemm_kind, math_inst, layout, alignment, tiling, arch, mixed=False,
47
+ epilogue_functor=None, swizzling_functor=cutlass.IdentitySwizzle1, **kwargs):
48
+ """
49
+ Test GEMM Operation based on configuration
50
+ """
51
+
52
+ if "data_type" in kwargs.keys():
53
+ data_type = kwargs["data_type"]
54
+ else:
55
+ if mixed or math_inst.element_a == cutlass.bfloat16:
56
+ data_type = [
57
+ math_inst.element_a,
58
+ math_inst.element_b,
59
+ math_inst.element_accumulator,
60
+ math_inst.element_accumulator
61
+ ]
62
+ else:
63
+ data_type = [
64
+ math_inst.element_a,
65
+ math_inst.element_b,
66
+ math_inst.element_a,
67
+ math_inst.element_accumulator
68
+ ]
69
+
70
+ tile_description = TileDescription(
71
+ tiling[0], tiling[1], tiling[2],
72
+ math_inst
73
+ )
74
+
75
+ A = TensorDescription(
76
+ data_type[0], layout[0], alignment[0]
77
+ )
78
+
79
+ B = TensorDescription(
80
+ data_type[1], layout[1], alignment[1]
81
+ )
82
+
83
+ C = TensorDescription(
84
+ data_type[2], layout[2], alignment[2]
85
+ )
86
+
87
+ element_epilogue = data_type[3]
88
+ if epilogue_functor is None:
89
+ epilogue_functor = LinearCombination(
90
+ C.element, C.alignment,
91
+ math_inst.element_accumulator, element_epilogue)
92
+
93
+ if gemm_kind == GemmKind.Universal:
94
+ operation = GemmOperationUniversal(
95
+ arch=arch, tile_description=tile_description,
96
+ A=A, B=B, C=C,
97
+ epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
98
+ )
99
+ if A.layout in [cutlass.ColumnMajorInterleaved32, cutlass.RowMajorInterleaved32]:
100
+ return test_all_gemm(operation, "interleaved")
101
+ else:
102
+ return test_all_gemm(operation, "universal")
103
+
104
+ elif gemm_kind == GemmKind.Grouped:
105
+ operation = GemmOperationGrouped(
106
+ arch, tile_description, A, B, C,
107
+ epilogue_functor, swizzling_functor,
108
+ precompute_mode=kwargs["precompute_mode"]
109
+ )
110
+ testbed = TestbedGrouped(operation=operation)
111
+ return testbed.run(24)
112
+ else:
113
+ raise NotImplementedError("the gemm kind is not implemented")
114
+
115
+
116
+ def TestConv2dOperator(math_inst, alignment, tiling, arch,
117
+ stride_supports=[StrideSupport.Strided, StrideSupport.Strided, StrideSupport.Strided],
118
+ epilogue_functor=None,
119
+ swizzling_functor=cutlass.IdentitySwizzle1, interleaved=False, **kwargs):
120
+ """
121
+ Test Conv2d Operation based on configurations
122
+ """
123
+
124
+ mixeds = [False, True, False]
125
+ conv_kinds = [cutlass.conv.Operator.fprop, cutlass.conv.Operator.dgrad, cutlass.conv.Operator.wgrad]
126
+
127
+ results = []
128
+
129
+ default_swizzling_functor = swizzling_functor
130
+
131
+ if "layout" in kwargs.keys():
132
+ layout = kwargs["layout"]
133
+ else:
134
+ layout = (cutlass.TensorNHWC, cutlass.TensorNHWC, cutlass.TensorNHWC)
135
+
136
+ for mixed, conv_kind, stride_support in zip(mixeds, conv_kinds, stride_supports):
137
+
138
+ if "data_type" in kwargs.keys():
139
+ data_type = kwargs["data_type"]
140
+ else:
141
+ if mixed or math_inst.element_a == cutlass.bfloat16:
142
+ data_type = [
143
+ math_inst.element_a,
144
+ math_inst.element_b,
145
+ math_inst.element_accumulator,
146
+ math_inst.element_accumulator
147
+ ]
148
+ else:
149
+ data_type = [
150
+ math_inst.element_a,
151
+ math_inst.element_b,
152
+ math_inst.element_a,
153
+ math_inst.element_accumulator
154
+ ]
155
+ # skip Int8 Conv Backward
156
+ if data_type[0] == cutlass.int8 and conv_kind in [cutlass.conv.Operator.dgrad, cutlass.conv.Operator.wgrad]:
157
+ continue
158
+
159
+ A = TensorDescription(
160
+ element=data_type[0],
161
+ layout=layout[0],
162
+ alignment=alignment[0])
163
+ B = TensorDescription(
164
+ element=data_type[1],
165
+ layout=layout[1],
166
+ alignment=alignment[1])
167
+ C = TensorDescription(
168
+ element=data_type[2],
169
+ layout=layout[2],
170
+ alignment=alignment[2])
171
+
172
+ tile_description = TileDescription(
173
+ threadblock_shape=tiling[0], stages=tiling[1],
174
+ warp_count=tiling[2],
175
+ math_instruction=math_inst
176
+ )
177
+
178
+ if conv_kind == cutlass.conv.Operator.dgrad and stride_support == StrideSupport.Strided:
179
+ swizzling_functor = cutlass.StridedDgradIdentitySwizzle1
180
+ else:
181
+ swizzling_functor = default_swizzling_functor
182
+
183
+ if epilogue_functor is None:
184
+ epilogue_functor_ = LinearCombination(
185
+ C.element, C.alignment,
186
+ math_inst.element_accumulator, data_type[3])
187
+
188
+ operation = Conv2dOperation(
189
+ conv_kind=conv_kind, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
190
+ arch=arch, tile_description=tile_description, A=A, B=B, C=C,
191
+ stride_support=stride_support,
192
+ epilogue_functor=epilogue_functor_,
193
+ swizzling_functor=swizzling_functor
194
+ )
195
+
196
+ results.append(test_all_conv2d(operation, interleaved=interleaved))
197
+
198
+ return results
199
+
200
+
201
+
202
+ class Test_SM80(unittest.TestCase):
203
+ def test_SM80_TensorOp_16816(self):
204
+ math_instructions = [
205
+ MathInstruction(
206
+ [16, 8, 16], cutlass.float16, cutlass.float16, cutlass.float32,
207
+ cutlass.OpClass.TensorOp, MathOperation.multiply_add
208
+ ),
209
+ MathInstruction(
210
+ [16, 8, 16], cutlass.float16, cutlass.float16, cutlass.float16,
211
+ cutlass.OpClass.TensorOp, MathOperation.multiply_add
212
+ ),
213
+ MathInstruction(
214
+ [16, 8, 16], cutlass.bfloat16, cutlass.bfloat16, cutlass.float32,
215
+ cutlass.OpClass.TensorOp, MathOperation.multiply_add
216
+ )
217
+ ]
218
+
219
+ layouts = [
220
+ (cutlass.RowMajor, cutlass.RowMajor, cutlass.RowMajor),
221
+ (cutlass.ColumnMajor, cutlass.RowMajor, cutlass.RowMajor),
222
+ (cutlass.RowMajor, cutlass.ColumnMajor, cutlass.RowMajor)
223
+ ]
224
+
225
+ alignments = [
226
+ (8, 8, 8), (4, 8, 8), (8, 4, 8)
227
+ ]
228
+
229
+ tilings = [
230
+ ([256, 128, 32], 3, [4, 2, 1]),
231
+ ([64, 256, 32], 4, [1, 4, 1]),
232
+ ([128, 64, 64], 3, [2, 2, 1])
233
+ ]
234
+
235
+ for math_inst, layout, alignment, tiling in zip(math_instructions, layouts, alignments, tilings):
236
+ self.assertTrue(TestGemmOperator(GemmKind.Universal, math_inst, layout, alignment, tiling, 80, False))
237
+ self.assertTrue(TestGemmOperator(GemmKind.Grouped, math_inst, layout, alignment, tiling, 80, True, precompute_mode=SchedulerMode.Host))
238
+ stride_supports = [StrideSupport.Strided, StrideSupport.Strided, StrideSupport.Strided]
239
+ results = TestConv2dOperator(math_inst, alignment, tiling, 80, stride_supports=stride_supports)
240
+ for res in results:
241
+ self.assertTrue(res)
242
+
243
+ def test_SM80_TensorOp_1688(self):
244
+ # tf32 is not supported by most of python environment. Skip the test
245
+ self.assertTrue(True)
246
+
247
+ def test_SM80_TensorOp_1688_fast_math(self):
248
+ math_instructions = [
249
+ MathInstruction(
250
+ [16, 8, 8], cutlass.tfloat32, cutlass.tfloat32, cutlass.float32,
251
+ cutlass.OpClass.TensorOp, MathOperation.multiply_add
252
+ ),
253
+ MathInstruction(
254
+ [16, 8, 8], cutlass.float16, cutlass.float16, cutlass.float32,
255
+ cutlass.OpClass.TensorOp, MathOperation.multiply_add_fast_f16
256
+ ),
257
+ MathInstruction(
258
+ [16, 8, 8], cutlass.bfloat16, cutlass.bfloat16, cutlass.float32,
259
+ cutlass.OpClass.TensorOp, MathOperation.multiply_add_fast_bf16
260
+ ),
261
+ MathInstruction(
262
+ [16, 8, 8], cutlass.float32, cutlass.float32, cutlass.float32,
263
+ cutlass.OpClass.TensorOp, MathOperation.multiply_add_fast_f32
264
+ )
265
+ ]
266
+
267
+ layouts = [
268
+ (cutlass.RowMajor, cutlass.RowMajor, cutlass.ColumnMajor),
269
+ (cutlass.RowMajor, cutlass.ColumnMajor, cutlass.ColumnMajor),
270
+ (cutlass.ColumnMajor, cutlass.RowMajor, cutlass.ColumnMajor),
271
+ (cutlass.ColumnMajor, cutlass.ColumnMajor, cutlass.RowMajor)
272
+ ]
273
+ alignments = [
274
+ (4, 4, 4), (4, 2, 4), (2, 4, 4), (2, 2, 4)
275
+ ]
276
+ tilings = [
277
+ ([128, 256, 16], 3, [4, 2, 1]),
278
+ ([64, 256, 16], 4, [1, 4, 1]),
279
+ ([128, 64, 32], 3, [2, 2, 1]),
280
+ ([256, 64, 32], 3, [4, 2, 1])
281
+ ]
282
+ data_type = [
283
+ cutlass.float32, cutlass.float32, cutlass.float32, cutlass.float32
284
+ ]
285
+ for math_inst, layout, alignment, tiling in zip(math_instructions, layouts, alignments, tilings):
286
+ self.assertTrue(
287
+ TestGemmOperator(
288
+ GemmKind.Universal, math_inst, layout,
289
+ alignment, tiling, 80, False, data_type=data_type))
290
+ self.assertTrue(
291
+ TestGemmOperator(
292
+ GemmKind.Grouped, math_inst, layout, alignment, tiling, 80,
293
+ True, precompute_mode=SchedulerMode.Device, data_type=data_type))
294
+ stride_supports = [StrideSupport.Unity, StrideSupport.Strided, StrideSupport.Unity]
295
+ results = TestConv2dOperator(math_inst, alignment, tiling, 80, stride_supports=stride_supports, data_type=data_type)
296
+ for res in results:
297
+ self.assertTrue(res)
298
+
299
+ def test_SM80_TensorOp_884(self):
300
+ math_inst = MathInstruction(
301
+ [8, 8, 4], cutlass.float64, cutlass.float64, cutlass.float64,
302
+ cutlass.OpClass.TensorOp, MathOperation.multiply_add
303
+ )
304
+ layout = (cutlass.ColumnMajor, cutlass.ColumnMajor, cutlass.ColumnMajor)
305
+ alignment = (1, 1, 1)
306
+
307
+ tiling = ([64, 256, 16], 3, [2, 4, 1])
308
+ data_type = [cutlass.float64, cutlass.float64, cutlass.float64, cutlass.float64]
309
+ self.assertTrue(TestGemmOperator(GemmKind.Universal, math_inst, layout, alignment, tiling, 80, False, data_type=data_type))
310
+ self.assertTrue(TestGemmOperator(GemmKind.Grouped, math_inst, layout, alignment, tiling, 80, True, precompute_mode=SchedulerMode.Device, data_type=data_type))
311
+ stride_supports = [StrideSupport.Unity, StrideSupport.Strided, StrideSupport.Unity]
312
+ results = TestConv2dOperator(math_inst, alignment, tiling, 80, stride_supports=stride_supports, data_type=data_type)
313
+ for res in results:
314
+ self.assertTrue(res)
315
+
316
+ def test_SM80_TensorOp_16832_TN(self):
317
+ math_inst = MathInstruction(
318
+ [16, 8, 32], cutlass.int8, cutlass.int8, cutlass.int32,
319
+ cutlass.OpClass.TensorOp, MathOperation.multiply_add_saturate
320
+ )
321
+ layout = (cutlass.RowMajor, cutlass.ColumnMajor, cutlass.ColumnMajor)
322
+ alignment = (16, 16, 4)
323
+ alignment_mixed = (16, 16, 16)
324
+ tiling = ([128, 256, 64], 3, [2, 4, 1])
325
+
326
+ data_type = [cutlass.int8, cutlass.int8, cutlass.int32, cutlass.int32]
327
+ data_type_mixed = [cutlass.int8, cutlass.int8, cutlass.int8, cutlass.float32]
328
+
329
+ self.assertTrue(TestGemmOperator(GemmKind.Universal, math_inst, layout, alignment, tiling, 80, False, data_type=data_type))
330
+ self.assertTrue(TestGemmOperator(GemmKind.Grouped, math_inst, layout, alignment_mixed, tiling, 80, True, precompute_mode=SchedulerMode.Device, data_type=data_type_mixed))
331
+ stride_supports = [StrideSupport.Strided, StrideSupport.Strided, StrideSupport.Strided]
332
+ results = TestConv2dOperator(math_inst, alignment, tiling, 80, stride_supports=stride_supports, data_type=data_type)
333
+ for res in results:
334
+ self.assertTrue(res)
335
+
336
+ def test_SM80_Simt_f32(self):
337
+ math_inst = MathInstruction(
338
+ [1, 1, 1], cutlass.float32, cutlass.float32, cutlass.float32,
339
+ cutlass.OpClass.Simt, MathOperation.multiply_add
340
+ )
341
+ layout = (cutlass.RowMajor, cutlass.RowMajor, cutlass.RowMajor)
342
+ alignment = (1, 1, 1)
343
+
344
+ tiling = ([128, 256, 8], 4, [2, 4, 1])
345
+ data_type = [cutlass.float32, cutlass.float32, cutlass.float32, cutlass.float32]
346
+ self.assertTrue(TestGemmOperator(GemmKind.Universal, math_inst, layout, alignment, tiling, 80, False, data_type=data_type))
347
+ self.assertTrue(TestGemmOperator(GemmKind.Grouped, math_inst, layout, alignment, tiling, 80, True, precompute_mode=SchedulerMode.Host, data_type=data_type))
348
+ stride_supports = [StrideSupport.Strided, StrideSupport.Strided, StrideSupport.Strided]
349
+ results = TestConv2dOperator(math_inst, alignment, tiling, 80, stride_supports=stride_supports, data_type=data_type)
350
+ for res in results:
351
+ self.assertTrue(res)
352
+
353
+ def test_SM80_Simt_f64(self):
354
+ math_inst = MathInstruction(
355
+ [1, 1, 1], cutlass.float64, cutlass.float64, cutlass.float64,
356
+ cutlass.OpClass.Simt, MathOperation.multiply_add
357
+ )
358
+ layout = (cutlass.RowMajor, cutlass.RowMajor, cutlass.ColumnMajor)
359
+ alignment = (1, 1, 1)
360
+
361
+ tiling = ([64, 128, 8], 5, [2, 2, 1])
362
+ data_type = [cutlass.float64, cutlass.float64, cutlass.float64, cutlass.float64]
363
+ self.assertTrue(TestGemmOperator(GemmKind.Universal, math_inst, layout, alignment, tiling, 80, False, data_type=data_type))
364
+ self.assertTrue(TestGemmOperator(GemmKind.Grouped, math_inst, layout, alignment, tiling, 80, True, precompute_mode=SchedulerMode.Device, data_type=data_type))
365
+ stride_supports = [StrideSupport.Strided, StrideSupport.Strided, StrideSupport.Strided]
366
+ results = TestConv2dOperator(math_inst, alignment, tiling, 80, stride_supports=stride_supports, data_type=data_type)
367
+ for res in results:
368
+ self.assertTrue(res)
369
+
370
+ def test_SM80_TensorOp_16832_Interleaved(self):
371
+ math_inst = MathInstruction(
372
+ [16, 8, 32], cutlass.int8, cutlass.int8, cutlass.int32,
373
+ cutlass.OpClass.TensorOp, MathOperation.multiply_add_saturate
374
+ )
375
+
376
+ layout = (cutlass.ColumnMajorInterleaved32, cutlass.RowMajorInterleaved32, cutlass.ColumnMajorInterleaved32)
377
+ alignment_mixed = (16, 16, 8)
378
+ tiling = ([256, 64, 64], 4, [4, 1, 1])
379
+ data_type_mixed = [cutlass.int8, cutlass.int8, cutlass.int8, cutlass.float32]
380
+
381
+ epilogue_functor = FastLinearCombinationClamp(
382
+ data_type_mixed[2], alignment_mixed[2]
383
+ )
384
+
385
+ self.assertTrue(TestGemmOperator(GemmKind.Universal, math_inst, layout, alignment_mixed, tiling, 80, False, data_type=data_type_mixed, epilogue_functor=epilogue_functor))
386
+ stride_supports = [StrideSupport.Strided, StrideSupport.Strided, StrideSupport.Strided]
387
+ layout = [cutlass.TensorNC32HW32, cutlass.TensorC32RSK32, cutlass.TensorNC32HW32]
388
+ results = TestConv2dOperator(math_inst, alignment_mixed, tiling, 80, stride_supports=stride_supports, data_type=data_type_mixed, layout=layout, interleaved=True)
389
+ for res in results:
390
+ self.assertTrue(res)
391
+
392
+ def SM80_SparseTensorOp_16832(self):
393
+ pass
394
+ def SM80_PlanarComplexTensorOp_16816(self):
395
+ pass
396
+ def SM80_SparseTensorOp_16816_fast_math(self):
397
+ pass
398
+ def SM80_TensorOp_1688_complex(self):
399
+ pass
400
+ def SM80_TensorOp_1688_fast_fp32_math_complex(self):
401
+ pass
402
+ def SM80_TensorOp_1688_rank_k(self):
403
+ pass
404
+ def SM80_TensorOp_1688_rank_k_complex(self):
405
+ pass
406
+ def SM80_TensorOp_1688_trmm(self):
407
+ pass
408
+ def SM80_TensorOp_1688_trmm_complex(self):
409
+ pass
410
+ def SM80_TensorOp_1688_symm(self):
411
+ pass
412
+ def SM80_TensorOp_1688_symm_complex(self):
413
+ pass
414
+ def SM80_TensorOp_884_complex(self):
415
+ pass
416
+ def SM80_TensorOp_884_complex_gaussian(self):
417
+ pass
418
+ def SM80_TensorOp_884_rank_k(self):
419
+ pass
420
+ def SM80_TensorOp_884_rank_k_complex(self):
421
+ pass
422
+ def SM80_TensorOp_884_rank_k_complex_gaussian(self):
423
+ pass
424
+ def SM80_TensorOp_884_trmm(self):
425
+ pass
426
+ def SM80_TensorOp_884_trmm_complex(self):
427
+ pass
428
+ def SM80_TensorOp_884_trmm_complex_gaussian(self):
429
+ pass
430
+ def SM80_TensorOp_884_symm(self):
431
+ pass
432
+ def SM80_TensorOp_884_symm_complex(self):
433
+ pass
434
+ def SM80_TensorOp_884_symm_complex_gaussian(self):
435
+ pass
436
+ def SM80_SparseTensorOp_16864_TN(self):
437
+ pass
438
+ def SM80_TensorOp_16864_TN(self):
439
+ pass
440
+ def SM80_SparseTensorOp_168128_TN(self):
441
+ pass
442
+ def SM80_TensorOp_16864_Interleaved(self):
443
+ pass
444
+ def SM80_TensorOp_168256(self):
445
+ pass
446
+ def SM80_Simt_complex(self):
447
+ pass
448
+
449
+
450
+ if __name__ == '__main__':
451
+ pycutlass.get_memory_pool(2**20, 2**34)
452
+ pycutlass.compiler.nvcc()
453
+ unittest.main()