warp-lang 1.10.0__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (468) hide show
  1. warp/__init__.py +334 -0
  2. warp/__init__.pyi +5856 -0
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1077 -0
  5. warp/_src/build.py +620 -0
  6. warp/_src/build_dll.py +642 -0
  7. warp/_src/builtins.py +10555 -0
  8. warp/_src/codegen.py +4361 -0
  9. warp/_src/config.py +178 -0
  10. warp/_src/constants.py +59 -0
  11. warp/_src/context.py +8352 -0
  12. warp/_src/dlpack.py +464 -0
  13. warp/_src/fabric.py +362 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +510 -0
  16. warp/_src/fem/cache.py +689 -0
  17. warp/_src/fem/dirichlet.py +190 -0
  18. warp/_src/fem/domain.py +553 -0
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +703 -0
  21. warp/_src/fem/field/nodal_field.py +403 -0
  22. warp/_src/fem/field/restriction.py +39 -0
  23. warp/_src/fem/field/virtual.py +1021 -0
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/_src/fem/geometry/adaptive_nanogrid.py +782 -0
  26. warp/_src/fem/geometry/closest_point.py +99 -0
  27. warp/_src/fem/geometry/deformed_geometry.py +277 -0
  28. warp/_src/fem/geometry/element.py +854 -0
  29. warp/_src/fem/geometry/geometry.py +693 -0
  30. warp/_src/fem/geometry/grid_2d.py +478 -0
  31. warp/_src/fem/geometry/grid_3d.py +539 -0
  32. warp/_src/fem/geometry/hexmesh.py +956 -0
  33. warp/_src/fem/geometry/nanogrid.py +660 -0
  34. warp/_src/fem/geometry/partition.py +483 -0
  35. warp/_src/fem/geometry/quadmesh.py +597 -0
  36. warp/_src/fem/geometry/tetmesh.py +762 -0
  37. warp/_src/fem/geometry/trimesh.py +588 -0
  38. warp/_src/fem/integrate.py +2507 -0
  39. warp/_src/fem/linalg.py +385 -0
  40. warp/_src/fem/operator.py +398 -0
  41. warp/_src/fem/polynomial.py +231 -0
  42. warp/_src/fem/quadrature/__init__.py +17 -0
  43. warp/_src/fem/quadrature/pic_quadrature.py +318 -0
  44. warp/_src/fem/quadrature/quadrature.py +665 -0
  45. warp/_src/fem/space/__init__.py +248 -0
  46. warp/_src/fem/space/basis_function_space.py +499 -0
  47. warp/_src/fem/space/basis_space.py +681 -0
  48. warp/_src/fem/space/dof_mapper.py +253 -0
  49. warp/_src/fem/space/function_space.py +312 -0
  50. warp/_src/fem/space/grid_2d_function_space.py +179 -0
  51. warp/_src/fem/space/grid_3d_function_space.py +229 -0
  52. warp/_src/fem/space/hexmesh_function_space.py +255 -0
  53. warp/_src/fem/space/nanogrid_function_space.py +199 -0
  54. warp/_src/fem/space/partition.py +435 -0
  55. warp/_src/fem/space/quadmesh_function_space.py +222 -0
  56. warp/_src/fem/space/restriction.py +221 -0
  57. warp/_src/fem/space/shape/__init__.py +152 -0
  58. warp/_src/fem/space/shape/cube_shape_function.py +1107 -0
  59. warp/_src/fem/space/shape/shape_function.py +134 -0
  60. warp/_src/fem/space/shape/square_shape_function.py +928 -0
  61. warp/_src/fem/space/shape/tet_shape_function.py +829 -0
  62. warp/_src/fem/space/shape/triangle_shape_function.py +674 -0
  63. warp/_src/fem/space/tetmesh_function_space.py +270 -0
  64. warp/_src/fem/space/topology.py +461 -0
  65. warp/_src/fem/space/trimesh_function_space.py +193 -0
  66. warp/_src/fem/types.py +114 -0
  67. warp/_src/fem/utils.py +488 -0
  68. warp/_src/jax.py +188 -0
  69. warp/_src/jax_experimental/__init__.py +14 -0
  70. warp/_src/jax_experimental/custom_call.py +389 -0
  71. warp/_src/jax_experimental/ffi.py +1286 -0
  72. warp/_src/jax_experimental/xla_ffi.py +658 -0
  73. warp/_src/marching_cubes.py +710 -0
  74. warp/_src/math.py +416 -0
  75. warp/_src/optim/__init__.py +14 -0
  76. warp/_src/optim/adam.py +165 -0
  77. warp/_src/optim/linear.py +1608 -0
  78. warp/_src/optim/sgd.py +114 -0
  79. warp/_src/paddle.py +408 -0
  80. warp/_src/render/__init__.py +14 -0
  81. warp/_src/render/imgui_manager.py +291 -0
  82. warp/_src/render/render_opengl.py +3638 -0
  83. warp/_src/render/render_usd.py +939 -0
  84. warp/_src/render/utils.py +162 -0
  85. warp/_src/sparse.py +2718 -0
  86. warp/_src/tape.py +1208 -0
  87. warp/_src/thirdparty/__init__.py +0 -0
  88. warp/_src/thirdparty/appdirs.py +598 -0
  89. warp/_src/thirdparty/dlpack.py +145 -0
  90. warp/_src/thirdparty/unittest_parallel.py +676 -0
  91. warp/_src/torch.py +393 -0
  92. warp/_src/types.py +5888 -0
  93. warp/_src/utils.py +1695 -0
  94. warp/autograd.py +33 -0
  95. warp/bin/libwarp-clang.dylib +0 -0
  96. warp/bin/libwarp.dylib +0 -0
  97. warp/build.py +29 -0
  98. warp/build_dll.py +24 -0
  99. warp/codegen.py +24 -0
  100. warp/constants.py +24 -0
  101. warp/context.py +33 -0
  102. warp/dlpack.py +24 -0
  103. warp/examples/__init__.py +24 -0
  104. warp/examples/assets/bear.usd +0 -0
  105. warp/examples/assets/bunny.usd +0 -0
  106. warp/examples/assets/cube.usd +0 -0
  107. warp/examples/assets/nonuniform.usd +0 -0
  108. warp/examples/assets/nvidia_logo.png +0 -0
  109. warp/examples/assets/pixel.jpg +0 -0
  110. warp/examples/assets/rocks.nvdb +0 -0
  111. warp/examples/assets/rocks.usd +0 -0
  112. warp/examples/assets/sphere.usd +0 -0
  113. warp/examples/assets/square_cloth.usd +0 -0
  114. warp/examples/benchmarks/benchmark_api.py +389 -0
  115. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  116. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  117. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  118. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  119. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  120. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  121. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  122. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  123. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  124. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  125. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  126. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  127. warp/examples/benchmarks/benchmark_launches.py +301 -0
  128. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  129. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  130. warp/examples/browse.py +37 -0
  131. warp/examples/core/example_cupy.py +86 -0
  132. warp/examples/core/example_dem.py +241 -0
  133. warp/examples/core/example_fluid.py +299 -0
  134. warp/examples/core/example_graph_capture.py +150 -0
  135. warp/examples/core/example_marching_cubes.py +195 -0
  136. warp/examples/core/example_mesh.py +180 -0
  137. warp/examples/core/example_mesh_intersect.py +211 -0
  138. warp/examples/core/example_nvdb.py +182 -0
  139. warp/examples/core/example_raycast.py +111 -0
  140. warp/examples/core/example_raymarch.py +205 -0
  141. warp/examples/core/example_render_opengl.py +290 -0
  142. warp/examples/core/example_sample_mesh.py +300 -0
  143. warp/examples/core/example_sph.py +411 -0
  144. warp/examples/core/example_spin_lock.py +93 -0
  145. warp/examples/core/example_torch.py +211 -0
  146. warp/examples/core/example_wave.py +269 -0
  147. warp/examples/core/example_work_queue.py +118 -0
  148. warp/examples/distributed/example_jacobi_mpi.py +506 -0
  149. warp/examples/fem/example_adaptive_grid.py +286 -0
  150. warp/examples/fem/example_apic_fluid.py +469 -0
  151. warp/examples/fem/example_burgers.py +261 -0
  152. warp/examples/fem/example_convection_diffusion.py +181 -0
  153. warp/examples/fem/example_convection_diffusion_dg.py +225 -0
  154. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  155. warp/examples/fem/example_deformed_geometry.py +172 -0
  156. warp/examples/fem/example_diffusion.py +196 -0
  157. warp/examples/fem/example_diffusion_3d.py +225 -0
  158. warp/examples/fem/example_diffusion_mgpu.py +225 -0
  159. warp/examples/fem/example_distortion_energy.py +228 -0
  160. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  161. warp/examples/fem/example_magnetostatics.py +242 -0
  162. warp/examples/fem/example_mixed_elasticity.py +293 -0
  163. warp/examples/fem/example_navier_stokes.py +263 -0
  164. warp/examples/fem/example_nonconforming_contact.py +300 -0
  165. warp/examples/fem/example_stokes.py +213 -0
  166. warp/examples/fem/example_stokes_transfer.py +262 -0
  167. warp/examples/fem/example_streamlines.py +357 -0
  168. warp/examples/fem/utils.py +1047 -0
  169. warp/examples/interop/example_jax_callable.py +146 -0
  170. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  171. warp/examples/interop/example_jax_kernel.py +232 -0
  172. warp/examples/optim/example_diffray.py +561 -0
  173. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  174. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  175. warp/examples/tile/example_tile_cholesky.py +88 -0
  176. warp/examples/tile/example_tile_convolution.py +66 -0
  177. warp/examples/tile/example_tile_fft.py +55 -0
  178. warp/examples/tile/example_tile_filtering.py +113 -0
  179. warp/examples/tile/example_tile_matmul.py +85 -0
  180. warp/examples/tile/example_tile_mcgp.py +191 -0
  181. warp/examples/tile/example_tile_mlp.py +385 -0
  182. warp/examples/tile/example_tile_nbody.py +199 -0
  183. warp/fabric.py +24 -0
  184. warp/fem/__init__.py +173 -0
  185. warp/fem/adaptivity.py +26 -0
  186. warp/fem/cache.py +30 -0
  187. warp/fem/dirichlet.py +24 -0
  188. warp/fem/field/__init__.py +24 -0
  189. warp/fem/field/field.py +26 -0
  190. warp/fem/geometry/__init__.py +21 -0
  191. warp/fem/geometry/closest_point.py +31 -0
  192. warp/fem/linalg.py +38 -0
  193. warp/fem/operator.py +32 -0
  194. warp/fem/polynomial.py +29 -0
  195. warp/fem/space/__init__.py +22 -0
  196. warp/fem/space/basis_space.py +24 -0
  197. warp/fem/space/shape/__init__.py +68 -0
  198. warp/fem/space/topology.py +24 -0
  199. warp/fem/types.py +24 -0
  200. warp/fem/utils.py +32 -0
  201. warp/jax.py +29 -0
  202. warp/jax_experimental/__init__.py +29 -0
  203. warp/jax_experimental/custom_call.py +29 -0
  204. warp/jax_experimental/ffi.py +39 -0
  205. warp/jax_experimental/xla_ffi.py +24 -0
  206. warp/marching_cubes.py +24 -0
  207. warp/math.py +37 -0
  208. warp/native/array.h +1687 -0
  209. warp/native/builtin.h +2327 -0
  210. warp/native/bvh.cpp +562 -0
  211. warp/native/bvh.cu +826 -0
  212. warp/native/bvh.h +555 -0
  213. warp/native/clang/clang.cpp +541 -0
  214. warp/native/coloring.cpp +622 -0
  215. warp/native/crt.cpp +51 -0
  216. warp/native/crt.h +568 -0
  217. warp/native/cuda_crt.h +1058 -0
  218. warp/native/cuda_util.cpp +677 -0
  219. warp/native/cuda_util.h +313 -0
  220. warp/native/error.cpp +77 -0
  221. warp/native/error.h +36 -0
  222. warp/native/exports.h +2023 -0
  223. warp/native/fabric.h +246 -0
  224. warp/native/hashgrid.cpp +311 -0
  225. warp/native/hashgrid.cu +89 -0
  226. warp/native/hashgrid.h +240 -0
  227. warp/native/initializer_array.h +41 -0
  228. warp/native/intersect.h +1253 -0
  229. warp/native/intersect_adj.h +375 -0
  230. warp/native/intersect_tri.h +348 -0
  231. warp/native/mat.h +5189 -0
  232. warp/native/mathdx.cpp +93 -0
  233. warp/native/matnn.h +221 -0
  234. warp/native/mesh.cpp +266 -0
  235. warp/native/mesh.cu +406 -0
  236. warp/native/mesh.h +2097 -0
  237. warp/native/nanovdb/GridHandle.h +533 -0
  238. warp/native/nanovdb/HostBuffer.h +591 -0
  239. warp/native/nanovdb/NanoVDB.h +6246 -0
  240. warp/native/nanovdb/NodeManager.h +323 -0
  241. warp/native/nanovdb/PNanoVDB.h +3390 -0
  242. warp/native/noise.h +859 -0
  243. warp/native/quat.h +1664 -0
  244. warp/native/rand.h +342 -0
  245. warp/native/range.h +145 -0
  246. warp/native/reduce.cpp +174 -0
  247. warp/native/reduce.cu +363 -0
  248. warp/native/runlength_encode.cpp +79 -0
  249. warp/native/runlength_encode.cu +61 -0
  250. warp/native/scan.cpp +47 -0
  251. warp/native/scan.cu +55 -0
  252. warp/native/scan.h +23 -0
  253. warp/native/solid_angle.h +466 -0
  254. warp/native/sort.cpp +251 -0
  255. warp/native/sort.cu +286 -0
  256. warp/native/sort.h +35 -0
  257. warp/native/sparse.cpp +241 -0
  258. warp/native/sparse.cu +435 -0
  259. warp/native/spatial.h +1306 -0
  260. warp/native/svd.h +727 -0
  261. warp/native/temp_buffer.h +46 -0
  262. warp/native/tile.h +4124 -0
  263. warp/native/tile_radix_sort.h +1112 -0
  264. warp/native/tile_reduce.h +838 -0
  265. warp/native/tile_scan.h +240 -0
  266. warp/native/tuple.h +189 -0
  267. warp/native/vec.h +2199 -0
  268. warp/native/version.h +23 -0
  269. warp/native/volume.cpp +501 -0
  270. warp/native/volume.cu +68 -0
  271. warp/native/volume.h +970 -0
  272. warp/native/volume_builder.cu +483 -0
  273. warp/native/volume_builder.h +52 -0
  274. warp/native/volume_impl.h +70 -0
  275. warp/native/warp.cpp +1143 -0
  276. warp/native/warp.cu +4604 -0
  277. warp/native/warp.h +358 -0
  278. warp/optim/__init__.py +20 -0
  279. warp/optim/adam.py +24 -0
  280. warp/optim/linear.py +35 -0
  281. warp/optim/sgd.py +24 -0
  282. warp/paddle.py +24 -0
  283. warp/py.typed +0 -0
  284. warp/render/__init__.py +22 -0
  285. warp/render/imgui_manager.py +29 -0
  286. warp/render/render_opengl.py +24 -0
  287. warp/render/render_usd.py +24 -0
  288. warp/render/utils.py +24 -0
  289. warp/sparse.py +51 -0
  290. warp/tape.py +24 -0
  291. warp/tests/__init__.py +1 -0
  292. warp/tests/__main__.py +4 -0
  293. warp/tests/assets/curlnoise_golden.npy +0 -0
  294. warp/tests/assets/mlp_golden.npy +0 -0
  295. warp/tests/assets/pixel.npy +0 -0
  296. warp/tests/assets/pnoise_golden.npy +0 -0
  297. warp/tests/assets/spiky.usd +0 -0
  298. warp/tests/assets/test_grid.nvdb +0 -0
  299. warp/tests/assets/test_index_grid.nvdb +0 -0
  300. warp/tests/assets/test_int32_grid.nvdb +0 -0
  301. warp/tests/assets/test_vec_grid.nvdb +0 -0
  302. warp/tests/assets/torus.nvdb +0 -0
  303. warp/tests/assets/torus.usda +105 -0
  304. warp/tests/aux_test_class_kernel.py +34 -0
  305. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  306. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  307. warp/tests/aux_test_dependent.py +29 -0
  308. warp/tests/aux_test_grad_customs.py +29 -0
  309. warp/tests/aux_test_instancing_gc.py +26 -0
  310. warp/tests/aux_test_module_aot.py +7 -0
  311. warp/tests/aux_test_module_unload.py +23 -0
  312. warp/tests/aux_test_name_clash1.py +40 -0
  313. warp/tests/aux_test_name_clash2.py +40 -0
  314. warp/tests/aux_test_reference.py +9 -0
  315. warp/tests/aux_test_reference_reference.py +8 -0
  316. warp/tests/aux_test_square.py +16 -0
  317. warp/tests/aux_test_unresolved_func.py +22 -0
  318. warp/tests/aux_test_unresolved_symbol.py +22 -0
  319. warp/tests/cuda/__init__.py +0 -0
  320. warp/tests/cuda/test_async.py +676 -0
  321. warp/tests/cuda/test_conditional_captures.py +1147 -0
  322. warp/tests/cuda/test_ipc.py +124 -0
  323. warp/tests/cuda/test_mempool.py +233 -0
  324. warp/tests/cuda/test_multigpu.py +169 -0
  325. warp/tests/cuda/test_peer.py +139 -0
  326. warp/tests/cuda/test_pinned.py +84 -0
  327. warp/tests/cuda/test_streams.py +691 -0
  328. warp/tests/geometry/__init__.py +0 -0
  329. warp/tests/geometry/test_bvh.py +335 -0
  330. warp/tests/geometry/test_hash_grid.py +259 -0
  331. warp/tests/geometry/test_marching_cubes.py +294 -0
  332. warp/tests/geometry/test_mesh.py +318 -0
  333. warp/tests/geometry/test_mesh_query_aabb.py +392 -0
  334. warp/tests/geometry/test_mesh_query_point.py +935 -0
  335. warp/tests/geometry/test_mesh_query_ray.py +323 -0
  336. warp/tests/geometry/test_volume.py +1103 -0
  337. warp/tests/geometry/test_volume_write.py +346 -0
  338. warp/tests/interop/__init__.py +0 -0
  339. warp/tests/interop/test_dlpack.py +730 -0
  340. warp/tests/interop/test_jax.py +1673 -0
  341. warp/tests/interop/test_paddle.py +800 -0
  342. warp/tests/interop/test_torch.py +1001 -0
  343. warp/tests/run_coverage_serial.py +39 -0
  344. warp/tests/test_adam.py +162 -0
  345. warp/tests/test_arithmetic.py +1096 -0
  346. warp/tests/test_array.py +3756 -0
  347. warp/tests/test_array_reduce.py +156 -0
  348. warp/tests/test_assert.py +303 -0
  349. warp/tests/test_atomic.py +336 -0
  350. warp/tests/test_atomic_bitwise.py +209 -0
  351. warp/tests/test_atomic_cas.py +312 -0
  352. warp/tests/test_bool.py +220 -0
  353. warp/tests/test_builtins_resolution.py +732 -0
  354. warp/tests/test_closest_point_edge_edge.py +327 -0
  355. warp/tests/test_codegen.py +974 -0
  356. warp/tests/test_codegen_instancing.py +1495 -0
  357. warp/tests/test_compile_consts.py +215 -0
  358. warp/tests/test_conditional.py +298 -0
  359. warp/tests/test_context.py +35 -0
  360. warp/tests/test_copy.py +319 -0
  361. warp/tests/test_ctypes.py +618 -0
  362. warp/tests/test_dense.py +73 -0
  363. warp/tests/test_devices.py +127 -0
  364. warp/tests/test_enum.py +136 -0
  365. warp/tests/test_examples.py +424 -0
  366. warp/tests/test_fabricarray.py +998 -0
  367. warp/tests/test_fast_math.py +72 -0
  368. warp/tests/test_fem.py +2204 -0
  369. warp/tests/test_fixedarray.py +229 -0
  370. warp/tests/test_fp16.py +136 -0
  371. warp/tests/test_func.py +501 -0
  372. warp/tests/test_future_annotations.py +100 -0
  373. warp/tests/test_generics.py +656 -0
  374. warp/tests/test_grad.py +893 -0
  375. warp/tests/test_grad_customs.py +339 -0
  376. warp/tests/test_grad_debug.py +341 -0
  377. warp/tests/test_implicit_init.py +411 -0
  378. warp/tests/test_import.py +45 -0
  379. warp/tests/test_indexedarray.py +1140 -0
  380. warp/tests/test_intersect.py +103 -0
  381. warp/tests/test_iter.py +76 -0
  382. warp/tests/test_large.py +177 -0
  383. warp/tests/test_launch.py +411 -0
  384. warp/tests/test_lerp.py +151 -0
  385. warp/tests/test_linear_solvers.py +223 -0
  386. warp/tests/test_lvalue.py +427 -0
  387. warp/tests/test_map.py +526 -0
  388. warp/tests/test_mat.py +3515 -0
  389. warp/tests/test_mat_assign_copy.py +178 -0
  390. warp/tests/test_mat_constructors.py +573 -0
  391. warp/tests/test_mat_lite.py +122 -0
  392. warp/tests/test_mat_scalar_ops.py +2913 -0
  393. warp/tests/test_math.py +212 -0
  394. warp/tests/test_module_aot.py +287 -0
  395. warp/tests/test_module_hashing.py +258 -0
  396. warp/tests/test_modules_lite.py +70 -0
  397. warp/tests/test_noise.py +252 -0
  398. warp/tests/test_operators.py +299 -0
  399. warp/tests/test_options.py +129 -0
  400. warp/tests/test_overwrite.py +551 -0
  401. warp/tests/test_print.py +408 -0
  402. warp/tests/test_quat.py +2653 -0
  403. warp/tests/test_quat_assign_copy.py +145 -0
  404. warp/tests/test_rand.py +339 -0
  405. warp/tests/test_reload.py +303 -0
  406. warp/tests/test_rounding.py +157 -0
  407. warp/tests/test_runlength_encode.py +196 -0
  408. warp/tests/test_scalar_ops.py +133 -0
  409. warp/tests/test_smoothstep.py +108 -0
  410. warp/tests/test_snippet.py +318 -0
  411. warp/tests/test_sparse.py +845 -0
  412. warp/tests/test_spatial.py +2859 -0
  413. warp/tests/test_spatial_assign_copy.py +160 -0
  414. warp/tests/test_special_values.py +361 -0
  415. warp/tests/test_static.py +640 -0
  416. warp/tests/test_struct.py +901 -0
  417. warp/tests/test_tape.py +242 -0
  418. warp/tests/test_transient_module.py +93 -0
  419. warp/tests/test_triangle_closest_point.py +192 -0
  420. warp/tests/test_tuple.py +361 -0
  421. warp/tests/test_types.py +615 -0
  422. warp/tests/test_utils.py +594 -0
  423. warp/tests/test_vec.py +1408 -0
  424. warp/tests/test_vec_assign_copy.py +143 -0
  425. warp/tests/test_vec_constructors.py +325 -0
  426. warp/tests/test_vec_lite.py +80 -0
  427. warp/tests/test_vec_scalar_ops.py +2327 -0
  428. warp/tests/test_verify_fp.py +100 -0
  429. warp/tests/test_version.py +75 -0
  430. warp/tests/tile/__init__.py +0 -0
  431. warp/tests/tile/test_tile.py +1519 -0
  432. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  433. warp/tests/tile/test_tile_cholesky.py +608 -0
  434. warp/tests/tile/test_tile_load.py +724 -0
  435. warp/tests/tile/test_tile_mathdx.py +156 -0
  436. warp/tests/tile/test_tile_matmul.py +179 -0
  437. warp/tests/tile/test_tile_mlp.py +400 -0
  438. warp/tests/tile/test_tile_reduce.py +950 -0
  439. warp/tests/tile/test_tile_shared_memory.py +376 -0
  440. warp/tests/tile/test_tile_sort.py +121 -0
  441. warp/tests/tile/test_tile_view.py +173 -0
  442. warp/tests/unittest_serial.py +47 -0
  443. warp/tests/unittest_suites.py +430 -0
  444. warp/tests/unittest_utils.py +469 -0
  445. warp/tests/walkthrough_debug.py +95 -0
  446. warp/torch.py +24 -0
  447. warp/types.py +51 -0
  448. warp/utils.py +31 -0
  449. warp_lang-1.10.0.dist-info/METADATA +459 -0
  450. warp_lang-1.10.0.dist-info/RECORD +468 -0
  451. warp_lang-1.10.0.dist-info/WHEEL +5 -0
  452. warp_lang-1.10.0.dist-info/licenses/LICENSE.md +176 -0
  453. warp_lang-1.10.0.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  454. warp_lang-1.10.0.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  455. warp_lang-1.10.0.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  456. warp_lang-1.10.0.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  457. warp_lang-1.10.0.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  458. warp_lang-1.10.0.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  459. warp_lang-1.10.0.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  460. warp_lang-1.10.0.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  461. warp_lang-1.10.0.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  462. warp_lang-1.10.0.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  463. warp_lang-1.10.0.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  464. warp_lang-1.10.0.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  465. warp_lang-1.10.0.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  466. warp_lang-1.10.0.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  467. warp_lang-1.10.0.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  468. warp_lang-1.10.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,212 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import unittest
