warp-lang 1.10.0__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (468) hide show
  1. warp/__init__.py +334 -0
  2. warp/__init__.pyi +5856 -0
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1077 -0
  5. warp/_src/build.py +620 -0
  6. warp/_src/build_dll.py +642 -0
  7. warp/_src/builtins.py +10555 -0
  8. warp/_src/codegen.py +4361 -0
  9. warp/_src/config.py +178 -0
  10. warp/_src/constants.py +59 -0
  11. warp/_src/context.py +8352 -0
  12. warp/_src/dlpack.py +464 -0
  13. warp/_src/fabric.py +362 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +510 -0
  16. warp/_src/fem/cache.py +689 -0
  17. warp/_src/fem/dirichlet.py +190 -0
  18. warp/_src/fem/domain.py +553 -0
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +703 -0
  21. warp/_src/fem/field/nodal_field.py +403 -0
  22. warp/_src/fem/field/restriction.py +39 -0
  23. warp/_src/fem/field/virtual.py +1021 -0
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/_src/fem/geometry/adaptive_nanogrid.py +782 -0
  26. warp/_src/fem/geometry/closest_point.py +99 -0
  27. warp/_src/fem/geometry/deformed_geometry.py +277 -0
  28. warp/_src/fem/geometry/element.py +854 -0
  29. warp/_src/fem/geometry/geometry.py +693 -0
  30. warp/_src/fem/geometry/grid_2d.py +478 -0
  31. warp/_src/fem/geometry/grid_3d.py +539 -0
  32. warp/_src/fem/geometry/hexmesh.py +956 -0
  33. warp/_src/fem/geometry/nanogrid.py +660 -0
  34. warp/_src/fem/geometry/partition.py +483 -0
  35. warp/_src/fem/geometry/quadmesh.py +597 -0
  36. warp/_src/fem/geometry/tetmesh.py +762 -0
  37. warp/_src/fem/geometry/trimesh.py +588 -0
  38. warp/_src/fem/integrate.py +2507 -0
  39. warp/_src/fem/linalg.py +385 -0
  40. warp/_src/fem/operator.py +398 -0
  41. warp/_src/fem/polynomial.py +231 -0
  42. warp/_src/fem/quadrature/__init__.py +17 -0
  43. warp/_src/fem/quadrature/pic_quadrature.py +318 -0
  44. warp/_src/fem/quadrature/quadrature.py +665 -0
  45. warp/_src/fem/space/__init__.py +248 -0
  46. warp/_src/fem/space/basis_function_space.py +499 -0
  47. warp/_src/fem/space/basis_space.py +681 -0
  48. warp/_src/fem/space/dof_mapper.py +253 -0
  49. warp/_src/fem/space/function_space.py +312 -0
  50. warp/_src/fem/space/grid_2d_function_space.py +179 -0
  51. warp/_src/fem/space/grid_3d_function_space.py +229 -0
  52. warp/_src/fem/space/hexmesh_function_space.py +255 -0
  53. warp/_src/fem/space/nanogrid_function_space.py +199 -0
  54. warp/_src/fem/space/partition.py +435 -0
  55. warp/_src/fem/space/quadmesh_function_space.py +222 -0
  56. warp/_src/fem/space/restriction.py +221 -0
  57. warp/_src/fem/space/shape/__init__.py +152 -0
  58. warp/_src/fem/space/shape/cube_shape_function.py +1107 -0
  59. warp/_src/fem/space/shape/shape_function.py +134 -0
  60. warp/_src/fem/space/shape/square_shape_function.py +928 -0
  61. warp/_src/fem/space/shape/tet_shape_function.py +829 -0
  62. warp/_src/fem/space/shape/triangle_shape_function.py +674 -0
  63. warp/_src/fem/space/tetmesh_function_space.py +270 -0
  64. warp/_src/fem/space/topology.py +461 -0
  65. warp/_src/fem/space/trimesh_function_space.py +193 -0
  66. warp/_src/fem/types.py +114 -0
  67. warp/_src/fem/utils.py +488 -0
  68. warp/_src/jax.py +188 -0
  69. warp/_src/jax_experimental/__init__.py +14 -0
  70. warp/_src/jax_experimental/custom_call.py +389 -0
  71. warp/_src/jax_experimental/ffi.py +1286 -0
  72. warp/_src/jax_experimental/xla_ffi.py +658 -0
  73. warp/_src/marching_cubes.py +710 -0
  74. warp/_src/math.py +416 -0
  75. warp/_src/optim/__init__.py +14 -0
  76. warp/_src/optim/adam.py +165 -0
  77. warp/_src/optim/linear.py +1608 -0
  78. warp/_src/optim/sgd.py +114 -0
  79. warp/_src/paddle.py +408 -0
  80. warp/_src/render/__init__.py +14 -0
  81. warp/_src/render/imgui_manager.py +291 -0
  82. warp/_src/render/render_opengl.py +3638 -0
  83. warp/_src/render/render_usd.py +939 -0
  84. warp/_src/render/utils.py +162 -0
  85. warp/_src/sparse.py +2718 -0
  86. warp/_src/tape.py +1208 -0
  87. warp/_src/thirdparty/__init__.py +0 -0
  88. warp/_src/thirdparty/appdirs.py +598 -0
  89. warp/_src/thirdparty/dlpack.py +145 -0
  90. warp/_src/thirdparty/unittest_parallel.py +676 -0
  91. warp/_src/torch.py +393 -0
  92. warp/_src/types.py +5888 -0
  93. warp/_src/utils.py +1695 -0
  94. warp/autograd.py +33 -0
  95. warp/bin/libwarp-clang.dylib +0 -0
  96. warp/bin/libwarp.dylib +0 -0
  97. warp/build.py +29 -0
  98. warp/build_dll.py +24 -0
  99. warp/codegen.py +24 -0
  100. warp/constants.py +24 -0
  101. warp/context.py +33 -0
  102. warp/dlpack.py +24 -0
  103. warp/examples/__init__.py +24 -0
  104. warp/examples/assets/bear.usd +0 -0
  105. warp/examples/assets/bunny.usd +0 -0
  106. warp/examples/assets/cube.usd +0 -0
  107. warp/examples/assets/nonuniform.usd +0 -0
  108. warp/examples/assets/nvidia_logo.png +0 -0
  109. warp/examples/assets/pixel.jpg +0 -0
  110. warp/examples/assets/rocks.nvdb +0 -0
  111. warp/examples/assets/rocks.usd +0 -0
  112. warp/examples/assets/sphere.usd +0 -0
  113. warp/examples/assets/square_cloth.usd +0 -0
  114. warp/examples/benchmarks/benchmark_api.py +389 -0
  115. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  116. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  117. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  118. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  119. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  120. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  121. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  122. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  123. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  124. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  125. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  126. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  127. warp/examples/benchmarks/benchmark_launches.py +301 -0
  128. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  129. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  130. warp/examples/browse.py +37 -0
  131. warp/examples/core/example_cupy.py +86 -0
  132. warp/examples/core/example_dem.py +241 -0
  133. warp/examples/core/example_fluid.py +299 -0
  134. warp/examples/core/example_graph_capture.py +150 -0
  135. warp/examples/core/example_marching_cubes.py +195 -0
  136. warp/examples/core/example_mesh.py +180 -0
  137. warp/examples/core/example_mesh_intersect.py +211 -0
  138. warp/examples/core/example_nvdb.py +182 -0
  139. warp/examples/core/example_raycast.py +111 -0
  140. warp/examples/core/example_raymarch.py +205 -0
  141. warp/examples/core/example_render_opengl.py +290 -0
  142. warp/examples/core/example_sample_mesh.py +300 -0
  143. warp/examples/core/example_sph.py +411 -0
  144. warp/examples/core/example_spin_lock.py +93 -0
  145. warp/examples/core/example_torch.py +211 -0
  146. warp/examples/core/example_wave.py +269 -0
  147. warp/examples/core/example_work_queue.py +118 -0
  148. warp/examples/distributed/example_jacobi_mpi.py +506 -0
  149. warp/examples/fem/example_adaptive_grid.py +286 -0
  150. warp/examples/fem/example_apic_fluid.py +469 -0
  151. warp/examples/fem/example_burgers.py +261 -0
  152. warp/examples/fem/example_convection_diffusion.py +181 -0
  153. warp/examples/fem/example_convection_diffusion_dg.py +225 -0
  154. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  155. warp/examples/fem/example_deformed_geometry.py +172 -0
  156. warp/examples/fem/example_diffusion.py +196 -0
  157. warp/examples/fem/example_diffusion_3d.py +225 -0
  158. warp/examples/fem/example_diffusion_mgpu.py +225 -0
  159. warp/examples/fem/example_distortion_energy.py +228 -0
  160. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  161. warp/examples/fem/example_magnetostatics.py +242 -0
  162. warp/examples/fem/example_mixed_elasticity.py +293 -0
  163. warp/examples/fem/example_navier_stokes.py +263 -0
  164. warp/examples/fem/example_nonconforming_contact.py +300 -0
  165. warp/examples/fem/example_stokes.py +213 -0
  166. warp/examples/fem/example_stokes_transfer.py +262 -0
  167. warp/examples/fem/example_streamlines.py +357 -0
  168. warp/examples/fem/utils.py +1047 -0
  169. warp/examples/interop/example_jax_callable.py +146 -0
  170. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  171. warp/examples/interop/example_jax_kernel.py +232 -0
  172. warp/examples/optim/example_diffray.py +561 -0
  173. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  174. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  175. warp/examples/tile/example_tile_cholesky.py +88 -0
  176. warp/examples/tile/example_tile_convolution.py +66 -0
  177. warp/examples/tile/example_tile_fft.py +55 -0
  178. warp/examples/tile/example_tile_filtering.py +113 -0
  179. warp/examples/tile/example_tile_matmul.py +85 -0
  180. warp/examples/tile/example_tile_mcgp.py +191 -0
  181. warp/examples/tile/example_tile_mlp.py +385 -0
  182. warp/examples/tile/example_tile_nbody.py +199 -0
  183. warp/fabric.py +24 -0
  184. warp/fem/__init__.py +173 -0
  185. warp/fem/adaptivity.py +26 -0
  186. warp/fem/cache.py +30 -0
  187. warp/fem/dirichlet.py +24 -0
  188. warp/fem/field/__init__.py +24 -0
  189. warp/fem/field/field.py +26 -0
  190. warp/fem/geometry/__init__.py +21 -0
  191. warp/fem/geometry/closest_point.py +31 -0
  192. warp/fem/linalg.py +38 -0
  193. warp/fem/operator.py +32 -0
  194. warp/fem/polynomial.py +29 -0
  195. warp/fem/space/__init__.py +22 -0
  196. warp/fem/space/basis_space.py +24 -0
  197. warp/fem/space/shape/__init__.py +68 -0
  198. warp/fem/space/topology.py +24 -0
  199. warp/fem/types.py +24 -0
  200. warp/fem/utils.py +32 -0
  201. warp/jax.py +29 -0
  202. warp/jax_experimental/__init__.py +29 -0
  203. warp/jax_experimental/custom_call.py +29 -0
  204. warp/jax_experimental/ffi.py +39 -0
  205. warp/jax_experimental/xla_ffi.py +24 -0
  206. warp/marching_cubes.py +24 -0
  207. warp/math.py +37 -0
  208. warp/native/array.h +1687 -0
  209. warp/native/builtin.h +2327 -0
  210. warp/native/bvh.cpp +562 -0
  211. warp/native/bvh.cu +826 -0
  212. warp/native/bvh.h +555 -0
  213. warp/native/clang/clang.cpp +541 -0
  214. warp/native/coloring.cpp +622 -0
  215. warp/native/crt.cpp +51 -0
  216. warp/native/crt.h +568 -0
  217. warp/native/cuda_crt.h +1058 -0
  218. warp/native/cuda_util.cpp +677 -0
  219. warp/native/cuda_util.h +313 -0
  220. warp/native/error.cpp +77 -0
  221. warp/native/error.h +36 -0
  222. warp/native/exports.h +2023 -0
  223. warp/native/fabric.h +246 -0
  224. warp/native/hashgrid.cpp +311 -0
  225. warp/native/hashgrid.cu +89 -0
  226. warp/native/hashgrid.h +240 -0
  227. warp/native/initializer_array.h +41 -0
  228. warp/native/intersect.h +1253 -0
  229. warp/native/intersect_adj.h +375 -0
  230. warp/native/intersect_tri.h +348 -0
  231. warp/native/mat.h +5189 -0
  232. warp/native/mathdx.cpp +93 -0
  233. warp/native/matnn.h +221 -0
  234. warp/native/mesh.cpp +266 -0
  235. warp/native/mesh.cu +406 -0
  236. warp/native/mesh.h +2097 -0
  237. warp/native/nanovdb/GridHandle.h +533 -0
  238. warp/native/nanovdb/HostBuffer.h +591 -0
  239. warp/native/nanovdb/NanoVDB.h +6246 -0
  240. warp/native/nanovdb/NodeManager.h +323 -0
  241. warp/native/nanovdb/PNanoVDB.h +3390 -0
  242. warp/native/noise.h +859 -0
  243. warp/native/quat.h +1664 -0
  244. warp/native/rand.h +342 -0
  245. warp/native/range.h +145 -0
  246. warp/native/reduce.cpp +174 -0
  247. warp/native/reduce.cu +363 -0
  248. warp/native/runlength_encode.cpp +79 -0
  249. warp/native/runlength_encode.cu +61 -0
  250. warp/native/scan.cpp +47 -0
  251. warp/native/scan.cu +55 -0
  252. warp/native/scan.h +23 -0
  253. warp/native/solid_angle.h +466 -0
  254. warp/native/sort.cpp +251 -0
  255. warp/native/sort.cu +286 -0
  256. warp/native/sort.h +35 -0
  257. warp/native/sparse.cpp +241 -0
  258. warp/native/sparse.cu +435 -0
  259. warp/native/spatial.h +1306 -0
  260. warp/native/svd.h +727 -0
  261. warp/native/temp_buffer.h +46 -0
  262. warp/native/tile.h +4124 -0
  263. warp/native/tile_radix_sort.h +1112 -0
  264. warp/native/tile_reduce.h +838 -0
  265. warp/native/tile_scan.h +240 -0
  266. warp/native/tuple.h +189 -0
  267. warp/native/vec.h +2199 -0
  268. warp/native/version.h +23 -0
  269. warp/native/volume.cpp +501 -0
  270. warp/native/volume.cu +68 -0
  271. warp/native/volume.h +970 -0
  272. warp/native/volume_builder.cu +483 -0
  273. warp/native/volume_builder.h +52 -0
  274. warp/native/volume_impl.h +70 -0
  275. warp/native/warp.cpp +1143 -0
  276. warp/native/warp.cu +4604 -0
  277. warp/native/warp.h +358 -0
  278. warp/optim/__init__.py +20 -0
  279. warp/optim/adam.py +24 -0
  280. warp/optim/linear.py +35 -0
  281. warp/optim/sgd.py +24 -0
  282. warp/paddle.py +24 -0
  283. warp/py.typed +0 -0
  284. warp/render/__init__.py +22 -0
  285. warp/render/imgui_manager.py +29 -0
  286. warp/render/render_opengl.py +24 -0
  287. warp/render/render_usd.py +24 -0
  288. warp/render/utils.py +24 -0
  289. warp/sparse.py +51 -0
  290. warp/tape.py +24 -0
  291. warp/tests/__init__.py +1 -0
  292. warp/tests/__main__.py +4 -0
  293. warp/tests/assets/curlnoise_golden.npy +0 -0
  294. warp/tests/assets/mlp_golden.npy +0 -0
  295. warp/tests/assets/pixel.npy +0 -0
  296. warp/tests/assets/pnoise_golden.npy +0 -0
  297. warp/tests/assets/spiky.usd +0 -0
  298. warp/tests/assets/test_grid.nvdb +0 -0
  299. warp/tests/assets/test_index_grid.nvdb +0 -0
  300. warp/tests/assets/test_int32_grid.nvdb +0 -0
  301. warp/tests/assets/test_vec_grid.nvdb +0 -0
  302. warp/tests/assets/torus.nvdb +0 -0
  303. warp/tests/assets/torus.usda +105 -0
  304. warp/tests/aux_test_class_kernel.py +34 -0
  305. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  306. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  307. warp/tests/aux_test_dependent.py +29 -0
  308. warp/tests/aux_test_grad_customs.py +29 -0
  309. warp/tests/aux_test_instancing_gc.py +26 -0
  310. warp/tests/aux_test_module_aot.py +7 -0
  311. warp/tests/aux_test_module_unload.py +23 -0
  312. warp/tests/aux_test_name_clash1.py +40 -0
  313. warp/tests/aux_test_name_clash2.py +40 -0
  314. warp/tests/aux_test_reference.py +9 -0
  315. warp/tests/aux_test_reference_reference.py +8 -0
  316. warp/tests/aux_test_square.py +16 -0
  317. warp/tests/aux_test_unresolved_func.py +22 -0
  318. warp/tests/aux_test_unresolved_symbol.py +22 -0
  319. warp/tests/cuda/__init__.py +0 -0
  320. warp/tests/cuda/test_async.py +676 -0
  321. warp/tests/cuda/test_conditional_captures.py +1147 -0
  322. warp/tests/cuda/test_ipc.py +124 -0
  323. warp/tests/cuda/test_mempool.py +233 -0
  324. warp/tests/cuda/test_multigpu.py +169 -0
  325. warp/tests/cuda/test_peer.py +139 -0
  326. warp/tests/cuda/test_pinned.py +84 -0
  327. warp/tests/cuda/test_streams.py +691 -0
  328. warp/tests/geometry/__init__.py +0 -0
  329. warp/tests/geometry/test_bvh.py +335 -0
  330. warp/tests/geometry/test_hash_grid.py +259 -0
  331. warp/tests/geometry/test_marching_cubes.py +294 -0
  332. warp/tests/geometry/test_mesh.py +318 -0
  333. warp/tests/geometry/test_mesh_query_aabb.py +392 -0
  334. warp/tests/geometry/test_mesh_query_point.py +935 -0
  335. warp/tests/geometry/test_mesh_query_ray.py +323 -0
  336. warp/tests/geometry/test_volume.py +1103 -0
  337. warp/tests/geometry/test_volume_write.py +346 -0
  338. warp/tests/interop/__init__.py +0 -0
  339. warp/tests/interop/test_dlpack.py +730 -0
  340. warp/tests/interop/test_jax.py +1673 -0
  341. warp/tests/interop/test_paddle.py +800 -0
  342. warp/tests/interop/test_torch.py +1001 -0
  343. warp/tests/run_coverage_serial.py +39 -0
  344. warp/tests/test_adam.py +162 -0
  345. warp/tests/test_arithmetic.py +1096 -0
  346. warp/tests/test_array.py +3756 -0
  347. warp/tests/test_array_reduce.py +156 -0
  348. warp/tests/test_assert.py +303 -0
  349. warp/tests/test_atomic.py +336 -0
  350. warp/tests/test_atomic_bitwise.py +209 -0
  351. warp/tests/test_atomic_cas.py +312 -0
  352. warp/tests/test_bool.py +220 -0
  353. warp/tests/test_builtins_resolution.py +732 -0
  354. warp/tests/test_closest_point_edge_edge.py +327 -0
  355. warp/tests/test_codegen.py +974 -0
  356. warp/tests/test_codegen_instancing.py +1495 -0
  357. warp/tests/test_compile_consts.py +215 -0
  358. warp/tests/test_conditional.py +298 -0
  359. warp/tests/test_context.py +35 -0
  360. warp/tests/test_copy.py +319 -0
  361. warp/tests/test_ctypes.py +618 -0
  362. warp/tests/test_dense.py +73 -0
  363. warp/tests/test_devices.py +127 -0
  364. warp/tests/test_enum.py +136 -0
  365. warp/tests/test_examples.py +424 -0
  366. warp/tests/test_fabricarray.py +998 -0
  367. warp/tests/test_fast_math.py +72 -0
  368. warp/tests/test_fem.py +2204 -0
  369. warp/tests/test_fixedarray.py +229 -0
  370. warp/tests/test_fp16.py +136 -0
  371. warp/tests/test_func.py +501 -0
  372. warp/tests/test_future_annotations.py +100 -0
  373. warp/tests/test_generics.py +656 -0
  374. warp/tests/test_grad.py +893 -0
  375. warp/tests/test_grad_customs.py +339 -0
  376. warp/tests/test_grad_debug.py +341 -0
  377. warp/tests/test_implicit_init.py +411 -0
  378. warp/tests/test_import.py +45 -0
  379. warp/tests/test_indexedarray.py +1140 -0
  380. warp/tests/test_intersect.py +103 -0
  381. warp/tests/test_iter.py +76 -0
  382. warp/tests/test_large.py +177 -0
  383. warp/tests/test_launch.py +411 -0
  384. warp/tests/test_lerp.py +151 -0
  385. warp/tests/test_linear_solvers.py +223 -0
  386. warp/tests/test_lvalue.py +427 -0
  387. warp/tests/test_map.py +526 -0
  388. warp/tests/test_mat.py +3515 -0
  389. warp/tests/test_mat_assign_copy.py +178 -0
  390. warp/tests/test_mat_constructors.py +573 -0
  391. warp/tests/test_mat_lite.py +122 -0
  392. warp/tests/test_mat_scalar_ops.py +2913 -0
  393. warp/tests/test_math.py +212 -0
  394. warp/tests/test_module_aot.py +287 -0
  395. warp/tests/test_module_hashing.py +258 -0
  396. warp/tests/test_modules_lite.py +70 -0
  397. warp/tests/test_noise.py +252 -0
  398. warp/tests/test_operators.py +299 -0
  399. warp/tests/test_options.py +129 -0
  400. warp/tests/test_overwrite.py +551 -0
  401. warp/tests/test_print.py +408 -0
  402. warp/tests/test_quat.py +2653 -0
  403. warp/tests/test_quat_assign_copy.py +145 -0
  404. warp/tests/test_rand.py +339 -0
  405. warp/tests/test_reload.py +303 -0
  406. warp/tests/test_rounding.py +157 -0
  407. warp/tests/test_runlength_encode.py +196 -0
  408. warp/tests/test_scalar_ops.py +133 -0
  409. warp/tests/test_smoothstep.py +108 -0
  410. warp/tests/test_snippet.py +318 -0
  411. warp/tests/test_sparse.py +845 -0
  412. warp/tests/test_spatial.py +2859 -0
  413. warp/tests/test_spatial_assign_copy.py +160 -0
  414. warp/tests/test_special_values.py +361 -0
  415. warp/tests/test_static.py +640 -0
  416. warp/tests/test_struct.py +901 -0
  417. warp/tests/test_tape.py +242 -0
  418. warp/tests/test_transient_module.py +93 -0
  419. warp/tests/test_triangle_closest_point.py +192 -0
  420. warp/tests/test_tuple.py +361 -0
  421. warp/tests/test_types.py +615 -0
  422. warp/tests/test_utils.py +594 -0
  423. warp/tests/test_vec.py +1408 -0
  424. warp/tests/test_vec_assign_copy.py +143 -0
  425. warp/tests/test_vec_constructors.py +325 -0
  426. warp/tests/test_vec_lite.py +80 -0
  427. warp/tests/test_vec_scalar_ops.py +2327 -0
  428. warp/tests/test_verify_fp.py +100 -0
  429. warp/tests/test_version.py +75 -0
  430. warp/tests/tile/__init__.py +0 -0
  431. warp/tests/tile/test_tile.py +1519 -0
  432. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  433. warp/tests/tile/test_tile_cholesky.py +608 -0
  434. warp/tests/tile/test_tile_load.py +724 -0
  435. warp/tests/tile/test_tile_mathdx.py +156 -0
  436. warp/tests/tile/test_tile_matmul.py +179 -0
  437. warp/tests/tile/test_tile_mlp.py +400 -0
  438. warp/tests/tile/test_tile_reduce.py +950 -0
  439. warp/tests/tile/test_tile_shared_memory.py +376 -0
  440. warp/tests/tile/test_tile_sort.py +121 -0
  441. warp/tests/tile/test_tile_view.py +173 -0
  442. warp/tests/unittest_serial.py +47 -0
  443. warp/tests/unittest_suites.py +430 -0
  444. warp/tests/unittest_utils.py +469 -0
  445. warp/tests/walkthrough_debug.py +95 -0
  446. warp/torch.py +24 -0
  447. warp/types.py +51 -0
  448. warp/utils.py +31 -0
  449. warp_lang-1.10.0.dist-info/METADATA +459 -0
  450. warp_lang-1.10.0.dist-info/RECORD +468 -0
  451. warp_lang-1.10.0.dist-info/WHEEL +5 -0
  452. warp_lang-1.10.0.dist-info/licenses/LICENSE.md +176 -0
  453. warp_lang-1.10.0.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  454. warp_lang-1.10.0.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  455. warp_lang-1.10.0.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  456. warp_lang-1.10.0.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  457. warp_lang-1.10.0.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  458. warp_lang-1.10.0.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  459. warp_lang-1.10.0.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  460. warp_lang-1.10.0.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  461. warp_lang-1.10.0.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  462. warp_lang-1.10.0.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  463. warp_lang-1.10.0.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  464. warp_lang-1.10.0.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  465. warp_lang-1.10.0.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  466. warp_lang-1.10.0.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  467. warp_lang-1.10.0.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  468. warp_lang-1.10.0.dist-info/top_level.txt +1 -0
