warp-lang 1.6.1__py3-none-macosx_10_13_universal2.whl → 1.7.0__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (401) hide show
  1. warp/__init__.py +21 -7
  2. warp/autograd.py +14 -6
  3. warp/bin/libwarp-clang.dylib +0 -0
  4. warp/bin/libwarp.dylib +0 -0
  5. warp/build.py +424 -6
  6. warp/build_dll.py +20 -20
  7. warp/builtins.py +467 -368
  8. warp/codegen.py +193 -125
  9. warp/config.py +56 -12
  10. warp/constants.py +14 -6
  11. warp/context.py +524 -277
  12. warp/dlpack.py +22 -12
  13. warp/examples/__init__.py +14 -6
  14. warp/examples/assets/nonuniform.usd +0 -0
  15. warp/examples/assets/nvidia_logo.png +0 -0
  16. warp/examples/benchmarks/benchmark_api.py +14 -6
  17. warp/examples/benchmarks/benchmark_cloth.py +14 -6
  18. warp/examples/benchmarks/benchmark_cloth_cupy.py +14 -6
  19. warp/examples/benchmarks/benchmark_cloth_jax.py +14 -6
  20. warp/examples/benchmarks/benchmark_cloth_numba.py +15 -0
  21. warp/examples/benchmarks/benchmark_cloth_numpy.py +14 -6
  22. warp/examples/benchmarks/benchmark_cloth_paddle.py +14 -6
  23. warp/examples/benchmarks/benchmark_cloth_pytorch.py +14 -6
  24. warp/examples/benchmarks/benchmark_cloth_taichi.py +14 -6
  25. warp/examples/benchmarks/benchmark_cloth_warp.py +14 -6
  26. warp/examples/benchmarks/benchmark_gemm.py +82 -48
  27. warp/examples/benchmarks/benchmark_interop_paddle.py +14 -6
  28. warp/examples/benchmarks/benchmark_interop_torch.py +14 -6
  29. warp/examples/benchmarks/benchmark_launches.py +14 -6
  30. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  31. warp/examples/browse.py +14 -6
  32. warp/examples/core/example_cupy.py +14 -6
  33. warp/examples/core/example_dem.py +14 -6
  34. warp/examples/core/example_fluid.py +14 -6
  35. warp/examples/core/example_graph_capture.py +14 -6
  36. warp/examples/core/example_marching_cubes.py +14 -6
  37. warp/examples/core/example_mesh.py +14 -6
  38. warp/examples/core/example_mesh_intersect.py +14 -6
  39. warp/examples/core/example_nvdb.py +14 -6
  40. warp/examples/core/example_raycast.py +14 -6
  41. warp/examples/core/example_raymarch.py +14 -6
  42. warp/examples/core/example_render_opengl.py +14 -6
  43. warp/examples/core/example_sample_mesh.py +300 -0
  44. warp/examples/core/example_sph.py +14 -6
  45. warp/examples/core/example_torch.py +14 -6
  46. warp/examples/core/example_wave.py +14 -6
  47. warp/examples/fem/example_adaptive_grid.py +14 -6
  48. warp/examples/fem/example_apic_fluid.py +15 -7
  49. warp/examples/fem/example_burgers.py +16 -8
  50. warp/examples/fem/example_convection_diffusion.py +14 -6
  51. warp/examples/fem/example_convection_diffusion_dg.py +14 -6
  52. warp/examples/fem/example_deformed_geometry.py +15 -7
  53. warp/examples/fem/example_diffusion.py +14 -6
  54. warp/examples/fem/example_diffusion_3d.py +14 -6
  55. warp/examples/fem/example_diffusion_mgpu.py +14 -6
  56. warp/examples/fem/example_distortion_energy.py +15 -7
  57. warp/examples/fem/example_magnetostatics.py +20 -12
  58. warp/examples/fem/example_mixed_elasticity.py +14 -6
  59. warp/examples/fem/example_navier_stokes.py +14 -6
  60. warp/examples/fem/example_nonconforming_contact.py +14 -6
  61. warp/examples/fem/example_stokes.py +14 -6
  62. warp/examples/fem/example_stokes_transfer.py +14 -6
  63. warp/examples/fem/example_streamlines.py +14 -6
  64. warp/examples/fem/utils.py +24 -3
  65. warp/examples/interop/example_jax_callable.py +116 -0
  66. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  67. warp/examples/interop/example_jax_kernel.py +205 -0
  68. warp/examples/optim/example_bounce.py +14 -6
  69. warp/examples/optim/example_cloth_throw.py +14 -6
  70. warp/examples/optim/example_diffray.py +14 -6
  71. warp/examples/optim/example_drone.py +14 -6
  72. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  73. warp/examples/optim/example_inverse_kinematics.py +14 -6
  74. warp/examples/optim/example_inverse_kinematics_torch.py +14 -6
  75. warp/examples/optim/example_softbody_properties.py +14 -6
  76. warp/examples/optim/example_spring_cage.py +14 -6
  77. warp/examples/optim/example_trajectory.py +14 -6
  78. warp/examples/sim/example_cartpole.py +14 -6
  79. warp/examples/sim/example_cloth.py +14 -6
  80. warp/examples/sim/example_cloth_self_contact.py +14 -6
  81. warp/examples/sim/example_granular.py +14 -6
  82. warp/examples/sim/example_granular_collision_sdf.py +14 -6
  83. warp/examples/sim/example_jacobian_ik.py +14 -6
  84. warp/examples/sim/example_particle_chain.py +14 -6
  85. warp/examples/sim/example_quadruped.py +14 -6
  86. warp/examples/sim/example_rigid_chain.py +14 -6
  87. warp/examples/sim/example_rigid_contact.py +14 -6
  88. warp/examples/sim/example_rigid_force.py +14 -6
  89. warp/examples/sim/example_rigid_gyroscopic.py +14 -6
  90. warp/examples/sim/example_rigid_soft_contact.py +14 -6
  91. warp/examples/sim/example_soft_body.py +14 -6
  92. warp/examples/tile/example_tile_cholesky.py +14 -6
  93. warp/examples/tile/example_tile_convolution.py +14 -6
  94. warp/examples/tile/example_tile_fft.py +14 -6
  95. warp/examples/tile/example_tile_filtering.py +14 -6
  96. warp/examples/tile/example_tile_matmul.py +16 -10
  97. warp/examples/tile/example_tile_mlp.py +14 -6
  98. warp/examples/tile/example_tile_nbody.py +14 -6
  99. warp/examples/tile/example_tile_walker.py +14 -6
  100. warp/fabric.py +15 -0
  101. warp/fem/__init__.py +26 -1
  102. warp/fem/adaptivity.py +19 -4
  103. warp/fem/cache.py +15 -0
  104. warp/fem/dirichlet.py +15 -0
  105. warp/fem/domain.py +15 -0
  106. warp/fem/field/__init__.py +15 -0
  107. warp/fem/field/field.py +15 -0
  108. warp/fem/field/nodal_field.py +37 -68
  109. warp/fem/field/restriction.py +15 -0
  110. warp/fem/field/virtual.py +77 -23
  111. warp/fem/geometry/__init__.py +15 -0
  112. warp/fem/geometry/adaptive_nanogrid.py +24 -10
  113. warp/fem/geometry/closest_point.py +16 -1
  114. warp/fem/geometry/deformed_geometry.py +20 -2
  115. warp/fem/geometry/element.py +15 -0
  116. warp/fem/geometry/geometry.py +20 -0
  117. warp/fem/geometry/grid_2d.py +27 -12
  118. warp/fem/geometry/grid_3d.py +27 -15
  119. warp/fem/geometry/hexmesh.py +20 -7
  120. warp/fem/geometry/nanogrid.py +24 -11
  121. warp/fem/geometry/partition.py +15 -0
  122. warp/fem/geometry/quadmesh.py +28 -13
  123. warp/fem/geometry/tetmesh.py +18 -4
  124. warp/fem/geometry/trimesh.py +18 -8
  125. warp/fem/integrate.py +277 -93
  126. warp/fem/linalg.py +20 -5
  127. warp/fem/operator.py +15 -0
  128. warp/fem/polynomial.py +15 -0
  129. warp/fem/quadrature/__init__.py +15 -0
  130. warp/fem/quadrature/pic_quadrature.py +52 -22
  131. warp/fem/quadrature/quadrature.py +209 -25
  132. warp/fem/space/__init__.py +16 -1
  133. warp/fem/space/basis_function_space.py +19 -2
  134. warp/fem/space/basis_space.py +40 -18
  135. warp/fem/space/dof_mapper.py +15 -0
  136. warp/fem/space/function_space.py +15 -0
  137. warp/fem/space/grid_2d_function_space.py +15 -0
  138. warp/fem/space/grid_3d_function_space.py +15 -0
  139. warp/fem/space/hexmesh_function_space.py +17 -2
  140. warp/fem/space/nanogrid_function_space.py +15 -0
  141. warp/fem/space/partition.py +21 -2
  142. warp/fem/space/quadmesh_function_space.py +23 -8
  143. warp/fem/space/restriction.py +15 -0
  144. warp/fem/space/shape/__init__.py +15 -0
  145. warp/fem/space/shape/cube_shape_function.py +38 -23
  146. warp/fem/space/shape/shape_function.py +15 -0
  147. warp/fem/space/shape/square_shape_function.py +27 -12
  148. warp/fem/space/shape/tet_shape_function.py +15 -0
  149. warp/fem/space/shape/triangle_shape_function.py +16 -1
  150. warp/fem/space/tetmesh_function_space.py +18 -3
  151. warp/fem/space/topology.py +15 -0
  152. warp/fem/space/trimesh_function_space.py +17 -2
  153. warp/fem/types.py +15 -0
  154. warp/fem/utils.py +27 -6
  155. warp/jax.py +28 -7
  156. warp/jax_experimental/__init__.py +16 -0
  157. warp/{jax_experimental.py → jax_experimental/custom_call.py} +28 -33
  158. warp/jax_experimental/ffi.py +698 -0
  159. warp/jax_experimental/xla_ffi.py +602 -0
  160. warp/math.py +103 -6
  161. warp/native/array.h +28 -6
  162. warp/native/builtin.h +44 -9
  163. warp/native/bvh.cpp +18 -7
  164. warp/native/bvh.cu +57 -20
  165. warp/native/bvh.h +17 -7
  166. warp/native/clang/clang.cpp +45 -9
  167. warp/native/coloring.cpp +15 -6
  168. warp/native/crt.cpp +15 -6
  169. warp/native/crt.h +15 -6
  170. warp/native/cuda_crt.h +15 -6
  171. warp/native/cuda_util.cpp +29 -6
  172. warp/native/cuda_util.h +17 -6
  173. warp/native/error.cpp +15 -6
  174. warp/native/error.h +15 -6
  175. warp/native/exports.h +85 -63
  176. warp/native/fabric.h +15 -6
  177. warp/native/hashgrid.cpp +15 -6
  178. warp/native/hashgrid.cu +15 -6
  179. warp/native/hashgrid.h +15 -6
  180. warp/native/initializer_array.h +15 -6
  181. warp/native/intersect.h +41 -32
  182. warp/native/intersect_adj.h +48 -39
  183. warp/native/intersect_tri.h +17 -0
  184. warp/native/marching.cpp +16 -0
  185. warp/native/marching.cu +16 -7
  186. warp/native/marching.h +17 -0
  187. warp/native/mat.h +528 -15
  188. warp/native/mathdx.cpp +15 -6
  189. warp/native/matnn.h +15 -6
  190. warp/native/mesh.cpp +15 -6
  191. warp/native/mesh.cu +15 -6
  192. warp/native/mesh.h +25 -16
  193. warp/native/noise.h +15 -6
  194. warp/native/quat.h +114 -17
  195. warp/native/rand.h +21 -6
  196. warp/native/range.h +15 -6
  197. warp/native/reduce.cpp +15 -6
  198. warp/native/reduce.cu +15 -6
  199. warp/native/runlength_encode.cpp +15 -6
  200. warp/native/runlength_encode.cu +15 -6
  201. warp/native/scan.cpp +15 -6
  202. warp/native/scan.cu +15 -6
  203. warp/native/scan.h +15 -6
  204. warp/native/solid_angle.h +17 -0
  205. warp/native/sort.cpp +137 -65
  206. warp/native/sort.cu +167 -21
  207. warp/native/sort.h +23 -7
  208. warp/native/sparse.cpp +58 -28
  209. warp/native/sparse.cu +67 -23
  210. warp/native/spatial.h +15 -6
  211. warp/native/svd.h +131 -6
  212. warp/native/temp_buffer.h +15 -6
  213. warp/native/tile.h +316 -111
  214. warp/native/tile_reduce.h +61 -9
  215. warp/native/vec.h +83 -13
  216. warp/native/volume.cpp +100 -119
  217. warp/native/volume.cu +15 -6
  218. warp/native/volume.h +15 -6
  219. warp/native/volume_builder.cu +40 -16
  220. warp/native/volume_builder.h +21 -6
  221. warp/native/volume_impl.h +15 -6
  222. warp/native/warp.cpp +20 -12
  223. warp/native/warp.cu +114 -16
  224. warp/native/warp.h +34 -16
  225. warp/optim/__init__.py +14 -6
  226. warp/optim/adam.py +14 -6
  227. warp/optim/linear.py +25 -10
  228. warp/optim/sgd.py +14 -6
  229. warp/paddle.py +14 -6
  230. warp/render/__init__.py +14 -6
  231. warp/render/render_opengl.py +14 -6
  232. warp/render/render_usd.py +14 -6
  233. warp/render/utils.py +14 -6
  234. warp/sim/__init__.py +14 -7
  235. warp/sim/articulation.py +18 -10
  236. warp/sim/collide.py +35 -16
  237. warp/sim/graph_coloring.py +14 -6
  238. warp/sim/import_mjcf.py +463 -162
  239. warp/sim/import_snu.py +14 -7
  240. warp/sim/import_urdf.py +46 -18
  241. warp/sim/import_usd.py +14 -7
  242. warp/sim/inertia.py +14 -6
  243. warp/sim/integrator.py +14 -6
  244. warp/sim/integrator_euler.py +19 -11
  245. warp/sim/integrator_featherstone.py +17 -16
  246. warp/sim/integrator_vbd.py +222 -8
  247. warp/sim/integrator_xpbd.py +19 -11
  248. warp/sim/model.py +56 -19
  249. warp/sim/particles.py +14 -6
  250. warp/sim/render.py +14 -6
  251. warp/sim/utils.py +17 -2
  252. warp/sparse.py +657 -555
  253. warp/stubs.py +231 -19
  254. warp/tape.py +14 -6
  255. warp/tests/aux_test_class_kernel.py +14 -6
  256. warp/tests/aux_test_compile_consts_dummy.py +14 -6
  257. warp/tests/aux_test_conditional_unequal_types_kernels.py +14 -6
  258. warp/tests/aux_test_dependent.py +14 -6
  259. warp/tests/aux_test_grad_customs.py +14 -6
  260. warp/tests/aux_test_instancing_gc.py +14 -6
  261. warp/tests/aux_test_module_unload.py +14 -6
  262. warp/tests/aux_test_name_clash1.py +14 -6
  263. warp/tests/aux_test_name_clash2.py +14 -6
  264. warp/tests/aux_test_unresolved_func.py +14 -6
  265. warp/tests/aux_test_unresolved_symbol.py +14 -6
  266. warp/tests/cuda/__init__.py +0 -0
  267. warp/tests/{test_async.py → cuda/test_async.py} +14 -6
  268. warp/tests/{test_ipc.py → cuda/test_ipc.py} +14 -6
  269. warp/tests/{test_mempool.py → cuda/test_mempool.py} +53 -6
  270. warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +14 -6
  271. warp/tests/{test_peer.py → cuda/test_peer.py} +14 -6
  272. warp/tests/{test_pinned.py → cuda/test_pinned.py} +14 -6
  273. warp/tests/{test_streams.py → cuda/test_streams.py} +85 -6
  274. warp/tests/geometry/__init__.py +0 -0
  275. warp/tests/{test_bvh.py → geometry/test_bvh.py} +14 -6
  276. warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +14 -6
  277. warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +14 -6
  278. warp/tests/{test_mesh.py → geometry/test_mesh.py} +14 -6
  279. warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +14 -6
  280. warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +80 -69
  281. warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +15 -7
  282. warp/tests/{test_volume.py → geometry/test_volume.py} +55 -12
  283. warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +14 -6
  284. warp/tests/interop/__init__.py +0 -0
  285. warp/tests/{test_dlpack.py → interop/test_dlpack.py} +42 -11
  286. warp/tests/{test_jax.py → interop/test_jax.py} +14 -6
  287. warp/tests/{test_paddle.py → interop/test_paddle.py} +14 -6
  288. warp/tests/{test_torch.py → interop/test_torch.py} +14 -6
  289. warp/tests/run_coverage_serial.py +14 -6
  290. warp/tests/sim/__init__.py +0 -0
  291. warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +23 -16
  292. warp/tests/{flaky_test_sim_grad.py → sim/flaky_test_sim_grad.py} +14 -6
  293. warp/tests/{test_collision.py → sim/test_collision.py} +16 -8
  294. warp/tests/{test_coloring.py → sim/test_coloring.py} +14 -7
  295. warp/tests/{test_model.py → sim/test_model.py} +55 -7
  296. warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +14 -6
  297. warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +16 -7
  298. warp/tests/sim/test_vbd.py +597 -0
  299. warp/tests/test_adam.py +14 -6
  300. warp/tests/test_arithmetic.py +14 -6
  301. warp/tests/test_array.py +14 -6
  302. warp/tests/test_array_reduce.py +14 -6
  303. warp/tests/test_assert.py +14 -6
  304. warp/tests/test_atomic.py +14 -6
  305. warp/tests/test_bool.py +15 -7
  306. warp/tests/test_builtins_resolution.py +14 -6
  307. warp/tests/test_closest_point_edge_edge.py +14 -6
  308. warp/tests/test_codegen.py +14 -6
  309. warp/tests/test_codegen_instancing.py +14 -6
  310. warp/tests/test_compile_consts.py +14 -6
  311. warp/tests/test_conditional.py +14 -6
  312. warp/tests/test_context.py +14 -6
  313. warp/tests/test_copy.py +14 -6
  314. warp/tests/test_ctypes.py +14 -6
  315. warp/tests/test_dense.py +14 -6
  316. warp/tests/test_devices.py +14 -6
  317. warp/tests/test_examples.py +42 -42
  318. warp/tests/test_fabricarray.py +14 -6
  319. warp/tests/test_fast_math.py +14 -6
  320. warp/tests/test_fem.py +37 -10
  321. warp/tests/test_fp16.py +14 -6
  322. warp/tests/test_func.py +14 -6
  323. warp/tests/test_future_annotations.py +14 -6
  324. warp/tests/test_generics.py +14 -6
  325. warp/tests/test_grad.py +14 -6
  326. warp/tests/test_grad_customs.py +14 -6
  327. warp/tests/test_grad_debug.py +14 -6
  328. warp/tests/test_implicit_init.py +14 -6
  329. warp/tests/test_import.py +14 -6
  330. warp/tests/test_indexedarray.py +14 -6
  331. warp/tests/test_intersect.py +14 -6
  332. warp/tests/test_iter.py +14 -6
  333. warp/tests/test_large.py +14 -6
  334. warp/tests/test_launch.py +14 -6
  335. warp/tests/test_lerp.py +14 -6
  336. warp/tests/test_linear_solvers.py +15 -11
  337. warp/tests/test_lvalue.py +14 -6
  338. warp/tests/test_mat.py +247 -85
  339. warp/tests/test_mat_lite.py +14 -6
  340. warp/tests/test_mat_scalar_ops.py +18 -10
  341. warp/tests/test_math.py +14 -6
  342. warp/tests/test_mlp.py +14 -6
  343. warp/tests/test_module_hashing.py +14 -6
  344. warp/tests/test_modules_lite.py +14 -6
  345. warp/tests/test_noise.py +14 -6
  346. warp/tests/test_operators.py +14 -6
  347. warp/tests/test_options.py +14 -6
  348. warp/tests/test_overwrite.py +15 -60
  349. warp/tests/test_print.py +14 -6
  350. warp/tests/test_quat.py +81 -52
  351. warp/tests/test_rand.py +58 -43
  352. warp/tests/test_reload.py +14 -6
  353. warp/tests/test_rounding.py +14 -6
  354. warp/tests/test_runlength_encode.py +14 -6
  355. warp/tests/test_scalar_ops.py +14 -6
  356. warp/tests/test_smoothstep.py +14 -6
  357. warp/tests/test_snippet.py +15 -0
  358. warp/tests/test_sparse.py +61 -12
  359. warp/tests/test_spatial.py +89 -6
  360. warp/tests/test_special_values.py +14 -6
  361. warp/tests/test_static.py +15 -7
  362. warp/tests/test_struct.py +14 -6
  363. warp/tests/test_tape.py +14 -6
  364. warp/tests/test_transient_module.py +14 -6
  365. warp/tests/test_triangle_closest_point.py +14 -6
  366. warp/tests/test_types.py +14 -6
  367. warp/tests/test_utils.py +98 -10
  368. warp/tests/test_vec.py +60 -40
  369. warp/tests/test_vec_lite.py +14 -6
  370. warp/tests/test_vec_scalar_ops.py +14 -6
  371. warp/tests/test_verify_fp.py +14 -6
  372. warp/tests/tile/__init__.py +0 -0
  373. warp/tests/{test_tile.py → tile/test_tile.py} +150 -57
  374. warp/tests/{test_tile_load.py → tile/test_tile_load.py} +15 -7
  375. warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +23 -12
  376. warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +39 -20
  377. warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +74 -7
  378. warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +14 -6
  379. warp/tests/{test_tile_view.py → tile/test_tile_view.py} +15 -7
  380. warp/tests/unittest_serial.py +15 -6
  381. warp/tests/unittest_suites.py +59 -65
  382. warp/tests/unittest_utils.py +16 -7
  383. warp/tests/walkthrough_debug.py +14 -6
  384. warp/thirdparty/unittest_parallel.py +15 -8
  385. warp/torch.py +14 -6
  386. warp/types.py +124 -664
  387. warp/utils.py +151 -78
  388. {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/METADATA +39 -12
  389. warp_lang-1.7.0.dist-info/RECORD +429 -0
  390. {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/WHEEL +1 -1
  391. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  392. warp/examples/optim/example_walker.py +0 -309
  393. warp/native/cutlass_gemm.cpp +0 -34
  394. warp/native/cutlass_gemm.cu +0 -373
  395. warp/tests/test_matmul.py +0 -503
  396. warp/tests/test_matmul_lite.py +0 -403
  397. warp/tests/test_vbd.py +0 -378
  398. warp/tests/unused_test_misc.py +0 -69
  399. warp_lang-1.6.1.dist-info/LICENSE.md +0 -126
  400. warp_lang-1.6.1.dist-info/RECORD +0 -419
  401. {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/top_level.txt +0 -0
warp/types.py CHANGED
@@ -1,9 +1,17 @@
1
- # Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
2
- # NVIDIA CORPORATION and its licensors retain all intellectual property
3
- # and proprietary rights in and to this software, related documentation
4
- # and any modifications thereto. Any use, reproduction, disclosure or
5
- # distribution of this software and related documentation without an express
6
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
7
15
 
8
16
  from __future__ import annotations
9
17
 
@@ -12,7 +20,21 @@ import ctypes
12
20
  import inspect
13
21
  import struct
14
22
  import zlib
15
- from typing import Any, Callable, Generic, List, Literal, NamedTuple, Optional, Sequence, Tuple, TypeVar, Union
23
+ from typing import (
24
+ Any,
25
+ Callable,
26
+ Generic,
27
+ List,
28
+ Literal,
29
+ NamedTuple,
30
+ Optional,
31
+ Sequence,
32
+ Tuple,
33
+ TypeVar,
34
+ Union,
35
+ get_args,
36
+ get_origin,
37
+ )
16
38
 
17
39
  import numpy as np
18
40
  import numpy.typing as npt
@@ -48,7 +70,9 @@ class Transformation(Generic[Float]):
48
70
 
49
71
 
50
72
  class Array(Generic[DType]):
51
- pass
73
+ device: Optional[warp.context.Device]
74
+ dtype: type
75
+ size: int
52
76
 
53
77
 
54
78
  int_tuple_type_hints = {
@@ -1131,7 +1155,7 @@ ARRAY_TYPE_FABRIC_INDEXED = 3
1131
1155
  class launch_bounds_t(ctypes.Structure):
1132
1156
  _fields_ = [("shape", ctypes.c_int32 * LAUNCH_MAX_DIMS), ("ndim", ctypes.c_int32), ("size", ctypes.c_size_t)]
1133
1157
 
1134
- def __init__(self, shape):
1158
+ def __init__(self, shape: Union[int, Sequence[int]]):
1135
1159
  if isinstance(shape, int):
1136
1160
  # 1d launch
1137
1161
  self.ndim = 1
@@ -1252,7 +1276,7 @@ _type_size_cache = {
1252
1276
  }
1253
1277
 
1254
1278
 
1255
- def type_size_in_bytes(dtype):
1279
+ def type_size_in_bytes(dtype: type) -> int:
1256
1280
  size = _type_size_cache.get(dtype)
1257
1281
 
1258
1282
  if size is None:
@@ -1271,7 +1295,7 @@ def type_size_in_bytes(dtype):
1271
1295
  return size
1272
1296
 
1273
1297
 
1274
- def type_to_warp(dtype):
1298
+ def type_to_warp(dtype: type) -> type:
1275
1299
  if dtype == float:
1276
1300
  return float32
1277
1301
  elif dtype == int:
@@ -1282,7 +1306,7 @@ def type_to_warp(dtype):
1282
1306
  return dtype
1283
1307
 
1284
1308
 
1285
- def type_typestr(dtype):
1309
+ def type_typestr(dtype: type) -> str:
1286
1310
  if dtype == bool:
1287
1311
  return "|b1"
1288
1312
  elif dtype == float16:
@@ -1368,29 +1392,29 @@ def type_is_transformation(t):
1368
1392
  return getattr(t, "_wp_generic_type_hint_", None) is Transformation
1369
1393
 
1370
1394
 
1371
- value_types = (int, float, builtins.bool) + scalar_types
1395
+ value_types = (int, float, builtins.bool) + scalar_and_bool_types
1372
1396
 
1373
1397
 
1374
1398
  # returns true for all value types (int, float, bool, scalars, vectors, matrices)
1375
- def type_is_value(x):
1399
+ def type_is_value(x: Any) -> builtins.bool:
1376
1400
  return x in value_types or hasattr(x, "_wp_scalar_type_")
1377
1401
 
1378
1402
 
1379
1403
  # equivalent of the above but for values
1380
- def is_int(x):
1404
+ def is_int(x: Any) -> builtins.bool:
1381
1405
  return type_is_int(type(x))
1382
1406
 
1383
1407
 
1384
- def is_float(x):
1408
+ def is_float(x: Any) -> builtins.bool:
1385
1409
  return type_is_float(type(x))
1386
1410
 
1387
1411
 
1388
- def is_value(x):
1412
+ def is_value(x: Any) -> builtins.bool:
1389
1413
  return type_is_value(type(x))
1390
1414
 
1391
1415
 
1392
- # returns true if the passed *instance* is one of the array types
1393
- def is_array(a):
1416
+ def is_array(a) -> builtins.bool:
1417
+ """Return true if the passed *instance* is one of the array types."""
1394
1418
  return isinstance(a, array_types)
1395
1419
 
1396
1420
 
@@ -1457,21 +1481,21 @@ def types_equal(a, b, match_generic=False):
1457
1481
  if a_length is None or b_length is None or a_length == b_length:
1458
1482
  return True
1459
1483
 
1460
- a_origin = warp.codegen.get_type_origin(a)
1461
- b_origin = warp.codegen.get_type_origin(b)
1484
+ a_origin = get_origin(a)
1485
+ b_origin = get_origin(b)
1462
1486
  if a_origin is tuple and b_origin is tuple:
1463
- a_args = warp.codegen.get_type_args(a)
1464
- b_args = warp.codegen.get_type_args(b)
1487
+ a_args = get_args(a)
1488
+ b_args = get_args(b)
1465
1489
  if len(a_args) == len(b_args) and all(
1466
1490
  scalars_equal(x, y, match_generic=match_generic) for x, y in zip(a_args, b_args)
1467
1491
  ):
1468
1492
  return True
1469
1493
  elif a_origin is tuple and isinstance(b, Sequence):
1470
- a_args = warp.codegen.get_type_args(a)
1494
+ a_args = get_args(a)
1471
1495
  if len(a_args) == len(b) and all(scalars_equal(x, y, match_generic=match_generic) for x, y in zip(a_args, b)):
1472
1496
  return True
1473
1497
  elif b_origin is tuple and isinstance(a, Sequence):
1474
- b_args = warp.codegen.get_type_args(b)
1498
+ b_args = get_args(b)
1475
1499
  if len(b_args) == len(a) and all(scalars_equal(x, y, match_generic=match_generic) for x, y in zip(b_args, a)):
1476
1500
  return True
1477
1501
 
@@ -1592,7 +1616,7 @@ def array_ctype_from_interface(interface: dict, dtype=None, owner=None):
1592
1616
  return array_ctype
1593
1617
 
1594
1618
 
1595
- class array(Array):
1619
+ class array(Array[DType]):
1596
1620
  """A fixed-size multi-dimensional array containing values of the same type.
1597
1621
 
1598
1622
  Attributes:
@@ -1621,21 +1645,21 @@ class array(Array):
1621
1645
 
1622
1646
  def __init__(
1623
1647
  self,
1624
- data: Optional[Union[List, Tuple, npt.NDArray]] = None,
1625
- dtype: Union[DType, Any] = Any,
1626
- shape: Optional[Tuple[int, ...]] = None,
1648
+ data: Union[List, Tuple, npt.NDArray, None] = None,
1649
+ dtype: Any = Any,
1650
+ shape: Union[int, Tuple[int, ...], List[int], None] = None,
1627
1651
  strides: Optional[Tuple[int, ...]] = None,
1628
1652
  length: Optional[int] = None,
1629
1653
  ptr: Optional[int] = None,
1630
1654
  capacity: Optional[int] = None,
1631
1655
  device=None,
1632
- pinned: bool = False,
1633
- copy: bool = True,
1634
- owner: bool = False, # deprecated - pass deleter instead
1656
+ pinned: builtins.bool = False,
1657
+ copy: builtins.bool = True,
1658
+ owner: builtins.bool = False, # deprecated - pass deleter instead
1635
1659
  deleter: Optional[Callable[[int, int], None]] = None,
1636
1660
  ndim: Optional[int] = None,
1637
1661
  grad: Optional[array] = None,
1638
- requires_grad: bool = False,
1662
+ requires_grad: builtins.bool = False,
1639
1663
  ):
1640
1664
  """Constructs a new Warp array object
1641
1665
 
@@ -2931,7 +2955,7 @@ def from_ipc_handle(
2931
2955
 
2932
2956
  # A base class for non-contiguous arrays, providing the implementation of common methods like
2933
2957
  # contiguous(), to(), numpy(), list(), assign(), zero_(), and fill_().
2934
- class noncontiguous_array_base(Generic[T]):
2958
+ class noncontiguous_array_base(Array[T]):
2935
2959
  def __init__(self, array_type_id):
2936
2960
  self.type_id = array_type_id
2937
2961
  self.is_contiguous = False
@@ -3028,12 +3052,18 @@ def check_index_array(indices, expected_device):
3028
3052
  raise ValueError(f"Index array device ({indices.device} does not match data array device ({expected_device}))")
3029
3053
 
3030
3054
 
3031
- class indexedarray(noncontiguous_array_base[T]):
3055
+ class indexedarray(noncontiguous_array_base):
3032
3056
  # member attributes available during code-gen (e.g.: d = arr.shape[0])
3033
3057
  # (initialized when needed)
3034
3058
  _vars = None
3035
3059
 
3036
- def __init__(self, data: array = None, indices: Union[array, List[array]] = None, dtype=None, ndim=None):
3060
+ def __init__(
3061
+ self,
3062
+ data: Optional[array] = None,
3063
+ indices: Union[array, List[array], None] = None,
3064
+ dtype=None,
3065
+ ndim: Optional[int] = None,
3066
+ ):
3037
3067
  super().__init__(ARRAY_TYPE_INDEXED)
3038
3068
 
3039
3069
  # canonicalize types
@@ -3224,7 +3254,7 @@ class Tile:
3224
3254
  return f"wp::tile_alloc_empty<{Var.type_to_ctype(self.dtype)},wp::tile_shape_t<{','.join(map(str, self.shape))}>,{'true' if requires_grad else 'false'}>()"
3225
3255
  else:
3226
3256
  # tile will be initialized by another call, e.g.: tile_transpose()
3227
- return "NULL"
3257
+ return "nullptr"
3228
3258
 
3229
3259
  # return total tile size in bytes
3230
3260
  def size_in_bytes(self):
@@ -3626,7 +3656,7 @@ class Volume:
3626
3656
  instance.id = None
3627
3657
  return instance
3628
3658
 
3629
- def __init__(self, data: array, copy: bool = True):
3659
+ def __init__(self, data: array, copy: builtins.bool = True):
3630
3660
  """Class representing a sparse grid.
3631
3661
 
3632
3662
  Args:
@@ -4353,6 +4383,15 @@ class Volume:
4353
4383
  translation_buf = (ctypes.c_float * 3)(translation[0], translation[1], translation[2])
4354
4384
  return transform_buf, translation_buf
4355
4385
 
4386
+ # nanovdb types for which we instantiate the grid builder
4387
+ # Should be in sync with WP_VOLUME_BUILDER_INSTANTIATE_TYPES in volume_builder.h
4388
+ _supported_allocation_types = [
4389
+ "int32",
4390
+ "float",
4391
+ "Vec3f",
4392
+ "Vec4f",
4393
+ ]
4394
+
4356
4395
  @classmethod
4357
4396
  def allocate_by_tiles(
4358
4397
  cls,
@@ -4380,7 +4419,8 @@ class Volume:
4380
4419
  or a floating point scalar type (2D N-by-3 array of :class:`warp.float32` or 1D array of `warp.vec3f` values), indicating world space positions.
4381
4420
  Repeated points per tile are allowed and will be efficiently deduplicated.
4382
4421
  voxel_size (float or array-like): Voxel size(s) of the new volume. Ignored if `transform` is given.
4383
- bg_value (array-like, float, int or None): Value of unallocated voxels of the volume, also defines the volume's type. A :class:`warp.vec3` volume is created if this is `array-like`, an index volume will be created if `bg_value` is ``None``.
4422
+ bg_value (array-like, scalar or None): Value of unallocated voxels of the volume, also defines the volume's type. An index volume will be created if `bg_value` is ``None``.
4423
+ Other supported grid types are `int`, `float`, `vec3f`, and `vec4f`.
4384
4424
  translation (array-like): Translation between the index and world spaces.
4385
4425
  transform (array-like): Linear transform between the index and world spaces. If ``None``, deduced from `voxel_size`.
4386
4426
  device (Devicelike): The CUDA device to create the volume on, e.g.: "cuda" or "cuda:0".
@@ -4412,35 +4452,47 @@ class Volume:
4412
4452
  translation_buf,
4413
4453
  in_world_space,
4414
4454
  )
4415
- elif hasattr(bg_value, "__len__"):
4416
- volume.id = volume.runtime.core.volume_v_from_tiles_device(
4417
- volume.device.context,
4418
- ctypes.c_void_p(tile_points.ptr),
4419
- tile_points.shape[0],
4420
- transform_buf,
4421
- translation_buf,
4422
- in_world_space,
4423
- (ctypes.c_float * 3)(bg_value[0], bg_value[1], bg_value[2]),
4424
- )
4425
- elif isinstance(bg_value, int):
4426
- volume.id = volume.runtime.core.volume_i_from_tiles_device(
4427
- volume.device.context,
4428
- ctypes.c_void_p(tile_points.ptr),
4429
- tile_points.shape[0],
4430
- transform_buf,
4431
- translation_buf,
4432
- in_world_space,
4433
- bg_value,
4434
- )
4435
4455
  else:
4436
- volume.id = volume.runtime.core.volume_f_from_tiles_device(
4456
+ # normalize background value type
4457
+ grid_type = type_to_warp(type(bg_value))
4458
+ if not (is_value(bg_value) or type_is_vector(grid_type)) and (
4459
+ hasattr(bg_value, "__len__") and is_value(bg_value[0])
4460
+ ):
4461
+ # non-warp vectors are considered float, for backward compatibility
4462
+ grid_type = vector(len(bg_value), dtype=float)
4463
+
4464
+ # look for corresponding nvdb type
4465
+ try:
4466
+ nvdb_type = next(
4467
+ typ
4468
+ for typ in Volume._supported_allocation_types
4469
+ if types_equal(grid_type, Volume._nvdb_type_to_dtype[typ])
4470
+ )
4471
+ except StopIteration as err:
4472
+ raise TypeError(
4473
+ f"Unsupported bg_value type for volume allocation {type_repr(grid_type)}. Supported volume types are {', '.join(Volume._supported_allocation_types)}."
4474
+ ) from err
4475
+
4476
+ # cast to ctype
4477
+ # wrap scalar values in length-1 vectors to handle specific ctype conversion
4478
+ if not type_is_vector(grid_type):
4479
+ grid_type = vector(length=1, dtype=grid_type)
4480
+
4481
+ cvalue = grid_type(bg_value)
4482
+ cvalue_ptr = ctypes.pointer(cvalue)
4483
+ cvalue_size = ctypes.sizeof(cvalue)
4484
+ cvalue_type = nvdb_type.encode("ascii")
4485
+
4486
+ volume.id = volume.runtime.core.volume_from_tiles_device(
4437
4487
  volume.device.context,
4438
4488
  ctypes.c_void_p(tile_points.ptr),
4439
4489
  tile_points.shape[0],
4440
4490
  transform_buf,
4441
4491
  translation_buf,
4442
4492
  in_world_space,
4443
- float(bg_value),
4493
+ cvalue_ptr,
4494
+ cvalue_size,
4495
+ cvalue_type,
4444
4496
  )
4445
4497
 
4446
4498
  if volume.id == 0:
@@ -4598,6 +4650,8 @@ def matmul(
4598
4650
  ):
4599
4651
  """Computes a generic matrix-matrix multiplication (GEMM) of the form: `d = alpha * (a @ b) + beta * c`.
4600
4652
 
4653
+ .. versionremoved:: 1.7
4654
+
4601
4655
  .. deprecated:: 1.6
4602
4656
  Use :doc:`tile primitives </modules/tiles>` instead.
4603
4657
 
@@ -4611,80 +4665,8 @@ def matmul(
4611
4665
  allow_tf32x3_arith (bool): whether to use CUTLASS's 3xTF32 GEMMs, which enable accuracy similar to FP32
4612
4666
  while using Tensor Cores
4613
4667
  """
4614
- from warp.context import runtime
4615
-
4616
- warp.utils.warn(
4617
- "wp.matmul() is deprecated and will be removed in a\nfuture version. Use tile primitives instead.",
4618
- category=DeprecationWarning,
4619
- stacklevel=2,
4620
- )
4621
-
4622
- device = a.device
4623
-
4624
- if b.device != device or c.device != device or d.device != device:
4625
- raise RuntimeError("Matrices A, B, C, and D must all be on the same device as the runtime device.")
4626
-
4627
- if a.dtype != b.dtype or a.dtype != c.dtype or a.dtype != d.dtype:
4628
- raise RuntimeError(
4629
- "wp.matmul currently only supports operation between {A, B, C, D} matrices of the same type."
4630
- )
4631
-
4632
- if (
4633
- (not a.is_contiguous and not a.is_transposed)
4634
- or (not b.is_contiguous and not b.is_transposed)
4635
- or (not c.is_contiguous)
4636
- or (not d.is_contiguous)
4637
- ):
4638
- raise RuntimeError(
4639
- "wp.matmul is only valid for contiguous arrays, with the exception that A and/or B may be transposed."
4640
- )
4641
4668
 
4642
- m = a.shape[0]
4643
- n = b.shape[1]
4644
- k = a.shape[1]
4645
- if b.shape != (k, n) or c.shape != (m, n) or d.shape != (m, n):
4646
- raise RuntimeError(
4647
- "Invalid shapes for matrices: A = {} B = {} C = {} D = {}".format(a.shape, b.shape, c.shape, d.shape)
4648
- )
4649
-
4650
- if runtime.tape:
4651
- runtime.tape.record_func(
4652
- backward=lambda: adj_matmul(a, b, c, a.grad, b.grad, c.grad, d.grad, alpha, beta, allow_tf32x3_arith),
4653
- arrays=[a, b, c, d],
4654
- )
4655
- if warp.config.verify_autograd_array_access:
4656
- d.mark_write()
4657
- a.mark_read()
4658
- b.mark_read()
4659
- c.mark_read()
4660
-
4661
- # cpu fallback if no cuda devices found
4662
- if device == "cpu":
4663
- np_dtype = warp_type_to_np_dtype[a.dtype]
4664
- d.assign(alpha * np.matmul(a.numpy(), b.numpy(), dtype=np_dtype) + beta * c.numpy())
4665
- return
4666
-
4667
- cc = device.arch
4668
- ret = runtime.core.cutlass_gemm(
4669
- device.context,
4670
- cc,
4671
- m,
4672
- n,
4673
- k,
4674
- type_typestr(a.dtype).encode(),
4675
- ctypes.c_void_p(a.ptr),
4676
- ctypes.c_void_p(b.ptr),
4677
- ctypes.c_void_p(c.ptr),
4678
- ctypes.c_void_p(d.ptr),
4679
- alpha,
4680
- beta,
4681
- not a.is_transposed,
4682
- not b.is_transposed,
4683
- allow_tf32x3_arith,
4684
- 1,
4685
- )
4686
- if not ret:
4687
- raise RuntimeError("matmul failed.")
4669
+ raise RuntimeError("This function has been removed. Use tile primitives instead.")
4688
4670
 
4689
4671
 
4690
4672
  def adj_matmul(
@@ -4716,171 +4698,8 @@ def adj_matmul(
4716
4698
  allow_tf32x3_arith (bool): whether to use CUTLASS's 3xTF32 GEMMs, which enable accuracy similar to FP32
4717
4699
  while using Tensor Cores
4718
4700
  """
4719
- from warp.context import runtime
4720
-
4721
- device = a.device
4722
-
4723
- if (
4724
- b.device != device
4725
- or c.device != device
4726
- or adj_a.device != device
4727
- or adj_b.device != device
4728
- or adj_c.device != device
4729
- or adj_d.device != device
4730
- ):
4731
- raise RuntimeError(
4732
- "Matrices A, B, C, D, and their adjoints must all be on the same device as the runtime device."
4733
- )
4734
-
4735
- if (
4736
- a.dtype != b.dtype
4737
- or a.dtype != c.dtype
4738
- or a.dtype != adj_a.dtype
4739
- or a.dtype != adj_b.dtype
4740
- or a.dtype != adj_c.dtype
4741
- or a.dtype != adj_d.dtype
4742
- ):
4743
- raise RuntimeError(
4744
- "wp.adj_matmul currently only supports operation between {A, B, C, adj_D, adj_A, adj_B, adj_C} matrices of the same type."
4745
- )
4746
-
4747
- if (
4748
- (not a.is_contiguous and not a.is_transposed)
4749
- or (not b.is_contiguous and not b.is_transposed)
4750
- or (not c.is_contiguous)
4751
- or (not adj_a.is_contiguous and not adj_a.is_transposed)
4752
- or (not adj_b.is_contiguous and not adj_b.is_transposed)
4753
- or (not adj_c.is_contiguous)
4754
- or (not adj_d.is_contiguous)
4755
- ):
4756
- raise RuntimeError(
4757
- "wp.matmul is only valid for contiguous arrays, with the exception that A and/or B and their associated adjoints may be transposed."
4758
- )
4759
4701
 
4760
- m = a.shape[0]
4761
- n = b.shape[1]
4762
- k = a.shape[1]
4763
- if (
4764
- a.shape != (m, k)
4765
- or b.shape != (k, n)
4766
- or c.shape != (m, n)
4767
- or adj_d.shape != (m, n)
4768
- or adj_a.shape != (m, k)
4769
- or adj_b.shape != (k, n)
4770
- or adj_c.shape != (m, n)
4771
- ):
4772
- raise RuntimeError(
4773
- "Invalid shapes for matrices: A = {} B = {} C = {} adj_D = {} adj_A = {} adj_B = {} adj_C = {}".format(
4774
- a.shape, b.shape, c.shape, adj_d.shape, adj_a.shape, adj_b.shape, adj_c.shape
4775
- )
4776
- )
4777
-
4778
- # cpu fallback if no cuda devices found
4779
- if device == "cpu":
4780
- np_dtype = warp_type_to_np_dtype[a.dtype]
4781
- adj_a.assign(alpha * np.matmul(adj_d.numpy(), b.numpy().transpose(), dtype=np_dtype) + adj_a.numpy())
4782
- adj_b.assign(alpha * np.matmul(a.numpy().transpose(), adj_d.numpy(), dtype=np_dtype) + adj_b.numpy())
4783
- adj_c.assign(beta * adj_d.numpy() + adj_c.numpy())
4784
- return
4785
-
4786
- cc = device.arch
4787
-
4788
- # adj_a
4789
- if not a.is_transposed:
4790
- ret = runtime.core.cutlass_gemm(
4791
- device.context,
4792
- cc,
4793
- m,
4794
- k,
4795
- n,
4796
- type_typestr(a.dtype).encode(),
4797
- ctypes.c_void_p(adj_d.ptr),
4798
- ctypes.c_void_p(b.ptr),
4799
- ctypes.c_void_p(adj_a.ptr),
4800
- ctypes.c_void_p(adj_a.ptr),
4801
- alpha,
4802
- 1.0,
4803
- True,
4804
- b.is_transposed,
4805
- allow_tf32x3_arith,
4806
- 1,
4807
- )
4808
- if not ret:
4809
- raise RuntimeError("adj_matmul failed.")
4810
- else:
4811
- ret = runtime.core.cutlass_gemm(
4812
- device.context,
4813
- cc,
4814
- k,
4815
- m,
4816
- n,
4817
- type_typestr(a.dtype).encode(),
4818
- ctypes.c_void_p(b.ptr),
4819
- ctypes.c_void_p(adj_d.ptr),
4820
- ctypes.c_void_p(adj_a.ptr),
4821
- ctypes.c_void_p(adj_a.ptr),
4822
- alpha,
4823
- 1.0,
4824
- not b.is_transposed,
4825
- False,
4826
- allow_tf32x3_arith,
4827
- 1,
4828
- )
4829
- if not ret:
4830
- raise RuntimeError("adj_matmul failed.")
4831
-
4832
- # adj_b
4833
- if not b.is_transposed:
4834
- ret = runtime.core.cutlass_gemm(
4835
- device.context,
4836
- cc,
4837
- k,
4838
- n,
4839
- m,
4840
- type_typestr(a.dtype).encode(),
4841
- ctypes.c_void_p(a.ptr),
4842
- ctypes.c_void_p(adj_d.ptr),
4843
- ctypes.c_void_p(adj_b.ptr),
4844
- ctypes.c_void_p(adj_b.ptr),
4845
- alpha,
4846
- 1.0,
4847
- a.is_transposed,
4848
- True,
4849
- allow_tf32x3_arith,
4850
- 1,
4851
- )
4852
- if not ret:
4853
- raise RuntimeError("adj_matmul failed.")
4854
- else:
4855
- ret = runtime.core.cutlass_gemm(
4856
- device.context,
4857
- cc,
4858
- n,
4859
- k,
4860
- m,
4861
- type_typestr(a.dtype).encode(),
4862
- ctypes.c_void_p(adj_d.ptr),
4863
- ctypes.c_void_p(a.ptr),
4864
- ctypes.c_void_p(adj_b.ptr),
4865
- ctypes.c_void_p(adj_b.ptr),
4866
- alpha,
4867
- 1.0,
4868
- False,
4869
- not a.is_transposed,
4870
- allow_tf32x3_arith,
4871
- 1,
4872
- )
4873
- if not ret:
4874
- raise RuntimeError("adj_matmul failed.")
4875
-
4876
- # adj_c
4877
- warp.launch(
4878
- kernel=warp.utils.add_kernel_2d,
4879
- dim=adj_c.shape,
4880
- inputs=[adj_c, adj_d, adj_d.dtype(beta)],
4881
- device=device,
4882
- record_tape=False,
4883
- )
4702
+ raise RuntimeError("This function has been removed. Use tile primitives instead.")
4884
4703
 
4885
4704
 
4886
4705
  def batched_matmul(
@@ -4894,6 +4713,8 @@ def batched_matmul(
4894
4713
  ):
4895
4714
  """Computes a batched generic matrix-matrix multiplication (GEMM) of the form: `d = alpha * (a @ b) + beta * c`.
4896
4715
 
4716
+ .. versionremoved:: 1.7
4717
+
4897
4718
  .. deprecated:: 1.6
4898
4719
  Use :doc:`tile primitives </modules/tiles>` instead.
4899
4720
 
@@ -4907,107 +4728,8 @@ def batched_matmul(
4907
4728
  allow_tf32x3_arith (bool): whether to use CUTLASS's 3xTF32 GEMMs, which enable accuracy similar to FP32
4908
4729
  while using Tensor Cores
4909
4730
  """
4910
- from warp.context import runtime
4911
-
4912
- device = a.device
4913
-
4914
- if b.device != device or c.device != device or d.device != device:
4915
- raise RuntimeError("Matrices A, B, C, and D must all be on the same device as the runtime device.")
4916
-
4917
- if a.dtype != b.dtype or a.dtype != c.dtype or a.dtype != d.dtype:
4918
- raise RuntimeError(
4919
- "wp.batched_matmul currently only supports operation between {A, B, C, D} matrices of the same type."
4920
- )
4921
-
4922
- if (
4923
- (not a.is_contiguous and not a.is_transposed)
4924
- or (not b.is_contiguous and not b.is_transposed)
4925
- or (not c.is_contiguous)
4926
- or (not d.is_contiguous)
4927
- ):
4928
- raise RuntimeError(
4929
- "wp.matmul is only valid for contiguous arrays, with the exception that A and/or B may be transposed."
4930
- )
4931
-
4932
- m = a.shape[1]
4933
- n = b.shape[2]
4934
- k = a.shape[2]
4935
- batch_count = a.shape[0]
4936
- if b.shape != (batch_count, k, n) or c.shape != (batch_count, m, n) or d.shape != (batch_count, m, n):
4937
- raise RuntimeError(
4938
- "Invalid shapes for matrices: A = {} B = {} C = {} D = {}".format(a.shape, b.shape, c.shape, d.shape)
4939
- )
4940
4731
 
4941
- if runtime.tape:
4942
- runtime.tape.record_func(
4943
- backward=lambda: adj_batched_matmul(
4944
- a, b, c, a.grad, b.grad, c.grad, d.grad, alpha, beta, allow_tf32x3_arith
4945
- ),
4946
- arrays=[a, b, c, d],
4947
- )
4948
- if warp.config.verify_autograd_array_access:
4949
- d.mark_write()
4950
- a.mark_read()
4951
- b.mark_read()
4952
- c.mark_read()
4953
-
4954
- # cpu fallback if no cuda devices found
4955
- if device == "cpu":
4956
- np_dtype = warp_type_to_np_dtype[a.dtype]
4957
- d.assign(alpha * np.matmul(a.numpy(), b.numpy(), dtype=np_dtype) + beta * c.numpy())
4958
- return
4959
-
4960
- # handle case in which batch_count exceeds max_batch_count, which is a CUDA array size maximum
4961
- max_batch_count = 65535
4962
- iters = int(batch_count / max_batch_count)
4963
- remainder = batch_count % max_batch_count
4964
-
4965
- cc = device.arch
4966
- for i in range(iters):
4967
- idx_start = i * max_batch_count
4968
- idx_end = (i + 1) * max_batch_count if i < iters - 1 else batch_count
4969
- ret = runtime.core.cutlass_gemm(
4970
- device.context,
4971
- cc,
4972
- m,
4973
- n,
4974
- k,
4975
- type_typestr(a.dtype).encode(),
4976
- ctypes.c_void_p(a[idx_start:idx_end, :, :].ptr),
4977
- ctypes.c_void_p(b[idx_start:idx_end, :, :].ptr),
4978
- ctypes.c_void_p(c[idx_start:idx_end, :, :].ptr),
4979
- ctypes.c_void_p(d[idx_start:idx_end, :, :].ptr),
4980
- alpha,
4981
- beta,
4982
- not a.is_transposed,
4983
- not b.is_transposed,
4984
- allow_tf32x3_arith,
4985
- max_batch_count,
4986
- )
4987
- if not ret:
4988
- raise RuntimeError("Batched matmul failed.")
4989
-
4990
- idx_start = iters * max_batch_count
4991
- ret = runtime.core.cutlass_gemm(
4992
- device.context,
4993
- cc,
4994
- m,
4995
- n,
4996
- k,
4997
- type_typestr(a.dtype).encode(),
4998
- ctypes.c_void_p(a[idx_start:, :, :].ptr),
4999
- ctypes.c_void_p(b[idx_start:, :, :].ptr),
5000
- ctypes.c_void_p(c[idx_start:, :, :].ptr),
5001
- ctypes.c_void_p(d[idx_start:, :, :].ptr),
5002
- alpha,
5003
- beta,
5004
- not a.is_transposed,
5005
- not b.is_transposed,
5006
- allow_tf32x3_arith,
5007
- remainder,
5008
- )
5009
- if not ret:
5010
- raise RuntimeError("Batched matmul failed.")
4732
+ raise RuntimeError("This function has been removed. Use tile primitives instead.")
5011
4733
 
5012
4734
 
5013
4735
  def adj_batched_matmul(
@@ -5037,270 +4759,8 @@ def adj_batched_matmul(
5037
4759
  allow_tf32x3_arith (bool): whether to use CUTLASS's 3xTF32 GEMMs, which enable accuracy similar to FP32
5038
4760
  while using Tensor Cores
5039
4761
  """
5040
- from warp.context import runtime
5041
4762
 
5042
- device = a.device
5043
-
5044
- if (
5045
- b.device != device
5046
- or c.device != device
5047
- or adj_a.device != device
5048
- or adj_b.device != device
5049
- or adj_c.device != device
5050
- or adj_d.device != device
5051
- ):
5052
- raise RuntimeError(
5053
- "Matrices A, B, C, D, and their adjoints must all be on the same device as the runtime device."
5054
- )
5055
-
5056
- if (
5057
- a.dtype != b.dtype
5058
- or a.dtype != c.dtype
5059
- or a.dtype != adj_a.dtype
5060
- or a.dtype != adj_b.dtype
5061
- or a.dtype != adj_c.dtype
5062
- or a.dtype != adj_d.dtype
5063
- ):
5064
- raise RuntimeError(
5065
- "wp.adj_batched_matmul currently only supports operation between {A, B, C, adj_D, adj_A, adj_B, adj_C} matrices of the same type."
5066
- )
5067
-
5068
- m = a.shape[1]
5069
- n = b.shape[2]
5070
- k = a.shape[2]
5071
- batch_count = a.shape[0]
5072
- if (
5073
- b.shape != (batch_count, k, n)
5074
- or c.shape != (batch_count, m, n)
5075
- or adj_d.shape != (batch_count, m, n)
5076
- or adj_a.shape != (batch_count, m, k)
5077
- or adj_b.shape != (batch_count, k, n)
5078
- or adj_c.shape != (batch_count, m, n)
5079
- ):
5080
- raise RuntimeError(
5081
- "Invalid shapes for matrices: A = {} B = {} C = {} adj_D = {} adj_A = {} adj_B = {} adj_C = {}".format(
5082
- a.shape, b.shape, c.shape, adj_d.shape, adj_a.shape, adj_b.shape, adj_c.shape
5083
- )
5084
- )
5085
-
5086
- if (
5087
- (not a.is_contiguous and not a.is_transposed)
5088
- or (not b.is_contiguous and not b.is_transposed)
5089
- or (not c.is_contiguous)
5090
- or (not adj_a.is_contiguous and not adj_a.is_transposed)
5091
- or (not adj_b.is_contiguous and not adj_b.is_transposed)
5092
- or (not adj_c.is_contiguous)
5093
- or (not adj_d.is_contiguous)
5094
- ):
5095
- raise RuntimeError(
5096
- "wp.matmul is only valid for contiguous arrays, with the exception that A and/or B and their associated adjoints may be transposed."
5097
- )
5098
-
5099
- # cpu fallback if no cuda devices found
5100
- if device == "cpu":
5101
- np_dtype = warp_type_to_np_dtype[a.dtype]
5102
- adj_a.assign(alpha * np.matmul(adj_d.numpy(), b.numpy().transpose((0, 2, 1)), dtype=np_dtype) + adj_a.numpy())
5103
- adj_b.assign(alpha * np.matmul(a.numpy().transpose((0, 2, 1)), adj_d.numpy(), dtype=np_dtype) + adj_b.numpy())
5104
- adj_c.assign(beta * adj_d.numpy() + adj_c.numpy())
5105
- return
5106
-
5107
- # handle case in which batch_count exceeds max_batch_count, which is a CUDA array size maximum
5108
- max_batch_count = 65535
5109
- iters = int(batch_count / max_batch_count)
5110
- remainder = batch_count % max_batch_count
5111
-
5112
- cc = device.arch
5113
-
5114
- for i in range(iters):
5115
- idx_start = i * max_batch_count
5116
- idx_end = (i + 1) * max_batch_count if i < iters - 1 else batch_count
5117
-
5118
- # adj_a
5119
- if not a.is_transposed:
5120
- ret = runtime.core.cutlass_gemm(
5121
- device.context,
5122
- cc,
5123
- m,
5124
- k,
5125
- n,
5126
- type_typestr(a.dtype).encode(),
5127
- ctypes.c_void_p(adj_d[idx_start:idx_end, :, :].ptr),
5128
- ctypes.c_void_p(b[idx_start:idx_end, :, :].ptr),
5129
- ctypes.c_void_p(adj_a[idx_start:idx_end, :, :].ptr),
5130
- ctypes.c_void_p(adj_a[idx_start:idx_end, :, :].ptr),
5131
- alpha,
5132
- 1.0,
5133
- True,
5134
- b.is_transposed,
5135
- allow_tf32x3_arith,
5136
- max_batch_count,
5137
- )
5138
- if not ret:
5139
- raise RuntimeError("adj_matmul failed.")
5140
- else:
5141
- ret = runtime.core.cutlass_gemm(
5142
- device.context,
5143
- cc,
5144
- k,
5145
- m,
5146
- n,
5147
- type_typestr(a.dtype).encode(),
5148
- ctypes.c_void_p(b[idx_start:idx_end, :, :].ptr),
5149
- ctypes.c_void_p(adj_d[idx_start:idx_end, :, :].ptr),
5150
- ctypes.c_void_p(adj_a[idx_start:idx_end, :, :].ptr),
5151
- ctypes.c_void_p(adj_a[idx_start:idx_end, :, :].ptr),
5152
- alpha,
5153
- 1.0,
5154
- not b.is_transposed,
5155
- False,
5156
- allow_tf32x3_arith,
5157
- max_batch_count,
5158
- )
5159
- if not ret:
5160
- raise RuntimeError("adj_matmul failed.")
5161
-
5162
- # adj_b
5163
- if not b.is_transposed:
5164
- ret = runtime.core.cutlass_gemm(
5165
- device.context,
5166
- cc,
5167
- k,
5168
- n,
5169
- m,
5170
- type_typestr(a.dtype).encode(),
5171
- ctypes.c_void_p(a[idx_start:idx_end, :, :].ptr),
5172
- ctypes.c_void_p(adj_d[idx_start:idx_end, :, :].ptr),
5173
- ctypes.c_void_p(adj_b[idx_start:idx_end, :, :].ptr),
5174
- ctypes.c_void_p(adj_b[idx_start:idx_end, :, :].ptr),
5175
- alpha,
5176
- 1.0,
5177
- a.is_transposed,
5178
- True,
5179
- allow_tf32x3_arith,
5180
- max_batch_count,
5181
- )
5182
- if not ret:
5183
- raise RuntimeError("adj_matmul failed.")
5184
- else:
5185
- ret = runtime.core.cutlass_gemm(
5186
- device.context,
5187
- cc,
5188
- n,
5189
- k,
5190
- m,
5191
- type_typestr(a.dtype).encode(),
5192
- ctypes.c_void_p(adj_d[idx_start:idx_end, :, :].ptr),
5193
- ctypes.c_void_p(a[idx_start:idx_end, :, :].ptr),
5194
- ctypes.c_void_p(adj_b[idx_start:idx_end, :, :].ptr),
5195
- ctypes.c_void_p(adj_b[idx_start:idx_end, :, :].ptr),
5196
- alpha,
5197
- 1.0,
5198
- False,
5199
- not a.is_transposed,
5200
- allow_tf32x3_arith,
5201
- max_batch_count,
5202
- )
5203
- if not ret:
5204
- raise RuntimeError("adj_matmul failed.")
5205
-
5206
- idx_start = iters * max_batch_count
5207
-
5208
- # adj_a
5209
- if not a.is_transposed:
5210
- ret = runtime.core.cutlass_gemm(
5211
- device.context,
5212
- cc,
5213
- m,
5214
- k,
5215
- n,
5216
- type_typestr(a.dtype).encode(),
5217
- ctypes.c_void_p(adj_d[idx_start:, :, :].ptr),
5218
- ctypes.c_void_p(b[idx_start:, :, :].ptr),
5219
- ctypes.c_void_p(adj_a[idx_start:, :, :].ptr),
5220
- ctypes.c_void_p(adj_a[idx_start:, :, :].ptr),
5221
- alpha,
5222
- 1.0,
5223
- True,
5224
- b.is_transposed,
5225
- allow_tf32x3_arith,
5226
- remainder,
5227
- )
5228
- if not ret:
5229
- raise RuntimeError("adj_matmul failed.")
5230
- else:
5231
- ret = runtime.core.cutlass_gemm(
5232
- device.context,
5233
- cc,
5234
- k,
5235
- m,
5236
- n,
5237
- type_typestr(a.dtype).encode(),
5238
- ctypes.c_void_p(b[idx_start:, :, :].ptr),
5239
- ctypes.c_void_p(adj_d[idx_start:, :, :].ptr),
5240
- ctypes.c_void_p(adj_a[idx_start:, :, :].ptr),
5241
- ctypes.c_void_p(adj_a[idx_start:, :, :].ptr),
5242
- alpha,
5243
- 1.0,
5244
- not b.is_transposed,
5245
- False,
5246
- allow_tf32x3_arith,
5247
- remainder,
5248
- )
5249
- if not ret:
5250
- raise RuntimeError("adj_matmul failed.")
5251
-
5252
- # adj_b
5253
- if not b.is_transposed:
5254
- ret = runtime.core.cutlass_gemm(
5255
- device.context,
5256
- cc,
5257
- k,
5258
- n,
5259
- m,
5260
- type_typestr(a.dtype).encode(),
5261
- ctypes.c_void_p(a[idx_start:, :, :].ptr),
5262
- ctypes.c_void_p(adj_d[idx_start:, :, :].ptr),
5263
- ctypes.c_void_p(adj_b[idx_start:, :, :].ptr),
5264
- ctypes.c_void_p(adj_b[idx_start:, :, :].ptr),
5265
- alpha,
5266
- 1.0,
5267
- a.is_transposed,
5268
- True,
5269
- allow_tf32x3_arith,
5270
- remainder,
5271
- )
5272
- if not ret:
5273
- raise RuntimeError("adj_matmul failed.")
5274
- else:
5275
- ret = runtime.core.cutlass_gemm(
5276
- device.context,
5277
- cc,
5278
- n,
5279
- k,
5280
- m,
5281
- type_typestr(a.dtype).encode(),
5282
- ctypes.c_void_p(adj_d[idx_start:, :, :].ptr),
5283
- ctypes.c_void_p(a[idx_start:, :, :].ptr),
5284
- ctypes.c_void_p(adj_b[idx_start:, :, :].ptr),
5285
- ctypes.c_void_p(adj_b[idx_start:, :, :].ptr),
5286
- alpha,
5287
- 1.0,
5288
- False,
5289
- not a.is_transposed,
5290
- allow_tf32x3_arith,
5291
- remainder,
5292
- )
5293
- if not ret:
5294
- raise RuntimeError("adj_matmul failed.")
5295
-
5296
- # adj_c
5297
- warp.launch(
5298
- kernel=warp.utils.add_kernel_3d,
5299
- dim=adj_c.shape,
5300
- inputs=[adj_c, adj_d, adj_d.dtype(beta)],
5301
- device=device,
5302
- record_tape=False,
5303
- )
4763
+ raise RuntimeError("This function has been removed. Use tile primitives instead.")
5304
4764
 
5305
4765
 
5306
4766
  class HashGrid:
@@ -5683,7 +5143,7 @@ simple_type_codes = {
5683
5143
  }
5684
5144
 
5685
5145
 
5686
- def get_type_code(arg_type):
5146
+ def get_type_code(arg_type: type) -> str:
5687
5147
  if arg_type == Any:
5688
5148
  # special case for generics
5689
5149
  # note: since Python 3.11 Any is a type, so we check for it first
@@ -5747,8 +5207,8 @@ def get_type_code(arg_type):
5747
5207
  raise TypeError(f"Unrecognized type '{arg_type}'")
5748
5208
 
5749
5209
 
5750
- def get_signature(arg_types, func_name=None, arg_names=None):
5751
- type_codes = []
5210
+ def get_signature(arg_types: List[type], func_name: Optional[str] = None, arg_names: Optional[List[str]] = None) -> str:
5211
+ type_codes: List[str] = []
5752
5212
  for i, arg_type in enumerate(arg_types):
5753
5213
  try:
5754
5214
  type_codes.append(get_type_code(arg_type))