warp-lang 1.7.0__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (429) hide show
  1. warp/__init__.py +139 -0
  2. warp/__init__.pyi +1 -0
  3. warp/autograd.py +1142 -0
  4. warp/bin/warp-clang.so +0 -0
  5. warp/bin/warp.so +0 -0
  6. warp/build.py +557 -0
  7. warp/build_dll.py +405 -0
  8. warp/builtins.py +6855 -0
  9. warp/codegen.py +3969 -0
  10. warp/config.py +158 -0
  11. warp/constants.py +57 -0
  12. warp/context.py +6812 -0
  13. warp/dlpack.py +462 -0
  14. warp/examples/__init__.py +24 -0
  15. warp/examples/assets/bear.usd +0 -0
  16. warp/examples/assets/bunny.usd +0 -0
  17. warp/examples/assets/cartpole.urdf +110 -0
  18. warp/examples/assets/crazyflie.usd +0 -0
  19. warp/examples/assets/cube.usd +0 -0
  20. warp/examples/assets/nonuniform.usd +0 -0
  21. warp/examples/assets/nv_ant.xml +92 -0
  22. warp/examples/assets/nv_humanoid.xml +183 -0
  23. warp/examples/assets/nvidia_logo.png +0 -0
  24. warp/examples/assets/pixel.jpg +0 -0
  25. warp/examples/assets/quadruped.urdf +268 -0
  26. warp/examples/assets/rocks.nvdb +0 -0
  27. warp/examples/assets/rocks.usd +0 -0
  28. warp/examples/assets/sphere.usd +0 -0
  29. warp/examples/assets/square_cloth.usd +0 -0
  30. warp/examples/benchmarks/benchmark_api.py +389 -0
  31. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  32. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  33. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  34. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  35. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  36. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  37. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  38. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  39. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  40. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  41. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  42. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  43. warp/examples/benchmarks/benchmark_launches.py +301 -0
  44. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  45. warp/examples/browse.py +37 -0
  46. warp/examples/core/example_cupy.py +86 -0
  47. warp/examples/core/example_dem.py +241 -0
  48. warp/examples/core/example_fluid.py +299 -0
  49. warp/examples/core/example_graph_capture.py +150 -0
  50. warp/examples/core/example_marching_cubes.py +194 -0
  51. warp/examples/core/example_mesh.py +180 -0
  52. warp/examples/core/example_mesh_intersect.py +211 -0
  53. warp/examples/core/example_nvdb.py +182 -0
  54. warp/examples/core/example_raycast.py +111 -0
  55. warp/examples/core/example_raymarch.py +205 -0
  56. warp/examples/core/example_render_opengl.py +193 -0
  57. warp/examples/core/example_sample_mesh.py +300 -0
  58. warp/examples/core/example_sph.py +411 -0
  59. warp/examples/core/example_torch.py +211 -0
  60. warp/examples/core/example_wave.py +269 -0
  61. warp/examples/fem/example_adaptive_grid.py +286 -0
  62. warp/examples/fem/example_apic_fluid.py +423 -0
  63. warp/examples/fem/example_burgers.py +261 -0
  64. warp/examples/fem/example_convection_diffusion.py +178 -0
  65. warp/examples/fem/example_convection_diffusion_dg.py +204 -0
  66. warp/examples/fem/example_deformed_geometry.py +172 -0
  67. warp/examples/fem/example_diffusion.py +196 -0
  68. warp/examples/fem/example_diffusion_3d.py +225 -0
  69. warp/examples/fem/example_diffusion_mgpu.py +220 -0
  70. warp/examples/fem/example_distortion_energy.py +228 -0
  71. warp/examples/fem/example_magnetostatics.py +240 -0
  72. warp/examples/fem/example_mixed_elasticity.py +291 -0
  73. warp/examples/fem/example_navier_stokes.py +261 -0
  74. warp/examples/fem/example_nonconforming_contact.py +298 -0
  75. warp/examples/fem/example_stokes.py +213 -0
  76. warp/examples/fem/example_stokes_transfer.py +262 -0
  77. warp/examples/fem/example_streamlines.py +352 -0
  78. warp/examples/fem/utils.py +1000 -0
  79. warp/examples/interop/example_jax_callable.py +116 -0
  80. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  81. warp/examples/interop/example_jax_kernel.py +205 -0
  82. warp/examples/optim/example_bounce.py +266 -0
  83. warp/examples/optim/example_cloth_throw.py +228 -0
  84. warp/examples/optim/example_diffray.py +561 -0
  85. warp/examples/optim/example_drone.py +870 -0
  86. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  87. warp/examples/optim/example_inverse_kinematics.py +182 -0
  88. warp/examples/optim/example_inverse_kinematics_torch.py +191 -0
  89. warp/examples/optim/example_softbody_properties.py +400 -0
  90. warp/examples/optim/example_spring_cage.py +245 -0
  91. warp/examples/optim/example_trajectory.py +227 -0
  92. warp/examples/sim/example_cartpole.py +143 -0
  93. warp/examples/sim/example_cloth.py +225 -0
  94. warp/examples/sim/example_cloth_self_contact.py +322 -0
  95. warp/examples/sim/example_granular.py +130 -0
  96. warp/examples/sim/example_granular_collision_sdf.py +202 -0
  97. warp/examples/sim/example_jacobian_ik.py +244 -0
  98. warp/examples/sim/example_particle_chain.py +124 -0
  99. warp/examples/sim/example_quadruped.py +203 -0
  100. warp/examples/sim/example_rigid_chain.py +203 -0
  101. warp/examples/sim/example_rigid_contact.py +195 -0
  102. warp/examples/sim/example_rigid_force.py +133 -0
  103. warp/examples/sim/example_rigid_gyroscopic.py +115 -0
  104. warp/examples/sim/example_rigid_soft_contact.py +140 -0
  105. warp/examples/sim/example_soft_body.py +196 -0
  106. warp/examples/tile/example_tile_cholesky.py +87 -0
  107. warp/examples/tile/example_tile_convolution.py +66 -0
  108. warp/examples/tile/example_tile_fft.py +55 -0
  109. warp/examples/tile/example_tile_filtering.py +113 -0
  110. warp/examples/tile/example_tile_matmul.py +85 -0
  111. warp/examples/tile/example_tile_mlp.py +383 -0
  112. warp/examples/tile/example_tile_nbody.py +199 -0
  113. warp/examples/tile/example_tile_walker.py +327 -0
  114. warp/fabric.py +355 -0
  115. warp/fem/__init__.py +106 -0
  116. warp/fem/adaptivity.py +508 -0
  117. warp/fem/cache.py +572 -0
  118. warp/fem/dirichlet.py +202 -0
  119. warp/fem/domain.py +411 -0
  120. warp/fem/field/__init__.py +125 -0
  121. warp/fem/field/field.py +619 -0
  122. warp/fem/field/nodal_field.py +326 -0
  123. warp/fem/field/restriction.py +37 -0
  124. warp/fem/field/virtual.py +848 -0
  125. warp/fem/geometry/__init__.py +32 -0
  126. warp/fem/geometry/adaptive_nanogrid.py +857 -0
  127. warp/fem/geometry/closest_point.py +84 -0
  128. warp/fem/geometry/deformed_geometry.py +221 -0
  129. warp/fem/geometry/element.py +776 -0
  130. warp/fem/geometry/geometry.py +362 -0
  131. warp/fem/geometry/grid_2d.py +392 -0
  132. warp/fem/geometry/grid_3d.py +452 -0
  133. warp/fem/geometry/hexmesh.py +911 -0
  134. warp/fem/geometry/nanogrid.py +571 -0
  135. warp/fem/geometry/partition.py +389 -0
  136. warp/fem/geometry/quadmesh.py +663 -0
  137. warp/fem/geometry/tetmesh.py +855 -0
  138. warp/fem/geometry/trimesh.py +806 -0
  139. warp/fem/integrate.py +2335 -0
  140. warp/fem/linalg.py +419 -0
  141. warp/fem/operator.py +293 -0
  142. warp/fem/polynomial.py +229 -0
  143. warp/fem/quadrature/__init__.py +17 -0
  144. warp/fem/quadrature/pic_quadrature.py +299 -0
  145. warp/fem/quadrature/quadrature.py +591 -0
  146. warp/fem/space/__init__.py +228 -0
  147. warp/fem/space/basis_function_space.py +468 -0
  148. warp/fem/space/basis_space.py +667 -0
  149. warp/fem/space/dof_mapper.py +251 -0
  150. warp/fem/space/function_space.py +309 -0
  151. warp/fem/space/grid_2d_function_space.py +177 -0
  152. warp/fem/space/grid_3d_function_space.py +227 -0
  153. warp/fem/space/hexmesh_function_space.py +257 -0
  154. warp/fem/space/nanogrid_function_space.py +201 -0
  155. warp/fem/space/partition.py +367 -0
  156. warp/fem/space/quadmesh_function_space.py +223 -0
  157. warp/fem/space/restriction.py +179 -0
  158. warp/fem/space/shape/__init__.py +143 -0
  159. warp/fem/space/shape/cube_shape_function.py +1105 -0
  160. warp/fem/space/shape/shape_function.py +133 -0
  161. warp/fem/space/shape/square_shape_function.py +926 -0
  162. warp/fem/space/shape/tet_shape_function.py +834 -0
  163. warp/fem/space/shape/triangle_shape_function.py +672 -0
  164. warp/fem/space/tetmesh_function_space.py +271 -0
  165. warp/fem/space/topology.py +424 -0
  166. warp/fem/space/trimesh_function_space.py +194 -0
  167. warp/fem/types.py +99 -0
  168. warp/fem/utils.py +420 -0
  169. warp/jax.py +187 -0
  170. warp/jax_experimental/__init__.py +16 -0
  171. warp/jax_experimental/custom_call.py +351 -0
  172. warp/jax_experimental/ffi.py +698 -0
  173. warp/jax_experimental/xla_ffi.py +602 -0
  174. warp/math.py +244 -0
  175. warp/native/array.h +1145 -0
  176. warp/native/builtin.h +1800 -0
  177. warp/native/bvh.cpp +492 -0
  178. warp/native/bvh.cu +791 -0
  179. warp/native/bvh.h +554 -0
  180. warp/native/clang/clang.cpp +536 -0
  181. warp/native/coloring.cpp +613 -0
  182. warp/native/crt.cpp +51 -0
  183. warp/native/crt.h +362 -0
  184. warp/native/cuda_crt.h +1058 -0
  185. warp/native/cuda_util.cpp +646 -0
  186. warp/native/cuda_util.h +307 -0
  187. warp/native/error.cpp +77 -0
  188. warp/native/error.h +36 -0
  189. warp/native/exports.h +1878 -0
  190. warp/native/fabric.h +245 -0
  191. warp/native/hashgrid.cpp +311 -0
  192. warp/native/hashgrid.cu +87 -0
  193. warp/native/hashgrid.h +240 -0
  194. warp/native/initializer_array.h +41 -0
  195. warp/native/intersect.h +1230 -0
  196. warp/native/intersect_adj.h +375 -0
  197. warp/native/intersect_tri.h +339 -0
  198. warp/native/marching.cpp +19 -0
  199. warp/native/marching.cu +514 -0
  200. warp/native/marching.h +19 -0
  201. warp/native/mat.h +2220 -0
  202. warp/native/mathdx.cpp +87 -0
  203. warp/native/matnn.h +343 -0
  204. warp/native/mesh.cpp +266 -0
  205. warp/native/mesh.cu +404 -0
  206. warp/native/mesh.h +1980 -0
  207. warp/native/nanovdb/GridHandle.h +366 -0
  208. warp/native/nanovdb/HostBuffer.h +590 -0
  209. warp/native/nanovdb/NanoVDB.h +6624 -0
  210. warp/native/nanovdb/PNanoVDB.h +3390 -0
  211. warp/native/noise.h +859 -0
  212. warp/native/quat.h +1371 -0
  213. warp/native/rand.h +342 -0
  214. warp/native/range.h +139 -0
  215. warp/native/reduce.cpp +174 -0
  216. warp/native/reduce.cu +364 -0
  217. warp/native/runlength_encode.cpp +79 -0
  218. warp/native/runlength_encode.cu +61 -0
  219. warp/native/scan.cpp +47 -0
  220. warp/native/scan.cu +53 -0
  221. warp/native/scan.h +23 -0
  222. warp/native/solid_angle.h +466 -0
  223. warp/native/sort.cpp +251 -0
  224. warp/native/sort.cu +277 -0
  225. warp/native/sort.h +33 -0
  226. warp/native/sparse.cpp +378 -0
  227. warp/native/sparse.cu +524 -0
  228. warp/native/spatial.h +657 -0
  229. warp/native/svd.h +702 -0
  230. warp/native/temp_buffer.h +46 -0
  231. warp/native/tile.h +2584 -0
  232. warp/native/tile_reduce.h +264 -0
  233. warp/native/vec.h +1426 -0
  234. warp/native/volume.cpp +501 -0
  235. warp/native/volume.cu +67 -0
  236. warp/native/volume.h +969 -0
  237. warp/native/volume_builder.cu +477 -0
  238. warp/native/volume_builder.h +52 -0
  239. warp/native/volume_impl.h +70 -0
  240. warp/native/warp.cpp +1082 -0
  241. warp/native/warp.cu +3636 -0
  242. warp/native/warp.h +381 -0
  243. warp/optim/__init__.py +17 -0
  244. warp/optim/adam.py +163 -0
  245. warp/optim/linear.py +1137 -0
  246. warp/optim/sgd.py +112 -0
  247. warp/paddle.py +407 -0
  248. warp/render/__init__.py +18 -0
  249. warp/render/render_opengl.py +3518 -0
  250. warp/render/render_usd.py +784 -0
  251. warp/render/utils.py +160 -0
  252. warp/sim/__init__.py +65 -0
  253. warp/sim/articulation.py +793 -0
  254. warp/sim/collide.py +2395 -0
  255. warp/sim/graph_coloring.py +300 -0
  256. warp/sim/import_mjcf.py +790 -0
  257. warp/sim/import_snu.py +227 -0
  258. warp/sim/import_urdf.py +579 -0
  259. warp/sim/import_usd.py +894 -0
  260. warp/sim/inertia.py +324 -0
  261. warp/sim/integrator.py +242 -0
  262. warp/sim/integrator_euler.py +1997 -0
  263. warp/sim/integrator_featherstone.py +2101 -0
  264. warp/sim/integrator_vbd.py +2048 -0
  265. warp/sim/integrator_xpbd.py +3292 -0
  266. warp/sim/model.py +4791 -0
  267. warp/sim/particles.py +121 -0
  268. warp/sim/render.py +427 -0
  269. warp/sim/utils.py +428 -0
  270. warp/sparse.py +2057 -0
  271. warp/stubs.py +3333 -0
  272. warp/tape.py +1203 -0
  273. warp/tests/__init__.py +1 -0
  274. warp/tests/__main__.py +4 -0
  275. warp/tests/assets/curlnoise_golden.npy +0 -0
  276. warp/tests/assets/mlp_golden.npy +0 -0
  277. warp/tests/assets/pixel.npy +0 -0
  278. warp/tests/assets/pnoise_golden.npy +0 -0
  279. warp/tests/assets/spiky.usd +0 -0
  280. warp/tests/assets/test_grid.nvdb +0 -0
  281. warp/tests/assets/test_index_grid.nvdb +0 -0
  282. warp/tests/assets/test_int32_grid.nvdb +0 -0
  283. warp/tests/assets/test_vec_grid.nvdb +0 -0
  284. warp/tests/assets/torus.nvdb +0 -0
  285. warp/tests/assets/torus.usda +105 -0
  286. warp/tests/aux_test_class_kernel.py +34 -0
  287. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  288. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  289. warp/tests/aux_test_dependent.py +29 -0
  290. warp/tests/aux_test_grad_customs.py +29 -0
  291. warp/tests/aux_test_instancing_gc.py +26 -0
  292. warp/tests/aux_test_module_unload.py +23 -0
  293. warp/tests/aux_test_name_clash1.py +40 -0
  294. warp/tests/aux_test_name_clash2.py +40 -0
  295. warp/tests/aux_test_reference.py +9 -0
  296. warp/tests/aux_test_reference_reference.py +8 -0
  297. warp/tests/aux_test_square.py +16 -0
  298. warp/tests/aux_test_unresolved_func.py +22 -0
  299. warp/tests/aux_test_unresolved_symbol.py +22 -0
  300. warp/tests/cuda/__init__.py +0 -0
  301. warp/tests/cuda/test_async.py +676 -0
  302. warp/tests/cuda/test_ipc.py +124 -0
  303. warp/tests/cuda/test_mempool.py +233 -0
  304. warp/tests/cuda/test_multigpu.py +169 -0
  305. warp/tests/cuda/test_peer.py +139 -0
  306. warp/tests/cuda/test_pinned.py +84 -0
  307. warp/tests/cuda/test_streams.py +634 -0
  308. warp/tests/geometry/__init__.py +0 -0
  309. warp/tests/geometry/test_bvh.py +200 -0
  310. warp/tests/geometry/test_hash_grid.py +221 -0
  311. warp/tests/geometry/test_marching_cubes.py +74 -0
  312. warp/tests/geometry/test_mesh.py +316 -0
  313. warp/tests/geometry/test_mesh_query_aabb.py +399 -0
  314. warp/tests/geometry/test_mesh_query_point.py +932 -0
  315. warp/tests/geometry/test_mesh_query_ray.py +311 -0
  316. warp/tests/geometry/test_volume.py +1103 -0
  317. warp/tests/geometry/test_volume_write.py +346 -0
  318. warp/tests/interop/__init__.py +0 -0
  319. warp/tests/interop/test_dlpack.py +729 -0
  320. warp/tests/interop/test_jax.py +371 -0
  321. warp/tests/interop/test_paddle.py +800 -0
  322. warp/tests/interop/test_torch.py +1001 -0
  323. warp/tests/run_coverage_serial.py +39 -0
  324. warp/tests/sim/__init__.py +0 -0
  325. warp/tests/sim/disabled_kinematics.py +244 -0
  326. warp/tests/sim/flaky_test_sim_grad.py +290 -0
  327. warp/tests/sim/test_collision.py +604 -0
  328. warp/tests/sim/test_coloring.py +258 -0
  329. warp/tests/sim/test_model.py +224 -0
  330. warp/tests/sim/test_sim_grad_bounce_linear.py +212 -0
  331. warp/tests/sim/test_sim_kinematics.py +98 -0
  332. warp/tests/sim/test_vbd.py +597 -0
  333. warp/tests/test_adam.py +163 -0
  334. warp/tests/test_arithmetic.py +1096 -0
  335. warp/tests/test_array.py +2972 -0
  336. warp/tests/test_array_reduce.py +156 -0
  337. warp/tests/test_assert.py +250 -0
  338. warp/tests/test_atomic.py +153 -0
  339. warp/tests/test_bool.py +220 -0
  340. warp/tests/test_builtins_resolution.py +1298 -0
  341. warp/tests/test_closest_point_edge_edge.py +327 -0
  342. warp/tests/test_codegen.py +810 -0
  343. warp/tests/test_codegen_instancing.py +1495 -0
  344. warp/tests/test_compile_consts.py +215 -0
  345. warp/tests/test_conditional.py +252 -0
  346. warp/tests/test_context.py +42 -0
  347. warp/tests/test_copy.py +238 -0
  348. warp/tests/test_ctypes.py +638 -0
  349. warp/tests/test_dense.py +73 -0
  350. warp/tests/test_devices.py +97 -0
  351. warp/tests/test_examples.py +482 -0
  352. warp/tests/test_fabricarray.py +996 -0
  353. warp/tests/test_fast_math.py +74 -0
  354. warp/tests/test_fem.py +2003 -0
  355. warp/tests/test_fp16.py +136 -0
  356. warp/tests/test_func.py +454 -0
  357. warp/tests/test_future_annotations.py +98 -0
  358. warp/tests/test_generics.py +656 -0
  359. warp/tests/test_grad.py +893 -0
  360. warp/tests/test_grad_customs.py +339 -0
  361. warp/tests/test_grad_debug.py +341 -0
  362. warp/tests/test_implicit_init.py +411 -0
  363. warp/tests/test_import.py +45 -0
  364. warp/tests/test_indexedarray.py +1140 -0
  365. warp/tests/test_intersect.py +73 -0
  366. warp/tests/test_iter.py +76 -0
  367. warp/tests/test_large.py +177 -0
  368. warp/tests/test_launch.py +411 -0
  369. warp/tests/test_lerp.py +151 -0
  370. warp/tests/test_linear_solvers.py +193 -0
  371. warp/tests/test_lvalue.py +427 -0
  372. warp/tests/test_mat.py +2089 -0
  373. warp/tests/test_mat_lite.py +122 -0
  374. warp/tests/test_mat_scalar_ops.py +2913 -0
  375. warp/tests/test_math.py +178 -0
  376. warp/tests/test_mlp.py +282 -0
  377. warp/tests/test_module_hashing.py +258 -0
  378. warp/tests/test_modules_lite.py +44 -0
  379. warp/tests/test_noise.py +252 -0
  380. warp/tests/test_operators.py +299 -0
  381. warp/tests/test_options.py +129 -0
  382. warp/tests/test_overwrite.py +551 -0
  383. warp/tests/test_print.py +339 -0
  384. warp/tests/test_quat.py +2315 -0
  385. warp/tests/test_rand.py +339 -0
  386. warp/tests/test_reload.py +302 -0
  387. warp/tests/test_rounding.py +185 -0
  388. warp/tests/test_runlength_encode.py +196 -0
  389. warp/tests/test_scalar_ops.py +105 -0
  390. warp/tests/test_smoothstep.py +108 -0
  391. warp/tests/test_snippet.py +318 -0
  392. warp/tests/test_sparse.py +582 -0
  393. warp/tests/test_spatial.py +2229 -0
  394. warp/tests/test_special_values.py +361 -0
  395. warp/tests/test_static.py +592 -0
  396. warp/tests/test_struct.py +734 -0
  397. warp/tests/test_tape.py +204 -0
  398. warp/tests/test_transient_module.py +93 -0
  399. warp/tests/test_triangle_closest_point.py +145 -0
  400. warp/tests/test_types.py +562 -0
  401. warp/tests/test_utils.py +588 -0
  402. warp/tests/test_vec.py +1487 -0
  403. warp/tests/test_vec_lite.py +80 -0
  404. warp/tests/test_vec_scalar_ops.py +2327 -0
  405. warp/tests/test_verify_fp.py +100 -0
  406. warp/tests/tile/__init__.py +0 -0
  407. warp/tests/tile/test_tile.py +780 -0
  408. warp/tests/tile/test_tile_load.py +407 -0
  409. warp/tests/tile/test_tile_mathdx.py +208 -0
  410. warp/tests/tile/test_tile_mlp.py +402 -0
  411. warp/tests/tile/test_tile_reduce.py +447 -0
  412. warp/tests/tile/test_tile_shared_memory.py +247 -0
  413. warp/tests/tile/test_tile_view.py +173 -0
  414. warp/tests/unittest_serial.py +47 -0
  415. warp/tests/unittest_suites.py +427 -0
  416. warp/tests/unittest_utils.py +468 -0
  417. warp/tests/walkthrough_debug.py +93 -0
  418. warp/thirdparty/__init__.py +0 -0
  419. warp/thirdparty/appdirs.py +598 -0
  420. warp/thirdparty/dlpack.py +145 -0
  421. warp/thirdparty/unittest_parallel.py +570 -0
  422. warp/torch.py +391 -0
  423. warp/types.py +5230 -0
  424. warp/utils.py +1137 -0
  425. warp_lang-1.7.0.dist-info/METADATA +516 -0
  426. warp_lang-1.7.0.dist-info/RECORD +429 -0
  427. warp_lang-1.7.0.dist-info/WHEEL +5 -0
  428. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  429. warp_lang-1.7.0.dist-info/top_level.txt +1 -0
