warp-lang 0.9.0__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (315) hide show
  1. warp/__init__.py +15 -7
  2. warp/__init__.pyi +1 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +22 -443
  6. warp/build_dll.py +384 -0
  7. warp/builtins.py +998 -488
  8. warp/codegen.py +1307 -739
  9. warp/config.py +5 -3
  10. warp/constants.py +6 -0
  11. warp/context.py +1291 -548
  12. warp/dlpack.py +31 -31
  13. warp/fabric.py +326 -0
  14. warp/fem/__init__.py +27 -0
  15. warp/fem/cache.py +389 -0
  16. warp/fem/dirichlet.py +181 -0
  17. warp/fem/domain.py +263 -0
  18. warp/fem/field/__init__.py +101 -0
  19. warp/fem/field/field.py +149 -0
  20. warp/fem/field/nodal_field.py +299 -0
  21. warp/fem/field/restriction.py +21 -0
  22. warp/fem/field/test.py +181 -0
  23. warp/fem/field/trial.py +183 -0
  24. warp/fem/geometry/__init__.py +19 -0
  25. warp/fem/geometry/closest_point.py +70 -0
  26. warp/fem/geometry/deformed_geometry.py +271 -0
  27. warp/fem/geometry/element.py +744 -0
  28. warp/fem/geometry/geometry.py +186 -0
  29. warp/fem/geometry/grid_2d.py +373 -0
  30. warp/fem/geometry/grid_3d.py +435 -0
  31. warp/fem/geometry/hexmesh.py +953 -0
  32. warp/fem/geometry/partition.py +376 -0
  33. warp/fem/geometry/quadmesh_2d.py +532 -0
  34. warp/fem/geometry/tetmesh.py +840 -0
  35. warp/fem/geometry/trimesh_2d.py +577 -0
  36. warp/fem/integrate.py +1616 -0
  37. warp/fem/operator.py +191 -0
  38. warp/fem/polynomial.py +213 -0
  39. warp/fem/quadrature/__init__.py +2 -0
  40. warp/fem/quadrature/pic_quadrature.py +245 -0
  41. warp/fem/quadrature/quadrature.py +294 -0
  42. warp/fem/space/__init__.py +292 -0
  43. warp/fem/space/basis_space.py +489 -0
  44. warp/fem/space/collocated_function_space.py +105 -0
  45. warp/fem/space/dof_mapper.py +236 -0
  46. warp/fem/space/function_space.py +145 -0
  47. warp/fem/space/grid_2d_function_space.py +267 -0
  48. warp/fem/space/grid_3d_function_space.py +306 -0
  49. warp/fem/space/hexmesh_function_space.py +352 -0
  50. warp/fem/space/partition.py +350 -0
  51. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  52. warp/fem/space/restriction.py +160 -0
  53. warp/fem/space/shape/__init__.py +15 -0
  54. warp/fem/space/shape/cube_shape_function.py +738 -0
  55. warp/fem/space/shape/shape_function.py +103 -0
  56. warp/fem/space/shape/square_shape_function.py +611 -0
  57. warp/fem/space/shape/tet_shape_function.py +567 -0
  58. warp/fem/space/shape/triangle_shape_function.py +429 -0
  59. warp/fem/space/tetmesh_function_space.py +292 -0
  60. warp/fem/space/topology.py +295 -0
  61. warp/fem/space/trimesh_2d_function_space.py +221 -0
  62. warp/fem/types.py +77 -0
  63. warp/fem/utils.py +495 -0
  64. warp/native/array.h +164 -55
  65. warp/native/builtin.h +150 -174
  66. warp/native/bvh.cpp +75 -328
  67. warp/native/bvh.cu +406 -23
  68. warp/native/bvh.h +37 -45
  69. warp/native/clang/clang.cpp +136 -24
  70. warp/native/crt.cpp +1 -76
  71. warp/native/crt.h +111 -104
  72. warp/native/cuda_crt.h +1049 -0
  73. warp/native/cuda_util.cpp +15 -3
  74. warp/native/cuda_util.h +3 -1
  75. warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
  76. warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
  77. warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
  78. warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
  79. warp/native/cutlass/tools/library/scripts/library.py +799 -0
  80. warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
  81. warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
  82. warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
  83. warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
  84. warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
  85. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
  86. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
  87. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
  88. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
  89. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
  90. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
  91. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
  92. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
  93. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
  94. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
  95. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
  96. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
  97. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
  98. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
  99. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
  100. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
  101. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
  102. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
  103. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
  104. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
  105. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
  106. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
  107. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
  108. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
  109. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
  110. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
  111. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
  112. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
  113. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
  114. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
  115. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
  116. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
  117. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  118. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  119. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
  120. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
  121. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
  122. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  123. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  124. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
  125. warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
  126. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
  127. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
  128. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
  129. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
  130. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
  131. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
  132. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
  133. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
  134. warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
  135. warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
  136. warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
  137. warp/native/cutlass/tools/library/scripts/rt.py +796 -0
  138. warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
  139. warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
  140. warp/native/cutlass_gemm.cu +5 -3
  141. warp/native/exports.h +1240 -949
  142. warp/native/fabric.h +228 -0
  143. warp/native/hashgrid.cpp +4 -4
  144. warp/native/hashgrid.h +22 -2
  145. warp/native/initializer_array.h +2 -2
  146. warp/native/intersect.h +22 -7
  147. warp/native/intersect_adj.h +8 -8
  148. warp/native/intersect_tri.h +13 -16
  149. warp/native/marching.cu +157 -161
  150. warp/native/mat.h +119 -19
  151. warp/native/matnn.h +2 -2
  152. warp/native/mesh.cpp +108 -83
  153. warp/native/mesh.cu +243 -6
  154. warp/native/mesh.h +1547 -458
  155. warp/native/nanovdb/NanoVDB.h +1 -1
  156. warp/native/noise.h +272 -329
  157. warp/native/quat.h +51 -8
  158. warp/native/rand.h +45 -35
  159. warp/native/range.h +6 -2
  160. warp/native/reduce.cpp +157 -0
  161. warp/native/reduce.cu +348 -0
  162. warp/native/runlength_encode.cpp +62 -0
  163. warp/native/runlength_encode.cu +46 -0
  164. warp/native/scan.cu +11 -13
  165. warp/native/scan.h +1 -0
  166. warp/native/solid_angle.h +442 -0
  167. warp/native/sort.cpp +13 -0
  168. warp/native/sort.cu +9 -1
  169. warp/native/sparse.cpp +338 -0
  170. warp/native/sparse.cu +545 -0
  171. warp/native/spatial.h +2 -2
  172. warp/native/temp_buffer.h +30 -0
  173. warp/native/vec.h +126 -24
  174. warp/native/volume.h +120 -0
  175. warp/native/warp.cpp +658 -53
  176. warp/native/warp.cu +660 -68
  177. warp/native/warp.h +112 -12
  178. warp/optim/__init__.py +1 -0
  179. warp/optim/linear.py +922 -0
  180. warp/optim/sgd.py +92 -0
  181. warp/render/render_opengl.py +392 -152
  182. warp/render/render_usd.py +11 -11
  183. warp/sim/__init__.py +2 -2
  184. warp/sim/articulation.py +385 -185
  185. warp/sim/collide.py +21 -8
  186. warp/sim/import_mjcf.py +297 -106
  187. warp/sim/import_urdf.py +389 -210
  188. warp/sim/import_usd.py +198 -97
  189. warp/sim/inertia.py +17 -18
  190. warp/sim/integrator_euler.py +14 -8
  191. warp/sim/integrator_xpbd.py +161 -19
  192. warp/sim/model.py +795 -291
  193. warp/sim/optimizer.py +2 -6
  194. warp/sim/render.py +65 -3
  195. warp/sim/utils.py +3 -0
  196. warp/sparse.py +1227 -0
  197. warp/stubs.py +665 -223
  198. warp/tape.py +66 -15
  199. warp/tests/__main__.py +3 -6
  200. warp/tests/assets/curlnoise_golden.npy +0 -0
  201. warp/tests/assets/pnoise_golden.npy +0 -0
  202. warp/tests/assets/torus.usda +105 -105
  203. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  204. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  205. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  206. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  207. warp/tests/aux_test_unresolved_func.py +14 -0
  208. warp/tests/aux_test_unresolved_symbol.py +14 -0
  209. warp/tests/disabled_kinematics.py +239 -0
  210. warp/tests/run_coverage_serial.py +31 -0
  211. warp/tests/test_adam.py +103 -106
  212. warp/tests/test_arithmetic.py +128 -74
  213. warp/tests/test_array.py +1497 -211
  214. warp/tests/test_array_reduce.py +150 -0
  215. warp/tests/test_atomic.py +64 -28
  216. warp/tests/test_bool.py +99 -0
  217. warp/tests/test_builtins_resolution.py +1292 -0
  218. warp/tests/test_bvh.py +75 -43
  219. warp/tests/test_closest_point_edge_edge.py +54 -57
  220. warp/tests/test_codegen.py +233 -128
  221. warp/tests/test_compile_consts.py +28 -20
  222. warp/tests/test_conditional.py +108 -24
  223. warp/tests/test_copy.py +10 -12
  224. warp/tests/test_ctypes.py +112 -88
  225. warp/tests/test_dense.py +21 -14
  226. warp/tests/test_devices.py +98 -0
  227. warp/tests/test_dlpack.py +136 -108
  228. warp/tests/test_examples.py +277 -0
  229. warp/tests/test_fabricarray.py +955 -0
  230. warp/tests/test_fast_math.py +15 -11
  231. warp/tests/test_fem.py +1271 -0
  232. warp/tests/test_fp16.py +53 -19
  233. warp/tests/test_func.py +187 -74
  234. warp/tests/test_generics.py +194 -49
  235. warp/tests/test_grad.py +180 -116
  236. warp/tests/test_grad_customs.py +176 -0
  237. warp/tests/test_hash_grid.py +52 -37
  238. warp/tests/test_import.py +10 -23
  239. warp/tests/test_indexedarray.py +577 -24
  240. warp/tests/test_intersect.py +18 -9
  241. warp/tests/test_large.py +141 -0
  242. warp/tests/test_launch.py +251 -15
  243. warp/tests/test_lerp.py +64 -65
  244. warp/tests/test_linear_solvers.py +154 -0
  245. warp/tests/test_lvalue.py +493 -0
  246. warp/tests/test_marching_cubes.py +12 -13
  247. warp/tests/test_mat.py +508 -2778
  248. warp/tests/test_mat_lite.py +115 -0
  249. warp/tests/test_mat_scalar_ops.py +2889 -0
  250. warp/tests/test_math.py +103 -9
  251. warp/tests/test_matmul.py +305 -69
  252. warp/tests/test_matmul_lite.py +410 -0
  253. warp/tests/test_mesh.py +71 -14
  254. warp/tests/test_mesh_query_aabb.py +41 -25
  255. warp/tests/test_mesh_query_point.py +325 -34
  256. warp/tests/test_mesh_query_ray.py +39 -22
  257. warp/tests/test_mlp.py +30 -22
  258. warp/tests/test_model.py +92 -89
  259. warp/tests/test_modules_lite.py +39 -0
  260. warp/tests/test_multigpu.py +88 -114
  261. warp/tests/test_noise.py +12 -11
  262. warp/tests/test_operators.py +16 -20
  263. warp/tests/test_options.py +11 -11
  264. warp/tests/test_pinned.py +17 -18
  265. warp/tests/test_print.py +32 -11
  266. warp/tests/test_quat.py +275 -129
  267. warp/tests/test_rand.py +18 -16
  268. warp/tests/test_reload.py +38 -34
  269. warp/tests/test_rounding.py +50 -43
  270. warp/tests/test_runlength_encode.py +190 -0
  271. warp/tests/test_smoothstep.py +9 -11
  272. warp/tests/test_snippet.py +143 -0
  273. warp/tests/test_sparse.py +460 -0
  274. warp/tests/test_spatial.py +276 -243
  275. warp/tests/test_streams.py +110 -85
  276. warp/tests/test_struct.py +331 -85
  277. warp/tests/test_tape.py +39 -21
  278. warp/tests/test_torch.py +118 -89
  279. warp/tests/test_transient_module.py +12 -13
  280. warp/tests/test_types.py +614 -0
  281. warp/tests/test_utils.py +494 -0
  282. warp/tests/test_vec.py +354 -1987
  283. warp/tests/test_vec_lite.py +73 -0
  284. warp/tests/test_vec_scalar_ops.py +2099 -0
  285. warp/tests/test_volume.py +457 -293
  286. warp/tests/test_volume_write.py +124 -134
  287. warp/tests/unittest_serial.py +35 -0
  288. warp/tests/unittest_suites.py +341 -0
  289. warp/tests/unittest_utils.py +568 -0
  290. warp/tests/unused_test_misc.py +71 -0
  291. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  292. warp/thirdparty/appdirs.py +36 -45
  293. warp/thirdparty/unittest_parallel.py +549 -0
  294. warp/torch.py +72 -30
  295. warp/types.py +1744 -713
  296. warp/utils.py +360 -350
  297. warp_lang-0.11.0.dist-info/LICENSE.md +36 -0
  298. warp_lang-0.11.0.dist-info/METADATA +238 -0
  299. warp_lang-0.11.0.dist-info/RECORD +332 -0
  300. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
  301. warp/bin/warp-clang.exp +0 -0
  302. warp/bin/warp-clang.lib +0 -0
  303. warp/bin/warp.exp +0 -0
  304. warp/bin/warp.lib +0 -0
  305. warp/tests/test_all.py +0 -215
  306. warp/tests/test_array_scan.py +0 -60
  307. warp/tests/test_base.py +0 -208
  308. warp/tests/test_unresolved_func.py +0 -7
  309. warp/tests/test_unresolved_symbol.py +0 -7
  310. warp_lang-0.9.0.dist-info/METADATA +0 -20
  311. warp_lang-0.9.0.dist-info/RECORD +0 -177
  312. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  313. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  314. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  315. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
