warp-lang 1.7.0__py3-none-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (429) hide show
  1. warp/__init__.py +139 -0
  2. warp/__init__.pyi +1 -0
  3. warp/autograd.py +1142 -0
  4. warp/bin/warp-clang.so +0 -0
  5. warp/bin/warp.so +0 -0
  6. warp/build.py +557 -0
  7. warp/build_dll.py +405 -0
  8. warp/builtins.py +6855 -0
  9. warp/codegen.py +3969 -0
  10. warp/config.py +158 -0
  11. warp/constants.py +57 -0
  12. warp/context.py +6812 -0
  13. warp/dlpack.py +462 -0
  14. warp/examples/__init__.py +24 -0
  15. warp/examples/assets/bear.usd +0 -0
  16. warp/examples/assets/bunny.usd +0 -0
  17. warp/examples/assets/cartpole.urdf +110 -0
  18. warp/examples/assets/crazyflie.usd +0 -0
  19. warp/examples/assets/cube.usd +0 -0
  20. warp/examples/assets/nonuniform.usd +0 -0
  21. warp/examples/assets/nv_ant.xml +92 -0
  22. warp/examples/assets/nv_humanoid.xml +183 -0
  23. warp/examples/assets/nvidia_logo.png +0 -0
  24. warp/examples/assets/pixel.jpg +0 -0
  25. warp/examples/assets/quadruped.urdf +268 -0
  26. warp/examples/assets/rocks.nvdb +0 -0
  27. warp/examples/assets/rocks.usd +0 -0
  28. warp/examples/assets/sphere.usd +0 -0
  29. warp/examples/assets/square_cloth.usd +0 -0
  30. warp/examples/benchmarks/benchmark_api.py +389 -0
  31. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  32. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  33. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  34. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  35. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  36. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  37. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  38. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  39. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  40. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  41. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  42. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  43. warp/examples/benchmarks/benchmark_launches.py +301 -0
  44. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  45. warp/examples/browse.py +37 -0
  46. warp/examples/core/example_cupy.py +86 -0
  47. warp/examples/core/example_dem.py +241 -0
  48. warp/examples/core/example_fluid.py +299 -0
  49. warp/examples/core/example_graph_capture.py +150 -0
  50. warp/examples/core/example_marching_cubes.py +194 -0
  51. warp/examples/core/example_mesh.py +180 -0
  52. warp/examples/core/example_mesh_intersect.py +211 -0
  53. warp/examples/core/example_nvdb.py +182 -0
  54. warp/examples/core/example_raycast.py +111 -0
  55. warp/examples/core/example_raymarch.py +205 -0
  56. warp/examples/core/example_render_opengl.py +193 -0
  57. warp/examples/core/example_sample_mesh.py +300 -0
  58. warp/examples/core/example_sph.py +411 -0
  59. warp/examples/core/example_torch.py +211 -0
  60. warp/examples/core/example_wave.py +269 -0
  61. warp/examples/fem/example_adaptive_grid.py +286 -0
  62. warp/examples/fem/example_apic_fluid.py +423 -0
  63. warp/examples/fem/example_burgers.py +261 -0
  64. warp/examples/fem/example_convection_diffusion.py +178 -0
  65. warp/examples/fem/example_convection_diffusion_dg.py +204 -0
  66. warp/examples/fem/example_deformed_geometry.py +172 -0
  67. warp/examples/fem/example_diffusion.py +196 -0
  68. warp/examples/fem/example_diffusion_3d.py +225 -0
  69. warp/examples/fem/example_diffusion_mgpu.py +220 -0
  70. warp/examples/fem/example_distortion_energy.py +228 -0
  71. warp/examples/fem/example_magnetostatics.py +240 -0
  72. warp/examples/fem/example_mixed_elasticity.py +291 -0
  73. warp/examples/fem/example_navier_stokes.py +261 -0
  74. warp/examples/fem/example_nonconforming_contact.py +298 -0
  75. warp/examples/fem/example_stokes.py +213 -0
  76. warp/examples/fem/example_stokes_transfer.py +262 -0
  77. warp/examples/fem/example_streamlines.py +352 -0
  78. warp/examples/fem/utils.py +1000 -0
  79. warp/examples/interop/example_jax_callable.py +116 -0
  80. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  81. warp/examples/interop/example_jax_kernel.py +205 -0
  82. warp/examples/optim/example_bounce.py +266 -0
  83. warp/examples/optim/example_cloth_throw.py +228 -0
  84. warp/examples/optim/example_diffray.py +561 -0
  85. warp/examples/optim/example_drone.py +870 -0
  86. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  87. warp/examples/optim/example_inverse_kinematics.py +182 -0
  88. warp/examples/optim/example_inverse_kinematics_torch.py +191 -0
  89. warp/examples/optim/example_softbody_properties.py +400 -0
  90. warp/examples/optim/example_spring_cage.py +245 -0
  91. warp/examples/optim/example_trajectory.py +227 -0
  92. warp/examples/sim/example_cartpole.py +143 -0
  93. warp/examples/sim/example_cloth.py +225 -0
  94. warp/examples/sim/example_cloth_self_contact.py +322 -0
  95. warp/examples/sim/example_granular.py +130 -0
  96. warp/examples/sim/example_granular_collision_sdf.py +202 -0
  97. warp/examples/sim/example_jacobian_ik.py +244 -0
  98. warp/examples/sim/example_particle_chain.py +124 -0
  99. warp/examples/sim/example_quadruped.py +203 -0
  100. warp/examples/sim/example_rigid_chain.py +203 -0
  101. warp/examples/sim/example_rigid_contact.py +195 -0
  102. warp/examples/sim/example_rigid_force.py +133 -0
  103. warp/examples/sim/example_rigid_gyroscopic.py +115 -0
  104. warp/examples/sim/example_rigid_soft_contact.py +140 -0
  105. warp/examples/sim/example_soft_body.py +196 -0
  106. warp/examples/tile/example_tile_cholesky.py +87 -0
  107. warp/examples/tile/example_tile_convolution.py +66 -0
  108. warp/examples/tile/example_tile_fft.py +55 -0
  109. warp/examples/tile/example_tile_filtering.py +113 -0
  110. warp/examples/tile/example_tile_matmul.py +85 -0
  111. warp/examples/tile/example_tile_mlp.py +383 -0
  112. warp/examples/tile/example_tile_nbody.py +199 -0
  113. warp/examples/tile/example_tile_walker.py +327 -0
  114. warp/fabric.py +355 -0
  115. warp/fem/__init__.py +106 -0
  116. warp/fem/adaptivity.py +508 -0
  117. warp/fem/cache.py +572 -0
  118. warp/fem/dirichlet.py +202 -0
  119. warp/fem/domain.py +411 -0
  120. warp/fem/field/__init__.py +125 -0
  121. warp/fem/field/field.py +619 -0
  122. warp/fem/field/nodal_field.py +326 -0
  123. warp/fem/field/restriction.py +37 -0
  124. warp/fem/field/virtual.py +848 -0
  125. warp/fem/geometry/__init__.py +32 -0
  126. warp/fem/geometry/adaptive_nanogrid.py +857 -0
  127. warp/fem/geometry/closest_point.py +84 -0
  128. warp/fem/geometry/deformed_geometry.py +221 -0
  129. warp/fem/geometry/element.py +776 -0
  130. warp/fem/geometry/geometry.py +362 -0
  131. warp/fem/geometry/grid_2d.py +392 -0
  132. warp/fem/geometry/grid_3d.py +452 -0
  133. warp/fem/geometry/hexmesh.py +911 -0
  134. warp/fem/geometry/nanogrid.py +571 -0
  135. warp/fem/geometry/partition.py +389 -0
  136. warp/fem/geometry/quadmesh.py +663 -0
  137. warp/fem/geometry/tetmesh.py +855 -0
  138. warp/fem/geometry/trimesh.py +806 -0
  139. warp/fem/integrate.py +2335 -0
  140. warp/fem/linalg.py +419 -0
  141. warp/fem/operator.py +293 -0
  142. warp/fem/polynomial.py +229 -0
  143. warp/fem/quadrature/__init__.py +17 -0
  144. warp/fem/quadrature/pic_quadrature.py +299 -0
  145. warp/fem/quadrature/quadrature.py +591 -0
  146. warp/fem/space/__init__.py +228 -0
  147. warp/fem/space/basis_function_space.py +468 -0
  148. warp/fem/space/basis_space.py +667 -0
  149. warp/fem/space/dof_mapper.py +251 -0
  150. warp/fem/space/function_space.py +309 -0
  151. warp/fem/space/grid_2d_function_space.py +177 -0
  152. warp/fem/space/grid_3d_function_space.py +227 -0
  153. warp/fem/space/hexmesh_function_space.py +257 -0
  154. warp/fem/space/nanogrid_function_space.py +201 -0
  155. warp/fem/space/partition.py +367 -0
  156. warp/fem/space/quadmesh_function_space.py +223 -0
  157. warp/fem/space/restriction.py +179 -0
  158. warp/fem/space/shape/__init__.py +143 -0
  159. warp/fem/space/shape/cube_shape_function.py +1105 -0
  160. warp/fem/space/shape/shape_function.py +133 -0
  161. warp/fem/space/shape/square_shape_function.py +926 -0
  162. warp/fem/space/shape/tet_shape_function.py +834 -0
  163. warp/fem/space/shape/triangle_shape_function.py +672 -0
  164. warp/fem/space/tetmesh_function_space.py +271 -0
  165. warp/fem/space/topology.py +424 -0
  166. warp/fem/space/trimesh_function_space.py +194 -0
  167. warp/fem/types.py +99 -0
  168. warp/fem/utils.py +420 -0
  169. warp/jax.py +187 -0
  170. warp/jax_experimental/__init__.py +16 -0
  171. warp/jax_experimental/custom_call.py +351 -0
  172. warp/jax_experimental/ffi.py +698 -0
  173. warp/jax_experimental/xla_ffi.py +602 -0
  174. warp/math.py +244 -0
  175. warp/native/array.h +1145 -0
  176. warp/native/builtin.h +1800 -0
  177. warp/native/bvh.cpp +492 -0
  178. warp/native/bvh.cu +791 -0
  179. warp/native/bvh.h +554 -0
  180. warp/native/clang/clang.cpp +536 -0
  181. warp/native/coloring.cpp +613 -0
  182. warp/native/crt.cpp +51 -0
  183. warp/native/crt.h +362 -0
  184. warp/native/cuda_crt.h +1058 -0
  185. warp/native/cuda_util.cpp +646 -0
  186. warp/native/cuda_util.h +307 -0
  187. warp/native/error.cpp +77 -0
  188. warp/native/error.h +36 -0
  189. warp/native/exports.h +1878 -0
  190. warp/native/fabric.h +245 -0
  191. warp/native/hashgrid.cpp +311 -0
  192. warp/native/hashgrid.cu +87 -0
  193. warp/native/hashgrid.h +240 -0
  194. warp/native/initializer_array.h +41 -0
  195. warp/native/intersect.h +1230 -0
  196. warp/native/intersect_adj.h +375 -0
  197. warp/native/intersect_tri.h +339 -0
  198. warp/native/marching.cpp +19 -0
  199. warp/native/marching.cu +514 -0
  200. warp/native/marching.h +19 -0
  201. warp/native/mat.h +2220 -0
  202. warp/native/mathdx.cpp +87 -0
  203. warp/native/matnn.h +343 -0
  204. warp/native/mesh.cpp +266 -0
  205. warp/native/mesh.cu +404 -0
  206. warp/native/mesh.h +1980 -0
  207. warp/native/nanovdb/GridHandle.h +366 -0
  208. warp/native/nanovdb/HostBuffer.h +590 -0
  209. warp/native/nanovdb/NanoVDB.h +6624 -0
  210. warp/native/nanovdb/PNanoVDB.h +3390 -0
  211. warp/native/noise.h +859 -0
  212. warp/native/quat.h +1371 -0
  213. warp/native/rand.h +342 -0
  214. warp/native/range.h +139 -0
  215. warp/native/reduce.cpp +174 -0
  216. warp/native/reduce.cu +364 -0
  217. warp/native/runlength_encode.cpp +79 -0
  218. warp/native/runlength_encode.cu +61 -0
  219. warp/native/scan.cpp +47 -0
  220. warp/native/scan.cu +53 -0
  221. warp/native/scan.h +23 -0
  222. warp/native/solid_angle.h +466 -0
  223. warp/native/sort.cpp +251 -0
  224. warp/native/sort.cu +277 -0
  225. warp/native/sort.h +33 -0
  226. warp/native/sparse.cpp +378 -0
  227. warp/native/sparse.cu +524 -0
  228. warp/native/spatial.h +657 -0
  229. warp/native/svd.h +702 -0
  230. warp/native/temp_buffer.h +46 -0
  231. warp/native/tile.h +2584 -0
  232. warp/native/tile_reduce.h +264 -0
  233. warp/native/vec.h +1426 -0
  234. warp/native/volume.cpp +501 -0
  235. warp/native/volume.cu +67 -0
  236. warp/native/volume.h +969 -0
  237. warp/native/volume_builder.cu +477 -0
  238. warp/native/volume_builder.h +52 -0
  239. warp/native/volume_impl.h +70 -0
  240. warp/native/warp.cpp +1082 -0
  241. warp/native/warp.cu +3636 -0
  242. warp/native/warp.h +381 -0
  243. warp/optim/__init__.py +17 -0
  244. warp/optim/adam.py +163 -0
  245. warp/optim/linear.py +1137 -0
  246. warp/optim/sgd.py +112 -0
  247. warp/paddle.py +407 -0
  248. warp/render/__init__.py +18 -0
  249. warp/render/render_opengl.py +3518 -0
  250. warp/render/render_usd.py +784 -0
  251. warp/render/utils.py +160 -0
  252. warp/sim/__init__.py +65 -0
  253. warp/sim/articulation.py +793 -0
  254. warp/sim/collide.py +2395 -0
  255. warp/sim/graph_coloring.py +300 -0
  256. warp/sim/import_mjcf.py +790 -0
  257. warp/sim/import_snu.py +227 -0
  258. warp/sim/import_urdf.py +579 -0
  259. warp/sim/import_usd.py +894 -0
  260. warp/sim/inertia.py +324 -0
  261. warp/sim/integrator.py +242 -0
  262. warp/sim/integrator_euler.py +1997 -0
  263. warp/sim/integrator_featherstone.py +2101 -0
  264. warp/sim/integrator_vbd.py +2048 -0
  265. warp/sim/integrator_xpbd.py +3292 -0
  266. warp/sim/model.py +4791 -0
  267. warp/sim/particles.py +121 -0
  268. warp/sim/render.py +427 -0
  269. warp/sim/utils.py +428 -0
  270. warp/sparse.py +2057 -0
  271. warp/stubs.py +3333 -0
  272. warp/tape.py +1203 -0
  273. warp/tests/__init__.py +1 -0
  274. warp/tests/__main__.py +4 -0
  275. warp/tests/assets/curlnoise_golden.npy +0 -0
  276. warp/tests/assets/mlp_golden.npy +0 -0
  277. warp/tests/assets/pixel.npy +0 -0
  278. warp/tests/assets/pnoise_golden.npy +0 -0
  279. warp/tests/assets/spiky.usd +0 -0
  280. warp/tests/assets/test_grid.nvdb +0 -0
  281. warp/tests/assets/test_index_grid.nvdb +0 -0
  282. warp/tests/assets/test_int32_grid.nvdb +0 -0
  283. warp/tests/assets/test_vec_grid.nvdb +0 -0
  284. warp/tests/assets/torus.nvdb +0 -0
  285. warp/tests/assets/torus.usda +105 -0
  286. warp/tests/aux_test_class_kernel.py +34 -0
  287. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  288. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  289. warp/tests/aux_test_dependent.py +29 -0
  290. warp/tests/aux_test_grad_customs.py +29 -0
  291. warp/tests/aux_test_instancing_gc.py +26 -0
  292. warp/tests/aux_test_module_unload.py +23 -0
  293. warp/tests/aux_test_name_clash1.py +40 -0
  294. warp/tests/aux_test_name_clash2.py +40 -0
  295. warp/tests/aux_test_reference.py +9 -0
  296. warp/tests/aux_test_reference_reference.py +8 -0
  297. warp/tests/aux_test_square.py +16 -0
  298. warp/tests/aux_test_unresolved_func.py +22 -0
  299. warp/tests/aux_test_unresolved_symbol.py +22 -0
  300. warp/tests/cuda/__init__.py +0 -0
  301. warp/tests/cuda/test_async.py +676 -0
  302. warp/tests/cuda/test_ipc.py +124 -0
  303. warp/tests/cuda/test_mempool.py +233 -0
  304. warp/tests/cuda/test_multigpu.py +169 -0
  305. warp/tests/cuda/test_peer.py +139 -0
  306. warp/tests/cuda/test_pinned.py +84 -0
  307. warp/tests/cuda/test_streams.py +634 -0
  308. warp/tests/geometry/__init__.py +0 -0
  309. warp/tests/geometry/test_bvh.py +200 -0
  310. warp/tests/geometry/test_hash_grid.py +221 -0
  311. warp/tests/geometry/test_marching_cubes.py +74 -0
  312. warp/tests/geometry/test_mesh.py +316 -0
  313. warp/tests/geometry/test_mesh_query_aabb.py +399 -0
  314. warp/tests/geometry/test_mesh_query_point.py +932 -0
  315. warp/tests/geometry/test_mesh_query_ray.py +311 -0
  316. warp/tests/geometry/test_volume.py +1103 -0
  317. warp/tests/geometry/test_volume_write.py +346 -0
  318. warp/tests/interop/__init__.py +0 -0
  319. warp/tests/interop/test_dlpack.py +729 -0
  320. warp/tests/interop/test_jax.py +371 -0
  321. warp/tests/interop/test_paddle.py +800 -0
  322. warp/tests/interop/test_torch.py +1001 -0
  323. warp/tests/run_coverage_serial.py +39 -0
  324. warp/tests/sim/__init__.py +0 -0
  325. warp/tests/sim/disabled_kinematics.py +244 -0
  326. warp/tests/sim/flaky_test_sim_grad.py +290 -0
  327. warp/tests/sim/test_collision.py +604 -0
  328. warp/tests/sim/test_coloring.py +258 -0
  329. warp/tests/sim/test_model.py +224 -0
  330. warp/tests/sim/test_sim_grad_bounce_linear.py +212 -0
  331. warp/tests/sim/test_sim_kinematics.py +98 -0
  332. warp/tests/sim/test_vbd.py +597 -0
  333. warp/tests/test_adam.py +163 -0
  334. warp/tests/test_arithmetic.py +1096 -0
  335. warp/tests/test_array.py +2972 -0
  336. warp/tests/test_array_reduce.py +156 -0
  337. warp/tests/test_assert.py +250 -0
  338. warp/tests/test_atomic.py +153 -0
  339. warp/tests/test_bool.py +220 -0
  340. warp/tests/test_builtins_resolution.py +1298 -0
  341. warp/tests/test_closest_point_edge_edge.py +327 -0
  342. warp/tests/test_codegen.py +810 -0
  343. warp/tests/test_codegen_instancing.py +1495 -0
  344. warp/tests/test_compile_consts.py +215 -0
  345. warp/tests/test_conditional.py +252 -0
  346. warp/tests/test_context.py +42 -0
  347. warp/tests/test_copy.py +238 -0
  348. warp/tests/test_ctypes.py +638 -0
  349. warp/tests/test_dense.py +73 -0
  350. warp/tests/test_devices.py +97 -0
  351. warp/tests/test_examples.py +482 -0
  352. warp/tests/test_fabricarray.py +996 -0
  353. warp/tests/test_fast_math.py +74 -0
  354. warp/tests/test_fem.py +2003 -0
  355. warp/tests/test_fp16.py +136 -0
  356. warp/tests/test_func.py +454 -0
  357. warp/tests/test_future_annotations.py +98 -0
  358. warp/tests/test_generics.py +656 -0
  359. warp/tests/test_grad.py +893 -0
  360. warp/tests/test_grad_customs.py +339 -0
  361. warp/tests/test_grad_debug.py +341 -0
  362. warp/tests/test_implicit_init.py +411 -0
  363. warp/tests/test_import.py +45 -0
  364. warp/tests/test_indexedarray.py +1140 -0
  365. warp/tests/test_intersect.py +73 -0
  366. warp/tests/test_iter.py +76 -0
  367. warp/tests/test_large.py +177 -0
  368. warp/tests/test_launch.py +411 -0
  369. warp/tests/test_lerp.py +151 -0
  370. warp/tests/test_linear_solvers.py +193 -0
  371. warp/tests/test_lvalue.py +427 -0
  372. warp/tests/test_mat.py +2089 -0
  373. warp/tests/test_mat_lite.py +122 -0
  374. warp/tests/test_mat_scalar_ops.py +2913 -0
  375. warp/tests/test_math.py +178 -0
  376. warp/tests/test_mlp.py +282 -0
  377. warp/tests/test_module_hashing.py +258 -0
  378. warp/tests/test_modules_lite.py +44 -0
  379. warp/tests/test_noise.py +252 -0
  380. warp/tests/test_operators.py +299 -0
  381. warp/tests/test_options.py +129 -0
  382. warp/tests/test_overwrite.py +551 -0
  383. warp/tests/test_print.py +339 -0
  384. warp/tests/test_quat.py +2315 -0
  385. warp/tests/test_rand.py +339 -0
  386. warp/tests/test_reload.py +302 -0
  387. warp/tests/test_rounding.py +185 -0
  388. warp/tests/test_runlength_encode.py +196 -0
  389. warp/tests/test_scalar_ops.py +105 -0
  390. warp/tests/test_smoothstep.py +108 -0
  391. warp/tests/test_snippet.py +318 -0
  392. warp/tests/test_sparse.py +582 -0
  393. warp/tests/test_spatial.py +2229 -0
  394. warp/tests/test_special_values.py +361 -0
  395. warp/tests/test_static.py +592 -0
  396. warp/tests/test_struct.py +734 -0
  397. warp/tests/test_tape.py +204 -0
  398. warp/tests/test_transient_module.py +93 -0
  399. warp/tests/test_triangle_closest_point.py +145 -0
  400. warp/tests/test_types.py +562 -0
  401. warp/tests/test_utils.py +588 -0
  402. warp/tests/test_vec.py +1487 -0
  403. warp/tests/test_vec_lite.py +80 -0
  404. warp/tests/test_vec_scalar_ops.py +2327 -0
  405. warp/tests/test_verify_fp.py +100 -0
  406. warp/tests/tile/__init__.py +0 -0
  407. warp/tests/tile/test_tile.py +780 -0
  408. warp/tests/tile/test_tile_load.py +407 -0
  409. warp/tests/tile/test_tile_mathdx.py +208 -0
  410. warp/tests/tile/test_tile_mlp.py +402 -0
  411. warp/tests/tile/test_tile_reduce.py +447 -0
  412. warp/tests/tile/test_tile_shared_memory.py +247 -0
  413. warp/tests/tile/test_tile_view.py +173 -0
  414. warp/tests/unittest_serial.py +47 -0
  415. warp/tests/unittest_suites.py +427 -0
  416. warp/tests/unittest_utils.py +468 -0
  417. warp/tests/walkthrough_debug.py +93 -0
  418. warp/thirdparty/__init__.py +0 -0
  419. warp/thirdparty/appdirs.py +598 -0
  420. warp/thirdparty/dlpack.py +145 -0
  421. warp/thirdparty/unittest_parallel.py +570 -0
  422. warp/torch.py +391 -0
  423. warp/types.py +5230 -0
  424. warp/utils.py +1137 -0
  425. warp_lang-1.7.0.dist-info/METADATA +516 -0
  426. warp_lang-1.7.0.dist-info/RECORD +429 -0
  427. warp_lang-1.7.0.dist-info/WHEEL +5 -0
  428. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  429. warp_lang-1.7.0.dist-info/top_level.txt +1 -0
