warp-lang 0.10.1__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (300) hide show
  1. warp/__init__.py +10 -4
  2. warp/__init__.pyi +1 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +5 -3
  6. warp/build_dll.py +29 -9
  7. warp/builtins.py +868 -507
  8. warp/codegen.py +1074 -638
  9. warp/config.py +3 -3
  10. warp/constants.py +6 -0
  11. warp/context.py +715 -222
  12. warp/fabric.py +326 -0
  13. warp/fem/__init__.py +27 -0
  14. warp/fem/cache.py +389 -0
  15. warp/fem/dirichlet.py +181 -0
  16. warp/fem/domain.py +263 -0
  17. warp/fem/field/__init__.py +101 -0
  18. warp/fem/field/field.py +149 -0
  19. warp/fem/field/nodal_field.py +299 -0
  20. warp/fem/field/restriction.py +21 -0
  21. warp/fem/field/test.py +181 -0
  22. warp/fem/field/trial.py +183 -0
  23. warp/fem/geometry/__init__.py +19 -0
  24. warp/fem/geometry/closest_point.py +70 -0
  25. warp/fem/geometry/deformed_geometry.py +271 -0
  26. warp/fem/geometry/element.py +744 -0
  27. warp/fem/geometry/geometry.py +186 -0
  28. warp/fem/geometry/grid_2d.py +373 -0
  29. warp/fem/geometry/grid_3d.py +435 -0
  30. warp/fem/geometry/hexmesh.py +953 -0
  31. warp/fem/geometry/partition.py +376 -0
  32. warp/fem/geometry/quadmesh_2d.py +532 -0
  33. warp/fem/geometry/tetmesh.py +840 -0
  34. warp/fem/geometry/trimesh_2d.py +577 -0
  35. warp/fem/integrate.py +1616 -0
  36. warp/fem/operator.py +191 -0
  37. warp/fem/polynomial.py +213 -0
  38. warp/fem/quadrature/__init__.py +2 -0
  39. warp/fem/quadrature/pic_quadrature.py +245 -0
  40. warp/fem/quadrature/quadrature.py +294 -0
  41. warp/fem/space/__init__.py +292 -0
  42. warp/fem/space/basis_space.py +489 -0
  43. warp/fem/space/collocated_function_space.py +105 -0
  44. warp/fem/space/dof_mapper.py +236 -0
  45. warp/fem/space/function_space.py +145 -0
  46. warp/fem/space/grid_2d_function_space.py +267 -0
  47. warp/fem/space/grid_3d_function_space.py +306 -0
  48. warp/fem/space/hexmesh_function_space.py +352 -0
  49. warp/fem/space/partition.py +350 -0
  50. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  51. warp/fem/space/restriction.py +160 -0
  52. warp/fem/space/shape/__init__.py +15 -0
  53. warp/fem/space/shape/cube_shape_function.py +738 -0
  54. warp/fem/space/shape/shape_function.py +103 -0
  55. warp/fem/space/shape/square_shape_function.py +611 -0
  56. warp/fem/space/shape/tet_shape_function.py +567 -0
  57. warp/fem/space/shape/triangle_shape_function.py +429 -0
  58. warp/fem/space/tetmesh_function_space.py +292 -0
  59. warp/fem/space/topology.py +295 -0
  60. warp/fem/space/trimesh_2d_function_space.py +221 -0
  61. warp/fem/types.py +77 -0
  62. warp/fem/utils.py +495 -0
  63. warp/native/array.h +147 -44
  64. warp/native/builtin.h +122 -149
  65. warp/native/bvh.cpp +73 -325
  66. warp/native/bvh.cu +406 -23
  67. warp/native/bvh.h +34 -43
  68. warp/native/clang/clang.cpp +13 -8
  69. warp/native/crt.h +2 -0
  70. warp/native/cuda_crt.h +5 -0
  71. warp/native/cuda_util.cpp +15 -3
  72. warp/native/cuda_util.h +3 -1
  73. warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
  74. warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
  75. warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
  76. warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
  77. warp/native/cutlass/tools/library/scripts/library.py +799 -0
  78. warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
  79. warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
  80. warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
  81. warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
  82. warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
  83. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
  84. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
  85. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
  86. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
  87. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
  88. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
  89. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
  90. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
  91. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
  92. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
  93. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
  94. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
  95. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
  96. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
  97. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
  98. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
  99. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
  100. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
  101. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
  102. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
  103. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
  104. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
  105. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
  106. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
  107. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
  108. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
  109. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
  110. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
  111. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
  112. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
  113. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
  114. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
  115. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  116. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  117. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
  118. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
  119. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
  120. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  121. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  122. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
  123. warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
  124. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
  125. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
  126. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
  127. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
  128. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
  129. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
  130. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
  131. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
  132. warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
  133. warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
  134. warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
  135. warp/native/cutlass/tools/library/scripts/rt.py +796 -0
  136. warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
  137. warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
  138. warp/native/cutlass_gemm.cu +5 -3
  139. warp/native/exports.h +1240 -952
  140. warp/native/fabric.h +228 -0
  141. warp/native/hashgrid.cpp +4 -4
  142. warp/native/hashgrid.h +22 -2
  143. warp/native/intersect.h +22 -7
  144. warp/native/intersect_adj.h +8 -8
  145. warp/native/intersect_tri.h +1 -1
  146. warp/native/marching.cu +157 -161
  147. warp/native/mat.h +80 -19
  148. warp/native/matnn.h +2 -2
  149. warp/native/mesh.cpp +33 -108
  150. warp/native/mesh.cu +114 -23
  151. warp/native/mesh.h +446 -46
  152. warp/native/noise.h +272 -329
  153. warp/native/quat.h +51 -8
  154. warp/native/rand.h +45 -35
  155. warp/native/range.h +6 -2
  156. warp/native/reduce.cpp +1 -1
  157. warp/native/reduce.cu +10 -12
  158. warp/native/runlength_encode.cu +6 -10
  159. warp/native/scan.cu +8 -11
  160. warp/native/sparse.cpp +4 -4
  161. warp/native/sparse.cu +164 -154
  162. warp/native/spatial.h +2 -2
  163. warp/native/temp_buffer.h +14 -30
  164. warp/native/vec.h +107 -23
  165. warp/native/volume.h +120 -0
  166. warp/native/warp.cpp +560 -30
  167. warp/native/warp.cu +431 -44
  168. warp/native/warp.h +13 -4
  169. warp/optim/__init__.py +1 -0
  170. warp/optim/linear.py +922 -0
  171. warp/optim/sgd.py +92 -0
  172. warp/render/render_opengl.py +335 -119
  173. warp/render/render_usd.py +11 -11
  174. warp/sim/__init__.py +2 -2
  175. warp/sim/articulation.py +385 -185
  176. warp/sim/collide.py +8 -0
  177. warp/sim/import_mjcf.py +297 -106
  178. warp/sim/import_urdf.py +389 -210
  179. warp/sim/import_usd.py +198 -97
  180. warp/sim/inertia.py +17 -18
  181. warp/sim/integrator_euler.py +14 -8
  182. warp/sim/integrator_xpbd.py +158 -16
  183. warp/sim/model.py +795 -291
  184. warp/sim/render.py +3 -3
  185. warp/sim/utils.py +3 -0
  186. warp/sparse.py +640 -150
  187. warp/stubs.py +606 -267
  188. warp/tape.py +61 -10
  189. warp/tests/__main__.py +3 -6
  190. warp/tests/assets/curlnoise_golden.npy +0 -0
  191. warp/tests/assets/pnoise_golden.npy +0 -0
  192. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  193. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  194. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  195. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  196. warp/tests/aux_test_unresolved_func.py +14 -0
  197. warp/tests/aux_test_unresolved_symbol.py +14 -0
  198. warp/tests/disabled_kinematics.py +239 -0
  199. warp/tests/run_coverage_serial.py +31 -0
  200. warp/tests/test_adam.py +103 -106
  201. warp/tests/test_arithmetic.py +128 -74
  202. warp/tests/test_array.py +212 -97
  203. warp/tests/test_array_reduce.py +57 -23
  204. warp/tests/test_atomic.py +64 -28
  205. warp/tests/test_bool.py +99 -0
  206. warp/tests/test_builtins_resolution.py +1292 -0
  207. warp/tests/test_bvh.py +42 -18
  208. warp/tests/test_closest_point_edge_edge.py +54 -57
  209. warp/tests/test_codegen.py +208 -130
  210. warp/tests/test_compile_consts.py +28 -20
  211. warp/tests/test_conditional.py +108 -24
  212. warp/tests/test_copy.py +10 -12
  213. warp/tests/test_ctypes.py +112 -88
  214. warp/tests/test_dense.py +21 -14
  215. warp/tests/test_devices.py +98 -0
  216. warp/tests/test_dlpack.py +75 -75
  217. warp/tests/test_examples.py +277 -0
  218. warp/tests/test_fabricarray.py +955 -0
  219. warp/tests/test_fast_math.py +15 -11
  220. warp/tests/test_fem.py +1271 -0
  221. warp/tests/test_fp16.py +53 -19
  222. warp/tests/test_func.py +187 -86
  223. warp/tests/test_generics.py +194 -49
  224. warp/tests/test_grad.py +178 -109
  225. warp/tests/test_grad_customs.py +176 -0
  226. warp/tests/test_hash_grid.py +52 -37
  227. warp/tests/test_import.py +10 -23
  228. warp/tests/test_indexedarray.py +32 -31
  229. warp/tests/test_intersect.py +18 -9
  230. warp/tests/test_large.py +141 -0
  231. warp/tests/test_launch.py +14 -41
  232. warp/tests/test_lerp.py +64 -65
  233. warp/tests/test_linear_solvers.py +154 -0
  234. warp/tests/test_lvalue.py +493 -0
  235. warp/tests/test_marching_cubes.py +12 -13
  236. warp/tests/test_mat.py +517 -2898
  237. warp/tests/test_mat_lite.py +115 -0
  238. warp/tests/test_mat_scalar_ops.py +2889 -0
  239. warp/tests/test_math.py +103 -9
  240. warp/tests/test_matmul.py +305 -69
  241. warp/tests/test_matmul_lite.py +410 -0
  242. warp/tests/test_mesh.py +71 -14
  243. warp/tests/test_mesh_query_aabb.py +41 -25
  244. warp/tests/test_mesh_query_point.py +140 -22
  245. warp/tests/test_mesh_query_ray.py +39 -22
  246. warp/tests/test_mlp.py +30 -22
  247. warp/tests/test_model.py +92 -89
  248. warp/tests/test_modules_lite.py +39 -0
  249. warp/tests/test_multigpu.py +88 -114
  250. warp/tests/test_noise.py +12 -11
  251. warp/tests/test_operators.py +16 -20
  252. warp/tests/test_options.py +11 -11
  253. warp/tests/test_pinned.py +17 -18
  254. warp/tests/test_print.py +32 -11
  255. warp/tests/test_quat.py +275 -129
  256. warp/tests/test_rand.py +18 -16
  257. warp/tests/test_reload.py +38 -34
  258. warp/tests/test_rounding.py +50 -43
  259. warp/tests/test_runlength_encode.py +168 -20
  260. warp/tests/test_smoothstep.py +9 -11
  261. warp/tests/test_snippet.py +143 -0
  262. warp/tests/test_sparse.py +261 -63
  263. warp/tests/test_spatial.py +276 -243
  264. warp/tests/test_streams.py +110 -85
  265. warp/tests/test_struct.py +268 -63
  266. warp/tests/test_tape.py +39 -21
  267. warp/tests/test_torch.py +118 -89
  268. warp/tests/test_transient_module.py +12 -13
  269. warp/tests/test_types.py +614 -0
  270. warp/tests/test_utils.py +494 -0
  271. warp/tests/test_vec.py +354 -2050
  272. warp/tests/test_vec_lite.py +73 -0
  273. warp/tests/test_vec_scalar_ops.py +2099 -0
  274. warp/tests/test_volume.py +457 -293
  275. warp/tests/test_volume_write.py +124 -134
  276. warp/tests/unittest_serial.py +35 -0
  277. warp/tests/unittest_suites.py +341 -0
  278. warp/tests/unittest_utils.py +568 -0
  279. warp/tests/unused_test_misc.py +71 -0
  280. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  281. warp/thirdparty/appdirs.py +36 -45
  282. warp/thirdparty/unittest_parallel.py +549 -0
  283. warp/torch.py +9 -6
  284. warp/types.py +1089 -366
  285. warp/utils.py +93 -387
  286. warp_lang-0.11.0.dist-info/METADATA +238 -0
  287. warp_lang-0.11.0.dist-info/RECORD +332 -0
  288. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
  289. warp/tests/test_all.py +0 -219
  290. warp/tests/test_array_scan.py +0 -60
  291. warp/tests/test_base.py +0 -208
  292. warp/tests/test_unresolved_func.py +0 -7
  293. warp/tests/test_unresolved_symbol.py +0 -7
  294. warp_lang-0.10.1.dist-info/METADATA +0 -21
  295. warp_lang-0.10.1.dist-info/RECORD +0 -188
  296. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  297. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  298. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  299. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/LICENSE.md +0 -0
  300. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,96 @@
