warp-lang 1.9.1__py3-none-win_amd64.whl → 1.10.0rc2__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (346) hide show
  1. warp/__init__.py +301 -287
  2. warp/__init__.pyi +794 -305
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1075 -0
  5. warp/_src/build.py +618 -0
  6. warp/_src/build_dll.py +640 -0
  7. warp/{builtins.py → _src/builtins.py} +1382 -377
  8. warp/_src/codegen.py +4359 -0
  9. warp/{config.py → _src/config.py} +178 -169
  10. warp/_src/constants.py +57 -0
  11. warp/_src/context.py +8294 -0
  12. warp/_src/dlpack.py +462 -0
  13. warp/_src/fabric.py +355 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +508 -0
  16. warp/_src/fem/cache.py +687 -0
  17. warp/_src/fem/dirichlet.py +188 -0
  18. warp/{fem → _src/fem}/domain.py +40 -30
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +701 -0
  21. warp/{fem → _src/fem}/field/nodal_field.py +30 -15
  22. warp/{fem → _src/fem}/field/restriction.py +1 -1
  23. warp/{fem → _src/fem}/field/virtual.py +53 -27
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +77 -163
  26. warp/_src/fem/geometry/closest_point.py +97 -0
  27. warp/{fem → _src/fem}/geometry/deformed_geometry.py +14 -22
  28. warp/{fem → _src/fem}/geometry/element.py +32 -10
  29. warp/{fem → _src/fem}/geometry/geometry.py +48 -20
  30. warp/{fem → _src/fem}/geometry/grid_2d.py +12 -23
  31. warp/{fem → _src/fem}/geometry/grid_3d.py +12 -23
  32. warp/{fem → _src/fem}/geometry/hexmesh.py +40 -63
  33. warp/{fem → _src/fem}/geometry/nanogrid.py +255 -248
  34. warp/{fem → _src/fem}/geometry/partition.py +121 -63
  35. warp/{fem → _src/fem}/geometry/quadmesh.py +26 -45
  36. warp/{fem → _src/fem}/geometry/tetmesh.py +40 -63
  37. warp/{fem → _src/fem}/geometry/trimesh.py +26 -45
  38. warp/{fem → _src/fem}/integrate.py +164 -158
  39. warp/_src/fem/linalg.py +383 -0
  40. warp/_src/fem/operator.py +396 -0
  41. warp/_src/fem/polynomial.py +229 -0
  42. warp/{fem → _src/fem}/quadrature/pic_quadrature.py +15 -20
  43. warp/{fem → _src/fem}/quadrature/quadrature.py +95 -47
  44. warp/_src/fem/space/__init__.py +248 -0
  45. warp/{fem → _src/fem}/space/basis_function_space.py +20 -11
  46. warp/_src/fem/space/basis_space.py +679 -0
  47. warp/{fem → _src/fem}/space/dof_mapper.py +3 -3
  48. warp/{fem → _src/fem}/space/function_space.py +14 -13
  49. warp/{fem → _src/fem}/space/grid_2d_function_space.py +4 -7
  50. warp/{fem → _src/fem}/space/grid_3d_function_space.py +4 -4
  51. warp/{fem → _src/fem}/space/hexmesh_function_space.py +4 -10
  52. warp/{fem → _src/fem}/space/nanogrid_function_space.py +3 -9
  53. warp/{fem → _src/fem}/space/partition.py +117 -60
  54. warp/{fem → _src/fem}/space/quadmesh_function_space.py +4 -10
  55. warp/{fem → _src/fem}/space/restriction.py +66 -33
  56. warp/_src/fem/space/shape/__init__.py +152 -0
  57. warp/{fem → _src/fem}/space/shape/cube_shape_function.py +9 -9
  58. warp/{fem → _src/fem}/space/shape/shape_function.py +8 -9
  59. warp/{fem → _src/fem}/space/shape/square_shape_function.py +6 -6
  60. warp/{fem → _src/fem}/space/shape/tet_shape_function.py +3 -3
  61. warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +3 -3
  62. warp/{fem → _src/fem}/space/tetmesh_function_space.py +3 -9
  63. warp/_src/fem/space/topology.py +459 -0
  64. warp/{fem → _src/fem}/space/trimesh_function_space.py +3 -9
  65. warp/_src/fem/types.py +112 -0
  66. warp/_src/fem/utils.py +486 -0
  67. warp/_src/jax.py +186 -0
  68. warp/_src/jax_experimental/__init__.py +14 -0
  69. warp/_src/jax_experimental/custom_call.py +387 -0
  70. warp/_src/jax_experimental/ffi.py +1284 -0
  71. warp/_src/jax_experimental/xla_ffi.py +656 -0
  72. warp/_src/marching_cubes.py +708 -0
  73. warp/_src/math.py +414 -0
  74. warp/_src/optim/__init__.py +14 -0
  75. warp/_src/optim/adam.py +163 -0
  76. warp/_src/optim/linear.py +1606 -0
  77. warp/_src/optim/sgd.py +112 -0
  78. warp/_src/paddle.py +406 -0
  79. warp/_src/render/__init__.py +14 -0
  80. warp/_src/render/imgui_manager.py +289 -0
  81. warp/_src/render/render_opengl.py +3636 -0
  82. warp/_src/render/render_usd.py +937 -0
  83. warp/_src/render/utils.py +160 -0
  84. warp/_src/sparse.py +2716 -0
  85. warp/_src/tape.py +1206 -0
  86. warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
  87. warp/_src/torch.py +391 -0
  88. warp/_src/types.py +5870 -0
  89. warp/_src/utils.py +1693 -0
  90. warp/autograd.py +12 -1054
  91. warp/bin/warp-clang.dll +0 -0
  92. warp/bin/warp.dll +0 -0
  93. warp/build.py +8 -588
  94. warp/build_dll.py +6 -721
  95. warp/codegen.py +6 -4251
  96. warp/constants.py +6 -39
  97. warp/context.py +12 -8062
  98. warp/dlpack.py +6 -444
  99. warp/examples/distributed/example_jacobi_mpi.py +4 -5
  100. warp/examples/fem/example_adaptive_grid.py +1 -1
  101. warp/examples/fem/example_apic_fluid.py +1 -1
  102. warp/examples/fem/example_burgers.py +8 -8
  103. warp/examples/fem/example_diffusion.py +1 -1
  104. warp/examples/fem/example_distortion_energy.py +1 -1
  105. warp/examples/fem/example_mixed_elasticity.py +2 -2
  106. warp/examples/fem/example_navier_stokes.py +1 -1
  107. warp/examples/fem/example_nonconforming_contact.py +7 -7
  108. warp/examples/fem/example_stokes.py +1 -1
  109. warp/examples/fem/example_stokes_transfer.py +1 -1
  110. warp/examples/fem/utils.py +2 -2
  111. warp/examples/interop/example_jax_callable.py +1 -1
  112. warp/examples/interop/example_jax_ffi_callback.py +1 -1
  113. warp/examples/interop/example_jax_kernel.py +1 -1
  114. warp/examples/tile/example_tile_mcgp.py +191 -0
  115. warp/fabric.py +6 -337
  116. warp/fem/__init__.py +159 -97
  117. warp/fem/adaptivity.py +7 -489
  118. warp/fem/cache.py +9 -648
  119. warp/fem/dirichlet.py +6 -184
  120. warp/fem/field/__init__.py +8 -109
  121. warp/fem/field/field.py +7 -652
  122. warp/fem/geometry/__init__.py +7 -18
  123. warp/fem/geometry/closest_point.py +11 -77
  124. warp/fem/linalg.py +18 -366
  125. warp/fem/operator.py +11 -369
  126. warp/fem/polynomial.py +9 -209
  127. warp/fem/space/__init__.py +5 -211
  128. warp/fem/space/basis_space.py +6 -662
  129. warp/fem/space/shape/__init__.py +41 -118
  130. warp/fem/space/topology.py +6 -437
  131. warp/fem/types.py +6 -81
  132. warp/fem/utils.py +11 -444
  133. warp/jax.py +8 -165
  134. warp/jax_experimental/__init__.py +14 -1
  135. warp/jax_experimental/custom_call.py +8 -365
  136. warp/jax_experimental/ffi.py +17 -873
  137. warp/jax_experimental/xla_ffi.py +5 -605
  138. warp/marching_cubes.py +5 -689
  139. warp/math.py +16 -393
  140. warp/native/array.h +385 -37
  141. warp/native/builtin.h +314 -37
  142. warp/native/bvh.cpp +43 -9
  143. warp/native/bvh.cu +62 -27
  144. warp/native/bvh.h +310 -309
  145. warp/native/clang/clang.cpp +102 -97
  146. warp/native/coloring.cpp +0 -1
  147. warp/native/crt.h +208 -0
  148. warp/native/exports.h +156 -0
  149. warp/native/hashgrid.cu +2 -0
  150. warp/native/intersect.h +24 -1
  151. warp/native/intersect_tri.h +44 -35
  152. warp/native/mat.h +1456 -276
  153. warp/native/mesh.cpp +4 -4
  154. warp/native/mesh.cu +4 -2
  155. warp/native/mesh.h +176 -61
  156. warp/native/quat.h +0 -52
  157. warp/native/scan.cu +2 -0
  158. warp/native/sparse.cu +7 -3
  159. warp/native/spatial.h +12 -0
  160. warp/native/tile.h +681 -89
  161. warp/native/tile_radix_sort.h +1 -1
  162. warp/native/tile_reduce.h +394 -46
  163. warp/native/tile_scan.h +4 -4
  164. warp/native/vec.h +469 -0
  165. warp/native/version.h +23 -0
  166. warp/native/volume.cpp +1 -1
  167. warp/native/volume.cu +1 -0
  168. warp/native/volume.h +1 -1
  169. warp/native/volume_builder.cu +2 -0
  170. warp/native/warp.cpp +57 -29
  171. warp/native/warp.cu +253 -171
  172. warp/native/warp.h +11 -8
  173. warp/optim/__init__.py +6 -3
  174. warp/optim/adam.py +6 -145
  175. warp/optim/linear.py +14 -1585
  176. warp/optim/sgd.py +6 -94
  177. warp/paddle.py +6 -388
  178. warp/render/__init__.py +8 -4
  179. warp/render/imgui_manager.py +7 -267
  180. warp/render/render_opengl.py +6 -3618
  181. warp/render/render_usd.py +6 -919
  182. warp/render/utils.py +6 -142
  183. warp/sparse.py +37 -2563
  184. warp/tape.py +6 -1188
  185. warp/tests/__main__.py +1 -1
  186. warp/tests/cuda/test_async.py +4 -4
  187. warp/tests/cuda/test_conditional_captures.py +1 -1
  188. warp/tests/cuda/test_multigpu.py +1 -1
  189. warp/tests/cuda/test_streams.py +58 -1
  190. warp/tests/geometry/test_bvh.py +157 -22
  191. warp/tests/geometry/test_marching_cubes.py +0 -1
  192. warp/tests/geometry/test_mesh.py +5 -3
  193. warp/tests/geometry/test_mesh_query_aabb.py +5 -12
  194. warp/tests/geometry/test_mesh_query_point.py +5 -2
  195. warp/tests/geometry/test_mesh_query_ray.py +15 -3
  196. warp/tests/geometry/test_volume_write.py +5 -5
  197. warp/tests/interop/test_dlpack.py +14 -14
  198. warp/tests/interop/test_jax.py +772 -49
  199. warp/tests/interop/test_paddle.py +1 -1
  200. warp/tests/test_adam.py +0 -1
  201. warp/tests/test_arithmetic.py +9 -9
  202. warp/tests/test_array.py +527 -100
  203. warp/tests/test_array_reduce.py +3 -3
  204. warp/tests/test_atomic.py +12 -8
  205. warp/tests/test_atomic_bitwise.py +209 -0
  206. warp/tests/test_atomic_cas.py +4 -4
  207. warp/tests/test_bool.py +2 -2
  208. warp/tests/test_builtins_resolution.py +5 -571
  209. warp/tests/test_codegen.py +33 -14
  210. warp/tests/test_conditional.py +1 -1
  211. warp/tests/test_context.py +6 -6
  212. warp/tests/test_copy.py +242 -161
  213. warp/tests/test_ctypes.py +3 -3
  214. warp/tests/test_devices.py +24 -2
  215. warp/tests/test_examples.py +16 -84
  216. warp/tests/test_fabricarray.py +35 -35
  217. warp/tests/test_fast_math.py +0 -2
  218. warp/tests/test_fem.py +56 -10
  219. warp/tests/test_fixedarray.py +3 -3
  220. warp/tests/test_func.py +8 -5
  221. warp/tests/test_generics.py +1 -1
  222. warp/tests/test_indexedarray.py +24 -24
  223. warp/tests/test_intersect.py +39 -9
  224. warp/tests/test_large.py +1 -1
  225. warp/tests/test_lerp.py +3 -1
  226. warp/tests/test_linear_solvers.py +1 -1
  227. warp/tests/test_map.py +35 -4
  228. warp/tests/test_mat.py +52 -62
  229. warp/tests/test_mat_constructors.py +4 -5
  230. warp/tests/test_mat_lite.py +1 -1
  231. warp/tests/test_mat_scalar_ops.py +121 -121
  232. warp/tests/test_math.py +34 -0
  233. warp/tests/test_module_aot.py +4 -4
  234. warp/tests/test_modules_lite.py +28 -2
  235. warp/tests/test_print.py +11 -11
  236. warp/tests/test_quat.py +93 -58
  237. warp/tests/test_runlength_encode.py +1 -1
  238. warp/tests/test_scalar_ops.py +38 -10
  239. warp/tests/test_smoothstep.py +1 -1
  240. warp/tests/test_sparse.py +126 -15
  241. warp/tests/test_spatial.py +105 -87
  242. warp/tests/test_special_values.py +6 -6
  243. warp/tests/test_static.py +7 -7
  244. warp/tests/test_struct.py +13 -2
  245. warp/tests/test_triangle_closest_point.py +48 -1
  246. warp/tests/test_types.py +27 -15
  247. warp/tests/test_utils.py +52 -52
  248. warp/tests/test_vec.py +29 -29
  249. warp/tests/test_vec_constructors.py +5 -5
  250. warp/tests/test_vec_scalar_ops.py +97 -97
  251. warp/tests/test_version.py +75 -0
  252. warp/tests/tile/test_tile.py +178 -0
  253. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  254. warp/tests/tile/test_tile_cholesky.py +7 -4
  255. warp/tests/tile/test_tile_load.py +26 -2
  256. warp/tests/tile/test_tile_mathdx.py +3 -3
  257. warp/tests/tile/test_tile_matmul.py +1 -1
  258. warp/tests/tile/test_tile_mlp.py +2 -4
  259. warp/tests/tile/test_tile_reduce.py +214 -13
  260. warp/tests/unittest_suites.py +6 -14
  261. warp/tests/unittest_utils.py +10 -9
  262. warp/tests/walkthrough_debug.py +3 -1
  263. warp/torch.py +6 -373
  264. warp/types.py +29 -5764
  265. warp/utils.py +10 -1659
  266. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/METADATA +46 -99
  267. warp_lang-1.10.0rc2.dist-info/RECORD +468 -0
  268. warp_lang-1.10.0rc2.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  269. warp_lang-1.10.0rc2.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  270. warp_lang-1.10.0rc2.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  271. warp_lang-1.10.0rc2.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  272. warp_lang-1.10.0rc2.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  273. warp_lang-1.10.0rc2.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  274. warp_lang-1.10.0rc2.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  275. warp_lang-1.10.0rc2.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  276. warp_lang-1.10.0rc2.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  277. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  278. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  279. warp_lang-1.10.0rc2.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  280. warp_lang-1.10.0rc2.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  281. warp_lang-1.10.0rc2.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  282. warp_lang-1.10.0rc2.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  283. warp/examples/assets/cartpole.urdf +0 -110
  284. warp/examples/assets/crazyflie.usd +0 -0
  285. warp/examples/assets/nv_ant.xml +0 -92
  286. warp/examples/assets/nv_humanoid.xml +0 -183
  287. warp/examples/assets/quadruped.urdf +0 -268
  288. warp/examples/optim/example_bounce.py +0 -266
  289. warp/examples/optim/example_cloth_throw.py +0 -228
  290. warp/examples/optim/example_drone.py +0 -870
  291. warp/examples/optim/example_inverse_kinematics.py +0 -182
  292. warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
  293. warp/examples/optim/example_softbody_properties.py +0 -400
  294. warp/examples/optim/example_spring_cage.py +0 -245
  295. warp/examples/optim/example_trajectory.py +0 -227
  296. warp/examples/sim/example_cartpole.py +0 -143
  297. warp/examples/sim/example_cloth.py +0 -225
  298. warp/examples/sim/example_cloth_self_contact.py +0 -316
  299. warp/examples/sim/example_granular.py +0 -130
  300. warp/examples/sim/example_granular_collision_sdf.py +0 -202
  301. warp/examples/sim/example_jacobian_ik.py +0 -244
  302. warp/examples/sim/example_particle_chain.py +0 -124
  303. warp/examples/sim/example_quadruped.py +0 -203
  304. warp/examples/sim/example_rigid_chain.py +0 -203
  305. warp/examples/sim/example_rigid_contact.py +0 -195
  306. warp/examples/sim/example_rigid_force.py +0 -133
  307. warp/examples/sim/example_rigid_gyroscopic.py +0 -115
  308. warp/examples/sim/example_rigid_soft_contact.py +0 -140
  309. warp/examples/sim/example_soft_body.py +0 -196
  310. warp/examples/tile/example_tile_walker.py +0 -327
  311. warp/sim/__init__.py +0 -74
  312. warp/sim/articulation.py +0 -793
  313. warp/sim/collide.py +0 -2570
  314. warp/sim/graph_coloring.py +0 -307
  315. warp/sim/import_mjcf.py +0 -791
  316. warp/sim/import_snu.py +0 -227
  317. warp/sim/import_urdf.py +0 -579
  318. warp/sim/import_usd.py +0 -898
  319. warp/sim/inertia.py +0 -357
  320. warp/sim/integrator.py +0 -245
  321. warp/sim/integrator_euler.py +0 -2000
  322. warp/sim/integrator_featherstone.py +0 -2101
  323. warp/sim/integrator_vbd.py +0 -2487
  324. warp/sim/integrator_xpbd.py +0 -3295
  325. warp/sim/model.py +0 -4821
  326. warp/sim/particles.py +0 -121
  327. warp/sim/render.py +0 -431
  328. warp/sim/utils.py +0 -431
  329. warp/tests/sim/disabled_kinematics.py +0 -244
  330. warp/tests/sim/test_cloth.py +0 -863
  331. warp/tests/sim/test_collision.py +0 -743
  332. warp/tests/sim/test_coloring.py +0 -347
  333. warp/tests/sim/test_inertia.py +0 -161
  334. warp/tests/sim/test_model.py +0 -226
  335. warp/tests/sim/test_sim_grad.py +0 -287
  336. warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
  337. warp/tests/sim/test_sim_kinematics.py +0 -98
  338. warp/thirdparty/__init__.py +0 -0
  339. warp_lang-1.9.1.dist-info/RECORD +0 -456
  340. /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
  341. /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
  342. /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
  343. /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
  344. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/WHEEL +0 -0
  345. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/licenses/LICENSE.md +0 -0
  346. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/top_level.txt +0 -0
