warp-lang 1.10.0__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (468) hide show
  1. warp/__init__.py +334 -0
  2. warp/__init__.pyi +5856 -0
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1077 -0
  5. warp/_src/build.py +620 -0
  6. warp/_src/build_dll.py +642 -0
  7. warp/_src/builtins.py +10555 -0
  8. warp/_src/codegen.py +4361 -0
  9. warp/_src/config.py +178 -0
  10. warp/_src/constants.py +59 -0
  11. warp/_src/context.py +8352 -0
  12. warp/_src/dlpack.py +464 -0
  13. warp/_src/fabric.py +362 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +510 -0
  16. warp/_src/fem/cache.py +689 -0
  17. warp/_src/fem/dirichlet.py +190 -0
  18. warp/_src/fem/domain.py +553 -0
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +703 -0
  21. warp/_src/fem/field/nodal_field.py +403 -0
  22. warp/_src/fem/field/restriction.py +39 -0
  23. warp/_src/fem/field/virtual.py +1021 -0
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/_src/fem/geometry/adaptive_nanogrid.py +782 -0
  26. warp/_src/fem/geometry/closest_point.py +99 -0
  27. warp/_src/fem/geometry/deformed_geometry.py +277 -0
  28. warp/_src/fem/geometry/element.py +854 -0
  29. warp/_src/fem/geometry/geometry.py +693 -0
  30. warp/_src/fem/geometry/grid_2d.py +478 -0
  31. warp/_src/fem/geometry/grid_3d.py +539 -0
  32. warp/_src/fem/geometry/hexmesh.py +956 -0
  33. warp/_src/fem/geometry/nanogrid.py +660 -0
  34. warp/_src/fem/geometry/partition.py +483 -0
  35. warp/_src/fem/geometry/quadmesh.py +597 -0
  36. warp/_src/fem/geometry/tetmesh.py +762 -0
  37. warp/_src/fem/geometry/trimesh.py +588 -0
  38. warp/_src/fem/integrate.py +2507 -0
  39. warp/_src/fem/linalg.py +385 -0
  40. warp/_src/fem/operator.py +398 -0
  41. warp/_src/fem/polynomial.py +231 -0
  42. warp/_src/fem/quadrature/__init__.py +17 -0
  43. warp/_src/fem/quadrature/pic_quadrature.py +318 -0
  44. warp/_src/fem/quadrature/quadrature.py +665 -0
  45. warp/_src/fem/space/__init__.py +248 -0
  46. warp/_src/fem/space/basis_function_space.py +499 -0
  47. warp/_src/fem/space/basis_space.py +681 -0
  48. warp/_src/fem/space/dof_mapper.py +253 -0
  49. warp/_src/fem/space/function_space.py +312 -0
  50. warp/_src/fem/space/grid_2d_function_space.py +179 -0
  51. warp/_src/fem/space/grid_3d_function_space.py +229 -0
  52. warp/_src/fem/space/hexmesh_function_space.py +255 -0
  53. warp/_src/fem/space/nanogrid_function_space.py +199 -0
  54. warp/_src/fem/space/partition.py +435 -0
  55. warp/_src/fem/space/quadmesh_function_space.py +222 -0
  56. warp/_src/fem/space/restriction.py +221 -0
  57. warp/_src/fem/space/shape/__init__.py +152 -0
  58. warp/_src/fem/space/shape/cube_shape_function.py +1107 -0
  59. warp/_src/fem/space/shape/shape_function.py +134 -0
  60. warp/_src/fem/space/shape/square_shape_function.py +928 -0
  61. warp/_src/fem/space/shape/tet_shape_function.py +829 -0
  62. warp/_src/fem/space/shape/triangle_shape_function.py +674 -0
  63. warp/_src/fem/space/tetmesh_function_space.py +270 -0
  64. warp/_src/fem/space/topology.py +461 -0
  65. warp/_src/fem/space/trimesh_function_space.py +193 -0
  66. warp/_src/fem/types.py +114 -0
  67. warp/_src/fem/utils.py +488 -0
  68. warp/_src/jax.py +188 -0
  69. warp/_src/jax_experimental/__init__.py +14 -0
  70. warp/_src/jax_experimental/custom_call.py +389 -0
  71. warp/_src/jax_experimental/ffi.py +1286 -0
  72. warp/_src/jax_experimental/xla_ffi.py +658 -0
  73. warp/_src/marching_cubes.py +710 -0
  74. warp/_src/math.py +416 -0
  75. warp/_src/optim/__init__.py +14 -0
  76. warp/_src/optim/adam.py +165 -0
  77. warp/_src/optim/linear.py +1608 -0
  78. warp/_src/optim/sgd.py +114 -0
  79. warp/_src/paddle.py +408 -0
  80. warp/_src/render/__init__.py +14 -0
  81. warp/_src/render/imgui_manager.py +291 -0
  82. warp/_src/render/render_opengl.py +3638 -0
  83. warp/_src/render/render_usd.py +939 -0
  84. warp/_src/render/utils.py +162 -0
  85. warp/_src/sparse.py +2718 -0
  86. warp/_src/tape.py +1208 -0
  87. warp/_src/thirdparty/__init__.py +0 -0
  88. warp/_src/thirdparty/appdirs.py +598 -0
  89. warp/_src/thirdparty/dlpack.py +145 -0
  90. warp/_src/thirdparty/unittest_parallel.py +676 -0
  91. warp/_src/torch.py +393 -0
  92. warp/_src/types.py +5888 -0
  93. warp/_src/utils.py +1695 -0
  94. warp/autograd.py +33 -0
  95. warp/bin/libwarp-clang.dylib +0 -0
  96. warp/bin/libwarp.dylib +0 -0
  97. warp/build.py +29 -0
  98. warp/build_dll.py +24 -0
  99. warp/codegen.py +24 -0
  100. warp/constants.py +24 -0
  101. warp/context.py +33 -0
  102. warp/dlpack.py +24 -0
  103. warp/examples/__init__.py +24 -0
  104. warp/examples/assets/bear.usd +0 -0
  105. warp/examples/assets/bunny.usd +0 -0
  106. warp/examples/assets/cube.usd +0 -0
  107. warp/examples/assets/nonuniform.usd +0 -0
  108. warp/examples/assets/nvidia_logo.png +0 -0
  109. warp/examples/assets/pixel.jpg +0 -0
  110. warp/examples/assets/rocks.nvdb +0 -0
  111. warp/examples/assets/rocks.usd +0 -0
  112. warp/examples/assets/sphere.usd +0 -0
  113. warp/examples/assets/square_cloth.usd +0 -0
  114. warp/examples/benchmarks/benchmark_api.py +389 -0
  115. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  116. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  117. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  118. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  119. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  120. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  121. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  122. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  123. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  124. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  125. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  126. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  127. warp/examples/benchmarks/benchmark_launches.py +301 -0
  128. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  129. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  130. warp/examples/browse.py +37 -0
  131. warp/examples/core/example_cupy.py +86 -0
  132. warp/examples/core/example_dem.py +241 -0
  133. warp/examples/core/example_fluid.py +299 -0
  134. warp/examples/core/example_graph_capture.py +150 -0
  135. warp/examples/core/example_marching_cubes.py +195 -0
  136. warp/examples/core/example_mesh.py +180 -0
  137. warp/examples/core/example_mesh_intersect.py +211 -0
  138. warp/examples/core/example_nvdb.py +182 -0
  139. warp/examples/core/example_raycast.py +111 -0
  140. warp/examples/core/example_raymarch.py +205 -0
  141. warp/examples/core/example_render_opengl.py +290 -0
  142. warp/examples/core/example_sample_mesh.py +300 -0
  143. warp/examples/core/example_sph.py +411 -0
  144. warp/examples/core/example_spin_lock.py +93 -0
  145. warp/examples/core/example_torch.py +211 -0
  146. warp/examples/core/example_wave.py +269 -0
  147. warp/examples/core/example_work_queue.py +118 -0
  148. warp/examples/distributed/example_jacobi_mpi.py +506 -0
  149. warp/examples/fem/example_adaptive_grid.py +286 -0
  150. warp/examples/fem/example_apic_fluid.py +469 -0
  151. warp/examples/fem/example_burgers.py +261 -0
  152. warp/examples/fem/example_convection_diffusion.py +181 -0
  153. warp/examples/fem/example_convection_diffusion_dg.py +225 -0
  154. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  155. warp/examples/fem/example_deformed_geometry.py +172 -0
  156. warp/examples/fem/example_diffusion.py +196 -0
  157. warp/examples/fem/example_diffusion_3d.py +225 -0
  158. warp/examples/fem/example_diffusion_mgpu.py +225 -0
  159. warp/examples/fem/example_distortion_energy.py +228 -0
  160. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  161. warp/examples/fem/example_magnetostatics.py +242 -0
  162. warp/examples/fem/example_mixed_elasticity.py +293 -0
  163. warp/examples/fem/example_navier_stokes.py +263 -0
  164. warp/examples/fem/example_nonconforming_contact.py +300 -0
  165. warp/examples/fem/example_stokes.py +213 -0
  166. warp/examples/fem/example_stokes_transfer.py +262 -0
  167. warp/examples/fem/example_streamlines.py +357 -0
  168. warp/examples/fem/utils.py +1047 -0
  169. warp/examples/interop/example_jax_callable.py +146 -0
  170. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  171. warp/examples/interop/example_jax_kernel.py +232 -0
  172. warp/examples/optim/example_diffray.py +561 -0
  173. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  174. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  175. warp/examples/tile/example_tile_cholesky.py +88 -0
  176. warp/examples/tile/example_tile_convolution.py +66 -0
  177. warp/examples/tile/example_tile_fft.py +55 -0
  178. warp/examples/tile/example_tile_filtering.py +113 -0
  179. warp/examples/tile/example_tile_matmul.py +85 -0
  180. warp/examples/tile/example_tile_mcgp.py +191 -0
  181. warp/examples/tile/example_tile_mlp.py +385 -0
  182. warp/examples/tile/example_tile_nbody.py +199 -0
  183. warp/fabric.py +24 -0
  184. warp/fem/__init__.py +173 -0
  185. warp/fem/adaptivity.py +26 -0
  186. warp/fem/cache.py +30 -0
  187. warp/fem/dirichlet.py +24 -0
  188. warp/fem/field/__init__.py +24 -0
  189. warp/fem/field/field.py +26 -0
  190. warp/fem/geometry/__init__.py +21 -0
  191. warp/fem/geometry/closest_point.py +31 -0
  192. warp/fem/linalg.py +38 -0
  193. warp/fem/operator.py +32 -0
  194. warp/fem/polynomial.py +29 -0
  195. warp/fem/space/__init__.py +22 -0
  196. warp/fem/space/basis_space.py +24 -0
  197. warp/fem/space/shape/__init__.py +68 -0
  198. warp/fem/space/topology.py +24 -0
  199. warp/fem/types.py +24 -0
  200. warp/fem/utils.py +32 -0
  201. warp/jax.py +29 -0
  202. warp/jax_experimental/__init__.py +29 -0
  203. warp/jax_experimental/custom_call.py +29 -0
  204. warp/jax_experimental/ffi.py +39 -0
  205. warp/jax_experimental/xla_ffi.py +24 -0
  206. warp/marching_cubes.py +24 -0
  207. warp/math.py +37 -0
  208. warp/native/array.h +1687 -0
  209. warp/native/builtin.h +2327 -0
  210. warp/native/bvh.cpp +562 -0
  211. warp/native/bvh.cu +826 -0
  212. warp/native/bvh.h +555 -0
  213. warp/native/clang/clang.cpp +541 -0
  214. warp/native/coloring.cpp +622 -0
  215. warp/native/crt.cpp +51 -0
  216. warp/native/crt.h +568 -0
  217. warp/native/cuda_crt.h +1058 -0
  218. warp/native/cuda_util.cpp +677 -0
  219. warp/native/cuda_util.h +313 -0
  220. warp/native/error.cpp +77 -0
  221. warp/native/error.h +36 -0
  222. warp/native/exports.h +2023 -0
  223. warp/native/fabric.h +246 -0
  224. warp/native/hashgrid.cpp +311 -0
  225. warp/native/hashgrid.cu +89 -0
  226. warp/native/hashgrid.h +240 -0
  227. warp/native/initializer_array.h +41 -0
  228. warp/native/intersect.h +1253 -0
  229. warp/native/intersect_adj.h +375 -0
  230. warp/native/intersect_tri.h +348 -0
  231. warp/native/mat.h +5189 -0
  232. warp/native/mathdx.cpp +93 -0
  233. warp/native/matnn.h +221 -0
  234. warp/native/mesh.cpp +266 -0
  235. warp/native/mesh.cu +406 -0
  236. warp/native/mesh.h +2097 -0
  237. warp/native/nanovdb/GridHandle.h +533 -0
  238. warp/native/nanovdb/HostBuffer.h +591 -0
  239. warp/native/nanovdb/NanoVDB.h +6246 -0
  240. warp/native/nanovdb/NodeManager.h +323 -0
  241. warp/native/nanovdb/PNanoVDB.h +3390 -0
  242. warp/native/noise.h +859 -0
  243. warp/native/quat.h +1664 -0
  244. warp/native/rand.h +342 -0
  245. warp/native/range.h +145 -0
  246. warp/native/reduce.cpp +174 -0
  247. warp/native/reduce.cu +363 -0
  248. warp/native/runlength_encode.cpp +79 -0
  249. warp/native/runlength_encode.cu +61 -0
  250. warp/native/scan.cpp +47 -0
  251. warp/native/scan.cu +55 -0
  252. warp/native/scan.h +23 -0
  253. warp/native/solid_angle.h +466 -0
  254. warp/native/sort.cpp +251 -0
  255. warp/native/sort.cu +286 -0
  256. warp/native/sort.h +35 -0
  257. warp/native/sparse.cpp +241 -0
  258. warp/native/sparse.cu +435 -0
  259. warp/native/spatial.h +1306 -0
  260. warp/native/svd.h +727 -0
  261. warp/native/temp_buffer.h +46 -0
  262. warp/native/tile.h +4124 -0
  263. warp/native/tile_radix_sort.h +1112 -0
  264. warp/native/tile_reduce.h +838 -0
  265. warp/native/tile_scan.h +240 -0
  266. warp/native/tuple.h +189 -0
  267. warp/native/vec.h +2199 -0
  268. warp/native/version.h +23 -0
  269. warp/native/volume.cpp +501 -0
  270. warp/native/volume.cu +68 -0
  271. warp/native/volume.h +970 -0
  272. warp/native/volume_builder.cu +483 -0
  273. warp/native/volume_builder.h +52 -0
  274. warp/native/volume_impl.h +70 -0
  275. warp/native/warp.cpp +1143 -0
  276. warp/native/warp.cu +4604 -0
  277. warp/native/warp.h +358 -0
  278. warp/optim/__init__.py +20 -0
  279. warp/optim/adam.py +24 -0
  280. warp/optim/linear.py +35 -0
  281. warp/optim/sgd.py +24 -0
  282. warp/paddle.py +24 -0
  283. warp/py.typed +0 -0
  284. warp/render/__init__.py +22 -0
  285. warp/render/imgui_manager.py +29 -0
  286. warp/render/render_opengl.py +24 -0
  287. warp/render/render_usd.py +24 -0
  288. warp/render/utils.py +24 -0
  289. warp/sparse.py +51 -0
  290. warp/tape.py +24 -0
  291. warp/tests/__init__.py +1 -0
  292. warp/tests/__main__.py +4 -0
  293. warp/tests/assets/curlnoise_golden.npy +0 -0
  294. warp/tests/assets/mlp_golden.npy +0 -0
  295. warp/tests/assets/pixel.npy +0 -0
  296. warp/tests/assets/pnoise_golden.npy +0 -0
  297. warp/tests/assets/spiky.usd +0 -0
  298. warp/tests/assets/test_grid.nvdb +0 -0
  299. warp/tests/assets/test_index_grid.nvdb +0 -0
  300. warp/tests/assets/test_int32_grid.nvdb +0 -0
  301. warp/tests/assets/test_vec_grid.nvdb +0 -0
  302. warp/tests/assets/torus.nvdb +0 -0
  303. warp/tests/assets/torus.usda +105 -0
  304. warp/tests/aux_test_class_kernel.py +34 -0
  305. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  306. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  307. warp/tests/aux_test_dependent.py +29 -0
  308. warp/tests/aux_test_grad_customs.py +29 -0
  309. warp/tests/aux_test_instancing_gc.py +26 -0
  310. warp/tests/aux_test_module_aot.py +7 -0
  311. warp/tests/aux_test_module_unload.py +23 -0
  312. warp/tests/aux_test_name_clash1.py +40 -0
  313. warp/tests/aux_test_name_clash2.py +40 -0
  314. warp/tests/aux_test_reference.py +9 -0
  315. warp/tests/aux_test_reference_reference.py +8 -0
  316. warp/tests/aux_test_square.py +16 -0
  317. warp/tests/aux_test_unresolved_func.py +22 -0
  318. warp/tests/aux_test_unresolved_symbol.py +22 -0
  319. warp/tests/cuda/__init__.py +0 -0
  320. warp/tests/cuda/test_async.py +676 -0
  321. warp/tests/cuda/test_conditional_captures.py +1147 -0
  322. warp/tests/cuda/test_ipc.py +124 -0
  323. warp/tests/cuda/test_mempool.py +233 -0
  324. warp/tests/cuda/test_multigpu.py +169 -0
  325. warp/tests/cuda/test_peer.py +139 -0
  326. warp/tests/cuda/test_pinned.py +84 -0
  327. warp/tests/cuda/test_streams.py +691 -0
  328. warp/tests/geometry/__init__.py +0 -0
  329. warp/tests/geometry/test_bvh.py +335 -0
  330. warp/tests/geometry/test_hash_grid.py +259 -0
  331. warp/tests/geometry/test_marching_cubes.py +294 -0
  332. warp/tests/geometry/test_mesh.py +318 -0
  333. warp/tests/geometry/test_mesh_query_aabb.py +392 -0
  334. warp/tests/geometry/test_mesh_query_point.py +935 -0
  335. warp/tests/geometry/test_mesh_query_ray.py +323 -0
  336. warp/tests/geometry/test_volume.py +1103 -0
  337. warp/tests/geometry/test_volume_write.py +346 -0
  338. warp/tests/interop/__init__.py +0 -0
  339. warp/tests/interop/test_dlpack.py +730 -0
  340. warp/tests/interop/test_jax.py +1673 -0
  341. warp/tests/interop/test_paddle.py +800 -0
  342. warp/tests/interop/test_torch.py +1001 -0
  343. warp/tests/run_coverage_serial.py +39 -0
  344. warp/tests/test_adam.py +162 -0
  345. warp/tests/test_arithmetic.py +1096 -0
  346. warp/tests/test_array.py +3756 -0
  347. warp/tests/test_array_reduce.py +156 -0
  348. warp/tests/test_assert.py +303 -0
  349. warp/tests/test_atomic.py +336 -0
  350. warp/tests/test_atomic_bitwise.py +209 -0
  351. warp/tests/test_atomic_cas.py +312 -0
  352. warp/tests/test_bool.py +220 -0
  353. warp/tests/test_builtins_resolution.py +732 -0
  354. warp/tests/test_closest_point_edge_edge.py +327 -0
  355. warp/tests/test_codegen.py +974 -0
  356. warp/tests/test_codegen_instancing.py +1495 -0
  357. warp/tests/test_compile_consts.py +215 -0
  358. warp/tests/test_conditional.py +298 -0
  359. warp/tests/test_context.py +35 -0
  360. warp/tests/test_copy.py +319 -0
  361. warp/tests/test_ctypes.py +618 -0
  362. warp/tests/test_dense.py +73 -0
  363. warp/tests/test_devices.py +127 -0
  364. warp/tests/test_enum.py +136 -0
  365. warp/tests/test_examples.py +424 -0
  366. warp/tests/test_fabricarray.py +998 -0
  367. warp/tests/test_fast_math.py +72 -0
  368. warp/tests/test_fem.py +2204 -0
  369. warp/tests/test_fixedarray.py +229 -0
  370. warp/tests/test_fp16.py +136 -0
  371. warp/tests/test_func.py +501 -0
  372. warp/tests/test_future_annotations.py +100 -0
  373. warp/tests/test_generics.py +656 -0
  374. warp/tests/test_grad.py +893 -0
  375. warp/tests/test_grad_customs.py +339 -0
  376. warp/tests/test_grad_debug.py +341 -0
  377. warp/tests/test_implicit_init.py +411 -0
  378. warp/tests/test_import.py +45 -0
  379. warp/tests/test_indexedarray.py +1140 -0
  380. warp/tests/test_intersect.py +103 -0
  381. warp/tests/test_iter.py +76 -0
  382. warp/tests/test_large.py +177 -0
  383. warp/tests/test_launch.py +411 -0
  384. warp/tests/test_lerp.py +151 -0
  385. warp/tests/test_linear_solvers.py +223 -0
  386. warp/tests/test_lvalue.py +427 -0
  387. warp/tests/test_map.py +526 -0
  388. warp/tests/test_mat.py +3515 -0
  389. warp/tests/test_mat_assign_copy.py +178 -0
  390. warp/tests/test_mat_constructors.py +573 -0
  391. warp/tests/test_mat_lite.py +122 -0
  392. warp/tests/test_mat_scalar_ops.py +2913 -0
  393. warp/tests/test_math.py +212 -0
  394. warp/tests/test_module_aot.py +287 -0
  395. warp/tests/test_module_hashing.py +258 -0
  396. warp/tests/test_modules_lite.py +70 -0
  397. warp/tests/test_noise.py +252 -0
  398. warp/tests/test_operators.py +299 -0
  399. warp/tests/test_options.py +129 -0
  400. warp/tests/test_overwrite.py +551 -0
  401. warp/tests/test_print.py +408 -0
  402. warp/tests/test_quat.py +2653 -0
  403. warp/tests/test_quat_assign_copy.py +145 -0
  404. warp/tests/test_rand.py +339 -0
  405. warp/tests/test_reload.py +303 -0
  406. warp/tests/test_rounding.py +157 -0
  407. warp/tests/test_runlength_encode.py +196 -0
  408. warp/tests/test_scalar_ops.py +133 -0
  409. warp/tests/test_smoothstep.py +108 -0
  410. warp/tests/test_snippet.py +318 -0
  411. warp/tests/test_sparse.py +845 -0
  412. warp/tests/test_spatial.py +2859 -0
  413. warp/tests/test_spatial_assign_copy.py +160 -0
  414. warp/tests/test_special_values.py +361 -0
  415. warp/tests/test_static.py +640 -0
  416. warp/tests/test_struct.py +901 -0
  417. warp/tests/test_tape.py +242 -0
  418. warp/tests/test_transient_module.py +93 -0
  419. warp/tests/test_triangle_closest_point.py +192 -0
  420. warp/tests/test_tuple.py +361 -0
  421. warp/tests/test_types.py +615 -0
  422. warp/tests/test_utils.py +594 -0
  423. warp/tests/test_vec.py +1408 -0
  424. warp/tests/test_vec_assign_copy.py +143 -0
  425. warp/tests/test_vec_constructors.py +325 -0
  426. warp/tests/test_vec_lite.py +80 -0
  427. warp/tests/test_vec_scalar_ops.py +2327 -0
  428. warp/tests/test_verify_fp.py +100 -0
  429. warp/tests/test_version.py +75 -0
  430. warp/tests/tile/__init__.py +0 -0
  431. warp/tests/tile/test_tile.py +1519 -0
  432. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  433. warp/tests/tile/test_tile_cholesky.py +608 -0
  434. warp/tests/tile/test_tile_load.py +724 -0
  435. warp/tests/tile/test_tile_mathdx.py +156 -0
  436. warp/tests/tile/test_tile_matmul.py +179 -0
  437. warp/tests/tile/test_tile_mlp.py +400 -0
  438. warp/tests/tile/test_tile_reduce.py +950 -0
  439. warp/tests/tile/test_tile_shared_memory.py +376 -0
  440. warp/tests/tile/test_tile_sort.py +121 -0
  441. warp/tests/tile/test_tile_view.py +173 -0
  442. warp/tests/unittest_serial.py +47 -0
  443. warp/tests/unittest_suites.py +430 -0
  444. warp/tests/unittest_utils.py +469 -0
  445. warp/tests/walkthrough_debug.py +95 -0
  446. warp/torch.py +24 -0
  447. warp/types.py +51 -0
  448. warp/utils.py +31 -0
  449. warp_lang-1.10.0.dist-info/METADATA +459 -0
  450. warp_lang-1.10.0.dist-info/RECORD +468 -0
  451. warp_lang-1.10.0.dist-info/WHEEL +5 -0
  452. warp_lang-1.10.0.dist-info/licenses/LICENSE.md +176 -0
  453. warp_lang-1.10.0.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  454. warp_lang-1.10.0.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  455. warp_lang-1.10.0.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  456. warp_lang-1.10.0.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  457. warp_lang-1.10.0.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  458. warp_lang-1.10.0.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  459. warp_lang-1.10.0.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  460. warp_lang-1.10.0.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  461. warp_lang-1.10.0.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  462. warp_lang-1.10.0.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  463. warp_lang-1.10.0.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  464. warp_lang-1.10.0.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  465. warp_lang-1.10.0.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  466. warp_lang-1.10.0.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  467. warp_lang-1.10.0.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  468. warp_lang-1.10.0.dist-info/top_level.txt +1 -0
