warp-lang 1.9.0__py3-none-win_amd64.whl → 1.10.0rc2__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (350) hide show
  1. warp/__init__.py +301 -287
  2. warp/__init__.pyi +2220 -313
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1075 -0
  5. warp/_src/build.py +618 -0
  6. warp/_src/build_dll.py +640 -0
  7. warp/{builtins.py → _src/builtins.py} +1497 -226
  8. warp/_src/codegen.py +4359 -0
  9. warp/{config.py → _src/config.py} +178 -169
  10. warp/_src/constants.py +57 -0
  11. warp/_src/context.py +8294 -0
  12. warp/_src/dlpack.py +462 -0
  13. warp/_src/fabric.py +355 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +508 -0
  16. warp/_src/fem/cache.py +687 -0
  17. warp/_src/fem/dirichlet.py +188 -0
  18. warp/{fem → _src/fem}/domain.py +40 -30
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +701 -0
  21. warp/{fem → _src/fem}/field/nodal_field.py +30 -15
  22. warp/{fem → _src/fem}/field/restriction.py +1 -1
  23. warp/{fem → _src/fem}/field/virtual.py +53 -27
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +77 -163
  26. warp/_src/fem/geometry/closest_point.py +97 -0
  27. warp/{fem → _src/fem}/geometry/deformed_geometry.py +14 -22
  28. warp/{fem → _src/fem}/geometry/element.py +32 -10
  29. warp/{fem → _src/fem}/geometry/geometry.py +48 -20
  30. warp/{fem → _src/fem}/geometry/grid_2d.py +12 -23
  31. warp/{fem → _src/fem}/geometry/grid_3d.py +12 -23
  32. warp/{fem → _src/fem}/geometry/hexmesh.py +40 -63
  33. warp/{fem → _src/fem}/geometry/nanogrid.py +255 -248
  34. warp/{fem → _src/fem}/geometry/partition.py +121 -63
  35. warp/{fem → _src/fem}/geometry/quadmesh.py +26 -45
  36. warp/{fem → _src/fem}/geometry/tetmesh.py +40 -63
  37. warp/{fem → _src/fem}/geometry/trimesh.py +26 -45
  38. warp/{fem → _src/fem}/integrate.py +164 -158
  39. warp/_src/fem/linalg.py +383 -0
  40. warp/_src/fem/operator.py +396 -0
  41. warp/_src/fem/polynomial.py +229 -0
  42. warp/{fem → _src/fem}/quadrature/pic_quadrature.py +15 -20
  43. warp/{fem → _src/fem}/quadrature/quadrature.py +95 -47
  44. warp/_src/fem/space/__init__.py +248 -0
  45. warp/{fem → _src/fem}/space/basis_function_space.py +20 -11
  46. warp/_src/fem/space/basis_space.py +679 -0
  47. warp/{fem → _src/fem}/space/dof_mapper.py +3 -3
  48. warp/{fem → _src/fem}/space/function_space.py +14 -13
  49. warp/{fem → _src/fem}/space/grid_2d_function_space.py +4 -7
  50. warp/{fem → _src/fem}/space/grid_3d_function_space.py +4 -4
  51. warp/{fem → _src/fem}/space/hexmesh_function_space.py +4 -10
  52. warp/{fem → _src/fem}/space/nanogrid_function_space.py +3 -9
  53. warp/{fem → _src/fem}/space/partition.py +117 -60
  54. warp/{fem → _src/fem}/space/quadmesh_function_space.py +4 -10
  55. warp/{fem → _src/fem}/space/restriction.py +66 -33
  56. warp/_src/fem/space/shape/__init__.py +152 -0
  57. warp/{fem → _src/fem}/space/shape/cube_shape_function.py +9 -9
  58. warp/{fem → _src/fem}/space/shape/shape_function.py +8 -9
  59. warp/{fem → _src/fem}/space/shape/square_shape_function.py +6 -6
  60. warp/{fem → _src/fem}/space/shape/tet_shape_function.py +3 -3
  61. warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +3 -3
  62. warp/{fem → _src/fem}/space/tetmesh_function_space.py +3 -9
  63. warp/_src/fem/space/topology.py +459 -0
  64. warp/{fem → _src/fem}/space/trimesh_function_space.py +3 -9
  65. warp/_src/fem/types.py +112 -0
  66. warp/_src/fem/utils.py +486 -0
  67. warp/_src/jax.py +186 -0
  68. warp/_src/jax_experimental/__init__.py +14 -0
  69. warp/_src/jax_experimental/custom_call.py +387 -0
  70. warp/_src/jax_experimental/ffi.py +1284 -0
  71. warp/_src/jax_experimental/xla_ffi.py +656 -0
  72. warp/_src/marching_cubes.py +708 -0
  73. warp/_src/math.py +414 -0
  74. warp/_src/optim/__init__.py +14 -0
  75. warp/_src/optim/adam.py +163 -0
  76. warp/_src/optim/linear.py +1606 -0
  77. warp/_src/optim/sgd.py +112 -0
  78. warp/_src/paddle.py +406 -0
  79. warp/_src/render/__init__.py +14 -0
  80. warp/_src/render/imgui_manager.py +289 -0
  81. warp/_src/render/render_opengl.py +3636 -0
  82. warp/_src/render/render_usd.py +937 -0
  83. warp/_src/render/utils.py +160 -0
  84. warp/_src/sparse.py +2716 -0
  85. warp/_src/tape.py +1206 -0
  86. warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
  87. warp/_src/torch.py +391 -0
  88. warp/_src/types.py +5870 -0
  89. warp/_src/utils.py +1693 -0
  90. warp/autograd.py +12 -1054
  91. warp/bin/warp-clang.dll +0 -0
  92. warp/bin/warp.dll +0 -0
  93. warp/build.py +8 -588
  94. warp/build_dll.py +6 -471
  95. warp/codegen.py +6 -4246
  96. warp/constants.py +6 -39
  97. warp/context.py +12 -7851
  98. warp/dlpack.py +6 -444
  99. warp/examples/distributed/example_jacobi_mpi.py +4 -5
  100. warp/examples/fem/example_adaptive_grid.py +1 -1
  101. warp/examples/fem/example_apic_fluid.py +1 -1
  102. warp/examples/fem/example_burgers.py +8 -8
  103. warp/examples/fem/example_diffusion.py +1 -1
  104. warp/examples/fem/example_distortion_energy.py +1 -1
  105. warp/examples/fem/example_mixed_elasticity.py +2 -2
  106. warp/examples/fem/example_navier_stokes.py +1 -1
  107. warp/examples/fem/example_nonconforming_contact.py +7 -7
  108. warp/examples/fem/example_stokes.py +1 -1
  109. warp/examples/fem/example_stokes_transfer.py +1 -1
  110. warp/examples/fem/utils.py +2 -2
  111. warp/examples/interop/example_jax_callable.py +1 -1
  112. warp/examples/interop/example_jax_ffi_callback.py +1 -1
  113. warp/examples/interop/example_jax_kernel.py +3 -2
  114. warp/examples/tile/example_tile_mcgp.py +191 -0
  115. warp/fabric.py +6 -337
  116. warp/fem/__init__.py +159 -97
  117. warp/fem/adaptivity.py +7 -489
  118. warp/fem/cache.py +9 -648
  119. warp/fem/dirichlet.py +6 -184
  120. warp/fem/field/__init__.py +8 -109
  121. warp/fem/field/field.py +7 -652
  122. warp/fem/geometry/__init__.py +7 -18
  123. warp/fem/geometry/closest_point.py +11 -77
  124. warp/fem/linalg.py +18 -366
  125. warp/fem/operator.py +11 -369
  126. warp/fem/polynomial.py +9 -209
  127. warp/fem/space/__init__.py +5 -211
  128. warp/fem/space/basis_space.py +6 -662
  129. warp/fem/space/shape/__init__.py +41 -118
  130. warp/fem/space/topology.py +6 -437
  131. warp/fem/types.py +6 -81
  132. warp/fem/utils.py +11 -444
  133. warp/jax.py +8 -165
  134. warp/jax_experimental/__init__.py +14 -1
  135. warp/jax_experimental/custom_call.py +8 -342
  136. warp/jax_experimental/ffi.py +17 -853
  137. warp/jax_experimental/xla_ffi.py +5 -596
  138. warp/marching_cubes.py +5 -689
  139. warp/math.py +16 -393
  140. warp/native/array.h +385 -37
  141. warp/native/builtin.h +316 -39
  142. warp/native/bvh.cpp +43 -9
  143. warp/native/bvh.cu +62 -27
  144. warp/native/bvh.h +310 -309
  145. warp/native/clang/clang.cpp +102 -97
  146. warp/native/coloring.cpp +0 -1
  147. warp/native/crt.h +208 -0
  148. warp/native/exports.h +156 -0
  149. warp/native/hashgrid.cu +2 -0
  150. warp/native/intersect.h +24 -1
  151. warp/native/intersect_tri.h +44 -35
  152. warp/native/mat.h +1456 -276
  153. warp/native/mesh.cpp +4 -4
  154. warp/native/mesh.cu +4 -2
  155. warp/native/mesh.h +176 -61
  156. warp/native/quat.h +0 -52
  157. warp/native/scan.cu +2 -0
  158. warp/native/sort.cu +22 -13
  159. warp/native/sort.h +2 -0
  160. warp/native/sparse.cu +7 -3
  161. warp/native/spatial.h +12 -0
  162. warp/native/tile.h +837 -70
  163. warp/native/tile_radix_sort.h +1 -1
  164. warp/native/tile_reduce.h +394 -46
  165. warp/native/tile_scan.h +4 -4
  166. warp/native/vec.h +469 -53
  167. warp/native/version.h +23 -0
  168. warp/native/volume.cpp +1 -1
  169. warp/native/volume.cu +1 -0
  170. warp/native/volume.h +1 -1
  171. warp/native/volume_builder.cu +2 -0
  172. warp/native/warp.cpp +60 -32
  173. warp/native/warp.cu +313 -201
  174. warp/native/warp.h +14 -11
  175. warp/optim/__init__.py +6 -3
  176. warp/optim/adam.py +6 -145
  177. warp/optim/linear.py +14 -1585
  178. warp/optim/sgd.py +6 -94
  179. warp/paddle.py +6 -388
  180. warp/render/__init__.py +8 -4
  181. warp/render/imgui_manager.py +7 -267
  182. warp/render/render_opengl.py +6 -3616
  183. warp/render/render_usd.py +6 -918
  184. warp/render/utils.py +6 -142
  185. warp/sparse.py +37 -2563
  186. warp/tape.py +6 -1188
  187. warp/tests/__main__.py +1 -1
  188. warp/tests/cuda/test_async.py +4 -4
  189. warp/tests/cuda/test_conditional_captures.py +1 -1
  190. warp/tests/cuda/test_multigpu.py +1 -1
  191. warp/tests/cuda/test_streams.py +58 -1
  192. warp/tests/geometry/test_bvh.py +157 -22
  193. warp/tests/geometry/test_hash_grid.py +38 -0
  194. warp/tests/geometry/test_marching_cubes.py +0 -1
  195. warp/tests/geometry/test_mesh.py +5 -3
  196. warp/tests/geometry/test_mesh_query_aabb.py +5 -12
  197. warp/tests/geometry/test_mesh_query_point.py +5 -2
  198. warp/tests/geometry/test_mesh_query_ray.py +15 -3
  199. warp/tests/geometry/test_volume_write.py +5 -5
  200. warp/tests/interop/test_dlpack.py +14 -14
  201. warp/tests/interop/test_jax.py +1382 -79
  202. warp/tests/interop/test_paddle.py +1 -1
  203. warp/tests/test_adam.py +0 -1
  204. warp/tests/test_arithmetic.py +9 -9
  205. warp/tests/test_array.py +529 -100
  206. warp/tests/test_array_reduce.py +3 -3
  207. warp/tests/test_atomic.py +12 -8
  208. warp/tests/test_atomic_bitwise.py +209 -0
  209. warp/tests/test_atomic_cas.py +4 -4
  210. warp/tests/test_bool.py +2 -2
  211. warp/tests/test_builtins_resolution.py +5 -571
  212. warp/tests/test_codegen.py +34 -15
  213. warp/tests/test_conditional.py +1 -1
  214. warp/tests/test_context.py +6 -6
  215. warp/tests/test_copy.py +242 -161
  216. warp/tests/test_ctypes.py +3 -3
  217. warp/tests/test_devices.py +24 -2
  218. warp/tests/test_examples.py +16 -84
  219. warp/tests/test_fabricarray.py +35 -35
  220. warp/tests/test_fast_math.py +0 -2
  221. warp/tests/test_fem.py +60 -14
  222. warp/tests/test_fixedarray.py +3 -3
  223. warp/tests/test_func.py +8 -5
  224. warp/tests/test_generics.py +1 -1
  225. warp/tests/test_indexedarray.py +24 -24
  226. warp/tests/test_intersect.py +39 -9
  227. warp/tests/test_large.py +1 -1
  228. warp/tests/test_lerp.py +3 -1
  229. warp/tests/test_linear_solvers.py +1 -1
  230. warp/tests/test_map.py +49 -4
  231. warp/tests/test_mat.py +52 -62
  232. warp/tests/test_mat_constructors.py +4 -5
  233. warp/tests/test_mat_lite.py +1 -1
  234. warp/tests/test_mat_scalar_ops.py +121 -121
  235. warp/tests/test_math.py +34 -0
  236. warp/tests/test_module_aot.py +4 -4
  237. warp/tests/test_modules_lite.py +28 -2
  238. warp/tests/test_print.py +11 -11
  239. warp/tests/test_quat.py +93 -58
  240. warp/tests/test_runlength_encode.py +1 -1
  241. warp/tests/test_scalar_ops.py +38 -10
  242. warp/tests/test_smoothstep.py +1 -1
  243. warp/tests/test_sparse.py +126 -15
  244. warp/tests/test_spatial.py +105 -87
  245. warp/tests/test_special_values.py +6 -6
  246. warp/tests/test_static.py +7 -7
  247. warp/tests/test_struct.py +13 -2
  248. warp/tests/test_triangle_closest_point.py +48 -1
  249. warp/tests/test_tuple.py +96 -0
  250. warp/tests/test_types.py +82 -9
  251. warp/tests/test_utils.py +52 -52
  252. warp/tests/test_vec.py +29 -29
  253. warp/tests/test_vec_constructors.py +5 -5
  254. warp/tests/test_vec_scalar_ops.py +97 -97
  255. warp/tests/test_version.py +75 -0
  256. warp/tests/tile/test_tile.py +239 -0
  257. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  258. warp/tests/tile/test_tile_cholesky.py +7 -4
  259. warp/tests/tile/test_tile_load.py +26 -2
  260. warp/tests/tile/test_tile_mathdx.py +3 -3
  261. warp/tests/tile/test_tile_matmul.py +1 -1
  262. warp/tests/tile/test_tile_mlp.py +2 -4
  263. warp/tests/tile/test_tile_reduce.py +214 -13
  264. warp/tests/unittest_suites.py +6 -14
  265. warp/tests/unittest_utils.py +10 -9
  266. warp/tests/walkthrough_debug.py +3 -1
  267. warp/torch.py +6 -373
  268. warp/types.py +29 -5750
  269. warp/utils.py +10 -1659
  270. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/METADATA +47 -103
  271. warp_lang-1.10.0rc2.dist-info/RECORD +468 -0
  272. warp_lang-1.10.0rc2.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  273. warp_lang-1.10.0rc2.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  274. warp_lang-1.10.0rc2.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  275. warp_lang-1.10.0rc2.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  276. warp_lang-1.10.0rc2.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  277. warp_lang-1.10.0rc2.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  278. warp_lang-1.10.0rc2.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  279. warp_lang-1.10.0rc2.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  280. warp_lang-1.10.0rc2.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  281. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  282. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  283. warp_lang-1.10.0rc2.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  284. warp_lang-1.10.0rc2.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  285. warp_lang-1.10.0rc2.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  286. warp_lang-1.10.0rc2.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  287. warp/examples/assets/cartpole.urdf +0 -110
  288. warp/examples/assets/crazyflie.usd +0 -0
  289. warp/examples/assets/nv_ant.xml +0 -92
  290. warp/examples/assets/nv_humanoid.xml +0 -183
  291. warp/examples/assets/quadruped.urdf +0 -268
  292. warp/examples/optim/example_bounce.py +0 -266
  293. warp/examples/optim/example_cloth_throw.py +0 -228
  294. warp/examples/optim/example_drone.py +0 -870
  295. warp/examples/optim/example_inverse_kinematics.py +0 -182
  296. warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
  297. warp/examples/optim/example_softbody_properties.py +0 -400
  298. warp/examples/optim/example_spring_cage.py +0 -245
  299. warp/examples/optim/example_trajectory.py +0 -227
  300. warp/examples/sim/example_cartpole.py +0 -143
  301. warp/examples/sim/example_cloth.py +0 -225
  302. warp/examples/sim/example_cloth_self_contact.py +0 -316
  303. warp/examples/sim/example_granular.py +0 -130
  304. warp/examples/sim/example_granular_collision_sdf.py +0 -202
  305. warp/examples/sim/example_jacobian_ik.py +0 -244
  306. warp/examples/sim/example_particle_chain.py +0 -124
  307. warp/examples/sim/example_quadruped.py +0 -203
  308. warp/examples/sim/example_rigid_chain.py +0 -203
  309. warp/examples/sim/example_rigid_contact.py +0 -195
  310. warp/examples/sim/example_rigid_force.py +0 -133
  311. warp/examples/sim/example_rigid_gyroscopic.py +0 -115
  312. warp/examples/sim/example_rigid_soft_contact.py +0 -140
  313. warp/examples/sim/example_soft_body.py +0 -196
  314. warp/examples/tile/example_tile_walker.py +0 -327
  315. warp/sim/__init__.py +0 -74
  316. warp/sim/articulation.py +0 -793
  317. warp/sim/collide.py +0 -2570
  318. warp/sim/graph_coloring.py +0 -307
  319. warp/sim/import_mjcf.py +0 -791
  320. warp/sim/import_snu.py +0 -227
  321. warp/sim/import_urdf.py +0 -579
  322. warp/sim/import_usd.py +0 -898
  323. warp/sim/inertia.py +0 -357
  324. warp/sim/integrator.py +0 -245
  325. warp/sim/integrator_euler.py +0 -2000
  326. warp/sim/integrator_featherstone.py +0 -2101
  327. warp/sim/integrator_vbd.py +0 -2487
  328. warp/sim/integrator_xpbd.py +0 -3295
  329. warp/sim/model.py +0 -4821
  330. warp/sim/particles.py +0 -121
  331. warp/sim/render.py +0 -431
  332. warp/sim/utils.py +0 -431
  333. warp/tests/sim/disabled_kinematics.py +0 -244
  334. warp/tests/sim/test_cloth.py +0 -863
  335. warp/tests/sim/test_collision.py +0 -743
  336. warp/tests/sim/test_coloring.py +0 -347
  337. warp/tests/sim/test_inertia.py +0 -161
  338. warp/tests/sim/test_model.py +0 -226
  339. warp/tests/sim/test_sim_grad.py +0 -287
  340. warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
  341. warp/tests/sim/test_sim_kinematics.py +0 -98
  342. warp/thirdparty/__init__.py +0 -0
  343. warp_lang-1.9.0.dist-info/RECORD +0 -456
  344. /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
  345. /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
  346. /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
  347. /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
  348. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/WHEEL +0 -0
  349. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/licenses/LICENSE.md +0 -0
  350. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/top_level.txt +0 -0
