warp-lang 0.9.0__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (315) hide show
  1. warp/__init__.py +15 -7
  2. warp/__init__.pyi +1 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +22 -443
  6. warp/build_dll.py +384 -0
  7. warp/builtins.py +998 -488
  8. warp/codegen.py +1307 -739
  9. warp/config.py +5 -3
  10. warp/constants.py +6 -0
  11. warp/context.py +1291 -548
  12. warp/dlpack.py +31 -31
  13. warp/fabric.py +326 -0
  14. warp/fem/__init__.py +27 -0
  15. warp/fem/cache.py +389 -0
  16. warp/fem/dirichlet.py +181 -0
  17. warp/fem/domain.py +263 -0
  18. warp/fem/field/__init__.py +101 -0
  19. warp/fem/field/field.py +149 -0
  20. warp/fem/field/nodal_field.py +299 -0
  21. warp/fem/field/restriction.py +21 -0
  22. warp/fem/field/test.py +181 -0
  23. warp/fem/field/trial.py +183 -0
  24. warp/fem/geometry/__init__.py +19 -0
  25. warp/fem/geometry/closest_point.py +70 -0
  26. warp/fem/geometry/deformed_geometry.py +271 -0
  27. warp/fem/geometry/element.py +744 -0
  28. warp/fem/geometry/geometry.py +186 -0
  29. warp/fem/geometry/grid_2d.py +373 -0
  30. warp/fem/geometry/grid_3d.py +435 -0
  31. warp/fem/geometry/hexmesh.py +953 -0
  32. warp/fem/geometry/partition.py +376 -0
  33. warp/fem/geometry/quadmesh_2d.py +532 -0
  34. warp/fem/geometry/tetmesh.py +840 -0
  35. warp/fem/geometry/trimesh_2d.py +577 -0
  36. warp/fem/integrate.py +1616 -0
  37. warp/fem/operator.py +191 -0
  38. warp/fem/polynomial.py +213 -0
  39. warp/fem/quadrature/__init__.py +2 -0
  40. warp/fem/quadrature/pic_quadrature.py +245 -0
  41. warp/fem/quadrature/quadrature.py +294 -0
  42. warp/fem/space/__init__.py +292 -0
  43. warp/fem/space/basis_space.py +489 -0
  44. warp/fem/space/collocated_function_space.py +105 -0
  45. warp/fem/space/dof_mapper.py +236 -0
  46. warp/fem/space/function_space.py +145 -0
  47. warp/fem/space/grid_2d_function_space.py +267 -0
  48. warp/fem/space/grid_3d_function_space.py +306 -0
  49. warp/fem/space/hexmesh_function_space.py +352 -0
  50. warp/fem/space/partition.py +350 -0
  51. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  52. warp/fem/space/restriction.py +160 -0
  53. warp/fem/space/shape/__init__.py +15 -0
  54. warp/fem/space/shape/cube_shape_function.py +738 -0
  55. warp/fem/space/shape/shape_function.py +103 -0
  56. warp/fem/space/shape/square_shape_function.py +611 -0
  57. warp/fem/space/shape/tet_shape_function.py +567 -0
  58. warp/fem/space/shape/triangle_shape_function.py +429 -0
  59. warp/fem/space/tetmesh_function_space.py +292 -0
  60. warp/fem/space/topology.py +295 -0
  61. warp/fem/space/trimesh_2d_function_space.py +221 -0
  62. warp/fem/types.py +77 -0
  63. warp/fem/utils.py +495 -0
  64. warp/native/array.h +164 -55
  65. warp/native/builtin.h +150 -174
  66. warp/native/bvh.cpp +75 -328
  67. warp/native/bvh.cu +406 -23
  68. warp/native/bvh.h +37 -45
  69. warp/native/clang/clang.cpp +136 -24
  70. warp/native/crt.cpp +1 -76
  71. warp/native/crt.h +111 -104
  72. warp/native/cuda_crt.h +1049 -0
  73. warp/native/cuda_util.cpp +15 -3
  74. warp/native/cuda_util.h +3 -1
  75. warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
  76. warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
  77. warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
  78. warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
  79. warp/native/cutlass/tools/library/scripts/library.py +799 -0
  80. warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
  81. warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
  82. warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
  83. warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
  84. warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
  85. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
  86. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
  87. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
  88. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
  89. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
  90. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
  91. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
  92. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
  93. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
  94. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
  95. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
  96. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
  97. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
  98. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
  99. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
  100. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
  101. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
  102. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
  103. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
  104. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
  105. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
  106. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
  107. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
  108. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
  109. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
  110. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
  111. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
  112. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
  113. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
  114. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
  115. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
  116. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
  117. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  118. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  119. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
  120. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
  121. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
  122. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  123. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  124. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
  125. warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
  126. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
  127. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
  128. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
  129. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
  130. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
  131. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
  132. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
  133. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
  134. warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
  135. warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
  136. warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
  137. warp/native/cutlass/tools/library/scripts/rt.py +796 -0
  138. warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
  139. warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
  140. warp/native/cutlass_gemm.cu +5 -3
  141. warp/native/exports.h +1240 -949
  142. warp/native/fabric.h +228 -0
  143. warp/native/hashgrid.cpp +4 -4
  144. warp/native/hashgrid.h +22 -2
  145. warp/native/initializer_array.h +2 -2
  146. warp/native/intersect.h +22 -7
  147. warp/native/intersect_adj.h +8 -8
  148. warp/native/intersect_tri.h +13 -16
  149. warp/native/marching.cu +157 -161
  150. warp/native/mat.h +119 -19
  151. warp/native/matnn.h +2 -2
  152. warp/native/mesh.cpp +108 -83
  153. warp/native/mesh.cu +243 -6
  154. warp/native/mesh.h +1547 -458
  155. warp/native/nanovdb/NanoVDB.h +1 -1
  156. warp/native/noise.h +272 -329
  157. warp/native/quat.h +51 -8
  158. warp/native/rand.h +45 -35
  159. warp/native/range.h +6 -2
  160. warp/native/reduce.cpp +157 -0
  161. warp/native/reduce.cu +348 -0
  162. warp/native/runlength_encode.cpp +62 -0
  163. warp/native/runlength_encode.cu +46 -0
  164. warp/native/scan.cu +11 -13
  165. warp/native/scan.h +1 -0
  166. warp/native/solid_angle.h +442 -0
  167. warp/native/sort.cpp +13 -0
  168. warp/native/sort.cu +9 -1
  169. warp/native/sparse.cpp +338 -0
  170. warp/native/sparse.cu +545 -0
  171. warp/native/spatial.h +2 -2
  172. warp/native/temp_buffer.h +30 -0
  173. warp/native/vec.h +126 -24
  174. warp/native/volume.h +120 -0
  175. warp/native/warp.cpp +658 -53
  176. warp/native/warp.cu +660 -68
  177. warp/native/warp.h +112 -12
  178. warp/optim/__init__.py +1 -0
  179. warp/optim/linear.py +922 -0
  180. warp/optim/sgd.py +92 -0
  181. warp/render/render_opengl.py +392 -152
  182. warp/render/render_usd.py +11 -11
  183. warp/sim/__init__.py +2 -2
  184. warp/sim/articulation.py +385 -185
  185. warp/sim/collide.py +21 -8
  186. warp/sim/import_mjcf.py +297 -106
  187. warp/sim/import_urdf.py +389 -210
  188. warp/sim/import_usd.py +198 -97
  189. warp/sim/inertia.py +17 -18
  190. warp/sim/integrator_euler.py +14 -8
  191. warp/sim/integrator_xpbd.py +161 -19
  192. warp/sim/model.py +795 -291
  193. warp/sim/optimizer.py +2 -6
  194. warp/sim/render.py +65 -3
  195. warp/sim/utils.py +3 -0
  196. warp/sparse.py +1227 -0
  197. warp/stubs.py +665 -223
  198. warp/tape.py +66 -15
  199. warp/tests/__main__.py +3 -6
  200. warp/tests/assets/curlnoise_golden.npy +0 -0
  201. warp/tests/assets/pnoise_golden.npy +0 -0
  202. warp/tests/assets/torus.usda +105 -105
  203. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  204. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  205. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  206. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  207. warp/tests/aux_test_unresolved_func.py +14 -0
  208. warp/tests/aux_test_unresolved_symbol.py +14 -0
  209. warp/tests/disabled_kinematics.py +239 -0
  210. warp/tests/run_coverage_serial.py +31 -0
  211. warp/tests/test_adam.py +103 -106
  212. warp/tests/test_arithmetic.py +128 -74
  213. warp/tests/test_array.py +1497 -211
  214. warp/tests/test_array_reduce.py +150 -0
  215. warp/tests/test_atomic.py +64 -28
  216. warp/tests/test_bool.py +99 -0
  217. warp/tests/test_builtins_resolution.py +1292 -0
  218. warp/tests/test_bvh.py +75 -43
  219. warp/tests/test_closest_point_edge_edge.py +54 -57
  220. warp/tests/test_codegen.py +233 -128
  221. warp/tests/test_compile_consts.py +28 -20
  222. warp/tests/test_conditional.py +108 -24
  223. warp/tests/test_copy.py +10 -12
  224. warp/tests/test_ctypes.py +112 -88
  225. warp/tests/test_dense.py +21 -14
  226. warp/tests/test_devices.py +98 -0
  227. warp/tests/test_dlpack.py +136 -108
  228. warp/tests/test_examples.py +277 -0
  229. warp/tests/test_fabricarray.py +955 -0
  230. warp/tests/test_fast_math.py +15 -11
  231. warp/tests/test_fem.py +1271 -0
  232. warp/tests/test_fp16.py +53 -19
  233. warp/tests/test_func.py +187 -74
  234. warp/tests/test_generics.py +194 -49
  235. warp/tests/test_grad.py +180 -116
  236. warp/tests/test_grad_customs.py +176 -0
  237. warp/tests/test_hash_grid.py +52 -37
  238. warp/tests/test_import.py +10 -23
  239. warp/tests/test_indexedarray.py +577 -24
  240. warp/tests/test_intersect.py +18 -9
  241. warp/tests/test_large.py +141 -0
  242. warp/tests/test_launch.py +251 -15
  243. warp/tests/test_lerp.py +64 -65
  244. warp/tests/test_linear_solvers.py +154 -0
  245. warp/tests/test_lvalue.py +493 -0
  246. warp/tests/test_marching_cubes.py +12 -13
  247. warp/tests/test_mat.py +508 -2778
  248. warp/tests/test_mat_lite.py +115 -0
  249. warp/tests/test_mat_scalar_ops.py +2889 -0
  250. warp/tests/test_math.py +103 -9
  251. warp/tests/test_matmul.py +305 -69
  252. warp/tests/test_matmul_lite.py +410 -0
  253. warp/tests/test_mesh.py +71 -14
  254. warp/tests/test_mesh_query_aabb.py +41 -25
  255. warp/tests/test_mesh_query_point.py +325 -34
  256. warp/tests/test_mesh_query_ray.py +39 -22
  257. warp/tests/test_mlp.py +30 -22
  258. warp/tests/test_model.py +92 -89
  259. warp/tests/test_modules_lite.py +39 -0
  260. warp/tests/test_multigpu.py +88 -114
  261. warp/tests/test_noise.py +12 -11
  262. warp/tests/test_operators.py +16 -20
  263. warp/tests/test_options.py +11 -11
  264. warp/tests/test_pinned.py +17 -18
  265. warp/tests/test_print.py +32 -11
  266. warp/tests/test_quat.py +275 -129
  267. warp/tests/test_rand.py +18 -16
  268. warp/tests/test_reload.py +38 -34
  269. warp/tests/test_rounding.py +50 -43
  270. warp/tests/test_runlength_encode.py +190 -0
  271. warp/tests/test_smoothstep.py +9 -11
  272. warp/tests/test_snippet.py +143 -0
  273. warp/tests/test_sparse.py +460 -0
  274. warp/tests/test_spatial.py +276 -243
  275. warp/tests/test_streams.py +110 -85
  276. warp/tests/test_struct.py +331 -85
  277. warp/tests/test_tape.py +39 -21
  278. warp/tests/test_torch.py +118 -89
  279. warp/tests/test_transient_module.py +12 -13
  280. warp/tests/test_types.py +614 -0
  281. warp/tests/test_utils.py +494 -0
  282. warp/tests/test_vec.py +354 -1987
  283. warp/tests/test_vec_lite.py +73 -0
  284. warp/tests/test_vec_scalar_ops.py +2099 -0
  285. warp/tests/test_volume.py +457 -293
  286. warp/tests/test_volume_write.py +124 -134
  287. warp/tests/unittest_serial.py +35 -0
  288. warp/tests/unittest_suites.py +341 -0
  289. warp/tests/unittest_utils.py +568 -0
  290. warp/tests/unused_test_misc.py +71 -0
  291. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  292. warp/thirdparty/appdirs.py +36 -45
  293. warp/thirdparty/unittest_parallel.py +549 -0
  294. warp/torch.py +72 -30
  295. warp/types.py +1744 -713
  296. warp/utils.py +360 -350
  297. warp_lang-0.11.0.dist-info/LICENSE.md +36 -0
  298. warp_lang-0.11.0.dist-info/METADATA +238 -0
  299. warp_lang-0.11.0.dist-info/RECORD +332 -0
  300. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
  301. warp/bin/warp-clang.exp +0 -0
  302. warp/bin/warp-clang.lib +0 -0
  303. warp/bin/warp.exp +0 -0
  304. warp/bin/warp.lib +0 -0
  305. warp/tests/test_all.py +0 -215
  306. warp/tests/test_array_scan.py +0 -60
  307. warp/tests/test_base.py +0 -208
  308. warp/tests/test_unresolved_func.py +0 -7
  309. warp/tests/test_unresolved_symbol.py +0 -7
  310. warp_lang-0.9.0.dist-info/METADATA +0 -20
  311. warp_lang-0.9.0.dist-info/RECORD +0 -177
  312. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  313. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  314. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  315. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,432 @@
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+ from pycutlass import *
33
+ import cutlass
34
+ from cuda import cuda
35
+ from cuda import nvrtc
36
+ import tempfile
37
+ import os
38
+ import ctypes
39
+
40
+ #
41
+ import json
42
+ import sqlite3
43
+
44
+
45
+ IncludeTemplate = r'''#include "${include}"
46
+ '''
47
+
48
+ #
49
+
50
+
51
+ class CompilationOptions:
52
+ '''
53
+ Compilation options.
54
+ '''
55
+
56
+ #
57
+ def __init__(self, flags, architectures=[80], include_paths=[]):
58
+ self.includes = []
59
+ self.include_paths = include_paths
60
+ self.flags = flags
61
+ self.architectures = architectures
62
+
63
+ def get_str(self):
64
+ options = ""
65
+
66
+ for flag in self.flags:
67
+ options += " " + flag
68
+
69
+ for incl in self.include_paths:
70
+ options += ' --include-path=%s' % incl
71
+
72
+ arch_list = "-arch="
73
+ for idx, arch in enumerate(self.architectures):
74
+ if idx:
75
+ arch_list += ","
76
+ arch_list += "sm_%d" % arch
77
+
78
+ options += " " + arch_list
79
+ return options
80
+
81
+ #
82
+ def get(self):
83
+ options = []
84
+
85
+ for flag in self.flags:
86
+ options.append(bytes(str.encode(flag)))
87
+
88
+ for incl in self.include_paths:
89
+ options.append(bytes(str.encode('--include-path=%s' % incl)))
90
+
91
+ arch_list = "-arch="
92
+ for idx, arch in enumerate(self.architectures):
93
+ if idx:
94
+ arch_list += ","
95
+ arch_list += "sm_%d" % arch
96
+
97
+ options.append(bytes(str.encode(arch_list)))
98
+
99
+ return options
100
+
101
+
102
+ def convertToBinaryData(filename):
103
+ with open(filename, 'rb') as file:
104
+ blobData = file.read()
105
+ return blobData
106
+
107
+
108
+ def CDLLBin(host_binary):
109
+ tempfile.tempdir = "./"
110
+ temp_so = tempfile.NamedTemporaryFile(
111
+ prefix='host_func', suffix='.so', delete=True)
112
+ with open(temp_so.name, 'wb') as file:
113
+ file.write(host_binary)
114
+ host_lib = ctypes.CDLL(temp_so.name)
115
+ return host_lib
116
+
117
+
118
+ class ArtifactManager:
119
+ """
120
+ Artifact manager
121
+ """
122
+
123
+ def __init__(self) -> None:
124
+ try:
125
+ connection = sqlite3.connect("./compiled_cache.db")
126
+ cursor = connection.cursor()
127
+ sqlite_create_table_query = """CREATE TABLE compiled_operations(op_key TEXT NOT NULL UNIQUE, cubin BLOB NOT NULL, hostbin BLOB NOT NULL, op_name TEXT NOT NULL, op_attrs TEXT NOT NULL)"""
128
+ cursor.execute(sqlite_create_table_query)
129
+ connection.commit()
130
+ cursor.close()
131
+ except:
132
+ pass
133
+
134
+ self.nvcc()
135
+ self.compiled_cache_device = cutlass.CompileCache()
136
+ self.compiled_cache_host = cutlass.CompileCache()
137
+
138
+ def nvrtc(self):
139
+ self.backend = "nvrtc"
140
+ self.default_compile_options = [
141
+ '-std=c++11', '-default-device',
142
+ ]
143
+ def nvcc(self):
144
+ self.backend = "nvcc"
145
+ self.default_compile_options = [
146
+ '-std=c++11',
147
+ ]
148
+ def insert_operation(self, op_key, cubin, hostfile, op_name, op_attrs):
149
+ connection = sqlite3.connect("./compiled_cache.db")
150
+ cursor = connection.cursor()
151
+ sqlite_insert_blob_query = """ INSERT OR IGNORE INTO compiled_operations (op_key, cubin, hostbin, op_name, op_attrs) VALUES (?, ?, ?, ?, ?)"""
152
+
153
+ hostbin = convertToBinaryData(hostfile)
154
+
155
+ data_tuple = (op_key, cubin, hostbin, op_name, json.dumps(op_attrs))
156
+
157
+ cursor.execute(sqlite_insert_blob_query, data_tuple)
158
+ connection.commit()
159
+ cursor.close()
160
+
161
+ def load_operation(self, op_key):
162
+ connection = sqlite3.connect("./compiled_cache.db")
163
+ cursor = connection.cursor()
164
+ sqlite_fetch_blob_query = """SELECT * from compiled_operations where op_key = ?"""
165
+ # try:
166
+ cursor.execute(sqlite_fetch_blob_query, (op_key, ))
167
+ record = cursor.fetchall()
168
+ if len(record) == 0:
169
+ return False
170
+ for row in record:
171
+ key, cubin_image, host_binary, operation_name, op_attr = row
172
+ op_attr = json.loads(op_attr)
173
+ err, module = cuda.cuModuleLoadData(cubin_image)
174
+ if err != cuda.CUresult.CUDA_SUCCESS:
175
+ raise RuntimeError('Cuda Error: {}'.format(err))
176
+
177
+ err, kernel = cuda.cuModuleGetFunction(
178
+ module, bytes(str.encode(operation_name)))
179
+ self.compiled_cache_device.insert(key, kernel)
180
+
181
+ compiled_host_fns = {}
182
+ host_lib = CDLLBin(host_binary)
183
+
184
+ func_name = operation_name + '_get_params'
185
+ func = getattr(host_lib, func_name)
186
+ func.restype = ctypes.POINTER(ctypes.c_char * op_attr[0])
187
+ compiled_host_fns['get_args'] = func
188
+
189
+ func_name = operation_name + '_shared_memory_size'
190
+ func = getattr(host_lib, func_name)
191
+ compiled_host_fns['shared_memory_capacity'] = func()
192
+
193
+ for attr in op_attr:
194
+ if isinstance(attr, str):
195
+ func_name = operation_name + '_' + attr
196
+ func = getattr(host_lib, func_name)
197
+ compiled_host_fns[attr] = func
198
+
199
+ self.compiled_cache_host.insert(key, compiled_host_fns)
200
+ return True
201
+
202
+ def emit_compile_(self, operation_list, compilation_options):
203
+ """
204
+ Compile a list of kernels and store them into database
205
+ """
206
+ source_buffer_device = ""
207
+ source_buffer_host = ""
208
+ # 1. include
209
+ includes = []
210
+ for operation in operation_list:
211
+ for incl in operation.emitter.includes:
212
+ if incl not in includes:
213
+ includes.append(incl)
214
+
215
+ includes_host = [
216
+ "builtin_types.h", "device_launch_parameters.h", "stddef.h"] + includes
217
+ for incl in includes:
218
+ source_buffer_device += SubstituteTemplate(
219
+ IncludeTemplate, {'include': incl})
220
+
221
+ for incl in includes_host:
222
+ if "/device/" not in incl:
223
+ source_buffer_host += SubstituteTemplate(
224
+ IncludeTemplate, {'include': incl})
225
+
226
+ # 2. Operations
227
+ for operation in operation_list:
228
+ source_buffer_device += operation.emit()
229
+ source_buffer_host += operation.emit()
230
+ values = {
231
+ 'operation_name': operation.name(),
232
+ 'operation_suffix': operation.emitter.operation_suffix
233
+ }
234
+ source_buffer_device += SubstituteTemplate(
235
+ operation.KernelTemplate, values)
236
+ source_buffer_host += SubstituteTemplate(
237
+ operation.HostTemplate, values)
238
+
239
+ if self.backend == "nvrtc":
240
+ # 3. compile
241
+ err, program = nvrtc.nvrtcCreateProgram(
242
+ str.encode(source_buffer_device),
243
+ bytes(str.encode("module.cu")),
244
+ 0, [], [])
245
+
246
+ if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
247
+ raise RuntimeError('NVRTC Error: {}'.format(err))
248
+
249
+ # Compile program
250
+ options = compilation_options.get()
251
+
252
+ err, = nvrtc.nvrtcCompileProgram(program, len(options), options)
253
+ if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
254
+
255
+ error_string = 'NVRTC Error: {}\n'.format(err)
256
+
257
+ # Get log from compilation
258
+ err, logSize = nvrtc.nvrtcGetProgramLogSize(program)
259
+ if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
260
+ raise RuntimeError('NVRTC Error: {}'.format(err))
261
+
262
+ log = b' ' * logSize
263
+ err, = nvrtc.nvrtcGetProgramLog(program, log)
264
+ if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
265
+ raise RuntimeError('NVRTC Error: {}'.format(err))
266
+
267
+ raise RuntimeError(
268
+ error_string + log.decode() + source_buffer_device)
269
+
270
+ # Get data from compilation
271
+ err, dataSize = nvrtc.nvrtcGetCUBINSize(program)
272
+ if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
273
+ raise RuntimeError('NVRTC Error: {}'.format(err))
274
+
275
+ cubin_image = b' ' * dataSize
276
+ err, = nvrtc.nvrtcGetCUBIN(program, cubin_image)
277
+ if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
278
+ raise RuntimeError('NVRTC Error: {}'.format(err))
279
+ else: # with nvcc backend
280
+ # emit code
281
+ tempfile.tempdir = "./"
282
+ temp_cu = tempfile.NamedTemporaryFile(
283
+ prefix='kernel', suffix='.cu', delete=True)
284
+ temp_cubin = tempfile.NamedTemporaryFile(
285
+ prefix='kernel', suffix='.cubin', delete=True)
286
+ with open(temp_cu.name, 'w') as file:
287
+ file.write(source_buffer_device)
288
+
289
+ # compile with nvcc
290
+ cuda_install_path = os.getenv('CUDA_INSTALL_PATH')
291
+ assert cuda_install_path is not None, "Environment variable 'CUDA_INSTALL_PATH' is not defined."
292
+ cmd_template = "${cuda_install_path}/bin/nvcc ${options} -cubin ${srcfile} -o ${tarfile}"
293
+ values = {
294
+ "cuda_install_path": cuda_install_path,
295
+ "options": compilation_options.get_str(),
296
+ "srcfile": temp_cu.name,
297
+ "tarfile": temp_cubin.name
298
+ }
299
+ cmd = SubstituteTemplate(cmd_template, values)
300
+ os.system(cmd)
301
+
302
+ # load the cubin image
303
+ with open(temp_cubin.name, 'rb') as file:
304
+ cubin_image = file.read()
305
+
306
+ # compile the host code
307
+ options = compilation_options.get()
308
+ cmd = "echo '%s'|g++ -x c++ -fpermissive -w -fPIC" % source_buffer_host
309
+ for opt in options:
310
+ opt = opt.decode("utf-8")
311
+ if opt not in ['-default-device', '-std=c++11', '-Xcicc', '-Xllc'] and '-arch=sm_' not in opt:
312
+ if '--include-path=' in opt:
313
+ cmd += " " + opt.replace('--include-path=', '-I')
314
+ else:
315
+ cmd += " " + opt
316
+
317
+ tempfile.tempdir = "./"
318
+ temp = tempfile.NamedTemporaryFile(
319
+ prefix='host_func', suffix='.so', delete=True)
320
+
321
+ cmd += ' - -shared -o %s' % temp.name
322
+ os.system(cmd)
323
+ host_lib = ctypes.CDLL(temp.name)
324
+
325
+ return cubin_image, host_lib, temp
326
+
327
+ def add_module(self, operations, compile_options=None):
328
+ """
329
+ Insert a new compiled device module
330
+ """
331
+ if compile_options is None:
332
+ cutlass_path = os.getenv('CUTLASS_PATH')
333
+ assert cutlass_path is not None, "Environment variable 'CUTLASS_PATH' is not defined."
334
+ cuda_install_path = os.getenv('CUDA_INSTALL_PATH')
335
+ assert cuda_install_path is not None, "Environment variable 'CUDA_INSTALL_PATH' is not defined."
336
+ architectures = []
337
+ for operation in operations:
338
+ if hasattr(operation, "tile_description"):
339
+ cc = operation.arch
340
+ if cc not in architectures:
341
+ architectures.append(cc)
342
+ include_paths = [
343
+ cuda_install_path + '/include',
344
+ cutlass_path + '/include',
345
+ cutlass_path + '/tools/util/include',
346
+ cutlass_path + '/tools/library/scripts/pycutlass/src/cpp/include'
347
+ ]
348
+ compile_options = CompilationOptions(
349
+ self.default_compile_options, architectures, include_paths)
350
+ # save the cubin
351
+ operation_key = []
352
+ operation_list = []
353
+ for operation in operations:
354
+ # step 1: get kernel string as key
355
+ key = operation.rt_module.emit() + operation.procedural_name() + self.backend
356
+ # step 1: check if the operation is in cache
357
+ compiled_kernel = self.compiled_cache_device.at(key)
358
+
359
+ if compiled_kernel is None:
360
+ hit = self.load_operation(key)
361
+ if hit:
362
+ compiled_kernel = self.compiled_cache_device.at(key)
363
+ assert compiled_kernel is not None
364
+ if compiled_kernel is not None:
365
+ operation.rt_module.kernel = compiled_kernel
366
+ compiled_host_fns = self.compiled_cache_host.at(key)
367
+ assert compiled_host_fns is not None
368
+ for key in compiled_host_fns.keys():
369
+ setattr(operation.rt_module, key, compiled_host_fns[key])
370
+ operation.rt_module.initialize()
371
+ else:
372
+ operation_list.append(operation.rt_module)
373
+ operation_key.append(key)
374
+ if len(operation_list) > 0:
375
+ cubin_image, host_lib, host_file = self.emit_compile_(
376
+ operation_list, compile_options)
377
+
378
+ err, module = cuda.cuModuleLoadData(cubin_image)
379
+ if err != cuda.CUresult.CUDA_SUCCESS:
380
+ raise RuntimeError('Cuda Error: {}'.format(err))
381
+
382
+ operation_name = []
383
+ operation_attr = []
384
+ for operation, key in zip(operation_list, operation_key):
385
+ # get device kernels
386
+ err, operation.kernel = cuda.cuModuleGetFunction(
387
+ module,
388
+ bytes(str.encode(operation.name()))
389
+ )
390
+ operation_name.append(operation.name())
391
+ self.compiled_cache_device.insert(key, operation.kernel)
392
+ # get host functions
393
+ compiled_host_fns = {}
394
+ op_attr = []
395
+
396
+ # get param size
397
+ func_name = operation.name() + '_get_param_size'
398
+ func = getattr(host_lib, func_name)
399
+ param_size = func()
400
+
401
+ func_name = operation.name() + '_get_params'
402
+ func = getattr(host_lib, func_name)
403
+ func.argtype = operation.argtype
404
+ func.restype = ctypes.POINTER(ctypes.c_char * param_size)
405
+ setattr(operation, 'get_args', func)
406
+ compiled_host_fns['get_args'] = func
407
+
408
+ # set shared memory size
409
+ func_name = operation.name() + '_shared_memory_size'
410
+ func = getattr(host_lib, func_name)
411
+ setattr(operation, 'shared_memory_capacity', func())
412
+ compiled_host_fns['shared_memory_capacity'] = func()
413
+ # set the maximum dynamic shared size
414
+ operation.initialize()
415
+
416
+ # get extra functions
417
+ op_attr.append(param_size)
418
+
419
+ if hasattr(operation, "extra_funcs"):
420
+ for suffix in operation.extra_funcs:
421
+ func_name = operation.name() + '_' + suffix
422
+ func = getattr(host_lib, func_name)
423
+ setattr(operation, suffix, func)
424
+ compiled_host_fns[suffix] = func
425
+ op_attr.append(suffix)
426
+
427
+ operation_attr.append(op_attr)
428
+ self.compiled_cache_host.insert(key, compiled_host_fns)
429
+
430
+ for key, operation_name, operation_attr in zip(operation_key, operation_name, operation_attr):
431
+ self.insert_operation(
432
+ key, cubin_image, host_file.name, operation_name, operation_attr)