warp-lang 1.10.0__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (468) hide show
  1. warp/__init__.py +334 -0
  2. warp/__init__.pyi +5856 -0
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1077 -0
  5. warp/_src/build.py +620 -0
  6. warp/_src/build_dll.py +642 -0
  7. warp/_src/builtins.py +10555 -0
  8. warp/_src/codegen.py +4361 -0
  9. warp/_src/config.py +178 -0
  10. warp/_src/constants.py +59 -0
  11. warp/_src/context.py +8352 -0
  12. warp/_src/dlpack.py +464 -0
  13. warp/_src/fabric.py +362 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +510 -0
  16. warp/_src/fem/cache.py +689 -0
  17. warp/_src/fem/dirichlet.py +190 -0
  18. warp/_src/fem/domain.py +553 -0
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +703 -0
  21. warp/_src/fem/field/nodal_field.py +403 -0
  22. warp/_src/fem/field/restriction.py +39 -0
  23. warp/_src/fem/field/virtual.py +1021 -0
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/_src/fem/geometry/adaptive_nanogrid.py +782 -0
  26. warp/_src/fem/geometry/closest_point.py +99 -0
  27. warp/_src/fem/geometry/deformed_geometry.py +277 -0
  28. warp/_src/fem/geometry/element.py +854 -0
  29. warp/_src/fem/geometry/geometry.py +693 -0
  30. warp/_src/fem/geometry/grid_2d.py +478 -0
  31. warp/_src/fem/geometry/grid_3d.py +539 -0
  32. warp/_src/fem/geometry/hexmesh.py +956 -0
  33. warp/_src/fem/geometry/nanogrid.py +660 -0
  34. warp/_src/fem/geometry/partition.py +483 -0
  35. warp/_src/fem/geometry/quadmesh.py +597 -0
  36. warp/_src/fem/geometry/tetmesh.py +762 -0
  37. warp/_src/fem/geometry/trimesh.py +588 -0
  38. warp/_src/fem/integrate.py +2507 -0
  39. warp/_src/fem/linalg.py +385 -0
  40. warp/_src/fem/operator.py +398 -0
  41. warp/_src/fem/polynomial.py +231 -0
  42. warp/_src/fem/quadrature/__init__.py +17 -0
  43. warp/_src/fem/quadrature/pic_quadrature.py +318 -0
  44. warp/_src/fem/quadrature/quadrature.py +665 -0
  45. warp/_src/fem/space/__init__.py +248 -0
  46. warp/_src/fem/space/basis_function_space.py +499 -0
  47. warp/_src/fem/space/basis_space.py +681 -0
  48. warp/_src/fem/space/dof_mapper.py +253 -0
  49. warp/_src/fem/space/function_space.py +312 -0
  50. warp/_src/fem/space/grid_2d_function_space.py +179 -0
  51. warp/_src/fem/space/grid_3d_function_space.py +229 -0
  52. warp/_src/fem/space/hexmesh_function_space.py +255 -0
  53. warp/_src/fem/space/nanogrid_function_space.py +199 -0
  54. warp/_src/fem/space/partition.py +435 -0
  55. warp/_src/fem/space/quadmesh_function_space.py +222 -0
  56. warp/_src/fem/space/restriction.py +221 -0
  57. warp/_src/fem/space/shape/__init__.py +152 -0
  58. warp/_src/fem/space/shape/cube_shape_function.py +1107 -0
  59. warp/_src/fem/space/shape/shape_function.py +134 -0
  60. warp/_src/fem/space/shape/square_shape_function.py +928 -0
  61. warp/_src/fem/space/shape/tet_shape_function.py +829 -0
  62. warp/_src/fem/space/shape/triangle_shape_function.py +674 -0
  63. warp/_src/fem/space/tetmesh_function_space.py +270 -0
  64. warp/_src/fem/space/topology.py +461 -0
  65. warp/_src/fem/space/trimesh_function_space.py +193 -0
  66. warp/_src/fem/types.py +114 -0
  67. warp/_src/fem/utils.py +488 -0
  68. warp/_src/jax.py +188 -0
  69. warp/_src/jax_experimental/__init__.py +14 -0
  70. warp/_src/jax_experimental/custom_call.py +389 -0
  71. warp/_src/jax_experimental/ffi.py +1286 -0
  72. warp/_src/jax_experimental/xla_ffi.py +658 -0
  73. warp/_src/marching_cubes.py +710 -0
  74. warp/_src/math.py +416 -0
  75. warp/_src/optim/__init__.py +14 -0
  76. warp/_src/optim/adam.py +165 -0
  77. warp/_src/optim/linear.py +1608 -0
  78. warp/_src/optim/sgd.py +114 -0
  79. warp/_src/paddle.py +408 -0
  80. warp/_src/render/__init__.py +14 -0
  81. warp/_src/render/imgui_manager.py +291 -0
  82. warp/_src/render/render_opengl.py +3638 -0
  83. warp/_src/render/render_usd.py +939 -0
  84. warp/_src/render/utils.py +162 -0
  85. warp/_src/sparse.py +2718 -0
  86. warp/_src/tape.py +1208 -0
  87. warp/_src/thirdparty/__init__.py +0 -0
  88. warp/_src/thirdparty/appdirs.py +598 -0
  89. warp/_src/thirdparty/dlpack.py +145 -0
  90. warp/_src/thirdparty/unittest_parallel.py +676 -0
  91. warp/_src/torch.py +393 -0
  92. warp/_src/types.py +5888 -0
  93. warp/_src/utils.py +1695 -0
  94. warp/autograd.py +33 -0
  95. warp/bin/libwarp-clang.dylib +0 -0
  96. warp/bin/libwarp.dylib +0 -0
  97. warp/build.py +29 -0
  98. warp/build_dll.py +24 -0
  99. warp/codegen.py +24 -0
  100. warp/constants.py +24 -0
  101. warp/context.py +33 -0
  102. warp/dlpack.py +24 -0
  103. warp/examples/__init__.py +24 -0
  104. warp/examples/assets/bear.usd +0 -0
  105. warp/examples/assets/bunny.usd +0 -0
  106. warp/examples/assets/cube.usd +0 -0
  107. warp/examples/assets/nonuniform.usd +0 -0
  108. warp/examples/assets/nvidia_logo.png +0 -0
  109. warp/examples/assets/pixel.jpg +0 -0
  110. warp/examples/assets/rocks.nvdb +0 -0
  111. warp/examples/assets/rocks.usd +0 -0
  112. warp/examples/assets/sphere.usd +0 -0
  113. warp/examples/assets/square_cloth.usd +0 -0
  114. warp/examples/benchmarks/benchmark_api.py +389 -0
  115. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  116. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  117. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  118. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  119. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  120. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  121. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  122. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  123. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  124. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  125. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  126. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  127. warp/examples/benchmarks/benchmark_launches.py +301 -0
  128. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  129. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  130. warp/examples/browse.py +37 -0
  131. warp/examples/core/example_cupy.py +86 -0
  132. warp/examples/core/example_dem.py +241 -0
  133. warp/examples/core/example_fluid.py +299 -0
  134. warp/examples/core/example_graph_capture.py +150 -0
  135. warp/examples/core/example_marching_cubes.py +195 -0
  136. warp/examples/core/example_mesh.py +180 -0
  137. warp/examples/core/example_mesh_intersect.py +211 -0
  138. warp/examples/core/example_nvdb.py +182 -0
  139. warp/examples/core/example_raycast.py +111 -0
  140. warp/examples/core/example_raymarch.py +205 -0
  141. warp/examples/core/example_render_opengl.py +290 -0
  142. warp/examples/core/example_sample_mesh.py +300 -0
  143. warp/examples/core/example_sph.py +411 -0
  144. warp/examples/core/example_spin_lock.py +93 -0
  145. warp/examples/core/example_torch.py +211 -0
  146. warp/examples/core/example_wave.py +269 -0
  147. warp/examples/core/example_work_queue.py +118 -0
  148. warp/examples/distributed/example_jacobi_mpi.py +506 -0
  149. warp/examples/fem/example_adaptive_grid.py +286 -0
  150. warp/examples/fem/example_apic_fluid.py +469 -0
  151. warp/examples/fem/example_burgers.py +261 -0
  152. warp/examples/fem/example_convection_diffusion.py +181 -0
  153. warp/examples/fem/example_convection_diffusion_dg.py +225 -0
  154. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  155. warp/examples/fem/example_deformed_geometry.py +172 -0
  156. warp/examples/fem/example_diffusion.py +196 -0
  157. warp/examples/fem/example_diffusion_3d.py +225 -0
  158. warp/examples/fem/example_diffusion_mgpu.py +225 -0
  159. warp/examples/fem/example_distortion_energy.py +228 -0
  160. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  161. warp/examples/fem/example_magnetostatics.py +242 -0
  162. warp/examples/fem/example_mixed_elasticity.py +293 -0
  163. warp/examples/fem/example_navier_stokes.py +263 -0
  164. warp/examples/fem/example_nonconforming_contact.py +300 -0
  165. warp/examples/fem/example_stokes.py +213 -0
  166. warp/examples/fem/example_stokes_transfer.py +262 -0
  167. warp/examples/fem/example_streamlines.py +357 -0
  168. warp/examples/fem/utils.py +1047 -0
  169. warp/examples/interop/example_jax_callable.py +146 -0
  170. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  171. warp/examples/interop/example_jax_kernel.py +232 -0
  172. warp/examples/optim/example_diffray.py +561 -0
  173. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  174. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  175. warp/examples/tile/example_tile_cholesky.py +88 -0
  176. warp/examples/tile/example_tile_convolution.py +66 -0
  177. warp/examples/tile/example_tile_fft.py +55 -0
  178. warp/examples/tile/example_tile_filtering.py +113 -0
  179. warp/examples/tile/example_tile_matmul.py +85 -0
  180. warp/examples/tile/example_tile_mcgp.py +191 -0
  181. warp/examples/tile/example_tile_mlp.py +385 -0
  182. warp/examples/tile/example_tile_nbody.py +199 -0
  183. warp/fabric.py +24 -0
  184. warp/fem/__init__.py +173 -0
  185. warp/fem/adaptivity.py +26 -0
  186. warp/fem/cache.py +30 -0
  187. warp/fem/dirichlet.py +24 -0
  188. warp/fem/field/__init__.py +24 -0
  189. warp/fem/field/field.py +26 -0
  190. warp/fem/geometry/__init__.py +21 -0
  191. warp/fem/geometry/closest_point.py +31 -0
  192. warp/fem/linalg.py +38 -0
  193. warp/fem/operator.py +32 -0
  194. warp/fem/polynomial.py +29 -0
  195. warp/fem/space/__init__.py +22 -0
  196. warp/fem/space/basis_space.py +24 -0
  197. warp/fem/space/shape/__init__.py +68 -0
  198. warp/fem/space/topology.py +24 -0
  199. warp/fem/types.py +24 -0
  200. warp/fem/utils.py +32 -0
  201. warp/jax.py +29 -0
  202. warp/jax_experimental/__init__.py +29 -0
  203. warp/jax_experimental/custom_call.py +29 -0
  204. warp/jax_experimental/ffi.py +39 -0
  205. warp/jax_experimental/xla_ffi.py +24 -0
  206. warp/marching_cubes.py +24 -0
  207. warp/math.py +37 -0
  208. warp/native/array.h +1687 -0
  209. warp/native/builtin.h +2327 -0
  210. warp/native/bvh.cpp +562 -0
  211. warp/native/bvh.cu +826 -0
  212. warp/native/bvh.h +555 -0
  213. warp/native/clang/clang.cpp +541 -0
  214. warp/native/coloring.cpp +622 -0
  215. warp/native/crt.cpp +51 -0
  216. warp/native/crt.h +568 -0
  217. warp/native/cuda_crt.h +1058 -0
  218. warp/native/cuda_util.cpp +677 -0
  219. warp/native/cuda_util.h +313 -0
  220. warp/native/error.cpp +77 -0
  221. warp/native/error.h +36 -0
  222. warp/native/exports.h +2023 -0
  223. warp/native/fabric.h +246 -0
  224. warp/native/hashgrid.cpp +311 -0
  225. warp/native/hashgrid.cu +89 -0
  226. warp/native/hashgrid.h +240 -0
  227. warp/native/initializer_array.h +41 -0
  228. warp/native/intersect.h +1253 -0
  229. warp/native/intersect_adj.h +375 -0
  230. warp/native/intersect_tri.h +348 -0
  231. warp/native/mat.h +5189 -0
  232. warp/native/mathdx.cpp +93 -0
  233. warp/native/matnn.h +221 -0
  234. warp/native/mesh.cpp +266 -0
  235. warp/native/mesh.cu +406 -0
  236. warp/native/mesh.h +2097 -0
  237. warp/native/nanovdb/GridHandle.h +533 -0
  238. warp/native/nanovdb/HostBuffer.h +591 -0
  239. warp/native/nanovdb/NanoVDB.h +6246 -0
  240. warp/native/nanovdb/NodeManager.h +323 -0
  241. warp/native/nanovdb/PNanoVDB.h +3390 -0
  242. warp/native/noise.h +859 -0
  243. warp/native/quat.h +1664 -0
  244. warp/native/rand.h +342 -0
  245. warp/native/range.h +145 -0
  246. warp/native/reduce.cpp +174 -0
  247. warp/native/reduce.cu +363 -0
  248. warp/native/runlength_encode.cpp +79 -0
  249. warp/native/runlength_encode.cu +61 -0
  250. warp/native/scan.cpp +47 -0
  251. warp/native/scan.cu +55 -0
  252. warp/native/scan.h +23 -0
  253. warp/native/solid_angle.h +466 -0
  254. warp/native/sort.cpp +251 -0
  255. warp/native/sort.cu +286 -0
  256. warp/native/sort.h +35 -0
  257. warp/native/sparse.cpp +241 -0
  258. warp/native/sparse.cu +435 -0
  259. warp/native/spatial.h +1306 -0
  260. warp/native/svd.h +727 -0
  261. warp/native/temp_buffer.h +46 -0
  262. warp/native/tile.h +4124 -0
  263. warp/native/tile_radix_sort.h +1112 -0
  264. warp/native/tile_reduce.h +838 -0
  265. warp/native/tile_scan.h +240 -0
  266. warp/native/tuple.h +189 -0
  267. warp/native/vec.h +2199 -0
  268. warp/native/version.h +23 -0
  269. warp/native/volume.cpp +501 -0
  270. warp/native/volume.cu +68 -0
  271. warp/native/volume.h +970 -0
  272. warp/native/volume_builder.cu +483 -0
  273. warp/native/volume_builder.h +52 -0
  274. warp/native/volume_impl.h +70 -0
  275. warp/native/warp.cpp +1143 -0
  276. warp/native/warp.cu +4604 -0
  277. warp/native/warp.h +358 -0
  278. warp/optim/__init__.py +20 -0
  279. warp/optim/adam.py +24 -0
  280. warp/optim/linear.py +35 -0
  281. warp/optim/sgd.py +24 -0
  282. warp/paddle.py +24 -0
  283. warp/py.typed +0 -0
  284. warp/render/__init__.py +22 -0
  285. warp/render/imgui_manager.py +29 -0
  286. warp/render/render_opengl.py +24 -0
  287. warp/render/render_usd.py +24 -0
  288. warp/render/utils.py +24 -0
  289. warp/sparse.py +51 -0
  290. warp/tape.py +24 -0
  291. warp/tests/__init__.py +1 -0
  292. warp/tests/__main__.py +4 -0
  293. warp/tests/assets/curlnoise_golden.npy +0 -0
  294. warp/tests/assets/mlp_golden.npy +0 -0
  295. warp/tests/assets/pixel.npy +0 -0
  296. warp/tests/assets/pnoise_golden.npy +0 -0
  297. warp/tests/assets/spiky.usd +0 -0
  298. warp/tests/assets/test_grid.nvdb +0 -0
  299. warp/tests/assets/test_index_grid.nvdb +0 -0
  300. warp/tests/assets/test_int32_grid.nvdb +0 -0
  301. warp/tests/assets/test_vec_grid.nvdb +0 -0
  302. warp/tests/assets/torus.nvdb +0 -0
  303. warp/tests/assets/torus.usda +105 -0
  304. warp/tests/aux_test_class_kernel.py +34 -0
  305. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  306. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  307. warp/tests/aux_test_dependent.py +29 -0
  308. warp/tests/aux_test_grad_customs.py +29 -0
  309. warp/tests/aux_test_instancing_gc.py +26 -0
  310. warp/tests/aux_test_module_aot.py +7 -0
  311. warp/tests/aux_test_module_unload.py +23 -0
  312. warp/tests/aux_test_name_clash1.py +40 -0
  313. warp/tests/aux_test_name_clash2.py +40 -0
  314. warp/tests/aux_test_reference.py +9 -0
  315. warp/tests/aux_test_reference_reference.py +8 -0
  316. warp/tests/aux_test_square.py +16 -0
  317. warp/tests/aux_test_unresolved_func.py +22 -0
  318. warp/tests/aux_test_unresolved_symbol.py +22 -0
  319. warp/tests/cuda/__init__.py +0 -0
  320. warp/tests/cuda/test_async.py +676 -0
  321. warp/tests/cuda/test_conditional_captures.py +1147 -0
  322. warp/tests/cuda/test_ipc.py +124 -0
  323. warp/tests/cuda/test_mempool.py +233 -0
  324. warp/tests/cuda/test_multigpu.py +169 -0
  325. warp/tests/cuda/test_peer.py +139 -0
  326. warp/tests/cuda/test_pinned.py +84 -0
  327. warp/tests/cuda/test_streams.py +691 -0
  328. warp/tests/geometry/__init__.py +0 -0
  329. warp/tests/geometry/test_bvh.py +335 -0
  330. warp/tests/geometry/test_hash_grid.py +259 -0
  331. warp/tests/geometry/test_marching_cubes.py +294 -0
  332. warp/tests/geometry/test_mesh.py +318 -0
  333. warp/tests/geometry/test_mesh_query_aabb.py +392 -0
  334. warp/tests/geometry/test_mesh_query_point.py +935 -0
  335. warp/tests/geometry/test_mesh_query_ray.py +323 -0
  336. warp/tests/geometry/test_volume.py +1103 -0
  337. warp/tests/geometry/test_volume_write.py +346 -0
  338. warp/tests/interop/__init__.py +0 -0
  339. warp/tests/interop/test_dlpack.py +730 -0
  340. warp/tests/interop/test_jax.py +1673 -0
  341. warp/tests/interop/test_paddle.py +800 -0
  342. warp/tests/interop/test_torch.py +1001 -0
  343. warp/tests/run_coverage_serial.py +39 -0
  344. warp/tests/test_adam.py +162 -0
  345. warp/tests/test_arithmetic.py +1096 -0
  346. warp/tests/test_array.py +3756 -0
  347. warp/tests/test_array_reduce.py +156 -0
  348. warp/tests/test_assert.py +303 -0
  349. warp/tests/test_atomic.py +336 -0
  350. warp/tests/test_atomic_bitwise.py +209 -0
  351. warp/tests/test_atomic_cas.py +312 -0
  352. warp/tests/test_bool.py +220 -0
  353. warp/tests/test_builtins_resolution.py +732 -0
  354. warp/tests/test_closest_point_edge_edge.py +327 -0
  355. warp/tests/test_codegen.py +974 -0
  356. warp/tests/test_codegen_instancing.py +1495 -0
  357. warp/tests/test_compile_consts.py +215 -0
  358. warp/tests/test_conditional.py +298 -0
  359. warp/tests/test_context.py +35 -0
  360. warp/tests/test_copy.py +319 -0
  361. warp/tests/test_ctypes.py +618 -0
  362. warp/tests/test_dense.py +73 -0
  363. warp/tests/test_devices.py +127 -0
  364. warp/tests/test_enum.py +136 -0
  365. warp/tests/test_examples.py +424 -0
  366. warp/tests/test_fabricarray.py +998 -0
  367. warp/tests/test_fast_math.py +72 -0
  368. warp/tests/test_fem.py +2204 -0
  369. warp/tests/test_fixedarray.py +229 -0
  370. warp/tests/test_fp16.py +136 -0
  371. warp/tests/test_func.py +501 -0
  372. warp/tests/test_future_annotations.py +100 -0
  373. warp/tests/test_generics.py +656 -0
  374. warp/tests/test_grad.py +893 -0
  375. warp/tests/test_grad_customs.py +339 -0
  376. warp/tests/test_grad_debug.py +341 -0
  377. warp/tests/test_implicit_init.py +411 -0
  378. warp/tests/test_import.py +45 -0
  379. warp/tests/test_indexedarray.py +1140 -0
  380. warp/tests/test_intersect.py +103 -0
  381. warp/tests/test_iter.py +76 -0
  382. warp/tests/test_large.py +177 -0
  383. warp/tests/test_launch.py +411 -0
  384. warp/tests/test_lerp.py +151 -0
  385. warp/tests/test_linear_solvers.py +223 -0
  386. warp/tests/test_lvalue.py +427 -0
  387. warp/tests/test_map.py +526 -0
  388. warp/tests/test_mat.py +3515 -0
  389. warp/tests/test_mat_assign_copy.py +178 -0
  390. warp/tests/test_mat_constructors.py +573 -0
  391. warp/tests/test_mat_lite.py +122 -0
  392. warp/tests/test_mat_scalar_ops.py +2913 -0
  393. warp/tests/test_math.py +212 -0
  394. warp/tests/test_module_aot.py +287 -0
  395. warp/tests/test_module_hashing.py +258 -0
  396. warp/tests/test_modules_lite.py +70 -0
  397. warp/tests/test_noise.py +252 -0
  398. warp/tests/test_operators.py +299 -0
  399. warp/tests/test_options.py +129 -0
  400. warp/tests/test_overwrite.py +551 -0
  401. warp/tests/test_print.py +408 -0
  402. warp/tests/test_quat.py +2653 -0
  403. warp/tests/test_quat_assign_copy.py +145 -0
  404. warp/tests/test_rand.py +339 -0
  405. warp/tests/test_reload.py +303 -0
  406. warp/tests/test_rounding.py +157 -0
  407. warp/tests/test_runlength_encode.py +196 -0
  408. warp/tests/test_scalar_ops.py +133 -0
  409. warp/tests/test_smoothstep.py +108 -0
  410. warp/tests/test_snippet.py +318 -0
  411. warp/tests/test_sparse.py +845 -0
  412. warp/tests/test_spatial.py +2859 -0
  413. warp/tests/test_spatial_assign_copy.py +160 -0
  414. warp/tests/test_special_values.py +361 -0
  415. warp/tests/test_static.py +640 -0
  416. warp/tests/test_struct.py +901 -0
  417. warp/tests/test_tape.py +242 -0
  418. warp/tests/test_transient_module.py +93 -0
  419. warp/tests/test_triangle_closest_point.py +192 -0
  420. warp/tests/test_tuple.py +361 -0
  421. warp/tests/test_types.py +615 -0
  422. warp/tests/test_utils.py +594 -0
  423. warp/tests/test_vec.py +1408 -0
  424. warp/tests/test_vec_assign_copy.py +143 -0
  425. warp/tests/test_vec_constructors.py +325 -0
  426. warp/tests/test_vec_lite.py +80 -0
  427. warp/tests/test_vec_scalar_ops.py +2327 -0
  428. warp/tests/test_verify_fp.py +100 -0
  429. warp/tests/test_version.py +75 -0
  430. warp/tests/tile/__init__.py +0 -0
  431. warp/tests/tile/test_tile.py +1519 -0
  432. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  433. warp/tests/tile/test_tile_cholesky.py +608 -0
  434. warp/tests/tile/test_tile_load.py +724 -0
  435. warp/tests/tile/test_tile_mathdx.py +156 -0
  436. warp/tests/tile/test_tile_matmul.py +179 -0
  437. warp/tests/tile/test_tile_mlp.py +400 -0
  438. warp/tests/tile/test_tile_reduce.py +950 -0
  439. warp/tests/tile/test_tile_shared_memory.py +376 -0
  440. warp/tests/tile/test_tile_sort.py +121 -0
  441. warp/tests/tile/test_tile_view.py +173 -0
  442. warp/tests/unittest_serial.py +47 -0
  443. warp/tests/unittest_suites.py +430 -0
  444. warp/tests/unittest_utils.py +469 -0
  445. warp/tests/walkthrough_debug.py +95 -0
  446. warp/torch.py +24 -0
  447. warp/types.py +51 -0
  448. warp/utils.py +31 -0
  449. warp_lang-1.10.0.dist-info/METADATA +459 -0
  450. warp_lang-1.10.0.dist-info/RECORD +468 -0
  451. warp_lang-1.10.0.dist-info/WHEEL +5 -0
  452. warp_lang-1.10.0.dist-info/licenses/LICENSE.md +176 -0
  453. warp_lang-1.10.0.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  454. warp_lang-1.10.0.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  455. warp_lang-1.10.0.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  456. warp_lang-1.10.0.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  457. warp_lang-1.10.0.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  458. warp_lang-1.10.0.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  459. warp_lang-1.10.0.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  460. warp_lang-1.10.0.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  461. warp_lang-1.10.0.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  462. warp_lang-1.10.0.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  463. warp_lang-1.10.0.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  464. warp_lang-1.10.0.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  465. warp_lang-1.10.0.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  466. warp_lang-1.10.0.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  467. warp_lang-1.10.0.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  468. warp_lang-1.10.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,312 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import unittest
