warp-lang 1.10.0__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (468) hide show
  1. warp/__init__.py +334 -0
  2. warp/__init__.pyi +5856 -0
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1077 -0
  5. warp/_src/build.py +620 -0
  6. warp/_src/build_dll.py +642 -0
  7. warp/_src/builtins.py +10555 -0
  8. warp/_src/codegen.py +4361 -0
  9. warp/_src/config.py +178 -0
  10. warp/_src/constants.py +59 -0
  11. warp/_src/context.py +8352 -0
  12. warp/_src/dlpack.py +464 -0
  13. warp/_src/fabric.py +362 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +510 -0
  16. warp/_src/fem/cache.py +689 -0
  17. warp/_src/fem/dirichlet.py +190 -0
  18. warp/_src/fem/domain.py +553 -0
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +703 -0
  21. warp/_src/fem/field/nodal_field.py +403 -0
  22. warp/_src/fem/field/restriction.py +39 -0
  23. warp/_src/fem/field/virtual.py +1021 -0
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/_src/fem/geometry/adaptive_nanogrid.py +782 -0
  26. warp/_src/fem/geometry/closest_point.py +99 -0
  27. warp/_src/fem/geometry/deformed_geometry.py +277 -0
  28. warp/_src/fem/geometry/element.py +854 -0
  29. warp/_src/fem/geometry/geometry.py +693 -0
  30. warp/_src/fem/geometry/grid_2d.py +478 -0
  31. warp/_src/fem/geometry/grid_3d.py +539 -0
  32. warp/_src/fem/geometry/hexmesh.py +956 -0
  33. warp/_src/fem/geometry/nanogrid.py +660 -0
  34. warp/_src/fem/geometry/partition.py +483 -0
  35. warp/_src/fem/geometry/quadmesh.py +597 -0
  36. warp/_src/fem/geometry/tetmesh.py +762 -0
  37. warp/_src/fem/geometry/trimesh.py +588 -0
  38. warp/_src/fem/integrate.py +2507 -0
  39. warp/_src/fem/linalg.py +385 -0
  40. warp/_src/fem/operator.py +398 -0
  41. warp/_src/fem/polynomial.py +231 -0
  42. warp/_src/fem/quadrature/__init__.py +17 -0
  43. warp/_src/fem/quadrature/pic_quadrature.py +318 -0
  44. warp/_src/fem/quadrature/quadrature.py +665 -0
  45. warp/_src/fem/space/__init__.py +248 -0
  46. warp/_src/fem/space/basis_function_space.py +499 -0
  47. warp/_src/fem/space/basis_space.py +681 -0
  48. warp/_src/fem/space/dof_mapper.py +253 -0
  49. warp/_src/fem/space/function_space.py +312 -0
  50. warp/_src/fem/space/grid_2d_function_space.py +179 -0
  51. warp/_src/fem/space/grid_3d_function_space.py +229 -0
  52. warp/_src/fem/space/hexmesh_function_space.py +255 -0
  53. warp/_src/fem/space/nanogrid_function_space.py +199 -0
  54. warp/_src/fem/space/partition.py +435 -0
  55. warp/_src/fem/space/quadmesh_function_space.py +222 -0
  56. warp/_src/fem/space/restriction.py +221 -0
  57. warp/_src/fem/space/shape/__init__.py +152 -0
  58. warp/_src/fem/space/shape/cube_shape_function.py +1107 -0
  59. warp/_src/fem/space/shape/shape_function.py +134 -0
  60. warp/_src/fem/space/shape/square_shape_function.py +928 -0
  61. warp/_src/fem/space/shape/tet_shape_function.py +829 -0
  62. warp/_src/fem/space/shape/triangle_shape_function.py +674 -0
  63. warp/_src/fem/space/tetmesh_function_space.py +270 -0
  64. warp/_src/fem/space/topology.py +461 -0
  65. warp/_src/fem/space/trimesh_function_space.py +193 -0
  66. warp/_src/fem/types.py +114 -0
  67. warp/_src/fem/utils.py +488 -0
  68. warp/_src/jax.py +188 -0
  69. warp/_src/jax_experimental/__init__.py +14 -0
  70. warp/_src/jax_experimental/custom_call.py +389 -0
  71. warp/_src/jax_experimental/ffi.py +1286 -0
  72. warp/_src/jax_experimental/xla_ffi.py +658 -0
  73. warp/_src/marching_cubes.py +710 -0
  74. warp/_src/math.py +416 -0
  75. warp/_src/optim/__init__.py +14 -0
  76. warp/_src/optim/adam.py +165 -0
  77. warp/_src/optim/linear.py +1608 -0
  78. warp/_src/optim/sgd.py +114 -0
  79. warp/_src/paddle.py +408 -0
  80. warp/_src/render/__init__.py +14 -0
  81. warp/_src/render/imgui_manager.py +291 -0
  82. warp/_src/render/render_opengl.py +3638 -0
  83. warp/_src/render/render_usd.py +939 -0
  84. warp/_src/render/utils.py +162 -0
  85. warp/_src/sparse.py +2718 -0
  86. warp/_src/tape.py +1208 -0
  87. warp/_src/thirdparty/__init__.py +0 -0
  88. warp/_src/thirdparty/appdirs.py +598 -0
  89. warp/_src/thirdparty/dlpack.py +145 -0
  90. warp/_src/thirdparty/unittest_parallel.py +676 -0
  91. warp/_src/torch.py +393 -0
  92. warp/_src/types.py +5888 -0
  93. warp/_src/utils.py +1695 -0
  94. warp/autograd.py +33 -0
  95. warp/bin/libwarp-clang.dylib +0 -0
  96. warp/bin/libwarp.dylib +0 -0
  97. warp/build.py +29 -0
  98. warp/build_dll.py +24 -0
  99. warp/codegen.py +24 -0
  100. warp/constants.py +24 -0
  101. warp/context.py +33 -0
  102. warp/dlpack.py +24 -0
  103. warp/examples/__init__.py +24 -0
  104. warp/examples/assets/bear.usd +0 -0
  105. warp/examples/assets/bunny.usd +0 -0
  106. warp/examples/assets/cube.usd +0 -0
  107. warp/examples/assets/nonuniform.usd +0 -0
  108. warp/examples/assets/nvidia_logo.png +0 -0
  109. warp/examples/assets/pixel.jpg +0 -0
  110. warp/examples/assets/rocks.nvdb +0 -0
  111. warp/examples/assets/rocks.usd +0 -0
  112. warp/examples/assets/sphere.usd +0 -0
  113. warp/examples/assets/square_cloth.usd +0 -0
  114. warp/examples/benchmarks/benchmark_api.py +389 -0
  115. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  116. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  117. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  118. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  119. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  120. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  121. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  122. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  123. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  124. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  125. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  126. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  127. warp/examples/benchmarks/benchmark_launches.py +301 -0
  128. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  129. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  130. warp/examples/browse.py +37 -0
  131. warp/examples/core/example_cupy.py +86 -0
  132. warp/examples/core/example_dem.py +241 -0
  133. warp/examples/core/example_fluid.py +299 -0
  134. warp/examples/core/example_graph_capture.py +150 -0
  135. warp/examples/core/example_marching_cubes.py +195 -0
  136. warp/examples/core/example_mesh.py +180 -0
  137. warp/examples/core/example_mesh_intersect.py +211 -0
  138. warp/examples/core/example_nvdb.py +182 -0
  139. warp/examples/core/example_raycast.py +111 -0
  140. warp/examples/core/example_raymarch.py +205 -0
  141. warp/examples/core/example_render_opengl.py +290 -0
  142. warp/examples/core/example_sample_mesh.py +300 -0
  143. warp/examples/core/example_sph.py +411 -0
  144. warp/examples/core/example_spin_lock.py +93 -0
  145. warp/examples/core/example_torch.py +211 -0
  146. warp/examples/core/example_wave.py +269 -0
  147. warp/examples/core/example_work_queue.py +118 -0
  148. warp/examples/distributed/example_jacobi_mpi.py +506 -0
  149. warp/examples/fem/example_adaptive_grid.py +286 -0
  150. warp/examples/fem/example_apic_fluid.py +469 -0
  151. warp/examples/fem/example_burgers.py +261 -0
  152. warp/examples/fem/example_convection_diffusion.py +181 -0
  153. warp/examples/fem/example_convection_diffusion_dg.py +225 -0
  154. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  155. warp/examples/fem/example_deformed_geometry.py +172 -0
  156. warp/examples/fem/example_diffusion.py +196 -0
  157. warp/examples/fem/example_diffusion_3d.py +225 -0
  158. warp/examples/fem/example_diffusion_mgpu.py +225 -0
  159. warp/examples/fem/example_distortion_energy.py +228 -0
  160. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  161. warp/examples/fem/example_magnetostatics.py +242 -0
  162. warp/examples/fem/example_mixed_elasticity.py +293 -0
  163. warp/examples/fem/example_navier_stokes.py +263 -0
  164. warp/examples/fem/example_nonconforming_contact.py +300 -0
  165. warp/examples/fem/example_stokes.py +213 -0
  166. warp/examples/fem/example_stokes_transfer.py +262 -0
  167. warp/examples/fem/example_streamlines.py +357 -0
  168. warp/examples/fem/utils.py +1047 -0
  169. warp/examples/interop/example_jax_callable.py +146 -0
  170. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  171. warp/examples/interop/example_jax_kernel.py +232 -0
  172. warp/examples/optim/example_diffray.py +561 -0
  173. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  174. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  175. warp/examples/tile/example_tile_cholesky.py +88 -0
  176. warp/examples/tile/example_tile_convolution.py +66 -0
  177. warp/examples/tile/example_tile_fft.py +55 -0
  178. warp/examples/tile/example_tile_filtering.py +113 -0
  179. warp/examples/tile/example_tile_matmul.py +85 -0
  180. warp/examples/tile/example_tile_mcgp.py +191 -0
  181. warp/examples/tile/example_tile_mlp.py +385 -0
  182. warp/examples/tile/example_tile_nbody.py +199 -0
  183. warp/fabric.py +24 -0
  184. warp/fem/__init__.py +173 -0
  185. warp/fem/adaptivity.py +26 -0
  186. warp/fem/cache.py +30 -0
  187. warp/fem/dirichlet.py +24 -0
  188. warp/fem/field/__init__.py +24 -0
  189. warp/fem/field/field.py +26 -0
  190. warp/fem/geometry/__init__.py +21 -0
  191. warp/fem/geometry/closest_point.py +31 -0
  192. warp/fem/linalg.py +38 -0
  193. warp/fem/operator.py +32 -0
  194. warp/fem/polynomial.py +29 -0
  195. warp/fem/space/__init__.py +22 -0
  196. warp/fem/space/basis_space.py +24 -0
  197. warp/fem/space/shape/__init__.py +68 -0
  198. warp/fem/space/topology.py +24 -0
  199. warp/fem/types.py +24 -0
  200. warp/fem/utils.py +32 -0
  201. warp/jax.py +29 -0
  202. warp/jax_experimental/__init__.py +29 -0
  203. warp/jax_experimental/custom_call.py +29 -0
  204. warp/jax_experimental/ffi.py +39 -0
  205. warp/jax_experimental/xla_ffi.py +24 -0
  206. warp/marching_cubes.py +24 -0
  207. warp/math.py +37 -0
  208. warp/native/array.h +1687 -0
  209. warp/native/builtin.h +2327 -0
  210. warp/native/bvh.cpp +562 -0
  211. warp/native/bvh.cu +826 -0
  212. warp/native/bvh.h +555 -0
  213. warp/native/clang/clang.cpp +541 -0
  214. warp/native/coloring.cpp +622 -0
  215. warp/native/crt.cpp +51 -0
  216. warp/native/crt.h +568 -0
  217. warp/native/cuda_crt.h +1058 -0
  218. warp/native/cuda_util.cpp +677 -0
  219. warp/native/cuda_util.h +313 -0
  220. warp/native/error.cpp +77 -0
  221. warp/native/error.h +36 -0
  222. warp/native/exports.h +2023 -0
  223. warp/native/fabric.h +246 -0
  224. warp/native/hashgrid.cpp +311 -0
  225. warp/native/hashgrid.cu +89 -0
  226. warp/native/hashgrid.h +240 -0
  227. warp/native/initializer_array.h +41 -0
  228. warp/native/intersect.h +1253 -0
  229. warp/native/intersect_adj.h +375 -0
  230. warp/native/intersect_tri.h +348 -0
  231. warp/native/mat.h +5189 -0
  232. warp/native/mathdx.cpp +93 -0
  233. warp/native/matnn.h +221 -0
  234. warp/native/mesh.cpp +266 -0
  235. warp/native/mesh.cu +406 -0
  236. warp/native/mesh.h +2097 -0
  237. warp/native/nanovdb/GridHandle.h +533 -0
  238. warp/native/nanovdb/HostBuffer.h +591 -0
  239. warp/native/nanovdb/NanoVDB.h +6246 -0
  240. warp/native/nanovdb/NodeManager.h +323 -0
  241. warp/native/nanovdb/PNanoVDB.h +3390 -0
  242. warp/native/noise.h +859 -0
  243. warp/native/quat.h +1664 -0
  244. warp/native/rand.h +342 -0
  245. warp/native/range.h +145 -0
  246. warp/native/reduce.cpp +174 -0
  247. warp/native/reduce.cu +363 -0
  248. warp/native/runlength_encode.cpp +79 -0
  249. warp/native/runlength_encode.cu +61 -0
  250. warp/native/scan.cpp +47 -0
  251. warp/native/scan.cu +55 -0
  252. warp/native/scan.h +23 -0
  253. warp/native/solid_angle.h +466 -0
  254. warp/native/sort.cpp +251 -0
  255. warp/native/sort.cu +286 -0
  256. warp/native/sort.h +35 -0
  257. warp/native/sparse.cpp +241 -0
  258. warp/native/sparse.cu +435 -0
  259. warp/native/spatial.h +1306 -0
  260. warp/native/svd.h +727 -0
  261. warp/native/temp_buffer.h +46 -0
  262. warp/native/tile.h +4124 -0
  263. warp/native/tile_radix_sort.h +1112 -0
  264. warp/native/tile_reduce.h +838 -0
  265. warp/native/tile_scan.h +240 -0
  266. warp/native/tuple.h +189 -0
  267. warp/native/vec.h +2199 -0
  268. warp/native/version.h +23 -0
  269. warp/native/volume.cpp +501 -0
  270. warp/native/volume.cu +68 -0
  271. warp/native/volume.h +970 -0
  272. warp/native/volume_builder.cu +483 -0
  273. warp/native/volume_builder.h +52 -0
  274. warp/native/volume_impl.h +70 -0
  275. warp/native/warp.cpp +1143 -0
  276. warp/native/warp.cu +4604 -0
  277. warp/native/warp.h +358 -0
  278. warp/optim/__init__.py +20 -0
  279. warp/optim/adam.py +24 -0
  280. warp/optim/linear.py +35 -0
  281. warp/optim/sgd.py +24 -0
  282. warp/paddle.py +24 -0
  283. warp/py.typed +0 -0
  284. warp/render/__init__.py +22 -0
  285. warp/render/imgui_manager.py +29 -0
  286. warp/render/render_opengl.py +24 -0
  287. warp/render/render_usd.py +24 -0
  288. warp/render/utils.py +24 -0
  289. warp/sparse.py +51 -0
  290. warp/tape.py +24 -0
  291. warp/tests/__init__.py +1 -0
  292. warp/tests/__main__.py +4 -0
  293. warp/tests/assets/curlnoise_golden.npy +0 -0
  294. warp/tests/assets/mlp_golden.npy +0 -0
  295. warp/tests/assets/pixel.npy +0 -0
  296. warp/tests/assets/pnoise_golden.npy +0 -0
  297. warp/tests/assets/spiky.usd +0 -0
  298. warp/tests/assets/test_grid.nvdb +0 -0
  299. warp/tests/assets/test_index_grid.nvdb +0 -0
  300. warp/tests/assets/test_int32_grid.nvdb +0 -0
  301. warp/tests/assets/test_vec_grid.nvdb +0 -0
  302. warp/tests/assets/torus.nvdb +0 -0
  303. warp/tests/assets/torus.usda +105 -0
  304. warp/tests/aux_test_class_kernel.py +34 -0
  305. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  306. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  307. warp/tests/aux_test_dependent.py +29 -0
  308. warp/tests/aux_test_grad_customs.py +29 -0
  309. warp/tests/aux_test_instancing_gc.py +26 -0
  310. warp/tests/aux_test_module_aot.py +7 -0
  311. warp/tests/aux_test_module_unload.py +23 -0
  312. warp/tests/aux_test_name_clash1.py +40 -0
  313. warp/tests/aux_test_name_clash2.py +40 -0
  314. warp/tests/aux_test_reference.py +9 -0
  315. warp/tests/aux_test_reference_reference.py +8 -0
  316. warp/tests/aux_test_square.py +16 -0
  317. warp/tests/aux_test_unresolved_func.py +22 -0
  318. warp/tests/aux_test_unresolved_symbol.py +22 -0
  319. warp/tests/cuda/__init__.py +0 -0
  320. warp/tests/cuda/test_async.py +676 -0
  321. warp/tests/cuda/test_conditional_captures.py +1147 -0
  322. warp/tests/cuda/test_ipc.py +124 -0
  323. warp/tests/cuda/test_mempool.py +233 -0
  324. warp/tests/cuda/test_multigpu.py +169 -0
  325. warp/tests/cuda/test_peer.py +139 -0
  326. warp/tests/cuda/test_pinned.py +84 -0
  327. warp/tests/cuda/test_streams.py +691 -0
  328. warp/tests/geometry/__init__.py +0 -0
  329. warp/tests/geometry/test_bvh.py +335 -0
  330. warp/tests/geometry/test_hash_grid.py +259 -0
  331. warp/tests/geometry/test_marching_cubes.py +294 -0
  332. warp/tests/geometry/test_mesh.py +318 -0
  333. warp/tests/geometry/test_mesh_query_aabb.py +392 -0
  334. warp/tests/geometry/test_mesh_query_point.py +935 -0
  335. warp/tests/geometry/test_mesh_query_ray.py +323 -0
  336. warp/tests/geometry/test_volume.py +1103 -0
  337. warp/tests/geometry/test_volume_write.py +346 -0
  338. warp/tests/interop/__init__.py +0 -0
  339. warp/tests/interop/test_dlpack.py +730 -0
  340. warp/tests/interop/test_jax.py +1673 -0
  341. warp/tests/interop/test_paddle.py +800 -0
  342. warp/tests/interop/test_torch.py +1001 -0
  343. warp/tests/run_coverage_serial.py +39 -0
  344. warp/tests/test_adam.py +162 -0
  345. warp/tests/test_arithmetic.py +1096 -0
  346. warp/tests/test_array.py +3756 -0
  347. warp/tests/test_array_reduce.py +156 -0
  348. warp/tests/test_assert.py +303 -0
  349. warp/tests/test_atomic.py +336 -0
  350. warp/tests/test_atomic_bitwise.py +209 -0
  351. warp/tests/test_atomic_cas.py +312 -0
  352. warp/tests/test_bool.py +220 -0
  353. warp/tests/test_builtins_resolution.py +732 -0
  354. warp/tests/test_closest_point_edge_edge.py +327 -0
  355. warp/tests/test_codegen.py +974 -0
  356. warp/tests/test_codegen_instancing.py +1495 -0
  357. warp/tests/test_compile_consts.py +215 -0
  358. warp/tests/test_conditional.py +298 -0
  359. warp/tests/test_context.py +35 -0
  360. warp/tests/test_copy.py +319 -0
  361. warp/tests/test_ctypes.py +618 -0
  362. warp/tests/test_dense.py +73 -0
  363. warp/tests/test_devices.py +127 -0
  364. warp/tests/test_enum.py +136 -0
  365. warp/tests/test_examples.py +424 -0
  366. warp/tests/test_fabricarray.py +998 -0
  367. warp/tests/test_fast_math.py +72 -0
  368. warp/tests/test_fem.py +2204 -0
  369. warp/tests/test_fixedarray.py +229 -0
  370. warp/tests/test_fp16.py +136 -0
  371. warp/tests/test_func.py +501 -0
  372. warp/tests/test_future_annotations.py +100 -0
  373. warp/tests/test_generics.py +656 -0
  374. warp/tests/test_grad.py +893 -0
  375. warp/tests/test_grad_customs.py +339 -0
  376. warp/tests/test_grad_debug.py +341 -0
  377. warp/tests/test_implicit_init.py +411 -0
  378. warp/tests/test_import.py +45 -0
  379. warp/tests/test_indexedarray.py +1140 -0
  380. warp/tests/test_intersect.py +103 -0
  381. warp/tests/test_iter.py +76 -0
  382. warp/tests/test_large.py +177 -0
  383. warp/tests/test_launch.py +411 -0
  384. warp/tests/test_lerp.py +151 -0
  385. warp/tests/test_linear_solvers.py +223 -0
  386. warp/tests/test_lvalue.py +427 -0
  387. warp/tests/test_map.py +526 -0
  388. warp/tests/test_mat.py +3515 -0
  389. warp/tests/test_mat_assign_copy.py +178 -0
  390. warp/tests/test_mat_constructors.py +573 -0
  391. warp/tests/test_mat_lite.py +122 -0
  392. warp/tests/test_mat_scalar_ops.py +2913 -0
  393. warp/tests/test_math.py +212 -0
  394. warp/tests/test_module_aot.py +287 -0
  395. warp/tests/test_module_hashing.py +258 -0
  396. warp/tests/test_modules_lite.py +70 -0
  397. warp/tests/test_noise.py +252 -0
  398. warp/tests/test_operators.py +299 -0
  399. warp/tests/test_options.py +129 -0
  400. warp/tests/test_overwrite.py +551 -0
  401. warp/tests/test_print.py +408 -0
  402. warp/tests/test_quat.py +2653 -0
  403. warp/tests/test_quat_assign_copy.py +145 -0
  404. warp/tests/test_rand.py +339 -0
  405. warp/tests/test_reload.py +303 -0
  406. warp/tests/test_rounding.py +157 -0
  407. warp/tests/test_runlength_encode.py +196 -0
  408. warp/tests/test_scalar_ops.py +133 -0
  409. warp/tests/test_smoothstep.py +108 -0
  410. warp/tests/test_snippet.py +318 -0
  411. warp/tests/test_sparse.py +845 -0
  412. warp/tests/test_spatial.py +2859 -0
  413. warp/tests/test_spatial_assign_copy.py +160 -0
  414. warp/tests/test_special_values.py +361 -0
  415. warp/tests/test_static.py +640 -0
  416. warp/tests/test_struct.py +901 -0
  417. warp/tests/test_tape.py +242 -0
  418. warp/tests/test_transient_module.py +93 -0
  419. warp/tests/test_triangle_closest_point.py +192 -0
  420. warp/tests/test_tuple.py +361 -0
  421. warp/tests/test_types.py +615 -0
  422. warp/tests/test_utils.py +594 -0
  423. warp/tests/test_vec.py +1408 -0
  424. warp/tests/test_vec_assign_copy.py +143 -0
  425. warp/tests/test_vec_constructors.py +325 -0
  426. warp/tests/test_vec_lite.py +80 -0
  427. warp/tests/test_vec_scalar_ops.py +2327 -0
  428. warp/tests/test_verify_fp.py +100 -0
  429. warp/tests/test_version.py +75 -0
  430. warp/tests/tile/__init__.py +0 -0
  431. warp/tests/tile/test_tile.py +1519 -0
  432. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  433. warp/tests/tile/test_tile_cholesky.py +608 -0
  434. warp/tests/tile/test_tile_load.py +724 -0
  435. warp/tests/tile/test_tile_mathdx.py +156 -0
  436. warp/tests/tile/test_tile_matmul.py +179 -0
  437. warp/tests/tile/test_tile_mlp.py +400 -0
  438. warp/tests/tile/test_tile_reduce.py +950 -0
  439. warp/tests/tile/test_tile_shared_memory.py +376 -0
  440. warp/tests/tile/test_tile_sort.py +121 -0
  441. warp/tests/tile/test_tile_view.py +173 -0
  442. warp/tests/unittest_serial.py +47 -0
  443. warp/tests/unittest_suites.py +430 -0
  444. warp/tests/unittest_utils.py +469 -0
  445. warp/tests/walkthrough_debug.py +95 -0
  446. warp/torch.py +24 -0
  447. warp/types.py +51 -0
  448. warp/utils.py +31 -0
  449. warp_lang-1.10.0.dist-info/METADATA +459 -0
  450. warp_lang-1.10.0.dist-info/RECORD +468 -0
  451. warp_lang-1.10.0.dist-info/WHEEL +5 -0
  452. warp_lang-1.10.0.dist-info/licenses/LICENSE.md +176 -0
  453. warp_lang-1.10.0.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  454. warp_lang-1.10.0.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  455. warp_lang-1.10.0.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  456. warp_lang-1.10.0.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  457. warp_lang-1.10.0.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  458. warp_lang-1.10.0.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  459. warp_lang-1.10.0.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  460. warp_lang-1.10.0.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  461. warp_lang-1.10.0.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  462. warp_lang-1.10.0.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  463. warp_lang-1.10.0.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  464. warp_lang-1.10.0.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  465. warp_lang-1.10.0.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  466. warp_lang-1.10.0.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  467. warp_lang-1.10.0.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  468. warp_lang-1.10.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,838 @@
