warp-lang 1.10.0__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (468) hide show
  1. warp/__init__.py +334 -0
  2. warp/__init__.pyi +5856 -0
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1077 -0
  5. warp/_src/build.py +620 -0
  6. warp/_src/build_dll.py +642 -0
  7. warp/_src/builtins.py +10555 -0
  8. warp/_src/codegen.py +4361 -0
  9. warp/_src/config.py +178 -0
  10. warp/_src/constants.py +59 -0
  11. warp/_src/context.py +8352 -0
  12. warp/_src/dlpack.py +464 -0
  13. warp/_src/fabric.py +362 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +510 -0
  16. warp/_src/fem/cache.py +689 -0
  17. warp/_src/fem/dirichlet.py +190 -0
  18. warp/_src/fem/domain.py +553 -0
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +703 -0
  21. warp/_src/fem/field/nodal_field.py +403 -0
  22. warp/_src/fem/field/restriction.py +39 -0
  23. warp/_src/fem/field/virtual.py +1021 -0
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/_src/fem/geometry/adaptive_nanogrid.py +782 -0
  26. warp/_src/fem/geometry/closest_point.py +99 -0
  27. warp/_src/fem/geometry/deformed_geometry.py +277 -0
  28. warp/_src/fem/geometry/element.py +854 -0
  29. warp/_src/fem/geometry/geometry.py +693 -0
  30. warp/_src/fem/geometry/grid_2d.py +478 -0
  31. warp/_src/fem/geometry/grid_3d.py +539 -0
  32. warp/_src/fem/geometry/hexmesh.py +956 -0
  33. warp/_src/fem/geometry/nanogrid.py +660 -0
  34. warp/_src/fem/geometry/partition.py +483 -0
  35. warp/_src/fem/geometry/quadmesh.py +597 -0
  36. warp/_src/fem/geometry/tetmesh.py +762 -0
  37. warp/_src/fem/geometry/trimesh.py +588 -0
  38. warp/_src/fem/integrate.py +2507 -0
  39. warp/_src/fem/linalg.py +385 -0
  40. warp/_src/fem/operator.py +398 -0
  41. warp/_src/fem/polynomial.py +231 -0
  42. warp/_src/fem/quadrature/__init__.py +17 -0
  43. warp/_src/fem/quadrature/pic_quadrature.py +318 -0
  44. warp/_src/fem/quadrature/quadrature.py +665 -0
  45. warp/_src/fem/space/__init__.py +248 -0
  46. warp/_src/fem/space/basis_function_space.py +499 -0
  47. warp/_src/fem/space/basis_space.py +681 -0
  48. warp/_src/fem/space/dof_mapper.py +253 -0
  49. warp/_src/fem/space/function_space.py +312 -0
  50. warp/_src/fem/space/grid_2d_function_space.py +179 -0
  51. warp/_src/fem/space/grid_3d_function_space.py +229 -0
  52. warp/_src/fem/space/hexmesh_function_space.py +255 -0
  53. warp/_src/fem/space/nanogrid_function_space.py +199 -0
  54. warp/_src/fem/space/partition.py +435 -0
  55. warp/_src/fem/space/quadmesh_function_space.py +222 -0
  56. warp/_src/fem/space/restriction.py +221 -0
  57. warp/_src/fem/space/shape/__init__.py +152 -0
  58. warp/_src/fem/space/shape/cube_shape_function.py +1107 -0
  59. warp/_src/fem/space/shape/shape_function.py +134 -0
  60. warp/_src/fem/space/shape/square_shape_function.py +928 -0
  61. warp/_src/fem/space/shape/tet_shape_function.py +829 -0
  62. warp/_src/fem/space/shape/triangle_shape_function.py +674 -0
  63. warp/_src/fem/space/tetmesh_function_space.py +270 -0
  64. warp/_src/fem/space/topology.py +461 -0
  65. warp/_src/fem/space/trimesh_function_space.py +193 -0
  66. warp/_src/fem/types.py +114 -0
  67. warp/_src/fem/utils.py +488 -0
  68. warp/_src/jax.py +188 -0
  69. warp/_src/jax_experimental/__init__.py +14 -0
  70. warp/_src/jax_experimental/custom_call.py +389 -0
  71. warp/_src/jax_experimental/ffi.py +1286 -0
  72. warp/_src/jax_experimental/xla_ffi.py +658 -0
  73. warp/_src/marching_cubes.py +710 -0
  74. warp/_src/math.py +416 -0
  75. warp/_src/optim/__init__.py +14 -0
  76. warp/_src/optim/adam.py +165 -0
  77. warp/_src/optim/linear.py +1608 -0
  78. warp/_src/optim/sgd.py +114 -0
  79. warp/_src/paddle.py +408 -0
  80. warp/_src/render/__init__.py +14 -0
  81. warp/_src/render/imgui_manager.py +291 -0
  82. warp/_src/render/render_opengl.py +3638 -0
  83. warp/_src/render/render_usd.py +939 -0
  84. warp/_src/render/utils.py +162 -0
  85. warp/_src/sparse.py +2718 -0
  86. warp/_src/tape.py +1208 -0
  87. warp/_src/thirdparty/__init__.py +0 -0
  88. warp/_src/thirdparty/appdirs.py +598 -0
  89. warp/_src/thirdparty/dlpack.py +145 -0
  90. warp/_src/thirdparty/unittest_parallel.py +676 -0
  91. warp/_src/torch.py +393 -0
  92. warp/_src/types.py +5888 -0
  93. warp/_src/utils.py +1695 -0
  94. warp/autograd.py +33 -0
  95. warp/bin/libwarp-clang.dylib +0 -0
  96. warp/bin/libwarp.dylib +0 -0
  97. warp/build.py +29 -0
  98. warp/build_dll.py +24 -0
  99. warp/codegen.py +24 -0
  100. warp/constants.py +24 -0
  101. warp/context.py +33 -0
  102. warp/dlpack.py +24 -0
  103. warp/examples/__init__.py +24 -0
  104. warp/examples/assets/bear.usd +0 -0
  105. warp/examples/assets/bunny.usd +0 -0
  106. warp/examples/assets/cube.usd +0 -0
  107. warp/examples/assets/nonuniform.usd +0 -0
  108. warp/examples/assets/nvidia_logo.png +0 -0
  109. warp/examples/assets/pixel.jpg +0 -0
  110. warp/examples/assets/rocks.nvdb +0 -0
  111. warp/examples/assets/rocks.usd +0 -0
  112. warp/examples/assets/sphere.usd +0 -0
  113. warp/examples/assets/square_cloth.usd +0 -0
  114. warp/examples/benchmarks/benchmark_api.py +389 -0
  115. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  116. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  117. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  118. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  119. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  120. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  121. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  122. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  123. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  124. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  125. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  126. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  127. warp/examples/benchmarks/benchmark_launches.py +301 -0
  128. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  129. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  130. warp/examples/browse.py +37 -0
  131. warp/examples/core/example_cupy.py +86 -0
  132. warp/examples/core/example_dem.py +241 -0
  133. warp/examples/core/example_fluid.py +299 -0
  134. warp/examples/core/example_graph_capture.py +150 -0
  135. warp/examples/core/example_marching_cubes.py +195 -0
  136. warp/examples/core/example_mesh.py +180 -0
  137. warp/examples/core/example_mesh_intersect.py +211 -0
  138. warp/examples/core/example_nvdb.py +182 -0
  139. warp/examples/core/example_raycast.py +111 -0
  140. warp/examples/core/example_raymarch.py +205 -0
  141. warp/examples/core/example_render_opengl.py +290 -0
  142. warp/examples/core/example_sample_mesh.py +300 -0
  143. warp/examples/core/example_sph.py +411 -0
  144. warp/examples/core/example_spin_lock.py +93 -0
  145. warp/examples/core/example_torch.py +211 -0
  146. warp/examples/core/example_wave.py +269 -0
  147. warp/examples/core/example_work_queue.py +118 -0
  148. warp/examples/distributed/example_jacobi_mpi.py +506 -0
  149. warp/examples/fem/example_adaptive_grid.py +286 -0
  150. warp/examples/fem/example_apic_fluid.py +469 -0
  151. warp/examples/fem/example_burgers.py +261 -0
  152. warp/examples/fem/example_convection_diffusion.py +181 -0
  153. warp/examples/fem/example_convection_diffusion_dg.py +225 -0
  154. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  155. warp/examples/fem/example_deformed_geometry.py +172 -0
  156. warp/examples/fem/example_diffusion.py +196 -0
  157. warp/examples/fem/example_diffusion_3d.py +225 -0
  158. warp/examples/fem/example_diffusion_mgpu.py +225 -0
  159. warp/examples/fem/example_distortion_energy.py +228 -0
  160. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  161. warp/examples/fem/example_magnetostatics.py +242 -0
  162. warp/examples/fem/example_mixed_elasticity.py +293 -0
  163. warp/examples/fem/example_navier_stokes.py +263 -0
  164. warp/examples/fem/example_nonconforming_contact.py +300 -0
  165. warp/examples/fem/example_stokes.py +213 -0
  166. warp/examples/fem/example_stokes_transfer.py +262 -0
  167. warp/examples/fem/example_streamlines.py +357 -0
  168. warp/examples/fem/utils.py +1047 -0
  169. warp/examples/interop/example_jax_callable.py +146 -0
  170. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  171. warp/examples/interop/example_jax_kernel.py +232 -0
  172. warp/examples/optim/example_diffray.py +561 -0
  173. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  174. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  175. warp/examples/tile/example_tile_cholesky.py +88 -0
  176. warp/examples/tile/example_tile_convolution.py +66 -0
  177. warp/examples/tile/example_tile_fft.py +55 -0
  178. warp/examples/tile/example_tile_filtering.py +113 -0
  179. warp/examples/tile/example_tile_matmul.py +85 -0
  180. warp/examples/tile/example_tile_mcgp.py +191 -0
  181. warp/examples/tile/example_tile_mlp.py +385 -0
  182. warp/examples/tile/example_tile_nbody.py +199 -0
  183. warp/fabric.py +24 -0
  184. warp/fem/__init__.py +173 -0
  185. warp/fem/adaptivity.py +26 -0
  186. warp/fem/cache.py +30 -0
  187. warp/fem/dirichlet.py +24 -0
  188. warp/fem/field/__init__.py +24 -0
  189. warp/fem/field/field.py +26 -0
  190. warp/fem/geometry/__init__.py +21 -0
  191. warp/fem/geometry/closest_point.py +31 -0
  192. warp/fem/linalg.py +38 -0
  193. warp/fem/operator.py +32 -0
  194. warp/fem/polynomial.py +29 -0
  195. warp/fem/space/__init__.py +22 -0
  196. warp/fem/space/basis_space.py +24 -0
  197. warp/fem/space/shape/__init__.py +68 -0
  198. warp/fem/space/topology.py +24 -0
  199. warp/fem/types.py +24 -0
  200. warp/fem/utils.py +32 -0
  201. warp/jax.py +29 -0
  202. warp/jax_experimental/__init__.py +29 -0
  203. warp/jax_experimental/custom_call.py +29 -0
  204. warp/jax_experimental/ffi.py +39 -0
  205. warp/jax_experimental/xla_ffi.py +24 -0
  206. warp/marching_cubes.py +24 -0
  207. warp/math.py +37 -0
  208. warp/native/array.h +1687 -0
  209. warp/native/builtin.h +2327 -0
  210. warp/native/bvh.cpp +562 -0
  211. warp/native/bvh.cu +826 -0
  212. warp/native/bvh.h +555 -0
  213. warp/native/clang/clang.cpp +541 -0
  214. warp/native/coloring.cpp +622 -0
  215. warp/native/crt.cpp +51 -0
  216. warp/native/crt.h +568 -0
  217. warp/native/cuda_crt.h +1058 -0
  218. warp/native/cuda_util.cpp +677 -0
  219. warp/native/cuda_util.h +313 -0
  220. warp/native/error.cpp +77 -0
  221. warp/native/error.h +36 -0
  222. warp/native/exports.h +2023 -0
  223. warp/native/fabric.h +246 -0
  224. warp/native/hashgrid.cpp +311 -0
  225. warp/native/hashgrid.cu +89 -0
  226. warp/native/hashgrid.h +240 -0
  227. warp/native/initializer_array.h +41 -0
  228. warp/native/intersect.h +1253 -0
  229. warp/native/intersect_adj.h +375 -0
  230. warp/native/intersect_tri.h +348 -0
  231. warp/native/mat.h +5189 -0
  232. warp/native/mathdx.cpp +93 -0
  233. warp/native/matnn.h +221 -0
  234. warp/native/mesh.cpp +266 -0
  235. warp/native/mesh.cu +406 -0
  236. warp/native/mesh.h +2097 -0
  237. warp/native/nanovdb/GridHandle.h +533 -0
  238. warp/native/nanovdb/HostBuffer.h +591 -0
  239. warp/native/nanovdb/NanoVDB.h +6246 -0
  240. warp/native/nanovdb/NodeManager.h +323 -0
  241. warp/native/nanovdb/PNanoVDB.h +3390 -0
  242. warp/native/noise.h +859 -0
  243. warp/native/quat.h +1664 -0
  244. warp/native/rand.h +342 -0
  245. warp/native/range.h +145 -0
  246. warp/native/reduce.cpp +174 -0
  247. warp/native/reduce.cu +363 -0
  248. warp/native/runlength_encode.cpp +79 -0
  249. warp/native/runlength_encode.cu +61 -0
  250. warp/native/scan.cpp +47 -0
  251. warp/native/scan.cu +55 -0
  252. warp/native/scan.h +23 -0
  253. warp/native/solid_angle.h +466 -0
  254. warp/native/sort.cpp +251 -0
  255. warp/native/sort.cu +286 -0
  256. warp/native/sort.h +35 -0
  257. warp/native/sparse.cpp +241 -0
  258. warp/native/sparse.cu +435 -0
  259. warp/native/spatial.h +1306 -0
  260. warp/native/svd.h +727 -0
  261. warp/native/temp_buffer.h +46 -0
  262. warp/native/tile.h +4124 -0
  263. warp/native/tile_radix_sort.h +1112 -0
  264. warp/native/tile_reduce.h +838 -0
  265. warp/native/tile_scan.h +240 -0
  266. warp/native/tuple.h +189 -0
  267. warp/native/vec.h +2199 -0
  268. warp/native/version.h +23 -0
  269. warp/native/volume.cpp +501 -0
  270. warp/native/volume.cu +68 -0
  271. warp/native/volume.h +970 -0
  272. warp/native/volume_builder.cu +483 -0
  273. warp/native/volume_builder.h +52 -0
  274. warp/native/volume_impl.h +70 -0
  275. warp/native/warp.cpp +1143 -0
  276. warp/native/warp.cu +4604 -0
  277. warp/native/warp.h +358 -0
  278. warp/optim/__init__.py +20 -0
  279. warp/optim/adam.py +24 -0
  280. warp/optim/linear.py +35 -0
  281. warp/optim/sgd.py +24 -0
  282. warp/paddle.py +24 -0
  283. warp/py.typed +0 -0
  284. warp/render/__init__.py +22 -0
  285. warp/render/imgui_manager.py +29 -0
  286. warp/render/render_opengl.py +24 -0
  287. warp/render/render_usd.py +24 -0
  288. warp/render/utils.py +24 -0
  289. warp/sparse.py +51 -0
  290. warp/tape.py +24 -0
  291. warp/tests/__init__.py +1 -0
  292. warp/tests/__main__.py +4 -0
  293. warp/tests/assets/curlnoise_golden.npy +0 -0
  294. warp/tests/assets/mlp_golden.npy +0 -0
  295. warp/tests/assets/pixel.npy +0 -0
  296. warp/tests/assets/pnoise_golden.npy +0 -0
  297. warp/tests/assets/spiky.usd +0 -0
  298. warp/tests/assets/test_grid.nvdb +0 -0
  299. warp/tests/assets/test_index_grid.nvdb +0 -0
  300. warp/tests/assets/test_int32_grid.nvdb +0 -0
  301. warp/tests/assets/test_vec_grid.nvdb +0 -0
  302. warp/tests/assets/torus.nvdb +0 -0
  303. warp/tests/assets/torus.usda +105 -0
  304. warp/tests/aux_test_class_kernel.py +34 -0
  305. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  306. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  307. warp/tests/aux_test_dependent.py +29 -0
  308. warp/tests/aux_test_grad_customs.py +29 -0
  309. warp/tests/aux_test_instancing_gc.py +26 -0
  310. warp/tests/aux_test_module_aot.py +7 -0
  311. warp/tests/aux_test_module_unload.py +23 -0
  312. warp/tests/aux_test_name_clash1.py +40 -0
  313. warp/tests/aux_test_name_clash2.py +40 -0
  314. warp/tests/aux_test_reference.py +9 -0
  315. warp/tests/aux_test_reference_reference.py +8 -0
  316. warp/tests/aux_test_square.py +16 -0
  317. warp/tests/aux_test_unresolved_func.py +22 -0
  318. warp/tests/aux_test_unresolved_symbol.py +22 -0
  319. warp/tests/cuda/__init__.py +0 -0
  320. warp/tests/cuda/test_async.py +676 -0
  321. warp/tests/cuda/test_conditional_captures.py +1147 -0
  322. warp/tests/cuda/test_ipc.py +124 -0
  323. warp/tests/cuda/test_mempool.py +233 -0
  324. warp/tests/cuda/test_multigpu.py +169 -0
  325. warp/tests/cuda/test_peer.py +139 -0
  326. warp/tests/cuda/test_pinned.py +84 -0
  327. warp/tests/cuda/test_streams.py +691 -0
  328. warp/tests/geometry/__init__.py +0 -0
  329. warp/tests/geometry/test_bvh.py +335 -0
  330. warp/tests/geometry/test_hash_grid.py +259 -0
  331. warp/tests/geometry/test_marching_cubes.py +294 -0
  332. warp/tests/geometry/test_mesh.py +318 -0
  333. warp/tests/geometry/test_mesh_query_aabb.py +392 -0
  334. warp/tests/geometry/test_mesh_query_point.py +935 -0
  335. warp/tests/geometry/test_mesh_query_ray.py +323 -0
  336. warp/tests/geometry/test_volume.py +1103 -0
  337. warp/tests/geometry/test_volume_write.py +346 -0
  338. warp/tests/interop/__init__.py +0 -0
  339. warp/tests/interop/test_dlpack.py +730 -0
  340. warp/tests/interop/test_jax.py +1673 -0
  341. warp/tests/interop/test_paddle.py +800 -0
  342. warp/tests/interop/test_torch.py +1001 -0
  343. warp/tests/run_coverage_serial.py +39 -0
  344. warp/tests/test_adam.py +162 -0
  345. warp/tests/test_arithmetic.py +1096 -0
  346. warp/tests/test_array.py +3756 -0
  347. warp/tests/test_array_reduce.py +156 -0
  348. warp/tests/test_assert.py +303 -0
  349. warp/tests/test_atomic.py +336 -0
  350. warp/tests/test_atomic_bitwise.py +209 -0
  351. warp/tests/test_atomic_cas.py +312 -0
  352. warp/tests/test_bool.py +220 -0
  353. warp/tests/test_builtins_resolution.py +732 -0
  354. warp/tests/test_closest_point_edge_edge.py +327 -0
  355. warp/tests/test_codegen.py +974 -0
  356. warp/tests/test_codegen_instancing.py +1495 -0
  357. warp/tests/test_compile_consts.py +215 -0
  358. warp/tests/test_conditional.py +298 -0
  359. warp/tests/test_context.py +35 -0
  360. warp/tests/test_copy.py +319 -0
  361. warp/tests/test_ctypes.py +618 -0
  362. warp/tests/test_dense.py +73 -0
  363. warp/tests/test_devices.py +127 -0
  364. warp/tests/test_enum.py +136 -0
  365. warp/tests/test_examples.py +424 -0
  366. warp/tests/test_fabricarray.py +998 -0
  367. warp/tests/test_fast_math.py +72 -0
  368. warp/tests/test_fem.py +2204 -0
  369. warp/tests/test_fixedarray.py +229 -0
  370. warp/tests/test_fp16.py +136 -0
  371. warp/tests/test_func.py +501 -0
  372. warp/tests/test_future_annotations.py +100 -0
  373. warp/tests/test_generics.py +656 -0
  374. warp/tests/test_grad.py +893 -0
  375. warp/tests/test_grad_customs.py +339 -0
  376. warp/tests/test_grad_debug.py +341 -0
  377. warp/tests/test_implicit_init.py +411 -0
  378. warp/tests/test_import.py +45 -0
  379. warp/tests/test_indexedarray.py +1140 -0
  380. warp/tests/test_intersect.py +103 -0
  381. warp/tests/test_iter.py +76 -0
  382. warp/tests/test_large.py +177 -0
  383. warp/tests/test_launch.py +411 -0
  384. warp/tests/test_lerp.py +151 -0
  385. warp/tests/test_linear_solvers.py +223 -0
  386. warp/tests/test_lvalue.py +427 -0
  387. warp/tests/test_map.py +526 -0
  388. warp/tests/test_mat.py +3515 -0
  389. warp/tests/test_mat_assign_copy.py +178 -0
  390. warp/tests/test_mat_constructors.py +573 -0
  391. warp/tests/test_mat_lite.py +122 -0
  392. warp/tests/test_mat_scalar_ops.py +2913 -0
  393. warp/tests/test_math.py +212 -0
  394. warp/tests/test_module_aot.py +287 -0
  395. warp/tests/test_module_hashing.py +258 -0
  396. warp/tests/test_modules_lite.py +70 -0
  397. warp/tests/test_noise.py +252 -0
  398. warp/tests/test_operators.py +299 -0
  399. warp/tests/test_options.py +129 -0
  400. warp/tests/test_overwrite.py +551 -0
  401. warp/tests/test_print.py +408 -0
  402. warp/tests/test_quat.py +2653 -0
  403. warp/tests/test_quat_assign_copy.py +145 -0
  404. warp/tests/test_rand.py +339 -0
  405. warp/tests/test_reload.py +303 -0
  406. warp/tests/test_rounding.py +157 -0
  407. warp/tests/test_runlength_encode.py +196 -0
  408. warp/tests/test_scalar_ops.py +133 -0
  409. warp/tests/test_smoothstep.py +108 -0
  410. warp/tests/test_snippet.py +318 -0
  411. warp/tests/test_sparse.py +845 -0
  412. warp/tests/test_spatial.py +2859 -0
  413. warp/tests/test_spatial_assign_copy.py +160 -0
  414. warp/tests/test_special_values.py +361 -0
  415. warp/tests/test_static.py +640 -0
  416. warp/tests/test_struct.py +901 -0
  417. warp/tests/test_tape.py +242 -0
  418. warp/tests/test_transient_module.py +93 -0
  419. warp/tests/test_triangle_closest_point.py +192 -0
  420. warp/tests/test_tuple.py +361 -0
  421. warp/tests/test_types.py +615 -0
  422. warp/tests/test_utils.py +594 -0
  423. warp/tests/test_vec.py +1408 -0
  424. warp/tests/test_vec_assign_copy.py +143 -0
  425. warp/tests/test_vec_constructors.py +325 -0
  426. warp/tests/test_vec_lite.py +80 -0
  427. warp/tests/test_vec_scalar_ops.py +2327 -0
  428. warp/tests/test_verify_fp.py +100 -0
  429. warp/tests/test_version.py +75 -0
  430. warp/tests/tile/__init__.py +0 -0
  431. warp/tests/tile/test_tile.py +1519 -0
  432. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  433. warp/tests/tile/test_tile_cholesky.py +608 -0
  434. warp/tests/tile/test_tile_load.py +724 -0
  435. warp/tests/tile/test_tile_mathdx.py +156 -0
  436. warp/tests/tile/test_tile_matmul.py +179 -0
  437. warp/tests/tile/test_tile_mlp.py +400 -0
  438. warp/tests/tile/test_tile_reduce.py +950 -0
  439. warp/tests/tile/test_tile_shared_memory.py +376 -0
  440. warp/tests/tile/test_tile_sort.py +121 -0
  441. warp/tests/tile/test_tile_view.py +173 -0
  442. warp/tests/unittest_serial.py +47 -0
  443. warp/tests/unittest_suites.py +430 -0
  444. warp/tests/unittest_utils.py +469 -0
  445. warp/tests/walkthrough_debug.py +95 -0
  446. warp/torch.py +24 -0
  447. warp/types.py +51 -0
  448. warp/utils.py +31 -0
  449. warp_lang-1.10.0.dist-info/METADATA +459 -0
  450. warp_lang-1.10.0.dist-info/RECORD +468 -0
  451. warp_lang-1.10.0.dist-info/WHEEL +5 -0
  452. warp_lang-1.10.0.dist-info/licenses/LICENSE.md +176 -0
  453. warp_lang-1.10.0.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  454. warp_lang-1.10.0.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  455. warp_lang-1.10.0.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  456. warp_lang-1.10.0.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  457. warp_lang-1.10.0.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  458. warp_lang-1.10.0.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  459. warp_lang-1.10.0.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  460. warp_lang-1.10.0.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  461. warp_lang-1.10.0.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  462. warp_lang-1.10.0.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  463. warp_lang-1.10.0.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  464. warp_lang-1.10.0.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  465. warp_lang-1.10.0.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  466. warp_lang-1.10.0.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  467. warp_lang-1.10.0.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  468. warp_lang-1.10.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,573 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import unittest
