warp-lang 1.10.0__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (468) hide show
  1. warp/__init__.py +334 -0
  2. warp/__init__.pyi +5856 -0
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1077 -0
  5. warp/_src/build.py +620 -0
  6. warp/_src/build_dll.py +642 -0
  7. warp/_src/builtins.py +10555 -0
  8. warp/_src/codegen.py +4361 -0
  9. warp/_src/config.py +178 -0
  10. warp/_src/constants.py +59 -0
  11. warp/_src/context.py +8352 -0
  12. warp/_src/dlpack.py +464 -0
  13. warp/_src/fabric.py +362 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +510 -0
  16. warp/_src/fem/cache.py +689 -0
  17. warp/_src/fem/dirichlet.py +190 -0
  18. warp/_src/fem/domain.py +553 -0
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +703 -0
  21. warp/_src/fem/field/nodal_field.py +403 -0
  22. warp/_src/fem/field/restriction.py +39 -0
  23. warp/_src/fem/field/virtual.py +1021 -0
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/_src/fem/geometry/adaptive_nanogrid.py +782 -0
  26. warp/_src/fem/geometry/closest_point.py +99 -0
  27. warp/_src/fem/geometry/deformed_geometry.py +277 -0
  28. warp/_src/fem/geometry/element.py +854 -0
  29. warp/_src/fem/geometry/geometry.py +693 -0
  30. warp/_src/fem/geometry/grid_2d.py +478 -0
  31. warp/_src/fem/geometry/grid_3d.py +539 -0
  32. warp/_src/fem/geometry/hexmesh.py +956 -0
  33. warp/_src/fem/geometry/nanogrid.py +660 -0
  34. warp/_src/fem/geometry/partition.py +483 -0
  35. warp/_src/fem/geometry/quadmesh.py +597 -0
  36. warp/_src/fem/geometry/tetmesh.py +762 -0
  37. warp/_src/fem/geometry/trimesh.py +588 -0
  38. warp/_src/fem/integrate.py +2507 -0
  39. warp/_src/fem/linalg.py +385 -0
  40. warp/_src/fem/operator.py +398 -0
  41. warp/_src/fem/polynomial.py +231 -0
  42. warp/_src/fem/quadrature/__init__.py +17 -0
  43. warp/_src/fem/quadrature/pic_quadrature.py +318 -0
  44. warp/_src/fem/quadrature/quadrature.py +665 -0
  45. warp/_src/fem/space/__init__.py +248 -0
  46. warp/_src/fem/space/basis_function_space.py +499 -0
  47. warp/_src/fem/space/basis_space.py +681 -0
  48. warp/_src/fem/space/dof_mapper.py +253 -0
  49. warp/_src/fem/space/function_space.py +312 -0
  50. warp/_src/fem/space/grid_2d_function_space.py +179 -0
  51. warp/_src/fem/space/grid_3d_function_space.py +229 -0
  52. warp/_src/fem/space/hexmesh_function_space.py +255 -0
  53. warp/_src/fem/space/nanogrid_function_space.py +199 -0
  54. warp/_src/fem/space/partition.py +435 -0
  55. warp/_src/fem/space/quadmesh_function_space.py +222 -0
  56. warp/_src/fem/space/restriction.py +221 -0
  57. warp/_src/fem/space/shape/__init__.py +152 -0
  58. warp/_src/fem/space/shape/cube_shape_function.py +1107 -0
  59. warp/_src/fem/space/shape/shape_function.py +134 -0
  60. warp/_src/fem/space/shape/square_shape_function.py +928 -0
  61. warp/_src/fem/space/shape/tet_shape_function.py +829 -0
  62. warp/_src/fem/space/shape/triangle_shape_function.py +674 -0
  63. warp/_src/fem/space/tetmesh_function_space.py +270 -0
  64. warp/_src/fem/space/topology.py +461 -0
  65. warp/_src/fem/space/trimesh_function_space.py +193 -0
  66. warp/_src/fem/types.py +114 -0
  67. warp/_src/fem/utils.py +488 -0
  68. warp/_src/jax.py +188 -0
  69. warp/_src/jax_experimental/__init__.py +14 -0
  70. warp/_src/jax_experimental/custom_call.py +389 -0
  71. warp/_src/jax_experimental/ffi.py +1286 -0
  72. warp/_src/jax_experimental/xla_ffi.py +658 -0
  73. warp/_src/marching_cubes.py +710 -0
  74. warp/_src/math.py +416 -0
  75. warp/_src/optim/__init__.py +14 -0
  76. warp/_src/optim/adam.py +165 -0
  77. warp/_src/optim/linear.py +1608 -0
  78. warp/_src/optim/sgd.py +114 -0
  79. warp/_src/paddle.py +408 -0
  80. warp/_src/render/__init__.py +14 -0
  81. warp/_src/render/imgui_manager.py +291 -0
  82. warp/_src/render/render_opengl.py +3638 -0
  83. warp/_src/render/render_usd.py +939 -0
  84. warp/_src/render/utils.py +162 -0
  85. warp/_src/sparse.py +2718 -0
  86. warp/_src/tape.py +1208 -0
  87. warp/_src/thirdparty/__init__.py +0 -0
  88. warp/_src/thirdparty/appdirs.py +598 -0
  89. warp/_src/thirdparty/dlpack.py +145 -0
  90. warp/_src/thirdparty/unittest_parallel.py +676 -0
  91. warp/_src/torch.py +393 -0
  92. warp/_src/types.py +5888 -0
  93. warp/_src/utils.py +1695 -0
  94. warp/autograd.py +33 -0
  95. warp/bin/libwarp-clang.dylib +0 -0
  96. warp/bin/libwarp.dylib +0 -0
  97. warp/build.py +29 -0
  98. warp/build_dll.py +24 -0
  99. warp/codegen.py +24 -0
  100. warp/constants.py +24 -0
  101. warp/context.py +33 -0
  102. warp/dlpack.py +24 -0
  103. warp/examples/__init__.py +24 -0
  104. warp/examples/assets/bear.usd +0 -0
  105. warp/examples/assets/bunny.usd +0 -0
  106. warp/examples/assets/cube.usd +0 -0
  107. warp/examples/assets/nonuniform.usd +0 -0
  108. warp/examples/assets/nvidia_logo.png +0 -0
  109. warp/examples/assets/pixel.jpg +0 -0
  110. warp/examples/assets/rocks.nvdb +0 -0
  111. warp/examples/assets/rocks.usd +0 -0
  112. warp/examples/assets/sphere.usd +0 -0
  113. warp/examples/assets/square_cloth.usd +0 -0
  114. warp/examples/benchmarks/benchmark_api.py +389 -0
  115. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  116. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  117. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  118. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  119. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  120. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  121. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  122. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  123. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  124. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  125. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  126. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  127. warp/examples/benchmarks/benchmark_launches.py +301 -0
  128. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  129. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  130. warp/examples/browse.py +37 -0
  131. warp/examples/core/example_cupy.py +86 -0
  132. warp/examples/core/example_dem.py +241 -0
  133. warp/examples/core/example_fluid.py +299 -0
  134. warp/examples/core/example_graph_capture.py +150 -0
  135. warp/examples/core/example_marching_cubes.py +195 -0
  136. warp/examples/core/example_mesh.py +180 -0
  137. warp/examples/core/example_mesh_intersect.py +211 -0
  138. warp/examples/core/example_nvdb.py +182 -0
  139. warp/examples/core/example_raycast.py +111 -0
  140. warp/examples/core/example_raymarch.py +205 -0
  141. warp/examples/core/example_render_opengl.py +290 -0
  142. warp/examples/core/example_sample_mesh.py +300 -0
  143. warp/examples/core/example_sph.py +411 -0
  144. warp/examples/core/example_spin_lock.py +93 -0
  145. warp/examples/core/example_torch.py +211 -0
  146. warp/examples/core/example_wave.py +269 -0
  147. warp/examples/core/example_work_queue.py +118 -0
  148. warp/examples/distributed/example_jacobi_mpi.py +506 -0
  149. warp/examples/fem/example_adaptive_grid.py +286 -0
  150. warp/examples/fem/example_apic_fluid.py +469 -0
  151. warp/examples/fem/example_burgers.py +261 -0
  152. warp/examples/fem/example_convection_diffusion.py +181 -0
  153. warp/examples/fem/example_convection_diffusion_dg.py +225 -0
  154. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  155. warp/examples/fem/example_deformed_geometry.py +172 -0
  156. warp/examples/fem/example_diffusion.py +196 -0
  157. warp/examples/fem/example_diffusion_3d.py +225 -0
  158. warp/examples/fem/example_diffusion_mgpu.py +225 -0
  159. warp/examples/fem/example_distortion_energy.py +228 -0
  160. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  161. warp/examples/fem/example_magnetostatics.py +242 -0
  162. warp/examples/fem/example_mixed_elasticity.py +293 -0
  163. warp/examples/fem/example_navier_stokes.py +263 -0
  164. warp/examples/fem/example_nonconforming_contact.py +300 -0
  165. warp/examples/fem/example_stokes.py +213 -0
  166. warp/examples/fem/example_stokes_transfer.py +262 -0
  167. warp/examples/fem/example_streamlines.py +357 -0
  168. warp/examples/fem/utils.py +1047 -0
  169. warp/examples/interop/example_jax_callable.py +146 -0
  170. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  171. warp/examples/interop/example_jax_kernel.py +232 -0
  172. warp/examples/optim/example_diffray.py +561 -0
  173. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  174. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  175. warp/examples/tile/example_tile_cholesky.py +88 -0
  176. warp/examples/tile/example_tile_convolution.py +66 -0
  177. warp/examples/tile/example_tile_fft.py +55 -0
  178. warp/examples/tile/example_tile_filtering.py +113 -0
  179. warp/examples/tile/example_tile_matmul.py +85 -0
  180. warp/examples/tile/example_tile_mcgp.py +191 -0
  181. warp/examples/tile/example_tile_mlp.py +385 -0
  182. warp/examples/tile/example_tile_nbody.py +199 -0
  183. warp/fabric.py +24 -0
  184. warp/fem/__init__.py +173 -0
  185. warp/fem/adaptivity.py +26 -0
  186. warp/fem/cache.py +30 -0
  187. warp/fem/dirichlet.py +24 -0
  188. warp/fem/field/__init__.py +24 -0
  189. warp/fem/field/field.py +26 -0
  190. warp/fem/geometry/__init__.py +21 -0
  191. warp/fem/geometry/closest_point.py +31 -0
  192. warp/fem/linalg.py +38 -0
  193. warp/fem/operator.py +32 -0
  194. warp/fem/polynomial.py +29 -0
  195. warp/fem/space/__init__.py +22 -0
  196. warp/fem/space/basis_space.py +24 -0
  197. warp/fem/space/shape/__init__.py +68 -0
  198. warp/fem/space/topology.py +24 -0
  199. warp/fem/types.py +24 -0
  200. warp/fem/utils.py +32 -0
  201. warp/jax.py +29 -0
  202. warp/jax_experimental/__init__.py +29 -0
  203. warp/jax_experimental/custom_call.py +29 -0
  204. warp/jax_experimental/ffi.py +39 -0
  205. warp/jax_experimental/xla_ffi.py +24 -0
  206. warp/marching_cubes.py +24 -0
  207. warp/math.py +37 -0
  208. warp/native/array.h +1687 -0
  209. warp/native/builtin.h +2327 -0
  210. warp/native/bvh.cpp +562 -0
  211. warp/native/bvh.cu +826 -0
  212. warp/native/bvh.h +555 -0
  213. warp/native/clang/clang.cpp +541 -0
  214. warp/native/coloring.cpp +622 -0
  215. warp/native/crt.cpp +51 -0
  216. warp/native/crt.h +568 -0
  217. warp/native/cuda_crt.h +1058 -0
  218. warp/native/cuda_util.cpp +677 -0
  219. warp/native/cuda_util.h +313 -0
  220. warp/native/error.cpp +77 -0
  221. warp/native/error.h +36 -0
  222. warp/native/exports.h +2023 -0
  223. warp/native/fabric.h +246 -0
  224. warp/native/hashgrid.cpp +311 -0
  225. warp/native/hashgrid.cu +89 -0
  226. warp/native/hashgrid.h +240 -0
  227. warp/native/initializer_array.h +41 -0
  228. warp/native/intersect.h +1253 -0
  229. warp/native/intersect_adj.h +375 -0
  230. warp/native/intersect_tri.h +348 -0
  231. warp/native/mat.h +5189 -0
  232. warp/native/mathdx.cpp +93 -0
  233. warp/native/matnn.h +221 -0
  234. warp/native/mesh.cpp +266 -0
  235. warp/native/mesh.cu +406 -0
  236. warp/native/mesh.h +2097 -0
  237. warp/native/nanovdb/GridHandle.h +533 -0
  238. warp/native/nanovdb/HostBuffer.h +591 -0
  239. warp/native/nanovdb/NanoVDB.h +6246 -0
  240. warp/native/nanovdb/NodeManager.h +323 -0
  241. warp/native/nanovdb/PNanoVDB.h +3390 -0
  242. warp/native/noise.h +859 -0
  243. warp/native/quat.h +1664 -0
  244. warp/native/rand.h +342 -0
  245. warp/native/range.h +145 -0
  246. warp/native/reduce.cpp +174 -0
  247. warp/native/reduce.cu +363 -0
  248. warp/native/runlength_encode.cpp +79 -0
  249. warp/native/runlength_encode.cu +61 -0
  250. warp/native/scan.cpp +47 -0
  251. warp/native/scan.cu +55 -0
  252. warp/native/scan.h +23 -0
  253. warp/native/solid_angle.h +466 -0
  254. warp/native/sort.cpp +251 -0
  255. warp/native/sort.cu +286 -0
  256. warp/native/sort.h +35 -0
  257. warp/native/sparse.cpp +241 -0
  258. warp/native/sparse.cu +435 -0
  259. warp/native/spatial.h +1306 -0
  260. warp/native/svd.h +727 -0
  261. warp/native/temp_buffer.h +46 -0
  262. warp/native/tile.h +4124 -0
  263. warp/native/tile_radix_sort.h +1112 -0
  264. warp/native/tile_reduce.h +838 -0
  265. warp/native/tile_scan.h +240 -0
  266. warp/native/tuple.h +189 -0
  267. warp/native/vec.h +2199 -0
  268. warp/native/version.h +23 -0
  269. warp/native/volume.cpp +501 -0
  270. warp/native/volume.cu +68 -0
  271. warp/native/volume.h +970 -0
  272. warp/native/volume_builder.cu +483 -0
  273. warp/native/volume_builder.h +52 -0
  274. warp/native/volume_impl.h +70 -0
  275. warp/native/warp.cpp +1143 -0
  276. warp/native/warp.cu +4604 -0
  277. warp/native/warp.h +358 -0
  278. warp/optim/__init__.py +20 -0
  279. warp/optim/adam.py +24 -0
  280. warp/optim/linear.py +35 -0
  281. warp/optim/sgd.py +24 -0
  282. warp/paddle.py +24 -0
  283. warp/py.typed +0 -0
  284. warp/render/__init__.py +22 -0
  285. warp/render/imgui_manager.py +29 -0
  286. warp/render/render_opengl.py +24 -0
  287. warp/render/render_usd.py +24 -0
  288. warp/render/utils.py +24 -0
  289. warp/sparse.py +51 -0
  290. warp/tape.py +24 -0
  291. warp/tests/__init__.py +1 -0
  292. warp/tests/__main__.py +4 -0
  293. warp/tests/assets/curlnoise_golden.npy +0 -0
  294. warp/tests/assets/mlp_golden.npy +0 -0
  295. warp/tests/assets/pixel.npy +0 -0
  296. warp/tests/assets/pnoise_golden.npy +0 -0
  297. warp/tests/assets/spiky.usd +0 -0
  298. warp/tests/assets/test_grid.nvdb +0 -0
  299. warp/tests/assets/test_index_grid.nvdb +0 -0
  300. warp/tests/assets/test_int32_grid.nvdb +0 -0
  301. warp/tests/assets/test_vec_grid.nvdb +0 -0
  302. warp/tests/assets/torus.nvdb +0 -0
  303. warp/tests/assets/torus.usda +105 -0
  304. warp/tests/aux_test_class_kernel.py +34 -0
  305. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  306. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  307. warp/tests/aux_test_dependent.py +29 -0
  308. warp/tests/aux_test_grad_customs.py +29 -0
  309. warp/tests/aux_test_instancing_gc.py +26 -0
  310. warp/tests/aux_test_module_aot.py +7 -0
  311. warp/tests/aux_test_module_unload.py +23 -0
  312. warp/tests/aux_test_name_clash1.py +40 -0
  313. warp/tests/aux_test_name_clash2.py +40 -0
  314. warp/tests/aux_test_reference.py +9 -0
  315. warp/tests/aux_test_reference_reference.py +8 -0
  316. warp/tests/aux_test_square.py +16 -0
  317. warp/tests/aux_test_unresolved_func.py +22 -0
  318. warp/tests/aux_test_unresolved_symbol.py +22 -0
  319. warp/tests/cuda/__init__.py +0 -0
  320. warp/tests/cuda/test_async.py +676 -0
  321. warp/tests/cuda/test_conditional_captures.py +1147 -0
  322. warp/tests/cuda/test_ipc.py +124 -0
  323. warp/tests/cuda/test_mempool.py +233 -0
  324. warp/tests/cuda/test_multigpu.py +169 -0
  325. warp/tests/cuda/test_peer.py +139 -0
  326. warp/tests/cuda/test_pinned.py +84 -0
  327. warp/tests/cuda/test_streams.py +691 -0
  328. warp/tests/geometry/__init__.py +0 -0
  329. warp/tests/geometry/test_bvh.py +335 -0
  330. warp/tests/geometry/test_hash_grid.py +259 -0
  331. warp/tests/geometry/test_marching_cubes.py +294 -0
  332. warp/tests/geometry/test_mesh.py +318 -0
  333. warp/tests/geometry/test_mesh_query_aabb.py +392 -0
  334. warp/tests/geometry/test_mesh_query_point.py +935 -0
  335. warp/tests/geometry/test_mesh_query_ray.py +323 -0
  336. warp/tests/geometry/test_volume.py +1103 -0
  337. warp/tests/geometry/test_volume_write.py +346 -0
  338. warp/tests/interop/__init__.py +0 -0
  339. warp/tests/interop/test_dlpack.py +730 -0
  340. warp/tests/interop/test_jax.py +1673 -0
  341. warp/tests/interop/test_paddle.py +800 -0
  342. warp/tests/interop/test_torch.py +1001 -0
  343. warp/tests/run_coverage_serial.py +39 -0
  344. warp/tests/test_adam.py +162 -0
  345. warp/tests/test_arithmetic.py +1096 -0
  346. warp/tests/test_array.py +3756 -0
  347. warp/tests/test_array_reduce.py +156 -0
  348. warp/tests/test_assert.py +303 -0
  349. warp/tests/test_atomic.py +336 -0
  350. warp/tests/test_atomic_bitwise.py +209 -0
  351. warp/tests/test_atomic_cas.py +312 -0
  352. warp/tests/test_bool.py +220 -0
  353. warp/tests/test_builtins_resolution.py +732 -0
  354. warp/tests/test_closest_point_edge_edge.py +327 -0
  355. warp/tests/test_codegen.py +974 -0
  356. warp/tests/test_codegen_instancing.py +1495 -0
  357. warp/tests/test_compile_consts.py +215 -0
  358. warp/tests/test_conditional.py +298 -0
  359. warp/tests/test_context.py +35 -0
  360. warp/tests/test_copy.py +319 -0
  361. warp/tests/test_ctypes.py +618 -0
  362. warp/tests/test_dense.py +73 -0
  363. warp/tests/test_devices.py +127 -0
  364. warp/tests/test_enum.py +136 -0
  365. warp/tests/test_examples.py +424 -0
  366. warp/tests/test_fabricarray.py +998 -0
  367. warp/tests/test_fast_math.py +72 -0
  368. warp/tests/test_fem.py +2204 -0
  369. warp/tests/test_fixedarray.py +229 -0
  370. warp/tests/test_fp16.py +136 -0
  371. warp/tests/test_func.py +501 -0
  372. warp/tests/test_future_annotations.py +100 -0
  373. warp/tests/test_generics.py +656 -0
  374. warp/tests/test_grad.py +893 -0
  375. warp/tests/test_grad_customs.py +339 -0
  376. warp/tests/test_grad_debug.py +341 -0
  377. warp/tests/test_implicit_init.py +411 -0
  378. warp/tests/test_import.py +45 -0
  379. warp/tests/test_indexedarray.py +1140 -0
  380. warp/tests/test_intersect.py +103 -0
  381. warp/tests/test_iter.py +76 -0
  382. warp/tests/test_large.py +177 -0
  383. warp/tests/test_launch.py +411 -0
  384. warp/tests/test_lerp.py +151 -0
  385. warp/tests/test_linear_solvers.py +223 -0
  386. warp/tests/test_lvalue.py +427 -0
  387. warp/tests/test_map.py +526 -0
  388. warp/tests/test_mat.py +3515 -0
  389. warp/tests/test_mat_assign_copy.py +178 -0
  390. warp/tests/test_mat_constructors.py +573 -0
  391. warp/tests/test_mat_lite.py +122 -0
  392. warp/tests/test_mat_scalar_ops.py +2913 -0
  393. warp/tests/test_math.py +212 -0
  394. warp/tests/test_module_aot.py +287 -0
  395. warp/tests/test_module_hashing.py +258 -0
  396. warp/tests/test_modules_lite.py +70 -0
  397. warp/tests/test_noise.py +252 -0
  398. warp/tests/test_operators.py +299 -0
  399. warp/tests/test_options.py +129 -0
  400. warp/tests/test_overwrite.py +551 -0
  401. warp/tests/test_print.py +408 -0
  402. warp/tests/test_quat.py +2653 -0
  403. warp/tests/test_quat_assign_copy.py +145 -0
  404. warp/tests/test_rand.py +339 -0
  405. warp/tests/test_reload.py +303 -0
  406. warp/tests/test_rounding.py +157 -0
  407. warp/tests/test_runlength_encode.py +196 -0
  408. warp/tests/test_scalar_ops.py +133 -0
  409. warp/tests/test_smoothstep.py +108 -0
  410. warp/tests/test_snippet.py +318 -0
  411. warp/tests/test_sparse.py +845 -0
  412. warp/tests/test_spatial.py +2859 -0
  413. warp/tests/test_spatial_assign_copy.py +160 -0
  414. warp/tests/test_special_values.py +361 -0
  415. warp/tests/test_static.py +640 -0
  416. warp/tests/test_struct.py +901 -0
  417. warp/tests/test_tape.py +242 -0
  418. warp/tests/test_transient_module.py +93 -0
  419. warp/tests/test_triangle_closest_point.py +192 -0
  420. warp/tests/test_tuple.py +361 -0
  421. warp/tests/test_types.py +615 -0
  422. warp/tests/test_utils.py +594 -0
  423. warp/tests/test_vec.py +1408 -0
  424. warp/tests/test_vec_assign_copy.py +143 -0
  425. warp/tests/test_vec_constructors.py +325 -0
  426. warp/tests/test_vec_lite.py +80 -0
  427. warp/tests/test_vec_scalar_ops.py +2327 -0
  428. warp/tests/test_verify_fp.py +100 -0
  429. warp/tests/test_version.py +75 -0
  430. warp/tests/tile/__init__.py +0 -0
  431. warp/tests/tile/test_tile.py +1519 -0
  432. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  433. warp/tests/tile/test_tile_cholesky.py +608 -0
  434. warp/tests/tile/test_tile_load.py +724 -0
  435. warp/tests/tile/test_tile_mathdx.py +156 -0
  436. warp/tests/tile/test_tile_matmul.py +179 -0
  437. warp/tests/tile/test_tile_mlp.py +400 -0
  438. warp/tests/tile/test_tile_reduce.py +950 -0
  439. warp/tests/tile/test_tile_shared_memory.py +376 -0
  440. warp/tests/tile/test_tile_sort.py +121 -0
  441. warp/tests/tile/test_tile_view.py +173 -0
  442. warp/tests/unittest_serial.py +47 -0
  443. warp/tests/unittest_suites.py +430 -0
  444. warp/tests/unittest_utils.py +469 -0
  445. warp/tests/walkthrough_debug.py +95 -0
  446. warp/torch.py +24 -0
  447. warp/types.py +51 -0
  448. warp/utils.py +31 -0
  449. warp_lang-1.10.0.dist-info/METADATA +459 -0
  450. warp_lang-1.10.0.dist-info/RECORD +468 -0
  451. warp_lang-1.10.0.dist-info/WHEEL +5 -0
  452. warp_lang-1.10.0.dist-info/licenses/LICENSE.md +176 -0
  453. warp_lang-1.10.0.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  454. warp_lang-1.10.0.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  455. warp_lang-1.10.0.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  456. warp_lang-1.10.0.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  457. warp_lang-1.10.0.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  458. warp_lang-1.10.0.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  459. warp_lang-1.10.0.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  460. warp_lang-1.10.0.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  461. warp_lang-1.10.0.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  462. warp_lang-1.10.0.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  463. warp_lang-1.10.0.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  464. warp_lang-1.10.0.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  465. warp_lang-1.10.0.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  466. warp_lang-1.10.0.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  467. warp_lang-1.10.0.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  468. warp_lang-1.10.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,376 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import unittest
