warp-lang 0.9.0__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (315) hide show
  1. warp/__init__.py +15 -7
  2. warp/__init__.pyi +1 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +22 -443
  6. warp/build_dll.py +384 -0
  7. warp/builtins.py +998 -488
  8. warp/codegen.py +1307 -739
  9. warp/config.py +5 -3
  10. warp/constants.py +6 -0
  11. warp/context.py +1291 -548
  12. warp/dlpack.py +31 -31
  13. warp/fabric.py +326 -0
  14. warp/fem/__init__.py +27 -0
  15. warp/fem/cache.py +389 -0
  16. warp/fem/dirichlet.py +181 -0
  17. warp/fem/domain.py +263 -0
  18. warp/fem/field/__init__.py +101 -0
  19. warp/fem/field/field.py +149 -0
  20. warp/fem/field/nodal_field.py +299 -0
  21. warp/fem/field/restriction.py +21 -0
  22. warp/fem/field/test.py +181 -0
  23. warp/fem/field/trial.py +183 -0
  24. warp/fem/geometry/__init__.py +19 -0
  25. warp/fem/geometry/closest_point.py +70 -0
  26. warp/fem/geometry/deformed_geometry.py +271 -0
  27. warp/fem/geometry/element.py +744 -0
  28. warp/fem/geometry/geometry.py +186 -0
  29. warp/fem/geometry/grid_2d.py +373 -0
  30. warp/fem/geometry/grid_3d.py +435 -0
  31. warp/fem/geometry/hexmesh.py +953 -0
  32. warp/fem/geometry/partition.py +376 -0
  33. warp/fem/geometry/quadmesh_2d.py +532 -0
  34. warp/fem/geometry/tetmesh.py +840 -0
  35. warp/fem/geometry/trimesh_2d.py +577 -0
  36. warp/fem/integrate.py +1616 -0
  37. warp/fem/operator.py +191 -0
  38. warp/fem/polynomial.py +213 -0
  39. warp/fem/quadrature/__init__.py +2 -0
  40. warp/fem/quadrature/pic_quadrature.py +245 -0
  41. warp/fem/quadrature/quadrature.py +294 -0
  42. warp/fem/space/__init__.py +292 -0
  43. warp/fem/space/basis_space.py +489 -0
  44. warp/fem/space/collocated_function_space.py +105 -0
  45. warp/fem/space/dof_mapper.py +236 -0
  46. warp/fem/space/function_space.py +145 -0
  47. warp/fem/space/grid_2d_function_space.py +267 -0
  48. warp/fem/space/grid_3d_function_space.py +306 -0
  49. warp/fem/space/hexmesh_function_space.py +352 -0
  50. warp/fem/space/partition.py +350 -0
  51. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  52. warp/fem/space/restriction.py +160 -0
  53. warp/fem/space/shape/__init__.py +15 -0
  54. warp/fem/space/shape/cube_shape_function.py +738 -0
  55. warp/fem/space/shape/shape_function.py +103 -0
  56. warp/fem/space/shape/square_shape_function.py +611 -0
  57. warp/fem/space/shape/tet_shape_function.py +567 -0
  58. warp/fem/space/shape/triangle_shape_function.py +429 -0
  59. warp/fem/space/tetmesh_function_space.py +292 -0
  60. warp/fem/space/topology.py +295 -0
  61. warp/fem/space/trimesh_2d_function_space.py +221 -0
  62. warp/fem/types.py +77 -0
  63. warp/fem/utils.py +495 -0
  64. warp/native/array.h +164 -55
  65. warp/native/builtin.h +150 -174
  66. warp/native/bvh.cpp +75 -328
  67. warp/native/bvh.cu +406 -23
  68. warp/native/bvh.h +37 -45
  69. warp/native/clang/clang.cpp +136 -24
  70. warp/native/crt.cpp +1 -76
  71. warp/native/crt.h +111 -104
  72. warp/native/cuda_crt.h +1049 -0
  73. warp/native/cuda_util.cpp +15 -3
  74. warp/native/cuda_util.h +3 -1
  75. warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
  76. warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
  77. warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
  78. warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
  79. warp/native/cutlass/tools/library/scripts/library.py +799 -0
  80. warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
  81. warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
  82. warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
  83. warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
  84. warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
  85. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
  86. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
  87. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
  88. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
  89. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
  90. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
  91. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
  92. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
  93. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
  94. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
  95. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
  96. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
  97. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
  98. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
  99. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
  100. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
  101. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
  102. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
  103. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
  104. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
  105. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
  106. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
  107. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
  108. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
  109. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
  110. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
  111. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
  112. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
  113. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
  114. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
  115. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
  116. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
  117. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  118. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  119. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
  120. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
  121. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
  122. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  123. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  124. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
  125. warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
  126. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
  127. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
  128. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
  129. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
  130. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
  131. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
  132. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
  133. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
  134. warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
  135. warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
  136. warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
  137. warp/native/cutlass/tools/library/scripts/rt.py +796 -0
  138. warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
  139. warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
  140. warp/native/cutlass_gemm.cu +5 -3
  141. warp/native/exports.h +1240 -949
  142. warp/native/fabric.h +228 -0
  143. warp/native/hashgrid.cpp +4 -4
  144. warp/native/hashgrid.h +22 -2
  145. warp/native/initializer_array.h +2 -2
  146. warp/native/intersect.h +22 -7
  147. warp/native/intersect_adj.h +8 -8
  148. warp/native/intersect_tri.h +13 -16
  149. warp/native/marching.cu +157 -161
  150. warp/native/mat.h +119 -19
  151. warp/native/matnn.h +2 -2
  152. warp/native/mesh.cpp +108 -83
  153. warp/native/mesh.cu +243 -6
  154. warp/native/mesh.h +1547 -458
  155. warp/native/nanovdb/NanoVDB.h +1 -1
  156. warp/native/noise.h +272 -329
  157. warp/native/quat.h +51 -8
  158. warp/native/rand.h +45 -35
  159. warp/native/range.h +6 -2
  160. warp/native/reduce.cpp +157 -0
  161. warp/native/reduce.cu +348 -0
  162. warp/native/runlength_encode.cpp +62 -0
  163. warp/native/runlength_encode.cu +46 -0
  164. warp/native/scan.cu +11 -13
  165. warp/native/scan.h +1 -0
  166. warp/native/solid_angle.h +442 -0
  167. warp/native/sort.cpp +13 -0
  168. warp/native/sort.cu +9 -1
  169. warp/native/sparse.cpp +338 -0
  170. warp/native/sparse.cu +545 -0
  171. warp/native/spatial.h +2 -2
  172. warp/native/temp_buffer.h +30 -0
  173. warp/native/vec.h +126 -24
  174. warp/native/volume.h +120 -0
  175. warp/native/warp.cpp +658 -53
  176. warp/native/warp.cu +660 -68
  177. warp/native/warp.h +112 -12
  178. warp/optim/__init__.py +1 -0
  179. warp/optim/linear.py +922 -0
  180. warp/optim/sgd.py +92 -0
  181. warp/render/render_opengl.py +392 -152
  182. warp/render/render_usd.py +11 -11
  183. warp/sim/__init__.py +2 -2
  184. warp/sim/articulation.py +385 -185
  185. warp/sim/collide.py +21 -8
  186. warp/sim/import_mjcf.py +297 -106
  187. warp/sim/import_urdf.py +389 -210
  188. warp/sim/import_usd.py +198 -97
  189. warp/sim/inertia.py +17 -18
  190. warp/sim/integrator_euler.py +14 -8
  191. warp/sim/integrator_xpbd.py +161 -19
  192. warp/sim/model.py +795 -291
  193. warp/sim/optimizer.py +2 -6
  194. warp/sim/render.py +65 -3
  195. warp/sim/utils.py +3 -0
  196. warp/sparse.py +1227 -0
  197. warp/stubs.py +665 -223
  198. warp/tape.py +66 -15
  199. warp/tests/__main__.py +3 -6
  200. warp/tests/assets/curlnoise_golden.npy +0 -0
  201. warp/tests/assets/pnoise_golden.npy +0 -0
  202. warp/tests/assets/torus.usda +105 -105
  203. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  204. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  205. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  206. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  207. warp/tests/aux_test_unresolved_func.py +14 -0
  208. warp/tests/aux_test_unresolved_symbol.py +14 -0
  209. warp/tests/disabled_kinematics.py +239 -0
  210. warp/tests/run_coverage_serial.py +31 -0
  211. warp/tests/test_adam.py +103 -106
  212. warp/tests/test_arithmetic.py +128 -74
  213. warp/tests/test_array.py +1497 -211
  214. warp/tests/test_array_reduce.py +150 -0
  215. warp/tests/test_atomic.py +64 -28
  216. warp/tests/test_bool.py +99 -0
  217. warp/tests/test_builtins_resolution.py +1292 -0
  218. warp/tests/test_bvh.py +75 -43
  219. warp/tests/test_closest_point_edge_edge.py +54 -57
  220. warp/tests/test_codegen.py +233 -128
  221. warp/tests/test_compile_consts.py +28 -20
  222. warp/tests/test_conditional.py +108 -24
  223. warp/tests/test_copy.py +10 -12
  224. warp/tests/test_ctypes.py +112 -88
  225. warp/tests/test_dense.py +21 -14
  226. warp/tests/test_devices.py +98 -0
  227. warp/tests/test_dlpack.py +136 -108
  228. warp/tests/test_examples.py +277 -0
  229. warp/tests/test_fabricarray.py +955 -0
  230. warp/tests/test_fast_math.py +15 -11
  231. warp/tests/test_fem.py +1271 -0
  232. warp/tests/test_fp16.py +53 -19
  233. warp/tests/test_func.py +187 -74
  234. warp/tests/test_generics.py +194 -49
  235. warp/tests/test_grad.py +180 -116
  236. warp/tests/test_grad_customs.py +176 -0
  237. warp/tests/test_hash_grid.py +52 -37
  238. warp/tests/test_import.py +10 -23
  239. warp/tests/test_indexedarray.py +577 -24
  240. warp/tests/test_intersect.py +18 -9
  241. warp/tests/test_large.py +141 -0
  242. warp/tests/test_launch.py +251 -15
  243. warp/tests/test_lerp.py +64 -65
  244. warp/tests/test_linear_solvers.py +154 -0
  245. warp/tests/test_lvalue.py +493 -0
  246. warp/tests/test_marching_cubes.py +12 -13
  247. warp/tests/test_mat.py +508 -2778
  248. warp/tests/test_mat_lite.py +115 -0
  249. warp/tests/test_mat_scalar_ops.py +2889 -0
  250. warp/tests/test_math.py +103 -9
  251. warp/tests/test_matmul.py +305 -69
  252. warp/tests/test_matmul_lite.py +410 -0
  253. warp/tests/test_mesh.py +71 -14
  254. warp/tests/test_mesh_query_aabb.py +41 -25
  255. warp/tests/test_mesh_query_point.py +325 -34
  256. warp/tests/test_mesh_query_ray.py +39 -22
  257. warp/tests/test_mlp.py +30 -22
  258. warp/tests/test_model.py +92 -89
  259. warp/tests/test_modules_lite.py +39 -0
  260. warp/tests/test_multigpu.py +88 -114
  261. warp/tests/test_noise.py +12 -11
  262. warp/tests/test_operators.py +16 -20
  263. warp/tests/test_options.py +11 -11
  264. warp/tests/test_pinned.py +17 -18
  265. warp/tests/test_print.py +32 -11
  266. warp/tests/test_quat.py +275 -129
  267. warp/tests/test_rand.py +18 -16
  268. warp/tests/test_reload.py +38 -34
  269. warp/tests/test_rounding.py +50 -43
  270. warp/tests/test_runlength_encode.py +190 -0
  271. warp/tests/test_smoothstep.py +9 -11
  272. warp/tests/test_snippet.py +143 -0
  273. warp/tests/test_sparse.py +460 -0
  274. warp/tests/test_spatial.py +276 -243
  275. warp/tests/test_streams.py +110 -85
  276. warp/tests/test_struct.py +331 -85
  277. warp/tests/test_tape.py +39 -21
  278. warp/tests/test_torch.py +118 -89
  279. warp/tests/test_transient_module.py +12 -13
  280. warp/tests/test_types.py +614 -0
  281. warp/tests/test_utils.py +494 -0
  282. warp/tests/test_vec.py +354 -1987
  283. warp/tests/test_vec_lite.py +73 -0
  284. warp/tests/test_vec_scalar_ops.py +2099 -0
  285. warp/tests/test_volume.py +457 -293
  286. warp/tests/test_volume_write.py +124 -134
  287. warp/tests/unittest_serial.py +35 -0
  288. warp/tests/unittest_suites.py +341 -0
  289. warp/tests/unittest_utils.py +568 -0
  290. warp/tests/unused_test_misc.py +71 -0
  291. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  292. warp/thirdparty/appdirs.py +36 -45
  293. warp/thirdparty/unittest_parallel.py +549 -0
  294. warp/torch.py +72 -30
  295. warp/types.py +1744 -713
  296. warp/utils.py +360 -350
  297. warp_lang-0.11.0.dist-info/LICENSE.md +36 -0
  298. warp_lang-0.11.0.dist-info/METADATA +238 -0
  299. warp_lang-0.11.0.dist-info/RECORD +332 -0
  300. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
  301. warp/bin/warp-clang.exp +0 -0
  302. warp/bin/warp-clang.lib +0 -0
  303. warp/bin/warp.exp +0 -0
  304. warp/bin/warp.lib +0 -0
  305. warp/tests/test_all.py +0 -215
  306. warp/tests/test_array_scan.py +0 -60
  307. warp/tests/test_base.py +0 -208
  308. warp/tests/test_unresolved_func.py +0 -7
  309. warp/tests/test_unresolved_symbol.py +0 -7
  310. warp_lang-0.9.0.dist-info/METADATA +0 -20
  311. warp_lang-0.9.0.dist-info/RECORD +0 -177
  312. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  313. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  314. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  315. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
