warp-lang 1.6.1__py3-none-win_amd64.whl → 1.7.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (401) hide show
  1. warp/__init__.py +21 -7
  2. warp/autograd.py +14 -6
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +424 -6
  6. warp/build_dll.py +20 -20
  7. warp/builtins.py +467 -368
  8. warp/codegen.py +193 -125
  9. warp/config.py +56 -12
  10. warp/constants.py +14 -6
  11. warp/context.py +524 -277
  12. warp/dlpack.py +22 -12
  13. warp/examples/__init__.py +14 -6
  14. warp/examples/assets/nonuniform.usd +0 -0
  15. warp/examples/assets/nvidia_logo.png +0 -0
  16. warp/examples/benchmarks/benchmark_api.py +14 -6
  17. warp/examples/benchmarks/benchmark_cloth.py +14 -6
  18. warp/examples/benchmarks/benchmark_cloth_cupy.py +14 -6
  19. warp/examples/benchmarks/benchmark_cloth_jax.py +14 -6
  20. warp/examples/benchmarks/benchmark_cloth_numba.py +15 -0
  21. warp/examples/benchmarks/benchmark_cloth_numpy.py +14 -6
  22. warp/examples/benchmarks/benchmark_cloth_paddle.py +14 -6
  23. warp/examples/benchmarks/benchmark_cloth_pytorch.py +14 -6
  24. warp/examples/benchmarks/benchmark_cloth_taichi.py +14 -6
  25. warp/examples/benchmarks/benchmark_cloth_warp.py +14 -6
  26. warp/examples/benchmarks/benchmark_gemm.py +82 -48
  27. warp/examples/benchmarks/benchmark_interop_paddle.py +14 -6
  28. warp/examples/benchmarks/benchmark_interop_torch.py +14 -6
  29. warp/examples/benchmarks/benchmark_launches.py +14 -6
  30. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  31. warp/examples/browse.py +14 -6
  32. warp/examples/core/example_cupy.py +14 -6
  33. warp/examples/core/example_dem.py +14 -6
  34. warp/examples/core/example_fluid.py +14 -6
  35. warp/examples/core/example_graph_capture.py +14 -6
  36. warp/examples/core/example_marching_cubes.py +14 -6
  37. warp/examples/core/example_mesh.py +14 -6
  38. warp/examples/core/example_mesh_intersect.py +14 -6
  39. warp/examples/core/example_nvdb.py +14 -6
  40. warp/examples/core/example_raycast.py +14 -6
  41. warp/examples/core/example_raymarch.py +14 -6
  42. warp/examples/core/example_render_opengl.py +14 -6
  43. warp/examples/core/example_sample_mesh.py +300 -0
  44. warp/examples/core/example_sph.py +14 -6
  45. warp/examples/core/example_torch.py +14 -6
  46. warp/examples/core/example_wave.py +14 -6
  47. warp/examples/fem/example_adaptive_grid.py +14 -6
  48. warp/examples/fem/example_apic_fluid.py +15 -7
  49. warp/examples/fem/example_burgers.py +16 -8
  50. warp/examples/fem/example_convection_diffusion.py +14 -6
  51. warp/examples/fem/example_convection_diffusion_dg.py +14 -6
  52. warp/examples/fem/example_deformed_geometry.py +15 -7
  53. warp/examples/fem/example_diffusion.py +14 -6
  54. warp/examples/fem/example_diffusion_3d.py +14 -6
  55. warp/examples/fem/example_diffusion_mgpu.py +14 -6
  56. warp/examples/fem/example_distortion_energy.py +15 -7
  57. warp/examples/fem/example_magnetostatics.py +20 -12
  58. warp/examples/fem/example_mixed_elasticity.py +14 -6
  59. warp/examples/fem/example_navier_stokes.py +14 -6
  60. warp/examples/fem/example_nonconforming_contact.py +14 -6
  61. warp/examples/fem/example_stokes.py +14 -6
  62. warp/examples/fem/example_stokes_transfer.py +14 -6
  63. warp/examples/fem/example_streamlines.py +14 -6
  64. warp/examples/fem/utils.py +24 -3
  65. warp/examples/interop/example_jax_callable.py +116 -0
  66. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  67. warp/examples/interop/example_jax_kernel.py +205 -0
  68. warp/examples/optim/example_bounce.py +14 -6
  69. warp/examples/optim/example_cloth_throw.py +14 -6
  70. warp/examples/optim/example_diffray.py +14 -6
  71. warp/examples/optim/example_drone.py +14 -6
  72. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  73. warp/examples/optim/example_inverse_kinematics.py +14 -6
  74. warp/examples/optim/example_inverse_kinematics_torch.py +14 -6
  75. warp/examples/optim/example_softbody_properties.py +14 -6
  76. warp/examples/optim/example_spring_cage.py +14 -6
  77. warp/examples/optim/example_trajectory.py +14 -6
  78. warp/examples/sim/example_cartpole.py +14 -6
  79. warp/examples/sim/example_cloth.py +14 -6
  80. warp/examples/sim/example_cloth_self_contact.py +14 -6
  81. warp/examples/sim/example_granular.py +14 -6
  82. warp/examples/sim/example_granular_collision_sdf.py +14 -6
  83. warp/examples/sim/example_jacobian_ik.py +14 -6
  84. warp/examples/sim/example_particle_chain.py +14 -6
  85. warp/examples/sim/example_quadruped.py +14 -6
  86. warp/examples/sim/example_rigid_chain.py +14 -6
  87. warp/examples/sim/example_rigid_contact.py +14 -6
  88. warp/examples/sim/example_rigid_force.py +14 -6
  89. warp/examples/sim/example_rigid_gyroscopic.py +14 -6
  90. warp/examples/sim/example_rigid_soft_contact.py +14 -6
  91. warp/examples/sim/example_soft_body.py +14 -6
  92. warp/examples/tile/example_tile_cholesky.py +14 -6
  93. warp/examples/tile/example_tile_convolution.py +14 -6
  94. warp/examples/tile/example_tile_fft.py +14 -6
  95. warp/examples/tile/example_tile_filtering.py +14 -6
  96. warp/examples/tile/example_tile_matmul.py +16 -10
  97. warp/examples/tile/example_tile_mlp.py +14 -6
  98. warp/examples/tile/example_tile_nbody.py +14 -6
  99. warp/examples/tile/example_tile_walker.py +14 -6
  100. warp/fabric.py +15 -0
  101. warp/fem/__init__.py +26 -1
  102. warp/fem/adaptivity.py +19 -4
  103. warp/fem/cache.py +15 -0
  104. warp/fem/dirichlet.py +15 -0
  105. warp/fem/domain.py +15 -0
  106. warp/fem/field/__init__.py +15 -0
  107. warp/fem/field/field.py +15 -0
  108. warp/fem/field/nodal_field.py +37 -68
  109. warp/fem/field/restriction.py +15 -0
  110. warp/fem/field/virtual.py +77 -23
  111. warp/fem/geometry/__init__.py +15 -0
  112. warp/fem/geometry/adaptive_nanogrid.py +24 -10
  113. warp/fem/geometry/closest_point.py +16 -1
  114. warp/fem/geometry/deformed_geometry.py +20 -2
  115. warp/fem/geometry/element.py +15 -0
  116. warp/fem/geometry/geometry.py +20 -0
  117. warp/fem/geometry/grid_2d.py +27 -12
  118. warp/fem/geometry/grid_3d.py +27 -15
  119. warp/fem/geometry/hexmesh.py +20 -7
  120. warp/fem/geometry/nanogrid.py +24 -11
  121. warp/fem/geometry/partition.py +15 -0
  122. warp/fem/geometry/quadmesh.py +28 -13
  123. warp/fem/geometry/tetmesh.py +18 -4
  124. warp/fem/geometry/trimesh.py +18 -8
  125. warp/fem/integrate.py +277 -93
  126. warp/fem/linalg.py +20 -5
  127. warp/fem/operator.py +15 -0
  128. warp/fem/polynomial.py +15 -0
  129. warp/fem/quadrature/__init__.py +15 -0
  130. warp/fem/quadrature/pic_quadrature.py +52 -22
  131. warp/fem/quadrature/quadrature.py +209 -25
  132. warp/fem/space/__init__.py +16 -1
  133. warp/fem/space/basis_function_space.py +19 -2
  134. warp/fem/space/basis_space.py +40 -18
  135. warp/fem/space/dof_mapper.py +15 -0
  136. warp/fem/space/function_space.py +15 -0
  137. warp/fem/space/grid_2d_function_space.py +15 -0
  138. warp/fem/space/grid_3d_function_space.py +15 -0
  139. warp/fem/space/hexmesh_function_space.py +17 -2
  140. warp/fem/space/nanogrid_function_space.py +15 -0
  141. warp/fem/space/partition.py +21 -2
  142. warp/fem/space/quadmesh_function_space.py +23 -8
  143. warp/fem/space/restriction.py +15 -0
  144. warp/fem/space/shape/__init__.py +15 -0
  145. warp/fem/space/shape/cube_shape_function.py +38 -23
  146. warp/fem/space/shape/shape_function.py +15 -0
  147. warp/fem/space/shape/square_shape_function.py +27 -12
  148. warp/fem/space/shape/tet_shape_function.py +15 -0
  149. warp/fem/space/shape/triangle_shape_function.py +16 -1
  150. warp/fem/space/tetmesh_function_space.py +18 -3
  151. warp/fem/space/topology.py +15 -0
  152. warp/fem/space/trimesh_function_space.py +17 -2
  153. warp/fem/types.py +15 -0
  154. warp/fem/utils.py +27 -6
  155. warp/jax.py +28 -7
  156. warp/jax_experimental/__init__.py +16 -0
  157. warp/{jax_experimental.py → jax_experimental/custom_call.py} +28 -33
  158. warp/jax_experimental/ffi.py +698 -0
  159. warp/jax_experimental/xla_ffi.py +602 -0
  160. warp/math.py +103 -6
  161. warp/native/array.h +28 -6
  162. warp/native/builtin.h +44 -9
  163. warp/native/bvh.cpp +18 -7
  164. warp/native/bvh.cu +57 -20
  165. warp/native/bvh.h +17 -7
  166. warp/native/clang/clang.cpp +45 -9
  167. warp/native/coloring.cpp +15 -6
  168. warp/native/crt.cpp +15 -6
  169. warp/native/crt.h +15 -6
  170. warp/native/cuda_crt.h +15 -6
  171. warp/native/cuda_util.cpp +29 -6
  172. warp/native/cuda_util.h +17 -6
  173. warp/native/error.cpp +15 -6
  174. warp/native/error.h +15 -6
  175. warp/native/exports.h +85 -63
  176. warp/native/fabric.h +15 -6
  177. warp/native/hashgrid.cpp +15 -6
  178. warp/native/hashgrid.cu +15 -6
  179. warp/native/hashgrid.h +15 -6
  180. warp/native/initializer_array.h +15 -6
  181. warp/native/intersect.h +41 -32
  182. warp/native/intersect_adj.h +48 -39
  183. warp/native/intersect_tri.h +17 -0
  184. warp/native/marching.cpp +16 -0
  185. warp/native/marching.cu +16 -7
  186. warp/native/marching.h +17 -0
  187. warp/native/mat.h +528 -15
  188. warp/native/mathdx.cpp +15 -6
  189. warp/native/matnn.h +15 -6
  190. warp/native/mesh.cpp +15 -6
  191. warp/native/mesh.cu +15 -6
  192. warp/native/mesh.h +25 -16
  193. warp/native/noise.h +15 -6
  194. warp/native/quat.h +114 -17
  195. warp/native/rand.h +21 -6
  196. warp/native/range.h +15 -6
  197. warp/native/reduce.cpp +15 -6
  198. warp/native/reduce.cu +15 -6
  199. warp/native/runlength_encode.cpp +15 -6
  200. warp/native/runlength_encode.cu +15 -6
  201. warp/native/scan.cpp +15 -6
  202. warp/native/scan.cu +15 -6
  203. warp/native/scan.h +15 -6
  204. warp/native/solid_angle.h +17 -0
  205. warp/native/sort.cpp +137 -65
  206. warp/native/sort.cu +167 -21
  207. warp/native/sort.h +23 -7
  208. warp/native/sparse.cpp +58 -28
  209. warp/native/sparse.cu +67 -23
  210. warp/native/spatial.h +15 -6
  211. warp/native/svd.h +131 -6
  212. warp/native/temp_buffer.h +15 -6
  213. warp/native/tile.h +316 -111
  214. warp/native/tile_reduce.h +61 -9
  215. warp/native/vec.h +83 -13
  216. warp/native/volume.cpp +100 -119
  217. warp/native/volume.cu +15 -6
  218. warp/native/volume.h +15 -6
  219. warp/native/volume_builder.cu +40 -16
  220. warp/native/volume_builder.h +21 -6
  221. warp/native/volume_impl.h +15 -6
  222. warp/native/warp.cpp +20 -12
  223. warp/native/warp.cu +114 -16
  224. warp/native/warp.h +34 -16
  225. warp/optim/__init__.py +14 -6
  226. warp/optim/adam.py +14 -6
  227. warp/optim/linear.py +25 -10
  228. warp/optim/sgd.py +14 -6
  229. warp/paddle.py +14 -6
  230. warp/render/__init__.py +14 -6
  231. warp/render/render_opengl.py +14 -6
  232. warp/render/render_usd.py +14 -6
  233. warp/render/utils.py +14 -6
  234. warp/sim/__init__.py +14 -7
  235. warp/sim/articulation.py +18 -10
  236. warp/sim/collide.py +35 -16
  237. warp/sim/graph_coloring.py +14 -6
  238. warp/sim/import_mjcf.py +463 -162
  239. warp/sim/import_snu.py +14 -7
  240. warp/sim/import_urdf.py +46 -18
  241. warp/sim/import_usd.py +14 -7
  242. warp/sim/inertia.py +14 -6
  243. warp/sim/integrator.py +14 -6
  244. warp/sim/integrator_euler.py +19 -11
  245. warp/sim/integrator_featherstone.py +17 -16
  246. warp/sim/integrator_vbd.py +222 -8
  247. warp/sim/integrator_xpbd.py +19 -11
  248. warp/sim/model.py +56 -19
  249. warp/sim/particles.py +14 -6
  250. warp/sim/render.py +14 -6
  251. warp/sim/utils.py +17 -2
  252. warp/sparse.py +657 -555
  253. warp/stubs.py +231 -19
  254. warp/tape.py +14 -6
  255. warp/tests/aux_test_class_kernel.py +14 -6
  256. warp/tests/aux_test_compile_consts_dummy.py +14 -6
  257. warp/tests/aux_test_conditional_unequal_types_kernels.py +14 -6
  258. warp/tests/aux_test_dependent.py +14 -6
  259. warp/tests/aux_test_grad_customs.py +14 -6
  260. warp/tests/aux_test_instancing_gc.py +14 -6
  261. warp/tests/aux_test_module_unload.py +14 -6
  262. warp/tests/aux_test_name_clash1.py +14 -6
  263. warp/tests/aux_test_name_clash2.py +14 -6
  264. warp/tests/aux_test_unresolved_func.py +14 -6
  265. warp/tests/aux_test_unresolved_symbol.py +14 -6
  266. warp/tests/cuda/__init__.py +0 -0
  267. warp/tests/{test_async.py → cuda/test_async.py} +14 -6
  268. warp/tests/{test_ipc.py → cuda/test_ipc.py} +14 -6
  269. warp/tests/{test_mempool.py → cuda/test_mempool.py} +53 -6
  270. warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +14 -6
  271. warp/tests/{test_peer.py → cuda/test_peer.py} +14 -6
  272. warp/tests/{test_pinned.py → cuda/test_pinned.py} +14 -6
  273. warp/tests/{test_streams.py → cuda/test_streams.py} +85 -6
  274. warp/tests/geometry/__init__.py +0 -0
  275. warp/tests/{test_bvh.py → geometry/test_bvh.py} +14 -6
  276. warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +14 -6
  277. warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +14 -6
  278. warp/tests/{test_mesh.py → geometry/test_mesh.py} +14 -6
  279. warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +14 -6
  280. warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +80 -69
  281. warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +15 -7
  282. warp/tests/{test_volume.py → geometry/test_volume.py} +55 -12
  283. warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +14 -6
  284. warp/tests/interop/__init__.py +0 -0
  285. warp/tests/{test_dlpack.py → interop/test_dlpack.py} +42 -11
  286. warp/tests/{test_jax.py → interop/test_jax.py} +14 -6
  287. warp/tests/{test_paddle.py → interop/test_paddle.py} +14 -6
  288. warp/tests/{test_torch.py → interop/test_torch.py} +14 -6
  289. warp/tests/run_coverage_serial.py +14 -6
  290. warp/tests/sim/__init__.py +0 -0
  291. warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +23 -16
  292. warp/tests/{flaky_test_sim_grad.py → sim/flaky_test_sim_grad.py} +14 -6
  293. warp/tests/{test_collision.py → sim/test_collision.py} +16 -8
  294. warp/tests/{test_coloring.py → sim/test_coloring.py} +14 -7
  295. warp/tests/{test_model.py → sim/test_model.py} +55 -7
  296. warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +14 -6
  297. warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +16 -7
  298. warp/tests/sim/test_vbd.py +597 -0
  299. warp/tests/test_adam.py +14 -6
  300. warp/tests/test_arithmetic.py +14 -6
  301. warp/tests/test_array.py +14 -6
  302. warp/tests/test_array_reduce.py +14 -6
  303. warp/tests/test_assert.py +14 -6
  304. warp/tests/test_atomic.py +14 -6
  305. warp/tests/test_bool.py +15 -7
  306. warp/tests/test_builtins_resolution.py +14 -6
  307. warp/tests/test_closest_point_edge_edge.py +14 -6
  308. warp/tests/test_codegen.py +14 -6
  309. warp/tests/test_codegen_instancing.py +14 -6
  310. warp/tests/test_compile_consts.py +14 -6
  311. warp/tests/test_conditional.py +14 -6
  312. warp/tests/test_context.py +14 -6
  313. warp/tests/test_copy.py +14 -6
  314. warp/tests/test_ctypes.py +14 -6
  315. warp/tests/test_dense.py +14 -6
  316. warp/tests/test_devices.py +14 -6
  317. warp/tests/test_examples.py +42 -42
  318. warp/tests/test_fabricarray.py +14 -6
  319. warp/tests/test_fast_math.py +14 -6
  320. warp/tests/test_fem.py +37 -10
  321. warp/tests/test_fp16.py +14 -6
  322. warp/tests/test_func.py +14 -6
  323. warp/tests/test_future_annotations.py +14 -6
  324. warp/tests/test_generics.py +14 -6
  325. warp/tests/test_grad.py +14 -6
  326. warp/tests/test_grad_customs.py +14 -6
  327. warp/tests/test_grad_debug.py +14 -6
  328. warp/tests/test_implicit_init.py +14 -6
  329. warp/tests/test_import.py +14 -6
  330. warp/tests/test_indexedarray.py +14 -6
  331. warp/tests/test_intersect.py +14 -6
  332. warp/tests/test_iter.py +14 -6
  333. warp/tests/test_large.py +14 -6
  334. warp/tests/test_launch.py +14 -6
  335. warp/tests/test_lerp.py +14 -6
  336. warp/tests/test_linear_solvers.py +15 -11
  337. warp/tests/test_lvalue.py +14 -6
  338. warp/tests/test_mat.py +247 -85
  339. warp/tests/test_mat_lite.py +14 -6
  340. warp/tests/test_mat_scalar_ops.py +18 -10
  341. warp/tests/test_math.py +14 -6
  342. warp/tests/test_mlp.py +14 -6
  343. warp/tests/test_module_hashing.py +14 -6
  344. warp/tests/test_modules_lite.py +14 -6
  345. warp/tests/test_noise.py +14 -6
  346. warp/tests/test_operators.py +14 -6
  347. warp/tests/test_options.py +14 -6
  348. warp/tests/test_overwrite.py +15 -60
  349. warp/tests/test_print.py +14 -6
  350. warp/tests/test_quat.py +81 -52
  351. warp/tests/test_rand.py +58 -43
  352. warp/tests/test_reload.py +14 -6
  353. warp/tests/test_rounding.py +14 -6
  354. warp/tests/test_runlength_encode.py +14 -6
  355. warp/tests/test_scalar_ops.py +14 -6
  356. warp/tests/test_smoothstep.py +14 -6
  357. warp/tests/test_snippet.py +15 -0
  358. warp/tests/test_sparse.py +61 -12
  359. warp/tests/test_spatial.py +89 -6
  360. warp/tests/test_special_values.py +14 -6
  361. warp/tests/test_static.py +15 -7
  362. warp/tests/test_struct.py +14 -6
  363. warp/tests/test_tape.py +14 -6
  364. warp/tests/test_transient_module.py +14 -6
  365. warp/tests/test_triangle_closest_point.py +14 -6
  366. warp/tests/test_types.py +14 -6
  367. warp/tests/test_utils.py +98 -10
  368. warp/tests/test_vec.py +60 -40
  369. warp/tests/test_vec_lite.py +14 -6
  370. warp/tests/test_vec_scalar_ops.py +14 -6
  371. warp/tests/test_verify_fp.py +14 -6
  372. warp/tests/tile/__init__.py +0 -0
  373. warp/tests/{test_tile.py → tile/test_tile.py} +150 -57
  374. warp/tests/{test_tile_load.py → tile/test_tile_load.py} +15 -7
  375. warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +23 -12
  376. warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +39 -20
  377. warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +74 -7
  378. warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +14 -6
  379. warp/tests/{test_tile_view.py → tile/test_tile_view.py} +15 -7
  380. warp/tests/unittest_serial.py +15 -6
  381. warp/tests/unittest_suites.py +59 -65
  382. warp/tests/unittest_utils.py +16 -7
  383. warp/tests/walkthrough_debug.py +14 -6
  384. warp/thirdparty/unittest_parallel.py +15 -8
  385. warp/torch.py +14 -6
  386. warp/types.py +124 -664
  387. warp/utils.py +151 -78
  388. {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/METADATA +39 -12
  389. warp_lang-1.7.0.dist-info/RECORD +429 -0
  390. {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/WHEEL +1 -1
  391. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  392. warp/examples/optim/example_walker.py +0 -309
  393. warp/native/cutlass_gemm.cpp +0 -34
  394. warp/native/cutlass_gemm.cu +0 -373
  395. warp/tests/test_matmul.py +0 -503
  396. warp/tests/test_matmul_lite.py +0 -403
  397. warp/tests/test_vbd.py +0 -378
  398. warp/tests/unused_test_misc.py +0 -69
  399. warp_lang-1.6.1.dist-info/LICENSE.md +0 -126
  400. warp_lang-1.6.1.dist-info/RECORD +0 -419
  401. {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/top_level.txt +0 -0
warp/tests/test_matmul.py DELETED
@@ -1,503 +0,0 @@
1
- # Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
2
- # NVIDIA CORPORATION and its licensors retain all intellectual property
3
- # and proprietary rights in and to this software, related documentation
4
- # and any modifications thereto. Any use, reproduction, disclosure or
5
- # distribution of this software and related documentation without an express
6
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
-
8
- import itertools
9
- import unittest
10
- from typing import Any
11
-
12
- import numpy as np
13
-
14
- import warp as wp
15
- from warp.tests.unittest_utils import *
16
-
17
- wp.init() # For wp.context.runtime.core.is_cutlass_enabled()
18
-
19
-
20
- class gemm_test_bed_runner:
21
- def __init__(self, dtype, device):
22
- self.dtype = dtype
23
- self.device = device
24
-
25
- def alloc(self, m, n, k, batch_count):
26
- rng = np.random.default_rng(42)
27
- low = -4.5
28
- high = 3.5
29
- if batch_count == 1:
30
- A = wp.array2d(
31
- np.ceil(rng.uniform(low=low, high=high, size=(m, k))),
32
- dtype=self.dtype,
33
- device=self.device,
34
- requires_grad=True,
35
- )
36
- B = wp.array2d(
37
- np.ceil(rng.uniform(low=low, high=high, size=(k, n))),
38
- dtype=self.dtype,
39
- device=self.device,
40
- requires_grad=True,
41
- )
42
- C = wp.array2d(
43
- np.ceil(rng.uniform(low=low, high=high, size=(m, n))),
44
- dtype=self.dtype,
45
- device=self.device,
46
- requires_grad=True,
47
- )
48
- D = wp.array2d(np.zeros((m, n)), dtype=self.dtype, device=self.device, requires_grad=True)
49
- else:
50
- A = wp.array3d(
51
- np.ceil(rng.uniform(low=low, high=high, size=(batch_count, m, k))),
52
- dtype=self.dtype,
53
- device=self.device,
54
- requires_grad=True,
55
- )
56
- B = wp.array3d(
57
- np.ceil(rng.uniform(low=low, high=high, size=(batch_count, k, n))),
58
- dtype=self.dtype,
59
- device=self.device,
60
- requires_grad=True,
61
- )
62
- C = wp.array3d(
63
- np.ceil(rng.uniform(low=low, high=high, size=(batch_count, m, n))),
64
- dtype=self.dtype,
65
- device=self.device,
66
- requires_grad=True,
67
- )
68
- D = wp.array3d(np.zeros((batch_count, m, n)), dtype=self.dtype, device=self.device, requires_grad=True)
69
- return A, B, C, D
70
-
71
- def run_and_verify(self, m, n, k, batch_count, alpha, beta):
72
- A, B, C, D = self.alloc(m, n, k, batch_count)
73
- ones = wp.zeros_like(D)
74
- ones.fill_(1.0)
75
-
76
- np_dtype = wp.types.warp_type_to_np_dtype[self.dtype]
77
-
78
- if batch_count == 1:
79
- tape = wp.Tape()
80
- with tape:
81
- wp.matmul(A, B, C, D, alpha, beta, False)
82
- tape.backward(grads={D: ones})
83
-
84
- D_np = alpha * np.matmul(A.numpy(), B.numpy(), dtype=np_dtype) + beta * C.numpy()
85
- assert_np_equal(D.numpy(), D_np)
86
-
87
- adj_A_np = alpha * np.matmul(ones.numpy(), B.numpy().transpose(), dtype=np_dtype)
88
- adj_B_np = alpha * np.matmul(A.numpy().transpose(), ones.numpy(), dtype=np_dtype)
89
- adj_C_np = beta * ones.numpy()
90
-
91
- else:
92
- tape = wp.Tape()
93
- with tape:
94
- wp.batched_matmul(A, B, C, D, alpha, beta, False)
95
- tape.backward(grads={D: ones})
96
-
97
- D_np = alpha * np.matmul(A.numpy(), B.numpy(), dtype=np_dtype) + beta * C.numpy()
98
- assert_np_equal(D.numpy(), D_np)
99
-
100
- adj_A_np = alpha * np.matmul(ones.numpy(), B.numpy().transpose((0, 2, 1)), dtype=np_dtype)
101
- adj_B_np = alpha * np.matmul(A.numpy().transpose((0, 2, 1)), ones.numpy(), dtype=np_dtype)
102
- adj_C_np = beta * ones.numpy()
103
-
104
- assert_np_equal(A.grad.numpy(), adj_A_np)
105
- assert_np_equal(B.grad.numpy(), adj_B_np)
106
- assert_np_equal(C.grad.numpy(), adj_C_np)
107
-
108
- def run(self):
109
- Ms = [16, 32, 64]
110
- Ns = [16, 32, 64]
111
- Ks = [16, 32, 64]
112
- batch_counts = [1, 4]
113
- betas = [0.0, 1.0]
114
- alpha = 1.0
115
-
116
- for batch_count, m, n, k, beta in itertools.product(batch_counts, Ms, Ns, Ks, betas):
117
- self.run_and_verify(m, n, k, batch_count, alpha, beta)
118
-
119
-
120
- class gemm_test_bed_runner_transpose:
121
- def __init__(self, dtype, device):
122
- self.dtype = dtype
123
- self.device = device
124
-
125
- def alloc(self, m, n, k, batch_count):
126
- rng = np.random.default_rng(42)
127
- low = -4.5
128
- high = 3.5
129
- if batch_count == 1:
130
- A = wp.array2d(
131
- np.ceil(rng.uniform(low=low, high=high, size=(m, k))),
132
- dtype=self.dtype,
133
- device=self.device,
134
- requires_grad=True,
135
- )
136
- B = wp.array2d(
137
- np.ceil(rng.uniform(low=low, high=high, size=(k, n))),
138
- dtype=self.dtype,
139
- device=self.device,
140
- requires_grad=True,
141
- )
142
- C = wp.array2d(
143
- np.ceil(rng.uniform(low=low, high=high, size=(m, n))),
144
- dtype=self.dtype,
145
- device=self.device,
146
- requires_grad=True,
147
- )
148
- D = wp.array2d(np.zeros((m, n)), dtype=self.dtype, device=self.device, requires_grad=True)
149
- AT = wp.array2d(A.numpy().transpose([1, 0]), dtype=self.dtype, device=self.device, requires_grad=True)
150
- BT = wp.array2d(B.numpy().transpose([1, 0]), dtype=self.dtype, device=self.device, requires_grad=True)
151
- else:
152
- A = wp.array3d(
153
- np.ceil(rng.uniform(low=low, high=high, size=(batch_count, m, k))),
154
- dtype=self.dtype,
155
- device=self.device,
156
- requires_grad=True,
157
- )
158
- B = wp.array3d(
159
- np.ceil(rng.uniform(low=low, high=high, size=(batch_count, k, n))),
160
- dtype=self.dtype,
161
- device=self.device,
162
- requires_grad=True,
163
- )
164
- C = wp.array3d(
165
- np.ceil(rng.uniform(low=low, high=high, size=(batch_count, m, n))),
166
- dtype=self.dtype,
167
- device=self.device,
168
- requires_grad=True,
169
- )
170
- D = wp.array3d(np.zeros((batch_count, m, n)), dtype=self.dtype, device=self.device, requires_grad=True)
171
- AT = wp.array3d(A.numpy().transpose([0, 2, 1]), dtype=self.dtype, device=self.device, requires_grad=True)
172
- BT = wp.array3d(B.numpy().transpose([0, 2, 1]), dtype=self.dtype, device=self.device, requires_grad=True)
173
- return A, B, C, D, AT, BT
174
-
175
- def run_and_verify(self, m, n, k, batch_count, alpha, beta):
176
- A, B, C1, D1, AT1, BT1 = self.alloc(m, n, k, batch_count)
177
- C2 = wp.clone(C1)
178
- C3 = wp.clone(C1)
179
- D2 = wp.clone(D1)
180
- D3 = wp.clone(D1)
181
- AT2 = wp.clone(AT1)
182
- BT2 = wp.clone(BT1)
183
- ones1 = wp.zeros_like(D1)
184
- ones1.fill_(1.0)
185
- ones2 = wp.zeros_like(D2)
186
- ones2.fill_(1.0)
187
- ones3 = wp.zeros_like(D3)
188
- ones3.fill_(1.0)
189
-
190
- np_dtype = wp.types.warp_type_to_np_dtype[self.dtype]
191
-
192
- if batch_count == 1:
193
- ATT1 = AT1.transpose([1, 0])
194
- BTT1 = BT1.transpose([1, 0])
195
- ATT2 = AT2.transpose([1, 0])
196
- BTT2 = BT2.transpose([1, 0])
197
- tape = wp.Tape()
198
- with tape:
199
- wp.matmul(A, BTT1, C1, D1, alpha, beta, False)
200
- wp.matmul(ATT1, B, C2, D2, alpha, beta, False)
201
- wp.matmul(ATT2, BTT2, C3, D3, alpha, beta, False)
202
- tape.backward(grads={D1: ones1, D2: ones2, D3: ones3})
203
-
204
- D_np = alpha * np.matmul(A.numpy(), B.numpy(), dtype=np_dtype) + beta * C1.numpy()
205
- assert_np_equal(D1.numpy(), D_np)
206
- assert_np_equal(D2.numpy(), D_np)
207
- assert_np_equal(D3.numpy(), D_np)
208
-
209
- adj_A_np = alpha * np.matmul(ones1.numpy(), B.numpy().transpose(), dtype=np_dtype)
210
- adj_B_np = alpha * np.matmul(A.numpy().transpose(), ones1.numpy(), dtype=np_dtype)
211
- adj_C_np = beta * ones1.numpy()
212
-
213
- else:
214
- ATT1 = AT1.transpose([0, 2, 1])
215
- BTT1 = BT1.transpose([0, 2, 1])
216
- ATT2 = AT2.transpose([0, 2, 1])
217
- BTT2 = BT2.transpose([0, 2, 1])
218
- tape = wp.Tape()
219
- with tape:
220
- wp.batched_matmul(A, BTT1, C1, D1, alpha, beta, False)
221
- wp.batched_matmul(ATT1, B, C2, D2, alpha, beta, False)
222
- wp.batched_matmul(ATT2, BTT2, C3, D3, alpha, beta, False)
223
- tape.backward(grads={D1: ones1, D2: ones2, D3: ones3})
224
-
225
- D_np = alpha * np.matmul(A.numpy(), B.numpy(), dtype=np_dtype) + beta * C1.numpy()
226
- assert_np_equal(D1.numpy(), D_np)
227
- assert_np_equal(D2.numpy(), D_np)
228
- assert_np_equal(D3.numpy(), D_np)
229
-
230
- adj_A_np = alpha * np.matmul(ones1.numpy(), B.numpy().transpose((0, 2, 1)), dtype=np_dtype)
231
- adj_B_np = alpha * np.matmul(A.numpy().transpose((0, 2, 1)), ones1.numpy(), dtype=np_dtype)
232
- adj_C_np = beta * ones1.numpy()
233
-
234
- assert_np_equal(A.grad.numpy(), adj_A_np)
235
- assert_np_equal(ATT1.grad.numpy(), adj_A_np)
236
- assert_np_equal(ATT2.grad.numpy(), adj_A_np)
237
- assert_np_equal(B.grad.numpy(), adj_B_np)
238
- assert_np_equal(BTT1.grad.numpy(), adj_B_np)
239
- assert_np_equal(BTT2.grad.numpy(), adj_B_np)
240
- assert_np_equal(C1.grad.numpy(), adj_C_np)
241
- assert_np_equal(C2.grad.numpy(), adj_C_np)
242
- assert_np_equal(C3.grad.numpy(), adj_C_np)
243
-
244
- def run(self):
245
- m = 16
246
- n = 32
247
- k = 64
248
- batch_counts = [1, 4]
249
- beta = 1.0
250
- alpha = 1.0
251
-
252
- for batch_count in batch_counts:
253
- self.run_and_verify(m, n, k, batch_count, alpha, beta)
254
-
255
-
256
- # NOTE: F16 tests are slow due to the performance of the reference numpy F16 matmuls performed on CPU.
257
- def test_f16(test, device):
258
- gemm_test_bed_runner(wp.float16, device).run()
259
- gemm_test_bed_runner_transpose(wp.float16, device).run()
260
-
261
-
262
- @unittest.skipUnless(wp.context.runtime.core.is_cutlass_enabled(), "Warp was not built with CUTLASS support")
263
- def test_f32(test, device):
264
- gemm_test_bed_runner(wp.float32, device).run()
265
- gemm_test_bed_runner_transpose(wp.float32, device).run()
266
-
267
-
268
- @unittest.skipUnless(wp.context.runtime.core.is_cutlass_enabled(), "Warp was not built with CUTLASS support")
269
- def test_f64(test, device):
270
- gemm_test_bed_runner(wp.float64, device).run()
271
- gemm_test_bed_runner_transpose(wp.float64, device).run()
272
-
273
-
274
- @wp.kernel
275
- def matrix_sum_kernel(arr: wp.array2d(dtype=float), loss: wp.array(dtype=float)):
276
- i, j = wp.tid()
277
- wp.atomic_add(loss, 0, arr[i, j])
278
-
279
-
280
- @unittest.skipUnless(wp.context.runtime.core.is_cutlass_enabled(), "Warp was not built with CUTLASS support")
281
- def test_tape(test, device):
282
- rng = np.random.default_rng(42)
283
- low = -4.5
284
- high = 3.5
285
- m = 64
286
- n = 128
287
- k = 256
288
- A = wp.array2d(
289
- np.ceil(rng.uniform(low=low, high=high, size=(m, k))), dtype=float, device=device, requires_grad=True
290
- )
291
- B = wp.array2d(
292
- np.ceil(rng.uniform(low=low, high=high, size=(k, n))), dtype=float, device=device, requires_grad=True
293
- )
294
- C = wp.array2d(
295
- np.ceil(rng.uniform(low=low, high=high, size=(m, n))), dtype=float, device=device, requires_grad=True
296
- )
297
- D = wp.array2d(np.zeros((m, n)), dtype=float, device=device, requires_grad=True)
298
- loss = wp.zeros(1, dtype=float, device=device, requires_grad=True)
299
-
300
- # test tape
301
- tape = wp.Tape()
302
- with tape:
303
- wp.matmul(A, B, C, D)
304
- wp.launch(matrix_sum_kernel, dim=(m, n), inputs=[D, loss], device=device)
305
-
306
- tape.backward(loss=loss)
307
- A_grad = A.grad.numpy()
308
- tape.reset()
309
-
310
- # test adjoint
311
- D.grad = wp.ones((m, n), dtype=float, device=device)
312
- wp.adj_matmul(A, B, C, A.grad, B.grad, C.grad, D.grad)
313
- assert_np_equal(A_grad, A.grad.numpy())
314
-
315
- # test zero
316
- tape.zero()
317
- assert_array_equal(A.grad, wp.zeros_like(A))
318
-
319
-
320
- @unittest.skipUnless(wp.context.runtime.core.is_cutlass_enabled(), "Warp was not built with CUTLASS support")
321
- def test_operator(test, device):
322
- rng = np.random.default_rng(42)
323
- low = -4.5
324
- high = 3.5
325
- m = 64
326
- n = 128
327
- k = 256
328
- A = wp.array2d(
329
- np.ceil(rng.uniform(low=low, high=high, size=(m, k))), dtype=float, device=device, requires_grad=True
330
- )
331
- B = wp.array2d(
332
- np.ceil(rng.uniform(low=low, high=high, size=(k, n))), dtype=float, device=device, requires_grad=True
333
- )
334
- loss = wp.zeros(1, dtype=float, device=device, requires_grad=True)
335
-
336
- # test tape
337
- tape = wp.Tape()
338
- with tape:
339
- D = A @ B
340
- wp.launch(matrix_sum_kernel, dim=(m, n), inputs=[D, loss], device=device)
341
-
342
- tape.backward(loss=loss)
343
-
344
- # test adjoint
345
- D.grad = wp.ones((m, n), dtype=float, device=device)
346
- B_transpose = wp.array2d(B.transpose().numpy(), dtype=float, device=device)
347
-
348
- adj_A = D.grad @ B_transpose
349
- assert_array_equal(adj_A, A.grad)
350
-
351
- # test zero
352
- tape.zero()
353
- assert_array_equal(A.grad, wp.zeros_like(A))
354
-
355
-
356
- @unittest.skipUnless(wp.context.runtime.core.is_cutlass_enabled(), "Warp was not built with CUTLASS support")
357
- def test_large_batch_count(test, device):
358
- rng = np.random.default_rng(42)
359
- low = -4.5
360
- high = 3.5
361
- m = 2
362
- n = 3
363
- k = 4
364
- batch_count = 65535 * 2 + int(65535 / 2)
365
- A = wp.array3d(
366
- np.ceil(rng.uniform(low=low, high=high, size=(batch_count, m, k))),
367
- dtype=float,
368
- device=device,
369
- requires_grad=True,
370
- )
371
- B = wp.array3d(
372
- np.ceil(rng.uniform(low=low, high=high, size=(batch_count, k, n))),
373
- dtype=float,
374
- device=device,
375
- requires_grad=True,
376
- )
377
- C = wp.array3d(
378
- np.ceil(rng.uniform(low=low, high=high, size=(batch_count, m, n))),
379
- dtype=float,
380
- device=device,
381
- requires_grad=True,
382
- )
383
- D = wp.array3d(np.zeros((batch_count, m, n)), dtype=float, device=device, requires_grad=True)
384
- ones = wp.zeros_like(D)
385
- ones.fill_(1.0)
386
-
387
- alpha = 1.0
388
- beta = 1.0
389
-
390
- tape = wp.Tape()
391
- with tape:
392
- wp.batched_matmul(A, B, C, D, alpha=alpha, beta=beta, allow_tf32x3_arith=False)
393
- tape.backward(grads={D: ones})
394
-
395
- D_np = alpha * np.matmul(A.numpy(), B.numpy()) + beta * C.numpy()
396
- assert_np_equal(D.numpy(), D_np)
397
-
398
- adj_A_np = alpha * np.matmul(ones.numpy(), B.numpy().transpose((0, 2, 1)))
399
- adj_B_np = alpha * np.matmul(A.numpy().transpose((0, 2, 1)), ones.numpy())
400
- adj_C_np = beta * ones.numpy()
401
-
402
- assert_np_equal(A.grad.numpy(), adj_A_np)
403
- assert_np_equal(B.grad.numpy(), adj_B_np)
404
- assert_np_equal(C.grad.numpy(), adj_C_np)
405
-
406
-
407
- @unittest.skipUnless(wp.context.runtime.core.is_cutlass_enabled(), "Warp was not built with CUTLASS support")
408
- def test_adjoint_accumulation(test, device):
409
- a_np = np.ones(shape=(2, 3))
410
- b_np = np.ones(shape=(3, 2))
411
- c_np = np.zeros(shape=(2, 2))
412
- d_np = np.zeros(shape=(2, 2))
413
-
414
- a_wp = wp.from_numpy(a_np, dtype=float, requires_grad=True, device=device)
415
- b_wp = wp.from_numpy(b_np, dtype=float, requires_grad=True, device=device)
416
- c_wp = wp.from_numpy(c_np, dtype=float, requires_grad=True, device=device)
417
- d1_wp = wp.from_numpy(d_np, dtype=float, requires_grad=True, device=device)
418
- d2_wp = wp.from_numpy(d_np, dtype=float, requires_grad=True, device=device)
419
-
420
- tape = wp.Tape()
421
-
422
- with tape:
423
- wp.matmul(a_wp, b_wp, c_wp, d1_wp, alpha=1.0, beta=1.0)
424
- wp.matmul(a_wp, b_wp, d1_wp, d2_wp, alpha=1.0, beta=1.0)
425
-
426
- d_grad = wp.zeros_like(d2_wp, device=device)
427
- d_grad.fill_(1.0)
428
- grads = {d2_wp: d_grad}
429
- tape.backward(grads=grads)
430
-
431
- assert_np_equal(a_wp.grad.numpy(), 4.0 * np.ones(shape=(2, 3)))
432
- assert_np_equal(b_wp.grad.numpy(), 4.0 * np.ones(shape=(3, 2)))
433
- assert_np_equal(c_wp.grad.numpy(), np.ones(shape=(2, 2)))
434
-
435
-
436
- @unittest.skipUnless(wp.context.runtime.core.is_cutlass_enabled(), "Warp was not built with CUTLASS support")
437
- def test_cuda_graph_capture(test, device):
438
- @wp.kernel
439
- def mat_sum(mat: wp.array2d(dtype=Any), loss: wp.array(dtype=Any)):
440
- i, j = wp.tid()
441
- e = mat[i, j]
442
- wp.atomic_add(loss, 0, e)
443
-
444
- for T in [wp.float16, wp.float32, wp.float64]:
445
- wp.overload(mat_sum, [wp.array2d(dtype=T), wp.array(dtype=T)])
446
-
447
- wp.load_module(device=device)
448
- wp.load_module(module="warp.utils", device=device)
449
-
450
- for dtype in [wp.float16, wp.float32, wp.float64]:
451
- m = 8
452
- n = 8
453
- k = 8
454
-
455
- A = wp.ones((m, n), dtype=dtype, device=device, requires_grad=True)
456
- B = wp.ones((n, k), dtype=dtype, device=device, requires_grad=True)
457
- C = wp.zeros((m, k), dtype=dtype, device=device, requires_grad=True)
458
- D = wp.zeros((m, k), dtype=dtype, device=device, requires_grad=True)
459
-
460
- loss = wp.zeros(1, dtype=dtype, device=device, requires_grad=True)
461
-
462
- wp.capture_begin(device, force_module_load=False)
463
- try:
464
- tape = wp.Tape()
465
-
466
- with tape:
467
- wp.matmul(A, B, C, D)
468
- wp.launch(mat_sum, dim=(m, k), inputs=[D, loss], device=device)
469
-
470
- tape.backward(loss=loss)
471
- finally:
472
- graph = wp.capture_end(device)
473
-
474
- wp.capture_launch(graph)
475
-
476
- assert_np_equal(A.grad.numpy(), 8.0 * np.ones((m, n), dtype=wp.types.warp_type_to_np_dtype[dtype]))
477
-
478
-
479
- devices = get_test_devices()
480
- cuda_devices = get_selected_cuda_test_devices()
481
-
482
-
483
- class TestMatmul(unittest.TestCase):
484
- pass
485
-
486
-
487
- # add_function_test(TestMatmul, "test_f16", test_f16, devices=devices)
488
- add_function_test(TestMatmul, "test_f32", test_f32, devices=devices, check_output=False)
489
- add_function_test(TestMatmul, "test_f64", test_f64, devices=devices, check_output=False)
490
- add_function_test(TestMatmul, "test_tape", test_tape, devices=devices, check_output=False)
491
- add_function_test(TestMatmul, "test_operator", test_operator, devices=devices, check_output=False)
492
- add_function_test(TestMatmul, "test_large_batch_count", test_large_batch_count, devices=devices, check_output=False)
493
- add_function_test(
494
- TestMatmul, "test_adjoint_accumulation", test_adjoint_accumulation, devices=devices, check_output=False
495
- )
496
- add_function_test(
497
- TestMatmul, "test_cuda_graph_capture", test_cuda_graph_capture, devices=cuda_devices, check_output=False
498
- )
499
-
500
-
501
- if __name__ == "__main__":
502
- wp.clear_kernel_cache()
503
- unittest.main(verbosity=2, failfast=False)