warp-lang 1.6.1__py3-none-macosx_10_13_universal2.whl → 1.7.0__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (401) hide show
  1. warp/__init__.py +21 -7
  2. warp/autograd.py +14 -6
  3. warp/bin/libwarp-clang.dylib +0 -0
  4. warp/bin/libwarp.dylib +0 -0
  5. warp/build.py +424 -6
  6. warp/build_dll.py +20 -20
  7. warp/builtins.py +467 -368
  8. warp/codegen.py +193 -125
  9. warp/config.py +56 -12
  10. warp/constants.py +14 -6
  11. warp/context.py +524 -277
  12. warp/dlpack.py +22 -12
  13. warp/examples/__init__.py +14 -6
  14. warp/examples/assets/nonuniform.usd +0 -0
  15. warp/examples/assets/nvidia_logo.png +0 -0
  16. warp/examples/benchmarks/benchmark_api.py +14 -6
  17. warp/examples/benchmarks/benchmark_cloth.py +14 -6
  18. warp/examples/benchmarks/benchmark_cloth_cupy.py +14 -6
  19. warp/examples/benchmarks/benchmark_cloth_jax.py +14 -6
  20. warp/examples/benchmarks/benchmark_cloth_numba.py +15 -0
  21. warp/examples/benchmarks/benchmark_cloth_numpy.py +14 -6
  22. warp/examples/benchmarks/benchmark_cloth_paddle.py +14 -6
  23. warp/examples/benchmarks/benchmark_cloth_pytorch.py +14 -6
  24. warp/examples/benchmarks/benchmark_cloth_taichi.py +14 -6
  25. warp/examples/benchmarks/benchmark_cloth_warp.py +14 -6
  26. warp/examples/benchmarks/benchmark_gemm.py +82 -48
  27. warp/examples/benchmarks/benchmark_interop_paddle.py +14 -6
  28. warp/examples/benchmarks/benchmark_interop_torch.py +14 -6
  29. warp/examples/benchmarks/benchmark_launches.py +14 -6
  30. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  31. warp/examples/browse.py +14 -6
  32. warp/examples/core/example_cupy.py +14 -6
  33. warp/examples/core/example_dem.py +14 -6
  34. warp/examples/core/example_fluid.py +14 -6
  35. warp/examples/core/example_graph_capture.py +14 -6
  36. warp/examples/core/example_marching_cubes.py +14 -6
  37. warp/examples/core/example_mesh.py +14 -6
  38. warp/examples/core/example_mesh_intersect.py +14 -6
  39. warp/examples/core/example_nvdb.py +14 -6
  40. warp/examples/core/example_raycast.py +14 -6
  41. warp/examples/core/example_raymarch.py +14 -6
  42. warp/examples/core/example_render_opengl.py +14 -6
  43. warp/examples/core/example_sample_mesh.py +300 -0
  44. warp/examples/core/example_sph.py +14 -6
  45. warp/examples/core/example_torch.py +14 -6
  46. warp/examples/core/example_wave.py +14 -6
  47. warp/examples/fem/example_adaptive_grid.py +14 -6
  48. warp/examples/fem/example_apic_fluid.py +15 -7
  49. warp/examples/fem/example_burgers.py +16 -8
  50. warp/examples/fem/example_convection_diffusion.py +14 -6
  51. warp/examples/fem/example_convection_diffusion_dg.py +14 -6
  52. warp/examples/fem/example_deformed_geometry.py +15 -7
  53. warp/examples/fem/example_diffusion.py +14 -6
  54. warp/examples/fem/example_diffusion_3d.py +14 -6
  55. warp/examples/fem/example_diffusion_mgpu.py +14 -6
  56. warp/examples/fem/example_distortion_energy.py +15 -7
  57. warp/examples/fem/example_magnetostatics.py +20 -12
  58. warp/examples/fem/example_mixed_elasticity.py +14 -6
  59. warp/examples/fem/example_navier_stokes.py +14 -6
  60. warp/examples/fem/example_nonconforming_contact.py +14 -6
  61. warp/examples/fem/example_stokes.py +14 -6
  62. warp/examples/fem/example_stokes_transfer.py +14 -6
  63. warp/examples/fem/example_streamlines.py +14 -6
  64. warp/examples/fem/utils.py +24 -3
  65. warp/examples/interop/example_jax_callable.py +116 -0
  66. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  67. warp/examples/interop/example_jax_kernel.py +205 -0
  68. warp/examples/optim/example_bounce.py +14 -6
  69. warp/examples/optim/example_cloth_throw.py +14 -6
  70. warp/examples/optim/example_diffray.py +14 -6
  71. warp/examples/optim/example_drone.py +14 -6
  72. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  73. warp/examples/optim/example_inverse_kinematics.py +14 -6
  74. warp/examples/optim/example_inverse_kinematics_torch.py +14 -6
  75. warp/examples/optim/example_softbody_properties.py +14 -6
  76. warp/examples/optim/example_spring_cage.py +14 -6
  77. warp/examples/optim/example_trajectory.py +14 -6
  78. warp/examples/sim/example_cartpole.py +14 -6
  79. warp/examples/sim/example_cloth.py +14 -6
  80. warp/examples/sim/example_cloth_self_contact.py +14 -6
  81. warp/examples/sim/example_granular.py +14 -6
  82. warp/examples/sim/example_granular_collision_sdf.py +14 -6
  83. warp/examples/sim/example_jacobian_ik.py +14 -6
  84. warp/examples/sim/example_particle_chain.py +14 -6
  85. warp/examples/sim/example_quadruped.py +14 -6
  86. warp/examples/sim/example_rigid_chain.py +14 -6
  87. warp/examples/sim/example_rigid_contact.py +14 -6
  88. warp/examples/sim/example_rigid_force.py +14 -6
  89. warp/examples/sim/example_rigid_gyroscopic.py +14 -6
  90. warp/examples/sim/example_rigid_soft_contact.py +14 -6
  91. warp/examples/sim/example_soft_body.py +14 -6
  92. warp/examples/tile/example_tile_cholesky.py +14 -6
  93. warp/examples/tile/example_tile_convolution.py +14 -6
  94. warp/examples/tile/example_tile_fft.py +14 -6
  95. warp/examples/tile/example_tile_filtering.py +14 -6
  96. warp/examples/tile/example_tile_matmul.py +16 -10
  97. warp/examples/tile/example_tile_mlp.py +14 -6
  98. warp/examples/tile/example_tile_nbody.py +14 -6
  99. warp/examples/tile/example_tile_walker.py +14 -6
  100. warp/fabric.py +15 -0
  101. warp/fem/__init__.py +26 -1
  102. warp/fem/adaptivity.py +19 -4
  103. warp/fem/cache.py +15 -0
  104. warp/fem/dirichlet.py +15 -0
  105. warp/fem/domain.py +15 -0
  106. warp/fem/field/__init__.py +15 -0
  107. warp/fem/field/field.py +15 -0
  108. warp/fem/field/nodal_field.py +37 -68
  109. warp/fem/field/restriction.py +15 -0
  110. warp/fem/field/virtual.py +77 -23
  111. warp/fem/geometry/__init__.py +15 -0
  112. warp/fem/geometry/adaptive_nanogrid.py +24 -10
  113. warp/fem/geometry/closest_point.py +16 -1
  114. warp/fem/geometry/deformed_geometry.py +20 -2
  115. warp/fem/geometry/element.py +15 -0
  116. warp/fem/geometry/geometry.py +20 -0
  117. warp/fem/geometry/grid_2d.py +27 -12
  118. warp/fem/geometry/grid_3d.py +27 -15
  119. warp/fem/geometry/hexmesh.py +20 -7
  120. warp/fem/geometry/nanogrid.py +24 -11
  121. warp/fem/geometry/partition.py +15 -0
  122. warp/fem/geometry/quadmesh.py +28 -13
  123. warp/fem/geometry/tetmesh.py +18 -4
  124. warp/fem/geometry/trimesh.py +18 -8
  125. warp/fem/integrate.py +277 -93
  126. warp/fem/linalg.py +20 -5
  127. warp/fem/operator.py +15 -0
  128. warp/fem/polynomial.py +15 -0
  129. warp/fem/quadrature/__init__.py +15 -0
  130. warp/fem/quadrature/pic_quadrature.py +52 -22
  131. warp/fem/quadrature/quadrature.py +209 -25
  132. warp/fem/space/__init__.py +16 -1
  133. warp/fem/space/basis_function_space.py +19 -2
  134. warp/fem/space/basis_space.py +40 -18
  135. warp/fem/space/dof_mapper.py +15 -0
  136. warp/fem/space/function_space.py +15 -0
  137. warp/fem/space/grid_2d_function_space.py +15 -0
  138. warp/fem/space/grid_3d_function_space.py +15 -0
  139. warp/fem/space/hexmesh_function_space.py +17 -2
  140. warp/fem/space/nanogrid_function_space.py +15 -0
  141. warp/fem/space/partition.py +21 -2
  142. warp/fem/space/quadmesh_function_space.py +23 -8
  143. warp/fem/space/restriction.py +15 -0
  144. warp/fem/space/shape/__init__.py +15 -0
  145. warp/fem/space/shape/cube_shape_function.py +38 -23
  146. warp/fem/space/shape/shape_function.py +15 -0
  147. warp/fem/space/shape/square_shape_function.py +27 -12
  148. warp/fem/space/shape/tet_shape_function.py +15 -0
  149. warp/fem/space/shape/triangle_shape_function.py +16 -1
  150. warp/fem/space/tetmesh_function_space.py +18 -3
  151. warp/fem/space/topology.py +15 -0
  152. warp/fem/space/trimesh_function_space.py +17 -2
  153. warp/fem/types.py +15 -0
  154. warp/fem/utils.py +27 -6
  155. warp/jax.py +28 -7
  156. warp/jax_experimental/__init__.py +16 -0
  157. warp/{jax_experimental.py → jax_experimental/custom_call.py} +28 -33
  158. warp/jax_experimental/ffi.py +698 -0
  159. warp/jax_experimental/xla_ffi.py +602 -0
  160. warp/math.py +103 -6
  161. warp/native/array.h +28 -6
  162. warp/native/builtin.h +44 -9
  163. warp/native/bvh.cpp +18 -7
  164. warp/native/bvh.cu +57 -20
  165. warp/native/bvh.h +17 -7
  166. warp/native/clang/clang.cpp +45 -9
  167. warp/native/coloring.cpp +15 -6
  168. warp/native/crt.cpp +15 -6
  169. warp/native/crt.h +15 -6
  170. warp/native/cuda_crt.h +15 -6
  171. warp/native/cuda_util.cpp +29 -6
  172. warp/native/cuda_util.h +17 -6
  173. warp/native/error.cpp +15 -6
  174. warp/native/error.h +15 -6
  175. warp/native/exports.h +85 -63
  176. warp/native/fabric.h +15 -6
  177. warp/native/hashgrid.cpp +15 -6
  178. warp/native/hashgrid.cu +15 -6
  179. warp/native/hashgrid.h +15 -6
  180. warp/native/initializer_array.h +15 -6
  181. warp/native/intersect.h +41 -32
  182. warp/native/intersect_adj.h +48 -39
  183. warp/native/intersect_tri.h +17 -0
  184. warp/native/marching.cpp +16 -0
  185. warp/native/marching.cu +16 -7
  186. warp/native/marching.h +17 -0
  187. warp/native/mat.h +528 -15
  188. warp/native/mathdx.cpp +15 -6
  189. warp/native/matnn.h +15 -6
  190. warp/native/mesh.cpp +15 -6
  191. warp/native/mesh.cu +15 -6
  192. warp/native/mesh.h +25 -16
  193. warp/native/noise.h +15 -6
  194. warp/native/quat.h +114 -17
  195. warp/native/rand.h +21 -6
  196. warp/native/range.h +15 -6
  197. warp/native/reduce.cpp +15 -6
  198. warp/native/reduce.cu +15 -6
  199. warp/native/runlength_encode.cpp +15 -6
  200. warp/native/runlength_encode.cu +15 -6
  201. warp/native/scan.cpp +15 -6
  202. warp/native/scan.cu +15 -6
  203. warp/native/scan.h +15 -6
  204. warp/native/solid_angle.h +17 -0
  205. warp/native/sort.cpp +137 -65
  206. warp/native/sort.cu +167 -21
  207. warp/native/sort.h +23 -7
  208. warp/native/sparse.cpp +58 -28
  209. warp/native/sparse.cu +67 -23
  210. warp/native/spatial.h +15 -6
  211. warp/native/svd.h +131 -6
  212. warp/native/temp_buffer.h +15 -6
  213. warp/native/tile.h +316 -111
  214. warp/native/tile_reduce.h +61 -9
  215. warp/native/vec.h +83 -13
  216. warp/native/volume.cpp +100 -119
  217. warp/native/volume.cu +15 -6
  218. warp/native/volume.h +15 -6
  219. warp/native/volume_builder.cu +40 -16
  220. warp/native/volume_builder.h +21 -6
  221. warp/native/volume_impl.h +15 -6
  222. warp/native/warp.cpp +20 -12
  223. warp/native/warp.cu +114 -16
  224. warp/native/warp.h +34 -16
  225. warp/optim/__init__.py +14 -6
  226. warp/optim/adam.py +14 -6
  227. warp/optim/linear.py +25 -10
  228. warp/optim/sgd.py +14 -6
  229. warp/paddle.py +14 -6
  230. warp/render/__init__.py +14 -6
  231. warp/render/render_opengl.py +14 -6
  232. warp/render/render_usd.py +14 -6
  233. warp/render/utils.py +14 -6
  234. warp/sim/__init__.py +14 -7
  235. warp/sim/articulation.py +18 -10
  236. warp/sim/collide.py +35 -16
  237. warp/sim/graph_coloring.py +14 -6
  238. warp/sim/import_mjcf.py +463 -162
  239. warp/sim/import_snu.py +14 -7
  240. warp/sim/import_urdf.py +46 -18
  241. warp/sim/import_usd.py +14 -7
  242. warp/sim/inertia.py +14 -6
  243. warp/sim/integrator.py +14 -6
  244. warp/sim/integrator_euler.py +19 -11
  245. warp/sim/integrator_featherstone.py +17 -16
  246. warp/sim/integrator_vbd.py +222 -8
  247. warp/sim/integrator_xpbd.py +19 -11
  248. warp/sim/model.py +56 -19
  249. warp/sim/particles.py +14 -6
  250. warp/sim/render.py +14 -6
  251. warp/sim/utils.py +17 -2
  252. warp/sparse.py +657 -555
  253. warp/stubs.py +231 -19
  254. warp/tape.py +14 -6
  255. warp/tests/aux_test_class_kernel.py +14 -6
  256. warp/tests/aux_test_compile_consts_dummy.py +14 -6
  257. warp/tests/aux_test_conditional_unequal_types_kernels.py +14 -6
  258. warp/tests/aux_test_dependent.py +14 -6
  259. warp/tests/aux_test_grad_customs.py +14 -6
  260. warp/tests/aux_test_instancing_gc.py +14 -6
  261. warp/tests/aux_test_module_unload.py +14 -6
  262. warp/tests/aux_test_name_clash1.py +14 -6
  263. warp/tests/aux_test_name_clash2.py +14 -6
  264. warp/tests/aux_test_unresolved_func.py +14 -6
  265. warp/tests/aux_test_unresolved_symbol.py +14 -6
  266. warp/tests/cuda/__init__.py +0 -0
  267. warp/tests/{test_async.py → cuda/test_async.py} +14 -6
  268. warp/tests/{test_ipc.py → cuda/test_ipc.py} +14 -6
  269. warp/tests/{test_mempool.py → cuda/test_mempool.py} +53 -6
  270. warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +14 -6
  271. warp/tests/{test_peer.py → cuda/test_peer.py} +14 -6
  272. warp/tests/{test_pinned.py → cuda/test_pinned.py} +14 -6
  273. warp/tests/{test_streams.py → cuda/test_streams.py} +85 -6
  274. warp/tests/geometry/__init__.py +0 -0
  275. warp/tests/{test_bvh.py → geometry/test_bvh.py} +14 -6
  276. warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +14 -6
  277. warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +14 -6
  278. warp/tests/{test_mesh.py → geometry/test_mesh.py} +14 -6
  279. warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +14 -6
  280. warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +80 -69
  281. warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +15 -7
  282. warp/tests/{test_volume.py → geometry/test_volume.py} +55 -12
  283. warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +14 -6
  284. warp/tests/interop/__init__.py +0 -0
  285. warp/tests/{test_dlpack.py → interop/test_dlpack.py} +42 -11
  286. warp/tests/{test_jax.py → interop/test_jax.py} +14 -6
  287. warp/tests/{test_paddle.py → interop/test_paddle.py} +14 -6
  288. warp/tests/{test_torch.py → interop/test_torch.py} +14 -6
  289. warp/tests/run_coverage_serial.py +14 -6
  290. warp/tests/sim/__init__.py +0 -0
  291. warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +23 -16
  292. warp/tests/{flaky_test_sim_grad.py → sim/flaky_test_sim_grad.py} +14 -6
  293. warp/tests/{test_collision.py → sim/test_collision.py} +16 -8
  294. warp/tests/{test_coloring.py → sim/test_coloring.py} +14 -7
  295. warp/tests/{test_model.py → sim/test_model.py} +55 -7
  296. warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +14 -6
  297. warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +16 -7
  298. warp/tests/sim/test_vbd.py +597 -0
  299. warp/tests/test_adam.py +14 -6
  300. warp/tests/test_arithmetic.py +14 -6
  301. warp/tests/test_array.py +14 -6
  302. warp/tests/test_array_reduce.py +14 -6
  303. warp/tests/test_assert.py +14 -6
  304. warp/tests/test_atomic.py +14 -6
  305. warp/tests/test_bool.py +15 -7
  306. warp/tests/test_builtins_resolution.py +14 -6
  307. warp/tests/test_closest_point_edge_edge.py +14 -6
  308. warp/tests/test_codegen.py +14 -6
  309. warp/tests/test_codegen_instancing.py +14 -6
  310. warp/tests/test_compile_consts.py +14 -6
  311. warp/tests/test_conditional.py +14 -6
  312. warp/tests/test_context.py +14 -6
  313. warp/tests/test_copy.py +14 -6
  314. warp/tests/test_ctypes.py +14 -6
  315. warp/tests/test_dense.py +14 -6
  316. warp/tests/test_devices.py +14 -6
  317. warp/tests/test_examples.py +42 -42
  318. warp/tests/test_fabricarray.py +14 -6
  319. warp/tests/test_fast_math.py +14 -6
  320. warp/tests/test_fem.py +37 -10
  321. warp/tests/test_fp16.py +14 -6
  322. warp/tests/test_func.py +14 -6
  323. warp/tests/test_future_annotations.py +14 -6
  324. warp/tests/test_generics.py +14 -6
  325. warp/tests/test_grad.py +14 -6
  326. warp/tests/test_grad_customs.py +14 -6
  327. warp/tests/test_grad_debug.py +14 -6
  328. warp/tests/test_implicit_init.py +14 -6
  329. warp/tests/test_import.py +14 -6
  330. warp/tests/test_indexedarray.py +14 -6
  331. warp/tests/test_intersect.py +14 -6
  332. warp/tests/test_iter.py +14 -6
  333. warp/tests/test_large.py +14 -6
  334. warp/tests/test_launch.py +14 -6
  335. warp/tests/test_lerp.py +14 -6
  336. warp/tests/test_linear_solvers.py +15 -11
  337. warp/tests/test_lvalue.py +14 -6
  338. warp/tests/test_mat.py +247 -85
  339. warp/tests/test_mat_lite.py +14 -6
  340. warp/tests/test_mat_scalar_ops.py +18 -10
  341. warp/tests/test_math.py +14 -6
  342. warp/tests/test_mlp.py +14 -6
  343. warp/tests/test_module_hashing.py +14 -6
  344. warp/tests/test_modules_lite.py +14 -6
  345. warp/tests/test_noise.py +14 -6
  346. warp/tests/test_operators.py +14 -6
  347. warp/tests/test_options.py +14 -6
  348. warp/tests/test_overwrite.py +15 -60
  349. warp/tests/test_print.py +14 -6
  350. warp/tests/test_quat.py +81 -52
  351. warp/tests/test_rand.py +58 -43
  352. warp/tests/test_reload.py +14 -6
  353. warp/tests/test_rounding.py +14 -6
  354. warp/tests/test_runlength_encode.py +14 -6
  355. warp/tests/test_scalar_ops.py +14 -6
  356. warp/tests/test_smoothstep.py +14 -6
  357. warp/tests/test_snippet.py +15 -0
  358. warp/tests/test_sparse.py +61 -12
  359. warp/tests/test_spatial.py +89 -6
  360. warp/tests/test_special_values.py +14 -6
  361. warp/tests/test_static.py +15 -7
  362. warp/tests/test_struct.py +14 -6
  363. warp/tests/test_tape.py +14 -6
  364. warp/tests/test_transient_module.py +14 -6
  365. warp/tests/test_triangle_closest_point.py +14 -6
  366. warp/tests/test_types.py +14 -6
  367. warp/tests/test_utils.py +98 -10
  368. warp/tests/test_vec.py +60 -40
  369. warp/tests/test_vec_lite.py +14 -6
  370. warp/tests/test_vec_scalar_ops.py +14 -6
  371. warp/tests/test_verify_fp.py +14 -6
  372. warp/tests/tile/__init__.py +0 -0
  373. warp/tests/{test_tile.py → tile/test_tile.py} +150 -57
  374. warp/tests/{test_tile_load.py → tile/test_tile_load.py} +15 -7
  375. warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +23 -12
  376. warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +39 -20
  377. warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +74 -7
  378. warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +14 -6
  379. warp/tests/{test_tile_view.py → tile/test_tile_view.py} +15 -7
  380. warp/tests/unittest_serial.py +15 -6
  381. warp/tests/unittest_suites.py +59 -65
  382. warp/tests/unittest_utils.py +16 -7
  383. warp/tests/walkthrough_debug.py +14 -6
  384. warp/thirdparty/unittest_parallel.py +15 -8
  385. warp/torch.py +14 -6
  386. warp/types.py +124 -664
  387. warp/utils.py +151 -78
  388. {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/METADATA +39 -12
  389. warp_lang-1.7.0.dist-info/RECORD +429 -0
  390. {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/WHEEL +1 -1
  391. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  392. warp/examples/optim/example_walker.py +0 -309
  393. warp/native/cutlass_gemm.cpp +0 -34
  394. warp/native/cutlass_gemm.cu +0 -373
  395. warp/tests/test_matmul.py +0 -503
  396. warp/tests/test_matmul_lite.py +0 -403
  397. warp/tests/test_vbd.py +0 -378
  398. warp/tests/unused_test_misc.py +0 -69
  399. warp_lang-1.6.1.dist-info/LICENSE.md +0 -126
  400. warp_lang-1.6.1.dist-info/RECORD +0 -419
  401. {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/top_level.txt +0 -0
@@ -1,403 +0,0 @@
1
- # Copyright (c) 2023 NVIDIA CORPORATION. All rights reserved.
2
- # NVIDIA CORPORATION and its licensors retain all intellectual property
3
- # and proprietary rights in and to this software, related documentation
4
- # and any modifications thereto. Any use, reproduction, disclosure or
5
- # distribution of this software and related documentation without an express
6
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
-
8
- import unittest
9
-
10
- import numpy as np
11
-
12
- import warp as wp
13
- from warp.tests.unittest_utils import *
14
-
15
- wp.init() # For wp.context.runtime.core.is_cutlass_enabled()
16
-
17
-
18
- class gemm_test_bed_runner:
19
- def __init__(self, dtype, device):
20
- self.dtype = dtype
21
- self.device = device
22
-
23
- def alloc(self, m, n, k, batch_count):
24
- rng = np.random.default_rng(42)
25
- low = -4.5
26
- high = 3.5
27
- if batch_count == 1:
28
- A = wp.array2d(
29
- np.ceil(rng.uniform(low=low, high=high, size=(m, k))),
30
- dtype=self.dtype,
31
- device=self.device,
32
- requires_grad=True,
33
- )
34
- B = wp.array2d(
35
- np.ceil(rng.uniform(low=low, high=high, size=(k, n))),
36
- dtype=self.dtype,
37
- device=self.device,
38
- requires_grad=True,
39
- )
40
- C = wp.array2d(
41
- np.ceil(rng.uniform(low=low, high=high, size=(m, n))),
42
- dtype=self.dtype,
43
- device=self.device,
44
- requires_grad=True,
45
- )
46
- D = wp.array2d(np.zeros((m, n)), dtype=self.dtype, device=self.device, requires_grad=True)
47
- else:
48
- A = wp.array3d(
49
- np.ceil(rng.uniform(low=low, high=high, size=(batch_count, m, k))),
50
- dtype=self.dtype,
51
- device=self.device,
52
- requires_grad=True,
53
- )
54
- B = wp.array3d(
55
- np.ceil(rng.uniform(low=low, high=high, size=(batch_count, k, n))),
56
- dtype=self.dtype,
57
- device=self.device,
58
- requires_grad=True,
59
- )
60
- C = wp.array3d(
61
- np.ceil(rng.uniform(low=low, high=high, size=(batch_count, m, n))),
62
- dtype=self.dtype,
63
- device=self.device,
64
- requires_grad=True,
65
- )
66
- D = wp.array3d(np.zeros((batch_count, m, n)), dtype=self.dtype, device=self.device, requires_grad=True)
67
- return A, B, C, D
68
-
69
- def run_and_verify(self, m, n, k, batch_count, alpha, beta):
70
- A, B, C, D = self.alloc(m, n, k, batch_count)
71
- ones = wp.zeros_like(D)
72
- ones.fill_(1.0)
73
-
74
- if batch_count == 1:
75
- tape = wp.Tape()
76
- with tape:
77
- wp.matmul(A, B, C, D, alpha, beta, False)
78
- tape.backward(grads={D: ones})
79
-
80
- D_np = alpha * (A.numpy() @ B.numpy()) + beta * C.numpy()
81
- assert_np_equal(D.numpy(), D_np)
82
-
83
- adj_A_np = alpha * np.matmul(ones.numpy(), B.numpy().transpose())
84
- adj_B_np = alpha * (A.numpy().transpose() @ ones.numpy())
85
- adj_C_np = beta * ones.numpy()
86
-
87
- else:
88
- tape = wp.Tape()
89
- with tape:
90
- wp.batched_matmul(A, B, C, D, alpha, beta, False)
91
- tape.backward(grads={D: ones})
92
-
93
- D_np = alpha * np.matmul(A.numpy(), B.numpy()) + beta * C.numpy()
94
- assert_np_equal(D.numpy(), D_np)
95
-
96
- adj_A_np = alpha * np.matmul(ones.numpy(), B.numpy().transpose((0, 2, 1)))
97
- adj_B_np = alpha * np.matmul(A.numpy().transpose((0, 2, 1)), ones.numpy())
98
- adj_C_np = beta * ones.numpy()
99
-
100
- assert_np_equal(A.grad.numpy(), adj_A_np)
101
- assert_np_equal(B.grad.numpy(), adj_B_np)
102
- assert_np_equal(C.grad.numpy(), adj_C_np)
103
-
104
- def run(self):
105
- m = 8
106
- n = 16
107
- k = 32
108
- batch_count = 1
109
- beta = 1.0
110
- alpha = 1.0
111
-
112
- self.run_and_verify(m, n, k, batch_count, alpha, beta)
113
-
114
-
115
- class gemm_test_bed_runner_transpose:
116
- def __init__(self, dtype, device):
117
- self.dtype = dtype
118
- self.device = device
119
-
120
- def alloc(self, m, n, k, batch_count):
121
- rng = np.random.default_rng(42)
122
- low = -4.5
123
- high = 3.5
124
- if batch_count == 1:
125
- A = wp.array2d(
126
- np.ceil(rng.uniform(low=low, high=high, size=(m, k))),
127
- dtype=self.dtype,
128
- device=self.device,
129
- requires_grad=True,
130
- )
131
- B = wp.array2d(
132
- np.ceil(rng.uniform(low=low, high=high, size=(k, n))),
133
- dtype=self.dtype,
134
- device=self.device,
135
- requires_grad=True,
136
- )
137
- C = wp.array2d(
138
- np.ceil(rng.uniform(low=low, high=high, size=(m, n))),
139
- dtype=self.dtype,
140
- device=self.device,
141
- requires_grad=True,
142
- )
143
- D = wp.array2d(np.zeros((m, n)), dtype=self.dtype, device=self.device, requires_grad=True)
144
- AT = wp.array2d(A.numpy().transpose([1, 0]), dtype=self.dtype, device=self.device, requires_grad=True)
145
- BT = wp.array2d(B.numpy().transpose([1, 0]), dtype=self.dtype, device=self.device, requires_grad=True)
146
- else:
147
- A = wp.array3d(
148
- np.ceil(rng.uniform(low=low, high=high, size=(batch_count, m, k))),
149
- dtype=self.dtype,
150
- device=self.device,
151
- requires_grad=True,
152
- )
153
- B = wp.array3d(
154
- np.ceil(rng.uniform(low=low, high=high, size=(batch_count, k, n))),
155
- dtype=self.dtype,
156
- device=self.device,
157
- requires_grad=True,
158
- )
159
- C = wp.array3d(
160
- np.ceil(rng.uniform(low=low, high=high, size=(batch_count, m, n))),
161
- dtype=self.dtype,
162
- device=self.device,
163
- requires_grad=True,
164
- )
165
- D = wp.array3d(np.zeros((batch_count, m, n)), dtype=self.dtype, device=self.device, requires_grad=True)
166
- AT = wp.array3d(A.numpy().transpose([0, 2, 1]), dtype=self.dtype, device=self.device, requires_grad=True)
167
- BT = wp.array3d(B.numpy().transpose([0, 2, 1]), dtype=self.dtype, device=self.device, requires_grad=True)
168
- return A, B, C, D, AT, BT
169
-
170
- def run_and_verify(self, m, n, k, batch_count, alpha, beta):
171
- A, B, C1, D1, AT1, BT1 = self.alloc(m, n, k, batch_count)
172
- C2 = wp.clone(C1)
173
- C3 = wp.clone(C1)
174
- D2 = wp.clone(D1)
175
- D3 = wp.clone(D1)
176
- AT2 = wp.clone(AT1)
177
- BT2 = wp.clone(BT1)
178
- ones1 = wp.zeros_like(D1)
179
- ones1.fill_(1.0)
180
- ones2 = wp.zeros_like(D2)
181
- ones2.fill_(1.0)
182
- ones3 = wp.zeros_like(D3)
183
- ones3.fill_(1.0)
184
-
185
- if batch_count == 1:
186
- ATT1 = AT1.transpose([1, 0])
187
- BTT1 = BT1.transpose([1, 0])
188
- ATT2 = AT2.transpose([1, 0])
189
- BTT2 = BT2.transpose([1, 0])
190
- tape = wp.Tape()
191
- with tape:
192
- wp.matmul(A, BTT1, C1, D1, alpha, beta, False)
193
- wp.matmul(ATT1, B, C2, D2, alpha, beta, False)
194
- wp.matmul(ATT2, BTT2, C3, D3, alpha, beta, False)
195
- tape.backward(grads={D1: ones1, D2: ones2, D3: ones3})
196
-
197
- D_np = alpha * (A.numpy() @ B.numpy()) + beta * C1.numpy()
198
- assert_np_equal(D1.numpy(), D_np)
199
- assert_np_equal(D2.numpy(), D_np)
200
- assert_np_equal(D3.numpy(), D_np)
201
-
202
- adj_A_np = alpha * (ones1.numpy() @ B.numpy().transpose())
203
- adj_B_np = alpha * (A.numpy().transpose() @ ones1.numpy())
204
- adj_C_np = beta * ones1.numpy()
205
-
206
- else:
207
- ATT1 = AT1.transpose([0, 2, 1])
208
- BTT1 = BT1.transpose([0, 2, 1])
209
- ATT2 = AT2.transpose([0, 2, 1])
210
- BTT2 = BT2.transpose([0, 2, 1])
211
- tape = wp.Tape()
212
- with tape:
213
- wp.batched_matmul(A, BTT1, C1, D1, alpha, beta, False)
214
- wp.batched_matmul(ATT1, B, C2, D2, alpha, beta, False)
215
- wp.batched_matmul(ATT2, BTT2, C3, D3, alpha, beta, False)
216
- tape.backward(grads={D1: ones1, D2: ones2, D3: ones3})
217
-
218
- D_np = alpha * np.matmul(A.numpy(), B.numpy()) + beta * C1.numpy()
219
- assert_np_equal(D1.numpy(), D_np)
220
- assert_np_equal(D2.numpy(), D_np)
221
- assert_np_equal(D3.numpy(), D_np)
222
-
223
- adj_A_np = alpha * np.matmul(ones1.numpy(), B.numpy().transpose((0, 2, 1)))
224
- adj_B_np = alpha * np.matmul(A.numpy().transpose((0, 2, 1)), ones1.numpy())
225
- adj_C_np = beta * ones1.numpy()
226
-
227
- assert_np_equal(A.grad.numpy(), adj_A_np)
228
- assert_np_equal(ATT1.grad.numpy(), adj_A_np)
229
- assert_np_equal(ATT2.grad.numpy(), adj_A_np)
230
- assert_np_equal(B.grad.numpy(), adj_B_np)
231
- assert_np_equal(BTT1.grad.numpy(), adj_B_np)
232
- assert_np_equal(BTT2.grad.numpy(), adj_B_np)
233
- assert_np_equal(C1.grad.numpy(), adj_C_np)
234
- assert_np_equal(C2.grad.numpy(), adj_C_np)
235
- assert_np_equal(C3.grad.numpy(), adj_C_np)
236
-
237
- def run(self):
238
- m = 8
239
- n = 16
240
- k = 32
241
- batch_counts = [1, 4]
242
- beta = 1.0
243
- alpha = 1.0
244
-
245
- for batch_count in batch_counts:
246
- self.run_and_verify(m, n, k, batch_count, alpha, beta)
247
-
248
-
249
- @unittest.skipUnless(wp.context.runtime.core.is_cutlass_enabled(), "Warp was not built with CUTLASS support")
250
- def test_f32(test, device):
251
- gemm_test_bed_runner(wp.float32, device).run()
252
- gemm_test_bed_runner_transpose(wp.float32, device).run()
253
-
254
-
255
- @wp.kernel
256
- def matrix_sum_kernel(arr: wp.array2d(dtype=float), loss: wp.array(dtype=float)):
257
- i, j = wp.tid()
258
- wp.atomic_add(loss, 0, arr[i, j])
259
-
260
-
261
- @unittest.skipUnless(wp.context.runtime.core.is_cutlass_enabled(), "Warp was not built with CUTLASS support")
262
- def test_tape(test, device):
263
- rng = np.random.default_rng(42)
264
- low = -4.5
265
- high = 3.5
266
- m = 8
267
- n = 16
268
- k = 32
269
- A = wp.array2d(
270
- np.ceil(rng.uniform(low=low, high=high, size=(m, k))), dtype=float, device=device, requires_grad=True
271
- )
272
- B = wp.array2d(
273
- np.ceil(rng.uniform(low=low, high=high, size=(k, n))), dtype=float, device=device, requires_grad=True
274
- )
275
- C = wp.array2d(
276
- np.ceil(rng.uniform(low=low, high=high, size=(m, n))), dtype=float, device=device, requires_grad=True
277
- )
278
- D = wp.array2d(np.zeros((m, n)), dtype=float, device=device, requires_grad=True)
279
- loss = wp.zeros(1, dtype=float, device=device, requires_grad=True)
280
-
281
- # test tape
282
- tape = wp.Tape()
283
- with tape:
284
- wp.matmul(A, B, C, D)
285
- wp.launch(matrix_sum_kernel, dim=(m, n), inputs=[D, loss], device=device)
286
-
287
- tape.backward(loss=loss)
288
- A_grad = A.grad.numpy()
289
- tape.reset()
290
-
291
- # test adjoint
292
- D.grad = wp.ones((m, n), dtype=float, device=device)
293
- wp.adj_matmul(A, B, C, A.grad, B.grad, C.grad, D.grad)
294
- assert_np_equal(A_grad, A.grad.numpy())
295
-
296
- # test zero
297
- tape.zero()
298
- assert_array_equal(A.grad, wp.zeros_like(A))
299
-
300
-
301
- @unittest.skipUnless(wp.context.runtime.core.is_cutlass_enabled(), "Warp was not built with CUTLASS support")
302
- def test_operator(test, device):
303
- rng = np.random.default_rng(42)
304
- low = -4.5
305
- high = 3.5
306
- m = 8
307
- n = 16
308
- k = 32
309
- A = wp.array2d(
310
- np.ceil(rng.uniform(low=low, high=high, size=(m, k))), dtype=float, device=device, requires_grad=True
311
- )
312
- B = wp.array2d(
313
- np.ceil(rng.uniform(low=low, high=high, size=(k, n))), dtype=float, device=device, requires_grad=True
314
- )
315
- loss = wp.zeros(1, dtype=float, device=device, requires_grad=True)
316
-
317
- # test tape
318
- tape = wp.Tape()
319
- with tape:
320
- D = A @ B
321
- wp.launch(matrix_sum_kernel, dim=(m, n), inputs=[D, loss], device=device)
322
-
323
- tape.backward(loss=loss)
324
-
325
- # test adjoint
326
- D.grad = wp.ones((m, n), dtype=float, device=device)
327
- B_transpose = wp.array2d(B.transpose().numpy(), dtype=float, device=device)
328
-
329
- adj_A = D.grad @ B_transpose
330
- assert_array_equal(adj_A, A.grad)
331
-
332
- # test zero
333
- tape.zero()
334
- assert_array_equal(A.grad, wp.zeros_like(A))
335
-
336
-
337
- @unittest.skipUnless(wp.context.runtime.core.is_cutlass_enabled(), "Warp was not built with CUTLASS support")
338
- def test_large_batch_count(test, device):
339
- rng = np.random.default_rng(42)
340
- low = -4.5
341
- high = 3.5
342
- m = 2
343
- n = 3
344
- k = 4
345
- batch_count = 65535 * 2 + int(65535 / 2)
346
- A = wp.array3d(
347
- np.ceil(rng.uniform(low=low, high=high, size=(batch_count, m, k))),
348
- dtype=float,
349
- device=device,
350
- requires_grad=True,
351
- )
352
- B = wp.array3d(
353
- np.ceil(rng.uniform(low=low, high=high, size=(batch_count, k, n))),
354
- dtype=float,
355
- device=device,
356
- requires_grad=True,
357
- )
358
- C = wp.array3d(
359
- np.ceil(rng.uniform(low=low, high=high, size=(batch_count, m, n))),
360
- dtype=float,
361
- device=device,
362
- requires_grad=True,
363
- )
364
- D = wp.array3d(np.zeros((batch_count, m, n)), dtype=float, device=device, requires_grad=True)
365
- ones = wp.zeros_like(D)
366
- ones.fill_(1.0)
367
-
368
- alpha = 1.0
369
- beta = 1.0
370
-
371
- tape = wp.Tape()
372
- with tape:
373
- wp.batched_matmul(A, B, C, D, alpha=alpha, beta=beta, allow_tf32x3_arith=False)
374
- tape.backward(grads={D: ones})
375
-
376
- D_np = alpha * np.matmul(A.numpy(), B.numpy()) + beta * C.numpy()
377
- assert_np_equal(D.numpy(), D_np)
378
-
379
- adj_A_np = alpha * np.matmul(ones.numpy(), B.numpy().transpose((0, 2, 1)))
380
- adj_B_np = alpha * np.matmul(A.numpy().transpose((0, 2, 1)), ones.numpy())
381
- adj_C_np = beta * ones.numpy()
382
-
383
- assert_np_equal(A.grad.numpy(), adj_A_np)
384
- assert_np_equal(B.grad.numpy(), adj_B_np)
385
- assert_np_equal(C.grad.numpy(), adj_C_np)
386
-
387
-
388
- devices = get_test_devices()
389
-
390
-
391
- class TestMatmulLite(unittest.TestCase):
392
- pass
393
-
394
-
395
- add_function_test(TestMatmulLite, "test_f32", test_f32, devices=devices, check_output=False)
396
- add_function_test(TestMatmulLite, "test_tape", test_tape, devices=devices, check_output=False)
397
- add_function_test(TestMatmulLite, "test_operator", test_operator, devices=devices, check_output=False)
398
- add_function_test(TestMatmulLite, "test_large_batch_count", test_large_batch_count, devices=devices, check_output=False)
399
-
400
-
401
- if __name__ == "__main__":
402
- wp.clear_kernel_cache()
403
- unittest.main(verbosity=2, failfast=False)