warp-lang 0.9.0__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (315) hide show
  1. warp/__init__.py +15 -7
  2. warp/__init__.pyi +1 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +22 -443
  6. warp/build_dll.py +384 -0
  7. warp/builtins.py +998 -488
  8. warp/codegen.py +1307 -739
  9. warp/config.py +5 -3
  10. warp/constants.py +6 -0
  11. warp/context.py +1291 -548
  12. warp/dlpack.py +31 -31
  13. warp/fabric.py +326 -0
  14. warp/fem/__init__.py +27 -0
  15. warp/fem/cache.py +389 -0
  16. warp/fem/dirichlet.py +181 -0
  17. warp/fem/domain.py +263 -0
  18. warp/fem/field/__init__.py +101 -0
  19. warp/fem/field/field.py +149 -0
  20. warp/fem/field/nodal_field.py +299 -0
  21. warp/fem/field/restriction.py +21 -0
  22. warp/fem/field/test.py +181 -0
  23. warp/fem/field/trial.py +183 -0
  24. warp/fem/geometry/__init__.py +19 -0
  25. warp/fem/geometry/closest_point.py +70 -0
  26. warp/fem/geometry/deformed_geometry.py +271 -0
  27. warp/fem/geometry/element.py +744 -0
  28. warp/fem/geometry/geometry.py +186 -0
  29. warp/fem/geometry/grid_2d.py +373 -0
  30. warp/fem/geometry/grid_3d.py +435 -0
  31. warp/fem/geometry/hexmesh.py +953 -0
  32. warp/fem/geometry/partition.py +376 -0
  33. warp/fem/geometry/quadmesh_2d.py +532 -0
  34. warp/fem/geometry/tetmesh.py +840 -0
  35. warp/fem/geometry/trimesh_2d.py +577 -0
  36. warp/fem/integrate.py +1616 -0
  37. warp/fem/operator.py +191 -0
  38. warp/fem/polynomial.py +213 -0
  39. warp/fem/quadrature/__init__.py +2 -0
  40. warp/fem/quadrature/pic_quadrature.py +245 -0
  41. warp/fem/quadrature/quadrature.py +294 -0
  42. warp/fem/space/__init__.py +292 -0
  43. warp/fem/space/basis_space.py +489 -0
  44. warp/fem/space/collocated_function_space.py +105 -0
  45. warp/fem/space/dof_mapper.py +236 -0
  46. warp/fem/space/function_space.py +145 -0
  47. warp/fem/space/grid_2d_function_space.py +267 -0
  48. warp/fem/space/grid_3d_function_space.py +306 -0
  49. warp/fem/space/hexmesh_function_space.py +352 -0
  50. warp/fem/space/partition.py +350 -0
  51. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  52. warp/fem/space/restriction.py +160 -0
  53. warp/fem/space/shape/__init__.py +15 -0
  54. warp/fem/space/shape/cube_shape_function.py +738 -0
  55. warp/fem/space/shape/shape_function.py +103 -0
  56. warp/fem/space/shape/square_shape_function.py +611 -0
  57. warp/fem/space/shape/tet_shape_function.py +567 -0
  58. warp/fem/space/shape/triangle_shape_function.py +429 -0
  59. warp/fem/space/tetmesh_function_space.py +292 -0
  60. warp/fem/space/topology.py +295 -0
  61. warp/fem/space/trimesh_2d_function_space.py +221 -0
  62. warp/fem/types.py +77 -0
  63. warp/fem/utils.py +495 -0
  64. warp/native/array.h +164 -55
  65. warp/native/builtin.h +150 -174
  66. warp/native/bvh.cpp +75 -328
  67. warp/native/bvh.cu +406 -23
  68. warp/native/bvh.h +37 -45
  69. warp/native/clang/clang.cpp +136 -24
  70. warp/native/crt.cpp +1 -76
  71. warp/native/crt.h +111 -104
  72. warp/native/cuda_crt.h +1049 -0
  73. warp/native/cuda_util.cpp +15 -3
  74. warp/native/cuda_util.h +3 -1
  75. warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
  76. warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
  77. warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
  78. warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
  79. warp/native/cutlass/tools/library/scripts/library.py +799 -0
  80. warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
  81. warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
  82. warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
  83. warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
  84. warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
  85. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
  86. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
  87. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
  88. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
  89. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
  90. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
  91. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
  92. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
  93. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
  94. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
  95. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
  96. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
  97. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
  98. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
  99. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
  100. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
  101. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
  102. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
  103. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
  104. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
  105. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
  106. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
  107. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
  108. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
  109. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
  110. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
  111. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
  112. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
  113. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
  114. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
  115. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
  116. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
  117. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  118. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  119. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
  120. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
  121. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
  122. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  123. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  124. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
  125. warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
  126. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
  127. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
  128. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
  129. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
  130. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
  131. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
  132. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
  133. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
  134. warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
  135. warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
  136. warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
  137. warp/native/cutlass/tools/library/scripts/rt.py +796 -0
  138. warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
  139. warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
  140. warp/native/cutlass_gemm.cu +5 -3
  141. warp/native/exports.h +1240 -949
  142. warp/native/fabric.h +228 -0
  143. warp/native/hashgrid.cpp +4 -4
  144. warp/native/hashgrid.h +22 -2
  145. warp/native/initializer_array.h +2 -2
  146. warp/native/intersect.h +22 -7
  147. warp/native/intersect_adj.h +8 -8
  148. warp/native/intersect_tri.h +13 -16
  149. warp/native/marching.cu +157 -161
  150. warp/native/mat.h +119 -19
  151. warp/native/matnn.h +2 -2
  152. warp/native/mesh.cpp +108 -83
  153. warp/native/mesh.cu +243 -6
  154. warp/native/mesh.h +1547 -458
  155. warp/native/nanovdb/NanoVDB.h +1 -1
  156. warp/native/noise.h +272 -329
  157. warp/native/quat.h +51 -8
  158. warp/native/rand.h +45 -35
  159. warp/native/range.h +6 -2
  160. warp/native/reduce.cpp +157 -0
  161. warp/native/reduce.cu +348 -0
  162. warp/native/runlength_encode.cpp +62 -0
  163. warp/native/runlength_encode.cu +46 -0
  164. warp/native/scan.cu +11 -13
  165. warp/native/scan.h +1 -0
  166. warp/native/solid_angle.h +442 -0
  167. warp/native/sort.cpp +13 -0
  168. warp/native/sort.cu +9 -1
  169. warp/native/sparse.cpp +338 -0
  170. warp/native/sparse.cu +545 -0
  171. warp/native/spatial.h +2 -2
  172. warp/native/temp_buffer.h +30 -0
  173. warp/native/vec.h +126 -24
  174. warp/native/volume.h +120 -0
  175. warp/native/warp.cpp +658 -53
  176. warp/native/warp.cu +660 -68
  177. warp/native/warp.h +112 -12
  178. warp/optim/__init__.py +1 -0
  179. warp/optim/linear.py +922 -0
  180. warp/optim/sgd.py +92 -0
  181. warp/render/render_opengl.py +392 -152
  182. warp/render/render_usd.py +11 -11
  183. warp/sim/__init__.py +2 -2
  184. warp/sim/articulation.py +385 -185
  185. warp/sim/collide.py +21 -8
  186. warp/sim/import_mjcf.py +297 -106
  187. warp/sim/import_urdf.py +389 -210
  188. warp/sim/import_usd.py +198 -97
  189. warp/sim/inertia.py +17 -18
  190. warp/sim/integrator_euler.py +14 -8
  191. warp/sim/integrator_xpbd.py +161 -19
  192. warp/sim/model.py +795 -291
  193. warp/sim/optimizer.py +2 -6
  194. warp/sim/render.py +65 -3
  195. warp/sim/utils.py +3 -0
  196. warp/sparse.py +1227 -0
  197. warp/stubs.py +665 -223
  198. warp/tape.py +66 -15
  199. warp/tests/__main__.py +3 -6
  200. warp/tests/assets/curlnoise_golden.npy +0 -0
  201. warp/tests/assets/pnoise_golden.npy +0 -0
  202. warp/tests/assets/torus.usda +105 -105
  203. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  204. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  205. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  206. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  207. warp/tests/aux_test_unresolved_func.py +14 -0
  208. warp/tests/aux_test_unresolved_symbol.py +14 -0
  209. warp/tests/disabled_kinematics.py +239 -0
  210. warp/tests/run_coverage_serial.py +31 -0
  211. warp/tests/test_adam.py +103 -106
  212. warp/tests/test_arithmetic.py +128 -74
  213. warp/tests/test_array.py +1497 -211
  214. warp/tests/test_array_reduce.py +150 -0
  215. warp/tests/test_atomic.py +64 -28
  216. warp/tests/test_bool.py +99 -0
  217. warp/tests/test_builtins_resolution.py +1292 -0
  218. warp/tests/test_bvh.py +75 -43
  219. warp/tests/test_closest_point_edge_edge.py +54 -57
  220. warp/tests/test_codegen.py +233 -128
  221. warp/tests/test_compile_consts.py +28 -20
  222. warp/tests/test_conditional.py +108 -24
  223. warp/tests/test_copy.py +10 -12
  224. warp/tests/test_ctypes.py +112 -88
  225. warp/tests/test_dense.py +21 -14
  226. warp/tests/test_devices.py +98 -0
  227. warp/tests/test_dlpack.py +136 -108
  228. warp/tests/test_examples.py +277 -0
  229. warp/tests/test_fabricarray.py +955 -0
  230. warp/tests/test_fast_math.py +15 -11
  231. warp/tests/test_fem.py +1271 -0
  232. warp/tests/test_fp16.py +53 -19
  233. warp/tests/test_func.py +187 -74
  234. warp/tests/test_generics.py +194 -49
  235. warp/tests/test_grad.py +180 -116
  236. warp/tests/test_grad_customs.py +176 -0
  237. warp/tests/test_hash_grid.py +52 -37
  238. warp/tests/test_import.py +10 -23
  239. warp/tests/test_indexedarray.py +577 -24
  240. warp/tests/test_intersect.py +18 -9
  241. warp/tests/test_large.py +141 -0
  242. warp/tests/test_launch.py +251 -15
  243. warp/tests/test_lerp.py +64 -65
  244. warp/tests/test_linear_solvers.py +154 -0
  245. warp/tests/test_lvalue.py +493 -0
  246. warp/tests/test_marching_cubes.py +12 -13
  247. warp/tests/test_mat.py +508 -2778
  248. warp/tests/test_mat_lite.py +115 -0
  249. warp/tests/test_mat_scalar_ops.py +2889 -0
  250. warp/tests/test_math.py +103 -9
  251. warp/tests/test_matmul.py +305 -69
  252. warp/tests/test_matmul_lite.py +410 -0
  253. warp/tests/test_mesh.py +71 -14
  254. warp/tests/test_mesh_query_aabb.py +41 -25
  255. warp/tests/test_mesh_query_point.py +325 -34
  256. warp/tests/test_mesh_query_ray.py +39 -22
  257. warp/tests/test_mlp.py +30 -22
  258. warp/tests/test_model.py +92 -89
  259. warp/tests/test_modules_lite.py +39 -0
  260. warp/tests/test_multigpu.py +88 -114
  261. warp/tests/test_noise.py +12 -11
  262. warp/tests/test_operators.py +16 -20
  263. warp/tests/test_options.py +11 -11
  264. warp/tests/test_pinned.py +17 -18
  265. warp/tests/test_print.py +32 -11
  266. warp/tests/test_quat.py +275 -129
  267. warp/tests/test_rand.py +18 -16
  268. warp/tests/test_reload.py +38 -34
  269. warp/tests/test_rounding.py +50 -43
  270. warp/tests/test_runlength_encode.py +190 -0
  271. warp/tests/test_smoothstep.py +9 -11
  272. warp/tests/test_snippet.py +143 -0
  273. warp/tests/test_sparse.py +460 -0
  274. warp/tests/test_spatial.py +276 -243
  275. warp/tests/test_streams.py +110 -85
  276. warp/tests/test_struct.py +331 -85
  277. warp/tests/test_tape.py +39 -21
  278. warp/tests/test_torch.py +118 -89
  279. warp/tests/test_transient_module.py +12 -13
  280. warp/tests/test_types.py +614 -0
  281. warp/tests/test_utils.py +494 -0
  282. warp/tests/test_vec.py +354 -1987
  283. warp/tests/test_vec_lite.py +73 -0
  284. warp/tests/test_vec_scalar_ops.py +2099 -0
  285. warp/tests/test_volume.py +457 -293
  286. warp/tests/test_volume_write.py +124 -134
  287. warp/tests/unittest_serial.py +35 -0
  288. warp/tests/unittest_suites.py +341 -0
  289. warp/tests/unittest_utils.py +568 -0
  290. warp/tests/unused_test_misc.py +71 -0
  291. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  292. warp/thirdparty/appdirs.py +36 -45
  293. warp/thirdparty/unittest_parallel.py +549 -0
  294. warp/torch.py +72 -30
  295. warp/types.py +1744 -713
  296. warp/utils.py +360 -350
  297. warp_lang-0.11.0.dist-info/LICENSE.md +36 -0
  298. warp_lang-0.11.0.dist-info/METADATA +238 -0
  299. warp_lang-0.11.0.dist-info/RECORD +332 -0
  300. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
  301. warp/bin/warp-clang.exp +0 -0
  302. warp/bin/warp-clang.lib +0 -0
  303. warp/bin/warp.exp +0 -0
  304. warp/bin/warp.lib +0 -0
  305. warp/tests/test_all.py +0 -215
  306. warp/tests/test_array_scan.py +0 -60
  307. warp/tests/test_base.py +0 -208
  308. warp/tests/test_unresolved_func.py +0 -7
  309. warp/tests/test_unresolved_symbol.py +0 -7
  310. warp_lang-0.9.0.dist-info/METADATA +0 -20
  311. warp_lang-0.9.0.dist-info/RECORD +0 -177
  312. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  313. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  314. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  315. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