17
+
18
+ import numpy as np
19
+
20
+ import warp as wp
21
+ from warp.tests.unittest_utils import *
22
+
23
+
24
+ # checks that we can configure shared memory to the expected size
25
+ def test_tile_shared_mem_size(test, device):
26
+ DIM_M = 32
27
+ DIM_N = 32
28
+
29
+ BLOCK_DIM = 256
30
+
31
+ @wp.kernel(module="unique")
32
+ def compute(out: wp.array2d(dtype=float)):
33
+ a = wp.tile_ones(shape=(DIM_M, DIM_N), dtype=float, storage="shared")
34
+ b = wp.tile_ones(shape=(DIM_M, DIM_N), dtype=float, storage="shared") * 2.0
35
+
36
+ c = a + b
37
+ wp.tile_store(out, c)
38
+
39
+ out = wp.empty((DIM_M, DIM_N), dtype=float, device=device)
40
+
41
+ wp.launch_tiled(compute, dim=[1], inputs=[out], block_dim=BLOCK_DIM, device=device)
42
+
43
+ # check output
44
+ assert_np_equal(out.numpy(), np.ones((DIM_M, DIM_N)) * 3.0)
45
+
46
+ # check required shared memory
47
+ expected_forward_bytes = DIM_M * DIM_N * 4 * 2
48
+ expected_backward_bytes = expected_forward_bytes * 2
49
+
50
+ # check shared memory for kernel on the device
51
+ module_exec = compute.module.load(device, BLOCK_DIM)
52
+ hooks = module_exec.get_kernel_hooks(compute)
53
+
54
+ assert hooks.forward_smem_bytes == expected_forward_bytes
55
+ assert hooks.backward_smem_bytes == expected_backward_bytes
56
+
57
+
58
+ # checks that we can configure shared memory > 48kb default
59
+ def test_tile_shared_mem_large(test, device):
60
+ # set dimensions that require 64kb for the forward kernel
61
+ DIM_M = 64
62
+ DIM_N = 128
63
+
64
+ BLOCK_DIM = 256
65
+
66
+ # we disable backward kernel gen since 128k is not supported on most architectures
67
+ @wp.kernel(enable_backward=False, module="unique")
68
+ def compute(out: wp.array2d(dtype=float)):
69
+ a = wp.tile_ones(shape=(DIM_M, DIM_N), dtype=float, storage="shared")
70
+ b = wp.tile_ones(shape=(DIM_M, DIM_N), dtype=float, storage="shared") * 2.0
71
+
72
+ c = a + b
73
+ wp.tile_store(out, c)
74
+
75
+ out = wp.empty((DIM_M, DIM_N), dtype=float, device=device)
76
+
77
+ wp.launch_tiled(compute, dim=[1], inputs=[out], block_dim=BLOCK_DIM, device=device)
78
+
79
+ # check output
80
+ assert_np_equal(out.numpy(), np.ones((DIM_M, DIM_N)) * 3.0)
81
+
82
+ # check required shared memory
83
+ expected_forward_bytes = DIM_M * DIM_N * 4 * 2
84
+ expected_backward_bytes = 0
85
+
86
+ assert expected_forward_bytes == 2**16
87
+
88
+ # check shared memory for kernel on the device
89
+ module_exec = compute.module.load(device, BLOCK_DIM)
90
+ hooks = module_exec.get_kernel_hooks(compute)
91
+
92
+ assert hooks.forward_smem_bytes == expected_forward_bytes
93
+ assert hooks.backward_smem_bytes == expected_backward_bytes
94
+
95
+
96
+ # checks that we can configure dynamic shared memory during graph capture
97
+ def test_tile_shared_mem_graph(test, device):
98
+ DIM_M = 32
99
+ DIM_N = 32
100
+
101
+ BLOCK_DIM = 256
102
+
103
+ @wp.kernel(module="unique")
104
+ def compute(out: wp.array2d(dtype=float)):
105
+ a = wp.tile_ones(shape=(DIM_M, DIM_N), dtype=float, storage="shared")
106
+ b = wp.tile_ones(shape=(DIM_M, DIM_N), dtype=float, storage="shared") * 2.0
107
+
108
+ c = a + b
109
+ wp.tile_store(out, c)
110
+
111
+ out = wp.empty((DIM_M, DIM_N), dtype=float, device=device)
112
+
113
+ # preload the unique module
114
+ wp.load_module(compute.module, device=device, block_dim=BLOCK_DIM)
115
+
116
+ with wp.ScopedCapture(device, force_module_load=False) as capture:
117
+ wp.launch_tiled(compute, dim=[1], inputs=[out], block_dim=BLOCK_DIM, device=device)
118
+
119
+ wp.capture_launch(capture.graph)
120
+
121
+ # check output
122
+ assert_np_equal(out.numpy(), np.ones((DIM_M, DIM_N)) * 3.0)
123
+
124
+ # check required shared memory
125
+ expected_forward_bytes = DIM_M * DIM_N * 4 * 2
126
+ expected_backward_bytes = expected_forward_bytes * 2
127
+
128
+ # check shared memory for kernel on the device
129
+ module_exec = compute.module.load(device, BLOCK_DIM)
130
+ hooks = module_exec.get_kernel_hooks(compute)
131
+
132
+ assert hooks.forward_smem_bytes == expected_forward_bytes
133
+ assert hooks.backward_smem_bytes == expected_backward_bytes
134
+
135
+
136
+ # checks that stack allocations work for user functions
137
+ def test_tile_shared_mem_func(test, device):
138
+ DIM_M = 64
139
+ DIM_N = 64
140
+
141
+ SMALL_DIM_M = 64 // 4
142
+ SMALL_DIM_N = 64 // 4
143
+
144
+ BLOCK_DIM = 256
145
+
146
+ @wp.func
147
+ def add_tile_small():
148
+ a = wp.tile_ones(shape=(SMALL_DIM_M, SMALL_DIM_N), dtype=float, storage="shared")
149
+ b = wp.tile_ones(shape=(SMALL_DIM_M, SMALL_DIM_N), dtype=float, storage="shared") * 2.0
150
+
151
+ return a + b
152
+
153
+ @wp.func
154
+ def add_tile_big():
155
+ a = wp.tile_ones(shape=(DIM_M, DIM_N), dtype=float, storage="shared")
156
+ b = wp.tile_ones(shape=(DIM_M, DIM_N), dtype=float, storage="shared") * 2.0
157
+
158
+ return a + b
159
+
160
+ @wp.kernel(module="unique")
161
+ def compute(out: wp.array2d(dtype=float)):
162
+ s = add_tile_small()
163
+ b = add_tile_big()
164
+
165
+ wp.tile_store(out, b)
166
+
167
+ out = wp.empty((DIM_M, DIM_N), dtype=float, device=device)
168
+
169
+ wp.launch_tiled(compute, dim=[1], inputs=[out], block_dim=BLOCK_DIM, device=device)
170
+
171
+ # check shared memory for kernel on the device
172
+ module_exec = compute.module.load(device, BLOCK_DIM)
173
+ hooks = module_exec.get_kernel_hooks(compute)
174
+
175
+ # ensure that total required dynamic shared is the larger of the two tiles
176
+ expected_required_shared = 64 * 64 * 4 * 2
177
+
178
+ assert hooks.forward_smem_bytes == expected_required_shared
179
+ assert hooks.backward_smem_bytes == expected_required_shared * 2
180
+
181
+
182
+ def round_up(a, b):
183
+ return b * ((a + b - 1) // b)
184
+
185
+
186
+ # checks that using non-16B aligned sizes work
187
+ def test_tile_shared_non_aligned(test, device):
188
+ # Tile size = 4 (float) * 1 * 3 = 12B % 16 != 0
189
+ DIM_M = 1
190
+ DIM_N = 3
191
+
192
+ BLOCK_DIM = 256
193
+
194
+ @wp.func
195
+ def foo():
196
+ a = wp.tile_ones(shape=(DIM_M, DIM_N), dtype=float, storage="shared") * 2.0
197
+ b = wp.tile_ones(shape=(DIM_M, DIM_N), dtype=float, storage="shared") * 3.0
198
+ return a + b
199
+
200
+ @wp.kernel(module="unique")
201
+ def compute(out: wp.array2d(dtype=float)):
202
+ # This test the logic in the stack allocator, which should increment and
203
+ # decrement the stack pointer each time foo() is called
204
+ # Failing to do so correct will make b out of bounds and corrupt the results
205
+ for _ in range(4096):
206
+ foo()
207
+ b = wp.tile_ones(shape=(DIM_M, DIM_N), dtype=float, storage="shared")
208
+ wp.tile_store(out, b)
209
+
210
+ out = wp.empty((DIM_M, DIM_N), dtype=float, device=device)
211
+
212
+ wp.launch_tiled(compute, dim=[1], inputs=[out], block_dim=BLOCK_DIM, device=device)
213
+
214
+ assert_np_equal(out.numpy(), np.ones((DIM_M, DIM_N), dtype=float))
215
+
216
+ # check shared memory for kernel on the device
217
+ module_exec = compute.module.load(device, BLOCK_DIM)
218
+ hooks = module_exec.get_kernel_hooks(compute)
219
+
220
+ # ensure that total required dynamic shared is the larger of the two tiles
221
+ expected_required_shared = 3 * round_up(DIM_M * DIM_N * 4, 16)
222
+
223
+ assert hooks.forward_smem_bytes == expected_required_shared
224
+ assert hooks.backward_smem_bytes == expected_required_shared * 2
225
+
226
+
227
+ def test_tile_shared_vec_accumulation(test, device):
228
+ BLOCK_DIM = 256
229
+
230
+ @wp.kernel(module="unique")
231
+ def compute(indices: wp.array(dtype=int), vecs: wp.array(dtype=wp.vec3), output: wp.array2d(dtype=float)):
232
+ i, j = wp.tid()
233
+
234
+ idx_tile = wp.tile_load(indices, shape=BLOCK_DIM, offset=i * BLOCK_DIM)
235
+ idx = idx_tile[j]
236
+
237
+ s = wp.tile_zeros(shape=(1, 3), dtype=float)
238
+
239
+ s[0, 0] += vecs[idx].x
240
+ s[0, 1] += vecs[idx].y
241
+ s[0, 2] += vecs[idx].z
242
+
243
+ wp.tile_store(output, s, offset=(i, 0))
244
+
245
+ N = BLOCK_DIM * 3
246
+
247
+ basis_vecs = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.float32)
248
+ vecs = wp.array(basis_vecs, dtype=wp.vec3, requires_grad=True, device=device)
249
+
250
+ rng = np.random.default_rng(42)
251
+ indices_np = rng.integers(0, 3, size=N)
252
+
253
+ indices = wp.array(indices_np, dtype=int, requires_grad=True, device=device)
254
+
255
+ output = wp.zeros(shape=(3, 3), dtype=float, requires_grad=True, device=device)
256
+
257
+ tape = wp.Tape()
258
+ with tape:
259
+ wp.launch_tiled(compute, dim=3, inputs=[indices, vecs, output], block_dim=BLOCK_DIM, device=device)
260
+
261
+ output.grad = wp.ones_like(output)
262
+
263
+ tape.backward()
264
+
265
+ n0 = np.count_nonzero(indices_np == 0)
266
+ n1 = np.count_nonzero(indices_np == 1)
267
+ n2 = np.count_nonzero(indices_np == 2)
268
+ true_grads = np.array([[n0, n0, n0], [n1, n1, n1], [n2, n2, n2]])
269
+
270
+ indices_np = indices_np.reshape((3, BLOCK_DIM))
271
+
272
+ def compute_row(idx):
273
+ n0 = np.count_nonzero(indices_np[idx, :] == 0)
274
+ n1 = np.count_nonzero(indices_np[idx, :] == 1)
275
+ n2 = np.count_nonzero(indices_np[idx, :] == 2)
276
+ return np.array([1, 0, 0]) * n0 + np.array([0, 1, 0]) * n1 + np.array([0, 0, 1]) * n2
277
+
278
+ row_0 = compute_row(0)
279
+ row_1 = compute_row(1)
280
+ row_2 = compute_row(2)
281
+
282
+ true_vecs = np.stack([row_0, row_1, row_2])
283
+
284
+ assert_np_equal(output.numpy(), true_vecs)
285
+ assert_np_equal(vecs.grad.numpy(), true_grads)
286
+
287
+
288
+ def test_tile_shared_simple_reduction_add(test, device):
289
+ BLOCK_DIM = 256
290
+
291
+ @wp.kernel(module="unique")
292
+ def compute(x: wp.array(dtype=float), y: wp.array(dtype=float)):
293
+ i, j = wp.tid()
294
+
295
+ t = wp.tile_load(x, shape=BLOCK_DIM, offset=BLOCK_DIM * i)
296
+
297
+ k = BLOCK_DIM // 2
298
+ while k > 0:
299
+ if j < k:
300
+ t[j] += t[j + k]
301
+ k //= 2
302
+
303
+ wp.tile_store(y, wp.tile_view(t, offset=(0,), shape=(1,)), i)
304
+
305
+ N = BLOCK_DIM * 4
306
+ x_np = np.arange(N, dtype=np.float32)
307
+ x = wp.array(x_np, dtype=float, device=device)
308
+ y = wp.zeros(4, dtype=float, device=device)
309
+
310
+ wp.launch_tiled(compute, dim=4, inputs=[x], outputs=[y], block_dim=BLOCK_DIM, device=device)
311
+
312
+ assert_np_equal(np.sum(y.numpy()), np.sum(x_np))
313
+
314
+
315
+ def test_tile_shared_simple_reduction_sub(test, device):
316
+ BLOCK_DIM = 256
317
+
318
+ @wp.kernel(module="unique")
319
+ def compute(x: wp.array(dtype=float), y: wp.array(dtype=float)):
320
+ i, j = wp.tid()
321
+
322
+ t = wp.tile_load(x, shape=BLOCK_DIM, offset=BLOCK_DIM * i)
323
+
324
+ k = BLOCK_DIM // 2
325
+ while k > 0:
326
+ if j < k:
327
+ t[j] -= t[j + k]
328
+ k //= 2
329
+
330
+ wp.tile_store(y, wp.tile_view(t, offset=(0,), shape=(1,)), i)
331
+
332
+ N = BLOCK_DIM * 4
333
+ x_np = np.arange(N, dtype=np.float32)
334
+ x = wp.array(x_np, dtype=float, device=device)
335
+ y = wp.zeros(4, dtype=float, device=device)
336
+
337
+ wp.launch_tiled(compute, dim=4, inputs=[x], outputs=[y], block_dim=BLOCK_DIM, device=device)
338
+
339
+ assert_np_equal(np.sum(y.numpy()), 0.0)
340
+
341
+
342
+ devices = get_cuda_test_devices()
343
+
344
+
345
+ class TestTileSharedMemory(unittest.TestCase):
346
+ pass
347
+
348
+
349
+ add_function_test(
350
+ TestTileSharedMemory, "test_tile_shared_mem_size", test_tile_shared_mem_size, devices=devices, check_output=False
351
+ )
352
+ add_function_test(
353
+ TestTileSharedMemory, "test_tile_shared_mem_large", test_tile_shared_mem_large, devices=devices, check_output=False
354
+ )
355
+ add_function_test(TestTileSharedMemory, "test_tile_shared_mem_graph", test_tile_shared_mem_graph, devices=devices)
356
+ add_function_test(TestTileSharedMemory, "test_tile_shared_mem_func", test_tile_shared_mem_func, devices=devices)
357
+ add_function_test(TestTileSharedMemory, "test_tile_shared_non_aligned", test_tile_shared_non_aligned, devices=devices)
358
+ add_function_test(
359
+ TestTileSharedMemory, "test_tile_shared_vec_accumulation", test_tile_shared_vec_accumulation, devices=devices
360
+ )
361
+ add_function_test(
362
+ TestTileSharedMemory,
363
+ "test_tile_shared_simple_reduction_add",
364
+ test_tile_shared_simple_reduction_add,
365
+ devices=devices,
366
+ )
367
+ add_function_test(
368
+ TestTileSharedMemory,
369
+ "test_tile_shared_simple_reduction_sub",
370
+ test_tile_shared_simple_reduction_sub,
371
+ devices=devices,
372
+ )
373
+
374
+ if __name__ == "__main__":
375
+ wp.clear_kernel_cache()
376
+ unittest.main(verbosity=2, failfast=True)
@@ -0,0 +1,121 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import unittest
17
+
18
+ import numpy as np
19
+
20
+ import warp as wp
21
+ from warp.tests.unittest_utils import *
22
+
23
+
24
+ def create_sort_kernel(KEY_TYPE, MAX_SORT_LENGTH):
25
+ @wp.kernel
26
+ def tile_sort_kernel(
27
+ input_keys: wp.array(dtype=KEY_TYPE),
28
+ input_values: wp.array(dtype=wp.int32),
29
+ output_keys: wp.array(dtype=KEY_TYPE),
30
+ output_values: wp.array(dtype=wp.int32),
31
+ ):
32
+ # Load input into shared memory
33
+ keys = wp.tile_load(input_keys, shape=MAX_SORT_LENGTH, storage="shared")
34
+ values = wp.tile_load(input_values, shape=MAX_SORT_LENGTH, storage="shared")
35
+
36
+ # Perform in-place sorting
37
+ wp.tile_sort(keys, values)
38
+
39
+ # Store sorted shared memory into output arrays
40
+ wp.tile_store(output_keys, keys)
41
+ wp.tile_store(output_values, values)
42
+
43
+ return tile_sort_kernel
44
+
45
+
46
+ def test_tile_sort(test, device):
47
+ # Forward-declare kernels for more efficient compilation
48
+ kernels = {}
49
+ for dtype in [int, float]:
50
+ for i in range(0, 11):
51
+ length = 2**i + 1
52
+ kernels[(dtype, length)] = create_sort_kernel(dtype, length)
53
+
54
+ for (dtype, length), kernel in kernels.items():
55
+ for j in range(5, 10):
56
+ TILE_DIM = 2**j
57
+
58
+ rng = np.random.default_rng(42) # Create a random generator instance
59
+
60
+ if dtype == int:
61
+ np_keys = rng.choice(1000000000, size=length, replace=False)
62
+ else: # dtype == float
63
+ np_keys = rng.uniform(0, 1000000000, size=length).astype(dtype)
64
+
65
+ np_values = np.arange(length)
66
+
67
+ # Generate random keys and iota indexer
68
+ input_keys = wp.array(np_keys, dtype=dtype, device=device)
69
+ input_values = wp.array(np_values, dtype=int, device=device)
70
+ output_keys = wp.zeros_like(input_keys, device=device)
71
+ output_values = wp.zeros_like(input_values, device=device)
72
+
73
+ # Execute sorting kernel
74
+ wp.launch_tiled(
75
+ kernel,
76
+ dim=1,
77
+ inputs=[input_keys, input_values, output_keys, output_values],
78
+ block_dim=TILE_DIM,
79
+ device=device,
80
+ )
81
+ wp.synchronize()
82
+
83
+ # Sort using NumPy for validation
84
+ sorted_indices = np.argsort(np_keys)
85
+ np_sorted_keys = np_keys[sorted_indices]
86
+ np_sorted_values = np_values[sorted_indices]
87
+
88
+ if dtype == int:
89
+ keys_match = np.array_equal(output_keys.numpy(), np_sorted_keys)
90
+ else: # dtype == float
91
+ keys_match = np.allclose(output_keys.numpy(), np_sorted_keys, atol=1e-6) # Use tolerance for floats
92
+
93
+ values_match = np.array_equal(output_values.numpy(), np_sorted_values)
94
+
95
+ if not keys_match or not values_match:
96
+ print(f"Test failed for dtype={dtype}, TILE_DIM={TILE_DIM}, length={length}")
97
+ print("")
98
+ print(output_keys.numpy())
99
+ print(np_sorted_keys)
100
+ print("")
101
+ print(output_values.numpy())
102
+ print(np_sorted_values)
103
+ print("")
104
+
105
+ # Validate results
106
+ test.assertTrue(keys_match, f"Key sorting mismatch for dtype={dtype}!")
107
+ test.assertTrue(values_match, f"Value sorting mismatch for dtype={dtype}!")
108
+
109
+
110
+ devices = get_test_devices()
111
+
112
+
113
+ class TestTileSort(unittest.TestCase):
114
+ pass
115
+
116
+
117
+ add_function_test(TestTileSort, "test_tile_sort", test_tile_sort, devices=devices)
118
+
119
+ if __name__ == "__main__":
120
+ wp.clear_kernel_cache()
121
+ unittest.main(verbosity=2, failfast=True)
@@ -0,0 +1,173 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import unittest
17
+
18
+ import numpy as np
19
+
20
+ import warp as wp
21
+ from warp.tests.unittest_utils import *
22
+
23
+ TILE_DIM = 64
24
+ TILE_M = 16
25
+ TILE_N = 32
26
+ TILE_O = 8
27
+
28
+
29
+ @wp.kernel
30
+ def test_tile_view_kernel(src: wp.array2d(dtype=float), dst: wp.array2d(dtype=float)):
31
+ # load whole source into local memory
32
+ a = wp.tile_load(src, shape=(TILE_M, TILE_N))
33
+
34
+ # copy the source array row by row
35
+ for i in range(TILE_M):
36
+ # create a view on original array and store
37
+ row = a[i]
38
+ wp.tile_store(dst[i], row)
39
+
40
+
41
+ def test_tile_view(test, device):
42
+ rng = np.random.default_rng(42)
43
+
44
+ a = wp.array(rng.random((TILE_M, TILE_N), dtype=np.float32), requires_grad=True, device=device)
45
+ b = wp.array(np.zeros((TILE_M, TILE_N), dtype=np.float32), requires_grad=True, device=device)
46
+
47
+ with wp.Tape() as tape:
48
+ wp.launch_tiled(test_tile_view_kernel, dim=[1], inputs=[a, b], block_dim=32, device=device)
49
+
50
+ assert_np_equal(b.numpy(), a.numpy())
51
+ b.grad = wp.ones_like(b, device=device)
52
+ tape.backward()
53
+
54
+ assert_np_equal(a.grad.numpy(), np.ones_like(a.numpy()))
55
+
56
+
57
+ @wp.kernel
58
+ def test_tile_assign_1d_kernel(src: wp.array2d(dtype=float), dst: wp.array2d(dtype=float)):
59
+ # load whole source into local memory
60
+ a = wp.tile_load(src, shape=(TILE_M, TILE_N))
61
+ b = wp.tile_zeros(dtype=float, shape=(TILE_M, TILE_N))
62
+
63
+ # copy the source array row by row
64
+ for i in range(int(TILE_M)):
65
+ # create views onto source and dest rows
66
+ row_src = a[i]
67
+ row_dst = b[i]
68
+
69
+ # copy onto dest row
70
+ wp.tile_assign(row_dst, row_src)
71
+
72
+ wp.tile_store(dst, b)
73
+
74
+
75
+ def test_tile_assign_1d(test, device):
76
+ rng = np.random.default_rng(42)
77
+
78
+ a = wp.array(rng.random((TILE_M, TILE_N), dtype=np.float32), requires_grad=True, device=device)
79
+ b = wp.array(np.zeros((TILE_M, TILE_N), dtype=np.float32), requires_grad=True, device=device)
80
+
81
+ with wp.Tape() as tape:
82
+ wp.launch_tiled(test_tile_assign_1d_kernel, dim=[1], inputs=[a, b], block_dim=32, device=device)
83
+
84
+ assert_np_equal(b.numpy(), a.numpy())
85
+ b.grad = wp.ones_like(b, device=device)
86
+ tape.backward()
87
+
88
+ assert_np_equal(a.grad.numpy(), np.ones_like(a.numpy()))
89
+
90
+
91
+ @wp.kernel
92
+ def test_tile_assign_2d_kernel(src: wp.array3d(dtype=float), dst: wp.array3d(dtype=float)):
93
+ # load whole source into local memory
94
+ a = wp.tile_load(src, shape=(TILE_M, TILE_N, TILE_O))
95
+ b = wp.tile_zeros(dtype=float, shape=(TILE_M, TILE_N, TILE_O))
96
+
97
+ # copy the source array slice by slice
98
+ for i in range(TILE_M):
99
+ # create views onto source and dest slice
100
+ row_src = a[i]
101
+ row_dst = b[i]
102
+
103
+ # copy onto dest slice
104
+ wp.tile_assign(row_dst, row_src)
105
+
106
+ wp.tile_store(dst, b)
107
+
108
+
109
+ def test_tile_assign_2d(test, device):
110
+ rng = np.random.default_rng(42)
111
+
112
+ a = wp.array(rng.random((TILE_M, TILE_N, TILE_O), dtype=np.float32), requires_grad=True, device=device)
113
+ b = wp.array(np.zeros((TILE_M, TILE_N, TILE_O), dtype=np.float32), requires_grad=True, device=device)
114
+
115
+ with wp.Tape() as tape:
116
+ wp.launch_tiled(test_tile_assign_2d_kernel, dim=[1], inputs=[a, b], block_dim=32, device=device)
117
+
118
+ assert_np_equal(b.numpy(), a.numpy())
119
+ b.grad = wp.ones_like(b, device=device)
120
+ tape.backward()
121
+
122
+ assert_np_equal(a.grad.numpy(), np.ones_like(a.numpy()))
123
+
124
+
125
+ @wp.kernel
126
+ def test_tile_view_offset_kernel(src: wp.array2d(dtype=float), dst: wp.array2d(dtype=float)):
127
+ # load whole source into local memory
128
+ a = wp.tile_load(src, shape=(TILE_M, TILE_N))
129
+ b = wp.tile_zeros(shape=(TILE_M, TILE_N), dtype=float)
130
+
131
+ # copy the source array slice by slice
132
+ for i in range(TILE_M // 4):
133
+ # create views onto source and dest slice 4 rows at a time
134
+ v = wp.tile_view(a, offset=(i * 4, 0), shape=(4, TILE_N))
135
+
136
+ # copy onto dest slice
137
+ wp.tile_assign(b, v, offset=(i * 4, 0))
138
+
139
+ wp.tile_store(dst, b)
140
+
141
+
142
+ def test_tile_view_offset(test, device):
143
+ rng = np.random.default_rng(42)
144
+
145
+ a = wp.array(rng.random((TILE_M, TILE_N), dtype=np.float32), requires_grad=True, device=device)
146
+ b = wp.array(np.zeros((TILE_M, TILE_N), dtype=np.float32), requires_grad=True, device=device)
147
+
148
+ with wp.Tape() as tape:
149
+ wp.launch_tiled(test_tile_view_offset_kernel, dim=[1], inputs=[a, b], block_dim=32, device=device)
150
+
151
+ assert_np_equal(b.numpy(), a.numpy())
152
+ b.grad = wp.ones_like(b, device=device)
153
+ tape.backward()
154
+
155
+ assert_np_equal(a.grad.numpy(), np.ones_like(a.numpy()))
156
+
157
+
158
+ devices = get_test_devices()
159
+
160
+
161
+ class TestTileView(unittest.TestCase):
162
+ pass
163
+
164
+
165
+ add_function_test(TestTileView, "test_tile_view", test_tile_view, devices=devices)
166
+ add_function_test(TestTileView, "test_tile_view_offset", test_tile_view_offset, devices=devices)
167
+ add_function_test(TestTileView, "test_tile_assign_1d", test_tile_assign_1d, devices=devices)
168
+ add_function_test(TestTileView, "test_tile_assign_2d", test_tile_assign_2d, devices=devices)
169
+
170
+
171
+ if __name__ == "__main__":
172
+ wp.clear_kernel_cache()
173
+ unittest.main(verbosity=2, failfast=True)