warp/native/array.h ADDED
@@ -0,0 +1,1687 @@
1
+ /*
2
+ * SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: Apache-2.0
4
+ *
5
+ * Licensed under the Apache License, Version 2.0 (the "License");
6
+ * you may not use this file except in compliance with the License.
7
+ * You may obtain a copy of the License at
8
+ *
9
+ * http://www.apache.org/licenses/LICENSE-2.0
10
+ *
11
+ * Unless required by applicable law or agreed to in writing, software
12
+ * distributed under the License is distributed on an "AS IS" BASIS,
13
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ * See the License for the specific language governing permissions and
15
+ * limitations under the License.
16
+ */
17
+
18
+ #pragma once
19
+
20
+ #include "builtin.h"
21
+
22
+ namespace wp
23
+ {
24
+
25
+ #if FP_CHECK
26
+
27
+ #define FP_ASSERT_FWD(value) \
28
+ print(value); \
29
+ printf(")\n"); \
30
+ assert(0); \
31
+
32
+ #define FP_ASSERT_ADJ(value, adj_value) \
33
+ print(value); \
34
+ printf(", "); \
35
+ print(adj_value); \
36
+ printf(")\n"); \
37
+ assert(0); \
38
+
39
+ #define FP_VERIFY_FWD(value) \
40
+ if (!isfinite(value)) { \
41
+ printf("%s:%d - %s(addr", __FILE__, __LINE__, __FUNCTION__); \
42
+ FP_ASSERT_FWD(value) \
43
+ } \
44
+
45
+ #define FP_VERIFY_FWD_1(value) \
46
+ if (!isfinite(value)) { \
47
+ printf("%s:%d - %s(arr, %d) ", __FILE__, __LINE__, __FUNCTION__, i); \
48
+ FP_ASSERT_FWD(value) \
49
+ } \
50
+
51
+ #define FP_VERIFY_FWD_2(value) \
52
+ if (!isfinite(value)) { \
53
+ printf("%s:%d - %s(arr, %d, %d) ", __FILE__, __LINE__, __FUNCTION__, i, j); \
54
+ FP_ASSERT_FWD(value) \
55
+ } \
56
+
57
+ #define FP_VERIFY_FWD_3(value) \
58
+ if (!isfinite(value)) { \
59
+ printf("%s:%d - %s(arr, %d, %d, %d) ", __FILE__, __LINE__, __FUNCTION__, i, j, k); \
60
+ FP_ASSERT_FWD(value) \
61
+ } \
62
+
63
+ #define FP_VERIFY_FWD_4(value) \
64
+ if (!isfinite(value)) { \
65
+ printf("%s:%d - %s(arr, %d, %d, %d, %d) ", __FILE__, __LINE__, __FUNCTION__, i, j, k, l); \
66
+ FP_ASSERT_FWD(value) \
67
+ } \
68
+
69
+ #define FP_VERIFY_ADJ(value, adj_value) \
70
+ if (!isfinite(value) || !isfinite(adj_value)) \
71
+ { \
72
+ printf("%s:%d - %s(addr", __FILE__, __LINE__, __FUNCTION__); \
73
+ FP_ASSERT_ADJ(value, adj_value); \
74
+ } \
75
+
76
+ #define FP_VERIFY_ADJ_1(value, adj_value) \
77
+ if (!isfinite(value) || !isfinite(adj_value)) \
78
+ { \
79
+ printf("%s:%d - %s(arr, %d) ", __FILE__, __LINE__, __FUNCTION__, i); \
80
+ FP_ASSERT_ADJ(value, adj_value); \
81
+ } \
82
+
83
+ #define FP_VERIFY_ADJ_2(value, adj_value) \
84
+ if (!isfinite(value) || !isfinite(adj_value)) \
85
+ { \
86
+ printf("%s:%d - %s(arr, %d, %d) ", __FILE__, __LINE__, __FUNCTION__, i, j); \
87
+ FP_ASSERT_ADJ(value, adj_value); \
88
+ } \
89
+
90
+ #define FP_VERIFY_ADJ_3(value, adj_value) \
91
+ if (!isfinite(value) || !isfinite(adj_value)) \
92
+ { \
93
+ printf("%s:%d - %s(arr, %d, %d, %d) ", __FILE__, __LINE__, __FUNCTION__, i, j, k); \
94
+ FP_ASSERT_ADJ(value, adj_value); \
95
+ } \
96
+
97
+ #define FP_VERIFY_ADJ_4(value, adj_value) \
98
+ if (!isfinite(value) || !isfinite(adj_value)) \
99
+ { \
100
+ printf("%s:%d - %s(arr, %d, %d, %d, %d) ", __FILE__, __LINE__, __FUNCTION__, i, j, k, l); \
101
+ FP_ASSERT_ADJ(value, adj_value); \
102
+ } \
103
+
104
+
105
+ #else
106
+
107
+ #define FP_VERIFY_FWD(value) {}
108
+ #define FP_VERIFY_FWD_1(value) {}
109
+ #define FP_VERIFY_FWD_2(value) {}
110
+ #define FP_VERIFY_FWD_3(value) {}
111
+ #define FP_VERIFY_FWD_4(value) {}
112
+
113
+ #define FP_VERIFY_ADJ(value, adj_value) {}
114
+ #define FP_VERIFY_ADJ_1(value, adj_value) {}
115
+ #define FP_VERIFY_ADJ_2(value, adj_value) {}
116
+ #define FP_VERIFY_ADJ_3(value, adj_value) {}
117
+ #define FP_VERIFY_ADJ_4(value, adj_value) {}
118
+
119
+ #endif // WP_FP_CHECK
120
+
121
+
122
+ template<size_t... Is>
123
+ struct index_sequence {};
124
+
125
+ template<size_t N, size_t... Is>
126
+ struct make_index_sequence_impl : make_index_sequence_impl<N-1, N-1, Is...> {};
127
+
128
+ template<size_t... Is>
129
+ struct make_index_sequence_impl<0, Is...>
130
+ {
131
+ using type = index_sequence<Is...>;
132
+ };
133
+
134
+ template<size_t N>
135
+ using make_index_sequence = typename make_index_sequence_impl<N>::type;
136
+
137
+
138
+ const int ARRAY_MAX_DIMS = 4; // must match constant in types.py
139
+
140
+ // must match constants in types.py
141
+ const int ARRAY_TYPE_REGULAR = 0;
142
+ const int ARRAY_TYPE_INDEXED = 1;
143
+ const int ARRAY_TYPE_FABRIC = 2;
144
+ const int ARRAY_TYPE_FABRIC_INDEXED = 3;
145
+
146
+ struct shape_t
147
+ {
148
+ int dims[ARRAY_MAX_DIMS];
149
+
150
+ CUDA_CALLABLE inline shape_t()
151
+ : dims()
152
+ {}
153
+
154
+ CUDA_CALLABLE inline int operator[](int i) const
155
+ {
156
+ assert(i < ARRAY_MAX_DIMS);
157
+ return dims[i];
158
+ }
159
+
160
+ CUDA_CALLABLE inline int& operator[](int i)
161
+ {
162
+ assert(i < ARRAY_MAX_DIMS);
163
+ return dims[i];
164
+ }
165
+ };
166
+
167
+ CUDA_CALLABLE inline int extract(const shape_t& s, int i)
168
+ {
169
+ return s.dims[i];
170
+ }
171
+
172
+ CUDA_CALLABLE inline void adj_extract(const shape_t& s, int i, const shape_t& adj_s, int adj_i, int adj_ret) {}
173
+
174
+ inline CUDA_CALLABLE void print(shape_t s)
175
+ {
176
+ // todo: only print valid dims, currently shape has a fixed size
177
+ // but we don't know how many dims are valid (e.g.: 1d, 2d, etc)
178
+ // should probably store ndim with shape
179
+ printf("(%d, %d, %d, %d)\n", s.dims[0], s.dims[1], s.dims[2], s.dims[3]);
180
+ }
181
+ inline CUDA_CALLABLE void adj_print(shape_t s, shape_t& adj_s) {}
182
+
183
+
184
+ template <typename T>
185
+ struct array_t
186
+ {
187
+ CUDA_CALLABLE inline array_t()
188
+ : data(nullptr),
189
+ grad(nullptr),
190
+ shape(),
191
+ strides(),
192
+ ndim(0)
193
+ {}
194
+
195
+ CUDA_CALLABLE array_t(T* data, int size, T* grad=nullptr) : data(data), grad(grad) {
196
+ // constructor for 1d array
197
+ shape.dims[0] = size;
198
+ shape.dims[1] = 0;
199
+ shape.dims[2] = 0;
200
+ shape.dims[3] = 0;
201
+ ndim = 1;
202
+ strides[0] = sizeof(T);
203
+ strides[1] = 0;
204
+ strides[2] = 0;
205
+ strides[3] = 0;
206
+ }
207
+ CUDA_CALLABLE array_t(T* data, int dim0, int dim1, T* grad=nullptr) : data(data), grad(grad) {
208
+ // constructor for 2d array
209
+ shape.dims[0] = dim0;
210
+ shape.dims[1] = dim1;
211
+ shape.dims[2] = 0;
212
+ shape.dims[3] = 0;
213
+ ndim = 2;
214
+ strides[0] = dim1 * sizeof(T);
215
+ strides[1] = sizeof(T);
216
+ strides[2] = 0;
217
+ strides[3] = 0;
218
+ }
219
+ CUDA_CALLABLE array_t(T* data, int dim0, int dim1, int dim2, T* grad=nullptr) : data(data), grad(grad) {
220
+ // constructor for 3d array
221
+ shape.dims[0] = dim0;
222
+ shape.dims[1] = dim1;
223
+ shape.dims[2] = dim2;
224
+ shape.dims[3] = 0;
225
+ ndim = 3;
226
+ strides[0] = dim1 * dim2 * sizeof(T);
227
+ strides[1] = dim2 * sizeof(T);
228
+ strides[2] = sizeof(T);
229
+ strides[3] = 0;
230
+ }
231
+ CUDA_CALLABLE array_t(T* data, int dim0, int dim1, int dim2, int dim3, T* grad=nullptr) : data(data), grad(grad) {
232
+ // constructor for 4d array
233
+ shape.dims[0] = dim0;
234
+ shape.dims[1] = dim1;
235
+ shape.dims[2] = dim2;
236
+ shape.dims[3] = dim3;
237
+ ndim = 4;
238
+ strides[0] = dim1 * dim2 * dim3 * sizeof(T);
239
+ strides[1] = dim2 * dim3 * sizeof(T);
240
+ strides[2] = dim3 * sizeof(T);
241
+ strides[3] = sizeof(T);
242
+ }
243
+
244
+ CUDA_CALLABLE array_t(uint64 data, int size, uint64 grad=0)
245
+ : array_t((T*)(data), size, (T*)(grad))
246
+ {}
247
+
248
+ CUDA_CALLABLE array_t(uint64 data, int dim0, int dim1, uint64 grad=0)
249
+ : array_t((T*)(data), dim0, dim1, (T*)(grad))
250
+ {}
251
+
252
+ CUDA_CALLABLE array_t(uint64 data, int dim0, int dim1, int dim2, uint64 grad=0)
253
+ : array_t((T*)(data), dim0, dim1, dim2, (T*)(grad))
254
+ {}
255
+
256
+ CUDA_CALLABLE array_t(uint64 data, int dim0, int dim1, int dim2, int dim3, uint64 grad=0)
257
+ : array_t((T*)(data), dim0, dim1, dim2, dim3, (T*)(grad))
258
+ {}
259
+
260
+ CUDA_CALLABLE inline bool empty() const { return !data; }
261
+
262
+ T* data;
263
+ T* grad;
264
+ shape_t shape;
265
+ int strides[ARRAY_MAX_DIMS];
266
+ int ndim;
267
+
268
+ CUDA_CALLABLE inline operator T*() const { return data; }
269
+ };
270
+
271
+
272
+ // Required when compiling adjoints.
273
+ template <typename T>
274
+ inline CUDA_CALLABLE array_t<T> add(
275
+ const array_t<T>& a, const array_t<T>& b
276
+ )
277
+ {
278
+ return array_t<T>();
279
+ }
280
+
281
+
282
+ // Stack‑allocated counterpart to `array_t<T>`.
283
+ // Useful for small buffers that have their shape known at compile-time,
284
+ // and that gain from having array semantics instead of vectors.
285
+ template <int Size, typename T>
286
+ struct fixedarray_t : array_t<T>
287
+ {
288
+ using Base = array_t<T>;
289
+
290
+ static_assert(Size > 0, "Expected Size > 0");
291
+
292
+ CUDA_CALLABLE inline fixedarray_t()
293
+ : Base(storage, Size), storage()
294
+ {}
295
+
296
+ CUDA_CALLABLE fixedarray_t(int dim0, T* grad=nullptr)
297
+ : Base(storage, dim0, grad), storage()
298
+ {
299
+ assert(Size == dim0);
300
+ }
301
+
302
+ CUDA_CALLABLE fixedarray_t(int dim0, int dim1, T* grad=nullptr)
303
+ : Base(storage, dim0, dim1, grad), storage()
304
+ {
305
+ assert(Size == dim0 * dim1);
306
+ }
307
+
308
+ CUDA_CALLABLE fixedarray_t(int dim0, int dim1, int dim2, T* grad=nullptr)
309
+ : Base(storage, dim0, dim1, dim2, grad), storage()
310
+ {
311
+ assert(Size == dim0 * dim1 * dim2);
312
+ }
313
+
314
+ CUDA_CALLABLE fixedarray_t(int dim0, int dim1, int dim2, int dim3, T* grad=nullptr)
315
+ : Base(storage, dim0, dim1, dim2, dim3, grad), storage()
316
+ {
317
+ assert(Size == dim0 * dim1 * dim2 * dim3);
318
+ }
319
+
320
+ CUDA_CALLABLE fixedarray_t<Size, T>& operator=(const fixedarray_t<Size, T>& other)
321
+ {
322
+ for (unsigned int i = 0; i < Size; ++i)
323
+ {
324
+ this->storage[i] = other.storage[i];
325
+ }
326
+
327
+ this->data = this->storage;
328
+ this->grad = nullptr;
329
+ this->shape = other.shape;
330
+
331
+ for (unsigned int i = 0; i < ARRAY_MAX_DIMS; ++i)
332
+ {
333
+ this->strides[i] = other.strides[i];
334
+ }
335
+
336
+ this->ndim = other.ndim;
337
+
338
+ return *this;
339
+ }
340
+
341
+ T storage[Size];
342
+ };
343
+
344
+
345
+ // Required when compiling adjoints.
346
+ template <int Size, typename T>
347
+ inline CUDA_CALLABLE fixedarray_t<Size, T> add(
348
+ const fixedarray_t<Size, T>& a, const fixedarray_t<Size, T>& b
349
+ )
350
+ {
351
+ return fixedarray_t<Size, T>();
352
+ }
353
+
354
+
355
+ // TODO:
356
+ // - templated index type?
357
+ // - templated dimensionality? (also for array_t to save space when passing arrays to kernels)
358
+ template <typename T>
359
+ struct indexedarray_t
360
+ {
361
+ CUDA_CALLABLE inline indexedarray_t()
362
+ : arr(),
363
+ indices(),
364
+ shape()
365
+ {}
366
+
367
+ CUDA_CALLABLE inline bool empty() const { return !arr.data; }
368
+
369
+ array_t<T> arr;
370
+ int* indices[ARRAY_MAX_DIMS]; // index array per dimension (can be NULL)
371
+ shape_t shape; // element count per dimension (num. indices if indexed, array dim if not)
372
+ };
373
+
374
+
375
+ // return stride (in bytes) of the given index
376
+ template <typename T>
377
+ CUDA_CALLABLE inline size_t stride(const array_t<T>& a, int dim)
378
+ {
379
+ return size_t(a.strides[dim]);
380
+ }
381
+
382
+ template <typename T>
383
+ CUDA_CALLABLE inline T* data_at_byte_offset(const array_t<T>& a, size_t byte_offset)
384
+ {
385
+ return reinterpret_cast<T*>(reinterpret_cast<char*>(a.data) + byte_offset);
386
+ }
387
+
388
+ template <typename T>
389
+ CUDA_CALLABLE inline T* grad_at_byte_offset(const array_t<T>& a, size_t byte_offset)
390
+ {
391
+ return reinterpret_cast<T*>(reinterpret_cast<char*>(a.grad) + byte_offset);
392
+ }
393
+
394
+ template <typename T>
395
+ CUDA_CALLABLE inline size_t byte_offset(const array_t<T>& arr, int i)
396
+ {
397
+ assert(i >= 0 && i < arr.shape[0]);
398
+
399
+ return i*stride(arr, 0);
400
+ }
401
+
402
+ template <typename T>
403
+ CUDA_CALLABLE inline size_t byte_offset(const array_t<T>& arr, int i, int j)
404
+ {
405
+ // if (i < 0 || i >= arr.shape[0])
406
+ // printf("i: %d > arr.shape[0]: %d\n", i, arr.shape[0]);
407
+
408
+ // if (j < 0 || j >= arr.shape[1])
409
+ // printf("j: %d > arr.shape[1]: %d\n", j, arr.shape[1]);
410
+
411
+
412
+ assert(i >= 0 && i < arr.shape[0]);
413
+ assert(j >= 0 && j < arr.shape[1]);
414
+
415
+ return i*stride(arr, 0) + j*stride(arr, 1);
416
+ }
417
+
418
+ template <typename T>
419
+ CUDA_CALLABLE inline size_t byte_offset(const array_t<T>& arr, int i, int j, int k)
420
+ {
421
+ assert(i >= 0 && i < arr.shape[0]);
422
+ assert(j >= 0 && j < arr.shape[1]);
423
+ assert(k >= 0 && k < arr.shape[2]);
424
+
425
+ return i*stride(arr, 0) + j*stride(arr, 1) + k*stride(arr, 2);
426
+ }
427
+
428
+ template <typename T>
429
+ CUDA_CALLABLE inline size_t byte_offset(const array_t<T>& arr, int i, int j, int k, int l)
430
+ {
431
+ assert(i >= 0 && i < arr.shape[0]);
432
+ assert(j >= 0 && j < arr.shape[1]);
433
+ assert(k >= 0 && k < arr.shape[2]);
434
+ assert(l >= 0 && l < arr.shape[3]);
435
+
436
+ return i*stride(arr, 0) + j*stride(arr, 1) + k*stride(arr, 2) + l*stride(arr, 3);
437
+ }
438
+
439
+ template <typename T>
440
+ CUDA_CALLABLE inline T& index(const array_t<T>& arr, int i)
441
+ {
442
+ assert(arr.ndim == 1);
443
+ assert(i >= -arr.shape[0] && i < arr.shape[0]);
444
+
445
+ if (i < 0)
446
+ {
447
+ i += arr.shape[0];
448
+ }
449
+
450
+ T& result = *data_at_byte_offset(arr, byte_offset(arr, i));
451
+ FP_VERIFY_FWD_1(result)
452
+
453
+ return result;
454
+ }
455
+
456
+ template <typename T>
457
+ CUDA_CALLABLE inline T& index(const array_t<T>& arr, int i, int j)
458
+ {
459
+ assert(arr.ndim == 2);
460
+ assert(i >= -arr.shape[0] && i < arr.shape[0]);
461
+ assert(j >= -arr.shape[1] && j < arr.shape[1]);
462
+
463
+ if (i < 0)
464
+ {
465
+ i += arr.shape[0];
466
+ }
467
+ if (j < 0)
468
+ {
469
+ j += arr.shape[1];
470
+ }
471
+
472
+ T& result = *data_at_byte_offset(arr, byte_offset(arr, i, j));
473
+ FP_VERIFY_FWD_2(result)
474
+
475
+ return result;
476
+ }
477
+
478
+ template <typename T>
479
+ CUDA_CALLABLE inline T& index(const array_t<T>& arr, int i, int j, int k)
480
+ {
481
+ assert(arr.ndim == 3);
482
+ assert(i >= -arr.shape[0] && i < arr.shape[0]);
483
+ assert(j >= -arr.shape[1] && j < arr.shape[1]);
484
+ assert(k >= -arr.shape[2] && k < arr.shape[2]);
485
+
486
+ if (i < 0)
487
+ {
488
+ i += arr.shape[0];
489
+ }
490
+ if (j < 0)
491
+ {
492
+ j += arr.shape[1];
493
+ }
494
+ if (k < 0)
495
+ {
496
+ k += arr.shape[2];
497
+ }
498
+
499
+ T& result = *data_at_byte_offset(arr, byte_offset(arr, i, j, k));
500
+ FP_VERIFY_FWD_3(result)
501
+
502
+ return result;
503
+ }
504
+
505
+ template <typename T>
506
+ CUDA_CALLABLE inline T& index(const array_t<T>& arr, int i, int j, int k, int l)
507
+ {
508
+ assert(arr.ndim == 4);
509
+ assert(i >= -arr.shape[0] && i < arr.shape[0]);
510
+ assert(j >= -arr.shape[1] && j < arr.shape[1]);
511
+ assert(k >= -arr.shape[2] && k < arr.shape[2]);
512
+ assert(l >= -arr.shape[3] && l < arr.shape[3]);
513
+
514
+ if (i < 0)
515
+ {
516
+ i += arr.shape[0];
517
+ }
518
+ if (j < 0)
519
+ {
520
+ j += arr.shape[1];
521
+ }
522
+ if (k < 0)
523
+ {
524
+ k += arr.shape[2];
525
+ }
526
+ if (l < 0)
527
+ {
528
+ l += arr.shape[3];
529
+ }
530
+
531
+ T& result = *data_at_byte_offset(arr, byte_offset(arr, i, j, k, l));
532
+ FP_VERIFY_FWD_4(result)
533
+
534
+ return result;
535
+ }
536
+
537
+ template <typename T>
538
+ CUDA_CALLABLE inline T& index_grad(const array_t<T>& arr, int i)
539
+ {
540
+ assert(arr.ndim == 1);
541
+ assert(i >= -arr.shape[0] && i < arr.shape[0]);
542
+
543
+ if (i < 0)
544
+ {
545
+ i += arr.shape[0];
546
+ }
547
+
548
+ T& result = *grad_at_byte_offset(arr, byte_offset(arr, i));
549
+ FP_VERIFY_FWD_1(result)
550
+
551
+ return result;
552
+ }
553
+
554
+ template <typename T>
555
+ CUDA_CALLABLE inline T& index_grad(const array_t<T>& arr, int i, int j)
556
+ {
557
+ assert(arr.ndim == 2);
558
+ assert(i >= -arr.shape[0] && i < arr.shape[0]);
559
+ assert(j >= -arr.shape[1] && j < arr.shape[1]);
560
+
561
+ if (i < 0)
562
+ {
563
+ i += arr.shape[0];
564
+ }
565
+ if (j < 0)
566
+ {
567
+ j += arr.shape[1];
568
+ }
569
+
570
+ T& result = *grad_at_byte_offset(arr, byte_offset(arr, i, j));
571
+ FP_VERIFY_FWD_2(result)
572
+
573
+ return result;
574
+ }
575
+
576
+ template <typename T>
577
+ CUDA_CALLABLE inline T& index_grad(const array_t<T>& arr, int i, int j, int k)
578
+ {
579
+ assert(arr.ndim == 3);
580
+ assert(i >= -arr.shape[0] && i < arr.shape[0]);
581
+ assert(j >= -arr.shape[1] && j < arr.shape[1]);
582
+ assert(k >= -arr.shape[2] && k < arr.shape[2]);
583
+
584
+ if (i < 0)
585
+ {
586
+ i += arr.shape[0];
587
+ }
588
+ if (j < 0)
589
+ {
590
+ j += arr.shape[1];
591
+ }
592
+ if (k < 0)
593
+ {
594
+ k += arr.shape[2];
595
+ }
596
+
597
+ T& result = *grad_at_byte_offset(arr, byte_offset(arr, i, j, k));
598
+ FP_VERIFY_FWD_3(result)
599
+
600
+ return result;
601
+ }
602
+
603
+ template <typename T>
604
+ CUDA_CALLABLE inline T& index_grad(const array_t<T>& arr, int i, int j, int k, int l)
605
+ {
606
+ assert(arr.ndim == 4);
607
+ assert(i >= -arr.shape[0] && i < arr.shape[0]);
608
+ assert(j >= -arr.shape[1] && j < arr.shape[1]);
609
+ assert(k >= -arr.shape[2] && k < arr.shape[2]);
610
+ assert(l >= -arr.shape[3] && l < arr.shape[3]);
611
+
612
+ if (i < 0)
613
+ {
614
+ i += arr.shape[0];
615
+ }
616
+ if (j < 0)
617
+ {
618
+ j += arr.shape[1];
619
+ }
620
+ if (k < 0)
621
+ {
622
+ k += arr.shape[2];
623
+ }
624
+ if (l < 0)
625
+ {
626
+ l += arr.shape[3];
627
+ }
628
+
629
+ T& result = *grad_at_byte_offset(arr, byte_offset(arr, i, j, k, l));
630
+ FP_VERIFY_FWD_4(result)
631
+
632
+ return result;
633
+ }
634
+
635
+
636
+ template <typename T>
637
+ CUDA_CALLABLE inline T& index(const indexedarray_t<T>& iarr, int i)
638
+ {
639
+ assert(iarr.arr.ndim == 1);
640
+ assert(i >= -iarr.shape[0] && i < iarr.shape[0]);
641
+
642
+ if (i < 0)
643
+ {
644
+ i += iarr.shape[0];
645
+ }
646
+
647
+ if (iarr.indices[0])
648
+ {
649
+ i = iarr.indices[0][i];
650
+ assert(i >= 0 && i < iarr.arr.shape[0]);
651
+ }
652
+
653
+ T& result = *data_at_byte_offset(iarr.arr, byte_offset(iarr.arr, i));
654
+ FP_VERIFY_FWD_1(result)
655
+
656
+ return result;
657
+ }
658
+
659
+ template <typename T>
660
+ CUDA_CALLABLE inline T& index(const indexedarray_t<T>& iarr, int i, int j)
661
+ {
662
+ assert(iarr.arr.ndim == 2);
663
+ assert(i >= -iarr.shape[0] && i < iarr.shape[0]);
664
+ assert(j >= -iarr.shape[1] && j < iarr.shape[1]);
665
+
666
+ if (i < 0)
667
+ {
668
+ i += iarr.shape[0];
669
+ }
670
+ if (j < 0)
671
+ {
672
+ j += iarr.shape[1];
673
+ }
674
+
675
+ if (iarr.indices[0])
676
+ {
677
+ i = iarr.indices[0][i];
678
+ assert(i >= 0 && i < iarr.arr.shape[0]);
679
+ }
680
+ if (iarr.indices[1])
681
+ {
682
+ j = iarr.indices[1][j];
683
+ assert(j >= 0 && j < iarr.arr.shape[1]);
684
+ }
685
+
686
+ T& result = *data_at_byte_offset(iarr.arr, byte_offset(iarr.arr, i, j));
687
+ FP_VERIFY_FWD_1(result)
688
+
689
+ return result;
690
+ }
691
+
692
+ template <typename T>
693
+ CUDA_CALLABLE inline T& index(const indexedarray_t<T>& iarr, int i, int j, int k)
694
+ {
695
+ assert(iarr.arr.ndim == 3);
696
+ assert(i >= -iarr.shape[0] && i < iarr.shape[0]);
697
+ assert(j >= -iarr.shape[1] && j < iarr.shape[1]);
698
+ assert(k >= -iarr.shape[2] && k < iarr.shape[2]);
699
+
700
+ if (i < 0)
701
+ {
702
+ i += iarr.shape[0];
703
+ }
704
+ if (j < 0)
705
+ {
706
+ j += iarr.shape[1];
707
+ }
708
+ if (k < 0)
709
+ {
710
+ k += iarr.shape[2];
711
+ }
712
+
713
+ if (iarr.indices[0])
714
+ {
715
+ i = iarr.indices[0][i];
716
+ assert(i >= 0 && i < iarr.arr.shape[0]);
717
+ }
718
+ if (iarr.indices[1])
719
+ {
720
+ j = iarr.indices[1][j];
721
+ assert(j >= 0 && j < iarr.arr.shape[1]);
722
+ }
723
+ if (iarr.indices[2])
724
+ {
725
+ k = iarr.indices[2][k];
726
+ assert(k >= 0 && k < iarr.arr.shape[2]);
727
+ }
728
+
729
+ T& result = *data_at_byte_offset(iarr.arr, byte_offset(iarr.arr, i, j, k));
730
+ FP_VERIFY_FWD_1(result)
731
+
732
+ return result;
733
+ }
734
+
735
+ template <typename T>
736
+ CUDA_CALLABLE inline T& index(const indexedarray_t<T>& iarr, int i, int j, int k, int l)
737
+ {
738
+ assert(iarr.arr.ndim == 4);
739
+ assert(i >= -iarr.shape[0] && i < iarr.shape[0]);
740
+ assert(j >= -iarr.shape[1] && j < iarr.shape[1]);
741
+ assert(k >= -iarr.shape[2] && k < iarr.shape[2]);
742
+ assert(l >= -iarr.shape[3] && l < iarr.shape[3]);
743
+
744
+ if (i < 0)
745
+ {
746
+ i += iarr.shape[0];
747
+ }
748
+ if (j < 0)
749
+ {
750
+ j += iarr.shape[1];
751
+ }
752
+ if (k < 0)
753
+ {
754
+ k += iarr.shape[2];
755
+ }
756
+ if (l < 0)
757
+ {
758
+ l += iarr.shape[3];
759
+ }
760
+
761
+ if (iarr.indices[0])
762
+ {
763
+ i = iarr.indices[0][i];
764
+ assert(i >= 0 && i < iarr.arr.shape[0]);
765
+ }
766
+ if (iarr.indices[1])
767
+ {
768
+ j = iarr.indices[1][j];
769
+ assert(j >= 0 && j < iarr.arr.shape[1]);
770
+ }
771
+ if (iarr.indices[2])
772
+ {
773
+ k = iarr.indices[2][k];
774
+ assert(k >= 0 && k < iarr.arr.shape[2]);
775
+ }
776
+ if (iarr.indices[3])
777
+ {
778
+ l = iarr.indices[3][l];
779
+ assert(l >= 0 && l < iarr.arr.shape[3]);
780
+ }
781
+
782
+ T& result = *data_at_byte_offset(iarr.arr, byte_offset(iarr.arr, i, j, k, l));
783
+ FP_VERIFY_FWD_1(result)
784
+
785
+ return result;
786
+ }
787
+
788
+
789
+ template <typename T>
790
+ CUDA_CALLABLE inline array_t<T> view(array_t<T>& src, int i)
791
+ {
792
+ assert(src.ndim > 1);
793
+ assert(i >= -src.shape[0] && i < src.shape[0]);
794
+
795
+ if (i < 0)
796
+ {
797
+ i += src.shape[0];
798
+ }
799
+
800
+ array_t<T> a;
801
+ size_t offset = byte_offset(src, i);
802
+ a.data = data_at_byte_offset(src, offset);
803
+ if (src.grad)
804
+ a.grad = grad_at_byte_offset(src, offset);
805
+ a.shape[0] = src.shape[1];
806
+ a.shape[1] = src.shape[2];
807
+ a.shape[2] = src.shape[3];
808
+ a.strides[0] = src.strides[1];
809
+ a.strides[1] = src.strides[2];
810
+ a.strides[2] = src.strides[3];
811
+ a.ndim = src.ndim-1;
812
+
813
+ return a;
814
+ }
815
+
816
+ template <typename T>
817
+ CUDA_CALLABLE inline array_t<T> view(array_t<T>& src, int i, int j)
818
+ {
819
+ assert(src.ndim > 2);
820
+ assert(i >= -src.shape[0] && i < src.shape[0]);
821
+ assert(j >= -src.shape[1] && j < src.shape[1]);
822
+
823
+ if (i < 0)
824
+ {
825
+ i += src.shape[0];
826
+ }
827
+ if (j < 0)
828
+ {
829
+ j += src.shape[1];
830
+ }
831
+
832
+ array_t<T> a;
833
+ size_t offset = byte_offset(src, i, j);
834
+ a.data = data_at_byte_offset(src, offset);
835
+ if (src.grad)
836
+ a.grad = grad_at_byte_offset(src, offset);
837
+ a.shape[0] = src.shape[2];
838
+ a.shape[1] = src.shape[3];
839
+ a.strides[0] = src.strides[2];
840
+ a.strides[1] = src.strides[3];
841
+ a.ndim = src.ndim-2;
842
+
843
+ return a;
844
+ }
845
+
846
+ template <typename T>
847
+ CUDA_CALLABLE inline array_t<T> view(array_t<T>& src, int i, int j, int k)
848
+ {
849
+ assert(src.ndim > 3);
850
+ assert(i >= -src.shape[0] && i < src.shape[0]);
851
+ assert(j >= -src.shape[1] && j < src.shape[1]);
852
+ assert(k >= -src.shape[2] && k < src.shape[2]);
853
+
854
+ if (i < 0)
855
+ {
856
+ i += src.shape[0];
857
+ }
858
+ if (j < 0)
859
+ {
860
+ j += src.shape[1];
861
+ }
862
+ if (k < 0)
863
+ {
864
+ k += src.shape[2];
865
+ }
866
+
867
+ array_t<T> a;
868
+ size_t offset = byte_offset(src, i, j, k);
869
+ a.data = data_at_byte_offset(src, offset);
870
+ if (src.grad)
871
+ a.grad = grad_at_byte_offset(src, offset);
872
+ a.shape[0] = src.shape[3];
873
+ a.strides[0] = src.strides[3];
874
+ a.ndim = src.ndim-3;
875
+
876
+ return a;
877
+ }
878
+
879
+
880
+ template <typename T, size_t... Idxs>
881
+ size_t byte_offset_helper(
882
+ array_t<T>& src,
883
+ const slice_t (&slices)[sizeof...(Idxs)],
884
+ index_sequence<Idxs...>
885
+ )
886
+ {
887
+ return byte_offset(src, slices[Idxs].start...);
888
+ }
889
+
890
+
891
+ template <typename T, typename... Slices>
892
+ CUDA_CALLABLE inline array_t<T> view(array_t<T>& src, const Slices&... slice_args)
893
+ {
894
+ constexpr int N = sizeof...(Slices);
895
+ static_assert(N >= 1 && N <= 4, "view supports 1 to 4 slices");
896
+ assert(src.ndim >= N);
897
+
898
+ slice_t slices[N] = { slice_args... };
899
+ int slice_idxs[N];
900
+ int slice_count = 0;
901
+
902
+ for (int i = 0; i < N; ++i)
903
+ {
904
+ if (slices[i].step == 0)
905
+ {
906
+ // We have a slice representing an integer index.
907
+ if (slices[i].start < 0)
908
+ {
909
+ slices[i].start += src.shape[i];
910
+ }
911
+ }
912
+ else
913
+ {
914
+ slices[i] = slice_adjust_indices(slices[i], src.shape[i]);
915
+ slice_idxs[slice_count] = i;
916
+ ++slice_count;
917
+ }
918
+ }
919
+
920
+ size_t offset = byte_offset_helper(src, slices, make_index_sequence<N>{});
921
+
922
+ array_t<T> out;
923
+
924
+ out.data = data_at_byte_offset(src, offset);
925
+ if (src.grad)
926
+ {
927
+ out.grad = grad_at_byte_offset(src, offset);
928
+ }
929
+
930
+ int dim = 0;
931
+ for (; dim < slice_count; ++dim)
932
+ {
933
+ int idx = slice_idxs[dim];
934
+ out.shape[dim] = slice_get_length(slices[idx]);
935
+ out.strides[dim] = src.strides[idx] * slices[idx].step;
936
+ }
937
+ for (; dim < slice_count + 4 - N; ++dim)
938
+ {
939
+ out.shape[dim] = src.shape[dim - slice_count + N];
940
+ out.strides[dim] = src.strides[dim - slice_count + N];
941
+ }
942
+ for (; dim < 4; ++dim)
943
+ {
944
+ out.shape[dim] = 0;
945
+ out.strides[dim] = 0;
946
+ }
947
+
948
+ out.ndim = src.ndim + slice_count - N;
949
+ return out;
950
+ }
951
+
952
+ template <typename T>
953
+ CUDA_CALLABLE inline indexedarray_t<T> view(indexedarray_t<T>& src, int i)
954
+ {
955
+ assert(src.arr.ndim > 1);
956
+
957
+ if (src.indices[0])
958
+ {
959
+ assert(i >= -src.shape[0] && i < src.shape[0]);
960
+ if (i < 0)
961
+ {
962
+ i += src.shape[0];
963
+ }
964
+ i = src.indices[0][i];
965
+ }
966
+
967
+ indexedarray_t<T> a;
968
+ a.arr = view(src.arr, i);
969
+ a.indices[0] = src.indices[1];
970
+ a.indices[1] = src.indices[2];
971
+ a.indices[2] = src.indices[3];
972
+ a.shape[0] = src.shape[1];
973
+ a.shape[1] = src.shape[2];
974
+ a.shape[2] = src.shape[3];
975
+
976
+ return a;
977
+ }
978
+
979
+ template <typename T>
980
+ CUDA_CALLABLE inline indexedarray_t<T> view(indexedarray_t<T>& src, int i, int j)
981
+ {
982
+ assert(src.arr.ndim > 2);
983
+
984
+ if (src.indices[0])
985
+ {
986
+ assert(i >= -src.shape[0] && i < src.shape[0]);
987
+ if (i < 0)
988
+ {
989
+ i += src.shape[0];
990
+ }
991
+ i = src.indices[0][i];
992
+ }
993
+ if (src.indices[1])
994
+ {
995
+ assert(j >= -src.shape[1] && j < src.shape[1]);
996
+ if (j < 0)
997
+ {
998
+ j += src.shape[1];
999
+ }
1000
+ j = src.indices[1][j];
1001
+ }
1002
+
1003
+ indexedarray_t<T> a;
1004
+ a.arr = view(src.arr, i, j);
1005
+ a.indices[0] = src.indices[2];
1006
+ a.indices[1] = src.indices[3];
1007
+ a.shape[0] = src.shape[2];
1008
+ a.shape[1] = src.shape[3];
1009
+
1010
+ return a;
1011
+ }
1012
+
1013
+ template <typename T>
1014
+ CUDA_CALLABLE inline indexedarray_t<T> view(indexedarray_t<T>& src, int i, int j, int k)
1015
+ {
1016
+ assert(src.arr.ndim > 3);
1017
+
1018
+ if (src.indices[0])
1019
+ {
1020
+ assert(i >= -src.shape[0] && i < src.shape[0]);
1021
+ if (i < 0)
1022
+ {
1023
+ i += src.shape[0];
1024
+ }
1025
+ i = src.indices[0][i];
1026
+ }
1027
+ if (src.indices[1])
1028
+ {
1029
+ assert(j >= -src.shape[1] && j < src.shape[1]);
1030
+ if (j < 0)
1031
+ {
1032
+ j += src.shape[1];
1033
+ }
1034
+ j = src.indices[1][j];
1035
+ }
1036
+ if (src.indices[2])
1037
+ {
1038
+ assert(k >= -src.shape[2] && k < src.shape[2]);
1039
+ if (k < 0)
1040
+ {
1041
+ k += src.shape[2];
1042
+ }
1043
+ k = src.indices[2][k];
1044
+ }
1045
+
1046
+ indexedarray_t<T> a;
1047
+ a.arr = view(src.arr, i, j, k);
1048
+ a.indices[0] = src.indices[3];
1049
+ a.shape[0] = src.shape[3];
1050
+
1051
+ return a;
1052
+ }
1053
+
1054
+ template<template<typename> class A1, template<typename> class A2, template<typename> class A3, typename T>
1055
+ inline CUDA_CALLABLE void adj_view(A1<T>& src, int i, A2<T>& adj_src, int adj_i, A3<T>& adj_ret) {}
1056
+ template<template<typename> class A1, template<typename> class A2, template<typename> class A3, typename T>
1057
+ inline CUDA_CALLABLE void adj_view(A1<T>& src, int i, int j, A2<T>& adj_src, int adj_i, int adj_j, A3<T>& adj_ret) {}
1058
+ template<template<typename> class A1, template<typename> class A2, template<typename> class A3, typename T>
1059
+ inline CUDA_CALLABLE void adj_view(A1<T>& src, int i, int j, int k, A2<T>& adj_src, int adj_i, int adj_j, int adj_k, A3<T>& adj_ret) {}
1060
+
1061
+ template <typename... Args>
1062
+ CUDA_CALLABLE inline void adj_view(Args&&...) { }
1063
+
1064
+ // TODO: lower_bound() for indexed arrays?
1065
+
1066
+ template <typename T>
1067
+ CUDA_CALLABLE inline int lower_bound(const array_t<T>& arr, int arr_begin, int arr_end, T value)
1068
+ {
1069
+ assert(arr.ndim == 1);
1070
+
1071
+ int lower = arr_begin;
1072
+ int upper = arr_end - 1;
1073
+
1074
+ while(lower < upper)
1075
+ {
1076
+ int mid = lower + (upper - lower) / 2;
1077
+
1078
+ if (arr[mid] < value)
1079
+ {
1080
+ lower = mid + 1;
1081
+ }
1082
+ else
1083
+ {
1084
+ upper = mid;
1085
+ }
1086
+ }
1087
+
1088
+ return lower;
1089
+ }
1090
+
1091
+ template <typename T>
1092
+ CUDA_CALLABLE inline int lower_bound(const array_t<T>& arr, T value)
1093
+ {
1094
+ return lower_bound(arr, 0, arr.shape[0], value);
1095
+ }
1096
+
1097
+ template <typename T> inline CUDA_CALLABLE void adj_lower_bound(const array_t<T>& arr, T value, array_t<T> adj_arr, T adj_value, int adj_ret) {}
1098
+ template <typename T> inline CUDA_CALLABLE void adj_lower_bound(const array_t<T>& arr, int arr_begin, int arr_end, T value, array_t<T> adj_arr, int adj_arr_begin, int adj_arr_end, T adj_value, int adj_ret) {}
1099
+
1100
+ template<template<typename> class A, typename T>
1101
+ inline CUDA_CALLABLE T atomic_add(const A<T>& buf, int i, T value) { return atomic_add(&index(buf, i), value); }
1102
+ template<template<typename> class A, typename T>
1103
+ inline CUDA_CALLABLE T atomic_add(const A<T>& buf, int i, int j, T value) { return atomic_add(&index(buf, i, j), value); }
1104
+ template<template<typename> class A, typename T>
1105
+ inline CUDA_CALLABLE T atomic_add(const A<T>& buf, int i, int j, int k, T value) { return atomic_add(&index(buf, i, j, k), value); }
1106
+ template<template<typename> class A, typename T>
1107
+ inline CUDA_CALLABLE T atomic_add(const A<T>& buf, int i, int j, int k, int l, T value) { return atomic_add(&index(buf, i, j, k, l), value); }
1108
+
1109
+ template<template<typename> class A, typename T>
1110
+ inline CUDA_CALLABLE T atomic_sub(const A<T>& buf, int i, T value) { return atomic_add(&index(buf, i), -value); }
1111
+ template<template<typename> class A, typename T>
1112
+ inline CUDA_CALLABLE T atomic_sub(const A<T>& buf, int i, int j, T value) { return atomic_add(&index(buf, i, j), -value); }
1113
+ template<template<typename> class A, typename T>
1114
+ inline CUDA_CALLABLE T atomic_sub(const A<T>& buf, int i, int j, int k, T value) { return atomic_add(&index(buf, i, j, k), -value); }
1115
+ template<template<typename> class A, typename T>
1116
+ inline CUDA_CALLABLE T atomic_sub(const A<T>& buf, int i, int j, int k, int l, T value) { return atomic_add(&index(buf, i, j, k, l), -value); }
1117
+
1118
+ template<template<typename> class A, typename T>
1119
+ inline CUDA_CALLABLE T atomic_min(const A<T>& buf, int i, T value) { return atomic_min(&index(buf, i), value); }
1120
+ template<template<typename> class A, typename T>
1121
+ inline CUDA_CALLABLE T atomic_min(const A<T>& buf, int i, int j, T value) { return atomic_min(&index(buf, i, j), value); }
1122
+ template<template<typename> class A, typename T>
1123
+ inline CUDA_CALLABLE T atomic_min(const A<T>& buf, int i, int j, int k, T value) { return atomic_min(&index(buf, i, j, k), value); }
1124
+ template<template<typename> class A, typename T>
1125
+ inline CUDA_CALLABLE T atomic_min(const A<T>& buf, int i, int j, int k, int l, T value) { return atomic_min(&index(buf, i, j, k, l), value); }
1126
+
1127
+ template<template<typename> class A, typename T>
1128
+ inline CUDA_CALLABLE T atomic_max(const A<T>& buf, int i, T value) { return atomic_max(&index(buf, i), value); }
1129
+ template<template<typename> class A, typename T>
1130
+ inline CUDA_CALLABLE T atomic_max(const A<T>& buf, int i, int j, T value) { return atomic_max(&index(buf, i, j), value); }
1131
+ template<template<typename> class A, typename T>
1132
+ inline CUDA_CALLABLE T atomic_max(const A<T>& buf, int i, int j, int k, T value) { return atomic_max(&index(buf, i, j, k), value); }
1133
+ template<template<typename> class A, typename T>
1134
+ inline CUDA_CALLABLE T atomic_max(const A<T>& buf, int i, int j, int k, int l, T value) { return atomic_max(&index(buf, i, j, k, l), value); }
1135
+
1136
+ template<template<typename> class A, typename T>
1137
+ inline CUDA_CALLABLE T atomic_cas(const A<T>& buf, int i, T old_value, T new_value) { return atomic_cas(&index(buf, i), old_value, new_value); }
1138
+ template<template<typename> class A, typename T>
1139
+ inline CUDA_CALLABLE T atomic_cas(const A<T>& buf, int i, int j, T old_value, T new_value) { return atomic_cas(&index(buf, i, j), old_value, new_value); }
1140
+ template<template<typename> class A, typename T>
1141
+ inline CUDA_CALLABLE T atomic_cas(const A<T>& buf, int i, int j, int k, T old_value, T new_value) { return atomic_cas(&index(buf, i, j, k), old_value, new_value); }
1142
+ template<template<typename> class A, typename T>
1143
+ inline CUDA_CALLABLE T atomic_cas(const A<T>& buf, int i, int j, int k, int l, T old_value, T new_value) { return atomic_cas(&index(buf, i, j, k, l), old_value, new_value); }
1144
+
1145
+ template<template<typename> class A, typename T>
1146
+ inline CUDA_CALLABLE T atomic_exch(const A<T>& buf, int i, T value) { return atomic_exch(&index(buf, i), value); }
1147
+ template<template<typename> class A, typename T>
1148
+ inline CUDA_CALLABLE T atomic_exch(const A<T>& buf, int i, int j, T value) { return atomic_exch(&index(buf, i, j), value); }
1149
+ template<template<typename> class A, typename T>
1150
+ inline CUDA_CALLABLE T atomic_exch(const A<T>& buf, int i, int j, int k, T value) { return atomic_exch(&index(buf, i, j, k), value); }
1151
+ template<template<typename> class A, typename T>
1152
+ inline CUDA_CALLABLE T atomic_exch(const A<T>& buf, int i, int j, int k, int l, T value) { return atomic_exch(&index(buf, i, j, k, l), value); }
1153
+
1154
+ template<template<typename> class A, typename T>
1155
+ inline CUDA_CALLABLE T atomic_and(const A<T>& buf, int i, T value) { return atomic_and(&index(buf, i), value); }
1156
+ template<template<typename> class A, typename T>
1157
+ inline CUDA_CALLABLE T atomic_and(const A<T>& buf, int i, int j, T value) { return atomic_and(&index(buf, i, j), value); }
1158
+ template<template<typename> class A, typename T>
1159
+ inline CUDA_CALLABLE T atomic_and(const A<T>& buf, int i, int j, int k, T value) { return atomic_and(&index(buf, i, j, k), value); }
1160
+ template<template<typename> class A, typename T>
1161
+ inline CUDA_CALLABLE T atomic_and(const A<T>& buf, int i, int j, int k, int l, T value) { return atomic_and(&index(buf, i, j, k, l), value); }
1162
+
1163
+ template<template<typename> class A, typename T>
1164
+ inline CUDA_CALLABLE T atomic_or(const A<T>& buf, int i, T value) { return atomic_or(&index(buf, i), value); }
1165
+ template<template<typename> class A, typename T>
1166
+ inline CUDA_CALLABLE T atomic_or(const A<T>& buf, int i, int j, T value) { return atomic_or(&index(buf, i, j), value); }
1167
+ template<template<typename> class A, typename T>
1168
+ inline CUDA_CALLABLE T atomic_or(const A<T>& buf, int i, int j, int k, T value) { return atomic_or(&index(buf, i, j, k), value); }
1169
+ template<template<typename> class A, typename T>
1170
+ inline CUDA_CALLABLE T atomic_or(const A<T>& buf, int i, int j, int k, int l, T value) { return atomic_or(&index(buf, i, j, k, l), value); }
1171
+
1172
+ template<template<typename> class A, typename T>
1173
+ inline CUDA_CALLABLE T atomic_xor(const A<T>& buf, int i, T value) { return atomic_xor(&index(buf, i), value); }
1174
+ template<template<typename> class A, typename T>
1175
+ inline CUDA_CALLABLE T atomic_xor(const A<T>& buf, int i, int j, T value) { return atomic_xor(&index(buf, i, j), value); }
1176
+ template<template<typename> class A, typename T>
1177
+ inline CUDA_CALLABLE T atomic_xor(const A<T>& buf, int i, int j, int k, T value) { return atomic_xor(&index(buf, i, j, k), value); }
1178
+ template<template<typename> class A, typename T>
1179
+ inline CUDA_CALLABLE T atomic_xor(const A<T>& buf, int i, int j, int k, int l, T value) { return atomic_xor(&index(buf, i, j, k, l), value); }
1180
+
1181
+ template<template<typename> class A, typename T>
1182
+ inline CUDA_CALLABLE T* address(const A<T>& buf, int i)
1183
+ {
1184
+ return &index(buf, i); // cppcheck-suppress returnDanglingLifetime
1185
+ }
1186
+ template<template<typename> class A, typename T>
1187
+ inline CUDA_CALLABLE T* address(const A<T>& buf, int i, int j)
1188
+ {
1189
+ return &index(buf, i, j); // cppcheck-suppress returnDanglingLifetime
1190
+ }
1191
+ template<template<typename> class A, typename T>
1192
+ inline CUDA_CALLABLE T* address(const A<T>& buf, int i, int j, int k)
1193
+ {
1194
+ return &index(buf, i, j, k); // cppcheck-suppress returnDanglingLifetime
1195
+ }
1196
+ template<template<typename> class A, typename T>
1197
+ inline CUDA_CALLABLE T* address(const A<T>& buf, int i, int j, int k, int l)
1198
+ {
1199
+ return &index(buf, i, j, k, l); // cppcheck-suppress returnDanglingLifetime
1200
+ }
1201
+
1202
+ template<template<typename> class A, typename T>
1203
+ inline CUDA_CALLABLE void array_store(const A<T>& buf, int i, T value)
1204
+ {
1205
+ FP_VERIFY_FWD_1(value)
1206
+
1207
+ index(buf, i) = value;
1208
+ }
1209
+ template<template<typename> class A, typename T>
1210
+ inline CUDA_CALLABLE void array_store(const A<T>& buf, int i, int j, T value)
1211
+ {
1212
+ FP_VERIFY_FWD_2(value)
1213
+
1214
+ index(buf, i, j) = value;
1215
+ }
1216
+ template<template<typename> class A, typename T>
1217
+ inline CUDA_CALLABLE void array_store(const A<T>& buf, int i, int j, int k, T value)
1218
+ {
1219
+ FP_VERIFY_FWD_3(value)
1220
+
1221
+ index(buf, i, j, k) = value;
1222
+ }
1223
+ template<template<typename> class A, typename T>
1224
+ inline CUDA_CALLABLE void array_store(const A<T>& buf, int i, int j, int k, int l, T value)
1225
+ {
1226
+ FP_VERIFY_FWD_4(value)
1227
+
1228
+ index(buf, i, j, k, l) = value;
1229
+ }
1230
+
1231
+ template<typename T>
1232
+ inline CUDA_CALLABLE void store(T* address, T value)
1233
+ {
1234
+ FP_VERIFY_FWD(value)
1235
+
1236
+ *address = value;
1237
+ }
1238
+
1239
+ template<typename T>
1240
+ inline CUDA_CALLABLE T load(T* address)
1241
+ {
1242
+ T value = *address;
1243
+ FP_VERIFY_FWD(value)
1244
+
1245
+ return value;
1246
+ }
1247
+
1248
+ // where() overload for array condition - returns a if array.data is non-null, otherwise returns b
1249
+ template <typename T1, typename T2>
1250
+ CUDA_CALLABLE inline T2 where(const array_t<T1>& arr, const T2& a, const T2& b) { return arr.data?a:b; }
1251
+
1252
+ template <typename T1, typename T2>
1253
+ CUDA_CALLABLE inline void adj_where(const array_t<T1>& arr, const T2& a, const T2& b, const array_t<T1>& adj_cond, T2& adj_a, T2& adj_b, const T2& adj_ret)
1254
+ {
1255
+ if (arr.data)
1256
+ adj_a += adj_ret;
1257
+ else
1258
+ adj_b += adj_ret;
1259
+ }
1260
+
1261
+ // stub for the case where we have an nested array inside a struct and
1262
+ // atomic add the whole struct onto an array (e.g.: during backwards pass)
1263
+ template <typename T>
1264
+ CUDA_CALLABLE inline void atomic_add(array_t<T>*, array_t<T>) {}
1265
+
1266
+ // for float and vector types this is just an alias for an atomic add
1267
+ template <typename T>
1268
+ CUDA_CALLABLE inline void adj_atomic_add(T* buf, T value) { atomic_add(buf, value); }
1269
+
1270
+
1271
+ // for integral types we do not accumulate gradients
1272
+ CUDA_CALLABLE inline void adj_atomic_add(int8* buf, int8 value) { }
1273
+ CUDA_CALLABLE inline void adj_atomic_add(uint8* buf, uint8 value) { }
1274
+ CUDA_CALLABLE inline void adj_atomic_add(int16* buf, int16 value) { }
1275
+ CUDA_CALLABLE inline void adj_atomic_add(uint16* buf, uint16 value) { }
1276
+ CUDA_CALLABLE inline void adj_atomic_add(int32* buf, int32 value) { }
1277
+ CUDA_CALLABLE inline void adj_atomic_add(uint32* buf, uint32 value) { }
1278
+ CUDA_CALLABLE inline void adj_atomic_add(int64* buf, int64 value) { }
1279
+ CUDA_CALLABLE inline void adj_atomic_add(uint64* buf, uint64 value) { }
1280
+
1281
+ CUDA_CALLABLE inline void adj_atomic_add(bool* buf, bool value) { }
1282
+
1283
+ // only generate gradients for T types
1284
+ template<typename T>
1285
+ inline CUDA_CALLABLE void adj_address(const array_t<T>& buf, int i, const array_t<T>& adj_buf, int adj_i, const T& adj_output)
1286
+ {
1287
+ if (adj_buf.data)
1288
+ adj_atomic_add(&index(adj_buf, i), adj_output);
1289
+ else if (buf.grad)
1290
+ adj_atomic_add(&index_grad(buf, i), adj_output);
1291
+ }
1292
+ template<typename T>
1293
+ inline CUDA_CALLABLE void adj_address(const array_t<T>& buf, int i, int j, const array_t<T>& adj_buf, int adj_i, int adj_j, const T& adj_output)
1294
+ {
1295
+ if (adj_buf.data)
1296
+ adj_atomic_add(&index(adj_buf, i, j), adj_output);
1297
+ else if (buf.grad)
1298
+ adj_atomic_add(&index_grad(buf, i, j), adj_output);
1299
+ }
1300
+ template<typename T>
1301
+ inline CUDA_CALLABLE void adj_address(const array_t<T>& buf, int i, int j, int k, const array_t<T>& adj_buf, int adj_i, int adj_j, int adj_k, const T& adj_output)
1302
+ {
1303
+ if (adj_buf.data)
1304
+ adj_atomic_add(&index(adj_buf, i, j, k), adj_output);
1305
+ else if (buf.grad)
1306
+ adj_atomic_add(&index_grad(buf, i, j, k), adj_output);
1307
+ }
1308
+ template<typename T>
1309
+ inline CUDA_CALLABLE void adj_address(const array_t<T>& buf, int i, int j, int k, int l, const array_t<T>& adj_buf, int adj_i, int adj_j, int adj_k, int adj_l, const T& adj_output)
1310
+ {
1311
+ if (adj_buf.data)
1312
+ adj_atomic_add(&index(adj_buf, i, j, k, l), adj_output);
1313
+ else if (buf.grad)
1314
+ adj_atomic_add(&index_grad(buf, i, j, k, l), adj_output);
1315
+ }
1316
+
1317
+ template<typename T>
1318
+ inline CUDA_CALLABLE void adj_array_store(const array_t<T>& buf, int i, T value, const array_t<T>& adj_buf, int adj_i, T& adj_value)
1319
+ {
1320
+ if (adj_buf.data)
1321
+ adj_value += index(adj_buf, i);
1322
+ else if (buf.grad)
1323
+ adj_value += index_grad(buf, i);
1324
+
1325
+ FP_VERIFY_ADJ_1(value, adj_value)
1326
+ }
1327
+ template<typename T>
1328
+ inline CUDA_CALLABLE void adj_array_store(const array_t<T>& buf, int i, int j, T value, const array_t<T>& adj_buf, int adj_i, int adj_j, T& adj_value)
1329
+ {
1330
+ if (adj_buf.data)
1331
+ adj_value += index(adj_buf, i, j);
1332
+ else if (buf.grad)
1333
+ adj_value += index_grad(buf, i, j);
1334
+
1335
+ FP_VERIFY_ADJ_2(value, adj_value)
1336
+ }
1337
+ template<typename T>
1338
+ inline CUDA_CALLABLE void adj_array_store(const array_t<T>& buf, int i, int j, int k, T value, const array_t<T>& adj_buf, int adj_i, int adj_j, int adj_k, T& adj_value)
1339
+ {
1340
+ if (adj_buf.data)
1341
+ adj_value += index(adj_buf, i, j, k);
1342
+ else if (buf.grad)
1343
+ adj_value += index_grad(buf, i, j, k);
1344
+
1345
+ FP_VERIFY_ADJ_3(value, adj_value)
1346
+ }
1347
+ template<typename T>
1348
+ inline CUDA_CALLABLE void adj_array_store(const array_t<T>& buf, int i, int j, int k, int l, T value, const array_t<T>& adj_buf, int adj_i, int adj_j, int adj_k, int adj_l, T& adj_value)
1349
+ {
1350
+ if (adj_buf.data)
1351
+ adj_value += index(adj_buf, i, j, k, l);
1352
+ else if (buf.grad)
1353
+ adj_value += index_grad(buf, i, j, k, l);
1354
+
1355
+ FP_VERIFY_ADJ_4(value, adj_value)
1356
+ }
1357
+
1358
+ template<typename T>
1359
+ inline CUDA_CALLABLE void adj_store(const T* address, T value, const T& adj_address, T& adj_value)
1360
+ {
1361
+ // nop; generic store() operations are not differentiable, only array_store() is
1362
+ FP_VERIFY_ADJ(value, adj_value)
1363
+ }
1364
+
1365
+ template<typename T>
1366
+ inline CUDA_CALLABLE void adj_load(const T* address, const T& adj_address, T& adj_value)
1367
+ {
1368
+ // nop; generic load() operations are not differentiable
1369
+ }
1370
+
1371
+ template<typename T>
1372
+ inline CUDA_CALLABLE void adj_atomic_add(const array_t<T>& buf, int i, T value, const array_t<T>& adj_buf, int adj_i, T& adj_value, const T& adj_ret)
1373
+ {
1374
+ if (adj_buf.data)
1375
+ adj_value += index(adj_buf, i);
1376
+ else if (buf.grad)
1377
+ adj_value += index_grad(buf, i);
1378
+
1379
+ FP_VERIFY_ADJ_1(value, adj_value)
1380
+ }
1381
+ template<typename T>
1382
+ inline CUDA_CALLABLE void adj_atomic_add(const array_t<T>& buf, int i, int j, T value, const array_t<T>& adj_buf, int adj_i, int adj_j, T& adj_value, const T& adj_ret)
1383
+ {
1384
+ if (adj_buf.data)
1385
+ adj_value += index(adj_buf, i, j);
1386
+ else if (buf.grad)
1387
+ adj_value += index_grad(buf, i, j);
1388
+
1389
+ FP_VERIFY_ADJ_2(value, adj_value)
1390
+ }
1391
+ template<typename T>
1392
+ inline CUDA_CALLABLE void adj_atomic_add(const array_t<T>& buf, int i, int j, int k, T value, const array_t<T>& adj_buf, int adj_i, int adj_j, int adj_k, T& adj_value, const T& adj_ret)
1393
+ {
1394
+ if (adj_buf.data)
1395
+ adj_value += index(adj_buf, i, j, k);
1396
+ else if (buf.grad)
1397
+ adj_value += index_grad(buf, i, j, k);
1398
+
1399
+ FP_VERIFY_ADJ_3(value, adj_value)
1400
+ }
1401
+ template<typename T>
1402
+ inline CUDA_CALLABLE void adj_atomic_add(const array_t<T>& buf, int i, int j, int k, int l, T value, const array_t<T>& adj_buf, int adj_i, int adj_j, int adj_k, int adj_l, T& adj_value, const T& adj_ret)
1403
+ {
1404
+ if (adj_buf.data)
1405
+ adj_value += index(adj_buf, i, j, k, l);
1406
+ else if (buf.grad)
1407
+ adj_value += index_grad(buf, i, j, k, l);
1408
+
1409
+ FP_VERIFY_ADJ_4(value, adj_value)
1410
+ }
1411
+
1412
+ template<typename T>
1413
+ inline CUDA_CALLABLE void adj_atomic_sub(const array_t<T>& buf, int i, T value, const array_t<T>& adj_buf, int adj_i, T& adj_value, const T& adj_ret)
1414
+ {
1415
+ if (adj_buf.data)
1416
+ adj_value -= index(adj_buf, i);
1417
+ else if (buf.grad)
1418
+ adj_value -= index_grad(buf, i);
1419
+
1420
+ FP_VERIFY_ADJ_1(value, adj_value)
1421
+ }
1422
+ template<typename T>
1423
+ inline CUDA_CALLABLE void adj_atomic_sub(const array_t<T>& buf, int i, int j, T value, const array_t<T>& adj_buf, int adj_i, int adj_j, T& adj_value, const T& adj_ret)
1424
+ {
1425
+ if (adj_buf.data)
1426
+ adj_value -= index(adj_buf, i, j);
1427
+ else if (buf.grad)
1428
+ adj_value -= index_grad(buf, i, j);
1429
+
1430
+ FP_VERIFY_ADJ_2(value, adj_value)
1431
+ }
1432
+ template<typename T>
1433
+ inline CUDA_CALLABLE void adj_atomic_sub(const array_t<T>& buf, int i, int j, int k, T value, const array_t<T>& adj_buf, int adj_i, int adj_j, int adj_k, T& adj_value, const T& adj_ret)
1434
+ {
1435
+ if (adj_buf.data)
1436
+ adj_value -= index(adj_buf, i, j, k);
1437
+ else if (buf.grad)
1438
+ adj_value -= index_grad(buf, i, j, k);
1439
+
1440
+ FP_VERIFY_ADJ_3(value, adj_value)
1441
+ }
1442
+ template<typename T>
1443
+ inline CUDA_CALLABLE void adj_atomic_sub(const array_t<T>& buf, int i, int j, int k, int l, T value, const array_t<T>& adj_buf, int adj_i, int adj_j, int adj_k, int adj_l, T& adj_value, const T& adj_ret)
1444
+ {
1445
+ if (adj_buf.data)
1446
+ adj_value -= index(adj_buf, i, j, k, l);
1447
+ else if (buf.grad)
1448
+ adj_value -= index_grad(buf, i, j, k, l);
1449
+
1450
+ FP_VERIFY_ADJ_4(value, adj_value)
1451
+ }
1452
+
1453
+ // generic array types that do not support gradient computation (indexedarray, etc.)
1454
+ template<template<typename> class A1, template<typename> class A2, typename T>
1455
+ inline CUDA_CALLABLE void adj_address(const A1<T>& buf, int i, const A2<T>& adj_buf, int adj_i, const T& adj_output) {}
1456
+ template<template<typename> class A1, template<typename> class A2, typename T>
1457
+ inline CUDA_CALLABLE void adj_address(const A1<T>& buf, int i, int j, const A2<T>& adj_buf, int adj_i, int adj_j, const T& adj_output) {}
1458
+ template<template<typename> class A1, template<typename> class A2, typename T>
1459
+ inline CUDA_CALLABLE void adj_address(const A1<T>& buf, int i, int j, int k, const A2<T>& adj_buf, int adj_i, int adj_j, int adj_k, const T& adj_output) {}
1460
+ template<template<typename> class A1, template<typename> class A2, typename T>
1461
+ inline CUDA_CALLABLE void adj_address(const A1<T>& buf, int i, int j, int k, int l, const A2<T>& adj_buf, int adj_i, int adj_j, int adj_k, int adj_l, const T& adj_output) {}
1462
+
1463
+ template<template<typename> class A1, template<typename> class A2, typename T>
1464
+ inline CUDA_CALLABLE void adj_array_store(const A1<T>& buf, int i, T value, const A2<T>& adj_buf, int adj_i, T& adj_value) {}
1465
+ template<template<typename> class A1, template<typename> class A2, typename T>
1466
+ inline CUDA_CALLABLE void adj_array_store(const A1<T>& buf, int i, int j, T value, const A2<T>& adj_buf, int adj_i, int adj_j, T& adj_value) {}
1467
+ template<template<typename> class A1, template<typename> class A2, typename T>
1468
+ inline CUDA_CALLABLE void adj_array_store(const A1<T>& buf, int i, int j, int k, T value, const A2<T>& adj_buf, int adj_i, int adj_j, int adj_k, T& adj_value) {}
1469
+ template<template<typename> class A1, template<typename> class A2, typename T>
1470
+ inline CUDA_CALLABLE void adj_array_store(const A1<T>& buf, int i, int j, int k, int l, T value, const A2<T>& adj_buf, int adj_i, int adj_j, int adj_k, int adj_l, T& adj_value) {}
1471
+
1472
+ template<template<typename> class A1, template<typename> class A2, typename T>
1473
+ inline CUDA_CALLABLE void adj_atomic_add(const A1<T>& buf, int i, T value, const A2<T>& adj_buf, int adj_i, T& adj_value, const T& adj_ret) {}
1474
+ template<template<typename> class A1, template<typename> class A2, typename T>
1475
+ inline CUDA_CALLABLE void adj_atomic_add(const A1<T>& buf, int i, int j, T value, const A2<T>& adj_buf, int adj_i, int adj_j, T& adj_value, const T& adj_ret) {}
1476
+ template<template<typename> class A1, template<typename> class A2, typename T>
1477
+ inline CUDA_CALLABLE void adj_atomic_add(const A1<T>& buf, int i, int j, int k, T value, const A2<T>& adj_buf, int adj_i, int adj_j, int adj_k, T& adj_value, const T& adj_ret) {}
1478
+ template<template<typename> class A1, template<typename> class A2, typename T>
1479
+ inline CUDA_CALLABLE void adj_atomic_add(const A1<T>& buf, int i, int j, int k, int l, T value, const A2<T>& adj_buf, int adj_i, int adj_j, int adj_k, int adj_l, T& adj_value, const T& adj_ret) {}
1480
+
1481
+ template<template<typename> class A1, template<typename> class A2, typename T>
1482
+ inline CUDA_CALLABLE void adj_atomic_sub(const A1<T>& buf, int i, T value, const A2<T>& adj_buf, int adj_i, T& adj_value, const T& adj_ret) {}
1483
+ template<template<typename> class A1, template<typename> class A2, typename T>
1484
+ inline CUDA_CALLABLE void adj_atomic_sub(const A1<T>& buf, int i, int j, T value, const A2<T>& adj_buf, int adj_i, int adj_j, T& adj_value, const T& adj_ret) {}
1485
+ template<template<typename> class A1, template<typename> class A2, typename T>
1486
+ inline CUDA_CALLABLE void adj_atomic_sub(const A1<T>& buf, int i, int j, int k, T value, const A2<T>& adj_buf, int adj_i, int adj_j, int adj_k, T& adj_value, const T& adj_ret) {}
1487
+ template<template<typename> class A1, template<typename> class A2, typename T>
1488
+ inline CUDA_CALLABLE void adj_atomic_sub(const A1<T>& buf, int i, int j, int k, int l, T value, const A2<T>& adj_buf, int adj_i, int adj_j, int adj_k, int adj_l, T& adj_value, const T& adj_ret) {}
1489
+
1490
+ // generic handler for scalar values
1491
+ template<template<typename> class A1, template<typename> class A2, typename T>
1492
+ inline CUDA_CALLABLE void adj_atomic_min(const A1<T>& buf, int i, T value, const A2<T>& adj_buf, int adj_i, T& adj_value, const T& adj_ret) {
1493
+ if (adj_buf.data)
1494
+ adj_atomic_minmax(&index(buf, i), &index(adj_buf, i), value, adj_value);
1495
+ else if (buf.grad)
1496
+ adj_atomic_minmax(&index(buf, i), &index_grad(buf, i), value, adj_value);
1497
+
1498
+ FP_VERIFY_ADJ_1(value, adj_value)
1499
+ }
1500
+ template<template<typename> class A1, template<typename> class A2, typename T>
1501
+ inline CUDA_CALLABLE void adj_atomic_min(const A1<T>& buf, int i, int j, T value, const A2<T>& adj_buf, int adj_i, int adj_j, T& adj_value, const T& adj_ret) {
1502
+ if (adj_buf.data)
1503
+ adj_atomic_minmax(&index(buf, i, j), &index(adj_buf, i, j), value, adj_value);
1504
+ else if (buf.grad)
1505
+ adj_atomic_minmax(&index(buf, i, j), &index_grad(buf, i, j), value, adj_value);
1506
+
1507
+ FP_VERIFY_ADJ_2(value, adj_value)
1508
+ }
1509
+ template<template<typename> class A1, template<typename> class A2, typename T>
1510
+ inline CUDA_CALLABLE void adj_atomic_min(const A1<T>& buf, int i, int j, int k, T value, const A2<T>& adj_buf, int adj_i, int adj_j, int adj_k, T& adj_value, const T& adj_ret) {
1511
+ if (adj_buf.data)
1512
+ adj_atomic_minmax(&index(buf, i, j, k), &index(adj_buf, i, j, k), value, adj_value);
1513
+ else if (buf.grad)
1514
+ adj_atomic_minmax(&index(buf, i, j, k), &index_grad(buf, i, j, k), value, adj_value);
1515
+
1516
+ FP_VERIFY_ADJ_3(value, adj_value)
1517
+ }
1518
+ template<template<typename> class A1, template<typename> class A2, typename T>
1519
+ inline CUDA_CALLABLE void adj_atomic_min(const A1<T>& buf, int i, int j, int k, int l, T value, const A2<T>& adj_buf, int adj_i, int adj_j, int adj_k, int adj_l, T& adj_value, const T& adj_ret) {
1520
+ if (adj_buf.data)
1521
+ adj_atomic_minmax(&index(buf, i, j, k, l), &index(adj_buf, i, j, k, l), value, adj_value);
1522
+ else if (buf.grad)
1523
+ adj_atomic_minmax(&index(buf, i, j, k, l), &index_grad(buf, i, j, k, l), value, adj_value);
1524
+
1525
+ FP_VERIFY_ADJ_4(value, adj_value)
1526
+ }
1527
+
1528
+ template<template<typename> class A1, template<typename> class A2, typename T>
1529
+ inline CUDA_CALLABLE void adj_atomic_max(const A1<T>& buf, int i, T value, const A2<T>& adj_buf, int adj_i, T& adj_value, const T& adj_ret) {
1530
+ if (adj_buf.data)
1531
+ adj_atomic_minmax(&index(buf, i), &index(adj_buf, i), value, adj_value);
1532
+ else if (buf.grad)
1533
+ adj_atomic_minmax(&index(buf, i), &index_grad(buf, i), value, adj_value);
1534
+
1535
+ FP_VERIFY_ADJ_1(value, adj_value)
1536
+ }
1537
+ template<template<typename> class A1, template<typename> class A2, typename T>
1538
+ inline CUDA_CALLABLE void adj_atomic_max(const A1<T>& buf, int i, int j, T value, const A2<T>& adj_buf, int adj_i, int adj_j, T& adj_value, const T& adj_ret) {
1539
+ if (adj_buf.data)
1540
+ adj_atomic_minmax(&index(buf, i, j), &index(adj_buf, i, j), value, adj_value);
1541
+ else if (buf.grad)
1542
+ adj_atomic_minmax(&index(buf, i, j), &index_grad(buf, i, j), value, adj_value);
1543
+
1544
+ FP_VERIFY_ADJ_2(value, adj_value)
1545
+ }
1546
+ template<template<typename> class A1, template<typename> class A2, typename T>
1547
+ inline CUDA_CALLABLE void adj_atomic_max(const A1<T>& buf, int i, int j, int k, T value, const A2<T>& adj_buf, int adj_i, int adj_j, int adj_k, T& adj_value, const T& adj_ret) {
1548
+ if (adj_buf.data)
1549
+ adj_atomic_minmax(&index(buf, i, j, k), &index(adj_buf, i, j, k), value, adj_value);
1550
+ else if (buf.grad)
1551
+ adj_atomic_minmax(&index(buf, i, j, k), &index_grad(buf, i, j, k), value, adj_value);
1552
+
1553
+ FP_VERIFY_ADJ_3(value, adj_value)
1554
+ }
1555
+ template<template<typename> class A1, template<typename> class A2, typename T>
1556
+ inline CUDA_CALLABLE void adj_atomic_max(const A1<T>& buf, int i, int j, int k, int l, T value, const A2<T>& adj_buf, int adj_i, int adj_j, int adj_k, int adj_l, T& adj_value, const T& adj_ret) {
1557
+ if (adj_buf.data)
1558
+ adj_atomic_minmax(&index(buf, i, j, k, l), &index(adj_buf, i, j, k, l), value, adj_value);
1559
+ else if (buf.grad)
1560
+ adj_atomic_minmax(&index(buf, i, j, k, l), &index_grad(buf, i, j, k, l), value, adj_value);
1561
+
1562
+ FP_VERIFY_ADJ_4(value, adj_value)
1563
+ }
1564
+
1565
+ template<template<typename> class A1, template<typename> class A2, typename T>
1566
+ inline CUDA_CALLABLE void adj_atomic_cas(const A1<T>& buf, int i, T compare, T value, const A2<T>& adj_buf, int adj_i, T& adj_compare, T& adj_value, const T& adj_ret) {
1567
+ if (adj_buf.data)
1568
+ adj_atomic_cas(&index(buf, i), compare, value, &index(adj_buf, i), adj_compare, adj_value, adj_ret);
1569
+ else if (buf.grad)
1570
+ adj_atomic_cas(&index(buf, i), compare, value, &index_grad(buf, i), adj_compare, adj_value, adj_ret);
1571
+
1572
+ FP_VERIFY_ADJ_1(value, adj_value)
1573
+ }
1574
+
1575
+ template<template<typename> class A1, template<typename> class A2, typename T>
1576
+ inline CUDA_CALLABLE void adj_atomic_cas(const A1<T>& buf, int i, int j, T compare, T value, const A2<T>& adj_buf, int adj_i, int adj_j, T& adj_compare, T& adj_value, const T& adj_ret) {
1577
+ if (adj_buf.data)
1578
+ adj_atomic_cas(&index(buf, i, j), compare, value, &index(adj_buf, i, j), adj_compare, adj_value, adj_ret);
1579
+ else if (buf.grad)
1580
+ adj_atomic_cas(&index(buf, i, j), compare, value, &index_grad(buf, i, j), adj_compare, adj_value, adj_ret);
1581
+
1582
+ FP_VERIFY_ADJ_2(value, adj_value)
1583
+ }
1584
+
1585
+ template<template<typename> class A1, template<typename> class A2, typename T>
1586
+ inline CUDA_CALLABLE void adj_atomic_cas(const A1<T>& buf, int i, int j, int k, T compare, T value, const A2<T>& adj_buf, int adj_i, int adj_j, int adj_k, T& adj_compare, T& adj_value, const T& adj_ret) {
1587
+ if (adj_buf.data)
1588
+ adj_atomic_cas(&index(buf, i, j, k), compare, value, &index(adj_buf, i, j, k), adj_compare, adj_value, adj_ret);
1589
+ else if (buf.grad)
1590
+ adj_atomic_cas(&index(buf, i, j, k), compare, value, &index_grad(buf, i, j, k), adj_compare, adj_value, adj_ret);
1591
+
1592
+ FP_VERIFY_ADJ_3(value, adj_value)
1593
+ }
1594
+
1595
+ template<template<typename> class A1, template<typename> class A2, typename T>
1596
+ inline CUDA_CALLABLE void adj_atomic_cas(const A1<T>& buf, int i, int j, int k, int l, T compare, T value, const A2<T>& adj_buf, int adj_i, int adj_j, int adj_k, int adj_l, T& adj_compare, T& adj_value, const T& adj_ret) {
1597
+ if (adj_buf.data)
1598
+ adj_atomic_cas(&index(buf, i, j, k, l), compare, value, &index(adj_buf, i, j, k, l), adj_compare, adj_value, adj_ret);
1599
+ else if (buf.grad)
1600
+ adj_atomic_cas(&index(buf, i, j, k, l), compare, value, &index_grad(buf, i, j, k, l), adj_compare, adj_value, adj_ret);
1601
+
1602
+ FP_VERIFY_ADJ_4(value, adj_value)
1603
+ }
1604
+
1605
+ template<template<typename> class A1, template<typename> class A2, typename T>
1606
+ inline CUDA_CALLABLE void adj_atomic_exch(const A1<T>& buf, int i, T value, const A2<T>& adj_buf, int adj_i, T& adj_value, const T& adj_ret) {
1607
+ if (adj_buf.data)
1608
+ adj_atomic_exch(&index(buf, i), value, &index(adj_buf, i), adj_value, adj_ret);
1609
+ else if (buf.grad)
1610
+ adj_atomic_exch(&index(buf, i), value, &index_grad(buf, i), adj_value, adj_ret);
1611
+
1612
+ FP_VERIFY_ADJ_1(value, adj_value)
1613
+ }
1614
+
1615
+ template<template<typename> class A1, template<typename> class A2, typename T>
1616
+ inline CUDA_CALLABLE void adj_atomic_exch(const A1<T>& buf, int i, int j, T value, const A2<T>& adj_buf, int adj_i, int adj_j, T& adj_value, const T& adj_ret) {
1617
+ if (adj_buf.data)
1618
+ adj_atomic_exch(&index(buf, i, j), value, &index(adj_buf, i, j), adj_value, adj_ret);
1619
+ else if (buf.grad)
1620
+ adj_atomic_exch(&index(buf, i, j), value, &index_grad(buf, i, j), adj_value, adj_ret);
1621
+
1622
+ FP_VERIFY_ADJ_2(value, adj_value)
1623
+ }
1624
+
1625
+ template<template<typename> class A1, template<typename> class A2, typename T>
1626
+ inline CUDA_CALLABLE void adj_atomic_exch(const A1<T>& buf, int i, int j, int k, T value, const A2<T>& adj_buf, int adj_i, int adj_j, int adj_k, T& adj_value, const T& adj_ret) {
1627
+ if (adj_buf.data)
1628
+ adj_atomic_exch(&index(buf, i, j, k), value, &index(adj_buf, i, j, k), adj_value, adj_ret);
1629
+ else if (buf.grad)
1630
+ adj_atomic_exch(&index(buf, i, j, k), value, &index_grad(buf, i, j, k), adj_value, adj_ret);
1631
+
1632
+ FP_VERIFY_ADJ_3(value, adj_value)
1633
+ }
1634
+
1635
+ template<template<typename> class A1, template<typename> class A2, typename T>
1636
+ inline CUDA_CALLABLE void adj_atomic_exch(const A1<T>& buf, int i, int j, int k, int l, T value, const A2<T>& adj_buf, int adj_i, int adj_j, int adj_k, int adj_l, T& adj_value, const T& adj_ret) {
1637
+ if (adj_buf.data)
1638
+ adj_atomic_exch(&index(buf, i, j, k, l), value, &index(adj_buf, i, j, k, l), adj_value, adj_ret);
1639
+ else if (buf.grad)
1640
+ adj_atomic_exch(&index(buf, i, j, k, l), value, &index_grad(buf, i, j, k, l), adj_value, adj_ret);
1641
+
1642
+ FP_VERIFY_ADJ_4(value, adj_value)
1643
+ }
1644
+
1645
+ // for bitwise operations we do not accumulate gradients
1646
+ template<template<typename> class A1, template<typename> class A2, typename T>
1647
+ inline CUDA_CALLABLE void adj_atomic_and(const A1<T>& buf, int i, T value, const A2<T>& adj_buf, int adj_i, T& adj_value, const T& adj_ret) {}
1648
+ template<template<typename> class A1, template<typename> class A2, typename T>
1649
+ inline CUDA_CALLABLE void adj_atomic_and(const A1<T>& buf, int i, int j, T value, const A2<T>& adj_buf, int adj_i, int adj_j, T& adj_value, const T& adj_ret) {}
1650
+ template<template<typename> class A1, template<typename> class A2, typename T>
1651
+ inline CUDA_CALLABLE void adj_atomic_and(const A1<T>& buf, int i, int j, int k, T value, const A2<T>& adj_buf, int adj_i, int adj_j, int adj_k, T& adj_value, const T& adj_ret) {}
1652
+ template<template<typename> class A1, template<typename> class A2, typename T>
1653
+ inline CUDA_CALLABLE void adj_atomic_and(const A1<T>& buf, int i, int j, int k, int l, T value, const A2<T>& adj_buf, int adj_i, int adj_j, int adj_k, int adj_l, T& adj_value, const T& adj_ret) {}
1654
+
1655
+ template<template<typename> class A1, template<typename> class A2, typename T>
1656
+ inline CUDA_CALLABLE void adj_atomic_or(const A1<T>& buf, int i, T value, const A2<T>& adj_buf, int adj_i, T& adj_value, const T& adj_ret) {}
1657
+ template<template<typename> class A1, template<typename> class A2, typename T>
1658
+ inline CUDA_CALLABLE void adj_atomic_or(const A1<T>& buf, int i, int j, T value, const A2<T>& adj_buf, int adj_i, int adj_j, T& adj_value, const T& adj_ret) {}
1659
+ template<template<typename> class A1, template<typename> class A2, typename T>
1660
+ inline CUDA_CALLABLE void adj_atomic_or(const A1<T>& buf, int i, int j, int k, T value, const A2<T>& adj_buf, int adj_i, int adj_j, int adj_k, T& adj_value, const T& adj_ret) {}
1661
+ template<template<typename> class A1, template<typename> class A2, typename T>
1662
+ inline CUDA_CALLABLE void adj_atomic_or(const A1<T>& buf, int i, int j, int k, int l, T value, const A2<T>& adj_buf, int adj_i, int adj_j, int adj_k, int adj_l, T& adj_value, const T& adj_ret) {}
1663
+
1664
+ template<template<typename> class A1, template<typename> class A2, typename T>
1665
+ inline CUDA_CALLABLE void adj_atomic_xor(const A1<T>& buf, int i, T value, const A2<T>& adj_buf, int adj_i, T& adj_value, const T& adj_ret) {}
1666
+ template<template<typename> class A1, template<typename> class A2, typename T>
1667
+ inline CUDA_CALLABLE void adj_atomic_xor(const A1<T>& buf, int i, int j, T value, const A2<T>& adj_buf, int adj_i, int adj_j, T& adj_value, const T& adj_ret) {}
1668
+ template<template<typename> class A1, template<typename> class A2, typename T>
1669
+ inline CUDA_CALLABLE void adj_atomic_xor(const A1<T>& buf, int i, int j, int k, T value, const A2<T>& adj_buf, int adj_i, int adj_j, int adj_k, T& adj_value, const T& adj_ret) {}
1670
+ template<template<typename> class A1, template<typename> class A2, typename T>
1671
+ inline CUDA_CALLABLE void adj_atomic_xor(const A1<T>& buf, int i, int j, int k, int l, T value, const A2<T>& adj_buf, int adj_i, int adj_j, int adj_k, int adj_l, T& adj_value, const T& adj_ret) {}
1672
+
1673
+
1674
+ template<template<typename> class A, typename T>
1675
+ CUDA_CALLABLE inline int len(const A<T>& a)
1676
+ {
1677
+ return a.shape[0];
1678
+ }
1679
+
1680
+ template<template<typename> class A, typename T>
1681
+ CUDA_CALLABLE inline void adj_len(const A<T>& a, A<T>& adj_a, int& adj_ret)
1682
+ {
1683
+ }
1684
+
1685
+ } // namespace wp
1686
+
1687
+ #include "fabric.h"