warp-lang 0.9.0__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (315) hide show
  1. warp/__init__.py +15 -7
  2. warp/__init__.pyi +1 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +22 -443
  6. warp/build_dll.py +384 -0
  7. warp/builtins.py +998 -488
  8. warp/codegen.py +1307 -739
  9. warp/config.py +5 -3
  10. warp/constants.py +6 -0
  11. warp/context.py +1291 -548
  12. warp/dlpack.py +31 -31
  13. warp/fabric.py +326 -0
  14. warp/fem/__init__.py +27 -0
  15. warp/fem/cache.py +389 -0
  16. warp/fem/dirichlet.py +181 -0
  17. warp/fem/domain.py +263 -0
  18. warp/fem/field/__init__.py +101 -0
  19. warp/fem/field/field.py +149 -0
  20. warp/fem/field/nodal_field.py +299 -0
  21. warp/fem/field/restriction.py +21 -0
  22. warp/fem/field/test.py +181 -0
  23. warp/fem/field/trial.py +183 -0
  24. warp/fem/geometry/__init__.py +19 -0
  25. warp/fem/geometry/closest_point.py +70 -0
  26. warp/fem/geometry/deformed_geometry.py +271 -0
  27. warp/fem/geometry/element.py +744 -0
  28. warp/fem/geometry/geometry.py +186 -0
  29. warp/fem/geometry/grid_2d.py +373 -0
  30. warp/fem/geometry/grid_3d.py +435 -0
  31. warp/fem/geometry/hexmesh.py +953 -0
  32. warp/fem/geometry/partition.py +376 -0
  33. warp/fem/geometry/quadmesh_2d.py +532 -0
  34. warp/fem/geometry/tetmesh.py +840 -0
  35. warp/fem/geometry/trimesh_2d.py +577 -0
  36. warp/fem/integrate.py +1616 -0
  37. warp/fem/operator.py +191 -0
  38. warp/fem/polynomial.py +213 -0
  39. warp/fem/quadrature/__init__.py +2 -0
  40. warp/fem/quadrature/pic_quadrature.py +245 -0
  41. warp/fem/quadrature/quadrature.py +294 -0
  42. warp/fem/space/__init__.py +292 -0
  43. warp/fem/space/basis_space.py +489 -0
  44. warp/fem/space/collocated_function_space.py +105 -0
  45. warp/fem/space/dof_mapper.py +236 -0
  46. warp/fem/space/function_space.py +145 -0
  47. warp/fem/space/grid_2d_function_space.py +267 -0
  48. warp/fem/space/grid_3d_function_space.py +306 -0
  49. warp/fem/space/hexmesh_function_space.py +352 -0
  50. warp/fem/space/partition.py +350 -0
  51. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  52. warp/fem/space/restriction.py +160 -0
  53. warp/fem/space/shape/__init__.py +15 -0
  54. warp/fem/space/shape/cube_shape_function.py +738 -0
  55. warp/fem/space/shape/shape_function.py +103 -0
  56. warp/fem/space/shape/square_shape_function.py +611 -0
  57. warp/fem/space/shape/tet_shape_function.py +567 -0
  58. warp/fem/space/shape/triangle_shape_function.py +429 -0
  59. warp/fem/space/tetmesh_function_space.py +292 -0
  60. warp/fem/space/topology.py +295 -0
  61. warp/fem/space/trimesh_2d_function_space.py +221 -0
  62. warp/fem/types.py +77 -0
  63. warp/fem/utils.py +495 -0
  64. warp/native/array.h +164 -55
  65. warp/native/builtin.h +150 -174
  66. warp/native/bvh.cpp +75 -328
  67. warp/native/bvh.cu +406 -23
  68. warp/native/bvh.h +37 -45
  69. warp/native/clang/clang.cpp +136 -24
  70. warp/native/crt.cpp +1 -76
  71. warp/native/crt.h +111 -104
  72. warp/native/cuda_crt.h +1049 -0
  73. warp/native/cuda_util.cpp +15 -3
  74. warp/native/cuda_util.h +3 -1
  75. warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
  76. warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
  77. warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
  78. warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
  79. warp/native/cutlass/tools/library/scripts/library.py +799 -0
  80. warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
  81. warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
  82. warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
  83. warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
  84. warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
  85. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
  86. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
  87. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
  88. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
  89. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
  90. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
  91. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
  92. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
  93. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
  94. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
  95. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
  96. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
  97. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
  98. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
  99. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
  100. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
  101. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
  102. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
  103. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
  104. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
  105. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
  106. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
  107. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
  108. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
  109. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
  110. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
  111. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
  112. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
  113. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
  114. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
  115. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
  116. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
  117. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  118. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  119. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
  120. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
  121. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
  122. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  123. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  124. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
  125. warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
  126. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
  127. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
  128. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
  129. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
  130. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
  131. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
  132. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
  133. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
  134. warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
  135. warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
  136. warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
  137. warp/native/cutlass/tools/library/scripts/rt.py +796 -0
  138. warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
  139. warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
  140. warp/native/cutlass_gemm.cu +5 -3
  141. warp/native/exports.h +1240 -949
  142. warp/native/fabric.h +228 -0
  143. warp/native/hashgrid.cpp +4 -4
  144. warp/native/hashgrid.h +22 -2
  145. warp/native/initializer_array.h +2 -2
  146. warp/native/intersect.h +22 -7
  147. warp/native/intersect_adj.h +8 -8
  148. warp/native/intersect_tri.h +13 -16
  149. warp/native/marching.cu +157 -161
  150. warp/native/mat.h +119 -19
  151. warp/native/matnn.h +2 -2
  152. warp/native/mesh.cpp +108 -83
  153. warp/native/mesh.cu +243 -6
  154. warp/native/mesh.h +1547 -458
  155. warp/native/nanovdb/NanoVDB.h +1 -1
  156. warp/native/noise.h +272 -329
  157. warp/native/quat.h +51 -8
  158. warp/native/rand.h +45 -35
  159. warp/native/range.h +6 -2
  160. warp/native/reduce.cpp +157 -0
  161. warp/native/reduce.cu +348 -0
  162. warp/native/runlength_encode.cpp +62 -0
  163. warp/native/runlength_encode.cu +46 -0
  164. warp/native/scan.cu +11 -13
  165. warp/native/scan.h +1 -0
  166. warp/native/solid_angle.h +442 -0
  167. warp/native/sort.cpp +13 -0
  168. warp/native/sort.cu +9 -1
  169. warp/native/sparse.cpp +338 -0
  170. warp/native/sparse.cu +545 -0
  171. warp/native/spatial.h +2 -2
  172. warp/native/temp_buffer.h +30 -0
  173. warp/native/vec.h +126 -24
  174. warp/native/volume.h +120 -0
  175. warp/native/warp.cpp +658 -53
  176. warp/native/warp.cu +660 -68
  177. warp/native/warp.h +112 -12
  178. warp/optim/__init__.py +1 -0
  179. warp/optim/linear.py +922 -0
  180. warp/optim/sgd.py +92 -0
  181. warp/render/render_opengl.py +392 -152
  182. warp/render/render_usd.py +11 -11
  183. warp/sim/__init__.py +2 -2
  184. warp/sim/articulation.py +385 -185
  185. warp/sim/collide.py +21 -8
  186. warp/sim/import_mjcf.py +297 -106
  187. warp/sim/import_urdf.py +389 -210
  188. warp/sim/import_usd.py +198 -97
  189. warp/sim/inertia.py +17 -18
  190. warp/sim/integrator_euler.py +14 -8
  191. warp/sim/integrator_xpbd.py +161 -19
  192. warp/sim/model.py +795 -291
  193. warp/sim/optimizer.py +2 -6
  194. warp/sim/render.py +65 -3
  195. warp/sim/utils.py +3 -0
  196. warp/sparse.py +1227 -0
  197. warp/stubs.py +665 -223
  198. warp/tape.py +66 -15
  199. warp/tests/__main__.py +3 -6
  200. warp/tests/assets/curlnoise_golden.npy +0 -0
  201. warp/tests/assets/pnoise_golden.npy +0 -0
  202. warp/tests/assets/torus.usda +105 -105
  203. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  204. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  205. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  206. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  207. warp/tests/aux_test_unresolved_func.py +14 -0
  208. warp/tests/aux_test_unresolved_symbol.py +14 -0
  209. warp/tests/disabled_kinematics.py +239 -0
  210. warp/tests/run_coverage_serial.py +31 -0
  211. warp/tests/test_adam.py +103 -106
  212. warp/tests/test_arithmetic.py +128 -74
  213. warp/tests/test_array.py +1497 -211
  214. warp/tests/test_array_reduce.py +150 -0
  215. warp/tests/test_atomic.py +64 -28
  216. warp/tests/test_bool.py +99 -0
  217. warp/tests/test_builtins_resolution.py +1292 -0
  218. warp/tests/test_bvh.py +75 -43
  219. warp/tests/test_closest_point_edge_edge.py +54 -57
  220. warp/tests/test_codegen.py +233 -128
  221. warp/tests/test_compile_consts.py +28 -20
  222. warp/tests/test_conditional.py +108 -24
  223. warp/tests/test_copy.py +10 -12
  224. warp/tests/test_ctypes.py +112 -88
  225. warp/tests/test_dense.py +21 -14
  226. warp/tests/test_devices.py +98 -0
  227. warp/tests/test_dlpack.py +136 -108
  228. warp/tests/test_examples.py +277 -0
  229. warp/tests/test_fabricarray.py +955 -0
  230. warp/tests/test_fast_math.py +15 -11
  231. warp/tests/test_fem.py +1271 -0
  232. warp/tests/test_fp16.py +53 -19
  233. warp/tests/test_func.py +187 -74
  234. warp/tests/test_generics.py +194 -49
  235. warp/tests/test_grad.py +180 -116
  236. warp/tests/test_grad_customs.py +176 -0
  237. warp/tests/test_hash_grid.py +52 -37
  238. warp/tests/test_import.py +10 -23
  239. warp/tests/test_indexedarray.py +577 -24
  240. warp/tests/test_intersect.py +18 -9
  241. warp/tests/test_large.py +141 -0
  242. warp/tests/test_launch.py +251 -15
  243. warp/tests/test_lerp.py +64 -65
  244. warp/tests/test_linear_solvers.py +154 -0
  245. warp/tests/test_lvalue.py +493 -0
  246. warp/tests/test_marching_cubes.py +12 -13
  247. warp/tests/test_mat.py +508 -2778
  248. warp/tests/test_mat_lite.py +115 -0
  249. warp/tests/test_mat_scalar_ops.py +2889 -0
  250. warp/tests/test_math.py +103 -9
  251. warp/tests/test_matmul.py +305 -69
  252. warp/tests/test_matmul_lite.py +410 -0
  253. warp/tests/test_mesh.py +71 -14
  254. warp/tests/test_mesh_query_aabb.py +41 -25
  255. warp/tests/test_mesh_query_point.py +325 -34
  256. warp/tests/test_mesh_query_ray.py +39 -22
  257. warp/tests/test_mlp.py +30 -22
  258. warp/tests/test_model.py +92 -89
  259. warp/tests/test_modules_lite.py +39 -0
  260. warp/tests/test_multigpu.py +88 -114
  261. warp/tests/test_noise.py +12 -11
  262. warp/tests/test_operators.py +16 -20
  263. warp/tests/test_options.py +11 -11
  264. warp/tests/test_pinned.py +17 -18
  265. warp/tests/test_print.py +32 -11
  266. warp/tests/test_quat.py +275 -129
  267. warp/tests/test_rand.py +18 -16
  268. warp/tests/test_reload.py +38 -34
  269. warp/tests/test_rounding.py +50 -43
  270. warp/tests/test_runlength_encode.py +190 -0
  271. warp/tests/test_smoothstep.py +9 -11
  272. warp/tests/test_snippet.py +143 -0
  273. warp/tests/test_sparse.py +460 -0
  274. warp/tests/test_spatial.py +276 -243
  275. warp/tests/test_streams.py +110 -85
  276. warp/tests/test_struct.py +331 -85
  277. warp/tests/test_tape.py +39 -21
  278. warp/tests/test_torch.py +118 -89
  279. warp/tests/test_transient_module.py +12 -13
  280. warp/tests/test_types.py +614 -0
  281. warp/tests/test_utils.py +494 -0
  282. warp/tests/test_vec.py +354 -1987
  283. warp/tests/test_vec_lite.py +73 -0
  284. warp/tests/test_vec_scalar_ops.py +2099 -0
  285. warp/tests/test_volume.py +457 -293
  286. warp/tests/test_volume_write.py +124 -134
  287. warp/tests/unittest_serial.py +35 -0
  288. warp/tests/unittest_suites.py +341 -0
  289. warp/tests/unittest_utils.py +568 -0
  290. warp/tests/unused_test_misc.py +71 -0
  291. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  292. warp/thirdparty/appdirs.py +36 -45
  293. warp/thirdparty/unittest_parallel.py +549 -0
  294. warp/torch.py +72 -30
  295. warp/types.py +1744 -713
  296. warp/utils.py +360 -350
  297. warp_lang-0.11.0.dist-info/LICENSE.md +36 -0
  298. warp_lang-0.11.0.dist-info/METADATA +238 -0
  299. warp_lang-0.11.0.dist-info/RECORD +332 -0
  300. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
  301. warp/bin/warp-clang.exp +0 -0
  302. warp/bin/warp-clang.lib +0 -0
  303. warp/bin/warp.exp +0 -0
  304. warp/bin/warp.lib +0 -0
  305. warp/tests/test_all.py +0 -215
  306. warp/tests/test_array_scan.py +0 -60
  307. warp/tests/test_base.py +0 -208
  308. warp/tests/test_unresolved_func.py +0 -7
  309. warp/tests/test_unresolved_symbol.py +0 -7
  310. warp_lang-0.9.0.dist-info/METADATA +0 -20
  311. warp_lang-0.9.0.dist-info/RECORD +0 -177
  312. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  313. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  314. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  315. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
