warp-lang 1.7.0__py3-none-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (429) hide show
  1. warp/__init__.py +139 -0
  2. warp/__init__.pyi +1 -0
  3. warp/autograd.py +1142 -0
  4. warp/bin/warp-clang.so +0 -0
  5. warp/bin/warp.so +0 -0
  6. warp/build.py +557 -0
  7. warp/build_dll.py +405 -0
  8. warp/builtins.py +6855 -0
  9. warp/codegen.py +3969 -0
  10. warp/config.py +158 -0
  11. warp/constants.py +57 -0
  12. warp/context.py +6812 -0
  13. warp/dlpack.py +462 -0
  14. warp/examples/__init__.py +24 -0
  15. warp/examples/assets/bear.usd +0 -0
  16. warp/examples/assets/bunny.usd +0 -0
  17. warp/examples/assets/cartpole.urdf +110 -0
  18. warp/examples/assets/crazyflie.usd +0 -0
  19. warp/examples/assets/cube.usd +0 -0
  20. warp/examples/assets/nonuniform.usd +0 -0
  21. warp/examples/assets/nv_ant.xml +92 -0
  22. warp/examples/assets/nv_humanoid.xml +183 -0
  23. warp/examples/assets/nvidia_logo.png +0 -0
  24. warp/examples/assets/pixel.jpg +0 -0
  25. warp/examples/assets/quadruped.urdf +268 -0
  26. warp/examples/assets/rocks.nvdb +0 -0
  27. warp/examples/assets/rocks.usd +0 -0
  28. warp/examples/assets/sphere.usd +0 -0
  29. warp/examples/assets/square_cloth.usd +0 -0
  30. warp/examples/benchmarks/benchmark_api.py +389 -0
  31. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  32. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  33. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  34. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  35. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  36. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  37. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  38. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  39. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  40. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  41. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  42. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  43. warp/examples/benchmarks/benchmark_launches.py +301 -0
  44. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  45. warp/examples/browse.py +37 -0
  46. warp/examples/core/example_cupy.py +86 -0
  47. warp/examples/core/example_dem.py +241 -0
  48. warp/examples/core/example_fluid.py +299 -0
  49. warp/examples/core/example_graph_capture.py +150 -0
  50. warp/examples/core/example_marching_cubes.py +194 -0
  51. warp/examples/core/example_mesh.py +180 -0
  52. warp/examples/core/example_mesh_intersect.py +211 -0
  53. warp/examples/core/example_nvdb.py +182 -0
  54. warp/examples/core/example_raycast.py +111 -0
  55. warp/examples/core/example_raymarch.py +205 -0
  56. warp/examples/core/example_render_opengl.py +193 -0
  57. warp/examples/core/example_sample_mesh.py +300 -0
  58. warp/examples/core/example_sph.py +411 -0
  59. warp/examples/core/example_torch.py +211 -0
  60. warp/examples/core/example_wave.py +269 -0
  61. warp/examples/fem/example_adaptive_grid.py +286 -0
  62. warp/examples/fem/example_apic_fluid.py +423 -0
  63. warp/examples/fem/example_burgers.py +261 -0
  64. warp/examples/fem/example_convection_diffusion.py +178 -0
  65. warp/examples/fem/example_convection_diffusion_dg.py +204 -0
  66. warp/examples/fem/example_deformed_geometry.py +172 -0
  67. warp/examples/fem/example_diffusion.py +196 -0
  68. warp/examples/fem/example_diffusion_3d.py +225 -0
  69. warp/examples/fem/example_diffusion_mgpu.py +220 -0
  70. warp/examples/fem/example_distortion_energy.py +228 -0
  71. warp/examples/fem/example_magnetostatics.py +240 -0
  72. warp/examples/fem/example_mixed_elasticity.py +291 -0
  73. warp/examples/fem/example_navier_stokes.py +261 -0
  74. warp/examples/fem/example_nonconforming_contact.py +298 -0
  75. warp/examples/fem/example_stokes.py +213 -0
  76. warp/examples/fem/example_stokes_transfer.py +262 -0
  77. warp/examples/fem/example_streamlines.py +352 -0
  78. warp/examples/fem/utils.py +1000 -0
  79. warp/examples/interop/example_jax_callable.py +116 -0
  80. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  81. warp/examples/interop/example_jax_kernel.py +205 -0
  82. warp/examples/optim/example_bounce.py +266 -0
  83. warp/examples/optim/example_cloth_throw.py +228 -0
  84. warp/examples/optim/example_diffray.py +561 -0
  85. warp/examples/optim/example_drone.py +870 -0
  86. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  87. warp/examples/optim/example_inverse_kinematics.py +182 -0
  88. warp/examples/optim/example_inverse_kinematics_torch.py +191 -0
  89. warp/examples/optim/example_softbody_properties.py +400 -0
  90. warp/examples/optim/example_spring_cage.py +245 -0
  91. warp/examples/optim/example_trajectory.py +227 -0
  92. warp/examples/sim/example_cartpole.py +143 -0
  93. warp/examples/sim/example_cloth.py +225 -0
  94. warp/examples/sim/example_cloth_self_contact.py +322 -0
  95. warp/examples/sim/example_granular.py +130 -0
  96. warp/examples/sim/example_granular_collision_sdf.py +202 -0
  97. warp/examples/sim/example_jacobian_ik.py +244 -0
  98. warp/examples/sim/example_particle_chain.py +124 -0
  99. warp/examples/sim/example_quadruped.py +203 -0
  100. warp/examples/sim/example_rigid_chain.py +203 -0
  101. warp/examples/sim/example_rigid_contact.py +195 -0
  102. warp/examples/sim/example_rigid_force.py +133 -0
  103. warp/examples/sim/example_rigid_gyroscopic.py +115 -0
  104. warp/examples/sim/example_rigid_soft_contact.py +140 -0
  105. warp/examples/sim/example_soft_body.py +196 -0
  106. warp/examples/tile/example_tile_cholesky.py +87 -0
  107. warp/examples/tile/example_tile_convolution.py +66 -0
  108. warp/examples/tile/example_tile_fft.py +55 -0
  109. warp/examples/tile/example_tile_filtering.py +113 -0
  110. warp/examples/tile/example_tile_matmul.py +85 -0
  111. warp/examples/tile/example_tile_mlp.py +383 -0
  112. warp/examples/tile/example_tile_nbody.py +199 -0
  113. warp/examples/tile/example_tile_walker.py +327 -0
  114. warp/fabric.py +355 -0
  115. warp/fem/__init__.py +106 -0
  116. warp/fem/adaptivity.py +508 -0
  117. warp/fem/cache.py +572 -0
  118. warp/fem/dirichlet.py +202 -0
  119. warp/fem/domain.py +411 -0
  120. warp/fem/field/__init__.py +125 -0
  121. warp/fem/field/field.py +619 -0
  122. warp/fem/field/nodal_field.py +326 -0
  123. warp/fem/field/restriction.py +37 -0
  124. warp/fem/field/virtual.py +848 -0
  125. warp/fem/geometry/__init__.py +32 -0
  126. warp/fem/geometry/adaptive_nanogrid.py +857 -0
  127. warp/fem/geometry/closest_point.py +84 -0
  128. warp/fem/geometry/deformed_geometry.py +221 -0
  129. warp/fem/geometry/element.py +776 -0
  130. warp/fem/geometry/geometry.py +362 -0
  131. warp/fem/geometry/grid_2d.py +392 -0
  132. warp/fem/geometry/grid_3d.py +452 -0
  133. warp/fem/geometry/hexmesh.py +911 -0
  134. warp/fem/geometry/nanogrid.py +571 -0
  135. warp/fem/geometry/partition.py +389 -0
  136. warp/fem/geometry/quadmesh.py +663 -0
  137. warp/fem/geometry/tetmesh.py +855 -0
  138. warp/fem/geometry/trimesh.py +806 -0
  139. warp/fem/integrate.py +2335 -0
  140. warp/fem/linalg.py +419 -0
  141. warp/fem/operator.py +293 -0
  142. warp/fem/polynomial.py +229 -0
  143. warp/fem/quadrature/__init__.py +17 -0
  144. warp/fem/quadrature/pic_quadrature.py +299 -0
  145. warp/fem/quadrature/quadrature.py +591 -0
  146. warp/fem/space/__init__.py +228 -0
  147. warp/fem/space/basis_function_space.py +468 -0
  148. warp/fem/space/basis_space.py +667 -0
  149. warp/fem/space/dof_mapper.py +251 -0
  150. warp/fem/space/function_space.py +309 -0
  151. warp/fem/space/grid_2d_function_space.py +177 -0
  152. warp/fem/space/grid_3d_function_space.py +227 -0
  153. warp/fem/space/hexmesh_function_space.py +257 -0
  154. warp/fem/space/nanogrid_function_space.py +201 -0
  155. warp/fem/space/partition.py +367 -0
  156. warp/fem/space/quadmesh_function_space.py +223 -0
  157. warp/fem/space/restriction.py +179 -0
  158. warp/fem/space/shape/__init__.py +143 -0
  159. warp/fem/space/shape/cube_shape_function.py +1105 -0
  160. warp/fem/space/shape/shape_function.py +133 -0
  161. warp/fem/space/shape/square_shape_function.py +926 -0
  162. warp/fem/space/shape/tet_shape_function.py +834 -0
  163. warp/fem/space/shape/triangle_shape_function.py +672 -0
  164. warp/fem/space/tetmesh_function_space.py +271 -0
  165. warp/fem/space/topology.py +424 -0
  166. warp/fem/space/trimesh_function_space.py +194 -0
  167. warp/fem/types.py +99 -0
  168. warp/fem/utils.py +420 -0
  169. warp/jax.py +187 -0
  170. warp/jax_experimental/__init__.py +16 -0
  171. warp/jax_experimental/custom_call.py +351 -0
  172. warp/jax_experimental/ffi.py +698 -0
  173. warp/jax_experimental/xla_ffi.py +602 -0
  174. warp/math.py +244 -0
  175. warp/native/array.h +1145 -0
  176. warp/native/builtin.h +1800 -0
  177. warp/native/bvh.cpp +492 -0
  178. warp/native/bvh.cu +791 -0
  179. warp/native/bvh.h +554 -0
  180. warp/native/clang/clang.cpp +536 -0
  181. warp/native/coloring.cpp +613 -0
  182. warp/native/crt.cpp +51 -0
  183. warp/native/crt.h +362 -0
  184. warp/native/cuda_crt.h +1058 -0
  185. warp/native/cuda_util.cpp +646 -0
  186. warp/native/cuda_util.h +307 -0
  187. warp/native/error.cpp +77 -0
  188. warp/native/error.h +36 -0
  189. warp/native/exports.h +1878 -0
  190. warp/native/fabric.h +245 -0
  191. warp/native/hashgrid.cpp +311 -0
  192. warp/native/hashgrid.cu +87 -0
  193. warp/native/hashgrid.h +240 -0
  194. warp/native/initializer_array.h +41 -0
  195. warp/native/intersect.h +1230 -0
  196. warp/native/intersect_adj.h +375 -0
  197. warp/native/intersect_tri.h +339 -0
  198. warp/native/marching.cpp +19 -0
  199. warp/native/marching.cu +514 -0
  200. warp/native/marching.h +19 -0
  201. warp/native/mat.h +2220 -0
  202. warp/native/mathdx.cpp +87 -0
  203. warp/native/matnn.h +343 -0
  204. warp/native/mesh.cpp +266 -0
  205. warp/native/mesh.cu +404 -0
  206. warp/native/mesh.h +1980 -0
  207. warp/native/nanovdb/GridHandle.h +366 -0
  208. warp/native/nanovdb/HostBuffer.h +590 -0
  209. warp/native/nanovdb/NanoVDB.h +6624 -0
  210. warp/native/nanovdb/PNanoVDB.h +3390 -0
  211. warp/native/noise.h +859 -0
  212. warp/native/quat.h +1371 -0
  213. warp/native/rand.h +342 -0
  214. warp/native/range.h +139 -0
  215. warp/native/reduce.cpp +174 -0
  216. warp/native/reduce.cu +364 -0
  217. warp/native/runlength_encode.cpp +79 -0
  218. warp/native/runlength_encode.cu +61 -0
  219. warp/native/scan.cpp +47 -0
  220. warp/native/scan.cu +53 -0
  221. warp/native/scan.h +23 -0
  222. warp/native/solid_angle.h +466 -0
  223. warp/native/sort.cpp +251 -0
  224. warp/native/sort.cu +277 -0
  225. warp/native/sort.h +33 -0
  226. warp/native/sparse.cpp +378 -0
  227. warp/native/sparse.cu +524 -0
  228. warp/native/spatial.h +657 -0
  229. warp/native/svd.h +702 -0
  230. warp/native/temp_buffer.h +46 -0
  231. warp/native/tile.h +2584 -0
  232. warp/native/tile_reduce.h +264 -0
  233. warp/native/vec.h +1426 -0
  234. warp/native/volume.cpp +501 -0
  235. warp/native/volume.cu +67 -0
  236. warp/native/volume.h +969 -0
  237. warp/native/volume_builder.cu +477 -0
  238. warp/native/volume_builder.h +52 -0
  239. warp/native/volume_impl.h +70 -0
  240. warp/native/warp.cpp +1082 -0
  241. warp/native/warp.cu +3636 -0
  242. warp/native/warp.h +381 -0
  243. warp/optim/__init__.py +17 -0
  244. warp/optim/adam.py +163 -0
  245. warp/optim/linear.py +1137 -0
  246. warp/optim/sgd.py +112 -0
  247. warp/paddle.py +407 -0
  248. warp/render/__init__.py +18 -0
  249. warp/render/render_opengl.py +3518 -0
  250. warp/render/render_usd.py +784 -0
  251. warp/render/utils.py +160 -0
  252. warp/sim/__init__.py +65 -0
  253. warp/sim/articulation.py +793 -0
  254. warp/sim/collide.py +2395 -0
  255. warp/sim/graph_coloring.py +300 -0
  256. warp/sim/import_mjcf.py +790 -0
  257. warp/sim/import_snu.py +227 -0
  258. warp/sim/import_urdf.py +579 -0
  259. warp/sim/import_usd.py +894 -0
  260. warp/sim/inertia.py +324 -0
  261. warp/sim/integrator.py +242 -0
  262. warp/sim/integrator_euler.py +1997 -0
  263. warp/sim/integrator_featherstone.py +2101 -0
  264. warp/sim/integrator_vbd.py +2048 -0
  265. warp/sim/integrator_xpbd.py +3292 -0
  266. warp/sim/model.py +4791 -0
  267. warp/sim/particles.py +121 -0
  268. warp/sim/render.py +427 -0
  269. warp/sim/utils.py +428 -0
  270. warp/sparse.py +2057 -0
  271. warp/stubs.py +3333 -0
  272. warp/tape.py +1203 -0
  273. warp/tests/__init__.py +1 -0
  274. warp/tests/__main__.py +4 -0
  275. warp/tests/assets/curlnoise_golden.npy +0 -0
  276. warp/tests/assets/mlp_golden.npy +0 -0
  277. warp/tests/assets/pixel.npy +0 -0
  278. warp/tests/assets/pnoise_golden.npy +0 -0
  279. warp/tests/assets/spiky.usd +0 -0
  280. warp/tests/assets/test_grid.nvdb +0 -0
  281. warp/tests/assets/test_index_grid.nvdb +0 -0
  282. warp/tests/assets/test_int32_grid.nvdb +0 -0
  283. warp/tests/assets/test_vec_grid.nvdb +0 -0
  284. warp/tests/assets/torus.nvdb +0 -0
  285. warp/tests/assets/torus.usda +105 -0
  286. warp/tests/aux_test_class_kernel.py +34 -0
  287. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  288. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  289. warp/tests/aux_test_dependent.py +29 -0
  290. warp/tests/aux_test_grad_customs.py +29 -0
  291. warp/tests/aux_test_instancing_gc.py +26 -0
  292. warp/tests/aux_test_module_unload.py +23 -0
  293. warp/tests/aux_test_name_clash1.py +40 -0
  294. warp/tests/aux_test_name_clash2.py +40 -0
  295. warp/tests/aux_test_reference.py +9 -0
  296. warp/tests/aux_test_reference_reference.py +8 -0
  297. warp/tests/aux_test_square.py +16 -0
  298. warp/tests/aux_test_unresolved_func.py +22 -0
  299. warp/tests/aux_test_unresolved_symbol.py +22 -0
  300. warp/tests/cuda/__init__.py +0 -0
  301. warp/tests/cuda/test_async.py +676 -0
  302. warp/tests/cuda/test_ipc.py +124 -0
  303. warp/tests/cuda/test_mempool.py +233 -0
  304. warp/tests/cuda/test_multigpu.py +169 -0
  305. warp/tests/cuda/test_peer.py +139 -0
  306. warp/tests/cuda/test_pinned.py +84 -0
  307. warp/tests/cuda/test_streams.py +634 -0
  308. warp/tests/geometry/__init__.py +0 -0
  309. warp/tests/geometry/test_bvh.py +200 -0
  310. warp/tests/geometry/test_hash_grid.py +221 -0
  311. warp/tests/geometry/test_marching_cubes.py +74 -0
  312. warp/tests/geometry/test_mesh.py +316 -0
  313. warp/tests/geometry/test_mesh_query_aabb.py +399 -0
  314. warp/tests/geometry/test_mesh_query_point.py +932 -0
  315. warp/tests/geometry/test_mesh_query_ray.py +311 -0
  316. warp/tests/geometry/test_volume.py +1103 -0
  317. warp/tests/geometry/test_volume_write.py +346 -0
  318. warp/tests/interop/__init__.py +0 -0
  319. warp/tests/interop/test_dlpack.py +729 -0
  320. warp/tests/interop/test_jax.py +371 -0
  321. warp/tests/interop/test_paddle.py +800 -0
  322. warp/tests/interop/test_torch.py +1001 -0
  323. warp/tests/run_coverage_serial.py +39 -0
  324. warp/tests/sim/__init__.py +0 -0
  325. warp/tests/sim/disabled_kinematics.py +244 -0
  326. warp/tests/sim/flaky_test_sim_grad.py +290 -0
  327. warp/tests/sim/test_collision.py +604 -0
  328. warp/tests/sim/test_coloring.py +258 -0
  329. warp/tests/sim/test_model.py +224 -0
  330. warp/tests/sim/test_sim_grad_bounce_linear.py +212 -0
  331. warp/tests/sim/test_sim_kinematics.py +98 -0
  332. warp/tests/sim/test_vbd.py +597 -0
  333. warp/tests/test_adam.py +163 -0
  334. warp/tests/test_arithmetic.py +1096 -0
  335. warp/tests/test_array.py +2972 -0
  336. warp/tests/test_array_reduce.py +156 -0
  337. warp/tests/test_assert.py +250 -0
  338. warp/tests/test_atomic.py +153 -0
  339. warp/tests/test_bool.py +220 -0
  340. warp/tests/test_builtins_resolution.py +1298 -0
  341. warp/tests/test_closest_point_edge_edge.py +327 -0
  342. warp/tests/test_codegen.py +810 -0
  343. warp/tests/test_codegen_instancing.py +1495 -0
  344. warp/tests/test_compile_consts.py +215 -0
  345. warp/tests/test_conditional.py +252 -0
  346. warp/tests/test_context.py +42 -0
  347. warp/tests/test_copy.py +238 -0
  348. warp/tests/test_ctypes.py +638 -0
  349. warp/tests/test_dense.py +73 -0
  350. warp/tests/test_devices.py +97 -0
  351. warp/tests/test_examples.py +482 -0
  352. warp/tests/test_fabricarray.py +996 -0
  353. warp/tests/test_fast_math.py +74 -0
  354. warp/tests/test_fem.py +2003 -0
  355. warp/tests/test_fp16.py +136 -0
  356. warp/tests/test_func.py +454 -0
  357. warp/tests/test_future_annotations.py +98 -0
  358. warp/tests/test_generics.py +656 -0
  359. warp/tests/test_grad.py +893 -0
  360. warp/tests/test_grad_customs.py +339 -0
  361. warp/tests/test_grad_debug.py +341 -0
  362. warp/tests/test_implicit_init.py +411 -0
  363. warp/tests/test_import.py +45 -0
  364. warp/tests/test_indexedarray.py +1140 -0
  365. warp/tests/test_intersect.py +73 -0
  366. warp/tests/test_iter.py +76 -0
  367. warp/tests/test_large.py +177 -0
  368. warp/tests/test_launch.py +411 -0
  369. warp/tests/test_lerp.py +151 -0
  370. warp/tests/test_linear_solvers.py +193 -0
  371. warp/tests/test_lvalue.py +427 -0
  372. warp/tests/test_mat.py +2089 -0
  373. warp/tests/test_mat_lite.py +122 -0
  374. warp/tests/test_mat_scalar_ops.py +2913 -0
  375. warp/tests/test_math.py +178 -0
  376. warp/tests/test_mlp.py +282 -0
  377. warp/tests/test_module_hashing.py +258 -0
  378. warp/tests/test_modules_lite.py +44 -0
  379. warp/tests/test_noise.py +252 -0
  380. warp/tests/test_operators.py +299 -0
  381. warp/tests/test_options.py +129 -0
  382. warp/tests/test_overwrite.py +551 -0
  383. warp/tests/test_print.py +339 -0
  384. warp/tests/test_quat.py +2315 -0
  385. warp/tests/test_rand.py +339 -0
  386. warp/tests/test_reload.py +302 -0
  387. warp/tests/test_rounding.py +185 -0
  388. warp/tests/test_runlength_encode.py +196 -0
  389. warp/tests/test_scalar_ops.py +105 -0
  390. warp/tests/test_smoothstep.py +108 -0
  391. warp/tests/test_snippet.py +318 -0
  392. warp/tests/test_sparse.py +582 -0
  393. warp/tests/test_spatial.py +2229 -0
  394. warp/tests/test_special_values.py +361 -0
  395. warp/tests/test_static.py +592 -0
  396. warp/tests/test_struct.py +734 -0
  397. warp/tests/test_tape.py +204 -0
  398. warp/tests/test_transient_module.py +93 -0
  399. warp/tests/test_triangle_closest_point.py +145 -0
  400. warp/tests/test_types.py +562 -0
  401. warp/tests/test_utils.py +588 -0
  402. warp/tests/test_vec.py +1487 -0
  403. warp/tests/test_vec_lite.py +80 -0
  404. warp/tests/test_vec_scalar_ops.py +2327 -0
  405. warp/tests/test_verify_fp.py +100 -0
  406. warp/tests/tile/__init__.py +0 -0
  407. warp/tests/tile/test_tile.py +780 -0
  408. warp/tests/tile/test_tile_load.py +407 -0
  409. warp/tests/tile/test_tile_mathdx.py +208 -0
  410. warp/tests/tile/test_tile_mlp.py +402 -0
  411. warp/tests/tile/test_tile_reduce.py +447 -0
  412. warp/tests/tile/test_tile_shared_memory.py +247 -0
  413. warp/tests/tile/test_tile_view.py +173 -0
  414. warp/tests/unittest_serial.py +47 -0
  415. warp/tests/unittest_suites.py +427 -0
  416. warp/tests/unittest_utils.py +468 -0
  417. warp/tests/walkthrough_debug.py +93 -0
  418. warp/thirdparty/__init__.py +0 -0
  419. warp/thirdparty/appdirs.py +598 -0
  420. warp/thirdparty/dlpack.py +145 -0
  421. warp/thirdparty/unittest_parallel.py +570 -0
  422. warp/torch.py +391 -0
  423. warp/types.py +5230 -0
  424. warp/utils.py +1137 -0
  425. warp_lang-1.7.0.dist-info/METADATA +516 -0
  426. warp_lang-1.7.0.dist-info/RECORD +429 -0
  427. warp_lang-1.7.0.dist-info/WHEEL +5 -0
  428. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  429. warp_lang-1.7.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,447 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import unittest
