warp-lang 1.10.0__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (468) hide show
  1. warp/__init__.py +334 -0
  2. warp/__init__.pyi +5856 -0
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1077 -0
  5. warp/_src/build.py +620 -0
  6. warp/_src/build_dll.py +642 -0
  7. warp/_src/builtins.py +10555 -0
  8. warp/_src/codegen.py +4361 -0
  9. warp/_src/config.py +178 -0
  10. warp/_src/constants.py +59 -0
  11. warp/_src/context.py +8352 -0
  12. warp/_src/dlpack.py +464 -0
  13. warp/_src/fabric.py +362 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +510 -0
  16. warp/_src/fem/cache.py +689 -0
  17. warp/_src/fem/dirichlet.py +190 -0
  18. warp/_src/fem/domain.py +553 -0
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +703 -0
  21. warp/_src/fem/field/nodal_field.py +403 -0
  22. warp/_src/fem/field/restriction.py +39 -0
  23. warp/_src/fem/field/virtual.py +1021 -0
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/_src/fem/geometry/adaptive_nanogrid.py +782 -0
  26. warp/_src/fem/geometry/closest_point.py +99 -0
  27. warp/_src/fem/geometry/deformed_geometry.py +277 -0
  28. warp/_src/fem/geometry/element.py +854 -0
  29. warp/_src/fem/geometry/geometry.py +693 -0
  30. warp/_src/fem/geometry/grid_2d.py +478 -0
  31. warp/_src/fem/geometry/grid_3d.py +539 -0
  32. warp/_src/fem/geometry/hexmesh.py +956 -0
  33. warp/_src/fem/geometry/nanogrid.py +660 -0
  34. warp/_src/fem/geometry/partition.py +483 -0
  35. warp/_src/fem/geometry/quadmesh.py +597 -0
  36. warp/_src/fem/geometry/tetmesh.py +762 -0
  37. warp/_src/fem/geometry/trimesh.py +588 -0
  38. warp/_src/fem/integrate.py +2507 -0
  39. warp/_src/fem/linalg.py +385 -0
  40. warp/_src/fem/operator.py +398 -0
  41. warp/_src/fem/polynomial.py +231 -0
  42. warp/_src/fem/quadrature/__init__.py +17 -0
  43. warp/_src/fem/quadrature/pic_quadrature.py +318 -0
  44. warp/_src/fem/quadrature/quadrature.py +665 -0
  45. warp/_src/fem/space/__init__.py +248 -0
  46. warp/_src/fem/space/basis_function_space.py +499 -0
  47. warp/_src/fem/space/basis_space.py +681 -0
  48. warp/_src/fem/space/dof_mapper.py +253 -0
  49. warp/_src/fem/space/function_space.py +312 -0
  50. warp/_src/fem/space/grid_2d_function_space.py +179 -0
  51. warp/_src/fem/space/grid_3d_function_space.py +229 -0
  52. warp/_src/fem/space/hexmesh_function_space.py +255 -0
  53. warp/_src/fem/space/nanogrid_function_space.py +199 -0
  54. warp/_src/fem/space/partition.py +435 -0
  55. warp/_src/fem/space/quadmesh_function_space.py +222 -0
  56. warp/_src/fem/space/restriction.py +221 -0
  57. warp/_src/fem/space/shape/__init__.py +152 -0
  58. warp/_src/fem/space/shape/cube_shape_function.py +1107 -0
  59. warp/_src/fem/space/shape/shape_function.py +134 -0
  60. warp/_src/fem/space/shape/square_shape_function.py +928 -0
  61. warp/_src/fem/space/shape/tet_shape_function.py +829 -0
  62. warp/_src/fem/space/shape/triangle_shape_function.py +674 -0
  63. warp/_src/fem/space/tetmesh_function_space.py +270 -0
  64. warp/_src/fem/space/topology.py +461 -0
  65. warp/_src/fem/space/trimesh_function_space.py +193 -0
  66. warp/_src/fem/types.py +114 -0
  67. warp/_src/fem/utils.py +488 -0
  68. warp/_src/jax.py +188 -0
  69. warp/_src/jax_experimental/__init__.py +14 -0
  70. warp/_src/jax_experimental/custom_call.py +389 -0
  71. warp/_src/jax_experimental/ffi.py +1286 -0
  72. warp/_src/jax_experimental/xla_ffi.py +658 -0
  73. warp/_src/marching_cubes.py +710 -0
  74. warp/_src/math.py +416 -0
  75. warp/_src/optim/__init__.py +14 -0
  76. warp/_src/optim/adam.py +165 -0
  77. warp/_src/optim/linear.py +1608 -0
  78. warp/_src/optim/sgd.py +114 -0
  79. warp/_src/paddle.py +408 -0
  80. warp/_src/render/__init__.py +14 -0
  81. warp/_src/render/imgui_manager.py +291 -0
  82. warp/_src/render/render_opengl.py +3638 -0
  83. warp/_src/render/render_usd.py +939 -0
  84. warp/_src/render/utils.py +162 -0
  85. warp/_src/sparse.py +2718 -0
  86. warp/_src/tape.py +1208 -0
  87. warp/_src/thirdparty/__init__.py +0 -0
  88. warp/_src/thirdparty/appdirs.py +598 -0
  89. warp/_src/thirdparty/dlpack.py +145 -0
  90. warp/_src/thirdparty/unittest_parallel.py +676 -0
  91. warp/_src/torch.py +393 -0
  92. warp/_src/types.py +5888 -0
  93. warp/_src/utils.py +1695 -0
  94. warp/autograd.py +33 -0
  95. warp/bin/libwarp-clang.dylib +0 -0
  96. warp/bin/libwarp.dylib +0 -0
  97. warp/build.py +29 -0
  98. warp/build_dll.py +24 -0
  99. warp/codegen.py +24 -0
  100. warp/constants.py +24 -0
  101. warp/context.py +33 -0
  102. warp/dlpack.py +24 -0
  103. warp/examples/__init__.py +24 -0
  104. warp/examples/assets/bear.usd +0 -0
  105. warp/examples/assets/bunny.usd +0 -0
  106. warp/examples/assets/cube.usd +0 -0
  107. warp/examples/assets/nonuniform.usd +0 -0
  108. warp/examples/assets/nvidia_logo.png +0 -0
  109. warp/examples/assets/pixel.jpg +0 -0
  110. warp/examples/assets/rocks.nvdb +0 -0
  111. warp/examples/assets/rocks.usd +0 -0
  112. warp/examples/assets/sphere.usd +0 -0
  113. warp/examples/assets/square_cloth.usd +0 -0
  114. warp/examples/benchmarks/benchmark_api.py +389 -0
  115. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  116. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  117. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  118. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  119. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  120. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  121. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  122. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  123. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  124. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  125. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  126. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  127. warp/examples/benchmarks/benchmark_launches.py +301 -0
  128. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  129. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  130. warp/examples/browse.py +37 -0
  131. warp/examples/core/example_cupy.py +86 -0
  132. warp/examples/core/example_dem.py +241 -0
  133. warp/examples/core/example_fluid.py +299 -0
  134. warp/examples/core/example_graph_capture.py +150 -0
  135. warp/examples/core/example_marching_cubes.py +195 -0
  136. warp/examples/core/example_mesh.py +180 -0
  137. warp/examples/core/example_mesh_intersect.py +211 -0
  138. warp/examples/core/example_nvdb.py +182 -0
  139. warp/examples/core/example_raycast.py +111 -0
  140. warp/examples/core/example_raymarch.py +205 -0
  141. warp/examples/core/example_render_opengl.py +290 -0
  142. warp/examples/core/example_sample_mesh.py +300 -0
  143. warp/examples/core/example_sph.py +411 -0
  144. warp/examples/core/example_spin_lock.py +93 -0
  145. warp/examples/core/example_torch.py +211 -0
  146. warp/examples/core/example_wave.py +269 -0
  147. warp/examples/core/example_work_queue.py +118 -0
  148. warp/examples/distributed/example_jacobi_mpi.py +506 -0
  149. warp/examples/fem/example_adaptive_grid.py +286 -0
  150. warp/examples/fem/example_apic_fluid.py +469 -0
  151. warp/examples/fem/example_burgers.py +261 -0
  152. warp/examples/fem/example_convection_diffusion.py +181 -0
  153. warp/examples/fem/example_convection_diffusion_dg.py +225 -0
  154. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  155. warp/examples/fem/example_deformed_geometry.py +172 -0
  156. warp/examples/fem/example_diffusion.py +196 -0
  157. warp/examples/fem/example_diffusion_3d.py +225 -0
  158. warp/examples/fem/example_diffusion_mgpu.py +225 -0
  159. warp/examples/fem/example_distortion_energy.py +228 -0
  160. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  161. warp/examples/fem/example_magnetostatics.py +242 -0
  162. warp/examples/fem/example_mixed_elasticity.py +293 -0
  163. warp/examples/fem/example_navier_stokes.py +263 -0
  164. warp/examples/fem/example_nonconforming_contact.py +300 -0
  165. warp/examples/fem/example_stokes.py +213 -0
  166. warp/examples/fem/example_stokes_transfer.py +262 -0
  167. warp/examples/fem/example_streamlines.py +357 -0
  168. warp/examples/fem/utils.py +1047 -0
  169. warp/examples/interop/example_jax_callable.py +146 -0
  170. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  171. warp/examples/interop/example_jax_kernel.py +232 -0
  172. warp/examples/optim/example_diffray.py +561 -0
  173. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  174. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  175. warp/examples/tile/example_tile_cholesky.py +88 -0
  176. warp/examples/tile/example_tile_convolution.py +66 -0
  177. warp/examples/tile/example_tile_fft.py +55 -0
  178. warp/examples/tile/example_tile_filtering.py +113 -0
  179. warp/examples/tile/example_tile_matmul.py +85 -0
  180. warp/examples/tile/example_tile_mcgp.py +191 -0
  181. warp/examples/tile/example_tile_mlp.py +385 -0
  182. warp/examples/tile/example_tile_nbody.py +199 -0
  183. warp/fabric.py +24 -0
  184. warp/fem/__init__.py +173 -0
  185. warp/fem/adaptivity.py +26 -0
  186. warp/fem/cache.py +30 -0
  187. warp/fem/dirichlet.py +24 -0
  188. warp/fem/field/__init__.py +24 -0
  189. warp/fem/field/field.py +26 -0
  190. warp/fem/geometry/__init__.py +21 -0
  191. warp/fem/geometry/closest_point.py +31 -0
  192. warp/fem/linalg.py +38 -0
  193. warp/fem/operator.py +32 -0
  194. warp/fem/polynomial.py +29 -0
  195. warp/fem/space/__init__.py +22 -0
  196. warp/fem/space/basis_space.py +24 -0
  197. warp/fem/space/shape/__init__.py +68 -0
  198. warp/fem/space/topology.py +24 -0
  199. warp/fem/types.py +24 -0
  200. warp/fem/utils.py +32 -0
  201. warp/jax.py +29 -0
  202. warp/jax_experimental/__init__.py +29 -0
  203. warp/jax_experimental/custom_call.py +29 -0
  204. warp/jax_experimental/ffi.py +39 -0
  205. warp/jax_experimental/xla_ffi.py +24 -0
  206. warp/marching_cubes.py +24 -0
  207. warp/math.py +37 -0
  208. warp/native/array.h +1687 -0
  209. warp/native/builtin.h +2327 -0
  210. warp/native/bvh.cpp +562 -0
  211. warp/native/bvh.cu +826 -0
  212. warp/native/bvh.h +555 -0
  213. warp/native/clang/clang.cpp +541 -0
  214. warp/native/coloring.cpp +622 -0
  215. warp/native/crt.cpp +51 -0
  216. warp/native/crt.h +568 -0
  217. warp/native/cuda_crt.h +1058 -0
  218. warp/native/cuda_util.cpp +677 -0
  219. warp/native/cuda_util.h +313 -0
  220. warp/native/error.cpp +77 -0
  221. warp/native/error.h +36 -0
  222. warp/native/exports.h +2023 -0
  223. warp/native/fabric.h +246 -0
  224. warp/native/hashgrid.cpp +311 -0
  225. warp/native/hashgrid.cu +89 -0
  226. warp/native/hashgrid.h +240 -0
  227. warp/native/initializer_array.h +41 -0
  228. warp/native/intersect.h +1253 -0
  229. warp/native/intersect_adj.h +375 -0
  230. warp/native/intersect_tri.h +348 -0
  231. warp/native/mat.h +5189 -0
  232. warp/native/mathdx.cpp +93 -0
  233. warp/native/matnn.h +221 -0
  234. warp/native/mesh.cpp +266 -0
  235. warp/native/mesh.cu +406 -0
  236. warp/native/mesh.h +2097 -0
  237. warp/native/nanovdb/GridHandle.h +533 -0
  238. warp/native/nanovdb/HostBuffer.h +591 -0
  239. warp/native/nanovdb/NanoVDB.h +6246 -0
  240. warp/native/nanovdb/NodeManager.h +323 -0
  241. warp/native/nanovdb/PNanoVDB.h +3390 -0
  242. warp/native/noise.h +859 -0
  243. warp/native/quat.h +1664 -0
  244. warp/native/rand.h +342 -0
  245. warp/native/range.h +145 -0
  246. warp/native/reduce.cpp +174 -0
  247. warp/native/reduce.cu +363 -0
  248. warp/native/runlength_encode.cpp +79 -0
  249. warp/native/runlength_encode.cu +61 -0
  250. warp/native/scan.cpp +47 -0
  251. warp/native/scan.cu +55 -0
  252. warp/native/scan.h +23 -0
  253. warp/native/solid_angle.h +466 -0
  254. warp/native/sort.cpp +251 -0
  255. warp/native/sort.cu +286 -0
  256. warp/native/sort.h +35 -0
  257. warp/native/sparse.cpp +241 -0
  258. warp/native/sparse.cu +435 -0
  259. warp/native/spatial.h +1306 -0
  260. warp/native/svd.h +727 -0
  261. warp/native/temp_buffer.h +46 -0
  262. warp/native/tile.h +4124 -0
  263. warp/native/tile_radix_sort.h +1112 -0
  264. warp/native/tile_reduce.h +838 -0
  265. warp/native/tile_scan.h +240 -0
  266. warp/native/tuple.h +189 -0
  267. warp/native/vec.h +2199 -0
  268. warp/native/version.h +23 -0
  269. warp/native/volume.cpp +501 -0
  270. warp/native/volume.cu +68 -0
  271. warp/native/volume.h +970 -0
  272. warp/native/volume_builder.cu +483 -0
  273. warp/native/volume_builder.h +52 -0
  274. warp/native/volume_impl.h +70 -0
  275. warp/native/warp.cpp +1143 -0
  276. warp/native/warp.cu +4604 -0
  277. warp/native/warp.h +358 -0
  278. warp/optim/__init__.py +20 -0
  279. warp/optim/adam.py +24 -0
  280. warp/optim/linear.py +35 -0
  281. warp/optim/sgd.py +24 -0
  282. warp/paddle.py +24 -0
  283. warp/py.typed +0 -0
  284. warp/render/__init__.py +22 -0
  285. warp/render/imgui_manager.py +29 -0
  286. warp/render/render_opengl.py +24 -0
  287. warp/render/render_usd.py +24 -0
  288. warp/render/utils.py +24 -0
  289. warp/sparse.py +51 -0
  290. warp/tape.py +24 -0
  291. warp/tests/__init__.py +1 -0
  292. warp/tests/__main__.py +4 -0
  293. warp/tests/assets/curlnoise_golden.npy +0 -0
  294. warp/tests/assets/mlp_golden.npy +0 -0
  295. warp/tests/assets/pixel.npy +0 -0
  296. warp/tests/assets/pnoise_golden.npy +0 -0
  297. warp/tests/assets/spiky.usd +0 -0
  298. warp/tests/assets/test_grid.nvdb +0 -0
  299. warp/tests/assets/test_index_grid.nvdb +0 -0
  300. warp/tests/assets/test_int32_grid.nvdb +0 -0
  301. warp/tests/assets/test_vec_grid.nvdb +0 -0
  302. warp/tests/assets/torus.nvdb +0 -0
  303. warp/tests/assets/torus.usda +105 -0
  304. warp/tests/aux_test_class_kernel.py +34 -0
  305. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  306. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  307. warp/tests/aux_test_dependent.py +29 -0
  308. warp/tests/aux_test_grad_customs.py +29 -0
  309. warp/tests/aux_test_instancing_gc.py +26 -0
  310. warp/tests/aux_test_module_aot.py +7 -0
  311. warp/tests/aux_test_module_unload.py +23 -0
  312. warp/tests/aux_test_name_clash1.py +40 -0
  313. warp/tests/aux_test_name_clash2.py +40 -0
  314. warp/tests/aux_test_reference.py +9 -0
  315. warp/tests/aux_test_reference_reference.py +8 -0
  316. warp/tests/aux_test_square.py +16 -0
  317. warp/tests/aux_test_unresolved_func.py +22 -0
  318. warp/tests/aux_test_unresolved_symbol.py +22 -0
  319. warp/tests/cuda/__init__.py +0 -0
  320. warp/tests/cuda/test_async.py +676 -0
  321. warp/tests/cuda/test_conditional_captures.py +1147 -0
  322. warp/tests/cuda/test_ipc.py +124 -0
  323. warp/tests/cuda/test_mempool.py +233 -0
  324. warp/tests/cuda/test_multigpu.py +169 -0
  325. warp/tests/cuda/test_peer.py +139 -0
  326. warp/tests/cuda/test_pinned.py +84 -0
  327. warp/tests/cuda/test_streams.py +691 -0
  328. warp/tests/geometry/__init__.py +0 -0
  329. warp/tests/geometry/test_bvh.py +335 -0
  330. warp/tests/geometry/test_hash_grid.py +259 -0
  331. warp/tests/geometry/test_marching_cubes.py +294 -0
  332. warp/tests/geometry/test_mesh.py +318 -0
  333. warp/tests/geometry/test_mesh_query_aabb.py +392 -0
  334. warp/tests/geometry/test_mesh_query_point.py +935 -0
  335. warp/tests/geometry/test_mesh_query_ray.py +323 -0
  336. warp/tests/geometry/test_volume.py +1103 -0
  337. warp/tests/geometry/test_volume_write.py +346 -0
  338. warp/tests/interop/__init__.py +0 -0
  339. warp/tests/interop/test_dlpack.py +730 -0
  340. warp/tests/interop/test_jax.py +1673 -0
  341. warp/tests/interop/test_paddle.py +800 -0
  342. warp/tests/interop/test_torch.py +1001 -0
  343. warp/tests/run_coverage_serial.py +39 -0
  344. warp/tests/test_adam.py +162 -0
  345. warp/tests/test_arithmetic.py +1096 -0
  346. warp/tests/test_array.py +3756 -0
  347. warp/tests/test_array_reduce.py +156 -0
  348. warp/tests/test_assert.py +303 -0
  349. warp/tests/test_atomic.py +336 -0
  350. warp/tests/test_atomic_bitwise.py +209 -0
  351. warp/tests/test_atomic_cas.py +312 -0
  352. warp/tests/test_bool.py +220 -0
  353. warp/tests/test_builtins_resolution.py +732 -0
  354. warp/tests/test_closest_point_edge_edge.py +327 -0
  355. warp/tests/test_codegen.py +974 -0
  356. warp/tests/test_codegen_instancing.py +1495 -0
  357. warp/tests/test_compile_consts.py +215 -0
  358. warp/tests/test_conditional.py +298 -0
  359. warp/tests/test_context.py +35 -0
  360. warp/tests/test_copy.py +319 -0
  361. warp/tests/test_ctypes.py +618 -0
  362. warp/tests/test_dense.py +73 -0
  363. warp/tests/test_devices.py +127 -0
  364. warp/tests/test_enum.py +136 -0
  365. warp/tests/test_examples.py +424 -0
  366. warp/tests/test_fabricarray.py +998 -0
  367. warp/tests/test_fast_math.py +72 -0
  368. warp/tests/test_fem.py +2204 -0
  369. warp/tests/test_fixedarray.py +229 -0
  370. warp/tests/test_fp16.py +136 -0
  371. warp/tests/test_func.py +501 -0
  372. warp/tests/test_future_annotations.py +100 -0
  373. warp/tests/test_generics.py +656 -0
  374. warp/tests/test_grad.py +893 -0
  375. warp/tests/test_grad_customs.py +339 -0
  376. warp/tests/test_grad_debug.py +341 -0
  377. warp/tests/test_implicit_init.py +411 -0
  378. warp/tests/test_import.py +45 -0
  379. warp/tests/test_indexedarray.py +1140 -0
  380. warp/tests/test_intersect.py +103 -0
  381. warp/tests/test_iter.py +76 -0
  382. warp/tests/test_large.py +177 -0
  383. warp/tests/test_launch.py +411 -0
  384. warp/tests/test_lerp.py +151 -0
  385. warp/tests/test_linear_solvers.py +223 -0
  386. warp/tests/test_lvalue.py +427 -0
  387. warp/tests/test_map.py +526 -0
  388. warp/tests/test_mat.py +3515 -0
  389. warp/tests/test_mat_assign_copy.py +178 -0
  390. warp/tests/test_mat_constructors.py +573 -0
  391. warp/tests/test_mat_lite.py +122 -0
  392. warp/tests/test_mat_scalar_ops.py +2913 -0
  393. warp/tests/test_math.py +212 -0
  394. warp/tests/test_module_aot.py +287 -0
  395. warp/tests/test_module_hashing.py +258 -0
  396. warp/tests/test_modules_lite.py +70 -0
  397. warp/tests/test_noise.py +252 -0
  398. warp/tests/test_operators.py +299 -0
  399. warp/tests/test_options.py +129 -0
  400. warp/tests/test_overwrite.py +551 -0
  401. warp/tests/test_print.py +408 -0
  402. warp/tests/test_quat.py +2653 -0
  403. warp/tests/test_quat_assign_copy.py +145 -0
  404. warp/tests/test_rand.py +339 -0
  405. warp/tests/test_reload.py +303 -0
  406. warp/tests/test_rounding.py +157 -0
  407. warp/tests/test_runlength_encode.py +196 -0
  408. warp/tests/test_scalar_ops.py +133 -0
  409. warp/tests/test_smoothstep.py +108 -0
  410. warp/tests/test_snippet.py +318 -0
  411. warp/tests/test_sparse.py +845 -0
  412. warp/tests/test_spatial.py +2859 -0
  413. warp/tests/test_spatial_assign_copy.py +160 -0
  414. warp/tests/test_special_values.py +361 -0
  415. warp/tests/test_static.py +640 -0
  416. warp/tests/test_struct.py +901 -0
  417. warp/tests/test_tape.py +242 -0
  418. warp/tests/test_transient_module.py +93 -0
  419. warp/tests/test_triangle_closest_point.py +192 -0
  420. warp/tests/test_tuple.py +361 -0
  421. warp/tests/test_types.py +615 -0
  422. warp/tests/test_utils.py +594 -0
  423. warp/tests/test_vec.py +1408 -0
  424. warp/tests/test_vec_assign_copy.py +143 -0
  425. warp/tests/test_vec_constructors.py +325 -0
  426. warp/tests/test_vec_lite.py +80 -0
  427. warp/tests/test_vec_scalar_ops.py +2327 -0
  428. warp/tests/test_verify_fp.py +100 -0
  429. warp/tests/test_version.py +75 -0
  430. warp/tests/tile/__init__.py +0 -0
  431. warp/tests/tile/test_tile.py +1519 -0
  432. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  433. warp/tests/tile/test_tile_cholesky.py +608 -0
  434. warp/tests/tile/test_tile_load.py +724 -0
  435. warp/tests/tile/test_tile_mathdx.py +156 -0
  436. warp/tests/tile/test_tile_matmul.py +179 -0
  437. warp/tests/tile/test_tile_mlp.py +400 -0
  438. warp/tests/tile/test_tile_reduce.py +950 -0
  439. warp/tests/tile/test_tile_shared_memory.py +376 -0
  440. warp/tests/tile/test_tile_sort.py +121 -0
  441. warp/tests/tile/test_tile_view.py +173 -0
  442. warp/tests/unittest_serial.py +47 -0
  443. warp/tests/unittest_suites.py +430 -0
  444. warp/tests/unittest_utils.py +469 -0
  445. warp/tests/walkthrough_debug.py +95 -0
  446. warp/torch.py +24 -0
  447. warp/types.py +51 -0
  448. warp/utils.py +31 -0
  449. warp_lang-1.10.0.dist-info/METADATA +459 -0
  450. warp_lang-1.10.0.dist-info/RECORD +468 -0
  451. warp_lang-1.10.0.dist-info/WHEEL +5 -0
  452. warp_lang-1.10.0.dist-info/licenses/LICENSE.md +176 -0
  453. warp_lang-1.10.0.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  454. warp_lang-1.10.0.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  455. warp_lang-1.10.0.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  456. warp_lang-1.10.0.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  457. warp_lang-1.10.0.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  458. warp_lang-1.10.0.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  459. warp_lang-1.10.0.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  460. warp_lang-1.10.0.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  461. warp_lang-1.10.0.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  462. warp_lang-1.10.0.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  463. warp_lang-1.10.0.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  464. warp_lang-1.10.0.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  465. warp_lang-1.10.0.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  466. warp_lang-1.10.0.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  467. warp_lang-1.10.0.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  468. warp_lang-1.10.0.dist-info/top_level.txt +1 -0