warp/native/sort.cu ADDED
@@ -0,0 +1,286 @@
1
+ /*
2
+ * SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: Apache-2.0
4
+ *
5
+ * Licensed under the Apache License, Version 2.0 (the "License");
6
+ * you may not use this file except in compliance with the License.
7
+ * You may obtain a copy of the License at
8
+ *
9
+ * http://www.apache.org/licenses/LICENSE-2.0
10
+ *
11
+ * Unless required by applicable law or agreed to in writing, software
12
+ * distributed under the License is distributed on an "AS IS" BASIS,
13
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ * See the License for the specific language governing permissions and
15
+ * limitations under the License.
16
+ */
17
+
18
+ #include "warp.h"
19
+ #include "cuda_util.h"
20
+ #include "sort.h"
21
+
22
+ #define THRUST_IGNORE_CUB_VERSION_CHECK
23
+
24
+ #include <cub/cub.cuh>
25
+
26
+ #include <unordered_map>
27
+
28
+ // temporary buffer for radix sort
29
+ struct RadixSortTemp
30
+ {
31
+ void* mem = NULL;
32
+ size_t size = 0;
33
+ };
34
+
35
+ // use unique temp buffers per CUDA stream to avoid race conditions
36
+ static std::unordered_map<void*, RadixSortTemp> g_radix_sort_temp_map;
37
+
38
+
39
+ template <typename KeyType>
40
+ void radix_sort_reserve_internal(void* context, int n, void** mem_out, size_t* size_out)
41
+ {
42
+ ContextGuard guard(context);
43
+
44
+ cub::DoubleBuffer<KeyType> d_keys;
45
+ cub::DoubleBuffer<int> d_values;
46
+
47
+ CUstream stream = static_cast<CUstream>(wp_cuda_stream_get_current());
48
+
49
+ // compute temporary memory required
50
+ size_t sort_temp_size;
51
+ check_cuda(cub::DeviceRadixSort::SortPairs(
52
+ NULL,
53
+ sort_temp_size,
54
+ d_keys,
55
+ d_values,
56
+ n, 0, sizeof(KeyType)*8,
57
+ stream));
58
+
59
+ RadixSortTemp& temp = g_radix_sort_temp_map[stream];
60
+
61
+ if (sort_temp_size > temp.size)
62
+ {
63
+ wp_free_device(WP_CURRENT_CONTEXT, temp.mem);
64
+ temp.mem = wp_alloc_device(WP_CURRENT_CONTEXT, sort_temp_size);
65
+ temp.size = sort_temp_size;
66
+ }
67
+
68
+ if (mem_out)
69
+ *mem_out = temp.mem;
70
+ if (size_out)
71
+ *size_out = temp.size;
72
+ }
73
+
74
+ void radix_sort_reserve(void* context, int n, void** mem_out, size_t* size_out)
75
+ {
76
+ radix_sort_reserve_internal<int>(context, n, mem_out, size_out);
77
+ }
78
+
79
+ void radix_sort_release(void* context, void* stream)
80
+ {
81
+ // release temporary buffer for the given stream, if it exists
82
+ auto it = g_radix_sort_temp_map.find(stream);
83
+ if (it != g_radix_sort_temp_map.end())
84
+ {
85
+ wp_free_device(context, it->second.mem);
86
+ g_radix_sort_temp_map.erase(it);
87
+ }
88
+ }
89
+
90
+ template <typename KeyType>
91
+ void radix_sort_pairs_device(void* context, KeyType* keys, int* values, int n)
92
+ {
93
+ ContextGuard guard(context);
94
+
95
+ cub::DoubleBuffer<KeyType> d_keys(keys, keys + n);
96
+ cub::DoubleBuffer<int> d_values(values, values + n);
97
+
98
+ RadixSortTemp temp;
99
+ radix_sort_reserve_internal<KeyType>(WP_CURRENT_CONTEXT, n, &temp.mem, &temp.size);
100
+
101
+ // sort
102
+ check_cuda(cub::DeviceRadixSort::SortPairs(
103
+ temp.mem,
104
+ temp.size,
105
+ d_keys,
106
+ d_values,
107
+ n, 0, sizeof(KeyType)*8,
108
+ (cudaStream_t)wp_cuda_stream_get_current()));
109
+
110
+ if (d_keys.Current() != keys)
111
+ wp_memcpy_d2d(WP_CURRENT_CONTEXT, keys, d_keys.Current(), sizeof(KeyType)*n);
112
+
113
+ if (d_values.Current() != values)
114
+ wp_memcpy_d2d(WP_CURRENT_CONTEXT, values, d_values.Current(), sizeof(int)*n);
115
+ }
116
+
117
+ void radix_sort_pairs_device(void* context, int* keys, int* values, int n)
118
+ {
119
+ radix_sort_pairs_device<int>(context, keys, values, n);
120
+ }
121
+
122
+ void radix_sort_pairs_device(void* context, float* keys, int* values, int n)
123
+ {
124
+ radix_sort_pairs_device<float>(context, keys, values, n);
125
+ }
126
+
127
+ void radix_sort_pairs_device(void* context, int64_t* keys, int* values, int n)
128
+ {
129
+ radix_sort_pairs_device<int64_t>(context, keys, values, n);
130
+ }
131
+
132
+ void wp_radix_sort_pairs_int_device(uint64_t keys, uint64_t values, int n)
133
+ {
134
+ radix_sort_pairs_device(
135
+ WP_CURRENT_CONTEXT,
136
+ reinterpret_cast<int *>(keys),
137
+ reinterpret_cast<int *>(values), n);
138
+ }
139
+
140
+ void wp_radix_sort_pairs_float_device(uint64_t keys, uint64_t values, int n)
141
+ {
142
+ radix_sort_pairs_device(
143
+ WP_CURRENT_CONTEXT,
144
+ reinterpret_cast<float *>(keys),
145
+ reinterpret_cast<int *>(values), n);
146
+ }
147
+
148
+ void wp_radix_sort_pairs_int64_device(uint64_t keys, uint64_t values, int n)
149
+ {
150
+ radix_sort_pairs_device(
151
+ WP_CURRENT_CONTEXT,
152
+ reinterpret_cast<int64_t *>(keys),
153
+ reinterpret_cast<int *>(values), n);
154
+ }
155
+
156
+ void segmented_sort_reserve(void* context, int n, int num_segments, void** mem_out, size_t* size_out)
157
+ {
158
+ ContextGuard guard(context);
159
+
160
+ cub::DoubleBuffer<int> d_keys;
161
+ cub::DoubleBuffer<int> d_values;
162
+
163
+ int* start_indices = NULL;
164
+ int* end_indices = NULL;
165
+
166
+ CUstream stream = static_cast<CUstream>(wp_cuda_stream_get_current());
167
+
168
+ // compute temporary memory required
169
+ size_t sort_temp_size;
170
+ check_cuda(cub::DeviceSegmentedRadixSort::SortPairs(
171
+ NULL,
172
+ sort_temp_size,
173
+ d_keys,
174
+ d_values,
175
+ n,
176
+ num_segments,
177
+ start_indices,
178
+ end_indices,
179
+ 0,
180
+ 32,
181
+ stream));
182
+
183
+ RadixSortTemp& temp = g_radix_sort_temp_map[stream];
184
+
185
+ if (sort_temp_size > temp.size)
186
+ {
187
+ wp_free_device(WP_CURRENT_CONTEXT, temp.mem);
188
+ temp.mem = wp_alloc_device(WP_CURRENT_CONTEXT, sort_temp_size);
189
+ temp.size = sort_temp_size;
190
+ }
191
+
192
+ if (mem_out)
193
+ *mem_out = temp.mem;
194
+ if (size_out)
195
+ *size_out = temp.size;
196
+ }
197
+
198
+ // segment_start_indices and segment_end_indices are arrays of length num_segments, where segment_start_indices[i] is the index of the first element
199
+ // in the i-th segment and segment_end_indices[i] is the index after the last element in the i-th segment
200
+ // https://nvidia.github.io/cccl/cub/api/structcub_1_1DeviceSegmentedRadixSort.html
201
+ void segmented_sort_pairs_device(void* context, float* keys, int* values, int n, int* segment_start_indices, int* segment_end_indices, int num_segments)
202
+ {
203
+ ContextGuard guard(context);
204
+
205
+ cub::DoubleBuffer<float> d_keys(keys, keys + n);
206
+ cub::DoubleBuffer<int> d_values(values, values + n);
207
+
208
+ RadixSortTemp temp;
209
+ segmented_sort_reserve(WP_CURRENT_CONTEXT, n, num_segments, &temp.mem, &temp.size);
210
+
211
+ // sort
212
+ check_cuda(cub::DeviceSegmentedRadixSort::SortPairs(
213
+ temp.mem,
214
+ temp.size,
215
+ d_keys,
216
+ d_values,
217
+ n,
218
+ num_segments,
219
+ segment_start_indices,
220
+ segment_end_indices,
221
+ 0,
222
+ 32,
223
+ (cudaStream_t)wp_cuda_stream_get_current()));
224
+
225
+ if (d_keys.Current() != keys)
226
+ wp_memcpy_d2d(WP_CURRENT_CONTEXT, keys, d_keys.Current(), sizeof(float)*n);
227
+
228
+ if (d_values.Current() != values)
229
+ wp_memcpy_d2d(WP_CURRENT_CONTEXT, values, d_values.Current(), sizeof(int)*n);
230
+ }
231
+
232
+ void wp_segmented_sort_pairs_float_device(uint64_t keys, uint64_t values, int n, uint64_t segment_start_indices, uint64_t segment_end_indices, int num_segments)
233
+ {
234
+ segmented_sort_pairs_device(
235
+ WP_CURRENT_CONTEXT,
236
+ reinterpret_cast<float *>(keys),
237
+ reinterpret_cast<int *>(values), n,
238
+ reinterpret_cast<int *>(segment_start_indices),
239
+ reinterpret_cast<int *>(segment_end_indices),
240
+ num_segments);
241
+ }
242
+
243
+ // segment_indices is an array of length num_segments + 1, where segment_indices[i] is the index of the first element in the i-th segment
244
+ // The end of a segment is given by segment_indices[i+1]
245
+ // https://nvidia.github.io/cccl/cub/api/structcub_1_1DeviceSegmentedSort.html#a-simple-example
246
+ void segmented_sort_pairs_device(void* context, int* keys, int* values, int n, int* segment_start_indices, int* segment_end_indices, int num_segments)
247
+ {
248
+ ContextGuard guard(context);
249
+
250
+ cub::DoubleBuffer<int> d_keys(keys, keys + n);
251
+ cub::DoubleBuffer<int> d_values(values, values + n);
252
+
253
+ RadixSortTemp temp;
254
+ segmented_sort_reserve(WP_CURRENT_CONTEXT, n, num_segments, &temp.mem, &temp.size);
255
+
256
+ // sort
257
+ check_cuda(cub::DeviceSegmentedRadixSort::SortPairs(
258
+ temp.mem,
259
+ temp.size,
260
+ d_keys,
261
+ d_values,
262
+ n,
263
+ num_segments,
264
+ segment_start_indices,
265
+ segment_end_indices,
266
+ 0,
267
+ 32,
268
+ (cudaStream_t)wp_cuda_stream_get_current()));
269
+
270
+ if (d_keys.Current() != keys)
271
+ wp_memcpy_d2d(WP_CURRENT_CONTEXT, keys, d_keys.Current(), sizeof(float)*n);
272
+
273
+ if (d_values.Current() != values)
274
+ wp_memcpy_d2d(WP_CURRENT_CONTEXT, values, d_values.Current(), sizeof(int)*n);
275
+ }
276
+
277
+ void wp_segmented_sort_pairs_int_device(uint64_t keys, uint64_t values, int n, uint64_t segment_start_indices, uint64_t segment_end_indices, int num_segments)
278
+ {
279
+ segmented_sort_pairs_device(
280
+ WP_CURRENT_CONTEXT,
281
+ reinterpret_cast<int *>(keys),
282
+ reinterpret_cast<int *>(values), n,
283
+ reinterpret_cast<int *>(segment_start_indices),
284
+ reinterpret_cast<int *>(segment_end_indices),
285
+ num_segments);
286
+ }
warp/native/sort.h ADDED
@@ -0,0 +1,35 @@
1
+ /*
2
+ * SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: Apache-2.0
4
+ *
5
+ * Licensed under the Apache License, Version 2.0 (the "License");
6
+ * you may not use this file except in compliance with the License.
7
+ * You may obtain a copy of the License at
8
+ *
9
+ * http://www.apache.org/licenses/LICENSE-2.0
10
+ *
11
+ * Unless required by applicable law or agreed to in writing, software
12
+ * distributed under the License is distributed on an "AS IS" BASIS,
13
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ * See the License for the specific language governing permissions and
15
+ * limitations under the License.
16
+ */
17
+
18
+ #pragma once
19
+
20
+ #include <stddef.h>
21
+
22
+ void radix_sort_reserve(void* context, int n, void** mem_out=NULL, size_t* size_out=NULL);
23
+ void radix_sort_release(void* context, void* stream);
24
+
25
+ void radix_sort_pairs_host(int* keys, int* values, int n);
26
+ void radix_sort_pairs_host(float* keys, int* values, int n);
27
+ void radix_sort_pairs_host(int64_t* keys, int* values, int n);
28
+ void radix_sort_pairs_device(void* context, int* keys, int* values, int n);
29
+ void radix_sort_pairs_device(void* context, float* keys, int* values, int n);
30
+ void radix_sort_pairs_device(void* context, int64_t* keys, int* values, int n);
31
+
32
+ void segmented_sort_pairs_host(float* keys, int* values, int n, int* segment_start_indices, int* segment_end_indices, int num_segments);
33
+ void segmented_sort_pairs_device(void* context, float* keys, int* values, int n, int* segment_start_indices, int* segment_end_indices, int num_segments);
34
+ void segmented_sort_pairs_host(void* context, int* keys, int* values, int n, int* segment_start_indices, int* segment_end_indices, int num_segments);
35
+ void segmented_sort_pairs_device(void* context, int* keys, int* values, int n, int* segment_start_indices, int* segment_end_indices, int num_segments);
warp/native/sparse.cpp ADDED
@@ -0,0 +1,241 @@
1
+ /*
2
+ * SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: Apache-2.0
4
+ *
5
+ * Licensed under the Apache License, Version 2.0 (the "License");
6
+ * you may not use this file except in compliance with the License.
7
+ * You may obtain a copy of the License at
8
+ *
9
+ * http://www.apache.org/licenses/LICENSE-2.0
10
+ *
11
+ * Unless required by applicable law or agreed to in writing, software
12
+ * distributed under the License is distributed on an "AS IS" BASIS,
13
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ * See the License for the specific language governing permissions and
15
+ * limitations under the License.
16
+ */
17
+
18
+ #include "warp.h"
19
+
20
+ #include <algorithm>
21
+ #include <cstddef>
22
+ #include <numeric>
23
+ #include <vector>
24
+
25
+ namespace
26
+ {
27
+
28
+ template <typename T> bool bsr_block_is_zero(int block_idx, int block_size, const void* values, const uint64_t scalar_zero_mask)
29
+ {
30
+ const T* block_values = static_cast<const T*>(values) + block_idx * block_size;
31
+ const T zero_mask = static_cast<T>(scalar_zero_mask);
32
+
33
+ return std::all_of(block_values, block_values + block_size, [zero_mask](T v) { return (v & zero_mask) == T(0); });
34
+ }
35
+
36
+ } // namespace
37
+
38
+
39
+ WP_API void wp_bsr_matrix_from_triplets_host(
40
+ int block_size,
41
+ int scalar_size_in_bytes,
42
+ int row_count,
43
+ int col_count,
44
+ int nnz,
45
+ const int* tpl_nnz,
46
+ const int* tpl_rows,
47
+ const int* tpl_columns,
48
+ const void* tpl_values,
49
+ const uint64_t scalar_zero_mask,
50
+ bool masked_topology,
51
+ int* tpl_block_offsets,
52
+ int* tpl_block_indices,
53
+ int* bsr_offsets,
54
+ int* bsr_columns,
55
+ int* bsr_nnz,
56
+ void* bsr_nnz_event)
57
+ {
58
+ if (tpl_nnz != nullptr)
59
+ {
60
+ nnz = *tpl_nnz;
61
+ }
62
+
63
+ // allocate temporary buffers if not provided
64
+ bool return_summed_blocks = tpl_block_offsets != nullptr && tpl_block_indices != nullptr;
65
+ if (!return_summed_blocks)
66
+ {
67
+ tpl_block_offsets = static_cast<int*>(wp_alloc_host(size_t(nnz) * sizeof(int)));
68
+ tpl_block_indices = static_cast<int*>(wp_alloc_host(size_t(nnz) * sizeof(int)));
69
+ }
70
+
71
+ std::iota(tpl_block_indices, tpl_block_indices + nnz, 0);
72
+
73
+ // remove invalid indices / indices not in mask
74
+ auto discard_invalid_block = [&](int i) -> bool
75
+ {
76
+ const int row = tpl_rows[i];
77
+ const int col = tpl_columns[i];
78
+ if (row < 0 || row >= row_count || col < 0 || col >= col_count)
79
+ {
80
+ return true;
81
+ }
82
+
83
+ if (!masked_topology)
84
+ {
85
+ return false;
86
+ }
87
+
88
+ const int* beg = bsr_columns + bsr_offsets[row];
89
+ const int* end = bsr_columns + bsr_offsets[row + 1];
90
+ const int* block = std::lower_bound(beg, end, col);
91
+ return block == end || *block != col;
92
+ };
93
+
94
+ int* valid_indices_end = std::remove_if(tpl_block_indices, tpl_block_indices + nnz, discard_invalid_block);
95
+
96
+ // remove zero blocks
97
+ if (tpl_values != nullptr && scalar_zero_mask != 0)
98
+ {
99
+ switch (scalar_size_in_bytes)
100
+ {
101
+ case sizeof(uint8_t):
102
+ valid_indices_end = std::remove_if(tpl_block_indices, valid_indices_end, [block_size, tpl_values, scalar_zero_mask](uint32_t i) { return bsr_block_is_zero<uint8_t>(i, block_size, tpl_values, scalar_zero_mask); });
103
+ break;
104
+ case sizeof(uint16_t):
105
+ valid_indices_end = std::remove_if(tpl_block_indices, valid_indices_end, [block_size, tpl_values, scalar_zero_mask](uint32_t i) { return bsr_block_is_zero<uint16_t>(i, block_size, tpl_values, scalar_zero_mask); });
106
+ break;
107
+ case sizeof(uint32_t):
108
+ valid_indices_end = std::remove_if(tpl_block_indices, valid_indices_end, [block_size, tpl_values, scalar_zero_mask](uint32_t i) { return bsr_block_is_zero<uint32_t>(i, block_size, tpl_values, scalar_zero_mask); });
109
+ break;
110
+ case sizeof(uint64_t):
111
+ valid_indices_end = std::remove_if(tpl_block_indices, valid_indices_end, [block_size, tpl_values, scalar_zero_mask](uint32_t i) { return bsr_block_is_zero<uint64_t>(i, block_size, tpl_values, scalar_zero_mask); });
112
+ break;
113
+ }
114
+ }
115
+
116
+ // sort block indices according to lexico order
117
+ std::sort(tpl_block_indices, valid_indices_end, [tpl_rows, tpl_columns](int i, int j) -> bool
118
+ { return tpl_rows[i] < tpl_rows[j] || (tpl_rows[i] == tpl_rows[j] && tpl_columns[i] < tpl_columns[j]); });
119
+
120
+ // accumulate blocks at same locations, count blocks per row
121
+ std::fill_n(bsr_offsets, row_count + 1, 0);
122
+
123
+ int current_row = -1;
124
+ int current_col = -1;
125
+ int current_block_idx = -1;
126
+
127
+ for (int *block = tpl_block_indices, *block_offset = tpl_block_offsets ; block != valid_indices_end ; ++ block)
128
+ {
129
+ int32_t idx = *block;
130
+ int row = tpl_rows[idx];
131
+ int col = tpl_columns[idx];
132
+
133
+ if (row != current_row || col != current_col)
134
+ {
135
+ *(bsr_columns++) = col;
136
+
137
+ ++bsr_offsets[row + 1];
138
+
139
+ if(current_row == -1) {
140
+ *block_offset = 0;
141
+ } else {
142
+ *(block_offset+1) = *block_offset;
143
+ ++block_offset;
144
+ }
145
+
146
+ current_row = row;
147
+ current_col = col;
148
+ }
149
+
150
+ ++(*block_offset);
151
+ }
152
+
153
+ // build postfix sum of row counts
154
+ std::partial_sum(bsr_offsets, bsr_offsets + row_count + 1, bsr_offsets);
155
+
156
+ if(!return_summed_blocks)
157
+ {
158
+ // free our temporary buffers
159
+ wp_free_host(tpl_block_offsets);
160
+ wp_free_host(tpl_block_indices);
161
+ }
162
+
163
+ if (bsr_nnz != nullptr)
164
+ {
165
+ *bsr_nnz = bsr_offsets[row_count];
166
+ }
167
+ }
168
+
169
+ WP_API void wp_bsr_transpose_host(
170
+ int row_count, int col_count, int nnz,
171
+ const int* bsr_offsets, const int* bsr_columns,
172
+ int* transposed_bsr_offsets,
173
+ int* transposed_bsr_columns,
174
+ int* block_indices
175
+ )
176
+ {
177
+ nnz = bsr_offsets[row_count];
178
+
179
+ std::vector<int> bsr_rows(nnz);
180
+ std::iota(block_indices, block_indices + nnz, 0);
181
+
182
+ // Fill row indices from offsets
183
+ for (int row = 0; row < row_count; ++row)
184
+ {
185
+ std::fill(bsr_rows.begin() + bsr_offsets[row], bsr_rows.begin() + bsr_offsets[row + 1], row);
186
+ }
187
+
188
+ // sort block indices according to (transposed) lexico order
189
+ std::sort(
190
+ block_indices, block_indices + nnz, [&bsr_rows, bsr_columns](int i, int j) -> bool
191
+ { return bsr_columns[i] < bsr_columns[j] || (bsr_columns[i] == bsr_columns[j] && bsr_rows[i] < bsr_rows[j]); });
192
+
193
+ // Count blocks per column and transpose blocks
194
+ std::fill_n(transposed_bsr_offsets, col_count + 1, 0);
195
+
196
+ for (int i = 0; i < nnz; ++i)
197
+ {
198
+ int idx = block_indices[i];
199
+ int row = bsr_rows[idx];
200
+ int col = bsr_columns[idx];
201
+
202
+ ++transposed_bsr_offsets[col + 1];
203
+ transposed_bsr_columns[i] = row;
204
+ }
205
+
206
+ // build postfix sum of column counts
207
+ std::partial_sum(transposed_bsr_offsets, transposed_bsr_offsets + col_count + 1, transposed_bsr_offsets);
208
+
209
+ }
210
+
211
+ #if !WP_ENABLE_CUDA
212
+ WP_API void wp_bsr_matrix_from_triplets_device(
213
+ int block_size,
214
+ int scalar_size_in_bytes,
215
+ int row_count,
216
+ int col_count,
217
+ int tpl_nnz_upper_bound,
218
+ const int* tpl_nnz,
219
+ const int* tpl_rows,
220
+ const int* tpl_columns,
221
+ const void* tpl_values,
222
+ const uint64_t scalar_zero_mask,
223
+ bool masked_topology,
224
+ int* summed_block_offsets,
225
+ int* summed_block_indices,
226
+ int* bsr_offsets,
227
+ int* bsr_columns,
228
+ int* bsr_nnz,
229
+ void* bsr_nnz_event) {}
230
+
231
+
232
+ WP_API void wp_bsr_transpose_device(
233
+ int row_count, int col_count, int nnz,
234
+ const int* bsr_offsets, const int* bsr_columns,
235
+ int* transposed_bsr_offsets,
236
+ int* transposed_bsr_columns,
237
+ int* src_block_indices) {}
238
+
239
+
240
+
241
+ #endif