warp-lang 1.10.0__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (468) hide show
  1. warp/__init__.py +334 -0
  2. warp/__init__.pyi +5856 -0
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1077 -0
  5. warp/_src/build.py +620 -0
  6. warp/_src/build_dll.py +642 -0
  7. warp/_src/builtins.py +10555 -0
  8. warp/_src/codegen.py +4361 -0
  9. warp/_src/config.py +178 -0
  10. warp/_src/constants.py +59 -0
  11. warp/_src/context.py +8352 -0
  12. warp/_src/dlpack.py +464 -0
  13. warp/_src/fabric.py +362 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +510 -0
  16. warp/_src/fem/cache.py +689 -0
  17. warp/_src/fem/dirichlet.py +190 -0
  18. warp/_src/fem/domain.py +553 -0
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +703 -0
  21. warp/_src/fem/field/nodal_field.py +403 -0
  22. warp/_src/fem/field/restriction.py +39 -0
  23. warp/_src/fem/field/virtual.py +1021 -0
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/_src/fem/geometry/adaptive_nanogrid.py +782 -0
  26. warp/_src/fem/geometry/closest_point.py +99 -0
  27. warp/_src/fem/geometry/deformed_geometry.py +277 -0
  28. warp/_src/fem/geometry/element.py +854 -0
  29. warp/_src/fem/geometry/geometry.py +693 -0
  30. warp/_src/fem/geometry/grid_2d.py +478 -0
  31. warp/_src/fem/geometry/grid_3d.py +539 -0
  32. warp/_src/fem/geometry/hexmesh.py +956 -0
  33. warp/_src/fem/geometry/nanogrid.py +660 -0
  34. warp/_src/fem/geometry/partition.py +483 -0
  35. warp/_src/fem/geometry/quadmesh.py +597 -0
  36. warp/_src/fem/geometry/tetmesh.py +762 -0
  37. warp/_src/fem/geometry/trimesh.py +588 -0
  38. warp/_src/fem/integrate.py +2507 -0
  39. warp/_src/fem/linalg.py +385 -0
  40. warp/_src/fem/operator.py +398 -0
  41. warp/_src/fem/polynomial.py +231 -0
  42. warp/_src/fem/quadrature/__init__.py +17 -0
  43. warp/_src/fem/quadrature/pic_quadrature.py +318 -0
  44. warp/_src/fem/quadrature/quadrature.py +665 -0
  45. warp/_src/fem/space/__init__.py +248 -0
  46. warp/_src/fem/space/basis_function_space.py +499 -0
  47. warp/_src/fem/space/basis_space.py +681 -0
  48. warp/_src/fem/space/dof_mapper.py +253 -0
  49. warp/_src/fem/space/function_space.py +312 -0
  50. warp/_src/fem/space/grid_2d_function_space.py +179 -0
  51. warp/_src/fem/space/grid_3d_function_space.py +229 -0
  52. warp/_src/fem/space/hexmesh_function_space.py +255 -0
  53. warp/_src/fem/space/nanogrid_function_space.py +199 -0
  54. warp/_src/fem/space/partition.py +435 -0
  55. warp/_src/fem/space/quadmesh_function_space.py +222 -0
  56. warp/_src/fem/space/restriction.py +221 -0
  57. warp/_src/fem/space/shape/__init__.py +152 -0
  58. warp/_src/fem/space/shape/cube_shape_function.py +1107 -0
  59. warp/_src/fem/space/shape/shape_function.py +134 -0
  60. warp/_src/fem/space/shape/square_shape_function.py +928 -0
  61. warp/_src/fem/space/shape/tet_shape_function.py +829 -0
  62. warp/_src/fem/space/shape/triangle_shape_function.py +674 -0
  63. warp/_src/fem/space/tetmesh_function_space.py +270 -0
  64. warp/_src/fem/space/topology.py +461 -0
  65. warp/_src/fem/space/trimesh_function_space.py +193 -0
  66. warp/_src/fem/types.py +114 -0
  67. warp/_src/fem/utils.py +488 -0
  68. warp/_src/jax.py +188 -0
  69. warp/_src/jax_experimental/__init__.py +14 -0
  70. warp/_src/jax_experimental/custom_call.py +389 -0
  71. warp/_src/jax_experimental/ffi.py +1286 -0
  72. warp/_src/jax_experimental/xla_ffi.py +658 -0
  73. warp/_src/marching_cubes.py +710 -0
  74. warp/_src/math.py +416 -0
  75. warp/_src/optim/__init__.py +14 -0
  76. warp/_src/optim/adam.py +165 -0
  77. warp/_src/optim/linear.py +1608 -0
  78. warp/_src/optim/sgd.py +114 -0
  79. warp/_src/paddle.py +408 -0
  80. warp/_src/render/__init__.py +14 -0
  81. warp/_src/render/imgui_manager.py +291 -0
  82. warp/_src/render/render_opengl.py +3638 -0
  83. warp/_src/render/render_usd.py +939 -0
  84. warp/_src/render/utils.py +162 -0
  85. warp/_src/sparse.py +2718 -0
  86. warp/_src/tape.py +1208 -0
  87. warp/_src/thirdparty/__init__.py +0 -0
  88. warp/_src/thirdparty/appdirs.py +598 -0
  89. warp/_src/thirdparty/dlpack.py +145 -0
  90. warp/_src/thirdparty/unittest_parallel.py +676 -0
  91. warp/_src/torch.py +393 -0
  92. warp/_src/types.py +5888 -0
  93. warp/_src/utils.py +1695 -0
  94. warp/autograd.py +33 -0
  95. warp/bin/libwarp-clang.dylib +0 -0
  96. warp/bin/libwarp.dylib +0 -0
  97. warp/build.py +29 -0
  98. warp/build_dll.py +24 -0
  99. warp/codegen.py +24 -0
  100. warp/constants.py +24 -0
  101. warp/context.py +33 -0
  102. warp/dlpack.py +24 -0
  103. warp/examples/__init__.py +24 -0
  104. warp/examples/assets/bear.usd +0 -0
  105. warp/examples/assets/bunny.usd +0 -0
  106. warp/examples/assets/cube.usd +0 -0
  107. warp/examples/assets/nonuniform.usd +0 -0
  108. warp/examples/assets/nvidia_logo.png +0 -0
  109. warp/examples/assets/pixel.jpg +0 -0
  110. warp/examples/assets/rocks.nvdb +0 -0
  111. warp/examples/assets/rocks.usd +0 -0
  112. warp/examples/assets/sphere.usd +0 -0
  113. warp/examples/assets/square_cloth.usd +0 -0
  114. warp/examples/benchmarks/benchmark_api.py +389 -0
  115. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  116. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  117. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  118. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  119. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  120. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  121. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  122. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  123. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  124. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  125. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  126. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  127. warp/examples/benchmarks/benchmark_launches.py +301 -0
  128. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  129. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  130. warp/examples/browse.py +37 -0
  131. warp/examples/core/example_cupy.py +86 -0
  132. warp/examples/core/example_dem.py +241 -0
  133. warp/examples/core/example_fluid.py +299 -0
  134. warp/examples/core/example_graph_capture.py +150 -0
  135. warp/examples/core/example_marching_cubes.py +195 -0
  136. warp/examples/core/example_mesh.py +180 -0
  137. warp/examples/core/example_mesh_intersect.py +211 -0
  138. warp/examples/core/example_nvdb.py +182 -0
  139. warp/examples/core/example_raycast.py +111 -0
  140. warp/examples/core/example_raymarch.py +205 -0
  141. warp/examples/core/example_render_opengl.py +290 -0
  142. warp/examples/core/example_sample_mesh.py +300 -0
  143. warp/examples/core/example_sph.py +411 -0
  144. warp/examples/core/example_spin_lock.py +93 -0
  145. warp/examples/core/example_torch.py +211 -0
  146. warp/examples/core/example_wave.py +269 -0
  147. warp/examples/core/example_work_queue.py +118 -0
  148. warp/examples/distributed/example_jacobi_mpi.py +506 -0
  149. warp/examples/fem/example_adaptive_grid.py +286 -0
  150. warp/examples/fem/example_apic_fluid.py +469 -0
  151. warp/examples/fem/example_burgers.py +261 -0
  152. warp/examples/fem/example_convection_diffusion.py +181 -0
  153. warp/examples/fem/example_convection_diffusion_dg.py +225 -0
  154. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  155. warp/examples/fem/example_deformed_geometry.py +172 -0
  156. warp/examples/fem/example_diffusion.py +196 -0
  157. warp/examples/fem/example_diffusion_3d.py +225 -0
  158. warp/examples/fem/example_diffusion_mgpu.py +225 -0
  159. warp/examples/fem/example_distortion_energy.py +228 -0
  160. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  161. warp/examples/fem/example_magnetostatics.py +242 -0
  162. warp/examples/fem/example_mixed_elasticity.py +293 -0
  163. warp/examples/fem/example_navier_stokes.py +263 -0
  164. warp/examples/fem/example_nonconforming_contact.py +300 -0
  165. warp/examples/fem/example_stokes.py +213 -0
  166. warp/examples/fem/example_stokes_transfer.py +262 -0
  167. warp/examples/fem/example_streamlines.py +357 -0
  168. warp/examples/fem/utils.py +1047 -0
  169. warp/examples/interop/example_jax_callable.py +146 -0
  170. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  171. warp/examples/interop/example_jax_kernel.py +232 -0
  172. warp/examples/optim/example_diffray.py +561 -0
  173. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  174. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  175. warp/examples/tile/example_tile_cholesky.py +88 -0
  176. warp/examples/tile/example_tile_convolution.py +66 -0
  177. warp/examples/tile/example_tile_fft.py +55 -0
  178. warp/examples/tile/example_tile_filtering.py +113 -0
  179. warp/examples/tile/example_tile_matmul.py +85 -0
  180. warp/examples/tile/example_tile_mcgp.py +191 -0
  181. warp/examples/tile/example_tile_mlp.py +385 -0
  182. warp/examples/tile/example_tile_nbody.py +199 -0
  183. warp/fabric.py +24 -0
  184. warp/fem/__init__.py +173 -0
  185. warp/fem/adaptivity.py +26 -0
  186. warp/fem/cache.py +30 -0
  187. warp/fem/dirichlet.py +24 -0
  188. warp/fem/field/__init__.py +24 -0
  189. warp/fem/field/field.py +26 -0
  190. warp/fem/geometry/__init__.py +21 -0
  191. warp/fem/geometry/closest_point.py +31 -0
  192. warp/fem/linalg.py +38 -0
  193. warp/fem/operator.py +32 -0
  194. warp/fem/polynomial.py +29 -0
  195. warp/fem/space/__init__.py +22 -0
  196. warp/fem/space/basis_space.py +24 -0
  197. warp/fem/space/shape/__init__.py +68 -0
  198. warp/fem/space/topology.py +24 -0
  199. warp/fem/types.py +24 -0
  200. warp/fem/utils.py +32 -0
  201. warp/jax.py +29 -0
  202. warp/jax_experimental/__init__.py +29 -0
  203. warp/jax_experimental/custom_call.py +29 -0
  204. warp/jax_experimental/ffi.py +39 -0
  205. warp/jax_experimental/xla_ffi.py +24 -0
  206. warp/marching_cubes.py +24 -0
  207. warp/math.py +37 -0
  208. warp/native/array.h +1687 -0
  209. warp/native/builtin.h +2327 -0
  210. warp/native/bvh.cpp +562 -0
  211. warp/native/bvh.cu +826 -0
  212. warp/native/bvh.h +555 -0
  213. warp/native/clang/clang.cpp +541 -0
  214. warp/native/coloring.cpp +622 -0
  215. warp/native/crt.cpp +51 -0
  216. warp/native/crt.h +568 -0
  217. warp/native/cuda_crt.h +1058 -0
  218. warp/native/cuda_util.cpp +677 -0
  219. warp/native/cuda_util.h +313 -0
  220. warp/native/error.cpp +77 -0
  221. warp/native/error.h +36 -0
  222. warp/native/exports.h +2023 -0
  223. warp/native/fabric.h +246 -0
  224. warp/native/hashgrid.cpp +311 -0
  225. warp/native/hashgrid.cu +89 -0
  226. warp/native/hashgrid.h +240 -0
  227. warp/native/initializer_array.h +41 -0
  228. warp/native/intersect.h +1253 -0
  229. warp/native/intersect_adj.h +375 -0
  230. warp/native/intersect_tri.h +348 -0
  231. warp/native/mat.h +5189 -0
  232. warp/native/mathdx.cpp +93 -0
  233. warp/native/matnn.h +221 -0
  234. warp/native/mesh.cpp +266 -0
  235. warp/native/mesh.cu +406 -0
  236. warp/native/mesh.h +2097 -0
  237. warp/native/nanovdb/GridHandle.h +533 -0
  238. warp/native/nanovdb/HostBuffer.h +591 -0
  239. warp/native/nanovdb/NanoVDB.h +6246 -0
  240. warp/native/nanovdb/NodeManager.h +323 -0
  241. warp/native/nanovdb/PNanoVDB.h +3390 -0
  242. warp/native/noise.h +859 -0
  243. warp/native/quat.h +1664 -0
  244. warp/native/rand.h +342 -0
  245. warp/native/range.h +145 -0
  246. warp/native/reduce.cpp +174 -0
  247. warp/native/reduce.cu +363 -0
  248. warp/native/runlength_encode.cpp +79 -0
  249. warp/native/runlength_encode.cu +61 -0
  250. warp/native/scan.cpp +47 -0
  251. warp/native/scan.cu +55 -0
  252. warp/native/scan.h +23 -0
  253. warp/native/solid_angle.h +466 -0
  254. warp/native/sort.cpp +251 -0
  255. warp/native/sort.cu +286 -0
  256. warp/native/sort.h +35 -0
  257. warp/native/sparse.cpp +241 -0
  258. warp/native/sparse.cu +435 -0
  259. warp/native/spatial.h +1306 -0
  260. warp/native/svd.h +727 -0
  261. warp/native/temp_buffer.h +46 -0
  262. warp/native/tile.h +4124 -0
  263. warp/native/tile_radix_sort.h +1112 -0
  264. warp/native/tile_reduce.h +838 -0
  265. warp/native/tile_scan.h +240 -0
  266. warp/native/tuple.h +189 -0
  267. warp/native/vec.h +2199 -0
  268. warp/native/version.h +23 -0
  269. warp/native/volume.cpp +501 -0
  270. warp/native/volume.cu +68 -0
  271. warp/native/volume.h +970 -0
  272. warp/native/volume_builder.cu +483 -0
  273. warp/native/volume_builder.h +52 -0
  274. warp/native/volume_impl.h +70 -0
  275. warp/native/warp.cpp +1143 -0
  276. warp/native/warp.cu +4604 -0
  277. warp/native/warp.h +358 -0
  278. warp/optim/__init__.py +20 -0
  279. warp/optim/adam.py +24 -0
  280. warp/optim/linear.py +35 -0
  281. warp/optim/sgd.py +24 -0
  282. warp/paddle.py +24 -0
  283. warp/py.typed +0 -0
  284. warp/render/__init__.py +22 -0
  285. warp/render/imgui_manager.py +29 -0
  286. warp/render/render_opengl.py +24 -0
  287. warp/render/render_usd.py +24 -0
  288. warp/render/utils.py +24 -0
  289. warp/sparse.py +51 -0
  290. warp/tape.py +24 -0
  291. warp/tests/__init__.py +1 -0
  292. warp/tests/__main__.py +4 -0
  293. warp/tests/assets/curlnoise_golden.npy +0 -0
  294. warp/tests/assets/mlp_golden.npy +0 -0
  295. warp/tests/assets/pixel.npy +0 -0
  296. warp/tests/assets/pnoise_golden.npy +0 -0
  297. warp/tests/assets/spiky.usd +0 -0
  298. warp/tests/assets/test_grid.nvdb +0 -0
  299. warp/tests/assets/test_index_grid.nvdb +0 -0
  300. warp/tests/assets/test_int32_grid.nvdb +0 -0
  301. warp/tests/assets/test_vec_grid.nvdb +0 -0
  302. warp/tests/assets/torus.nvdb +0 -0
  303. warp/tests/assets/torus.usda +105 -0
  304. warp/tests/aux_test_class_kernel.py +34 -0
  305. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  306. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  307. warp/tests/aux_test_dependent.py +29 -0
  308. warp/tests/aux_test_grad_customs.py +29 -0
  309. warp/tests/aux_test_instancing_gc.py +26 -0
  310. warp/tests/aux_test_module_aot.py +7 -0
  311. warp/tests/aux_test_module_unload.py +23 -0
  312. warp/tests/aux_test_name_clash1.py +40 -0
  313. warp/tests/aux_test_name_clash2.py +40 -0
  314. warp/tests/aux_test_reference.py +9 -0
  315. warp/tests/aux_test_reference_reference.py +8 -0
  316. warp/tests/aux_test_square.py +16 -0
  317. warp/tests/aux_test_unresolved_func.py +22 -0
  318. warp/tests/aux_test_unresolved_symbol.py +22 -0
  319. warp/tests/cuda/__init__.py +0 -0
  320. warp/tests/cuda/test_async.py +676 -0
  321. warp/tests/cuda/test_conditional_captures.py +1147 -0
  322. warp/tests/cuda/test_ipc.py +124 -0
  323. warp/tests/cuda/test_mempool.py +233 -0
  324. warp/tests/cuda/test_multigpu.py +169 -0
  325. warp/tests/cuda/test_peer.py +139 -0
  326. warp/tests/cuda/test_pinned.py +84 -0
  327. warp/tests/cuda/test_streams.py +691 -0
  328. warp/tests/geometry/__init__.py +0 -0
  329. warp/tests/geometry/test_bvh.py +335 -0
  330. warp/tests/geometry/test_hash_grid.py +259 -0
  331. warp/tests/geometry/test_marching_cubes.py +294 -0
  332. warp/tests/geometry/test_mesh.py +318 -0
  333. warp/tests/geometry/test_mesh_query_aabb.py +392 -0
  334. warp/tests/geometry/test_mesh_query_point.py +935 -0
  335. warp/tests/geometry/test_mesh_query_ray.py +323 -0
  336. warp/tests/geometry/test_volume.py +1103 -0
  337. warp/tests/geometry/test_volume_write.py +346 -0
  338. warp/tests/interop/__init__.py +0 -0
  339. warp/tests/interop/test_dlpack.py +730 -0
  340. warp/tests/interop/test_jax.py +1673 -0
  341. warp/tests/interop/test_paddle.py +800 -0
  342. warp/tests/interop/test_torch.py +1001 -0
  343. warp/tests/run_coverage_serial.py +39 -0
  344. warp/tests/test_adam.py +162 -0
  345. warp/tests/test_arithmetic.py +1096 -0
  346. warp/tests/test_array.py +3756 -0
  347. warp/tests/test_array_reduce.py +156 -0
  348. warp/tests/test_assert.py +303 -0
  349. warp/tests/test_atomic.py +336 -0
  350. warp/tests/test_atomic_bitwise.py +209 -0
  351. warp/tests/test_atomic_cas.py +312 -0
  352. warp/tests/test_bool.py +220 -0
  353. warp/tests/test_builtins_resolution.py +732 -0
  354. warp/tests/test_closest_point_edge_edge.py +327 -0
  355. warp/tests/test_codegen.py +974 -0
  356. warp/tests/test_codegen_instancing.py +1495 -0
  357. warp/tests/test_compile_consts.py +215 -0
  358. warp/tests/test_conditional.py +298 -0
  359. warp/tests/test_context.py +35 -0
  360. warp/tests/test_copy.py +319 -0
  361. warp/tests/test_ctypes.py +618 -0
  362. warp/tests/test_dense.py +73 -0
  363. warp/tests/test_devices.py +127 -0
  364. warp/tests/test_enum.py +136 -0
  365. warp/tests/test_examples.py +424 -0
  366. warp/tests/test_fabricarray.py +998 -0
  367. warp/tests/test_fast_math.py +72 -0
  368. warp/tests/test_fem.py +2204 -0
  369. warp/tests/test_fixedarray.py +229 -0
  370. warp/tests/test_fp16.py +136 -0
  371. warp/tests/test_func.py +501 -0
  372. warp/tests/test_future_annotations.py +100 -0
  373. warp/tests/test_generics.py +656 -0
  374. warp/tests/test_grad.py +893 -0
  375. warp/tests/test_grad_customs.py +339 -0
  376. warp/tests/test_grad_debug.py +341 -0
  377. warp/tests/test_implicit_init.py +411 -0
  378. warp/tests/test_import.py +45 -0
  379. warp/tests/test_indexedarray.py +1140 -0
  380. warp/tests/test_intersect.py +103 -0
  381. warp/tests/test_iter.py +76 -0
  382. warp/tests/test_large.py +177 -0
  383. warp/tests/test_launch.py +411 -0
  384. warp/tests/test_lerp.py +151 -0
  385. warp/tests/test_linear_solvers.py +223 -0
  386. warp/tests/test_lvalue.py +427 -0
  387. warp/tests/test_map.py +526 -0
  388. warp/tests/test_mat.py +3515 -0
  389. warp/tests/test_mat_assign_copy.py +178 -0
  390. warp/tests/test_mat_constructors.py +573 -0
  391. warp/tests/test_mat_lite.py +122 -0
  392. warp/tests/test_mat_scalar_ops.py +2913 -0
  393. warp/tests/test_math.py +212 -0
  394. warp/tests/test_module_aot.py +287 -0
  395. warp/tests/test_module_hashing.py +258 -0
  396. warp/tests/test_modules_lite.py +70 -0
  397. warp/tests/test_noise.py +252 -0
  398. warp/tests/test_operators.py +299 -0
  399. warp/tests/test_options.py +129 -0
  400. warp/tests/test_overwrite.py +551 -0
  401. warp/tests/test_print.py +408 -0
  402. warp/tests/test_quat.py +2653 -0
  403. warp/tests/test_quat_assign_copy.py +145 -0
  404. warp/tests/test_rand.py +339 -0
  405. warp/tests/test_reload.py +303 -0
  406. warp/tests/test_rounding.py +157 -0
  407. warp/tests/test_runlength_encode.py +196 -0
  408. warp/tests/test_scalar_ops.py +133 -0
  409. warp/tests/test_smoothstep.py +108 -0
  410. warp/tests/test_snippet.py +318 -0
  411. warp/tests/test_sparse.py +845 -0
  412. warp/tests/test_spatial.py +2859 -0
  413. warp/tests/test_spatial_assign_copy.py +160 -0
  414. warp/tests/test_special_values.py +361 -0
  415. warp/tests/test_static.py +640 -0
  416. warp/tests/test_struct.py +901 -0
  417. warp/tests/test_tape.py +242 -0
  418. warp/tests/test_transient_module.py +93 -0
  419. warp/tests/test_triangle_closest_point.py +192 -0
  420. warp/tests/test_tuple.py +361 -0
  421. warp/tests/test_types.py +615 -0
  422. warp/tests/test_utils.py +594 -0
  423. warp/tests/test_vec.py +1408 -0
  424. warp/tests/test_vec_assign_copy.py +143 -0
  425. warp/tests/test_vec_constructors.py +325 -0
  426. warp/tests/test_vec_lite.py +80 -0
  427. warp/tests/test_vec_scalar_ops.py +2327 -0
  428. warp/tests/test_verify_fp.py +100 -0
  429. warp/tests/test_version.py +75 -0
  430. warp/tests/tile/__init__.py +0 -0
  431. warp/tests/tile/test_tile.py +1519 -0
  432. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  433. warp/tests/tile/test_tile_cholesky.py +608 -0
  434. warp/tests/tile/test_tile_load.py +724 -0
  435. warp/tests/tile/test_tile_mathdx.py +156 -0
  436. warp/tests/tile/test_tile_matmul.py +179 -0
  437. warp/tests/tile/test_tile_mlp.py +400 -0
  438. warp/tests/tile/test_tile_reduce.py +950 -0
  439. warp/tests/tile/test_tile_shared_memory.py +376 -0
  440. warp/tests/tile/test_tile_sort.py +121 -0
  441. warp/tests/tile/test_tile_view.py +173 -0
  442. warp/tests/unittest_serial.py +47 -0
  443. warp/tests/unittest_suites.py +430 -0
  444. warp/tests/unittest_utils.py +469 -0
  445. warp/tests/walkthrough_debug.py +95 -0
  446. warp/torch.py +24 -0
  447. warp/types.py +51 -0
  448. warp/utils.py +31 -0
  449. warp_lang-1.10.0.dist-info/METADATA +459 -0
  450. warp_lang-1.10.0.dist-info/RECORD +468 -0
  451. warp_lang-1.10.0.dist-info/WHEEL +5 -0
  452. warp_lang-1.10.0.dist-info/licenses/LICENSE.md +176 -0
  453. warp_lang-1.10.0.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  454. warp_lang-1.10.0.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  455. warp_lang-1.10.0.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  456. warp_lang-1.10.0.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  457. warp_lang-1.10.0.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  458. warp_lang-1.10.0.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  459. warp_lang-1.10.0.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  460. warp_lang-1.10.0.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  461. warp_lang-1.10.0.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  462. warp_lang-1.10.0.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  463. warp_lang-1.10.0.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  464. warp_lang-1.10.0.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  465. warp_lang-1.10.0.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  466. warp_lang-1.10.0.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  467. warp_lang-1.10.0.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  468. warp_lang-1.10.0.dist-info/top_level.txt +1 -0