warp/__init__.py CHANGED
@@ -10,8 +10,9 @@
10
10
 
11
11
  from warp.types import array, array1d, array2d, array3d, array4d, constant
12
12
  from warp.types import indexedarray, indexedarray1d, indexedarray2d, indexedarray3d, indexedarray4d
13
+ from warp.fabric import fabricarray, fabricarrayarray, indexedfabricarray, indexedfabricarrayarray
13
14
 
14
- from warp.types import int8, uint8, int16, uint16, int32, uint32, int64, uint64, float16, float32, float64
15
+ from warp.types import bool, int8, uint8, int16, uint16, int32, uint32, int64, uint64, float16, float32, float64
15
16
  from warp.types import vec2, vec2b, vec2ub, vec2s, vec2us, vec2i, vec2ui, vec2l, vec2ul, vec2h, vec2f, vec2d
16
17
  from warp.types import vec3, vec3b, vec3ub, vec3s, vec3us, vec3i, vec3ui, vec3l, vec3ul, vec3h, vec3f, vec3d
17
18
  from warp.types import vec4, vec4b, vec4ub, vec4s, vec4us, vec4i, vec4ui, vec4l, vec4ul, vec4h, vec4f, vec4d
@@ -25,7 +26,9 @@ from warp.types import spatial_matrix, spatial_matrixh, spatial_matrixf, spatial
25
26
 