warp/native/reduce.cu ADDED
@@ -0,0 +1,348 @@
1
+
2
+ #include "cuda_util.h"
3
+ #include "warp.h"
4
+
5
+ #include "temp_buffer.h"
6
+
7
+ #define THRUST_IGNORE_CUB_VERSION_CHECK
8
+ #include <cub/device/device_reduce.cuh>
9
+ #include <cub/iterator/counting_input_iterator.cuh>
10
+
11
+ namespace
12
+ {
13
+
14
+ template <typename T>
15
+ __global__ void cwise_mult_kernel(int len, int stride_a, int stride_b, const T *a, const T *b, T *out)
16
+ {
17
+ int i = blockIdx.x * blockDim.x + threadIdx.x;
18
+ if (i >= len)
19
+ return;
20
+ out[i] = a[i * stride_a] * b[i * stride_b];
21
+ }
22
+
23
+ /// Custom iterator for allowing strided access with CUB
24
+ template <typename T> struct cub_strided_iterator
25
+ {
26
+ typedef cub_strided_iterator<T> self_type;
27
+ typedef std::ptrdiff_t difference_type;
28
+ typedef T value_type;
29
+ typedef T *pointer;
30
+ typedef T &reference;
31
+
32
+ typedef std::random_access_iterator_tag iterator_category; ///< The iterator category
33
+
34
+ T *ptr = nullptr;
35
+ int stride = 1;
36
+
37
+ CUDA_CALLABLE self_type operator++(int)
38
+ {
39
+ return ++(self_type(*this));
40
+ }
41
+
42
+ CUDA_CALLABLE self_type &operator++()
43
+ {
44
+ ptr += stride;
45
+ return *this;
46
+ }
47
+
48
+ __host__ __device__ __forceinline__ reference operator*() const
49
+ {
50
+ return *ptr;
51
+ }
52
+
53
+ CUDA_CALLABLE self_type operator+(difference_type n) const
54
+ {
55
+ return self_type(*this) += n;
56
+ }
57
+
58
+ CUDA_CALLABLE self_type &operator+=(difference_type n)
59
+ {
60
+ ptr += n * stride;
61
+ return *this;
62
+ }
63
+
64
+ CUDA_CALLABLE self_type operator-(difference_type n) const
65
+ {
66
+ return self_type(*this) -= n;
67
+ }
68
+
69
+ CUDA_CALLABLE self_type &operator-=(difference_type n)
70
+ {
71
+ ptr -= n * stride;
72
+ return *this;
73
+ }
74
+
75
+ CUDA_CALLABLE difference_type operator-(const self_type &other) const
76
+ {
77
+ return (ptr - other.ptr) / stride;
78
+ }
79
+
80
+ CUDA_CALLABLE reference operator[](difference_type n) const
81
+ {
82
+ return *(ptr + n * stride);
83
+ }
84
+
85
+ CUDA_CALLABLE pointer operator->() const
86
+ {
87
+ return ptr;
88
+ }
89
+
90
+ CUDA_CALLABLE bool operator==(const self_type &rhs) const
91
+ {
92
+ return (ptr == rhs.ptr);
93
+ }
94
+
95
+ CUDA_CALLABLE bool operator!=(const self_type &rhs) const
96
+ {
97
+ return (ptr != rhs.ptr);
98
+ }
99
+ };
100
+
101
+ template <typename T> void array_sum_device(const T *ptr_a, T *ptr_out, int count, int byte_stride, int type_length)
102
+ {
103
+ assert((byte_stride % sizeof(T)) == 0);
104
+ const int stride = byte_stride / sizeof(T);
105
+
106
+ ContextGuard guard(cuda_context_get_current());
107
+ cudaStream_t stream = static_cast<cudaStream_t>(cuda_stream_get_current());
108
+
109
+ cub_strided_iterator<const T> ptr_strided{ptr_a, stride};
110
+
111
+ size_t buff_size = 0;
112
+ check_cuda(cub::DeviceReduce::Sum(nullptr, buff_size, ptr_strided, ptr_out, count, stream));
113
+ void* temp_buffer = alloc_temp_device(WP_CURRENT_CONTEXT, buff_size);
114
+
115
+ for (int k = 0; k < type_length; ++k)
116
+ {
117
+ cub_strided_iterator<const T> ptr_strided{ptr_a + k, stride};
118
+ check_cuda(cub::DeviceReduce::Sum(temp_buffer, buff_size, ptr_strided, ptr_out + k, count, stream));
119
+ }
120
+
121
+ free_temp_device(WP_CURRENT_CONTEXT, temp_buffer);
122
+ }
123
+
124
+ template <typename T>
125
+ void array_sum_device_dispatch(const T *ptr_a, T *ptr_out, int count, int byte_stride, int type_length)
126
+ {
127
+ using vec2 = wp::vec_t<2, T>;
128
+ using vec3 = wp::vec_t<3, T>;
129
+ using vec4 = wp::vec_t<4, T>;
130
+
131
+ // specialized calls for common vector types
132
+
133
+ if ((type_length % 4) == 0 && (byte_stride % sizeof(vec4)) == 0)
134
+ {
135
+ return array_sum_device(reinterpret_cast<const vec4 *>(ptr_a), reinterpret_cast<vec4 *>(ptr_out), count,
136
+ byte_stride, type_length / 4);
137
+ }
138
+
139
+ if ((type_length % 3) == 0 && (byte_stride % sizeof(vec3)) == 0)
140
+ {
141
+ return array_sum_device(reinterpret_cast<const vec3 *>(ptr_a), reinterpret_cast<vec3 *>(ptr_out), count,
142
+ byte_stride, type_length / 3);
143
+ }
144
+
145
+ if ((type_length % 2) == 0 && (byte_stride % sizeof(vec2)) == 0)
146
+ {
147
+ return array_sum_device(reinterpret_cast<const vec2 *>(ptr_a), reinterpret_cast<vec2 *>(ptr_out), count,
148
+ byte_stride, type_length / 2);
149
+ }
150
+
151
+ return array_sum_device(ptr_a, ptr_out, count, byte_stride, type_length);
152
+ }
153
+
154
+ template <typename T> CUDA_CALLABLE T element_inner_product(const T &a, const T &b)
155
+ {
156
+ return a * b;
157
+ }
158
+
159
+ template <unsigned Length, typename T>
160
+ CUDA_CALLABLE T element_inner_product(const wp::vec_t<Length, T> &a, const wp::vec_t<Length, T> &b)
161
+ {
162
+ return wp::dot(a, b);
163
+ }
164
+
165
+ /// Custom iterator for allowing strided access with CUB
166
+ template <typename ElemT, typename ScalarT> struct cub_inner_product_iterator
167
+ {
168
+ typedef cub_inner_product_iterator<ElemT, ScalarT> self_type;
169
+ typedef std::ptrdiff_t difference_type;
170
+ typedef ScalarT value_type;
171
+ typedef ScalarT *pointer;
172
+ typedef ScalarT reference;
173
+
174
+ typedef std::random_access_iterator_tag iterator_category; ///< The iterator category
175
+
176
+ const ElemT *ptr_a = nullptr;
177
+ const ElemT *ptr_b = nullptr;
178
+
179
+ int stride_a = 1;
180
+ int stride_b = 1;
181
+ int type_length = 1;
182
+
183
+ CUDA_CALLABLE self_type operator++(int)
184
+ {
185
+ return ++(self_type(*this));
186
+ }
187
+
188
+ CUDA_CALLABLE self_type &operator++()
189
+ {
190
+ ptr_a += stride_a;
191
+ ptr_b += stride_b;
192
+ return *this;
193
+ }
194
+
195
+ __host__ __device__ __forceinline__ reference operator*() const
196
+ {
197
+ return compute_value(0);
198
+ }
199
+
200
+ CUDA_CALLABLE self_type operator+(difference_type n) const
201
+ {
202
+ return self_type(*this) += n;
203
+ }
204
+
205
+ CUDA_CALLABLE self_type &operator+=(difference_type n)
206
+ {
207
+ ptr_a += n * stride_a;
208
+ ptr_b += n * stride_b;
209
+ return *this;
210
+ }
211
+
212
+ CUDA_CALLABLE self_type operator-(difference_type n) const
213
+ {
214
+ return self_type(*this) -= n;
215
+ }
216
+
217
+ CUDA_CALLABLE self_type &operator-=(difference_type n)
218
+ {
219
+ ptr_a -= n * stride_a;
220
+ ptr_b -= n * stride_b;
221
+ return *this;
222
+ }
223
+
224
+ CUDA_CALLABLE difference_type operator-(const self_type &other) const
225
+ {
226
+ return (ptr_a - other.ptr_a) / stride_a;
227
+ }
228
+
229
+ CUDA_CALLABLE reference operator[](difference_type n) const
230
+ {
231
+ return compute_value(n);
232
+ }
233
+
234
+ CUDA_CALLABLE bool operator==(const self_type &rhs) const
235
+ {
236
+ return (ptr_a == rhs.ptr_a);
237
+ }
238
+
239
+ CUDA_CALLABLE bool operator!=(const self_type &rhs) const
240
+ {
241
+ return (ptr_a != rhs.ptr_a);
242
+ }
243
+
244
+ private:
245
+ CUDA_CALLABLE ScalarT compute_value(difference_type n) const
246
+ {
247
+ ScalarT val(0);
248
+ const ElemT *a = ptr_a + n * stride_a;
249
+ const ElemT *b = ptr_b + n * stride_b;
250
+ for (int k = 0; k < type_length; ++k)
251
+ {
252
+ val += element_inner_product(a[k], b[k]);
253
+ }
254
+ return val;
255
+ }
256
+ };
257
+
258
+ template <typename ElemT, typename ScalarT>
259
+ void array_inner_device(const ElemT *ptr_a, const ElemT *ptr_b, ScalarT *ptr_out, int count, int byte_stride_a,
260
+ int byte_stride_b, int type_length)
261
+ {
262
+ assert((byte_stride_a % sizeof(ElemT)) == 0);
263
+ assert((byte_stride_b % sizeof(ElemT)) == 0);
264
+ const int stride_a = byte_stride_a / sizeof(ElemT);
265
+ const int stride_b = byte_stride_b / sizeof(ElemT);
266
+
267
+ ContextGuard guard(cuda_context_get_current());
268
+ cudaStream_t stream = static_cast<cudaStream_t>(cuda_stream_get_current());
269
+
270
+ cub_inner_product_iterator<ElemT, ScalarT> inner_iterator{ptr_a, ptr_b, stride_a, stride_b, type_length};
271
+
272
+ size_t buff_size = 0;
273
+ check_cuda(cub::DeviceReduce::Sum(nullptr, buff_size, inner_iterator, ptr_out, count, stream));
274
+ void* temp_buffer = alloc_temp_device(WP_CURRENT_CONTEXT, buff_size);
275
+
276
+ check_cuda(cub::DeviceReduce::Sum(temp_buffer, buff_size, inner_iterator, ptr_out, count, stream));
277
+
278
+ free_temp_device(WP_CURRENT_CONTEXT, temp_buffer);
279
+ }
280
+
281
+ template <typename T>
282
+ void array_inner_device_dispatch(const T *ptr_a, const T *ptr_b, T *ptr_out, int count, int byte_stride_a,
283
+ int byte_stride_b, int type_length)
284
+ {
285
+ using vec2 = wp::vec_t<2, T>;
286
+ using vec3 = wp::vec_t<3, T>;
287
+ using vec4 = wp::vec_t<4, T>;
288
+
289
+ // specialized calls for common vector types
290
+
291
+ if ((type_length % 4) == 0 && (byte_stride_a % sizeof(vec4)) == 0 && (byte_stride_b % sizeof(vec4)) == 0)
292
+ {
293
+ return array_inner_device(reinterpret_cast<const vec4 *>(ptr_a), reinterpret_cast<const vec4 *>(ptr_b), ptr_out,
294
+ count, byte_stride_a, byte_stride_b, type_length / 4);
295
+ }
296
+
297
+ if ((type_length % 3) == 0 && (byte_stride_a % sizeof(vec3)) == 0 && (byte_stride_b % sizeof(vec3)) == 0)
298
+ {
299
+ return array_inner_device(reinterpret_cast<const vec3 *>(ptr_a), reinterpret_cast<const vec3 *>(ptr_b), ptr_out,
300
+ count, byte_stride_a, byte_stride_b, type_length / 3);
301
+ }
302
+
303
+ if ((type_length % 2) == 0 && (byte_stride_a % sizeof(vec2)) == 0 && (byte_stride_b % sizeof(vec2)) == 0)
304
+ {
305
+ return array_inner_device(reinterpret_cast<const vec2 *>(ptr_a), reinterpret_cast<const vec2 *>(ptr_b), ptr_out,
306
+ count, byte_stride_a, byte_stride_b, type_length / 2);
307
+ }
308
+
309
+ return array_inner_device(ptr_a, ptr_b, ptr_out, count, byte_stride_a, byte_stride_b, type_length);
310
+ }
311
+
312
+ } // anonymous namespace
313
+
314
+ void array_inner_float_device(uint64_t a, uint64_t b, uint64_t out, int count, int byte_stride_a, int byte_stride_b,
315
+ int type_len)
316
+ {
317
+ void *context = cuda_context_get_current();
318
+
319
+ const float *ptr_a = (const float *)(a);
320
+ const float *ptr_b = (const float *)(b);
321
+ float *ptr_out = (float *)(out);
322
+
323
+ array_inner_device_dispatch(ptr_a, ptr_b, ptr_out, count, byte_stride_a, byte_stride_b, type_len);
324
+ }
325
+
326
+ void array_inner_double_device(uint64_t a, uint64_t b, uint64_t out, int count, int byte_stride_a, int byte_stride_b,
327
+ int type_len)
328
+ {
329
+ const double *ptr_a = (const double *)(a);
330
+ const double *ptr_b = (const double *)(b);
331
+ double *ptr_out = (double *)(out);
332
+
333
+ array_inner_device_dispatch(ptr_a, ptr_b, ptr_out, count, byte_stride_a, byte_stride_b, type_len);
334
+ }
335
+
336
+ void array_sum_float_device(uint64_t a, uint64_t out, int count, int byte_stride, int type_length)
337
+ {
338
+ const float *ptr_a = (const float *)(a);
339
+ float *ptr_out = (float *)(out);
340
+ array_sum_device_dispatch(ptr_a, ptr_out, count, byte_stride, type_length);
341
+ }
342
+
343
+ void array_sum_double_device(uint64_t a, uint64_t out, int count, int byte_stride, int type_length)
344
+ {
345
+ const double *ptr_a = (const double *)(a);
346
+ double *ptr_out = (double *)(out);
347
+ array_sum_device_dispatch(ptr_a, ptr_out, count, byte_stride, type_length);
348
+ }
@@ -0,0 +1,62 @@
1
+ #include "warp.h"
2
+
3
+ #include <cstdint>
4
+
5
+ template <typename T>
6
+ void runlength_encode_host(int n,
7
+ const T *values,
8
+ T *run_values,
9
+ int *run_lengths,
10
+ int *run_count)
11
+ {
12
+ if (n == 0)
13
+ {
14
+ *run_count = 0;
15
+ return;
16
+ }
17
+
18
+ const T *end = values + n;
19
+
20
+ *run_count = 1;
21
+ *run_lengths = 1;
22
+ *run_values = *values;
23
+
24
+ while (++values != end)
25
+ {
26
+ if (*values == *run_values)
27
+ {
28
+ ++*run_lengths;
29
+ }
30
+ else
31
+ {
32
+ ++*run_count;
33
+ *(++run_lengths) = 1;
34
+ *(++run_values) = *values;
35
+ }
36
+ }
37
+ }
38
+
39
+ void runlength_encode_int_host(
40
+ uint64_t values,
41
+ uint64_t run_values,
42
+ uint64_t run_lengths,
43
+ uint64_t run_count,
44
+ int n)
45
+ {
46
+ runlength_encode_host<int>(n,
47
+ reinterpret_cast<const int *>(values),
48
+ reinterpret_cast<int *>(run_values),
49
+ reinterpret_cast<int *>(run_lengths),
50
+ reinterpret_cast<int *>(run_count));
51
+ }
52
+
53
+ #if !WP_ENABLE_CUDA
54
+ void runlength_encode_int_device(
55
+ uint64_t values,
56
+ uint64_t run_values,
57
+ uint64_t run_lengths,
58
+ uint64_t run_count,
59
+ int n)
60
+ {
61
+ }
62
+ #endif
@@ -0,0 +1,46 @@
1
+
2
+
3
+ #include "warp.h"
4
+ #include "cuda_util.h"
5
+
6
+ #define THRUST_IGNORE_CUB_VERSION_CHECK
7
+ #include <cub/device/device_run_length_encode.cuh>
8
+
9
+ template <typename T>
10
+ void runlength_encode_device(int n,
11
+ const T *values,
12
+ T *run_values,
13
+ int *run_lengths,
14
+ int *run_count)
15
+ {
16
+ ContextGuard guard(cuda_context_get_current());
17
+ cudaStream_t stream = static_cast<cudaStream_t>(cuda_stream_get_current());
18
+
19
+ size_t buff_size = 0;
20
+ check_cuda(cub::DeviceRunLengthEncode::Encode(
21
+ nullptr, buff_size, values, run_values, run_lengths, run_count,
22
+ n, stream));
23
+
24
+ void* temp_buffer = alloc_temp_device(WP_CURRENT_CONTEXT, buff_size);
25
+
26
+ check_cuda(cub::DeviceRunLengthEncode::Encode(
27
+ temp_buffer, buff_size, values, run_values, run_lengths, run_count,
28
+ n, stream));
29
+
30
+ free_temp_device(WP_CURRENT_CONTEXT, temp_buffer);
31
+ }
32
+
33
+ void runlength_encode_int_device(
34
+ uint64_t values,
35
+ uint64_t run_values,
36
+ uint64_t run_lengths,
37
+ uint64_t run_count,
38
+ int n)
39
+ {
40
+ return runlength_encode_device<int>(
41
+ n,
42
+ reinterpret_cast<const int *>(values),
43
+ reinterpret_cast<int *>(run_values),
44
+ reinterpret_cast<int *>(run_lengths),
45
+ reinterpret_cast<int *>(run_count));
46
+ }
warp/native/scan.cu CHANGED
@@ -3,35 +3,33 @@
3
3
 
