warp-lang 1.7.0__py3-none-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (429) hide show
  1. warp/__init__.py +139 -0
  2. warp/__init__.pyi +1 -0
  3. warp/autograd.py +1142 -0
  4. warp/bin/warp-clang.so +0 -0
  5. warp/bin/warp.so +0 -0
  6. warp/build.py +557 -0
  7. warp/build_dll.py +405 -0
  8. warp/builtins.py +6855 -0
  9. warp/codegen.py +3969 -0
  10. warp/config.py +158 -0
  11. warp/constants.py +57 -0
  12. warp/context.py +6812 -0
  13. warp/dlpack.py +462 -0
  14. warp/examples/__init__.py +24 -0
  15. warp/examples/assets/bear.usd +0 -0
  16. warp/examples/assets/bunny.usd +0 -0
  17. warp/examples/assets/cartpole.urdf +110 -0
  18. warp/examples/assets/crazyflie.usd +0 -0
  19. warp/examples/assets/cube.usd +0 -0
  20. warp/examples/assets/nonuniform.usd +0 -0
  21. warp/examples/assets/nv_ant.xml +92 -0
  22. warp/examples/assets/nv_humanoid.xml +183 -0
  23. warp/examples/assets/nvidia_logo.png +0 -0
  24. warp/examples/assets/pixel.jpg +0 -0
  25. warp/examples/assets/quadruped.urdf +268 -0
  26. warp/examples/assets/rocks.nvdb +0 -0
  27. warp/examples/assets/rocks.usd +0 -0
  28. warp/examples/assets/sphere.usd +0 -0
  29. warp/examples/assets/square_cloth.usd +0 -0
  30. warp/examples/benchmarks/benchmark_api.py +389 -0
  31. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  32. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  33. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  34. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  35. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  36. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  37. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  38. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  39. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  40. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  41. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  42. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  43. warp/examples/benchmarks/benchmark_launches.py +301 -0
  44. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  45. warp/examples/browse.py +37 -0
  46. warp/examples/core/example_cupy.py +86 -0
  47. warp/examples/core/example_dem.py +241 -0
  48. warp/examples/core/example_fluid.py +299 -0
  49. warp/examples/core/example_graph_capture.py +150 -0
  50. warp/examples/core/example_marching_cubes.py +194 -0
  51. warp/examples/core/example_mesh.py +180 -0
  52. warp/examples/core/example_mesh_intersect.py +211 -0
  53. warp/examples/core/example_nvdb.py +182 -0
  54. warp/examples/core/example_raycast.py +111 -0
  55. warp/examples/core/example_raymarch.py +205 -0
  56. warp/examples/core/example_render_opengl.py +193 -0
  57. warp/examples/core/example_sample_mesh.py +300 -0
  58. warp/examples/core/example_sph.py +411 -0
  59. warp/examples/core/example_torch.py +211 -0
  60. warp/examples/core/example_wave.py +269 -0
  61. warp/examples/fem/example_adaptive_grid.py +286 -0
  62. warp/examples/fem/example_apic_fluid.py +423 -0
  63. warp/examples/fem/example_burgers.py +261 -0
  64. warp/examples/fem/example_convection_diffusion.py +178 -0
  65. warp/examples/fem/example_convection_diffusion_dg.py +204 -0
  66. warp/examples/fem/example_deformed_geometry.py +172 -0
  67. warp/examples/fem/example_diffusion.py +196 -0
  68. warp/examples/fem/example_diffusion_3d.py +225 -0
  69. warp/examples/fem/example_diffusion_mgpu.py +220 -0
  70. warp/examples/fem/example_distortion_energy.py +228 -0
  71. warp/examples/fem/example_magnetostatics.py +240 -0
  72. warp/examples/fem/example_mixed_elasticity.py +291 -0
  73. warp/examples/fem/example_navier_stokes.py +261 -0
  74. warp/examples/fem/example_nonconforming_contact.py +298 -0
  75. warp/examples/fem/example_stokes.py +213 -0
  76. warp/examples/fem/example_stokes_transfer.py +262 -0
  77. warp/examples/fem/example_streamlines.py +352 -0
  78. warp/examples/fem/utils.py +1000 -0
  79. warp/examples/interop/example_jax_callable.py +116 -0
  80. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  81. warp/examples/interop/example_jax_kernel.py +205 -0
  82. warp/examples/optim/example_bounce.py +266 -0
  83. warp/examples/optim/example_cloth_throw.py +228 -0
  84. warp/examples/optim/example_diffray.py +561 -0
  85. warp/examples/optim/example_drone.py +870 -0
  86. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  87. warp/examples/optim/example_inverse_kinematics.py +182 -0
  88. warp/examples/optim/example_inverse_kinematics_torch.py +191 -0
  89. warp/examples/optim/example_softbody_properties.py +400 -0
  90. warp/examples/optim/example_spring_cage.py +245 -0
  91. warp/examples/optim/example_trajectory.py +227 -0
  92. warp/examples/sim/example_cartpole.py +143 -0
  93. warp/examples/sim/example_cloth.py +225 -0
  94. warp/examples/sim/example_cloth_self_contact.py +322 -0
  95. warp/examples/sim/example_granular.py +130 -0
  96. warp/examples/sim/example_granular_collision_sdf.py +202 -0
  97. warp/examples/sim/example_jacobian_ik.py +244 -0
  98. warp/examples/sim/example_particle_chain.py +124 -0
  99. warp/examples/sim/example_quadruped.py +203 -0
  100. warp/examples/sim/example_rigid_chain.py +203 -0
  101. warp/examples/sim/example_rigid_contact.py +195 -0
  102. warp/examples/sim/example_rigid_force.py +133 -0
  103. warp/examples/sim/example_rigid_gyroscopic.py +115 -0
  104. warp/examples/sim/example_rigid_soft_contact.py +140 -0
  105. warp/examples/sim/example_soft_body.py +196 -0
  106. warp/examples/tile/example_tile_cholesky.py +87 -0
  107. warp/examples/tile/example_tile_convolution.py +66 -0
  108. warp/examples/tile/example_tile_fft.py +55 -0
  109. warp/examples/tile/example_tile_filtering.py +113 -0
  110. warp/examples/tile/example_tile_matmul.py +85 -0
  111. warp/examples/tile/example_tile_mlp.py +383 -0
  112. warp/examples/tile/example_tile_nbody.py +199 -0
  113. warp/examples/tile/example_tile_walker.py +327 -0
  114. warp/fabric.py +355 -0
  115. warp/fem/__init__.py +106 -0
  116. warp/fem/adaptivity.py +508 -0
  117. warp/fem/cache.py +572 -0
  118. warp/fem/dirichlet.py +202 -0
  119. warp/fem/domain.py +411 -0
  120. warp/fem/field/__init__.py +125 -0
  121. warp/fem/field/field.py +619 -0
  122. warp/fem/field/nodal_field.py +326 -0
  123. warp/fem/field/restriction.py +37 -0
  124. warp/fem/field/virtual.py +848 -0
  125. warp/fem/geometry/__init__.py +32 -0
  126. warp/fem/geometry/adaptive_nanogrid.py +857 -0
  127. warp/fem/geometry/closest_point.py +84 -0
  128. warp/fem/geometry/deformed_geometry.py +221 -0
  129. warp/fem/geometry/element.py +776 -0
  130. warp/fem/geometry/geometry.py +362 -0
  131. warp/fem/geometry/grid_2d.py +392 -0
  132. warp/fem/geometry/grid_3d.py +452 -0
  133. warp/fem/geometry/hexmesh.py +911 -0
  134. warp/fem/geometry/nanogrid.py +571 -0
  135. warp/fem/geometry/partition.py +389 -0
  136. warp/fem/geometry/quadmesh.py +663 -0
  137. warp/fem/geometry/tetmesh.py +855 -0
  138. warp/fem/geometry/trimesh.py +806 -0
  139. warp/fem/integrate.py +2335 -0
  140. warp/fem/linalg.py +419 -0
  141. warp/fem/operator.py +293 -0
  142. warp/fem/polynomial.py +229 -0
  143. warp/fem/quadrature/__init__.py +17 -0
  144. warp/fem/quadrature/pic_quadrature.py +299 -0
  145. warp/fem/quadrature/quadrature.py +591 -0
  146. warp/fem/space/__init__.py +228 -0
  147. warp/fem/space/basis_function_space.py +468 -0
  148. warp/fem/space/basis_space.py +667 -0
  149. warp/fem/space/dof_mapper.py +251 -0
  150. warp/fem/space/function_space.py +309 -0
  151. warp/fem/space/grid_2d_function_space.py +177 -0
  152. warp/fem/space/grid_3d_function_space.py +227 -0
  153. warp/fem/space/hexmesh_function_space.py +257 -0
  154. warp/fem/space/nanogrid_function_space.py +201 -0
  155. warp/fem/space/partition.py +367 -0
  156. warp/fem/space/quadmesh_function_space.py +223 -0
  157. warp/fem/space/restriction.py +179 -0
  158. warp/fem/space/shape/__init__.py +143 -0
  159. warp/fem/space/shape/cube_shape_function.py +1105 -0
  160. warp/fem/space/shape/shape_function.py +133 -0
  161. warp/fem/space/shape/square_shape_function.py +926 -0
  162. warp/fem/space/shape/tet_shape_function.py +834 -0
  163. warp/fem/space/shape/triangle_shape_function.py +672 -0
  164. warp/fem/space/tetmesh_function_space.py +271 -0
  165. warp/fem/space/topology.py +424 -0
  166. warp/fem/space/trimesh_function_space.py +194 -0
  167. warp/fem/types.py +99 -0
  168. warp/fem/utils.py +420 -0
  169. warp/jax.py +187 -0
  170. warp/jax_experimental/__init__.py +16 -0
  171. warp/jax_experimental/custom_call.py +351 -0
  172. warp/jax_experimental/ffi.py +698 -0
  173. warp/jax_experimental/xla_ffi.py +602 -0
  174. warp/math.py +244 -0
  175. warp/native/array.h +1145 -0
  176. warp/native/builtin.h +1800 -0
  177. warp/native/bvh.cpp +492 -0
  178. warp/native/bvh.cu +791 -0
  179. warp/native/bvh.h +554 -0
  180. warp/native/clang/clang.cpp +536 -0
  181. warp/native/coloring.cpp +613 -0
  182. warp/native/crt.cpp +51 -0
  183. warp/native/crt.h +362 -0
  184. warp/native/cuda_crt.h +1058 -0
  185. warp/native/cuda_util.cpp +646 -0
  186. warp/native/cuda_util.h +307 -0
  187. warp/native/error.cpp +77 -0
  188. warp/native/error.h +36 -0
  189. warp/native/exports.h +1878 -0
  190. warp/native/fabric.h +245 -0
  191. warp/native/hashgrid.cpp +311 -0
  192. warp/native/hashgrid.cu +87 -0
  193. warp/native/hashgrid.h +240 -0
  194. warp/native/initializer_array.h +41 -0
  195. warp/native/intersect.h +1230 -0
  196. warp/native/intersect_adj.h +375 -0
  197. warp/native/intersect_tri.h +339 -0
  198. warp/native/marching.cpp +19 -0
  199. warp/native/marching.cu +514 -0
  200. warp/native/marching.h +19 -0
  201. warp/native/mat.h +2220 -0
  202. warp/native/mathdx.cpp +87 -0
  203. warp/native/matnn.h +343 -0
  204. warp/native/mesh.cpp +266 -0
  205. warp/native/mesh.cu +404 -0
  206. warp/native/mesh.h +1980 -0
  207. warp/native/nanovdb/GridHandle.h +366 -0
  208. warp/native/nanovdb/HostBuffer.h +590 -0
  209. warp/native/nanovdb/NanoVDB.h +6624 -0
  210. warp/native/nanovdb/PNanoVDB.h +3390 -0
  211. warp/native/noise.h +859 -0
  212. warp/native/quat.h +1371 -0
  213. warp/native/rand.h +342 -0
  214. warp/native/range.h +139 -0
  215. warp/native/reduce.cpp +174 -0
  216. warp/native/reduce.cu +364 -0
  217. warp/native/runlength_encode.cpp +79 -0
  218. warp/native/runlength_encode.cu +61 -0
  219. warp/native/scan.cpp +47 -0
  220. warp/native/scan.cu +53 -0
  221. warp/native/scan.h +23 -0
  222. warp/native/solid_angle.h +466 -0
  223. warp/native/sort.cpp +251 -0
  224. warp/native/sort.cu +277 -0
  225. warp/native/sort.h +33 -0
  226. warp/native/sparse.cpp +378 -0
  227. warp/native/sparse.cu +524 -0
  228. warp/native/spatial.h +657 -0
  229. warp/native/svd.h +702 -0
  230. warp/native/temp_buffer.h +46 -0
  231. warp/native/tile.h +2584 -0
  232. warp/native/tile_reduce.h +264 -0
  233. warp/native/vec.h +1426 -0
  234. warp/native/volume.cpp +501 -0
  235. warp/native/volume.cu +67 -0
  236. warp/native/volume.h +969 -0
  237. warp/native/volume_builder.cu +477 -0
  238. warp/native/volume_builder.h +52 -0
  239. warp/native/volume_impl.h +70 -0
  240. warp/native/warp.cpp +1082 -0
  241. warp/native/warp.cu +3636 -0
  242. warp/native/warp.h +381 -0
  243. warp/optim/__init__.py +17 -0
  244. warp/optim/adam.py +163 -0
  245. warp/optim/linear.py +1137 -0
  246. warp/optim/sgd.py +112 -0
  247. warp/paddle.py +407 -0
  248. warp/render/__init__.py +18 -0
  249. warp/render/render_opengl.py +3518 -0
  250. warp/render/render_usd.py +784 -0
  251. warp/render/utils.py +160 -0
  252. warp/sim/__init__.py +65 -0
  253. warp/sim/articulation.py +793 -0
  254. warp/sim/collide.py +2395 -0
  255. warp/sim/graph_coloring.py +300 -0
  256. warp/sim/import_mjcf.py +790 -0
  257. warp/sim/import_snu.py +227 -0
  258. warp/sim/import_urdf.py +579 -0
  259. warp/sim/import_usd.py +894 -0
  260. warp/sim/inertia.py +324 -0
  261. warp/sim/integrator.py +242 -0
  262. warp/sim/integrator_euler.py +1997 -0
  263. warp/sim/integrator_featherstone.py +2101 -0
  264. warp/sim/integrator_vbd.py +2048 -0
  265. warp/sim/integrator_xpbd.py +3292 -0
  266. warp/sim/model.py +4791 -0
  267. warp/sim/particles.py +121 -0
  268. warp/sim/render.py +427 -0
  269. warp/sim/utils.py +428 -0
  270. warp/sparse.py +2057 -0
  271. warp/stubs.py +3333 -0
  272. warp/tape.py +1203 -0
  273. warp/tests/__init__.py +1 -0
  274. warp/tests/__main__.py +4 -0
  275. warp/tests/assets/curlnoise_golden.npy +0 -0
  276. warp/tests/assets/mlp_golden.npy +0 -0
  277. warp/tests/assets/pixel.npy +0 -0
  278. warp/tests/assets/pnoise_golden.npy +0 -0
  279. warp/tests/assets/spiky.usd +0 -0
  280. warp/tests/assets/test_grid.nvdb +0 -0
  281. warp/tests/assets/test_index_grid.nvdb +0 -0
  282. warp/tests/assets/test_int32_grid.nvdb +0 -0
  283. warp/tests/assets/test_vec_grid.nvdb +0 -0
  284. warp/tests/assets/torus.nvdb +0 -0
  285. warp/tests/assets/torus.usda +105 -0
  286. warp/tests/aux_test_class_kernel.py +34 -0
  287. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  288. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  289. warp/tests/aux_test_dependent.py +29 -0
  290. warp/tests/aux_test_grad_customs.py +29 -0
  291. warp/tests/aux_test_instancing_gc.py +26 -0
  292. warp/tests/aux_test_module_unload.py +23 -0
  293. warp/tests/aux_test_name_clash1.py +40 -0
  294. warp/tests/aux_test_name_clash2.py +40 -0
  295. warp/tests/aux_test_reference.py +9 -0
  296. warp/tests/aux_test_reference_reference.py +8 -0
  297. warp/tests/aux_test_square.py +16 -0
  298. warp/tests/aux_test_unresolved_func.py +22 -0
  299. warp/tests/aux_test_unresolved_symbol.py +22 -0
  300. warp/tests/cuda/__init__.py +0 -0
  301. warp/tests/cuda/test_async.py +676 -0
  302. warp/tests/cuda/test_ipc.py +124 -0
  303. warp/tests/cuda/test_mempool.py +233 -0
  304. warp/tests/cuda/test_multigpu.py +169 -0
  305. warp/tests/cuda/test_peer.py +139 -0
  306. warp/tests/cuda/test_pinned.py +84 -0
  307. warp/tests/cuda/test_streams.py +634 -0
  308. warp/tests/geometry/__init__.py +0 -0
  309. warp/tests/geometry/test_bvh.py +200 -0
  310. warp/tests/geometry/test_hash_grid.py +221 -0
  311. warp/tests/geometry/test_marching_cubes.py +74 -0
  312. warp/tests/geometry/test_mesh.py +316 -0
  313. warp/tests/geometry/test_mesh_query_aabb.py +399 -0
  314. warp/tests/geometry/test_mesh_query_point.py +932 -0
  315. warp/tests/geometry/test_mesh_query_ray.py +311 -0
  316. warp/tests/geometry/test_volume.py +1103 -0
  317. warp/tests/geometry/test_volume_write.py +346 -0
  318. warp/tests/interop/__init__.py +0 -0
  319. warp/tests/interop/test_dlpack.py +729 -0
  320. warp/tests/interop/test_jax.py +371 -0
  321. warp/tests/interop/test_paddle.py +800 -0
  322. warp/tests/interop/test_torch.py +1001 -0
  323. warp/tests/run_coverage_serial.py +39 -0
  324. warp/tests/sim/__init__.py +0 -0
  325. warp/tests/sim/disabled_kinematics.py +244 -0
  326. warp/tests/sim/flaky_test_sim_grad.py +290 -0
  327. warp/tests/sim/test_collision.py +604 -0
  328. warp/tests/sim/test_coloring.py +258 -0
  329. warp/tests/sim/test_model.py +224 -0
  330. warp/tests/sim/test_sim_grad_bounce_linear.py +212 -0
  331. warp/tests/sim/test_sim_kinematics.py +98 -0
  332. warp/tests/sim/test_vbd.py +597 -0
  333. warp/tests/test_adam.py +163 -0
  334. warp/tests/test_arithmetic.py +1096 -0
  335. warp/tests/test_array.py +2972 -0
  336. warp/tests/test_array_reduce.py +156 -0
  337. warp/tests/test_assert.py +250 -0
  338. warp/tests/test_atomic.py +153 -0
  339. warp/tests/test_bool.py +220 -0
  340. warp/tests/test_builtins_resolution.py +1298 -0
  341. warp/tests/test_closest_point_edge_edge.py +327 -0
  342. warp/tests/test_codegen.py +810 -0
  343. warp/tests/test_codegen_instancing.py +1495 -0
  344. warp/tests/test_compile_consts.py +215 -0
  345. warp/tests/test_conditional.py +252 -0
  346. warp/tests/test_context.py +42 -0
  347. warp/tests/test_copy.py +238 -0
  348. warp/tests/test_ctypes.py +638 -0
  349. warp/tests/test_dense.py +73 -0
  350. warp/tests/test_devices.py +97 -0
  351. warp/tests/test_examples.py +482 -0
  352. warp/tests/test_fabricarray.py +996 -0
  353. warp/tests/test_fast_math.py +74 -0
  354. warp/tests/test_fem.py +2003 -0
  355. warp/tests/test_fp16.py +136 -0
  356. warp/tests/test_func.py +454 -0
  357. warp/tests/test_future_annotations.py +98 -0
  358. warp/tests/test_generics.py +656 -0
  359. warp/tests/test_grad.py +893 -0
  360. warp/tests/test_grad_customs.py +339 -0
  361. warp/tests/test_grad_debug.py +341 -0
  362. warp/tests/test_implicit_init.py +411 -0
  363. warp/tests/test_import.py +45 -0
  364. warp/tests/test_indexedarray.py +1140 -0
  365. warp/tests/test_intersect.py +73 -0
  366. warp/tests/test_iter.py +76 -0
  367. warp/tests/test_large.py +177 -0
  368. warp/tests/test_launch.py +411 -0
  369. warp/tests/test_lerp.py +151 -0
  370. warp/tests/test_linear_solvers.py +193 -0
  371. warp/tests/test_lvalue.py +427 -0
  372. warp/tests/test_mat.py +2089 -0
  373. warp/tests/test_mat_lite.py +122 -0
  374. warp/tests/test_mat_scalar_ops.py +2913 -0
  375. warp/tests/test_math.py +178 -0
  376. warp/tests/test_mlp.py +282 -0
  377. warp/tests/test_module_hashing.py +258 -0
  378. warp/tests/test_modules_lite.py +44 -0
  379. warp/tests/test_noise.py +252 -0
  380. warp/tests/test_operators.py +299 -0
  381. warp/tests/test_options.py +129 -0
  382. warp/tests/test_overwrite.py +551 -0
  383. warp/tests/test_print.py +339 -0
  384. warp/tests/test_quat.py +2315 -0
  385. warp/tests/test_rand.py +339 -0
  386. warp/tests/test_reload.py +302 -0
  387. warp/tests/test_rounding.py +185 -0
  388. warp/tests/test_runlength_encode.py +196 -0
  389. warp/tests/test_scalar_ops.py +105 -0
  390. warp/tests/test_smoothstep.py +108 -0
  391. warp/tests/test_snippet.py +318 -0
  392. warp/tests/test_sparse.py +582 -0
  393. warp/tests/test_spatial.py +2229 -0
  394. warp/tests/test_special_values.py +361 -0
  395. warp/tests/test_static.py +592 -0
  396. warp/tests/test_struct.py +734 -0
  397. warp/tests/test_tape.py +204 -0
  398. warp/tests/test_transient_module.py +93 -0
  399. warp/tests/test_triangle_closest_point.py +145 -0
  400. warp/tests/test_types.py +562 -0
  401. warp/tests/test_utils.py +588 -0
  402. warp/tests/test_vec.py +1487 -0
  403. warp/tests/test_vec_lite.py +80 -0
  404. warp/tests/test_vec_scalar_ops.py +2327 -0
  405. warp/tests/test_verify_fp.py +100 -0
  406. warp/tests/tile/__init__.py +0 -0
  407. warp/tests/tile/test_tile.py +780 -0
  408. warp/tests/tile/test_tile_load.py +407 -0
  409. warp/tests/tile/test_tile_mathdx.py +208 -0
  410. warp/tests/tile/test_tile_mlp.py +402 -0
  411. warp/tests/tile/test_tile_reduce.py +447 -0
  412. warp/tests/tile/test_tile_shared_memory.py +247 -0
  413. warp/tests/tile/test_tile_view.py +173 -0
  414. warp/tests/unittest_serial.py +47 -0
  415. warp/tests/unittest_suites.py +427 -0
  416. warp/tests/unittest_utils.py +468 -0
  417. warp/tests/walkthrough_debug.py +93 -0
  418. warp/thirdparty/__init__.py +0 -0
  419. warp/thirdparty/appdirs.py +598 -0
  420. warp/thirdparty/dlpack.py +145 -0
  421. warp/thirdparty/unittest_parallel.py +570 -0
  422. warp/torch.py +391 -0
  423. warp/types.py +5230 -0
  424. warp/utils.py +1137 -0
  425. warp_lang-1.7.0.dist-info/METADATA +516 -0
  426. warp_lang-1.7.0.dist-info/RECORD +429 -0
  427. warp_lang-1.7.0.dist-info/WHEEL +5 -0
  428. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  429. warp_lang-1.7.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,588 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import contextlib