26
27
  # geometry types
27
28
  from warp.types import Bvh, Mesh, HashGrid, Volume, MarchingCubes
28
- from warp.types import bvh_query_t, mesh_query_aabb_t, hash_grid_query_t
29
+ from warp.types import bvh_query_t, hash_grid_query_t, mesh_query_aabb_t, mesh_query_point_t, mesh_query_ray_t
30
+
31
+
29
32
 
30
33
  # device-wide gemms
31
34
  from warp.types import matmul, adj_matmul, batched_matmul, adj_batched_matmul, from_ptr
@@ -34,7 +37,7 @@ from warp.types import matmul, adj_matmul, batched_matmul, adj_batched_matmul, f
34
37
  from warp.types import vector as vec
35
38
  from warp.types import matrix as mat
36
39
 
37
- from warp.context import init, func, kernel, struct, overload
40
+ from warp.context import init, func, func_grad, func_replay, func_native, kernel, struct, overload
38
41
  from warp.context import is_cpu_available, is_cuda_available, is_device_available
39
42
  from warp.context import get_devices, get_preferred_device
40
43
  from warp.context import get_cuda_devices, get_cuda_device_count, get_cuda_device, map_cuda_device, unmap_cuda_device
@@ -42,6 +45,8 @@ from warp.context import get_device, set_device, synchronize_device
42
45
  from warp.context import (
43
46
  zeros,
44
47
  zeros_like,
48
+ full,
49
+ full_like,
45
50
  clone,
46
51
  empty,
47
52
  empty_like,
@@ -54,15 +59,14 @@ from warp.context import (
54
59
  )
55
60
  from warp.context import set_module_options, get_module_options, get_module
56
61
  from warp.context import capture_begin, capture_end, capture_launch
57
- from warp.context import print_builtins, export_builtins, export_stubs
58
- from warp.context import Kernel, Function
62
+ from warp.context import Kernel, Function, Launch
59
63
  from warp.context import Stream, get_stream, set_stream, synchronize_stream
60
64
  from warp.context import Event, record_event, wait_event, wait_stream
61
65
  from warp.context import RegisteredGLBuffer
62
66
 
63
67
  from warp.tape import Tape
64
- from warp.utils import ScopedTimer, ScopedCudaGuard, ScopedDevice, ScopedStream
65
- from warp.utils import transform_expand
68
+ from warp.utils import ScopedTimer, ScopedDevice, ScopedStream
69
+ from warp.utils import transform_expand, quat_between_vectors
66
70
 
67
71
  from warp.torch import from_torch, to_torch
68
72
  from warp.torch import device_from_torch, device_to_torch
@@ -76,3 +80,7 @@ from warp.dlpack import from_dlpack, to_dlpack
76
80
  from warp.constants import *
77
81
 
78
82
  from . import builtins
83
+
84
+ import warp.config
85
+
86
+ __version__ = warp.config.version
warp/__init__.pyi ADDED
@@ -0,0 +1 @@
1
+ from .stubs import *
warp/bin/warp-clang.dll CHANGED
Binary file
warp/bin/warp.dll CHANGED
Binary file
warp/build.py CHANGED
@@ -6,142 +6,28 @@
6
6
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
7
 
8
8
  import os
9
- import sys
10
- import subprocess
11
- import ctypes
12
- import _ctypes
13
9
 
14
10
  import warp.config
15
- import warp.utils
16
- from warp.utils import ScopedTimer
17
11
  from warp.thirdparty import appdirs
18
12
 
19
- # return from funtions without type -> C++ compile error
20
- # array[i,j] += x -> augassign not handling target of subscript
21
-
22
-
23
- def run_cmd(cmd, capture=False):
24
- if warp.config.verbose:
25
- print(cmd)
26
-
27
- try:
28
- return subprocess.check_output(cmd, shell=True)
29
- except subprocess.CalledProcessError as e:
30
- print(e.output.decode())
31
- raise (e)
32
-
33
-
34
- # cut-down version of vcvars64.bat that allows using
35
- # custom toolchain locations
36
- def set_msvc_compiler(msvc_path, sdk_path):
37
- if "INCLUDE" not in os.environ:
38
- os.environ["INCLUDE"] = ""
39
-
40
- if "LIB" not in os.environ:
41
- os.environ["LIB"] = ""
42
-
43
- msvc_path = os.path.abspath(msvc_path)
44
- sdk_path = os.path.abspath(sdk_path)
45
-
46
- os.environ["INCLUDE"] += os.pathsep + os.path.join(msvc_path, "include")
47
- os.environ["INCLUDE"] += os.pathsep + os.path.join(sdk_path, "include/winrt")
48
- os.environ["INCLUDE"] += os.pathsep + os.path.join(sdk_path, "include/um")
49
- os.environ["INCLUDE"] += os.pathsep + os.path.join(sdk_path, "include/ucrt")
50
- os.environ["INCLUDE"] += os.pathsep + os.path.join(sdk_path, "include/shared")
51
-
52
- os.environ["LIB"] += os.pathsep + os.path.join(msvc_path, "lib/x64")
53
- os.environ["LIB"] += os.pathsep + os.path.join(sdk_path, "lib/ucrt/x64")
54
- os.environ["LIB"] += os.pathsep + os.path.join(sdk_path, "lib/um/x64")
55
-
56
- os.environ["PATH"] += os.pathsep + os.path.join(msvc_path, "bin/HostX64/x64")
57
- os.environ["PATH"] += os.pathsep + os.path.join(sdk_path, "bin/x64")
58
-
59
- warp.config.host_compiler = os.path.join(msvc_path, "bin", "HostX64", "x64", "cl.exe")
60
-
61
-
62
- def find_host_compiler():
63
- if os.name == "nt":
64
- try:
65
- # try and find an installed host compiler (msvc)
66
- # runs vcvars and copies back the build environment
67
-
68
- vswhere_path = r"%ProgramFiles(x86)%/Microsoft Visual Studio/Installer/vswhere.exe"
69
- vswhere_path = os.path.expandvars(vswhere_path)
70
- if not os.path.exists(vswhere_path):
71
- return ""
72
-
73
- vs_path = run_cmd(f'"{vswhere_path}" -latest -property installationPath').decode().rstrip()
74
- vsvars_path = os.path.join(vs_path, "VC\\Auxiliary\\Build\\vcvars64.bat")
75
-
76
- output = run_cmd(f'"{vsvars_path}" && set').decode()
77
-
78
- for line in output.splitlines():
79
- pair = line.split("=", 1)
80
- if len(pair) >= 2:
81
- os.environ[pair[0]] = pair[1]
82
-
83
- cl_path = run_cmd("where cl.exe").decode("utf-8").rstrip()
84
- cl_version = os.environ["VCToolsVersion"].split(".")
85
-
86
- # ensure at least VS2019 version, see list of MSVC versions here https://en.wikipedia.org/wiki/Microsoft_Visual_C%2B%2B
87
- cl_required_major = 14
88
- cl_required_minor = 29
89
-
90
- if (
91
- (int(cl_version[0]) < cl_required_major)
92
- or (int(cl_version[0]) == cl_required_major)
93
- and int(cl_version[1]) < cl_required_minor
94
- ):
95
- print(
96
- f"Warp: MSVC found but compiler version too old, found {cl_version[0]}.{cl_version[1]}, but must be {cl_required_major}.{cl_required_minor} or higher, kernel host compilation will be disabled."
97
- )
98
- return ""
99
-
100
- return cl_path
101
-
102
- except Exception as e:
103
- # couldn't find host compiler
104
- return ""
105
- else:
106
- # try and find g++
107
- try:
108
- return run_cmd("which g++").decode()
109
- except:
110
- return ""
111
-
112
-
113
- def get_cuda_toolkit_version(cuda_home):
114
- try:
115
- # the toolkit version can be obtained by running "nvcc --version"
116
- nvcc_path = os.path.join(cuda_home, "bin", "nvcc")
117
- nvcc_version_output = subprocess.check_output([nvcc_path, "--version"]).decode("utf-8")
118
- # search for release substring (e.g., "release 11.5")
119
- import re
120
-
121
- m = re.search(r"(?<=release )\d+\.\d+", nvcc_version_output)
122
- if m is not None:
123
- return tuple(int(x) for x in m.group(0).split("."))
124
- else:
125
- raise Exception("Failed to parse NVCC output")
126
-
127
- except Exception as e:
128
- print(f"Failed to determine CUDA Toolkit version: {e}")
129
-
130
13
 
131
14
  # builds cuda source to PTX or CUBIN using NVRTC (output type determined by output_path extension)
132
15
  def build_cuda(cu_path, arch, output_path, config="release", verify_fp=False, fast_math=False):
133
- src_file = open(cu_path)
134
- src = src_file.read().encode("utf-8")
135
- src_file.close()
16
+ with open(cu_path, "rb") as src_file:
17
+ src = src_file.read()
18
+ cu_path = cu_path.encode("utf-8")
19
+ inc_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "native").encode("utf-8")
20
+ output_path = output_path.encode("utf-8")
136
21
 
