warp-lang 1.7.0__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (429) hide show
  1. warp/__init__.py +139 -0
  2. warp/__init__.pyi +1 -0
  3. warp/autograd.py +1142 -0
  4. warp/bin/warp-clang.so +0 -0
  5. warp/bin/warp.so +0 -0
  6. warp/build.py +557 -0
  7. warp/build_dll.py +405 -0
  8. warp/builtins.py +6855 -0
  9. warp/codegen.py +3969 -0
  10. warp/config.py +158 -0
  11. warp/constants.py +57 -0
  12. warp/context.py +6812 -0
  13. warp/dlpack.py +462 -0
  14. warp/examples/__init__.py +24 -0
  15. warp/examples/assets/bear.usd +0 -0
  16. warp/examples/assets/bunny.usd +0 -0
  17. warp/examples/assets/cartpole.urdf +110 -0
  18. warp/examples/assets/crazyflie.usd +0 -0
  19. warp/examples/assets/cube.usd +0 -0
  20. warp/examples/assets/nonuniform.usd +0 -0
  21. warp/examples/assets/nv_ant.xml +92 -0
  22. warp/examples/assets/nv_humanoid.xml +183 -0
  23. warp/examples/assets/nvidia_logo.png +0 -0
  24. warp/examples/assets/pixel.jpg +0 -0
  25. warp/examples/assets/quadruped.urdf +268 -0
  26. warp/examples/assets/rocks.nvdb +0 -0
  27. warp/examples/assets/rocks.usd +0 -0
  28. warp/examples/assets/sphere.usd +0 -0
  29. warp/examples/assets/square_cloth.usd +0 -0
  30. warp/examples/benchmarks/benchmark_api.py +389 -0
  31. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  32. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  33. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  34. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  35. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  36. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  37. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  38. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  39. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  40. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  41. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  42. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  43. warp/examples/benchmarks/benchmark_launches.py +301 -0
  44. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  45. warp/examples/browse.py +37 -0
  46. warp/examples/core/example_cupy.py +86 -0
  47. warp/examples/core/example_dem.py +241 -0
  48. warp/examples/core/example_fluid.py +299 -0
  49. warp/examples/core/example_graph_capture.py +150 -0
  50. warp/examples/core/example_marching_cubes.py +194 -0
  51. warp/examples/core/example_mesh.py +180 -0
  52. warp/examples/core/example_mesh_intersect.py +211 -0
  53. warp/examples/core/example_nvdb.py +182 -0
  54. warp/examples/core/example_raycast.py +111 -0
  55. warp/examples/core/example_raymarch.py +205 -0
  56. warp/examples/core/example_render_opengl.py +193 -0
  57. warp/examples/core/example_sample_mesh.py +300 -0
  58. warp/examples/core/example_sph.py +411 -0
  59. warp/examples/core/example_torch.py +211 -0
  60. warp/examples/core/example_wave.py +269 -0
  61. warp/examples/fem/example_adaptive_grid.py +286 -0
  62. warp/examples/fem/example_apic_fluid.py +423 -0
  63. warp/examples/fem/example_burgers.py +261 -0
  64. warp/examples/fem/example_convection_diffusion.py +178 -0
  65. warp/examples/fem/example_convection_diffusion_dg.py +204 -0
  66. warp/examples/fem/example_deformed_geometry.py +172 -0
  67. warp/examples/fem/example_diffusion.py +196 -0
  68. warp/examples/fem/example_diffusion_3d.py +225 -0
  69. warp/examples/fem/example_diffusion_mgpu.py +220 -0
  70. warp/examples/fem/example_distortion_energy.py +228 -0
  71. warp/examples/fem/example_magnetostatics.py +240 -0
  72. warp/examples/fem/example_mixed_elasticity.py +291 -0
  73. warp/examples/fem/example_navier_stokes.py +261 -0
  74. warp/examples/fem/example_nonconforming_contact.py +298 -0
  75. warp/examples/fem/example_stokes.py +213 -0
  76. warp/examples/fem/example_stokes_transfer.py +262 -0
  77. warp/examples/fem/example_streamlines.py +352 -0
  78. warp/examples/fem/utils.py +1000 -0
  79. warp/examples/interop/example_jax_callable.py +116 -0
  80. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  81. warp/examples/interop/example_jax_kernel.py +205 -0
  82. warp/examples/optim/example_bounce.py +266 -0
  83. warp/examples/optim/example_cloth_throw.py +228 -0
  84. warp/examples/optim/example_diffray.py +561 -0
  85. warp/examples/optim/example_drone.py +870 -0
  86. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  87. warp/examples/optim/example_inverse_kinematics.py +182 -0
  88. warp/examples/optim/example_inverse_kinematics_torch.py +191 -0
  89. warp/examples/optim/example_softbody_properties.py +400 -0
  90. warp/examples/optim/example_spring_cage.py +245 -0
  91. warp/examples/optim/example_trajectory.py +227 -0
  92. warp/examples/sim/example_cartpole.py +143 -0
  93. warp/examples/sim/example_cloth.py +225 -0
  94. warp/examples/sim/example_cloth_self_contact.py +322 -0
  95. warp/examples/sim/example_granular.py +130 -0
  96. warp/examples/sim/example_granular_collision_sdf.py +202 -0
  97. warp/examples/sim/example_jacobian_ik.py +244 -0
  98. warp/examples/sim/example_particle_chain.py +124 -0
  99. warp/examples/sim/example_quadruped.py +203 -0
  100. warp/examples/sim/example_rigid_chain.py +203 -0
  101. warp/examples/sim/example_rigid_contact.py +195 -0
  102. warp/examples/sim/example_rigid_force.py +133 -0
  103. warp/examples/sim/example_rigid_gyroscopic.py +115 -0
  104. warp/examples/sim/example_rigid_soft_contact.py +140 -0
  105. warp/examples/sim/example_soft_body.py +196 -0
  106. warp/examples/tile/example_tile_cholesky.py +87 -0
  107. warp/examples/tile/example_tile_convolution.py +66 -0
  108. warp/examples/tile/example_tile_fft.py +55 -0
  109. warp/examples/tile/example_tile_filtering.py +113 -0
  110. warp/examples/tile/example_tile_matmul.py +85 -0
  111. warp/examples/tile/example_tile_mlp.py +383 -0
  112. warp/examples/tile/example_tile_nbody.py +199 -0
  113. warp/examples/tile/example_tile_walker.py +327 -0
  114. warp/fabric.py +355 -0
  115. warp/fem/__init__.py +106 -0
  116. warp/fem/adaptivity.py +508 -0
  117. warp/fem/cache.py +572 -0
  118. warp/fem/dirichlet.py +202 -0
  119. warp/fem/domain.py +411 -0
  120. warp/fem/field/__init__.py +125 -0
  121. warp/fem/field/field.py +619 -0
  122. warp/fem/field/nodal_field.py +326 -0
  123. warp/fem/field/restriction.py +37 -0
  124. warp/fem/field/virtual.py +848 -0
  125. warp/fem/geometry/__init__.py +32 -0
  126. warp/fem/geometry/adaptive_nanogrid.py +857 -0
  127. warp/fem/geometry/closest_point.py +84 -0
  128. warp/fem/geometry/deformed_geometry.py +221 -0
  129. warp/fem/geometry/element.py +776 -0
  130. warp/fem/geometry/geometry.py +362 -0
  131. warp/fem/geometry/grid_2d.py +392 -0
  132. warp/fem/geometry/grid_3d.py +452 -0
  133. warp/fem/geometry/hexmesh.py +911 -0
  134. warp/fem/geometry/nanogrid.py +571 -0
  135. warp/fem/geometry/partition.py +389 -0
  136. warp/fem/geometry/quadmesh.py +663 -0
  137. warp/fem/geometry/tetmesh.py +855 -0
  138. warp/fem/geometry/trimesh.py +806 -0
  139. warp/fem/integrate.py +2335 -0
  140. warp/fem/linalg.py +419 -0
  141. warp/fem/operator.py +293 -0
  142. warp/fem/polynomial.py +229 -0
  143. warp/fem/quadrature/__init__.py +17 -0
  144. warp/fem/quadrature/pic_quadrature.py +299 -0
  145. warp/fem/quadrature/quadrature.py +591 -0
  146. warp/fem/space/__init__.py +228 -0
  147. warp/fem/space/basis_function_space.py +468 -0
  148. warp/fem/space/basis_space.py +667 -0
  149. warp/fem/space/dof_mapper.py +251 -0
  150. warp/fem/space/function_space.py +309 -0
  151. warp/fem/space/grid_2d_function_space.py +177 -0
  152. warp/fem/space/grid_3d_function_space.py +227 -0
  153. warp/fem/space/hexmesh_function_space.py +257 -0
  154. warp/fem/space/nanogrid_function_space.py +201 -0
  155. warp/fem/space/partition.py +367 -0
  156. warp/fem/space/quadmesh_function_space.py +223 -0
  157. warp/fem/space/restriction.py +179 -0
  158. warp/fem/space/shape/__init__.py +143 -0
  159. warp/fem/space/shape/cube_shape_function.py +1105 -0
  160. warp/fem/space/shape/shape_function.py +133 -0
  161. warp/fem/space/shape/square_shape_function.py +926 -0
  162. warp/fem/space/shape/tet_shape_function.py +834 -0
  163. warp/fem/space/shape/triangle_shape_function.py +672 -0
  164. warp/fem/space/tetmesh_function_space.py +271 -0
  165. warp/fem/space/topology.py +424 -0
  166. warp/fem/space/trimesh_function_space.py +194 -0
  167. warp/fem/types.py +99 -0
  168. warp/fem/utils.py +420 -0
  169. warp/jax.py +187 -0
  170. warp/jax_experimental/__init__.py +16 -0
  171. warp/jax_experimental/custom_call.py +351 -0
  172. warp/jax_experimental/ffi.py +698 -0
  173. warp/jax_experimental/xla_ffi.py +602 -0
  174. warp/math.py +244 -0
  175. warp/native/array.h +1145 -0
  176. warp/native/builtin.h +1800 -0
  177. warp/native/bvh.cpp +492 -0
  178. warp/native/bvh.cu +791 -0
  179. warp/native/bvh.h +554 -0
  180. warp/native/clang/clang.cpp +536 -0
  181. warp/native/coloring.cpp +613 -0
  182. warp/native/crt.cpp +51 -0
  183. warp/native/crt.h +362 -0
  184. warp/native/cuda_crt.h +1058 -0
  185. warp/native/cuda_util.cpp +646 -0
  186. warp/native/cuda_util.h +307 -0
  187. warp/native/error.cpp +77 -0
  188. warp/native/error.h +36 -0
  189. warp/native/exports.h +1878 -0
  190. warp/native/fabric.h +245 -0
  191. warp/native/hashgrid.cpp +311 -0
  192. warp/native/hashgrid.cu +87 -0
  193. warp/native/hashgrid.h +240 -0
  194. warp/native/initializer_array.h +41 -0
  195. warp/native/intersect.h +1230 -0
  196. warp/native/intersect_adj.h +375 -0
  197. warp/native/intersect_tri.h +339 -0
  198. warp/native/marching.cpp +19 -0
  199. warp/native/marching.cu +514 -0
  200. warp/native/marching.h +19 -0
  201. warp/native/mat.h +2220 -0
  202. warp/native/mathdx.cpp +87 -0
  203. warp/native/matnn.h +343 -0
  204. warp/native/mesh.cpp +266 -0
  205. warp/native/mesh.cu +404 -0
  206. warp/native/mesh.h +1980 -0
  207. warp/native/nanovdb/GridHandle.h +366 -0
  208. warp/native/nanovdb/HostBuffer.h +590 -0
  209. warp/native/nanovdb/NanoVDB.h +6624 -0
  210. warp/native/nanovdb/PNanoVDB.h +3390 -0
  211. warp/native/noise.h +859 -0
  212. warp/native/quat.h +1371 -0
  213. warp/native/rand.h +342 -0
  214. warp/native/range.h +139 -0
  215. warp/native/reduce.cpp +174 -0
  216. warp/native/reduce.cu +364 -0
  217. warp/native/runlength_encode.cpp +79 -0
  218. warp/native/runlength_encode.cu +61 -0
  219. warp/native/scan.cpp +47 -0
  220. warp/native/scan.cu +53 -0
  221. warp/native/scan.h +23 -0
  222. warp/native/solid_angle.h +466 -0
  223. warp/native/sort.cpp +251 -0
  224. warp/native/sort.cu +277 -0
  225. warp/native/sort.h +33 -0
  226. warp/native/sparse.cpp +378 -0
  227. warp/native/sparse.cu +524 -0
  228. warp/native/spatial.h +657 -0
  229. warp/native/svd.h +702 -0
  230. warp/native/temp_buffer.h +46 -0
  231. warp/native/tile.h +2584 -0
  232. warp/native/tile_reduce.h +264 -0
  233. warp/native/vec.h +1426 -0
  234. warp/native/volume.cpp +501 -0
  235. warp/native/volume.cu +67 -0
  236. warp/native/volume.h +969 -0
  237. warp/native/volume_builder.cu +477 -0
  238. warp/native/volume_builder.h +52 -0
  239. warp/native/volume_impl.h +70 -0
  240. warp/native/warp.cpp +1082 -0
  241. warp/native/warp.cu +3636 -0
  242. warp/native/warp.h +381 -0
  243. warp/optim/__init__.py +17 -0
  244. warp/optim/adam.py +163 -0
  245. warp/optim/linear.py +1137 -0
  246. warp/optim/sgd.py +112 -0
  247. warp/paddle.py +407 -0
  248. warp/render/__init__.py +18 -0
  249. warp/render/render_opengl.py +3518 -0
  250. warp/render/render_usd.py +784 -0
  251. warp/render/utils.py +160 -0
  252. warp/sim/__init__.py +65 -0
  253. warp/sim/articulation.py +793 -0
  254. warp/sim/collide.py +2395 -0
  255. warp/sim/graph_coloring.py +300 -0
  256. warp/sim/import_mjcf.py +790 -0
  257. warp/sim/import_snu.py +227 -0
  258. warp/sim/import_urdf.py +579 -0
  259. warp/sim/import_usd.py +894 -0
  260. warp/sim/inertia.py +324 -0
  261. warp/sim/integrator.py +242 -0
  262. warp/sim/integrator_euler.py +1997 -0
  263. warp/sim/integrator_featherstone.py +2101 -0
  264. warp/sim/integrator_vbd.py +2048 -0
  265. warp/sim/integrator_xpbd.py +3292 -0
  266. warp/sim/model.py +4791 -0
  267. warp/sim/particles.py +121 -0
  268. warp/sim/render.py +427 -0
  269. warp/sim/utils.py +428 -0
  270. warp/sparse.py +2057 -0
  271. warp/stubs.py +3333 -0
  272. warp/tape.py +1203 -0
  273. warp/tests/__init__.py +1 -0
  274. warp/tests/__main__.py +4 -0
  275. warp/tests/assets/curlnoise_golden.npy +0 -0
  276. warp/tests/assets/mlp_golden.npy +0 -0
  277. warp/tests/assets/pixel.npy +0 -0
  278. warp/tests/assets/pnoise_golden.npy +0 -0
  279. warp/tests/assets/spiky.usd +0 -0
  280. warp/tests/assets/test_grid.nvdb +0 -0
  281. warp/tests/assets/test_index_grid.nvdb +0 -0
  282. warp/tests/assets/test_int32_grid.nvdb +0 -0
  283. warp/tests/assets/test_vec_grid.nvdb +0 -0
  284. warp/tests/assets/torus.nvdb +0 -0
  285. warp/tests/assets/torus.usda +105 -0
  286. warp/tests/aux_test_class_kernel.py +34 -0
  287. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  288. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  289. warp/tests/aux_test_dependent.py +29 -0
  290. warp/tests/aux_test_grad_customs.py +29 -0
  291. warp/tests/aux_test_instancing_gc.py +26 -0
  292. warp/tests/aux_test_module_unload.py +23 -0
  293. warp/tests/aux_test_name_clash1.py +40 -0
  294. warp/tests/aux_test_name_clash2.py +40 -0
  295. warp/tests/aux_test_reference.py +9 -0
  296. warp/tests/aux_test_reference_reference.py +8 -0
  297. warp/tests/aux_test_square.py +16 -0
  298. warp/tests/aux_test_unresolved_func.py +22 -0
  299. warp/tests/aux_test_unresolved_symbol.py +22 -0
  300. warp/tests/cuda/__init__.py +0 -0
  301. warp/tests/cuda/test_async.py +676 -0
  302. warp/tests/cuda/test_ipc.py +124 -0
  303. warp/tests/cuda/test_mempool.py +233 -0
  304. warp/tests/cuda/test_multigpu.py +169 -0
  305. warp/tests/cuda/test_peer.py +139 -0
  306. warp/tests/cuda/test_pinned.py +84 -0
  307. warp/tests/cuda/test_streams.py +634 -0
  308. warp/tests/geometry/__init__.py +0 -0
  309. warp/tests/geometry/test_bvh.py +200 -0
  310. warp/tests/geometry/test_hash_grid.py +221 -0
  311. warp/tests/geometry/test_marching_cubes.py +74 -0
  312. warp/tests/geometry/test_mesh.py +316 -0
  313. warp/tests/geometry/test_mesh_query_aabb.py +399 -0
  314. warp/tests/geometry/test_mesh_query_point.py +932 -0
  315. warp/tests/geometry/test_mesh_query_ray.py +311 -0
  316. warp/tests/geometry/test_volume.py +1103 -0
  317. warp/tests/geometry/test_volume_write.py +346 -0
  318. warp/tests/interop/__init__.py +0 -0
  319. warp/tests/interop/test_dlpack.py +729 -0
  320. warp/tests/interop/test_jax.py +371 -0
  321. warp/tests/interop/test_paddle.py +800 -0
  322. warp/tests/interop/test_torch.py +1001 -0
  323. warp/tests/run_coverage_serial.py +39 -0
  324. warp/tests/sim/__init__.py +0 -0
  325. warp/tests/sim/disabled_kinematics.py +244 -0
  326. warp/tests/sim/flaky_test_sim_grad.py +290 -0
  327. warp/tests/sim/test_collision.py +604 -0
  328. warp/tests/sim/test_coloring.py +258 -0
  329. warp/tests/sim/test_model.py +224 -0
  330. warp/tests/sim/test_sim_grad_bounce_linear.py +212 -0
  331. warp/tests/sim/test_sim_kinematics.py +98 -0
  332. warp/tests/sim/test_vbd.py +597 -0
  333. warp/tests/test_adam.py +163 -0
  334. warp/tests/test_arithmetic.py +1096 -0
  335. warp/tests/test_array.py +2972 -0
  336. warp/tests/test_array_reduce.py +156 -0
  337. warp/tests/test_assert.py +250 -0
  338. warp/tests/test_atomic.py +153 -0
  339. warp/tests/test_bool.py +220 -0
  340. warp/tests/test_builtins_resolution.py +1298 -0
  341. warp/tests/test_closest_point_edge_edge.py +327 -0
  342. warp/tests/test_codegen.py +810 -0
  343. warp/tests/test_codegen_instancing.py +1495 -0
  344. warp/tests/test_compile_consts.py +215 -0
  345. warp/tests/test_conditional.py +252 -0
  346. warp/tests/test_context.py +42 -0
  347. warp/tests/test_copy.py +238 -0
  348. warp/tests/test_ctypes.py +638 -0
  349. warp/tests/test_dense.py +73 -0
  350. warp/tests/test_devices.py +97 -0
  351. warp/tests/test_examples.py +482 -0
  352. warp/tests/test_fabricarray.py +996 -0
  353. warp/tests/test_fast_math.py +74 -0
  354. warp/tests/test_fem.py +2003 -0
  355. warp/tests/test_fp16.py +136 -0
  356. warp/tests/test_func.py +454 -0
  357. warp/tests/test_future_annotations.py +98 -0
  358. warp/tests/test_generics.py +656 -0
  359. warp/tests/test_grad.py +893 -0
  360. warp/tests/test_grad_customs.py +339 -0
  361. warp/tests/test_grad_debug.py +341 -0
  362. warp/tests/test_implicit_init.py +411 -0
  363. warp/tests/test_import.py +45 -0
  364. warp/tests/test_indexedarray.py +1140 -0
  365. warp/tests/test_intersect.py +73 -0
  366. warp/tests/test_iter.py +76 -0
  367. warp/tests/test_large.py +177 -0
  368. warp/tests/test_launch.py +411 -0
  369. warp/tests/test_lerp.py +151 -0
  370. warp/tests/test_linear_solvers.py +193 -0
  371. warp/tests/test_lvalue.py +427 -0
  372. warp/tests/test_mat.py +2089 -0
  373. warp/tests/test_mat_lite.py +122 -0
  374. warp/tests/test_mat_scalar_ops.py +2913 -0
  375. warp/tests/test_math.py +178 -0
  376. warp/tests/test_mlp.py +282 -0
  377. warp/tests/test_module_hashing.py +258 -0
  378. warp/tests/test_modules_lite.py +44 -0
  379. warp/tests/test_noise.py +252 -0
  380. warp/tests/test_operators.py +299 -0
  381. warp/tests/test_options.py +129 -0
  382. warp/tests/test_overwrite.py +551 -0
  383. warp/tests/test_print.py +339 -0
  384. warp/tests/test_quat.py +2315 -0
  385. warp/tests/test_rand.py +339 -0
  386. warp/tests/test_reload.py +302 -0
  387. warp/tests/test_rounding.py +185 -0
  388. warp/tests/test_runlength_encode.py +196 -0
  389. warp/tests/test_scalar_ops.py +105 -0
  390. warp/tests/test_smoothstep.py +108 -0
  391. warp/tests/test_snippet.py +318 -0
  392. warp/tests/test_sparse.py +582 -0
  393. warp/tests/test_spatial.py +2229 -0
  394. warp/tests/test_special_values.py +361 -0
  395. warp/tests/test_static.py +592 -0
  396. warp/tests/test_struct.py +734 -0
  397. warp/tests/test_tape.py +204 -0
  398. warp/tests/test_transient_module.py +93 -0
  399. warp/tests/test_triangle_closest_point.py +145 -0
  400. warp/tests/test_types.py +562 -0
  401. warp/tests/test_utils.py +588 -0
  402. warp/tests/test_vec.py +1487 -0
  403. warp/tests/test_vec_lite.py +80 -0
  404. warp/tests/test_vec_scalar_ops.py +2327 -0
  405. warp/tests/test_verify_fp.py +100 -0
  406. warp/tests/tile/__init__.py +0 -0
  407. warp/tests/tile/test_tile.py +780 -0
  408. warp/tests/tile/test_tile_load.py +407 -0
  409. warp/tests/tile/test_tile_mathdx.py +208 -0
  410. warp/tests/tile/test_tile_mlp.py +402 -0
  411. warp/tests/tile/test_tile_reduce.py +447 -0
  412. warp/tests/tile/test_tile_shared_memory.py +247 -0
  413. warp/tests/tile/test_tile_view.py +173 -0
  414. warp/tests/unittest_serial.py +47 -0
  415. warp/tests/unittest_suites.py +427 -0
  416. warp/tests/unittest_utils.py +468 -0
  417. warp/tests/walkthrough_debug.py +93 -0
  418. warp/thirdparty/__init__.py +0 -0
  419. warp/thirdparty/appdirs.py +598 -0
  420. warp/thirdparty/dlpack.py +145 -0
  421. warp/thirdparty/unittest_parallel.py +570 -0
  422. warp/torch.py +391 -0
  423. warp/types.py +5230 -0
  424. warp/utils.py +1137 -0
  425. warp_lang-1.7.0.dist-info/METADATA +516 -0
  426. warp_lang-1.7.0.dist-info/RECORD +429 -0
  427. warp_lang-1.7.0.dist-info/WHEEL +5 -0
  428. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  429. warp_lang-1.7.0.dist-info/top_level.txt +1 -0