17
+ import inspect
18
+ import io
19
+ import unittest
20
+
21
+ from warp.tests.unittest_utils import *
22
+
23
+
24
+ def test_array_scan(test, device):
25
+ rng = np.random.default_rng(123)
26
+
27
+ for dtype in (int, float):
28
+ if dtype == int:
29
+ values = rng.integers(-1e6, high=1e6, size=100000, dtype=dtype)
30
+ else:
31
+ values = rng.uniform(low=-1e6, high=1e6, size=100000)
32
+
33
+ expected = np.cumsum(values)
34
+
35
+ values = wp.array(values, dtype=dtype, device=device)
36
+ result_inc = wp.zeros_like(values)
37
+ result_exc = wp.zeros_like(values)
38
+
39
+ wp.utils.array_scan(values, result_inc, True)
40
+ wp.utils.array_scan(values, result_exc, False)
41
+
42
+ tolerance = 0 if dtype == int else 1e-3
43
+
44
+ result_inc = result_inc.numpy().squeeze()
45
+ result_exc = result_exc.numpy().squeeze()
46
+ error_inc = np.max(np.abs(result_inc - expected)) / abs(expected[-1])
47
+ error_exc = max(np.max(np.abs(result_exc[1:] - expected[:-1])), abs(result_exc[0])) / abs(expected[-2])
48
+
49
+ test.assertTrue(error_inc <= tolerance)
50
+ test.assertTrue(error_exc <= tolerance)
51
+
52
+
53
+ def test_array_scan_empty(test, device):
54
+ values = wp.array((), dtype=int, device=device)
55
+ result = wp.array((), dtype=int, device=device)
56
+ wp.utils.array_scan(values, result)
57
+
58
+
59
+ def test_array_scan_error_sizes_mismatch(test, device):
60
+ values = wp.zeros(123, dtype=int, device=device)
61
+ result = wp.zeros(234, dtype=int, device=device)
62
+ with test.assertRaisesRegex(
63
+ RuntimeError,
64
+ r"Array storage sizes do not match$",
65
+ ):
66
+ wp.utils.array_scan(values, result, True)
67
+
68
+
69
+ def test_array_scan_error_dtypes_mismatch(test, device):
70
+ values = wp.zeros(123, dtype=int, device=device)
71
+ result = wp.zeros(123, dtype=float, device=device)
72
+ with test.assertRaisesRegex(
73
+ RuntimeError,
74
+ r"Array data types do not match$",
75
+ ):
76
+ wp.utils.array_scan(values, result, True)
77
+
78
+
79
+ def test_array_scan_error_unsupported_dtype(test, device):
80
+ values = wp.zeros(123, dtype=wp.vec3, device=device)
81
+ result = wp.zeros(123, dtype=wp.vec3, device=device)
82
+ with test.assertRaisesRegex(
83
+ RuntimeError,
84
+ r"Unsupported data type$",
85
+ ):
86
+ wp.utils.array_scan(values, result, True)
87
+
88
+
89
+ def test_radix_sort_pairs(test, device):
90
+ keyTypes = [int, wp.float32, wp.int64]
91
+
92
+ for keyType in keyTypes:
93
+ keys = wp.array((7, 2, 8, 4, 1, 6, 5, 3, 0, 0, 0, 0, 0, 0, 0, 0), dtype=keyType, device=device)
94
+ values = wp.array((1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 0), dtype=int, device=device)
95
+ wp.utils.radix_sort_pairs(keys, values, 8)
96
+ assert_np_equal(keys.numpy()[:8], np.array((1, 2, 3, 4, 5, 6, 7, 8)))
97
+ assert_np_equal(values.numpy()[:8], np.array((5, 2, 8, 4, 7, 6, 1, 3)))
98
+
99
+
100
+ def test_segmented_sort_pairs(test, device):
101
+ keyTypes = [int, wp.float32]
102
+
103
+ for keyType in keyTypes:
104
+ keys = wp.array((7, 2, 8, 4, 1, 6, 5, 3, 0, 0, 0, 0, 0, 0, 0, 0), dtype=keyType, device=device)
105
+ values = wp.array((1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 0), dtype=int, device=device)
106
+ wp.utils.segmented_sort_pairs(
107
+ keys,
108
+ values,
109
+ 8,
110
+ wp.array((0, 4), dtype=int, device=device),
111
+ wp.array((4, 8), dtype=int, device=device),
112
+ )
113
+ assert_np_equal(keys.numpy()[:8], np.array((2, 4, 7, 8, 1, 3, 5, 6)))
114
+ assert_np_equal(values.numpy()[:8], np.array((2, 4, 1, 3, 5, 8, 7, 6)))
115
+
116
+
117
+ def test_radix_sort_pairs_empty(test, device):
118
+ keyTypes = [int, wp.float32, wp.int64]
119
+
120
+ for keyType in keyTypes:
121
+ keys = wp.array((), dtype=keyType, device=device)
122
+ values = wp.array((), dtype=int, device=device)
123
+ wp.utils.radix_sort_pairs(keys, values, 0)
124
+
125
+
126
+ def test_segmented_sort_pairs_empty(test, device):
127
+ keyTypes = [int, wp.float32]
128
+
129
+ for keyType in keyTypes:
130
+ keys = wp.array((), dtype=keyType, device=device)
131
+ values = wp.array((), dtype=int, device=device)
132
+ wp.utils.segmented_sort_pairs(
133
+ keys, values, 0, wp.array((), dtype=int, device=device), wp.array((), dtype=int, device=device)
134
+ )
135
+
136
+
137
+ def test_radix_sort_pairs_error_insufficient_storage(test, device):
138
+ keyTypes = [int, wp.float32, wp.int64]
139
+
140
+ for keyType in keyTypes:
141
+ keys = wp.array((1, 2, 3), dtype=keyType, device=device)
142
+ values = wp.array((1, 2, 3), dtype=int, device=device)
143
+ with test.assertRaisesRegex(
144
+ RuntimeError,
145
+ r"Array storage must be large enough to contain 2\*count elements$",
146
+ ):
147
+ wp.utils.radix_sort_pairs(keys, values, 3)
148
+
149
+
150
+ def test_segmented_sort_pairs_error_insufficient_storage(test, device):
151
+ keyTypes = [int, wp.float32]
152
+
153
+ for keyType in keyTypes:
154
+ keys = wp.array((1, 2, 3), dtype=keyType, device=device)
155
+ values = wp.array((1, 2, 3), dtype=int, device=device)
156
+ with test.assertRaisesRegex(
157
+ RuntimeError,
158
+ r"Array storage must be large enough to contain 2\*count elements$",
159
+ ):
160
+ wp.utils.segmented_sort_pairs(
161
+ keys,
162
+ values,
163
+ 3,
164
+ wp.array((0,), dtype=int, device=device),
165
+ wp.array((3,), dtype=int, device=device),
166
+ )
167
+
168
+
169
+ def test_radix_sort_pairs_error_unsupported_dtype(test, device):
170
+ keyTypes = [int, wp.float32, wp.int64]
171
+
172
+ for keyType in keyTypes:
173
+ keys = wp.array((1.0, 2.0, 3.0), dtype=keyType, device=device)
174
+ values = wp.array((1.0, 2.0, 3.0), dtype=float, device=device)
175
+ with test.assertRaisesRegex(
176
+ RuntimeError,
177
+ r"Unsupported data type$",
178
+ ):
179
+ wp.utils.radix_sort_pairs(keys, values, 1)
180
+
181
+
182
+ def test_segmented_sort_pairs_error_unsupported_dtype(test, device):
183
+ keyTypes = [int, wp.float32]
184
+
185
+ for keyType in keyTypes:
186
+ keys = wp.array((1.0, 2.0, 3.0), dtype=keyType, device=device)
187
+ values = wp.array((1.0, 2.0, 3.0), dtype=float, device=device)
188
+ with test.assertRaisesRegex(
189
+ RuntimeError,
190
+ r"Unsupported data type$",
191
+ ):
192
+ wp.utils.segmented_sort_pairs(
193
+ keys,
194
+ values,
195
+ 1,
196
+ wp.array((0,), dtype=int, device=device),
197
+ wp.array((3,), dtype=int, device=device),
198
+ )
199
+
200
+
201
+ def test_array_sum(test, device):
202
+ for dtype in (wp.float32, wp.float64):
203
+ with test.subTest(dtype=dtype):
204
+ values = wp.array((1.0, 2.0, 3.0), dtype=dtype, device=device)
205
+ test.assertEqual(wp.utils.array_sum(values), 6.0)
206
+
207
+ values = wp.array((1.0, 2.0, 3.0), dtype=dtype, device=device)
208
+ result = wp.empty(shape=(1,), dtype=dtype, device=device)
209
+ wp.utils.array_sum(values, out=result)
210
+ test.assertEqual(result.numpy()[0], 6.0)
211
+
212
+
213
+ def test_array_sum_error_out_dtype_mismatch(test, device):
214
+ values = wp.array((1.0, 2.0, 3.0), dtype=wp.float32, device=device)
215
+ result = wp.empty(shape=(1,), dtype=wp.float64, device=device)
216
+ with test.assertRaisesRegex(
217
+ RuntimeError,
218
+ r"out array should have type float32$",
219
+ ):
220
+ wp.utils.array_sum(values, out=result)
221
+
222
+
223
+ def test_array_sum_error_out_shape_mismatch(test, device):
224
+ values = wp.array((1.0, 2.0, 3.0), dtype=wp.float32, device=device)
225
+ result = wp.empty(shape=(2,), dtype=wp.float32, device=device)
226
+ with test.assertRaisesRegex(
227
+ RuntimeError,
228
+ r"out array should have shape \(1,\)$",
229
+ ):
230
+ wp.utils.array_sum(values, out=result)
231
+
232
+
233
+ def test_array_sum_error_unsupported_dtype(test, device):
234
+ values = wp.array((1, 2, 3), dtype=int, device=device)
235
+ with test.assertRaisesRegex(
236
+ RuntimeError,
237
+ r"Unsupported data type$",
238
+ ):
239
+ wp.utils.array_sum(values)
240
+
241
+
242
+ def test_array_inner(test, device):
243
+ for dtype in (wp.float32, wp.float64):
244
+ a = wp.array((1.0, 2.0, 3.0), dtype=dtype, device=device)
245
+ b = wp.array((1.0, 2.0, 3.0), dtype=dtype, device=device)
246
+ test.assertEqual(wp.utils.array_inner(a, b), 14.0)
247
+
248
+ a = wp.array((1.0, 2.0, 3.0), dtype=dtype, device=device)
249
+ b = wp.array((1.0, 2.0, 3.0), dtype=dtype, device=device)
250
+ result = wp.empty(shape=(1,), dtype=dtype, device=device)
251
+ wp.utils.array_inner(a, b, out=result)
252
+ test.assertEqual(result.numpy()[0], 14.0)
253
+
254
+
255
+ def test_array_inner_error_sizes_mismatch(test, device):
256
+ a = wp.array((1.0, 2.0), dtype=wp.float32, device=device)
257
+ b = wp.array((1.0, 2.0, 3.0), dtype=wp.float32, device=device)
258
+ with test.assertRaisesRegex(
259
+ RuntimeError,
260
+ r"Array storage sizes do not match$",
261
+ ):
262
+ wp.utils.array_inner(a, b)
263
+
264
+
265
+ def test_array_inner_error_dtypes_mismatch(test, device):
266
+ a = wp.array((1.0, 2.0, 3.0), dtype=wp.float32, device=device)
267
+ b = wp.array((1.0, 2.0, 3.0), dtype=wp.float64, device=device)
268
+ with test.assertRaisesRegex(
269
+ RuntimeError,
270
+ r"Array data types do not match$",
271
+ ):
272
+ wp.utils.array_inner(a, b)
273
+
274
+
275
+ def test_array_inner_error_out_dtype_mismatch(test, device):
276
+ a = wp.array((1.0, 2.0, 3.0), dtype=wp.float32, device=device)
277
+ b = wp.array((1.0, 2.0, 3.0), dtype=wp.float32, device=device)
278
+ result = wp.empty(shape=(1,), dtype=wp.float64, device=device)
279
+ with test.assertRaisesRegex(
280
+ RuntimeError,
281
+ r"out array should have type float32$",
282
+ ):
283
+ wp.utils.array_inner(a, b, result)
284
+
285
+
286
+ def test_array_inner_error_out_shape_mismatch(test, device):
287
+ a = wp.array((1.0, 2.0, 3.0), dtype=wp.float32, device=device)
288
+ b = wp.array((1.0, 2.0, 3.0), dtype=wp.float32, device=device)
289
+ result = wp.empty(shape=(2,), dtype=wp.float32, device=device)
290
+ with test.assertRaisesRegex(
291
+ RuntimeError,
292
+ r"out array should have shape \(1,\)$",
293
+ ):
294
+ wp.utils.array_inner(a, b, result)
295
+
296
+
297
+ def test_array_inner_error_unsupported_dtype(test, device):
298
+ a = wp.array((1, 2, 3), dtype=int, device=device)
299
+ b = wp.array((1, 2, 3), dtype=int, device=device)
300
+ with test.assertRaisesRegex(
301
+ RuntimeError,
302
+ r"Unsupported data type$",
303
+ ):
304
+ wp.utils.array_inner(a, b)
305
+
306
+
307
+ def test_array_cast(test, device):
308
+ values = wp.array((1, 2, 3), dtype=int, device=device)
309
+ result = wp.empty(3, dtype=float, device=device)
310
+ wp.utils.array_cast(values, result)
311
+ test.assertEqual(result.dtype, wp.float32)
312
+ test.assertEqual(result.shape, (3,))
313
+ assert_np_equal(result.numpy(), np.array((1.0, 2.0, 3.0), dtype=float))
314
+
315
+ values = wp.array((1, 2, 3, 4), dtype=int, device=device)
316
+ result = wp.empty((2, 2), dtype=float, device=device)
317
+ wp.utils.array_cast(values, result)
318
+ test.assertEqual(result.dtype, wp.float32)
319
+ test.assertEqual(result.shape, (2, 2))
320
+ assert_np_equal(result.numpy(), np.array(((1.0, 2.0), (3.0, 4.0)), dtype=float))
321
+
322
+ values = wp.array(((1, 2), (3, 4)), dtype=wp.vec2, device=device)
323
+ result = wp.zeros(2, dtype=float, device=device)
324
+ wp.utils.array_cast(values, result, count=1)
325
+ test.assertEqual(result.dtype, wp.float32)
326
+ test.assertEqual(result.shape, (2,))
327
+ assert_np_equal(result.numpy(), np.array((1.0, 2.0), dtype=float))
328
+
329
+ values = wp.array(((1, 2), (3, 4)), dtype=int, device=device)
330
+ result = wp.zeros((2, 2), dtype=int, device=device)
331
+ wp.utils.array_cast(values, result)
332
+ test.assertEqual(result.dtype, wp.int32)
333
+ test.assertEqual(result.shape, (2, 2))
334
+ assert_np_equal(result.numpy(), np.array(((1, 2), (3, 4)), dtype=int))
335
+
336
+
337
+ def test_array_cast_error_unsupported_partial_cast(test, device):
338
+ values = wp.array(((1, 2), (3, 4)), dtype=int, device=device)
339
+ result = wp.zeros((2, 2), dtype=float, device=device)
340
+ with test.assertRaisesRegex(
341
+ RuntimeError,
342
+ r"Partial cast is not supported for arrays with more than one dimension$",
343
+ ):
344
+ wp.utils.array_cast(values, result, count=1)
345
+
346
+
347
+ devices = get_test_devices()
348
+
349
+
350
+ class TestUtils(unittest.TestCase):
351
+ def test_warn(self):
352
+ # Multiple warnings get printed out each time.
353
+ with contextlib.redirect_stdout(io.StringIO()) as f:
354
+ wp.utils.warn("hello, world!")
355
+ wp.utils.warn("hello, world!")
356
+
357
+ expected = "Warp UserWarning: hello, world!\nWarp UserWarning: hello, world!\n"
358
+
359
+ self.assertEqual(f.getvalue(), expected)
360
+
361
+ # Test verbose warnings
362
+ saved_verbosity = wp.config.verbose_warnings
363
+ try:
364
+ wp.config.verbose_warnings = True
365
+ with contextlib.redirect_stdout(io.StringIO()) as f:
366
+ frame_info = inspect.getframeinfo(inspect.currentframe())
367
+ wp.utils.warn("hello, world!")
368
+ wp.utils.warn("hello, world!")
369
+
370
+ expected = (
371
+ f"Warp UserWarning: hello, world! ({frame_info.filename}:{frame_info.lineno + 1})\n"
372
+ ' wp.utils.warn("hello, world!")\n'
373
+ f"Warp UserWarning: hello, world! ({frame_info.filename}:{frame_info.lineno + 2})\n"
374
+ ' wp.utils.warn("hello, world!")\n'
375
+ )
376
+
377
+ self.assertEqual(f.getvalue(), expected)
378
+
379
+ finally:
380
+ # make sure to restore warning verbosity
381
+ wp.config.verbose_warnings = saved_verbosity
382
+
383
+ # Multiple similar deprecation warnings get printed out only once.
384
+ with contextlib.redirect_stdout(io.StringIO()) as f:
385
+ wp.utils.warn("hello, world!", category=DeprecationWarning)
386
+ wp.utils.warn("hello, world!", category=DeprecationWarning)
387
+
388
+ expected = "Warp DeprecationWarning: hello, world!\n"
389
+
390
+ self.assertEqual(f.getvalue(), expected)
391
+
392
+ # Multiple different deprecation warnings get printed out each time.
393
+ with contextlib.redirect_stdout(io.StringIO()) as f:
394
+ wp.utils.warn("foo", category=DeprecationWarning)
395
+ wp.utils.warn("bar", category=DeprecationWarning)
396
+
397
+ expected = "Warp DeprecationWarning: foo\nWarp DeprecationWarning: bar\n"
398
+
399
+ self.assertEqual(f.getvalue(), expected)
400
+
401
+ def test_transform_expand(self):
402
+ t = (1.0, 2.0, 3.0, 4.0, 3.0, 2.0, 1.0)
403
+ self.assertEqual(
404
+ wp.utils.transform_expand(t),
405
+ wp.transformf(p=(1.0, 2.0, 3.0), q=(4.0, 3.0, 2.0, 1.0)),
406
+ )
407
+
408
+ @unittest.skipUnless(wp.is_cuda_available(), "Requires CUDA")
409
+ def test_array_scan_error_devices_mismatch(self):
410
+ values = wp.zeros(123, dtype=int, device="cpu")
411
+ result = wp.zeros_like(values, device="cuda:0")
412
+ with self.assertRaisesRegex(
413
+ RuntimeError,
414
+ r"Array storage devices do not match$",
415
+ ):
416
+ wp.utils.array_scan(values, result, True)
417
+
418
+ @unittest.skipUnless(wp.is_cuda_available(), "Requires CUDA")
419
+ def test_radix_sort_pairs_error_devices_mismatch(self):
420
+ keys = wp.array((1, 2, 3), dtype=int, device="cpu")
421
+ values = wp.array((1, 2, 3), dtype=int, device="cuda:0")
422
+ with self.assertRaisesRegex(
423
+ RuntimeError,
424
+ r"Array storage devices do not match$",
425
+ ):
426
+ wp.utils.radix_sort_pairs(keys, values, 1)
427
+
428
+ @unittest.skipUnless(wp.is_cuda_available(), "Requires CUDA")
429
+ def test_array_inner_error_out_device_mismatch(self):
430
+ a = wp.array((1.0, 2.0, 3.0), dtype=wp.float32, device="cpu")
431
+ b = wp.array((1.0, 2.0, 3.0), dtype=wp.float32, device="cpu")
432
+ result = wp.empty(shape=(1,), dtype=wp.float32, device="cuda:0")
433
+ with self.assertRaisesRegex(
434
+ RuntimeError,
435
+ r"out storage device should match values array$",
436
+ ):
437
+ wp.utils.array_inner(a, b, result)
438
+
439
+ @unittest.skipUnless(wp.is_cuda_available(), "Requires CUDA")
440
+ def test_array_sum_error_out_device_mismatch(self):
441
+ values = wp.array((1.0, 2.0, 3.0), dtype=wp.float32, device="cpu")
442
+ result = wp.empty(shape=(1,), dtype=wp.float32, device="cuda:0")
443
+ with self.assertRaisesRegex(
444
+ RuntimeError,
445
+ r"out storage device should match values array$",
446
+ ):
447
+ wp.utils.array_sum(values, out=result)
448
+
449
+ @unittest.skipUnless(wp.is_cuda_available(), "Requires CUDA")
450
+ def test_array_inner_error_devices_mismatch(self):
451
+ a = wp.array((1.0, 2.0, 3.0), dtype=wp.float32, device="cpu")
452
+ b = wp.array((1.0, 2.0, 3.0), dtype=wp.float32, device="cuda:0")
453
+ with self.assertRaisesRegex(
454
+ RuntimeError,
455
+ r"Array storage devices do not match$",
456
+ ):
457
+ wp.utils.array_inner(a, b)
458
+
459
+ @unittest.skipUnless(wp.is_cuda_available(), "Requires CUDA")
460
+ def test_array_cast_error_devices_mismatch(self):
461
+ values = wp.array((1, 2, 3), dtype=int, device="cpu")
462
+ result = wp.empty(3, dtype=float, device="cuda:0")
463
+ with self.assertRaisesRegex(
464
+ RuntimeError,
465
+ r"Array storage devices do not match$",
466
+ ):
467
+ wp.utils.array_cast(values, result)
468
+
469
+ def test_mesh_adjacency(self):
470
+ triangles = (
471
+ (0, 3, 1),
472
+ (0, 2, 3),
473
+ )
474
+ adj = wp.utils.MeshAdjacency(triangles, len(triangles))
475
+ expected_edges = {
476
+ (0, 3): (0, 3, 1, 2, 0, 1),
477
+ (1, 3): (3, 1, 0, -1, 0, -1),
478
+ (0, 1): (1, 0, 3, -1, 0, -1),
479
+ (0, 2): (0, 2, 3, -1, 1, -1),
480
+ (2, 3): (2, 3, 0, -1, 1, -1),
481
+ }
482
+ edges = {k: (e.v0, e.v1, e.o0, e.o1, e.f0, e.f1) for k, e in adj.edges.items()}
483
+ self.assertDictEqual(edges, expected_edges)
484
+
485
+ def test_mesh_adjacency_error_manifold(self):
486
+ triangles = (
487
+ (0, 3, 1),
488
+ (0, 2, 3),
489
+ (3, 0, 1),
490
+ )
491
+
492
+ with contextlib.redirect_stdout(io.StringIO()) as f:
493
+ wp.utils.MeshAdjacency(triangles, len(triangles))
494
+
495
+ self.assertEqual(f.getvalue(), "Detected non-manifold edge\n")
496
+
497
+ def test_scoped_timer(self):
498
+ with contextlib.redirect_stdout(io.StringIO()) as f:
499
+ with wp.ScopedTimer("hello"):
500
+ pass
501
+
502
+ self.assertRegex(f.getvalue(), r"^hello took \d+\.\d+ ms$")
503
+
504
+ with contextlib.redirect_stdout(io.StringIO()) as f:
505
+ with wp.ScopedTimer("hello", detailed=True):
506
+ pass
507
+
508
+ self.assertRegex(f.getvalue(), r"^ 4 function calls in \d+\.\d+ seconds")
509
+ self.assertRegex(f.getvalue(), r"hello took \d+\.\d+ ms$")
510
+
511
+
512
+ add_function_test(TestUtils, "test_array_scan", test_array_scan, devices=devices)
513
+ add_function_test(TestUtils, "test_array_scan_empty", test_array_scan_empty, devices=devices)
514
+ add_function_test(
515
+ TestUtils, "test_array_scan_error_sizes_mismatch", test_array_scan_error_sizes_mismatch, devices=devices
516
+ )
517
+ add_function_test(
518
+ TestUtils, "test_array_scan_error_dtypes_mismatch", test_array_scan_error_dtypes_mismatch, devices=devices
519
+ )
520
+ add_function_test(
521
+ TestUtils, "test_array_scan_error_unsupported_dtype", test_array_scan_error_unsupported_dtype, devices=devices
522
+ )
523
+ add_function_test(TestUtils, "test_radix_sort_pairs", test_radix_sort_pairs, devices=devices)
524
+ add_function_test(TestUtils, "test_radix_sort_pairs_empty", test_radix_sort_pairs, devices=devices)
525
+ add_function_test(
526
+ TestUtils,
527
+ "test_radix_sort_pairs_error_insufficient_storage",
528
+ test_radix_sort_pairs_error_insufficient_storage,
529
+ devices=devices,
530
+ )
531
+ add_function_test(
532
+ TestUtils,
533
+ "test_radix_sort_pairs_error_unsupported_dtype",
534
+ test_radix_sort_pairs_error_unsupported_dtype,
535
+ devices=devices,
536
+ )
537
+ add_function_test(TestUtils, "test_segmented_sort_pairs", test_segmented_sort_pairs, devices=devices)
538
+ add_function_test(TestUtils, "test_segmented_sort_pairs_empty", test_segmented_sort_pairs, devices=devices)
539
+ add_function_test(
540
+ TestUtils,
541
+ "test_segmented_sort_pairs_error_insufficient_storage",
542
+ test_segmented_sort_pairs_error_insufficient_storage,
543
+ devices=devices,
544
+ )
545
+ add_function_test(
546
+ TestUtils,
547
+ "test_segmented_sort_pairs_error_unsupported_dtype",
548
+ test_segmented_sort_pairs_error_unsupported_dtype,
549
+ devices=devices,
550
+ )
551
+ add_function_test(TestUtils, "test_array_sum", test_array_sum, devices=devices)
552
+ add_function_test(
553
+ TestUtils, "test_array_sum_error_out_dtype_mismatch", test_array_sum_error_out_dtype_mismatch, devices=devices
554
+ )
555
+ add_function_test(
556
+ TestUtils, "test_array_sum_error_out_shape_mismatch", test_array_sum_error_out_shape_mismatch, devices=devices
557
+ )
558
+ add_function_test(
559
+ TestUtils, "test_array_sum_error_unsupported_dtype", test_array_sum_error_unsupported_dtype, devices=devices
560
+ )
561
+ add_function_test(TestUtils, "test_array_inner", test_array_inner, devices=devices)
562
+ add_function_test(
563
+ TestUtils, "test_array_inner_error_sizes_mismatch", test_array_inner_error_sizes_mismatch, devices=devices
564
+ )
565
+ add_function_test(
566
+ TestUtils, "test_array_inner_error_dtypes_mismatch", test_array_inner_error_dtypes_mismatch, devices=devices
567
+ )
568
+ add_function_test(
569
+ TestUtils, "test_array_inner_error_out_dtype_mismatch", test_array_inner_error_out_dtype_mismatch, devices=devices
570
+ )
571
+ add_function_test(
572
+ TestUtils, "test_array_inner_error_out_shape_mismatch", test_array_inner_error_out_shape_mismatch, devices=devices
573
+ )
574
+ add_function_test(
575
+ TestUtils, "test_array_inner_error_unsupported_dtype", test_array_inner_error_unsupported_dtype, devices=devices
576
+ )
577
+ add_function_test(TestUtils, "test_array_cast", test_array_cast, devices=devices)
578
+ add_function_test(
579
+ TestUtils,
580
+ "test_array_cast_error_unsupported_partial_cast",
581
+ test_array_cast_error_unsupported_partial_cast,
582
+ devices=devices,
583
+ )
584
+
585
+
586
+ if __name__ == "__main__":
587
+ wp.clear_kernel_cache()
588
+ unittest.main(verbosity=2)