17
+
18
+ import numpy as np
19
+
20
+ import warp as wp
21
+ from warp.tests.unittest_utils import *
22
+
23
+ np_float_types = [np.float16, np.float32, np.float64]
24
+
25
+ kernel_cache = {}
26
+
27
+
28
+ def getkernel(func, suffix=""):
29
+ key = func.__name__ + "_" + suffix
30
+ if key not in kernel_cache:
31
+ kernel_cache[key] = wp.Kernel(func=func, key=key)
32
+ return kernel_cache[key]
33
+
34
+
35
+ def get_select_kernel(dtype):
36
+ def output_select_kernel_fn(input: wp.array(dtype=dtype), index: int, out: wp.array(dtype=dtype)):
37
+ out[0] = input[index]
38
+
39
+ return getkernel(output_select_kernel_fn, suffix=dtype.__name__)
40
+
41
+
42
+ def test_anon_constructor_error_shape_arg_missing(test, device):
43
+ @wp.kernel
44
+ def kernel():
45
+ wp.matrix(1.0, 2.0, 3.0)
46
+
47
+ with test.assertRaisesRegex(
48
+ RuntimeError,
49
+ r"the `shape` argument must be specified when initializing a matrix by value$",
50
+ ):
51
+ wp.launch(kernel, dim=1, inputs=[], device=device)
52
+
53
+
54
+ def test_anon_constructor_error_shape_mismatch(test, device):
55
+ @wp.kernel
56
+ def kernel():
57
+ wp.matrix(wp.matrix(shape=(1, 2), dtype=float), shape=(3, 4), dtype=float)
58
+
59
+ with test.assertRaisesRegex(
60
+ RuntimeError,
61
+ r"incompatible matrix of shape \(3, 4\) given when copy constructing a matrix of shape \(1, 2\)$",
62
+ ):
63
+ wp.launch(kernel, dim=1, inputs=[], device=device)
64
+
65
+
66
+ def test_anon_constructor_error_type_mismatch(test, device):
67
+ @wp.kernel
68
+ def kernel():
69
+ wp.matrix(1.0, shape=(3, 2), dtype=wp.float16)
70
+
71
+ with test.assertRaisesRegex(
72
+ RuntimeError,
73
+ r"the value used to fill this matrix is expected to be of the type `float16`$",
74
+ ):
75
+ wp.launch(kernel, dim=1, inputs=[], device=device)
76
+
77
+
78
+ def test_anon_constructor_error_invalid_arg_count(test, device):
79
+ @wp.kernel
80
+ def kernel():
81
+ wp.matrix(1.0, 2.0, 3.0, shape=(2, 2), dtype=float)
82
+
83
+ with test.assertRaisesRegex(
84
+ RuntimeError,
85
+ r"incompatible number of values given \(3\) when constructing a matrix of shape \(2, 2\)$",
86
+ ):
87
+ wp.launch(kernel, dim=1, inputs=[], device=device)
88
+
89
+
90
+ def test_tpl_constructor_error_incompatible_sizes(test, device):
91
+ @wp.kernel
92
+ def kernel():
93
+ wp.mat33(wp.mat22(1.0, 2.0, 3.0, 4.0))
94
+
95
+ with test.assertRaisesRegex(
96
+ RuntimeError,
97
+ r"incompatible matrix of shape \(3, 3\) given when copy constructing a matrix of shape \(2, 2\)$",
98
+ ):
99
+ wp.launch(kernel, dim=1, inputs=[], device=device)
100
+
101
+
102
+ def test_tpl_constructor_error_invalid_arg_count(test, device):
103
+ @wp.kernel
104
+ def kernel():
105
+ wp.mat22(1.0, 2.0, 3.0)
106
+
107
+ with test.assertRaisesRegex(
108
+ RuntimeError,
109
+ r"incompatible number of values given \(3\) when constructing a matrix of shape \(2, 2\)$",
110
+ ):
111
+ wp.launch(kernel, dim=1, inputs=[], device=device)
112
+
113
+
114
+ def test_matrix_from_vecs_runtime(test, device):
115
+ m1 = wp.matrix_from_cols(
116
+ wp.vec3(1.0, 2.0, 3.0),
117
+ wp.vec3(4.0, 5.0, 6.0),
118
+ wp.vec3(7.0, 8.0, 9.0),
119
+ )
120
+ assert m1[0, 0] == 1.0
121
+ assert m1[0, 1] == 4.0
122
+ assert m1[0, 2] == 7.0
123
+ assert m1[1, 0] == 2.0
124
+ assert m1[1, 1] == 5.0
125
+ assert m1[1, 2] == 8.0
126
+ assert m1[2, 0] == 3.0
127
+ assert m1[2, 1] == 6.0
128
+ assert m1[2, 2] == 9.0
129
+
130
+ assert m1.get_row(0) == wp.vec3(1.0, 4.0, 7.0)
131
+ assert m1.get_row(1) == wp.vec3(2.0, 5.0, 8.0)
132
+ assert m1.get_row(2) == wp.vec3(3.0, 6.0, 9.0)
133
+ assert m1.get_col(0) == wp.vec3(1.0, 2.0, 3.0)
134
+ assert m1.get_col(1) == wp.vec3(4.0, 5.0, 6.0)
135
+ assert m1.get_col(2) == wp.vec3(7.0, 8.0, 9.0)
136
+
137
+ m1.set_row(0, wp.vec3(8.0, 9.0, 10.0))
138
+ m1.set_row(1, wp.vec3(11.0, 12.0, 13.0))
139
+ m1.set_row(2, wp.vec3(14.0, 15.0, 16.0))
140
+
141
+ assert m1 == wp.matrix_from_rows(
142
+ wp.vec3(8.0, 9.0, 10.0),
143
+ wp.vec3(11.0, 12.0, 13.0),
144
+ wp.vec3(14.0, 15.0, 16.0),
145
+ )
146
+
147
+ m1.set_col(0, wp.vec3(8.0, 9.0, 10.0))
148
+ m1.set_col(1, wp.vec3(11.0, 12.0, 13.0))
149
+ m1.set_col(2, wp.vec3(14.0, 15.0, 16.0))
150
+
151
+ assert m1 == wp.matrix_from_cols(
152
+ wp.vec3(8.0, 9.0, 10.0),
153
+ wp.vec3(11.0, 12.0, 13.0),
154
+ wp.vec3(14.0, 15.0, 16.0),
155
+ )
156
+
157
+ m2 = wp.matrix_from_rows(
158
+ wp.vec3(1.0, 2.0, 3.0),
159
+ wp.vec3(4.0, 5.0, 6.0),
160
+ wp.vec3(7.0, 8.0, 9.0),
161
+ )
162
+ assert m2[0, 0] == 1.0
163
+ assert m2[0, 1] == 2.0
164
+ assert m2[0, 2] == 3.0
165
+ assert m2[1, 0] == 4.0
166
+ assert m2[1, 1] == 5.0
167
+ assert m2[1, 2] == 6.0
168
+ assert m2[2, 0] == 7.0
169
+ assert m2[2, 1] == 8.0
170
+ assert m2[2, 2] == 9.0
171
+
172
+ assert m2.get_row(0) == wp.vec3(1.0, 2.0, 3.0)
173
+ assert m2.get_row(1) == wp.vec3(4.0, 5.0, 6.0)
174
+ assert m2.get_row(2) == wp.vec3(7.0, 8.0, 9.0)
175
+ assert m2.get_col(0) == wp.vec3(1.0, 4.0, 7.0)
176
+ assert m2.get_col(1) == wp.vec3(2.0, 5.0, 8.0)
177
+ assert m2.get_col(2) == wp.vec3(3.0, 6.0, 9.0)
178
+
179
+ m2.set_row(0, wp.vec3(8.0, 9.0, 10.0))
180
+ m2.set_row(1, wp.vec3(11.0, 12.0, 13.0))
181
+ m2.set_row(2, wp.vec3(14.0, 15.0, 16.0))
182
+
183
+ assert m2 == wp.matrix_from_rows(
184
+ wp.vec3(8.0, 9.0, 10.0),
185
+ wp.vec3(11.0, 12.0, 13.0),
186
+ wp.vec3(14.0, 15.0, 16.0),
187
+ )
188
+
189
+ m2.set_col(0, wp.vec3(8.0, 9.0, 10.0))
190
+ m2.set_col(1, wp.vec3(11.0, 12.0, 13.0))
191
+ m2.set_col(2, wp.vec3(14.0, 15.0, 16.0))
192
+
193
+ assert m2 == wp.matrix_from_cols(
194
+ wp.vec3(8.0, 9.0, 10.0),
195
+ wp.vec3(11.0, 12.0, 13.0),
196
+ wp.vec3(14.0, 15.0, 16.0),
197
+ )
198
+
199
+ m3 = wp.matrix_from_cols(
200
+ wp.vec3(1.0, 2.0, 3.0),
201
+ wp.vec3(4.0, 5.0, 6.0),
202
+ )
203
+ assert m3[0, 0] == 1.0
204
+ assert m3[0, 1] == 4.0
205
+ assert m3[1, 0] == 2.0
206
+ assert m3[1, 1] == 5.0
207
+ assert m3[2, 0] == 3.0
208
+ assert m3[2, 1] == 6.0
209
+
210
+ assert m3.get_row(0) == wp.vec2(1.0, 4.0)
211
+ assert m3.get_row(1) == wp.vec2(2.0, 5.0)
212
+ assert m3.get_row(2) == wp.vec2(3.0, 6.0)
213
+ assert m3.get_col(0) == wp.vec3(1.0, 2.0, 3.0)
214
+ assert m3.get_col(1) == wp.vec3(4.0, 5.0, 6.0)
215
+
216
+ m3.set_row(0, wp.vec2(7.0, 8.0))
217
+ m3.set_row(1, wp.vec2(9.0, 10.0))
218
+ m3.set_row(2, wp.vec2(11.0, 12.0))
219
+
220
+ assert m3 == wp.matrix_from_rows(
221
+ wp.vec2(7.0, 8.0),
222
+ wp.vec2(9.0, 10.0),
223
+ wp.vec2(11.0, 12.0),
224
+ )
225
+
226
+ m3.set_col(0, wp.vec3(7.0, 8.0, 9.0))
227
+ m3.set_col(1, wp.vec3(10.0, 11.0, 12.0))
228
+
229
+ assert m3 == wp.matrix_from_cols(
230
+ wp.vec3(7.0, 8.0, 9.0),
231
+ wp.vec3(10.0, 11.0, 12.0),
232
+ )
233
+
234
+ m4 = wp.matrix_from_rows(
235
+ wp.vec3(1.0, 2.0, 3.0),
236
+ wp.vec3(4.0, 5.0, 6.0),
237
+ )
238
+ assert m4[0, 0] == 1.0
239
+ assert m4[0, 1] == 2.0
240
+ assert m4[0, 2] == 3.0
241
+ assert m4[1, 0] == 4.0
242
+ assert m4[1, 1] == 5.0
243
+ assert m4[1, 2] == 6.0
244
+
245
+ assert m4.get_row(0) == wp.vec3(1.0, 2.0, 3.0)
246
+ assert m4.get_row(1) == wp.vec3(4.0, 5.0, 6.0)
247
+ assert m4.get_col(0) == wp.vec2(1.0, 4.0)
248
+ assert m4.get_col(1) == wp.vec2(2.0, 5.0)
249
+ assert m4.get_col(2) == wp.vec2(3.0, 6.0)
250
+
251
+ m4.set_row(0, wp.vec3(7.0, 8.0, 9.0))
252
+ m4.set_row(1, wp.vec3(10.0, 11.0, 12.0))
253
+
254
+ assert m4 == wp.matrix_from_rows(
255
+ wp.vec3(7.0, 8.0, 9.0),
256
+ wp.vec3(10.0, 11.0, 12.0),
257
+ )
258
+
259
+ m4.set_col(0, wp.vec2(7.0, 8.0))
260
+ m4.set_col(1, wp.vec2(9.0, 10.0))
261
+ m4.set_col(2, wp.vec2(11.0, 12.0))
262
+
263
+ assert m4 == wp.matrix_from_cols(
264
+ wp.vec2(7.0, 8.0),
265
+ wp.vec2(9.0, 10.0),
266
+ wp.vec2(11.0, 12.0),
267
+ )
268
+
269
+ m4.set_row(0, 13.0)
270
+
271
+ assert m4 == wp.matrix_from_rows(
272
+ wp.vec3(13.0, 13.0, 13.0),
273
+ wp.vec3(8.0, 10.0, 12.0),
274
+ )
275
+
276
+ m4.set_col(2, 14.0)
277
+
278
+ assert m4 == wp.matrix_from_rows(
279
+ wp.vec3(13.0, 13.0, 14.0),
280
+ wp.vec3(8.0, 10.0, 14.0),
281
+ )
282
+
283
+
284
+ # Test matrix constructors using explicit type (float16)
285
+ # note that these tests are specifically not using generics / closure
286
+ # args to create kernels dynamically (like the rest of this file)
287
+ # as those use different code paths to resolve arg types which
288
+ # has lead to regressions.
289
+ @wp.kernel
290
+ def test_constructors_explicit_precision():
291
+ # construction for custom matrix types
292
+ eye = wp.identity(dtype=wp.float16, n=2)
293
+ zeros = wp.matrix(shape=(2, 2), dtype=wp.float16)
294
+ custom = wp.matrix(wp.float16(0.0), wp.float16(1.0), wp.float16(2.0), wp.float16(3.0), shape=(2, 2))
295
+
296
+ for i in range(2):
297
+ for j in range(2):
298
+ if i == j:
299
+ wp.expect_eq(eye[i, j], wp.float16(1.0))
300
+ else:
301
+ wp.expect_eq(eye[i, j], wp.float16(0.0))
302
+
303
+ wp.expect_eq(zeros[i, j], wp.float16(0.0))
304
+ wp.expect_eq(custom[i, j], wp.float16(i) * wp.float16(2.0) + wp.float16(j))
305
+
306
+
307
+ # Same as above but with a default (float/int) type
308
+ # which tests some different code paths that
309
+ # need to ensure types are correctly canonicalized
310
+ # during codegen
311
+ @wp.kernel
312
+ def test_constructors_default_precision():
313
+ # construction for default (float) matrix types
314
+ eye = wp.identity(dtype=float, n=2)
315
+ zeros = wp.matrix(shape=(2, 2), dtype=float)
316
+ custom = wp.matrix(0.0, 1.0, 2.0, 3.0, shape=(2, 2))
317
+
318
+ for i in range(2):
319
+ for j in range(2):
320
+ if i == j:
321
+ wp.expect_eq(eye[i, j], 1.0)
322
+ else:
323
+ wp.expect_eq(eye[i, j], 0.0)
324
+
325
+ wp.expect_eq(zeros[i, j], 0.0)
326
+ wp.expect_eq(custom[i, j], float(i) * 2.0 + float(j))
327
+
328
+
329
+ # NOTE: Compile tile is highly sensitive to shape so we use small values now
330
+ CONSTANT_SHAPE_ROWS = wp.constant(2)
331
+ CONSTANT_SHAPE_COLS = wp.constant(2)
332
+
333
+
334
+ # tests that we can use global constants in shape keyword argument
335
+ # for matrix constructor
336
+ @wp.kernel
337
+ def test_constructors_constant_shape():
338
+ m = wp.matrix(shape=(CONSTANT_SHAPE_ROWS, CONSTANT_SHAPE_COLS), dtype=float)
339
+
340
+ for i in range(CONSTANT_SHAPE_ROWS):
341
+ for j in range(CONSTANT_SHAPE_COLS):
342
+ m[i, j] = float(i * j)
343
+
344
+
345
+ @wp.kernel
346
+ def test_matrix_from_vecs():
347
+ m1 = wp.matrix_from_cols(
348
+ wp.vec3(1.0, 2.0, 3.0),
349
+ wp.vec3(4.0, 5.0, 6.0),
350
+ wp.vec3(7.0, 8.0, 9.0),
351
+ )
352
+ wp.expect_eq(m1[0, 0], 1.0)
353
+ wp.expect_eq(m1[0, 1], 4.0)
354
+ wp.expect_eq(m1[0, 2], 7.0)
355
+ wp.expect_eq(m1[1, 0], 2.0)
356
+ wp.expect_eq(m1[1, 1], 5.0)
357
+ wp.expect_eq(m1[1, 2], 8.0)
358
+ wp.expect_eq(m1[2, 0], 3.0)
359
+ wp.expect_eq(m1[2, 1], 6.0)
360
+ wp.expect_eq(m1[2, 2], 9.0)
361
+
362
+ m2 = wp.matrix_from_rows(
363
+ wp.vec3(1.0, 2.0, 3.0),
364
+ wp.vec3(4.0, 5.0, 6.0),
365
+ wp.vec3(7.0, 8.0, 9.0),
366
+ )
367
+ wp.expect_eq(m2[0, 0], 1.0)
368
+ wp.expect_eq(m2[0, 1], 2.0)
369
+ wp.expect_eq(m2[0, 2], 3.0)
370
+ wp.expect_eq(m2[1, 0], 4.0)
371
+ wp.expect_eq(m2[1, 1], 5.0)
372
+ wp.expect_eq(m2[1, 2], 6.0)
373
+ wp.expect_eq(m2[2, 0], 7.0)
374
+ wp.expect_eq(m2[2, 1], 8.0)
375
+ wp.expect_eq(m2[2, 2], 9.0)
376
+
377
+ m3 = wp.matrix_from_cols(
378
+ wp.vec3(1.0, 2.0, 3.0),
379
+ wp.vec3(4.0, 5.0, 6.0),
380
+ )
381
+ wp.expect_eq(m3[0, 0], 1.0)
382
+ wp.expect_eq(m3[0, 1], 4.0)
383
+ wp.expect_eq(m3[1, 0], 2.0)
384
+ wp.expect_eq(m3[1, 1], 5.0)
385
+ wp.expect_eq(m3[2, 0], 3.0)
386
+ wp.expect_eq(m3[2, 1], 6.0)
387
+
388
+ m4 = wp.matrix_from_rows(
389
+ wp.vec3(1.0, 2.0, 3.0),
390
+ wp.vec3(4.0, 5.0, 6.0),
391
+ )
392
+ wp.expect_eq(m4[0, 0], 1.0)
393
+ wp.expect_eq(m4[0, 1], 2.0)
394
+ wp.expect_eq(m4[0, 2], 3.0)
395
+ wp.expect_eq(m4[1, 0], 4.0)
396
+ wp.expect_eq(m4[1, 1], 5.0)
397
+ wp.expect_eq(m4[1, 2], 6.0)
398
+
399
+
400
+ mat32d = wp.mat(shape=(3, 2), dtype=wp.float64)
401
+
402
+
403
+ @wp.kernel
404
+ def test_matrix_constructor_value_func():
405
+ a = wp.mat22()
406
+ b = wp.matrix(a, shape=(2, 2))
407
+ c = mat32d()
408
+ d = mat32d(c, shape=(3, 2))
409
+ e = mat32d(wp.float64(1.0), wp.float64(2.0), wp.float64(1.0), wp.float64(2.0), wp.float64(1.0), wp.float64(2.0))
410
+ f = wp.matrix(1.0, 2.0, 3.0, 4.0, shape=(2, 2), dtype=float)
411
+
412
+
413
+ def test_quat_constructor(test, device, dtype, register_kernels=False):
414
+ rng = np.random.default_rng(123)
415
+
416
+ tol = {
417
+ np.float16: 1.0e-3,
418
+ np.float32: 1.0e-6,
419
+ np.float64: 1.0e-8,
420
+ }.get(dtype, 0)
421
+
422
+ wptype = wp._src.types.np_dtype_to_warp_type[np.dtype(dtype)]
423
+ vec4 = wp._src.types.vector(length=4, dtype=wptype)
424
+ vec3 = wp._src.types.vector(length=3, dtype=wptype)
425
+ quat = wp._src.types.quaternion(dtype=wptype)
426
+
427
+ output_select_kernel = get_select_kernel(wptype)
428
+
429
+ def check_mat_quat_constructor(
430
+ p: wp.array(dtype=vec3),
431
+ r: wp.array(dtype=quat),
432
+ s: wp.array(dtype=vec3),
433
+ outcomponents: wp.array(dtype=wptype),
434
+ outcomponents_alt: wp.array(dtype=wptype),
435
+ ):
436
+ m = wp.transform_compose(p[0], r[0], s[0])
437
+
438
+ R = wp.transpose(wp.quat_to_matrix(r[0]))
439
+ c0 = s[0][0] * R[0]
440
+ c1 = s[0][1] * R[1]
441
+ c2 = s[0][2] * R[2]
442
+ m_alt = wp.matrix_from_cols(
443
+ vec4(c0[0], c0[1], c0[2], wptype(0.0)),
444
+ vec4(c1[0], c1[1], c1[2], wptype(0.0)),
445
+ vec4(c2[0], c2[1], c2[2], wptype(0.0)),
446
+ vec4(p[0][0], p[0][1], p[0][2], wptype(1.0)),
447
+ )
448
+
449
+ idx = 0
450
+ for i in range(4):
451
+ for j in range(4):
452
+ outcomponents[idx] = m[i, j]
453
+ outcomponents_alt[idx] = m_alt[i, j]
454
+ idx = idx + 1
455
+
456
+ kernel = getkernel(check_mat_quat_constructor, suffix=dtype.__name__)
457
+
458
+ if register_kernels:
459
+ return
460
+
461
+ # translation:
462
+ p = wp.array(rng.standard_normal(size=(1, 3)).astype(dtype), dtype=vec3, requires_grad=True, device=device)
463
+
464
+ # generate a normalized quaternion for the rotation:
465
+ r = rng.standard_normal(size=(1, 4))
466
+ r /= np.linalg.norm(r)
467
+ r = wp.array(r.astype(dtype), dtype=quat, requires_grad=True, device=device)
468
+
469
+ # scale:
470
+ s = wp.array(rng.standard_normal(size=(1, 3)).astype(dtype), dtype=vec3, requires_grad=True, device=device)
471
+
472
+ # just going to generate the matrix using the constructor, then
473
+ # more manually, and make sure the values/gradients are the same:
474
+ outcomponents = wp.zeros(4 * 4, dtype=wptype, requires_grad=True, device=device)
475
+ outcomponents_alt = wp.zeros(4 * 4, dtype=wptype, requires_grad=True, device=device)
476
+ wp.launch(kernel, dim=1, inputs=[p, r, s], outputs=[outcomponents, outcomponents_alt], device=device)
477
+ assert_np_equal(outcomponents.numpy(), outcomponents_alt.numpy(), tol=1.0e-6)
478
+
479
+ idx = 0
480
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
481
+ out_alt = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
482
+ for _i in range(4):
483
+ for _j in range(4):
484
+ tape = wp.Tape()
485
+ with tape:
486
+ wp.launch(kernel, dim=1, inputs=[p, r, s], outputs=[outcomponents, outcomponents_alt], device=device)
487
+ wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device)
488
+ wp.launch(
489
+ output_select_kernel, dim=1, inputs=[outcomponents_alt, idx], outputs=[out_alt], device=device
490
+ )
491
+
492
+ tape.backward(loss=out)
493
+ p_grad = 1.0 * tape.gradients[p].numpy()[0]
494
+ r_grad = 1.0 * tape.gradients[r].numpy()[0]
495
+ s_grad = 1.0 * tape.gradients[s].numpy()[0]
496
+ tape.zero()
497
+
498
+ tape.backward(loss=out_alt)
499
+ p_grad_alt = 1.0 * tape.gradients[p].numpy()[0]
500
+ r_grad_alt = 1.0 * tape.gradients[r].numpy()[0]
501
+ s_grad_alt = 1.0 * tape.gradients[s].numpy()[0]
502
+ tape.zero()
503
+
504
+ assert_np_equal(p_grad, p_grad_alt, tol=tol)
505
+ assert_np_equal(r_grad, r_grad_alt, tol=tol)
506
+ assert_np_equal(s_grad, s_grad_alt, tol=tol)
507
+
508
+ idx = idx + 1
509
+
510
+
511
+ devices = get_test_devices()
512
+
513
+
514
+ class TestMatConstructors(unittest.TestCase):
515
+ pass
516
+
517
+
518
+ add_function_test(
519
+ TestMatConstructors,
520
+ "test_anon_constructor_error_shape_arg_missing",
521
+ test_anon_constructor_error_shape_arg_missing,
522
+ devices=devices,
523
+ )
524
+ add_function_test(
525
+ TestMatConstructors,
526
+ "test_anon_constructor_error_shape_mismatch",
527
+ test_anon_constructor_error_shape_mismatch,
528
+ devices=devices,
529
+ )
530
+ add_function_test(
531
+ TestMatConstructors,
532
+ "test_anon_constructor_error_type_mismatch",
533
+ test_anon_constructor_error_type_mismatch,
534
+ devices=devices,
535
+ )
536
+ add_function_test(
537
+ TestMatConstructors,
538
+ "test_anon_constructor_error_invalid_arg_count",
539
+ test_anon_constructor_error_invalid_arg_count,
540
+ devices=devices,
541
+ )
542
+ add_function_test(
543
+ TestMatConstructors,
544
+ "test_tpl_constructor_error_incompatible_sizes",
545
+ test_tpl_constructor_error_incompatible_sizes,
546
+ devices=devices,
547
+ )
548
+ add_function_test(
549
+ TestMatConstructors,
550
+ "test_tpl_constructor_error_invalid_arg_count",
551
+ test_tpl_constructor_error_invalid_arg_count,
552
+ devices=devices,
553
+ )
554
+ add_function_test(TestMatConstructors, "test_matrix_from_vecs_runtime", test_matrix_from_vecs_runtime, devices=devices)
555
+
556
+ add_kernel_test(TestMatConstructors, test_constructors_explicit_precision, dim=1, devices=devices)
557
+ add_kernel_test(TestMatConstructors, test_constructors_default_precision, dim=1, devices=devices)
558
+ add_kernel_test(TestMatConstructors, test_constructors_constant_shape, dim=1, devices=devices)
559
+ add_kernel_test(TestMatConstructors, test_matrix_from_vecs, dim=1, devices=devices)
560
+ add_kernel_test(TestMatConstructors, test_matrix_constructor_value_func, dim=1, devices=devices)
561
+
562
+ for dtype in np_float_types:
563
+ add_function_test_register_kernel(
564
+ TestMatConstructors,
565
+ f"test_quat_constructor_{dtype.__name__}",
566
+ test_quat_constructor,
567
+ devices=devices,
568
+ dtype=dtype,
569
+ )
570
+
571
+ if __name__ == "__main__":
572
+ wp.clear_kernel_cache()
573
+ unittest.main(verbosity=2, failfast=True)
@@ -0,0 +1,122 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import unittest
17
+
18
+ import warp as wp
19
+ from warp.tests.unittest_utils import *
20
+
21
+ mat32d = wp.mat(shape=(3, 2), dtype=wp.float64)
22
+
23
+
24
+ @wp.kernel
25
+ def test_matrix_constructor_value_func():
26
+ a = wp.mat22()
27
+ b = wp.matrix(a, shape=(2, 2))
28
+ c = mat32d()
29
+ d = mat32d(c, shape=(3, 2))
30
+ e = mat32d(wp.float64(1.0), wp.float64(2.0), wp.float64(1.0), wp.float64(2.0), wp.float64(1.0), wp.float64(2.0))
31
+ f = mat32d(
32
+ wp.vec3d(wp.float64(1.0), wp.float64(2.0), wp.float64(3.0)),
33
+ wp.vec3d(wp.float64(1.0), wp.float64(2.0), wp.float64(3.0)),
34
+ )
35
+ g = wp.matrix(1.0, shape=(3, 2))
36
+
37
+
38
+ # Test matrix constructors using explicit type (float16)
39
+ # note that these tests are specifically not using generics / closure
40
+ # args to create kernels dynamically (like the rest of this file)
41
+ # as those use different code paths to resolve arg types which
42
+ # has lead to regressions.
43
+ @wp.kernel
44
+ def test_constructors_explicit_precision():
45
+ # construction for custom matrix types
46
+ eye = wp.identity(dtype=wp.float16, n=2)
47
+ zeros = wp.matrix(shape=(2, 2), dtype=wp.float16)
48
+ custom = wp.matrix(wp.float16(0.0), wp.float16(1.0), wp.float16(2.0), wp.float16(3.0), shape=(2, 2))
49
+
50
+ for i in range(2):
51
+ for j in range(2):
52
+ if i == j:
53
+ wp.expect_eq(eye[i, j], wp.float16(1.0))
54
+ else:
55
+ wp.expect_eq(eye[i, j], wp.float16(0.0))
56
+
57
+ wp.expect_eq(zeros[i, j], wp.float16(0.0))
58
+ wp.expect_eq(custom[i, j], wp.float16(i) * wp.float16(2.0) + wp.float16(j))
59
+
60
+
61
+ # Same as above but with a default (float/int) type
62
+ # which tests some different code paths that
63
+ # need to ensure types are correctly canonicalized
64
+ # during codegen
65
+ @wp.kernel
66
+ def test_constructors_default_precision():
67
+ # construction for default (float) matrix types
68
+ eye = wp.identity(dtype=float, n=2)
69
+ zeros = wp.matrix(shape=(2, 2), dtype=float)
70
+ custom = wp.matrix(0.0, 1.0, 2.0, 3.0, shape=(2, 2))
71
+
72
+ for i in range(2):
73
+ for j in range(2):
74
+ if i == j:
75
+ wp.expect_eq(eye[i, j], 1.0)
76
+ else:
77
+ wp.expect_eq(eye[i, j], 0.0)
78
+
79
+ wp.expect_eq(zeros[i, j], 0.0)
80
+ wp.expect_eq(custom[i, j], float(i) * 2.0 + float(j))
81
+
82
+
83
+ @wp.kernel
84
+ def test_matrix_mutation(expected: wp._src.types.matrix(shape=(10, 3), dtype=float)):
85
+ m = wp.matrix(shape=(10, 3), dtype=float)
86
+
87
+ # test direct element indexing
88
+ m[0, 0] = 1.0
89
+ m[0, 1] = 2.0
90
+ m[0, 2] = 3.0
91
+
92
+ # The nested indexing (matrix->vector->scalar) below does not
93
+ # currently modify m because m[0] returns row vector by
94
+ # value rather than reference, this is different from NumPy
95
+ # which always returns by ref. Not clear how we can support
96
+ # this as well as auto-diff.
97
+
98
+ # m[0][1] = 2.0
99
+ # m[0][2] = 3.0
100
+
101
+ # test setting rows
102
+ for i in range(1, 10):
103
+ m[i] = m[i - 1] + wp.vec3(1.0, 2.0, 3.0)
104
+
105
+ wp.expect_eq(m, expected)
106
+
107
+
108
+ devices = get_test_devices()
109
+
110
+
111
+ class TestMatLite(unittest.TestCase):
112
+ pass
113
+
114
+
115
+ add_kernel_test(TestMatLite, test_matrix_constructor_value_func, dim=1, devices=devices)
116
+ add_kernel_test(TestMatLite, test_constructors_explicit_precision, dim=1, devices=devices)
117
+ add_kernel_test(TestMatLite, test_constructors_default_precision, dim=1, devices=devices)
118
+
119
+
120
+ if __name__ == "__main__":
121
+ wp.clear_kernel_cache()
122
+ unittest.main(verbosity=2, failfast=True)