137
- inc_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "native").encode("utf-8")
138
- output_path = output_path.encode("utf-8")
22
+ if warp.config.llvm_cuda:
23
+ warp.context.runtime.llvm.compile_cuda(src, cu_path, inc_path, output_path, False)
139
24
 
140
- err = warp.context.runtime.core.cuda_compile_program(
141
- src, arch, inc_path, config == "debug", warp.config.verbose, verify_fp, fast_math, output_path
142
- )
143
- if err:
144
- raise Exception("CUDA build failed")
25
+ else:
26
+ err = warp.context.runtime.core.cuda_compile_program(
27
+ src, arch, inc_path, config == "debug", warp.config.verbose, verify_fp, fast_math, output_path
28
+ )
29
+ if err != 0:
30
+ raise Exception(f"CUDA kernel build failed with error code {err}")
145
31
 
146
32
 
147
33
  # load PTX or CUBIN as a CUDA runtime module (input type determined by input_path extension)
@@ -152,320 +38,16 @@ def load_cuda(input_path, device):
152
38
  return warp.context.runtime.core.cuda_load_module(device.context, input_path.encode("utf-8"))
153
39
 
154
40
 
155
- def quote(path):
156
- return '"' + path + '"'
157
-
158
-
159
- def build_dll(dll_path, cpp_paths, cu_path, libs=[], mode="release", verify_fp=False, fast_math=False, quick=False):
160
- cuda_home = warp.config.cuda_path
161
- cuda_cmd = None
162
-
163
- if quick:
164
- cutlass_includes = ""
165
- cutlass_enabled = "WP_ENABLE_CUTLASS=0"
166
- else:
167
- cutlass_home = "warp/native/cutlass"
168
- cutlass_includes = f'-I"{cutlass_home}/include" -I"{cutlass_home}/tools/util/include"'
169
- cutlass_enabled = "WP_ENABLE_CUTLASS=1"
170
-
171
- if quick or cu_path is None:
172
- cuda_compat_enabled = "WP_ENABLE_CUDA_COMPATIBILITY=0"
173
- else:
174
- cuda_compat_enabled = "WP_ENABLE_CUDA_COMPATIBILITY=1"
175
-
176
- import pathlib
177
-
178
- warp_home_path = pathlib.Path(__file__).parent
179
- warp_home = warp_home_path.resolve()
180
- nanovdb_home = warp_home_path.parent / "_build/host-deps/nanovdb/include"
41
+ def build_cpu(obj_path, cpp_path, mode="release", verify_fp=False, fast_math=False):
42
+ with open(cpp_path, "rb") as cpp:
43
+ src = cpp.read()
44
+ cpp_path = cpp_path.encode("utf-8")
45
+ inc_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "native").encode("utf-8")
46
+ obj_path = obj_path.encode("utf-8")
181
47
 