1
+ /*
2
+ * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: Apache-2.0
4
+ *
5
+ * Licensed under the Apache License, Version 2.0 (the "License");
6
+ * you may not use this file except in compliance with the License.
7
+ * You may obtain a copy of the License at
8
+ *
9
+ * http://www.apache.org/licenses/LICENSE-2.0
10
+ *
11
+ * Unless required by applicable law or agreed to in writing, software
12
+ * distributed under the License is distributed on an "AS IS" BASIS,
13
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ * See the License for the specific language governing permissions and
15
+ * limitations under the License.
16
+ */
17
+
18
+ #pragma once
19
+
20
+ #include "tile.h"
21
+
22
+ #ifdef __clang__
23
+ // disable warnings related to C++17 extensions on CPU JIT builds
24
+ #pragma clang diagnostic push
25
+ #pragma clang diagnostic ignored "-Wc++17-extensions"
26
+ #endif // __clang__
27
+
28
+ #define WP_TILE_WARP_SIZE 32
29
+
30
+ namespace wp
31
+ {
32
+
33
+
34
+ template <typename T>
35
+ int argmax_tracker(T champion_value, T current_value, int champion_index, int current_index)
36
+ {
37
+ return current_value > champion_value ? current_index : champion_index;
38
+ }
39
+
40
+ template <typename T>
41
+ int argmin_tracker(T champion_value, T current_value, int champion_index, int current_index)
42
+ {
43
+ return current_value < champion_value ? current_index : champion_index;
44
+ }
45
+
46
+
47
+ #if defined(__CUDA_ARCH__)
48
+
49
+ template <typename T>
50
+ inline CUDA_CALLABLE T warp_shuffle_down(T val, int offset, int mask)
51
+ {
52
+ typedef unsigned int Word;
53
+
54
+ union
55
+ {
56
+ T output;
57
+ Word output_storage;
58
+ };
59
+
60
+ union
61
+ {
62
+ T input;
63
+ Word input_storage;
64
+ };
65
+
66
+ input = val;
67
+
68
+ Word* dest = reinterpret_cast<Word*>(&output);
69
+ Word* src = reinterpret_cast<Word*>(&input);
70
+
71
+ unsigned int shuffle_word;
72
+
73
+ constexpr int word_count = (sizeof(T) + sizeof(Word) - 1) / sizeof(Word);
74
+
75
+ WP_PRAGMA_UNROLL
76
+ for (int i=0; i < word_count; ++i)
77
+ {
78
+ shuffle_word = __shfl_down_sync(mask, src[i], offset, WP_TILE_WARP_SIZE);
79
+ dest[i] = shuffle_word;
80
+ }
81
+
82
+ return output;
83
+ }
84
+
85
+ // vector overload
86
+ template <unsigned Length, typename T>
87
+ inline CUDA_CALLABLE wp::vec_t<Length, T> warp_shuffle_down(wp::vec_t<Length, T> val, int offset, int mask)
88
+ {
89
+ wp::vec_t<Length, T> result;
90
+
91
+ for (unsigned i=0; i < Length; ++i)
92
+ result[i] = __shfl_down_sync(mask, val[i], offset, WP_TILE_WARP_SIZE);
93
+
94
+ return result;
95
+ }
96
+
97
+ // matrix overload
98
+ template <unsigned Rows, unsigned Cols, typename T>
99
+ inline CUDA_CALLABLE wp::mat_t<Rows, Cols, T> warp_shuffle_down(wp::mat_t<Rows, Cols, T> val, int offset, int mask)
100
+ {
101
+ wp::mat_t<Rows, Cols, T> result;
102
+
103
+ for (unsigned i=0; i < Rows; ++i)
104
+ for (unsigned j=0; j < Cols; ++j)
105
+ result.data[i][j] = __shfl_down_sync(mask, val.data[i][j], offset, WP_TILE_WARP_SIZE);
106
+
107
+ return result;
108
+ }
109
+
110
+
111
+ template <typename T, typename Op>
112
+ inline CUDA_CALLABLE T warp_reduce(T val, Op f, unsigned int mask)
113
+ {
114
+ T sum = val;
115
+
116
+ if (mask == 0xFFFFFFFF)
117
+ {
118
+ // handle case where entire warp is active
119
+ for (int offset=WP_TILE_WARP_SIZE/2; offset > 0; offset /= 2)
120
+ {
121
+ sum = f(sum, warp_shuffle_down(sum, offset, mask));
122
+ }
123
+ }
124
+ else
125
+ {
126
+ // handle partial warp case - works for contiguous masks
127
+ for (int offset=WP_TILE_WARP_SIZE/2; offset > 0; offset /= 2)
128
+ {
129
+ T shfl_val = warp_shuffle_down(sum, offset, mask);
130
+ if ((mask & (1 << ((threadIdx.x + offset)%WP_TILE_WARP_SIZE))) != 0)
131
+ sum = f(sum, shfl_val);
132
+ }
133
+ }
134
+
135
+ return sum;
136
+ }
137
+
138
+ template <typename T>
139
+ struct ValueAndIndex
140
+ {
141
+ T value;
142
+ int index;
143
+ };
144
+
145
+ template <typename T, typename Op, typename OpTrack>
146
+ inline CUDA_CALLABLE ValueAndIndex<T> warp_reduce_tracked(T val, int idx, Op f, OpTrack track, unsigned int mask)
147
+ {
148
+ T sum = val;
149
+ int index = idx;
150
+
151
+ if (mask == 0xFFFFFFFF)
152
+ {
153
+ // handle case where entire warp is active
154
+ for (int offset=WP_TILE_WARP_SIZE/2; offset > 0; offset /= 2)
155
+ {
156
+ auto shfl_val = warp_shuffle_down(sum, offset, mask);
157
+ int shfl_idx = warp_shuffle_down(index, offset, mask);
158
+ index = track(sum, shfl_val, index, shfl_idx);
159
+ sum = f(sum, shfl_val);
160
+ }
161
+ }
162
+ else
163
+ {
164
+ // handle partial warp case
165
+ for (int offset=WP_TILE_WARP_SIZE/2; offset > 0; offset /= 2)
166
+ {
167
+ T shfl_val = warp_shuffle_down(sum, offset, mask);
168
+ int shfl_index = warp_shuffle_down(index, offset, mask);
169
+ if ((mask & (1 << ((threadIdx.x + offset)%WP_TILE_WARP_SIZE))) != 0)
170
+ {
171
+ index = track(sum, shfl_val, index, shfl_index);
172
+ sum = f(sum, shfl_val);
173
+ }
174
+ }
175
+ }
176
+
177
+ ValueAndIndex<T> result;
178
+ result.value = sum;
179
+ result.index = index;
180
+
181
+ return result;
182
+ }
183
+
184
+ // combines per-thread reduction results across warps and the entire block
185
+ // assumes each thread has already reduced its local data to thread_sum
186
+ // returns the block-wide reduced value (only valid in thread 0)
187
+ template <typename T, typename Op>
188
+ inline CUDA_CALLABLE T block_combine_thread_results(T thread_sum, bool thread_has_data, Op f,
189
+ T* partials, int& active_warps)
190
+ {
191
+ constexpr int warp_count = (WP_TILE_BLOCK_DIM + WP_TILE_WARP_SIZE - 1) / WP_TILE_WARP_SIZE;
192
+ const int warp_index = threadIdx.x / WP_TILE_WARP_SIZE;
193
+ const int lane_index = threadIdx.x % WP_TILE_WARP_SIZE;
194
+
195
+ // determine which threads have data
196
+ unsigned int mask = __ballot_sync(0xFFFFFFFF, thread_has_data);
197
+ bool warp_is_active = mask != 0;
198
+
199
+ // warp reduction
200
+ T warp_sum;
201
+ if (thread_has_data)
202
+ warp_sum = warp_reduce(thread_sum, f, mask);
203
+
204
+ // lane 0 of each active warp writes to shared memory and increments counter
205
+ if (lane_index == 0 && warp_is_active)
206
+ {
207
+ partials[warp_index] = warp_sum;
208
+ atomicAdd(&active_warps, 1);
209
+ }
210
+
211
+ // sync to ensure all warps have written their partials
212
+ WP_TILE_SYNC();
213
+
214
+ // thread 0 performs final reduction across active warps
215
+ T block_sum;
216
+ if (threadIdx.x == 0)
217
+ {
218
+ block_sum = partials[0];
219
+
220
+ for (int w = 1; w < active_warps; ++w)
221
+ {
222
+ block_sum = f(block_sum, partials[w]);
223
+ }
224
+ }
225
+
226
+ return block_sum;
227
+ }
228
+
229
+ // non-axis version which computes sum
230
+ // across the entire tile using the whole block
231
+ template <typename Tile, typename Op>
232
+ auto tile_reduce_impl(Op f, Tile& t)
233
+ {
234
+ using T = typename Tile::Type;
235
+
236
+ auto input = t.copy_to_register();
237
+ auto output = tile_register_t<T, tile_layout_register_t<tile_shape_t<1>>>();
238
+
239
+ constexpr int warp_count = (WP_TILE_BLOCK_DIM + WP_TILE_WARP_SIZE - 1) / WP_TILE_WARP_SIZE;
240
+
241
+ using Layout = typename decltype(input)::Layout;
242
+
243
+ // step 1: each thread reduces its own registers locally
244
+ T thread_sum = input.data[0];
245
+ bool thread_has_data = Layout::valid(Layout::linear_from_register(0));
246
+
247
+ WP_PRAGMA_UNROLL
248
+ for (int i=1; i < Layout::NumRegs; ++i)
249
+ {
250
+ int linear = Layout::linear_from_register(i);
251
+ if (!Layout::valid(linear))
252
+ break;
253
+
254
+ thread_sum = f(thread_sum, input.data[i]);
255
+ }
256
+
257
+ // shared memory for cross-warp reduction
258
+ __shared__ T partials[warp_count];
259
+ __shared__ int active_warps;
260
+
261
+ if (threadIdx.x == 0)
262
+ active_warps = 0;
263
+
264
+ WP_TILE_SYNC();
265
+
266
+ // step 2-3: combine thread results across warps and block
267
+ T block_sum = block_combine_thread_results(thread_sum, thread_has_data, f, partials, active_warps);
268
+
269
+ if (threadIdx.x == 0)
270
+ output.data[0] = block_sum;
271
+
272
+ return output;
273
+ }
274
+
275
+ template <int Axis, typename Op, typename Tile>
276
+ auto tile_reduce_axis_impl(Op f, Tile& t)
277
+ {
278
+ using T = typename Tile::Type;
279
+ using InputShape = typename Tile::Layout::Shape;
280
+ using OutputShape = typename tile_shape_remove_dim<Axis, InputShape>::type;
281
+
282
+ constexpr int reduce_dim_size = InputShape::dim(Axis);
283
+ constexpr int output_size = OutputShape::size();
284
+
285
+ // special case: 1D input delegates to block-wide tile_reduce_impl for optimal performance
286
+ if constexpr (InputShape::N == 1)
287
+ {
288
+ return tile_reduce_impl(f, t);
289
+ }
290
+
291
+ // shared memory buffer for the output (used by all tiers)
292
+ __shared__ T output_buffer[output_size];
293
+
294
+ // create output layout for coordinate conversion (used by all tiers)
295
+ using OutputLayout = tile_layout_strided_t<OutputShape>;
296
+
297
+ if constexpr (reduce_dim_size <= 32)
298
+ {
299
+ // Tier 1: Single thread per output element (optimal for small reductions)
300
+
301
+ // each thread processes output elements, performing reduction along the axis
302
+ for (int out_idx = WP_TILE_THREAD_IDX; out_idx < output_size; out_idx += WP_TILE_BLOCK_DIM)
303
+ {
304
+ // convert output linear index to output coordinates
305
+ auto out_coord = OutputLayout::coord_from_linear(out_idx);
306
+
307
+ // initialize accumulator with first element along the reduction axis
308
+ T accumulator = t.data(tile_coord_insert_axis<Axis>(out_coord, 0));
309
+
310
+ // reduce across the axis
311
+ for (int i = 1; i < reduce_dim_size; ++i)
312
+ {
313
+ accumulator = f(accumulator, t.data(tile_coord_insert_axis<Axis>(out_coord, i)));
314
+ }
315
+
316
+ // store to output buffer
317
+ output_buffer[out_idx] = accumulator;
318
+ }
319
+
320
+ // sync before reading output
321
+ WP_TILE_SYNC();
322
+ }
323
+ else if constexpr (reduce_dim_size <= 256)
324
+ {
325
+ // Tier 2: Warp-based reduction (one warp per output element)
326
+ constexpr int warp_count = (WP_TILE_BLOCK_DIM + WP_TILE_WARP_SIZE - 1) / WP_TILE_WARP_SIZE;
327
+ const int warp_index = threadIdx.x / WP_TILE_WARP_SIZE;
328
+ const int lane_index = threadIdx.x % WP_TILE_WARP_SIZE;
329
+
330
+ constexpr int chunks_per_slice = (reduce_dim_size + WP_TILE_WARP_SIZE - 1) / WP_TILE_WARP_SIZE;
331
+
332
+ // shared memory: one accumulator per warp
333
+ __shared__ T warp_partials[warp_count];
334
+
335
+ // each warp processes output slices
336
+ for (int out_idx = warp_index; out_idx < output_size; out_idx += warp_count)
337
+ {
338
+ auto out_coord = OutputLayout::coord_from_linear(out_idx);
339
+
340
+ // process the reduction axis in chunks of 32
341
+ for (int chunk = 0; chunk < chunks_per_slice; ++chunk)
342
+ {
343
+ int axis_idx = chunk * WP_TILE_WARP_SIZE + lane_index;
344
+ bool valid = axis_idx < reduce_dim_size;
345
+
346
+ T val;
347
+ if (valid)
348
+ {
349
+ auto in_coord = tile_coord_insert_axis<Axis>(out_coord, axis_idx);
350
+ val = t.data(in_coord);
351
+ }
352
+
353
+ // warp reduce this chunk (only valid lanes participate)
354
+ unsigned int mask = __ballot_sync(0xFFFFFFFF, valid);
355
+ T chunk_result = warp_reduce(val, f, mask);
356
+
357
+ // lane 0 accumulates the chunk result
358
+ if (lane_index == 0)
359
+ {
360
+ if (chunk == 0)
361
+ warp_partials[warp_index] = chunk_result;
362
+ else
363
+ warp_partials[warp_index] = f(warp_partials[warp_index], chunk_result);
364
+ }
365
+ }
366
+
367
+ // lane 0 writes final result for this output element
368
+ if (lane_index == 0)
369
+ output_buffer[out_idx] = warp_partials[warp_index];
370
+ }
371
+
372
+ // sync before reading output
373
+ WP_TILE_SYNC();
374
+ }
375
+ else
376
+ {
377
+ // Tier 3: Block-level reduction (entire block collaborates on each output element)
378
+ constexpr int warp_count = (WP_TILE_BLOCK_DIM + WP_TILE_WARP_SIZE - 1) / WP_TILE_WARP_SIZE;
379
+
380
+ // shared memory for cross-warp reduction
381
+ __shared__ T partials[warp_count];
382
+ __shared__ int active_warps;
383
+
384
+ // process each output element sequentially with full block cooperation
385
+ for (int out_idx = 0; out_idx < output_size; ++out_idx)
386
+ {
387
+ auto out_coord = OutputLayout::coord_from_linear(out_idx);
388
+
389
+ // step 1: each thread reduces its strided subset of the slice locally
390
+ bool thread_has_data = threadIdx.x < reduce_dim_size;
391
+ T thread_sum;
392
+
393
+ if (thread_has_data)
394
+ {
395
+ // initialize with first element
396
+ auto in_coord = tile_coord_insert_axis<Axis>(out_coord, threadIdx.x);
397
+ thread_sum = t.data(in_coord);
398
+
399
+ // reduce remaining elements with stride
400
+ for (int i = threadIdx.x + WP_TILE_BLOCK_DIM; i < reduce_dim_size; i += WP_TILE_BLOCK_DIM)
401
+ {
402
+ auto in_coord = tile_coord_insert_axis<Axis>(out_coord, i);
403
+ T val = t.data(in_coord);
404
+ thread_sum = f(thread_sum, val);
405
+ }
406
+ }
407
+
408
+ // initialize active warp counter
409
+ if (threadIdx.x == 0)
410
+ active_warps = 0;
411
+
412
+ WP_TILE_SYNC();
413
+
414
+ // step 2-3: combine thread results across warps and block
415
+ T block_sum = block_combine_thread_results(thread_sum, thread_has_data, f, partials, active_warps);
416
+
417
+ if (threadIdx.x == 0)
418
+ output_buffer[out_idx] = block_sum;
419
+
420
+ // sync before next output element
421
+ WP_TILE_SYNC();
422
+ }
423
+ }
424
+
425
+ // copy from shared memory buffer to register tile (common to all tiers)
426
+ auto output = tile_register_t<T, tile_layout_register_t<OutputShape>>();
427
+ using OutputRegLayout = typename decltype(output)::Layout;
428
+
429
+ WP_PRAGMA_UNROLL
430
+ for (int i = 0; i < OutputRegLayout::NumRegs; ++i)
431
+ {
432
+ int linear = OutputRegLayout::linear_from_register(i);
433
+ if (!OutputRegLayout::valid(linear))
434
+ break;
435
+
436
+ output.data[i] = output_buffer[linear];
437
+ }
438
+
439
+ return output;
440
+ }
441
+
442
+ // non-axis version which computes sum
443
+ // across the entire tile using the whole block
444
+ template <typename Tile, typename Op, typename OpTrack>
445
+ auto tile_arg_reduce_impl(Op f, OpTrack track, Tile& t)
446
+ {
447
+ using T = typename Tile::Type;
448
+
449
+ auto input = t.copy_to_register();
450
+ auto output = tile_register_t<int, tile_layout_register_t<tile_shape_t<1>>>();
451
+
452
+ const int warp_count = (WP_TILE_BLOCK_DIM + WP_TILE_WARP_SIZE - 1)/WP_TILE_WARP_SIZE;
453
+ const int warp_index = threadIdx.x/WP_TILE_WARP_SIZE;
454
+ const int lane_index = threadIdx.x%WP_TILE_WARP_SIZE;
455
+
456
+ using Layout = typename decltype(input)::Layout;
457
+
458
+ int champion_index = Layout::NumRegs > 0 ? Layout::linear_from_register(0) : -1;
459
+ T thread_sum = input.data[0];
460
+
461
+ // thread reduction
462
+ WP_PRAGMA_UNROLL
463
+ for (int i=1; i < Layout::NumRegs; ++i)
464
+ {
465
+ int linear = Layout::linear_from_register(i);
466
+ if (!Layout::valid(linear))
467
+ break;
468
+
469
+ champion_index = track(thread_sum, input.data[i], champion_index, linear);
470
+ thread_sum = f(thread_sum, input.data[i]);
471
+ }
472
+
473
+ // ensure that only threads with at least one valid item participate in the reduction
474
+ unsigned int mask = __ballot_sync(__activemask(), Layout::valid(Layout::linear_from_register(0)));
475
+ bool warp_is_active = mask != 0;
476
+
477
+ // warp reduction
478
+ ValueAndIndex<T> warp_sum = warp_reduce_tracked(thread_sum, champion_index, f, track, mask);
479
+
480
+ // fixed size scratch pad for partial results in shared memory
481
+ __shared__ T partials[warp_count];
482
+ __shared__ int partials_idx[warp_count];
483
+
484
+ // count of active warps
485
+ __shared__ int active_warps;
486
+ if (threadIdx.x == 0)
487
+ active_warps = 0;
488
+
489
+ // ensure active_warps is initialized
490
+ WP_TILE_SYNC();
491
+
492
+ if (lane_index == 0 && warp_is_active)
493
+ {
494
+ partials[warp_index] = warp_sum.value;
495
+ partials_idx[warp_index] = warp_sum.index;
496
+ atomicAdd(&active_warps, 1);
497
+ }
498
+
499
+ // ensure partials are ready
500
+ WP_TILE_SYNC();
501
+
502
+ // reduce across block, todo: use warp_reduce() here
503
+ if (threadIdx.x == 0)
504
+ {
505
+ T block_sum = partials[0];
506
+ int block_champion_index = partials_idx[0];
507
+
508
+ WP_PRAGMA_UNROLL
509
+ for (int i=1; i < active_warps; ++i)
510
+ {
511
+ block_champion_index = track(block_sum, partials[i], block_champion_index, partials_idx[i]);
512
+ block_sum = f(block_sum, partials[i]);
513
+ }
514
+
515
+ output.data[0] = block_champion_index;
516
+ }
517
+
518
+ return output;
519
+ }
520
+
521
+ #else
522
+
523
+ // CPU implementation
524
+
525
+ template <typename Tile, typename Op>
526
+ auto tile_reduce_impl(Op f, Tile& t)
527
+ {
528
+ using T = typename Tile::Type;
529
+
530
+ auto input = t.copy_to_register();
531
+ auto output = tile_register_t<T, tile_layout_register_t<tile_shape_t<1>>>();
532
+
533
+ using Layout = typename decltype(input)::Layout;
534
+
535
+ T sum = input.data[0];
536
+
537
+ WP_PRAGMA_UNROLL
538
+ for (int i=1; i < Layout::NumRegs; ++i)
539
+ {
540
+ int linear = Layout::linear_from_register(i);
541
+ if (!Layout::valid(linear))
542
+ break;
543
+
544
+ sum = f(sum, input.data[i]);
545
+ }
546
+
547
+ output.data[0] = sum;
548
+ return output;
549
+ }
550
+
551
+ template <int Axis, typename Op, typename Tile>
552
+ auto tile_reduce_axis_impl(Op f, Tile& t)
553
+ {
554
+ using T = typename Tile::Type;
555
+ using InputShape = typename Tile::Layout::Shape;
556
+ using OutputShape = typename tile_shape_remove_dim<Axis, InputShape>::type;
557
+
558
+ constexpr int reduce_dim_size = InputShape::dim(Axis);
559
+
560
+ // CPU version - work directly with register tiles, no thread coordination needed
561
+ auto input = t.copy_to_register();
562
+ auto output = tile_register_t<T, tile_layout_register_t<OutputShape>>();
563
+ using OutputLayout = typename decltype(output)::Layout;
564
+
565
+ // iterate through each output element and reduce along the axis
566
+ constexpr int output_size = OutputShape::size();
567
+ for (int out_idx = 0; out_idx < output_size; ++out_idx)
568
+ {
569
+ T accumulator;
570
+
571
+ // special case for 1D input (reduces to single value)
572
+ if constexpr (InputShape::N == 1)
573
+ {
574
+ accumulator = input.data[0];
575
+ for (int i = 1; i < reduce_dim_size; ++i)
576
+ {
577
+ // input is in registers, linear access
578
+ accumulator = f(accumulator, input.data[i]);
579
+ }
580
+ }
581
+ else
582
+ {
583
+ // multi-dimensional case
584
+ auto out_coord = OutputLayout::coord_from_linear(out_idx);
585
+
586
+ // get input coordinates by inserting axis values
587
+ auto coord_0 = tile_coord_insert_axis<Axis>(out_coord, 0);
588
+ int input_linear_0 = tile_layout_register_t<InputShape>::linear_from_coord(coord_0);
589
+ int input_reg_0 = tile_layout_register_t<InputShape>::register_from_linear(input_linear_0);
590
+ accumulator = input.data[input_reg_0];
591
+
592
+ // reduce across the axis
593
+ for (int i = 1; i < reduce_dim_size; ++i)
594
+ {
595
+ auto coord_i = tile_coord_insert_axis<Axis>(out_coord, i);
596
+ int input_linear_i = tile_layout_register_t<InputShape>::linear_from_coord(coord_i);
597
+ int input_reg_i = tile_layout_register_t<InputShape>::register_from_linear(input_linear_i);
598
+ accumulator = f(accumulator, input.data[input_reg_i]);
599
+ }
600
+ }
601
+
602
+ // store to output register
603
+ int output_reg = OutputLayout::register_from_linear(out_idx);
604
+ output.data[output_reg] = accumulator;
605
+ }
606
+
607
+ return output;
608
+ }
609
+
610
+ template <typename Tile, typename Op, typename OpTrack>
611
+ auto tile_arg_reduce_impl(Op f, OpTrack track, Tile& t)
612
+ {
613
+ using T = typename Tile::Type;
614
+
615
+ auto input = t.copy_to_register();
616
+ auto output = tile_register_t<int, tile_layout_register_t<tile_shape_t<1>>>();
617
+
618
+ using Layout = typename decltype(input)::Layout;
619
+
620
+ int champion_index = Layout::NumRegs > 0 ? Layout::linear_from_register(0) : -1;
621
+ T sum = input.data[0];
622
+
623
+ WP_PRAGMA_UNROLL
624
+ for (int i=1; i < Layout::NumRegs; ++i)
625
+ {
626
+ int linear = Layout::linear_from_register(i);
627
+ if (!Layout::valid(linear))
628
+ break;
629
+
630
+ champion_index = track(sum, input.data[i], champion_index, linear);
631
+ sum = f(sum, input.data[i]);
632
+ }
633
+
634
+ output.data[0] = champion_index;
635
+ return output;
636
+ }
637
+
638
+ #endif // !defined(__CUDA_ARCH__)
639
+
640
+ inline void adj_tile_reduce_impl()
641
+ {
642
+ // todo: general purpose reduction gradients not implemented
643
+ }
644
+
645
+ inline void adj_tile_reduce_axis_impl()
646
+ {
647
+ // todo: axis-specific reduction gradients not implemented
648
+ }
649
+
650
+ // entry point for Python code-gen, wraps op in a lambda to perform overload resolution
651
+ #define tile_reduce(op, t) tile_reduce_impl([](auto x, auto y) { return op(x, y);}, t)
652
+ #define adj_tile_reduce(op, t, adj_op, adj_t, adj_ret) adj_tile_reduce_impl()
653
+
654
+ #define tile_arg_reduce(op, opTrack, t) tile_arg_reduce_impl([](auto x, auto y) { return op(x, y);}, [](auto a, auto b, auto c, auto d) { return opTrack(a, b, c, d); }, t)
655
+ #define adj_tile_arg_reduce(op, t, adj_op, adj_t, adj_ret) adj_tile_arg_reduce_impl()
656
+
657
+ // axis-specific reduction entry points
658
+ #define tile_reduce_axis(op, t, axis) tile_reduce_axis_impl<axis>([](auto x, auto y) { return op(x, y);}, t)
659
+ #define adj_tile_reduce_axis(op, t, axis, adj_op, adj_t, adj_axis, adj_ret) adj_tile_reduce_axis_impl()
660
+
661
+ // convenience methods for specific reductions
662
+
663
+ // whole-tile sum
664
+ template <typename Tile>
665
+ auto tile_sum(Tile& t)
666
+ {
667
+ return tile_reduce(add, t);
668
+ }
669
+
670
+ // special case adjoint for summation
671
+ template <typename Tile, typename AdjTile>
672
+ void adj_tile_sum(Tile& t, Tile& adj_t, AdjTile& adj_ret)
673
+ {
674
+ using T = typename Tile::Type;
675
+
676
+ auto adj_reg = adj_ret.grad_to_register();
677
+
678
+ #if !defined(__CUDA_ARCH__)
679
+ T scratch = adj_reg.data[0];
680
+ #else
681
+ // broadcast incoming adjoint to block
682
+ __shared__ T scratch;
683
+ if (WP_TILE_THREAD_IDX == 0)
684
+ scratch = adj_reg.data[0];
685
+
686
+ WP_TILE_SYNC();
687
+ #endif
688
+
689
+ auto adj_ret_reg = tile_register_like<Tile>();
690
+ using Layout = typename decltype(adj_ret_reg)::Layout;
691
+ for (int i=0; i < Layout::NumRegs; ++i)
692
+ {
693
+ adj_ret_reg.data[i] += scratch;
694
+ }
695
+ adj_t.grad_add(adj_ret_reg);
696
+ }
697
+
698
+ // axis-specific sum
699
+ template <int Axis, typename Tile>
700
+ auto tile_sum(Tile& t)
701
+ {
702
+ return tile_reduce_axis_impl<Axis>([](auto x, auto y) { return add(x, y); }, t);
703
+ }
704
+
705
+ // special case adjoint for axis-specific summation
706
+ template<int Axis, typename Tile, typename AdjTile>
707
+ void adj_tile_sum(Tile& t, Tile& adj_t, AdjTile& adj_ret)
708
+ {
709
+ using InputShape = typename Tile::Layout::Shape;
710
+
711
+ if constexpr (InputShape::N == 1)
712
+ {
713
+ // 1D -> scalar case: broadcast scalar to 1D
714
+ auto broadcasted = tile_broadcast<InputShape::dim(0), 0>(adj_ret);
715
+ tile_add_inplace(adj_t, broadcasted);
716
+ }
717
+ else if constexpr (InputShape::N == 2)
718
+ {
719
+ if constexpr (Axis == 0)
720
+ {
721
+ // broadcast from (D1,) to (D0, D1) with strides (0, 1)
722
+ auto broadcasted = tile_broadcast<InputShape::dim(0), InputShape::dim(1), 0, 1>(adj_ret);
723
+ tile_add_inplace(adj_t, broadcasted);
724
+ }
725
+ else // Axis == 1
726
+ {
727
+ // broadcast from (D0,) to (D0, D1) with strides (1, 0)
728
+ auto broadcasted = tile_broadcast<InputShape::dim(0), InputShape::dim(1), 1, 0>(adj_ret);
729
+ tile_add_inplace(adj_t, broadcasted);
730
+ }
731
+ }
732
+ else if constexpr (InputShape::N == 3)
733
+ {
734
+ if constexpr (Axis == 0)
735
+ {
736
+ // broadcast from (D1, D2) to (D0, D1, D2) with strides (0, D2, 1)
737
+ auto broadcasted = tile_broadcast<InputShape::dim(0), InputShape::dim(1), InputShape::dim(2), 0, InputShape::dim(2), 1>(adj_ret);
738
+ tile_add_inplace(adj_t, broadcasted);
739
+ }
740
+ else if constexpr (Axis == 1)
741
+ {
742
+ // broadcast from (D0, D2) to (D0, D1, D2) with strides (D2, 0, 1)
743
+ auto broadcasted = tile_broadcast<InputShape::dim(0), InputShape::dim(1), InputShape::dim(2), InputShape::dim(2), 0, 1>(adj_ret);
744
+ tile_add_inplace(adj_t, broadcasted);
745
+ }
746
+ else // Axis == 2
747
+ {
748
+ // broadcast from (D0, D1) to (D0, D1, D2) with strides (D1, 1, 0)
749
+ auto broadcasted = tile_broadcast<InputShape::dim(0), InputShape::dim(1), InputShape::dim(2), InputShape::dim(1), 1, 0>(adj_ret);
750
+ tile_add_inplace(adj_t, broadcasted);
751
+ }
752
+ }
753
+ else if constexpr (InputShape::N == 4)
754
+ {
755
+ if constexpr (Axis == 0)
756
+ {
757
+ // broadcast from (D1, D2, D3) to (D0, D1, D2, D3) with strides (0, D2*D3, D3, 1)
758
+ auto broadcasted = tile_broadcast<InputShape::dim(0), InputShape::dim(1), InputShape::dim(2), InputShape::dim(3), 0, InputShape::dim(2)*InputShape::dim(3), InputShape::dim(3), 1>(adj_ret);
759
+ tile_add_inplace(adj_t, broadcasted);
760
+ }
761
+ else if constexpr (Axis == 1)
762
+ {
763
+ // broadcast from (D0, D2, D3) to (D0, D1, D2, D3) with strides (D2*D3, 0, D3, 1)
764
+ auto broadcasted = tile_broadcast<InputShape::dim(0), InputShape::dim(1), InputShape::dim(2), InputShape::dim(3), InputShape::dim(2)*InputShape::dim(3), 0, InputShape::dim(3), 1>(adj_ret);
765
+ tile_add_inplace(adj_t, broadcasted);
766
+ }
767
+ else if constexpr (Axis == 2)
768
+ {
769
+ // broadcast from (D0, D1, D3) to (D0, D1, D2, D3) with strides (D1*D3, D3, 0, 1)
770
+ auto broadcasted = tile_broadcast<InputShape::dim(0), InputShape::dim(1), InputShape::dim(2), InputShape::dim(3), InputShape::dim(1)*InputShape::dim(3), InputShape::dim(3), 0, 1>(adj_ret);
771
+ tile_add_inplace(adj_t, broadcasted);
772
+ }
773
+ else // Axis == 3
774
+ {
775
+ // broadcast from (D0, D1, D2) to (D0, D1, D2, D3) with strides (D1*D2, D2, 1, 0)
776
+ auto broadcasted = tile_broadcast<InputShape::dim(0), InputShape::dim(1), InputShape::dim(2), InputShape::dim(3), InputShape::dim(1)*InputShape::dim(2), InputShape::dim(2), 1, 0>(adj_ret);
777
+ tile_add_inplace(adj_t, broadcasted);
778
+ }
779
+ }
780
+ }
781
+
782
+ template <typename Tile>
783
+ auto tile_max(Tile& t)
784
+ {
785
+ return tile_reduce(max, t);
786
+ }
787
+
788
+ template <typename Tile, typename AdjTile>
789
+ void adj_tile_max(Tile& t, Tile& adj_t, AdjTile& adj_ret)
790
+ {
791
+ // todo: not implemented
792
+ }
793
+
794
+ template <typename Tile>
795
+ auto tile_min(Tile& t)
796
+ {
797
+ return tile_reduce(min, t);
798
+ }
799
+
800
+ template <typename Tile, typename AdjTile>
801
+ void adj_tile_min(Tile& t, Tile& adj_t, AdjTile& adj_ret)
802
+ {
803
+ // todo: not implemented
804
+ }
805
+
806
+
807
+
808
+ template <typename Tile>
809
+ auto tile_argmax(Tile& t)
810
+ {
811
+ return tile_arg_reduce(max, argmax_tracker, t);
812
+ }
813
+
814
+ template <typename Tile, typename AdjTile>
815
+ void adj_tile_argmax(Tile& t, Tile& adj_t, AdjTile& adj_ret)
816
+ {
817
+ // todo: not implemented
818
+ }
819
+
820
+ template <typename Tile>
821
+ auto tile_argmin(Tile& t)
822
+ {
823
+ return tile_arg_reduce(min, argmin_tracker, t);
824
+ }
825
+
826
+ template <typename Tile, typename AdjTile>
827
+ void adj_tile_argmin(Tile& t, Tile& adj_t, AdjTile& adj_ret)
828
+ {
829
+ // todo: not implemented
830
+ }
831
+
832
+
833
+ } // namespace wp
834
+
835
+
836
+ #ifdef __clang__
837
+ #pragma clang diagnostic pop
838
+ #endif