warp-lang 1.10.0__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (468) hide show
  1. warp/__init__.py +334 -0
  2. warp/__init__.pyi +5856 -0
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1077 -0
  5. warp/_src/build.py +620 -0
  6. warp/_src/build_dll.py +642 -0
  7. warp/_src/builtins.py +10555 -0
  8. warp/_src/codegen.py +4361 -0
  9. warp/_src/config.py +178 -0
  10. warp/_src/constants.py +59 -0
  11. warp/_src/context.py +8352 -0
  12. warp/_src/dlpack.py +464 -0
  13. warp/_src/fabric.py +362 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +510 -0
  16. warp/_src/fem/cache.py +689 -0
  17. warp/_src/fem/dirichlet.py +190 -0
  18. warp/_src/fem/domain.py +553 -0
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +703 -0
  21. warp/_src/fem/field/nodal_field.py +403 -0
  22. warp/_src/fem/field/restriction.py +39 -0
  23. warp/_src/fem/field/virtual.py +1021 -0
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/_src/fem/geometry/adaptive_nanogrid.py +782 -0
  26. warp/_src/fem/geometry/closest_point.py +99 -0
  27. warp/_src/fem/geometry/deformed_geometry.py +277 -0
  28. warp/_src/fem/geometry/element.py +854 -0
  29. warp/_src/fem/geometry/geometry.py +693 -0
  30. warp/_src/fem/geometry/grid_2d.py +478 -0
  31. warp/_src/fem/geometry/grid_3d.py +539 -0
  32. warp/_src/fem/geometry/hexmesh.py +956 -0
  33. warp/_src/fem/geometry/nanogrid.py +660 -0
  34. warp/_src/fem/geometry/partition.py +483 -0
  35. warp/_src/fem/geometry/quadmesh.py +597 -0
  36. warp/_src/fem/geometry/tetmesh.py +762 -0
  37. warp/_src/fem/geometry/trimesh.py +588 -0
  38. warp/_src/fem/integrate.py +2507 -0
  39. warp/_src/fem/linalg.py +385 -0
  40. warp/_src/fem/operator.py +398 -0
  41. warp/_src/fem/polynomial.py +231 -0
  42. warp/_src/fem/quadrature/__init__.py +17 -0
  43. warp/_src/fem/quadrature/pic_quadrature.py +318 -0
  44. warp/_src/fem/quadrature/quadrature.py +665 -0
  45. warp/_src/fem/space/__init__.py +248 -0
  46. warp/_src/fem/space/basis_function_space.py +499 -0
  47. warp/_src/fem/space/basis_space.py +681 -0
  48. warp/_src/fem/space/dof_mapper.py +253 -0
  49. warp/_src/fem/space/function_space.py +312 -0
  50. warp/_src/fem/space/grid_2d_function_space.py +179 -0
  51. warp/_src/fem/space/grid_3d_function_space.py +229 -0
  52. warp/_src/fem/space/hexmesh_function_space.py +255 -0
  53. warp/_src/fem/space/nanogrid_function_space.py +199 -0
  54. warp/_src/fem/space/partition.py +435 -0
  55. warp/_src/fem/space/quadmesh_function_space.py +222 -0
  56. warp/_src/fem/space/restriction.py +221 -0
  57. warp/_src/fem/space/shape/__init__.py +152 -0
  58. warp/_src/fem/space/shape/cube_shape_function.py +1107 -0
  59. warp/_src/fem/space/shape/shape_function.py +134 -0
  60. warp/_src/fem/space/shape/square_shape_function.py +928 -0
  61. warp/_src/fem/space/shape/tet_shape_function.py +829 -0
  62. warp/_src/fem/space/shape/triangle_shape_function.py +674 -0
  63. warp/_src/fem/space/tetmesh_function_space.py +270 -0
  64. warp/_src/fem/space/topology.py +461 -0
  65. warp/_src/fem/space/trimesh_function_space.py +193 -0
  66. warp/_src/fem/types.py +114 -0
  67. warp/_src/fem/utils.py +488 -0
  68. warp/_src/jax.py +188 -0
  69. warp/_src/jax_experimental/__init__.py +14 -0
  70. warp/_src/jax_experimental/custom_call.py +389 -0
  71. warp/_src/jax_experimental/ffi.py +1286 -0
  72. warp/_src/jax_experimental/xla_ffi.py +658 -0
  73. warp/_src/marching_cubes.py +710 -0
  74. warp/_src/math.py +416 -0
  75. warp/_src/optim/__init__.py +14 -0
  76. warp/_src/optim/adam.py +165 -0
  77. warp/_src/optim/linear.py +1608 -0
  78. warp/_src/optim/sgd.py +114 -0
  79. warp/_src/paddle.py +408 -0
  80. warp/_src/render/__init__.py +14 -0
  81. warp/_src/render/imgui_manager.py +291 -0
  82. warp/_src/render/render_opengl.py +3638 -0
  83. warp/_src/render/render_usd.py +939 -0
  84. warp/_src/render/utils.py +162 -0
  85. warp/_src/sparse.py +2718 -0
  86. warp/_src/tape.py +1208 -0
  87. warp/_src/thirdparty/__init__.py +0 -0
  88. warp/_src/thirdparty/appdirs.py +598 -0
  89. warp/_src/thirdparty/dlpack.py +145 -0
  90. warp/_src/thirdparty/unittest_parallel.py +676 -0
  91. warp/_src/torch.py +393 -0
  92. warp/_src/types.py +5888 -0
  93. warp/_src/utils.py +1695 -0
  94. warp/autograd.py +33 -0
  95. warp/bin/libwarp-clang.dylib +0 -0
  96. warp/bin/libwarp.dylib +0 -0
  97. warp/build.py +29 -0
  98. warp/build_dll.py +24 -0
  99. warp/codegen.py +24 -0
  100. warp/constants.py +24 -0
  101. warp/context.py +33 -0
  102. warp/dlpack.py +24 -0
  103. warp/examples/__init__.py +24 -0
  104. warp/examples/assets/bear.usd +0 -0
  105. warp/examples/assets/bunny.usd +0 -0
  106. warp/examples/assets/cube.usd +0 -0
  107. warp/examples/assets/nonuniform.usd +0 -0
  108. warp/examples/assets/nvidia_logo.png +0 -0
  109. warp/examples/assets/pixel.jpg +0 -0
  110. warp/examples/assets/rocks.nvdb +0 -0
  111. warp/examples/assets/rocks.usd +0 -0
  112. warp/examples/assets/sphere.usd +0 -0
  113. warp/examples/assets/square_cloth.usd +0 -0
  114. warp/examples/benchmarks/benchmark_api.py +389 -0
  115. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  116. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  117. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  118. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  119. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  120. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  121. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  122. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  123. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  124. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  125. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  126. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  127. warp/examples/benchmarks/benchmark_launches.py +301 -0
  128. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  129. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  130. warp/examples/browse.py +37 -0
  131. warp/examples/core/example_cupy.py +86 -0
  132. warp/examples/core/example_dem.py +241 -0
  133. warp/examples/core/example_fluid.py +299 -0
  134. warp/examples/core/example_graph_capture.py +150 -0
  135. warp/examples/core/example_marching_cubes.py +195 -0
  136. warp/examples/core/example_mesh.py +180 -0
  137. warp/examples/core/example_mesh_intersect.py +211 -0
  138. warp/examples/core/example_nvdb.py +182 -0
  139. warp/examples/core/example_raycast.py +111 -0
  140. warp/examples/core/example_raymarch.py +205 -0
  141. warp/examples/core/example_render_opengl.py +290 -0
  142. warp/examples/core/example_sample_mesh.py +300 -0
  143. warp/examples/core/example_sph.py +411 -0
  144. warp/examples/core/example_spin_lock.py +93 -0
  145. warp/examples/core/example_torch.py +211 -0
  146. warp/examples/core/example_wave.py +269 -0
  147. warp/examples/core/example_work_queue.py +118 -0
  148. warp/examples/distributed/example_jacobi_mpi.py +506 -0
  149. warp/examples/fem/example_adaptive_grid.py +286 -0
  150. warp/examples/fem/example_apic_fluid.py +469 -0
  151. warp/examples/fem/example_burgers.py +261 -0
  152. warp/examples/fem/example_convection_diffusion.py +181 -0
  153. warp/examples/fem/example_convection_diffusion_dg.py +225 -0
  154. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  155. warp/examples/fem/example_deformed_geometry.py +172 -0
  156. warp/examples/fem/example_diffusion.py +196 -0
  157. warp/examples/fem/example_diffusion_3d.py +225 -0
  158. warp/examples/fem/example_diffusion_mgpu.py +225 -0
  159. warp/examples/fem/example_distortion_energy.py +228 -0
  160. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  161. warp/examples/fem/example_magnetostatics.py +242 -0
  162. warp/examples/fem/example_mixed_elasticity.py +293 -0
  163. warp/examples/fem/example_navier_stokes.py +263 -0
  164. warp/examples/fem/example_nonconforming_contact.py +300 -0
  165. warp/examples/fem/example_stokes.py +213 -0
  166. warp/examples/fem/example_stokes_transfer.py +262 -0
  167. warp/examples/fem/example_streamlines.py +357 -0
  168. warp/examples/fem/utils.py +1047 -0
  169. warp/examples/interop/example_jax_callable.py +146 -0
  170. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  171. warp/examples/interop/example_jax_kernel.py +232 -0
  172. warp/examples/optim/example_diffray.py +561 -0
  173. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  174. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  175. warp/examples/tile/example_tile_cholesky.py +88 -0
  176. warp/examples/tile/example_tile_convolution.py +66 -0
  177. warp/examples/tile/example_tile_fft.py +55 -0
  178. warp/examples/tile/example_tile_filtering.py +113 -0
  179. warp/examples/tile/example_tile_matmul.py +85 -0
  180. warp/examples/tile/example_tile_mcgp.py +191 -0
  181. warp/examples/tile/example_tile_mlp.py +385 -0
  182. warp/examples/tile/example_tile_nbody.py +199 -0
  183. warp/fabric.py +24 -0
  184. warp/fem/__init__.py +173 -0
  185. warp/fem/adaptivity.py +26 -0
  186. warp/fem/cache.py +30 -0
  187. warp/fem/dirichlet.py +24 -0
  188. warp/fem/field/__init__.py +24 -0
  189. warp/fem/field/field.py +26 -0
  190. warp/fem/geometry/__init__.py +21 -0
  191. warp/fem/geometry/closest_point.py +31 -0
  192. warp/fem/linalg.py +38 -0
  193. warp/fem/operator.py +32 -0
  194. warp/fem/polynomial.py +29 -0
  195. warp/fem/space/__init__.py +22 -0
  196. warp/fem/space/basis_space.py +24 -0
  197. warp/fem/space/shape/__init__.py +68 -0
  198. warp/fem/space/topology.py +24 -0
  199. warp/fem/types.py +24 -0
  200. warp/fem/utils.py +32 -0
  201. warp/jax.py +29 -0
  202. warp/jax_experimental/__init__.py +29 -0
  203. warp/jax_experimental/custom_call.py +29 -0
  204. warp/jax_experimental/ffi.py +39 -0
  205. warp/jax_experimental/xla_ffi.py +24 -0
  206. warp/marching_cubes.py +24 -0
  207. warp/math.py +37 -0
  208. warp/native/array.h +1687 -0
  209. warp/native/builtin.h +2327 -0
  210. warp/native/bvh.cpp +562 -0
  211. warp/native/bvh.cu +826 -0
  212. warp/native/bvh.h +555 -0
  213. warp/native/clang/clang.cpp +541 -0
  214. warp/native/coloring.cpp +622 -0
  215. warp/native/crt.cpp +51 -0
  216. warp/native/crt.h +568 -0
  217. warp/native/cuda_crt.h +1058 -0
  218. warp/native/cuda_util.cpp +677 -0
  219. warp/native/cuda_util.h +313 -0
  220. warp/native/error.cpp +77 -0
  221. warp/native/error.h +36 -0
  222. warp/native/exports.h +2023 -0
  223. warp/native/fabric.h +246 -0
  224. warp/native/hashgrid.cpp +311 -0
  225. warp/native/hashgrid.cu +89 -0
  226. warp/native/hashgrid.h +240 -0
  227. warp/native/initializer_array.h +41 -0
  228. warp/native/intersect.h +1253 -0
  229. warp/native/intersect_adj.h +375 -0
  230. warp/native/intersect_tri.h +348 -0
  231. warp/native/mat.h +5189 -0
  232. warp/native/mathdx.cpp +93 -0
  233. warp/native/matnn.h +221 -0
  234. warp/native/mesh.cpp +266 -0
  235. warp/native/mesh.cu +406 -0
  236. warp/native/mesh.h +2097 -0
  237. warp/native/nanovdb/GridHandle.h +533 -0
  238. warp/native/nanovdb/HostBuffer.h +591 -0
  239. warp/native/nanovdb/NanoVDB.h +6246 -0
  240. warp/native/nanovdb/NodeManager.h +323 -0
  241. warp/native/nanovdb/PNanoVDB.h +3390 -0
  242. warp/native/noise.h +859 -0
  243. warp/native/quat.h +1664 -0
  244. warp/native/rand.h +342 -0
  245. warp/native/range.h +145 -0
  246. warp/native/reduce.cpp +174 -0
  247. warp/native/reduce.cu +363 -0
  248. warp/native/runlength_encode.cpp +79 -0
  249. warp/native/runlength_encode.cu +61 -0
  250. warp/native/scan.cpp +47 -0
  251. warp/native/scan.cu +55 -0
  252. warp/native/scan.h +23 -0
  253. warp/native/solid_angle.h +466 -0
  254. warp/native/sort.cpp +251 -0
  255. warp/native/sort.cu +286 -0
  256. warp/native/sort.h +35 -0
  257. warp/native/sparse.cpp +241 -0
  258. warp/native/sparse.cu +435 -0
  259. warp/native/spatial.h +1306 -0
  260. warp/native/svd.h +727 -0
  261. warp/native/temp_buffer.h +46 -0
  262. warp/native/tile.h +4124 -0
  263. warp/native/tile_radix_sort.h +1112 -0
  264. warp/native/tile_reduce.h +838 -0
  265. warp/native/tile_scan.h +240 -0
  266. warp/native/tuple.h +189 -0
  267. warp/native/vec.h +2199 -0
  268. warp/native/version.h +23 -0
  269. warp/native/volume.cpp +501 -0
  270. warp/native/volume.cu +68 -0
  271. warp/native/volume.h +970 -0
  272. warp/native/volume_builder.cu +483 -0
  273. warp/native/volume_builder.h +52 -0
  274. warp/native/volume_impl.h +70 -0
  275. warp/native/warp.cpp +1143 -0
  276. warp/native/warp.cu +4604 -0
  277. warp/native/warp.h +358 -0
  278. warp/optim/__init__.py +20 -0
  279. warp/optim/adam.py +24 -0
  280. warp/optim/linear.py +35 -0
  281. warp/optim/sgd.py +24 -0
  282. warp/paddle.py +24 -0
  283. warp/py.typed +0 -0
  284. warp/render/__init__.py +22 -0
  285. warp/render/imgui_manager.py +29 -0
  286. warp/render/render_opengl.py +24 -0
  287. warp/render/render_usd.py +24 -0
  288. warp/render/utils.py +24 -0
  289. warp/sparse.py +51 -0
  290. warp/tape.py +24 -0
  291. warp/tests/__init__.py +1 -0
  292. warp/tests/__main__.py +4 -0
  293. warp/tests/assets/curlnoise_golden.npy +0 -0
  294. warp/tests/assets/mlp_golden.npy +0 -0
  295. warp/tests/assets/pixel.npy +0 -0
  296. warp/tests/assets/pnoise_golden.npy +0 -0
  297. warp/tests/assets/spiky.usd +0 -0
  298. warp/tests/assets/test_grid.nvdb +0 -0
  299. warp/tests/assets/test_index_grid.nvdb +0 -0
  300. warp/tests/assets/test_int32_grid.nvdb +0 -0
  301. warp/tests/assets/test_vec_grid.nvdb +0 -0
  302. warp/tests/assets/torus.nvdb +0 -0
  303. warp/tests/assets/torus.usda +105 -0
  304. warp/tests/aux_test_class_kernel.py +34 -0
  305. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  306. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  307. warp/tests/aux_test_dependent.py +29 -0
  308. warp/tests/aux_test_grad_customs.py +29 -0
  309. warp/tests/aux_test_instancing_gc.py +26 -0
  310. warp/tests/aux_test_module_aot.py +7 -0
  311. warp/tests/aux_test_module_unload.py +23 -0
  312. warp/tests/aux_test_name_clash1.py +40 -0
  313. warp/tests/aux_test_name_clash2.py +40 -0
  314. warp/tests/aux_test_reference.py +9 -0
  315. warp/tests/aux_test_reference_reference.py +8 -0
  316. warp/tests/aux_test_square.py +16 -0
  317. warp/tests/aux_test_unresolved_func.py +22 -0
  318. warp/tests/aux_test_unresolved_symbol.py +22 -0
  319. warp/tests/cuda/__init__.py +0 -0
  320. warp/tests/cuda/test_async.py +676 -0
  321. warp/tests/cuda/test_conditional_captures.py +1147 -0
  322. warp/tests/cuda/test_ipc.py +124 -0
  323. warp/tests/cuda/test_mempool.py +233 -0
  324. warp/tests/cuda/test_multigpu.py +169 -0
  325. warp/tests/cuda/test_peer.py +139 -0
  326. warp/tests/cuda/test_pinned.py +84 -0
  327. warp/tests/cuda/test_streams.py +691 -0
  328. warp/tests/geometry/__init__.py +0 -0
  329. warp/tests/geometry/test_bvh.py +335 -0
  330. warp/tests/geometry/test_hash_grid.py +259 -0
  331. warp/tests/geometry/test_marching_cubes.py +294 -0
  332. warp/tests/geometry/test_mesh.py +318 -0
  333. warp/tests/geometry/test_mesh_query_aabb.py +392 -0
  334. warp/tests/geometry/test_mesh_query_point.py +935 -0
  335. warp/tests/geometry/test_mesh_query_ray.py +323 -0
  336. warp/tests/geometry/test_volume.py +1103 -0
  337. warp/tests/geometry/test_volume_write.py +346 -0
  338. warp/tests/interop/__init__.py +0 -0
  339. warp/tests/interop/test_dlpack.py +730 -0
  340. warp/tests/interop/test_jax.py +1673 -0
  341. warp/tests/interop/test_paddle.py +800 -0
  342. warp/tests/interop/test_torch.py +1001 -0
  343. warp/tests/run_coverage_serial.py +39 -0
  344. warp/tests/test_adam.py +162 -0
  345. warp/tests/test_arithmetic.py +1096 -0
  346. warp/tests/test_array.py +3756 -0
  347. warp/tests/test_array_reduce.py +156 -0
  348. warp/tests/test_assert.py +303 -0
  349. warp/tests/test_atomic.py +336 -0
  350. warp/tests/test_atomic_bitwise.py +209 -0
  351. warp/tests/test_atomic_cas.py +312 -0
  352. warp/tests/test_bool.py +220 -0
  353. warp/tests/test_builtins_resolution.py +732 -0
  354. warp/tests/test_closest_point_edge_edge.py +327 -0
  355. warp/tests/test_codegen.py +974 -0
  356. warp/tests/test_codegen_instancing.py +1495 -0
  357. warp/tests/test_compile_consts.py +215 -0
  358. warp/tests/test_conditional.py +298 -0
  359. warp/tests/test_context.py +35 -0
  360. warp/tests/test_copy.py +319 -0
  361. warp/tests/test_ctypes.py +618 -0
  362. warp/tests/test_dense.py +73 -0
  363. warp/tests/test_devices.py +127 -0
  364. warp/tests/test_enum.py +136 -0
  365. warp/tests/test_examples.py +424 -0
  366. warp/tests/test_fabricarray.py +998 -0
  367. warp/tests/test_fast_math.py +72 -0
  368. warp/tests/test_fem.py +2204 -0
  369. warp/tests/test_fixedarray.py +229 -0
  370. warp/tests/test_fp16.py +136 -0
  371. warp/tests/test_func.py +501 -0
  372. warp/tests/test_future_annotations.py +100 -0
  373. warp/tests/test_generics.py +656 -0
  374. warp/tests/test_grad.py +893 -0
  375. warp/tests/test_grad_customs.py +339 -0
  376. warp/tests/test_grad_debug.py +341 -0
  377. warp/tests/test_implicit_init.py +411 -0
  378. warp/tests/test_import.py +45 -0
  379. warp/tests/test_indexedarray.py +1140 -0
  380. warp/tests/test_intersect.py +103 -0
  381. warp/tests/test_iter.py +76 -0
  382. warp/tests/test_large.py +177 -0
  383. warp/tests/test_launch.py +411 -0
  384. warp/tests/test_lerp.py +151 -0
  385. warp/tests/test_linear_solvers.py +223 -0
  386. warp/tests/test_lvalue.py +427 -0
  387. warp/tests/test_map.py +526 -0
  388. warp/tests/test_mat.py +3515 -0
  389. warp/tests/test_mat_assign_copy.py +178 -0
  390. warp/tests/test_mat_constructors.py +573 -0
  391. warp/tests/test_mat_lite.py +122 -0
  392. warp/tests/test_mat_scalar_ops.py +2913 -0
  393. warp/tests/test_math.py +212 -0
  394. warp/tests/test_module_aot.py +287 -0
  395. warp/tests/test_module_hashing.py +258 -0
  396. warp/tests/test_modules_lite.py +70 -0
  397. warp/tests/test_noise.py +252 -0
  398. warp/tests/test_operators.py +299 -0
  399. warp/tests/test_options.py +129 -0
  400. warp/tests/test_overwrite.py +551 -0
  401. warp/tests/test_print.py +408 -0
  402. warp/tests/test_quat.py +2653 -0
  403. warp/tests/test_quat_assign_copy.py +145 -0
  404. warp/tests/test_rand.py +339 -0
  405. warp/tests/test_reload.py +303 -0
  406. warp/tests/test_rounding.py +157 -0
  407. warp/tests/test_runlength_encode.py +196 -0
  408. warp/tests/test_scalar_ops.py +133 -0
  409. warp/tests/test_smoothstep.py +108 -0
  410. warp/tests/test_snippet.py +318 -0
  411. warp/tests/test_sparse.py +845 -0
  412. warp/tests/test_spatial.py +2859 -0
  413. warp/tests/test_spatial_assign_copy.py +160 -0
  414. warp/tests/test_special_values.py +361 -0
  415. warp/tests/test_static.py +640 -0
  416. warp/tests/test_struct.py +901 -0
  417. warp/tests/test_tape.py +242 -0
  418. warp/tests/test_transient_module.py +93 -0
  419. warp/tests/test_triangle_closest_point.py +192 -0
  420. warp/tests/test_tuple.py +361 -0
  421. warp/tests/test_types.py +615 -0
  422. warp/tests/test_utils.py +594 -0
  423. warp/tests/test_vec.py +1408 -0
  424. warp/tests/test_vec_assign_copy.py +143 -0
  425. warp/tests/test_vec_constructors.py +325 -0
  426. warp/tests/test_vec_lite.py +80 -0
  427. warp/tests/test_vec_scalar_ops.py +2327 -0
  428. warp/tests/test_verify_fp.py +100 -0
  429. warp/tests/test_version.py +75 -0
  430. warp/tests/tile/__init__.py +0 -0
  431. warp/tests/tile/test_tile.py +1519 -0
  432. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  433. warp/tests/tile/test_tile_cholesky.py +608 -0
  434. warp/tests/tile/test_tile_load.py +724 -0
  435. warp/tests/tile/test_tile_mathdx.py +156 -0
  436. warp/tests/tile/test_tile_matmul.py +179 -0
  437. warp/tests/tile/test_tile_mlp.py +400 -0
  438. warp/tests/tile/test_tile_reduce.py +950 -0
  439. warp/tests/tile/test_tile_shared_memory.py +376 -0
  440. warp/tests/tile/test_tile_sort.py +121 -0
  441. warp/tests/tile/test_tile_view.py +173 -0
  442. warp/tests/unittest_serial.py +47 -0
  443. warp/tests/unittest_suites.py +430 -0
  444. warp/tests/unittest_utils.py +469 -0
  445. warp/tests/walkthrough_debug.py +95 -0
  446. warp/torch.py +24 -0
  447. warp/types.py +51 -0
  448. warp/utils.py +31 -0
  449. warp_lang-1.10.0.dist-info/METADATA +459 -0
  450. warp_lang-1.10.0.dist-info/RECORD +468 -0
  451. warp_lang-1.10.0.dist-info/WHEEL +5 -0
  452. warp_lang-1.10.0.dist-info/licenses/LICENSE.md +176 -0
  453. warp_lang-1.10.0.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  454. warp_lang-1.10.0.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  455. warp_lang-1.10.0.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  456. warp_lang-1.10.0.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  457. warp_lang-1.10.0.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  458. warp_lang-1.10.0.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  459. warp_lang-1.10.0.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  460. warp_lang-1.10.0.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  461. warp_lang-1.10.0.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  462. warp_lang-1.10.0.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  463. warp_lang-1.10.0.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  464. warp_lang-1.10.0.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  465. warp_lang-1.10.0.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  466. warp_lang-1.10.0.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  467. warp_lang-1.10.0.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  468. warp_lang-1.10.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,242 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import unittest
