warp-lang 1.9.0__py3-none-win_amd64.whl → 1.10.0rc2__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (350) hide show
  1. warp/__init__.py +301 -287
  2. warp/__init__.pyi +2220 -313
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1075 -0
  5. warp/_src/build.py +618 -0
  6. warp/_src/build_dll.py +640 -0
  7. warp/{builtins.py → _src/builtins.py} +1497 -226
  8. warp/_src/codegen.py +4359 -0
  9. warp/{config.py → _src/config.py} +178 -169
  10. warp/_src/constants.py +57 -0
  11. warp/_src/context.py +8294 -0
  12. warp/_src/dlpack.py +462 -0
  13. warp/_src/fabric.py +355 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +508 -0
  16. warp/_src/fem/cache.py +687 -0
  17. warp/_src/fem/dirichlet.py +188 -0
  18. warp/{fem → _src/fem}/domain.py +40 -30
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +701 -0
  21. warp/{fem → _src/fem}/field/nodal_field.py +30 -15
  22. warp/{fem → _src/fem}/field/restriction.py +1 -1
  23. warp/{fem → _src/fem}/field/virtual.py +53 -27
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +77 -163
  26. warp/_src/fem/geometry/closest_point.py +97 -0
  27. warp/{fem → _src/fem}/geometry/deformed_geometry.py +14 -22
  28. warp/{fem → _src/fem}/geometry/element.py +32 -10
  29. warp/{fem → _src/fem}/geometry/geometry.py +48 -20
  30. warp/{fem → _src/fem}/geometry/grid_2d.py +12 -23
  31. warp/{fem → _src/fem}/geometry/grid_3d.py +12 -23
  32. warp/{fem → _src/fem}/geometry/hexmesh.py +40 -63
  33. warp/{fem → _src/fem}/geometry/nanogrid.py +255 -248
  34. warp/{fem → _src/fem}/geometry/partition.py +121 -63
  35. warp/{fem → _src/fem}/geometry/quadmesh.py +26 -45
  36. warp/{fem → _src/fem}/geometry/tetmesh.py +40 -63
  37. warp/{fem → _src/fem}/geometry/trimesh.py +26 -45
  38. warp/{fem → _src/fem}/integrate.py +164 -158
  39. warp/_src/fem/linalg.py +383 -0
  40. warp/_src/fem/operator.py +396 -0
  41. warp/_src/fem/polynomial.py +229 -0
  42. warp/{fem → _src/fem}/quadrature/pic_quadrature.py +15 -20
  43. warp/{fem → _src/fem}/quadrature/quadrature.py +95 -47
  44. warp/_src/fem/space/__init__.py +248 -0
  45. warp/{fem → _src/fem}/space/basis_function_space.py +20 -11
  46. warp/_src/fem/space/basis_space.py +679 -0
  47. warp/{fem → _src/fem}/space/dof_mapper.py +3 -3
  48. warp/{fem → _src/fem}/space/function_space.py +14 -13
  49. warp/{fem → _src/fem}/space/grid_2d_function_space.py +4 -7
  50. warp/{fem → _src/fem}/space/grid_3d_function_space.py +4 -4
  51. warp/{fem → _src/fem}/space/hexmesh_function_space.py +4 -10
  52. warp/{fem → _src/fem}/space/nanogrid_function_space.py +3 -9
  53. warp/{fem → _src/fem}/space/partition.py +117 -60
  54. warp/{fem → _src/fem}/space/quadmesh_function_space.py +4 -10
  55. warp/{fem → _src/fem}/space/restriction.py +66 -33
  56. warp/_src/fem/space/shape/__init__.py +152 -0
  57. warp/{fem → _src/fem}/space/shape/cube_shape_function.py +9 -9
  58. warp/{fem → _src/fem}/space/shape/shape_function.py +8 -9
  59. warp/{fem → _src/fem}/space/shape/square_shape_function.py +6 -6
  60. warp/{fem → _src/fem}/space/shape/tet_shape_function.py +3 -3
  61. warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +3 -3
  62. warp/{fem → _src/fem}/space/tetmesh_function_space.py +3 -9
  63. warp/_src/fem/space/topology.py +459 -0
  64. warp/{fem → _src/fem}/space/trimesh_function_space.py +3 -9
  65. warp/_src/fem/types.py +112 -0
  66. warp/_src/fem/utils.py +486 -0
  67. warp/_src/jax.py +186 -0
  68. warp/_src/jax_experimental/__init__.py +14 -0
  69. warp/_src/jax_experimental/custom_call.py +387 -0
  70. warp/_src/jax_experimental/ffi.py +1284 -0
  71. warp/_src/jax_experimental/xla_ffi.py +656 -0
  72. warp/_src/marching_cubes.py +708 -0
  73. warp/_src/math.py +414 -0
  74. warp/_src/optim/__init__.py +14 -0
  75. warp/_src/optim/adam.py +163 -0
  76. warp/_src/optim/linear.py +1606 -0
  77. warp/_src/optim/sgd.py +112 -0
  78. warp/_src/paddle.py +406 -0
  79. warp/_src/render/__init__.py +14 -0
  80. warp/_src/render/imgui_manager.py +289 -0
  81. warp/_src/render/render_opengl.py +3636 -0
  82. warp/_src/render/render_usd.py +937 -0
  83. warp/_src/render/utils.py +160 -0
  84. warp/_src/sparse.py +2716 -0
  85. warp/_src/tape.py +1206 -0
  86. warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
  87. warp/_src/torch.py +391 -0
  88. warp/_src/types.py +5870 -0
  89. warp/_src/utils.py +1693 -0
  90. warp/autograd.py +12 -1054
  91. warp/bin/warp-clang.dll +0 -0
  92. warp/bin/warp.dll +0 -0
  93. warp/build.py +8 -588
  94. warp/build_dll.py +6 -471
  95. warp/codegen.py +6 -4246
  96. warp/constants.py +6 -39
  97. warp/context.py +12 -7851
  98. warp/dlpack.py +6 -444
  99. warp/examples/distributed/example_jacobi_mpi.py +4 -5
  100. warp/examples/fem/example_adaptive_grid.py +1 -1
  101. warp/examples/fem/example_apic_fluid.py +1 -1
  102. warp/examples/fem/example_burgers.py +8 -8
  103. warp/examples/fem/example_diffusion.py +1 -1
  104. warp/examples/fem/example_distortion_energy.py +1 -1
  105. warp/examples/fem/example_mixed_elasticity.py +2 -2
  106. warp/examples/fem/example_navier_stokes.py +1 -1
  107. warp/examples/fem/example_nonconforming_contact.py +7 -7
  108. warp/examples/fem/example_stokes.py +1 -1
  109. warp/examples/fem/example_stokes_transfer.py +1 -1
  110. warp/examples/fem/utils.py +2 -2
  111. warp/examples/interop/example_jax_callable.py +1 -1
  112. warp/examples/interop/example_jax_ffi_callback.py +1 -1
  113. warp/examples/interop/example_jax_kernel.py +3 -2
  114. warp/examples/tile/example_tile_mcgp.py +191 -0
  115. warp/fabric.py +6 -337
  116. warp/fem/__init__.py +159 -97
  117. warp/fem/adaptivity.py +7 -489
  118. warp/fem/cache.py +9 -648
  119. warp/fem/dirichlet.py +6 -184
  120. warp/fem/field/__init__.py +8 -109
  121. warp/fem/field/field.py +7 -652
  122. warp/fem/geometry/__init__.py +7 -18
  123. warp/fem/geometry/closest_point.py +11 -77
  124. warp/fem/linalg.py +18 -366
  125. warp/fem/operator.py +11 -369
  126. warp/fem/polynomial.py +9 -209
  127. warp/fem/space/__init__.py +5 -211
  128. warp/fem/space/basis_space.py +6 -662
  129. warp/fem/space/shape/__init__.py +41 -118
  130. warp/fem/space/topology.py +6 -437
  131. warp/fem/types.py +6 -81
  132. warp/fem/utils.py +11 -444
  133. warp/jax.py +8 -165
  134. warp/jax_experimental/__init__.py +14 -1
  135. warp/jax_experimental/custom_call.py +8 -342
  136. warp/jax_experimental/ffi.py +17 -853
  137. warp/jax_experimental/xla_ffi.py +5 -596
  138. warp/marching_cubes.py +5 -689
  139. warp/math.py +16 -393
  140. warp/native/array.h +385 -37
  141. warp/native/builtin.h +316 -39
  142. warp/native/bvh.cpp +43 -9
  143. warp/native/bvh.cu +62 -27
  144. warp/native/bvh.h +310 -309
  145. warp/native/clang/clang.cpp +102 -97
  146. warp/native/coloring.cpp +0 -1
  147. warp/native/crt.h +208 -0
  148. warp/native/exports.h +156 -0
  149. warp/native/hashgrid.cu +2 -0
  150. warp/native/intersect.h +24 -1
  151. warp/native/intersect_tri.h +44 -35
  152. warp/native/mat.h +1456 -276
  153. warp/native/mesh.cpp +4 -4
  154. warp/native/mesh.cu +4 -2
  155. warp/native/mesh.h +176 -61
  156. warp/native/quat.h +0 -52
  157. warp/native/scan.cu +2 -0
  158. warp/native/sort.cu +22 -13
  159. warp/native/sort.h +2 -0
  160. warp/native/sparse.cu +7 -3
  161. warp/native/spatial.h +12 -0
  162. warp/native/tile.h +837 -70
  163. warp/native/tile_radix_sort.h +1 -1
  164. warp/native/tile_reduce.h +394 -46
  165. warp/native/tile_scan.h +4 -4
  166. warp/native/vec.h +469 -53
  167. warp/native/version.h +23 -0
  168. warp/native/volume.cpp +1 -1
  169. warp/native/volume.cu +1 -0
  170. warp/native/volume.h +1 -1
  171. warp/native/volume_builder.cu +2 -0
  172. warp/native/warp.cpp +60 -32
  173. warp/native/warp.cu +313 -201
  174. warp/native/warp.h +14 -11
  175. warp/optim/__init__.py +6 -3
  176. warp/optim/adam.py +6 -145
  177. warp/optim/linear.py +14 -1585
  178. warp/optim/sgd.py +6 -94
  179. warp/paddle.py +6 -388
  180. warp/render/__init__.py +8 -4
  181. warp/render/imgui_manager.py +7 -267
  182. warp/render/render_opengl.py +6 -3616
  183. warp/render/render_usd.py +6 -918
  184. warp/render/utils.py +6 -142
  185. warp/sparse.py +37 -2563
  186. warp/tape.py +6 -1188
  187. warp/tests/__main__.py +1 -1
  188. warp/tests/cuda/test_async.py +4 -4
  189. warp/tests/cuda/test_conditional_captures.py +1 -1
  190. warp/tests/cuda/test_multigpu.py +1 -1
  191. warp/tests/cuda/test_streams.py +58 -1
  192. warp/tests/geometry/test_bvh.py +157 -22
  193. warp/tests/geometry/test_hash_grid.py +38 -0
  194. warp/tests/geometry/test_marching_cubes.py +0 -1
  195. warp/tests/geometry/test_mesh.py +5 -3
  196. warp/tests/geometry/test_mesh_query_aabb.py +5 -12
  197. warp/tests/geometry/test_mesh_query_point.py +5 -2
  198. warp/tests/geometry/test_mesh_query_ray.py +15 -3
  199. warp/tests/geometry/test_volume_write.py +5 -5
  200. warp/tests/interop/test_dlpack.py +14 -14
  201. warp/tests/interop/test_jax.py +1382 -79
  202. warp/tests/interop/test_paddle.py +1 -1
  203. warp/tests/test_adam.py +0 -1
  204. warp/tests/test_arithmetic.py +9 -9
  205. warp/tests/test_array.py +529 -100
  206. warp/tests/test_array_reduce.py +3 -3
  207. warp/tests/test_atomic.py +12 -8
  208. warp/tests/test_atomic_bitwise.py +209 -0
  209. warp/tests/test_atomic_cas.py +4 -4
  210. warp/tests/test_bool.py +2 -2
  211. warp/tests/test_builtins_resolution.py +5 -571
  212. warp/tests/test_codegen.py +34 -15
  213. warp/tests/test_conditional.py +1 -1
  214. warp/tests/test_context.py +6 -6
  215. warp/tests/test_copy.py +242 -161
  216. warp/tests/test_ctypes.py +3 -3
  217. warp/tests/test_devices.py +24 -2
  218. warp/tests/test_examples.py +16 -84
  219. warp/tests/test_fabricarray.py +35 -35
  220. warp/tests/test_fast_math.py +0 -2
  221. warp/tests/test_fem.py +60 -14
  222. warp/tests/test_fixedarray.py +3 -3
  223. warp/tests/test_func.py +8 -5
  224. warp/tests/test_generics.py +1 -1
  225. warp/tests/test_indexedarray.py +24 -24
  226. warp/tests/test_intersect.py +39 -9
  227. warp/tests/test_large.py +1 -1
  228. warp/tests/test_lerp.py +3 -1
  229. warp/tests/test_linear_solvers.py +1 -1
  230. warp/tests/test_map.py +49 -4
  231. warp/tests/test_mat.py +52 -62
  232. warp/tests/test_mat_constructors.py +4 -5
  233. warp/tests/test_mat_lite.py +1 -1
  234. warp/tests/test_mat_scalar_ops.py +121 -121
  235. warp/tests/test_math.py +34 -0
  236. warp/tests/test_module_aot.py +4 -4
  237. warp/tests/test_modules_lite.py +28 -2
  238. warp/tests/test_print.py +11 -11
  239. warp/tests/test_quat.py +93 -58
  240. warp/tests/test_runlength_encode.py +1 -1
  241. warp/tests/test_scalar_ops.py +38 -10
  242. warp/tests/test_smoothstep.py +1 -1
  243. warp/tests/test_sparse.py +126 -15
  244. warp/tests/test_spatial.py +105 -87
  245. warp/tests/test_special_values.py +6 -6
  246. warp/tests/test_static.py +7 -7
  247. warp/tests/test_struct.py +13 -2
  248. warp/tests/test_triangle_closest_point.py +48 -1
  249. warp/tests/test_tuple.py +96 -0
  250. warp/tests/test_types.py +82 -9
  251. warp/tests/test_utils.py +52 -52
  252. warp/tests/test_vec.py +29 -29
  253. warp/tests/test_vec_constructors.py +5 -5
  254. warp/tests/test_vec_scalar_ops.py +97 -97
  255. warp/tests/test_version.py +75 -0
  256. warp/tests/tile/test_tile.py +239 -0
  257. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  258. warp/tests/tile/test_tile_cholesky.py +7 -4
  259. warp/tests/tile/test_tile_load.py +26 -2
  260. warp/tests/tile/test_tile_mathdx.py +3 -3
  261. warp/tests/tile/test_tile_matmul.py +1 -1
  262. warp/tests/tile/test_tile_mlp.py +2 -4
  263. warp/tests/tile/test_tile_reduce.py +214 -13
  264. warp/tests/unittest_suites.py +6 -14
  265. warp/tests/unittest_utils.py +10 -9
  266. warp/tests/walkthrough_debug.py +3 -1
  267. warp/torch.py +6 -373
  268. warp/types.py +29 -5750
  269. warp/utils.py +10 -1659
  270. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/METADATA +47 -103
  271. warp_lang-1.10.0rc2.dist-info/RECORD +468 -0
  272. warp_lang-1.10.0rc2.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  273. warp_lang-1.10.0rc2.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  274. warp_lang-1.10.0rc2.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  275. warp_lang-1.10.0rc2.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  276. warp_lang-1.10.0rc2.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  277. warp_lang-1.10.0rc2.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  278. warp_lang-1.10.0rc2.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  279. warp_lang-1.10.0rc2.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  280. warp_lang-1.10.0rc2.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  281. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  282. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  283. warp_lang-1.10.0rc2.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  284. warp_lang-1.10.0rc2.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  285. warp_lang-1.10.0rc2.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  286. warp_lang-1.10.0rc2.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  287. warp/examples/assets/cartpole.urdf +0 -110
  288. warp/examples/assets/crazyflie.usd +0 -0
  289. warp/examples/assets/nv_ant.xml +0 -92
  290. warp/examples/assets/nv_humanoid.xml +0 -183
  291. warp/examples/assets/quadruped.urdf +0 -268
  292. warp/examples/optim/example_bounce.py +0 -266
  293. warp/examples/optim/example_cloth_throw.py +0 -228
  294. warp/examples/optim/example_drone.py +0 -870
  295. warp/examples/optim/example_inverse_kinematics.py +0 -182
  296. warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
  297. warp/examples/optim/example_softbody_properties.py +0 -400
  298. warp/examples/optim/example_spring_cage.py +0 -245
  299. warp/examples/optim/example_trajectory.py +0 -227
  300. warp/examples/sim/example_cartpole.py +0 -143
  301. warp/examples/sim/example_cloth.py +0 -225
  302. warp/examples/sim/example_cloth_self_contact.py +0 -316
  303. warp/examples/sim/example_granular.py +0 -130
  304. warp/examples/sim/example_granular_collision_sdf.py +0 -202
  305. warp/examples/sim/example_jacobian_ik.py +0 -244
  306. warp/examples/sim/example_particle_chain.py +0 -124
  307. warp/examples/sim/example_quadruped.py +0 -203
  308. warp/examples/sim/example_rigid_chain.py +0 -203
  309. warp/examples/sim/example_rigid_contact.py +0 -195
  310. warp/examples/sim/example_rigid_force.py +0 -133
  311. warp/examples/sim/example_rigid_gyroscopic.py +0 -115
  312. warp/examples/sim/example_rigid_soft_contact.py +0 -140
  313. warp/examples/sim/example_soft_body.py +0 -196
  314. warp/examples/tile/example_tile_walker.py +0 -327
  315. warp/sim/__init__.py +0 -74
  316. warp/sim/articulation.py +0 -793
  317. warp/sim/collide.py +0 -2570
  318. warp/sim/graph_coloring.py +0 -307
  319. warp/sim/import_mjcf.py +0 -791
  320. warp/sim/import_snu.py +0 -227
  321. warp/sim/import_urdf.py +0 -579
  322. warp/sim/import_usd.py +0 -898
  323. warp/sim/inertia.py +0 -357
  324. warp/sim/integrator.py +0 -245
  325. warp/sim/integrator_euler.py +0 -2000
  326. warp/sim/integrator_featherstone.py +0 -2101
  327. warp/sim/integrator_vbd.py +0 -2487
  328. warp/sim/integrator_xpbd.py +0 -3295
  329. warp/sim/model.py +0 -4821
  330. warp/sim/particles.py +0 -121
  331. warp/sim/render.py +0 -431
  332. warp/sim/utils.py +0 -431
  333. warp/tests/sim/disabled_kinematics.py +0 -244
  334. warp/tests/sim/test_cloth.py +0 -863
  335. warp/tests/sim/test_collision.py +0 -743
  336. warp/tests/sim/test_coloring.py +0 -347
  337. warp/tests/sim/test_inertia.py +0 -161
  338. warp/tests/sim/test_model.py +0 -226
  339. warp/tests/sim/test_sim_grad.py +0 -287
  340. warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
  341. warp/tests/sim/test_sim_kinematics.py +0 -98
  342. warp/thirdparty/__init__.py +0 -0
  343. warp_lang-1.9.0.dist-info/RECORD +0 -456
  344. /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
  345. /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
  346. /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
  347. /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
  348. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/WHEEL +0 -0
  349. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/licenses/LICENSE.md +0 -0
  350. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/top_level.txt +0 -0