warp/native/array.h ADDED
@@ -0,0 +1,1145 @@
1
+ /*
2
+ * SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: Apache-2.0
4
+ *
5
+ * Licensed under the Apache License, Version 2.0 (the "License");
6
+ * you may not use this file except in compliance with the License.
7
+ * You may obtain a copy of the License at
8
+ *
9
+ * http://www.apache.org/licenses/LICENSE-2.0
10
+ *
11
+ * Unless required by applicable law or agreed to in writing, software
12
+ * distributed under the License is distributed on an "AS IS" BASIS,
13
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ * See the License for the specific language governing permissions and
15
+ * limitations under the License.
16
+ */
17
+
18
+ #pragma once
19
+
20
+ #include "builtin.h"
21
+
22
+ namespace wp
23
+ {
24
+
25
+ #if FP_CHECK
26
+
27
+ #define FP_ASSERT_FWD(value) \
28
+ print(value); \
29
+ printf(")\n"); \
30
+ assert(0); \
31
+
32
+ #define FP_ASSERT_ADJ(value, adj_value) \
33
+ print(value); \
34
+ printf(", "); \
35
+ print(adj_value); \
36
+ printf(")\n"); \
37
+ assert(0); \
38
+
39
+ #define FP_VERIFY_FWD(value) \
40
+ if (!isfinite(value)) { \
41
+ printf("%s:%d - %s(addr", __FILE__, __LINE__, __FUNCTION__); \
42
+ FP_ASSERT_FWD(value) \
43
+ } \
44
+
45
+ #define FP_VERIFY_FWD_1(value) \
46
+ if (!isfinite(value)) { \
47
+ printf("%s:%d - %s(arr, %d) ", __FILE__, __LINE__, __FUNCTION__, i); \
48
+ FP_ASSERT_FWD(value) \
49
+ } \
50
+
51
+ #define FP_VERIFY_FWD_2(value) \
52
+ if (!isfinite(value)) { \
53
+ printf("%s:%d - %s(arr, %d, %d) ", __FILE__, __LINE__, __FUNCTION__, i, j); \
54
+ FP_ASSERT_FWD(value) \
55
+ } \
56
+
57
+ #define FP_VERIFY_FWD_3(value) \
58
+ if (!isfinite(value)) { \
59
+ printf("%s:%d - %s(arr, %d, %d, %d) ", __FILE__, __LINE__, __FUNCTION__, i, j, k); \
60
+ FP_ASSERT_FWD(value) \
61
+ } \
62
+
63
+ #define FP_VERIFY_FWD_4(value) \
64
+ if (!isfinite(value)) { \
65
+ printf("%s:%d - %s(arr, %d, %d, %d, %d) ", __FILE__, __LINE__, __FUNCTION__, i, j, k, l); \
66
+ FP_ASSERT_FWD(value) \
67
+ } \
68
+
69
+ #define FP_VERIFY_ADJ(value, adj_value) \
70
+ if (!isfinite(value) || !isfinite(adj_value)) \
71
+ { \
72
+ printf("%s:%d - %s(addr", __FILE__, __LINE__, __FUNCTION__); \
73
+ FP_ASSERT_ADJ(value, adj_value); \
74
+ } \
75
+
76
+ #define FP_VERIFY_ADJ_1(value, adj_value) \
77
+ if (!isfinite(value) || !isfinite(adj_value)) \
78
+ { \
79
+ printf("%s:%d - %s(arr, %d) ", __FILE__, __LINE__, __FUNCTION__, i); \
80
+ FP_ASSERT_ADJ(value, adj_value); \
81
+ } \
82
+
83
+ #define FP_VERIFY_ADJ_2(value, adj_value) \
84
+ if (!isfinite(value) || !isfinite(adj_value)) \
85
+ { \
86
+ printf("%s:%d - %s(arr, %d, %d) ", __FILE__, __LINE__, __FUNCTION__, i, j); \
87
+ FP_ASSERT_ADJ(value, adj_value); \
88
+ } \
89
+
90
+ #define FP_VERIFY_ADJ_3(value, adj_value) \
91
+ if (!isfinite(value) || !isfinite(adj_value)) \
92
+ { \
93
+ printf("%s:%d - %s(arr, %d, %d, %d) ", __FILE__, __LINE__, __FUNCTION__, i, j, k); \
94
+ FP_ASSERT_ADJ(value, adj_value); \
95
+ } \
96
+
97
+ #define FP_VERIFY_ADJ_4(value, adj_value) \
98
+ if (!isfinite(value) || !isfinite(adj_value)) \
99
+ { \
100
+ printf("%s:%d - %s(arr, %d, %d, %d, %d) ", __FILE__, __LINE__, __FUNCTION__, i, j, k, l); \
101
+ FP_ASSERT_ADJ(value, adj_value); \
102
+ } \
103
+
104
+
105
+ #else
106
+
107
+ #define FP_VERIFY_FWD(value) {}
108
+ #define FP_VERIFY_FWD_1(value) {}
109
+ #define FP_VERIFY_FWD_2(value) {}
110
+ #define FP_VERIFY_FWD_3(value) {}
111
+ #define FP_VERIFY_FWD_4(value) {}
112
+
113
+ #define FP_VERIFY_ADJ(value, adj_value) {}
114
+ #define FP_VERIFY_ADJ_1(value, adj_value) {}
115
+ #define FP_VERIFY_ADJ_2(value, adj_value) {}
116
+ #define FP_VERIFY_ADJ_3(value, adj_value) {}
117
+ #define FP_VERIFY_ADJ_4(value, adj_value) {}
118
+
119
+ #endif // WP_FP_CHECK
120
+
121
+ const int ARRAY_MAX_DIMS = 4; // must match constant in types.py
122
+
123
+ // must match constants in types.py
124
+ const int ARRAY_TYPE_REGULAR = 0;
125
+ const int ARRAY_TYPE_INDEXED = 1;
126
+ const int ARRAY_TYPE_FABRIC = 2;
127
+ const int ARRAY_TYPE_FABRIC_INDEXED = 3;
128
+
129
+ struct shape_t
130
+ {
131
+ int dims[ARRAY_MAX_DIMS];
132
+
133
+ CUDA_CALLABLE inline shape_t()
134
+ : dims()
135
+ {}
136
+
137
+ CUDA_CALLABLE inline int operator[](int i) const
138
+ {
139
+ assert(i < ARRAY_MAX_DIMS);
140
+ return dims[i];
141
+ }
142
+
143
+ CUDA_CALLABLE inline int& operator[](int i)
144
+ {
145
+ assert(i < ARRAY_MAX_DIMS);
146
+ return dims[i];
147
+ }
148
+ };
149
+
150
+ CUDA_CALLABLE inline int extract(const shape_t& s, int i)
151
+ {
152
+ return s.dims[i];
153
+ }
154
+
155
+ CUDA_CALLABLE inline void adj_extract(const shape_t& s, int i, const shape_t& adj_s, int adj_i, int adj_ret) {}
156
+
157
+ inline CUDA_CALLABLE void print(shape_t s)
158
+ {
159
+ // todo: only print valid dims, currently shape has a fixed size
160
+ // but we don't know how many dims are valid (e.g.: 1d, 2d, etc)
161
+ // should probably store ndim with shape
162
+ printf("(%d, %d, %d, %d)\n", s.dims[0], s.dims[1], s.dims[2], s.dims[3]);
163
+ }
164
+ inline CUDA_CALLABLE void adj_print(shape_t s, shape_t& shape_t) {}
165
+
166
+
167
+ template <typename T>
168
+ struct array_t
169
+ {
170
+ CUDA_CALLABLE inline array_t()
171
+ : data(nullptr),
172
+ grad(nullptr),
173
+ shape(),
174
+ strides(),
175
+ ndim(0)
176
+ {}
177
+
178
+ CUDA_CALLABLE array_t(T* data, int size, T* grad=nullptr) : data(data), grad(grad) {
179
+ // constructor for 1d array
180
+ shape.dims[0] = size;
181
+ shape.dims[1] = 0;
182
+ shape.dims[2] = 0;
183
+ shape.dims[3] = 0;
184
+ ndim = 1;
185
+ strides[0] = sizeof(T);
186
+ strides[1] = 0;
187
+ strides[2] = 0;
188
+ strides[3] = 0;
189
+ }
190
+ CUDA_CALLABLE array_t(T* data, int dim0, int dim1, T* grad=nullptr) : data(data), grad(grad) {
191
+ // constructor for 2d array
192
+ shape.dims[0] = dim0;
193
+ shape.dims[1] = dim1;
194
+ shape.dims[2] = 0;
195
+ shape.dims[3] = 0;
196
+ ndim = 2;
197
+ strides[0] = dim1 * sizeof(T);
198
+ strides[1] = sizeof(T);
199
+ strides[2] = 0;
200
+ strides[3] = 0;
201
+ }
202
+ CUDA_CALLABLE array_t(T* data, int dim0, int dim1, int dim2, T* grad=nullptr) : data(data), grad(grad) {
203
+ // constructor for 3d array
204
+ shape.dims[0] = dim0;
205
+ shape.dims[1] = dim1;
206
+ shape.dims[2] = dim2;
207
+ shape.dims[3] = 0;
208
+ ndim = 3;
209
+ strides[0] = dim1 * dim2 * sizeof(T);
210
+ strides[1] = dim2 * sizeof(T);
211
+ strides[2] = sizeof(T);
212
+ strides[3] = 0;
213
+ }
214
+ CUDA_CALLABLE array_t(T* data, int dim0, int dim1, int dim2, int dim3, T* grad=nullptr) : data(data), grad(grad) {
215
+ // constructor for 4d array
216
+ shape.dims[0] = dim0;
217
+ shape.dims[1] = dim1;
218
+ shape.dims[2] = dim2;
219
+ shape.dims[3] = dim3;
220
+ ndim = 4;
221
+ strides[0] = dim1 * dim2 * dim3 * sizeof(T);
222
+ strides[1] = dim2 * dim3 * sizeof(T);
223
+ strides[2] = dim3 * sizeof(T);
224
+ strides[3] = sizeof(T);
225
+ }
226
+
227
+ CUDA_CALLABLE array_t(uint64 data, int size, uint64 grad=0)
228
+ : array_t((T*)(data), size, (T*)(grad))
229
+ {}
230
+
231
+ CUDA_CALLABLE array_t(uint64 data, int dim0, int dim1, uint64 grad=0)
232
+ : array_t((T*)(data), dim0, dim1, (T*)(grad))
233
+ {}
234
+
235
+ CUDA_CALLABLE array_t(uint64 data, int dim0, int dim1, int dim2, uint64 grad=0)
236
+ : array_t((T*)(data), dim0, dim1, dim2, (T*)(grad))
237
+ {}
238
+
239
+ CUDA_CALLABLE array_t(uint64 data, int dim0, int dim1, int dim2, int dim3, uint64 grad=0)
240
+ : array_t((T*)(data), dim0, dim1, dim2, dim3, (T*)(grad))
241
+ {}
242
+
243
+ CUDA_CALLABLE inline bool empty() const { return !data; }
244
+
245
+ T* data;
246
+ T* grad;
247
+ shape_t shape;
248
+ int strides[ARRAY_MAX_DIMS];
249
+ int ndim;
250
+
251
+ CUDA_CALLABLE inline operator T*() const { return data; }
252
+ };
253
+
254
+
255
+ // TODO:
256
+ // - templated index type?
257
+ // - templated dimensionality? (also for array_t to save space when passing arrays to kernels)
258
+ template <typename T>
259
+ struct indexedarray_t
260
+ {
261
+ CUDA_CALLABLE inline indexedarray_t()
262
+ : arr(),
263
+ indices(),
264
+ shape()
265
+ {}
266
+
267
+ CUDA_CALLABLE inline bool empty() const { return !arr.data; }
268
+
269
+ array_t<T> arr;
270
+ int* indices[ARRAY_MAX_DIMS]; // index array per dimension (can be NULL)
271
+ shape_t shape; // element count per dimension (num. indices if indexed, array dim if not)
272
+ };
273
+
274
+
275
+ // return stride (in bytes) of the given index
276
+ template <typename T>
277
+ CUDA_CALLABLE inline size_t stride(const array_t<T>& a, int dim)
278
+ {
279
+ return size_t(a.strides[dim]);
280
+ }
281
+
282
+ template <typename T>
283
+ CUDA_CALLABLE inline T* data_at_byte_offset(const array_t<T>& a, size_t byte_offset)
284
+ {
285
+ return reinterpret_cast<T*>(reinterpret_cast<char*>(a.data) + byte_offset);
286
+ }
287
+
288
+ template <typename T>
289
+ CUDA_CALLABLE inline T* grad_at_byte_offset(const array_t<T>& a, size_t byte_offset)
290
+ {
291
+ return reinterpret_cast<T*>(reinterpret_cast<char*>(a.grad) + byte_offset);
292
+ }
293
+
294
+ template <typename T>
295
+ CUDA_CALLABLE inline size_t byte_offset(const array_t<T>& arr, int i)
296
+ {
297
+ assert(i >= 0 && i < arr.shape[0]);
298
+
299
+ return i*stride(arr, 0);
300
+ }
301
+
302
+ template <typename T>
303
+ CUDA_CALLABLE inline size_t byte_offset(const array_t<T>& arr, int i, int j)
304
+ {
305
+ // if (i < 0 || i >= arr.shape[0])
306
+ // printf("i: %d > arr.shape[0]: %d\n", i, arr.shape[0]);
307
+
308
+ // if (j < 0 || j >= arr.shape[1])
309
+ // printf("j: %d > arr.shape[1]: %d\n", j, arr.shape[1]);
310
+
311
+
312
+ assert(i >= 0 && i < arr.shape[0]);
313
+ assert(j >= 0 && j < arr.shape[1]);
314
+
315
+ return i*stride(arr, 0) + j*stride(arr, 1);
316
+ }
317
+
318
+ template <typename T>
319
+ CUDA_CALLABLE inline size_t byte_offset(const array_t<T>& arr, int i, int j, int k)
320
+ {
321
+ assert(i >= 0 && i < arr.shape[0]);
322
+ assert(j >= 0 && j < arr.shape[1]);
323
+ assert(k >= 0 && k < arr.shape[2]);
324
+
325
+ return i*stride(arr, 0) + j*stride(arr, 1) + k*stride(arr, 2);
326
+ }
327
+
328
+ template <typename T>
329
+ CUDA_CALLABLE inline size_t byte_offset(const array_t<T>& arr, int i, int j, int k, int l)
330
+ {
331
+ assert(i >= 0 && i < arr.shape[0]);
332
+ assert(j >= 0 && j < arr.shape[1]);
333
+ assert(k >= 0 && k < arr.shape[2]);
334
+ assert(l >= 0 && l < arr.shape[3]);
335
+
336
+ return i*stride(arr, 0) + j*stride(arr, 1) + k*stride(arr, 2) + l*stride(arr, 3);
337
+ }
338
+
339
+ template <typename T>
340
+ CUDA_CALLABLE inline T& index(const array_t<T>& arr, int i)
341
+ {
342
+ assert(arr.ndim == 1);
343
+ T& result = *data_at_byte_offset(arr, byte_offset(arr, i));
344
+ FP_VERIFY_FWD_1(result)
345
+
346
+ return result;
347
+ }
348
+
349
+ template <typename T>
350
+ CUDA_CALLABLE inline T& index(const array_t<T>& arr, int i, int j)
351
+ {
352
+ assert(arr.ndim == 2);
353
+ T& result = *data_at_byte_offset(arr, byte_offset(arr, i, j));
354
+ FP_VERIFY_FWD_2(result)
355
+
356
+ return result;
357
+ }
358
+
359
+ template <typename T>
360
+ CUDA_CALLABLE inline T& index(const array_t<T>& arr, int i, int j, int k)
361
+ {
362
+ assert(arr.ndim == 3);
363
+ T& result = *data_at_byte_offset(arr, byte_offset(arr, i, j, k));
364
+ FP_VERIFY_FWD_3(result)
365
+
366
+ return result;
367
+ }
368
+
369
+ template <typename T>
370
+ CUDA_CALLABLE inline T& index(const array_t<T>& arr, int i, int j, int k, int l)
371
+ {
372
+ assert(arr.ndim == 4);
373
+ T& result = *data_at_byte_offset(arr, byte_offset(arr, i, j, k, l));
374
+ FP_VERIFY_FWD_4(result)
375
+
376
+ return result;
377
+ }
378
+
379
+ template <typename T>
380
+ CUDA_CALLABLE inline T& index_grad(const array_t<T>& arr, int i)
381
+ {
382
+ T& result = *grad_at_byte_offset(arr, byte_offset(arr, i));
383
+ FP_VERIFY_FWD_1(result)
384
+
385
+ return result;
386
+ }
387
+
388
+ template <typename T>
389
+ CUDA_CALLABLE inline T& index_grad(const array_t<T>& arr, int i, int j)
390
+ {
391
+ T& result = *grad_at_byte_offset(arr, byte_offset(arr, i, j));
392
+ FP_VERIFY_FWD_2(result)
393
+
394
+ return result;
395
+ }
396
+
397
+ template <typename T>
398
+ CUDA_CALLABLE inline T& index_grad(const array_t<T>& arr, int i, int j, int k)
399
+ {
400
+ T& result = *grad_at_byte_offset(arr, byte_offset(arr, i, j, k));
401
+ FP_VERIFY_FWD_3(result)
402
+
403
+ return result;
404
+ }
405
+
406
+ template <typename T>
407
+ CUDA_CALLABLE inline T& index_grad(const array_t<T>& arr, int i, int j, int k, int l)
408
+ {
409
+ T& result = *grad_at_byte_offset(arr, byte_offset(arr, i, j, k, l));
410
+ FP_VERIFY_FWD_4(result)
411
+
412
+ return result;
413
+ }
414
+
415
+
416
+ template <typename T>
417
+ CUDA_CALLABLE inline T& index(const indexedarray_t<T>& iarr, int i)
418
+ {
419
+ assert(iarr.arr.ndim == 1);
420
+ assert(i >= 0 && i < iarr.shape[0]);
421
+
422
+ if (iarr.indices[0])
423
+ {
424
+ i = iarr.indices[0][i];
425
+ assert(i >= 0 && i < iarr.arr.shape[0]);
426
+ }
427
+
428
+ T& result = *data_at_byte_offset(iarr.arr, byte_offset(iarr.arr, i));
429
+ FP_VERIFY_FWD_1(result)
430
+
431
+ return result;
432
+ }
433
+
434
+ template <typename T>
435
+ CUDA_CALLABLE inline T& index(const indexedarray_t<T>& iarr, int i, int j)
436
+ {
437
+ assert(iarr.arr.ndim == 2);
438
+ assert(i >= 0 && i < iarr.shape[0]);
439
+ assert(j >= 0 && j < iarr.shape[1]);
440
+
441
+ if (iarr.indices[0])
442
+ {
443
+ i = iarr.indices[0][i];
444
+ assert(i >= 0 && i < iarr.arr.shape[0]);
445
+ }
446
+ if (iarr.indices[1])
447
+ {
448
+ j = iarr.indices[1][j];
449
+ assert(j >= 0 && j < iarr.arr.shape[1]);
450
+ }
451
+
452
+ T& result = *data_at_byte_offset(iarr.arr, byte_offset(iarr.arr, i, j));
453
+ FP_VERIFY_FWD_1(result)
454
+
455
+ return result;
456
+ }
457
+
458
+ template <typename T>
459
+ CUDA_CALLABLE inline T& index(const indexedarray_t<T>& iarr, int i, int j, int k)
460
+ {
461
+ assert(iarr.arr.ndim == 3);
462
+ assert(i >= 0 && i < iarr.shape[0]);
463
+ assert(j >= 0 && j < iarr.shape[1]);
464
+ assert(k >= 0 && k < iarr.shape[2]);
465
+
466
+ if (iarr.indices[0])
467
+ {
468
+ i = iarr.indices[0][i];
469
+ assert(i >= 0 && i < iarr.arr.shape[0]);
470
+ }
471
+ if (iarr.indices[1])
472
+ {
473
+ j = iarr.indices[1][j];
474
+ assert(j >= 0 && j < iarr.arr.shape[1]);
475
+ }
476
+ if (iarr.indices[2])
477
+ {
478
+ k = iarr.indices[2][k];
479
+ assert(k >= 0 && k < iarr.arr.shape[2]);
480
+ }
481
+
482
+ T& result = *data_at_byte_offset(iarr.arr, byte_offset(iarr.arr, i, j, k));
483
+ FP_VERIFY_FWD_1(result)
484
+
485
+ return result;
486
+ }
487
+
488
+ template <typename T>
489
+ CUDA_CALLABLE inline T& index(const indexedarray_t<T>& iarr, int i, int j, int k, int l)
490
+ {
491
+ assert(iarr.arr.ndim == 4);
492
+ assert(i >= 0 && i < iarr.shape[0]);
493
+ assert(j >= 0 && j < iarr.shape[1]);
494
+ assert(k >= 0 && k < iarr.shape[2]);
495
+ assert(l >= 0 && l < iarr.shape[3]);
496
+
497
+ if (iarr.indices[0])
498
+ {
499
+ i = iarr.indices[0][i];
500
+ assert(i >= 0 && i < iarr.arr.shape[0]);
501
+ }
502
+ if (iarr.indices[1])
503
+ {
504
+ j = iarr.indices[1][j];
505
+ assert(j >= 0 && j < iarr.arr.shape[1]);
506
+ }
507
+ if (iarr.indices[2])
508
+ {
509
+ k = iarr.indices[2][k];
510
+ assert(k >= 0 && k < iarr.arr.shape[2]);
511
+ }
512
+ if (iarr.indices[3])
513
+ {
514
+ l = iarr.indices[3][l];
515
+ assert(l >= 0 && l < iarr.arr.shape[3]);
516
+ }
517
+
518
+ T& result = *data_at_byte_offset(iarr.arr, byte_offset(iarr.arr, i, j, k, l));
519
+ FP_VERIFY_FWD_1(result)
520
+
521
+ return result;
522
+ }
523
+
524
+
525
+ template <typename T>
526
+ CUDA_CALLABLE inline array_t<T> view(array_t<T>& src, int i)
527
+ {
528
+ assert(src.ndim > 1);
529
+ assert(i >= 0 && i < src.shape[0]);
530
+
531
+ array_t<T> a;
532
+ size_t offset = byte_offset(src, i);
533
+ a.data = data_at_byte_offset(src, offset);
534
+ if (src.grad)
535
+ a.grad = grad_at_byte_offset(src, offset);
536
+ a.shape[0] = src.shape[1];
537
+ a.shape[1] = src.shape[2];
538
+ a.shape[2] = src.shape[3];
539
+ a.strides[0] = src.strides[1];
540
+ a.strides[1] = src.strides[2];
541
+ a.strides[2] = src.strides[3];
542
+ a.ndim = src.ndim-1;
543
+
544
+ return a;
545
+ }
546
+
547
+ template <typename T>
548
+ CUDA_CALLABLE inline array_t<T> view(array_t<T>& src, int i, int j)
549
+ {
550
+ assert(src.ndim > 2);
551
+ assert(i >= 0 && i < src.shape[0]);
552
+ assert(j >= 0 && j < src.shape[1]);
553
+
554
+ array_t<T> a;
555
+ size_t offset = byte_offset(src, i, j);
556
+ a.data = data_at_byte_offset(src, offset);
557
+ if (src.grad)
558
+ a.grad = grad_at_byte_offset(src, offset);
559
+ a.shape[0] = src.shape[2];
560
+ a.shape[1] = src.shape[3];
561
+ a.strides[0] = src.strides[2];
562
+ a.strides[1] = src.strides[3];
563
+ a.ndim = src.ndim-2;
564
+
565
+ return a;
566
+ }
567
+
568
+ template <typename T>
569
+ CUDA_CALLABLE inline array_t<T> view(array_t<T>& src, int i, int j, int k)
570
+ {
571
+ assert(src.ndim > 3);
572
+ assert(i >= 0 && i < src.shape[0]);
573
+ assert(j >= 0 && j < src.shape[1]);
574
+ assert(k >= 0 && k < src.shape[2]);
575
+
576
+ array_t<T> a;
577
+ size_t offset = byte_offset(src, i, j, k);
578
+ a.data = data_at_byte_offset(src, offset);
579
+ if (src.grad)
580
+ a.grad = grad_at_byte_offset(src, offset);
581
+ a.shape[0] = src.shape[3];
582
+ a.strides[0] = src.strides[3];
583
+ a.ndim = src.ndim-3;
584
+
585
+ return a;
586
+ }
587
+
588
+
589
+ template <typename T>
590
+ CUDA_CALLABLE inline indexedarray_t<T> view(indexedarray_t<T>& src, int i)
591
+ {
592
+ assert(src.arr.ndim > 1);
593
+
594
+ if (src.indices[0])
595
+ {
596
+ assert(i >= 0 && i < src.shape[0]);
597
+ i = src.indices[0][i];
598
+ }
599
+
600
+ indexedarray_t<T> a;
601
+ a.arr = view(src.arr, i);
602
+ a.indices[0] = src.indices[1];
603
+ a.indices[1] = src.indices[2];
604
+ a.indices[2] = src.indices[3];
605
+ a.shape[0] = src.shape[1];
606
+ a.shape[1] = src.shape[2];
607
+ a.shape[2] = src.shape[3];
608
+
609
+ return a;
610
+ }
611
+
612
+ template <typename T>
613
+ CUDA_CALLABLE inline indexedarray_t<T> view(indexedarray_t<T>& src, int i, int j)
614
+ {
615
+ assert(src.arr.ndim > 2);
616
+
617
+ if (src.indices[0])
618
+ {
619
+ assert(i >= 0 && i < src.shape[0]);
620
+ i = src.indices[0][i];
621
+ }
622
+ if (src.indices[1])
623
+ {
624
+ assert(j >= 0 && j < src.shape[1]);
625
+ j = src.indices[1][j];
626
+ }
627
+
628
+ indexedarray_t<T> a;
629
+ a.arr = view(src.arr, i, j);
630
+ a.indices[0] = src.indices[2];
631
+ a.indices[1] = src.indices[3];
632
+ a.shape[0] = src.shape[2];
633
+ a.shape[1] = src.shape[3];
634
+
635
+ return a;
636
+ }
637
+
638
+ template <typename T>
639
+ CUDA_CALLABLE inline indexedarray_t<T> view(indexedarray_t<T>& src, int i, int j, int k)
640
+ {
641
+ assert(src.arr.ndim > 3);
642
+
643
+ if (src.indices[0])
644
+ {
645
+ assert(i >= 0 && i < src.shape[0]);
646
+ i = src.indices[0][i];
647
+ }
648
+ if (src.indices[1])
649
+ {
650
+ assert(j >= 0 && j < src.shape[1]);
651
+ j = src.indices[1][j];
652
+ }
653
+ if (src.indices[2])
654
+ {
655
+ assert(k >= 0 && k < src.shape[2]);
656
+ k = src.indices[2][k];
657
+ }
658
+
659
+ indexedarray_t<T> a;
660
+ a.arr = view(src.arr, i, j, k);
661
+ a.indices[0] = src.indices[3];
662
+ a.shape[0] = src.shape[3];
663
+
664
+ return a;
665
+ }
666
+
667
+ template<template<typename> class A1, template<typename> class A2, template<typename> class A3, typename T>
668
+ inline CUDA_CALLABLE void adj_view(A1<T>& src, int i, A2<T>& adj_src, int adj_i, A3<T> adj_ret) {}
669
+ template<template<typename> class A1, template<typename> class A2, template<typename> class A3, typename T>
670
+ inline CUDA_CALLABLE void adj_view(A1<T>& src, int i, int j, A2<T>& adj_src, int adj_i, int adj_j, A3<T> adj_ret) {}
671
+ template<template<typename> class A1, template<typename> class A2, template<typename> class A3, typename T>
672
+ inline CUDA_CALLABLE void adj_view(A1<T>& src, int i, int j, int k, A2<T>& adj_src, int adj_i, int adj_j, int adj_k, A3<T> adj_ret) {}
673
+
674
+ // TODO: lower_bound() for indexed arrays?
675
+
676
+ template <typename T>
677
+ CUDA_CALLABLE inline int lower_bound(const array_t<T>& arr, int arr_begin, int arr_end, T value)
678
+ {
679
+ assert(arr.ndim == 1);
680
+
681
+ int lower = arr_begin;
682
+ int upper = arr_end - 1;
683
+
684
+ while(lower < upper)
685
+ {
686
+ int mid = lower + (upper - lower) / 2;
687
+
688
+ if (arr[mid] < value)
689
+ {
690
+ lower = mid + 1;
691
+ }
692
+ else
693
+ {
694
+ upper = mid;
695
+ }
696
+ }
697
+
698
+ return lower;
699
+ }
700
+
701
+ template <typename T>
702
+ CUDA_CALLABLE inline int lower_bound(const array_t<T>& arr, T value)
703
+ {
704
+ return lower_bound(arr, 0, arr.shape[0], value);
705
+ }
706
+
707
+ template <typename T> inline CUDA_CALLABLE void adj_lower_bound(const array_t<T>& arr, T value, array_t<T> adj_arr, T adj_value, int adj_ret) {}
708
+ template <typename T> inline CUDA_CALLABLE void adj_lower_bound(const array_t<T>& arr, int arr_begin, int arr_end, T value, array_t<T> adj_arr, int adj_arr_begin, int adj_arr_end, T adj_value, int adj_ret) {}
709
+
710
+ template<template<typename> class A, typename T>
711
+ inline CUDA_CALLABLE T atomic_add(const A<T>& buf, int i, T value) { return atomic_add(&index(buf, i), value); }
712
+ template<template<typename> class A, typename T>
713
+ inline CUDA_CALLABLE T atomic_add(const A<T>& buf, int i, int j, T value) { return atomic_add(&index(buf, i, j), value); }
714
+ template<template<typename> class A, typename T>
715
+ inline CUDA_CALLABLE T atomic_add(const A<T>& buf, int i, int j, int k, T value) { return atomic_add(&index(buf, i, j, k), value); }
716
+ template<template<typename> class A, typename T>
717
+ inline CUDA_CALLABLE T atomic_add(const A<T>& buf, int i, int j, int k, int l, T value) { return atomic_add(&index(buf, i, j, k, l), value); }
718
+
719
+ template<template<typename> class A, typename T>
720
+ inline CUDA_CALLABLE T atomic_sub(const A<T>& buf, int i, T value) { return atomic_add(&index(buf, i), -value); }
721
+ template<template<typename> class A, typename T>
722
+ inline CUDA_CALLABLE T atomic_sub(const A<T>& buf, int i, int j, T value) { return atomic_add(&index(buf, i, j), -value); }
723
+ template<template<typename> class A, typename T>
724
+ inline CUDA_CALLABLE T atomic_sub(const A<T>& buf, int i, int j, int k, T value) { return atomic_add(&index(buf, i, j, k), -value); }
725
+ template<template<typename> class A, typename T>
726
+ inline CUDA_CALLABLE T atomic_sub(const A<T>& buf, int i, int j, int k, int l, T value) { return atomic_add(&index(buf, i, j, k, l), -value); }
727
+
728
+ template<template<typename> class A, typename T>
729
+ inline CUDA_CALLABLE T atomic_min(const A<T>& buf, int i, T value) { return atomic_min(&index(buf, i), value); }
730
+ template<template<typename> class A, typename T>
731
+ inline CUDA_CALLABLE T atomic_min(const A<T>& buf, int i, int j, T value) { return atomic_min(&index(buf, i, j), value); }
732
+ template<template<typename> class A, typename T>
733
+ inline CUDA_CALLABLE T atomic_min(const A<T>& buf, int i, int j, int k, T value) { return atomic_min(&index(buf, i, j, k), value); }
734
+ template<template<typename> class A, typename T>
735
+ inline CUDA_CALLABLE T atomic_min(const A<T>& buf, int i, int j, int k, int l, T value) { return atomic_min(&index(buf, i, j, k, l), value); }
736
+
737
+ template<template<typename> class A, typename T>
738
+ inline CUDA_CALLABLE T atomic_max(const A<T>& buf, int i, T value) { return atomic_max(&index(buf, i), value); }
739
+ template<template<typename> class A, typename T>
740
+ inline CUDA_CALLABLE T atomic_max(const A<T>& buf, int i, int j, T value) { return atomic_max(&index(buf, i, j), value); }
741
+ template<template<typename> class A, typename T>
742
+ inline CUDA_CALLABLE T atomic_max(const A<T>& buf, int i, int j, int k, T value) { return atomic_max(&index(buf, i, j, k), value); }
743
+ template<template<typename> class A, typename T>
744
+ inline CUDA_CALLABLE T atomic_max(const A<T>& buf, int i, int j, int k, int l, T value) { return atomic_max(&index(buf, i, j, k, l), value); }
745
+
746
+ template<template<typename> class A, typename T>
747
+ inline CUDA_CALLABLE T* address(const A<T>& buf, int i) { return &index(buf, i); }
748
+ template<template<typename> class A, typename T>
749
+ inline CUDA_CALLABLE T* address(const A<T>& buf, int i, int j) { return &index(buf, i, j); }
750
+ template<template<typename> class A, typename T>
751
+ inline CUDA_CALLABLE T* address(const A<T>& buf, int i, int j, int k) { return &index(buf, i, j, k); }
752
+ template<template<typename> class A, typename T>
753
+ inline CUDA_CALLABLE T* address(const A<T>& buf, int i, int j, int k, int l) { return &index(buf, i, j, k, l); }
754
+
755
+ template<template<typename> class A, typename T>
756
+ inline CUDA_CALLABLE void array_store(const A<T>& buf, int i, T value)
757
+ {
758
+ FP_VERIFY_FWD_1(value)
759
+
760
+ index(buf, i) = value;
761
+ }
762
+ template<template<typename> class A, typename T>
763
+ inline CUDA_CALLABLE void array_store(const A<T>& buf, int i, int j, T value)
764
+ {
765
+ FP_VERIFY_FWD_2(value)
766
+
767
+ index(buf, i, j) = value;
768
+ }
769
+ template<template<typename> class A, typename T>
770
+ inline CUDA_CALLABLE void array_store(const A<T>& buf, int i, int j, int k, T value)
771
+ {
772
+ FP_VERIFY_FWD_3(value)
773
+
774
+ index(buf, i, j, k) = value;
775
+ }
776
+ template<template<typename> class A, typename T>
777
+ inline CUDA_CALLABLE void array_store(const A<T>& buf, int i, int j, int k, int l, T value)
778
+ {
779
+ FP_VERIFY_FWD_4(value)
780
+
781
+ index(buf, i, j, k, l) = value;
782
+ }
783
+
784
+ template<typename T>
785
+ inline CUDA_CALLABLE void store(T* address, T value)
786
+ {
787
+ FP_VERIFY_FWD(value)
788
+
789
+ *address = value;
790
+ }
791
+
792
+ template<typename T>
793
+ inline CUDA_CALLABLE T load(T* address)
794
+ {
795
+ T value = *address;
796
+ FP_VERIFY_FWD(value)
797
+
798
+ return value;
799
+ }
800
+
801
+ // select operator to check for array being null
802
+ template <typename T1, typename T2>
803
+ CUDA_CALLABLE inline T2 select(const array_t<T1>& arr, const T2& a, const T2& b) { return arr.data?b:a; }
804
+
805
+ template <typename T1, typename T2>
806
+ CUDA_CALLABLE inline void adj_select(const array_t<T1>& arr, const T2& a, const T2& b, const array_t<T1>& adj_cond, T2& adj_a, T2& adj_b, const T2& adj_ret)
807
+ {
808
+ if (arr.data)
809
+ adj_b += adj_ret;
810
+ else
811
+ adj_a += adj_ret;
812
+ }
813
+
814
+ // where operator to check for array being null, opposite convention compared to select
815
+ template <typename T1, typename T2>
816
+ CUDA_CALLABLE inline T2 where(const array_t<T1>& arr, const T2& a, const T2& b) { return arr.data?a:b; }
817
+
818
+ template <typename T1, typename T2>
819
+ CUDA_CALLABLE inline void adj_where(const array_t<T1>& arr, const T2& a, const T2& b, const array_t<T1>& adj_cond, T2& adj_a, T2& adj_b, const T2& adj_ret)
820
+ {
821
+ if (arr.data)
822
+ adj_a += adj_ret;
823
+ else
824
+ adj_b += adj_ret;
825
+ }
826
+
827
+ // stub for the case where we have an nested array inside a struct and
828
+ // atomic add the whole struct onto an array (e.g.: during backwards pass)
829
+ template <typename T>
830
+ CUDA_CALLABLE inline void atomic_add(array_t<T>*, array_t<T>) {}
831
+
832
+ // for float and vector types this is just an alias for an atomic add
833
+ template <typename T>
834
+ CUDA_CALLABLE inline void adj_atomic_add(T* buf, T value) { atomic_add(buf, value); }
835
+
836
+
837
+ // for integral types we do not accumulate gradients
838
+ CUDA_CALLABLE inline void adj_atomic_add(int8* buf, int8 value) { }
839
+ CUDA_CALLABLE inline void adj_atomic_add(uint8* buf, uint8 value) { }
840
+ CUDA_CALLABLE inline void adj_atomic_add(int16* buf, int16 value) { }
841
+ CUDA_CALLABLE inline void adj_atomic_add(uint16* buf, uint16 value) { }
842
+ CUDA_CALLABLE inline void adj_atomic_add(int32* buf, int32 value) { }
843
+ CUDA_CALLABLE inline void adj_atomic_add(uint32* buf, uint32 value) { }
844
+ CUDA_CALLABLE inline void adj_atomic_add(int64* buf, int64 value) { }
845
+ CUDA_CALLABLE inline void adj_atomic_add(uint64* buf, uint64 value) { }
846
+
847
+ CUDA_CALLABLE inline void adj_atomic_add(bool* buf, bool value) { }
848
+
849
+ // only generate gradients for T types
850
+ template<typename T>
851
+ inline CUDA_CALLABLE void adj_address(const array_t<T>& buf, int i, const array_t<T>& adj_buf, int adj_i, const T& adj_output)
852
+ {
853
+ if (adj_buf.data)
854
+ adj_atomic_add(&index(adj_buf, i), adj_output);
855
+ else if (buf.grad)
856
+ adj_atomic_add(&index_grad(buf, i), adj_output);
857
+ }
858
+ template<typename T>
859
+ inline CUDA_CALLABLE void adj_address(const array_t<T>& buf, int i, int j, const array_t<T>& adj_buf, int adj_i, int adj_j, const T& adj_output)
860
+ {
861
+ if (adj_buf.data)
862
+ adj_atomic_add(&index(adj_buf, i, j), adj_output);
863
+ else if (buf.grad)
864
+ adj_atomic_add(&index_grad(buf, i, j), adj_output);
865
+ }
866
+ template<typename T>
867
+ inline CUDA_CALLABLE void adj_address(const array_t<T>& buf, int i, int j, int k, const array_t<T>& adj_buf, int adj_i, int adj_j, int adj_k, const T& adj_output)
868
+ {
869
+ if (adj_buf.data)
870
+ adj_atomic_add(&index(adj_buf, i, j, k), adj_output);
871
+ else if (buf.grad)
872
+ adj_atomic_add(&index_grad(buf, i, j, k), adj_output);
873
+ }
874
+ template<typename T>
875
+ inline CUDA_CALLABLE void adj_address(const array_t<T>& buf, int i, int j, int k, int l, const array_t<T>& adj_buf, int adj_i, int adj_j, int adj_k, int adj_l, const T& adj_output)
876
+ {
877
+ if (adj_buf.data)
878
+ adj_atomic_add(&index(adj_buf, i, j, k, l), adj_output);
879
+ else if (buf.grad)
880
+ adj_atomic_add(&index_grad(buf, i, j, k, l), adj_output);
881
+ }
882
+
883
+ template<typename T>
884
+ inline CUDA_CALLABLE void adj_array_store(const array_t<T>& buf, int i, T value, const array_t<T>& adj_buf, int adj_i, T& adj_value)
885
+ {
886
+ if (adj_buf.data)
887
+ adj_value += index(adj_buf, i);
888
+ else if (buf.grad)
889
+ adj_value += index_grad(buf, i);
890
+
891
+ FP_VERIFY_ADJ_1(value, adj_value)
892
+ }
893
+ template<typename T>
894
+ inline CUDA_CALLABLE void adj_array_store(const array_t<T>& buf, int i, int j, T value, const array_t<T>& adj_buf, int adj_i, int adj_j, T& adj_value)
895
+ {
896
+ if (adj_buf.data)
897
+ adj_value += index(adj_buf, i, j);
898
+ else if (buf.grad)
899
+ adj_value += index_grad(buf, i, j);
900
+
901
+ FP_VERIFY_ADJ_2(value, adj_value)
902
+ }
903
+ template<typename T>
904
+ inline CUDA_CALLABLE void adj_array_store(const array_t<T>& buf, int i, int j, int k, T value, const array_t<T>& adj_buf, int adj_i, int adj_j, int adj_k, T& adj_value)
905
+ {
906
+ if (adj_buf.data)
907
+ adj_value += index(adj_buf, i, j, k);
908
+ else if (buf.grad)
909
+ adj_value += index_grad(buf, i, j, k);
910
+
911
+ FP_VERIFY_ADJ_3(value, adj_value)
912
+ }
913
+ template<typename T>
914
+ inline CUDA_CALLABLE void adj_array_store(const array_t<T>& buf, int i, int j, int k, int l, T value, const array_t<T>& adj_buf, int adj_i, int adj_j, int adj_k, int adj_l, T& adj_value)
915
+ {
916
+ if (adj_buf.data)
917
+ adj_value += index(adj_buf, i, j, k, l);
918
+ else if (buf.grad)
919
+ adj_value += index_grad(buf, i, j, k, l);
920
+
921
+ FP_VERIFY_ADJ_4(value, adj_value)
922
+ }
923
+
924
+ template<typename T>
925
+ inline CUDA_CALLABLE void adj_store(const T* address, T value, const T& adj_address, T& adj_value)
926
+ {
927
+ // nop; generic store() operations are not differentiable, only array_store() is
928
+ FP_VERIFY_ADJ(value, adj_value)
929
+ }
930
+
931
+ template<typename T>
932
+ inline CUDA_CALLABLE void adj_load(const T* address, const T& adj_address, T& adj_value)
933
+ {
934
+ // nop; generic load() operations are not differentiable
935
+ }
936
+
937
+ template<typename T>
938
+ inline CUDA_CALLABLE void adj_atomic_add(const array_t<T>& buf, int i, T value, const array_t<T>& adj_buf, int adj_i, T& adj_value, const T& adj_ret)
939
+ {
940
+ if (adj_buf.data)
941
+ adj_value += index(adj_buf, i);
942
+ else if (buf.grad)
943
+ adj_value += index_grad(buf, i);
944
+
945
+ FP_VERIFY_ADJ_1(value, adj_value)
946
+ }
947
+ template<typename T>
948
+ inline CUDA_CALLABLE void adj_atomic_add(const array_t<T>& buf, int i, int j, T value, const array_t<T>& adj_buf, int adj_i, int adj_j, T& adj_value, const T& adj_ret)
949
+ {
950
+ if (adj_buf.data)
951
+ adj_value += index(adj_buf, i, j);
952
+ else if (buf.grad)
953
+ adj_value += index_grad(buf, i, j);
954
+
955
+ FP_VERIFY_ADJ_2(value, adj_value)
956
+ }
957
+ template<typename T>
958
+ inline CUDA_CALLABLE void adj_atomic_add(const array_t<T>& buf, int i, int j, int k, T value, const array_t<T>& adj_buf, int adj_i, int adj_j, int adj_k, T& adj_value, const T& adj_ret)
959
+ {
960
+ if (adj_buf.data)
961
+ adj_value += index(adj_buf, i, j, k);
962
+ else if (buf.grad)
963
+ adj_value += index_grad(buf, i, j, k);
964
+
965
+ FP_VERIFY_ADJ_3(value, adj_value)
966
+ }
967
+ template<typename T>
968
+ inline CUDA_CALLABLE void adj_atomic_add(const array_t<T>& buf, int i, int j, int k, int l, T value, const array_t<T>& adj_buf, int adj_i, int adj_j, int adj_k, int adj_l, T& adj_value, const T& adj_ret)
969
+ {
970
+ if (adj_buf.data)
971
+ adj_value += index(adj_buf, i, j, k, l);
972
+ else if (buf.grad)
973
+ adj_value += index_grad(buf, i, j, k, l);
974
+
975
+ FP_VERIFY_ADJ_4(value, adj_value)
976
+ }
977
+
978
+ template<typename T>
979
+ inline CUDA_CALLABLE void adj_atomic_sub(const array_t<T>& buf, int i, T value, const array_t<T>& adj_buf, int adj_i, T& adj_value, const T& adj_ret)
980
+ {
981
+ if (adj_buf.data)
982
+ adj_value -= index(adj_buf, i);
983
+ else if (buf.grad)
984
+ adj_value -= index_grad(buf, i);
985
+
986
+ FP_VERIFY_ADJ_1(value, adj_value)
987
+ }
988
+ template<typename T>
989
+ inline CUDA_CALLABLE void adj_atomic_sub(const array_t<T>& buf, int i, int j, T value, const array_t<T>& adj_buf, int adj_i, int adj_j, T& adj_value, const T& adj_ret)
990
+ {
991
+ if (adj_buf.data)
992
+ adj_value -= index(adj_buf, i, j);
993
+ else if (buf.grad)
994
+ adj_value -= index_grad(buf, i, j);
995
+
996
+ FP_VERIFY_ADJ_2(value, adj_value)
997
+ }
998
+ template<typename T>
999
+ inline CUDA_CALLABLE void adj_atomic_sub(const array_t<T>& buf, int i, int j, int k, T value, const array_t<T>& adj_buf, int adj_i, int adj_j, int adj_k, T& adj_value, const T& adj_ret)
1000
+ {
1001
+ if (adj_buf.data)
1002
+ adj_value -= index(adj_buf, i, j, k);
1003
+ else if (buf.grad)
1004
+ adj_value -= index_grad(buf, i, j, k);
1005
+
1006
+ FP_VERIFY_ADJ_3(value, adj_value)
1007
+ }
1008
+ template<typename T>
1009
+ inline CUDA_CALLABLE void adj_atomic_sub(const array_t<T>& buf, int i, int j, int k, int l, T value, const array_t<T>& adj_buf, int adj_i, int adj_j, int adj_k, int adj_l, T& adj_value, const T& adj_ret)
1010
+ {
1011
+ if (adj_buf.data)
1012
+ adj_value -= index(adj_buf, i, j, k, l);
1013
+ else if (buf.grad)
1014
+ adj_value -= index_grad(buf, i, j, k, l);
1015
+
1016
+ FP_VERIFY_ADJ_4(value, adj_value)
1017
+ }
1018
+
1019
+ // generic array types that do not support gradient computation (indexedarray, etc.)
1020
+ template<template<typename> class A1, template<typename> class A2, typename T>
1021
+ inline CUDA_CALLABLE void adj_address(const A1<T>& buf, int i, const A2<T>& adj_buf, int adj_i, const T& adj_output) {}
1022
+ template<template<typename> class A1, template<typename> class A2, typename T>
1023
+ inline CUDA_CALLABLE void adj_address(const A1<T>& buf, int i, int j, const A2<T>& adj_buf, int adj_i, int adj_j, const T& adj_output) {}
1024
+ template<template<typename> class A1, template<typename> class A2, typename T>
1025
+ inline CUDA_CALLABLE void adj_address(const A1<T>& buf, int i, int j, int k, const A2<T>& adj_buf, int adj_i, int adj_j, int adj_k, const T& adj_output) {}
1026
+ template<template<typename> class A1, template<typename> class A2, typename T>
1027
+ inline CUDA_CALLABLE void adj_address(const A1<T>& buf, int i, int j, int k, int l, const A2<T>& adj_buf, int adj_i, int adj_j, int adj_k, int adj_l, const T& adj_output) {}
1028
+
1029
+ template<template<typename> class A1, template<typename> class A2, typename T>
1030
+ inline CUDA_CALLABLE void adj_array_store(const A1<T>& buf, int i, T value, const A2<T>& adj_buf, int adj_i, T& adj_value) {}
1031
+ template<template<typename> class A1, template<typename> class A2, typename T>
1032
+ inline CUDA_CALLABLE void adj_array_store(const A1<T>& buf, int i, int j, T value, const A2<T>& adj_buf, int adj_i, int adj_j, T& adj_value) {}
1033
+ template<template<typename> class A1, template<typename> class A2, typename T>
1034
+ inline CUDA_CALLABLE void adj_array_store(const A1<T>& buf, int i, int j, int k, T value, const A2<T>& adj_buf, int adj_i, int adj_j, int adj_k, T& adj_value) {}
1035
+ template<template<typename> class A1, template<typename> class A2, typename T>
1036
+ inline CUDA_CALLABLE void adj_array_store(const A1<T>& buf, int i, int j, int k, int l, T value, const A2<T>& adj_buf, int adj_i, int adj_j, int adj_k, int adj_l, T& adj_value) {}
1037
+
1038
+ template<template<typename> class A1, template<typename> class A2, typename T>
1039
+ inline CUDA_CALLABLE void adj_atomic_add(const A1<T>& buf, int i, T value, const A2<T>& adj_buf, int adj_i, T& adj_value, const T& adj_ret) {}
1040
+ template<template<typename> class A1, template<typename> class A2, typename T>
1041
+ inline CUDA_CALLABLE void adj_atomic_add(const A1<T>& buf, int i, int j, T value, const A2<T>& adj_buf, int adj_i, int adj_j, T& adj_value, const T& adj_ret) {}
1042
+ template<template<typename> class A1, template<typename> class A2, typename T>
1043
+ inline CUDA_CALLABLE void adj_atomic_add(const A1<T>& buf, int i, int j, int k, T value, const A2<T>& adj_buf, int adj_i, int adj_j, int adj_k, T& adj_value, const T& adj_ret) {}
1044
+ template<template<typename> class A1, template<typename> class A2, typename T>
1045
+ inline CUDA_CALLABLE void adj_atomic_add(const A1<T>& buf, int i, int j, int k, int l, T value, const A2<T>& adj_buf, int adj_i, int adj_j, int adj_k, int adj_l, T& adj_value, const T& adj_ret) {}
1046
+
1047
+ template<template<typename> class A1, template<typename> class A2, typename T>
1048
+ inline CUDA_CALLABLE void adj_atomic_sub(const A1<T>& buf, int i, T value, const A2<T>& adj_buf, int adj_i, T& adj_value, const T& adj_ret) {}
1049
+ template<template<typename> class A1, template<typename> class A2, typename T>
1050
+ inline CUDA_CALLABLE void adj_atomic_sub(const A1<T>& buf, int i, int j, T value, const A2<T>& adj_buf, int adj_i, int adj_j, T& adj_value, const T& adj_ret) {}
1051
+ template<template<typename> class A1, template<typename> class A2, typename T>
1052
+ inline CUDA_CALLABLE void adj_atomic_sub(const A1<T>& buf, int i, int j, int k, T value, const A2<T>& adj_buf, int adj_i, int adj_j, int adj_k, T& adj_value, const T& adj_ret) {}
1053
+ template<template<typename> class A1, template<typename> class A2, typename T>
1054
+ inline CUDA_CALLABLE void adj_atomic_sub(const A1<T>& buf, int i, int j, int k, int l, T value, const A2<T>& adj_buf, int adj_i, int adj_j, int adj_k, int adj_l, T& adj_value, const T& adj_ret) {}
1055
+
1056
+ // generic handler for scalar values
1057
+ template<template<typename> class A1, template<typename> class A2, typename T>
1058
+ inline CUDA_CALLABLE void adj_atomic_min(const A1<T>& buf, int i, T value, const A2<T>& adj_buf, int adj_i, T& adj_value, const T& adj_ret) {
1059
+ if (adj_buf.data)
1060
+ adj_atomic_minmax(&index(buf, i), &index(adj_buf, i), value, adj_value);
1061
+ else if (buf.grad)
1062
+ adj_atomic_minmax(&index(buf, i), &index_grad(buf, i), value, adj_value);
1063
+
1064
+ FP_VERIFY_ADJ_1(value, adj_value)
1065
+ }
1066
+ template<template<typename> class A1, template<typename> class A2, typename T>
1067
+ inline CUDA_CALLABLE void adj_atomic_min(const A1<T>& buf, int i, int j, T value, const A2<T>& adj_buf, int adj_i, int adj_j, T& adj_value, const T& adj_ret) {
1068
+ if (adj_buf.data)
1069
+ adj_atomic_minmax(&index(buf, i, j), &index(adj_buf, i, j), value, adj_value);
1070
+ else if (buf.grad)
1071
+ adj_atomic_minmax(&index(buf, i, j), &index_grad(buf, i, j), value, adj_value);
1072
+
1073
+ FP_VERIFY_ADJ_2(value, adj_value)
1074
+ }
1075
+ template<template<typename> class A1, template<typename> class A2, typename T>
1076
+ inline CUDA_CALLABLE void adj_atomic_min(const A1<T>& buf, int i, int j, int k, T value, const A2<T>& adj_buf, int adj_i, int adj_j, int adj_k, T& adj_value, const T& adj_ret) {
1077
+ if (adj_buf.data)
1078
+ adj_atomic_minmax(&index(buf, i, j, k), &index(adj_buf, i, j, k), value, adj_value);
1079
+ else if (buf.grad)
1080
+ adj_atomic_minmax(&index(buf, i, j, k), &index_grad(buf, i, j, k), value, adj_value);
1081
+
1082
+ FP_VERIFY_ADJ_3(value, adj_value)
1083
+ }
1084
+ template<template<typename> class A1, template<typename> class A2, typename T>
1085
+ inline CUDA_CALLABLE void adj_atomic_min(const A1<T>& buf, int i, int j, int k, int l, T value, const A2<T>& adj_buf, int adj_i, int adj_j, int adj_k, int adj_l, T& adj_value, const T& adj_ret) {
1086
+ if (adj_buf.data)
1087
+ adj_atomic_minmax(&index(buf, i, j, k, l), &index(adj_buf, i, j, k, l), value, adj_value);
1088
+ else if (buf.grad)
1089
+ adj_atomic_minmax(&index(buf, i, j, k, l), &index_grad(buf, i, j, k, l), value, adj_value);
1090
+
1091
+ FP_VERIFY_ADJ_4(value, adj_value)
1092
+ }
1093
+
1094
+ template<template<typename> class A1, template<typename> class A2, typename T>
1095
+ inline CUDA_CALLABLE void adj_atomic_max(const A1<T>& buf, int i, T value, const A2<T>& adj_buf, int adj_i, T& adj_value, const T& adj_ret) {
1096
+ if (adj_buf.data)
1097
+ adj_atomic_minmax(&index(buf, i), &index(adj_buf, i), value, adj_value);
1098
+ else if (buf.grad)
1099
+ adj_atomic_minmax(&index(buf, i), &index_grad(buf, i), value, adj_value);
1100
+
1101
+ FP_VERIFY_ADJ_1(value, adj_value)
1102
+ }
1103
+ template<template<typename> class A1, template<typename> class A2, typename T>
1104
+ inline CUDA_CALLABLE void adj_atomic_max(const A1<T>& buf, int i, int j, T value, const A2<T>& adj_buf, int adj_i, int adj_j, T& adj_value, const T& adj_ret) {
1105
+ if (adj_buf.data)
1106
+ adj_atomic_minmax(&index(buf, i, j), &index(adj_buf, i, j), value, adj_value);
1107
+ else if (buf.grad)
1108
+ adj_atomic_minmax(&index(buf, i, j), &index_grad(buf, i, j), value, adj_value);
1109
+
1110
+ FP_VERIFY_ADJ_2(value, adj_value)
1111
+ }
1112
+ template<template<typename> class A1, template<typename> class A2, typename T>
1113
+ inline CUDA_CALLABLE void adj_atomic_max(const A1<T>& buf, int i, int j, int k, T value, const A2<T>& adj_buf, int adj_i, int adj_j, int adj_k, T& adj_value, const T& adj_ret) {
1114
+ if (adj_buf.data)
1115
+ adj_atomic_minmax(&index(buf, i, j, k), &index(adj_buf, i, j, k), value, adj_value);
1116
+ else if (buf.grad)
1117
+ adj_atomic_minmax(&index(buf, i, j, k), &index_grad(buf, i, j, k), value, adj_value);
1118
+
1119
+ FP_VERIFY_ADJ_3(value, adj_value)
1120
+ }
1121
+ template<template<typename> class A1, template<typename> class A2, typename T>
1122
+ inline CUDA_CALLABLE void adj_atomic_max(const A1<T>& buf, int i, int j, int k, int l, T value, const A2<T>& adj_buf, int adj_i, int adj_j, int adj_k, int adj_l, T& adj_value, const T& adj_ret) {
1123
+ if (adj_buf.data)
1124
+ adj_atomic_minmax(&index(buf, i, j, k, l), &index(adj_buf, i, j, k, l), value, adj_value);
1125
+ else if (buf.grad)
1126
+ adj_atomic_minmax(&index(buf, i, j, k, l), &index_grad(buf, i, j, k, l), value, adj_value);
1127
+
1128
+ FP_VERIFY_ADJ_4(value, adj_value)
1129
+ }
1130
+
1131
+ template<template<typename> class A, typename T>
1132
+ CUDA_CALLABLE inline int len(const A<T>& a)
1133
+ {
1134
+ return a.shape[0];
1135
+ }
1136
+
1137
+ template<template<typename> class A, typename T>
1138
+ CUDA_CALLABLE inline void adj_len(const A<T>& a, A<T>& adj_a, int& adj_ret)
1139
+ {
1140
+ }
1141
+
1142
+
1143
+ } // namespace wp
1144
+
1145
+ #include "fabric.h"