warp/native/sort.cu ADDED
@@ -0,0 +1,277 @@
1
+ /*
2
+ * SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: Apache-2.0
4
+ *
5
+ * Licensed under the Apache License, Version 2.0 (the "License");
6
+ * you may not use this file except in compliance with the License.
7
+ * You may obtain a copy of the License at
8
+ *
9
+ * http://www.apache.org/licenses/LICENSE-2.0
10
+ *
11
+ * Unless required by applicable law or agreed to in writing, software
12
+ * distributed under the License is distributed on an "AS IS" BASIS,
13
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ * See the License for the specific language governing permissions and
15
+ * limitations under the License.
16
+ */
17
+
18
+ #include "warp.h"
19
+ #include "cuda_util.h"
20
+ #include "sort.h"
21
+
22
+ #define THRUST_IGNORE_CUB_VERSION_CHECK
23
+
24
+ #include <cub/cub.cuh>
25
+
26
+ #include <map>
27
+
28
+ // temporary buffer for radix sort
29
+ struct RadixSortTemp
30
+ {
31
+ void* mem = NULL;
32
+ size_t size = 0;
33
+ };
34
+
35
+ // map temp buffers to CUDA contexts
36
+ static std::map<void*, RadixSortTemp> g_radix_sort_temp_map;
37
+
38
+
39
+ template <typename KeyType>
40
+ void radix_sort_reserve_internal(void* context, int n, void** mem_out, size_t* size_out)
41
+ {
42
+ ContextGuard guard(context);
43
+
44
+ cub::DoubleBuffer<KeyType> d_keys;
45
+ cub::DoubleBuffer<int> d_values;
46
+
47
+ // compute temporary memory required
48
+ size_t sort_temp_size;
49
+ check_cuda(cub::DeviceRadixSort::SortPairs(
50
+ NULL,
51
+ sort_temp_size,
52
+ d_keys,
53
+ d_values,
54
+ n, 0, sizeof(KeyType)*8,
55
+ (cudaStream_t)cuda_stream_get_current()));
56
+
57
+ if (!context)
58
+ context = cuda_context_get_current();
59
+
60
+ RadixSortTemp& temp = g_radix_sort_temp_map[context];
61
+
62
+ if (sort_temp_size > temp.size)
63
+ {
64
+ free_device(WP_CURRENT_CONTEXT, temp.mem);
65
+ temp.mem = alloc_device(WP_CURRENT_CONTEXT, sort_temp_size);
66
+ temp.size = sort_temp_size;
67
+ }
68
+
69
+ if (mem_out)
70
+ *mem_out = temp.mem;
71
+ if (size_out)
72
+ *size_out = temp.size;
73
+ }
74
+
75
+ void radix_sort_reserve(void* context, int n, void** mem_out, size_t* size_out)
76
+ {
77
+ radix_sort_reserve_internal<int>(context, n, mem_out, size_out);
78
+ }
79
+
80
+ template <typename KeyType>
81
+ void radix_sort_pairs_device(void* context, KeyType* keys, int* values, int n)
82
+ {
83
+ ContextGuard guard(context);
84
+
85
+ cub::DoubleBuffer<KeyType> d_keys(keys, keys + n);
86
+ cub::DoubleBuffer<int> d_values(values, values + n);
87
+
88
+ RadixSortTemp temp;
89
+ radix_sort_reserve_internal<KeyType>(WP_CURRENT_CONTEXT, n, &temp.mem, &temp.size);
90
+
91
+ // sort
92
+ check_cuda(cub::DeviceRadixSort::SortPairs(
93
+ temp.mem,
94
+ temp.size,
95
+ d_keys,
96
+ d_values,
97
+ n, 0, sizeof(KeyType)*8,
98
+ (cudaStream_t)cuda_stream_get_current()));
99
+
100
+ if (d_keys.Current() != keys)
101
+ memcpy_d2d(WP_CURRENT_CONTEXT, keys, d_keys.Current(), sizeof(KeyType)*n);
102
+
103
+ if (d_values.Current() != values)
104
+ memcpy_d2d(WP_CURRENT_CONTEXT, values, d_values.Current(), sizeof(int)*n);
105
+ }
106
+
107
+ void radix_sort_pairs_device(void* context, int* keys, int* values, int n)
108
+ {
109
+ radix_sort_pairs_device<int>(context, keys, values, n);
110
+ }
111
+
112
+ void radix_sort_pairs_device(void* context, float* keys, int* values, int n)
113
+ {
114
+ radix_sort_pairs_device<float>(context, keys, values, n);
115
+ }
116
+
117
+ void radix_sort_pairs_device(void* context, int64_t* keys, int* values, int n)
118
+ {
119
+ radix_sort_pairs_device<int64_t>(context, keys, values, n);
120
+ }
121
+
122
+ void radix_sort_pairs_int_device(uint64_t keys, uint64_t values, int n)
123
+ {
124
+ radix_sort_pairs_device(
125
+ WP_CURRENT_CONTEXT,
126
+ reinterpret_cast<int *>(keys),
127
+ reinterpret_cast<int *>(values), n);
128
+ }
129
+
130
+ void radix_sort_pairs_float_device(uint64_t keys, uint64_t values, int n)
131
+ {
132
+ radix_sort_pairs_device(
133
+ WP_CURRENT_CONTEXT,
134
+ reinterpret_cast<float *>(keys),
135
+ reinterpret_cast<int *>(values), n);
136
+ }
137
+
138
+ void radix_sort_pairs_int64_device(uint64_t keys, uint64_t values, int n)
139
+ {
140
+ radix_sort_pairs_device(
141
+ WP_CURRENT_CONTEXT,
142
+ reinterpret_cast<int64_t *>(keys),
143
+ reinterpret_cast<int *>(values), n);
144
+ }
145
+
146
+ void segmented_sort_reserve(void* context, int n, int num_segments, void** mem_out, size_t* size_out)
147
+ {
148
+ ContextGuard guard(context);
149
+
150
+ cub::DoubleBuffer<int> d_keys;
151
+ cub::DoubleBuffer<int> d_values;
152
+
153
+ int* start_indices = NULL;
154
+ int* end_indices = NULL;
155
+
156
+ // compute temporary memory required
157
+ size_t sort_temp_size;
158
+ check_cuda(cub::DeviceSegmentedRadixSort::SortPairs(
159
+ NULL,
160
+ sort_temp_size,
161
+ d_keys,
162
+ d_values,
163
+ n,
164
+ num_segments,
165
+ start_indices,
166
+ end_indices,
167
+ 0,
168
+ 32,
169
+ (cudaStream_t)cuda_stream_get_current()));
170
+
171
+ if (!context)
172
+ context = cuda_context_get_current();
173
+
174
+ RadixSortTemp& temp = g_radix_sort_temp_map[context];
175
+
176
+ if (sort_temp_size > temp.size)
177
+ {
178
+ free_device(WP_CURRENT_CONTEXT, temp.mem);
179
+ temp.mem = alloc_device(WP_CURRENT_CONTEXT, sort_temp_size);
180
+ temp.size = sort_temp_size;
181
+ }
182
+
183
+ if (mem_out)
184
+ *mem_out = temp.mem;
185
+ if (size_out)
186
+ *size_out = temp.size;
187
+ }
188
+
189
+ // segment_start_indices and segment_end_indices are arrays of length num_segments, where segment_start_indices[i] is the index of the first element
190
+ // in the i-th segment and segment_end_indices[i] is the index after the last element in the i-th segment
191
+ // https://nvidia.github.io/cccl/cub/api/structcub_1_1DeviceSegmentedRadixSort.html
192
+ void segmented_sort_pairs_device(void* context, float* keys, int* values, int n, int* segment_start_indices, int* segment_end_indices, int num_segments)
193
+ {
194
+ ContextGuard guard(context);
195
+
196
+ cub::DoubleBuffer<float> d_keys(keys, keys + n);
197
+ cub::DoubleBuffer<int> d_values(values, values + n);
198
+
199
+ RadixSortTemp temp;
200
+ segmented_sort_reserve(WP_CURRENT_CONTEXT, n, num_segments, &temp.mem, &temp.size);
201
+
202
+ // sort
203
+ check_cuda(cub::DeviceSegmentedRadixSort::SortPairs(
204
+ temp.mem,
205
+ temp.size,
206
+ d_keys,
207
+ d_values,
208
+ n,
209
+ num_segments,
210
+ segment_start_indices,
211
+ segment_end_indices,
212
+ 0,
213
+ 32,
214
+ (cudaStream_t)cuda_stream_get_current()));
215
+
216
+ if (d_keys.Current() != keys)
217
+ memcpy_d2d(WP_CURRENT_CONTEXT, keys, d_keys.Current(), sizeof(float)*n);
218
+
219
+ if (d_values.Current() != values)
220
+ memcpy_d2d(WP_CURRENT_CONTEXT, values, d_values.Current(), sizeof(int)*n);
221
+ }
222
+
223
+ void segmented_sort_pairs_float_device(uint64_t keys, uint64_t values, int n, uint64_t segment_start_indices, uint64_t segment_end_indices, int num_segments)
224
+ {
225
+ segmented_sort_pairs_device(
226
+ WP_CURRENT_CONTEXT,
227
+ reinterpret_cast<float *>(keys),
228
+ reinterpret_cast<int *>(values), n,
229
+ reinterpret_cast<int *>(segment_start_indices),
230
+ reinterpret_cast<int *>(segment_end_indices),
231
+ num_segments);
232
+ }
233
+
234
+ // segment_indices is an array of length num_segments + 1, where segment_indices[i] is the index of the first element in the i-th segment
235
+ // The end of a segment is given by segment_indices[i+1]
236
+ // https://nvidia.github.io/cccl/cub/api/structcub_1_1DeviceSegmentedSort.html#a-simple-example
237
+ void segmented_sort_pairs_device(void* context, int* keys, int* values, int n, int* segment_start_indices, int* segment_end_indices, int num_segments)
238
+ {
239
+ ContextGuard guard(context);
240
+
241
+ cub::DoubleBuffer<int> d_keys(keys, keys + n);
242
+ cub::DoubleBuffer<int> d_values(values, values + n);
243
+
244
+ RadixSortTemp temp;
245
+ segmented_sort_reserve(WP_CURRENT_CONTEXT, n, num_segments, &temp.mem, &temp.size);
246
+
247
+ // sort
248
+ check_cuda(cub::DeviceSegmentedRadixSort::SortPairs(
249
+ temp.mem,
250
+ temp.size,
251
+ d_keys,
252
+ d_values,
253
+ n,
254
+ num_segments,
255
+ segment_start_indices,
256
+ segment_end_indices,
257
+ 0,
258
+ 32,
259
+ (cudaStream_t)cuda_stream_get_current()));
260
+
261
+ if (d_keys.Current() != keys)
262
+ memcpy_d2d(WP_CURRENT_CONTEXT, keys, d_keys.Current(), sizeof(float)*n);
263
+
264
+ if (d_values.Current() != values)
265
+ memcpy_d2d(WP_CURRENT_CONTEXT, values, d_values.Current(), sizeof(int)*n);
266
+ }
267
+
268
+ void segmented_sort_pairs_int_device(uint64_t keys, uint64_t values, int n, uint64_t segment_start_indices, uint64_t segment_end_indices, int num_segments)
269
+ {
270
+ segmented_sort_pairs_device(
271
+ WP_CURRENT_CONTEXT,
272
+ reinterpret_cast<int *>(keys),
273
+ reinterpret_cast<int *>(values), n,
274
+ reinterpret_cast<int *>(segment_start_indices),
275
+ reinterpret_cast<int *>(segment_end_indices),
276
+ num_segments);
277
+ }
warp/native/sort.h ADDED
@@ -0,0 +1,33 @@
1
+ /*
2
+ * SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: Apache-2.0
4
+ *
5
+ * Licensed under the Apache License, Version 2.0 (the "License");
6
+ * you may not use this file except in compliance with the License.
7
+ * You may obtain a copy of the License at
8
+ *
9
+ * http://www.apache.org/licenses/LICENSE-2.0
10
+ *
11
+ * Unless required by applicable law or agreed to in writing, software
12
+ * distributed under the License is distributed on an "AS IS" BASIS,
13
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ * See the License for the specific language governing permissions and
15
+ * limitations under the License.
16
+ */
17
+
18
+ #pragma once
19
+
20
+ #include <stddef.h>
21
+
22
+ void radix_sort_reserve(void* context, int n, void** mem_out=NULL, size_t* size_out=NULL);
23
+ void radix_sort_pairs_host(int* keys, int* values, int n);
24
+ void radix_sort_pairs_host(float* keys, int* values, int n);
25
+ void radix_sort_pairs_host(int64_t* keys, int* values, int n);
26
+ void radix_sort_pairs_device(void* context, int* keys, int* values, int n);
27
+ void radix_sort_pairs_device(void* context, float* keys, int* values, int n);
28
+ void radix_sort_pairs_device(void* context, int64_t* keys, int* values, int n);
29
+
30
+ void segmented_sort_pairs_host(float* keys, int* values, int n, int* segment_start_indices, int* segment_end_indices, int num_segments);
31
+ void segmented_sort_pairs_device(void* context, float* keys, int* values, int n, int* segment_start_indices, int* segment_end_indices, int num_segments);
32
+ void segmented_sort_pairs_host(void* context, int* keys, int* values, int n, int* segment_start_indices, int* segment_end_indices, int num_segments);
33
+ void segmented_sort_pairs_device(void* context, int* keys, int* values, int n, int* segment_start_indices, int* segment_end_indices, int num_segments);
warp/native/sparse.cpp ADDED
@@ -0,0 +1,378 @@
1
+ /*
2
+ * SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: Apache-2.0
4
+ *
5
+ * Licensed under the Apache License, Version 2.0 (the "License");
6
+ * you may not use this file except in compliance with the License.
7
+ * You may obtain a copy of the License at
8
+ *
9
+ * http://www.apache.org/licenses/LICENSE-2.0
10
+ *
11
+ * Unless required by applicable law or agreed to in writing, software
12
+ * distributed under the License is distributed on an "AS IS" BASIS,
13
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ * See the License for the specific language governing permissions and
15
+ * limitations under the License.
16
+ */
17
+
18
+ #include "warp.h"
19
+
20
+ #include <algorithm>
21
+ #include <numeric>
22
+ #include <vector>
23
+
24
+ namespace
25
+ {
26
+
27
+ // Specialized is_zero and accumulation function for common block sizes
28
+ // Rely on compiler to unroll loops when block size is known
29
+
30
+ template <int N, typename T> bool bsr_fixed_block_is_zero(const T* val, int value_size)
31
+ {
32
+ return std::all_of(val, val + N, [](float v) { return v == T(0); });
33
+ }
34
+
35
+ template <typename T> bool bsr_dyn_block_is_zero(const T* val, int value_size)
36
+ {
37
+ return std::all_of(val, val + value_size, [](float v) { return v == T(0); });
38
+ }
39
+
40
+ template <int N, typename T> void bsr_fixed_block_accumulate(const T* val, T* sum, int value_size)
41
+ {
42
+ for (int i = 0; i < N; ++i, ++val, ++sum)
43
+ {
44
+ *sum += *val;
45
+ }
46
+ }
47
+
48
+ template <typename T> void bsr_dyn_block_accumulate(const T* val, T* sum, int value_size)
49
+ {
50
+ for (int i = 0; i < value_size; ++i, ++val, ++sum)
51
+ {
52
+ *sum += *val;
53
+ }
54
+ }
55
+
56
+ template <int Rows, int Cols, typename T>
57
+ void bsr_fixed_block_transpose(const T* src, T* dest, int row_count, int col_count)
58
+ {
59
+ for (int r = 0; r < Rows; ++r)
60
+ {
61
+ for (int c = 0; c < Cols; ++c)
62
+ {
63
+ dest[c * Rows + r] = src[r * Cols + c];
64
+ }
65
+ }
66
+ }
67
+
68
+ template <typename T> void bsr_dyn_block_transpose(const T* src, T* dest, int row_count, int col_count)
69
+ {
70
+ for (int r = 0; r < row_count; ++r)
71
+ {
72
+ for (int c = 0; c < col_count; ++c)
73
+ {
74
+ dest[c * row_count + r] = src[r * col_count + c];
75
+ }
76
+ }
77
+ }
78
+
79
+ } // namespace
80
+
81
+ template <typename T>
82
+ int bsr_matrix_from_triplets_host(const int rows_per_block, const int cols_per_block, const int row_count,
83
+ const int nnz, const int* tpl_rows, const int* tpl_columns, const T* tpl_values,
84
+ const bool prune_numerical_zeros, const bool masked, int* bsr_offsets,
85
+ int* bsr_columns, T* bsr_values)
86
+ {
87
+
88
+ // get specialized accumulator for common block sizes (1,1), (1,2), (1,3),
89
+ // (2,2), (2,3), (3,3)
90
+ const int block_size = rows_per_block * cols_per_block;
91
+ void (*block_accumulate_func)(const T*, T*, int);
92
+ bool (*block_is_zero_func)(const T*, int);
93
+ switch (block_size)
94
+ {
95
+ case 1:
96
+ block_accumulate_func = bsr_fixed_block_accumulate<1, T>;
97
+ block_is_zero_func = bsr_fixed_block_is_zero<1, T>;
98
+ break;
99
+ case 2:
100
+ block_accumulate_func = bsr_fixed_block_accumulate<2, T>;
101
+ block_is_zero_func = bsr_fixed_block_is_zero<2, T>;
102
+ break;
103
+ case 3:
104
+ block_accumulate_func = bsr_fixed_block_accumulate<3, T>;
105
+ block_is_zero_func = bsr_fixed_block_is_zero<3, T>;
106
+ break;
107
+ case 4:
108
+ block_accumulate_func = bsr_fixed_block_accumulate<4, T>;
109
+ block_is_zero_func = bsr_fixed_block_is_zero<4, T>;
110
+ break;
111
+ case 6:
112
+ block_accumulate_func = bsr_fixed_block_accumulate<6, T>;
113
+ block_is_zero_func = bsr_fixed_block_is_zero<6, T>;
114
+ break;
115
+ case 9:
116
+ block_accumulate_func = bsr_fixed_block_accumulate<9, T>;
117
+ block_is_zero_func = bsr_fixed_block_is_zero<9, T>;
118
+ break;
119
+ default:
120
+ block_accumulate_func = bsr_dyn_block_accumulate<T>;
121
+ block_is_zero_func = bsr_dyn_block_is_zero<T>;
122
+ }
123
+
124
+ std::vector<int> block_indices(nnz);
125
+ std::iota(block_indices.begin(), block_indices.end(), 0);
126
+
127
+ // remove zero blocks and invalid row indices
128
+
129
+ auto discard_block = [&](int i)
130
+ {
131
+ const int row = tpl_rows[i];
132
+ if (row < 0 || row >= row_count)
133
+ {
134
+ return true;
135
+ }
136
+
137
+ if (prune_numerical_zeros && tpl_values && block_is_zero_func(tpl_values + i * block_size, block_size))
138
+ {
139
+ return true;
140
+ }
141
+
142
+ if (!masked)
143
+ {
144
+ return false;
145
+ }
146
+
147
+ const int* beg = bsr_columns + bsr_offsets[row];
148
+ const int* end = bsr_columns + bsr_offsets[row + 1];
149
+ const int col = tpl_columns[i];
150
+ const int* block = std::lower_bound(beg, end, col);
151
+ return block == end || *block != col;
152
+ };
153
+
154
+ block_indices.erase(std::remove_if(block_indices.begin(), block_indices.end(), discard_block), block_indices.end());
155
+
156
+ // sort block indices according to lexico order
157
+ std::sort(block_indices.begin(), block_indices.end(), [tpl_rows, tpl_columns](int i, int j) -> bool
158
+ { return tpl_rows[i] < tpl_rows[j] || (tpl_rows[i] == tpl_rows[j] && tpl_columns[i] < tpl_columns[j]); });
159
+
160
+ // accumulate blocks at same locations, count blocks per row
161
+ std::fill_n(bsr_offsets, row_count + 1, 0);
162
+
163
+ int current_row = -1;
164
+ int current_col = -1;
165
+
166
+ // so that we get back to the start for the first block
167
+ if (bsr_values)
168
+ {
169
+ bsr_values -= block_size;
170
+ }
171
+
172
+ for (int i = 0; i < block_indices.size(); ++i)
173
+ {
174
+ int idx = block_indices[i];
175
+ int row = tpl_rows[idx];
176
+ int col = tpl_columns[idx];
177
+ const T* val = tpl_values + idx * block_size;
178
+
179
+ if (row == current_row && col == current_col)
180
+ {
181
+ if (bsr_values)
182
+ {
183
+ block_accumulate_func(val, bsr_values, block_size);
184
+ }
185
+ }
186
+ else
187
+ {
188
+ *(bsr_columns++) = col;
189
+
190
+ if (bsr_values)
191
+ {
192
+ bsr_values += block_size;
193
+ std::copy_n(val, block_size, bsr_values);
194
+ }
195
+
196
+ bsr_offsets[row + 1]++;
197
+
198
+ current_row = row;
199
+ current_col = col;
200
+ }
201
+ }
202
+
203
+ // build postfix sum of row counts
204
+ std::partial_sum(bsr_offsets, bsr_offsets + row_count + 1, bsr_offsets);
205
+
206
+ return bsr_offsets[row_count];
207
+ }
208
+
209
+ template <typename T>
210
+ void bsr_transpose_host(int rows_per_block, int cols_per_block, int row_count, int col_count, int nnz_up,
211
+ const int* bsr_offsets, const int* bsr_columns, const T* bsr_values,
212
+ int* transposed_bsr_offsets, int* transposed_bsr_columns, T* transposed_bsr_values)
213
+ {
214
+ const int nnz = bsr_offsets[row_count];
215
+ const int block_size = rows_per_block * cols_per_block;
216
+
217
+ void (*block_transpose_func)(const T*, T*, int, int) = bsr_dyn_block_transpose<T>;
218
+ switch (rows_per_block)
219
+ {
220
+ case 1:
221
+ switch (cols_per_block)
222
+ {
223
+ case 1:
224
+ block_transpose_func = bsr_fixed_block_transpose<1, 1, T>;
225
+ break;
226
+ case 2:
227
+ block_transpose_func = bsr_fixed_block_transpose<1, 2, T>;
228
+ break;
229
+ case 3:
230
+ block_transpose_func = bsr_fixed_block_transpose<1, 3, T>;
231
+ break;
232
+ }
233
+ break;
234
+ case 2:
235
+ switch (cols_per_block)
236
+ {
237
+ case 1:
238
+ block_transpose_func = bsr_fixed_block_transpose<2, 1, T>;
239
+ break;
240
+ case 2:
241
+ block_transpose_func = bsr_fixed_block_transpose<2, 2, T>;
242
+ break;
243
+ case 3:
244
+ block_transpose_func = bsr_fixed_block_transpose<2, 3, T>;
245
+ break;
246
+ }
247
+ break;
248
+ case 3:
249
+ switch (cols_per_block)
250
+ {
251
+ case 1:
252
+ block_transpose_func = bsr_fixed_block_transpose<3, 1, T>;
253
+ break;
254
+ case 2:
255
+ block_transpose_func = bsr_fixed_block_transpose<3, 2, T>;
256
+ break;
257
+ case 3:
258
+ block_transpose_func = bsr_fixed_block_transpose<3, 3, T>;
259
+ break;
260
+ }
261
+ break;
262
+ }
263
+
264
+ std::vector<int> block_indices(nnz), bsr_rows(nnz);
265
+ std::iota(block_indices.begin(), block_indices.end(), 0);
266
+
267
+ // Fill row indices from offsets
268
+ for (int row = 0; row < row_count; ++row)
269
+ {
270
+ std::fill(bsr_rows.begin() + bsr_offsets[row], bsr_rows.begin() + bsr_offsets[row + 1], row);
271
+ }
272
+
273
+ // sort block indices according to (transposed) lexico order
274
+ std::sort(
275
+ block_indices.begin(), block_indices.end(), [&bsr_rows, bsr_columns](int i, int j) -> bool
276
+ { return bsr_columns[i] < bsr_columns[j] || (bsr_columns[i] == bsr_columns[j] && bsr_rows[i] < bsr_rows[j]); });
277
+
278
+ // Count blocks per column and transpose blocks
279
+ std::fill_n(transposed_bsr_offsets, col_count + 1, 0);
280
+
281
+ for (int i = 0; i < nnz; ++i)
282
+ {
283
+ int idx = block_indices[i];
284
+ int row = bsr_rows[idx];
285
+ int col = bsr_columns[idx];
286
+
287
+ ++transposed_bsr_offsets[col + 1];
288
+ transposed_bsr_columns[i] = row;
289
+
290
+ if (transposed_bsr_values != nullptr)
291
+ {
292
+ const T* src_block = bsr_values + idx * block_size;
293
+ T* dst_block = transposed_bsr_values + i * block_size;
294
+ block_transpose_func(src_block, dst_block, rows_per_block, cols_per_block);
295
+ }
296
+ }
297
+
298
+ // build postfix sum of column counts
299
+ std::partial_sum(transposed_bsr_offsets, transposed_bsr_offsets + col_count + 1, transposed_bsr_offsets);
300
+ }
301
+
302
+ WP_API void bsr_matrix_from_triplets_float_host(int rows_per_block, int cols_per_block, int row_count, int nnz,
303
+ int* tpl_rows, int* tpl_columns, void* tpl_values,
304
+ bool prune_numerical_zeros, bool masked, int* bsr_offsets,
305
+ int* bsr_columns, void* bsr_values, int* bsr_nnz, void* bsr_nnz_event)
306
+ {
307
+ bsr_matrix_from_triplets_host<float>(rows_per_block, cols_per_block, row_count, nnz, tpl_rows, tpl_columns,
308
+ static_cast<const float*>(tpl_values), prune_numerical_zeros, masked,
309
+ bsr_offsets, bsr_columns, static_cast<float*>(bsr_values));
310
+ if (bsr_nnz)
311
+ {
312
+ *bsr_nnz = bsr_offsets[row_count];
313
+ }
314
+ }
315
+
316
+ WP_API void bsr_matrix_from_triplets_double_host(int rows_per_block, int cols_per_block, int row_count, int nnz,
317
+ int* tpl_rows, int* tpl_columns, void* tpl_values,
318
+ bool prune_numerical_zeros, bool masked, int* bsr_offsets,
319
+ int* bsr_columns, void* bsr_values, int* bsr_nnz, void* bsr_nnz_event)
320
+ {
321
+ bsr_matrix_from_triplets_host<double>(rows_per_block, cols_per_block, row_count, nnz, tpl_rows, tpl_columns,
322
+ static_cast<const double*>(tpl_values), prune_numerical_zeros, masked,
323
+ bsr_offsets, bsr_columns, static_cast<double*>(bsr_values));
324
+ if (bsr_nnz)
325
+ {
326
+ *bsr_nnz = bsr_offsets[row_count];
327
+ }
328
+ }
329
+
330
+ WP_API void bsr_transpose_float_host(int rows_per_block, int cols_per_block, int row_count, int col_count, int nnz,
331
+ int* bsr_offsets, int* bsr_columns, void* bsr_values, int* transposed_bsr_offsets,
332
+ int* transposed_bsr_columns, void* transposed_bsr_values)
333
+ {
334
+ bsr_transpose_host(rows_per_block, cols_per_block, row_count, col_count, nnz, bsr_offsets, bsr_columns,
335
+ static_cast<const float*>(bsr_values), transposed_bsr_offsets, transposed_bsr_columns,
336
+ static_cast<float*>(transposed_bsr_values));
337
+ }
338
+
339
+ WP_API void bsr_transpose_double_host(int rows_per_block, int cols_per_block, int row_count, int col_count, int nnz,
340
+ int* bsr_offsets, int* bsr_columns, void* bsr_values, int* transposed_bsr_offsets,
341
+ int* transposed_bsr_columns, void* transposed_bsr_values)
342
+ {
343
+ bsr_transpose_host(rows_per_block, cols_per_block, row_count, col_count, nnz, bsr_offsets, bsr_columns,
344
+ static_cast<const double*>(bsr_values), transposed_bsr_offsets, transposed_bsr_columns,
345
+ static_cast<double*>(transposed_bsr_values));
346
+ }
347
+
348
+ #if !WP_ENABLE_CUDA
349
+ WP_API void bsr_matrix_from_triplets_float_device(int rows_per_block, int cols_per_block, int row_count, int nnz,
350
+ int* tpl_rows, int* tpl_columns, void* tpl_values,
351
+ bool prune_numerical_zeros, bool masked, int* bsr_offsets,
352
+ int* bsr_columns, void* bsr_values, int* bsr_nnz, void* bsr_nnz_event)
353
+ {
354
+ }
355
+
356
+ WP_API void bsr_matrix_from_triplets_double_device(int rows_per_block, int cols_per_block, int row_count, int nnz,
357
+ int* tpl_rows, int* tpl_columns, void* tpl_values,
358
+ bool prune_numerical_zeros, bool masked, int* bsr_offsets,
359
+ int* bsr_columns, void* bsr_values, int* bsr_nnz,
360
+ void* bsr_nnz_event)
361
+ {
362
+ }
363
+
364
+ WP_API void bsr_transpose_float_device(int rows_per_block, int cols_per_block, int row_count, int col_count, int nnz,
365
+ int* bsr_offsets, int* bsr_columns, void* bsr_values,
366
+ int* transposed_bsr_offsets, int* transposed_bsr_columns,
367
+ void* transposed_bsr_values)
368
+ {
369
+ }
370
+
371
+ WP_API void bsr_transpose_double_device(int rows_per_block, int cols_per_block, int row_count, int col_count, int nnz,
372
+ int* bsr_offsets, int* bsr_columns, void* bsr_values,
373
+ int* transposed_bsr_offsets, int* transposed_bsr_columns,
374
+ void* transposed_bsr_values)
375
+ {
376
+ }
377
+
378
+ #endif