warp/native/array.h CHANGED
@@ -19,6 +19,12 @@ namespace wp
19
19
  printf(")\n"); \
20
20
  assert(0); \
21
21
 
22
+ #define FP_VERIFY_FWD(value) \
23
+ if (!isfinite(value)) { \
24
+ printf("%s:%d - %s(addr", __FILE__, __LINE__, __FUNCTION__); \
25
+ FP_ASSERT_FWD(value) \
26
+ } \
27
+
22
28
  #define FP_VERIFY_FWD_1(value) \
23
29
  if (!isfinite(value)) { \
24
30
  printf("%s:%d - %s(arr, %d) ", __FILE__, __LINE__, __FUNCTION__, i); \
@@ -43,6 +49,13 @@ namespace wp
43
49
  FP_ASSERT_FWD(value) \
44
50
  } \
45
51
 
52
+ #define FP_VERIFY_ADJ(value, adj_value) \
53
+ if (!isfinite(value) || !isfinite(adj_value)) \
54
+ { \
55
+ printf("%s:%d - %s(addr", __FILE__, __LINE__, __FUNCTION__); \
56
+ FP_ASSERT_ADJ(value, adj_value); \
57
+ } \
58
+
46
59
  #define FP_VERIFY_ADJ_1(value, adj_value) \
