warp-lang 0.10.1__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (300) hide show
  1. warp/__init__.py +10 -4
  2. warp/__init__.pyi +1 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +5 -3
  6. warp/build_dll.py +29 -9
  7. warp/builtins.py +868 -507
  8. warp/codegen.py +1074 -638
  9. warp/config.py +3 -3
  10. warp/constants.py +6 -0
  11. warp/context.py +715 -222
  12. warp/fabric.py +326 -0
  13. warp/fem/__init__.py +27 -0
  14. warp/fem/cache.py +389 -0
  15. warp/fem/dirichlet.py +181 -0
  16. warp/fem/domain.py +263 -0
  17. warp/fem/field/__init__.py +101 -0
  18. warp/fem/field/field.py +149 -0
  19. warp/fem/field/nodal_field.py +299 -0
  20. warp/fem/field/restriction.py +21 -0
  21. warp/fem/field/test.py +181 -0
  22. warp/fem/field/trial.py +183 -0
  23. warp/fem/geometry/__init__.py +19 -0
  24. warp/fem/geometry/closest_point.py +70 -0
  25. warp/fem/geometry/deformed_geometry.py +271 -0
  26. warp/fem/geometry/element.py +744 -0
  27. warp/fem/geometry/geometry.py +186 -0
  28. warp/fem/geometry/grid_2d.py +373 -0
  29. warp/fem/geometry/grid_3d.py +435 -0
  30. warp/fem/geometry/hexmesh.py +953 -0
  31. warp/fem/geometry/partition.py +376 -0
  32. warp/fem/geometry/quadmesh_2d.py +532 -0
  33. warp/fem/geometry/tetmesh.py +840 -0
  34. warp/fem/geometry/trimesh_2d.py +577 -0
  35. warp/fem/integrate.py +1616 -0
  36. warp/fem/operator.py +191 -0
  37. warp/fem/polynomial.py +213 -0
  38. warp/fem/quadrature/__init__.py +2 -0
  39. warp/fem/quadrature/pic_quadrature.py +245 -0
  40. warp/fem/quadrature/quadrature.py +294 -0
  41. warp/fem/space/__init__.py +292 -0
  42. warp/fem/space/basis_space.py +489 -0
  43. warp/fem/space/collocated_function_space.py +105 -0
  44. warp/fem/space/dof_mapper.py +236 -0
  45. warp/fem/space/function_space.py +145 -0
  46. warp/fem/space/grid_2d_function_space.py +267 -0
  47. warp/fem/space/grid_3d_function_space.py +306 -0
  48. warp/fem/space/hexmesh_function_space.py +352 -0
  49. warp/fem/space/partition.py +350 -0
  50. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  51. warp/fem/space/restriction.py +160 -0
  52. warp/fem/space/shape/__init__.py +15 -0
  53. warp/fem/space/shape/cube_shape_function.py +738 -0
  54. warp/fem/space/shape/shape_function.py +103 -0
  55. warp/fem/space/shape/square_shape_function.py +611 -0
  56. warp/fem/space/shape/tet_shape_function.py +567 -0
  57. warp/fem/space/shape/triangle_shape_function.py +429 -0
  58. warp/fem/space/tetmesh_function_space.py +292 -0
  59. warp/fem/space/topology.py +295 -0
  60. warp/fem/space/trimesh_2d_function_space.py +221 -0
  61. warp/fem/types.py +77 -0
  62. warp/fem/utils.py +495 -0
  63. warp/native/array.h +147 -44
  64. warp/native/builtin.h +122 -149
  65. warp/native/bvh.cpp +73 -325
  66. warp/native/bvh.cu +406 -23
  67. warp/native/bvh.h +34 -43
  68. warp/native/clang/clang.cpp +13 -8
  69. warp/native/crt.h +2 -0
  70. warp/native/cuda_crt.h +5 -0
  71. warp/native/cuda_util.cpp +15 -3
  72. warp/native/cuda_util.h +3 -1
  73. warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
  74. warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
  75. warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
  76. warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
  77. warp/native/cutlass/tools/library/scripts/library.py +799 -0
  78. warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
  79. warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
  80. warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
  81. warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
  82. warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
  83. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
  84. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
  85. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
  86. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
  87. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
  88. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
  89. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
  90. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
  91. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
  92. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
  93. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
  94. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
  95. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
  96. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
  97. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
  98. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
  99. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
  100. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
  101. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
  102. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
  103. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
  104. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
  105. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
  106. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
  107. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
  108. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
  109. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
  110. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
  111. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
  112. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
  113. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
  114. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
  115. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  116. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  117. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
  118. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
  119. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
  120. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  121. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  122. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
  123. warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
  124. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
  125. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
  126. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
  127. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
  128. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
  129. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
  130. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
  131. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
  132. warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
  133. warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
  134. warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
  135. warp/native/cutlass/tools/library/scripts/rt.py +796 -0
  136. warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
  137. warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
  138. warp/native/cutlass_gemm.cu +5 -3
  139. warp/native/exports.h +1240 -952
  140. warp/native/fabric.h +228 -0
  141. warp/native/hashgrid.cpp +4 -4
  142. warp/native/hashgrid.h +22 -2
  143. warp/native/intersect.h +22 -7
  144. warp/native/intersect_adj.h +8 -8
  145. warp/native/intersect_tri.h +1 -1
  146. warp/native/marching.cu +157 -161
  147. warp/native/mat.h +80 -19
  148. warp/native/matnn.h +2 -2
  149. warp/native/mesh.cpp +33 -108
  150. warp/native/mesh.cu +114 -23
  151. warp/native/mesh.h +446 -46
  152. warp/native/noise.h +272 -329
  153. warp/native/quat.h +51 -8
  154. warp/native/rand.h +45 -35
  155. warp/native/range.h +6 -2
  156. warp/native/reduce.cpp +1 -1
  157. warp/native/reduce.cu +10 -12
  158. warp/native/runlength_encode.cu +6 -10
  159. warp/native/scan.cu +8 -11
  160. warp/native/sparse.cpp +4 -4
  161. warp/native/sparse.cu +164 -154
  162. warp/native/spatial.h +2 -2
  163. warp/native/temp_buffer.h +14 -30
  164. warp/native/vec.h +107 -23
  165. warp/native/volume.h +120 -0
  166. warp/native/warp.cpp +560 -30
  167. warp/native/warp.cu +431 -44
  168. warp/native/warp.h +13 -4
  169. warp/optim/__init__.py +1 -0
  170. warp/optim/linear.py +922 -0
  171. warp/optim/sgd.py +92 -0
  172. warp/render/render_opengl.py +335 -119
  173. warp/render/render_usd.py +11 -11
  174. warp/sim/__init__.py +2 -2
  175. warp/sim/articulation.py +385 -185
  176. warp/sim/collide.py +8 -0
  177. warp/sim/import_mjcf.py +297 -106
  178. warp/sim/import_urdf.py +389 -210
  179. warp/sim/import_usd.py +198 -97
  180. warp/sim/inertia.py +17 -18
  181. warp/sim/integrator_euler.py +14 -8
  182. warp/sim/integrator_xpbd.py +158 -16
  183. warp/sim/model.py +795 -291
  184. warp/sim/render.py +3 -3
  185. warp/sim/utils.py +3 -0
  186. warp/sparse.py +640 -150
  187. warp/stubs.py +606 -267
  188. warp/tape.py +61 -10
  189. warp/tests/__main__.py +3 -6
  190. warp/tests/assets/curlnoise_golden.npy +0 -0
  191. warp/tests/assets/pnoise_golden.npy +0 -0
  192. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  193. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  194. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  195. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  196. warp/tests/aux_test_unresolved_func.py +14 -0
  197. warp/tests/aux_test_unresolved_symbol.py +14 -0
  198. warp/tests/disabled_kinematics.py +239 -0
  199. warp/tests/run_coverage_serial.py +31 -0
  200. warp/tests/test_adam.py +103 -106
  201. warp/tests/test_arithmetic.py +128 -74
  202. warp/tests/test_array.py +212 -97
  203. warp/tests/test_array_reduce.py +57 -23
  204. warp/tests/test_atomic.py +64 -28
  205. warp/tests/test_bool.py +99 -0
  206. warp/tests/test_builtins_resolution.py +1292 -0
  207. warp/tests/test_bvh.py +42 -18
  208. warp/tests/test_closest_point_edge_edge.py +54 -57
  209. warp/tests/test_codegen.py +208 -130
  210. warp/tests/test_compile_consts.py +28 -20
  211. warp/tests/test_conditional.py +108 -24
  212. warp/tests/test_copy.py +10 -12
  213. warp/tests/test_ctypes.py +112 -88
  214. warp/tests/test_dense.py +21 -14
  215. warp/tests/test_devices.py +98 -0
  216. warp/tests/test_dlpack.py +75 -75
  217. warp/tests/test_examples.py +277 -0
  218. warp/tests/test_fabricarray.py +955 -0
  219. warp/tests/test_fast_math.py +15 -11
  220. warp/tests/test_fem.py +1271 -0
  221. warp/tests/test_fp16.py +53 -19
  222. warp/tests/test_func.py +187 -86
  223. warp/tests/test_generics.py +194 -49
  224. warp/tests/test_grad.py +178 -109
  225. warp/tests/test_grad_customs.py +176 -0
  226. warp/tests/test_hash_grid.py +52 -37
  227. warp/tests/test_import.py +10 -23
  228. warp/tests/test_indexedarray.py +32 -31
  229. warp/tests/test_intersect.py +18 -9
  230. warp/tests/test_large.py +141 -0
  231. warp/tests/test_launch.py +14 -41
  232. warp/tests/test_lerp.py +64 -65
  233. warp/tests/test_linear_solvers.py +154 -0
  234. warp/tests/test_lvalue.py +493 -0
  235. warp/tests/test_marching_cubes.py +12 -13
  236. warp/tests/test_mat.py +517 -2898
  237. warp/tests/test_mat_lite.py +115 -0
  238. warp/tests/test_mat_scalar_ops.py +2889 -0
  239. warp/tests/test_math.py +103 -9
  240. warp/tests/test_matmul.py +305 -69
  241. warp/tests/test_matmul_lite.py +410 -0
  242. warp/tests/test_mesh.py +71 -14
  243. warp/tests/test_mesh_query_aabb.py +41 -25
  244. warp/tests/test_mesh_query_point.py +140 -22
  245. warp/tests/test_mesh_query_ray.py +39 -22
  246. warp/tests/test_mlp.py +30 -22
  247. warp/tests/test_model.py +92 -89
  248. warp/tests/test_modules_lite.py +39 -0
  249. warp/tests/test_multigpu.py +88 -114
  250. warp/tests/test_noise.py +12 -11
  251. warp/tests/test_operators.py +16 -20
  252. warp/tests/test_options.py +11 -11
  253. warp/tests/test_pinned.py +17 -18
  254. warp/tests/test_print.py +32 -11
  255. warp/tests/test_quat.py +275 -129
  256. warp/tests/test_rand.py +18 -16
  257. warp/tests/test_reload.py +38 -34
  258. warp/tests/test_rounding.py +50 -43
  259. warp/tests/test_runlength_encode.py +168 -20
  260. warp/tests/test_smoothstep.py +9 -11
  261. warp/tests/test_snippet.py +143 -0
  262. warp/tests/test_sparse.py +261 -63
  263. warp/tests/test_spatial.py +276 -243
  264. warp/tests/test_streams.py +110 -85
  265. warp/tests/test_struct.py +268 -63
  266. warp/tests/test_tape.py +39 -21
  267. warp/tests/test_torch.py +118 -89
  268. warp/tests/test_transient_module.py +12 -13
  269. warp/tests/test_types.py +614 -0
  270. warp/tests/test_utils.py +494 -0
  271. warp/tests/test_vec.py +354 -2050
  272. warp/tests/test_vec_lite.py +73 -0
  273. warp/tests/test_vec_scalar_ops.py +2099 -0
  274. warp/tests/test_volume.py +457 -293
  275. warp/tests/test_volume_write.py +124 -134
  276. warp/tests/unittest_serial.py +35 -0
  277. warp/tests/unittest_suites.py +341 -0
  278. warp/tests/unittest_utils.py +568 -0
  279. warp/tests/unused_test_misc.py +71 -0
  280. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  281. warp/thirdparty/appdirs.py +36 -45
  282. warp/thirdparty/unittest_parallel.py +549 -0
  283. warp/torch.py +9 -6
  284. warp/types.py +1089 -366
  285. warp/utils.py +93 -387
  286. warp_lang-0.11.0.dist-info/METADATA +238 -0
  287. warp_lang-0.11.0.dist-info/RECORD +332 -0
  288. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
  289. warp/tests/test_all.py +0 -219
  290. warp/tests/test_array_scan.py +0 -60
  291. warp/tests/test_base.py +0 -208
  292. warp/tests/test_unresolved_func.py +0 -7
  293. warp/tests/test_unresolved_symbol.py +0 -7
  294. warp_lang-0.10.1.dist-info/METADATA +0 -21
  295. warp_lang-0.10.1.dist-info/RECORD +0 -188
  296. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  297. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  298. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  299. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/LICENSE.md +0 -0
  300. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,80 @@