4
4
  #define THRUST_IGNORE_CUB_VERSION_CHECK
5
5
 
6
- #include <cub/cub.cuh>
6
+ #include <cub/device/device_scan.cuh>
7
7
 
8
8
  template<typename T>
9
9
  void scan_device(const T* values_in, T* values_out, int n, bool inclusive)
10
10
  {
11
- static void* scan_temp_memory = NULL;
12
- static size_t scan_temp_max_size = 0;
11
+ ContextGuard guard(cuda_context_get_current());
12
+
13
+ cudaStream_t stream = static_cast<cudaStream_t>(cuda_stream_get_current());
13
14
 
14
15
  // compute temporary memory required
15
16
  size_t scan_temp_size;
16
17
  if (inclusive) {
17
- cub::DeviceScan::InclusiveSum(NULL, scan_temp_size, values_in, values_out, n);
18
+ check_cuda(cub::DeviceScan::InclusiveSum(NULL, scan_temp_size, values_in, values_out, n));
18
19
  } else {
19
- cub::DeviceScan::ExclusiveSum(NULL, scan_temp_size, values_in, values_out, n);
20
+ check_cuda(cub::DeviceScan::ExclusiveSum(NULL, scan_temp_size, values_in, values_out, n));
20
21
  }
21
22
 
22
- if (scan_temp_size > scan_temp_max_size)
23
- {
24
- free_device(WP_CURRENT_CONTEXT, scan_temp_memory);
25
- scan_temp_memory = alloc_device(WP_CURRENT_CONTEXT, scan_temp_size);
26
- scan_temp_max_size = scan_temp_size;
27
- }
23
+ void* temp_buffer = alloc_temp_device(WP_CURRENT_CONTEXT, scan_temp_size);
28
24
 
29
25
  // scan
30
26
  if (inclusive) {
31
- cub::DeviceScan::InclusiveSum(scan_temp_memory, scan_temp_size, values_in, values_out, n, (cudaStream_t)cuda_stream_get_current());
27
+ check_cuda(cub::DeviceScan::InclusiveSum(temp_buffer, scan_temp_size, values_in, values_out, n, stream));
32
28
  } else {
33
- cub::DeviceScan::ExclusiveSum(scan_temp_memory, scan_temp_size, values_in, values_out, n, (cudaStream_t)cuda_stream_get_current());
29
+ check_cuda(cub::DeviceScan::ExclusiveSum(temp_buffer, scan_temp_size, values_in, values_out, n, stream));
34
30
  }
31
+
32
+ free_temp_device(WP_CURRENT_CONTEXT, temp_buffer);
35
33
  }
36
34
 
37
35
  template void scan_device(const int*, int*, int, bool);
warp/native/scan.h CHANGED
@@ -4,3 +4,4 @@ template<typename T>
4
4
  void scan_host(const T* values_in, T* values_out, int n, bool inclusive = true);
5
5
  template<typename T>
6
6
  void scan_device(const T* values_in, T* values_out, int n, bool inclusive = true);
7
+