warp-lang 0.10.1__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (300) hide show
  1. warp/__init__.py +10 -4
  2. warp/__init__.pyi +1 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +5 -3
  6. warp/build_dll.py +29 -9
  7. warp/builtins.py +868 -507
  8. warp/codegen.py +1074 -638
  9. warp/config.py +3 -3
  10. warp/constants.py +6 -0
  11. warp/context.py +715 -222
  12. warp/fabric.py +326 -0
  13. warp/fem/__init__.py +27 -0
  14. warp/fem/cache.py +389 -0
  15. warp/fem/dirichlet.py +181 -0
  16. warp/fem/domain.py +263 -0
  17. warp/fem/field/__init__.py +101 -0
  18. warp/fem/field/field.py +149 -0
  19. warp/fem/field/nodal_field.py +299 -0
  20. warp/fem/field/restriction.py +21 -0
  21. warp/fem/field/test.py +181 -0
  22. warp/fem/field/trial.py +183 -0
  23. warp/fem/geometry/__init__.py +19 -0
  24. warp/fem/geometry/closest_point.py +70 -0
  25. warp/fem/geometry/deformed_geometry.py +271 -0
  26. warp/fem/geometry/element.py +744 -0
  27. warp/fem/geometry/geometry.py +186 -0
  28. warp/fem/geometry/grid_2d.py +373 -0
  29. warp/fem/geometry/grid_3d.py +435 -0
  30. warp/fem/geometry/hexmesh.py +953 -0
  31. warp/fem/geometry/partition.py +376 -0
  32. warp/fem/geometry/quadmesh_2d.py +532 -0
  33. warp/fem/geometry/tetmesh.py +840 -0
  34. warp/fem/geometry/trimesh_2d.py +577 -0
  35. warp/fem/integrate.py +1616 -0
  36. warp/fem/operator.py +191 -0
  37. warp/fem/polynomial.py +213 -0
  38. warp/fem/quadrature/__init__.py +2 -0
  39. warp/fem/quadrature/pic_quadrature.py +245 -0
  40. warp/fem/quadrature/quadrature.py +294 -0
  41. warp/fem/space/__init__.py +292 -0
  42. warp/fem/space/basis_space.py +489 -0
  43. warp/fem/space/collocated_function_space.py +105 -0
  44. warp/fem/space/dof_mapper.py +236 -0
  45. warp/fem/space/function_space.py +145 -0
  46. warp/fem/space/grid_2d_function_space.py +267 -0
  47. warp/fem/space/grid_3d_function_space.py +306 -0
  48. warp/fem/space/hexmesh_function_space.py +352 -0
  49. warp/fem/space/partition.py +350 -0
  50. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  51. warp/fem/space/restriction.py +160 -0
  52. warp/fem/space/shape/__init__.py +15 -0
  53. warp/fem/space/shape/cube_shape_function.py +738 -0
  54. warp/fem/space/shape/shape_function.py +103 -0
  55. warp/fem/space/shape/square_shape_function.py +611 -0
  56. warp/fem/space/shape/tet_shape_function.py +567 -0
  57. warp/fem/space/shape/triangle_shape_function.py +429 -0
  58. warp/fem/space/tetmesh_function_space.py +292 -0
  59. warp/fem/space/topology.py +295 -0
  60. warp/fem/space/trimesh_2d_function_space.py +221 -0
  61. warp/fem/types.py +77 -0
  62. warp/fem/utils.py +495 -0
  63. warp/native/array.h +147 -44
  64. warp/native/builtin.h +122 -149
  65. warp/native/bvh.cpp +73 -325
  66. warp/native/bvh.cu +406 -23
  67. warp/native/bvh.h +34 -43
  68. warp/native/clang/clang.cpp +13 -8
  69. warp/native/crt.h +2 -0
  70. warp/native/cuda_crt.h +5 -0
  71. warp/native/cuda_util.cpp +15 -3
  72. warp/native/cuda_util.h +3 -1
  73. warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
  74. warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
  75. warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
  76. warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
  77. warp/native/cutlass/tools/library/scripts/library.py +799 -0
  78. warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
  79. warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
  80. warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
  81. warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
  82. warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
  83. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
  84. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
  85. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
  86. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
  87. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
  88. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
  89. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
  90. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
  91. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
  92. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
  93. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
  94. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
  95. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
  96. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
  97. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
  98. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
  99. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
  100. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
  101. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
  102. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
  103. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
  104. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
  105. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
  106. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
  107. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
  108. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
  109. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
  110. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
  111. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
  112. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
  113. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
  114. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
  115. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  116. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  117. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
  118. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
  119. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
  120. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  121. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  122. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
  123. warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
  124. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
  125. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
  126. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
  127. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
  128. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
  129. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
  130. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
  131. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
  132. warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
  133. warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
  134. warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
  135. warp/native/cutlass/tools/library/scripts/rt.py +796 -0
  136. warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
  137. warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
  138. warp/native/cutlass_gemm.cu +5 -3
  139. warp/native/exports.h +1240 -952
  140. warp/native/fabric.h +228 -0
  141. warp/native/hashgrid.cpp +4 -4
  142. warp/native/hashgrid.h +22 -2
  143. warp/native/intersect.h +22 -7
  144. warp/native/intersect_adj.h +8 -8
  145. warp/native/intersect_tri.h +1 -1
  146. warp/native/marching.cu +157 -161
  147. warp/native/mat.h +80 -19
  148. warp/native/matnn.h +2 -2
  149. warp/native/mesh.cpp +33 -108
  150. warp/native/mesh.cu +114 -23
  151. warp/native/mesh.h +446 -46
  152. warp/native/noise.h +272 -329
  153. warp/native/quat.h +51 -8
  154. warp/native/rand.h +45 -35
  155. warp/native/range.h +6 -2
  156. warp/native/reduce.cpp +1 -1
  157. warp/native/reduce.cu +10 -12
  158. warp/native/runlength_encode.cu +6 -10
  159. warp/native/scan.cu +8 -11
  160. warp/native/sparse.cpp +4 -4
  161. warp/native/sparse.cu +164 -154
  162. warp/native/spatial.h +2 -2
  163. warp/native/temp_buffer.h +14 -30
  164. warp/native/vec.h +107 -23
  165. warp/native/volume.h +120 -0
  166. warp/native/warp.cpp +560 -30
  167. warp/native/warp.cu +431 -44
  168. warp/native/warp.h +13 -4
  169. warp/optim/__init__.py +1 -0
  170. warp/optim/linear.py +922 -0
  171. warp/optim/sgd.py +92 -0
  172. warp/render/render_opengl.py +335 -119
  173. warp/render/render_usd.py +11 -11
  174. warp/sim/__init__.py +2 -2
  175. warp/sim/articulation.py +385 -185
  176. warp/sim/collide.py +8 -0
  177. warp/sim/import_mjcf.py +297 -106
  178. warp/sim/import_urdf.py +389 -210
  179. warp/sim/import_usd.py +198 -97
  180. warp/sim/inertia.py +17 -18
  181. warp/sim/integrator_euler.py +14 -8
  182. warp/sim/integrator_xpbd.py +158 -16
  183. warp/sim/model.py +795 -291
  184. warp/sim/render.py +3 -3
  185. warp/sim/utils.py +3 -0
  186. warp/sparse.py +640 -150
  187. warp/stubs.py +606 -267
  188. warp/tape.py +61 -10
  189. warp/tests/__main__.py +3 -6
  190. warp/tests/assets/curlnoise_golden.npy +0 -0
  191. warp/tests/assets/pnoise_golden.npy +0 -0
  192. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  193. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  194. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  195. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  196. warp/tests/aux_test_unresolved_func.py +14 -0
  197. warp/tests/aux_test_unresolved_symbol.py +14 -0
  198. warp/tests/disabled_kinematics.py +239 -0
  199. warp/tests/run_coverage_serial.py +31 -0
  200. warp/tests/test_adam.py +103 -106
  201. warp/tests/test_arithmetic.py +128 -74
  202. warp/tests/test_array.py +212 -97
  203. warp/tests/test_array_reduce.py +57 -23
  204. warp/tests/test_atomic.py +64 -28
  205. warp/tests/test_bool.py +99 -0
  206. warp/tests/test_builtins_resolution.py +1292 -0
  207. warp/tests/test_bvh.py +42 -18
  208. warp/tests/test_closest_point_edge_edge.py +54 -57
  209. warp/tests/test_codegen.py +208 -130
  210. warp/tests/test_compile_consts.py +28 -20
  211. warp/tests/test_conditional.py +108 -24
  212. warp/tests/test_copy.py +10 -12
  213. warp/tests/test_ctypes.py +112 -88
  214. warp/tests/test_dense.py +21 -14
  215. warp/tests/test_devices.py +98 -0
  216. warp/tests/test_dlpack.py +75 -75
  217. warp/tests/test_examples.py +277 -0
  218. warp/tests/test_fabricarray.py +955 -0
  219. warp/tests/test_fast_math.py +15 -11
  220. warp/tests/test_fem.py +1271 -0
  221. warp/tests/test_fp16.py +53 -19
  222. warp/tests/test_func.py +187 -86
  223. warp/tests/test_generics.py +194 -49
  224. warp/tests/test_grad.py +178 -109
  225. warp/tests/test_grad_customs.py +176 -0
  226. warp/tests/test_hash_grid.py +52 -37
  227. warp/tests/test_import.py +10 -23
  228. warp/tests/test_indexedarray.py +32 -31
  229. warp/tests/test_intersect.py +18 -9
  230. warp/tests/test_large.py +141 -0
  231. warp/tests/test_launch.py +14 -41
  232. warp/tests/test_lerp.py +64 -65
  233. warp/tests/test_linear_solvers.py +154 -0
  234. warp/tests/test_lvalue.py +493 -0
  235. warp/tests/test_marching_cubes.py +12 -13
  236. warp/tests/test_mat.py +517 -2898
  237. warp/tests/test_mat_lite.py +115 -0
  238. warp/tests/test_mat_scalar_ops.py +2889 -0
  239. warp/tests/test_math.py +103 -9
  240. warp/tests/test_matmul.py +305 -69
  241. warp/tests/test_matmul_lite.py +410 -0
  242. warp/tests/test_mesh.py +71 -14
  243. warp/tests/test_mesh_query_aabb.py +41 -25
  244. warp/tests/test_mesh_query_point.py +140 -22
  245. warp/tests/test_mesh_query_ray.py +39 -22
  246. warp/tests/test_mlp.py +30 -22
  247. warp/tests/test_model.py +92 -89
  248. warp/tests/test_modules_lite.py +39 -0
  249. warp/tests/test_multigpu.py +88 -114
  250. warp/tests/test_noise.py +12 -11
  251. warp/tests/test_operators.py +16 -20
  252. warp/tests/test_options.py +11 -11
  253. warp/tests/test_pinned.py +17 -18
  254. warp/tests/test_print.py +32 -11
  255. warp/tests/test_quat.py +275 -129
  256. warp/tests/test_rand.py +18 -16
  257. warp/tests/test_reload.py +38 -34
  258. warp/tests/test_rounding.py +50 -43
  259. warp/tests/test_runlength_encode.py +168 -20
  260. warp/tests/test_smoothstep.py +9 -11
  261. warp/tests/test_snippet.py +143 -0
  262. warp/tests/test_sparse.py +261 -63
  263. warp/tests/test_spatial.py +276 -243
  264. warp/tests/test_streams.py +110 -85
  265. warp/tests/test_struct.py +268 -63
  266. warp/tests/test_tape.py +39 -21
  267. warp/tests/test_torch.py +118 -89
  268. warp/tests/test_transient_module.py +12 -13
  269. warp/tests/test_types.py +614 -0
  270. warp/tests/test_utils.py +494 -0
  271. warp/tests/test_vec.py +354 -2050
  272. warp/tests/test_vec_lite.py +73 -0
  273. warp/tests/test_vec_scalar_ops.py +2099 -0
  274. warp/tests/test_volume.py +457 -293
  275. warp/tests/test_volume_write.py +124 -134
  276. warp/tests/unittest_serial.py +35 -0
  277. warp/tests/unittest_suites.py +341 -0
  278. warp/tests/unittest_utils.py +568 -0
  279. warp/tests/unused_test_misc.py +71 -0
  280. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  281. warp/thirdparty/appdirs.py +36 -45
  282. warp/thirdparty/unittest_parallel.py +549 -0
  283. warp/torch.py +9 -6
  284. warp/types.py +1089 -366
  285. warp/utils.py +93 -387
  286. warp_lang-0.11.0.dist-info/METADATA +238 -0
  287. warp_lang-0.11.0.dist-info/RECORD +332 -0
  288. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
  289. warp/tests/test_all.py +0 -219
  290. warp/tests/test_array_scan.py +0 -60
  291. warp/tests/test_base.py +0 -208
  292. warp/tests/test_unresolved_func.py +0 -7
  293. warp/tests/test_unresolved_symbol.py +0 -7
  294. warp_lang-0.10.1.dist-info/METADATA +0 -21
  295. warp_lang-0.10.1.dist-info/RECORD +0 -188
  296. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  297. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  298. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  299. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/LICENSE.md +0 -0
  300. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,432 @@
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+ from pycutlass import *
33
+ import cutlass
34
+ from cuda import cuda
35
+ from cuda import nvrtc
36
+ import tempfile
37
+ import os
38
+ import ctypes
39
+
40
+ #
41
+ import json
42
+ import sqlite3
43
+
44
+
45
+ IncludeTemplate = r'''#include "${include}"
46
+ '''
47
+
48
+ #
49
+
50
+
51
+ class CompilationOptions:
52
+ '''
53
+ Compilation options.
54
+ '''
55
+
56
+ #
57
+ def __init__(self, flags, architectures=[80], include_paths=[]):
58
+ self.includes = []
59
+ self.include_paths = include_paths
60
+ self.flags = flags
61
+ self.architectures = architectures
62
+
63
+ def get_str(self):
64
+ options = ""
65
+
66
+ for flag in self.flags:
67
+ options += " " + flag
68
+
69
+ for incl in self.include_paths:
70
+ options += ' --include-path=%s' % incl
71
+
72
+ arch_list = "-arch="
73
+ for idx, arch in enumerate(self.architectures):
74
+ if idx:
75
+ arch_list += ","
76
+ arch_list += "sm_%d" % arch
77
+
78
+ options += " " + arch_list
79
+ return options
80
+
81
+ #
82
+ def get(self):
83
+ options = []
84
+
85
+ for flag in self.flags:
86
+ options.append(bytes(str.encode(flag)))
87
+
88
+ for incl in self.include_paths:
89
+ options.append(bytes(str.encode('--include-path=%s' % incl)))
90
+
91
+ arch_list = "-arch="
92
+ for idx, arch in enumerate(self.architectures):
93
+ if idx:
94
+ arch_list += ","
95
+ arch_list += "sm_%d" % arch
96
+
97
+ options.append(bytes(str.encode(arch_list)))
98
+
99
+ return options
100
+
101
+
102
+ def convertToBinaryData(filename):
103
+ with open(filename, 'rb') as file:
104
+ blobData = file.read()
105
+ return blobData
106
+
107
+
108
+ def CDLLBin(host_binary):
109
+ tempfile.tempdir = "./"
110
+ temp_so = tempfile.NamedTemporaryFile(
111
+ prefix='host_func', suffix='.so', delete=True)
112
+ with open(temp_so.name, 'wb') as file:
113
+ file.write(host_binary)
114
+ host_lib = ctypes.CDLL(temp_so.name)
115
+ return host_lib
116
+
117
+
118
+ class ArtifactManager:
119
+ """
120
+ Artifact manager
121
+ """
122
+
123
+ def __init__(self) -> None:
124
+ try:
125
+ connection = sqlite3.connect("./compiled_cache.db")
126
+ cursor = connection.cursor()
127
+ sqlite_create_table_query = """CREATE TABLE compiled_operations(op_key TEXT NOT NULL UNIQUE, cubin BLOB NOT NULL, hostbin BLOB NOT NULL, op_name TEXT NOT NULL, op_attrs TEXT NOT NULL)"""
128
+ cursor.execute(sqlite_create_table_query)
129
+ connection.commit()
130
+ cursor.close()
131
+ except:
132
+ pass
133
+
134
+ self.nvcc()
135
+ self.compiled_cache_device = cutlass.CompileCache()
136
+ self.compiled_cache_host = cutlass.CompileCache()
137
+
138
+ def nvrtc(self):
139
+ self.backend = "nvrtc"
140
+ self.default_compile_options = [
141
+ '-std=c++11', '-default-device',
142
+ ]
143
+ def nvcc(self):
144
+ self.backend = "nvcc"
145
+ self.default_compile_options = [
146
+ '-std=c++11',
147
+ ]
148
+ def insert_operation(self, op_key, cubin, hostfile, op_name, op_attrs):
149
+ connection = sqlite3.connect("./compiled_cache.db")
150
+ cursor = connection.cursor()
151
+ sqlite_insert_blob_query = """ INSERT OR IGNORE INTO compiled_operations (op_key, cubin, hostbin, op_name, op_attrs) VALUES (?, ?, ?, ?, ?)"""
152
+
153
+ hostbin = convertToBinaryData(hostfile)
154
+
155
+ data_tuple = (op_key, cubin, hostbin, op_name, json.dumps(op_attrs))
156
+
157
+ cursor.execute(sqlite_insert_blob_query, data_tuple)
158
+ connection.commit()
159
+ cursor.close()
160
+
161
+ def load_operation(self, op_key):
162
+ connection = sqlite3.connect("./compiled_cache.db")
163
+ cursor = connection.cursor()
164
+ sqlite_fetch_blob_query = """SELECT * from compiled_operations where op_key = ?"""
165
+ # try:
166
+ cursor.execute(sqlite_fetch_blob_query, (op_key, ))
167
+ record = cursor.fetchall()
168
+ if len(record) == 0:
169
+ return False
170
+ for row in record:
171
+ key, cubin_image, host_binary, operation_name, op_attr = row
172
+ op_attr = json.loads(op_attr)
173
+ err, module = cuda.cuModuleLoadData(cubin_image)
174
+ if err != cuda.CUresult.CUDA_SUCCESS:
175
+ raise RuntimeError('Cuda Error: {}'.format(err))
176
+
177
+ err, kernel = cuda.cuModuleGetFunction(
178
+ module, bytes(str.encode(operation_name)))
179
+ self.compiled_cache_device.insert(key, kernel)
180
+
181
+ compiled_host_fns = {}
182
+ host_lib = CDLLBin(host_binary)
183
+
184
+ func_name = operation_name + '_get_params'
185
+ func = getattr(host_lib, func_name)
186
+ func.restype = ctypes.POINTER(ctypes.c_char * op_attr[0])
187
+ compiled_host_fns['get_args'] = func
188
+
189
+ func_name = operation_name + '_shared_memory_size'
190
+ func = getattr(host_lib, func_name)
191
+ compiled_host_fns['shared_memory_capacity'] = func()
192
+
193
+ for attr in op_attr:
194
+ if isinstance(attr, str):
195
+ func_name = operation_name + '_' + attr
196
+ func = getattr(host_lib, func_name)
197
+ compiled_host_fns[attr] = func
198
+
199
+ self.compiled_cache_host.insert(key, compiled_host_fns)
200
+ return True
201
+
202
+ def emit_compile_(self, operation_list, compilation_options):
203
+ """
204
+ Compile a list of kernels and store them into database
205
+ """
206
+ source_buffer_device = ""
207
+ source_buffer_host = ""
208
+ # 1. include
209
+ includes = []
210
+ for operation in operation_list:
211
+ for incl in operation.emitter.includes:
212
+ if incl not in includes:
213
+ includes.append(incl)
214
+
215
+ includes_host = [
216
+ "builtin_types.h", "device_launch_parameters.h", "stddef.h"] + includes
217
+ for incl in includes:
218
+ source_buffer_device += SubstituteTemplate(
219
+ IncludeTemplate, {'include': incl})
220
+
221
+ for incl in includes_host:
222
+ if "/device/" not in incl:
223
+ source_buffer_host += SubstituteTemplate(
224
+ IncludeTemplate, {'include': incl})
225
+
226
+ # 2. Operations
227
+ for operation in operation_list:
228
+ source_buffer_device += operation.emit()
229
+ source_buffer_host += operation.emit()
230
+ values = {
231
+ 'operation_name': operation.name(),
232
+ 'operation_suffix': operation.emitter.operation_suffix
233
+ }
234
+ source_buffer_device += SubstituteTemplate(
235
+ operation.KernelTemplate, values)
236
+ source_buffer_host += SubstituteTemplate(
237
+ operation.HostTemplate, values)
238
+
239
+ if self.backend == "nvrtc":
240
+ # 3. compile
241
+ err, program = nvrtc.nvrtcCreateProgram(
242
+ str.encode(source_buffer_device),
243
+ bytes(str.encode("module.cu")),
244
+ 0, [], [])
245
+
246
+ if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
247
+ raise RuntimeError('NVRTC Error: {}'.format(err))
248
+
249
+ # Compile program
250
+ options = compilation_options.get()
251
+
252
+ err, = nvrtc.nvrtcCompileProgram(program, len(options), options)
253
+ if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
254
+
255
+ error_string = 'NVRTC Error: {}\n'.format(err)
256
+
257
+ # Get log from compilation
258
+ err, logSize = nvrtc.nvrtcGetProgramLogSize(program)
259
+ if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
260
+ raise RuntimeError('NVRTC Error: {}'.format(err))
261
+
262
+ log = b' ' * logSize
263
+ err, = nvrtc.nvrtcGetProgramLog(program, log)
264
+ if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
265
+ raise RuntimeError('NVRTC Error: {}'.format(err))
266
+
267
+ raise RuntimeError(
268
+ error_string + log.decode() + source_buffer_device)
269
+
270
+ # Get data from compilation
271
+ err, dataSize = nvrtc.nvrtcGetCUBINSize(program)
272
+ if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
273
+ raise RuntimeError('NVRTC Error: {}'.format(err))
274
+
275
+ cubin_image = b' ' * dataSize
276
+ err, = nvrtc.nvrtcGetCUBIN(program, cubin_image)
277
+ if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
278
+ raise RuntimeError('NVRTC Error: {}'.format(err))
279
+ else: # with nvcc backend
280
+ # emit code
281
+ tempfile.tempdir = "./"
282
+ temp_cu = tempfile.NamedTemporaryFile(
283
+ prefix='kernel', suffix='.cu', delete=True)
284
+ temp_cubin = tempfile.NamedTemporaryFile(
285
+ prefix='kernel', suffix='.cubin', delete=True)
286
+ with open(temp_cu.name, 'w') as file:
287
+ file.write(source_buffer_device)
288
+
289
+ # compile with nvcc
290
+ cuda_install_path = os.getenv('CUDA_INSTALL_PATH')
291
+ assert cuda_install_path is not None, "Environment variable 'CUDA_INSTALL_PATH' is not defined."
292
+ cmd_template = "${cuda_install_path}/bin/nvcc ${options} -cubin ${srcfile} -o ${tarfile}"
293
+ values = {
294
+ "cuda_install_path": cuda_install_path,
295
+ "options": compilation_options.get_str(),
296
+ "srcfile": temp_cu.name,
297
+ "tarfile": temp_cubin.name
298
+ }
299
+ cmd = SubstituteTemplate(cmd_template, values)
300
+ os.system(cmd)
301
+
302
+ # load the cubin image
303
+ with open(temp_cubin.name, 'rb') as file:
304
+ cubin_image = file.read()
305
+
306
+ # compile the host code
307
+ options = compilation_options.get()
308
+ cmd = "echo '%s'|g++ -x c++ -fpermissive -w -fPIC" % source_buffer_host
309
+ for opt in options:
310
+ opt = opt.decode("utf-8")
311
+ if opt not in ['-default-device', '-std=c++11', '-Xcicc', '-Xllc'] and '-arch=sm_' not in opt:
312
+ if '--include-path=' in opt:
313
+ cmd += " " + opt.replace('--include-path=', '-I')
314
+ else:
315
+ cmd += " " + opt
316
+
317
+ tempfile.tempdir = "./"
318
+ temp = tempfile.NamedTemporaryFile(
319
+ prefix='host_func', suffix='.so', delete=True)
320
+
321
+ cmd += ' - -shared -o %s' % temp.name
322
+ os.system(cmd)
323
+ host_lib = ctypes.CDLL(temp.name)
324
+
325
+ return cubin_image, host_lib, temp
326
+
327
+ def add_module(self, operations, compile_options=None):
328
+ """
329
+ Insert a new compiled device module
330
+ """
331
+ if compile_options is None:
332
+ cutlass_path = os.getenv('CUTLASS_PATH')
333
+ assert cutlass_path is not None, "Environment variable 'CUTLASS_PATH' is not defined."
334
+ cuda_install_path = os.getenv('CUDA_INSTALL_PATH')
335
+ assert cuda_install_path is not None, "Environment variable 'CUDA_INSTALL_PATH' is not defined."
336
+ architectures = []
337
+ for operation in operations:
338
+ if hasattr(operation, "tile_description"):
339
+ cc = operation.arch
340
+ if cc not in architectures:
341
+ architectures.append(cc)
342
+ include_paths = [
343
+ cuda_install_path + '/include',
344
+ cutlass_path + '/include',
345
+ cutlass_path + '/tools/util/include',
346
+ cutlass_path + '/tools/library/scripts/pycutlass/src/cpp/include'
347
+ ]
348
+ compile_options = CompilationOptions(
349
+ self.default_compile_options, architectures, include_paths)
350
+ # save the cubin
351
+ operation_key = []
352
+ operation_list = []
353
+ for operation in operations:
354
+ # step 1: get kernel string as key
355
+ key = operation.rt_module.emit() + operation.procedural_name() + self.backend
356
+ # step 1: check if the operation is in cache
357
+ compiled_kernel = self.compiled_cache_device.at(key)
358
+
359
+ if compiled_kernel is None:
360
+ hit = self.load_operation(key)
361
+ if hit:
362
+ compiled_kernel = self.compiled_cache_device.at(key)
363
+ assert compiled_kernel is not None
364
+ if compiled_kernel is not None:
365
+ operation.rt_module.kernel = compiled_kernel
366
+ compiled_host_fns = self.compiled_cache_host.at(key)
367
+ assert compiled_host_fns is not None
368
+ for key in compiled_host_fns.keys():
369
+ setattr(operation.rt_module, key, compiled_host_fns[key])
370
+ operation.rt_module.initialize()
371
+ else:
372
+ operation_list.append(operation.rt_module)
373
+ operation_key.append(key)
374
+ if len(operation_list) > 0:
375
+ cubin_image, host_lib, host_file = self.emit_compile_(
376
+ operation_list, compile_options)
377
+
378
+ err, module = cuda.cuModuleLoadData(cubin_image)
379
+ if err != cuda.CUresult.CUDA_SUCCESS:
380
+ raise RuntimeError('Cuda Error: {}'.format(err))
381
+
382
+ operation_name = []
383
+ operation_attr = []
384
+ for operation, key in zip(operation_list, operation_key):
385
+ # get device kernels
386
+ err, operation.kernel = cuda.cuModuleGetFunction(
387
+ module,
388
+ bytes(str.encode(operation.name()))
389
+ )
390
+ operation_name.append(operation.name())
391
+ self.compiled_cache_device.insert(key, operation.kernel)
392
+ # get host functions
393
+ compiled_host_fns = {}
394
+ op_attr = []
395
+
396
+ # get param size
397
+ func_name = operation.name() + '_get_param_size'
398
+ func = getattr(host_lib, func_name)
399
+ param_size = func()
400
+
401
+ func_name = operation.name() + '_get_params'
402
+ func = getattr(host_lib, func_name)
403
+ func.argtype = operation.argtype
404
+ func.restype = ctypes.POINTER(ctypes.c_char * param_size)
405
+ setattr(operation, 'get_args', func)
406
+ compiled_host_fns['get_args'] = func
407
+
408
+ # set shared memory size
409
+ func_name = operation.name() + '_shared_memory_size'
410
+ func = getattr(host_lib, func_name)
411
+ setattr(operation, 'shared_memory_capacity', func())
412
+ compiled_host_fns['shared_memory_capacity'] = func()
413
+ # set the maximum dynamic shared size
414
+ operation.initialize()
415
+
416
+ # get extra functions
417
+ op_attr.append(param_size)
418
+
419
+ if hasattr(operation, "extra_funcs"):
420
+ for suffix in operation.extra_funcs:
421
+ func_name = operation.name() + '_' + suffix
422
+ func = getattr(host_lib, func_name)
423
+ setattr(operation, suffix, func)
424
+ compiled_host_fns[suffix] = func
425
+ op_attr.append(suffix)
426
+
427
+ operation_attr.append(op_attr)
428
+ self.compiled_cache_host.insert(key, compiled_host_fns)
429
+
430
+ for key, operation_name, operation_attr in zip(operation_key, operation_name, operation_attr):
431
+ self.insert_operation(
432
+ key, cubin_image, host_file.name, operation_name, operation_attr)