@@ -921,7 +921,7 @@ template <typename K, typename V, typename KeyToUint>
921
921
  void radix_sort_pairs_cpu_core(K* keys, K* aux_keys, V* values, V* aux_values, int n)
922
922
  {
923
923
  KeyToUint converter;
924
- static unsigned int tables[2][1 << 16];
924
+ unsigned int tables[2][1 << 16];
925
925
  memset(tables, 0, sizeof(tables));
926
926
 
927
927
  // build histograms
warp/native/tile_reduce.h CHANGED
@@ -19,6 +19,12 @@
19
19
 
20
20
  #include "tile.h"
21
21
 
22
+ #ifdef __clang__
23
+ // disable warnings related to C++17 extensions on CPU JIT builds
24
+ #pragma clang diagnostic push
25
+ #pragma clang diagnostic ignored "-Wc++17-extensions"
26
+ #endif // __clang__
27
+
22
28
  #define WP_TILE_WARP_SIZE 32
23
29
 
24
30
  namespace wp
@@ -76,7 +82,7 @@ inline CUDA_CALLABLE T warp_shuffle_down(T val, int offset, int mask)
76
82
  return output;
77
83
  }
78
84
 
79
- // Vector overload
85
+ // vector overload
80
86
  template <unsigned Length, typename T>