warp/_src/build_dll.py ADDED
@@ -0,0 +1,642 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from __future__ import annotations
17
+
18
+ import concurrent.futures
19
+ import os
20
+ import platform
21
+ import subprocess
22
+ import sys
23
+ import time
24
+
25
+ from warp._src.utils import ScopedTimer
26
+
27
+ _wp_module_name_ = "warp.build_dll"
28
+
29
+ verbose_cmd = True # print command lines before executing them
30
+
31
+ MIN_CTK_VERSION = (12, 0)
32
+
33
+
34
+ def machine_architecture() -> str:
35
+ """Return a canonical machine architecture string.
36
+ - "x86_64" for x86-64, aka. AMD64, aka. x64
37
+ - "aarch64" for AArch64, aka. ARM64
38
+ """
39
+ machine = platform.machine()
40
+ if machine == "x86_64" or machine == "AMD64":
41
+ return "x86_64"
42
+ if machine == "aarch64" or machine == "arm64":
43
+ return "aarch64"
44
+ raise RuntimeError(f"Unrecognized machine architecture {machine}")
45
+
46
+
47
+ def run_cmd(cmd):
48
+ if verbose_cmd:
49
+ print(cmd)
50
+
51
+ try:
52
+ return subprocess.check_output(cmd, stderr=subprocess.STDOUT, shell=True)
53
+ except subprocess.CalledProcessError as e:
54
+ print("Command failed with exit code:", e.returncode)
55
+ print("Command output was:")
56
+ print(e.output.decode())
57
+ raise e
58
+
59
+
60
+ # cut-down version of vcvars64.bat that allows using
61
+ # custom toolchain locations, returns the compiler program path
62
+ def set_msvc_env(msvc_path, sdk_path):
63
+ if "INCLUDE" not in os.environ:
64
+ os.environ["INCLUDE"] = ""
65
+
66
+ if "LIB" not in os.environ:
67
+ os.environ["LIB"] = ""
68
+
69
+ msvc_path = os.path.abspath(msvc_path)
70
+ sdk_path = os.path.abspath(sdk_path)
71
+
72
+ os.environ["INCLUDE"] += os.pathsep + os.path.join(msvc_path, "include")
73
+ os.environ["INCLUDE"] += os.pathsep + os.path.join(sdk_path, "include/winrt")
74
+ os.environ["INCLUDE"] += os.pathsep + os.path.join(sdk_path, "include/um")
75
+ os.environ["INCLUDE"] += os.pathsep + os.path.join(sdk_path, "include/ucrt")
76
+ os.environ["INCLUDE"] += os.pathsep + os.path.join(sdk_path, "include/shared")
77
+
78
+ os.environ["LIB"] += os.pathsep + os.path.join(msvc_path, "lib/x64")
79
+ os.environ["LIB"] += os.pathsep + os.path.join(sdk_path, "lib/ucrt/x64")
80
+ os.environ["LIB"] += os.pathsep + os.path.join(sdk_path, "lib/um/x64")
81
+
82
+ os.environ["PATH"] += os.pathsep + os.path.join(msvc_path, "bin/HostX64/x64")
83
+ os.environ["PATH"] += os.pathsep + os.path.join(sdk_path, "bin/x64")
84
+
85
+ return os.path.join(msvc_path, "bin", "HostX64", "x64", "cl.exe")
86
+
87
+
88
+ def find_host_compiler():
89
+ if os.name == "nt":
90
+ # try and find an installed host compiler (msvc)
91
+ # runs vcvars and copies back the build environment
92
+
93
+ vswhere_path = r"%ProgramFiles(x86)%/Microsoft Visual Studio/Installer/vswhere.exe"
94
+ vswhere_path = os.path.expandvars(vswhere_path)
95
+ if not os.path.exists(vswhere_path):
96
+ return ""
97
+
98
+ vs_path = run_cmd(f'"{vswhere_path}" -latest -property installationPath').decode().rstrip()
99
+ vsvars_path = os.path.join(vs_path, "VC\\Auxiliary\\Build\\vcvars64.bat")
100
+
101
+ output = run_cmd(f'"{vsvars_path}" && set').decode()
102
+
103
+ for line in output.splitlines():
104
+ pair = line.split("=", 1)
105
+ if len(pair) >= 2:
106
+ os.environ[pair[0]] = pair[1]
107
+
108
+ cl_path = run_cmd("where cl.exe").decode("utf-8").rstrip()
109
+ cl_version = os.environ["VCToolsVersion"].split(".")
110
+
111
+ # ensure at least VS2019 version, see list of MSVC versions here https://en.wikipedia.org/wiki/Microsoft_Visual_C%2B%2B
112
+ cl_required_major = 14
113
+ cl_required_minor = 29
114
+
115
+ if int(cl_version[0]) < cl_required_major or (
116
+ (int(cl_version[0]) == cl_required_major) and (int(cl_version[1]) < cl_required_minor)
117
+ ):
118
+ print(
119
+ f"Warp: MSVC found but compiler version too old, found {cl_version[0]}.{cl_version[1]}, but must be {cl_required_major}.{cl_required_minor} or higher, kernel host compilation will be disabled."
120
+ )
121
+ return ""
122
+
123
+ return cl_path
124
+
125
+ else:
126
+ # try and find g++
127
+ return run_cmd("which g++").decode()
128
+
129
+
130
+ def get_cuda_toolkit_version(cuda_home) -> tuple[int, int]:
131
+ try:
132
+ # the toolkit version can be obtained by running "nvcc --version"
133
+ nvcc_path = os.path.join(cuda_home, "bin", "nvcc")
134
+ nvcc_version_output = subprocess.check_output([nvcc_path, "--version"]).decode("utf-8")
135
+ # search for release substring (e.g., "release 11.5")
136
+ import re
137
+
138
+ m = re.search(r"release (\d+)\.(\d+)", nvcc_version_output)
139
+ if m is not None:
140
+ major, minor = map(int, m.groups())
141
+ return (major, minor)
142
+ else:
143
+ raise Exception("Failed to parse NVCC output")
144
+
145
+ except Exception as e:
146
+ print(f"Warning: Failed to determine CUDA Toolkit version: {e}")
147
+ return MIN_CTK_VERSION
148
+
149
+
150
+ def quote(path):
151
+ return '"' + path + '"'
152
+
153
+
154
+ def add_llvm_bin_to_path(args):
155
+ """Add the LLVM bin directory to the PATH environment variable if it's set.
156
+
157
+ Args:
158
+ args: The argument namespace containing llvm_path.
159
+
160
+ Returns:
161
+ ``True`` if the PATH was updated, ``False`` otherwise.
162
+ """
163
+ if not hasattr(args, "llvm_path") or not args.llvm_path:
164
+ return False
165
+
166
+ # Construct the bin directory path
167
+ llvm_bin_path = os.path.join(args.llvm_path, "bin")
168
+
169
+ # Check if the directory exists
170
+ if not os.path.isdir(llvm_bin_path):
171
+ print(f"Warning: LLVM bin directory not found at {llvm_bin_path}")
172
+ return False
173
+
174
+ # Add to PATH environment variable
175
+ os.environ["PATH"] = llvm_bin_path + os.pathsep + os.environ.get("PATH", "")
176
+
177
+ print(f"Added {llvm_bin_path} to PATH")
178
+ return True
179
+
180
+
181
+ def _get_architectures_cu12(
182
+ ctk_version: tuple[int, int], arch: str, target_platform: str, quick_build: bool = False
183
+ ) -> tuple[list[str], list[str]]:
184
+ """Get architecture flags for CUDA 12.x."""
185
+ gencode_opts = []
186
+ clang_arch_flags = []
187
+
188
+ if quick_build:
189
+ gencode_opts = ["-gencode=arch=compute_52,code=compute_52", "-gencode=arch=compute_75,code=compute_75"]
190
+ clang_arch_flags = ["--cuda-gpu-arch=sm_52", "--cuda-gpu-arch=sm_75"]
191
+ else:
192
+ if arch == "aarch64" and target_platform == "linux" and ctk_version == (12, 9):
193
+ # Skip certain architectures for aarch64 with CUDA 12.9 due to CCCL bug
194
+ print(
195
+ "[INFO] Skipping sm_52, sm_60, sm_61, and sm_70 targets for ARM due to a CUDA Toolkit bug. "
196
+ "See https://nvidia.github.io/warp/installation.html#cuda-12-9-limitation-on-linux-arm-platforms "
197
+ "for details."
198
+ )
199
+ else:
200
+ gencode_opts.extend(
201
+ [
202
+ "-gencode=arch=compute_52,code=sm_52", # Maxwell
203
+ "-gencode=arch=compute_60,code=sm_60", # Pascal
204
+ "-gencode=arch=compute_61,code=sm_61",
205
+ "-gencode=arch=compute_70,code=sm_70", # Volta
206
+ ]
207
+ )
208
+ clang_arch_flags.extend(
209
+ [
210
+ "--cuda-gpu-arch=sm_52",
211
+ "--cuda-gpu-arch=sm_60",
212
+ "--cuda-gpu-arch=sm_61",
213
+ "--cuda-gpu-arch=sm_70",
214
+ ]
215
+ )
216
+
217
+ # Desktop architectures
218
+ gencode_opts.extend(
219
+ [
220
+ "-gencode=arch=compute_75,code=sm_75", # Turing
221
+ "-gencode=arch=compute_75,code=compute_75", # Turing (PTX)
222
+ "-gencode=arch=compute_80,code=sm_80", # Ampere
223
+ "-gencode=arch=compute_86,code=sm_86",
224
+ "-gencode=arch=compute_89,code=sm_89", # Ada
225
+ "-gencode=arch=compute_90,code=sm_90", # Hopper
226
+ ]
227
+ )
228
+ clang_arch_flags.extend(
229
+ [
230
+ "--cuda-gpu-arch=sm_75", # Turing
231
+ "--cuda-gpu-arch=sm_80", # Ampere
232
+ "--cuda-gpu-arch=sm_86",
233
+ "--cuda-gpu-arch=sm_89", # Ada
234
+ "--cuda-gpu-arch=sm_90", # Hopper
235
+ ]
236
+ )
237
+
238
+ if ctk_version >= (12, 8):
239
+ gencode_opts.extend(["-gencode=arch=compute_100,code=sm_100", "-gencode=arch=compute_120,code=sm_120"])
240
+ clang_arch_flags.extend(["--cuda-gpu-arch=sm_100", "--cuda-gpu-arch=sm_120"])
241
+
242
+ # Mobile architectures for aarch64 Linux
243
+ if arch == "aarch64" and target_platform == "linux":
244
+ gencode_opts.extend(
245
+ [
246
+ "-gencode=arch=compute_87,code=sm_87", # Orin
247
+ "-gencode=arch=compute_53,code=sm_53", # X1
248
+ "-gencode=arch=compute_62,code=sm_62", # X2
249
+ "-gencode=arch=compute_72,code=sm_72", # Xavier
250
+ ]
251
+ )
252
+ clang_arch_flags.extend(
253
+ [
254
+ "--cuda-gpu-arch=sm_87",
255
+ "--cuda-gpu-arch=sm_53",
256
+ "--cuda-gpu-arch=sm_62",
257
+ "--cuda-gpu-arch=sm_72",
258
+ ]
259
+ )
260
+
261
+ # Thor support in CUDA 12.8+
262
+ if ctk_version >= (12, 8):
263
+ gencode_opts.append("-gencode=arch=compute_101,code=sm_101") # Thor (CUDA 12 numbering)
264
+ clang_arch_flags.append("--cuda-gpu-arch=sm_101")
265
+
266
+ if ctk_version >= (12, 9):
267
+ gencode_opts.append("-gencode=arch=compute_121,code=sm_121")
268
+ clang_arch_flags.append("--cuda-gpu-arch=sm_121")
269
+
270
+ # PTX for future hardware (use highest available compute capability)
271
+ if ctk_version >= (12, 9):
272
+ gencode_opts.extend(["-gencode=arch=compute_121,code=compute_121"])
273
+ elif ctk_version >= (12, 8):
274
+ gencode_opts.extend(["-gencode=arch=compute_120,code=compute_120"])
275
+ else:
276
+ gencode_opts.append("-gencode=arch=compute_90,code=compute_90")
277
+
278
+ return gencode_opts, clang_arch_flags
279
+
280
+
281
+ def _get_architectures_cu13(
282
+ ctk_version: tuple[int, int], arch: str, target_platform: str, quick_build: bool = False
283
+ ) -> tuple[list[str], list[str]]:
284
+ """Get architecture flags for CUDA 13.x."""
285
+ gencode_opts = []
286
+ clang_arch_flags = []
287
+
288
+ if quick_build:
289
+ gencode_opts = ["-gencode=arch=compute_75,code=compute_75"]
290
+ clang_arch_flags = ["--cuda-gpu-arch=sm_75"]
291
+ else:
292
+ # Desktop architectures
293
+ gencode_opts.extend(
294
+ [
295
+ "-gencode=arch=compute_75,code=sm_75", # Turing
296
+ "-gencode=arch=compute_75,code=compute_75", # Turing (PTX)
297
+ "-gencode=arch=compute_80,code=sm_80", # Ampere
298
+ "-gencode=arch=compute_86,code=sm_86",
299
+ "-gencode=arch=compute_89,code=sm_89", # Ada
300
+ "-gencode=arch=compute_90,code=sm_90", # Hopper
301
+ "-gencode=arch=compute_100,code=sm_100", # Blackwell
302
+ "-gencode=arch=compute_120,code=sm_120", # Blackwell
303
+ ]
304
+ )
305
+ clang_arch_flags.extend(
306
+ [
307
+ "--cuda-gpu-arch=sm_75", # Turing
308
+ "--cuda-gpu-arch=sm_80", # Ampere
309
+ "--cuda-gpu-arch=sm_86",
310
+ "--cuda-gpu-arch=sm_89", # Ada
311
+ "--cuda-gpu-arch=sm_90", # Hopper
312
+ "--cuda-gpu-arch=sm_100", # Blackwell
313
+ "--cuda-gpu-arch=sm_120", # Blackwell
314
+ ]
315
+ )
316
+
317
+ # Mobile architectures for aarch64 Linux
318
+ if arch == "aarch64" and target_platform == "linux":
319
+ gencode_opts.extend(
320
+ [
321
+ "-gencode=arch=compute_87,code=sm_87", # Orin
322
+ "-gencode=arch=compute_110,code=sm_110", # Thor
323
+ "-gencode=arch=compute_121,code=sm_121", # Spark
324
+ ]
325
+ )
326
+ clang_arch_flags.extend(
327
+ [
328
+ "--cuda-gpu-arch=sm_87",
329
+ "--cuda-gpu-arch=sm_110",
330
+ "--cuda-gpu-arch=sm_121",
331
+ ]
332
+ )
333
+
334
+ # PTX for future hardware (use highest available compute capability)
335
+ gencode_opts.extend(["-gencode=arch=compute_121,code=compute_121"])
336
+
337
+ return gencode_opts, clang_arch_flags
338
+
339
+
340
+ def build_dll_for_arch(args, dll_path, cpp_paths, cu_paths, arch, libs: list[str] | None = None, mode=None):
341
+ mode = args.mode if (mode is None) else mode
342
+ cuda_home = args.cuda_path
343
+ cuda_cmd = None
344
+
345
+ # Add LLVM bin directory to PATH
346
+ add_llvm_bin_to_path(args)
347
+
348
+ if args.quick or cu_paths is None:
349
+ cuda_compat_enabled = "WP_ENABLE_CUDA_COMPATIBILITY=0"
350
+ else:
351
+ cuda_compat_enabled = "WP_ENABLE_CUDA_COMPATIBILITY=1"
352
+
353
+ if libs is None:
354
+ libs = []
355
+
356
+ import pathlib
357
+
358
+ warp_home_path = pathlib.Path(__file__).parent.parent
359
+ warp_home = warp_home_path.resolve()
360
+
361
+ if args.verbose:
362
+ print(f"Building {dll_path}")
363
+
364
+ native_dir = os.path.join(warp_home, "native")
365
+
366
+ if cu_paths:
367
+ # check CUDA Toolkit version
368
+ ctk_version = get_cuda_toolkit_version(cuda_home)
369
+ if ctk_version < MIN_CTK_VERSION:
370
+ raise Exception(
371
+ f"CUDA Toolkit version {MIN_CTK_VERSION[0]}.{MIN_CTK_VERSION[1]}+ is required (found {ctk_version[0]}.{ctk_version[1]} in {cuda_home})"
372
+ )
373
+
374
+ # Get architecture flags based on CUDA version
375
+ if ctk_version >= (13, 0):
376
+ gencode_opts, clang_arch_flags = _get_architectures_cu13(ctk_version, arch, sys.platform, args.quick)
377
+ else:
378
+ gencode_opts, clang_arch_flags = _get_architectures_cu12(ctk_version, arch, sys.platform, args.quick)
379
+
380
+ nvcc_opts = [
381
+ *gencode_opts,
382
+ "-t0", # multithreaded compilation
383
+ "--extended-lambda",
384
+ ]
385
+
386
+ # Clang options
387
+ clang_opts = [
388
+ *clang_arch_flags,
389
+ "-std=c++17",
390
+ "-xcuda",
391
+ f'--cuda-path="{cuda_home}"',
392
+ ]
393
+
394
+ if args.compile_time_trace:
395
+ if ctk_version >= (12, 8):
396
+ nvcc_opts.append("--fdevice-time-trace=_build/build_lib_@filename@_compile-time-trace")
397
+ else:
398
+ print("Warp warning: CUDA version is less than 12.8, compile_time_trace is not supported")
399
+
400
+ if args.fast_math:
401
+ nvcc_opts.append("--use_fast_math")
402
+
403
+ # is the library being built with CUDA enabled?
404
+ cuda_enabled = "WP_ENABLE_CUDA=1" if (cu_paths is not None) else "WP_ENABLE_CUDA=0"
405
+
406
+ if args.libmathdx_path:
407
+ libmathdx_includes = f' -I"{args.libmathdx_path}/include"'
408
+ mathdx_enabled = "WP_ENABLE_MATHDX=1"
409
+ else:
410
+ libmathdx_includes = ""
411
+ mathdx_enabled = "WP_ENABLE_MATHDX=0"
412
+
413
+ if os.name == "nt":
414
+ if args.host_compiler:
415
+ host_linker = os.path.join(os.path.dirname(args.host_compiler), "link.exe")
416
+ else:
417
+ raise RuntimeError("Warp build error: No host compiler was found")
418
+
419
+ cpp_includes = f' /I"{warp_home_path.parent}/external/llvm-project/out/install/{mode}-{arch}/include"'
420
+ cpp_includes += f' /I"{warp_home_path.parent}/_build/host-deps/llvm-project/release-{arch}/include"'
421
+ cuda_includes = f' /I"{cuda_home}/include"' if cu_paths else ""
422
+ includes = cpp_includes + cuda_includes
423
+
424
+ # nvrtc_static.lib is built with /MT and _ITERATOR_DEBUG_LEVEL=0 so if we link it in we must match these options
425
+ if cu_paths or mode != "debug":
426
+ runtime = "/MT"
427
+ iter_dbg = "_ITERATOR_DEBUG_LEVEL=0"
428
+ debug = "NDEBUG"
429
+ else:
430
+ runtime = "/MTd"
431
+ iter_dbg = "_ITERATOR_DEBUG_LEVEL=2"
432
+ debug = "_DEBUG"
433
+
434
+ cpp_flags = f'/nologo /std:c++17 /GR- {runtime} /D "{debug}" /D "{cuda_enabled}" /D "{mathdx_enabled}" /D "{cuda_compat_enabled}" /D "{iter_dbg}" /I"{native_dir}" {includes} '
435
+
436
+ if args.mode == "debug":
437
+ cpp_flags += "/FS /Zi /Od /D WP_ENABLE_DEBUG=1"
438
+ linkopts = ["/DLL", "/DEBUG"]
439
+ elif args.mode == "release":
440
+ cpp_flags += "/Ox /D WP_ENABLE_DEBUG=0"
441
+ linkopts = ["/DLL"]
442
+ else:
443
+ raise RuntimeError(f"Unrecognized build configuration (debug, release), got: {args.mode}")
444
+
445
+ if args.verify_fp:
446
+ cpp_flags += ' /D "WP_VERIFY_FP"'
447
+
448
+ if args.fast_math:
449
+ cpp_flags += " /fp:fast"
450
+
451
+ with concurrent.futures.ThreadPoolExecutor(max_workers=args.jobs) as executor:
452
+ futures, wall_clock = [], time.perf_counter_ns()
453
+
454
+ cpp_cmds = []
455
+ for cpp_path in cpp_paths:
456
+ cpp_out = cpp_path + ".obj"
457
+ linkopts.append(quote(cpp_out))
458
+ cpp_cmd = f'"{args.host_compiler}" {cpp_flags} -c "{cpp_path}" /Fo"{cpp_out}"'
459
+ cpp_cmds.append(cpp_cmd)
460
+
461
+ if args.jobs <= 1:
462
+ with ScopedTimer("build", active=args.verbose):
463
+ for cpp_cmd in cpp_cmds:
464
+ run_cmd(cpp_cmd)
465
+ else:
466
+ futures = [executor.submit(run_cmd, cmd=cpp_cmd) for cpp_cmd in cpp_cmds]
467
+
468
+ cuda_cmds = []
469
+ if cu_paths:
470
+ for cu_path in cu_paths:
471
+ cu_out = cu_path + ".o"
472
+
473
+ _nvcc_opts = [
474
+ opt.replace("@filename@", os.path.basename(cu_path).replace(".", "_")) for opt in nvcc_opts
475
+ ]
476
+
477
+ if mode == "debug":
478
+ cuda_cmd = f'"{cuda_home}/bin/nvcc" --std=c++17 --compiler-options=/MT,/Zi,/Od -g -G -O0 -DNDEBUG -D_ITERATOR_DEBUG_LEVEL=0 -I"{native_dir}" -line-info {" ".join(_nvcc_opts)} -DWP_ENABLE_CUDA=1 -D{mathdx_enabled} {libmathdx_includes} -o "{cu_out}" -c "{cu_path}"'
479
+ elif mode == "release":
480
+ cuda_cmd = f'"{cuda_home}/bin/nvcc" --std=c++17 -O3 {" ".join(_nvcc_opts)} -I"{native_dir}" -DNDEBUG -DWP_ENABLE_CUDA=1 -D{mathdx_enabled} {libmathdx_includes} -o "{cu_out}" -c "{cu_path}"'
481
+
482
+ cuda_cmds.append(cuda_cmd)
483
+
484
+ linkopts.append(quote(cu_out))
485
+
486
+ linkopts.append(
487
+ f'cudart_static.lib nvrtc_static.lib nvrtc-builtins_static.lib nvptxcompiler_static.lib ws2_32.lib user32.lib /LIBPATH:"{cuda_home}/lib/x64"'
488
+ )
489
+
490
+ if args.libmathdx_path:
491
+ linkopts.append(f'nvJitLink_static.lib /LIBPATH:"{args.libmathdx_path}/lib/x64" mathdx_static.lib')
492
+
493
+ if args.jobs <= 1:
494
+ with ScopedTimer("build_cuda", active=args.verbose):
495
+ for cuda_cmd in cuda_cmds:
496
+ run_cmd(cuda_cmd)
497
+ else:
498
+ futures.extend([executor.submit(run_cmd, cmd=cuda_cmd) for cuda_cmd in cuda_cmds])
499
+
500
+ if futures:
501
+ done, pending = concurrent.futures.wait(futures, return_when=concurrent.futures.FIRST_EXCEPTION)
502
+ for d in done:
503
+ if e := d.exception():
504
+ for f in pending:
505
+ f.cancel()
506
+ raise e
507
+ elapsed = (time.perf_counter_ns() - wall_clock) / 1000000.0
508
+ print(f"build took {elapsed:.2f} ms ({args.jobs:d} workers)")
509
+
510
+ with ScopedTimer("link", active=args.verbose):
511
+ link_cmd = f'"{host_linker}" {" ".join(linkopts + libs)} /out:"{dll_path}"'
512
+ run_cmd(link_cmd)
513
+
514
+ else:
515
+ # Unix compilation
516
+ cuda_compiler = "clang++" if getattr(args, "clang_build_toolchain", False) else "nvcc"
517
+ cpp_compiler = "clang++" if getattr(args, "clang_build_toolchain", False) else "g++"
518
+
519
+ cpp_includes = f' -I"{warp_home_path.parent}/external/llvm-project/out/install/{mode}-{arch}/include"'
520
+ cpp_includes += f' -I"{warp_home_path.parent}/_build/host-deps/llvm-project/release-{arch}/include"'
521
+ cuda_includes = f' -I"{cuda_home}/include"' if cu_paths else ""
522
+ includes = cpp_includes + cuda_includes
523
+
524
+ if sys.platform == "darwin":
525
+ version = f"--target={arch}-apple-macos11"
526
+ else:
527
+ if cpp_compiler == "g++":
528
+ version = "-fabi-version=13" # GCC 8.2+
529
+ else:
530
+ version = ""
531
+
532
+ cpp_flags = f'-Werror -Wuninitialized {version} --std=c++17 -fno-rtti -D{cuda_enabled} -D{mathdx_enabled} -D{cuda_compat_enabled} -fPIC -fvisibility=hidden -D_GLIBCXX_USE_CXX11_ABI=0 -I"{native_dir}" {includes} '
533
+
534
+ if mode == "debug":
535
+ cpp_flags += "-O0 -g -D_DEBUG -DWP_ENABLE_DEBUG=1 -fkeep-inline-functions"
536
+
537
+ if mode == "release":
538
+ cpp_flags += "-O3 -DNDEBUG -DWP_ENABLE_DEBUG=0"
539
+
540
+ if args.verify_fp:
541
+ cpp_flags += " -DWP_VERIFY_FP"
542
+
543
+ if args.fast_math:
544
+ cpp_flags += " -ffast-math"
545
+
546
+ ld_inputs = []
547
+
548
+ with concurrent.futures.ThreadPoolExecutor(max_workers=args.jobs) as executor:
549
+ futures, wall_clock = [], time.perf_counter_ns()
550
+
551
+ cpp_cmds = []
552
+ for cpp_path in cpp_paths:
553
+ cpp_out = cpp_path + ".o"
554
+ ld_inputs.append(quote(cpp_out))
555
+ cpp_cmd = f'{cpp_compiler} {cpp_flags} -c "{cpp_path}" -o "{cpp_out}"'
556
+ cpp_cmds.append(cpp_cmd)
557
+
558
+ if args.jobs <= 1:
559
+ with ScopedTimer("build", active=args.verbose):
560
+ for cpp_cmd in cpp_cmds:
561
+ run_cmd(cpp_cmd)
562
+ else:
563
+ futures = [executor.submit(run_cmd, cmd=cpp_cmd) for cpp_cmd in cpp_cmds]
564
+
565
+ cuda_cmds = []
566
+ if cu_paths:
567
+ for cu_path in cu_paths:
568
+ cu_out = cu_path + ".o"
569
+
570
+ _nvcc_opts = [
571
+ opt.replace("@filename@", os.path.basename(cu_path).replace(".", "_")) for opt in nvcc_opts
572
+ ]
573
+
574
+ if cuda_compiler == "nvcc":
575
+ if mode == "debug":
576
+ cuda_cmd = f'"{cuda_home}/bin/nvcc" --std=c++17 -g -G -O0 --compiler-options -fPIC,-fvisibility=hidden -D_DEBUG -D_ITERATOR_DEBUG_LEVEL=0 -line-info {" ".join(_nvcc_opts)} -DWP_ENABLE_CUDA=1 -I"{native_dir}" -D{mathdx_enabled} {libmathdx_includes} -o "{cu_out}" -c "{cu_path}"'
577
+ elif mode == "release":
578
+ cuda_cmd = f'"{cuda_home}/bin/nvcc" --std=c++17 -O3 --compiler-options -fPIC,-fvisibility=hidden {" ".join(_nvcc_opts)} -DNDEBUG -DWP_ENABLE_CUDA=1 -I"{native_dir}" -D{mathdx_enabled} {libmathdx_includes} -o "{cu_out}" -c "{cu_path}"'
579
+ else:
580
+ # Use Clang compiler
581
+ if mode == "debug":
582
+ cuda_cmd = f'clang++ -Werror -Wuninitialized -Wno-unknown-cuda-version {" ".join(clang_opts)} -g -O0 -fPIC -fvisibility=hidden -D_DEBUG -D_ITERATOR_DEBUG_LEVEL=0 -DWP_ENABLE_CUDA=1 -I"{native_dir}" -D{mathdx_enabled} {libmathdx_includes} -o "{cu_out}" -c "{cu_path}"'
583
+ elif mode == "release":
584
+ cuda_cmd = f'clang++ -Werror -Wuninitialized -Wno-unknown-cuda-version {" ".join(clang_opts)} -O3 -fPIC -fvisibility=hidden -DNDEBUG -DWP_ENABLE_CUDA=1 -I"{native_dir}" -D{mathdx_enabled} {libmathdx_includes} -o "{cu_out}" -c "{cu_path}"'
585
+
586
+ cuda_cmds.append(cuda_cmd)
587
+
588
+ ld_inputs.append(quote(cu_out))
589
+
590
+ ld_inputs.append(
591
+ f'-L"{cuda_home}/lib64" -lcudart_static -lnvrtc_static -lnvrtc-builtins_static -lnvptxcompiler_static -lpthread -ldl -lrt'
592
+ )
593
+
594
+ if args.libmathdx_path:
595
+ ld_inputs.append(f"-lnvJitLink_static -L{args.libmathdx_path}/lib -lmathdx_static")
596
+
597
+ if args.jobs <= 1:
598
+ with ScopedTimer("build_cuda", active=args.verbose):
599
+ for cuda_cmd in cuda_cmds:
600
+ run_cmd(cuda_cmd)
601
+ else:
602
+ futures.extend([executor.submit(run_cmd, cmd=cuda_cmd) for cuda_cmd in cuda_cmds])
603
+
604
+ if futures:
605
+ done, pending = concurrent.futures.wait(futures, return_when=concurrent.futures.FIRST_EXCEPTION)
606
+ for d in done:
607
+ if e := d.exception():
608
+ for f in pending:
609
+ f.cancel()
610
+ raise e
611
+ elapsed = (time.perf_counter_ns() - wall_clock) / 1000000.0
612
+ print(f"build took {elapsed:.2f} ms ({args.jobs:d} workers)")
613
+
614
+ if sys.platform == "darwin":
615
+ opt_no_undefined = "-Wl,-undefined,error"
616
+ opt_exclude_libs = ""
617
+ else:
618
+ opt_no_undefined = "-Wl,--no-undefined"
619
+ opt_exclude_libs = "-Wl,--exclude-libs,ALL"
620
+
621
+ with ScopedTimer("link", active=args.verbose):
622
+ origin = "@loader_path" if (sys.platform == "darwin") else "$ORIGIN"
623
+ link_cmd = f"{cpp_compiler} {version} -shared -Wl,-rpath,'{origin}' {opt_no_undefined} {opt_exclude_libs} -o '{dll_path}' {' '.join(ld_inputs + libs)}"
624
+ run_cmd(link_cmd)
625
+
626
+ # Strip symbols to reduce the binary size
627
+ if mode == "release":
628
+ if sys.platform == "darwin":
629
+ run_cmd(f"strip -x {dll_path}") # Strip all local symbols
630
+ else: # Linux
631
+ # Strip all symbols except for those needed to support debugging JIT-compiled code
632
+ run_cmd(
633
+ f"strip --strip-all --keep-symbol=__jit_debug_register_code --keep-symbol=__jit_debug_descriptor {dll_path}"
634
+ )
635
+
636
+
637
+ def build_dll(args, dll_path, cpp_paths, cu_paths, libs=None):
638
+ if sys.platform == "darwin":
639
+ # build for ARM64 only (may be cross-compiled from Intel Mac)
640
+ build_dll_for_arch(args, dll_path, cpp_paths, cu_paths, "aarch64", libs)
641
+ else:
642
+ build_dll_for_arch(args, dll_path, cpp_paths, cu_paths, machine_architecture(), libs)