1
+ # test/unit/conv/device/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.cu
2
+ import pycutlass
3
+ from pycutlass.conv2d_operation import *
4
+ from pycutlass import *
5
+ from pycutlass.test import *
6
+ from pycutlass.utils.device import device_cc
7
+ import unittest
8
+
9
+
10
+ @unittest.skipIf(device_cc() < 80, "Device compute capability is insufficient for SM80 tests.")
11
+ class Conv2dWgradImplicitGemmF32nhwcF32nhwcF32nhwcSimtF32SM80(unittest.TestCase):
12
+ def test_SM80_Device_Conv2d_Wgrad_Analytic_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32(self):
13
+ math_inst = MathInstruction(
14
+ instruction_shape=[1, 1, 1],
15
+ element_a=cutlass.float32, element_b=cutlass.float32,
16
+ element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.Simt,
17
+ math_operation=MathOperation.multiply_add
18
+ )
19
+
20
+ A = TensorDescription(
21
+ element=math_inst.element_a,
22
+ layout=cutlass.TensorNHWC,
23
+ alignment=4)
24
+ B = TensorDescription(
25
+ element=math_inst.element_b,
26
+ layout=cutlass.TensorNHWC,
27
+ alignment=4)
28
+ C = TensorDescription(
29
+ element=cutlass.float32,
30
+ layout=cutlass.TensorNHWC,
31
+ alignment=1)
32
+
33
+ tile_description = TileDescription(
34
+ threadblock_shape=[128, 128, 8], stages=4,
35
+ warp_count=[2, 4, 1],
36
+ math_instruction=math_inst
37
+ )
38
+
39
+ epilogue_functor = LinearCombination(
40
+ C.element, C.alignment,
41
+ math_inst.element_accumulator, cutlass.float32)
42
+
43
+ operation = Conv2dOperation(
44
+ conv_kind=cutlass.conv.Operator.wgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic,
45
+ arch=80, tile_description=tile_description, A=A, B=B, C=C,
46
+ stride_support=StrideSupport.Strided,
47
+ epilogue_functor=epilogue_functor,
48
+ swizzling_functor=cutlass.IdentitySwizzle1
49
+ )
50
+
51
+ self.assertTrue(test_all_conv2d(operation))
52
+
53
+ def test_SM80_Device_Conv2d_Wgrad_Optimized_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32(self):
54
+ math_inst = MathInstruction(
55
+ instruction_shape=[1, 1, 1],
56
+ element_a=cutlass.float32, element_b=cutlass.float32,
57
+ element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.Simt,
58
+ math_operation=MathOperation.multiply_add
59
+ )
60
+
61
+ A = TensorDescription(
62
+ element=math_inst.element_a,
63
+ layout=cutlass.TensorNHWC,
64
+ alignment=4)
65
+ B = TensorDescription(
66
+ element=math_inst.element_b,
67
+ layout=cutlass.TensorNHWC,
68
+ alignment=4)
69
+ C = TensorDescription(
70
+ element=cutlass.float32,
71
+ layout=cutlass.TensorNHWC,
72
+ alignment=1)
73
+
74
+ tile_description = TileDescription(
75
+ threadblock_shape=[128, 128, 8], stages=4,
76
+ warp_count=[2, 4, 1],
77
+ math_instruction=math_inst
78
+ )
79
+
80
+ epilogue_functor = LinearCombination(
81
+ C.element, C.alignment,
82
+ math_inst.element_accumulator, cutlass.float32)
83
+
84
+ operation = Conv2dOperation(
85
+ conv_kind=cutlass.conv.Operator.wgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
86
+ arch=80, tile_description=tile_description, A=A, B=B, C=C,
87
+ stride_support=StrideSupport.Strided,
88
+ epilogue_functor=epilogue_functor,
89
+ swizzling_functor=cutlass.IdentitySwizzle1
90
+ )
91
+
92
+ self.assertTrue(test_all_conv2d(operation))
93
+
94
+ if __name__ == '__main__':
95
+ pycutlass.get_memory_pool(2**26, 2**26)
96
+ unittest.main()
@@ -0,0 +1,107 @@
1
+ # test/unit/conv/device/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.cu
2
+ import pycutlass
3
+ from pycutlass import *
4
+ from pycutlass.test import *
5
+ from pycutlass.utils.device import device_cc
6
+ import unittest
7
+
8
+
9
+ @unittest.skipIf(device_cc() < 80, "Device compute capability is insufficient for SM80 tests.")
10
+ class Conv2dWgradImplicitGemmTF32nhwcTF32nhwcTF32nhwcTensorOpF32SM80(unittest.TestCase):
11
+ def test_SM80_Device_Conv2d_Wgrad_Optimized_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32(self):
12
+ math_inst = MathInstruction(
13
+ instruction_shape=[16, 8, 8],
14
+ element_a=cutlass.float32, element_b=cutlass.float32,
15
+ element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp,
16
+ math_operation=MathOperation.multiply_add
17
+ )
18
+
19
+ A = TensorDescription(
20
+ element=math_inst.element_a,
21
+ layout=cutlass.TensorNHWC,
22
+ alignment=4)
23
+ B = TensorDescription(
24
+ element=math_inst.element_b,
25
+ layout=cutlass.TensorNHWC,
26
+ alignment=4)
27
+ C = TensorDescription(
28
+ element=cutlass.float32,
29
+ layout=cutlass.TensorNHWC,
30
+ alignment=8)
31
+
32
+ tile_description = TileDescription(
33
+ threadblock_shape=[128, 128, 16], stages=3,
34
+ warp_count=[2, 2, 1],
35
+ math_instruction=math_inst
36
+ )
37
+
38
+ epilogue_functor = LinearCombination(
39
+ C.element, C.alignment,
40
+ math_inst.element_accumulator, cutlass.float32)
41
+
42
+ operation = Conv2dOperation(
43
+ conv_kind=cutlass.conv.Operator.wgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
44
+ arch=80, tile_description=tile_description, A=A, B=B, C=C,
45
+ stride_support=StrideSupport.Strided,
46
+ epilogue_functor=epilogue_functor,
47
+ swizzling_functor=cutlass.IdentitySwizzle1
48
+ )
49
+
50
+ self.assertTrue(test_all_conv2d(operation))
51
+
52
+ def test_SM80_Device_Conv2d_Wgrad_Optimized_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_align1(self):
53
+ math_inst = MathInstruction(
54
+ instruction_shape=[16, 8, 8],
55
+ element_a=cutlass.float32, element_b=cutlass.float32,
56
+ element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp,
57
+ math_operation=MathOperation.multiply_add
58
+ )
59
+
60
+ A = TensorDescription(
61
+ element=math_inst.element_a,
62
+ layout=cutlass.TensorNHWC,
63
+ alignment=1)
64
+ B = TensorDescription(
65
+ element=math_inst.element_b,
66
+ layout=cutlass.TensorNHWC,
67
+ alignment=1)
68
+ C = TensorDescription(
69
+ element=cutlass.float32,
70
+ layout=cutlass.TensorNHWC,
71
+ alignment=4)
72
+
73
+ tile_description = TileDescription(
74
+ threadblock_shape=[128, 128, 32], stages=3,
75
+ warp_count=[2, 2, 1],
76
+ math_instruction=math_inst
77
+ )
78
+
79
+ epilogue_functor = LinearCombination(
80
+ C.element, C.alignment,
81
+ math_inst.element_accumulator, cutlass.float32)
82
+
83
+ operation = Conv2dOperation(
84
+ conv_kind=cutlass.conv.Operator.wgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
85
+ arch=80, tile_description=tile_description, A=A, B=B, C=C,
86
+ stride_support=StrideSupport.Strided,
87
+ epilogue_functor=epilogue_functor,
88
+ swizzling_functor=cutlass.IdentitySwizzle1
89
+ )
90
+
91
+ problem_sizes = [
92
+ cutlass.conv.Conv2dProblemSize(
93
+ cutlass.Tensor4DCoord(1, 8, 8, 1),
94
+ cutlass.Tensor4DCoord(1, 3, 3, 1),
95
+ cutlass.Tensor4DCoord(1, 1, 1, 1),
96
+ cutlass.MatrixCoord(1, 1),
97
+ cutlass.MatrixCoord(1, 1),
98
+ cutlass.conv.Mode.cross_correlation,
99
+ 1, 1
100
+ ),
101
+ ]
102
+
103
+ self.assertTrue(test_all_conv2d(operation, problem_sizes))
104
+
105
+ if __name__ == '__main__':
106
+ pycutlass.get_memory_pool(2**26, 2**26)
107
+ unittest.main()
@@ -0,0 +1,10 @@
1
+ import pycutlass
2
+ import unittest
3
+ from pycutlass.memory_manager import *
4
+
5
+ if __name__ == '__main__':
6
+ pycutlass.get_memory_pool(2**32, 2**32)
7
+ loader = unittest.TestLoader()
8
+ tests = loader.discover('./', 'conv2d_*.py')
9
+ testRunner = unittest.runner.TextTestRunner()
10
+ testRunner.run(tests)
@@ -0,0 +1,146 @@
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+ ## Test case for Pytorch
33
+ import pycutlass
34
+ import unittest
35
+ from pycutlass import *
36
+ from pycutlass.utils.device import device_cc
37
+ import torch
38
+ import cupy as cp
39
+
40
+
41
+ class Test_Frontend(unittest.TestCase):
42
+ def setUp(self) -> None:
43
+ #
44
+ # define the cutlass operator
45
+ #
46
+ cc = device_cc()
47
+ math_inst = MathInstruction(
48
+ [1, 1, 1], cutlass.float32, cutlass.float32, cutlass.float32,
49
+ cutlass.OpClass.Simt, MathOperation.multiply_add
50
+ )
51
+
52
+ # Stages > 2 is supported only for compute capability 80 and beyond
53
+ stages = 4 if cc >= 80 else 2
54
+
55
+
56
+ tile_description = TileDescription(
57
+ [128, 128, 8], stages, [2, 4, 1],
58
+ math_inst
59
+ )
60
+
61
+ A = TensorDescription(
62
+ cutlass.float32, cutlass.RowMajor, 1
63
+ )
64
+
65
+ B = TensorDescription(
66
+ cutlass.float32, cutlass.RowMajor, 1
67
+ )
68
+
69
+ C = TensorDescription(
70
+ cutlass.float32, cutlass.RowMajor, 1
71
+ )
72
+
73
+ epilogue_functor = LinearCombination(
74
+ C.element, C.alignment,
75
+ math_inst.element_accumulator, cutlass.float32)
76
+
77
+ self.operation = GemmOperationUniversal(
78
+ arch=cc, tile_description=tile_description,
79
+ A=A, B=B, C=C,
80
+ epilogue_functor=epilogue_functor,
81
+ swizzling_functor=cutlass.IdentitySwizzle1
82
+ )
83
+
84
+ pycutlass.compiler.add_module([self.operation,])
85
+
86
+
87
+ def test_torch_frontend(self):
88
+ problem_size = cutlass.gemm.GemmCoord(512, 256, 128)
89
+
90
+ tensor_A = torch.ceil(torch.empty(size=(problem_size.m(), problem_size.k()), dtype=torch.float32, device="cuda").uniform_(-8.5, 7.5))
91
+ tensor_B = torch.ceil(torch.empty(size=(problem_size.k(), problem_size.n()), dtype=torch.float32, device="cuda").uniform_(-8.5, 7.5))
92
+ tensor_C = torch.ceil(torch.empty(size=(problem_size.m(), problem_size.n()), dtype=torch.float32, device="cuda").uniform_(-8.5, 7.5))
93
+ tensor_D = torch.empty_like(tensor_C)
94
+
95
+
96
+ alpha = 1.0
97
+ beta = 0.0
98
+
99
+ arguments = GemmArguments(
100
+ operation=self.operation, problem_size=problem_size,
101
+ A=tensor_A, B=tensor_B, C=tensor_C, D=tensor_D,
102
+ output_op=self.operation.epilogue_type(alpha, beta),
103
+ gemm_mode=cutlass.gemm.Mode.Gemm, split_k_splices=1
104
+ )
105
+
106
+ self.operation.run(arguments)
107
+
108
+ arguments.sync()
109
+
110
+ tensor_D_ref = alpha * tensor_A @ tensor_B + beta * tensor_C
111
+
112
+ self.assertTrue(torch.equal(tensor_D, tensor_D_ref))
113
+
114
+ def test_cupy_frontend(self):
115
+ cp.cuda.set_allocator(rmm.rmm_cupy_allocator)
116
+
117
+ problem_size = cutlass.gemm.GemmCoord(512, 256, 128)
118
+
119
+ tensor_A = cp.ceil(cp.random.uniform(low=-8.5, high=7.5, size=(problem_size.m(), problem_size.k()), dtype=cp.float32))
120
+ tensor_B = cp.ceil(cp.random.uniform(low=-8.5, high=7.5, size=(problem_size.k(), problem_size.n()), dtype=cp.float32))
121
+ tensor_C = cp.ceil(cp.random.uniform(low=-8.5, high=7.5, size=(problem_size.m(), problem_size.n()), dtype=cp.float32))
122
+ tensor_D = cp.ones_like(tensor_C)
123
+
124
+ alpha = 1.0
125
+ beta = 1.0
126
+
127
+ tensor_D_ref = alpha * tensor_A @ tensor_B + beta * tensor_C
128
+
129
+ arguments = GemmArguments(
130
+ operation=self.operation, problem_size=problem_size,
131
+ A=tensor_A, B=tensor_B, C=tensor_C, D=tensor_D,
132
+ output_op=self.operation.epilogue_type(alpha, beta),
133
+ gemm_mode=cutlass.gemm.Mode.Gemm, split_k_splices=1
134
+ )
135
+
136
+ self.operation.run(arguments)
137
+
138
+ arguments.sync()
139
+
140
+ self.assertTrue(cp.array_equal(tensor_D, tensor_D_ref))
141
+
142
+
143
+
144
+ if __name__ == '__main__':
145
+ pycutlass.get_memory_pool(2**32, 2**32)
146
+ unittest.main()
@@ -0,0 +1,96 @@
1
+ import pycutlass
2
+ from pycutlass import *
3
+ from pycutlass.test import *
4
+ import unittest
5
+
6
+ from pycutlass.test.gemm_testbed import test_all_gemm
7
+ from pycutlass.utils.device import device_cc
8
+
9
+
10
+ @unittest.skipIf(device_cc() < 80, "Device compute capability is insufficient for SM80 tests.")
11
+ class GemmBF16TensorOpSm80(unittest.TestCase):
12
+ def SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32_64x128x64_32x64x64(self):
13
+ math_inst = MathInstruction(
14
+ instruction_shape=[16, 8, 16],
15
+ element_a=cutlass.bfloat16, element_b=cutlass.bfloat16,
16
+ element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp,
17
+ math_operation=MathOperation.multiply_add
18
+ )
19
+
20
+ tile_description = TileDescription(
21
+ threadblock_shape=[64, 128, 64],
22
+ stages=4, warp_count=[2, 2, 1],
23
+ math_instruction=math_inst
24
+ )
25
+
26
+ A = TensorDescription(
27
+ element=cutlass.bfloat16, layout=cutlass.ColumnMajor,
28
+ alignment=8
29
+ )
30
+ B = TensorDescription(
31
+ element=cutlass.bfloat16, layout=cutlass.ColumnMajor,
32
+ alignment=8
33
+ )
34
+ C = TensorDescription(
35
+ element=cutlass.float32, layout=cutlass.RowMajor,
36
+ alignment=4
37
+ )
38
+
39
+ epilogue_functor = LinearCombination(
40
+ C.element, C.alignment,
41
+ math_inst.element_accumulator, cutlass.float32)
42
+
43
+ swizzling_functor = cutlass.IdentitySwizzle1
44
+
45
+ operation = GemmOperationUniversal(
46
+ arch=80, tile_description=tile_description,
47
+ A=A, B=B, C=C,
48
+ epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
49
+ )
50
+
51
+ self.assertTrue(test_all_gemm(operation, "universal"))
52
+
53
+ def test_SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32_128x256x64_64x64x64(self):
54
+ math_inst = MathInstruction(
55
+ instruction_shape=[16, 8, 16],
56
+ element_a=cutlass.bfloat16, element_b=cutlass.bfloat16,
57
+ element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp,
58
+ math_operation=MathOperation.multiply_add
59
+ )
60
+
61
+ tile_description = TileDescription(
62
+ threadblock_shape=[64, 128, 32],
63
+ stages=6, warp_count=[2, 2, 1],
64
+ math_instruction=math_inst
65
+ )
66
+
67
+ A = TensorDescription(
68
+ element=cutlass.bfloat16, layout=cutlass.RowMajor,
69
+ alignment=8
70
+ )
71
+ B = TensorDescription(
72
+ element=cutlass.bfloat16, layout=cutlass.RowMajor,
73
+ alignment=8
74
+ )
75
+ C = TensorDescription(
76
+ element=cutlass.bfloat16, layout=cutlass.RowMajor,
77
+ alignment=8
78
+ )
79
+
80
+ epilogue_functor = LinearCombination(
81
+ C.element, C.alignment,
82
+ math_inst.element_accumulator, cutlass.float32)
83
+
84
+ swizzling_functor = cutlass.IdentitySwizzle1
85
+
86
+ operation = GemmOperationUniversal(
87
+ arch=80, tile_description=tile_description,
88
+ A=A, B=B, C=C,
89
+ epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
90
+ )
91
+
92
+ self.assertTrue(test_all_gemm(operation, "multistage"))
93
+
94
+ if __name__ == '__main__':
95
+ pycutlass.get_memory_pool(2**24, 2**24)
96
+ unittest.main()