17
+
18
+ import numpy as np
19
+
20
+ import warp as wp
21
+ from warp.tests.unittest_utils import *
22
+
23
+ TILE_M = wp.constant(8)
24
+ TILE_N = wp.constant(4)
25
+ TILE_K = wp.constant(8)
26
+
27
+ # num threads per-tile
28
+ TILE_DIM = 64
29
+
30
+
31
+ @wp.kernel
32
+ def tile_sum_kernel(input: wp.array2d(dtype=float), output: wp.array(dtype=float)):
33
+ # output tile index
34
+ i = wp.tid()
35
+
36
+ n = input.shape[1]
37
+ count = int(n / TILE_DIM)
38
+
39
+ s = wp.tile_zeros(shape=1, dtype=float)
40
+
41
+ for j in range(count):
42
+ a = wp.tile_load(input[i], shape=TILE_DIM, offset=j * TILE_DIM)
43
+ s += wp.tile_sum(a) * 0.5
44
+
45
+ wp.tile_store(output, s, offset=i)
46
+
47
+
48
+ def test_tile_reduce_sum(test, device):
49
+ batch_count = 56
50
+
51
+ N = TILE_DIM * 3
52
+
53
+ rng = np.random.default_rng(42)
54
+ input = rng.random((batch_count, N), dtype=np.float32)
55
+
56
+ input_wp = wp.array(input, requires_grad=True, device=device)
57
+ output_wp = wp.zeros(batch_count, requires_grad=True, device=device)
58
+
59
+ with wp.Tape() as tape:
60
+ wp.launch_tiled(
61
+ tile_sum_kernel, dim=[batch_count], inputs=[input_wp, output_wp], block_dim=TILE_DIM, device=device
62
+ )
63
+
64
+ sum_wp = output_wp.numpy()
65
+ for i in range(batch_count):
66
+ sum_np = np.sum(input[i]) * 0.5
67
+ test.assertAlmostEqual(sum_wp[i], sum_np, places=4)
68
+
69
+ output_wp.grad.fill_(1.0)
70
+
71
+ tape.backward()
72
+
73
+ assert_np_equal(input_wp.grad.numpy(), np.ones_like(input) * 0.5, tol=1.0e-4)
74
+
75
+
76
+ @wp.kernel
77
+ def tile_min_kernel(input: wp.array2d(dtype=float), output: wp.array(dtype=float)):
78
+ # output tile index
79
+ i = wp.tid()
80
+
81
+ a = wp.tile_load(input[i], shape=TILE_DIM)
82
+ m = wp.tile_min(a)
83
+
84
+ wp.tile_store(output, m, offset=i)
85
+
86
+
87
+ def test_tile_reduce_min(test, device):
88
+ batch_count = 56
89
+
90
+ N = TILE_DIM
91
+
92
+ rng = np.random.default_rng(42)
93
+ input = rng.random((batch_count, N), dtype=np.float32)
94
+
95
+ input_wp = wp.array(input, requires_grad=True, device=device)
96
+ output_wp = wp.zeros(batch_count, requires_grad=True, device=device)
97
+
98
+ with wp.Tape() as tape:
99
+ wp.launch_tiled(
100
+ tile_min_kernel, dim=[batch_count], inputs=[input_wp, output_wp], block_dim=TILE_DIM, device=device
101
+ )
102
+
103
+ min_wp = output_wp.numpy()
104
+ for i in range(batch_count):
105
+ min_np = np.min(input[i])
106
+ test.assertAlmostEqual(min_wp[i], min_np, places=4)
107
+
108
+
109
+ @wp.kernel
110
+ def tile_max_kernel(input: wp.array2d(dtype=float), output: wp.array(dtype=float)):
111
+ # output tile index
112
+ i = wp.tid()
113
+
114
+ a = wp.tile_load(input[i], shape=TILE_DIM)
115
+ m = wp.tile_max(a)
116
+
117
+ wp.tile_store(output, m, offset=i)
118
+
119
+
120
+ def test_tile_reduce_max(test, device):
121
+ batch_count = 56
122
+
123
+ N = TILE_DIM
124
+
125
+ rng = np.random.default_rng(42)
126
+ input = rng.random((batch_count, N), dtype=np.float32)
127
+
128
+ input_wp = wp.array(input, requires_grad=True, device=device)
129
+ output_wp = wp.zeros(batch_count, requires_grad=True, device=device)
130
+
131
+ with wp.Tape() as tape:
132
+ wp.launch_tiled(
133
+ tile_max_kernel, dim=[batch_count], inputs=[input_wp, output_wp], block_dim=TILE_DIM, device=device
134
+ )
135
+
136
+ max_wp = output_wp.numpy()
137
+ for i in range(batch_count):
138
+ max_np = np.max(input[i])
139
+ test.assertAlmostEqual(max_wp[i], max_np, places=4)
140
+
141
+
142
+ @wp.kernel
143
+ def tile_reduce_custom_kernel(input: wp.array2d(dtype=float), output: wp.array(dtype=float)):
144
+ # output tile index
145
+ i = wp.tid()
146
+
147
+ a = wp.tile_load(input[i], shape=TILE_DIM)
148
+ m = wp.tile_reduce(wp.mul, a)
149
+
150
+ wp.tile_store(output, m, offset=i)
151
+
152
+
153
+ def test_tile_reduce_custom(test, device):
154
+ batch_count = 56
155
+
156
+ N = TILE_DIM
157
+
158
+ rng = np.random.default_rng(42)
159
+ input = rng.random((batch_count, N), dtype=np.float32)
160
+
161
+ input_wp = wp.array(input, requires_grad=True, device=device)
162
+ output_wp = wp.zeros(batch_count, requires_grad=True, device=device)
163
+
164
+ with wp.Tape() as tape:
165
+ wp.launch_tiled(
166
+ tile_reduce_custom_kernel,
167
+ dim=[batch_count],
168
+ inputs=[input_wp, output_wp],
169
+ block_dim=TILE_DIM,
170
+ device=device,
171
+ )
172
+
173
+ prod_wp = output_wp.numpy()
174
+ for i in range(batch_count):
175
+ prod_np = np.prod(input[i])
176
+ test.assertAlmostEqual(prod_wp[i], prod_np, places=4)
177
+
178
+
179
+ @wp.struct
180
+ class KeyValue:
181
+ key: wp.int32
182
+ value: wp.float32
183
+
184
+
185
+ @wp.func
186
+ def kv_max(a: KeyValue, b: KeyValue) -> KeyValue:
187
+ return wp.where(a.value < b.value, b, a)
188
+
189
+
190
+ @wp.kernel
191
+ def initialize_key_value(values: wp.array2d(dtype=wp.float32), keyvalues: wp.array2d(dtype=KeyValue)):
192
+ batch, idx = wp.tid()
193
+ keyvalues[batch, idx] = KeyValue(idx, values[batch, idx])
194
+
195
+
196
+ @wp.kernel(enable_backward=False)
197
+ def tile_reduce_custom_struct_kernel(values: wp.array2d(dtype=KeyValue), res: wp.array(dtype=KeyValue)):
198
+ # output tile index
199
+ i = wp.tid()
200
+
201
+ t = wp.tile_load(values, shape=(1, TILE_DIM), offset=(i, 0))
202
+
203
+ max_el = wp.tile_reduce(kv_max, t)
204
+ wp.tile_store(res, max_el, offset=i)
205
+
206
+
207
+ def test_tile_reduce_custom_struct(test, device):
208
+ batch_count = 56
209
+
210
+ N = TILE_DIM
211
+
212
+ rng = np.random.default_rng(42)
213
+ input = rng.random((batch_count, N), dtype=np.float32)
214
+
215
+ input_wp = wp.array(input, dtype=wp.float32, device=device)
216
+ keyvalues_wp = wp.empty(input_wp.shape, dtype=KeyValue, device=device)
217
+
218
+ wp.launch(initialize_key_value, dim=[batch_count, N], inputs=[input_wp], outputs=[keyvalues_wp], device=device)
219
+
220
+ output_wp = wp.empty(batch_count, dtype=KeyValue, device=device)
221
+
222
+ wp.launch_tiled(
223
+ tile_reduce_custom_struct_kernel,
224
+ dim=[batch_count],
225
+ inputs=[keyvalues_wp],
226
+ outputs=[output_wp],
227
+ block_dim=TILE_DIM,
228
+ device=device,
229
+ )
230
+
231
+ prod_wp = np.array([k for k, v in output_wp.numpy()])
232
+ expected = np.argmax(input, axis=1)
233
+
234
+ assert_np_equal(prod_wp, expected)
235
+
236
+
237
+ @wp.kernel
238
+ def tile_grouped_sum_kernel(input: wp.array3d(dtype=float), output: wp.array(dtype=float)):
239
+ # output tile index
240
+ i = wp.tid()
241
+
242
+ a = wp.tile_load(input[i], shape=(TILE_M, TILE_N))
243
+ s = wp.tile_sum(a) * 0.5
244
+
245
+ wp.tile_store(output, s, offset=i)
246
+
247
+
248
+ def test_tile_reduce_grouped_sum(test, device):
249
+ batch_count = 56
250
+
251
+ M = TILE_M
252
+ N = TILE_N
253
+
254
+ rng = np.random.default_rng(42)
255
+ input = rng.random((batch_count, M, N), dtype=np.float32)
256
+
257
+ input_wp = wp.array(input, requires_grad=True, device=device)
258
+ output_wp = wp.zeros(batch_count, requires_grad=True, device=device)
259
+
260
+ with wp.Tape() as tape:
261
+ wp.launch_tiled(
262
+ tile_sum_kernel, dim=[batch_count], inputs=[input_wp, output_wp], block_dim=TILE_DIM, device=device
263
+ )
264
+
265
+ sum_wp = output_wp.numpy()
266
+ for i in range(batch_count):
267
+ sum_np = np.sum(input[i]) * 0.5
268
+ test.assertAlmostEqual(sum_wp[i], sum_np, places=4)
269
+
270
+ output_wp.grad.fill_(1.0)
271
+
272
+ tape.backward()
273
+
274
+ assert_np_equal(input_wp.grad.numpy(), np.ones_like(input) * 0.5, tol=1.0e-4)
275
+
276
+
277
+ @wp.kernel
278
+ def tile_reduce_simt_kernel(output: wp.array(dtype=int)):
279
+ # thread index
280
+ i = wp.tid()
281
+
282
+ t = wp.tile(i) # convert to block wide tile
283
+ s = wp.tile_sum(t) # sum over block
284
+
285
+ # update global sum
286
+ wp.tile_atomic_add(output, s)
287
+
288
+
289
+ def test_tile_reduce_simt(test, device):
290
+ # use an unaligned grid dimension
291
+ N = TILE_DIM * 4 + 5
292
+
293
+ output = wp.zeros(shape=1, dtype=int, requires_grad=True, device=device)
294
+
295
+ with wp.Tape() as tape:
296
+ wp.launch(tile_reduce_simt_kernel, dim=N, inputs=[output], block_dim=TILE_DIM, device=device)
297
+
298
+ test.assertEqual(output.numpy()[0], np.sum(np.arange(N)))
299
+
300
+
301
+ @wp.kernel
302
+ def tile_untile_kernel(output: wp.array(dtype=int)):
303
+ # thread index
304
+ i = wp.tid()
305
+
306
+ # convert to block wide tile
307
+ t = wp.tile(i) * 2
308
+ s = wp.untile(t)
309
+
310
+ output[i] = s
311
+
312
+
313
+ def test_tile_untile(test, device):
314
+ # use an unaligned grid dimension
315
+ N = TILE_DIM * 4 + 5
316
+
317
+ output = wp.zeros(shape=N, dtype=int, requires_grad=True, device=device)
318
+
319
+ with wp.Tape() as tape:
320
+ wp.launch(tile_untile_kernel, dim=N, inputs=[output], block_dim=TILE_DIM, device=device)
321
+
322
+ assert_np_equal(output.numpy(), np.arange(N) * 2)
323
+
324
+
325
+ @wp.kernel
326
+ def tile_untile_scalar_kernel(output: wp.array(dtype=int)):
327
+ # thread index
328
+ i = wp.tid()
329
+
330
+ # convert to block wide tile
331
+ t = wp.tile(i) * 2
332
+ s = wp.untile(t)
333
+
334
+ output[i] = s
335
+
336
+
337
+ def test_tile_untile_scalar(test, device):
338
+ # use an unaligned grid dimension
339
+ N = TILE_DIM * 4 + 5
340
+
341
+ output = wp.zeros(shape=N, dtype=int, requires_grad=True, device=device)
342
+
343
+ with wp.Tape() as tape:
344
+ wp.launch(tile_untile_kernel, dim=N, inputs=[output], block_dim=TILE_DIM, device=device)
345
+
346
+ assert_np_equal(output.numpy(), np.arange(N) * 2)
347
+
348
+
349
+ @wp.kernel
350
+ def test_untile_vector_kernel(input: wp.array(dtype=wp.vec3), output: wp.array(dtype=wp.vec3)):
351
+ i = wp.tid()
352
+
353
+ v = input[i] * 0.5
354
+
355
+ t = wp.tile(v)
356
+ u = wp.untile(t)
357
+
358
+ output[i] = u * 2.0
359
+
360
+
361
+ def test_tile_untile_vector(test, device):
362
+ input = wp.full(16, wp.vec3(1.0, 2.0, 3.0), requires_grad=True, device=device)
363
+ output = wp.zeros_like(input, device=device)
364
+
365
+ with wp.Tape() as tape:
366
+ wp.launch(test_untile_vector_kernel, dim=16, inputs=[input, output], block_dim=16, device=device)
367
+
368
+ output.grad = wp.ones_like(output, device=device)
369
+ tape.backward()
370
+
371
+ assert_np_equal(output.numpy(), input.numpy())
372
+ assert_np_equal(input.grad.numpy(), np.ones((16, 3)))
373
+
374
+
375
+ @wp.kernel
376
+ def tile_ones_kernel(out: wp.array(dtype=float)):
377
+ i = wp.tid()
378
+
379
+ t = wp.tile_ones(dtype=float, shape=(16, 16))
380
+ s = wp.tile_sum(t)
381
+
382
+ wp.tile_store(out, s)
383
+
384
+
385
+ def test_tile_ones(test, device):
386
+ output = wp.zeros(1, dtype=float, device=device)
387
+
388
+ with wp.Tape() as tape:
389
+ wp.launch_tiled(tile_ones_kernel, dim=[1], inputs=[output], block_dim=TILE_DIM, device=device)
390
+
391
+ test.assertAlmostEqual(output.numpy()[0], 256.0)
392
+
393
+
394
+ @wp.kernel
395
+ def tile_arange_kernel(out: wp.array2d(dtype=int)):
396
+ i = wp.tid()
397
+
398
+ a = wp.tile_arange(17, dtype=int)
399
+ b = wp.tile_arange(5, 23, dtype=int)
400
+ c = wp.tile_arange(0, 34, 2, dtype=int)
401
+ d = wp.tile_arange(-1, 16, dtype=int)
402
+ e = wp.tile_arange(17, 0, -1, dtype=int)
403
+
404
+ wp.tile_store(out[0], a)
405
+ wp.tile_store(out[1], b)
406
+ wp.tile_store(out[2], c)
407
+ wp.tile_store(out[3], d)
408
+ wp.tile_store(out[4], e)
409
+
410
+
411
+ def test_tile_arange(test, device):
412
+ N = 17
413
+
414
+ output = wp.zeros(shape=(5, N), dtype=int, device=device)
415
+
416
+ with wp.Tape() as tape:
417
+ wp.launch_tiled(tile_arange_kernel, dim=[1], inputs=[output], block_dim=TILE_DIM, device=device)
418
+
419
+ assert_np_equal(output.numpy()[0], np.arange(17))
420
+ assert_np_equal(output.numpy()[1], np.arange(5, 22))
421
+ assert_np_equal(output.numpy()[2], np.arange(0, 34, 2))
422
+ assert_np_equal(output.numpy()[3], np.arange(-1, 16))
423
+ assert_np_equal(output.numpy()[4], np.arange(17, 0, -1))
424
+
425
+
426
+ devices = get_test_devices()
427
+
428
+
429
+ class TestTileReduce(unittest.TestCase):
430
+ pass
431
+
432
+
433
+ add_function_test(TestTileReduce, "test_tile_reduce_sum", test_tile_reduce_sum, devices=devices)
434
+ add_function_test(TestTileReduce, "test_tile_reduce_min", test_tile_reduce_min, devices=devices)
435
+ add_function_test(TestTileReduce, "test_tile_reduce_max", test_tile_reduce_max, devices=devices)
436
+ add_function_test(TestTileReduce, "test_tile_reduce_custom", test_tile_reduce_custom, devices=devices)
437
+ add_function_test(TestTileReduce, "test_tile_reduce_custom_struct", test_tile_reduce_custom_struct, devices=devices)
438
+ add_function_test(TestTileReduce, "test_tile_reduce_grouped_sum", test_tile_reduce_sum, devices=devices)
439
+ add_function_test(TestTileReduce, "test_tile_reduce_simt", test_tile_reduce_simt, devices=devices)
440
+ add_function_test(TestTileReduce, "test_tile_ones", test_tile_ones, devices=devices)
441
+ add_function_test(TestTileReduce, "test_tile_arange", test_tile_arange, devices=devices)
442
+ add_function_test(TestTileReduce, "test_tile_untile_scalar", test_tile_untile_scalar, devices=devices)
443
+ add_function_test(TestTileReduce, "test_tile_untile_vector", test_tile_untile_vector, devices=devices)
444
+
445
+ if __name__ == "__main__":
446
+ wp.clear_kernel_cache()
447
+ unittest.main(verbosity=2, failfast=True)
@@ -0,0 +1,247 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import unittest
17
+
18
+ import numpy as np
19
+
20
+ import warp as wp
21
+ from warp.tests.unittest_utils import *
22
+
23
+
24
+ # checks that we can configure shared memory to the expected size
25
+ def test_tile_shared_mem_size(test, device):
26
+ DIM_M = 32
27
+ DIM_N = 32
28
+
29
+ BLOCK_DIM = 256
30
+
31
+ @wp.kernel
32
+ def compute(out: wp.array2d(dtype=float)):
33
+ a = wp.tile_ones(shape=(DIM_M, DIM_N), dtype=float, storage="shared")
34
+ b = wp.tile_ones(shape=(DIM_M, DIM_N), dtype=float, storage="shared") * 2.0
35
+
36
+ c = a + b
37
+ wp.tile_store(out, c)
38
+
39
+ out = wp.empty((DIM_M, DIM_N), dtype=float, device=device)
40
+
41
+ wp.launch_tiled(compute, dim=[1], inputs=[out], block_dim=BLOCK_DIM, device=device)
42
+
43
+ # check output
44
+ assert_np_equal(out.numpy(), np.ones((DIM_M, DIM_N)) * 3.0)
45
+
46
+ # check required shared memory
47
+ expected_forward_bytes = DIM_M * DIM_N * 4 * 2
48
+ expected_backward_bytes = expected_forward_bytes * 2
49
+
50
+ # check shared memory for kernel on the device
51
+ module_exec = compute.module.load(device, BLOCK_DIM)
52
+ hooks = module_exec.get_kernel_hooks(compute)
53
+
54
+ assert hooks.forward_smem_bytes == expected_forward_bytes
55
+ assert hooks.backward_smem_bytes == expected_backward_bytes
56
+
57
+
58
+ # checks that we can configure shared memory > 48kb default
59
+ def test_tile_shared_mem_large(test, device):
60
+ # set dimensions that require 64kb for the forward kernel
61
+ DIM_M = 64
62
+ DIM_N = 128
63
+
64
+ BLOCK_DIM = 256
65
+
66
+ # we disable backward kernel gen since 128k is not supported on most architectures
67
+ @wp.kernel(enable_backward=False)
68
+ def compute(out: wp.array2d(dtype=float)):
69
+ a = wp.tile_ones(shape=(DIM_M, DIM_N), dtype=float, storage="shared")
70
+ b = wp.tile_ones(shape=(DIM_M, DIM_N), dtype=float, storage="shared") * 2.0
71
+
72
+ c = a + b
73
+ wp.tile_store(out, c)
74
+
75
+ out = wp.empty((DIM_M, DIM_N), dtype=float, device=device)
76
+
77
+ wp.launch_tiled(compute, dim=[1], inputs=[out], block_dim=BLOCK_DIM, device=device)
78
+
79
+ # check output
80
+ assert_np_equal(out.numpy(), np.ones((DIM_M, DIM_N)) * 3.0)
81
+
82
+ # check required shared memory
83
+ expected_forward_bytes = DIM_M * DIM_N * 4 * 2
84
+ expected_backward_bytes = 0
85
+
86
+ assert expected_forward_bytes == 2**16
87
+
88
+ # check shared memory for kernel on the device
89
+ module_exec = compute.module.load(device, BLOCK_DIM)
90
+ hooks = module_exec.get_kernel_hooks(compute)
91
+
92
+ assert hooks.forward_smem_bytes == expected_forward_bytes
93
+ assert hooks.backward_smem_bytes == expected_backward_bytes
94
+
95
+
96
+ # checks that we can configure dynamic shared memory during graph capture
97
+ def test_tile_shared_mem_graph(test, device):
98
+ DIM_M = 32
99
+ DIM_N = 32
100
+
101
+ BLOCK_DIM = 256
102
+
103
+ @wp.kernel
104
+ def compute(out: wp.array2d(dtype=float)):
105
+ a = wp.tile_ones(shape=(DIM_M, DIM_N), dtype=float, storage="shared")
106
+ b = wp.tile_ones(shape=(DIM_M, DIM_N), dtype=float, storage="shared") * 2.0
107
+
108
+ c = a + b
109
+ wp.tile_store(out, c)
110
+
111
+ out = wp.empty((DIM_M, DIM_N), dtype=float, device=device)
112
+
113
+ wp.load_module(device=device)
114
+
115
+ wp.capture_begin(device, force_module_load=False)
116
+ wp.launch_tiled(compute, dim=[1], inputs=[out], block_dim=BLOCK_DIM, device=device)
117
+ graph = wp.capture_end(device)
118
+
119
+ wp.capture_launch(graph)
120
+
121
+ # check output
122
+ assert_np_equal(out.numpy(), np.ones((DIM_M, DIM_N)) * 3.0)
123
+
124
+ # check required shared memory
125
+ expected_forward_bytes = DIM_M * DIM_N * 4 * 2
126
+ expected_backward_bytes = expected_forward_bytes * 2
127
+
128
+ # check shared memory for kernel on the device
129
+ module_exec = compute.module.load(device, BLOCK_DIM)
130
+ hooks = module_exec.get_kernel_hooks(compute)
131
+
132
+ assert hooks.forward_smem_bytes == expected_forward_bytes
133
+ assert hooks.backward_smem_bytes == expected_backward_bytes
134
+
135
+
136
+ # checks that stack allocations work for user functions
137
+ def test_tile_shared_mem_func(test, device):
138
+ DIM_M = 64
139
+ DIM_N = 64
140
+
141
+ SMALL_DIM_M = 64 // 4
142
+ SMALL_DIM_N = 64 // 4
143
+
144
+ BLOCK_DIM = 256
145
+
146
+ @wp.func
147
+ def add_tile_small():
148
+ a = wp.tile_ones(shape=(SMALL_DIM_M, SMALL_DIM_N), dtype=float, storage="shared")
149
+ b = wp.tile_ones(shape=(SMALL_DIM_M, SMALL_DIM_N), dtype=float, storage="shared") * 2.0
150
+
151
+ return a + b
152
+
153
+ @wp.func
154
+ def add_tile_big():
155
+ a = wp.tile_ones(shape=(DIM_M, DIM_N), dtype=float, storage="shared")
156
+ b = wp.tile_ones(shape=(DIM_M, DIM_N), dtype=float, storage="shared") * 2.0
157
+
158
+ return a + b
159
+
160
+ @wp.kernel
161
+ def compute(out: wp.array2d(dtype=float)):
162
+ s = add_tile_small()
163
+ b = add_tile_big()
164
+
165
+ wp.tile_store(out, b)
166
+
167
+ out = wp.empty((DIM_M, DIM_N), dtype=float, device=device)
168
+
169
+ wp.launch_tiled(compute, dim=[1], inputs=[out], block_dim=BLOCK_DIM, device=device)
170
+
171
+ # check shared memory for kernel on the device
172
+ module_exec = compute.module.load(device, BLOCK_DIM)
173
+ hooks = module_exec.get_kernel_hooks(compute)
174
+
175
+ # ensure that total required dynamic shared is the larger of the two tiles
176
+ expected_required_shared = 64 * 64 * 4 * 2
177
+
178
+ assert hooks.forward_smem_bytes == expected_required_shared
179
+ assert hooks.backward_smem_bytes == expected_required_shared * 2
180
+
181
+
182
+ def round_up(a, b):
183
+ return b * ((a + b - 1) // b)
184
+
185
+
186
+ # checks that using non-16B aligned sizes work
187
+ def test_tile_shared_non_aligned(test, device):
188
+ # Tile size = 4 (float) * 1 * 3 = 12B % 16 != 0
189
+ DIM_M = 1
190
+ DIM_N = 3
191
+
192
+ BLOCK_DIM = 256
193
+
194
+ @wp.func
195
+ def foo():
196
+ a = wp.tile_ones(shape=(DIM_M, DIM_N), dtype=float, storage="shared") * 2.0
197
+ b = wp.tile_ones(shape=(DIM_M, DIM_N), dtype=float, storage="shared") * 3.0
198
+ return a + b
199
+
200
+ @wp.kernel
201
+ def compute(out: wp.array2d(dtype=float)):
202
+ # This test the logic in the stack allocator, which should increment and
203
+ # decrement the stack pointer each time foo() is called
204
+ # Failing to do so correct will make b out of bounds and corrupt the results
205
+ for _ in range(4096):
206
+ foo()
207
+ b = wp.tile_ones(shape=(DIM_M, DIM_N), dtype=float, storage="shared")
208
+ wp.tile_store(out, b)
209
+
210
+ out = wp.empty((DIM_M, DIM_N), dtype=float, device=device)
211
+
212
+ wp.launch_tiled(compute, dim=[1], inputs=[out], block_dim=BLOCK_DIM, device=device)
213
+
214
+ assert_np_equal(out.numpy(), np.ones((DIM_M, DIM_N), dtype=float))
215
+
216
+ # check shared memory for kernel on the device
217
+ module_exec = compute.module.load(device, BLOCK_DIM)
218
+ hooks = module_exec.get_kernel_hooks(compute)
219
+
220
+ # ensure that total required dynamic shared is the larger of the two tiles
221
+ expected_required_shared = 3 * round_up(DIM_M * DIM_N * 4, 16)
222
+
223
+ assert hooks.forward_smem_bytes == expected_required_shared
224
+ assert hooks.backward_smem_bytes == expected_required_shared * 2
225
+
226
+
227
+ devices = get_cuda_test_devices()
228
+
229
+
230
+ class TestTileSharedMemory(unittest.TestCase):
231
+ pass
232
+
233
+
234
+ add_function_test(
235
+ TestTileSharedMemory, "test_tile_shared_mem_size", test_tile_shared_mem_size, devices=devices, check_output=False
236
+ )
237
+ add_function_test(
238
+ TestTileSharedMemory, "test_tile_shared_mem_large", test_tile_shared_mem_large, devices=devices, check_output=False
239
+ )
240
+ add_function_test(TestTileSharedMemory, "test_tile_shared_mem_graph", test_tile_shared_mem_graph, devices=devices)
241
+ add_function_test(TestTileSharedMemory, "test_tile_shared_mem_func", test_tile_shared_mem_func, devices=devices)
242
+ add_function_test(TestTileSharedMemory, "test_tile_shared_non_aligned", test_tile_shared_non_aligned, devices=devices)
243
+
244
+
245
+ if __name__ == "__main__":
246
+ wp.clear_kernel_cache()
247
+ unittest.main(verbosity=2, failfast=True)