warp/native/builtin.h CHANGED
@@ -46,7 +46,6 @@ __device__ void __debugbreak() {}
46
46
  namespace wp
47
47
  {
48
48
 
49
-
50
49
  // numeric types (used from generated kernels)
51
50
  typedef float float32;
52
51
  typedef double float64;
@@ -141,7 +140,7 @@ static_assert(sizeof(half) == 2, "Size of half / float16 type must be 2-bytes");
141
140
 
142
141
  typedef half float16;
143
142
 
144
- #if __CUDA_ARCH__
143
+ #if defined(__CUDA_ARCH__)
145
144
 
146
145
  CUDA_CALLABLE inline half float_to_half(float x)
147
146
  {
@@ -157,95 +156,38 @@ CUDA_CALLABLE inline float half_to_float(half x)
157
156
  return val;
158
157
  }
159
158
 
160
- #else
159
+ #elif defined(__clang__)
161
160
 
162
- // adapted from Fabien Giesen's post: https://gist.github.com/rygorous/2156668
161
+ // _Float16 is Clang's native half-precision floating-point type
163
162
  inline half float_to_half(float x)
164
163
  {
165
- union fp32
166
- {
167
- uint32 u;
168
- float f;
169
164
 
170
- struct
171
- {
172
- unsigned int mantissa : 23;
173
- unsigned int exponent : 8;
174
- unsigned int sign : 1;
175
- };
176
- };
177
-
178
- fp32 f;
179
- f.f = x;
180
-
181
- fp32 f32infty = { 255 << 23 };
182
- fp32 f16infty = { 31 << 23 };
183
- fp32 magic = { 15 << 23 };
184
- uint32 sign_mask = 0x80000000u;
185
- uint32 round_mask = ~0xfffu;
186
- half o;
187
-
188
- uint32 sign = f.u & sign_mask;
189
- f.u ^= sign;
190
-
191
- // NOTE all the integer compares in this function can be safely
192
- // compiled into signed compares since all operands are below
193
- // 0x80000000. Important if you want fast straight SSE2 code
194
- // (since there's no unsigned PCMPGTD).
195
-
196
- if (f.u >= f32infty.u) // Inf or NaN (all exponent bits set)
197
- o.u = (f.u > f32infty.u) ? 0x7e00 : 0x7c00; // NaN->qNaN and Inf->Inf
198
- else // (De)normalized number or zero
199
- {
200
- f.u &= round_mask;
201
- f.f *= magic.f;
202
- f.u -= round_mask;
203
- if (f.u > f16infty.u) f.u = f16infty.u; // Clamp to signed infinity if overflowed
204
-
205
- o.u = f.u >> 13; // Take the bits!
206
- }
207
-
208
- o.u |= sign >> 16;
209
- return o;
165
+ _Float16 f16 = static_cast<_Float16>(x);
166
+ return *reinterpret_cast<half*>(&f16);
210
167
  }
211
168
 
212
-
213
169
  inline float half_to_float(half h)
214
170
  {
215
- union fp32
216
- {
217
- uint32 u;
218
- float f;
171
+ _Float16 f16 = *reinterpret_cast<_Float16*>(&h);
172
+ return static_cast<float>(f16);
173
+ }
219
174
 
220
- struct
221
- {
222
- unsigned int mantissa : 23;
223
- unsigned int exponent : 8;
224
- unsigned int sign : 1;
225
- };
226
- };
227
-
228
- static const fp32 magic = { 113 << 23 };
229
- static const uint32 shifted_exp = 0x7c00 << 13; // exponent mask after shift
230
- fp32 o;
231
-
232
- o.u = (h.u & 0x7fff) << 13; // exponent/mantissa bits
233
- uint32 exp = shifted_exp & o.u; // just the exponent
234
- o.u += (127 - 15) << 23; // exponent adjust
235
-
236
- // handle exponent special cases
237
- if (exp == shifted_exp) // Inf/NaN?
238
- o.u += (128 - 16) << 23; // extra exp adjust
239
- else if (exp == 0) // Zero/Denormal?
240
- {
241
- o.u += 1 << 23; // extra exp adjust
242
- o.f -= magic.f; // renormalize
243
- }
175
+ #else // Native C++ for Warp builtins outside of kernels
176
+
177
+ extern "C" WP_API uint16_t float_to_half_bits(float x);
178
+ extern "C" WP_API float half_bits_to_float(uint16_t u);
244
179
 
245
- o.u |= (h.u & 0x8000) << 16; // sign bit
246
- return o.f;
180
+ inline half float_to_half(float x)
181
+ {
182
+ half h;
183
+ h.u = float_to_half_bits(x);
184
+ return h;
247
185
  }
248
186
 
187
+ inline float half_to_float(half h)
188
+ {
189
+ return half_bits_to_float(h.u);
190
+ }
249
191
 
250
192
  #endif
251
193
 
@@ -353,7 +295,7 @@ inline CUDA_CALLABLE T rshift(T a, T b) { return a>>b; } \
353
295
  inline CUDA_CALLABLE T invert(T x) { return ~x; } \
354
296
  inline CUDA_CALLABLE bool isfinite(T x) { return true; } \
355
297
  inline CUDA_CALLABLE void adj_mul(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
356
- inline CUDA_CALLABLE void adj_div(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
298
+ inline CUDA_CALLABLE void adj_div(T a, T b, T ret, T& adj_a, T& adj_b, T adj_ret) { } \
357
299
  inline CUDA_CALLABLE void adj_add(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
358
300
  inline CUDA_CALLABLE void adj_sub(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
359
301
  inline CUDA_CALLABLE void adj_mod(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
@@ -491,11 +433,6 @@ inline CUDA_CALLABLE void adj_clamp(T x, T a, T b, T& adj_x, T& adj_a, T& adj_b,
491
433
  else\
492
434
  adj_x += adj_ret;\
493
435
  }\
494
- inline CUDA_CALLABLE void adj_round(T x, T& adj_x, T adj_ret){ }\
495
- inline CUDA_CALLABLE void adj_rint(T x, T& adj_x, T adj_ret){ }\
496
- inline CUDA_CALLABLE void adj_trunc(T x, T& adj_x, T adj_ret){ }\
497
- inline CUDA_CALLABLE void adj_floor(T x, T& adj_x, T adj_ret){ }\
498
- inline CUDA_CALLABLE void adj_ceil(T x, T& adj_x, T adj_ret){ }\
499
436
  inline CUDA_CALLABLE T div(T a, T b)\
500
437
  {\
501
438
  DO_IF_FPCHECK(\
@@ -506,10 +443,10 @@ inline CUDA_CALLABLE T div(T a, T b)\
506
443
  })\
507
444
  return a/b;\
508
445
  }\
509
- inline CUDA_CALLABLE void adj_div(T a, T b, T& adj_a, T& adj_b, T adj_ret)\
446
+ inline CUDA_CALLABLE void adj_div(T a, T b, T ret, T& adj_a, T& adj_b, T adj_ret)\
510
447
  {\
511
448
  adj_a += adj_ret/b;\
512
- adj_b -= adj_ret*(a/b)/b;\
449
+ adj_b -= adj_ret*(ret)/b;\
513
450
  DO_IF_FPCHECK(\
514
451
  if (!isfinite(adj_a) || !isfinite(adj_b))\
515
452
  {\
@@ -788,16 +725,16 @@ inline CUDA_CALLABLE double floordiv(double a, double b)
788
725
  inline CUDA_CALLABLE float leaky_min(float a, float b, float r) { return min(a, b); }
789
726
  inline CUDA_CALLABLE float leaky_max(float a, float b, float r) { return max(a, b); }
790
727
 
791
- inline CUDA_CALLABLE half abs(half x) { return ::fabs(float(x)); }
792
- inline CUDA_CALLABLE float abs(float x) { return ::fabs(x); }
728
+ inline CUDA_CALLABLE half abs(half x) { return ::fabsf(float(x)); }
729
+ inline CUDA_CALLABLE float abs(float x) { return ::fabsf(x); }
793
730
  inline CUDA_CALLABLE double abs(double x) { return ::fabs(x); }
794
731
 
795
- inline CUDA_CALLABLE float acos(float x){ return ::acos(min(max(x, -1.0f), 1.0f)); }
796
- inline CUDA_CALLABLE float asin(float x){ return ::asin(min(max(x, -1.0f), 1.0f)); }
797
- inline CUDA_CALLABLE float atan(float x) { return ::atan(x); }
798
- inline CUDA_CALLABLE float atan2(float y, float x) { return ::atan2(y, x); }
799
- inline CUDA_CALLABLE float sin(float x) { return ::sin(x); }
800
- inline CUDA_CALLABLE float cos(float x) { return ::cos(x); }
732
+ inline CUDA_CALLABLE float acos(float x){ return ::acosf(min(max(x, -1.0f), 1.0f)); }
733
+ inline CUDA_CALLABLE float asin(float x){ return ::asinf(min(max(x, -1.0f), 1.0f)); }
734
+ inline CUDA_CALLABLE float atan(float x) { return ::atanf(x); }
735
+ inline CUDA_CALLABLE float atan2(float y, float x) { return ::atan2f(y, x); }
736
+ inline CUDA_CALLABLE float sin(float x) { return ::sinf(x); }
737
+ inline CUDA_CALLABLE float cos(float x) { return ::cosf(x); }
801
738
 
802
739
  inline CUDA_CALLABLE double acos(double x){ return ::acos(min(max(x, -1.0), 1.0)); }
803
740
  inline CUDA_CALLABLE double asin(double x){ return ::asin(min(max(x, -1.0), 1.0)); }
@@ -806,12 +743,12 @@ inline CUDA_CALLABLE double atan2(double y, double x) { return ::atan2(y, x); }
806
743
  inline CUDA_CALLABLE double sin(double x) { return ::sin(x); }
807
744
  inline CUDA_CALLABLE double cos(double x) { return ::cos(x); }
808
745
 
809
- inline CUDA_CALLABLE half acos(half x){ return ::acos(min(max(float(x), -1.0f), 1.0f)); }
810
- inline CUDA_CALLABLE half asin(half x){ return ::asin(min(max(float(x), -1.0f), 1.0f)); }
811
- inline CUDA_CALLABLE half atan(half x) { return ::atan(float(x)); }
812
- inline CUDA_CALLABLE half atan2(half y, half x) { return ::atan2(float(y), float(x)); }
813
- inline CUDA_CALLABLE half sin(half x) { return ::sin(float(x)); }
814
- inline CUDA_CALLABLE half cos(half x) { return ::cos(float(x)); }
746
+ inline CUDA_CALLABLE half acos(half x){ return ::acosf(min(max(float(x), -1.0f), 1.0f)); }
747
+ inline CUDA_CALLABLE half asin(half x){ return ::asinf(min(max(float(x), -1.0f), 1.0f)); }
748
+ inline CUDA_CALLABLE half atan(half x) { return ::atanf(float(x)); }
749
+ inline CUDA_CALLABLE half atan2(half y, half x) { return ::atan2f(float(y), float(x)); }
750
+ inline CUDA_CALLABLE half sin(half x) { return ::sinf(float(x)); }
751
+ inline CUDA_CALLABLE half cos(half x) { return ::cosf(float(x)); }
815
752
 
816
753
 
817
754
  inline CUDA_CALLABLE float sqrt(float x)
@@ -823,7 +760,7 @@ inline CUDA_CALLABLE float sqrt(float x)
823
760
  assert(0);
824
761
  }
825
762
  #endif
826
- return ::sqrt(x);
763
+ return ::sqrtf(x);
827
764
  }
828
765
  inline CUDA_CALLABLE double sqrt(double x)
829
766
  {
@@ -845,10 +782,14 @@ inline CUDA_CALLABLE half sqrt(half x)
845
782
  assert(0);
846
783
  }
847
784
  #endif
848
- return ::sqrt(float(x));
785
+ return ::sqrtf(float(x));
849
786
  }
850
787
 
851
- inline CUDA_CALLABLE float tan(float x) { return ::tan(x); }
788
+ inline CUDA_CALLABLE float cbrt(float x) { return ::cbrtf(x); }
789
+ inline CUDA_CALLABLE double cbrt(double x) { return ::cbrt(x); }
790
+ inline CUDA_CALLABLE half cbrt(half x) { return ::cbrtf(float(x)); }
791
+
792
+ inline CUDA_CALLABLE float tan(float x) { return ::tanf(x); }
852
793
  inline CUDA_CALLABLE float sinh(float x) { return ::sinhf(x);}
853
794
  inline CUDA_CALLABLE float cosh(float x) { return ::coshf(x);}
854
795
  inline CUDA_CALLABLE float tanh(float x) { return ::tanhf(x);}
@@ -862,7 +803,7 @@ inline CUDA_CALLABLE double tanh(double x) { return ::tanh(x);}
862
803
  inline CUDA_CALLABLE double degrees(double x) { return x * RAD_TO_DEG;}
863
804
  inline CUDA_CALLABLE double radians(double x) { return x * DEG_TO_RAD;}
864
805
 
865
- inline CUDA_CALLABLE half tan(half x) { return ::tan(float(x)); }
806
+ inline CUDA_CALLABLE half tan(half x) { return ::tanf(float(x)); }
866
807
  inline CUDA_CALLABLE half sinh(half x) { return ::sinhf(float(x));}
867
808
  inline CUDA_CALLABLE half cosh(half x) { return ::coshf(float(x));}
868
809
  inline CUDA_CALLABLE half tanh(half x) { return ::tanhf(float(x));}
@@ -874,6 +815,21 @@ inline CUDA_CALLABLE float rint(float x) { return ::rintf(x); }
874
815
  inline CUDA_CALLABLE float trunc(float x) { return ::truncf(x); }
875
816
  inline CUDA_CALLABLE float floor(float x) { return ::floorf(x); }
876
817
  inline CUDA_CALLABLE float ceil(float x) { return ::ceilf(x); }
818
+ inline CUDA_CALLABLE float frac(float x) { return x - trunc(x); }
819
+
820
+ inline CUDA_CALLABLE double round(double x) { return ::round(x); }
821
+ inline CUDA_CALLABLE double rint(double x) { return ::rint(x); }
822
+ inline CUDA_CALLABLE double trunc(double x) { return ::trunc(x); }
823
+ inline CUDA_CALLABLE double floor(double x) { return ::floor(x); }
824
+ inline CUDA_CALLABLE double ceil(double x) { return ::ceil(x); }
825
+ inline CUDA_CALLABLE double frac(double x) { return x - trunc(x); }
826
+
827
+ inline CUDA_CALLABLE half round(half x) { return ::roundf(float(x)); }
828
+ inline CUDA_CALLABLE half rint(half x) { return ::rintf(float(x)); }
829
+ inline CUDA_CALLABLE half trunc(half x) { return ::truncf(float(x)); }
830
+ inline CUDA_CALLABLE half floor(half x) { return ::floorf(float(x)); }
831
+ inline CUDA_CALLABLE half ceil(half x) { return ::ceilf(float(x)); }
832
+ inline CUDA_CALLABLE half frac(half x) { return float(x) - trunc(float(x)); }
877
833
 
878
834
  #define DECLARE_ADJOINTS(T)\
879
835
  inline CUDA_CALLABLE void adj_log(T a, T& adj_a, T adj_ret)\
@@ -903,11 +859,11 @@ inline CUDA_CALLABLE void adj_log10(T a, T& adj_a, T adj_ret)\
903
859
  assert(0);\
904
860
  })\
905
861
  }\
906
- inline CUDA_CALLABLE void adj_exp(T a, T& adj_a, T adj_ret) { adj_a += exp(a)*adj_ret; }\
907
- inline CUDA_CALLABLE void adj_pow(T a, T b, T& adj_a, T& adj_b, T adj_ret)\
862
+ inline CUDA_CALLABLE void adj_exp(T a, T ret, T& adj_a, T adj_ret) { adj_a += ret*adj_ret; }\
863
+ inline CUDA_CALLABLE void adj_pow(T a, T b, T ret, T& adj_a, T& adj_b, T adj_ret)\
908
864
  { \
909
865
  adj_a += b*pow(a, b-T(1))*adj_ret;\
910
- adj_b += log(a)*pow(a, b)*adj_ret;\
866
+ adj_b += log(a)*ret*adj_ret;\
911
867
  DO_IF_FPCHECK(if (!isfinite(adj_a) || !isfinite(adj_b))\
912
868
  {\
913
869
  printf("%s:%d - adj_pow(%f, %f, %f, %f, %f)\n", __FILE__, __LINE__, float(a), float(b), float(adj_a), float(adj_b), float(adj_ret));\
@@ -1006,20 +962,28 @@ inline CUDA_CALLABLE void adj_cosh(T x, T& adj_x, T adj_ret)\
1006
962
  {\
1007
963
  adj_x += sinh(x)*adj_ret;\
1008
964
  }\
1009
- inline CUDA_CALLABLE void adj_tanh(T x, T& adj_x, T adj_ret)\
965
+ inline CUDA_CALLABLE void adj_tanh(T x, T ret, T& adj_x, T adj_ret)\
1010
966
  {\
1011
- T tanh_x = tanh(x);\
1012
- adj_x += (T(1) - tanh_x*tanh_x)*adj_ret;\
967
+ adj_x += (T(1) - ret*ret)*adj_ret;\
1013
968
  }\
1014
- inline CUDA_CALLABLE void adj_sqrt(T x, T& adj_x, T adj_ret)\
969
+ inline CUDA_CALLABLE void adj_sqrt(T x, T ret, T& adj_x, T adj_ret)\
1015
970
  {\
1016
- adj_x += T(0.5)*(T(1)/sqrt(x))*adj_ret;\
971
+ adj_x += T(0.5)*(T(1)/ret)*adj_ret;\
1017
972
  DO_IF_FPCHECK(if (!isfinite(adj_x))\
1018
973
  {\
1019
974
  printf("%s:%d - adj_sqrt(%f, %f, %f)\n", __FILE__, __LINE__, float(x), float(adj_x), float(adj_ret));\
1020
975
  assert(0);\
1021
976
  })\
1022
977
  }\
978
+ inline CUDA_CALLABLE void adj_cbrt(T x, T ret, T& adj_x, T adj_ret)\
979
+ {\
980
+ adj_x += (T(1)/T(3))*(T(1)/(ret*ret))*adj_ret;\
981
+ DO_IF_FPCHECK(if (!isfinite(adj_x))\
982
+ {\
983
+ printf("%s:%d - adj_cbrt(%f, %f, %f)\n", __FILE__, __LINE__, float(x), float(adj_x), float(adj_ret));\
984
+ assert(0);\
985
+ })\
986
+ }\
1023
987
  inline CUDA_CALLABLE void adj_degrees(T x, T& adj_x, T adj_ret)\
1024
988
  {\
1025
989
  adj_x += RAD_TO_DEG * adj_ret;\
@@ -1027,7 +991,13 @@ inline CUDA_CALLABLE void adj_degrees(T x, T& adj_x, T adj_ret)\
1027
991
  inline CUDA_CALLABLE void adj_radians(T x, T& adj_x, T adj_ret)\
1028
992
  {\
1029
993
  adj_x += DEG_TO_RAD * adj_ret;\
1030
- }
994
+ }\
995
+ inline CUDA_CALLABLE void adj_round(T x, T& adj_x, T adj_ret){ }\
996
+ inline CUDA_CALLABLE void adj_rint(T x, T& adj_x, T adj_ret){ }\
997
+ inline CUDA_CALLABLE void adj_trunc(T x, T& adj_x, T adj_ret){ }\
998
+ inline CUDA_CALLABLE void adj_floor(T x, T& adj_x, T adj_ret){ }\
999
+ inline CUDA_CALLABLE void adj_ceil(T x, T& adj_x, T adj_ret){ }\
1000
+ inline CUDA_CALLABLE void adj_frac(T x, T& adj_x, T adj_ret){ }
1031
1001
 
1032
1002
  DECLARE_ADJOINTS(float16)
1033
1003
  DECLARE_ADJOINTS(float32)
@@ -1051,17 +1021,31 @@ CUDA_CALLABLE inline void adj_select(const C& cond, const T& a, const T& b, C& a
1051
1021
  }
1052
1022
 
1053
1023
  template <typename T>
1054
- CUDA_CALLABLE inline void copy(T& dest, const T& src)
1024
+ CUDA_CALLABLE inline T copy(const T& src)
1025
+ {
1026
+ return src;
1027
+ }
1028
+
1029
+ template <typename T>
1030
+ CUDA_CALLABLE inline void adj_copy(const T& src, T& adj_src, T& adj_dest)
1031
+ {
1032
+ adj_src = adj_dest;
1033
+ adj_dest = T{};
1034
+ }
1035
+
1036
+ template <typename T>
1037
+ CUDA_CALLABLE inline void assign(T& dest, const T& src)
1055
1038
  {
1056
1039
  dest = src;
1057
1040
  }
1058
1041
 
1059
1042
  template <typename T>
1060
- CUDA_CALLABLE inline void adj_copy(T& dest, const T& src, T& adj_dest, T& adj_src)
1043
+ CUDA_CALLABLE inline void adj_assign(T& dest, const T& src, T& adj_dest, T& adj_src)
1061
1044
  {
1062
- // nop, this is non-differentiable operation since it violates SSA
1045
+ // this is generally a non-differentiable operation since it violates SSA,
1046
+ // except in read-modify-write statements which are reversible through backpropagation
1063
1047
  adj_src = adj_dest;
1064
- adj_dest = T(0);
1048
+ adj_dest = T{};
1065
1049
  }
1066
1050
 
1067
1051
 
@@ -1106,34 +1090,8 @@ struct launch_bounds_t
1106
1090
  size_t size; // total number of threads
1107
1091
  };
1108
1092
 
1109
- #ifdef __CUDACC__
1110
-
1111
- // store launch bounds in shared memory so
1112
- // we can access them from any user func
1113
- // this is to avoid having to explicitly
1114
- // set another piece of __constant__ memory
1115
- // from the host
1116
- __shared__ launch_bounds_t s_launchBounds;
1117
-
1118
- __device__ inline void set_launch_bounds(const launch_bounds_t& b)
1119
- {
1120
- if (threadIdx.x == 0)
1121
- s_launchBounds = b;
1122
-
1123
- __syncthreads();
1124
- }
1125
-
1126
- #else
1127
-
1128
- // for single-threaded CPU we store launch
1129
- // bounds in static memory to share globally
1130
- static launch_bounds_t s_launchBounds;
1093
+ #ifndef __CUDACC__
1131
1094
  static size_t s_threadIdx;
1132
-
1133
- inline void set_launch_bounds(const launch_bounds_t& b)
1134
- {
1135
- s_launchBounds = b;
1136
- }
1137
1095
  #endif
1138
1096
 
1139
1097
  inline CUDA_CALLABLE size_t grid_index()
@@ -1147,10 +1105,8 @@ inline CUDA_CALLABLE size_t grid_index()
1147
1105
  #endif
1148
1106
  }
1149
1107
 
1150
- inline CUDA_CALLABLE int tid()
1108
+ inline CUDA_CALLABLE int tid(size_t index)
1151
1109
  {
1152
- const size_t index = grid_index();
1153
-
1154
1110
  // For the 1-D tid() we need to warn the user if we're about to provide a truncated index
1155
1111
  // Only do this in _DEBUG when called from device to avoid excessive register allocation
1156
1112
  #if defined(_DEBUG) || !defined(__CUDA_ARCH__)
@@ -1161,23 +1117,19 @@ inline CUDA_CALLABLE int tid()
1161
1117
  return static_cast<int>(index);
1162
1118
  }
1163
1119
 
1164
- inline CUDA_CALLABLE_DEVICE void tid(int& i, int& j)
1120
+ inline CUDA_CALLABLE_DEVICE void tid(int& i, int& j, size_t index, const launch_bounds_t& launch_bounds)
1165
1121
  {
1166
- const size_t index = grid_index();
1167
-
1168
- const int n = s_launchBounds.shape[1];
1122
+ const size_t n = launch_bounds.shape[1];
1169
1123
 
1170
1124
  // convert to work item
1171
1125
  i = index/n;
1172
1126
  j = index%n;
1173
1127
  }
1174
1128
 
1175
- inline CUDA_CALLABLE_DEVICE void tid(int& i, int& j, int& k)
1129
+ inline CUDA_CALLABLE_DEVICE void tid(int& i, int& j, int& k, size_t index, const launch_bounds_t& launch_bounds)
1176
1130
  {
1177
- const size_t index = grid_index();
1178
-
1179
- const int n = s_launchBounds.shape[1];
1180
- const int o = s_launchBounds.shape[2];
1131
+ const size_t n = launch_bounds.shape[1];
1132
+ const size_t o = launch_bounds.shape[2];
1181
1133
 
1182
1134
  // convert to work item
1183
1135
  i = index/(n*o);
@@ -1185,13 +1137,11 @@ inline CUDA_CALLABLE_DEVICE void tid(int& i, int& j, int& k)
1185
1137
  k = index%o;
1186
1138
  }
1187
1139
 
1188
- inline CUDA_CALLABLE_DEVICE void tid(int& i, int& j, int& k, int& l)
1140
+ inline CUDA_CALLABLE_DEVICE void tid(int& i, int& j, int& k, int& l, size_t index, const launch_bounds_t& launch_bounds)
1189
1141
  {
1190
- const size_t index = grid_index();
1191
-
1192
- const int n = s_launchBounds.shape[1];
1193
- const int o = s_launchBounds.shape[2];
1194
- const int p = s_launchBounds.shape[3];
1142
+ const size_t n = launch_bounds.shape[1];
1143
+ const size_t o = launch_bounds.shape[2];
1144
+ const size_t p = launch_bounds.shape[3];
1195
1145
 
1196
1146
  // convert to work item
1197
1147
  i = index/(n*o*p);
@@ -1203,11 +1153,11 @@ inline CUDA_CALLABLE_DEVICE void tid(int& i, int& j, int& k, int& l)
1203
1153
  template<typename T>
1204
1154
  inline CUDA_CALLABLE T atomic_add(T* buf, T value)
1205
1155
  {
1206
- #if defined(WP_CPU)
1156
+ #if !defined(__CUDA_ARCH__)
1207
1157
  T old = buf[0];
1208
1158
  buf[0] += value;
1209
1159
  return old;
1210
- #elif defined(WP_CUDA)
1160
+ #else
1211
1161
  return atomicAdd(buf, value);
1212
1162
  #endif
1213
1163
  }
@@ -1215,11 +1165,14 @@ inline CUDA_CALLABLE T atomic_add(T* buf, T value)
1215
1165
  template<>
1216
1166
  inline CUDA_CALLABLE float16 atomic_add(float16* buf, float16 value)
1217
1167
  {
1218
- #if defined(WP_CPU)
1168
+ #if !defined(__CUDA_ARCH__)
1219
1169
  float16 old = buf[0];
1220
1170
  buf[0] += value;
1221
1171
  return old;
1222
- #elif defined(WP_CUDA)
1172
+ #elif defined(__clang__) // CUDA compiled by Clang
1173
+ __half r = atomicAdd(reinterpret_cast<__half*>(buf), *reinterpret_cast<__half*>(&value));
1174
+ return *reinterpret_cast<float16*>(&r);
1175
+ #else // CUDA compiled by NVRTC
1223
1176
  //return atomicAdd(buf, value);
1224
1177
 
1225
1178
  /* Define __PTR for atomicAdd prototypes below, undef after done */
@@ -1243,7 +1196,7 @@ inline CUDA_CALLABLE float16 atomic_add(float16* buf, float16 value)
1243
1196
 
1244
1197
  #undef __PTR
1245
1198
 
1246
- #endif // defined(WP_CUDA)
1199
+ #endif // CUDA compiled by NVRTC
1247
1200
 
1248
1201
  }
1249
1202
 
@@ -1318,9 +1271,36 @@ inline CUDA_CALLABLE int atomic_min(int* address, int val)
1318
1271
  #endif
1319
1272
  }
1320
1273
 
1274
+ // default behavior for adjoint of atomic min/max operation that accumulates gradients for all elements matching the min/max value
1275
+ template <typename T>
1276
+ CUDA_CALLABLE inline void adj_atomic_minmax(T *addr, T *adj_addr, const T &value, T &adj_value)
1277
+ {
1278
+ if (value == *addr)
1279
+ adj_value += *adj_addr;
1280
+ }
1281
+
1282
+ // for integral types we do not accumulate gradients
1283
+ CUDA_CALLABLE inline void adj_atomic_minmax(int8* buf, int8* adj_buf, const int8 &value, int8 &adj_value) { }
1284
+ CUDA_CALLABLE inline void adj_atomic_minmax(uint8* buf, uint8* adj_buf, const uint8 &value, uint8 &adj_value) { }
1285
+ CUDA_CALLABLE inline void adj_atomic_minmax(int16* buf, int16* adj_buf, const int16 &value, int16 &adj_value) { }
1286
+ CUDA_CALLABLE inline void adj_atomic_minmax(uint16* buf, uint16* adj_buf, const uint16 &value, uint16 &adj_value) { }
1287
+ CUDA_CALLABLE inline void adj_atomic_minmax(int32* buf, int32* adj_buf, const int32 &value, int32 &adj_value) { }
1288
+ CUDA_CALLABLE inline void adj_atomic_minmax(uint32* buf, uint32* adj_buf, const uint32 &value, uint32 &adj_value) { }
1289
+ CUDA_CALLABLE inline void adj_atomic_minmax(int64* buf, int64* adj_buf, const int64 &value, int64 &adj_value) { }
1290
+ CUDA_CALLABLE inline void adj_atomic_minmax(uint64* buf, uint64* adj_buf, const uint64 &value, uint64 &adj_value) { }
1291
+ CUDA_CALLABLE inline void adj_atomic_minmax(bool* buf, bool* adj_buf, const bool &value, bool &adj_value) { }
1292
+
1321
1293
 
1322
1294
  } // namespace wp
1323
1295
 
1296
+
1297
+ // bool and printf are defined outside of the wp namespace in crt.h, hence
1298
+ // their adjoint counterparts are also defined in the global namespace.
1299
+ template <typename T>
1300
+ CUDA_CALLABLE inline void adj_bool(T, T&, bool) {}
1301
+ inline CUDA_CALLABLE void adj_printf(const char* fmt, ...) {}
1302
+
1303
+
1324
1304
  #include "vec.h"
1325
1305
  #include "mat.h"
1326
1306
  #include "quat.h"
@@ -1485,10 +1465,6 @@ inline CUDA_CALLABLE void adj_print(transform_t<Type> t, transform_t<Type>& adj_
1485
1465
  inline CUDA_CALLABLE void adj_print(str t, str& adj_t) {}
1486
1466
 
1487
1467
 
1488
- // printf defined globally in crt.h
1489
- inline CUDA_CALLABLE void adj_printf(const char* fmt, ...) {}
1490
-
1491
-
1492
1468
  template <typename T>
1493
1469
  inline CUDA_CALLABLE void expect_eq(const T& actual, const T& expected)
1494
1470
  {
@@ -1528,7 +1504,7 @@ inline CUDA_CALLABLE void expect_near(const T& actual, const T& expected, const
1528
1504
  {
1529
1505
  if (abs(actual - expected) > tolerance)
1530
1506
  {
1531
- printf("Error, expect_near() failed with torerance "); print(tolerance);
1507
+ printf("Error, expect_near() failed with tolerance "); print(tolerance);
1532
1508
  printf("\t Expected: "); print(expected);
1533
1509
  printf("\t Actual: "); print(actual);
1534
1510
  }
@@ -1539,7 +1515,7 @@ inline CUDA_CALLABLE void expect_near(const vec3& actual, const vec3& expected,
1539
1515
  const float diff = max(max(abs(actual[0] - expected[0]), abs(actual[1] - expected[1])), abs(actual[2] - expected[2]));
1540
1516
  if (diff > tolerance)
1541
1517
  {
1542
- printf("Error, expect_near() failed with torerance "); print(tolerance);
1518
+ printf("Error, expect_near() failed with tolerance "); print(tolerance);
1543
1519
  printf("\t Expected: "); print(expected);
1544
1520
  printf("\t Actual: "); print(actual);
1545
1521
  }