81
87
  inline CUDA_CALLABLE wp::vec_t<Length, T> warp_shuffle_down(wp::vec_t<Length, T> val, int offset, int mask)
82
88
  {
@@ -88,7 +94,7 @@ inline CUDA_CALLABLE wp::vec_t<Length, T> warp_shuffle_down(wp::vec_t<Length, T>
88
94
  return result;
89
95
  }
90
96
 
91
- // Matrix overload
97
+ // matrix overload
92
98
  template <unsigned Rows, unsigned Cols, typename T>
93
99
  inline CUDA_CALLABLE wp::mat_t<Rows, Cols, T> warp_shuffle_down(wp::mat_t<Rows, Cols, T> val, int offset, int mask)
94
100
  {
@@ -117,7 +123,7 @@ inline CUDA_CALLABLE T warp_reduce(T val, Op f, unsigned int mask)
117
123
  }
118
124
  else
119
125
  {
120
- // handle partial warp case
126
+ // handle partial warp case - works for contiguous masks
121
127
  for (int offset=WP_TILE_WARP_SIZE/2; offset > 0; offset /= 2)
122
128
  {
123
129
  T shfl_val = warp_shuffle_down(sum, offset, mask);
@@ -175,6 +181,51 @@ inline CUDA_CALLABLE ValueAndIndex<T> warp_reduce_tracked(T val, int idx, Op f,
175
181
  return result;
176
182
  }
177
183
 
184
+ // combines per-thread reduction results across warps and the entire block
185
+ // assumes each thread has already reduced its local data to thread_sum
186
+ // returns the block-wide reduced value (only valid in thread 0)
187
+ template <typename T, typename Op>
188
+ inline CUDA_CALLABLE T block_combine_thread_results(T thread_sum, bool thread_has_data, Op f,
189
+ T* partials, int& active_warps)
190
+ {
191
+ constexpr int warp_count = (WP_TILE_BLOCK_DIM + WP_TILE_WARP_SIZE - 1) / WP_TILE_WARP_SIZE;
192
+ const int warp_index = threadIdx.x / WP_TILE_WARP_SIZE;
193
+ const int lane_index = threadIdx.x % WP_TILE_WARP_SIZE;
194
+
195
+ // determine which threads have data
196
+ unsigned int mask = __ballot_sync(0xFFFFFFFF, thread_has_data);
197
+ bool warp_is_active = mask != 0;
198
+
199
+ // warp reduction
200
+ T warp_sum;
201
+ if (thread_has_data)
202
+ warp_sum = warp_reduce(thread_sum, f, mask);
203
+
204
+ // lane 0 of each active warp writes to shared memory and increments counter
205
+ if (lane_index == 0 && warp_is_active)
206
+ {
207
+ partials[warp_index] = warp_sum;
208
+ atomicAdd(&active_warps, 1);
209
+ }
210
+
211
+ // sync to ensure all warps have written their partials
212
+ WP_TILE_SYNC();
213
+
214
+ // thread 0 performs final reduction across active warps
215
+ T block_sum;
216
+ if (threadIdx.x == 0)
217
+ {
218
+ block_sum = partials[0];
219
+
220
+ for (int w = 1; w < active_warps; ++w)
221
+ {
222
+ block_sum = f(block_sum, partials[w]);
223
+ }
224
+ }
225
+
226
+ return block_sum;
227
+ }
228
+
178
229
  // non-axis version which computes sum
179
230
  // across the entire tile using the whole block
180
231
  template <typename Tile, typename Op>
@@ -185,15 +236,14 @@ auto tile_reduce_impl(Op f, Tile& t)
185
236
  auto input = t.copy_to_register();
186
237
  auto output = tile_register_t<T, tile_layout_register_t<tile_shape_t<1>>>();
187
238
 
188
- const int warp_count = (WP_TILE_BLOCK_DIM + WP_TILE_WARP_SIZE - 1)/WP_TILE_WARP_SIZE;
189
- const int warp_index = threadIdx.x/WP_TILE_WARP_SIZE;
190
- const int lane_index = threadIdx.x%WP_TILE_WARP_SIZE;
191
-
192
- T thread_sum = input.data[0];
239
+ constexpr int warp_count = (WP_TILE_BLOCK_DIM + WP_TILE_WARP_SIZE - 1) / WP_TILE_WARP_SIZE;
193
240
 
194
241
  using Layout = typename decltype(input)::Layout;
195
242
 
196
- // thread reduction
243
+ // step 1: each thread reduces its own registers locally
244
+ T thread_sum = input.data[0];
245
+ bool thread_has_data = Layout::valid(Layout::linear_from_register(0));
246
+
197
247
  WP_PRAGMA_UNROLL
198
248
  for (int i=1; i < Layout::NumRegs; ++i)
199
249
  {
@@ -204,48 +254,190 @@ auto tile_reduce_impl(Op f, Tile& t)
204
254
  thread_sum = f(thread_sum, input.data[i]);
205
255
  }
206
256
 
207
- // ensure that only threads with at least one valid item participate in the reduction
208
- unsigned int mask = __ballot_sync(__activemask(), Layout::valid(Layout::linear_from_register(0)));
209
- bool warp_is_active = mask != 0;
210
-
211
- // warp reduction
212
- T warp_sum = warp_reduce(thread_sum, f, mask);
213
-
214
- // fixed size scratch pad for partial results in shared memory
215
- WP_TILE_SHARED T partials[warp_count];
216
-
217
- // count of active warps
218
- WP_TILE_SHARED int active_warps;
257
+ // shared memory for cross-warp reduction
258
+ __shared__ T partials[warp_count];
259
+ __shared__ int active_warps;
260
+
219
261
  if (threadIdx.x == 0)
220
262
  active_warps = 0;
221
263
 
222
- // ensure active_warps is initialized
223
264
  WP_TILE_SYNC();
224
265
 
225
- if (lane_index == 0 && warp_is_active)
226
- {
227
- partials[warp_index] = warp_sum;
228
- atomicAdd(&active_warps, 1);
229
- }
266
+ // step 2-3: combine thread results across warps and block
267
+ T block_sum = block_combine_thread_results(thread_sum, thread_has_data, f, partials, active_warps);
230
268
 
231
- // ensure partials are ready
232
- WP_TILE_SYNC();
233
-
234
- // reduce across block, todo: use warp_reduce() here
235
269
  if (threadIdx.x == 0)
236
- {
237
- T block_sum = partials[0];
238
-
239
- WP_PRAGMA_UNROLL
240
- for (int i=1; i < active_warps; ++i)
241
- block_sum = f(block_sum, partials[i]);
242
-
243
270
  output.data[0] = block_sum;
244
- }
245
271
 
246
272
  return output;
247
273
  }
248
274
 
275
+ template <int Axis, typename Op, typename Tile>
276
+ auto tile_reduce_axis_impl(Op f, Tile& t)
277
+ {
278
+ using T = typename Tile::Type;
279
+ using InputShape = typename Tile::Layout::Shape;
280
+ using OutputShape = typename tile_shape_remove_dim<Axis, InputShape>::type;
281
+
282
+ constexpr int reduce_dim_size = InputShape::dim(Axis);
283
+ constexpr int output_size = OutputShape::size();
284
+
285
+ // special case: 1D input delegates to block-wide tile_reduce_impl for optimal performance
286
+ if constexpr (InputShape::N == 1)
287
+ {
288
+ return tile_reduce_impl(f, t);
289
+ }
290
+
291
+ // shared memory buffer for the output (used by all tiers)
292
+ __shared__ T output_buffer[output_size];
293
+
294
+ // create output layout for coordinate conversion (used by all tiers)
295
+ using OutputLayout = tile_layout_strided_t<OutputShape>;
296
+
297
+ if constexpr (reduce_dim_size <= 32)
298
+ {
299
+ // Tier 1: Single thread per output element (optimal for small reductions)
300
+
301
+ // each thread processes output elements, performing reduction along the axis
302
+ for (int out_idx = WP_TILE_THREAD_IDX; out_idx < output_size; out_idx += WP_TILE_BLOCK_DIM)
303
+ {
304
+ // convert output linear index to output coordinates
305
+ auto out_coord = OutputLayout::coord_from_linear(out_idx);
306
+
307
+ // initialize accumulator with first element along the reduction axis
308
+ T accumulator = t.data(tile_coord_insert_axis<Axis>(out_coord, 0));
309
+
310
+ // reduce across the axis
311
+ for (int i = 1; i < reduce_dim_size; ++i)
312
+ {
313
+ accumulator = f(accumulator, t.data(tile_coord_insert_axis<Axis>(out_coord, i)));
314
+ }
315
+
316
+ // store to output buffer
317
+ output_buffer[out_idx] = accumulator;
318
+ }
319
+
320
+ // sync before reading output
321
+ WP_TILE_SYNC();
322
+ }
323
+ else if constexpr (reduce_dim_size <= 256)
324
+ {
325
+ // Tier 2: Warp-based reduction (one warp per output element)
326
+ constexpr int warp_count = (WP_TILE_BLOCK_DIM + WP_TILE_WARP_SIZE - 1) / WP_TILE_WARP_SIZE;
327
+ const int warp_index = threadIdx.x / WP_TILE_WARP_SIZE;
328
+ const int lane_index = threadIdx.x % WP_TILE_WARP_SIZE;
329
+
330
+ constexpr int chunks_per_slice = (reduce_dim_size + WP_TILE_WARP_SIZE - 1) / WP_TILE_WARP_SIZE;
331
+
332
+ // shared memory: one accumulator per warp
333
+ __shared__ T warp_partials[warp_count];
334
+
335
+ // each warp processes output slices
336
+ for (int out_idx = warp_index; out_idx < output_size; out_idx += warp_count)
337
+ {
338
+ auto out_coord = OutputLayout::coord_from_linear(out_idx);
339
+
340
+ // process the reduction axis in chunks of 32
341
+ for (int chunk = 0; chunk < chunks_per_slice; ++chunk)
342
+ {
343
+ int axis_idx = chunk * WP_TILE_WARP_SIZE + lane_index;
344
+ bool valid = axis_idx < reduce_dim_size;
345
+
346
+ T val;
347
+ if (valid)
348
+ {
349
+ auto in_coord = tile_coord_insert_axis<Axis>(out_coord, axis_idx);
350
+ val = t.data(in_coord);
351
+ }
352
+
353
+ // warp reduce this chunk (only valid lanes participate)
354
+ unsigned int mask = __ballot_sync(0xFFFFFFFF, valid);
355
+ T chunk_result = warp_reduce(val, f, mask);
356
+
357
+ // lane 0 accumulates the chunk result
358
+ if (lane_index == 0)
359
+ {
360
+ if (chunk == 0)
361
+ warp_partials[warp_index] = chunk_result;
362
+ else
363
+ warp_partials[warp_index] = f(warp_partials[warp_index], chunk_result);
364
+ }
365
+ }
366
+
367
+ // lane 0 writes final result for this output element
368
+ if (lane_index == 0)
369
+ output_buffer[out_idx] = warp_partials[warp_index];
370
+ }
371
+
372
+ // sync before reading output
373
+ WP_TILE_SYNC();
374
+ }
375
+ else
376
+ {
377
+ // Tier 3: Block-level reduction (entire block collaborates on each output element)
378
+ constexpr int warp_count = (WP_TILE_BLOCK_DIM + WP_TILE_WARP_SIZE - 1) / WP_TILE_WARP_SIZE;
379
+
380
+ // shared memory for cross-warp reduction
381
+ __shared__ T partials[warp_count];
382
+ __shared__ int active_warps;
383
+
384
+ // process each output element sequentially with full block cooperation
385
+ for (int out_idx = 0; out_idx < output_size; ++out_idx)
386
+ {
387
+ auto out_coord = OutputLayout::coord_from_linear(out_idx);
388
+
389
+ // step 1: each thread reduces its strided subset of the slice locally
390
+ bool thread_has_data = threadIdx.x < reduce_dim_size;
391
+ T thread_sum;
392
+
393
+ if (thread_has_data)
394
+ {
395
+ // initialize with first element
396
+ auto in_coord = tile_coord_insert_axis<Axis>(out_coord, threadIdx.x);
397
+ thread_sum = t.data(in_coord);
398
+
399
+ // reduce remaining elements with stride
400
+ for (int i = threadIdx.x + WP_TILE_BLOCK_DIM; i < reduce_dim_size; i += WP_TILE_BLOCK_DIM)
401
+ {
402
+ auto in_coord = tile_coord_insert_axis<Axis>(out_coord, i);
403
+ T val = t.data(in_coord);
404
+ thread_sum = f(thread_sum, val);
405
+ }
406
+ }
407
+
408
+ // initialize active warp counter
409
+ if (threadIdx.x == 0)
410
+ active_warps = 0;
411
+
412
+ WP_TILE_SYNC();
413
+
414
+ // step 2-3: combine thread results across warps and block
415
+ T block_sum = block_combine_thread_results(thread_sum, thread_has_data, f, partials, active_warps);
416
+
417
+ if (threadIdx.x == 0)
418
+ output_buffer[out_idx] = block_sum;
419
+
420
+ // sync before next output element
421
+ WP_TILE_SYNC();
422
+ }
423
+ }
424
+
425
+ // copy from shared memory buffer to register tile (common to all tiers)
426
+ auto output = tile_register_t<T, tile_layout_register_t<OutputShape>>();
427
+ using OutputRegLayout = typename decltype(output)::Layout;
428
+
429
+ WP_PRAGMA_UNROLL
430
+ for (int i = 0; i < OutputRegLayout::NumRegs; ++i)
431
+ {
432
+ int linear = OutputRegLayout::linear_from_register(i);
433
+ if (!OutputRegLayout::valid(linear))
434
+ break;
435
+
436
+ output.data[i] = output_buffer[linear];
437
+ }
438
+
439
+ return output;
440
+ }
249
441
 
250
442
  // non-axis version which computes sum
251
443
  // across the entire tile using the whole block
@@ -286,11 +478,11 @@ auto tile_arg_reduce_impl(Op f, OpTrack track, Tile& t)
286
478
  ValueAndIndex<T> warp_sum = warp_reduce_tracked(thread_sum, champion_index, f, track, mask);
287
479
 
288
480
  // fixed size scratch pad for partial results in shared memory
289
- WP_TILE_SHARED T partials[warp_count];
290
- WP_TILE_SHARED int partials_idx[warp_count];
481
+ __shared__ T partials[warp_count];
482
+ __shared__ int partials_idx[warp_count];
291
483
 
292
484
  // count of active warps
293
- WP_TILE_SHARED int active_warps;
485
+ __shared__ int active_warps;
294
486
  if (threadIdx.x == 0)
295
487
  active_warps = 0;
296
488
 
@@ -356,6 +548,65 @@ auto tile_reduce_impl(Op f, Tile& t)
356
548
  return output;
357
549
  }
358
550
 
551
+ template <int Axis, typename Op, typename Tile>
552
+ auto tile_reduce_axis_impl(Op f, Tile& t)
553
+ {
554
+ using T = typename Tile::Type;
555
+ using InputShape = typename Tile::Layout::Shape;
556
+ using OutputShape = typename tile_shape_remove_dim<Axis, InputShape>::type;
557
+
558
+ constexpr int reduce_dim_size = InputShape::dim(Axis);
559
+
560
+ // CPU version - work directly with register tiles, no thread coordination needed
561
+ auto input = t.copy_to_register();
562
+ auto output = tile_register_t<T, tile_layout_register_t<OutputShape>>();
563
+ using OutputLayout = typename decltype(output)::Layout;
564
+
565
+ // iterate through each output element and reduce along the axis
566
+ constexpr int output_size = OutputShape::size();
567
+ for (int out_idx = 0; out_idx < output_size; ++out_idx)
568
+ {
569
+ T accumulator;
570
+
571
+ // special case for 1D input (reduces to single value)
572
+ if constexpr (InputShape::N == 1)
573
+ {
574
+ accumulator = input.data[0];
575
+ for (int i = 1; i < reduce_dim_size; ++i)
576
+ {
577
+ // input is in registers, linear access
578
+ accumulator = f(accumulator, input.data[i]);
579
+ }
580
+ }
581
+ else
582
+ {
583
+ // multi-dimensional case
584
+ auto out_coord = OutputLayout::coord_from_linear(out_idx);
585
+
586
+ // get input coordinates by inserting axis values
587
+ auto coord_0 = tile_coord_insert_axis<Axis>(out_coord, 0);
588
+ int input_linear_0 = tile_layout_register_t<InputShape>::linear_from_coord(coord_0);
589
+ int input_reg_0 = tile_layout_register_t<InputShape>::register_from_linear(input_linear_0);
590
+ accumulator = input.data[input_reg_0];
591
+
592
+ // reduce across the axis
593
+ for (int i = 1; i < reduce_dim_size; ++i)
594
+ {
595
+ auto coord_i = tile_coord_insert_axis<Axis>(out_coord, i);
596
+ int input_linear_i = tile_layout_register_t<InputShape>::linear_from_coord(coord_i);
597
+ int input_reg_i = tile_layout_register_t<InputShape>::register_from_linear(input_linear_i);
598
+ accumulator = f(accumulator, input.data[input_reg_i]);
599
+ }
600
+ }
601
+
602
+ // store to output register
603
+ int output_reg = OutputLayout::register_from_linear(out_idx);
604
+ output.data[output_reg] = accumulator;
605
+ }
606
+
607
+ return output;
608
+ }
609
+
359
610
  template <typename Tile, typename Op, typename OpTrack>
360
611
  auto tile_arg_reduce_impl(Op f, OpTrack track, Tile& t)
361
612
  {
@@ -391,15 +642,25 @@ inline void adj_tile_reduce_impl()
391
642
  // todo: general purpose reduction gradients not implemented
392
643
  }
393
644
 
645
+ inline void adj_tile_reduce_axis_impl()
646
+ {
647
+ // todo: axis-specific reduction gradients not implemented
648
+ }
649
+
394
650
  // entry point for Python code-gen, wraps op in a lambda to perform overload resolution
395
651
  #define tile_reduce(op, t) tile_reduce_impl([](auto x, auto y) { return op(x, y);}, t)
396
- #define adj_tile_reduce(op, a, adj_op, adj_a, adj_ret) adj_tile_reduce_impl()
652
+ #define adj_tile_reduce(op, t, adj_op, adj_t, adj_ret) adj_tile_reduce_impl()
397
653
 
398
654
  #define tile_arg_reduce(op, opTrack, t) tile_arg_reduce_impl([](auto x, auto y) { return op(x, y);}, [](auto a, auto b, auto c, auto d) { return opTrack(a, b, c, d); }, t)
399
- #define adj_tile_arg_reduce(op, a, adj_op, adj_a, adj_ret) adj_tile_arg_reduce_impl()
655
+ #define adj_tile_arg_reduce(op, t, adj_op, adj_t, adj_ret) adj_tile_arg_reduce_impl()
656
+
657
+ // axis-specific reduction entry points
658
+ #define tile_reduce_axis(op, t, axis) tile_reduce_axis_impl<axis>([](auto x, auto y) { return op(x, y);}, t)
659
+ #define adj_tile_reduce_axis(op, t, axis, adj_op, adj_t, adj_axis, adj_ret) adj_tile_reduce_axis_impl()
400
660
 
401
661
  // convenience methods for specific reductions
402
662
 
663
+ // whole-tile sum
403
664
  template <typename Tile>
404
665
  auto tile_sum(Tile& t)
405
666
  {
@@ -418,7 +679,7 @@ void adj_tile_sum(Tile& t, Tile& adj_t, AdjTile& adj_ret)
418
679
  T scratch = adj_reg.data[0];
419
680
  #else
420
681
  // broadcast incoming adjoint to block
421
- WP_TILE_SHARED T scratch;
682
+ __shared__ T scratch;
422
683
  if (WP_TILE_THREAD_IDX == 0)
423
684
  scratch = adj_reg.data[0];
424
685
 
@@ -434,6 +695,90 @@ void adj_tile_sum(Tile& t, Tile& adj_t, AdjTile& adj_ret)
434
695
  adj_t.grad_add(adj_ret_reg);
435
696
  }
436
697
 
698
+ // axis-specific sum
699
+ template <int Axis, typename Tile>
700
+ auto tile_sum(Tile& t)
701
+ {
702
+ return tile_reduce_axis_impl<Axis>([](auto x, auto y) { return add(x, y); }, t);
703
+ }
704
+
705
+ // special case adjoint for axis-specific summation
706
+ template<int Axis, typename Tile, typename AdjTile>
707
+ void adj_tile_sum(Tile& t, Tile& adj_t, AdjTile& adj_ret)
708
+ {
709
+ using InputShape = typename Tile::Layout::Shape;
710
+
711
+ if constexpr (InputShape::N == 1)
712
+ {
713
+ // 1D -> scalar case: broadcast scalar to 1D
714
+ auto broadcasted = tile_broadcast<InputShape::dim(0), 0>(adj_ret);
715
+ tile_add_inplace(adj_t, broadcasted);
716
+ }
717
+ else if constexpr (InputShape::N == 2)
718
+ {
719
+ if constexpr (Axis == 0)
720
+ {
721
+ // broadcast from (D1,) to (D0, D1) with strides (0, 1)
722
+ auto broadcasted = tile_broadcast<InputShape::dim(0), InputShape::dim(1), 0, 1>(adj_ret);
723
+ tile_add_inplace(adj_t, broadcasted);
724
+ }
725
+ else // Axis == 1
726
+ {
727
+ // broadcast from (D0,) to (D0, D1) with strides (1, 0)
728
+ auto broadcasted = tile_broadcast<InputShape::dim(0), InputShape::dim(1), 1, 0>(adj_ret);
729
+ tile_add_inplace(adj_t, broadcasted);
730
+ }
731
+ }
732
+ else if constexpr (InputShape::N == 3)
733
+ {
734
+ if constexpr (Axis == 0)
735
+ {
736
+ // broadcast from (D1, D2) to (D0, D1, D2) with strides (0, D2, 1)
737
+ auto broadcasted = tile_broadcast<InputShape::dim(0), InputShape::dim(1), InputShape::dim(2), 0, InputShape::dim(2), 1>(adj_ret);
738
+ tile_add_inplace(adj_t, broadcasted);
739
+ }
740
+ else if constexpr (Axis == 1)
741
+ {
742
+ // broadcast from (D0, D2) to (D0, D1, D2) with strides (D2, 0, 1)
743
+ auto broadcasted = tile_broadcast<InputShape::dim(0), InputShape::dim(1), InputShape::dim(2), InputShape::dim(2), 0, 1>(adj_ret);
744
+ tile_add_inplace(adj_t, broadcasted);
745
+ }
746
+ else // Axis == 2
747
+ {
748
+ // broadcast from (D0, D1) to (D0, D1, D2) with strides (D1, 1, 0)
749
+ auto broadcasted = tile_broadcast<InputShape::dim(0), InputShape::dim(1), InputShape::dim(2), InputShape::dim(1), 1, 0>(adj_ret);
750
+ tile_add_inplace(adj_t, broadcasted);
751
+ }
752
+ }
753
+ else if constexpr (InputShape::N == 4)
754
+ {
755
+ if constexpr (Axis == 0)
756
+ {
757
+ // broadcast from (D1, D2, D3) to (D0, D1, D2, D3) with strides (0, D2*D3, D3, 1)
758
+ auto broadcasted = tile_broadcast<InputShape::dim(0), InputShape::dim(1), InputShape::dim(2), InputShape::dim(3), 0, InputShape::dim(2)*InputShape::dim(3), InputShape::dim(3), 1>(adj_ret);
759
+ tile_add_inplace(adj_t, broadcasted);
760
+ }
761
+ else if constexpr (Axis == 1)
762
+ {
763
+ // broadcast from (D0, D2, D3) to (D0, D1, D2, D3) with strides (D2*D3, 0, D3, 1)
764
+ auto broadcasted = tile_broadcast<InputShape::dim(0), InputShape::dim(1), InputShape::dim(2), InputShape::dim(3), InputShape::dim(2)*InputShape::dim(3), 0, InputShape::dim(3), 1>(adj_ret);
765
+ tile_add_inplace(adj_t, broadcasted);
766
+ }
767
+ else if constexpr (Axis == 2)
768
+ {
769
+ // broadcast from (D0, D1, D3) to (D0, D1, D2, D3) with strides (D1*D3, D3, 0, 1)
770
+ auto broadcasted = tile_broadcast<InputShape::dim(0), InputShape::dim(1), InputShape::dim(2), InputShape::dim(3), InputShape::dim(1)*InputShape::dim(3), InputShape::dim(3), 0, 1>(adj_ret);
771
+ tile_add_inplace(adj_t, broadcasted);
772
+ }
773
+ else // Axis == 3
774
+ {
775
+ // broadcast from (D0, D1, D2) to (D0, D1, D2, D3) with strides (D1*D2, D2, 1, 0)
776
+ auto broadcasted = tile_broadcast<InputShape::dim(0), InputShape::dim(1), InputShape::dim(2), InputShape::dim(3), InputShape::dim(1)*InputShape::dim(2), InputShape::dim(2), 1, 0>(adj_ret);
777
+ tile_add_inplace(adj_t, broadcasted);
778
+ }
779
+ }
780
+ }
781
+
437
782
  template <typename Tile>
438
783
  auto tile_max(Tile& t)
439
784
  {
@@ -485,6 +830,9 @@ void adj_tile_argmin(Tile& t, Tile& adj_t, AdjTile& adj_ret)
485
830
  }
486
831
 
487
832
 
833
+ } // namespace wp
488
834
 