47
60
  if (!isfinite(value) || !isfinite(adj_value)) \
48
61
  { \
@@ -74,11 +87,13 @@ namespace wp
74
87
 
75
88
  #else
76
89
 
90
+ #define FP_VERIFY_FWD(value) {}
77
91
  #define FP_VERIFY_FWD_1(value) {}
78
92
  #define FP_VERIFY_FWD_2(value) {}
79
93
  #define FP_VERIFY_FWD_3(value) {}
80
94
  #define FP_VERIFY_FWD_4(value) {}
81
95
 
96
+ #define FP_VERIFY_ADJ(value, adj_value) {}
82
97
  #define FP_VERIFY_ADJ_1(value, adj_value) {}
83
98
  #define FP_VERIFY_ADJ_2(value, adj_value) {}
84
99
  #define FP_VERIFY_ADJ_3(value, adj_value) {}
@@ -88,14 +103,19 @@ namespace wp
88
103
 
89
104
  const int ARRAY_MAX_DIMS = 4; // must match constant in types.py
90
105
 
91
- const int ARRAY_TYPE_REGULAR = 0; // must match constant in types.py
92
- const int ARRAY_TYPE_INDEXED = 1; // must match constant in types.py
106
+ // must match constants in types.py
107
+ const int ARRAY_TYPE_REGULAR = 0;
108
+ const int ARRAY_TYPE_INDEXED = 1;
109
+ const int ARRAY_TYPE_FABRIC = 2;
110
+ const int ARRAY_TYPE_FABRIC_INDEXED = 3;
93
111
 
94
112
  struct shape_t
95
113
  {
96
114
  int dims[ARRAY_MAX_DIMS];
97
115
 
98
- CUDA_CALLABLE inline shape_t() : dims() {}
116
+ CUDA_CALLABLE inline shape_t()
117
+ : dims()
118
+ {}
99
119
 
100
120
  CUDA_CALLABLE inline int operator[](int i) const
101
121
  {
@@ -110,12 +130,12 @@ struct shape_t
110
130
  }
111
131
  };
112
132
 
113
- CUDA_CALLABLE inline int index(const shape_t& s, int i)
133
+ CUDA_CALLABLE inline int extract(const shape_t& s, int i)
114
134
  {
115
135
  return s.dims[i];
116
136
  }
117
137
 
118
- CUDA_CALLABLE inline void adj_index(const shape_t& s, int i, const shape_t& adj_s, int adj_i, int adj_ret) {}
138
+ CUDA_CALLABLE inline void adj_extract(const shape_t& s, int i, const shape_t& adj_s, int adj_i, int adj_ret) {}
119
139
 
120
140
  inline CUDA_CALLABLE void print(shape_t s)
121
141
  {
@@ -130,10 +150,15 @@ inline CUDA_CALLABLE void adj_print(shape_t s, shape_t& shape_t) {}
130
150
  template <typename T>
131
151
  struct array_t
132
152
  {
133
- CUDA_CALLABLE inline array_t() {}
134
- CUDA_CALLABLE inline array_t(int) {} // for backward a = 0 initialization syntax
135
-
136
- array_t(T* data, int size, T* grad=nullptr) : data(data), grad(grad) {
153
+ CUDA_CALLABLE inline array_t()
154
+ : data(nullptr),
155
+ grad(nullptr),
156
+ shape(),
157
+ strides(),
158
+ ndim(0)
159
+ {}
160
+
161
+ CUDA_CALLABLE array_t(T* data, int size, T* grad=nullptr) : data(data), grad(grad) {
137
162
  // constructor for 1d array
138
163
  shape.dims[0] = size;
139
164
  shape.dims[1] = 0;
@@ -145,7 +170,7 @@ struct array_t
145
170
  strides[2] = 0;
146
171
  strides[3] = 0;
147
172
  }
148
- array_t(T* data, int dim0, int dim1, T* grad=nullptr) : data(data), grad(grad) {
173
+ CUDA_CALLABLE array_t(T* data, int dim0, int dim1, T* grad=nullptr) : data(data), grad(grad) {
149
174
  // constructor for 2d array
150
175
  shape.dims[0] = dim0;
151
176
  shape.dims[1] = dim1;
@@ -157,7 +182,7 @@ struct array_t
157
182
  strides[2] = 0;
158
183
  strides[3] = 0;
159
184
  }
160
- array_t(T* data, int dim0, int dim1, int dim2, T* grad=nullptr) : data(data), grad(grad) {
185
+ CUDA_CALLABLE array_t(T* data, int dim0, int dim1, int dim2, T* grad=nullptr) : data(data), grad(grad) {
161
186
  // constructor for 3d array
162
187
  shape.dims[0] = dim0;
163
188
  shape.dims[1] = dim1;
@@ -169,7 +194,7 @@ struct array_t
169
194
  strides[2] = sizeof(T);
170
195
  strides[3] = 0;
171
196
  }
172
- array_t(T* data, int dim0, int dim1, int dim2, int dim3, T* grad=nullptr) : data(data), grad(grad) {
197
+ CUDA_CALLABLE array_t(T* data, int dim0, int dim1, int dim2, int dim3, T* grad=nullptr) : data(data), grad(grad) {
173
198
  // constructor for 4d array
174
199
  shape.dims[0] = dim0;
175
200
  shape.dims[1] = dim1;
@@ -182,10 +207,10 @@ struct array_t
182
207
  strides[3] = sizeof(T);
183
208
  }
184
209
 
185
- inline bool empty() const { return !data; }
210
+ CUDA_CALLABLE inline bool empty() const { return !data; }
186
211
 
187
- T* data{nullptr};
188
- T* grad{nullptr};
212
+ T* data;
213
+ T* grad;
189
214
  shape_t shape;
190
215
  int strides[ARRAY_MAX_DIMS];
191
216
  int ndim;
@@ -200,10 +225,13 @@ struct array_t
200
225
  template <typename T>
201
226
  struct indexedarray_t
202
227
  {
203
- CUDA_CALLABLE inline indexedarray_t() {}
204
- CUDA_CALLABLE inline indexedarray_t(int) {} // for backward a = 0 initialization syntax
228
+ CUDA_CALLABLE inline indexedarray_t()
229
+ : arr(),
230
+ indices(),
231
+ shape()
232
+ {}
205
233
 
206
- inline bool empty() const { return !arr.data; }
234
+ CUDA_CALLABLE inline bool empty() const { return !arr.data; }
207
235
 
208
236
  array_t<T> arr;
209
237
  int* indices[ARRAY_MAX_DIMS]; // index array per dimension (can be NULL)
@@ -597,13 +625,12 @@ inline CUDA_CALLABLE void adj_view(A1<T>& src, int i, int j, int k, A2<T>& adj_s
597
625
  // TODO: lower_bound() for indexed arrays?
598
626
 
599
627
  template <typename T>
600
- CUDA_CALLABLE inline int lower_bound(const array_t<T>& arr, T value)
628
+ CUDA_CALLABLE inline int lower_bound(const array_t<T>& arr, int arr_begin, int arr_end, T value)
601
629
  {
602
630
  assert(arr.ndim == 1);
603
- int n = arr.shape[0];
604
631
 
605
- int lower = 0;
606
- int upper = n - 1;
632
+ int lower = arr_begin;
633
+ int upper = arr_end - 1;
607
634
 
608
635
  while(lower < upper)
609
636
  {
@@ -622,7 +649,14 @@ CUDA_CALLABLE inline int lower_bound(const array_t<T>& arr, T value)
622
649
  return lower;
623
650
  }
624
651
 
652
+ template <typename T>
653
+ CUDA_CALLABLE inline int lower_bound(const array_t<T>& arr, T value)
654
+ {
655
+ return lower_bound(arr, 0, arr.shape[0], value);
656
+ }
657
+
625
658
  template <typename T> inline CUDA_CALLABLE void adj_lower_bound(const array_t<T>& arr, T value, array_t<T> adj_arr, T adj_value, int adj_ret) {}
659
+ template <typename T> inline CUDA_CALLABLE void adj_lower_bound(const array_t<T>& arr, int arr_begin, int arr_end, T value, array_t<T> adj_arr, int adj_arr_begin, int adj_arr_end, T adj_value, int adj_ret) {}
626
660
 
627
661
  template<template<typename> class A, typename T>
628
662
  inline CUDA_CALLABLE T atomic_add(const A<T>& buf, int i, T value) { return atomic_add(&index(buf, i), value); }
@@ -661,43 +695,60 @@ template<template<typename> class A, typename T>
661
695
  inline CUDA_CALLABLE T atomic_max(const A<T>& buf, int i, int j, int k, int l, T value) { return atomic_max(&index(buf, i, j, k, l), value); }
662
696
 
663
697
  template<template<typename> class A, typename T>
664
- inline CUDA_CALLABLE T load(const A<T>& buf, int i) { return index(buf, i); }
698
+ inline CUDA_CALLABLE T* address(const A<T>& buf, int i) { return &index(buf, i); }
665
699
  template<template<typename> class A, typename T>
666
- inline CUDA_CALLABLE T load(const A<T>& buf, int i, int j) { return index(buf, i, j); }
700
+ inline CUDA_CALLABLE T* address(const A<T>& buf, int i, int j) { return &index(buf, i, j); }
667
701
  template<template<typename> class A, typename T>
668
- inline CUDA_CALLABLE T load(const A<T>& buf, int i, int j, int k) { return index(buf, i, j, k); }
702
+ inline CUDA_CALLABLE T* address(const A<T>& buf, int i, int j, int k) { return &index(buf, i, j, k); }
669
703
  template<template<typename> class A, typename T>
670
- inline CUDA_CALLABLE T load(const A<T>& buf, int i, int j, int k, int l) { return index(buf, i, j, k, l); }
704
+ inline CUDA_CALLABLE T* address(const A<T>& buf, int i, int j, int k, int l) { return &index(buf, i, j, k, l); }
671
705
 
672
706
  template<template<typename> class A, typename T>
673
- inline CUDA_CALLABLE void store(const A<T>& buf, int i, T value)
707
+ inline CUDA_CALLABLE void array_store(const A<T>& buf, int i, T value)
674
708
  {
675
709
  FP_VERIFY_FWD_1(value)
676
710
 
677
711
  index(buf, i) = value;
678
712
  }
679
713
  template<template<typename> class A, typename T>
680
- inline CUDA_CALLABLE void store(const A<T>& buf, int i, int j, T value)
714
+ inline CUDA_CALLABLE void array_store(const A<T>& buf, int i, int j, T value)
681
715
  {
682
716
  FP_VERIFY_FWD_2(value)
683
717
 
684
718
  index(buf, i, j) = value;
685
719
  }
686
720
  template<template<typename> class A, typename T>
687
- inline CUDA_CALLABLE void store(const A<T>& buf, int i, int j, int k, T value)
721
+ inline CUDA_CALLABLE void array_store(const A<T>& buf, int i, int j, int k, T value)
688
722
  {
689
723
  FP_VERIFY_FWD_3(value)
690
724
 
691
725
  index(buf, i, j, k) = value;
692
726
  }
693
727
  template<template<typename> class A, typename T>
694
- inline CUDA_CALLABLE void store(const A<T>& buf, int i, int j, int k, int l, T value)
728
+ inline CUDA_CALLABLE void array_store(const A<T>& buf, int i, int j, int k, int l, T value)
695
729
  {
696
730
  FP_VERIFY_FWD_4(value)
697
731
 
698
732
  index(buf, i, j, k, l) = value;
699
733
  }
700
734
 
735
+ template<typename T>
736
+ inline CUDA_CALLABLE void store(T* address, T value)
737
+ {
738
+ FP_VERIFY_FWD(value)
739
+
740
+ *address = value;
741
+ }
742
+
743
+ template<typename T>
744
+ inline CUDA_CALLABLE T load(T* address)
745
+ {
746
+ T value = *address;
747
+ FP_VERIFY_FWD(value)
748
+
749
+ return value;
750
+ }
751
+
701
752
  // select operator to check for array being null
702
753
  template <typename T1, typename T2>
703
754
  CUDA_CALLABLE inline T2 select(const array_t<T1>& arr, const T2& a, const T2& b) { return arr.data?b:a; }
@@ -731,34 +782,36 @@ CUDA_CALLABLE inline void adj_atomic_add(uint32* buf, uint32 value) { }
731
782
  CUDA_CALLABLE inline void adj_atomic_add(int64* buf, int64 value) { }
732
783
  CUDA_CALLABLE inline void adj_atomic_add(uint64* buf, uint64 value) { }
733
784
 
785
+ CUDA_CALLABLE inline void adj_atomic_add(bool* buf, bool value) { }
786
+
734
787
  // only generate gradients for T types
735
788
  template<typename T>
736
- inline CUDA_CALLABLE void adj_load(const array_t<T>& buf, int i, const array_t<T>& adj_buf, int& adj_i, const T& adj_output)
789
+ inline CUDA_CALLABLE void adj_address(const array_t<T>& buf, int i, const array_t<T>& adj_buf, int& adj_i, const T& adj_output)
737
790
  {
738
791
  if (buf.grad)
739
792
  adj_atomic_add(&index_grad(buf, i), adj_output);
740
793
  }
741
794
  template<typename T>
742
- inline CUDA_CALLABLE void adj_load(const array_t<T>& buf, int i, int j, const array_t<T>& adj_buf, int& adj_i, int& adj_j, const T& adj_output)
795
+ inline CUDA_CALLABLE void adj_address(const array_t<T>& buf, int i, int j, const array_t<T>& adj_buf, int& adj_i, int& adj_j, const T& adj_output)
743
796
  {
744
797
  if (buf.grad)
745
798
  adj_atomic_add(&index_grad(buf, i, j), adj_output);
746
799
  }
747
800
  template<typename T>
748
- inline CUDA_CALLABLE void adj_load(const array_t<T>& buf, int i, int j, int k, const array_t<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, const T& adj_output)
801
+ inline CUDA_CALLABLE void adj_address(const array_t<T>& buf, int i, int j, int k, const array_t<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, const T& adj_output)
749
802
  {
750
803
  if (buf.grad)
751
804
  adj_atomic_add(&index_grad(buf, i, j, k), adj_output);
752
805
  }
753
806
  template<typename T>
754
- inline CUDA_CALLABLE void adj_load(const array_t<T>& buf, int i, int j, int k, int l, const array_t<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, int& adj_l, const T& adj_output)
807
+ inline CUDA_CALLABLE void adj_address(const array_t<T>& buf, int i, int j, int k, int l, const array_t<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, int& adj_l, const T& adj_output)
755
808
  {
756
809
  if (buf.grad)
757
810
  adj_atomic_add(&index_grad(buf, i, j, k, l), adj_output);
758
811
  }
759
812
 
760
813
  template<typename T>
761
- inline CUDA_CALLABLE void adj_store(const array_t<T>& buf, int i, T value, const array_t<T>& adj_buf, int& adj_i, T& adj_value)
814
+ inline CUDA_CALLABLE void adj_array_store(const array_t<T>& buf, int i, T value, const array_t<T>& adj_buf, int& adj_i, T& adj_value)
762
815
  {
763
816
  if (buf.grad)
764
817
  adj_value += index_grad(buf, i);
@@ -766,7 +819,7 @@ inline CUDA_CALLABLE void adj_store(const array_t<T>& buf, int i, T value, const
766
819
  FP_VERIFY_ADJ_1(value, adj_value)
767
820
  }
768
821
  template<typename T>
769
- inline CUDA_CALLABLE void adj_store(const array_t<T>& buf, int i, int j, T value, const array_t<T>& adj_buf, int& adj_i, int& adj_j, T& adj_value)
822
+ inline CUDA_CALLABLE void adj_array_store(const array_t<T>& buf, int i, int j, T value, const array_t<T>& adj_buf, int& adj_i, int& adj_j, T& adj_value)
770
823
  {
771
824
  if (buf.grad)
772
825
  adj_value += index_grad(buf, i, j);
@@ -775,7 +828,7 @@ inline CUDA_CALLABLE void adj_store(const array_t<T>& buf, int i, int j, T value
775
828
 
776
829
  }
777
830
  template<typename T>
778
- inline CUDA_CALLABLE void adj_store(const array_t<T>& buf, int i, int j, int k, T value, const array_t<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, T& adj_value)
831
+ inline CUDA_CALLABLE void adj_array_store(const array_t<T>& buf, int i, int j, int k, T value, const array_t<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, T& adj_value)
779
832
  {
780
833
  if (buf.grad)
781
834
  adj_value += index_grad(buf, i, j, k);
@@ -783,7 +836,7 @@ inline CUDA_CALLABLE void adj_store(const array_t<T>& buf, int i, int j, int k,
783
836
  FP_VERIFY_ADJ_3(value, adj_value)
784
837
  }
785
838
  template<typename T>
786
- inline CUDA_CALLABLE void adj_store(const array_t<T>& buf, int i, int j, int k, int l, T value, const array_t<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, int& adj_l, T& adj_value)
839
+ inline CUDA_CALLABLE void adj_array_store(const array_t<T>& buf, int i, int j, int k, int l, T value, const array_t<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, int& adj_l, T& adj_value)
787
840
  {
788
841
  if (buf.grad)
789
842
  adj_value += index_grad(buf, i, j, k, l);
@@ -791,6 +844,19 @@ inline CUDA_CALLABLE void adj_store(const array_t<T>& buf, int i, int j, int k,
791
844
  FP_VERIFY_ADJ_4(value, adj_value)
792
845
  }
793
846
 
847
+ template<typename T>
848
+ inline CUDA_CALLABLE void adj_store(const T* address, T value, const T& adj_address, T& adj_value)
849
+ {
850
+ // nop; generic store() operations are not differentiable, only array_store() is
851
+ FP_VERIFY_ADJ(value, adj_value)
852
+ }
853
+
854
+ template<typename T>
855
+ inline CUDA_CALLABLE void adj_load(const T* address, const T& adj_address, T& adj_value)
856
+ {
857
+ // nop; generic load() operations are not differentiable
858
+ }
859
+
794
860
  template<typename T>
795
861
  inline CUDA_CALLABLE void adj_atomic_add(const array_t<T>& buf, int i, T value, const array_t<T>& adj_buf, int& adj_i, T& adj_value, const T& adj_ret)
796
862
  {
@@ -860,22 +926,22 @@ inline CUDA_CALLABLE void adj_atomic_sub(const array_t<T>& buf, int i, int j, in
860
926
 
861
927
  // generic array types that do not support gradient computation (indexedarray, etc.)
862
928
  template<template<typename> class A1, template<typename> class A2, typename T>
863
- inline CUDA_CALLABLE void adj_load(const A1<T>& buf, int i, const A2<T>& adj_buf, int& adj_i, const T& adj_output) {}
929
+ inline CUDA_CALLABLE void adj_address(const A1<T>& buf, int i, const A2<T>& adj_buf, int& adj_i, const T& adj_output) {}
864
930
  template<template<typename> class A1, template<typename> class A2, typename T>
865
- inline CUDA_CALLABLE void adj_load(const A1<T>& buf, int i, int j, const A2<T>& adj_buf, int& adj_i, int& adj_j, const T& adj_output) {}
931
+ inline CUDA_CALLABLE void adj_address(const A1<T>& buf, int i, int j, const A2<T>& adj_buf, int& adj_i, int& adj_j, const T& adj_output) {}
866
932
  template<template<typename> class A1, template<typename> class A2, typename T>
867
- inline CUDA_CALLABLE void adj_load(const A1<T>& buf, int i, int j, int k, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, const T& adj_output) {}
933
+ inline CUDA_CALLABLE void adj_address(const A1<T>& buf, int i, int j, int k, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, const T& adj_output) {}
868
934
  template<template<typename> class A1, template<typename> class A2, typename T>
869
- inline CUDA_CALLABLE void adj_load(const A1<T>& buf, int i, int j, int k, int l, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, int& adj_l, const T& adj_output) {}
935
+ inline CUDA_CALLABLE void adj_address(const A1<T>& buf, int i, int j, int k, int l, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, int& adj_l, const T& adj_output) {}
870
936
 
871
937
  template<template<typename> class A1, template<typename> class A2, typename T>
872
- inline CUDA_CALLABLE void adj_store(const A1<T>& buf, int i, T value, const A2<T>& adj_buf, int& adj_i, T& adj_value) {}
938
+ inline CUDA_CALLABLE void adj_array_store(const A1<T>& buf, int i, T value, const A2<T>& adj_buf, int& adj_i, T& adj_value) {}
873
939
  template<template<typename> class A1, template<typename> class A2, typename T>
874
- inline CUDA_CALLABLE void adj_store(const A1<T>& buf, int i, int j, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, T& adj_value) {}
940
+ inline CUDA_CALLABLE void adj_array_store(const A1<T>& buf, int i, int j, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, T& adj_value) {}
875
941
  template<template<typename> class A1, template<typename> class A2, typename T>
876
- inline CUDA_CALLABLE void adj_store(const A1<T>& buf, int i, int j, int k, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, T& adj_value) {}
942
+ inline CUDA_CALLABLE void adj_array_store(const A1<T>& buf, int i, int j, int k, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, T& adj_value) {}
877
943
  template<template<typename> class A1, template<typename> class A2, typename T>
878
- inline CUDA_CALLABLE void adj_store(const A1<T>& buf, int i, int j, int k, int l, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, int& adj_l, T& adj_value) {}
944
+ inline CUDA_CALLABLE void adj_array_store(const A1<T>& buf, int i, int j, int k, int l, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, int& adj_l, T& adj_value) {}
879
945
 
880
946
  template<template<typename> class A1, template<typename> class A2, typename T>
881
947
  inline CUDA_CALLABLE void adj_atomic_add(const A1<T>& buf, int i, T value, const A2<T>& adj_buf, int& adj_i, T& adj_value, const T& adj_ret) {}
@@ -895,22 +961,65 @@ inline CUDA_CALLABLE void adj_atomic_sub(const A1<T>& buf, int i, int j, int k,
895
961
  template<template<typename> class A1, template<typename> class A2, typename T>
896
962
  inline CUDA_CALLABLE void adj_atomic_sub(const A1<T>& buf, int i, int j, int k, int l, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, int& adj_l, T& adj_value, const T& adj_ret) {}
897
963
 
964
+ // generic handler for scalar values
898
965
  template<template<typename> class A1, template<typename> class A2, typename T>
899
- inline CUDA_CALLABLE void adj_atomic_min(const A1<T>& buf, int i, T value, const A2<T>& adj_buf, int& adj_i, T& adj_value, const T& adj_ret) {}
966
+ inline CUDA_CALLABLE void adj_atomic_min(const A1<T>& buf, int i, T value, const A2<T>& adj_buf, int& adj_i, T& adj_value, const T& adj_ret) {
967
+ if (buf.grad)
968
+ adj_atomic_minmax(&index(buf, i), &index_grad(buf, i), value, adj_value);
969
+
970
+ FP_VERIFY_ADJ_1(value, adj_value)
971
+ }
900
972
  template<template<typename> class A1, template<typename> class A2, typename T>
901
- inline CUDA_CALLABLE void adj_atomic_min(const A1<T>& buf, int i, int j, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, T& adj_value, const T& adj_ret) {}
973
+ inline CUDA_CALLABLE void adj_atomic_min(const A1<T>& buf, int i, int j, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, T& adj_value, const T& adj_ret) {
974
+ if (buf.grad)
975
+ adj_atomic_minmax(&index(buf, i, j), &index_grad(buf, i, j), value, adj_value);
976
+
977
+ FP_VERIFY_ADJ_2(value, adj_value)
978
+ }
902
979
  template<template<typename> class A1, template<typename> class A2, typename T>
903
- inline CUDA_CALLABLE void adj_atomic_min(const A1<T>& buf, int i, int j, int k, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, T& adj_value, const T& adj_ret) {}
980
+ inline CUDA_CALLABLE void adj_atomic_min(const A1<T>& buf, int i, int j, int k, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, T& adj_value, const T& adj_ret) {
981
+ if (buf.grad)
982
+ adj_atomic_minmax(&index(buf, i, j, k), &index_grad(buf, i, j, k), value, adj_value);
983
+
984
+ FP_VERIFY_ADJ_3(value, adj_value)
985
+ }
904
986
  template<template<typename> class A1, template<typename> class A2, typename T>
905
- inline CUDA_CALLABLE void adj_atomic_min(const A1<T>& buf, int i, int j, int k, int l, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, int& adj_l, T& adj_value, const T& adj_ret) {}
987
+ inline CUDA_CALLABLE void adj_atomic_min(const A1<T>& buf, int i, int j, int k, int l, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, int& adj_l, T& adj_value, const T& adj_ret) {
988
+ if (buf.grad)
989
+ adj_atomic_minmax(&index(buf, i, j, k, l), &index_grad(buf, i, j, k, l), value, adj_value);
990
+
991
+ FP_VERIFY_ADJ_4(value, adj_value)
992
+ }
906
993
 
907
994
  template<template<typename> class A1, template<typename> class A2, typename T>
908
- inline CUDA_CALLABLE void adj_atomic_max(const A1<T>& buf, int i, T value, const A2<T>& adj_buf, int& adj_i, T& adj_value, const T& adj_ret) {}
995
+ inline CUDA_CALLABLE void adj_atomic_max(const A1<T>& buf, int i, T value, const A2<T>& adj_buf, int& adj_i, T& adj_value, const T& adj_ret) {
996
+ if (buf.grad)
997
+ adj_atomic_minmax(&index(buf, i), &index_grad(buf, i), value, adj_value);
998
+
999
+ FP_VERIFY_ADJ_1(value, adj_value)
1000
+ }
909
1001
  template<template<typename> class A1, template<typename> class A2, typename T>
910
- inline CUDA_CALLABLE void adj_atomic_max(const A1<T>& buf, int i, int j, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, T& adj_value, const T& adj_ret) {}
1002
+ inline CUDA_CALLABLE void adj_atomic_max(const A1<T>& buf, int i, int j, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, T& adj_value, const T& adj_ret) {
1003
+ if (buf.grad)
1004
+ adj_atomic_minmax(&index(buf, i, j), &index_grad(buf, i, j), value, adj_value);
1005
+
1006
+ FP_VERIFY_ADJ_2(value, adj_value)
1007
+ }
911
1008
  template<template<typename> class A1, template<typename> class A2, typename T>
912
- inline CUDA_CALLABLE void adj_atomic_max(const A1<T>& buf, int i, int j, int k, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, T& adj_value, const T& adj_ret) {}
1009
+ inline CUDA_CALLABLE void adj_atomic_max(const A1<T>& buf, int i, int j, int k, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, T& adj_value, const T& adj_ret) {
1010
+ if (buf.grad)
1011
+ adj_atomic_minmax(&index(buf, i, j, k), &index_grad(buf, i, j, k), value, adj_value);
1012
+
1013
+ FP_VERIFY_ADJ_3(value, adj_value)
1014
+ }
913
1015
  template<template<typename> class A1, template<typename> class A2, typename T>
914
- inline CUDA_CALLABLE void adj_atomic_max(const A1<T>& buf, int i, int j, int k, int l, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, int& adj_l, T& adj_value, const T& adj_ret) {}
1016
+ inline CUDA_CALLABLE void adj_atomic_max(const A1<T>& buf, int i, int j, int k, int l, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, int& adj_l, T& adj_value, const T& adj_ret) {
1017
+ if (buf.grad)
1018
+ adj_atomic_minmax(&index(buf, i, j, k, l), &index_grad(buf, i, j, k, l), value, adj_value);
1019
+
1020
+ FP_VERIFY_ADJ_4(value, adj_value)
1021
+ }
1022
+
1023
+ } // namespace wp
915
1024
 
916
- } // namespace wp
1025
+ #include "fabric.h"