1
+ import distutils.cmd
2
+ from setuptools import setup
3
+ import setuptools.command.build_py
4
+ import os
5
+
6
+ # build rmm dependency
7
+ class BuildRMM(distutils.cmd.Command):
8
+ user_options = []
9
+ def initialize_options(self):
10
+ pass
11
+ def finalize_options(self):
12
+ pass
13
+ def run(self):
14
+ try:
15
+ import rmm
16
+ except ImportError:
17
+ print("installing rmm")
18
+ os.system("git clone -b branch-22.08 --recurse-submodules https://github.com/rapidsai/rmm.git")
19
+ os.chdir("./rmm")
20
+ os.system("./build.sh librmm rmm")
21
+ os.chdir("./python")
22
+ os.system("python setup.py build_ext --inplace")
23
+ os.system("python setup.py install")
24
+
25
+ cutlass_path = os.getenv('CUTLASS_PATH')
26
+ assert cutlass_path is not None, "Environment variable 'CUTLASS_PATH' is not defined."
27
+ cuda_install_path = os.getenv('CUDA_INSTALL_PATH')
28
+ assert cuda_install_path is not None, "Environment variable 'CUDA_INSTALL_PATH' is not defined."
29
+
30
+ ext_modules = []
31
+
32
+ try:
33
+ from pybind11.setup_helpers import Pybind11Extension, build_ext
34
+ include_dirs = [
35
+ cutlass_path + "/include",
36
+ cuda_install_path + "/include",
37
+ cutlass_path + "/tools/util/include",
38
+ cutlass_path + "/test",
39
+ cutlass_path + "/tools/library/scripts/pycutlass/googletest/googletest/include"
40
+ ]
41
+
42
+ ext_modules = [
43
+ Pybind11Extension("cutlass",
44
+ ["src/cpp/cutlass.cpp"],
45
+ include_dirs=include_dirs,
46
+ extra_compile_args=["-fpermissive", "-w"])
47
+ ]
48
+ except ImportError:
49
+ pass
50
+
51
+ setup(
52
+ name="PyCutlass",
53
+ version="0.0.1",
54
+ author="Zhaodong Chen; Andrew Kerr; Haicheng Wu; Szymon Migacz; Graham Markall",
55
+ author_email="zhaodongc@nvidia.com",
56
+ description="Python interface for CUTLASS",
57
+ classifiers=[
58
+ "Programming Language :: Python :: 3",
59
+ "License :: OSI Approved :: MIT License",
60
+ "Operating System :: OS Independent",
61
+ ],
62
+ package_dir={"": "src"},
63
+ packages=['pycutlass', 'pycutlass.utils', 'pycutlass.test'],
64
+ setup_requires=["pybind11", "numpy<1.23"],
65
+ install_requires=[
66
+ "numpy<1.23",
67
+ 'pybind11',
68
+ 'cuda-python<11.7.0',
69
+ 'typeguard',
70
+ 'bfloat16',
71
+ 'typing',
72
+ 'scikit-build',
73
+ 'treelib'
74
+ ],
75
+ cmdclass={
76
+ 'rmm': BuildRMM
77
+ },
78
+ ext_modules=ext_modules,
79
+ python_requires=">=3.6",
80
+ )
@@ -0,0 +1,48 @@
1
+ import re
2
+
3
+
4
+ def SubstituteTemplate(template, values):
5
+ text = template
6
+ changed = True
7
+ while changed:
8
+ changed = False
9
+ for key, value in values.items():
10
+ regex = "\\$\\{%s\\}" % key
11
+ newtext = re.sub(regex, value, text)
12
+ if newtext != text:
13
+ changed = True
14
+ text = newtext
15
+ return text
16
+
17
+ from pycutlass.type_hint import *
18
+ from pycutlass.tensor_ref import *
19
+ from pycutlass.operation import *
20
+ from pycutlass.epilogue import *
21
+ from pycutlass.parser import *
22
+ from pycutlass.compiler import ArtifactManager
23
+ from pycutlass.memory_manager import *
24
+ from pycutlass.arguments import *
25
+ from pycutlass.library import *
26
+ from pycutlass.c_types import *
27
+ from pycutlass.gemm_operation import *
28
+ from pycutlass.conv2d_operation import *
29
+ from pycutlass.compiler import *
30
+ from pycutlass.utils import *
31
+ from pycutlass.frontend import *
32
+ from pycutlass.reduction_operation import *
33
+ from pycutlass.compiler import *
34
+
35
+ # module-wide variables
36
+
37
+ import sys
38
+ this = sys.modules[__name__]
39
+
40
+ # artifact manager
41
+ this.compiler = ArtifactManager()
42
+
43
+ def get_memory_pool(init_pool_size=0, max_pool_size=2**34):
44
+ this.memory_pool = PoolMemoryManager(
45
+ init_pool_size=init_pool_size,
46
+ max_pool_size=max_pool_size
47
+ )
48
+ return this.memory_pool
@@ -0,0 +1,118 @@
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+ from .frontend import CupyFrontend
33
+ from typeguard import typechecked
34
+ from pycutlass.frontend import *
35
+ from typing import Union
36
+ import numpy as np
37
+ from cuda import cuda
38
+ try:
39
+ import torch
40
+ torch_available = True
41
+ except ImportError:
42
+ torch_available = False
43
+ from cuda import cudart
44
+ try:
45
+ import cupy as cp
46
+ cupy_available = True
47
+ except ImportError:
48
+ cupy_available = False
49
+
50
+
51
+ # @typechecked
52
+ class ArgumentBase:
53
+ """
54
+ Base class for operation arguments
55
+ """
56
+
57
+ def __init__(self,
58
+ A: 'Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]',
59
+ B: 'Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]',
60
+ C: 'Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]',
61
+ D: 'Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]',
62
+ **kwargs) -> None:
63
+
64
+ # tensor_C can be interpreted as the bias with bias=True in keyword args
65
+ if "bias" in kwargs.keys():
66
+ self.bias = kwargs["bias"]
67
+ else:
68
+ # by default, tensor_C is not bias
69
+ self.bias = False
70
+
71
+ # preprocessing input tensors
72
+ if isinstance(A, np.ndarray):
73
+ self.host_D = D
74
+ self.buffer_A = NumpyFrontend.argument(A, False)
75
+ self.buffer_B = NumpyFrontend.argument(B, False)
76
+ self.buffer_C = NumpyFrontend.argument(C, False)
77
+ self.buffer_D = NumpyFrontend.argument(D, True)
78
+ self.ptr_A = self.buffer_A.ptr
79
+ self.ptr_B = self.buffer_B.ptr
80
+ self.ptr_C = self.buffer_C.ptr
81
+ self.ptr_D = self.buffer_D.ptr
82
+ # number of elements in C
83
+ self.tensor_c_numel = C.size
84
+ elif torch_available and isinstance(A, torch.Tensor):
85
+ self.ptr_A = TorchFrontend.argument(A)
86
+ self.ptr_B = TorchFrontend.argument(B)
87
+ self.ptr_C = TorchFrontend.argument(C)
88
+ self.ptr_D = TorchFrontend.argument(D)
89
+ # number of elements in C
90
+ self.tensor_c_numel = C.numel()
91
+ elif isinstance(A, cuda.CUdeviceptr):
92
+ self.ptr_A = A
93
+ self.ptr_B = B
94
+ self.ptr_C = C
95
+ self.ptr_D = D
96
+
97
+ elif cupy_available and isinstance(A, cp.ndarray):
98
+ self.ptr_A = CupyFrontend.argument(A)
99
+ self.ptr_B = CupyFrontend.argument(B)
100
+ self.ptr_C = CupyFrontend.argument(C)
101
+ self.ptr_D = CupyFrontend.argument(D)
102
+ # number of elements in C
103
+ self.tensor_c_numel = C.size
104
+ else:
105
+ raise TypeError(
106
+ "Unsupported Frontend. Only support numpy and torch")
107
+
108
+ def sync(self, stream_sync=True):
109
+ if stream_sync:
110
+ err, = cudart.cudaDeviceSynchronize()
111
+ if err != cuda.CUresult.CUDA_SUCCESS:
112
+ raise RuntimeError("CUDA Error %s" % str(err))
113
+
114
+ if hasattr(self, "host_D"):
115
+ err, = cuda.cuMemcpyDtoH(
116
+ self.host_D, self.ptr_D, self.host_D.size * self.host_D.itemsize)
117
+ if err != cuda.CUresult.CUDA_SUCCESS:
118
+ raise RuntimeError("CUDA Error %s" % str(err))
@@ -0,0 +1,241 @@
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ import ctypes
34
+ from pycutlass.library import *
35
+
36
+ # 12B
37
+
38
+
39
+ class GemmCoord_(ctypes.Structure):
40
+ _fields_ = [
41
+ ("m", ctypes.c_int),
42
+ ("n", ctypes.c_int),
43
+ ("k", ctypes.c_int)
44
+ ]
45
+
46
+ def __init__(self, gemm_coord) -> None:
47
+ for field_name, _ in self._fields_:
48
+ setattr(self, field_name, getattr(gemm_coord, field_name)())
49
+
50
+
51
+ class MatrixCoord_(ctypes.Structure):
52
+ _fields_ = [
53
+ ("row", ctypes.c_int),
54
+ ("column", ctypes.c_int)
55
+ ]
56
+
57
+
58
+ dtype2ctype = {
59
+ cutlass.float16: ctypes.c_uint16,
60
+ cutlass.float32: ctypes.c_float,
61
+ cutlass.float64: ctypes.c_double,
62
+ cutlass.int32: ctypes.c_int32
63
+ }
64
+
65
+
66
+ def get_gemm_arguments(epilogue_functor):
67
+
68
+ _EpilogueOutputOpParams = epilogue_functor.epilogue_type
69
+
70
+ class _GemmArguments(ctypes.Structure):
71
+ _fields_ = [
72
+ # Arguments from UniversalArgumentsBase
73
+ ("mode", ctypes.c_int),
74
+ ("problem_size", GemmCoord_),
75
+ ("batch_count", ctypes.c_int),
76
+ ("batch_stride_D", ctypes.c_longlong),
77
+ # Remaining arguments
78
+ ("epilogue", _EpilogueOutputOpParams),
79
+ ("ptr_A", ctypes.c_void_p),
80
+ ("ptr_B", ctypes.c_void_p),
81
+ ("ptr_C", ctypes.c_void_p),
82
+ ("ptr_D", ctypes.c_void_p),
83
+ ("batch_stride_A", ctypes.c_longlong),
84
+ ("batch_stride_B", ctypes.c_longlong),
85
+ ("batch_stride_C", ctypes.c_longlong),
86
+ ("stride_a", ctypes.c_longlong),
87
+ ("stride_b", ctypes.c_longlong),
88
+ ("stride_c", ctypes.c_longlong),
89
+ ("stride_d", ctypes.c_longlong),
90
+ ("lda", ctypes.c_longlong),
91
+ ("ldb", ctypes.c_longlong),
92
+ ("ldc", ctypes.c_longlong),
93
+ ("ldd", ctypes.c_longlong),
94
+ ("ptr_gather_A_indices", ctypes.c_void_p),
95
+ ("ptr_gether_B_indices", ctypes.c_void_p),
96
+ ("ptr_scatter_D_indices", ctypes.c_void_p)
97
+ ]
98
+
99
+ return _GemmArguments, _EpilogueOutputOpParams
100
+
101
+
102
+ ###########################################################################################
103
+ # GEMM Grouped
104
+ ###########################################################################################
105
+
106
+ # include/cutlass/gemm/kernel/gemm_grouped.h
107
+
108
+ def get_gemm_grouped_arguments(epilogue_functor):
109
+ _EpilogueOutputOpParams = epilogue_functor.epilogue_type
110
+
111
+ class _GEMMGroupedArguments(ctypes.Structure):
112
+ _fields_ = [
113
+ ("problem_sizes", ctypes.c_void_p),
114
+ ("problem_count", ctypes.c_int),
115
+ ("threadblock_count", ctypes.c_int),
116
+ ("output_op", _EpilogueOutputOpParams),
117
+ ("ptr_A", ctypes.c_void_p),
118
+ ("ptr_B", ctypes.c_void_p),
119
+ ("ptr_C", ctypes.c_void_p),
120
+ ("ptr_D", ctypes.c_void_p),
121
+ ("lda", ctypes.c_void_p),
122
+ ("ldb", ctypes.c_void_p),
123
+ ("ldc", ctypes.c_void_p),
124
+ ("ldd", ctypes.c_void_p),
125
+ ("host_problem_sizes", ctypes.c_void_p)
126
+ ]
127
+
128
+ return _GEMMGroupedArguments, _EpilogueOutputOpParams
129
+
130
+ ############################################################################################
131
+ # Convolution2D
132
+ ############################################################################################
133
+
134
+
135
+ # We use the arguments as the interface
136
+
137
+
138
+ # include/cutlass/conv/conv2d_problem_size.h
139
+ # 64B
140
+ class Conv2DProblemSize(ctypes.Structure):
141
+ _fields_ = [
142
+ ("N", ctypes.c_int),
143
+ ("H", ctypes.c_int),
144
+ ("W", ctypes.c_int),
145
+ ("C", ctypes.c_int),
146
+ ("P", ctypes.c_int),
147
+ ("Q", ctypes.c_int),
148
+ ("K", ctypes.c_int),
149
+ ("R", ctypes.c_int),
150
+ ("S", ctypes.c_int),
151
+ ("pad_h", ctypes.c_int),
152
+ ("pad_w", ctypes.c_int),
153
+ ("stride_h", ctypes.c_int),
154
+ ("stride_w", ctypes.c_int),
155
+ ("dilation_h", ctypes.c_int),
156
+ ("dilation_w", ctypes.c_int),
157
+ ("mode", ctypes.c_int), # kCrossCorrelation: 0, kConvolution: 1
158
+ ("split_k_slices", ctypes.c_int),
159
+ ("groups", ctypes.c_int)
160
+ ]
161
+
162
+ def __init__(self, problem_size) -> None:
163
+ for field_name, _ in self._fields_:
164
+ setattr(self, field_name, getattr(problem_size, field_name))
165
+
166
+
167
+ # include/cutlass/layout/tensor.h
168
+ # 12B
169
+ class Layout4D(ctypes.Structure):
170
+ _fields_ = [
171
+ ("stride", ctypes.c_int * 3)
172
+ ]
173
+
174
+ def __init__(self, tensor_ref):
175
+ stride = tensor_ref.stride()
176
+ setattr(self, "stride", (stride.at(0), stride.at(1), stride.at(2)))
177
+
178
+ # TODO: Tensor 5-D takes ("stride", ctypes.c_int * 4)
179
+
180
+
181
+ # include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_optimized.h
182
+ # TensorRef is basically cutlass::TensorRef<Element, Layout>;
183
+ # include/cutlass/tensor_ref.h
184
+ # 24B
185
+ class TensorRef_(ctypes.Structure):
186
+ _fields_ = [
187
+ ("ptr", ctypes.c_void_p),
188
+ ("layout", Layout4D)
189
+ ]
190
+
191
+ def __init__(self, tensor_ref):
192
+ setattr(self, "ptr", tensor_ref.data())
193
+ setattr(self, "layout", Layout4D(tensor_ref.layout()))
194
+
195
+
196
+ class TensorRef2D_(ctypes.Structure):
197
+ _fields_ = [
198
+ ("ptr", ctypes.c_void_p),
199
+ ("stride", ctypes.c_int)
200
+ ]
201
+
202
+
203
+ # include/cutlass/conv/kernel/implicit_gemm_convolution.h
204
+ # split_k_mode: kNone: 0, kSerial: 1, kParallel: 2, kParallelSerial: 3, kInvalid: 4
205
+
206
+ def get_conv2d_arguments(epilogue_functor):
207
+ _EpilogueOutputOpParams = epilogue_functor.epilogue_type
208
+
209
+ class _Conv2dArguments(ctypes.Structure):
210
+ _fields_ = [
211
+ ("problem_size", Conv2DProblemSize), # 0
212
+ ("ref_A", TensorRef_), # 72
213
+ ("ref_B", TensorRef_), # 96
214
+ ("ref_C", TensorRef_), # 120
215
+ ("ref_D", TensorRef_), # 144
216
+ ("output_op", _EpilogueOutputOpParams), # 168
217
+ ("split_k_mode", ctypes.c_int) # 192
218
+ ]
219
+
220
+ return _Conv2dArguments, _EpilogueOutputOpParams
221
+
222
+
223
+ ############################################################################################
224
+ # Reduction
225
+ ############################################################################################
226
+
227
+
228
+ def get_reduction_params(epilogue_functor):
229
+ _EpilogueOutputParams = epilogue_functor.epilogue_type
230
+
231
+ class _ReductionParams(ctypes.Structure):
232
+ _fields_ = [
233
+ ("problem_size", MatrixCoord_),
234
+ ("partitions", ctypes.c_int),
235
+ ("partition_stride", ctypes.c_longlong),
236
+ ("workspace", TensorRef2D_),
237
+ ("destination", TensorRef2D_),
238
+ ("source", TensorRef2D_),
239
+ ("output_op", _EpilogueOutputParams)
240
+ ]
241
+ return _ReductionParams, _EpilogueOutputParams