489
835
 
490
- } // namespace wp
836
+ #ifdef __clang__
837
+ #pragma clang diagnostic pop
838
+ #endif
warp/native/tile_scan.h CHANGED
@@ -50,7 +50,7 @@ inline CUDA_CALLABLE T scan_warp_inclusive(int lane, T value)
50
50
  template<typename T>
51
51
  inline CUDA_CALLABLE T thread_block_scan_inclusive(int lane, int warp_index, int num_warps, T value)
52
52
  {
53
- WP_TILE_SHARED T sums[1024 / WP_TILE_WARP_SIZE]; // 1024 is the maximum number of threads per block
53
+ __shared__ T sums[1024 / WP_TILE_WARP_SIZE]; // 1024 is the maximum number of threads per block
54
54
 
55
55
  value = scan_warp_inclusive(lane, value);
56
56
 
@@ -85,7 +85,7 @@ inline CUDA_CALLABLE void thread_block_scan(T* values, int num_elements)
85
85
  const int num_threads_in_block = blockDim.x;
86
86
  const int num_iterations = (num_elements + num_threads_in_block - 1) / num_threads_in_block;
87
87
 
88
- WP_TILE_SHARED T offset;
88
+ __shared__ T offset;
89
89
  if (threadIdx.x == 0)
90
90
  offset = T(0);
91
91
 
@@ -124,7 +124,7 @@ inline CUDA_CALLABLE auto tile_scan_inclusive_impl(Tile& t)
124
124
  constexpr int num_elements_to_scan = Tile::Layout::Shape::size();
125
125
 
126
126
  // create a temporary shared tile to hold the input values
127
- WP_TILE_SHARED T smem[num_elements_to_scan];
127
+ __shared__ T smem[num_elements_to_scan];
128
128
  tile_shared_t<T, tile_layout_strided_t<typename Tile::Layout::Shape>, false> scratch(smem, nullptr);
129
129
 
130
130
  // copy input values to scratch space
@@ -147,7 +147,7 @@ inline CUDA_CALLABLE auto tile_scan_exclusive_impl(Tile& t)
147
147
  constexpr int num_elements_to_scan = Tile::Layout::Shape::size();
148
148
 
149
149
  // create a temporary shared tile to hold the input values
150
- WP_TILE_SHARED T smem[num_elements_to_scan];
150
+ __shared__ T smem[num_elements_to_scan];
151
151
  tile_shared_t<T, tile_layout_strided_t<typename Tile::Layout::Shape>, false> scratch(smem, nullptr);
152
152
 
153
153
  // copy input values to scratch space