17
+
18
+ import numpy as np
19
+
20
+ import warp as wp
21
+ from warp.tests.unittest_utils import *
22
+
23
+
24
+ @wp.kernel
25
+ def mul_constant(x: wp.array(dtype=float), y: wp.array(dtype=float)):
26
+ tid = wp.tid()
27
+
28
+ y[tid] = x[tid] * 2.0
29
+
30
+
31
+ @wp.struct
32
+ class Multiplicands:
33
+ x: wp.array(dtype=float)
34
+ y: wp.array(dtype=float)
35
+
36
+
37
+ @wp.kernel
38
+ def mul_variable(mutiplicands: Multiplicands, z: wp.array(dtype=float)):
39
+ tid = wp.tid()
40
+
41
+ z[tid] = mutiplicands.x[tid] * mutiplicands.y[tid]
42
+
43
+
44
+ @wp.kernel
45
+ def dot_product(x: wp.array(dtype=float), y: wp.array(dtype=float), z: wp.array(dtype=float)):
46
+ tid = wp.tid()
47
+
48
+ wp.atomic_add(z, 0, x[tid] * y[tid])
49
+
50
+
51
+ def test_tape_mul_constant(test, device):
52
+ dim = 8
53
+ iters = 16
54
+ tape = wp.Tape()
55
+
56
+ # record onto tape
57
+ with tape:
58
+ # input data
59
+ x0 = wp.array(np.zeros(dim), dtype=wp.float32, device=device, requires_grad=True)
60
+ x = x0
61
+
62
+ for _i in range(iters):
63
+ y = wp.empty_like(x, requires_grad=True)
64
+ wp.launch(kernel=mul_constant, dim=dim, inputs=[x], outputs=[y], device=device)
65
+ x = y
66
+
67
+ # loss = wp.sum(x)
68
+ x.grad = wp.array(np.ones(dim), device=device, dtype=wp.float32)
69
+
70
+ # run backward
71
+ tape.backward()
72
+
73
+ # grad = 2.0^iters
74
+ assert_np_equal(tape.gradients[x0].numpy(), np.ones(dim) * (2**iters))
75
+
76
+
77
+ def test_tape_mul_variable(test, device):
78
+ dim = 8
79
+ tape = wp.Tape()
80
+
81
+ # record onto tape
82
+ with tape:
83
+ # input data (Note: We're intentionally testing structs in tapes here)
84
+ multiplicands = Multiplicands()
85
+ multiplicands.x = wp.array(np.ones(dim) * 16.0, dtype=wp.float32, device=device, requires_grad=True)
86
+ multiplicands.y = wp.array(np.ones(dim) * 32.0, dtype=wp.float32, device=device, requires_grad=True)
87
+ z = wp.zeros_like(multiplicands.x)
88
+
89
+ wp.launch(kernel=mul_variable, dim=dim, inputs=[multiplicands], outputs=[z], device=device)
90
+
91
+ # loss = wp.sum(x)
92
+ z.grad = wp.array(np.ones(dim), device=device, dtype=wp.float32)
93
+
94
+ # run backward
95
+ tape.backward()
96
+
97
+ # grad_x=y, grad_y=x
98
+ assert_np_equal(tape.gradients[multiplicands].x.numpy(), multiplicands.y.numpy())
99
+ assert_np_equal(tape.gradients[multiplicands].y.numpy(), multiplicands.x.numpy())
100
+
101
+ # run backward again with different incoming gradient
102
+ # should accumulate the same gradients again onto output
103
+ # so gradients = 2.0*prev
104
+ tape.backward()
105
+
106
+ assert_np_equal(tape.gradients[multiplicands].x.numpy(), multiplicands.y.numpy() * 2.0)
107
+ assert_np_equal(tape.gradients[multiplicands].y.numpy(), multiplicands.x.numpy() * 2.0)
108
+
109
+ # Clear launches and zero out the gradients
110
+ tape.reset()
111
+ assert_np_equal(tape.gradients[multiplicands].x.numpy(), np.zeros_like(tape.gradients[multiplicands].x.numpy()))
112
+ test.assertFalse(tape.launches)
113
+
114
+
115
+ def test_tape_dot_product(test, device):
116
+ dim = 8
117
+ tape = wp.Tape()
118
+
119
+ # record onto tape
120
+ with tape:
121
+ # input data
122
+ x = wp.array(np.ones(dim) * 16.0, dtype=wp.float32, device=device, requires_grad=True)
123
+ y = wp.array(np.ones(dim) * 32.0, dtype=wp.float32, device=device, requires_grad=True)
124
+ z = wp.zeros(n=1, dtype=wp.float32, device=device, requires_grad=True)
125
+
126
+ wp.launch(kernel=dot_product, dim=dim, inputs=[x, y], outputs=[z], device=device)
127
+
128
+ # scalar loss
129
+ tape.backward(loss=z)
130
+
131
+ # grad_x=y, grad_y=x
132
+ assert_np_equal(tape.gradients[x].numpy(), y.numpy())
133
+ assert_np_equal(tape.gradients[y].numpy(), x.numpy())
134
+
135
+
136
+ @wp.kernel
137
+ def assign_chain_kernel(x: wp.array(dtype=float), y: wp.array(dtype=float), z: wp.array(dtype=float)):
138
+ tid = wp.tid()
139
+ y[tid] = x[tid]
140
+ z[tid] = y[tid]
141
+
142
+
143
+ def test_tape_zero_multiple_outputs(test, device):
144
+ x = wp.array(np.arange(3), dtype=float, device=device, requires_grad=True)
145
+ y = wp.zeros_like(x)
146
+ z = wp.zeros_like(x)
147
+
148
+ tape = wp.Tape()
149
+ with tape:
150
+ wp.launch(assign_chain_kernel, dim=3, inputs=[x, y, z], device=device)
151
+
152
+ tape.backward(grads={y: wp.ones_like(x)})
153
+ assert_np_equal(x.grad.numpy(), np.ones(3, dtype=float))
154
+ tape.zero()
155
+
156
+ tape.backward(grads={z: wp.ones_like(x)})
157
+ assert_np_equal(x.grad.numpy(), np.ones(3, dtype=float))
158
+
159
+
160
+ @wp.struct
161
+ class NestedStruct:
162
+ arr: wp.array(dtype=float)
163
+
164
+
165
+ @wp.struct
166
+ class WrapperStruct:
167
+ nested: NestedStruct
168
+
169
+
170
+ @wp.kernel
171
+ def nested_loss_kernel(wrapper: WrapperStruct, loss: wp.array(dtype=float)):
172
+ i = wp.tid()
173
+ wp.atomic_add(loss, 0, wrapper.nested.arr[i])
174
+
175
+
176
+ def test_tape_nested_struct(test, device):
177
+ wrapper = WrapperStruct()
178
+ wrapper.nested = NestedStruct()
179
+ wrapper.nested.arr = wp.ones(shape=(1,), dtype=float, requires_grad=True, device=device)
180
+
181
+ loss = wp.zeros(shape=(1,), dtype=float, requires_grad=True, device=device)
182
+
183
+ tape = wp.Tape()
184
+ with tape:
185
+ wp.launch(nested_loss_kernel, dim=1, inputs=(wrapper, loss), device=device)
186
+
187
+ assert_np_equal(loss.numpy(), [1.0])
188
+
189
+ tape.backward(loss)
190
+ assert_np_equal(wrapper.nested.arr.grad.numpy(), [1.0])
191
+
192
+ tape.zero()
193
+
194
+ assert_np_equal(wrapper.nested.arr.grad.numpy(), [0.0])
195
+
196
+
197
+ def test_tape_visualize(test, device):
198
+ dim = 8
199
+ tape = wp.Tape()
200
+
201
+ # record onto tape
202
+ with tape:
203
+ # input data
204
+ x = wp.array(np.ones(dim) * 16.0, dtype=wp.float32, device=device, requires_grad=True)
205
+ y = wp.array(np.ones(dim) * 32.0, dtype=wp.float32, device=device, requires_grad=True)
206
+ z = wp.zeros(n=1, dtype=wp.float32, device=device, requires_grad=True)
207
+
208
+ tape.record_scope_begin("my loop")
209
+ for _ in range(16):
210
+ wp.launch(kernel=dot_product, dim=dim, inputs=[x, y], outputs=[z], device=device)
211
+ tape.record_scope_end()
212
+
213
+ # generate GraphViz diagram code
214
+ dot_code = tape.visualize(simplify_graph=True)
215
+
216
+ assert "repeated 16x" in dot_code
217
+ assert "my loop" in dot_code
218
+ assert dot_code.count("dot_product") == 1
219
+
220
+
221
+ devices = get_test_devices()
222
+
223
+
224
+ class TestTape(unittest.TestCase):
225
+ def test_tape_no_nested_tapes(self):
226
+ with self.assertRaises(RuntimeError):
227
+ with wp.Tape():
228
+ with wp.Tape():
229
+ pass
230
+
231
+
232
+ add_function_test(TestTape, "test_tape_mul_constant", test_tape_mul_constant, devices=devices)
233
+ add_function_test(TestTape, "test_tape_mul_variable", test_tape_mul_variable, devices=devices)
234
+ add_function_test(TestTape, "test_tape_dot_product", test_tape_dot_product, devices=devices)
235
+ add_function_test(TestTape, "test_tape_zero_multiple_outputs", test_tape_zero_multiple_outputs, devices=devices)
236
+ add_function_test(TestTape, "test_tape_nested_struct", test_tape_nested_struct, devices=devices)
237
+ add_function_test(TestTape, "test_tape_visualize", test_tape_visualize, devices=devices)
238
+
239
+
240
+ if __name__ == "__main__":
241
+ wp.clear_kernel_cache()
242
+ unittest.main(verbosity=2)
@@ -0,0 +1,93 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ import tempfile
18
+ import unittest
19
+ from importlib import util
20
+
21
+ import warp as wp
22
+ from warp.tests.unittest_utils import *
23
+
24
+ CODE = """# -*- coding: utf-8 -*-
25
+
26
+ import warp as wp
27
+
28
+ @wp.struct
29
+ class Data:
30
+ x: wp.array(dtype=int)
31
+
32
+ @wp.func
33
+ def increment(x: int):
34
+ # This shouldn't be picked up.
35
+ return x + 123
36
+
37
+ @wp.func
38
+ def increment(x: int):
39
+ return x + 1
40
+
41
+ @wp.kernel
42
+ def compute(data: Data):
43
+ data.x[0] = increment(data.x[0])
44
+ """
45
+
46
+
47
+ def load_code_as_module(code, name):
48
+ file, file_path = tempfile.mkstemp(suffix=".py")
49
+
50
+ try:
51
+ with os.fdopen(file, "w") as f:
52
+ f.write(code)
53
+
54
+ spec = util.spec_from_file_location(name, file_path)
55
+ module = util.module_from_spec(spec)
56
+ spec.loader.exec_module(module)
57
+ finally:
58
+ os.remove(file_path)
59
+
60
+ return module
61
+
62
+
63
+ def test_transient_module(test, device):
64
+ module = load_code_as_module(CODE, "")
65
+ # Loading it a second time shouldn't be an issue.
66
+ module = load_code_as_module(CODE, "")
67
+
68
+ assert len(module.compute.module.structs) == 1
69
+ assert len(module.compute.module.functions) == 1
70
+
71
+ data = module.Data()
72
+ data.x = wp.array([123], dtype=int, device=device)
73
+
74
+ wp.set_module_options({"foo": "bar"}, module=module)
75
+ assert wp.get_module_options(module=module).get("foo") == "bar"
76
+ assert module.compute.module.options.get("foo") == "bar"
77
+
78
+ wp.launch(module.compute, dim=1, inputs=[data], device=device)
79
+ assert_np_equal(data.x.numpy(), np.array([124]))
80
+
81
+
82
+ devices = get_test_devices()
83
+
84
+
85
+ class TestTransientModule(unittest.TestCase):
86
+ pass
87
+
88
+
89
+ add_function_test(TestTransientModule, "test_transient_module", test_transient_module, devices=devices)
90
+
91
+ if __name__ == "__main__":
92
+ wp.clear_kernel_cache()
93
+ unittest.main(verbosity=2)
@@ -0,0 +1,192 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import unittest
17
+
18
+ from warp.tests.unittest_utils import *
19
+
20
+
21
+ @wp.func
22
+ def triangle_closest_point_barycentric(a: wp.vec3, b: wp.vec3, c: wp.vec3, p: wp.vec3):
23
+ ab = b - a
24
+ ac = c - a
25
+ ap = p - a
26
+
27
+ d1 = wp.dot(ab, ap)
28
+ d2 = wp.dot(ac, ap)
29
+
30
+ if d1 <= 0.0 and d2 <= 0.0:
31
+ return wp.vec3(1.0, 0.0, 0.0)
32
+
33
+ bp = p - b
34
+ d3 = wp.dot(ab, bp)
35
+ d4 = wp.dot(ac, bp)
36
+
37
+ if d3 >= 0.0 and d4 <= d3:
38
+ return wp.vec3(0.0, 1.0, 0.0)
39
+
40
+ vc = d1 * d4 - d3 * d2
41
+ v = d1 / (d1 - d3)
42
+ if vc <= 0.0 and d1 >= 0.0 and d3 <= 0.0:
43
+ return wp.vec3(1.0 - v, v, 0.0)
44
+
45
+ cp = p - c
46
+ d5 = wp.dot(ab, cp)
47
+ d6 = wp.dot(ac, cp)
48
+
49
+ if d6 >= 0.0 and d5 <= d6:
50
+ return wp.vec3(0.0, 0.0, 1.0)
51
+
52
+ vb = d5 * d2 - d1 * d6
53
+ w = d2 / (d2 - d6)
54
+ if vb <= 0.0 and d2 >= 0.0 and d6 <= 0.0:
55
+ return wp.vec3(1.0 - w, 0.0, w)
56
+
57
+ va = d3 * d6 - d5 * d4
58
+ w = (d4 - d3) / ((d4 - d3) + (d5 - d6))
59
+ if va <= 0.0 and (d4 - d3) >= 0.0 and (d5 - d6) >= 0.0:
60
+ return wp.vec3(0.0, 1.0 - w, w)
61
+
62
+ denom = 1.0 / (va + vb + vc)
63
+ v = vb * denom
64
+ w = vc * denom
65
+
66
+ return wp.vec3(1.0 - v - w, v, w)
67
+
68
+
69
+ # a-b is the edge where the closest point is located at
70
+ @wp.func
71
+ def check_edge_feasible_region(p: wp.vec3, a: wp.vec3, b: wp.vec3, c: wp.vec3, eps: float):
72
+ ap = p - a
73
+ bp = p - b
74
+ ab = b - a
75
+
76
+ if wp.dot(ap, ab) < -eps:
77
+ return False
78
+
79
+ if wp.dot(bp, ab) > eps:
80
+ return False
81
+
82
+ ab_sqr_norm = wp.dot(ab, ab)
83
+ if ab_sqr_norm < eps:
84
+ return False
85
+
86
+ t = wp.dot(ab, c - a) / ab_sqr_norm
87
+
88
+ perpendicular_foot = a + t * ab
89
+
90
+ if wp.dot(c - perpendicular_foot, p - perpendicular_foot) > eps:
91
+ return False
92
+
93
+ return True
94
+
95
+
96
+ # closest point is a
97
+ @wp.func
98
+ def check_vertex_feasible_region(p: wp.vec3, a: wp.vec3, b: wp.vec3, c: wp.vec3, eps: float):
99
+ ap = p - a
100
+ ba = a - b
101
+ ca = a - c
102
+
103
+ if wp.dot(ap, ba) < -eps:
104
+ return False
105
+
106
+ if wp.dot(p, ca) < -eps:
107
+ return False
108
+
109
+ return True
110
+
111
+
112
+ @wp.kernel
113
+ def test_triangle_closest_point_kernel(tri: wp.array(dtype=wp.vec3), passed: wp.array(dtype=wp.bool)):
114
+ state = wp.uint32(wp.rand_init(wp.int32(123), wp.int32(0)))
115
+ eps = 1e-5
116
+
117
+ a = tri[0]
118
+ b = tri[1]
119
+ c = tri[2]
120
+
121
+ for _i in range(1000):
122
+ l = wp.float32(0.0)
123
+ while l < eps:
124
+ p = wp.vec3(wp.randn(state), wp.randn(state), wp.randn(state))
125
+ l = wp.length(p)
126
+
127
+ # project to a sphere with r=2
128
+ p = 2.0 * p / l
129
+
130
+ bary = triangle_closest_point_barycentric(tri[0], tri[1], tri[2], p)
131
+
132
+ for dim in range(3):
133
+ v1_index = (dim + 1) % 3
134
+ v2_index = (dim + 2) % 3
135
+ v1 = tri[v1_index]
136
+ v2 = tri[v2_index]
137
+ v3 = tri[dim]
138
+
139
+ # on edge
140
+ if bary[dim] == 0.0 and bary[v1_index] != 0.0 and bary[v2_index] != 0.0:
141
+ if not check_edge_feasible_region(p, v1, v2, v3, eps):
142
+ passed[0] = False
143
+ return
144
+
145
+ # p-closest_p must be perpendicular to v1-v2
146
+ closest_p = a * bary[0] + b * bary[1] + c * bary[2]
147
+ e = v1 - v2
148
+ err = wp.dot(e, closest_p - p)
149
+ if wp.abs(err) > eps:
150
+ passed[0] = False
151
+ return
152
+
153
+ if bary[v1_index] == 0.0 and bary[v2_index] == 0.0:
154
+ if not check_vertex_feasible_region(p, v3, v1, v2, eps):
155
+ passed[0] = False
156
+ return
157
+
158
+ if bary[dim] != 0.0 and bary[v1_index] != 0.0 and bary[v2_index] != 0.0:
159
+ closest_p = a * bary[0] + b * bary[1] + c * bary[2]
160
+ e1 = v1 - v2
161
+ e2 = v1 - v3
162
+ if wp.abs(wp.dot(e1, closest_p - p)) > eps or wp.abs(wp.dot(e2, closest_p - p)) > eps:
163
+ passed[0] = False
164
+ return
165
+
166
+
167
+ def test_triangle_closest_point(test, device):
168
+ passed = wp.array([True], dtype=wp.bool, device=device)
169
+
170
+ a = wp.vec3(1.0, 0.0, 0.0)
171
+ b = wp.vec3(0.0, 0.0, 0.0)
172
+ c = wp.vec3(0.0, 1.0, 0.0)
173
+
174
+ tri = wp.array([a, b, c], dtype=wp.vec3, device=device)
175
+ wp.launch(test_triangle_closest_point_kernel, dim=1, inputs=[tri, passed], device=device)
176
+ passed = passed.numpy()
177
+
178
+ test.assertTrue(passed.all())
179
+
180
+
181
+ devices = get_test_devices()
182
+
183
+
184
+ class TestTriangleClosestPoint(unittest.TestCase):
185
+ pass
186
+
187
+
188
+ add_function_test(TestTriangleClosestPoint, "test_triangle_closest_point", test_triangle_closest_point, devices=devices)
189
+
190
+ if __name__ == "__main__":
191
+ wp.clear_kernel_cache()
192
+ unittest.main(verbosity=2)