warp/native/reduce.cu ADDED
@@ -0,0 +1,363 @@
1
+ /*
2
+ * SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: Apache-2.0
4
+ *
5
+ * Licensed under the Apache License, Version 2.0 (the "License");
6
+ * you may not use this file except in compliance with the License.
7
+ * You may obtain a copy of the License at
8
+ *
9
+ * http://www.apache.org/licenses/LICENSE-2.0
10
+ *
11
+ * Unless required by applicable law or agreed to in writing, software
12
+ * distributed under the License is distributed on an "AS IS" BASIS,
13
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ * See the License for the specific language governing permissions and
15
+ * limitations under the License.
16
+ */
17
+
18
+ #include "cuda_util.h"
19
+ #include "warp.h"
20
+
21
+ #include "temp_buffer.h"
22
+
23
+ #define THRUST_IGNORE_CUB_VERSION_CHECK
24
+ #include <cub/device/device_reduce.cuh>
25
+
26
+ namespace
27
+ {
28
+
29
+ template <typename T>
30
+ __global__ void cwise_mult_kernel(int len, int stride_a, int stride_b, const T *a, const T *b, T *out)
31
+ {
32
+ int i = blockIdx.x * blockDim.x + threadIdx.x;
33
+ if (i >= len)
34
+ return;
35
+ out[i] = a[i * stride_a] * b[i * stride_b];
36
+ }
37
+
38
+ /// Custom iterator for allowing strided access with CUB
39
+ template <typename T> struct cub_strided_iterator
40
+ {
41
+ typedef cub_strided_iterator<T> self_type;
42
+ typedef std::ptrdiff_t difference_type;
43
+ typedef T value_type;
44
+ typedef T *pointer;
45
+ typedef T &reference;
46
+
47
+ typedef std::random_access_iterator_tag iterator_category; ///< The iterator category
48
+
49
+ T *ptr = nullptr;
50
+ int stride = 1;
51
+
52
+ CUDA_CALLABLE self_type operator++(int)
53
+ {
54
+ return ++(self_type(*this));
55
+ }
56
+
57
+ CUDA_CALLABLE self_type &operator++()
58
+ {
59
+ ptr += stride;
60
+ return *this;
61
+ }
62
+
63
+ __host__ __device__ __forceinline__ reference operator*() const
64
+ {
65
+ return *ptr;
66
+ }
67
+
68
+ CUDA_CALLABLE self_type operator+(difference_type n) const
69
+ {
70
+ return self_type(*this) += n;
71
+ }
72
+
73
+ CUDA_CALLABLE self_type &operator+=(difference_type n)
74
+ {
75
+ ptr += n * stride;
76
+ return *this;
77
+ }
78
+
79
+ CUDA_CALLABLE self_type operator-(difference_type n) const
80
+ {
81
+ return self_type(*this) -= n;
82
+ }
83
+
84
+ CUDA_CALLABLE self_type &operator-=(difference_type n)
85
+ {
86
+ ptr -= n * stride;
87
+ return *this;
88
+ }
89
+
90
+ CUDA_CALLABLE difference_type operator-(const self_type &other) const
91
+ {
92
+ return (ptr - other.ptr) / stride;
93
+ }
94
+
95
+ CUDA_CALLABLE reference operator[](difference_type n) const
96
+ {
97
+ return *(ptr + n * stride);
98
+ }
99
+
100
+ CUDA_CALLABLE pointer operator->() const
101
+ {
102
+ return ptr;
103
+ }
104
+
105
+ CUDA_CALLABLE bool operator==(const self_type &rhs) const
106
+ {
107
+ return (ptr == rhs.ptr);
108
+ }
109
+
110
+ CUDA_CALLABLE bool operator!=(const self_type &rhs) const
111
+ {
112
+ return (ptr != rhs.ptr);
113
+ }
114
+ };
115
+
116
+ template <typename T> void array_sum_device(const T *ptr_a, T *ptr_out, int count, int byte_stride, int type_length)
117
+ {
118
+ assert((byte_stride % sizeof(T)) == 0);
119
+ const int stride = byte_stride / sizeof(T);
120
+
121
+ ContextGuard guard(wp_cuda_context_get_current());
122
+ cudaStream_t stream = static_cast<cudaStream_t>(wp_cuda_stream_get_current());
123
+
124
+ cub_strided_iterator<const T> ptr_strided{ptr_a, stride};
125
+
126
+ size_t buff_size = 0;
127
+ check_cuda(cub::DeviceReduce::Sum(nullptr, buff_size, ptr_strided, ptr_out, count, stream));
128
+ void* temp_buffer = wp_alloc_device(WP_CURRENT_CONTEXT, buff_size);
129
+
130
+ for (int k = 0; k < type_length; ++k)
131
+ {
132
+ cub_strided_iterator<const T> ptr_strided{ptr_a + k, stride};
133
+ check_cuda(cub::DeviceReduce::Sum(temp_buffer, buff_size, ptr_strided, ptr_out + k, count, stream));
134
+ }
135
+
136
+ wp_free_device(WP_CURRENT_CONTEXT, temp_buffer);
137
+ }
138
+
139
+ template <typename T>
140
+ void array_sum_device_dispatch(const T *ptr_a, T *ptr_out, int count, int byte_stride, int type_length)
141
+ {
142
+ using vec2 = wp::vec_t<2, T>;
143
+ using vec3 = wp::vec_t<3, T>;
144
+ using vec4 = wp::vec_t<4, T>;
145
+
146
+ // specialized calls for common vector types
147
+
148
+ if ((type_length % 4) == 0 && (byte_stride % sizeof(vec4)) == 0)
149
+ {
150
+ return array_sum_device(reinterpret_cast<const vec4 *>(ptr_a), reinterpret_cast<vec4 *>(ptr_out), count,
151
+ byte_stride, type_length / 4);
152
+ }
153
+
154
+ if ((type_length % 3) == 0 && (byte_stride % sizeof(vec3)) == 0)
155
+ {
156
+ return array_sum_device(reinterpret_cast<const vec3 *>(ptr_a), reinterpret_cast<vec3 *>(ptr_out), count,
157
+ byte_stride, type_length / 3);
158
+ }
159
+
160
+ if ((type_length % 2) == 0 && (byte_stride % sizeof(vec2)) == 0)
161
+ {
162
+ return array_sum_device(reinterpret_cast<const vec2 *>(ptr_a), reinterpret_cast<vec2 *>(ptr_out), count,
163
+ byte_stride, type_length / 2);
164
+ }
165
+
166
+ return array_sum_device(ptr_a, ptr_out, count, byte_stride, type_length);
167
+ }
168
+
169
+ template <typename T> CUDA_CALLABLE T element_inner_product(const T &a, const T &b)
170
+ {
171
+ return a * b;
172
+ }
173
+
174
+ template <unsigned Length, typename T>
175
+ CUDA_CALLABLE T element_inner_product(const wp::vec_t<Length, T> &a, const wp::vec_t<Length, T> &b)
176
+ {
177
+ return wp::dot(a, b);
178
+ }
179
+
180
+ /// Custom iterator for allowing strided access with CUB
181
+ template <typename ElemT, typename ScalarT> struct cub_inner_product_iterator
182
+ {
183
+ typedef cub_inner_product_iterator<ElemT, ScalarT> self_type;
184
+ typedef std::ptrdiff_t difference_type;
185
+ typedef ScalarT value_type;
186
+ typedef ScalarT *pointer;
187
+ typedef ScalarT reference;
188
+
189
+ typedef std::random_access_iterator_tag iterator_category; ///< The iterator category
190
+
191
+ const ElemT *ptr_a = nullptr;
192
+ const ElemT *ptr_b = nullptr;
193
+
194
+ int stride_a = 1;
195
+ int stride_b = 1;
196
+ int type_length = 1;
197
+
198
+ CUDA_CALLABLE self_type operator++(int)
199
+ {
200
+ return ++(self_type(*this));
201
+ }
202
+
203
+ CUDA_CALLABLE self_type &operator++()
204
+ {
205
+ ptr_a += stride_a;
206
+ ptr_b += stride_b;
207
+ return *this;
208
+ }
209
+
210
+ __host__ __device__ __forceinline__ reference operator*() const
211
+ {
212
+ return compute_value(0);
213
+ }
214
+
215
+ CUDA_CALLABLE self_type operator+(difference_type n) const
216
+ {
217
+ return self_type(*this) += n;
218
+ }
219
+
220
+ CUDA_CALLABLE self_type &operator+=(difference_type n)
221
+ {
222
+ ptr_a += n * stride_a;
223
+ ptr_b += n * stride_b;
224
+ return *this;
225
+ }
226
+
227
+ CUDA_CALLABLE self_type operator-(difference_type n) const
228
+ {
229
+ return self_type(*this) -= n;
230
+ }
231
+
232
+ CUDA_CALLABLE self_type &operator-=(difference_type n)
233
+ {
234
+ ptr_a -= n * stride_a;
235
+ ptr_b -= n * stride_b;
236
+ return *this;
237
+ }
238
+
239
+ CUDA_CALLABLE difference_type operator-(const self_type &other) const
240
+ {
241
+ return (ptr_a - other.ptr_a) / stride_a;
242
+ }
243
+
244
+ CUDA_CALLABLE reference operator[](difference_type n) const
245
+ {
246
+ return compute_value(n);
247
+ }
248
+
249
+ CUDA_CALLABLE bool operator==(const self_type &rhs) const
250
+ {
251
+ return (ptr_a == rhs.ptr_a);
252
+ }
253
+
254
+ CUDA_CALLABLE bool operator!=(const self_type &rhs) const
255
+ {
256
+ return (ptr_a != rhs.ptr_a);
257
+ }
258
+
259
+ private:
260
+ CUDA_CALLABLE ScalarT compute_value(difference_type n) const
261
+ {
262
+ ScalarT val(0);
263
+ const ElemT *a = ptr_a + n * stride_a;
264
+ const ElemT *b = ptr_b + n * stride_b;
265
+ for (int k = 0; k < type_length; ++k)
266
+ {
267
+ val += element_inner_product(a[k], b[k]);
268
+ }
269
+ return val;
270
+ }
271
+ };
272
+
273
+ template <typename ElemT, typename ScalarT>
274
+ void array_inner_device(const ElemT *ptr_a, const ElemT *ptr_b, ScalarT *ptr_out, int count, int byte_stride_a,
275
+ int byte_stride_b, int type_length)
276
+ {
277
+ assert((byte_stride_a % sizeof(ElemT)) == 0);
278
+ assert((byte_stride_b % sizeof(ElemT)) == 0);
279
+ const int stride_a = byte_stride_a / sizeof(ElemT);
280
+ const int stride_b = byte_stride_b / sizeof(ElemT);
281
+
282
+ ContextGuard guard(wp_cuda_context_get_current());
283
+ cudaStream_t stream = static_cast<cudaStream_t>(wp_cuda_stream_get_current());
284
+
285
+ cub_inner_product_iterator<ElemT, ScalarT> inner_iterator{ptr_a, ptr_b, stride_a, stride_b, type_length};
286
+
287
+ size_t buff_size = 0;
288
+ check_cuda(cub::DeviceReduce::Sum(nullptr, buff_size, inner_iterator, ptr_out, count, stream));
289
+ void* temp_buffer = wp_alloc_device(WP_CURRENT_CONTEXT, buff_size);
290
+
291
+ check_cuda(cub::DeviceReduce::Sum(temp_buffer, buff_size, inner_iterator, ptr_out, count, stream));
292
+
293
+ wp_free_device(WP_CURRENT_CONTEXT, temp_buffer);
294
+ }
295
+
296
+ template <typename T>
297
+ void array_inner_device_dispatch(const T *ptr_a, const T *ptr_b, T *ptr_out, int count, int byte_stride_a,
298
+ int byte_stride_b, int type_length)
299
+ {
300
+ using vec2 = wp::vec_t<2, T>;
301
+ using vec3 = wp::vec_t<3, T>;
302
+ using vec4 = wp::vec_t<4, T>;
303
+
304
+ // specialized calls for common vector types
305
+
306
+ if ((type_length % 4) == 0 && (byte_stride_a % sizeof(vec4)) == 0 && (byte_stride_b % sizeof(vec4)) == 0)
307
+ {
308
+ return array_inner_device(reinterpret_cast<const vec4 *>(ptr_a), reinterpret_cast<const vec4 *>(ptr_b), ptr_out,
309
+ count, byte_stride_a, byte_stride_b, type_length / 4);
310
+ }
311
+
312
+ if ((type_length % 3) == 0 && (byte_stride_a % sizeof(vec3)) == 0 && (byte_stride_b % sizeof(vec3)) == 0)
313
+ {
314
+ return array_inner_device(reinterpret_cast<const vec3 *>(ptr_a), reinterpret_cast<const vec3 *>(ptr_b), ptr_out,
315
+ count, byte_stride_a, byte_stride_b, type_length / 3);
316
+ }
317
+
318
+ if ((type_length % 2) == 0 && (byte_stride_a % sizeof(vec2)) == 0 && (byte_stride_b % sizeof(vec2)) == 0)
319
+ {
320
+ return array_inner_device(reinterpret_cast<const vec2 *>(ptr_a), reinterpret_cast<const vec2 *>(ptr_b), ptr_out,
321
+ count, byte_stride_a, byte_stride_b, type_length / 2);
322
+ }
323
+
324
+ return array_inner_device(ptr_a, ptr_b, ptr_out, count, byte_stride_a, byte_stride_b, type_length);
325
+ }
326
+
327
+ } // anonymous namespace
328
+
329
+ void wp_array_inner_float_device(uint64_t a, uint64_t b, uint64_t out, int count, int byte_stride_a, int byte_stride_b,
330
+ int type_len)
331
+ {
332
+ void *context = wp_cuda_context_get_current();
333
+
334
+ const float *ptr_a = (const float *)(a);
335
+ const float *ptr_b = (const float *)(b);
336
+ float *ptr_out = (float *)(out);
337
+
338
+ array_inner_device_dispatch(ptr_a, ptr_b, ptr_out, count, byte_stride_a, byte_stride_b, type_len);
339
+ }
340
+
341
+ void wp_array_inner_double_device(uint64_t a, uint64_t b, uint64_t out, int count, int byte_stride_a, int byte_stride_b,
342
+ int type_len)
343
+ {
344
+ const double *ptr_a = (const double *)(a);
345
+ const double *ptr_b = (const double *)(b);
346
+ double *ptr_out = (double *)(out);
347
+
348
+ array_inner_device_dispatch(ptr_a, ptr_b, ptr_out, count, byte_stride_a, byte_stride_b, type_len);
349
+ }
350
+
351
+ void wp_array_sum_float_device(uint64_t a, uint64_t out, int count, int byte_stride, int type_length)
352
+ {
353
+ const float *ptr_a = (const float *)(a);
354
+ float *ptr_out = (float *)(out);
355
+ array_sum_device_dispatch(ptr_a, ptr_out, count, byte_stride, type_length);
356
+ }
357
+
358
+ void wp_array_sum_double_device(uint64_t a, uint64_t out, int count, int byte_stride, int type_length)
359
+ {
360
+ const double *ptr_a = (const double *)(a);
361
+ double *ptr_out = (double *)(out);
362
+ array_sum_device_dispatch(ptr_a, ptr_out, count, byte_stride, type_length);
363
+ }
@@ -0,0 +1,79 @@
1
+ /*
2
+ * SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: Apache-2.0
4
+ *
5
+ * Licensed under the Apache License, Version 2.0 (the "License");
6
+ * you may not use this file except in compliance with the License.
7
+ * You may obtain a copy of the License at
8
+ *
9
+ * http://www.apache.org/licenses/LICENSE-2.0
10
+ *
11
+ * Unless required by applicable law or agreed to in writing, software
12
+ * distributed under the License is distributed on an "AS IS" BASIS,
13
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ * See the License for the specific language governing permissions and
15
+ * limitations under the License.
16
+ */
17
+
18
+ #include "warp.h"
19
+
20
+ #include <cstdint>
21
+
22
+ template <typename T>
23
+ void runlength_encode_host(int n,
24
+ const T *values,
25
+ T *run_values,
26
+ int *run_lengths,
27
+ int *run_count)
28
+ {
29
+ if (n == 0)
30
+ {
31
+ *run_count = 0;
32
+ return;
33
+ }
34
+
35
+ const T *end = values + n;
36
+
37
+ *run_count = 1;
38
+ *run_lengths = 1;
39
+ *run_values = *values;
40
+
41
+ while (++values != end)
42
+ {
43
+ if (*values == *run_values)
44
+ {
45
+ ++*run_lengths;
46
+ }
47
+ else
48
+ {
49
+ ++*run_count;
50
+ *(++run_lengths) = 1;
51
+ *(++run_values) = *values;
52
+ }
53
+ }
54
+ }
55
+
56
+ void wp_runlength_encode_int_host(
57
+ uint64_t values,
58
+ uint64_t run_values,
59
+ uint64_t run_lengths,
60
+ uint64_t run_count,
61
+ int n)
62
+ {
63
+ runlength_encode_host<int>(n,
64
+ reinterpret_cast<const int *>(values),
65
+ reinterpret_cast<int *>(run_values),
66
+ reinterpret_cast<int *>(run_lengths),
67
+ reinterpret_cast<int *>(run_count));
68
+ }
69
+
70
+ #if !WP_ENABLE_CUDA
71
+ void wp_runlength_encode_int_device(
72
+ uint64_t values,
73
+ uint64_t run_values,
74
+ uint64_t run_lengths,
75
+ uint64_t run_count,
76
+ int n)
77
+ {
78
+ }
79
+ #endif
@@ -0,0 +1,61 @@
1
+ /*
2
+ * SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: Apache-2.0
4
+ *
5
+ * Licensed under the Apache License, Version 2.0 (the "License");
6
+ * you may not use this file except in compliance with the License.
7
+ * You may obtain a copy of the License at
8
+ *
9
+ * http://www.apache.org/licenses/LICENSE-2.0
10
+ *
11
+ * Unless required by applicable law or agreed to in writing, software
12
+ * distributed under the License is distributed on an "AS IS" BASIS,
13
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ * See the License for the specific language governing permissions and
15
+ * limitations under the License.
16
+ */
17
+
18
+ #include "warp.h"
19
+ #include "cuda_util.h"
20
+
21
+ #define THRUST_IGNORE_CUB_VERSION_CHECK
22
+ #include <cub/device/device_run_length_encode.cuh>
23
+
24
+ template <typename T>
25
+ void runlength_encode_device(int n,
26
+ const T *values,
27
+ T *run_values,
28
+ int *run_lengths,
29
+ int *run_count)
30
+ {
31
+ ContextGuard guard(wp_cuda_context_get_current());
32
+ cudaStream_t stream = static_cast<cudaStream_t>(wp_cuda_stream_get_current());
33
+
34
+ size_t buff_size = 0;
35
+ check_cuda(cub::DeviceRunLengthEncode::Encode(
36
+ nullptr, buff_size, values, run_values, run_lengths, run_count,
37
+ n, stream));
38
+
39
+ void* temp_buffer = wp_alloc_device(WP_CURRENT_CONTEXT, buff_size);
40
+
41
+ check_cuda(cub::DeviceRunLengthEncode::Encode(
42
+ temp_buffer, buff_size, values, run_values, run_lengths, run_count,
43
+ n, stream));
44
+
45
+ wp_free_device(WP_CURRENT_CONTEXT, temp_buffer);
46
+ }
47
+
48
+ void wp_runlength_encode_int_device(
49
+ uint64_t values,
50
+ uint64_t run_values,
51
+ uint64_t run_lengths,
52
+ uint64_t run_count,
53
+ int n)
54
+ {
55
+ return runlength_encode_device<int>(
56
+ n,
57
+ reinterpret_cast<const int *>(values),
58
+ reinterpret_cast<int *>(run_values),
59
+ reinterpret_cast<int *>(run_lengths),
60
+ reinterpret_cast<int *>(run_count));
61
+ }
warp/native/scan.cpp ADDED
@@ -0,0 +1,47 @@
1
+ /*
2
+ * SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: Apache-2.0
4
+ *
5
+ * Licensed under the Apache License, Version 2.0 (the "License");
6
+ * you may not use this file except in compliance with the License.
7
+ * You may obtain a copy of the License at
8
+ *
9
+ * http://www.apache.org/licenses/LICENSE-2.0
10
+ *
11
+ * Unless required by applicable law or agreed to in writing, software
12
+ * distributed under the License is distributed on an "AS IS" BASIS,
13
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ * See the License for the specific language governing permissions and
15
+ * limitations under the License.
16
+ */
17
+
18
+ #include "scan.h"
19
+
20
+ #include <numeric>
21
+
22
+ template<typename T>
23
+ void scan_host(const T* values_in, T* values_out, int n, bool inclusive)
24
+ {
25
+ static void* scan_temp_memory = NULL;
26
+ static size_t scan_temp_max_size = 0;
27
+
28
+ // compute temporary memory required
29
+ if (!inclusive && n > scan_temp_max_size)
30
+ {
31
+ wp_free_host(scan_temp_memory);
32
+ scan_temp_memory = wp_alloc_host(sizeof(T) * n);
33
+ scan_temp_max_size = n;
34
+ }
35
+
36
+ T* result = inclusive ? values_out : static_cast<T*>(scan_temp_memory);
37
+
38
+ // scan
39
+ std::partial_sum(values_in, values_in + n, result);
40
+ if (!inclusive) {
41
+ values_out[0] = (T)0;
42
+ wp_memcpy_h2h(values_out + 1, result, sizeof(T) * (n - 1));
43
+ }
44
+ }
45
+
46
+ template void scan_host(const int*, int*, int, bool);
47
+ template void scan_host(const float*, float*, int, bool);
warp/native/scan.cu ADDED
@@ -0,0 +1,55 @@
1
+ /*
2
+ * SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: Apache-2.0
4
+ *
5
+ * Licensed under the Apache License, Version 2.0 (the "License");
6
+ * you may not use this file except in compliance with the License.
7
+ * You may obtain a copy of the License at
8
+ *
9
+ * http://www.apache.org/licenses/LICENSE-2.0
10
+ *
11
+ * Unless required by applicable law or agreed to in writing, software
12
+ * distributed under the License is distributed on an "AS IS" BASIS,
13
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ * See the License for the specific language governing permissions and
15
+ * limitations under the License.
16
+ */
17
+
18
+ #include "warp.h"
19
+ #include "scan.h"
20
+
21
+ #include "cuda_util.h"
22
+
23
+ #define THRUST_IGNORE_CUB_VERSION_CHECK
24
+
25
+ #include <cub/device/device_scan.cuh>
26
+
27
+ template<typename T>
28
+ void scan_device(const T* values_in, T* values_out, int n, bool inclusive)
29
+ {
30
+ ContextGuard guard(wp_cuda_context_get_current());
31
+
32
+ cudaStream_t stream = static_cast<cudaStream_t>(wp_cuda_stream_get_current());
33
+
34
+ // compute temporary memory required
35
+ size_t scan_temp_size;
36
+ if (inclusive) {
37
+ check_cuda(cub::DeviceScan::InclusiveSum(NULL, scan_temp_size, values_in, values_out, n));
38
+ } else {
39
+ check_cuda(cub::DeviceScan::ExclusiveSum(NULL, scan_temp_size, values_in, values_out, n));
40
+ }
41
+
42
+ void* temp_buffer = wp_alloc_device(WP_CURRENT_CONTEXT, scan_temp_size);
43
+
44
+ // scan
45
+ if (inclusive) {
46
+ check_cuda(cub::DeviceScan::InclusiveSum(temp_buffer, scan_temp_size, values_in, values_out, n, stream));
47
+ } else {
48
+ check_cuda(cub::DeviceScan::ExclusiveSum(temp_buffer, scan_temp_size, values_in, values_out, n, stream));
49
+ }
50
+
51
+ wp_free_device(WP_CURRENT_CONTEXT, temp_buffer);
52
+ }
53
+
54
+ template void scan_device(const int*, int*, int, bool);
55
+ template void scan_device(const float*, float*, int, bool);
warp/native/scan.h ADDED
@@ -0,0 +1,23 @@
1
+ /*
2
+ * SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: Apache-2.0
4
+ *
5
+ * Licensed under the Apache License, Version 2.0 (the "License");
6
+ * you may not use this file except in compliance with the License.
7
+ * You may obtain a copy of the License at
8
+ *
9
+ * http://www.apache.org/licenses/LICENSE-2.0
10
+ *
11
+ * Unless required by applicable law or agreed to in writing, software
12
+ * distributed under the License is distributed on an "AS IS" BASIS,
13
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ * See the License for the specific language governing permissions and
15
+ * limitations under the License.
16
+ */
17
+
18
+ #pragma once
19
+
20
+ template<typename T>
21
+ void scan_host(const T* values_in, T* values_out, int n, bool inclusive = true);
22
+ template<typename T>
23
+ void scan_device(const T* values_in, T* values_out, int n, bool inclusive = true);