warp/native/array.h CHANGED
@@ -118,6 +118,23 @@ namespace wp
118
118
 
119
119
  #endif // WP_FP_CHECK
120
120
 
121
+
122
+ template<size_t... Is>
123
+ struct index_sequence {};
124
+
125
+ template<size_t N, size_t... Is>
126
+ struct make_index_sequence_impl : make_index_sequence_impl<N-1, N-1, Is...> {};
127
+
128
+ template<size_t... Is>
129
+ struct make_index_sequence_impl<0, Is...>
130
+ {
131
+ using type = index_sequence<Is...>;
132
+ };
133
+
134
+ template<size_t N>
135
+ using make_index_sequence = typename make_index_sequence_impl<N>::type;
136
+
137
+
121
138
  const int ARRAY_MAX_DIMS = 4; // must match constant in types.py
122
139
 
123
140
  // must match constants in types.py
@@ -423,6 +440,13 @@ template <typename T>
423
440
  CUDA_CALLABLE inline T& index(const array_t<T>& arr, int i)
424
441
  {
425
442
  assert(arr.ndim == 1);
443
+ assert(i >= -arr.shape[0] && i < arr.shape[0]);
444
+
445
+ if (i < 0)
446
+ {
447
+ i += arr.shape[0];
448
+ }
449
+
426
450
  T& result = *data_at_byte_offset(arr, byte_offset(arr, i));
427
451
  FP_VERIFY_FWD_1(result)
428
452
 
@@ -433,6 +457,18 @@ template <typename T>
433
457
  CUDA_CALLABLE inline T& index(const array_t<T>& arr, int i, int j)
434
458
  {
435
459
  assert(arr.ndim == 2);
460
+ assert(i >= -arr.shape[0] && i < arr.shape[0]);
461
+ assert(j >= -arr.shape[1] && j < arr.shape[1]);
462
+
463
+ if (i < 0)
464
+ {
465
+ i += arr.shape[0];
466
+ }
467
+ if (j < 0)
468
+ {
469
+ j += arr.shape[1];
470
+ }
471
+
436
472
  T& result = *data_at_byte_offset(arr, byte_offset(arr, i, j));
437
473
  FP_VERIFY_FWD_2(result)
438
474
 
@@ -443,6 +479,23 @@ template <typename T>
443
479
  CUDA_CALLABLE inline T& index(const array_t<T>& arr, int i, int j, int k)
444
480
  {
445
481
  assert(arr.ndim == 3);
482
+ assert(i >= -arr.shape[0] && i < arr.shape[0]);
483
+ assert(j >= -arr.shape[1] && j < arr.shape[1]);
484
+ assert(k >= -arr.shape[2] && k < arr.shape[2]);
485
+
486
+ if (i < 0)
487
+ {
488
+ i += arr.shape[0];
489
+ }
490
+ if (j < 0)
491
+ {
492
+ j += arr.shape[1];
493
+ }
494
+ if (k < 0)
495
+ {
496
+ k += arr.shape[2];
497
+ }
498
+
446
499
  T& result = *data_at_byte_offset(arr, byte_offset(arr, i, j, k));
447
500
  FP_VERIFY_FWD_3(result)
448
501
 
@@ -453,6 +506,28 @@ template <typename T>
453
506
  CUDA_CALLABLE inline T& index(const array_t<T>& arr, int i, int j, int k, int l)
454
507
  {
455
508
  assert(arr.ndim == 4);
509
+ assert(i >= -arr.shape[0] && i < arr.shape[0]);
510
+ assert(j >= -arr.shape[1] && j < arr.shape[1]);
511
+ assert(k >= -arr.shape[2] && k < arr.shape[2]);
512
+ assert(l >= -arr.shape[3] && l < arr.shape[3]);
513
+
514
+ if (i < 0)
515
+ {
516
+ i += arr.shape[0];
517
+ }
518
+ if (j < 0)
519
+ {
520
+ j += arr.shape[1];
521
+ }
522
+ if (k < 0)
523
+ {
524
+ k += arr.shape[2];
525
+ }
526
+ if (l < 0)
527
+ {
528
+ l += arr.shape[3];
529
+ }
530
+
456
531
  T& result = *data_at_byte_offset(arr, byte_offset(arr, i, j, k, l));
457
532
  FP_VERIFY_FWD_4(result)
458
533
 
@@ -462,6 +537,14 @@ CUDA_CALLABLE inline T& index(const array_t<T>& arr, int i, int j, int k, int l)
462
537
  template <typename T>
463
538
  CUDA_CALLABLE inline T& index_grad(const array_t<T>& arr, int i)
464
539
  {
540
+ assert(arr.ndim == 1);
541
+ assert(i >= -arr.shape[0] && i < arr.shape[0]);
542
+
543
+ if (i < 0)
544
+ {
545
+ i += arr.shape[0];
546
+ }
547
+
465
548
  T& result = *grad_at_byte_offset(arr, byte_offset(arr, i));
466
549
  FP_VERIFY_FWD_1(result)
467
550
 
@@ -471,6 +554,19 @@ CUDA_CALLABLE inline T& index_grad(const array_t<T>& arr, int i)
471
554
  template <typename T>
472
555
  CUDA_CALLABLE inline T& index_grad(const array_t<T>& arr, int i, int j)
473
556
  {
557
+ assert(arr.ndim == 2);
558
+ assert(i >= -arr.shape[0] && i < arr.shape[0]);
559
+ assert(j >= -arr.shape[1] && j < arr.shape[1]);
560
+
561
+ if (i < 0)
562
+ {
563
+ i += arr.shape[0];
564
+ }
565
+ if (j < 0)
566
+ {
567
+ j += arr.shape[1];
568
+ }
569
+
474
570
  T& result = *grad_at_byte_offset(arr, byte_offset(arr, i, j));
475
571
  FP_VERIFY_FWD_2(result)
476
572
 
@@ -480,6 +576,24 @@ CUDA_CALLABLE inline T& index_grad(const array_t<T>& arr, int i, int j)
480
576
  template <typename T>
481
577
  CUDA_CALLABLE inline T& index_grad(const array_t<T>& arr, int i, int j, int k)
482
578
  {
579
+ assert(arr.ndim == 3);
580
+ assert(i >= -arr.shape[0] && i < arr.shape[0]);
581
+ assert(j >= -arr.shape[1] && j < arr.shape[1]);
582
+ assert(k >= -arr.shape[2] && k < arr.shape[2]);
583
+
584
+ if (i < 0)
585
+ {
586
+ i += arr.shape[0];
587
+ }
588
+ if (j < 0)
589
+ {
590
+ j += arr.shape[1];
591
+ }
592
+ if (k < 0)
593
+ {
594
+ k += arr.shape[2];
595
+ }
596
+
483
597
  T& result = *grad_at_byte_offset(arr, byte_offset(arr, i, j, k));
484
598
  FP_VERIFY_FWD_3(result)
485
599
 
@@ -489,6 +603,29 @@ CUDA_CALLABLE inline T& index_grad(const array_t<T>& arr, int i, int j, int k)
489
603
  template <typename T>
490
604
  CUDA_CALLABLE inline T& index_grad(const array_t<T>& arr, int i, int j, int k, int l)
491
605
  {
606
+ assert(arr.ndim == 4);
607
+ assert(i >= -arr.shape[0] && i < arr.shape[0]);
608
+ assert(j >= -arr.shape[1] && j < arr.shape[1]);
609
+ assert(k >= -arr.shape[2] && k < arr.shape[2]);
610
+ assert(l >= -arr.shape[3] && l < arr.shape[3]);
611
+
612
+ if (i < 0)
613
+ {
614
+ i += arr.shape[0];
615
+ }
616
+ if (j < 0)
617
+ {
618
+ j += arr.shape[1];
619
+ }
620
+ if (k < 0)
621
+ {
622
+ k += arr.shape[2];
623
+ }
624
+ if (l < 0)
625
+ {
626
+ l += arr.shape[3];
627
+ }
628
+
492
629
  T& result = *grad_at_byte_offset(arr, byte_offset(arr, i, j, k, l));
493
630
  FP_VERIFY_FWD_4(result)
494
631
 
@@ -500,7 +637,12 @@ template <typename T>
500
637
  CUDA_CALLABLE inline T& index(const indexedarray_t<T>& iarr, int i)
501
638
  {
502
639
  assert(iarr.arr.ndim == 1);
503
- assert(i >= 0 && i < iarr.shape[0]);
640
+ assert(i >= -iarr.shape[0] && i < iarr.shape[0]);
641
+
642
+ if (i < 0)
643
+ {
644
+ i += iarr.shape[0];
645
+ }
504
646
 
505
647
  if (iarr.indices[0])
506
648
  {
@@ -518,8 +660,17 @@ template <typename T>
518
660
  CUDA_CALLABLE inline T& index(const indexedarray_t<T>& iarr, int i, int j)
519
661
  {
520
662
  assert(iarr.arr.ndim == 2);
521
- assert(i >= 0 && i < iarr.shape[0]);
522
- assert(j >= 0 && j < iarr.shape[1]);
663
+ assert(i >= -iarr.shape[0] && i < iarr.shape[0]);
664
+ assert(j >= -iarr.shape[1] && j < iarr.shape[1]);
665
+
666
+ if (i < 0)
667
+ {
668
+ i += iarr.shape[0];
669
+ }
670
+ if (j < 0)
671
+ {
672
+ j += iarr.shape[1];
673
+ }
523
674
 
524
675
  if (iarr.indices[0])
525
676
  {
@@ -542,9 +693,22 @@ template <typename T>
542
693
  CUDA_CALLABLE inline T& index(const indexedarray_t<T>& iarr, int i, int j, int k)
543
694
  {
544
695
  assert(iarr.arr.ndim == 3);
545
- assert(i >= 0 && i < iarr.shape[0]);
546
- assert(j >= 0 && j < iarr.shape[1]);
547
- assert(k >= 0 && k < iarr.shape[2]);
696
+ assert(i >= -iarr.shape[0] && i < iarr.shape[0]);
697
+ assert(j >= -iarr.shape[1] && j < iarr.shape[1]);
698
+ assert(k >= -iarr.shape[2] && k < iarr.shape[2]);
699
+
700
+ if (i < 0)
701
+ {
702
+ i += iarr.shape[0];
703
+ }
704
+ if (j < 0)
705
+ {
706
+ j += iarr.shape[1];
707
+ }
708
+ if (k < 0)
709
+ {
710
+ k += iarr.shape[2];
711
+ }
548
712
 
549
713
  if (iarr.indices[0])
550
714
  {
@@ -572,10 +736,27 @@ template <typename T>
572
736
  CUDA_CALLABLE inline T& index(const indexedarray_t<T>& iarr, int i, int j, int k, int l)
573
737
  {
574
738
  assert(iarr.arr.ndim == 4);
575
- assert(i >= 0 && i < iarr.shape[0]);
576
- assert(j >= 0 && j < iarr.shape[1]);
577
- assert(k >= 0 && k < iarr.shape[2]);
578
- assert(l >= 0 && l < iarr.shape[3]);
739
+ assert(i >= -iarr.shape[0] && i < iarr.shape[0]);
740
+ assert(j >= -iarr.shape[1] && j < iarr.shape[1]);
741
+ assert(k >= -iarr.shape[2] && k < iarr.shape[2]);
742
+ assert(l >= -iarr.shape[3] && l < iarr.shape[3]);
743
+
744
+ if (i < 0)
745
+ {
746
+ i += iarr.shape[0];
747
+ }
748
+ if (j < 0)
749
+ {
750
+ j += iarr.shape[1];
751
+ }
752
+ if (k < 0)
753
+ {
754
+ k += iarr.shape[2];
755
+ }
756
+ if (l < 0)
757
+ {
758
+ l += iarr.shape[3];
759
+ }
579
760
 
580
761
  if (iarr.indices[0])
581
762
  {
@@ -609,7 +790,12 @@ template <typename T>
609
790
  CUDA_CALLABLE inline array_t<T> view(array_t<T>& src, int i)
610
791
  {
611
792
  assert(src.ndim > 1);
612
- assert(i >= 0 && i < src.shape[0]);
793
+ assert(i >= -src.shape[0] && i < src.shape[0]);
794
+
795
+ if (i < 0)
796
+ {
797
+ i += src.shape[0];
798
+ }
613
799
 
614
800
  array_t<T> a;
615
801
  size_t offset = byte_offset(src, i);
@@ -631,8 +817,17 @@ template <typename T>
631
817
  CUDA_CALLABLE inline array_t<T> view(array_t<T>& src, int i, int j)
632
818
  {
633
819
  assert(src.ndim > 2);
634
- assert(i >= 0 && i < src.shape[0]);
635
- assert(j >= 0 && j < src.shape[1]);
820
+ assert(i >= -src.shape[0] && i < src.shape[0]);
821
+ assert(j >= -src.shape[1] && j < src.shape[1]);
822
+
823
+ if (i < 0)
824
+ {
825
+ i += src.shape[0];
826
+ }
827
+ if (j < 0)
828
+ {
829
+ j += src.shape[1];
830
+ }
636
831
 
637
832
  array_t<T> a;
638
833
  size_t offset = byte_offset(src, i, j);
@@ -652,9 +847,22 @@ template <typename T>
652
847
  CUDA_CALLABLE inline array_t<T> view(array_t<T>& src, int i, int j, int k)
653
848
  {
654
849
  assert(src.ndim > 3);
655
- assert(i >= 0 && i < src.shape[0]);
656
- assert(j >= 0 && j < src.shape[1]);
657
- assert(k >= 0 && k < src.shape[2]);
850
+ assert(i >= -src.shape[0] && i < src.shape[0]);
851
+ assert(j >= -src.shape[1] && j < src.shape[1]);
852
+ assert(k >= -src.shape[2] && k < src.shape[2]);
853
+
854
+ if (i < 0)
855
+ {
856
+ i += src.shape[0];
857
+ }
858
+ if (j < 0)
859
+ {
860
+ j += src.shape[1];
861
+ }
862
+ if (k < 0)
863
+ {
864
+ k += src.shape[2];
865
+ }
658
866
 
659
867
  array_t<T> a;
660
868
  size_t offset = byte_offset(src, i, j, k);
@@ -669,6 +877,78 @@ CUDA_CALLABLE inline array_t<T> view(array_t<T>& src, int i, int j, int k)
669
877
  }
670
878
 
671
879
 
880
+ template <typename T, size_t... Idxs>
881
+ size_t byte_offset_helper(
882
+ array_t<T>& src,
883
+ const slice_t (&slices)[sizeof...(Idxs)],
884
+ index_sequence<Idxs...>
885
+ )
886
+ {
887
+ return byte_offset(src, slices[Idxs].start...);
888
+ }
889
+
890
+
891
+ template <typename T, typename... Slices>
892
+ CUDA_CALLABLE inline array_t<T> view(array_t<T>& src, const Slices&... slice_args)
893
+ {
894
+ constexpr int N = sizeof...(Slices);
895
+ static_assert(N >= 1 && N <= 4, "view supports 1 to 4 slices");
896
+ assert(src.ndim >= N);
897
+
898
+ slice_t slices[N] = { slice_args... };
899
+ int slice_idxs[N];
900
+ int slice_count = 0;
901
+
902
+ for (int i = 0; i < N; ++i)
903
+ {
904
+ if (slices[i].step == 0)
905
+ {
906
+ // We have a slice representing an integer index.
907
+ if (slices[i].start < 0)
908
+ {
909
+ slices[i].start += src.shape[i];
910
+ }
911
+ }
912
+ else
913
+ {
914
+ slices[i] = slice_adjust_indices(slices[i], src.shape[i]);
915
+ slice_idxs[slice_count] = i;
916
+ ++slice_count;
917
+ }
918
+ }
919
+
920
+ size_t offset = byte_offset_helper(src, slices, make_index_sequence<N>{});
921
+
922
+ array_t<T> out;
923
+
924
+ out.data = data_at_byte_offset(src, offset);
925
+ if (src.grad)
926
+ {
927
+ out.grad = grad_at_byte_offset(src, offset);
928
+ }
929
+
930
+ int dim = 0;
931
+ for (; dim < slice_count; ++dim)
932
+ {
933
+ int idx = slice_idxs[dim];
934
+ out.shape[dim] = slice_get_length(slices[idx]);
935
+ out.strides[dim] = src.strides[idx] * slices[idx].step;
936
+ }
937
+ for (; dim < slice_count + 4 - N; ++dim)
938
+ {
939
+ out.shape[dim] = src.shape[dim - slice_count + N];
940
+ out.strides[dim] = src.strides[dim - slice_count + N];
941
+ }
942
+ for (; dim < 4; ++dim)
943
+ {
944
+ out.shape[dim] = 0;
945
+ out.strides[dim] = 0;
946
+ }
947
+
948
+ out.ndim = src.ndim + slice_count - N;
949
+ return out;
950
+ }
951
+
672
952
  template <typename T>
673
953
  CUDA_CALLABLE inline indexedarray_t<T> view(indexedarray_t<T>& src, int i)
674
954
  {
@@ -676,7 +956,11 @@ CUDA_CALLABLE inline indexedarray_t<T> view(indexedarray_t<T>& src, int i)
676
956
 
677
957
  if (src.indices[0])
678
958
  {
679
- assert(i >= 0 && i < src.shape[0]);
959
+ assert(i >= -src.shape[0] && i < src.shape[0]);
960
+ if (i < 0)
961
+ {
962
+ i += src.shape[0];
963
+ }
680
964
  i = src.indices[0][i];
681
965
  }
682
966
 
@@ -699,12 +983,20 @@ CUDA_CALLABLE inline indexedarray_t<T> view(indexedarray_t<T>& src, int i, int j
699
983
 
700
984
  if (src.indices[0])
701
985
  {
702
- assert(i >= 0 && i < src.shape[0]);
986
+ assert(i >= -src.shape[0] && i < src.shape[0]);
987
+ if (i < 0)
988
+ {
989
+ i += src.shape[0];
990
+ }
703
991
  i = src.indices[0][i];
704
992
  }
705
993
  if (src.indices[1])
706
994
  {
707
- assert(j >= 0 && j < src.shape[1]);
995
+ assert(j >= -src.shape[1] && j < src.shape[1]);
996
+ if (j < 0)
997
+ {
998
+ j += src.shape[1];
999
+ }
708
1000
  j = src.indices[1][j];
709
1001
  }
710
1002
 
@@ -725,17 +1017,29 @@ CUDA_CALLABLE inline indexedarray_t<T> view(indexedarray_t<T>& src, int i, int j
725
1017
 
726
1018
  if (src.indices[0])
727
1019
  {
728
- assert(i >= 0 && i < src.shape[0]);
1020
+ assert(i >= -src.shape[0] && i < src.shape[0]);
1021
+ if (i < 0)
1022
+ {
1023
+ i += src.shape[0];
1024
+ }
729
1025
  i = src.indices[0][i];
730
1026
  }
731
1027
  if (src.indices[1])
732
1028
  {
733
- assert(j >= 0 && j < src.shape[1]);
1029
+ assert(j >= -src.shape[1] && j < src.shape[1]);
1030
+ if (j < 0)
1031
+ {
1032
+ j += src.shape[1];
1033
+ }
734
1034
  j = src.indices[1][j];
735
1035
  }
736
1036
  if (src.indices[2])
737
1037
  {
738
- assert(k >= 0 && k < src.shape[2]);
1038
+ assert(k >= -src.shape[2] && k < src.shape[2]);
1039
+ if (k < 0)
1040
+ {
1041
+ k += src.shape[2];
1042
+ }
739
1043
  k = src.indices[2][k];
740
1044
  }
741
1045
 
@@ -754,6 +1058,9 @@ inline CUDA_CALLABLE void adj_view(A1<T>& src, int i, int j, A2<T>& adj_src, int
754
1058
  template<template<typename> class A1, template<typename> class A2, template<typename> class A3, typename T>
755
1059
  inline CUDA_CALLABLE void adj_view(A1<T>& src, int i, int j, int k, A2<T>& adj_src, int adj_i, int adj_j, int adj_k, A3<T>& adj_ret) {}
756
1060
 
1061
+ template <typename... Args>
1062
+ CUDA_CALLABLE inline void adj_view(Args&&...) { }
1063
+
757
1064
  // TODO: lower_bound() for indexed arrays?
758
1065
 
759
1066
  template <typename T>
@@ -844,6 +1151,33 @@ inline CUDA_CALLABLE T atomic_exch(const A<T>& buf, int i, int j, int k, T value
844
1151
  template<template<typename> class A, typename T>
845
1152
  inline CUDA_CALLABLE T atomic_exch(const A<T>& buf, int i, int j, int k, int l, T value) { return atomic_exch(&index(buf, i, j, k, l), value); }
846
1153
 
1154
+ template<template<typename> class A, typename T>
1155
+ inline CUDA_CALLABLE T atomic_and(const A<T>& buf, int i, T value) { return atomic_and(&index(buf, i), value); }
1156
+ template<template<typename> class A, typename T>
1157
+ inline CUDA_CALLABLE T atomic_and(const A<T>& buf, int i, int j, T value) { return atomic_and(&index(buf, i, j), value); }
1158
+ template<template<typename> class A, typename T>
1159
+ inline CUDA_CALLABLE T atomic_and(const A<T>& buf, int i, int j, int k, T value) { return atomic_and(&index(buf, i, j, k), value); }
1160
+ template<template<typename> class A, typename T>
1161
+ inline CUDA_CALLABLE T atomic_and(const A<T>& buf, int i, int j, int k, int l, T value) { return atomic_and(&index(buf, i, j, k, l), value); }
1162
+
1163
+ template<template<typename> class A, typename T>
1164
+ inline CUDA_CALLABLE T atomic_or(const A<T>& buf, int i, T value) { return atomic_or(&index(buf, i), value); }
1165
+ template<template<typename> class A, typename T>
1166
+ inline CUDA_CALLABLE T atomic_or(const A<T>& buf, int i, int j, T value) { return atomic_or(&index(buf, i, j), value); }
1167
+ template<template<typename> class A, typename T>
1168
+ inline CUDA_CALLABLE T atomic_or(const A<T>& buf, int i, int j, int k, T value) { return atomic_or(&index(buf, i, j, k), value); }
1169
+ template<template<typename> class A, typename T>
1170
+ inline CUDA_CALLABLE T atomic_or(const A<T>& buf, int i, int j, int k, int l, T value) { return atomic_or(&index(buf, i, j, k, l), value); }
1171
+
1172
+ template<template<typename> class A, typename T>
1173
+ inline CUDA_CALLABLE T atomic_xor(const A<T>& buf, int i, T value) { return atomic_xor(&index(buf, i), value); }
1174
+ template<template<typename> class A, typename T>
1175
+ inline CUDA_CALLABLE T atomic_xor(const A<T>& buf, int i, int j, T value) { return atomic_xor(&index(buf, i, j), value); }
1176
+ template<template<typename> class A, typename T>
1177
+ inline CUDA_CALLABLE T atomic_xor(const A<T>& buf, int i, int j, int k, T value) { return atomic_xor(&index(buf, i, j, k), value); }
1178
+ template<template<typename> class A, typename T>
1179
+ inline CUDA_CALLABLE T atomic_xor(const A<T>& buf, int i, int j, int k, int l, T value) { return atomic_xor(&index(buf, i, j, k, l), value); }
1180
+
847
1181
  template<template<typename> class A, typename T>
848
1182
  inline CUDA_CALLABLE T* address(const A<T>& buf, int i)
849
1183
  {
@@ -911,20 +1245,7 @@ inline CUDA_CALLABLE T load(T* address)
911
1245
  return value;
912
1246
  }
913
1247
 
914
- // select operator to check for array being null
915
- template <typename T1, typename T2>
916
- CUDA_CALLABLE inline T2 select(const array_t<T1>& arr, const T2& a, const T2& b) { return arr.data?b:a; }
917
-
918
- template <typename T1, typename T2>
919
- CUDA_CALLABLE inline void adj_select(const array_t<T1>& arr, const T2& a, const T2& b, const array_t<T1>& adj_cond, T2& adj_a, T2& adj_b, const T2& adj_ret)
920
- {
921
- if (arr.data)
922
- adj_b += adj_ret;
923
- else
924
- adj_a += adj_ret;
925
- }
926
-
927
- // where operator to check for array being null, opposite convention compared to select
1248
+ // where() overload for array condition - returns a if array.data is non-null, otherwise returns b
928
1249
  template <typename T1, typename T2>
929
1250
  CUDA_CALLABLE inline T2 where(const array_t<T1>& arr, const T2& a, const T2& b) { return arr.data?a:b; }
930
1251
 
@@ -1321,6 +1642,34 @@ inline CUDA_CALLABLE void adj_atomic_exch(const A1<T>& buf, int i, int j, int k,
1321
1642
  FP_VERIFY_ADJ_4(value, adj_value)
1322
1643
  }
1323
1644
 
1645
+ // for bitwise operations we do not accumulate gradients
1646
+ template<template<typename> class A1, template<typename> class A2, typename T>
1647
+ inline CUDA_CALLABLE void adj_atomic_and(const A1<T>& buf, int i, T value, const A2<T>& adj_buf, int adj_i, T& adj_value, const T& adj_ret) {}
1648
+ template<template<typename> class A1, template<typename> class A2, typename T>
1649
+ inline CUDA_CALLABLE void adj_atomic_and(const A1<T>& buf, int i, int j, T value, const A2<T>& adj_buf, int adj_i, int adj_j, T& adj_value, const T& adj_ret) {}
1650
+ template<template<typename> class A1, template<typename> class A2, typename T>
1651
+ inline CUDA_CALLABLE void adj_atomic_and(const A1<T>& buf, int i, int j, int k, T value, const A2<T>& adj_buf, int adj_i, int adj_j, int adj_k, T& adj_value, const T& adj_ret) {}
1652
+ template<template<typename> class A1, template<typename> class A2, typename T>
1653
+ inline CUDA_CALLABLE void adj_atomic_and(const A1<T>& buf, int i, int j, int k, int l, T value, const A2<T>& adj_buf, int adj_i, int adj_j, int adj_k, int adj_l, T& adj_value, const T& adj_ret) {}
1654
+
1655
+ template<template<typename> class A1, template<typename> class A2, typename T>
1656
+ inline CUDA_CALLABLE void adj_atomic_or(const A1<T>& buf, int i, T value, const A2<T>& adj_buf, int adj_i, T& adj_value, const T& adj_ret) {}
1657
+ template<template<typename> class A1, template<typename> class A2, typename T>
1658
+ inline CUDA_CALLABLE void adj_atomic_or(const A1<T>& buf, int i, int j, T value, const A2<T>& adj_buf, int adj_i, int adj_j, T& adj_value, const T& adj_ret) {}
1659
+ template<template<typename> class A1, template<typename> class A2, typename T>
1660
+ inline CUDA_CALLABLE void adj_atomic_or(const A1<T>& buf, int i, int j, int k, T value, const A2<T>& adj_buf, int adj_i, int adj_j, int adj_k, T& adj_value, const T& adj_ret) {}
1661
+ template<template<typename> class A1, template<typename> class A2, typename T>
1662
+ inline CUDA_CALLABLE void adj_atomic_or(const A1<T>& buf, int i, int j, int k, int l, T value, const A2<T>& adj_buf, int adj_i, int adj_j, int adj_k, int adj_l, T& adj_value, const T& adj_ret) {}
1663
+
1664
+ template<template<typename> class A1, template<typename> class A2, typename T>
1665
+ inline CUDA_CALLABLE void adj_atomic_xor(const A1<T>& buf, int i, T value, const A2<T>& adj_buf, int adj_i, T& adj_value, const T& adj_ret) {}
1666
+ template<template<typename> class A1, template<typename> class A2, typename T>
1667
+ inline CUDA_CALLABLE void adj_atomic_xor(const A1<T>& buf, int i, int j, T value, const A2<T>& adj_buf, int adj_i, int adj_j, T& adj_value, const T& adj_ret) {}
1668
+ template<template<typename> class A1, template<typename> class A2, typename T>
1669
+ inline CUDA_CALLABLE void adj_atomic_xor(const A1<T>& buf, int i, int j, int k, T value, const A2<T>& adj_buf, int adj_i, int adj_j, int adj_k, T& adj_value, const T& adj_ret) {}
1670
+ template<template<typename> class A1, template<typename> class A2, typename T>
1671
+ inline CUDA_CALLABLE void adj_atomic_xor(const A1<T>& buf, int i, int j, int k, int l, T value, const A2<T>& adj_buf, int adj_i, int adj_j, int adj_k, int adj_l, T& adj_value, const T& adj_ret) {}
1672
+
1324
1673
 
1325
1674
  template<template<typename> class A, typename T>
1326
1675
  CUDA_CALLABLE inline int len(const A<T>& a)
@@ -1333,7 +1682,6 @@ CUDA_CALLABLE inline void adj_len(const A<T>& a, A<T>& adj_a, int& adj_ret)
1333
1682
  {
1334
1683
  }
1335
1684
 
1336
-
1337
1685
  } // namespace wp
1338
1686
 
1339
1687
  #include "fabric.h"