warp/native/array.h CHANGED
@@ -118,6 +118,23 @@ namespace wp
118
118
 
119
119
  #endif // WP_FP_CHECK
120
120
 
121
+
122
+ template<size_t... Is>
123
+ struct index_sequence {};
124
+
125
+ template<size_t N, size_t... Is>
126
+ struct make_index_sequence_impl : make_index_sequence_impl<N-1, N-1, Is...> {};
127
+
128
+ template<size_t... Is>
129
+ struct make_index_sequence_impl<0, Is...>
130
+ {
131
+ using type = index_sequence<Is...>;
132
+ };
133
+
134
+ template<size_t N>
135
+ using make_index_sequence = typename make_index_sequence_impl<N>::type;
136
+
137
+
121
138
  const int ARRAY_MAX_DIMS = 4; // must match constant in types.py
122
139
 
123
140
  // must match constants in types.py
@@ -423,6 +440,13 @@ template <typename T>
423
440
  CUDA_CALLABLE inline T& index(const array_t<T>& arr, int i)
424
441
  {
425
442
  assert(arr.ndim == 1);
443
+ assert(i >= -arr.shape[0] && i < arr.shape[0]);
444
+
445
+ if (i < 0)
446
+ {
447
+ i += arr.shape[0];
448
+ }
449
+
426
450
  T& result = *data_at_byte_offset(arr, byte_offset(arr, i));
427
451
  FP_VERIFY_FWD_1(result)
428
452
 
@@ -433,6 +457,18 @@ template <typename T>
433
457
  CUDA_CALLABLE inline T& index(const array_t<T>& arr, int i, int j)
434
458
  {
435
459
  assert(arr.ndim == 2);
460
+ assert(i >= -arr.shape[0] && i < arr.shape[0]);
461
+ assert(j >= -arr.shape[1] && j < arr.shape[1]);
462
+
463
+ if (i < 0)
464
+ {
465
+ i += arr.shape[0];
466
+ }
467
+ if (j < 0)
468
+ {
469
+ j += arr.shape[1];
470
+ }
471
+
436
472
  T& result = *data_at_byte_offset(arr, byte_offset(arr, i, j));
437
473
  FP_VERIFY_FWD_2(result)
438
474
 
@@ -443,6 +479,23 @@ template <typename T>
443
479
  CUDA_CALLABLE inline T& index(const array_t<T>& arr, int i, int j, int k)
444
480
  {
445
481
  assert(arr.ndim == 3);
482
+ assert(i >= -arr.shape[0] && i < arr.shape[0]);
483
+ assert(j >= -arr.shape[1] && j < arr.shape[1]);
484
+ assert(k >= -arr.shape[2] && k < arr.shape[2]);
485
+
486
+ if (i < 0)
487
+ {
488
+ i += arr.shape[0];
489
+ }
490
+ if (j < 0)
491
+ {
492
+ j += arr.shape[1];
493
+ }
494
+ if (k < 0)
495
+ {
496
+ k += arr.shape[2];
497
+ }
498
+
446
499
  T& result = *data_at_byte_offset(arr, byte_offset(arr, i, j, k));
447
500
  FP_VERIFY_FWD_3(result)
448
501
 
@@ -453,6 +506,28 @@ template <typename T>
453
506
  CUDA_CALLABLE inline T& index(const array_t<T>& arr, int i, int j, int k, int l)
454
507
  {
455
508
  assert(arr.ndim == 4);
509
+ assert(i >= -arr.shape[0] && i < arr.shape[0]);
510
+ assert(j >= -arr.shape[1] && j < arr.shape[1]);
511
+ assert(k >= -arr.shape[2] && k < arr.shape[2]);
512
+ assert(l >= -arr.shape[3] && l < arr.shape[3]);
513
+
514
+ if (i < 0)
515
+ {
516
+ i += arr.shape[0];
517
+ }
518
+ if (j < 0)
519
+ {
520
+ j += arr.shape[1];
521
+ }
522
+ if (k < 0)
523
+ {
524
+ k += arr.shape[2];
525
+ }
526
+ if (l < 0)
527
+ {
528
+ l += arr.shape[3];
529
+ }
530
+
456
531
  T& result = *data_at_byte_offset(arr, byte_offset(arr, i, j, k, l));
457
532
  FP_VERIFY_FWD_4(result)
458
533
 
@@ -462,6 +537,14 @@ CUDA_CALLABLE inline T& index(const array_t<T>& arr, int i, int j, int k, int l)
462
537
  template <typename T>
463
538
  CUDA_CALLABLE inline T& index_grad(const array_t<T>& arr, int i)
464
539
  {
540
+ assert(arr.ndim == 1);
541
+ assert(i >= -arr.shape[0] && i < arr.shape[0]);
542
+
543
+ if (i < 0)
544
+ {
545
+ i += arr.shape[0];
546
+ }
547
+
465
548
  T& result = *grad_at_byte_offset(arr, byte_offset(arr, i));
466
549
  FP_VERIFY_FWD_1(result)
467
550
 
@@ -471,6 +554,19 @@ CUDA_CALLABLE inline T& index_grad(const array_t<T>& arr, int i)
471
554
  template <typename T>
472
555
  CUDA_CALLABLE inline T& index_grad(const array_t<T>& arr, int i, int j)
473
556
  {
557
+ assert(arr.ndim == 2);
558
+ assert(i >= -arr.shape[0] && i < arr.shape[0]);
559
+ assert(j >= -arr.shape[1] && j < arr.shape[1]);
560
+
561
+ if (i < 0)
562
+ {
563
+ i += arr.shape[0];
564
+ }
565
+ if (j < 0)
566
+ {
567
+ j += arr.shape[1];
568
+ }
569
+
474
570
  T& result = *grad_at_byte_offset(arr, byte_offset(arr, i, j));
475
571
  FP_VERIFY_FWD_2(result)
476
572
 
@@ -480,6 +576,24 @@ CUDA_CALLABLE inline T& index_grad(const array_t<T>& arr, int i, int j)
480
576
  template <typename T>
481
577
  CUDA_CALLABLE inline T& index_grad(const array_t<T>& arr, int i, int j, int k)
482
578
  {
579
+ assert(arr.ndim == 3);
580
+ assert(i >= -arr.shape[0] && i < arr.shape[0]);
581
+ assert(j >= -arr.shape[1] && j < arr.shape[1]);
582
+ assert(k >= -arr.shape[2] && k < arr.shape[2]);
583
+
584
+ if (i < 0)
585
+ {
586
+ i += arr.shape[0];
587
+ }
588
+ if (j < 0)
589
+ {
590
+ j += arr.shape[1];
591
+ }
592
+ if (k < 0)
593
+ {
594
+ k += arr.shape[2];
595
+ }
596
+
483
597
  T& result = *grad_at_byte_offset(arr, byte_offset(arr, i, j, k));
484
598
  FP_VERIFY_FWD_3(result)
485
599
 
@@ -489,6 +603,29 @@ CUDA_CALLABLE inline T& index_grad(const array_t<T>& arr, int i, int j, int k)
489
603
  template <typename T>
490
604
  CUDA_CALLABLE inline T& index_grad(const array_t<T>& arr, int i, int j, int k, int l)
491
605
  {
606
+ assert(arr.ndim == 4);
607
+ assert(i >= -arr.shape[0] && i < arr.shape[0]);
608
+ assert(j >= -arr.shape[1] && j < arr.shape[1]);
609
+ assert(k >= -arr.shape[2] && k < arr.shape[2]);
610
+ assert(l >= -arr.shape[3] && l < arr.shape[3]);
611
+
612
+ if (i < 0)
613
+ {
614
+ i += arr.shape[0];
615
+ }
616
+ if (j < 0)
617
+ {
618
+ j += arr.shape[1];
619
+ }
620
+ if (k < 0)
621
+ {
622
+ k += arr.shape[2];
623
+ }
624
+ if (l < 0)
625
+ {
626
+ l += arr.shape[3];
627
+ }
628
+
492
629
  T& result = *grad_at_byte_offset(arr, byte_offset(arr, i, j, k, l));
493
630
  FP_VERIFY_FWD_4(result)
494
631
 
@@ -500,7 +637,12 @@ template <typename T>
500
637
  CUDA_CALLABLE inline T& index(const indexedarray_t<T>& iarr, int i)
501
638
  {
502
639
  assert(iarr.arr.ndim == 1);
503
- assert(i >= 0 && i < iarr.shape[0]);
640
+ assert(i >= -iarr.shape[0] && i < iarr.shape[0]);
641
+
642
+ if (i < 0)
643
+ {
644
+ i += iarr.shape[0];
645
+ }
504
646
 
505
647
  if (iarr.indices[0])
506
648
  {
@@ -518,8 +660,17 @@ template <typename T>
518
660
  CUDA_CALLABLE inline T& index(const indexedarray_t<T>& iarr, int i, int j)
519
661
  {
520
662
  assert(iarr.arr.ndim == 2);
521
- assert(i >= 0 && i < iarr.shape[0]);
522
- assert(j >= 0 && j < iarr.shape[1]);
663
+ assert(i >= -iarr.shape[0] && i < iarr.shape[0]);
664
+ assert(j >= -iarr.shape[1] && j < iarr.shape[1]);
665
+
666
+ if (i < 0)
667
+ {
668
+ i += iarr.shape[0];
669
+ }
670
+ if (j < 0)
671
+ {
672
+ j += iarr.shape[1];
673
+ }
523
674
 
524
675
  if (iarr.indices[0])
525
676
  {
@@ -542,9 +693,22 @@ template <typename T>
542
693
  CUDA_CALLABLE inline T& index(const indexedarray_t<T>& iarr, int i, int j, int k)
543
694
  {
544
695
  assert(iarr.arr.ndim == 3);
545
- assert(i >= 0 && i < iarr.shape[0]);
546
- assert(j >= 0 && j < iarr.shape[1]);
547
- assert(k >= 0 && k < iarr.shape[2]);
696
+ assert(i >= -iarr.shape[0] && i < iarr.shape[0]);
697
+ assert(j >= -iarr.shape[1] && j < iarr.shape[1]);
698
+ assert(k >= -iarr.shape[2] && k < iarr.shape[2]);
699
+
700
+ if (i < 0)
701
+ {
702
+ i += iarr.shape[0];
703
+ }
704
+ if (j < 0)
705
+ {
706
+ j += iarr.shape[1];
707
+ }
708
+ if (k < 0)
709
+ {
710
+ k += iarr.shape[2];
711
+ }
548
712
 
549
713
  if (iarr.indices[0])
550
714
  {
@@ -572,10 +736,27 @@ template <typename T>
572
736
  CUDA_CALLABLE inline T& index(const indexedarray_t<T>& iarr, int i, int j, int k, int l)
573
737
  {
574
738
  assert(iarr.arr.ndim == 4);
575
- assert(i >= 0 && i < iarr.shape[0]);
576
- assert(j >= 0 && j < iarr.shape[1]);
577
- assert(k >= 0 && k < iarr.shape[2]);
578
- assert(l >= 0 && l < iarr.shape[3]);
739
+ assert(i >= -iarr.shape[0] && i < iarr.shape[0]);
740
+ assert(j >= -iarr.shape[1] && j < iarr.shape[1]);
741
+ assert(k >= -iarr.shape[2] && k < iarr.shape[2]);
742
+ assert(l >= -iarr.shape[3] && l < iarr.shape[3]);
743
+
744
+ if (i < 0)
745
+ {
746
+ i += iarr.shape[0];
747
+ }
748
+ if (j < 0)
749
+ {
750
+ j += iarr.shape[1];
751
+ }
752
+ if (k < 0)
753
+ {
754
+ k += iarr.shape[2];
755
+ }
756
+ if (l < 0)
757
+ {
758
+ l += iarr.shape[3];
759
+ }
579
760
 
580
761
  if (iarr.indices[0])
581
762
  {
@@ -609,7 +790,12 @@ template <typename T>
609
790
  CUDA_CALLABLE inline array_t<T> view(array_t<T>& src, int i)
610
791
  {
611
792
  assert(src.ndim > 1);
612
- assert(i >= 0 && i < src.shape[0]);
793
+ assert(i >= -src.shape[0] && i < src.shape[0]);
794
+
795
+ if (i < 0)
796
+ {
797
+ i += src.shape[0];
798
+ }
613
799
 
614
800
  array_t<T> a;
615
801
  size_t offset = byte_offset(src, i);
@@ -631,8 +817,17 @@ template <typename T>
631
817
  CUDA_CALLABLE inline array_t<T> view(array_t<T>& src, int i, int j)
632
818
  {
633
819
  assert(src.ndim > 2);
634
- assert(i >= 0 && i < src.shape[0]);
635
- assert(j >= 0 && j < src.shape[1]);
820
+ assert(i >= -src.shape[0] && i < src.shape[0]);
821
+ assert(j >= -src.shape[1] && j < src.shape[1]);
822
+
823
+ if (i < 0)
824
+ {
825
+ i += src.shape[0];
826
+ }
827
+ if (j < 0)
828
+ {
829
+ j += src.shape[1];
830
+ }
636
831
 
637
832
  array_t<T> a;
638
833
  size_t offset = byte_offset(src, i, j);
@@ -652,9 +847,22 @@ template <typename T>
652
847
  CUDA_CALLABLE inline array_t<T> view(array_t<T>& src, int i, int j, int k)
653
848
  {
654
849
  assert(src.ndim > 3);
655
- assert(i >= 0 && i < src.shape[0]);
656
- assert(j >= 0 && j < src.shape[1]);
657
- assert(k >= 0 && k < src.shape[2]);
850
+ assert(i >= -src.shape[0] && i < src.shape[0]);
851
+ assert(j >= -src.shape[1] && j < src.shape[1]);
852
+ assert(k >= -src.shape[2] && k < src.shape[2]);
853
+
854
+ if (i < 0)
855
+ {
856
+ i += src.shape[0];
857
+ }
858
+ if (j < 0)
859
+ {
860
+ j += src.shape[1];
861
+ }
862
+ if (k < 0)
863
+ {
864
+ k += src.shape[2];
865
+ }
658
866
 
659
867
  array_t<T> a;
660
868
  size_t offset = byte_offset(src, i, j, k);
@@ -669,6 +877,78 @@ CUDA_CALLABLE inline array_t<T> view(array_t<T>& src, int i, int j, int k)
669
877
  }
670
878
 
671
879
 
880
+ template <typename T, size_t... Idxs>
881
+ size_t byte_offset_helper(
882
+ array_t<T>& src,
883
+ const slice_t (&slices)[sizeof...(Idxs)],
884
+ index_sequence<Idxs...>
885
+ )
886
+ {
887
+ return byte_offset(src, slices[Idxs].start...);
888
+ }
889
+
890
+
891
+ template <typename T, typename... Slices>
892
+ CUDA_CALLABLE inline array_t<T> view(array_t<T>& src, const Slices&... slice_args)
893
+ {
894
+ constexpr int N = sizeof...(Slices);
895
+ static_assert(N >= 1 && N <= 4, "view supports 1 to 4 slices");
896
+ assert(src.ndim >= N);
897
+
898
+ slice_t slices[N] = { slice_args... };
899
+ int slice_idxs[N];
900
+ int slice_count = 0;
901
+
902
+ for (int i = 0; i < N; ++i)
903
+ {
904
+ if (slices[i].step == 0)
905
+ {
906
+ // We have a slice representing an integer index.
907
+ if (slices[i].start < 0)
908
+ {
909
+ slices[i].start += src.shape[i];
910
+ }
911
+ }
912
+ else
913
+ {
914
+ slices[i] = slice_adjust_indices(slices[i], src.shape[i]);
915
+ slice_idxs[slice_count] = i;
916
+ ++slice_count;
917
+ }
918
+ }
919
+
920
+ size_t offset = byte_offset_helper(src, slices, make_index_sequence<N>{});
921
+
922
+ array_t<T> out;
923
+
924
+ out.data = data_at_byte_offset(src, offset);
925
+ if (src.grad)
926
+ {
927
+ out.grad = grad_at_byte_offset(src, offset);
928
+ }
929
+
930
+ int dim = 0;
931
+ for (; dim < slice_count; ++dim)
932
+ {
933
+ int idx = slice_idxs[dim];
934
+ out.shape[dim] = slice_get_length(slices[idx]);
935
+ out.strides[dim] = src.strides[idx] * slices[idx].step;
936
+ }
937
+ for (; dim < slice_count + 4 - N; ++dim)
938
+ {
939
+ out.shape[dim] = src.shape[dim - slice_count + N];
940
+ out.strides[dim] = src.strides[dim - slice_count + N];
941
+ }
942
+ for (; dim < 4; ++dim)
943
+ {
944
+ out.shape[dim] = 0;
945
+ out.strides[dim] = 0;
946
+ }
947
+
948
+ out.ndim = src.ndim + slice_count - N;
949
+ return out;
950
+ }
951
+
672
952
  template <typename T>
673
953
  CUDA_CALLABLE inline indexedarray_t<T> view(indexedarray_t<T>& src, int i)
674
954
  {
@@ -676,7 +956,11 @@ CUDA_CALLABLE inline indexedarray_t<T> view(indexedarray_t<T>& src, int i)
676
956
 
677
957
  if (src.indices[0])
678
958
  {
679
- assert(i >= 0 && i < src.shape[0]);
959
+ assert(i >= -src.shape[0] && i < src.shape[0]);
960
+ if (i < 0)
961
+ {
962
+ i += src.shape[0];
963
+ }
680
964
  i = src.indices[0][i];
681
965
  }
682
966
 
@@ -699,12 +983,20 @@ CUDA_CALLABLE inline indexedarray_t<T> view(indexedarray_t<T>& src, int i, int j
699
983
 
700
984
  if (src.indices[0])
701
985
  {
702
- assert(i >= 0 && i < src.shape[0]);
986
+ assert(i >= -src.shape[0] && i < src.shape[0]);
987
+ if (i < 0)
988
+ {
989
+ i += src.shape[0];
990
+ }
703
991
  i = src.indices[0][i];
704
992
  }
705
993
  if (src.indices[1])
706
994
  {
707
- assert(j >= 0 && j < src.shape[1]);
995
+ assert(j >= -src.shape[1] && j < src.shape[1]);
996
+ if (j < 0)
997
+ {
998
+ j += src.shape[1];
999
+ }
708
1000
  j = src.indices[1][j];
709
1001
  }
710
1002
 
@@ -725,17 +1017,29 @@ CUDA_CALLABLE inline indexedarray_t<T> view(indexedarray_t<T>& src, int i, int j
725
1017
 
726
1018
  if (src.indices[0])
727
1019
  {
728
- assert(i >= 0 && i < src.shape[0]);
1020
+ assert(i >= -src.shape[0] && i < src.shape[0]);
1021
+ if (i < 0)
1022
+ {
1023
+ i += src.shape[0];
1024
+ }
729
1025
  i = src.indices[0][i];
730
1026
  }
731
1027
  if (src.indices[1])
732
1028
  {
733
- assert(j >= 0 && j < src.shape[1]);
1029
+ assert(j >= -src.shape[1] && j < src.shape[1]);
1030
+ if (j < 0)
1031
+ {
1032
+ j += src.shape[1];
1033
+ }
734
1034
  j = src.indices[1][j];
735
1035
  }
736
1036
  if (src.indices[2])
737
1037
  {
738
- assert(k >= 0 && k < src.shape[2]);
1038
+ assert(k >= -src.shape[2] && k < src.shape[2]);
1039
+ if (k < 0)
1040
+ {
1041
+ k += src.shape[2];
1042
+ }
739
1043
  k = src.indices[2][k];
740
1044
  }
741
1045
 
@@ -754,6 +1058,9 @@ inline CUDA_CALLABLE void adj_view(A1<T>& src, int i, int j, A2<T>& adj_src, int
754
1058
  template<template<typename> class A1, template<typename> class A2, template<typename> class A3, typename T>
755
1059
  inline CUDA_CALLABLE void adj_view(A1<T>& src, int i, int j, int k, A2<T>& adj_src, int adj_i, int adj_j, int adj_k, A3<T>& adj_ret) {}
756
1060
 
1061
+ template <typename... Args>
1062
+ CUDA_CALLABLE inline void adj_view(Args&&...) { }
1063
+
757
1064
  // TODO: lower_bound() for indexed arrays?
758
1065
 
759
1066
  template <typename T>
@@ -844,6 +1151,33 @@ inline CUDA_CALLABLE T atomic_exch(const A<T>& buf, int i, int j, int k, T value
844
1151
  template<template<typename> class A, typename T>
845
1152
  inline CUDA_CALLABLE T atomic_exch(const A<T>& buf, int i, int j, int k, int l, T value) { return atomic_exch(&index(buf, i, j, k, l), value); }
846
1153
 
1154
+ template<template<typename> class A, typename T>
1155
+ inline CUDA_CALLABLE T atomic_and(const A<T>& buf, int i, T value) { return atomic_and(&index(buf, i), value); }
1156
+ template<template<typename> class A, typename T>
1157
+ inline CUDA_CALLABLE T atomic_and(const A<T>& buf, int i, int j, T value) { return atomic_and(&index(buf, i, j), value); }
1158
+ template<template<typename> class A, typename T>
1159
+ inline CUDA_CALLABLE T atomic_and(const A<T>& buf, int i, int j, int k, T value) { return atomic_and(&index(buf, i, j, k), value); }
1160
+ template<template<typename> class A, typename T>
1161
+ inline CUDA_CALLABLE T atomic_and(const A<T>& buf, int i, int j, int k, int l, T value) { return atomic_and(&index(buf, i, j, k, l), value); }
1162
+
1163
+ template<template<typename> class A, typename T>
1164
+ inline CUDA_CALLABLE T atomic_or(const A<T>& buf, int i, T value) { return atomic_or(&index(buf, i), value); }
1165
+ template<template<typename> class A, typename T>
1166
+ inline CUDA_CALLABLE T atomic_or(const A<T>& buf, int i, int j, T value) { return atomic_or(&index(buf, i, j), value); }
1167
+ template<template<typename> class A, typename T>
1168
+ inline CUDA_CALLABLE T atomic_or(const A<T>& buf, int i, int j, int k, T value) { return atomic_or(&index(buf, i, j, k), value); }
1169
+ template<template<typename> class A, typename T>
1170
+ inline CUDA_CALLABLE T atomic_or(const A<T>& buf, int i, int j, int k, int l, T value) { return atomic_or(&index(buf, i, j, k, l), value); }
1171
+
1172
+ template<template<typename> class A, typename T>
1173
+ inline CUDA_CALLABLE T atomic_xor(const A<T>& buf, int i, T value) { return atomic_xor(&index(buf, i), value); }
1174
+ template<template<typename> class A, typename T>
1175
+ inline CUDA_CALLABLE T atomic_xor(const A<T>& buf, int i, int j, T value) { return atomic_xor(&index(buf, i, j), value); }
1176
+ template<template<typename> class A, typename T>
1177
+ inline CUDA_CALLABLE T atomic_xor(const A<T>& buf, int i, int j, int k, T value) { return atomic_xor(&index(buf, i, j, k), value); }
1178
+ template<template<typename> class A, typename T>
1179
+ inline CUDA_CALLABLE T atomic_xor(const A<T>& buf, int i, int j, int k, int l, T value) { return atomic_xor(&index(buf, i, j, k, l), value); }
1180
+
847
1181
  template<template<typename> class A, typename T>
848
1182
  inline CUDA_CALLABLE T* address(const A<T>& buf, int i)
849
1183
  {
@@ -911,20 +1245,7 @@ inline CUDA_CALLABLE T load(T* address)
911
1245
  return value;
912
1246
  }
913
1247
 
914
- // select operator to check for array being null
915
- template <typename T1, typename T2>
916
- CUDA_CALLABLE inline T2 select(const array_t<T1>& arr, const T2& a, const T2& b) { return arr.data?b:a; }
917
-
918
- template <typename T1, typename T2>
919
- CUDA_CALLABLE inline void adj_select(const array_t<T1>& arr, const T2& a, const T2& b, const array_t<T1>& adj_cond, T2& adj_a, T2& adj_b, const T2& adj_ret)
920
- {
921
- if (arr.data)
922
- adj_b += adj_ret;
923
- else
924
- adj_a += adj_ret;
925
- }
926
-
927
- // where operator to check for array being null, opposite convention compared to select
1248
+ // where() overload for array condition - returns a if array.data is non-null, otherwise returns b
928
1249
  template <typename T1, typename T2>
929
1250
  CUDA_CALLABLE inline T2 where(const array_t<T1>& arr, const T2& a, const T2& b) { return arr.data?a:b; }
930
1251
 
@@ -1321,6 +1642,34 @@ inline CUDA_CALLABLE void adj_atomic_exch(const A1<T>& buf, int i, int j, int k,
1321
1642
  FP_VERIFY_ADJ_4(value, adj_value)
1322
1643
  }
1323
1644
 
1645
+ // for bitwise operations we do not accumulate gradients
1646
+ template<template<typename> class A1, template<typename> class A2, typename T>
1647
+ inline CUDA_CALLABLE void adj_atomic_and(const A1<T>& buf, int i, T value, const A2<T>& adj_buf, int adj_i, T& adj_value, const T& adj_ret) {}
1648
+ template<template<typename> class A1, template<typename> class A2, typename T>
1649
+ inline CUDA_CALLABLE void adj_atomic_and(const A1<T>& buf, int i, int j, T value, const A2<T>& adj_buf, int adj_i, int adj_j, T& adj_value, const T& adj_ret) {}
1650
+ template<template<typename> class A1, template<typename> class A2, typename T>
1651
+ inline CUDA_CALLABLE void adj_atomic_and(const A1<T>& buf, int i, int j, int k, T value, const A2<T>& adj_buf, int adj_i, int adj_j, int adj_k, T& adj_value, const T& adj_ret) {}
1652
+ template<template<typename> class A1, template<typename> class A2, typename T>
1653
+ inline CUDA_CALLABLE void adj_atomic_and(const A1<T>& buf, int i, int j, int k, int l, T value, const A2<T>& adj_buf, int adj_i, int adj_j, int adj_k, int adj_l, T& adj_value, const T& adj_ret) {}
1654
+
1655
+ template<template<typename> class A1, template<typename> class A2, typename T>
1656
+ inline CUDA_CALLABLE void adj_atomic_or(const A1<T>& buf, int i, T value, const A2<T>& adj_buf, int adj_i, T& adj_value, const T& adj_ret) {}
1657
+ template<template<typename> class A1, template<typename> class A2, typename T>
1658
+ inline CUDA_CALLABLE void adj_atomic_or(const A1<T>& buf, int i, int j, T value, const A2<T>& adj_buf, int adj_i, int adj_j, T& adj_value, const T& adj_ret) {}
1659
+ template<template<typename> class A1, template<typename> class A2, typename T>
1660
+ inline CUDA_CALLABLE void adj_atomic_or(const A1<T>& buf, int i, int j, int k, T value, const A2<T>& adj_buf, int adj_i, int adj_j, int adj_k, T& adj_value, const T& adj_ret) {}
1661
+ template<template<typename> class A1, template<typename> class A2, typename T>
1662
+ inline CUDA_CALLABLE void adj_atomic_or(const A1<T>& buf, int i, int j, int k, int l, T value, const A2<T>& adj_buf, int adj_i, int adj_j, int adj_k, int adj_l, T& adj_value, const T& adj_ret) {}
1663
+
1664
+ template<template<typename> class A1, template<typename> class A2, typename T>
1665
+ inline CUDA_CALLABLE void adj_atomic_xor(const A1<T>& buf, int i, T value, const A2<T>& adj_buf, int adj_i, T& adj_value, const T& adj_ret) {}
1666
+ template<template<typename> class A1, template<typename> class A2, typename T>
1667
+ inline CUDA_CALLABLE void adj_atomic_xor(const A1<T>& buf, int i, int j, T value, const A2<T>& adj_buf, int adj_i, int adj_j, T& adj_value, const T& adj_ret) {}
1668
+ template<template<typename> class A1, template<typename> class A2, typename T>
1669
+ inline CUDA_CALLABLE void adj_atomic_xor(const A1<T>& buf, int i, int j, int k, T value, const A2<T>& adj_buf, int adj_i, int adj_j, int adj_k, T& adj_value, const T& adj_ret) {}
1670
+ template<template<typename> class A1, template<typename> class A2, typename T>
1671
+ inline CUDA_CALLABLE void adj_atomic_xor(const A1<T>& buf, int i, int j, int k, int l, T value, const A2<T>& adj_buf, int adj_i, int adj_j, int adj_k, int adj_l, T& adj_value, const T& adj_ret) {}
1672
+
1324
1673
 
1325
1674
  template<template<typename> class A, typename T>
1326
1675
  CUDA_CALLABLE inline int len(const A<T>& a)
@@ -1333,7 +1682,6 @@ CUDA_CALLABLE inline void adj_len(const A<T>& a, A<T>& adj_a, int& adj_ret)
1333
1682
  {
1334
1683
  }
1335
1684
 
1336
-
1337
1685
  } // namespace wp
1338
1686
 
1339
1687
  #include "fabric.h"