16
+
17
+ import numpy as np
18
+
19
+ import warp as wp
20
+ from warp.tests.unittest_utils import *
21
+
22
+ kernel_cache = {}
23
+
24
+
25
+ def getkernel(func, suffix=""):
26
+ key = func.__name__ + "_" + suffix
27
+ if key not in kernel_cache:
28
+ kernel_cache[key] = wp.Kernel(func=func, key=key)
29
+ return kernel_cache[key]
30
+
31
+
32
+ def test_atomic_cas(test, device, dtype, register_kernels=False):
33
+ warp_type = wp._src.types.np_dtype_to_warp_type[np.dtype(dtype)]
34
+ n = 100
35
+ counter = wp.array([0], dtype=warp_type, device=device)
36
+ lock = wp.array([0], dtype=warp_type, device=device)
37
+
38
+ @wp.func
39
+ def spinlock_acquire_1d(lock: wp.array(dtype=warp_type)):
40
+ # Try to acquire the lock by setting it to 1 if it's 0
41
+ while wp.atomic_cas(lock, 0, warp_type(0), warp_type(1)) == 1:
42
+ pass
43
+
44
+ @wp.func
45
+ def spinlock_release_1d(lock: wp.array(dtype=warp_type)):
46
+ # Release the lock by setting it back to 0
47
+ wp.atomic_exch(lock, 0, warp_type(0))
48
+
49
+ @wp.func
50
+ def volatile_read_1d(ptr: wp.array(dtype=warp_type), index: int):
51
+ value = wp.atomic_exch(ptr, index, warp_type(0))
52
+ wp.atomic_exch(ptr, index, value)
53
+ return value
54
+
55
+ def test_spinlock_counter_1d(counter: wp.array(dtype=warp_type), lock: wp.array(dtype=warp_type)):
56
+ # Try to acquire the lock
57
+ spinlock_acquire_1d(lock)
58
+
59
+ # Critical section - increment counter
60
+ # counter[0] = counter[0] + 1 # This gives wrong results - counter should be marked as volatile
61
+
62
+ # Work around since warp arrays cannot be marked as volatile
63
+ value = volatile_read_1d(counter, 0)
64
+ counter[0] = value + warp_type(1)
65
+
66
+ # Release the lock
67
+ spinlock_release_1d(lock)
68
+
69
+ kernel = getkernel(test_spinlock_counter_1d, suffix=dtype.__name__)
70
+
71
+ if register_kernels:
72
+ return
73
+
74
+ wp.launch(kernel, dim=n, inputs=[counter, lock], device=device)
75
+
76
+ # Verify counter reached n
77
+ counter_np = counter.numpy()
78
+ expected = np.array([n], dtype=dtype)
79
+
80
+ if not np.array_equal(counter_np, expected):
81
+ print(f"Counter mismatch: expected {expected}, got {counter_np}")
82
+
83
+ assert_np_equal(counter_np, expected)
84
+
85
+
86
+ def test_atomic_cas_2d(test, device, dtype, register_kernels=False):
87
+ warp_type = wp._src.types.np_dtype_to_warp_type[np.dtype(dtype)]
88
+ n = 100
89
+ counter = wp.array([0], dtype=warp_type, device=device)
90
+ lock = wp.zeros(shape=(1, 1), dtype=warp_type, device=device)
91
+
92
+ @wp.func
93
+ def spinlock_acquire_2d(lock: wp.array2d(dtype=warp_type)):
94
+ # Try to acquire the lock by setting it to 1 if it's 0
95
+ while wp.atomic_cas(lock, 0, 0, warp_type(0), warp_type(1)) == 1:
96
+ pass
97
+
98
+ @wp.func
99
+ def spinlock_release_2d(lock: wp.array2d(dtype=warp_type)):
100
+ # Release the lock by setting it back to 0
101
+ wp.atomic_exch(lock, 0, 0, warp_type(0))
102
+
103
+ @wp.func
104
+ def volatile_read_2d(ptr: wp.array(dtype=warp_type), index: int):
105
+ value = wp.atomic_exch(ptr, index, warp_type(0))
106
+ wp.atomic_exch(ptr, index, value)
107
+ return value
108
+
109
+ def test_spinlock_counter_2d(counter: wp.array(dtype=warp_type), lock: wp.array2d(dtype=warp_type)):
110
+ # Try to acquire the lock
111
+ spinlock_acquire_2d(lock)
112
+
113
+ # Critical section - increment counter
114
+ # counter[0] = counter[0] + 1 # This gives wrong results - counter should be marked as volatile
115
+
116
+ # Work around since warp arrays cannot be marked as volatile
117
+ value = volatile_read_2d(counter, 0)
118
+ counter[0] = value + warp_type(1)
119
+
120
+ # Release the lock
121
+ spinlock_release_2d(lock)
122
+
123
+ kernel = getkernel(test_spinlock_counter_2d, suffix=dtype.__name__)
124
+
125
+ if register_kernels:
126
+ return
127
+
128
+ wp.launch(kernel, dim=n, inputs=[counter, lock], device=device)
129
+
130
+ # Verify counter reached n
131
+ counter_np = counter.numpy()
132
+ expected = np.array([n], dtype=dtype)
133
+
134
+ if not np.array_equal(counter_np, expected):
135
+ print(f"Counter mismatch: expected {expected}, got {counter_np}")
136
+
137
+ assert_np_equal(counter_np, expected)
138
+
139
+
140
+ def test_atomic_cas_3d(test, device, dtype, register_kernels=False):
141
+ warp_type = wp._src.types.np_dtype_to_warp_type[np.dtype(dtype)]
142
+ n = 100
143
+ counter = wp.array([0], dtype=warp_type, device=device)
144
+ lock = wp.zeros(shape=(1, 1, 1), dtype=warp_type, device=device)
145
+
146
+ @wp.func
147
+ def spinlock_acquire_3d(lock: wp.array3d(dtype=warp_type)):
148
+ # Try to acquire the lock by setting it to 1 if it's 0
149
+ while wp.atomic_cas(lock, 0, 0, 0, warp_type(0), warp_type(1)) == 1:
150
+ pass
151
+
152
+ @wp.func
153
+ def spinlock_release_3d(lock: wp.array3d(dtype=warp_type)):
154
+ # Release the lock by setting it back to 0
155
+ wp.atomic_exch(lock, 0, 0, 0, warp_type(0))
156
+
157
+ @wp.func
158
+ def volatile_read_3d(ptr: wp.array(dtype=warp_type), index: int):
159
+ value = wp.atomic_exch(ptr, index, warp_type(0))
160
+ wp.atomic_exch(ptr, index, value)
161
+ return value
162
+
163
+ def test_spinlock_counter_3d(counter: wp.array(dtype=warp_type), lock: wp.array3d(dtype=warp_type)):
164
+ # Try to acquire the lock
165
+ spinlock_acquire_3d(lock)
166
+
167
+ # Critical section - increment counter
168
+ # counter[0] = counter[0] + 1 # This gives wrong results - counter should be marked as volatile
169
+
170
+ # Work around since warp arrays cannot be marked as volatile
171
+ value = volatile_read_3d(counter, 0)
172
+ counter[0] = value + warp_type(1)
173
+
174
+ # Release the lock
175
+ spinlock_release_3d(lock)
176
+
177
+ kernel = getkernel(test_spinlock_counter_3d, suffix=dtype.__name__)
178
+
179
+ if register_kernels:
180
+ return
181
+
182
+ wp.launch(kernel, dim=n, inputs=[counter, lock], device=device)
183
+
184
+ # Verify counter reached n
185
+ counter_np = counter.numpy()
186
+ expected = np.array([n], dtype=dtype)
187
+
188
+ if not np.array_equal(counter_np, expected):
189
+ print(f"Counter mismatch: expected {expected}, got {counter_np}")
190
+
191
+ assert_np_equal(counter_np, expected)
192
+
193
+
194
+ def create_spinlock_test_4d(dtype):
195
+ @wp.func
196
+ def spinlock_acquire(lock: wp.array(dtype=dtype, ndim=4)):
197
+ # Try to acquire the lock by setting it to 1 if it's 0
198
+ while wp.atomic_cas(lock, 0, 0, 0, 0, dtype(0), dtype(1)) == 1:
199
+ pass
200
+
201
+ @wp.func
202
+ def spinlock_release(lock: wp.array(dtype=dtype, ndim=4)):
203
+ # Release the lock by setting it back to 0
204
+ wp.atomic_exch(lock, 0, 0, 0, 0, dtype(0))
205
+
206
+ @wp.func
207
+ def volatile_read(ptr: wp.array(dtype=dtype), index: int):
208
+ value = wp.atomic_exch(ptr, index, dtype(0))
209
+ wp.atomic_exch(ptr, index, value)
210
+ return value
211
+
212
+ @wp.kernel
213
+ def test_spinlock_counter(counter: wp.array(dtype=dtype), lock: wp.array(dtype=dtype, ndim=4)):
214
+ # Try to acquire the lock
215
+ spinlock_acquire(lock)
216
+
217
+ # Critical section - increment counter
218
+ # counter[0] = counter[0] + 1 # This gives wrong results - counter should be marked as volatile
219
+
220
+ # Work around since warp arrays cannot be marked as volatile
221
+ value = volatile_read(counter, 0)
222
+ counter[0] = value + dtype(1)
223
+
224
+ # Release the lock
225
+ spinlock_release(lock)
226
+
227
+ return test_spinlock_counter
228
+
229
+
230
+ def test_atomic_cas_4d(test, device, dtype, register_kernels=False):
231
+ warp_type = wp._src.types.np_dtype_to_warp_type[np.dtype(dtype)]
232
+ n = 100
233
+ counter = wp.array([0], dtype=warp_type, device=device)
234
+ lock = wp.zeros(shape=(1, 1, 1, 1), dtype=warp_type, device=device)
235
+
236
+ @wp.func
237
+ def spinlock_acquire_4d(lock: wp.array4d(dtype=warp_type)):
238
+ # Try to acquire the lock by setting it to 1 if it's 0
239
+ while wp.atomic_cas(lock, 0, 0, 0, 0, warp_type(0), warp_type(1)) == 1:
240
+ pass
241
+
242
+ @wp.func
243
+ def spinlock_release_4d(lock: wp.array4d(dtype=warp_type)):
244
+ # Release the lock by setting it back to 0
245
+ wp.atomic_exch(lock, 0, 0, 0, 0, warp_type(0))
246
+
247
+ @wp.func
248
+ def volatile_read_4d(ptr: wp.array(dtype=warp_type), index: int):
249
+ value = wp.atomic_exch(ptr, index, warp_type(0))
250
+ wp.atomic_exch(ptr, index, value)
251
+ return value
252
+
253
+ def test_spinlock_counter_4d(counter: wp.array(dtype=warp_type), lock: wp.array4d(dtype=warp_type)):
254
+ # Try to acquire the lock
255
+ spinlock_acquire_4d(lock)
256
+
257
+ # Critical section - increment counter
258
+ # counter[0] = counter[0] + 1 # This gives wrong results - counter should be marked as volatile
259
+
260
+ # Work around since warp arrays cannot be marked as volatile
261
+ value = volatile_read_4d(counter, 0)
262
+ counter[0] = value + warp_type(1)
263
+
264
+ # Release the lock
265
+ spinlock_release_4d(lock)
266
+
267
+ kernel = getkernel(test_spinlock_counter_4d, suffix=dtype.__name__)
268
+
269
+ if register_kernels:
270
+ return
271
+
272
+ wp.launch(kernel, dim=n, inputs=[counter, lock], device=device)
273
+
274
+ # Verify counter reached n
275
+ counter_np = counter.numpy()
276
+ expected = np.array([n], dtype=dtype)
277
+
278
+ if not np.array_equal(counter_np, expected):
279
+ print(f"Counter mismatch: expected {expected}, got {counter_np}")
280
+
281
+ assert_np_equal(counter_np, expected)
282
+
283
+
284
+ devices = get_test_devices()
285
+
286
+
287
+ class TestAtomicCAS(unittest.TestCase):
288
+ pass
289
+
290
+
291
+ # Test all supported types
292
+ np_test_types = (np.int32, np.uint32, np.int64, np.uint64, np.float32, np.float64)
293
+
294
+ for dtype in np_test_types:
295
+ type_name = dtype.__name__
296
+ add_function_test_register_kernel(
297
+ TestAtomicCAS, f"test_cas_{type_name}", test_atomic_cas, devices=devices, dtype=dtype
298
+ )
299
+ # Add 2D test for each type
300
+ add_function_test_register_kernel(
301
+ TestAtomicCAS, f"test_cas_2d_{type_name}", test_atomic_cas_2d, devices=devices, dtype=dtype
302
+ )
303
+ add_function_test_register_kernel(
304
+ TestAtomicCAS, f"test_cas_3d_{type_name}", test_atomic_cas_3d, devices=devices, dtype=dtype
305
+ )
306
+ add_function_test_register_kernel(
307
+ TestAtomicCAS, f"test_cas_4d_{type_name}", test_atomic_cas_4d, devices=devices, dtype=dtype
308
+ )
309
+
310
+ if __name__ == "__main__":
311
+ wp.clear_kernel_cache()
312
+ unittest.main(verbosity=2)
@@ -0,0 +1,220 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import unittest
17
+
18
+ import numpy as np
19
+
20
+ import warp as wp
21
+ from warp.tests.unittest_utils import *
22
+
23
+ TRUE_CONSTANT = wp.constant(True)
24
+
25
+
26
+ @wp.func
27
+ def identity_function(input_bool: wp.bool, plain_bool: bool):
28
+ return input_bool and plain_bool
29
+
30
+
31
+ @wp.kernel
32
+ def identity_test(data: wp.array(dtype=wp.bool)):
33
+ i = wp.tid()
34
+
35
+ data[i] = data[i] and True
36
+ data[i] = data[i] and wp.bool(True)
37
+ data[i] = data[i] and not False
38
+ data[i] = data[i] and not wp.bool(False)
39
+ data[i] = identity_function(data[i], True)
40
+
41
+ if data[i]:
42
+ data[i] = True
43
+ else:
44
+ data[i] = False
45
+
46
+ if not data[i]:
47
+ data[i] = False
48
+ else:
49
+ data[i] = True
50
+
51
+ if data[i] and True:
52
+ data[i] = True
53
+ else:
54
+ data[i] = False
55
+
56
+ if data[i] or False:
57
+ data[i] = True
58
+ else:
59
+ data[i] = False
60
+
61
+ data[i] = wp.where(data[i], True, False)
62
+
63
+
64
+ def test_bool_identity_ops(test, device):
65
+ rng = np.random.default_rng(123)
66
+
67
+ dim_x = 10
68
+
69
+ rand_np = rng.random(dim_x) > 0.5
70
+
71
+ data_array = wp.array(data=rand_np, device=device)
72
+
73
+ test.assertEqual(data_array.dtype, wp.bool)
74
+
75
+ wp.launch(identity_test, dim=data_array.shape, inputs=[data_array], device=device)
76
+
77
+ assert_np_equal(data_array.numpy(), rand_np)
78
+
79
+
80
+ @wp.kernel
81
+ def check_compile_constant(result: wp.array(dtype=wp.bool)):
82
+ if TRUE_CONSTANT:
83
+ result[0] = TRUE_CONSTANT
84
+ else:
85
+ result[0] = False
86
+
87
+
88
+ def test_bool_constant(test, device):
89
+ compile_constant_value = wp.zeros(1, dtype=wp.bool, device=device)
90
+ wp.launch(check_compile_constant, 1, inputs=[compile_constant_value], device=device)
91
+ test.assertTrue(compile_constant_value.numpy()[0])
92
+
93
+ # Repeat the comparison with dtype=bool for the array
94
+ compile_constant_value = wp.zeros(1, dtype=bool, device=device)
95
+ wp.launch(check_compile_constant, 1, inputs=[compile_constant_value], device=device)
96
+ test.assertTrue(compile_constant_value.numpy()[0])
97
+
98
+
99
+ vec3bool = wp.vec(length=3, dtype=wp.bool)
100
+ bool_selector_vec = wp.constant(vec3bool([True, False, True]))
101
+
102
+
103
+ @wp.kernel
104
+ def sum_from_bool_vec(sum_array: wp.array(dtype=wp.int32)):
105
+ i = wp.tid()
106
+
107
+ if bool_selector_vec[0]:
108
+ sum_array[i] = sum_array[i] + 1
109
+ if bool_selector_vec[1]:
110
+ sum_array[i] = sum_array[i] + 2
111
+ if bool_selector_vec[2]:
112
+ sum_array[i] = sum_array[i] + 4
113
+
114
+
115
+ def test_bool_constant_vec(test, device):
116
+ result_array = wp.zeros(10, dtype=wp.int32, device=device)
117
+
118
+ wp.launch(sum_from_bool_vec, result_array.shape, inputs=[result_array], device=device)
119
+
120
+ assert_np_equal(result_array.numpy(), np.full(result_array.shape, 5))
121
+
122
+
123
+ mat22bool = wp.mat((2, 2), dtype=wp.bool)
124
+ bool_selector_mat = wp.constant(mat22bool([True, False, False, True]))
125
+
126
+
127
+ @wp.kernel
128
+ def sum_from_bool_mat(sum_array: wp.array(dtype=wp.int32)):
129
+ i = wp.tid()
130
+
131
+ if bool_selector_mat[0, 0]:
132
+ sum_array[i] = sum_array[i] + 1
133
+ if bool_selector_mat[0, 1]:
134
+ sum_array[i] = sum_array[i] + 2
135
+ if bool_selector_mat[1, 0]:
136
+ sum_array[i] = sum_array[i] + 4
137
+ if bool_selector_mat[1, 1]:
138
+ sum_array[i] = sum_array[i] + 8
139
+
140
+
141
+ def test_bool_constant_mat(test, device):
142
+ result_array = wp.zeros(10, dtype=wp.int32, device=device)
143
+
144
+ wp.launch(sum_from_bool_mat, result_array.shape, inputs=[result_array], device=device)
145
+
146
+ assert_np_equal(result_array.numpy(), np.full(result_array.shape, 9))
147
+
148
+
149
+ vec3bool_type = wp._src.types.vector(length=3, dtype=bool)
150
+
151
+
152
+ @wp.kernel
153
+ def test_bool_vec_anonymous_typing():
154
+ # Zero initialize
155
+ wp.expect_eq(vec3bool_type(), wp.vector(False, False, False))
156
+ # Scalar initialize
157
+ wp.expect_eq(vec3bool_type(True), wp.vector(True, True, True))
158
+ # Component-wise initialize
159
+ wp.expect_eq(vec3bool_type(True, False, True), wp.vector(True, False, True))
160
+
161
+
162
+ def test_bool_vec_typing(test, device):
163
+ # Zero initialize
164
+ vec3bool_z = vec3bool_type()
165
+ test.assertEqual(tuple(vec3bool_z), (False, False, False))
166
+ # Scalar initialize
167
+ vec3bool_s = vec3bool_type(True)
168
+ test.assertEqual(tuple(vec3bool_s), (True, True, True))
169
+ # Component-wise initialize
170
+ vec3bool_c = vec3bool_type(True, False, True)
171
+ test.assertEqual(tuple(vec3bool_c), (True, False, True))
172
+
173
+ wp.launch(test_bool_vec_anonymous_typing, (1,), inputs=[], device=device)
174
+
175
+
176
+ mat22bool_type = wp._src.types.matrix((2, 2), dtype=bool)
177
+
178
+
179
+ @wp.kernel
180
+ def test_bool_mat_anonymous_typing():
181
+ # Zero initialize
182
+ wp.expect_eq(mat22bool_type(), wp.matrix(False, False, False, False, shape=(2, 2)))
183
+ # Scalar initialize
184
+ wp.expect_eq(mat22bool_type(True), wp.matrix(True, True, True, True, shape=(2, 2)))
185
+ # Component-wise initialize
186
+ wp.expect_eq(mat22bool_type(True, False, True, False), wp.matrix(True, False, True, False, shape=(2, 2)))
187
+
188
+
189
+ def test_bool_mat_typing(test, device):
190
+ # Zero initialize
191
+ mat22bool_z = mat22bool_type()
192
+ test.assertEqual(tuple(mat22bool_z), ((False, False), (False, False)))
193
+ # Scalar initialize
194
+ mat22bool_s = mat22bool_type(True)
195
+ test.assertEqual(tuple(mat22bool_s), ((True, True), (True, True)))
196
+ # Component-wise initialize
197
+ mat22bool_c = mat22bool_type(True, False, True, False)
198
+ test.assertEqual(tuple(mat22bool_c), ((True, False), (True, False)))
199
+
200
+ wp.launch(test_bool_mat_anonymous_typing, (1,), inputs=[], device=device)
201
+
202
+
203
+ devices = get_test_devices()
204
+
205
+
206
+ class TestBool(unittest.TestCase):
207
+ pass
208
+
209
+
210
+ add_function_test(TestBool, "test_bool_identity_ops", test_bool_identity_ops, devices=devices)
211
+ add_function_test(TestBool, "test_bool_constant", test_bool_constant, devices=devices)
212
+ add_function_test(TestBool, "test_bool_constant_vec", test_bool_constant_vec, devices=devices)
213
+ add_function_test(TestBool, "test_bool_constant_mat", test_bool_constant_mat, devices=devices)
214
+ add_function_test(TestBool, "test_bool_vec_typing", test_bool_vec_typing, devices=devices)
215
+ add_function_test(TestBool, "test_bool_mat_typing", test_bool_mat_typing, devices=devices)
216
+
217
+
218
+ if __name__ == "__main__":
219
+ wp.clear_kernel_cache()
220
+ unittest.main(verbosity=2)