warp/native/sparse.cu ADDED
@@ -0,0 +1,524 @@
1
+ /*
2
+ * SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: Apache-2.0
4
+ *
5
+ * Licensed under the Apache License, Version 2.0 (the "License");
6
+ * you may not use this file except in compliance with the License.
7
+ * You may obtain a copy of the License at
8
+ *
9
+ * http://www.apache.org/licenses/LICENSE-2.0
10
+ *
11
+ * Unless required by applicable law or agreed to in writing, software
12
+ * distributed under the License is distributed on an "AS IS" BASIS,
13
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ * See the License for the specific language governing permissions and
15
+ * limitations under the License.
16
+ */
17
+
18
+ #include "cuda_util.h"
19
+ #include "warp.h"
20
+
21
+ #define THRUST_IGNORE_CUB_VERSION_CHECK
22
+
23
+ #include <cub/device/device_radix_sort.cuh>
24
+ #include <cub/device/device_run_length_encode.cuh>
25
+ #include <cub/device/device_scan.cuh>
26
+
27
+ namespace
28
+ {
29
+
30
+ // Combined row+column value that can be radix-sorted with CUB
31
+ using BsrRowCol = uint64_t;
32
+
33
+ static constexpr BsrRowCol PRUNED_ROWCOL = ~BsrRowCol(0);
34
+
35
+ CUDA_CALLABLE BsrRowCol bsr_combine_row_col(uint32_t row, uint32_t col)
36
+ {
37
+ return (static_cast<uint64_t>(row) << 32) | col;
38
+ }
39
+
40
+ CUDA_CALLABLE uint32_t bsr_get_row(const BsrRowCol& row_col) { return row_col >> 32; }
41
+
42
+ CUDA_CALLABLE uint32_t bsr_get_col(const BsrRowCol& row_col) { return row_col & INT_MAX; }
43
+
44
+ template <typename T> struct BsrBlockIsNotZero
45
+ {
46
+ int block_size;
47
+ const T* values;
48
+
49
+ CUDA_CALLABLE_DEVICE bool operator()(int i) const
50
+ {
51
+ if (!values)
52
+ return true;
53
+
54
+ const T* val = values + i * block_size;
55
+ for (int i = 0; i < block_size; ++i, ++val)
56
+ {
57
+ if (*val != T(0))
58
+ return true;
59
+ }
60
+ return false;
61
+ }
62
+ };
63
+
64
+ struct BsrBlockInMask
65
+ {
66
+ const int* bsr_offsets;
67
+ const int* bsr_columns;
68
+
69
+ CUDA_CALLABLE_DEVICE bool operator()(int row, int col) const
70
+ {
71
+ if (bsr_offsets == nullptr)
72
+ return true;
73
+
74
+ int lower = bsr_offsets[row];
75
+ int upper = bsr_offsets[row + 1] - 1;
76
+
77
+ while (lower < upper)
78
+ {
79
+ const int mid = lower + (upper - lower) / 2;
80
+
81
+ if (bsr_columns[mid] < col)
82
+ {
83
+ lower = mid + 1;
84
+ }
85
+ else
86
+ {
87
+ upper = mid;
88
+ }
89
+ }
90
+
91
+ return lower == upper && (bsr_columns[lower] == col);
92
+ }
93
+ };
94
+
95
+ template <typename T>
96
+ __global__ void bsr_fill_triplet_key_values(const int nnz, const int nrow, const int* tpl_rows, const int* tpl_columns,
97
+ const BsrBlockIsNotZero<T> nonZero, const BsrBlockInMask mask,
98
+ uint32_t* block_indices, BsrRowCol* tpl_row_col)
99
+ {
100
+ int block = blockIdx.x * blockDim.x + threadIdx.x;
101
+ if (block >= nnz)
102
+ return;
103
+
104
+ const int row = tpl_rows[block];
105
+ const int col = tpl_columns[block];
106
+ const bool is_valid = row >= 0 && row < nrow;
107
+
108
+ const BsrRowCol row_col =
109
+ is_valid && nonZero(block) && mask(row, col) ? bsr_combine_row_col(row, col) : PRUNED_ROWCOL;
110
+ tpl_row_col[block] = row_col;
111
+ block_indices[block] = block;
112
+ }
113
+
114
+ template <typename T>
115
+ __global__ void bsr_find_row_offsets(uint32_t row_count, const T* d_nnz, const BsrRowCol* unique_row_col,
116
+ int* row_offsets)
117
+ {
118
+ const uint32_t row = blockIdx.x * blockDim.x + threadIdx.x;
119
+
120
+ if (row > row_count)
121
+ return;
122
+
123
+ const uint32_t nnz = *d_nnz;
124
+ if (row == 0 || nnz == 0)
125
+ {
126
+ row_offsets[row] = 0;
127
+ return;
128
+ }
129
+
130
+ if (bsr_get_row(unique_row_col[nnz - 1]) < row)
131
+ {
132
+ row_offsets[row] = nnz;
133
+ return;
134
+ }
135
+
136
+ // binary search for row start
137
+ uint32_t lower = 0;
138
+ uint32_t upper = nnz - 1;
139
+ while (lower < upper)
140
+ {
141
+ uint32_t mid = lower + (upper - lower) / 2;
142
+
143
+ if (bsr_get_row(unique_row_col[mid]) < row)
144
+ {
145
+ lower = mid + 1;
146
+ }
147
+ else
148
+ {
149
+ upper = mid;
150
+ }
151
+ }
152
+
153
+ row_offsets[row] = lower;
154
+ }
155
+
156
+ template <typename T>
157
+ __global__ void bsr_merge_blocks(const int* d_nnz, int block_size, const uint32_t* block_offsets,
158
+ const uint32_t* sorted_block_indices, const BsrRowCol* unique_row_cols,
159
+ const T* tpl_values, int* bsr_cols, T* bsr_values)
160
+
161
+ {
162
+ const uint32_t i = blockIdx.x * blockDim.x + threadIdx.x;
163
+
164
+ if (i >= *d_nnz)
165
+ return;
166
+
167
+ const BsrRowCol row_col = unique_row_cols[i];
168
+ bsr_cols[i] = bsr_get_col(row_col);
169
+
170
+ // Accumulate merged block values
171
+ if (row_col == PRUNED_ROWCOL || bsr_values == nullptr)
172
+ return;
173
+
174
+ const uint32_t beg = i ? block_offsets[i - 1] : 0;
175
+ const uint32_t end = block_offsets[i];
176
+
177
+ T* bsr_val = bsr_values + i * block_size;
178
+ const T* tpl_val = tpl_values + sorted_block_indices[beg] * block_size;
179
+
180
+ for (int k = 0; k < block_size; ++k)
181
+ {
182
+ bsr_val[k] = tpl_val[k];
183
+ }
184
+
185
+ for (uint32_t cur = beg + 1; cur != end; ++cur)
186
+ {
187
+ const T* tpl_val = tpl_values + sorted_block_indices[cur] * block_size;
188
+ for (int k = 0; k < block_size; ++k)
189
+ {
190
+ bsr_val[k] += tpl_val[k];
191
+ }
192
+ }
193
+ }
194
+
195
+ template <typename T>
196
+ void bsr_matrix_from_triplets_device(const int rows_per_block, const int cols_per_block, const int row_count,
197
+ const int nnz, const int* tpl_rows, const int* tpl_columns, const T* tpl_values,
198
+ const bool prune_numerical_zeros, const bool masked, int* bsr_offsets,
199
+ int* bsr_columns, T* bsr_values, int* bsr_nnz, void* bsr_nnz_event)
200
+ {
201
+ const int block_size = rows_per_block * cols_per_block;
202
+
203
+ void* context = cuda_context_get_current();
204
+ ContextGuard guard(context);
205
+
206
+ // Per-context cached temporary buffers
207
+ // BsrFromTripletsTemp& bsr_temp = g_bsr_from_triplets_temp_map[context];
208
+
209
+ cudaStream_t stream = static_cast<cudaStream_t>(cuda_stream_get_current());
210
+
211
+ ScopedTemporary<uint32_t> block_indices(context, 2 * nnz + 1);
212
+ ScopedTemporary<BsrRowCol> combined_row_col(context, 2 * nnz);
213
+
214
+ cub::DoubleBuffer<uint32_t> d_keys(block_indices.buffer(), block_indices.buffer() + nnz);
215
+ cub::DoubleBuffer<BsrRowCol> d_values(combined_row_col.buffer(), combined_row_col.buffer() + nnz);
216
+
217
+ uint32_t* unique_triplet_count = block_indices.buffer() + 2 * nnz;
218
+
219
+ // Combine rows and columns so we can sort on them both
220
+ BsrBlockIsNotZero<T> isNotZero{block_size, prune_numerical_zeros ? tpl_values : nullptr};
221
+ BsrBlockInMask mask{masked ? bsr_offsets : nullptr, bsr_columns};
222
+ wp_launch_device(WP_CURRENT_CONTEXT, bsr_fill_triplet_key_values, nnz,
223
+ (nnz, row_count, tpl_rows, tpl_columns, isNotZero, mask, d_keys.Current(), d_values.Current()));
224
+
225
+ // Sort
226
+ {
227
+ size_t buff_size = 0;
228
+ check_cuda(cub::DeviceRadixSort::SortPairs(nullptr, buff_size, d_values, d_keys, nnz, 0, 64, stream));
229
+ ScopedTemporary<> temp(context, buff_size);
230
+ check_cuda(cub::DeviceRadixSort::SortPairs(temp.buffer(), buff_size, d_values, d_keys, nnz, 0, 64, stream));
231
+ }
232
+
233
+ // Runlength encode row-col sequences
234
+ {
235
+ size_t buff_size = 0;
236
+ check_cuda(cub::DeviceRunLengthEncode::Encode(nullptr, buff_size, d_values.Current(), d_values.Alternate(),
237
+ d_keys.Alternate(), unique_triplet_count, nnz, stream));
238
+ ScopedTemporary<> temp(context, buff_size);
239
+ check_cuda(cub::DeviceRunLengthEncode::Encode(temp.buffer(), buff_size, d_values.Current(),
240
+ d_values.Alternate(), d_keys.Alternate(), unique_triplet_count,
241
+ nnz, stream));
242
+ }
243
+
244
+ // Compute row offsets from sorted unique blocks
245
+ wp_launch_device(WP_CURRENT_CONTEXT, bsr_find_row_offsets, row_count + 1,
246
+ (row_count, unique_triplet_count, d_values.Alternate(), bsr_offsets));
247
+
248
+ if (bsr_nnz)
249
+ {
250
+ // Copy nnz to host, and record an event for the completed transfer if desired
251
+
252
+ memcpy_d2h(WP_CURRENT_CONTEXT, bsr_nnz, bsr_offsets + row_count, sizeof(int), stream);
253
+
254
+ if (bsr_nnz_event)
255
+ {
256
+ cuda_event_record(bsr_nnz_event, stream);
257
+ }
258
+ }
259
+
260
+ // Scan repeated block counts
261
+ {
262
+ size_t buff_size = 0;
263
+ check_cuda(
264
+ cub::DeviceScan::InclusiveSum(nullptr, buff_size, d_keys.Alternate(), d_keys.Alternate(), nnz, stream));
265
+ ScopedTemporary<> temp(context, buff_size);
266
+ check_cuda(cub::DeviceScan::InclusiveSum(temp.buffer(), buff_size, d_keys.Alternate(), d_keys.Alternate(), nnz,
267
+ stream));
268
+ }
269
+
270
+ // Accumulate repeated blocks and set column indices
271
+ wp_launch_device(WP_CURRENT_CONTEXT, bsr_merge_blocks, nnz,
272
+ (bsr_offsets + row_count, block_size, d_keys.Alternate(), d_keys.Current(), d_values.Alternate(),
273
+ tpl_values, bsr_columns, bsr_values));
274
+ }
275
+
276
+ __global__ void bsr_transpose_fill_row_col(const int nnz_upper_bound, const int row_count, const int* bsr_offsets,
277
+ const int* bsr_columns, int* block_indices, BsrRowCol* transposed_row_col)
278
+ {
279
+ int i = blockIdx.x * blockDim.x + threadIdx.x;
280
+
281
+ if (i >= nnz_upper_bound)
282
+ {
283
+ // Outside of allocated bounds, do nothing
284
+ return;
285
+ }
286
+
287
+ if (i >= bsr_offsets[row_count])
288
+ {
289
+ // Below upper bound but above actual nnz count, mark as invalid
290
+ transposed_row_col[i] = PRUNED_ROWCOL;
291
+ return;
292
+ }
293
+
294
+ block_indices[i] = i;
295
+
296
+ // Binary search for row
297
+ int lower = 0;
298
+ int upper = row_count - 1;
299
+
300
+ while (lower < upper)
301
+ {
302
+ int mid = lower + (upper - lower) / 2;
303
+
304
+ if (bsr_offsets[mid + 1] <= i)
305
+ {
306
+ lower = mid + 1;
307
+ }
308
+ else
309
+ {
310
+ upper = mid;
311
+ }
312
+ }
313
+
314
+ const int row = lower;
315
+ const int col = bsr_columns[i];
316
+ BsrRowCol row_col = bsr_combine_row_col(col, row);
317
+ transposed_row_col[i] = row_col;
318
+ }
319
+
320
+ template <int Rows, int Cols, typename T> struct BsrBlockTransposer
321
+ {
322
+ void CUDA_CALLABLE_DEVICE operator()(const T* src, T* dest) const
323
+ {
324
+ for (int r = 0; r < Rows; ++r)
325
+ {
326
+ for (int c = 0; c < Cols; ++c)
327
+ {
328
+ dest[c * Rows + r] = src[r * Cols + c];
329
+ }
330
+ }
331
+ }
332
+ };
333
+
334
+ template <typename T> struct BsrBlockTransposer<-1, -1, T>
335
+ {
336
+
337
+ int row_count;
338
+ int col_count;
339
+
340
+ void CUDA_CALLABLE_DEVICE operator()(const T* src, T* dest) const
341
+ {
342
+ for (int r = 0; r < row_count; ++r)
343
+ {
344
+ for (int c = 0; c < col_count; ++c)
345
+ {
346
+ dest[c * row_count + r] = src[r * col_count + c];
347
+ }
348
+ }
349
+ }
350
+ };
351
+
352
+ template <int Rows, int Cols, typename T>
353
+ __global__ void bsr_transpose_blocks(const int* nnz, const int block_size, BsrBlockTransposer<Rows, Cols, T> transposer,
354
+ const int* block_indices, const BsrRowCol* transposed_indices, const T* bsr_values,
355
+ int* transposed_bsr_columns, T* transposed_bsr_values)
356
+ {
357
+ int i = blockIdx.x * blockDim.x + threadIdx.x;
358
+ if (i >= *nnz)
359
+ return;
360
+
361
+ const int src_idx = block_indices[i];
362
+
363
+ transposer(bsr_values + src_idx * block_size, transposed_bsr_values + i * block_size);
364
+
365
+ transposed_bsr_columns[i] = bsr_get_col(transposed_indices[i]);
366
+ }
367
+
368
+ template <typename T>
369
+ void launch_bsr_transpose_blocks(int nnz, const int* d_nnz, const int block_size, const int rows_per_block,
370
+ const int cols_per_block, const int* block_indices,
371
+ const BsrRowCol* transposed_indices, const T* bsr_values, int* transposed_bsr_columns,
372
+ T* transposed_bsr_values)
373
+ {
374
+
375
+ switch (rows_per_block)
376
+ {
377
+ case 1:
378
+ switch (cols_per_block)
379
+ {
380
+ case 1:
381
+ wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
382
+ (d_nnz, block_size, BsrBlockTransposer<1, 1, T>{}, block_indices, transposed_indices,
383
+ bsr_values, transposed_bsr_columns, transposed_bsr_values));
384
+ return;
385
+ case 2:
386
+ wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
387
+ (d_nnz, block_size, BsrBlockTransposer<1, 2, T>{}, block_indices, transposed_indices,
388
+ bsr_values, transposed_bsr_columns, transposed_bsr_values));
389
+ return;
390
+ case 3:
391
+ wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
392
+ (d_nnz, block_size, BsrBlockTransposer<1, 3, T>{}, block_indices, transposed_indices,
393
+ bsr_values, transposed_bsr_columns, transposed_bsr_values));
394
+ return;
395
+ }
396
+ case 2:
397
+ switch (cols_per_block)
398
+ {
399
+ case 1:
400
+ wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
401
+ (d_nnz, block_size, BsrBlockTransposer<2, 1, T>{}, block_indices, transposed_indices,
402
+ bsr_values, transposed_bsr_columns, transposed_bsr_values));
403
+ return;
404
+ case 2:
405
+ wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
406
+ (d_nnz, block_size, BsrBlockTransposer<2, 2, T>{}, block_indices, transposed_indices,
407
+ bsr_values, transposed_bsr_columns, transposed_bsr_values));
408
+ return;
409
+ case 3:
410
+ wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
411
+ (d_nnz, block_size, BsrBlockTransposer<2, 3, T>{}, block_indices, transposed_indices,
412
+ bsr_values, transposed_bsr_columns, transposed_bsr_values));
413
+ return;
414
+ }
415
+ case 3:
416
+ switch (cols_per_block)
417
+ {
418
+ case 1:
419
+ wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
420
+ (d_nnz, block_size, BsrBlockTransposer<3, 1, T>{}, block_indices, transposed_indices,
421
+ bsr_values, transposed_bsr_columns, transposed_bsr_values));
422
+ return;
423
+ case 2:
424
+ wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
425
+ (d_nnz, block_size, BsrBlockTransposer<3, 2, T>{}, block_indices, transposed_indices,
426
+ bsr_values, transposed_bsr_columns, transposed_bsr_values));
427
+ return;
428
+ case 3:
429
+ wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
430
+ (d_nnz, block_size, BsrBlockTransposer<3, 3, T>{}, block_indices, transposed_indices,
431
+ bsr_values, transposed_bsr_columns, transposed_bsr_values));
432
+ return;
433
+ }
434
+ }
435
+
436
+ wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
437
+ (d_nnz, block_size, BsrBlockTransposer<-1, -1, T>{rows_per_block, cols_per_block}, block_indices,
438
+ transposed_indices, bsr_values, transposed_bsr_columns, transposed_bsr_values));
439
+ }
440
+
441
+ template <typename T>
442
+ void bsr_transpose_device(int rows_per_block, int cols_per_block, int row_count, int col_count, int nnz,
443
+ const int* bsr_offsets, const int* bsr_columns, const T* bsr_values,
444
+ int* transposed_bsr_offsets, int* transposed_bsr_columns, T* transposed_bsr_values)
445
+ {
446
+
447
+ const int block_size = rows_per_block * cols_per_block;
448
+
449
+ void* context = cuda_context_get_current();
450
+ ContextGuard guard(context);
451
+
452
+ cudaStream_t stream = static_cast<cudaStream_t>(cuda_stream_get_current());
453
+
454
+ ScopedTemporary<int> block_indices(context, 2 * nnz);
455
+ ScopedTemporary<BsrRowCol> combined_row_col(context, 2 * nnz);
456
+
457
+ cub::DoubleBuffer<int> d_keys(block_indices.buffer(), block_indices.buffer() + nnz);
458
+ cub::DoubleBuffer<BsrRowCol> d_values(combined_row_col.buffer(), combined_row_col.buffer() + nnz);
459
+
460
+ wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_fill_row_col, nnz,
461
+ (nnz, row_count, bsr_offsets, bsr_columns, d_keys.Current(), d_values.Current()));
462
+
463
+ // Sort blocks
464
+ {
465
+ size_t buff_size = 0;
466
+ check_cuda(cub::DeviceRadixSort::SortPairs(nullptr, buff_size, d_values, d_keys, nnz, 0, 64, stream));
467
+ ScopedTemporary<> temp(context, buff_size);
468
+ check_cuda(cub::DeviceRadixSort::SortPairs(temp.buffer(), buff_size, d_values, d_keys, nnz, 0, 64, stream));
469
+ }
470
+
471
+ // Compute row offsets from sorted unique blocks
472
+ wp_launch_device(WP_CURRENT_CONTEXT, bsr_find_row_offsets, col_count + 1,
473
+ (col_count, bsr_offsets + row_count, d_values.Current(), transposed_bsr_offsets));
474
+
475
+ // Move and transpose individual blocks
476
+ if (transposed_bsr_values != nullptr)
477
+ {
478
+ launch_bsr_transpose_blocks(nnz, bsr_offsets + row_count, block_size, rows_per_block, cols_per_block,
479
+ d_keys.Current(), d_values.Current(), bsr_values, transposed_bsr_columns,
480
+ transposed_bsr_values);
481
+ }
482
+ }
483
+
484
+ } // namespace
485
+
486
+ void bsr_matrix_from_triplets_float_device(int rows_per_block, int cols_per_block, int row_count, int nnz,
487
+ int* tpl_rows, int* tpl_columns, void* tpl_values,
488
+ bool prune_numerical_zeros, bool masked, int* bsr_offsets, int* bsr_columns,
489
+ void* bsr_values, int* bsr_nnz, void* bsr_nnz_event)
490
+ {
491
+ return bsr_matrix_from_triplets_device<float>(rows_per_block, cols_per_block, row_count, nnz, tpl_rows, tpl_columns,
492
+ static_cast<const float*>(tpl_values), prune_numerical_zeros, masked,
493
+ bsr_offsets, bsr_columns, static_cast<float*>(bsr_values), bsr_nnz,
494
+ bsr_nnz_event);
495
+ }
496
+
497
+ void bsr_matrix_from_triplets_double_device(int rows_per_block, int cols_per_block, int row_count, int nnz,
498
+ int* tpl_rows, int* tpl_columns, void* tpl_values,
499
+ bool prune_numerical_zeros, bool masked, int* bsr_offsets, int* bsr_columns,
500
+ void* bsr_values, int* bsr_nnz, void* bsr_nnz_event)
501
+ {
502
+ return bsr_matrix_from_triplets_device<double>(rows_per_block, cols_per_block, row_count, nnz, tpl_rows,
503
+ tpl_columns, static_cast<const double*>(tpl_values),
504
+ prune_numerical_zeros, masked, bsr_offsets, bsr_columns,
505
+ static_cast<double*>(bsr_values), bsr_nnz, bsr_nnz_event);
506
+ }
507
+
508
+ void bsr_transpose_float_device(int rows_per_block, int cols_per_block, int row_count, int col_count, int nnz,
509
+ int* bsr_offsets, int* bsr_columns, void* bsr_values, int* transposed_bsr_offsets,
510
+ int* transposed_bsr_columns, void* transposed_bsr_values)
511
+ {
512
+ bsr_transpose_device(rows_per_block, cols_per_block, row_count, col_count, nnz, bsr_offsets, bsr_columns,
513
+ static_cast<const float*>(bsr_values), transposed_bsr_offsets, transposed_bsr_columns,
514
+ static_cast<float*>(transposed_bsr_values));
515
+ }
516
+
517
+ void bsr_transpose_double_device(int rows_per_block, int cols_per_block, int row_count, int col_count, int nnz,
518
+ int* bsr_offsets, int* bsr_columns, void* bsr_values, int* transposed_bsr_offsets,
519
+ int* transposed_bsr_columns, void* transposed_bsr_values)
520
+ {
521
+ bsr_transpose_device(rows_per_block, cols_per_block, row_count, col_count, nnz, bsr_offsets, bsr_columns,
522
+ static_cast<const double*>(bsr_values), transposed_bsr_offsets, transposed_bsr_columns,
523
+ static_cast<double*>(transposed_bsr_values));
524
+ }