warp-lang 0.10.1__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (300) hide show
  1. warp/__init__.py +10 -4
  2. warp/__init__.pyi +1 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +5 -3
  6. warp/build_dll.py +29 -9
  7. warp/builtins.py +868 -507
  8. warp/codegen.py +1074 -638
  9. warp/config.py +3 -3
  10. warp/constants.py +6 -0
  11. warp/context.py +715 -222
  12. warp/fabric.py +326 -0
  13. warp/fem/__init__.py +27 -0
  14. warp/fem/cache.py +389 -0
  15. warp/fem/dirichlet.py +181 -0
  16. warp/fem/domain.py +263 -0
  17. warp/fem/field/__init__.py +101 -0
  18. warp/fem/field/field.py +149 -0
  19. warp/fem/field/nodal_field.py +299 -0
  20. warp/fem/field/restriction.py +21 -0
  21. warp/fem/field/test.py +181 -0
  22. warp/fem/field/trial.py +183 -0
  23. warp/fem/geometry/__init__.py +19 -0
  24. warp/fem/geometry/closest_point.py +70 -0
  25. warp/fem/geometry/deformed_geometry.py +271 -0
  26. warp/fem/geometry/element.py +744 -0
  27. warp/fem/geometry/geometry.py +186 -0
  28. warp/fem/geometry/grid_2d.py +373 -0
  29. warp/fem/geometry/grid_3d.py +435 -0
  30. warp/fem/geometry/hexmesh.py +953 -0
  31. warp/fem/geometry/partition.py +376 -0
  32. warp/fem/geometry/quadmesh_2d.py +532 -0
  33. warp/fem/geometry/tetmesh.py +840 -0
  34. warp/fem/geometry/trimesh_2d.py +577 -0
  35. warp/fem/integrate.py +1616 -0
  36. warp/fem/operator.py +191 -0
  37. warp/fem/polynomial.py +213 -0
  38. warp/fem/quadrature/__init__.py +2 -0
  39. warp/fem/quadrature/pic_quadrature.py +245 -0
  40. warp/fem/quadrature/quadrature.py +294 -0
  41. warp/fem/space/__init__.py +292 -0
  42. warp/fem/space/basis_space.py +489 -0
  43. warp/fem/space/collocated_function_space.py +105 -0
  44. warp/fem/space/dof_mapper.py +236 -0
  45. warp/fem/space/function_space.py +145 -0
  46. warp/fem/space/grid_2d_function_space.py +267 -0
  47. warp/fem/space/grid_3d_function_space.py +306 -0
  48. warp/fem/space/hexmesh_function_space.py +352 -0
  49. warp/fem/space/partition.py +350 -0
  50. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  51. warp/fem/space/restriction.py +160 -0
  52. warp/fem/space/shape/__init__.py +15 -0
  53. warp/fem/space/shape/cube_shape_function.py +738 -0
  54. warp/fem/space/shape/shape_function.py +103 -0
  55. warp/fem/space/shape/square_shape_function.py +611 -0
  56. warp/fem/space/shape/tet_shape_function.py +567 -0
  57. warp/fem/space/shape/triangle_shape_function.py +429 -0
  58. warp/fem/space/tetmesh_function_space.py +292 -0
  59. warp/fem/space/topology.py +295 -0
  60. warp/fem/space/trimesh_2d_function_space.py +221 -0
  61. warp/fem/types.py +77 -0
  62. warp/fem/utils.py +495 -0
  63. warp/native/array.h +147 -44
  64. warp/native/builtin.h +122 -149
  65. warp/native/bvh.cpp +73 -325
  66. warp/native/bvh.cu +406 -23
  67. warp/native/bvh.h +34 -43
  68. warp/native/clang/clang.cpp +13 -8
  69. warp/native/crt.h +2 -0
  70. warp/native/cuda_crt.h +5 -0
  71. warp/native/cuda_util.cpp +15 -3
  72. warp/native/cuda_util.h +3 -1
  73. warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
  74. warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
  75. warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
  76. warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
  77. warp/native/cutlass/tools/library/scripts/library.py +799 -0
  78. warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
  79. warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
  80. warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
  81. warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
  82. warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
  83. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
  84. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
  85. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
  86. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
  87. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
  88. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
  89. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
  90. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
  91. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
  92. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
  93. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
  94. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
  95. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
  96. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
  97. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
  98. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
  99. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
  100. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
  101. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
  102. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
  103. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
  104. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
  105. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
  106. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
  107. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
  108. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
  109. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
  110. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
  111. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
  112. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
  113. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
  114. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
  115. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  116. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  117. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
  118. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
  119. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
  120. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  121. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  122. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
  123. warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
  124. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
  125. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
  126. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
  127. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
  128. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
  129. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
  130. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
  131. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
  132. warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
  133. warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
  134. warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
  135. warp/native/cutlass/tools/library/scripts/rt.py +796 -0
  136. warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
  137. warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
  138. warp/native/cutlass_gemm.cu +5 -3
  139. warp/native/exports.h +1240 -952
  140. warp/native/fabric.h +228 -0
  141. warp/native/hashgrid.cpp +4 -4
  142. warp/native/hashgrid.h +22 -2
  143. warp/native/intersect.h +22 -7
  144. warp/native/intersect_adj.h +8 -8
  145. warp/native/intersect_tri.h +1 -1
  146. warp/native/marching.cu +157 -161
  147. warp/native/mat.h +80 -19
  148. warp/native/matnn.h +2 -2
  149. warp/native/mesh.cpp +33 -108
  150. warp/native/mesh.cu +114 -23
  151. warp/native/mesh.h +446 -46
  152. warp/native/noise.h +272 -329
  153. warp/native/quat.h +51 -8
  154. warp/native/rand.h +45 -35
  155. warp/native/range.h +6 -2
  156. warp/native/reduce.cpp +1 -1
  157. warp/native/reduce.cu +10 -12
  158. warp/native/runlength_encode.cu +6 -10
  159. warp/native/scan.cu +8 -11
  160. warp/native/sparse.cpp +4 -4
  161. warp/native/sparse.cu +164 -154
  162. warp/native/spatial.h +2 -2
  163. warp/native/temp_buffer.h +14 -30
  164. warp/native/vec.h +107 -23
  165. warp/native/volume.h +120 -0
  166. warp/native/warp.cpp +560 -30
  167. warp/native/warp.cu +431 -44
  168. warp/native/warp.h +13 -4
  169. warp/optim/__init__.py +1 -0
  170. warp/optim/linear.py +922 -0
  171. warp/optim/sgd.py +92 -0
  172. warp/render/render_opengl.py +335 -119
  173. warp/render/render_usd.py +11 -11
  174. warp/sim/__init__.py +2 -2
  175. warp/sim/articulation.py +385 -185
  176. warp/sim/collide.py +8 -0
  177. warp/sim/import_mjcf.py +297 -106
  178. warp/sim/import_urdf.py +389 -210
  179. warp/sim/import_usd.py +198 -97
  180. warp/sim/inertia.py +17 -18
  181. warp/sim/integrator_euler.py +14 -8
  182. warp/sim/integrator_xpbd.py +158 -16
  183. warp/sim/model.py +795 -291
  184. warp/sim/render.py +3 -3
  185. warp/sim/utils.py +3 -0
  186. warp/sparse.py +640 -150
  187. warp/stubs.py +606 -267
  188. warp/tape.py +61 -10
  189. warp/tests/__main__.py +3 -6
  190. warp/tests/assets/curlnoise_golden.npy +0 -0
  191. warp/tests/assets/pnoise_golden.npy +0 -0
  192. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  193. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  194. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  195. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  196. warp/tests/aux_test_unresolved_func.py +14 -0
  197. warp/tests/aux_test_unresolved_symbol.py +14 -0
  198. warp/tests/disabled_kinematics.py +239 -0
  199. warp/tests/run_coverage_serial.py +31 -0
  200. warp/tests/test_adam.py +103 -106
  201. warp/tests/test_arithmetic.py +128 -74
  202. warp/tests/test_array.py +212 -97
  203. warp/tests/test_array_reduce.py +57 -23
  204. warp/tests/test_atomic.py +64 -28
  205. warp/tests/test_bool.py +99 -0
  206. warp/tests/test_builtins_resolution.py +1292 -0
  207. warp/tests/test_bvh.py +42 -18
  208. warp/tests/test_closest_point_edge_edge.py +54 -57
  209. warp/tests/test_codegen.py +208 -130
  210. warp/tests/test_compile_consts.py +28 -20
  211. warp/tests/test_conditional.py +108 -24
  212. warp/tests/test_copy.py +10 -12
  213. warp/tests/test_ctypes.py +112 -88
  214. warp/tests/test_dense.py +21 -14
  215. warp/tests/test_devices.py +98 -0
  216. warp/tests/test_dlpack.py +75 -75
  217. warp/tests/test_examples.py +277 -0
  218. warp/tests/test_fabricarray.py +955 -0
  219. warp/tests/test_fast_math.py +15 -11
  220. warp/tests/test_fem.py +1271 -0
  221. warp/tests/test_fp16.py +53 -19
  222. warp/tests/test_func.py +187 -86
  223. warp/tests/test_generics.py +194 -49
  224. warp/tests/test_grad.py +178 -109
  225. warp/tests/test_grad_customs.py +176 -0
  226. warp/tests/test_hash_grid.py +52 -37
  227. warp/tests/test_import.py +10 -23
  228. warp/tests/test_indexedarray.py +32 -31
  229. warp/tests/test_intersect.py +18 -9
  230. warp/tests/test_large.py +141 -0
  231. warp/tests/test_launch.py +14 -41
  232. warp/tests/test_lerp.py +64 -65
  233. warp/tests/test_linear_solvers.py +154 -0
  234. warp/tests/test_lvalue.py +493 -0
  235. warp/tests/test_marching_cubes.py +12 -13
  236. warp/tests/test_mat.py +517 -2898
  237. warp/tests/test_mat_lite.py +115 -0
  238. warp/tests/test_mat_scalar_ops.py +2889 -0
  239. warp/tests/test_math.py +103 -9
  240. warp/tests/test_matmul.py +305 -69
  241. warp/tests/test_matmul_lite.py +410 -0
  242. warp/tests/test_mesh.py +71 -14
  243. warp/tests/test_mesh_query_aabb.py +41 -25
  244. warp/tests/test_mesh_query_point.py +140 -22
  245. warp/tests/test_mesh_query_ray.py +39 -22
  246. warp/tests/test_mlp.py +30 -22
  247. warp/tests/test_model.py +92 -89
  248. warp/tests/test_modules_lite.py +39 -0
  249. warp/tests/test_multigpu.py +88 -114
  250. warp/tests/test_noise.py +12 -11
  251. warp/tests/test_operators.py +16 -20
  252. warp/tests/test_options.py +11 -11
  253. warp/tests/test_pinned.py +17 -18
  254. warp/tests/test_print.py +32 -11
  255. warp/tests/test_quat.py +275 -129
  256. warp/tests/test_rand.py +18 -16
  257. warp/tests/test_reload.py +38 -34
  258. warp/tests/test_rounding.py +50 -43
  259. warp/tests/test_runlength_encode.py +168 -20
  260. warp/tests/test_smoothstep.py +9 -11
  261. warp/tests/test_snippet.py +143 -0
  262. warp/tests/test_sparse.py +261 -63
  263. warp/tests/test_spatial.py +276 -243
  264. warp/tests/test_streams.py +110 -85
  265. warp/tests/test_struct.py +268 -63
  266. warp/tests/test_tape.py +39 -21
  267. warp/tests/test_torch.py +118 -89
  268. warp/tests/test_transient_module.py +12 -13
  269. warp/tests/test_types.py +614 -0
  270. warp/tests/test_utils.py +494 -0
  271. warp/tests/test_vec.py +354 -2050
  272. warp/tests/test_vec_lite.py +73 -0
  273. warp/tests/test_vec_scalar_ops.py +2099 -0
  274. warp/tests/test_volume.py +457 -293
  275. warp/tests/test_volume_write.py +124 -134
  276. warp/tests/unittest_serial.py +35 -0
  277. warp/tests/unittest_suites.py +341 -0
  278. warp/tests/unittest_utils.py +568 -0
  279. warp/tests/unused_test_misc.py +71 -0
  280. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  281. warp/thirdparty/appdirs.py +36 -45
  282. warp/thirdparty/unittest_parallel.py +549 -0
  283. warp/torch.py +9 -6
  284. warp/types.py +1089 -366
  285. warp/utils.py +93 -387
  286. warp_lang-0.11.0.dist-info/METADATA +238 -0
  287. warp_lang-0.11.0.dist-info/RECORD +332 -0
  288. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
  289. warp/tests/test_all.py +0 -219
  290. warp/tests/test_array_scan.py +0 -60
  291. warp/tests/test_base.py +0 -208
  292. warp/tests/test_unresolved_func.py +0 -7
  293. warp/tests/test_unresolved_symbol.py +0 -7
  294. warp_lang-0.10.1.dist-info/METADATA +0 -21
  295. warp_lang-0.10.1.dist-info/RECORD +0 -188
  296. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  297. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  298. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  299. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/LICENSE.md +0 -0
  300. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1026 @@
1
+ ################################################################################
2
+ #
3
+ # Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ ################################################################################
32
+
33
+ from ast import Num
34
+ from audioop import mul
35
+ from pipes import Template
36
+ import struct
37
+ from pycutlass.library import DataTypeTag
38
+ from pycutlass import *
39
+ import cutlass
40
+ from scipy.special import erf
41
+
42
+ from pycutlass.c_types import MatrixCoord_
43
+ from pycutlass.frontend import NumpyFrontend
44
+
45
+ from cuda import cuda
46
+ from cuda import cudart
47
+
48
+ dtype2ctype = {
49
+ cutlass.float16: ctypes.c_uint16,
50
+ cutlass.float32: ctypes.c_float,
51
+ cutlass.float64: ctypes.c_double,
52
+ cutlass.int32: ctypes.c_int32
53
+ }
54
+
55
+
56
+ #################################################################################################
57
+ #
58
+ # Epilogue Functors
59
+ #
60
+ #################################################################################################
61
+
62
+ class EpilogueFunctorBase:
63
+ """
64
+ Base class for thread-level epilogue functors
65
+ """
66
+ def __init__(self) -> None:
67
+ pass
68
+
69
+ def emit(self, tag, template_argument):
70
+ template = """${tag}<${arguments}>"""
71
+ arguments = ""
72
+ for idx, arg in enumerate(template_argument):
73
+ arguments += arg
74
+ if idx < len(template_argument) - 1:
75
+ arguments += ", "
76
+ values = {
77
+ "tag": tag,
78
+ "arguments": arguments
79
+ }
80
+
81
+ return SubstituteTemplate(template, values)
82
+
83
+
84
+
85
+ class LinearCombination(EpilogueFunctorBase):
86
+ """
87
+ Apply a linear combination operator to an array of elements
88
+ D = alpha * accumulator + beta * source
89
+
90
+ :param element_output: data type used to load and store tensors
91
+
92
+ :param epilogue_vector_length: number of elements computed per operation.
93
+ Usually it is 128/sizeof_bits<ElementOutput_>, but we use 64 and 32 sometimes
94
+ when there are not enough data to store
95
+
96
+ :param element_accumulator: Accumulator data type
97
+
98
+ :param element_epilogue: data type used to compute linear combination
99
+ """
100
+ tag = "cutlass::epilogue::thread::LinearCombination"
101
+ def __init__(
102
+ self, element_output, epilogue_vector_length,
103
+ element_accumulator=None, element_epilogue=None) -> None: # TODO bind ScaleType
104
+ super().__init__()
105
+
106
+ if element_accumulator is None:
107
+ element_accumulator = element_output
108
+ if element_epilogue is None:
109
+ element_epilogue = element_output
110
+
111
+ self.element_output = element_output
112
+ self.element_accumulator = element_accumulator
113
+ self.element_epilogue = element_epilogue
114
+
115
+ self.template_arguments = [
116
+ DataTypeTag[element_output], str(epilogue_vector_length),
117
+ DataTypeTag[element_accumulator], DataTypeTag[element_epilogue]
118
+ ]
119
+
120
+ # get epilogue output op type
121
+ c_element_epilogue = dtype2ctype[self.element_epilogue]
122
+ element_epilogue = self.element_epilogue
123
+
124
+ class _EpilogueOutputOpParams(ctypes.Structure):
125
+ _fields_ = [
126
+ ("alpha_data", ctypes.c_longlong*2),
127
+ ("beta_data", ctypes.c_longlong*2),
128
+ ("alpha", c_element_epilogue),
129
+ ("beta", c_element_epilogue),
130
+ ("alpha_ptr", ctypes.c_void_p),
131
+ ("beta_ptr", ctypes.c_void_p),
132
+ ]
133
+ def __init__(self, alpha, beta, *args) -> None:
134
+ self.alpha = element_epilogue(alpha).storage
135
+ self.beta = element_epilogue(beta).storage
136
+ self.epilogue_type = _EpilogueOutputOpParams
137
+
138
+ def emit(self):
139
+ return super().emit(self.tag, self.template_arguments)
140
+
141
+
142
+ class LinearCombinationClamp(LinearCombination):
143
+ """
144
+ Applies a linear combination operator to an array of elements then clamps
145
+ the output before converting to the output element type.
146
+
147
+ D = alpha * accumulator + beta * source + uniform
148
+
149
+ :param element_output: data type used to load and store tensors
150
+
151
+ :param epilogue_vector_length: number of elements computed per operation.
152
+ Usually it is 128/sizeof_bits<ElementOutput_>, but we use 64 and 32 sometimes
153
+ when there are not enough data to store
154
+
155
+ :param element_accumulator: Accumulator data type
156
+
157
+ :param element_epilogue: data type used to compute linear combination
158
+ """
159
+ tag = "cutlass::epilogue::thread::LinearCombinationClamp"
160
+ def __init__(
161
+ self, element_output, epilogue_vector_length,
162
+ element_accumulator=None, element_epilogue=None) -> None:
163
+ # Base constructor
164
+ super().__init__(
165
+ element_output, epilogue_vector_length,
166
+ element_accumulator, element_epilogue)
167
+
168
+ c_element_epilogue = dtype2ctype[self.element_epilogue]
169
+ element_epilogue = self.element_epilogue
170
+
171
+ class _EpilogueOutputOpParams(ctypes.Structure):
172
+ _fields_ = [
173
+ ("alpha", c_element_epilogue),
174
+ ("beta", c_element_epilogue),
175
+ ("alpha_ptr", ctypes.c_void_p),
176
+ ("beta_ptr", ctypes.c_void_p),
177
+ ]
178
+ def __init__(self, alpha, beta, *args) -> None:
179
+ self.alpha = element_epilogue(alpha).storage
180
+ self.beta = element_epilogue(beta).storage
181
+ self.epilogue_type = _EpilogueOutputOpParams
182
+
183
+
184
+ class FastLinearCombinationClamp(EpilogueFunctorBase):
185
+ """
186
+ Applies a linear combination operator to an array of elements then clamps
187
+ the output before converting to the output element type.
188
+
189
+ D = alpha * accumulator + beta * source
190
+
191
+ Note: The below method only when problem_size_K <= 256 for signed int8 gemm
192
+ or problem_size_K <= 128 for unsigned int8 gemm. The default approach is
193
+ above.
194
+
195
+ :param element_output: data type used to load and store tensors
196
+
197
+ :param epilogue_vector_length: number of elements computed per operation.
198
+ Usually it is 128/sizeof_bits<ElementOutput_>, but we use 64 and 32 sometimes
199
+ when there are not enough data to store
200
+ """
201
+ tag = "cutlass::epilogue::thread::FastLinearCombinationClamp"
202
+ def __init__(self, element_output, epilogue_vector_length, *args) -> None:
203
+ super().__init__()
204
+
205
+ self.template_arguments = [
206
+ DataTypeTag[element_output], str(epilogue_vector_length)
207
+ ]
208
+
209
+ self.element_accumulator = cutlass.int32
210
+ self.element_epilogue = cutlass.float32
211
+
212
+ # get epilogue output op
213
+ c_element_epilogue = dtype2ctype[self.element_epilogue]
214
+ element_epilogue = self.element_epilogue
215
+
216
+ class _EpilogueOutputOpParams(ctypes.Structure):
217
+ _fields_ = [
218
+ ("alpha", c_element_epilogue),
219
+ ("beta", c_element_epilogue),
220
+ ("alpha_ptr", ctypes.c_void_p),
221
+ ("beta_ptr", ctypes.c_void_p),
222
+ ]
223
+ def __init__(self, alpha, beta, *args) -> None:
224
+ self.alpha = element_epilogue(alpha).storage
225
+ self.beta = element_epilogue(beta).storage
226
+ self.epilogue_type = _EpilogueOutputOpParams
227
+
228
+ def emit(self):
229
+ return super().emit(self.tag, self.template_arguments)
230
+
231
+
232
+ class LinearCombinationGeneric(LinearCombination):
233
+ """
234
+ Applies a linear combination operator followed by an activation function
235
+ to an array of elements.
236
+
237
+ D = activation(alpha * accumulator + beta * source)
238
+
239
+ :param activation_functor: input activation functor
240
+
241
+ :param element_output: data type used to load and store tensors
242
+
243
+ :param epilogue_vector_length: number of elements computed per operation.
244
+ Usually it is 128/sizeof_bits<ElementOutput_>, but we use 64 and 32 sometimes
245
+ when there are not enough data to store
246
+
247
+ :param element_accumulator: Accumulator data type
248
+
249
+ :param element_epilogue: data type used to compute linear combination
250
+ """
251
+ tag = "cutlass::epilogue::thread::LinearCombinationGeneric"
252
+ def __init__(
253
+ self, activation_functor,
254
+ element_output, epilogue_vector_length,
255
+ element_accumulator=None, element_epilogue=None) -> None:
256
+ super().__init__(
257
+ element_output, epilogue_vector_length,
258
+ element_accumulator, element_epilogue)
259
+
260
+ self.template_arguments = [
261
+ activation_functor.emit(),] + self.template_arguments
262
+
263
+ self.activation_functor = activation_functor
264
+ self.element_epilogue = element_epilogue
265
+
266
+ # get epilogue output op
267
+ self.epilogue_type = self.activation_functor.epilogue_output_op(self.element_epilogue)
268
+
269
+
270
+ class ActivationFunctor:
271
+ """
272
+ Base class for frequently used activation functions
273
+ """
274
+ def __init__(self, element_compute) -> None:
275
+ pass
276
+ @staticmethod
277
+ def numpy(x: np.ndarray):
278
+ raise NotImplementedError()
279
+
280
+ def emit(self):
281
+ return self.tag
282
+
283
+ @staticmethod
284
+ def epilogue_output_op(element_epilogue):
285
+ c_element_epilogue = dtype2ctype[element_epilogue]
286
+
287
+ class _EpilogueOutputOpParams(ctypes.Structure):
288
+ _fields_ = [
289
+ ("alpha", c_element_epilogue),
290
+ ("beta", c_element_epilogue),
291
+ ("alpha_ptr", ctypes.c_void_p),
292
+ ("beta_ptr", ctypes.c_void_p),
293
+ ]
294
+ def __init__(self, alpha, beta, *args) -> None:
295
+ self.alpha = element_epilogue(alpha).storage
296
+ self.beta = element_epilogue(beta).storage
297
+ return _EpilogueOutputOpParams
298
+
299
+ # identity operator
300
+ class identity(ActivationFunctor):
301
+ def numpy(x: np.ndarray):
302
+ return x
303
+
304
+ # ReLu operator,
305
+ class relu(ActivationFunctor):
306
+ tag = "cutlass::epilogue::thread::ReLu"
307
+
308
+ def __init__(self, element_compute):
309
+ super().__init__(element_compute)
310
+ class _Arguments(ctypes.Structure):
311
+ _fields_ = [
312
+ ("threshold", dtype2ctype[element_compute])
313
+ ]
314
+ def __init__(self, threshold=0.) -> None:
315
+ self.threshold = element_compute(threshold).storage
316
+ self.argument_type = _Arguments
317
+
318
+ def emit_visitor(self):
319
+ return "cutlass::ReLUVisitor"
320
+
321
+ @staticmethod
322
+ def numpy(x: np.ndarray):
323
+ return np.maximum(x, 0)
324
+
325
+ # Leaky ReLu operator
326
+ class leaky_relu(ActivationFunctor):
327
+ tag = "cutlass::epilogue::thread::LeakyReLU"
328
+
329
+ def __init__(self, element_compute) -> None:
330
+ super().__init__(element_compute)
331
+ class _Arguments(ctypes.Structure):
332
+ _fields_ = [
333
+ ("leaky_alpha", dtype2ctype[element_compute])
334
+ ]
335
+ def __init__(self, leaky_alpha) -> None:
336
+ self.leaky_alpha = element_compute(leaky_alpha).storage
337
+ self.argument_type = _Arguments
338
+
339
+ def emit_visitor(self):
340
+ return "cutlass::LeakyReLUVisitor"
341
+
342
+ @staticmethod
343
+ def numpy(x: np.ndarray, leaky_alpha):
344
+ return np.maximum(x, 0) + np.minimum(x, 0) * leaky_alpha
345
+
346
+ def epilogue_output_op(self, element_epilogue):
347
+ c_element_epilogue = dtype2ctype[element_epilogue]
348
+ class _EpilogueOutputOpParams(ctypes.Structure):
349
+ _fields_ = [
350
+ ("alpha", c_element_epilogue),
351
+ ("beta", c_element_epilogue),
352
+ ("alpha_ptr", ctypes.c_void_p),
353
+ ("beta_ptr", ctypes.c_void_p),
354
+ ("leaky_alpha", c_element_epilogue)
355
+ ]
356
+ def __init__(self, alpha, beta, leaky_alpha=0.2, *args) -> None:
357
+ self.alpha = element_epilogue(alpha).storage
358
+ self.beta = element_epilogue(beta).storage
359
+ self.alpha_ptr = 0
360
+ self.beta_ptr = 0
361
+ self.leaky_alpha = element_epilogue(leaky_alpha).storage
362
+ return _EpilogueOutputOpParams
363
+
364
+ # Tanh operator
365
+ class tanh(ActivationFunctor):
366
+ tag = "cutlass::epilogue::thread::Tanh"
367
+
368
+ def __init__(self, element_compute) -> None:
369
+ super().__init__(element_compute)
370
+ class _Arguments(ctypes.Structure):
371
+ _fields_ = [
372
+ ("tmp", ctypes.c_int)
373
+ ]
374
+ def __init__(self, *args) -> None:
375
+ self.tmp = 0
376
+ self.argument_type = _Arguments
377
+
378
+ def emit_visitor(self):
379
+ return "cutlass::TanhVisitor"
380
+
381
+ @staticmethod
382
+ def numpy(x: np.ndarray):
383
+ return np.tanh(x)
384
+
385
+ def sigmoid_op(x: np.ndarray):
386
+ return 1. / (1. + np.exp(-x))
387
+
388
+ # Sigmoid operator
389
+ class sigmoid(ActivationFunctor):
390
+ tag = "cutlass::epilogue::thread::Sigmoid"
391
+
392
+ @staticmethod
393
+ def numpy(x: np.ndarray):
394
+ return sigmoid_op(x)
395
+
396
+ # SiLu operator
397
+ class silu(ActivationFunctor):
398
+ tag = "cutlass::epilogue::thread::SiLu"
399
+
400
+ @staticmethod
401
+ def numpy(x: np.ndarray):
402
+ return x * sigmoid_op(x)
403
+
404
+ # Hardswish operator
405
+ class hardswish(ActivationFunctor):
406
+ tag = "cutlass::epilogue::thread::HardSwish"
407
+
408
+ @staticmethod
409
+ def numpy(x: np.ndarray):
410
+ relu6 = np.minimum(np.maximum(x + 3., 0), 6.)
411
+ return x * relu6 / 6.
412
+
413
+ # GELU operator
414
+ class gelu(ActivationFunctor):
415
+ tag = "cutlass::epilogue::thread::GELU"
416
+
417
+ @staticmethod
418
+ def numpy(x: np.ndarray):
419
+ return 0.5 * x * (1 + erf(x / np.sqrt(2.)))
420
+
421
+ # reduction operator
422
+ def reduction_op(tensor, direction, math, factor):
423
+ batch, m, n = tensor.shape
424
+ if math == "Add":
425
+ if direction == "row":
426
+ num_cta_n = (n + factor - 1) // factor
427
+ reduction = np.transpose(
428
+ np.sum(tensor.reshape(batch, m, num_cta_n, factor), axis=-1),
429
+ axes=[0, 2, 1]).flatten()
430
+ elif direction == "column":
431
+ num_cta_m = (m + factor - 1) // factor
432
+ reduction = np.sum(
433
+ tensor.reshape(batch, num_cta_m, factor, n), axis=-2).flatten()
434
+ else:
435
+ raise NotImplementedError
436
+ return reduction
437
+ else:
438
+ raise NotImplementedError
439
+
440
+ # # GELU operator implemented using the taylor series approximation
441
+ # class GELU_taylor(ActivationFunctor):
442
+ # tag = "cutlass::epilogue::thread::GELU_taylor"
443
+
444
+ # # Computes backwards pass for GELU operator
445
+ # class dGELU(ActivationFunctor):
446
+ # tag = "cutlass::epilogue::thread::dGELU"
447
+
448
+ ################################################################################
449
+ # Epilogue Visitor
450
+ ################################################################################
451
+
452
+
453
+ class LayerNorm(EpilogueFunctorBase):
454
+ """
455
+ Apply a linear combination operator to an array of elements
456
+ D = alpha * accumulator + beta * source
457
+
458
+ :param element_output: data type used to load and store tensors
459
+
460
+ :param epilogue_vector_length: number of elements computed per operation.
461
+ Usually it is 128/sizeof_bits<ElementOutput_>, but we use 64 and 32 sometimes
462
+ when there are not enough data to store
463
+
464
+ :param element_accumulator: Accumulator data type
465
+
466
+ :param element_epilogue: data type used to compute linear combination
467
+ """
468
+ KernelTemplate = """
469
+
470
+ cutlass::epilogue::threadblock::EpilogueVisitorLayerNorm<
471
+ cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
472
+ ${operation_name}_default::kThreadCount,
473
+ ${operation_name}_default::Epilogue::OutputTileIterator,
474
+ ${operation_name}_default::Epilogue::AccumulatorFragmentIterator::AccumulatorTile,
475
+ ${element_compute}, // element_compute
476
+ ${element_variance}, // element_variance
477
+ ${element_mean}, // element_mean
478
+ ${element_layer_norm_compute}, // element_layer_norm_compute
479
+ ${epilogue_functor},
480
+ ${shifted_k}>;
481
+ """
482
+ headers = ["gemm/gemm_universal_with_visitor.h",
483
+ "epilogue/epilogue_visitor_with_layernorm.h"]
484
+ def __init__(
485
+ self, elementwise_functor,
486
+ element_variance=None, element_mean=None,
487
+ element_layer_norm_compute=None, shifted_k=True) -> None: # TODO bind ScaleType
488
+ super().__init__()
489
+
490
+ self.elementwise_functor = elementwise_functor
491
+ self.element_compute = elementwise_functor.element_epilogue
492
+ self.element_output = elementwise_functor.element_output
493
+
494
+ if element_variance is None:
495
+ self.element_variance = self.element_output
496
+ if element_mean is None:
497
+ self.element_mean = self.element_output
498
+ if element_layer_norm_compute is None:
499
+ self.element_layer_norm_compute = self.element_compute
500
+ if shifted_k:
501
+ self.shifted_k = "true"
502
+ else:
503
+ self.shifted_k = "false"
504
+
505
+ # get epilogue output op
506
+ elementwise_params_type = self.elementwise_functor.epilogue_type
507
+
508
+ class _EpilogueVisitorParams(ctypes.Structure):
509
+ _fields_ = [
510
+ ("element_wise", elementwise_params_type),
511
+ ("ptr_Variance", ctypes.c_void_p),
512
+ ("ptr_Mean_", ctypes.c_void_p),
513
+ ("ptr_Shifted_K_", ctypes.c_void_p),
514
+ ("extent", MatrixCoord_)
515
+ ]
516
+ def __init__(self, elementwise_params, variance, mean, shift_k, extent) -> None:
517
+ self.element_wise = elementwise_params
518
+ if isinstance(variance, np.ndarray):
519
+ self.buffer_variance = NumpyFrontend.argument(variance, False)
520
+ self.buffer_mean = NumpyFrontend.argument(mean, False)
521
+ self.buffer_shift_k = NumpyFrontend.argument(shift_k, False)
522
+ self.ptr_Variance = int(self.buffer_variance.ptr)
523
+ self.ptr_Mean_ = int(self.buffer_mean.ptr)
524
+ self.ptr_Shifted_K_ = int(self.buffer_shift_k.ptr)
525
+ self.extent = MatrixCoord_(extent[0], extent[1])
526
+
527
+ self.host_variance = variance
528
+ self.host_mean = mean
529
+ self.host_shift_k = shift_k
530
+
531
+ def sync(self, stream_sync=True):
532
+ if stream_sync:
533
+ err, = cudart.cudaDeviceSynchronize()
534
+ if err != cuda.CUresult.CUDA_SUCCESS:
535
+ raise RuntimeError("CUDA Error %s" % str(err))
536
+
537
+ # if hasattr(self, "host_variance"):
538
+ err, = cuda.cuMemcpyDtoH(
539
+ self.host_variance, cuda.CUdeviceptr(self.ptr_Variance),
540
+ self.host_variance.size * self.host_variance.itemsize)
541
+ err, = cuda.cuMemcpyDtoH(
542
+ self.host_mean, cuda.CUdeviceptr(self.ptr_Mean_),
543
+ self.host_mean.size * self.host_mean.itemsize)
544
+ err, = cuda.cuMemcpyDtoH(
545
+ self.host_shift_k, cuda.CUdeviceptr(self.ptr_Shifted_K_),
546
+ self.host_shift_k.size * self.host_shift_k.itemsize)
547
+ if err != cuda.CUresult.CUDA_SUCCESS:
548
+ raise RuntimeError("CUDA Error %s" % str(err))
549
+
550
+ self.epilogue_type = _EpilogueVisitorParams
551
+
552
+ def emit(self, operation):
553
+ values = {
554
+ 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
555
+ 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
556
+ 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
557
+ 'operation_name': operation.procedural_name(),
558
+ 'element_compute': DataTypeTag[self.element_compute],
559
+ 'element_variance': DataTypeTag[self.element_variance],
560
+ 'element_mean': DataTypeTag[self.element_mean],
561
+ 'element_layer_norm_compute': DataTypeTag[self.element_layer_norm_compute],
562
+ 'epilogue_functor': self.elementwise_functor.emit(),
563
+ 'shifted_k': self.shifted_k
564
+ }
565
+ return SubstituteTemplate(self.KernelTemplate, values)
566
+
567
+
568
+
569
+ class AccumulatorOp:
570
+ Template = """
571
+ using ${instance_name} = cutlass::epilogue::threadblock::VisitorOpAccumulator<${element_accumulator}, ${elements_per_access}>;
572
+ """
573
+ counter = 0
574
+ def __init__(self, element_accumulator, elements_per_access) -> None:
575
+ self.element_accumulator = element_accumulator
576
+ self.elements_per_access = elements_per_access
577
+
578
+ self.instance_name = "AccumulatorOp%d" % AccumulatorOp.counter
579
+ AccumulatorOp.counter += 1
580
+
581
+
582
+ class _Arguments(ctypes.Structure):
583
+ _fields_ = [
584
+ ("tmp", ctypes.c_int)
585
+ ]
586
+ def __init__(self):
587
+ self.tmp = 0
588
+
589
+ self.argument_type = _Arguments
590
+
591
+ def emit(self, *args):
592
+ values = {
593
+ "instance_name": self.instance_name,
594
+ "element_accumulator": DataTypeTag[self.element_accumulator],
595
+ "elements_per_access": str(self.elements_per_access)
596
+ }
597
+ return SubstituteTemplate(self.Template, values)
598
+
599
+
600
+ class LinearCombinationOp:
601
+ Template = """
602
+ ${visitor_a}
603
+
604
+ ${visitor_b}
605
+
606
+ using ${instance_name} = cutlass::epilogue::threadblock::VisitorOpLinearCombination<
607
+ ${element_accumulator}, ${element_compute},
608
+ ${elements_per_access}, ${visitor_a_name}, ${visitor_b_name}>;
609
+ """
610
+ counter = 0
611
+ def __init__(self, element_accumulator, element_compute,
612
+ elements_per_access, visitor_a, visitor_b) -> None:
613
+ #
614
+ self.element_accumulator = element_accumulator
615
+ self.element_compute = element_compute
616
+ self.elements_per_access = elements_per_access
617
+ self.visitor_a = visitor_a
618
+ self.visitor_b = visitor_b
619
+
620
+ self.instance_name = "LinearCombinationOp%d" % LinearCombinationOp.counter
621
+ LinearCombinationOp.counter += 1
622
+
623
+ class _Arguments(ctypes.Structure):
624
+ _fields_ = [
625
+ ("alpha", dtype2ctype[self.element_compute]),
626
+ ("beta", dtype2ctype[self.element_compute]),
627
+ ("visitor_a", self.visitor_a.argument_type),
628
+ ("visitor_b", self.visitor_b.argument_type)
629
+ ]
630
+ def __init__(self, alpha, beta, visitor_a_arg, visitor_b_arg) -> None:
631
+ self.alpha = element_compute(alpha).storage
632
+ self.beta = element_compute(beta).storage
633
+ self.visitor_a = visitor_a_arg
634
+ self.visitor_b = visitor_b_arg
635
+
636
+ self.argument_type = _Arguments
637
+
638
+ def emit(self, operation):
639
+ values = {
640
+ "instance_name": self.instance_name,
641
+ "element_accumulator": DataTypeTag[self.element_accumulator],
642
+ "element_compute": DataTypeTag[self.element_compute],
643
+ "elements_per_access": str(self.elements_per_access),
644
+ "visitor_a_name": self.visitor_a.instance_name,
645
+ "visitor_b_name": self.visitor_b.instance_name,
646
+ "visitor_a": self.visitor_a.emit(operation),
647
+ "visitor_b": self.visitor_b.emit(operation)
648
+ }
649
+ return SubstituteTemplate(self.Template, values)
650
+
651
+ class VectorAdd:
652
+ def __init__(self, *args) -> None:
653
+ class _Arguments(ctypes.Structure):
654
+ _fields_ = [
655
+ ("tmp", ctypes.c_int)
656
+ ]
657
+ def __init__(self, *args) -> None:
658
+ self.tmp = 0
659
+ self.argument_type = _Arguments
660
+
661
+ def emit(self):
662
+ return "cutlass::VectorAdd"
663
+
664
+ class VectorMult:
665
+ def __init__(self, *args) -> None:
666
+ class _Arguments(ctypes.Structure):
667
+ _fields_ = [
668
+ ("tmp", ctypes.c_int)
669
+ ]
670
+ def __init__(self, *args) -> None:
671
+ self.tmp = 0
672
+ self.argument_type = _Arguments
673
+
674
+ def emit(self):
675
+ return "cutlass::VectorMult"
676
+
677
+
678
+ class BinaryOp:
679
+ Template = """
680
+ ${visitor_a}
681
+
682
+ ${visitor_b}
683
+
684
+ using ${instance_name} = cutlass::epilogue::threadblock::VisitorOpBinary<
685
+ ${element_accumulator}, ${element_compute},
686
+ ${elements_per_access}, ${visitor_a_name}, ${visitor_b_name}, ${binary_op}>;
687
+ """
688
+ counter = 0
689
+ def __init__(self, element_accumulator, element_compute,
690
+ elements_per_access, visitor_a, visitor_b, binary_op) -> None:
691
+ #
692
+ self.element_accumulator = element_accumulator
693
+ self.element_compute = element_compute
694
+ self.elements_per_access = elements_per_access
695
+ self.visitor_a = visitor_a
696
+ self.visitor_b = visitor_b
697
+ self.binary_op = binary_op
698
+
699
+ self.instance_name = "BinaryOp%d" % BinaryOp.counter
700
+ BinaryOp.counter += 1
701
+
702
+ class _Arguments(ctypes.Structure):
703
+ _fields_ = [
704
+ ("binary_param", binary_op.argument_type),
705
+ ("visitor_a", self.visitor_a.argument_type),
706
+ ("visitor_b", self.visitor_b.argument_type)
707
+ ]
708
+ def __init__(self, binary_param, visitor_a_arg, visitor_b_arg) -> None:
709
+ self.binary_param = binary_param
710
+ self.visitor_a = visitor_a_arg
711
+ self.visitor_b = visitor_b_arg
712
+
713
+ self.argument_type = _Arguments
714
+ def emit(self, operation):
715
+ values = {
716
+ "instance_name": self.instance_name,
717
+ "element_accumulator": DataTypeTag[self.element_accumulator],
718
+ "element_compute": DataTypeTag[self.element_compute],
719
+ "elements_per_access": str(self.elements_per_access),
720
+ "visitor_a_name": self.visitor_a.instance_name,
721
+ "visitor_b_name": self.visitor_b.instance_name,
722
+ "visitor_a": self.visitor_a.emit(operation),
723
+ "visitor_b": self.visitor_b.emit(operation),
724
+ "binary_op": self.binary_op.emit()
725
+ }
726
+ return SubstituteTemplate(self.Template, values)
727
+
728
+
729
+ class Mult:
730
+ def __init__(self, element_compute) -> None:
731
+ class _Arguments(ctypes.Structure):
732
+ _fields_ = [
733
+ ("alpha", dtype2ctype[element_compute])
734
+ ]
735
+ def __init__(self, alpha) -> None:
736
+ self.alpha = element_compute(alpha).storage
737
+
738
+ self.argument_type = _Arguments
739
+
740
+ def emit_visitor(self):
741
+ return "cutlass::Mult"
742
+
743
+ class UnaryOp:
744
+ Template = """
745
+ ${visitor}
746
+
747
+ using ${instance_name} = cutlass::epilogue::threadblock::VisitorOpUnary<
748
+ ${element_accumulator}, ${element_compute},
749
+ ${elements_per_access}, ${visitor_name}, ${unary_op}>;
750
+ """
751
+ counter = 0
752
+ def __init__(self, element_accumulator, element_compute,
753
+ elements_per_access, visitor, unary_op) -> None:
754
+ #
755
+ self.element_accumulator = element_accumulator
756
+ self.element_compute = element_compute
757
+ self.elements_per_access = elements_per_access
758
+ self.visitor = visitor
759
+ self.unary_op = unary_op
760
+
761
+ self.instance_name = "UnaryOp%d" % UnaryOp.counter
762
+ UnaryOp.counter += 1
763
+
764
+ class _Arguments(ctypes.Structure):
765
+ _fields_ = [
766
+ ("unary_param", unary_op.argument_type),
767
+ ("visitor_arg", self.visitor.argument_type)
768
+ ]
769
+ def __init__(self, unary_param, visitor_arg) -> None:
770
+ self.unary_param = unary_param
771
+ self.visitor_arg = visitor_arg
772
+
773
+ self.argument_type = _Arguments
774
+
775
+ def emit(self, operation):
776
+ values = {
777
+ "instance_name": self.instance_name,
778
+ "element_accumulator": DataTypeTag[self.element_accumulator],
779
+ "element_compute": DataTypeTag[self.element_compute],
780
+ "elements_per_access": str(self.elements_per_access),
781
+ "visitor_name": self.visitor.instance_name,
782
+ "unary_op": self.unary_op.emit_visitor(),
783
+ "visitor": self.visitor.emit(operation)
784
+ }
785
+ return SubstituteTemplate(self.Template, values)
786
+
787
+
788
+
789
+ class RowBroadcastOp:
790
+ Template = """
791
+ using ${instance_name} = cutlass::epilogue::threadblock::VisitorOpRowBroadcast<
792
+ ${element_accumulator}, ${element_fragment}, ${input_tile_iterator}>;
793
+ """
794
+ counter = 0
795
+ def __init__(self, element_accumulator, element_fragment) -> None:
796
+ self.element_accumulator = element_accumulator
797
+ self.element_fragment = element_fragment
798
+
799
+ self.instance_name = "RowBroadcastOp%d" % RowBroadcastOp.counter
800
+ RowBroadcastOp.counter += 1
801
+
802
+ class _Arguments(ctypes.Structure):
803
+ _fields_ = [
804
+ ("broadcast_ptr", ctypes.c_void_p),
805
+ ("batch_stride", ctypes.c_longlong)
806
+ ]
807
+ def __init__(self, broadcast_ptr, batch_stride=0):
808
+ self.broadcast_ptr = int(broadcast_ptr)
809
+ self.batch_stride = batch_stride
810
+
811
+ self.argument_type = _Arguments
812
+
813
+ def emit(self, operation):
814
+ values = {
815
+ "instance_name": self.instance_name,
816
+ "element_accumulator": DataTypeTag[self.element_accumulator],
817
+ "element_fragment": DataTypeTag[self.element_fragment],
818
+ "input_tile_iterator": operation.procedural_name() + "_default::Epilogue::OutputTileIterator"
819
+ }
820
+ return SubstituteTemplate(self.Template, values)
821
+
822
+
823
+ class ColumnBroadcastOp:
824
+ Template = """
825
+ using ${instance_name} = cutlass::epilogue::threadblock::VisitorOpColumnBroadcast<
826
+ ${element_accumulator}, ${element_fragment}, ${input_tile_iterator}>;
827
+ """
828
+ counter = 0
829
+ def __init__(self, element_accumulator, element_fragment) -> None:
830
+ self.element_accumulator = element_accumulator
831
+ self.element_fragment = element_fragment
832
+
833
+ self.instance_name = "ColumnBroadcastOp%d" % ColumnBroadcastOp.counter
834
+ ColumnBroadcastOp.counter += 1
835
+
836
+ class _Arguments(ctypes.Structure):
837
+ _fields_ = [
838
+ ("broadcast_ptr", ctypes.c_void_p),
839
+ ("batch_stride", ctypes.c_longlong)
840
+ ]
841
+ def __init__(self, broadcast_ptr, batch_stride=0):
842
+ self.broadcast_ptr = int(broadcast_ptr)
843
+ self.batch_stride = batch_stride
844
+
845
+ self.argument_type = _Arguments
846
+
847
+ def emit(self, operation):
848
+ values = {
849
+ "instance_name": self.instance_name,
850
+ "element_accumulator": DataTypeTag[self.element_accumulator],
851
+ "element_fragment": DataTypeTag[self.element_fragment],
852
+ "input_tile_iterator": operation.procedural_name() + "_default::Epilogue::OutputTileIterator"
853
+ }
854
+ return SubstituteTemplate(self.Template, values)
855
+
856
+
857
+ class TensorInputOp:
858
+ Template = """
859
+ using ${instance_name} = cutlass::epilogue::threadblock::VisitorOpTensorInput<
860
+ ${element_accumulator}, ${input_tile_iterator}>;
861
+ """
862
+ counter = 0
863
+ def __init__(self, element_accumulator) -> None:
864
+ self.element_accumulator = element_accumulator
865
+
866
+ self.instance_name = "TensorInputOp%d" % TensorInputOp.counter
867
+ TensorInputOp.counter += 1
868
+
869
+ class _Arguments(ctypes.Structure):
870
+ _fields_ = [
871
+ ("input_ptr", ctypes.c_void_p),
872
+ ("ldt", ctypes.c_int),
873
+ ("batch_stride", ctypes.c_longlong)
874
+ ]
875
+ def __init__(self, input_ptr, ldt, batch_stride=0) -> None:
876
+ self.input_ptr = int(input_ptr)
877
+ self.ldt = ldt
878
+ self.batch_stride = batch_stride
879
+
880
+ self.argument_type = _Arguments
881
+
882
+ def emit(self, operation):
883
+ values = {
884
+ "instance_name": self.instance_name,
885
+ "element_accumulator": DataTypeTag[self.element_accumulator],
886
+ "input_tile_iterator": operation.procedural_name() + "_default::Epilogue::OutputTileIterator"
887
+ }
888
+ return SubstituteTemplate(self.Template, values)
889
+
890
+ class TensorOutputOp:
891
+ Template = """
892
+ ${visitor}
893
+
894
+ using ${instance_name} = cutlass::epilogue::threadblock::VisitorOpTensorOutput<
895
+ ${element_accumulator}, ${output_tile_iterator}, ${visitor_name}>;
896
+ """
897
+ counter = 0
898
+ def __init__(self, element_accumulator, visitor) -> None:
899
+ self.element_accumulator = element_accumulator
900
+ self.visitor = visitor
901
+
902
+ self.instance_name = "TensorOutputOp%d" % TensorOutputOp.counter
903
+ TensorOutputOp.counter += 1
904
+
905
+ class _Arguments(ctypes.Structure):
906
+ _fields_ = [
907
+ ("output_ptr", ctypes.c_void_p),
908
+ ("ldt", ctypes.c_int),
909
+ ("batch_stride", ctypes.c_longlong),
910
+ ("visitor_arg", self.visitor.argument_type)
911
+ ]
912
+ def __init__(self, output_ptr, ldt, visitor_arg, batch_stride=0) -> None:
913
+ self.output_ptr = int(output_ptr)
914
+ self.ldt = int(ldt)
915
+ self.visitor_arg = visitor_arg
916
+ self.batch_stride = batch_stride
917
+
918
+ self.argument_type = _Arguments
919
+
920
+ def emit(self, operation):
921
+ values = {
922
+ "instance_name": self.instance_name,
923
+ "element_accumulator": DataTypeTag[self.element_accumulator],
924
+ "output_tile_iterator": operation.procedural_name() + "_default::Epilogue::OutputTileIterator",
925
+ "visitor_name": self.visitor.instance_name,
926
+ "visitor": self.visitor.emit(operation)
927
+ }
928
+ return SubstituteTemplate(self.Template, values)
929
+
930
+
931
+ class ColumnReductionOp:
932
+ Template = """
933
+ ${visitor}
934
+
935
+ using ${instance_name} = cutlass::epilogue::threadblock::VisitorOpColumnReduction<
936
+ cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
937
+ ${element_accumulator}, ${element_reduction}, ${element_reduction_accumulator},
938
+ ${output_tile_iterator}, ${visitor_name}>;
939
+ """
940
+ counter = 0
941
+ def __init__(self, element_accumulator, element_reduction,
942
+ element_reduction_accumulator, visitor) -> None:
943
+ self.element_accumulator = element_accumulator
944
+ self.element_reduction = element_reduction
945
+ self.element_reduction_accumulator = element_reduction_accumulator
946
+ self.visitor = visitor
947
+
948
+ self.instance_name = "ColumnReductionOp%d" % ColumnReductionOp.counter
949
+ ColumnReductionOp.counter += 1
950
+
951
+ class _Arguments(ctypes.Structure):
952
+ _fields_ = [
953
+ ("reduction_ptr", ctypes.c_void_p),
954
+ ("batch_stride", ctypes.c_longlong),
955
+ ("visitor_arg", self.visitor.argument_type)
956
+ ]
957
+ def __init__(self, reduction_ptr, visitor_arg, batch_stride=0) -> None:
958
+ self.reduction_ptr = reduction_ptr
959
+ self.batch_stride = batch_stride
960
+ self.visitor_arg = visitor_arg
961
+
962
+ self.argument_type = _Arguments
963
+
964
+ def emit(self, operation):
965
+ values = {
966
+ "instance_name": self.instance_name,
967
+ 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
968
+ 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
969
+ 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
970
+ "element_accumulator": DataTypeTag[self.element_accumulator],
971
+ "element_reduction": DataTypeTag[self.element_reduction],
972
+ "element_reduction_accumulator": DataTypeTag[self.element_reduction_accumulator],
973
+ "output_tile_iterator": operation.procedural_name() + "_default::Epilogue::OutputTileIterator",
974
+ "visitor_name": self.visitor.instance_name,
975
+ "visitor": self.visitor.emit(operation)
976
+ }
977
+ return SubstituteTemplate(self.Template, values)
978
+
979
+
980
+ class RowReductionOp:
981
+ Template = """
982
+ ${visitor}
983
+
984
+ using ${instance_name} = cutlass::epilogue::threadblock::VisitorOpRowReduction<
985
+ cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
986
+ ${element_accumulator}, ${element_reduction}, ${element_reduction_accumulator},
987
+ ${output_tile_iterator}, ${visitor_name}>;
988
+ """
989
+ counter = 0
990
+ def __init__(self, element_accumulator, element_reduction,
991
+ element_reduction_accumulator, visitor) -> None:
992
+ self.element_accumulator = element_accumulator
993
+ self.element_reduction = element_reduction
994
+ self.element_reduction_accumulator = element_reduction_accumulator
995
+ self.visitor = visitor
996
+
997
+ self.instance_name = "RowReductionOp%d" % RowReductionOp.counter
998
+ RowReductionOp.counter += 1
999
+
1000
+ class _Arguments(ctypes.Structure):
1001
+ _fields_ = [
1002
+ ("reduction_ptr", ctypes.c_void_p),
1003
+ ("batch_stride", ctypes.c_longlong),
1004
+ ("visitor_arg", self.visitor.argument_type)
1005
+ ]
1006
+ def __init__(self, reduction_ptr, visitor_arg, batch_stride=0) -> None:
1007
+ self.reduction_ptr = reduction_ptr
1008
+ self.visitor_arg = visitor_arg
1009
+ self.batch_stride = batch_stride
1010
+
1011
+ self.argument_type = _Arguments
1012
+
1013
+ def emit(self, operation):
1014
+ values = {
1015
+ "instance_name": self.instance_name,
1016
+ 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
1017
+ 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
1018
+ 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
1019
+ "element_accumulator": DataTypeTag[self.element_accumulator],
1020
+ "element_reduction": DataTypeTag[self.element_reduction],
1021
+ "element_reduction_accumulator": DataTypeTag[self.element_reduction_accumulator],
1022
+ "output_tile_iterator": operation.procedural_name() + "_default::Epilogue::OutputTileIterator",
1023
+ "visitor_name": self.visitor.instance_name,
1024
+ "visitor": self.visitor.emit(operation)
1025
+ }
1026
+ return SubstituteTemplate(self.Template, values)