warp-lang 1.6.1__py3-none-win_amd64.whl → 1.7.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (401) hide show
  1. warp/__init__.py +21 -7
  2. warp/autograd.py +14 -6
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +424 -6
  6. warp/build_dll.py +20 -20
  7. warp/builtins.py +467 -368
  8. warp/codegen.py +193 -125
  9. warp/config.py +56 -12
  10. warp/constants.py +14 -6
  11. warp/context.py +524 -277
  12. warp/dlpack.py +22 -12
  13. warp/examples/__init__.py +14 -6
  14. warp/examples/assets/nonuniform.usd +0 -0
  15. warp/examples/assets/nvidia_logo.png +0 -0
  16. warp/examples/benchmarks/benchmark_api.py +14 -6
  17. warp/examples/benchmarks/benchmark_cloth.py +14 -6
  18. warp/examples/benchmarks/benchmark_cloth_cupy.py +14 -6
  19. warp/examples/benchmarks/benchmark_cloth_jax.py +14 -6
  20. warp/examples/benchmarks/benchmark_cloth_numba.py +15 -0
  21. warp/examples/benchmarks/benchmark_cloth_numpy.py +14 -6
  22. warp/examples/benchmarks/benchmark_cloth_paddle.py +14 -6
  23. warp/examples/benchmarks/benchmark_cloth_pytorch.py +14 -6
  24. warp/examples/benchmarks/benchmark_cloth_taichi.py +14 -6
  25. warp/examples/benchmarks/benchmark_cloth_warp.py +14 -6
  26. warp/examples/benchmarks/benchmark_gemm.py +82 -48
  27. warp/examples/benchmarks/benchmark_interop_paddle.py +14 -6
  28. warp/examples/benchmarks/benchmark_interop_torch.py +14 -6
  29. warp/examples/benchmarks/benchmark_launches.py +14 -6
  30. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  31. warp/examples/browse.py +14 -6
  32. warp/examples/core/example_cupy.py +14 -6
  33. warp/examples/core/example_dem.py +14 -6
  34. warp/examples/core/example_fluid.py +14 -6
  35. warp/examples/core/example_graph_capture.py +14 -6
  36. warp/examples/core/example_marching_cubes.py +14 -6
  37. warp/examples/core/example_mesh.py +14 -6
  38. warp/examples/core/example_mesh_intersect.py +14 -6
  39. warp/examples/core/example_nvdb.py +14 -6
  40. warp/examples/core/example_raycast.py +14 -6
  41. warp/examples/core/example_raymarch.py +14 -6
  42. warp/examples/core/example_render_opengl.py +14 -6
  43. warp/examples/core/example_sample_mesh.py +300 -0
  44. warp/examples/core/example_sph.py +14 -6
  45. warp/examples/core/example_torch.py +14 -6
  46. warp/examples/core/example_wave.py +14 -6
  47. warp/examples/fem/example_adaptive_grid.py +14 -6
  48. warp/examples/fem/example_apic_fluid.py +15 -7
  49. warp/examples/fem/example_burgers.py +16 -8
  50. warp/examples/fem/example_convection_diffusion.py +14 -6
  51. warp/examples/fem/example_convection_diffusion_dg.py +14 -6
  52. warp/examples/fem/example_deformed_geometry.py +15 -7
  53. warp/examples/fem/example_diffusion.py +14 -6
  54. warp/examples/fem/example_diffusion_3d.py +14 -6
  55. warp/examples/fem/example_diffusion_mgpu.py +14 -6
  56. warp/examples/fem/example_distortion_energy.py +15 -7
  57. warp/examples/fem/example_magnetostatics.py +20 -12
  58. warp/examples/fem/example_mixed_elasticity.py +14 -6
  59. warp/examples/fem/example_navier_stokes.py +14 -6
  60. warp/examples/fem/example_nonconforming_contact.py +14 -6
  61. warp/examples/fem/example_stokes.py +14 -6
  62. warp/examples/fem/example_stokes_transfer.py +14 -6
  63. warp/examples/fem/example_streamlines.py +14 -6
  64. warp/examples/fem/utils.py +24 -3
  65. warp/examples/interop/example_jax_callable.py +116 -0
  66. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  67. warp/examples/interop/example_jax_kernel.py +205 -0
  68. warp/examples/optim/example_bounce.py +14 -6
  69. warp/examples/optim/example_cloth_throw.py +14 -6
  70. warp/examples/optim/example_diffray.py +14 -6
  71. warp/examples/optim/example_drone.py +14 -6
  72. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  73. warp/examples/optim/example_inverse_kinematics.py +14 -6
  74. warp/examples/optim/example_inverse_kinematics_torch.py +14 -6
  75. warp/examples/optim/example_softbody_properties.py +14 -6
  76. warp/examples/optim/example_spring_cage.py +14 -6
  77. warp/examples/optim/example_trajectory.py +14 -6
  78. warp/examples/sim/example_cartpole.py +14 -6
  79. warp/examples/sim/example_cloth.py +14 -6
  80. warp/examples/sim/example_cloth_self_contact.py +14 -6
  81. warp/examples/sim/example_granular.py +14 -6
  82. warp/examples/sim/example_granular_collision_sdf.py +14 -6
  83. warp/examples/sim/example_jacobian_ik.py +14 -6
  84. warp/examples/sim/example_particle_chain.py +14 -6
  85. warp/examples/sim/example_quadruped.py +14 -6
  86. warp/examples/sim/example_rigid_chain.py +14 -6
  87. warp/examples/sim/example_rigid_contact.py +14 -6
  88. warp/examples/sim/example_rigid_force.py +14 -6
  89. warp/examples/sim/example_rigid_gyroscopic.py +14 -6
  90. warp/examples/sim/example_rigid_soft_contact.py +14 -6
  91. warp/examples/sim/example_soft_body.py +14 -6
  92. warp/examples/tile/example_tile_cholesky.py +14 -6
  93. warp/examples/tile/example_tile_convolution.py +14 -6
  94. warp/examples/tile/example_tile_fft.py +14 -6
  95. warp/examples/tile/example_tile_filtering.py +14 -6
  96. warp/examples/tile/example_tile_matmul.py +16 -10
  97. warp/examples/tile/example_tile_mlp.py +14 -6
  98. warp/examples/tile/example_tile_nbody.py +14 -6
  99. warp/examples/tile/example_tile_walker.py +14 -6
  100. warp/fabric.py +15 -0
  101. warp/fem/__init__.py +26 -1
  102. warp/fem/adaptivity.py +19 -4
  103. warp/fem/cache.py +15 -0
  104. warp/fem/dirichlet.py +15 -0
  105. warp/fem/domain.py +15 -0
  106. warp/fem/field/__init__.py +15 -0
  107. warp/fem/field/field.py +15 -0
  108. warp/fem/field/nodal_field.py +37 -68
  109. warp/fem/field/restriction.py +15 -0
  110. warp/fem/field/virtual.py +77 -23
  111. warp/fem/geometry/__init__.py +15 -0
  112. warp/fem/geometry/adaptive_nanogrid.py +24 -10
  113. warp/fem/geometry/closest_point.py +16 -1
  114. warp/fem/geometry/deformed_geometry.py +20 -2
  115. warp/fem/geometry/element.py +15 -0
  116. warp/fem/geometry/geometry.py +20 -0
  117. warp/fem/geometry/grid_2d.py +27 -12
  118. warp/fem/geometry/grid_3d.py +27 -15
  119. warp/fem/geometry/hexmesh.py +20 -7
  120. warp/fem/geometry/nanogrid.py +24 -11
  121. warp/fem/geometry/partition.py +15 -0
  122. warp/fem/geometry/quadmesh.py +28 -13
  123. warp/fem/geometry/tetmesh.py +18 -4
  124. warp/fem/geometry/trimesh.py +18 -8
  125. warp/fem/integrate.py +277 -93
  126. warp/fem/linalg.py +20 -5
  127. warp/fem/operator.py +15 -0
  128. warp/fem/polynomial.py +15 -0
  129. warp/fem/quadrature/__init__.py +15 -0
  130. warp/fem/quadrature/pic_quadrature.py +52 -22
  131. warp/fem/quadrature/quadrature.py +209 -25
  132. warp/fem/space/__init__.py +16 -1
  133. warp/fem/space/basis_function_space.py +19 -2
  134. warp/fem/space/basis_space.py +40 -18
  135. warp/fem/space/dof_mapper.py +15 -0
  136. warp/fem/space/function_space.py +15 -0
  137. warp/fem/space/grid_2d_function_space.py +15 -0
  138. warp/fem/space/grid_3d_function_space.py +15 -0
  139. warp/fem/space/hexmesh_function_space.py +17 -2
  140. warp/fem/space/nanogrid_function_space.py +15 -0
  141. warp/fem/space/partition.py +21 -2
  142. warp/fem/space/quadmesh_function_space.py +23 -8
  143. warp/fem/space/restriction.py +15 -0
  144. warp/fem/space/shape/__init__.py +15 -0
  145. warp/fem/space/shape/cube_shape_function.py +38 -23
  146. warp/fem/space/shape/shape_function.py +15 -0
  147. warp/fem/space/shape/square_shape_function.py +27 -12
  148. warp/fem/space/shape/tet_shape_function.py +15 -0
  149. warp/fem/space/shape/triangle_shape_function.py +16 -1
  150. warp/fem/space/tetmesh_function_space.py +18 -3
  151. warp/fem/space/topology.py +15 -0
  152. warp/fem/space/trimesh_function_space.py +17 -2
  153. warp/fem/types.py +15 -0
  154. warp/fem/utils.py +27 -6
  155. warp/jax.py +28 -7
  156. warp/jax_experimental/__init__.py +16 -0
  157. warp/{jax_experimental.py → jax_experimental/custom_call.py} +28 -33
  158. warp/jax_experimental/ffi.py +698 -0
  159. warp/jax_experimental/xla_ffi.py +602 -0
  160. warp/math.py +103 -6
  161. warp/native/array.h +28 -6
  162. warp/native/builtin.h +44 -9
  163. warp/native/bvh.cpp +18 -7
  164. warp/native/bvh.cu +57 -20
  165. warp/native/bvh.h +17 -7
  166. warp/native/clang/clang.cpp +45 -9
  167. warp/native/coloring.cpp +15 -6
  168. warp/native/crt.cpp +15 -6
  169. warp/native/crt.h +15 -6
  170. warp/native/cuda_crt.h +15 -6
  171. warp/native/cuda_util.cpp +29 -6
  172. warp/native/cuda_util.h +17 -6
  173. warp/native/error.cpp +15 -6
  174. warp/native/error.h +15 -6
  175. warp/native/exports.h +85 -63
  176. warp/native/fabric.h +15 -6
  177. warp/native/hashgrid.cpp +15 -6
  178. warp/native/hashgrid.cu +15 -6
  179. warp/native/hashgrid.h +15 -6
  180. warp/native/initializer_array.h +15 -6
  181. warp/native/intersect.h +41 -32
  182. warp/native/intersect_adj.h +48 -39
  183. warp/native/intersect_tri.h +17 -0
  184. warp/native/marching.cpp +16 -0
  185. warp/native/marching.cu +16 -7
  186. warp/native/marching.h +17 -0
  187. warp/native/mat.h +528 -15
  188. warp/native/mathdx.cpp +15 -6
  189. warp/native/matnn.h +15 -6
  190. warp/native/mesh.cpp +15 -6
  191. warp/native/mesh.cu +15 -6
  192. warp/native/mesh.h +25 -16
  193. warp/native/noise.h +15 -6
  194. warp/native/quat.h +114 -17
  195. warp/native/rand.h +21 -6
  196. warp/native/range.h +15 -6
  197. warp/native/reduce.cpp +15 -6
  198. warp/native/reduce.cu +15 -6
  199. warp/native/runlength_encode.cpp +15 -6
  200. warp/native/runlength_encode.cu +15 -6
  201. warp/native/scan.cpp +15 -6
  202. warp/native/scan.cu +15 -6
  203. warp/native/scan.h +15 -6
  204. warp/native/solid_angle.h +17 -0
  205. warp/native/sort.cpp +137 -65
  206. warp/native/sort.cu +167 -21
  207. warp/native/sort.h +23 -7
  208. warp/native/sparse.cpp +58 -28
  209. warp/native/sparse.cu +67 -23
  210. warp/native/spatial.h +15 -6
  211. warp/native/svd.h +131 -6
  212. warp/native/temp_buffer.h +15 -6
  213. warp/native/tile.h +316 -111
  214. warp/native/tile_reduce.h +61 -9
  215. warp/native/vec.h +83 -13
  216. warp/native/volume.cpp +100 -119
  217. warp/native/volume.cu +15 -6
  218. warp/native/volume.h +15 -6
  219. warp/native/volume_builder.cu +40 -16
  220. warp/native/volume_builder.h +21 -6
  221. warp/native/volume_impl.h +15 -6
  222. warp/native/warp.cpp +20 -12
  223. warp/native/warp.cu +114 -16
  224. warp/native/warp.h +34 -16
  225. warp/optim/__init__.py +14 -6
  226. warp/optim/adam.py +14 -6
  227. warp/optim/linear.py +25 -10
  228. warp/optim/sgd.py +14 -6
  229. warp/paddle.py +14 -6
  230. warp/render/__init__.py +14 -6
  231. warp/render/render_opengl.py +14 -6
  232. warp/render/render_usd.py +14 -6
  233. warp/render/utils.py +14 -6
  234. warp/sim/__init__.py +14 -7
  235. warp/sim/articulation.py +18 -10
  236. warp/sim/collide.py +35 -16
  237. warp/sim/graph_coloring.py +14 -6
  238. warp/sim/import_mjcf.py +463 -162
  239. warp/sim/import_snu.py +14 -7
  240. warp/sim/import_urdf.py +46 -18
  241. warp/sim/import_usd.py +14 -7
  242. warp/sim/inertia.py +14 -6
  243. warp/sim/integrator.py +14 -6
  244. warp/sim/integrator_euler.py +19 -11
  245. warp/sim/integrator_featherstone.py +17 -16
  246. warp/sim/integrator_vbd.py +222 -8
  247. warp/sim/integrator_xpbd.py +19 -11
  248. warp/sim/model.py +56 -19
  249. warp/sim/particles.py +14 -6
  250. warp/sim/render.py +14 -6
  251. warp/sim/utils.py +17 -2
  252. warp/sparse.py +657 -555
  253. warp/stubs.py +231 -19
  254. warp/tape.py +14 -6
  255. warp/tests/aux_test_class_kernel.py +14 -6
  256. warp/tests/aux_test_compile_consts_dummy.py +14 -6
  257. warp/tests/aux_test_conditional_unequal_types_kernels.py +14 -6
  258. warp/tests/aux_test_dependent.py +14 -6
  259. warp/tests/aux_test_grad_customs.py +14 -6
  260. warp/tests/aux_test_instancing_gc.py +14 -6
  261. warp/tests/aux_test_module_unload.py +14 -6
  262. warp/tests/aux_test_name_clash1.py +14 -6
  263. warp/tests/aux_test_name_clash2.py +14 -6
  264. warp/tests/aux_test_unresolved_func.py +14 -6
  265. warp/tests/aux_test_unresolved_symbol.py +14 -6
  266. warp/tests/cuda/__init__.py +0 -0
  267. warp/tests/{test_async.py → cuda/test_async.py} +14 -6
  268. warp/tests/{test_ipc.py → cuda/test_ipc.py} +14 -6
  269. warp/tests/{test_mempool.py → cuda/test_mempool.py} +53 -6
  270. warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +14 -6
  271. warp/tests/{test_peer.py → cuda/test_peer.py} +14 -6
  272. warp/tests/{test_pinned.py → cuda/test_pinned.py} +14 -6
  273. warp/tests/{test_streams.py → cuda/test_streams.py} +85 -6
  274. warp/tests/geometry/__init__.py +0 -0
  275. warp/tests/{test_bvh.py → geometry/test_bvh.py} +14 -6
  276. warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +14 -6
  277. warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +14 -6
  278. warp/tests/{test_mesh.py → geometry/test_mesh.py} +14 -6
  279. warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +14 -6
  280. warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +80 -69
  281. warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +15 -7
  282. warp/tests/{test_volume.py → geometry/test_volume.py} +55 -12
  283. warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +14 -6
  284. warp/tests/interop/__init__.py +0 -0
  285. warp/tests/{test_dlpack.py → interop/test_dlpack.py} +42 -11
  286. warp/tests/{test_jax.py → interop/test_jax.py} +14 -6
  287. warp/tests/{test_paddle.py → interop/test_paddle.py} +14 -6
  288. warp/tests/{test_torch.py → interop/test_torch.py} +14 -6
  289. warp/tests/run_coverage_serial.py +14 -6
  290. warp/tests/sim/__init__.py +0 -0
  291. warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +23 -16
  292. warp/tests/{flaky_test_sim_grad.py → sim/flaky_test_sim_grad.py} +14 -6
  293. warp/tests/{test_collision.py → sim/test_collision.py} +16 -8
  294. warp/tests/{test_coloring.py → sim/test_coloring.py} +14 -7
  295. warp/tests/{test_model.py → sim/test_model.py} +55 -7
  296. warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +14 -6
  297. warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +16 -7
  298. warp/tests/sim/test_vbd.py +597 -0
  299. warp/tests/test_adam.py +14 -6
  300. warp/tests/test_arithmetic.py +14 -6
  301. warp/tests/test_array.py +14 -6
  302. warp/tests/test_array_reduce.py +14 -6
  303. warp/tests/test_assert.py +14 -6
  304. warp/tests/test_atomic.py +14 -6
  305. warp/tests/test_bool.py +15 -7
  306. warp/tests/test_builtins_resolution.py +14 -6
  307. warp/tests/test_closest_point_edge_edge.py +14 -6
  308. warp/tests/test_codegen.py +14 -6
  309. warp/tests/test_codegen_instancing.py +14 -6
  310. warp/tests/test_compile_consts.py +14 -6
  311. warp/tests/test_conditional.py +14 -6
  312. warp/tests/test_context.py +14 -6
  313. warp/tests/test_copy.py +14 -6
  314. warp/tests/test_ctypes.py +14 -6
  315. warp/tests/test_dense.py +14 -6
  316. warp/tests/test_devices.py +14 -6
  317. warp/tests/test_examples.py +42 -42
  318. warp/tests/test_fabricarray.py +14 -6
  319. warp/tests/test_fast_math.py +14 -6
  320. warp/tests/test_fem.py +37 -10
  321. warp/tests/test_fp16.py +14 -6
  322. warp/tests/test_func.py +14 -6
  323. warp/tests/test_future_annotations.py +14 -6
  324. warp/tests/test_generics.py +14 -6
  325. warp/tests/test_grad.py +14 -6
  326. warp/tests/test_grad_customs.py +14 -6
  327. warp/tests/test_grad_debug.py +14 -6
  328. warp/tests/test_implicit_init.py +14 -6
  329. warp/tests/test_import.py +14 -6
  330. warp/tests/test_indexedarray.py +14 -6
  331. warp/tests/test_intersect.py +14 -6
  332. warp/tests/test_iter.py +14 -6
  333. warp/tests/test_large.py +14 -6
  334. warp/tests/test_launch.py +14 -6
  335. warp/tests/test_lerp.py +14 -6
  336. warp/tests/test_linear_solvers.py +15 -11
  337. warp/tests/test_lvalue.py +14 -6
  338. warp/tests/test_mat.py +247 -85
  339. warp/tests/test_mat_lite.py +14 -6
  340. warp/tests/test_mat_scalar_ops.py +18 -10
  341. warp/tests/test_math.py +14 -6
  342. warp/tests/test_mlp.py +14 -6
  343. warp/tests/test_module_hashing.py +14 -6
  344. warp/tests/test_modules_lite.py +14 -6
  345. warp/tests/test_noise.py +14 -6
  346. warp/tests/test_operators.py +14 -6
  347. warp/tests/test_options.py +14 -6
  348. warp/tests/test_overwrite.py +15 -60
  349. warp/tests/test_print.py +14 -6
  350. warp/tests/test_quat.py +81 -52
  351. warp/tests/test_rand.py +58 -43
  352. warp/tests/test_reload.py +14 -6
  353. warp/tests/test_rounding.py +14 -6
  354. warp/tests/test_runlength_encode.py +14 -6
  355. warp/tests/test_scalar_ops.py +14 -6
  356. warp/tests/test_smoothstep.py +14 -6
  357. warp/tests/test_snippet.py +15 -0
  358. warp/tests/test_sparse.py +61 -12
  359. warp/tests/test_spatial.py +89 -6
  360. warp/tests/test_special_values.py +14 -6
  361. warp/tests/test_static.py +15 -7
  362. warp/tests/test_struct.py +14 -6
  363. warp/tests/test_tape.py +14 -6
  364. warp/tests/test_transient_module.py +14 -6
  365. warp/tests/test_triangle_closest_point.py +14 -6
  366. warp/tests/test_types.py +14 -6
  367. warp/tests/test_utils.py +98 -10
  368. warp/tests/test_vec.py +60 -40
  369. warp/tests/test_vec_lite.py +14 -6
  370. warp/tests/test_vec_scalar_ops.py +14 -6
  371. warp/tests/test_verify_fp.py +14 -6
  372. warp/tests/tile/__init__.py +0 -0
  373. warp/tests/{test_tile.py → tile/test_tile.py} +150 -57
  374. warp/tests/{test_tile_load.py → tile/test_tile_load.py} +15 -7
  375. warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +23 -12
  376. warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +39 -20
  377. warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +74 -7
  378. warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +14 -6
  379. warp/tests/{test_tile_view.py → tile/test_tile_view.py} +15 -7
  380. warp/tests/unittest_serial.py +15 -6
  381. warp/tests/unittest_suites.py +59 -65
  382. warp/tests/unittest_utils.py +16 -7
  383. warp/tests/walkthrough_debug.py +14 -6
  384. warp/thirdparty/unittest_parallel.py +15 -8
  385. warp/torch.py +14 -6
  386. warp/types.py +124 -664
  387. warp/utils.py +151 -78
  388. {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/METADATA +39 -12
  389. warp_lang-1.7.0.dist-info/RECORD +429 -0
  390. {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/WHEEL +1 -1
  391. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  392. warp/examples/optim/example_walker.py +0 -309
  393. warp/native/cutlass_gemm.cpp +0 -34
  394. warp/native/cutlass_gemm.cu +0 -373
  395. warp/tests/test_matmul.py +0 -503
  396. warp/tests/test_matmul_lite.py +0 -403
  397. warp/tests/test_vbd.py +0 -378
  398. warp/tests/unused_test_misc.py +0 -69
  399. warp_lang-1.6.1.dist-info/LICENSE.md +0 -126
  400. warp_lang-1.6.1.dist-info/RECORD +0 -419
  401. {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/top_level.txt +0 -0
warp/builtins.py CHANGED
@@ -1,15 +1,24 @@
1
- # Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
2
- # NVIDIA CORPORATION and its licensors retain all intellectual property
3
- # and proprietary rights in and to this software, related documentation
4
- # and any modifications thereto. Any use, reproduction, disclosure or
5
- # distribution of this software and related documentation without an express
6
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
7
16
  import builtins
8
17
  import functools
9
- import tempfile
10
- from pathlib import Path
11
18
  from typing import Any, Callable, Mapping, Sequence
12
19
 
20
+ import warp.build
21
+ import warp.context
13
22
  from warp.codegen import Reference, Var, strip_reference
14
23
  from warp.types import *
15
24
 
@@ -32,7 +41,7 @@ def sametypes(arg_types: Mapping[str, Any]):
32
41
  return all(types_equal(arg_type_0, t) for t in arg_types_iter)
33
42
 
34
43
 
35
- def sametypes_create_value_func(default):
44
+ def sametypes_create_value_func(default: TypeVar):
36
45
  def fn(arg_types, arg_values):
37
46
  if arg_types is None:
38
47
  return default
@@ -390,7 +399,7 @@ add_builtin(
390
399
  )
391
400
 
392
401
 
393
- def scalar_infer_type(arg_types: Mapping[str, type]):
402
+ def scalar_infer_type(arg_types: Union[Mapping[str, type], Tuple[type, ...], None]):
394
403
  if arg_types is None:
395
404
  return Scalar
396
405
 
@@ -941,6 +950,12 @@ def matrix_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, An
941
950
  raise RuntimeError("the `shape` argument must be specified when initializing a matrix by value")
942
951
 
943
952
  if all(type_is_vector(x) for x in variadic_arg_types):
953
+ warp.utils.warn(
954
+ "the built-in `wp.matrix()` won't support taking column vectors as input "
955
+ "in the future. Use `wp.matrix_from_rows()` or `wp.matrix_from_cols()` instead.",
956
+ DeprecationWarning,
957
+ )
958
+
944
959
  if shape[1] != variadic_arg_count:
945
960
  raise RuntimeError(
946
961
  f"incompatible number of column vectors given ({variadic_arg_count}) "
@@ -1021,6 +1036,86 @@ add_builtin(
1021
1036
  )
1022
1037
 
1023
1038
 
1039
+ def matrix_from_vecs_create_value_func(cols: bool):
1040
+ def fn(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
1041
+ if arg_types is None:
1042
+ return matrix(shape=(Any, Any), dtype=Scalar)
1043
+
1044
+ variadic_arg_types = arg_types.get("args", ())
1045
+ variadic_arg_count = len(variadic_arg_types)
1046
+
1047
+ if not all(type_is_vector(x) for x in variadic_arg_types):
1048
+ raise RuntimeError("all arguments are expected to be vectors")
1049
+
1050
+ length = variadic_arg_types[0]._length_
1051
+ if any(x._length_ != length for x in variadic_arg_types):
1052
+ raise RuntimeError("all vectors are expected to have the same length")
1053
+
1054
+ dtype = variadic_arg_types[0]._wp_scalar_type_
1055
+ if any(x._wp_scalar_type_ != dtype for x in variadic_arg_types):
1056
+ raise RuntimeError("all vectors are expected to have the same dtype")
1057
+
1058
+ shape = (length, variadic_arg_count) if cols else (variadic_arg_count, length)
1059
+ return matrix(shape=shape, dtype=dtype)
1060
+
1061
+ return fn
1062
+
1063
+
1064
+ def matrix_from_vecs_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
1065
+ # We're in the codegen stage where we emit the code calling the built-in.
1066
+ # Further validate the given argument values if needed and map them
1067
+ # to the underlying C++ function's runtime and template params.
1068
+
1069
+ shape = return_type._shape_
1070
+ dtype = return_type._wp_scalar_type_
1071
+
1072
+ variadic_args = args.get("args", ())
1073
+
1074
+ func_args = variadic_args
1075
+
1076
+ if shape in ((2, 2), (3, 3), (4, 4)):
1077
+ # Template specializations exist for these shapes, don't pass them
1078
+ # as template parameters.
1079
+ template_args = (dtype,)
1080
+ else:
1081
+ template_args = (*shape, dtype)
1082
+
1083
+ return (func_args, template_args)
1084
+
1085
+
1086
+ def matrix_from_vecs_initializer_list_func(args, return_type):
1087
+ shape = return_type._shape_
1088
+
1089
+ return shape[0] != shape[1] or shape[0] > 4
1090
+
1091
+
1092
+ add_builtin(
1093
+ "matrix_from_cols",
1094
+ input_types={"*args": vector(length=Any, dtype=Scalar)},
1095
+ variadic=True,
1096
+ value_func=matrix_from_vecs_create_value_func(cols=True),
1097
+ dispatch_func=matrix_from_vecs_dispatch_func,
1098
+ initializer_list_func=matrix_from_vecs_initializer_list_func,
1099
+ native_func="matrix_from_cols",
1100
+ doc="Construct a matrix from column vectors.",
1101
+ group="Vector Math",
1102
+ export=False,
1103
+ )
1104
+
1105
+ add_builtin(
1106
+ "matrix_from_rows",
1107
+ input_types={"*args": vector(length=Any, dtype=Scalar)},
1108
+ variadic=True,
1109
+ value_func=matrix_from_vecs_create_value_func(cols=False),
1110
+ dispatch_func=matrix_from_vecs_dispatch_func,
1111
+ initializer_list_func=matrix_from_vecs_initializer_list_func,
1112
+ native_func="matrix_from_rows",
1113
+ doc="Construct a matrix from row vectors.",
1114
+ group="Vector Math",
1115
+ export=False,
1116
+ )
1117
+
1118
+
1024
1119
  def identity_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
1025
1120
  if arg_types is None:
1026
1121
  return matrix(shape=(Any, Any), dtype=Scalar)
@@ -1132,6 +1227,21 @@ add_builtin(
1132
1227
  while the left and right basis vectors are returned in ``U`` and ``V``.""",
1133
1228
  )
1134
1229
 
1230
+ add_builtin(
1231
+ "svd2",
1232
+ input_types={
1233
+ "A": matrix(shape=(2, 2), dtype=Float),
1234
+ "U": matrix(shape=(2, 2), dtype=Float),
1235
+ "sigma": vector(length=2, dtype=Float),
1236
+ "V": matrix(shape=(2, 2), dtype=Scalar),
1237
+ },
1238
+ value_type=None,
1239
+ group="Vector Math",
1240
+ export=False,
1241
+ doc="""Compute the SVD of a 2x2 matrix ``A``. The singular values are returned in ``sigma``,
1242
+ while the left and right basis vectors are returned in ``U`` and ``V``.""",
1243
+ )
1244
+
1135
1245
  add_builtin(
1136
1246
  "qr3",
1137
1247
  input_types={
@@ -1323,7 +1433,18 @@ add_builtin(
1323
1433
  input_types={"mat": matrix(shape=(3, 3), dtype=Float)},
1324
1434
  value_func=lambda arg_types, arg_values: quaternion(dtype=float_infer_type(arg_types)),
1325
1435
  group="Quaternion Math",
1326
- doc="Construct a quaternion from a 3x3 matrix.",
1436
+ doc="""Construct a quaternion from a 3x3 matrix.
1437
+
1438
+ If the matrix is not a pure rotation, but for example includes scaling or skewing, the result is undefined.""",
1439
+ )
1440
+ add_builtin(
1441
+ "quat_from_matrix",
1442
+ input_types={"mat": matrix(shape=(4, 4), dtype=Float)},
1443
+ value_func=lambda arg_types, arg_values: quaternion(dtype=float_infer_type(arg_types)),
1444
+ group="Quaternion Math",
1445
+ doc="""Construct a quaternion from a 4x4 matrix.
1446
+
1447
+ If the top-left 3x3 block of the matrix is not a pure rotation, but for example includes scaling or skewing, the result is undefined.""",
1327
1448
  )
1328
1449
  add_builtin(
1329
1450
  "quat_rpy",
@@ -2366,7 +2487,7 @@ add_builtin(
2366
2487
 
2367
2488
  This function converts values computed using scalar kernel code to a tile representation for input into collective operations.
2368
2489
 
2369
- * If the input value is a scalar, then the resulting tile has ``shape=(1, block_dim)``
2490
+ * If the input value is a scalar, then the resulting tile has ``shape=(block_dim,)``
2370
2491
  * If the input value is a vector, then the resulting tile has ``shape=(length(vector), block_dim)``
2371
2492
 
2372
2493
  :param x: A per-thread local value, e.g. scalar, vector, or matrix.
@@ -2660,11 +2781,9 @@ def tile_broadcast_value_func(arg_types, arg_values):
2660
2781
  def tile_broadcast_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
2661
2782
  tile = arg_values["a"]
2662
2783
 
2663
- template_args = []
2664
- template_args.append(return_type.shape[0])
2665
- template_args.append(return_type.shape[1])
2666
- template_args.append(return_type.strides[0])
2667
- template_args.append(return_type.strides[1])
2784
+ assert len(return_type.shape) == len(return_type.strides)
2785
+ assert 1 <= len(return_type.shape) <= 4
2786
+ template_args = [*return_type.shape, *return_type.strides]
2668
2787
 
2669
2788
  return ((tile,), template_args)
2670
2789
 
@@ -2677,56 +2796,17 @@ add_builtin(
2677
2796
  variadic=False,
2678
2797
  doc="""Broadcast a tile.
2679
2798
 
2680
- This function will attempt to broadcast the input tile ``a`` to the destination shape (m, n).
2681
-
2799
+ Broadcasts the input tile ``a`` to the destination shape.
2682
2800
  Broadcasting follows NumPy broadcast rules.
2683
2801
 
2684
2802
  :param a: Tile to broadcast
2685
2803
  :param shape: The shape to broadcast to
2686
- :returns: Tile with broadcast ``shape=(m, n)``""",
2804
+ :returns: Tile with broadcast shape""",
2687
2805
  group="Tile Primitives",
2688
2806
  export=False,
2689
2807
  )
2690
2808
 
2691
2809
 
2692
- def tile_matmul_value_func(arg_types, arg_values):
2693
- # return generic type (for doc builds)
2694
- if arg_types is None:
2695
- return Tile(dtype=Any, shape=Any)
2696
-
2697
- if len(arg_types) != 3:
2698
- raise TypeError(f"tile_matmul() takes exactly 3 positional arguments but {len(arg_types)} were given")
2699
-
2700
- return None
2701
-
2702
-
2703
- def tile_matmul_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
2704
- a = arg_values["a"]
2705
- b = arg_values["b"]
2706
- out = arg_values["out"]
2707
-
2708
- # force the storage type of the input variables to shared memory
2709
- a.type.storage = "shared"
2710
- b.type.storage = "shared"
2711
- out.type.storage = "shared"
2712
-
2713
- template_args = []
2714
- return ((a, b, out), template_args)
2715
-
2716
-
2717
- add_builtin(
2718
- "tile_matmul_scalar",
2719
- input_types={"a": Tile, "b": Tile, "out": Tile},
2720
- value_func=tile_matmul_value_func,
2721
- dispatch_func=tile_matmul_dispatch_func,
2722
- variadic=True,
2723
- doc="Compute matrix product and accumulate out += a*b.",
2724
- group="Tile Primitives",
2725
- hidden=True,
2726
- export=False,
2727
- )
2728
-
2729
-
2730
2810
  def tile_sum_value_func(arg_types, arg_values):
2731
2811
  # return generic type (for doc builds)
2732
2812
  if arg_types is None:
@@ -3021,7 +3101,7 @@ def tile_binary_map_value_func(arg_types, arg_values):
3021
3101
 
3022
3102
  for i in range(len(a.shape)):
3023
3103
  if a.shape[i] != b.shape[i]:
3024
- raise ValueError(f"tile_map() shapes do not match on dimension {i}, got {a.shape[i]} and {b.shape[i]}")
3104
+ raise ValueError(f"tile_map() shapes do not match on dimension {i}, got {a.shape} and {b.shape}")
3025
3105
 
3026
3106
  return TileBinaryMap(a, b)
3027
3107
 
@@ -3798,6 +3878,18 @@ _volume_supported_value_types = {
3798
3878
  }
3799
3879
 
3800
3880
 
3881
+ def _is_volume_type_supported(dtype):
3882
+ for typ in _volume_supported_value_types:
3883
+ if types_equal(typ, dtype):
3884
+ return True
3885
+ return False
3886
+
3887
+
3888
+ def _check_volume_type_is_supported(dtype):
3889
+ if not _is_volume_type_supported(dtype):
3890
+ raise RuntimeError(f"unsupported volume type `{type_repr(dtype)}`")
3891
+
3892
+
3801
3893
  def check_volume_value_grad_compatibility(dtype, grad_dtype):
3802
3894
  if type_is_vector(dtype):
3803
3895
  expected = matrix(shape=(type_length(dtype), 3), dtype=type_scalar_type(dtype))
@@ -3813,9 +3905,7 @@ def volume_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, An
3813
3905
  return Any
3814
3906
 
3815
3907
  dtype = arg_values["dtype"]
3816
-
3817
- if dtype not in _volume_supported_value_types:
3818
- raise RuntimeError(f"unsupported volume type `{dtype.__name__}`")
3908
+ _check_volume_type_is_supported(dtype)
3819
3909
 
3820
3910
  return dtype
3821
3911
 
@@ -3851,9 +3941,7 @@ def volume_sample_grad_value_func(arg_types: Mapping[str, type], arg_values: Map
3851
3941
  return Any
3852
3942
 
3853
3943
  dtype = arg_values["dtype"]
3854
-
3855
- if dtype not in _volume_supported_value_types:
3856
- raise RuntimeError(f"unsupported volume type `{dtype.__name__}`")
3944
+ _check_volume_type_is_supported(dtype)
3857
3945
 
3858
3946
  check_volume_value_grad_compatibility(dtype, arg_types["grad"])
3859
3947
 
@@ -3891,9 +3979,7 @@ def volume_lookup_value_func(arg_types: Mapping[str, type], arg_values: Mapping[
3891
3979
  return Any
3892
3980
 
3893
3981
  dtype = arg_values["dtype"]
3894
-
3895
- if dtype not in _volume_supported_value_types:
3896
- raise RuntimeError(f"unsupported volume type `{dtype.__name__}`")
3982
+ _check_volume_type_is_supported(dtype)
3897
3983
 
3898
3984
  return dtype
3899
3985
 
@@ -3930,9 +4016,7 @@ def volume_store_value_func(arg_types: Mapping[str, type], arg_values: Mapping[s
3930
4016
  return None
3931
4017
 
3932
4018
  dtype = arg_types["value"]
3933
-
3934
- if dtype not in _volume_supported_value_types:
3935
- raise RuntimeError(f"unsupported volume type `{dtype.__name__}`")
4019
+ _check_volume_type_is_supported(dtype)
3936
4020
 
3937
4021
  return None
3938
4022
 
@@ -4182,6 +4266,20 @@ add_builtin(
4182
4266
  group="Random",
4183
4267
  doc="Return a random integer between [low, high).",
4184
4268
  )
4269
+ add_builtin(
4270
+ "randu",
4271
+ input_types={"state": uint32},
4272
+ value_type=uint32,
4273
+ group="Random",
4274
+ doc="Return a random unsigned integer in the range [0, 2^32).",
4275
+ )
4276
+ add_builtin(
4277
+ "randu",
4278
+ input_types={"state": uint32, "low": uint32, "high": uint32},
4279
+ value_type=uint32,
4280
+ group="Random",
4281
+ doc="Return a random unsigned integer between [low, high).",
4282
+ )
4185
4283
  add_builtin(
4186
4284
  "randf",
4187
4285
  input_types={"state": uint32},
@@ -4490,11 +4588,31 @@ add_builtin(
4490
4588
  export=False,
4491
4589
  group="Utility",
4492
4590
  )
4591
+
4592
+
4593
+ def select_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
4594
+ warp.utils.warn(
4595
+ "wp.select() is deprecated and will be removed in a future\n"
4596
+ "version. Use wp.where(cond, value_if_true, value_if_false) instead.",
4597
+ category=DeprecationWarning,
4598
+ )
4599
+
4600
+ func_args = tuple(args.values())
4601
+ template_args = ()
4602
+
4603
+ return (func_args, template_args)
4604
+
4605
+
4493
4606
  add_builtin(
4494
4607
  "select",
4495
4608
  input_types={"cond": builtins.bool, "value_if_false": Any, "value_if_true": Any},
4496
4609
  value_func=lambda arg_types, arg_values: Any if arg_types is None else arg_types["value_if_false"],
4497
- doc="Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true``",
4610
+ dispatch_func=select_dispatch_func,
4611
+ doc="""Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true``.
4612
+
4613
+ .. deprecated:: 1.7
4614
+ Use :func:`where` instead, which has the more intuitive argument order:
4615
+ ``where(cond, value_if_true, value_if_false)``.""",
4498
4616
  group="Utility",
4499
4617
  )
4500
4618
  for t in int_types:
@@ -4502,14 +4620,47 @@ for t in int_types:
4502
4620
  "select",
4503
4621
  input_types={"cond": t, "value_if_false": Any, "value_if_true": Any},
4504
4622
  value_func=lambda arg_types, arg_values: Any if arg_types is None else arg_types["value_if_false"],
4505
- doc="Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true``",
4623
+ dispatch_func=select_dispatch_func,
4624
+ doc="""Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true``.
4625
+
4626
+ .. deprecated:: 1.7
4627
+ Use :func:`where` instead, which has the more intuitive argument order:
4628
+ ``where(cond, value_if_true, value_if_false)``.""",
4506
4629
  group="Utility",
4507
4630
  )
4508
4631
  add_builtin(
4509
4632
  "select",
4510
4633
  input_types={"arr": array(dtype=Any), "value_if_false": Any, "value_if_true": Any},
4511
4634
  value_func=lambda arg_types, arg_values: Any if arg_types is None else arg_types["value_if_false"],
4512
- doc="Select between two arguments, if ``arr`` is null then return ``value_if_false``, otherwise return ``value_if_true``",
4635
+ dispatch_func=select_dispatch_func,
4636
+ doc="""Select between two arguments, if ``arr`` is null then return ``value_if_false``, otherwise return ``value_if_true``.
4637
+
4638
+ .. deprecated:: 1.7
4639
+ Use :func:`where` instead, which has the more intuitive argument order:
4640
+ ``where(arr, value_if_true, value_if_false)``.""",
4641
+ group="Utility",
4642
+ )
4643
+
4644
+ add_builtin(
4645
+ "where",
4646
+ input_types={"cond": builtins.bool, "value_if_true": Any, "value_if_false": Any},
4647
+ value_func=lambda arg_types, arg_values: Any if arg_types is None else arg_types["value_if_false"],
4648
+ doc="Select between two arguments, if ``cond`` is ``True`` then return ``value_if_true``, otherwise return ``value_if_false``.",
4649
+ group="Utility",
4650
+ )
4651
+ for t in int_types:
4652
+ add_builtin(
4653
+ "where",
4654
+ input_types={"cond": t, "value_if_true": Any, "value_if_false": Any},
4655
+ value_func=lambda arg_types, arg_values: Any if arg_types is None else arg_types["value_if_false"],
4656
+ doc="Select between two arguments, if ``cond`` is ``True`` then return ``value_if_true``, otherwise return ``value_if_false``.",
4657
+ group="Utility",
4658
+ )
4659
+ add_builtin(
4660
+ "where",
4661
+ input_types={"arr": array(dtype=Any), "value_if_true": Any, "value_if_false": Any},
4662
+ value_func=lambda arg_types, arg_values: Any if arg_types is None else arg_types["value_if_false"],
4663
+ doc="Select between two arguments, if ``arr`` is not null then return ``value_if_true``, otherwise return ``value_if_false``.",
4513
4664
  group="Utility",
4514
4665
  )
4515
4666
 
@@ -5103,33 +5254,51 @@ add_builtin(
5103
5254
  )
5104
5255
 
5105
5256
 
5257
+ # implements vector[index] = value
5258
+ add_builtin(
5259
+ "assign_inplace",
5260
+ input_types={"a": vector(length=Any, dtype=Scalar), "i": int, "value": Scalar},
5261
+ value_type=None,
5262
+ hidden=True,
5263
+ group="Utility",
5264
+ )
5265
+
5266
+ # implements quaternion[index] = value
5267
+ add_builtin(
5268
+ "assign_inplace",
5269
+ input_types={"a": quaternion(dtype=Scalar), "i": int, "value": Scalar},
5270
+ value_type=None,
5271
+ hidden=True,
5272
+ group="Utility",
5273
+ )
5274
+
5275
+
5106
5276
  def vector_assign_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
5107
5277
  vec_type = arg_types["a"]
5108
5278
  return vec_type
5109
5279
 
5110
5280
 
5111
- # implements vector[index] = value
5281
+ # implements vector[index] = value, performs a copy internally if wp.config.enable_vector_component_overwrites is True
5112
5282
  add_builtin(
5113
- "assign",
5283
+ "assign_copy",
5114
5284
  input_types={"a": vector(length=Any, dtype=Scalar), "i": int, "value": Scalar},
5115
5285
  value_func=vector_assign_value_func,
5116
5286
  hidden=True,
5117
5287
  group="Utility",
5118
5288
  )
5119
5289
 
5120
- # implements quaternion[index] = value
5290
+ # implements quaternion[index] = value, performs a copy internally if wp.config.enable_vector_component_overwrites is True
5121
5291
  add_builtin(
5122
- "assign",
5292
+ "assign_copy",
5123
5293
  input_types={"a": quaternion(dtype=Scalar), "i": int, "value": Scalar},
5124
5294
  value_func=vector_assign_value_func,
5125
5295
  hidden=True,
5126
5296
  group="Utility",
5127
5297
  )
5128
5298
 
5129
-
5130
5299
  # implements vector[idx] += scalar
5131
5300
  add_builtin(
5132
- "augassign_add",
5301
+ "add_inplace",
5133
5302
  input_types={"a": vector(length=Any, dtype=Scalar), "i": int, "value": Scalar},
5134
5303
  value_type=None,
5135
5304
  hidden=True,
@@ -5138,7 +5307,7 @@ add_builtin(
5138
5307
 
5139
5308
  # implements quaternion[idx] += scalar
5140
5309
  add_builtin(
5141
- "augassign_add",
5310
+ "add_inplace",
5142
5311
  input_types={"a": quaternion(dtype=Scalar), "i": int, "value": Scalar},
5143
5312
  value_type=None,
5144
5313
  hidden=True,
@@ -5147,7 +5316,7 @@ add_builtin(
5147
5316
 
5148
5317
  # implements vector[idx] -= scalar
5149
5318
  add_builtin(
5150
- "augassign_sub",
5319
+ "sub_inplace",
5151
5320
  input_types={"a": vector(length=Any, dtype=Scalar), "i": int, "value": Scalar},
5152
5321
  value_type=None,
5153
5322
  hidden=True,
@@ -5156,7 +5325,7 @@ add_builtin(
5156
5325
 
5157
5326
  # implements quaternion[idx] -= scalar
5158
5327
  add_builtin(
5159
- "augassign_sub",
5328
+ "sub_inplace",
5160
5329
  input_types={"a": quaternion(dtype=Scalar), "i": int, "value": Scalar},
5161
5330
  value_type=None,
5162
5331
  hidden=True,
@@ -5200,11 +5369,6 @@ add_builtin(
5200
5369
  )
5201
5370
 
5202
5371
 
5203
- def matrix_assign_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
5204
- mat_type = arg_types["a"]
5205
- return mat_type
5206
-
5207
-
5208
5372
  def matrix_vector_sametype(arg_types: Mapping[str, Any]):
5209
5373
  mat_size = arg_types["a"]._shape_[0]
5210
5374
  vec_size = arg_types["value"]._length_
@@ -5215,7 +5379,33 @@ def matrix_vector_sametype(arg_types: Mapping[str, Any]):
5215
5379
 
5216
5380
  # implements matrix[i,j] = scalar
5217
5381
  add_builtin(
5218
- "assign",
5382
+ "assign_inplace",
5383
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "j": int, "value": Scalar},
5384
+ value_type=None,
5385
+ hidden=True,
5386
+ group="Utility",
5387
+ )
5388
+
5389
+
5390
+ # implements matrix[i] = vector
5391
+ add_builtin(
5392
+ "assign_inplace",
5393
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "value": vector(length=Any, dtype=Scalar)},
5394
+ constraint=matrix_vector_sametype,
5395
+ value_type=None,
5396
+ hidden=True,
5397
+ group="Utility",
5398
+ )
5399
+
5400
+
5401
+ def matrix_assign_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
5402
+ mat_type = arg_types["a"]
5403
+ return mat_type
5404
+
5405
+
5406
+ # implements matrix[i,j] = scalar
5407
+ add_builtin(
5408
+ "assign_copy",
5219
5409
  input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "j": int, "value": Scalar},
5220
5410
  value_func=matrix_assign_value_func,
5221
5411
  hidden=True,
@@ -5225,7 +5415,7 @@ add_builtin(
5225
5415
 
5226
5416
  # implements matrix[i] = vector
5227
5417
  add_builtin(
5228
- "assign",
5418
+ "assign_copy",
5229
5419
  input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "value": vector(length=Any, dtype=Scalar)},
5230
5420
  constraint=matrix_vector_sametype,
5231
5421
  value_func=matrix_assign_value_func,
@@ -5236,7 +5426,7 @@ add_builtin(
5236
5426
 
5237
5427
  # implements matrix[i,j] += scalar
5238
5428
  add_builtin(
5239
- "augassign_add",
5429
+ "add_inplace",
5240
5430
  input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "j": int, "value": Scalar},
5241
5431
  value_type=None,
5242
5432
  hidden=True,
@@ -5244,9 +5434,20 @@ add_builtin(
5244
5434
  )
5245
5435
 
5246
5436
 
5437
+ # implements matrix[i] += vector
5438
+ add_builtin(
5439
+ "add_inplace",
5440
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "value": vector(length=Any, dtype=Scalar)},
5441
+ constraint=matrix_vector_sametype,
5442
+ value_type=None,
5443
+ hidden=True,
5444
+ group="Utility",
5445
+ )
5446
+
5447
+
5247
5448
  # implements matrix[i,j] -= scalar
5248
5449
  add_builtin(
5249
- "augassign_sub",
5450
+ "sub_inplace",
5250
5451
  input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "j": int, "value": Scalar},
5251
5452
  value_type=None,
5252
5453
  hidden=True,
@@ -5254,6 +5455,16 @@ add_builtin(
5254
5455
  )
5255
5456
 
5256
5457
 
5458
+ # implements matrix[i] -= vector
5459
+ add_builtin(
5460
+ "sub_inplace",
5461
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "value": vector(length=Any, dtype=Scalar)},
5462
+ value_type=None,
5463
+ hidden=True,
5464
+ group="Utility",
5465
+ )
5466
+
5467
+
5257
5468
  for t in scalar_types + vector_types + (bool,):
5258
5469
  if "vec" in t.__name__ or "mat" in t.__name__:
5259
5470
  continue
@@ -5401,7 +5612,27 @@ add_builtin(
5401
5612
  )
5402
5613
  add_builtin(
5403
5614
  "expect_near",
5404
- input_types={"a": vec3, "b": vec3, "tolerance": float},
5615
+ input_types={"a": vector(length=Any, dtype=Float), "b": vector(length=Any, dtype=Float), "tolerance": Float},
5616
+ defaults={"tolerance": 1.0e-6},
5617
+ value_type=None,
5618
+ doc="Prints an error to stdout if any element of ``a`` and ``b`` are not closer than tolerance in magnitude",
5619
+ group="Utility",
5620
+ )
5621
+ add_builtin(
5622
+ "expect_near",
5623
+ input_types={"a": quaternion(dtype=Float), "b": quaternion(dtype=Float), "tolerance": Float},
5624
+ defaults={"tolerance": 1.0e-6},
5625
+ value_type=None,
5626
+ doc="Prints an error to stdout if any element of ``a`` and ``b`` are not closer than tolerance in magnitude",
5627
+ group="Utility",
5628
+ )
5629
+ add_builtin(
5630
+ "expect_near",
5631
+ input_types={
5632
+ "a": matrix(shape=(Any, Any), dtype=Float),
5633
+ "b": matrix(shape=(Any, Any), dtype=Float),
5634
+ "tolerance": Float,
5635
+ },
5405
5636
  defaults={"tolerance": 1.0e-6},
5406
5637
  value_type=None,
5407
5638
  doc="Prints an error to stdout if any element of ``a`` and ``b`` are not closer than tolerance in magnitude",
@@ -5980,7 +6211,7 @@ add_builtin(
5980
6211
  ##
5981
6212
  ## Matmul
5982
6213
  ##
5983
- def tile_matmul_generic_value_func(arg_types, arg_values):
6214
+ def tile_matmul_value_func(arg_types, arg_values):
5984
6215
  # return generic type (for doc builds)
5985
6216
  if arg_types is None:
5986
6217
  return Tile(dtype=Any, shape=Any)
@@ -6006,7 +6237,7 @@ def tile_matmul_generic_value_func(arg_types, arg_values):
6006
6237
  return None
6007
6238
 
6008
6239
 
6009
- def tile_matmul_generic_lto_dispatch_func(
6240
+ def tile_matmul_lto_dispatch_func(
6010
6241
  arg_types: Mapping[str, type],
6011
6242
  return_type: Any,
6012
6243
  return_values: List[Var],
@@ -6045,142 +6276,82 @@ def tile_matmul_generic_lto_dispatch_func(
6045
6276
  out.type.storage = "shared"
6046
6277
  template_args = [accumulate]
6047
6278
 
6048
- # Maps Python/Warp types to C++ types and enums
6049
- def cublasdx_type_map(dtype):
6050
- if dtype == float16:
6051
- return ("wp::float16", 3, 0)
6052
- if dtype == float32:
6053
- return ("wp::float32", 5, 0)
6054
- if dtype == float64:
6055
- return ("wp::float64", 6, 0)
6056
- if dtype == vec2h:
6057
- return ("wp::vec2h", 3, 1)
6058
- if dtype == vec2f:
6059
- return ("wp::vec2f", 5, 1)
6060
- if dtype == vec2d:
6061
- return ("wp::vec2d", 6, 1)
6062
- raise TypeError("Unsupported input type in tile_matmul")
6063
-
6064
- def cublasdx_arrangement_map(layout):
6065
- if layout == "colmajor":
6066
- return 0 # CUBLASDX_ARRANGEMENT_COL_MAJOR
6067
- if layout == "rowmajor":
6068
- return 1 # CUBLASDX_ARRANGEMENT_ROW_MAJOR
6069
- raise ValueError("Unsupported layout in tile_matmul")
6070
-
6071
- # generate the LTO
6072
6279
  M, K = a.type.shape[0], a.type.shape[1]
6073
6280
  _, N = b.type.shape[0], b.type.shape[1]
6074
6281
  num_threads = options["block_dim"]
6075
6282
  arch = options["output_arch"]
6076
6283
 
6077
- def make_function(M, N, K, adtype, bdtype, cdtype, alayout, blayout, clayout):
6078
- (a_dtype, a_prec, a_type) = cublasdx_type_map(adtype)
6079
- (b_dtype, b_prec, b_type) = cublasdx_type_map(bdtype)
6080
- (c_dtype, c_prec, c_type) = cublasdx_type_map(cdtype)
6081
- a_arrangement = cublasdx_arrangement_map(alayout)
6082
- b_arrangement = cublasdx_arrangement_map(blayout)
6083
- c_arrangement = cublasdx_arrangement_map(clayout)
6084
-
6085
- if a_type != b_type or a_type != c_type:
6086
- raise TypeError("time_matmul(A, B, C) requires all inputs to be real or complex")
6087
-
6088
- element_type = a_type
6089
-
6090
- lto_symbol = f"dot_{M}_{N}_{K}_{arch}_{num_threads}_{a_arrangement}_{b_arrangement}_{c_arrangement}_{a_prec}_{b_prec}_{c_prec}_{element_type}"
6284
+ if arch is None or not warp.context.runtime.core.is_mathdx_enabled():
6285
+ # CPU/no-MathDx dispatch
6286
+ return ((0, 0, 0, a, b, out), template_args, [], 0)
6287
+ else:
6091
6288
 
6092
- # early out if LTO for this combination already exists for this module
6093
- if lto_symbol in builder.ltoirs:
6094
- return lto_symbol, builder.ltoirs[lto_symbol]
6289
+ def tile_flip_layout(layout):
6290
+ if layout == "rowmajor":
6291
+ return "colmajor"
6292
+ elif layout == "colmajor":
6293
+ return "rowmajor"
6095
6294
 
6096
- # otherwise compile LTO
6097
- lto_code = tempfile.NamedTemporaryFile(prefix="warp", delete=False)
6098
- result = warp.context.runtime.core.cuda_compile_dot(
6099
- lto_code.name.encode("utf-8"),
6100
- lto_symbol.encode("utf-8"),
6101
- 0,
6102
- None,
6103
- None,
6295
+ # generate the LTOs
6296
+ # C += A * B
6297
+ (fun_forward, lto_forward) = warp.build.build_lto_dot(
6298
+ M,
6299
+ N,
6300
+ K,
6301
+ a.type.dtype,
6302
+ b.type.dtype,
6303
+ out.type.dtype,
6304
+ a.type.layout,
6305
+ b.type.layout,
6306
+ out.type.layout,
6104
6307
  arch,
6308
+ num_threads,
6309
+ builder,
6310
+ )
6311
+ # adjA += adjC * B^T - Transpose ~= flipped layout
6312
+ (fun_backward_A, lto_backward_A) = warp.build.build_lto_dot(
6105
6313
  M,
6314
+ K,
6106
6315
  N,
6316
+ out.type.dtype,
6317
+ b.type.dtype,
6318
+ a.type.dtype,
6319
+ out.type.layout,
6320
+ tile_flip_layout(b.type.layout),
6321
+ a.type.layout,
6322
+ arch,
6323
+ num_threads,
6324
+ builder,
6325
+ )
6326
+ # adjB += A^T * adjC - Transpose ~= flipped layout
6327
+ (fun_backward_B, lto_backward_B) = warp.build.build_lto_dot(
6107
6328
  K,
6108
- a_prec,
6109
- b_prec,
6110
- c_prec,
6111
- element_type,
6112
- a_arrangement,
6113
- b_arrangement,
6114
- c_arrangement,
6329
+ N,
6330
+ M,
6331
+ a.type.dtype,
6332
+ out.type.dtype,
6333
+ b.type.dtype,
6334
+ tile_flip_layout(a.type.layout),
6335
+ out.type.layout,
6336
+ b.type.layout,
6337
+ arch,
6115
6338
  num_threads,
6339
+ builder,
6116
6340
  )
6117
- lto_code_path = Path(lto_code.name)
6118
- if not result:
6119
- lto_code.close()
6120
- if lto_code_path.exists():
6121
- lto_code_path.unlink()
6122
- raise RuntimeError("Failed to compile tile_matmul")
6123
- else:
6124
- with open(lto_code.name, "rb") as f:
6125
- lto_code_data = f.read()
6126
- lto_code.close()
6127
- lto_code_path.unlink()
6128
-
6129
- builder.ltoirs[lto_symbol] = lto_code_data
6130
- builder.ltoirs_decl[lto_symbol] = (
6131
- f"void {lto_symbol}({c_dtype}, {a_dtype}*, {b_dtype}*, {c_dtype}, {c_dtype}*);"
6132
- )
6133
-
6134
- return lto_symbol, lto_code_data
6135
6341
 
6136
- def tile_flip_layout(layout):
6137
- if layout == "rowmajor":
6138
- return "colmajor"
6139
- elif layout == "colmajor":
6140
- return "rowmajor"
6141
-
6142
- # C += A * B
6143
- (fun_forward, lto_forward) = make_function(
6144
- M, N, K, a.type.dtype, b.type.dtype, out.type.dtype, a.type.layout, b.type.layout, out.type.layout
6145
- )
6146
- # adjA += adjC * B^T - Transpose ~= flipped layout
6147
- (fun_backward_A, lto_backward_A) = make_function(
6148
- M,
6149
- K,
6150
- N,
6151
- out.type.dtype,
6152
- b.type.dtype,
6153
- a.type.dtype,
6154
- out.type.layout,
6155
- tile_flip_layout(b.type.layout),
6156
- a.type.layout,
6157
- )
6158
- # adjB += A^T * adjC - Transpose ~= flipped layout
6159
- (fun_backward_B, lto_backward_B) = make_function(
6160
- K,
6161
- N,
6162
- M,
6163
- a.type.dtype,
6164
- out.type.dtype,
6165
- b.type.dtype,
6166
- tile_flip_layout(a.type.layout),
6167
- out.type.layout,
6168
- b.type.layout,
6169
- )
6170
-
6171
- return (
6172
- (
6173
- Var(fun_forward, str, False, True, False),
6174
- Var(fun_backward_A, str, False, True, False),
6175
- Var(fun_backward_B, str, False, True, False),
6176
- a,
6177
- b,
6178
- out,
6179
- ),
6180
- template_args,
6181
- [lto_forward, lto_backward_A, lto_backward_B],
6182
- 0,
6183
- )
6342
+ return (
6343
+ (
6344
+ Var(fun_forward, str, False, True, False),
6345
+ Var(fun_backward_A, str, False, True, False),
6346
+ Var(fun_backward_B, str, False, True, False),
6347
+ a,
6348
+ b,
6349
+ out,
6350
+ ),
6351
+ template_args,
6352
+ [lto_forward, lto_backward_A, lto_backward_B],
6353
+ 0,
6354
+ )
6184
6355
 
6185
6356
 
6186
6357
  add_builtin(
@@ -6190,8 +6361,8 @@ add_builtin(
6190
6361
  "b": Tile(dtype=Any, shape=Any),
6191
6362
  "out": Tile(dtype=Any, shape=Any),
6192
6363
  },
6193
- value_func=tile_matmul_generic_value_func,
6194
- lto_dispatch_func=tile_matmul_generic_lto_dispatch_func,
6364
+ value_func=tile_matmul_value_func,
6365
+ lto_dispatch_func=tile_matmul_lto_dispatch_func,
6195
6366
  variadic=False,
6196
6367
  doc="""Computes the matrix product and accumulates ``out += a*b``.
6197
6368
 
@@ -6199,7 +6370,7 @@ add_builtin(
6199
6370
  * fp16, fp32, fp64 (real)
6200
6371
  * vec2h, vec2f, vec2d (complex)
6201
6372
 
6202
- All input and output tiles must have the same datatype. Tile data will be automatically be migrated
6373
+ All input and output tiles must have the same datatype. Tile data will automatically be migrated
6203
6374
  to shared memory if necessary and will use TensorCore operations when available.
6204
6375
 
6205
6376
  :param a: A tile with ``shape=(M, K)``
@@ -6213,8 +6384,8 @@ add_builtin(
6213
6384
  add_builtin(
6214
6385
  "tile_matmul",
6215
6386
  input_types={"a": Tile(dtype=Any, shape=Any), "b": Tile(dtype=Any, shape=Any)},
6216
- value_func=tile_matmul_generic_value_func,
6217
- lto_dispatch_func=tile_matmul_generic_lto_dispatch_func,
6387
+ value_func=tile_matmul_value_func,
6388
+ lto_dispatch_func=tile_matmul_lto_dispatch_func,
6218
6389
  variadic=False,
6219
6390
  doc="""Computes the matrix product ``out = a*b``.
6220
6391
 
@@ -6222,7 +6393,7 @@ add_builtin(
6222
6393
  * fp16, fp32, fp64 (real)
6223
6394
  * vec2h, vec2f, vec2d (complex)
6224
6395
 
6225
- Both input tiles must have the same datatype. Tile data will be automatically be migrated
6396
+ Both input tiles must have the same datatype. Tile data will automatically be migrated
6226
6397
  to shared memory if necessary and will use TensorCore operations when available.
6227
6398
 
6228
6399
  :param a: A tile with ``shape=(M, K)``
@@ -6294,59 +6465,29 @@ def tile_fft_generic_lto_dispatch_func(
6294
6465
  num_threads = options["block_dim"]
6295
6466
  arch = options["output_arch"]
6296
6467
  ept = size // num_threads
6297
- lto_symbol = f"fft_{size}_{ept}_{arch}_{direction}_{precision}"
6298
-
6299
- # early out if LTO for this combination already exists for this module
6300
- if lto_symbol in builder.ltoirs:
6301
- return lto_symbol, builder.ltoirs[lto_symbol]
6302
-
6303
- # otherwise compile LTO
6304
- lto_code = tempfile.NamedTemporaryFile(prefix="warp", delete=False)
6305
- shared_memory_size = ctypes.c_int(0)
6306
-
6307
- result = warp.context.runtime.core.cuda_compile_fft(
6308
- lto_code.name.encode("utf-8"),
6309
- lto_symbol.encode("utf-8"),
6310
- 0,
6311
- None,
6312
- None,
6313
- arch,
6314
- size,
6315
- ept,
6316
- dir,
6317
- precision,
6318
- ctypes.byref(shared_memory_size),
6319
- )
6320
- lto_code_path = Path(lto_code.name)
6321
- if not result:
6322
- lto_code.close()
6323
- if lto_code_path.exists():
6324
- lto_code_path.unlink()
6325
- raise RuntimeError("Failed to compile tile_fft")
6326
-
6327
- with open(lto_code.name, "rb") as f:
6328
- lto_code_data = f.read()
6329
-
6330
- lto_code.close()
6331
- lto_code_path.unlink()
6332
-
6333
- builder.ltoirs[lto_symbol] = lto_code_data
6334
-
6335
- shared_memory_bytes = Tile.round_up(shared_memory_size.value)
6336
-
6337
- return (
6338
- (
6339
- Var(lto_symbol, str, False, True, False),
6340
- Var(dtype, str, False, True, False),
6341
- Var(str(shared_memory_bytes), str, False, True, False),
6342
- Var(str(batch), str, False, True, False),
6343
- Var(str(ept), str, False, True, False),
6344
- inout,
6345
- ),
6346
- [],
6347
- [lto_code_data],
6348
- shared_memory_bytes,
6349
- )
6468
+
6469
+ if arch is None or not warp.context.runtime.core.is_mathdx_enabled():
6470
+ # CPU/no-MathDx dispatch
6471
+ return ([], [], [], 0)
6472
+ else:
6473
+ # generate the LTO
6474
+ lto_symbol, lto_code_data, shared_memory_bytes = warp.build.build_lto_fft(
6475
+ arch, size, ept, direction, dir, precision, builder
6476
+ )
6477
+
6478
+ return (
6479
+ (
6480
+ Var(lto_symbol, str, False, True, False),
6481
+ Var(dtype, str, False, True, False),
6482
+ Var(str(shared_memory_bytes), str, False, True, False),
6483
+ Var(str(batch), str, False, True, False),
6484
+ Var(str(ept), str, False, True, False),
6485
+ inout,
6486
+ ),
6487
+ [],
6488
+ [lto_code_data],
6489
+ shared_memory_bytes,
6490
+ )
6350
6491
 
6351
6492
 
6352
6493
  add_builtin(
@@ -6408,7 +6549,7 @@ def tile_cholesky_generic_value_func(arg_types, arg_values):
6408
6549
  raise TypeError(f"tile_cholesky() argument must be a tile, got {a!r}")
6409
6550
 
6410
6551
  if len(a.shape) != 2:
6411
- raise ValueError("tile_cholesky() argumust must be a 2D tile")
6552
+ raise ValueError("tile_cholesky() argument must be a 2D tile")
6412
6553
 
6413
6554
  if a.shape[0] != a.shape[1]:
6414
6555
  raise ValueError("tile_cholesky() argument must be square")
@@ -6449,57 +6590,36 @@ def tile_cholesky_generic_lto_dispatch_func(
6449
6590
  if out.type.shape[0] != M or out.type.shape[1] != M:
6450
6591
  raise ValueError("tile_cholesky() output tile must be square")
6451
6592
 
6452
- num_threads = options["block_dim"]
6453
- arch = options["output_arch"]
6454
- lto_symbol = f"potrf_{M}_{N}_{arch}_{precision_enum}"
6455
-
6456
- # early out if LTO for this combination already exists for this module
6457
- if lto_symbol in builder.ltoirs:
6458
- return lto_symbol, builder.ltoirs[lto_symbol]
6459
-
6460
- # otherwise compile LTO
6461
- lto_code = tempfile.NamedTemporaryFile(prefix="warp", delete=False)
6462
- universal_fatbin_code = tempfile.NamedTemporaryFile(prefix="warp", delete=False)
6593
+ solver = "potrf"
6594
+ solver_enum = cusolver_function_map[solver]
6463
6595
 
6464
- # cuSOLVERDx only support col-major input/outputs,
6596
+ # cuSOLVERDx only supports col-major input/outputs,
6465
6597
  # so we use upper to mimic a row-major input
6466
- result = warp.context.runtime.core.cuda_compile_solver(
6467
- universal_fatbin_code.name.encode("utf-8"),
6468
- lto_code.name.encode("utf-8"),
6469
- lto_symbol.encode("utf-8"),
6470
- 0,
6471
- None,
6472
- None,
6473
- arch,
6474
- M,
6475
- N,
6476
- cusolver_function_map["potrf"],
6477
- precision_enum,
6478
- cusolver_fill_mode_map["upper"],
6479
- num_threads,
6480
- )
6598
+ fill_mode = cusolver_fill_mode_map["upper"]
6481
6599
 
6482
- if not result:
6483
- for f in [lto_code, universal_fatbin_code]:
6484
- f.close()
6485
- if Path(f.name).exists():
6486
- Path(f.name).unlink()
6487
- raise RuntimeError("Failed to compile tile_cholesky")
6600
+ arch = options["output_arch"]
6601
+ num_threads = options["block_dim"]
6602
+ parameter_list = f"({dtype}*, unsigned)"
6488
6603
 
6604
+ if arch is None or not warp.context.runtime.core.is_mathdx_enabled():
6605
+ # CPU/no-MathDx dispatch
6606
+ return ((0, a, out), [], [], 0)
6489
6607
  else:
6490
- with open(lto_code.name, "rb") as f:
6491
- lto_code_data = f.read()
6492
- with open(universal_fatbin_code.name, "rb") as f:
6493
- universal_fatbin_code_data = f.read()
6494
- for f in [lto_code, universal_fatbin_code]:
6495
- f.close()
6496
- Path(f.name).unlink()
6497
-
6498
- builder.ltoirs[lto_symbol] = lto_code_data
6499
- builder.ltoirs_decl[lto_symbol] = f"void {lto_symbol}({dtype}*, unsigned);"
6500
- builder.fatbins["cholesky"] = universal_fatbin_code_data
6608
+ # generate the LTO
6609
+ lto_symbol, lto_code_data = warp.build.build_lto_solver(
6610
+ M,
6611
+ N,
6612
+ solver,
6613
+ solver_enum,
6614
+ fill_mode,
6615
+ arch,
6616
+ precision_enum,
6617
+ num_threads,
6618
+ parameter_list,
6619
+ builder,
6620
+ )
6501
6621
 
6502
- return ((Var(lto_symbol, str, False, True, False), a, out), [], [lto_code_data], 0)
6622
+ return ((Var(lto_symbol, str, False, True, False), a, out), [], [lto_code_data], 0)
6503
6623
 
6504
6624
 
6505
6625
  add_builtin(
@@ -6593,57 +6713,36 @@ def tile_cholesky_solve_generic_lto_dispatch_func(
6593
6713
  f"got {y.type.shape[0]} elements in output and {M} rows in 'L'"
6594
6714
  )
6595
6715
 
6596
- num_threads = options["block_dim"]
6597
- arch = options["output_arch"]
6598
- lto_symbol = f"potrs_{M}_{N}_{arch}_{precision_enum}"
6599
-
6600
- # early out if LTO for this combination already exists for this module
6601
- if lto_symbol in builder.ltoirs:
6602
- return lto_symbol, builder.ltoirs[lto_symbol]
6603
-
6604
- # otherwise compile LTO
6605
- lto_code = tempfile.NamedTemporaryFile(prefix="warp", delete=False)
6606
- universal_fatbin_code = tempfile.NamedTemporaryFile(prefix="warp", delete=False)
6716
+ solver = "potrs"
6717
+ solver_enum = cusolver_function_map[solver]
6607
6718
 
6608
- # cuSOLVERDx only support col-major input/outputs,
6719
+ # cuSOLVERDx only supports col-major input/outputs,
6609
6720
  # so we use upper to mimic a row-major input
6610
- result = warp.context.runtime.core.cuda_compile_solver(
6611
- universal_fatbin_code.name.encode("utf-8"),
6612
- lto_code.name.encode("utf-8"),
6613
- lto_symbol.encode("utf-8"),
6614
- 0,
6615
- None,
6616
- None,
6617
- arch,
6618
- M,
6619
- N,
6620
- cusolver_function_map["potrs"],
6621
- precision_enum,
6622
- cusolver_fill_mode_map["upper"],
6623
- num_threads,
6624
- )
6721
+ fill_mode = cusolver_fill_mode_map["upper"]
6625
6722
 
6626
- if not result:
6627
- for f in [lto_code, universal_fatbin_code]:
6628
- f.close()
6629
- if Path(f.name).exists():
6630
- Path(f.name).unlink()
6631
- raise RuntimeError("Failed to compile tile_cholesky_solve")
6723
+ arch = options["output_arch"]
6724
+ num_threads = options["block_dim"]
6725
+ parameter_list = f"({dtype}*, {dtype}*)"
6632
6726
 
6727
+ if arch is None or not warp.context.runtime.core.is_mathdx_enabled():
6728
+ # CPU/no-MathDx dispatch
6729
+ return ((0, L, x, y), [], [], 0)
6633
6730
  else:
6634
- with open(lto_code.name, "rb") as f:
6635
- lto_code_data = f.read()
6636
- with open(universal_fatbin_code.name, "rb") as f:
6637
- universal_fatbin_code_data = f.read()
6638
- for f in [lto_code, universal_fatbin_code]:
6639
- f.close()
6640
- Path(f.name).unlink()
6641
-
6642
- builder.ltoirs[lto_symbol] = lto_code_data
6643
- builder.ltoirs_decl[lto_symbol] = f"void {lto_symbol}({dtype}*, {dtype}*);"
6644
- builder.fatbins["cholesky"] = universal_fatbin_code_data
6645
-
6646
- return ((Var(lto_symbol, str, False, True, False), L, x, y), [], [lto_code_data], 0)
6731
+ # generate the LTO
6732
+ lto_symbol, lto_code_data = warp.build.build_lto_solver(
6733
+ M,
6734
+ N,
6735
+ solver,
6736
+ solver_enum,
6737
+ fill_mode,
6738
+ arch,
6739
+ precision_enum,
6740
+ num_threads,
6741
+ parameter_list,
6742
+ builder,
6743
+ )
6744
+
6745
+ return ((Var(lto_symbol, str, False, True, False), L, x, y), [], [lto_code_data], 0)
6647
6746
 
6648
6747
 
6649
6748
  add_builtin(