182
- # ensure that dll is not loaded in the process
183
- force_unload_dll(dll_path)
184
-
185
- # output stale, rebuild
186
- if warp.config.verbose:
187
- print(f"Building {dll_path}")
188
-
189
- native_dir = os.path.join(warp_home, "native")
190
-
191
- if cu_path:
192
- # check CUDA Toolkit version
193
- min_ctk_version = (11, 5)
194
- ctk_version = get_cuda_toolkit_version(cuda_home) or min_ctk_version
195
- if ctk_version < min_ctk_version:
196
- raise Exception(
197
- f"CUDA Toolkit version {min_ctk_version[0]}.{min_ctk_version[1]}+ is required (found {ctk_version[0]}.{ctk_version[1]} in {cuda_home})"
198
- )
199
-
200
- gencode_opts = []
201
-
202
- if quick:
203
- # minimum supported architectures (PTX)
204
- gencode_opts += ["-gencode=arch=compute_52,code=compute_52", "-gencode=arch=compute_75,code=compute_75"]
205
- else:
206
- # generate code for all supported architectures
207
- gencode_opts += [
208
- # SASS for supported desktop/datacenter architectures
209
- "-gencode=arch=compute_52,code=sm_52", # Maxwell
210
- "-gencode=arch=compute_60,code=sm_60", # Pascal
211
- "-gencode=arch=compute_61,code=sm_61",
212
- "-gencode=arch=compute_70,code=sm_70", # Volta
213
- "-gencode=arch=compute_75,code=sm_75", # Turing
214
- "-gencode=arch=compute_80,code=sm_80", # Ampere
215
- "-gencode=arch=compute_86,code=sm_86",
216
- # SASS for supported mobile architectures (e.g. Tegra/Jetson)
217
- # "-gencode=arch=compute_53,code=sm_53",
218
- # "-gencode=arch=compute_62,code=sm_62",
219
- # "-gencode=arch=compute_72,code=sm_72",
220
- # "-gencode=arch=compute_87,code=sm_87",
221
- ]
222
-
223
- # support for Ada and Hopper is available with CUDA Toolkit 11.8+
224
- if ctk_version >= (11, 8):
225
- gencode_opts += [
226
- "-gencode=arch=compute_89,code=sm_89", # Ada
227
- "-gencode=arch=compute_90,code=sm_90", # Hopper
228
- # PTX for future hardware
229
- "-gencode=arch=compute_90,code=compute_90",
230
- ]
231
- else:
232
- gencode_opts += [
233
- # PTX for future hardware
234
- "-gencode=arch=compute_86,code=compute_86",
235
- ]
236
-
237
- nvcc_opts = gencode_opts + [
238
- "-t0", # multithreaded compilation
239
- "--extended-lambda",
240
- ]
241
-
242
- if fast_math:
243
- nvcc_opts.append("--use_fast_math")
244
-
245
- # is the library being built with CUDA enabled?
246
- cuda_enabled = "WP_ENABLE_CUDA=1" if (cu_path is not None) else "WP_ENABLE_CUDA=0"
247
-
248
- if os.name == "nt":
249
- # try loading warp-clang.dll, except when we're building warp-clang.dll or warp.dll
250
- clang = None
251
- if os.path.basename(dll_path) != "warp-clang.dll" and os.path.basename(dll_path) != "warp.dll":
252
- try:
253
- clang = warp.build.load_dll(f"{warp_home_path}/bin/warp-clang.dll")
254
- except RuntimeError as e:
255
- clang = None
256
-
257
- if warp.config.host_compiler:
258
- host_linker = os.path.join(os.path.dirname(warp.config.host_compiler), "link.exe")
259
- elif not clang:
260
- raise RuntimeError("Warp build error: No host or bundled compiler was found")
261
-
262
- cpp_includes = f' /I"{warp_home_path.parent}/external/llvm-project/out/install/{mode}/include"'
263
- cpp_includes += f' /I"{warp_home_path.parent}/_build/host-deps/llvm-project/include"'
264
- cuda_includes = f' /I"{cuda_home}/include"' if cu_path else ""
265
- includes = cpp_includes + cuda_includes
266
-
267
- # nvrtc_static.lib is built with /MT and _ITERATOR_DEBUG_LEVEL=0 so if we link it in we must match these options
268
- if cu_path or mode != "debug":
269
- runtime = "/MT"
270
- iter_dbg = "_ITERATOR_DEBUG_LEVEL=0"
271
- debug = "NDEBUG"
272
- else:
273
- runtime = "/MTd"
274
- iter_dbg = "_ITERATOR_DEBUG_LEVEL=2"
275
- debug = "_DEBUG"
276
-
277
- if "/NODEFAULTLIB" in libs:
278
- runtime = "/sdl- /GS-" # don't specify a runtime, and disable security checks which depend on it
279
-
280
- if warp.config.mode == "debug":
281
- cpp_flags = f'/nologo {runtime} /Zi /Od /D "{debug}" /D WP_ENABLE_DEBUG=1 /D "WP_CPU" /D "{cuda_enabled}" /D "{cutlass_enabled}" /D "{cuda_compat_enabled}" /D "{iter_dbg}" /I"{native_dir}" /I"{nanovdb_home}" {includes}'
282
- linkopts = ["/DLL", "/DEBUG"]
283
- elif warp.config.mode == "release":
284
- cpp_flags = f'/nologo {runtime} /Ox /D "{debug}" /D WP_ENABLE_DEBUG=0 /D "WP_CPU" /D "{cuda_enabled}" /D "{cutlass_enabled}" /D "{cuda_compat_enabled}" /D "{iter_dbg}" /I"{native_dir}" /I"{nanovdb_home}" {includes}'
285
- linkopts = ["/DLL"]
286
- else:
287
- raise RuntimeError(f"Unrecognized build configuration (debug, release), got: {mode}")
288
-
289
- if verify_fp:
290
- cpp_flags += ' /D "WP_VERIFY_FP"'
291
-
292
- if fast_math:
293
- cpp_flags += " /fp:fast"
294
-
295
- with ScopedTimer("build", active=warp.config.verbose):
296
- for cpp_path in cpp_paths:
297
- cpp_out = cpp_path + ".obj"
298
- linkopts.append(quote(cpp_out))
299
-
300
- if clang:
301
- with open(cpp_path, "rb") as cpp:
302
- clang.compile_cpp(
303
- cpp.read(), native_dir.encode("utf-8"), cpp_out.encode("utf-8"), warp.config.mode == "debug"
304
- )
305
-
306
- else:
307
- cpp_cmd = f'"{warp.config.host_compiler}" {cpp_flags} -c "{cpp_path}" /Fo"{cpp_out}"'
308
- run_cmd(cpp_cmd)
309
-
310
- if cu_path:
311
- cu_out = cu_path + ".o"
312
-
313
- if mode == "debug":
314
- cuda_cmd = f'"{cuda_home}/bin/nvcc" --compiler-options=/MT,/Zi,/Od -g -G -O0 -DNDEBUG -D_ITERATOR_DEBUG_LEVEL=0 -I"{native_dir}" -I"{nanovdb_home}" -line-info {" ".join(nvcc_opts)} -DWP_CUDA -DWP_ENABLE_CUDA=1 -D{cutlass_enabled} {cutlass_includes} -o "{cu_out}" -c "{cu_path}"'
315
-
316
- elif mode == "release":
317
- cuda_cmd = f'"{cuda_home}/bin/nvcc" -O3 {" ".join(nvcc_opts)} -I"{native_dir}" -I"{nanovdb_home}" -DNDEBUG -DWP_CUDA -DWP_ENABLE_CUDA=1 -D{cutlass_enabled} {cutlass_includes} -o "{cu_out}" -c "{cu_path}"'
318
-
319
- with ScopedTimer("build_cuda", active=warp.config.verbose):
320
- run_cmd(cuda_cmd)
321
- linkopts.append(quote(cu_out))
322
- linkopts.append(
323
- f'cudart_static.lib nvrtc_static.lib nvrtc-builtins_static.lib nvptxcompiler_static.lib ws2_32.lib user32.lib /LIBPATH:"{cuda_home}/lib/x64"'
324
- )
325
-
326
- with ScopedTimer("link", active=warp.config.verbose):
327
- # Link into a DLL, unless we have LLVM to load the object code directly
328
- if not clang:
329
- link_cmd = f'"{host_linker}" {" ".join(linkopts + libs)} /out:"{dll_path}"'
330
- run_cmd(link_cmd)
331
-
332
- else:
333
- clang = None
334
- try:
335
- if sys.platform == "darwin":
336
- # try loading libwarp-clang.dylib, except when we're building libwarp-clang.dylib or libwarp.dylib
337
- if (
338
- os.path.basename(dll_path) != "libwarp-clang.dylib"
339
- and os.path.basename(dll_path) != "libwarp.dylib"
340
- ):
341
- clang = warp.build.load_dll(f"{warp_home_path}/bin/libwarp-clang.dylib")
342
- else: # Linux
343
- # try loading warp-clang.so, except when we're building warp-clang.so or warp.so
344
- if os.path.basename(dll_path) != "warp-clang.so" and os.path.basename(dll_path) != "warp.so":
345
- clang = warp.build.load_dll(f"{warp_home_path}/bin/warp-clang.so")
346
- except RuntimeError as e:
347
- clang = None
348
-
349
- cpp_includes = f' -I"{warp_home_path.parent}/external/llvm-project/out/install/{mode}/include"'
350
- cpp_includes += f' -I"{warp_home_path.parent}/_build/host-deps/llvm-project/include"'
351
- cuda_includes = f' -I"{cuda_home}/include"' if cu_path else ""
352
- includes = cpp_includes + cuda_includes
353
-
354
- if mode == "debug":
355
- cpp_flags = f'-O0 -g -fno-rtti -D_DEBUG -DWP_ENABLE_DEBUG=1 -DWP_CPU -D{cuda_enabled} -D{cutlass_enabled} -D{cuda_compat_enabled} -fPIC -fvisibility=hidden --std=c++14 -D_GLIBCXX_USE_CXX11_ABI=0 -fkeep-inline-functions -I"{native_dir}" {includes}'
356
-
357
- if mode == "release":
358
- cpp_flags = f'-O3 -DNDEBUG -DWP_ENABLE_DEBUG=0 -DWP_CPU -D{cuda_enabled} -D{cutlass_enabled} -D{cuda_compat_enabled} -fPIC -fvisibility=hidden --std=c++14 -D_GLIBCXX_USE_CXX11_ABI=0 -I"{native_dir}" {includes}'
359
-
360
- if verify_fp:
361
- cpp_flags += " -DWP_VERIFY_FP"
362
-
363
- if fast_math:
364
- cpp_flags += " -ffast-math"
365
-
366
- ld_inputs = []
367
-
368
- with ScopedTimer("build", active=warp.config.verbose):
369
- for cpp_path in cpp_paths:
370
- cpp_out = cpp_path + ".o"
371
- ld_inputs.append(quote(cpp_out))
372
-
373
- if clang:
374
- with open(cpp_path, "rb") as cpp:
375
- clang.compile_cpp(
376
- cpp.read(), native_dir.encode("utf-8"), cpp_out.encode("utf-8"), warp.config.mode == "debug"
377
- )
378
-
379
- else:
380
- build_cmd = f'g++ {cpp_flags} -c "{cpp_path}" -o "{cpp_out}"'
381
- run_cmd(build_cmd)
382
-
383
- if cu_path:
384
- cu_out = cu_path + ".o"
385
-
386
- if mode == "debug":
387
- cuda_cmd = f'"{cuda_home}/bin/nvcc" -g -G -O0 --compiler-options -fPIC,-fvisibility=hidden -D_DEBUG -D_ITERATOR_DEBUG_LEVEL=0 -line-info {" ".join(nvcc_opts)} -DWP_CUDA -DWP_ENABLE_CUDA=1 -I"{native_dir}" -D{cutlass_enabled} {cutlass_includes} -o "{cu_out}" -c "{cu_path}"'
388
-
389
- elif mode == "release":
390
- cuda_cmd = f'"{cuda_home}/bin/nvcc" -O3 --compiler-options -fPIC,-fvisibility=hidden {" ".join(nvcc_opts)} -DNDEBUG -DWP_CUDA -DWP_ENABLE_CUDA=1 -I"{native_dir}" -D{cutlass_enabled} {cutlass_includes} -o "{cu_out}" -c "{cu_path}"'
391
-
392
- with ScopedTimer("build_cuda", active=warp.config.verbose):
393
- run_cmd(cuda_cmd)
394
-
395
- ld_inputs.append(quote(cu_out))
396
- ld_inputs.append(
397
- f'-L"{cuda_home}/lib64" -lcudart_static -lnvrtc_static -lnvrtc-builtins_static -lnvptxcompiler_static -lpthread -ldl -lrt'
398
- )
399
-
400
- if sys.platform == "darwin":
401
- opt_no_undefined = "-Wl,-undefined,error"
402
- opt_exclude_libs = ""
403
- else:
404
- opt_no_undefined = "-Wl,--no-undefined"
405
- opt_exclude_libs = "-Wl,--exclude-libs,ALL"
406
-
407
- with ScopedTimer("link", active=warp.config.verbose):
408
- # Link into a DLL, unless we have LLVM to load the object code directly
409
- if not clang:
410
- origin = "@loader_path" if (sys.platform == "darwin") else "$ORIGIN"
411
- link_cmd = f"g++ -shared -Wl,-rpath,'{origin}' {opt_no_undefined} {opt_exclude_libs} -o '{dll_path}' {' '.join(ld_inputs + libs)}"
412
- run_cmd(link_cmd)
413
-
414
-
415
- def load_dll(dll_path):
416
- if sys.platform == "win32":
417
- if dll_path[-4:] != ".dll":
418
- return None
419
- elif sys.platform == "darwin":
420
- if dll_path[-6:] != ".dylib":
421
- return None
422
- else:
423
- if dll_path[-3:] != ".so":
424
- return None
425
-
426
- try:
427
- if sys.version_info[0] > 3 or sys.version_info[0] == 3 and sys.version_info[1] >= 8:
428
- dll = ctypes.CDLL(dll_path, winmode=0)
429
- else:
430
- dll = ctypes.CDLL(dll_path)
431
- except OSError:
432
- raise RuntimeError(f"Failed to load the shared library '{dll_path}'")
433
- return dll
434
-
435
-
436
- def unload_dll(dll):
437
- if dll is None:
438
- return
439
-
440
- handle = dll._handle
441
- del dll
442
-
443
- # force garbage collection to eliminate any Python references to the dll
444
- import gc
445
-
446
- gc.collect()
447
-
448
- # platform dependent unload, removes *all* references to the dll
449
- # note this should only be performed if you know there are no dangling
450
- # refs to the dll inside the Python program
451
- if os.name == "nt":
452
- max_attempts = 100
453
- for i in range(max_attempts):
454
- result = ctypes.windll.kernel32.FreeLibrary(ctypes.c_void_p(handle))
455
- if result == 0:
456
- return
457
- else:
458
- _ctypes.dlclose(handle)
459
-
460
-
461
- def force_unload_dll(dll_path):
462
- try:
463
- # force load/unload of the dll from the process
464
- dll = load_dll(dll_path)
465
- unload_dll(dll)
466
-
467
- except Exception as e:
468
- return
48
+ err = warp.context.runtime.llvm.compile_cpp(src, cpp_path, inc_path, obj_path, mode == "debug")
49
+ if err != 0:
50
+ raise Exception(f"CPU kernel build failed with error code {err}")
469
51
 
470
52
 
471
53
  kernel_bin_dir = None
@@ -481,9 +63,6 @@ def init_kernel_cache(path=None):
481
63
  To change the default cache location, set warp.config.kernel_cache_dir before calling warp.init().
482
64
  """
483
65
 
484
- warp_root_dir = os.path.dirname(os.path.realpath(__file__))
485
- warp_bin_dir = os.path.join(warp_root_dir, "bin")
486
-
487
66
  if path is not None:
488
67
  cache_root_dir = os.path.realpath(path)
489
68
  else: