warp-lang 0.10.1__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (300) hide show
  1. warp/__init__.py +10 -4
  2. warp/__init__.pyi +1 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +5 -3
  6. warp/build_dll.py +29 -9
  7. warp/builtins.py +868 -507
  8. warp/codegen.py +1074 -638
  9. warp/config.py +3 -3
  10. warp/constants.py +6 -0
  11. warp/context.py +715 -222
  12. warp/fabric.py +326 -0
  13. warp/fem/__init__.py +27 -0
  14. warp/fem/cache.py +389 -0
  15. warp/fem/dirichlet.py +181 -0
  16. warp/fem/domain.py +263 -0
  17. warp/fem/field/__init__.py +101 -0
  18. warp/fem/field/field.py +149 -0
  19. warp/fem/field/nodal_field.py +299 -0
  20. warp/fem/field/restriction.py +21 -0
  21. warp/fem/field/test.py +181 -0
  22. warp/fem/field/trial.py +183 -0
  23. warp/fem/geometry/__init__.py +19 -0
  24. warp/fem/geometry/closest_point.py +70 -0
  25. warp/fem/geometry/deformed_geometry.py +271 -0
  26. warp/fem/geometry/element.py +744 -0
  27. warp/fem/geometry/geometry.py +186 -0
  28. warp/fem/geometry/grid_2d.py +373 -0
  29. warp/fem/geometry/grid_3d.py +435 -0
  30. warp/fem/geometry/hexmesh.py +953 -0
  31. warp/fem/geometry/partition.py +376 -0
  32. warp/fem/geometry/quadmesh_2d.py +532 -0
  33. warp/fem/geometry/tetmesh.py +840 -0
  34. warp/fem/geometry/trimesh_2d.py +577 -0
  35. warp/fem/integrate.py +1616 -0
  36. warp/fem/operator.py +191 -0
  37. warp/fem/polynomial.py +213 -0
  38. warp/fem/quadrature/__init__.py +2 -0
  39. warp/fem/quadrature/pic_quadrature.py +245 -0
  40. warp/fem/quadrature/quadrature.py +294 -0
  41. warp/fem/space/__init__.py +292 -0
  42. warp/fem/space/basis_space.py +489 -0
  43. warp/fem/space/collocated_function_space.py +105 -0
  44. warp/fem/space/dof_mapper.py +236 -0
  45. warp/fem/space/function_space.py +145 -0
  46. warp/fem/space/grid_2d_function_space.py +267 -0
  47. warp/fem/space/grid_3d_function_space.py +306 -0
  48. warp/fem/space/hexmesh_function_space.py +352 -0
  49. warp/fem/space/partition.py +350 -0
  50. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  51. warp/fem/space/restriction.py +160 -0
  52. warp/fem/space/shape/__init__.py +15 -0
  53. warp/fem/space/shape/cube_shape_function.py +738 -0
  54. warp/fem/space/shape/shape_function.py +103 -0
  55. warp/fem/space/shape/square_shape_function.py +611 -0
  56. warp/fem/space/shape/tet_shape_function.py +567 -0
  57. warp/fem/space/shape/triangle_shape_function.py +429 -0
  58. warp/fem/space/tetmesh_function_space.py +292 -0
  59. warp/fem/space/topology.py +295 -0
  60. warp/fem/space/trimesh_2d_function_space.py +221 -0
  61. warp/fem/types.py +77 -0
  62. warp/fem/utils.py +495 -0
  63. warp/native/array.h +147 -44
  64. warp/native/builtin.h +122 -149
  65. warp/native/bvh.cpp +73 -325
  66. warp/native/bvh.cu +406 -23
  67. warp/native/bvh.h +34 -43
  68. warp/native/clang/clang.cpp +13 -8
  69. warp/native/crt.h +2 -0
  70. warp/native/cuda_crt.h +5 -0
  71. warp/native/cuda_util.cpp +15 -3
  72. warp/native/cuda_util.h +3 -1
  73. warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
  74. warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
  75. warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
  76. warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
  77. warp/native/cutlass/tools/library/scripts/library.py +799 -0
  78. warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
  79. warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
  80. warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
  81. warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
  82. warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
  83. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
  84. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
  85. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
  86. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
  87. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
  88. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
  89. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
  90. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
  91. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
  92. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
  93. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
  94. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
  95. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
  96. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
  97. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
  98. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
  99. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
  100. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
  101. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
  102. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
  103. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
  104. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
  105. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
  106. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
  107. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
  108. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
  109. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
  110. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
  111. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
  112. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
  113. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
  114. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
  115. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  116. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  117. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
  118. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
  119. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
  120. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  121. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  122. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
  123. warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
  124. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
  125. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
  126. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
  127. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
  128. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
  129. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
  130. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
  131. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
  132. warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
  133. warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
  134. warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
  135. warp/native/cutlass/tools/library/scripts/rt.py +796 -0
  136. warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
  137. warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
  138. warp/native/cutlass_gemm.cu +5 -3
  139. warp/native/exports.h +1240 -952
  140. warp/native/fabric.h +228 -0
  141. warp/native/hashgrid.cpp +4 -4
  142. warp/native/hashgrid.h +22 -2
  143. warp/native/intersect.h +22 -7
  144. warp/native/intersect_adj.h +8 -8
  145. warp/native/intersect_tri.h +1 -1
  146. warp/native/marching.cu +157 -161
  147. warp/native/mat.h +80 -19
  148. warp/native/matnn.h +2 -2
  149. warp/native/mesh.cpp +33 -108
  150. warp/native/mesh.cu +114 -23
  151. warp/native/mesh.h +446 -46
  152. warp/native/noise.h +272 -329
  153. warp/native/quat.h +51 -8
  154. warp/native/rand.h +45 -35
  155. warp/native/range.h +6 -2
  156. warp/native/reduce.cpp +1 -1
  157. warp/native/reduce.cu +10 -12
  158. warp/native/runlength_encode.cu +6 -10
  159. warp/native/scan.cu +8 -11
  160. warp/native/sparse.cpp +4 -4
  161. warp/native/sparse.cu +164 -154
  162. warp/native/spatial.h +2 -2
  163. warp/native/temp_buffer.h +14 -30
  164. warp/native/vec.h +107 -23
  165. warp/native/volume.h +120 -0
  166. warp/native/warp.cpp +560 -30
  167. warp/native/warp.cu +431 -44
  168. warp/native/warp.h +13 -4
  169. warp/optim/__init__.py +1 -0
  170. warp/optim/linear.py +922 -0
  171. warp/optim/sgd.py +92 -0
  172. warp/render/render_opengl.py +335 -119
  173. warp/render/render_usd.py +11 -11
  174. warp/sim/__init__.py +2 -2
  175. warp/sim/articulation.py +385 -185
  176. warp/sim/collide.py +8 -0
  177. warp/sim/import_mjcf.py +297 -106
  178. warp/sim/import_urdf.py +389 -210
  179. warp/sim/import_usd.py +198 -97
  180. warp/sim/inertia.py +17 -18
  181. warp/sim/integrator_euler.py +14 -8
  182. warp/sim/integrator_xpbd.py +158 -16
  183. warp/sim/model.py +795 -291
  184. warp/sim/render.py +3 -3
  185. warp/sim/utils.py +3 -0
  186. warp/sparse.py +640 -150
  187. warp/stubs.py +606 -267
  188. warp/tape.py +61 -10
  189. warp/tests/__main__.py +3 -6
  190. warp/tests/assets/curlnoise_golden.npy +0 -0
  191. warp/tests/assets/pnoise_golden.npy +0 -0
  192. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  193. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  194. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  195. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  196. warp/tests/aux_test_unresolved_func.py +14 -0
  197. warp/tests/aux_test_unresolved_symbol.py +14 -0
  198. warp/tests/disabled_kinematics.py +239 -0
  199. warp/tests/run_coverage_serial.py +31 -0
  200. warp/tests/test_adam.py +103 -106
  201. warp/tests/test_arithmetic.py +128 -74
  202. warp/tests/test_array.py +212 -97
  203. warp/tests/test_array_reduce.py +57 -23
  204. warp/tests/test_atomic.py +64 -28
  205. warp/tests/test_bool.py +99 -0
  206. warp/tests/test_builtins_resolution.py +1292 -0
  207. warp/tests/test_bvh.py +42 -18
  208. warp/tests/test_closest_point_edge_edge.py +54 -57
  209. warp/tests/test_codegen.py +208 -130
  210. warp/tests/test_compile_consts.py +28 -20
  211. warp/tests/test_conditional.py +108 -24
  212. warp/tests/test_copy.py +10 -12
  213. warp/tests/test_ctypes.py +112 -88
  214. warp/tests/test_dense.py +21 -14
  215. warp/tests/test_devices.py +98 -0
  216. warp/tests/test_dlpack.py +75 -75
  217. warp/tests/test_examples.py +277 -0
  218. warp/tests/test_fabricarray.py +955 -0
  219. warp/tests/test_fast_math.py +15 -11
  220. warp/tests/test_fem.py +1271 -0
  221. warp/tests/test_fp16.py +53 -19
  222. warp/tests/test_func.py +187 -86
  223. warp/tests/test_generics.py +194 -49
  224. warp/tests/test_grad.py +178 -109
  225. warp/tests/test_grad_customs.py +176 -0
  226. warp/tests/test_hash_grid.py +52 -37
  227. warp/tests/test_import.py +10 -23
  228. warp/tests/test_indexedarray.py +32 -31
  229. warp/tests/test_intersect.py +18 -9
  230. warp/tests/test_large.py +141 -0
  231. warp/tests/test_launch.py +14 -41
  232. warp/tests/test_lerp.py +64 -65
  233. warp/tests/test_linear_solvers.py +154 -0
  234. warp/tests/test_lvalue.py +493 -0
  235. warp/tests/test_marching_cubes.py +12 -13
  236. warp/tests/test_mat.py +517 -2898
  237. warp/tests/test_mat_lite.py +115 -0
  238. warp/tests/test_mat_scalar_ops.py +2889 -0
  239. warp/tests/test_math.py +103 -9
  240. warp/tests/test_matmul.py +305 -69
  241. warp/tests/test_matmul_lite.py +410 -0
  242. warp/tests/test_mesh.py +71 -14
  243. warp/tests/test_mesh_query_aabb.py +41 -25
  244. warp/tests/test_mesh_query_point.py +140 -22
  245. warp/tests/test_mesh_query_ray.py +39 -22
  246. warp/tests/test_mlp.py +30 -22
  247. warp/tests/test_model.py +92 -89
  248. warp/tests/test_modules_lite.py +39 -0
  249. warp/tests/test_multigpu.py +88 -114
  250. warp/tests/test_noise.py +12 -11
  251. warp/tests/test_operators.py +16 -20
  252. warp/tests/test_options.py +11 -11
  253. warp/tests/test_pinned.py +17 -18
  254. warp/tests/test_print.py +32 -11
  255. warp/tests/test_quat.py +275 -129
  256. warp/tests/test_rand.py +18 -16
  257. warp/tests/test_reload.py +38 -34
  258. warp/tests/test_rounding.py +50 -43
  259. warp/tests/test_runlength_encode.py +168 -20
  260. warp/tests/test_smoothstep.py +9 -11
  261. warp/tests/test_snippet.py +143 -0
  262. warp/tests/test_sparse.py +261 -63
  263. warp/tests/test_spatial.py +276 -243
  264. warp/tests/test_streams.py +110 -85
  265. warp/tests/test_struct.py +268 -63
  266. warp/tests/test_tape.py +39 -21
  267. warp/tests/test_torch.py +118 -89
  268. warp/tests/test_transient_module.py +12 -13
  269. warp/tests/test_types.py +614 -0
  270. warp/tests/test_utils.py +494 -0
  271. warp/tests/test_vec.py +354 -2050
  272. warp/tests/test_vec_lite.py +73 -0
  273. warp/tests/test_vec_scalar_ops.py +2099 -0
  274. warp/tests/test_volume.py +457 -293
  275. warp/tests/test_volume_write.py +124 -134
  276. warp/tests/unittest_serial.py +35 -0
  277. warp/tests/unittest_suites.py +341 -0
  278. warp/tests/unittest_utils.py +568 -0
  279. warp/tests/unused_test_misc.py +71 -0
  280. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  281. warp/thirdparty/appdirs.py +36 -45
  282. warp/thirdparty/unittest_parallel.py +549 -0
  283. warp/torch.py +9 -6
  284. warp/types.py +1089 -366
  285. warp/utils.py +93 -387
  286. warp_lang-0.11.0.dist-info/METADATA +238 -0
  287. warp_lang-0.11.0.dist-info/RECORD +332 -0
  288. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
  289. warp/tests/test_all.py +0 -219
  290. warp/tests/test_array_scan.py +0 -60
  291. warp/tests/test_base.py +0 -208
  292. warp/tests/test_unresolved_func.py +0 -7
  293. warp/tests/test_unresolved_symbol.py +0 -7
  294. warp_lang-0.10.1.dist-info/METADATA +0 -21
  295. warp_lang-0.10.1.dist-info/RECORD +0 -188
  296. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  297. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  298. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  299. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/LICENSE.md +0 -0
  300. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,402 @@
1
+ #
2
+ # \file generator.py
3
+ #
4
+ # \brief Generates the CUTLASS Library's instances
5
+ #
6
+
7
+ import enum
8
+ import os.path
9
+ import shutil
10
+
11
+ from library import *
12
+ from gemm_operation import *
13
+ from rank_k_operation import *
14
+ from rank_2k_operation import *
15
+ from trmm_operation import *
16
+ from symm_operation import *
17
+ from conv2d_operation import *
18
+ from conv3d_operation import *
19
+
20
+ ###################################################################################################
21
+
22
+ class EmitOperationKindLibrary:
23
+ def __init__(self, generated_path, kind, args):
24
+ self.generated_path = generated_path
25
+ self.kind = kind
26
+ self.args = args
27
+ self.emitters = {
28
+ OperationKind.Gemm: EmitGemmConfigurationLibrary
29
+ , OperationKind.Conv2d: EmitConv2dConfigurationLibrary
30
+ , OperationKind.Conv3d: EmitConv3dConfigurationLibrary
31
+ , OperationKind.RankK: EmitRankKConfigurationLibrary
32
+ , OperationKind.Rank2K: EmitRank2KConfigurationLibrary
33
+ , OperationKind.Trmm: EmitTrmmConfigurationLibrary
34
+ , OperationKind.Symm: EmitSymmConfigurationLibrary
35
+ }
36
+
37
+ self.configurations = [];
38
+
39
+ self.header_template ="""
40
+ /*
41
+ Generated by manifest.py - Do not edit.
42
+ */
43
+
44
+ #include "cutlass/cutlass.h"
45
+ #include "cutlass/library/library.h"
46
+ #include "cutlass/library/manifest.h"
47
+
48
+ namespace cutlass {
49
+ namespace library {
50
+
51
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
52
+
53
+ """
54
+ self.entry_template = """
55
+
56
+ //
57
+ // Entry point to construct operations
58
+ //
59
+ void initialize_all_${operation_name}_operations(Manifest &manifest) {
60
+ """
61
+ self.configuration_prototype_template = "void initialize_${configuration_name}(Manifest &manifest);\n"
62
+ self.configuration_template =" initialize_${configuration_name}(manifest);\n"
63
+
64
+ self.epilogue_template ="""
65
+
66
+ }
67
+
68
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
69
+
70
+ } // namespace library
71
+ } // namespace cutlass
72
+
73
+ """
74
+
75
+ #
76
+ def __enter__(self):
77
+ self.operation_path = os.path.join(self.generated_path, OperationKindNames[self.kind])
78
+ os.mkdir(self.operation_path)
79
+
80
+ self.top_level_path = os.path.join(self.operation_path, "all_%s_operations.cu" % OperationKindNames[self.kind])
81
+
82
+ self.top_level_file = open(self.top_level_path, "w")
83
+ self.top_level_file.write(self.header_template)
84
+
85
+ self.source_files = [self.top_level_path,]
86
+
87
+ return self
88
+
89
+ #
90
+ def emit(self, configuration_name, operations):
91
+
92
+ with self.emitters[self.kind](self.operation_path, configuration_name) as configuration_emitter:
93
+ for operation in operations:
94
+ configuration_emitter.emit(operation)
95
+
96
+ self.source_files.append(configuration_emitter.configuration_path)
97
+
98
+ self.configurations.append(configuration_name)
99
+ self.top_level_file.write(SubstituteTemplate(self.configuration_prototype_template, {'configuration_name': configuration_name} ))
100
+
101
+ #
102
+ def __exit__(self, exception_type, exception_value, traceback):
103
+ self.top_level_file.write(SubstituteTemplate(self.entry_template, {'operation_name': OperationKindNames[self.kind]}))
104
+
105
+ for configuration_name in self.configurations:
106
+ self.top_level_file.write(SubstituteTemplate(self.configuration_template, {'configuration_name': configuration_name}))
107
+
108
+ self.top_level_file.write(self.epilogue_template)
109
+ self.top_level_file.close()
110
+
111
+ class EmitInterfaceLibrary:
112
+ def __init__(self, generated_path, operation_count, args):
113
+ self.generated_path = generated_path
114
+ self.args = args
115
+
116
+
117
+ self.prototypes = []
118
+ self.fn_calls = []
119
+ self.operation_count = str(operation_count)
120
+
121
+ self.top_level_hdr_template = '''
122
+ /*
123
+ Generated by manifest.py - Do not edit.
124
+ */
125
+ '''
126
+ self.top_level_prologue = '''
127
+
128
+ #include "cutlass/library/library.h"
129
+ #include "cutlass/library/manifest.h"
130
+
131
+ namespace cutlass {
132
+ \tnamespace library {
133
+
134
+ ${prototypes}
135
+
136
+ \t\tvoid initialize_all(Manifest &manifest) {
137
+ \t\t\tmanifest.reserve(${operation_count});\n\n
138
+ ${fn_calls}
139
+ \t\t\t}
140
+
141
+ \t} // namespace library
142
+ } // namespace cutlass
143
+
144
+ '''
145
+
146
+ #
147
+ def __enter__(self):
148
+ self.top_level_path = os.path.join(self.generated_path, 'initialize_all.cpp')
149
+
150
+ self.top_level_file = open(self.top_level_path, "w")
151
+ self.top_level_file.write(self.top_level_hdr_template)
152
+
153
+ self.source_files = [self.top_level_path,]
154
+
155
+ return self
156
+
157
+ #
158
+ def emit(self, operation_name):
159
+ self.prototypes.append(SubstituteTemplate(
160
+ "\t\tvoid initialize_all_${operation_kind}_operations(Manifest &manifest);",
161
+ {'operation_kind': operation_name}))
162
+ self.fn_calls.append(SubstituteTemplate(
163
+ "\t\t\tinitialize_all_${operation_kind}_operations(manifest);",
164
+ {'operation_kind': operation_name}))
165
+
166
+
167
+
168
+ #
169
+ def __exit__(self, exception_type, exception_value, traceback):
170
+ self.top_level_file.write(SubstituteTemplate(self.top_level_prologue, {'prototypes':"\n".join(self.prototypes),
171
+ 'fn_calls':"\n".join(self.fn_calls),
172
+ 'operation_count': self.operation_count}))
173
+ self.top_level_file.close()
174
+
175
+ ###################################################################################################
176
+ ###################################################################################################
177
+
178
+ class Options:
179
+ def __init__(self):
180
+ pass
181
+
182
+ ###################################################################################################
183
+
184
+ #
185
+ class Manifest:
186
+
187
+ #
188
+ def __init__(self, args = None):
189
+ self.operations = {}
190
+ self.args = args
191
+ self.operation_count = 0
192
+ self.operations_by_name = {}
193
+
194
+ self.kernel_filter = ''
195
+ self.kernel_filter_list = []
196
+ self.kernel_names = []
197
+ self.operations_enabled = []
198
+ self.selected_kernels = []
199
+ self.ignore_kernel_names = []
200
+ self.compute_capabilities = [50,]
201
+ self.curr_build_dir = '.'
202
+ self.filter_by_cc = True
203
+
204
+ if self.args:
205
+ self.kernel_filter = self.args.kernels
206
+ self.curr_build_dir = args.curr_build_dir
207
+ architectures = args.architectures.split(';') if len(args.architectures) else ['50',]
208
+ self.compute_capabilities = [int(x) for x in architectures]
209
+
210
+ if args.filter_by_cc in ['false', 'False', '0']:
211
+ self.filter_by_cc = False
212
+
213
+ if args.operations == 'all':
214
+ self.operations_enabled = []
215
+ else:
216
+ operations_list = [
217
+ OperationKind.Gemm
218
+ , OperationKind.Conv2d
219
+ , OperationKind.Conv3d
220
+ , OperationKind.RankK
221
+ , OperationKind.Trmm
222
+ , OperationKind.Symm
223
+ ]
224
+ self.operations_enabled = [x for x in operations_list if OperationKindNames[x] in args.operations.split(',')]
225
+
226
+ if args.kernels == 'all':
227
+ self.kernel_names = []
228
+ else:
229
+ self.kernel_names = [x for x in args.kernels.split(',') if x != '']
230
+
231
+ self.ignore_kernel_names = [x for x in args.ignore_kernels.split(',') if x != '']
232
+
233
+ if args.kernel_filter_file is None:
234
+ self.kernel_filter_list = []
235
+ else:
236
+ self.kernel_filter_list = self.get_kernel_filters(args.kernel_filter_file)
237
+
238
+ #
239
+ def get_kernel_filters (self, kernelListFile):
240
+ if os.path.isfile(kernelListFile):
241
+ with open(kernelListFile, 'r') as fileReader:
242
+ lines = [line.rstrip() for line in fileReader if not line.startswith("#")]
243
+
244
+ lines = [re.compile(line) for line in lines if line]
245
+ return lines
246
+ else:
247
+ return []
248
+
249
+ #
250
+ def filter_out_kernels(self, kernel_name, kernel_filter_list):
251
+
252
+ for kernel_filter_re in kernel_filter_list:
253
+ if kernel_filter_re.search(kernel_name) is not None:
254
+ return True
255
+
256
+ return False
257
+
258
+
259
+ #
260
+ def _filter_string_matches(self, filter_string, haystack):
261
+ ''' Returns true if all substrings appear in the haystack in order'''
262
+ substrings = filter_string.split('*')
263
+ for sub in substrings:
264
+ idx = haystack.find(sub)
265
+ if idx < 0:
266
+ return False
267
+ haystack = haystack[idx + len(sub):]
268
+ return True
269
+
270
+ #
271
+ def filter(self, operation):
272
+ ''' Filtering operations based on various criteria'''
273
+
274
+ # filter based on compute capability
275
+ enabled = not (self.filter_by_cc)
276
+
277
+ for cc in self.compute_capabilities:
278
+ if cc >= operation.tile_description.minimum_compute_capability and \
279
+ cc <= operation.tile_description.maximum_compute_capability and \
280
+ (cc not in SharedMemPerCC or SharedMemPerCC[cc] >= CalculateSmemUsage(operation)):
281
+
282
+ enabled = True
283
+ break
284
+
285
+ if not enabled:
286
+ return False
287
+
288
+ if len(self.operations_enabled) and not operation.operation_kind in self.operations_enabled:
289
+ return False
290
+
291
+ # eliminate duplicates
292
+ if operation.procedural_name() in self.operations_by_name.keys():
293
+ return False
294
+
295
+ # Filter based on list of valid substrings
296
+ if len(self.kernel_names):
297
+ name = operation.procedural_name()
298
+ enabled = False
299
+
300
+ # compare against the include list
301
+ for name_substr in self.kernel_names:
302
+ if self._filter_string_matches(name_substr, name):
303
+ enabled = True
304
+ break
305
+
306
+ # compare against the exclude list
307
+ for name_substr in self.ignore_kernel_names:
308
+ if self._filter_string_matches(name_substr, name):
309
+ enabled = False
310
+ break
311
+
312
+ if len(self.kernel_filter_list) > 0:
313
+ enabled = False
314
+ if self.filter_out_kernels(operation.procedural_name(), self.kernel_filter_list):
315
+ enabled = True
316
+
317
+ # todo: filter based on compute data type
318
+ return enabled
319
+ #
320
+
321
+ #
322
+ def append(self, operation):
323
+ '''
324
+ Inserts the operation.
325
+
326
+ operation_kind -> configuration_name -> []
327
+ '''
328
+
329
+ if self.filter(operation):
330
+
331
+ self.selected_kernels.append(operation.procedural_name())
332
+
333
+ self.operations_by_name[operation.procedural_name()] = operation
334
+
335
+ # add the configuration
336
+ configuration_name = operation.configuration_name()
337
+
338
+ if operation.operation_kind not in self.operations.keys():
339
+ self.operations[operation.operation_kind] = {}
340
+
341
+ if configuration_name not in self.operations[operation.operation_kind].keys():
342
+ self.operations[operation.operation_kind][configuration_name] = []
343
+
344
+ self.operations[operation.operation_kind][configuration_name].append(operation)
345
+ self.operation_count += 1
346
+ #
347
+
348
+ #
349
+ def emit(self, target = GeneratorTarget.Library):
350
+
351
+ operation_emitters = {
352
+ GeneratorTarget.Library: EmitOperationKindLibrary
353
+ }
354
+ interface_emitters = {
355
+ GeneratorTarget.Library: EmitInterfaceLibrary
356
+ }
357
+
358
+ generated_path = os.path.join(self.curr_build_dir, 'generated')
359
+
360
+ # create generated/
361
+ if os.path.exists(generated_path):
362
+ shutil.rmtree(generated_path)
363
+
364
+ os.mkdir(generated_path)
365
+
366
+ source_files = []
367
+
368
+ with interface_emitters[target](generated_path, self.operation_count, self.args) as iface_emitter:
369
+ for operation_kind, configurations in self.operations.items():
370
+ iface_emitter.emit(OperationKindNames[operation_kind])
371
+
372
+ source_files += iface_emitter.source_files
373
+
374
+
375
+ # for each operation kind, emit initializer for all configurations
376
+ for operation_kind, configurations in self.operations.items():
377
+ with operation_emitters[target](generated_path, operation_kind, self.args) as operation_kind_emitter:
378
+ for configuration_name, operations in configurations.items():
379
+ operation_kind_emitter.emit(configuration_name, operations)
380
+
381
+ source_files += operation_kind_emitter.source_files
382
+
383
+ # write the manifest.cmake file containing paths from all targets
384
+ manifest_path = os.path.join(generated_path, "manifest.cmake")
385
+ with open(manifest_path, "w") as manifest_file:
386
+
387
+ target_name = 'cutlass_library_objs'
388
+
389
+ target_text = SubstituteTemplate("""cutlass_target_sources(
390
+ ${target_name}
391
+ BATCH_SOURCES ON
392
+ PRIVATE
393
+ """, { 'target_name': target_name})
394
+
395
+ manifest_file.write(target_text)
396
+
397
+ for source_file in source_files:
398
+ manifest_file.write(" %s\n" % str(source_file.replace('\\', '/')))
399
+ manifest_file.write(")")
400
+ #
401
+
402
+ ###################################################################################################
@@ -0,0 +1,96 @@
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ # Configuration file for the Sphinx documentation builder.
34
+ #
35
+ # This file only contains a selection of the most common options. For a full
36
+ # list see the documentation:
37
+ # https://www.sphinx-doc.org/en/master/usage/configuration.html
38
+
39
+ # -- Path setup --------------------------------------------------------------
40
+
41
+ # If extensions (or modules to document with autodoc) are in another directory,
42
+ # add these directories to sys.path here. If the directory is relative to the
43
+ # documentation root, use os.path.abspath to make it absolute, like shown here.
44
+ #
45
+ # import os
46
+ # import sys
47
+ # sys.path.insert(0, os.path.abspath('.'))
48
+
49
+
50
+ # -- Project information -----------------------------------------------------
51
+
52
+ project = 'PyCutlass'
53
+ copyright = '2022, Zhaodong Chen; Andrew Kerr; Haicheng Wu; Szymon Migacz; Graham Markall'
54
+ author = 'Zhaodong Chen; Andrew Kerr; Haicheng Wu; Szymon Migacz; Graham Markall'
55
+
56
+
57
+ # -- General configuration ---------------------------------------------------
58
+
59
+ # Add any Sphinx extension module names here, as strings. They can be
60
+ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
61
+ # ones.
62
+ extensions = [
63
+ 'sphinx.ext.duration',
64
+ 'sphinx.ext.doctest',
65
+ 'sphinx.ext.autodoc',
66
+ 'sphinx.ext.intersphinx',
67
+ 'enum_tools.autoenum',
68
+ 'sphinx.ext.autosummary',
69
+ 'm2r2'
70
+ ]
71
+
72
+ source_suffix = [".rst", ".md"]
73
+
74
+ autosummary_generate = True
75
+ autosummary_imported_members = True
76
+
77
+ # Add any paths that contain templates here, relative to this directory.
78
+ templates_path = ['_templates']
79
+
80
+ # List of patterns, relative to source directory, that match files and
81
+ # directories to ignore when looking for source files.
82
+ # This pattern also affects html_static_path and html_extra_path.
83
+ exclude_patterns = []
84
+
85
+
86
+ # -- Options for HTML output -------------------------------------------------
87
+
88
+ # The theme to use for HTML and HTML Help pages. See the documentation for
89
+ # a list of builtin themes.
90
+ #
91
+ html_theme = 'bizstyle'
92
+
93
+ # Add any paths that contain custom static files (such as style sheets) here,
94
+ # relative to this directory. They are copied after the builtin static files,
95
+ # so a file named "default.css" will overwrite the builtin "default.css".
96
+ # html_static_path = ['_static']
@@ -0,0 +1,106 @@
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ from pycutlass import *
34
+ import pycutlass
35
+ from pycutlass.epilogue import LinearCombination
36
+ from pycutlass.test.conv2d_testbed import Conv2dLauncher
37
+
38
+
39
+ if __name__ == "__main__":
40
+ pycutlass.get_memory_pool(2**33, 2**33)
41
+ pycutlass.compiler.nvcc()
42
+
43
+ math_inst = MathInstruction(
44
+ instruction_shape=[16, 8, 16],
45
+ element_a=cutlass.float16, element_b=cutlass.float16,
46
+ element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp,
47
+ math_operation=MathOperation.multiply_add
48
+ )
49
+
50
+ A = TensorDescription(
51
+ element=math_inst.element_a,
52
+ layout=cutlass.TensorNHWC,
53
+ alignment=8)
54
+ B = TensorDescription(
55
+ element=math_inst.element_b,
56
+ layout=cutlass.TensorNHWC,
57
+ alignment=8)
58
+ C = TensorDescription(
59
+ element=cutlass.float32,
60
+ layout=cutlass.TensorNHWC,
61
+ alignment=8)
62
+
63
+ tile_description = TileDescription(
64
+ threadblock_shape=[128, 128, 64], stages=4,
65
+ warp_count=[2, 2, 1],
66
+ math_instruction=math_inst
67
+ )
68
+
69
+ epilogue_functor = LinearCombination(cutlass.float32, 4, cutlass.float32, cutlass.float32)
70
+
71
+ operation = Conv2dOperation(
72
+ conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
73
+ arch=80, tile_description=tile_description, A=A, B=B, C=C,
74
+ element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided,
75
+ epilogue_functor=epilogue_functor,
76
+ swizzling_functor=cutlass.IdentitySwizzle1
77
+ )
78
+
79
+ profiler = Conv2dLauncher(operation, verification=False, profiling=True)
80
+
81
+ python_runtime = profiler.run(
82
+ problem_size = cutlass.conv.Conv2dProblemSize(
83
+ cutlass.Tensor4DCoord(32, 224, 224, 128),
84
+ cutlass.Tensor4DCoord(128, 3, 3, 128),
85
+ cutlass.Tensor4DCoord(1, 1, 1, 1),
86
+ cutlass.MatrixCoord(1, 1),
87
+ cutlass.MatrixCoord(1, 1),
88
+ cutlass.conv.Mode.cross_correlation,
89
+ 1, 1
90
+ ), split_k_mode=cutlass.conv.SplitKMode.Serial
91
+ )
92
+
93
+
94
+ cpp_runtime = profiler.run_cutlass_profiler(
95
+ problem_size = cutlass.conv.Conv2dProblemSize(
96
+ cutlass.Tensor4DCoord(32, 224, 224, 128),
97
+ cutlass.Tensor4DCoord(128, 3, 3, 128),
98
+ cutlass.Tensor4DCoord(1, 1, 1, 1),
99
+ cutlass.MatrixCoord(1, 1),
100
+ cutlass.MatrixCoord(1, 1),
101
+ cutlass.conv.Mode.cross_correlation,
102
+ 1, 1
103
+ ), split_k_mode=cutlass.conv.SplitKMode.Serial
104
+ )
105
+
106
+ print(cpp_runtime / python_runtime)
@@ -0,0 +1,91 @@
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ import pycutlass
34
+ from pycutlass import *
35
+ from pycutlass.test import *
36
+ from pycutlass.test.gemm_testbed import GemmUniversalLauncher
37
+
38
+ if __name__ == '__main__':
39
+ pycutlass.get_memory_pool(2**32, 2**32)
40
+ pycutlass.compiler.nvcc()
41
+
42
+ math_inst = MathInstruction(
43
+ instruction_shape=[16, 8, 16],
44
+ element_a=cutlass.float16, element_b=cutlass.float16,
45
+ element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp,
46
+ math_operation=MathOperation.multiply_add
47
+ )
48
+
49
+ tile_description = TileDescription(
50
+ threadblock_shape=[256, 128, 32],
51
+ stages=3, warp_count=[4, 2, 1],
52
+ math_instruction=math_inst
53
+ )
54
+
55
+ A = TensorDescription(
56
+ element=cutlass.float16, layout=cutlass.RowMajor,
57
+ alignment=4
58
+ )
59
+ B = TensorDescription(
60
+ element=cutlass.float16, layout=cutlass.RowMajor,
61
+ alignment=4
62
+ )
63
+ C = TensorDescription(
64
+ element=cutlass.float32, layout=cutlass.ColumnMajor,
65
+ alignment=4
66
+ )
67
+
68
+ element_epilogue = cutlass.float32
69
+
70
+ epilogue_functor = LinearCombination(cutlass.float32, 4, cutlass.float32, cutlass.float32)
71
+
72
+ swizzling_functor = cutlass.IdentitySwizzle1
73
+
74
+ operation = GemmOperationUniversal(
75
+ arch=80, tile_description=tile_description,
76
+ A=A, B=B, C=C, element_epilogue=element_epilogue,
77
+ epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
78
+ )
79
+
80
+ profiler = GemmUniversalLauncher(operation, verification=False, profiling=True)
81
+ python_runtime = profiler.run(
82
+ mode=cutlass.gemm.Mode.Gemm,
83
+ problem_size=cutlass.gemm.GemmCoord(4096, 4096, 4096)
84
+ )
85
+
86
+ cpp_runtime = profiler.run_cutlass_profiler(
87
+ mode=cutlass.gemm.Mode.Gemm,
88
+ problem_size=cutlass.gemm.GemmCoord(4096, 4096, 4096),
89
+ )
90
+
91
+ print(cpp_runtime / python_runtime)