17
+ from typing import Any, NamedTuple
18
+
19
+ import numpy as np
20
+
21
+ import warp as wp
22
+ from warp.tests.unittest_utils import add_function_test, assert_np_equal, get_test_devices
23
+
24
+
25
+ class ScalarFloatValues(NamedTuple):
26
+ degrees: wp.float32 = None
27
+ radians: wp.float32 = None
28
+
29
+
30
+ @wp.kernel
31
+ def scalar_float_kernel(
32
+ i: int,
33
+ x: wp.array(dtype=wp.float32),
34
+ out: wp.array(dtype=wp.float32),
35
+ ):
36
+ if i == 0:
37
+ out[0] = wp.degrees(x[0])
38
+ elif i == 1:
39
+ out[0] = wp.radians(x[0])
40
+
41
+
42
+ def test_scalar_math(test, device):
43
+ float_values = ScalarFloatValues(degrees=(0.123,), radians=(123.0,))
44
+ float_results_expected = ScalarFloatValues(degrees=7.047381, radians=2.146755)
45
+ adj_float_results_expected = ScalarFloatValues(degrees=57.29578, radians=0.017453)
46
+ for i, values in enumerate(float_values):
47
+ x = wp.array([values[0]], dtype=wp.float32, requires_grad=True, device=device)
48
+ out = wp.array([0.0], dtype=wp.float32, requires_grad=True, device=device)
49
+
50
+ tape = wp.Tape()
51
+ with tape:
52
+ wp.launch(scalar_float_kernel, dim=1, inputs=[i, x, out], device=device)
53
+
54
+ assert_np_equal(out.numpy(), np.array([float_results_expected[i]]), tol=1e-6)
55
+
56
+ tape.backward(out)
57
+
58
+ assert_np_equal(tape.gradients[x].numpy(), np.array([adj_float_results_expected[i]]), tol=1e-6)
59
+
60
+
61
+ @wp.kernel
62
+ def erf_kernel(x: wp.array(dtype=Any), out: wp.array(dtype=Any)):
63
+ i = wp.tid()
64
+
65
+ if i == 0:
66
+ out[i] = wp.erf(x[i])
67
+ elif i == 1:
68
+ out[i] = wp.erfc(x[i])
69
+ elif i == 2:
70
+ out[i] = wp.erfinv(x[i])
71
+ elif i == 3:
72
+ out[i] = wp.erfcinv(x[i])
73
+
74
+
75
+ def test_erf_math(test, device):
76
+ for type, tol in ((wp.float16, 1e-3), (wp.float32, 1e-6), (wp.float64, 1e-6)):
77
+ x = wp.full(4, value=0.123, dtype=type, requires_grad=True, device=device)
78
+ out = wp.zeros(4, dtype=type, requires_grad=True, device=device)
79
+
80
+ with wp.Tape() as tape:
81
+ wp.launch(erf_kernel, dim=4, inputs=[x], outputs=[out], device=device)
82
+
83
+ out.grad = wp.ones_like(out)
84
+
85
+ tape.backward()
86
+
87
+ out_true = np.array([0.13809388, 0.86190612, 0.10944129, 1.09057285])
88
+ adj_x_true = np.array([1.11143641, -1.11143641, 0.89690544, -2.91120449])
89
+
90
+ assert_np_equal(out.numpy(), out_true, tol=tol)
91
+ assert_np_equal(adj_x_true, x.grad.numpy(), tol=tol)
92
+
93
+
94
+ @wp.kernel
95
+ def test_vec_norm_kernel(vs: wp.array(dtype=Any), out: wp.array(dtype=float, ndim=2)):
96
+ tid = wp.tid()
97
+ out[tid, 0] = wp.norm_l1(vs[tid])
98
+ out[tid, 1] = wp.norm_l2(vs[tid])
99
+ out[tid, 2] = wp.norm_huber(vs[tid])
100
+ out[tid, 3] = wp.norm_pseudo_huber(vs[tid])
101
+
102
+
103
+ def test_vec_norm(test, device):
104
+ # ground-truth implementations from SciPy
105
+ def huber(delta, x):
106
+ if x <= delta:
107
+ return 0.5 * x**2
108
+ else:
109
+ return delta * (x - 0.5 * delta)
110
+
111
+ def pseudo_huber(delta, x):
112
+ return delta**2 * (np.sqrt(1 + (x / delta) ** 2) - 1)
113
+
114
+ v0 = wp.vec3(-2.0, -1.0, -3.0)
115
+ v1 = wp.vec3(2.0, 1.0, 3.0)
116
+ v2 = wp.vec3(0.0, 0.0, 0.0)
117
+
118
+ xs = wp.array([v0, v1, v2], dtype=wp.vec3, requires_grad=True, device=device)
119
+ out = wp.empty((len(xs), 4), dtype=wp.float32, requires_grad=True, device=device)
120
+
121
+ wp.launch(test_vec_norm_kernel, dim=len(xs), inputs=[xs], outputs=[out], device=device)
122
+
123
+ for i, x in enumerate([v0, v1, v2]):
124
+ assert_np_equal(
125
+ out.numpy()[i],
126
+ np.array(
127
+ [
128
+ np.linalg.norm(x, ord=1),
129
+ np.linalg.norm(x, ord=2),
130
+ huber(1.0, wp.length(x)),
131
+ # note SciPy defines the Pseudo-Huber loss slightly differently
132
+ pseudo_huber(1.0, wp.length(x)) + 1.0,
133
+ ]
134
+ ),
135
+ tol=1e-6,
136
+ )
137
+
138
+
139
+ devices = get_test_devices()
140
+
141
+
142
+ class TestMath(unittest.TestCase):
143
+ def test_vec_type(self):
144
+ vec5 = wp.vec(length=5, dtype=float)
145
+ v = vec5()
146
+ w = vec5()
147
+ a = vec5(1.0)
148
+ b = vec5(0.0, 0.0, 0.0, 0.0, 0.0)
149
+ c = vec5(0.0)
150
+
151
+ v[0] = 1.0
152
+ v.x = 0.0
153
+ v[1:] = [1.0, 1.0, 1.0, 1.0]
154
+
155
+ w[0] = 1.0
156
+ w[1:] = [0.0, 0.0, 0.0, 0.0]
157
+
158
+ self.assertEqual(v[0], w[1], "vec setter error")
159
+ self.assertEqual(v.x, w.y, "vec setter error")
160
+
161
+ for x in v[1:]:
162
+ self.assertEqual(x, 1.0, "vec slicing error")
163
+
164
+ self.assertEqual(b, c, "vec equality error")
165
+
166
+ self.assertEqual(str(v), "[0.0, 1.0, 1.0, 1.0, 1.0]", "vec to string error")
167
+
168
+ def test_mat_type(self):
169
+ mat55 = wp.mat(shape=(5, 5), dtype=float)
170
+ m1 = mat55()
171
+ m2 = mat55()
172
+
173
+ for i in range(5):
174
+ for j in range(5):
175
+ if i == j:
176
+ m1[i, j] = 1.0
177
+ else:
178
+ m1[i, j] = 0.0
179
+
180
+ for i in range(5):
181
+ m2[i] = [1.0, 1.0, 1.0, 1.0, 1.0]
182
+
183
+ a = mat55(1.0)
184
+ # fmt: off
185
+ b = mat55(
186
+ 1.0, 0.0, 0.0, 0.0, 0.0,
187
+ 0.0, 1.0, 0.0, 0.0, 0.0,
188
+ 0.0, 0.0, 1.0, 0.0, 0.0,
189
+ 0.0, 0.0, 0.0, 1.0, 0.0,
190
+ 0.0, 0.0, 0.0, 0.0, 1.0,
191
+ )
192
+ # fmt: on
193
+
194
+ self.assertEqual(m1, b, "mat element setting error")
195
+ self.assertEqual(m2, a, "mat row setting error")
196
+ self.assertEqual(m1[0, 0], 1.0, "mat element getting error")
197
+ self.assertEqual(m2[0], [1.0, 1.0, 1.0, 1.0, 1.0], "mat row getting error")
198
+ self.assertEqual(
199
+ str(b),
200
+ "[[1.0, 0.0, 0.0, 0.0, 0.0],\n [0.0, 1.0, 0.0, 0.0, 0.0],\n [0.0, 0.0, 1.0, 0.0, 0.0],\n [0.0, 0.0, 0.0, 1.0, 0.0],\n [0.0, 0.0, 0.0, 0.0, 1.0]]",
201
+ "mat to string error",
202
+ )
203
+
204
+
205
+ add_function_test(TestMath, "test_scalar_math", test_scalar_math, devices=devices)
206
+ add_function_test(TestMath, "test_erf_math", test_erf_math, devices=devices)
207
+ add_function_test(TestMath, "test_vec_norm", test_vec_norm, devices=devices)
208
+
209
+
210
+ if __name__ == "__main__":
211
+ wp.clear_kernel_cache()
212
+ unittest.main(verbosity=2)
@@ -0,0 +1,287 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import importlib
17
+ import importlib.util
18
+ import os
19
+ import shutil
20
+ import unittest
21
+ from pathlib import Path
22
+
23
+ import numpy as np
24
+
25
+ import warp as wp
26
+ import warp.tests.aux_test_module_aot
27
+ from warp.tests.unittest_utils import *
28
+
29
+ ADD_KERNEL_START = """import warp as wp
30
+
31
+
32
+ @wp.kernel
33
+ def add_kernel(a: wp.array(dtype=wp.int32), b: wp.array(dtype=wp.int32), res: wp.array(dtype=wp.int32)):
34
+ pass
35
+ """
36
+
37
+ ADD_KERNEL_FINAL = """import warp as wp
38
+
39
+
40
+ @wp.kernel
41
+ def add_kernel(a: wp.array(dtype=wp.int32), b: wp.array(dtype=wp.int32), res: wp.array(dtype=wp.int32)):
42
+ i = wp.tid()
43
+ res[i] = a[i] + b[i]
44
+ """
45
+
46
+
47
+ def reload_module(module):
48
+ # Clearing the .pyc file associated with a module is a necessary workaround
49
+ # for `importlib.reload` to work as expected when run from within Kit.
50
+ cache_file = importlib.util.cache_from_source(module.__file__)
51
+ if os.path.exists(cache_file):
52
+ os.remove(cache_file)
53
+ importlib.reload(module)
54
+
55
+
56
+ TEST_CACHE_DIR = Path(os.path.abspath(os.path.join(os.path.dirname(__file__), "test_module_aot_cache")))
57
+
58
+
59
+ def test_disable_hashing(test, device):
60
+ """Test that module hashing can be disabled.
61
+
62
+ A module is run, modified, and run again. The second run should not trigger
63
+ a recompilation since the hash will not be used to detect changes.
64
+ """
65
+
66
+ try:
67
+ shutil.rmtree(TEST_CACHE_DIR, ignore_errors=True)
68
+ TEST_CACHE_DIR.mkdir(parents=True, exist_ok=True)
69
+ wp.set_module_options(
70
+ {"block_dim": 1 if device.is_cpu else 256},
71
+ warp.tests.aux_test_module_aot,
72
+ )
73
+
74
+ a = wp.ones(10, dtype=wp.int32, device=device)
75
+ b = wp.ones(10, dtype=wp.int32, device=device)
76
+ res = wp.zeros((10,), dtype=wp.int32, device=device)
77
+
78
+ # Write out the module and import it
79
+ with open(os.path.abspath(os.path.join(os.path.dirname(__file__), "aux_test_module_aot.py")), "w") as f:
80
+ f.writelines(ADD_KERNEL_START)
81
+ reload_module(warp.tests.aux_test_module_aot)
82
+
83
+ # First launch, cold compile, expect res to be unchanged since kernel is empty
84
+ wp.compile_aot_module(warp.tests.aux_test_module_aot, device, module_dir=TEST_CACHE_DIR, strip_hash=True)
85
+ wp.load_aot_module(warp.tests.aux_test_module_aot, device, module_dir=TEST_CACHE_DIR, strip_hash=True)
86
+
87
+ wp.launch(
88
+ warp.tests.aux_test_module_aot.add_kernel,
89
+ dim=a.shape,
90
+ inputs=[a, b],
91
+ outputs=[res],
92
+ device=device,
93
+ )
94
+
95
+ assert_np_equal(res.numpy(), np.zeros((10,), dtype=np.int32))
96
+
97
+ res.zero_()
98
+
99
+ # Write out the modified module and import it
100
+ with open(os.path.abspath(os.path.join(os.path.dirname(__file__), "aux_test_module_aot.py")), "w") as f:
101
+ f.writelines(ADD_KERNEL_FINAL)
102
+ reload_module(warp.tests.aux_test_module_aot)
103
+
104
+ # This time, the hash checks will be skipped so the previously compiled module will be loaded
105
+ wp.load_aot_module(warp.tests.aux_test_module_aot, device, module_dir=TEST_CACHE_DIR, strip_hash=True)
106
+
107
+ # Kernel is executed with the ADD_KERNEL_START code, not the ADD_KERNEL_FINAL code
108
+ wp.launch(
109
+ warp.tests.aux_test_module_aot.add_kernel,
110
+ dim=a.shape,
111
+ inputs=[a, b],
112
+ outputs=[res],
113
+ device=device,
114
+ )
115
+
116
+ assert_np_equal(res.numpy(), np.zeros((10,), dtype=np.int32))
117
+ finally:
118
+ # Clear the cache directory
119
+ shutil.rmtree(TEST_CACHE_DIR, ignore_errors=True)
120
+ # Revert the module default options and auxiliary file to the original states
121
+ wp.set_module_options({"cuda_output": None, "strip_hash": False}, warp.tests.aux_test_module_aot)
122
+
123
+ with open(os.path.abspath(os.path.join(os.path.dirname(__file__), "aux_test_module_aot.py")), "w") as f:
124
+ f.writelines(ADD_KERNEL_FINAL)
125
+
126
+
127
+ def test_enable_hashing(test, device):
128
+ """Ensure that the logic of test_disable_hashing is sound.
129
+
130
+ This test sets "strip_hash" to False, so normal module hashing rules
131
+ should be in effect.
132
+ """
133
+
134
+ try:
135
+ shutil.rmtree(TEST_CACHE_DIR, ignore_errors=True)
136
+ TEST_CACHE_DIR.mkdir(parents=True, exist_ok=True)
137
+ wp.set_module_options(
138
+ {"block_dim": 1 if device.is_cpu else 256},
139
+ warp.tests.aux_test_module_aot,
140
+ )
141
+
142
+ a = wp.ones(10, dtype=wp.int32, device=device)
143
+ b = wp.ones(10, dtype=wp.int32, device=device)
144
+ res = wp.zeros((10,), dtype=wp.int32, device=device)
145
+
146
+ # Write out the module and import it
147
+ with open(os.path.abspath(os.path.join(os.path.dirname(__file__), "aux_test_module_aot.py")), "w") as f:
148
+ f.writelines(ADD_KERNEL_START)
149
+ reload_module(warp.tests.aux_test_module_aot)
150
+
151
+ # First launch, cold compile, expect no-op result
152
+ wp.compile_aot_module(warp.tests.aux_test_module_aot, device, module_dir=TEST_CACHE_DIR, strip_hash=False)
153
+ wp.load_aot_module(warp.tests.aux_test_module_aot, device, module_dir=TEST_CACHE_DIR, strip_hash=False)
154
+ wp.launch(
155
+ warp.tests.aux_test_module_aot.add_kernel,
156
+ dim=a.shape,
157
+ inputs=[a, b],
158
+ outputs=[res],
159
+ device=device,
160
+ )
161
+
162
+ assert_np_equal(res.numpy(), np.zeros((10,), dtype=np.int32))
163
+
164
+ # Write out the modified module (results in a different hash) and import it
165
+ with open(os.path.abspath(os.path.join(os.path.dirname(__file__), "aux_test_module_aot.py")), "w") as f:
166
+ f.writelines(ADD_KERNEL_FINAL)
167
+ reload_module(warp.tests.aux_test_module_aot)
168
+
169
+ # Trying to load the module should fail since a compiled module with the expected hash does not exist
170
+ with test.assertRaises(FileNotFoundError):
171
+ wp.load_aot_module("warp.tests.aux_test_module_aot", device, module_dir=TEST_CACHE_DIR, strip_hash=False)
172
+ finally:
173
+ # Clear the cache directory
174
+ shutil.rmtree(TEST_CACHE_DIR, ignore_errors=True)
175
+ # Revert the module default options and auxiliary file to the original states
176
+ wp.set_module_options({"cuda_output": None, "strip_hash": False}, warp.tests.aux_test_module_aot)
177
+
178
+ with open(os.path.abspath(os.path.join(os.path.dirname(__file__), "aux_test_module_aot.py")), "w") as f:
179
+ f.writelines(ADD_KERNEL_FINAL)
180
+
181
+
182
+ def test_module_load_resolution(test, device):
183
+ """Test various ways to resolving a module when loading and compiling."""
184
+
185
+ wp.set_module_options(
186
+ {"block_dim": 1 if device.is_cpu else 256},
187
+ warp.tests.aux_test_module_aot,
188
+ )
189
+
190
+ a = wp.ones(10, dtype=wp.int32, device=device)
191
+ b = wp.ones(10, dtype=wp.int32, device=device)
192
+ res = wp.zeros((10,), dtype=wp.int32, device=device)
193
+
194
+ reload_module(warp.tests.aux_test_module_aot)
195
+ wp.compile_aot_module(warp.tests.aux_test_module_aot, device)
196
+ wp.load_aot_module(warp.tests.aux_test_module_aot, device)
197
+
198
+ wp.launch(
199
+ warp.tests.aux_test_module_aot.add_kernel,
200
+ dim=a.shape,
201
+ inputs=[a, b],
202
+ outputs=[res],
203
+ device=device,
204
+ )
205
+ assert_np_equal(res.numpy(), np.full((10,), 2, dtype=np.int32))
206
+
207
+ reload_module(warp.tests.aux_test_module_aot)
208
+ res.zero_()
209
+ wp.compile_aot_module("warp.tests.aux_test_module_aot", device)
210
+ wp.load_aot_module("warp.tests.aux_test_module_aot", device)
211
+
212
+ wp.launch(
213
+ warp.tests.aux_test_module_aot.add_kernel,
214
+ dim=a.shape,
215
+ inputs=[a, b],
216
+ outputs=[res],
217
+ device=device,
218
+ )
219
+ assert_np_equal(res.numpy(), np.full((10,), 2, dtype=np.int32))
220
+
221
+
222
+ class TestModuleAOT(unittest.TestCase):
223
+ def test_module_compile_specified_arch_ptx(self):
224
+ """Test that a module can be compiled for a specific architecture or architectures (PTX)."""
225
+
226
+ if wp.get_cuda_device_count() == 0:
227
+ self.skipTest("No CUDA devices found")
228
+
229
+ if len(wp._src.context.runtime.nvrtc_supported_archs) < 2:
230
+ self.skipTest("NVRTC must support at least two architectures to run this test")
231
+
232
+ try:
233
+ shutil.rmtree(TEST_CACHE_DIR, ignore_errors=True)
234
+ TEST_CACHE_DIR.mkdir(parents=True, exist_ok=True)
235
+
236
+ archs = list(wp._src.context.runtime.nvrtc_supported_archs)[:2]
237
+
238
+ wp.compile_aot_module(warp.tests.aux_test_module_aot, arch=archs, module_dir=TEST_CACHE_DIR, use_ptx=True)
239
+
240
+ # Make sure the expected files exist
241
+ module_identifier = wp.get_module("warp.tests.aux_test_module_aot").get_module_identifier()
242
+ for arch in archs:
243
+ expected_filename = f"{module_identifier}.sm{arch}.ptx"
244
+ expected_path = TEST_CACHE_DIR / expected_filename
245
+ self.assertTrue(expected_path.exists(), f"Expected compiled PTX file not found: {expected_path}")
246
+
247
+ finally:
248
+ # Clear the cache directory
249
+ shutil.rmtree(TEST_CACHE_DIR, ignore_errors=True)
250
+
251
+ def test_module_compile_specified_arch_cubin(self):
252
+ """Test that a module can be compiled for a specific architecture or architectures (CUBIN)."""
253
+
254
+ if wp.get_cuda_device_count() == 0:
255
+ self.skipTest("No CUDA devices found")
256
+
257
+ if len(wp._src.context.runtime.nvrtc_supported_archs) < 2:
258
+ self.skipTest("NVRTC must support at least two architectures to run this test")
259
+
260
+ try:
261
+ shutil.rmtree(TEST_CACHE_DIR, ignore_errors=True)
262
+ TEST_CACHE_DIR.mkdir(parents=True, exist_ok=True)
263
+
264
+ archs = list(wp._src.context.runtime.nvrtc_supported_archs)[:2]
265
+
266
+ wp.compile_aot_module(warp.tests.aux_test_module_aot, arch=archs, module_dir=TEST_CACHE_DIR, use_ptx=False)
267
+
268
+ # Make sure the expected files exist
269
+ module_identifier = wp.get_module("warp.tests.aux_test_module_aot").get_module_identifier()
270
+ for arch in archs:
271
+ expected_filename = f"{module_identifier}.sm{arch}.cubin"
272
+ expected_path = TEST_CACHE_DIR / expected_filename
273
+ self.assertTrue(expected_path.exists(), f"Expected compiled CUBIN file not found: {expected_path}")
274
+
275
+ finally:
276
+ # Clear the cache directory
277
+ shutil.rmtree(TEST_CACHE_DIR, ignore_errors=True)
278
+
279
+
280
+ devices = get_test_devices()
281
+ add_function_test(TestModuleAOT, "test_disable_hashing", test_disable_hashing, devices=devices)
282
+ add_function_test(TestModuleAOT, "test_enable_hashing", test_enable_hashing, devices=devices)
283
+ add_function_test(TestModuleAOT, "test_module_load_resolution", test_module_load_resolution, devices=devices)
284
+
285
+ if __name__ == "__main__":
286
+ wp.clear_kernel_cache()
287
+ unittest.main(verbosity=2)