warp-lang 1.0.2__py3-none-win_amd64.whl → 1.2.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (356) hide show
  1. warp/__init__.py +108 -97
  2. warp/__init__.pyi +1 -1
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +88 -113
  6. warp/build_dll.py +383 -375
  7. warp/builtins.py +3693 -3354
  8. warp/codegen.py +2925 -2792
  9. warp/config.py +40 -36
  10. warp/constants.py +49 -45
  11. warp/context.py +5409 -5102
  12. warp/dlpack.py +442 -442
  13. warp/examples/__init__.py +16 -16
  14. warp/examples/assets/bear.usd +0 -0
  15. warp/examples/assets/bunny.usd +0 -0
  16. warp/examples/assets/cartpole.urdf +110 -110
  17. warp/examples/assets/crazyflie.usd +0 -0
  18. warp/examples/assets/cube.usd +0 -0
  19. warp/examples/assets/nv_ant.xml +92 -92
  20. warp/examples/assets/nv_humanoid.xml +183 -183
  21. warp/examples/assets/quadruped.urdf +267 -267
  22. warp/examples/assets/rocks.nvdb +0 -0
  23. warp/examples/assets/rocks.usd +0 -0
  24. warp/examples/assets/sphere.usd +0 -0
  25. warp/examples/benchmarks/benchmark_api.py +381 -383
  26. warp/examples/benchmarks/benchmark_cloth.py +278 -277
  27. warp/examples/benchmarks/benchmark_cloth_cupy.py +88 -88
  28. warp/examples/benchmarks/benchmark_cloth_jax.py +97 -100
  29. warp/examples/benchmarks/benchmark_cloth_numba.py +146 -142
  30. warp/examples/benchmarks/benchmark_cloth_numpy.py +77 -77
  31. warp/examples/benchmarks/benchmark_cloth_pytorch.py +86 -86
  32. warp/examples/benchmarks/benchmark_cloth_taichi.py +112 -112
  33. warp/examples/benchmarks/benchmark_cloth_warp.py +145 -146
  34. warp/examples/benchmarks/benchmark_launches.py +293 -295
  35. warp/examples/browse.py +29 -29
  36. warp/examples/core/example_dem.py +232 -219
  37. warp/examples/core/example_fluid.py +291 -267
  38. warp/examples/core/example_graph_capture.py +142 -126
  39. warp/examples/core/example_marching_cubes.py +186 -174
  40. warp/examples/core/example_mesh.py +172 -155
  41. warp/examples/core/example_mesh_intersect.py +203 -193
  42. warp/examples/core/example_nvdb.py +174 -170
  43. warp/examples/core/example_raycast.py +103 -90
  44. warp/examples/core/example_raymarch.py +197 -178
  45. warp/examples/core/example_render_opengl.py +183 -141
  46. warp/examples/core/example_sph.py +403 -387
  47. warp/examples/core/example_torch.py +219 -181
  48. warp/examples/core/example_wave.py +261 -248
  49. warp/examples/fem/bsr_utils.py +378 -380
  50. warp/examples/fem/example_apic_fluid.py +432 -389
  51. warp/examples/fem/example_burgers.py +262 -0
  52. warp/examples/fem/example_convection_diffusion.py +180 -168
  53. warp/examples/fem/example_convection_diffusion_dg.py +217 -209
  54. warp/examples/fem/example_deformed_geometry.py +175 -159
  55. warp/examples/fem/example_diffusion.py +199 -173
  56. warp/examples/fem/example_diffusion_3d.py +178 -152
  57. warp/examples/fem/example_diffusion_mgpu.py +219 -214
  58. warp/examples/fem/example_mixed_elasticity.py +242 -222
  59. warp/examples/fem/example_navier_stokes.py +257 -243
  60. warp/examples/fem/example_stokes.py +218 -192
  61. warp/examples/fem/example_stokes_transfer.py +263 -249
  62. warp/examples/fem/mesh_utils.py +133 -109
  63. warp/examples/fem/plot_utils.py +292 -287
  64. warp/examples/optim/example_bounce.py +258 -246
  65. warp/examples/optim/example_cloth_throw.py +220 -209
  66. warp/examples/optim/example_diffray.py +564 -536
  67. warp/examples/optim/example_drone.py +862 -835
  68. warp/examples/optim/example_inverse_kinematics.py +174 -168
  69. warp/examples/optim/example_inverse_kinematics_torch.py +183 -169
  70. warp/examples/optim/example_spring_cage.py +237 -231
  71. warp/examples/optim/example_trajectory.py +221 -199
  72. warp/examples/optim/example_walker.py +304 -293
  73. warp/examples/sim/example_cartpole.py +137 -129
  74. warp/examples/sim/example_cloth.py +194 -186
  75. warp/examples/sim/example_granular.py +122 -111
  76. warp/examples/sim/example_granular_collision_sdf.py +195 -186
  77. warp/examples/sim/example_jacobian_ik.py +234 -214
  78. warp/examples/sim/example_particle_chain.py +116 -105
  79. warp/examples/sim/example_quadruped.py +191 -180
  80. warp/examples/sim/example_rigid_chain.py +195 -187
  81. warp/examples/sim/example_rigid_contact.py +187 -177
  82. warp/examples/sim/example_rigid_force.py +125 -125
  83. warp/examples/sim/example_rigid_gyroscopic.py +107 -95
  84. warp/examples/sim/example_rigid_soft_contact.py +132 -122
  85. warp/examples/sim/example_soft_body.py +188 -177
  86. warp/fabric.py +337 -335
  87. warp/fem/__init__.py +61 -27
  88. warp/fem/cache.py +403 -388
  89. warp/fem/dirichlet.py +178 -179
  90. warp/fem/domain.py +262 -263
  91. warp/fem/field/__init__.py +100 -101
  92. warp/fem/field/field.py +148 -149
  93. warp/fem/field/nodal_field.py +298 -299
  94. warp/fem/field/restriction.py +22 -21
  95. warp/fem/field/test.py +180 -181
  96. warp/fem/field/trial.py +183 -183
  97. warp/fem/geometry/__init__.py +16 -19
  98. warp/fem/geometry/closest_point.py +69 -70
  99. warp/fem/geometry/deformed_geometry.py +270 -271
  100. warp/fem/geometry/element.py +748 -744
  101. warp/fem/geometry/geometry.py +184 -186
  102. warp/fem/geometry/grid_2d.py +380 -373
  103. warp/fem/geometry/grid_3d.py +437 -435
  104. warp/fem/geometry/hexmesh.py +953 -953
  105. warp/fem/geometry/nanogrid.py +455 -0
  106. warp/fem/geometry/partition.py +374 -376
  107. warp/fem/geometry/quadmesh_2d.py +532 -532
  108. warp/fem/geometry/tetmesh.py +840 -840
  109. warp/fem/geometry/trimesh_2d.py +577 -577
  110. warp/fem/integrate.py +1684 -1615
  111. warp/fem/operator.py +190 -191
  112. warp/fem/polynomial.py +214 -213
  113. warp/fem/quadrature/__init__.py +2 -2
  114. warp/fem/quadrature/pic_quadrature.py +243 -245
  115. warp/fem/quadrature/quadrature.py +295 -294
  116. warp/fem/space/__init__.py +179 -292
  117. warp/fem/space/basis_space.py +522 -489
  118. warp/fem/space/collocated_function_space.py +100 -105
  119. warp/fem/space/dof_mapper.py +236 -236
  120. warp/fem/space/function_space.py +148 -145
  121. warp/fem/space/grid_2d_function_space.py +148 -267
  122. warp/fem/space/grid_3d_function_space.py +167 -306
  123. warp/fem/space/hexmesh_function_space.py +253 -352
  124. warp/fem/space/nanogrid_function_space.py +202 -0
  125. warp/fem/space/partition.py +350 -350
  126. warp/fem/space/quadmesh_2d_function_space.py +261 -369
  127. warp/fem/space/restriction.py +161 -160
  128. warp/fem/space/shape/__init__.py +90 -15
  129. warp/fem/space/shape/cube_shape_function.py +728 -738
  130. warp/fem/space/shape/shape_function.py +102 -103
  131. warp/fem/space/shape/square_shape_function.py +611 -611
  132. warp/fem/space/shape/tet_shape_function.py +565 -567
  133. warp/fem/space/shape/triangle_shape_function.py +429 -429
  134. warp/fem/space/tetmesh_function_space.py +224 -292
  135. warp/fem/space/topology.py +297 -295
  136. warp/fem/space/trimesh_2d_function_space.py +153 -221
  137. warp/fem/types.py +77 -77
  138. warp/fem/utils.py +495 -495
  139. warp/jax.py +166 -141
  140. warp/jax_experimental.py +341 -339
  141. warp/native/array.h +1081 -1025
  142. warp/native/builtin.h +1603 -1560
  143. warp/native/bvh.cpp +402 -398
  144. warp/native/bvh.cu +533 -525
  145. warp/native/bvh.h +430 -429
  146. warp/native/clang/clang.cpp +496 -464
  147. warp/native/crt.cpp +42 -32
  148. warp/native/crt.h +352 -335
  149. warp/native/cuda_crt.h +1049 -1049
  150. warp/native/cuda_util.cpp +549 -540
  151. warp/native/cuda_util.h +288 -203
  152. warp/native/cutlass_gemm.cpp +34 -34
  153. warp/native/cutlass_gemm.cu +372 -372
  154. warp/native/error.cpp +66 -66
  155. warp/native/error.h +27 -27
  156. warp/native/exports.h +187 -0
  157. warp/native/fabric.h +228 -228
  158. warp/native/hashgrid.cpp +301 -278
  159. warp/native/hashgrid.cu +78 -77
  160. warp/native/hashgrid.h +227 -227
  161. warp/native/initializer_array.h +32 -32
  162. warp/native/intersect.h +1204 -1204
  163. warp/native/intersect_adj.h +365 -365
  164. warp/native/intersect_tri.h +322 -322
  165. warp/native/marching.cpp +2 -2
  166. warp/native/marching.cu +497 -497
  167. warp/native/marching.h +2 -2
  168. warp/native/mat.h +1545 -1498
  169. warp/native/matnn.h +333 -333
  170. warp/native/mesh.cpp +203 -203
  171. warp/native/mesh.cu +292 -293
  172. warp/native/mesh.h +1887 -1887
  173. warp/native/nanovdb/GridHandle.h +366 -0
  174. warp/native/nanovdb/HostBuffer.h +590 -0
  175. warp/native/nanovdb/NanoVDB.h +6624 -4782
  176. warp/native/nanovdb/PNanoVDB.h +3390 -2553
  177. warp/native/noise.h +850 -850
  178. warp/native/quat.h +1112 -1085
  179. warp/native/rand.h +303 -299
  180. warp/native/range.h +108 -108
  181. warp/native/reduce.cpp +156 -156
  182. warp/native/reduce.cu +348 -348
  183. warp/native/runlength_encode.cpp +61 -61
  184. warp/native/runlength_encode.cu +46 -46
  185. warp/native/scan.cpp +30 -30
  186. warp/native/scan.cu +36 -36
  187. warp/native/scan.h +7 -7
  188. warp/native/solid_angle.h +442 -442
  189. warp/native/sort.cpp +94 -94
  190. warp/native/sort.cu +97 -97
  191. warp/native/sort.h +14 -14
  192. warp/native/sparse.cpp +337 -337
  193. warp/native/sparse.cu +544 -544
  194. warp/native/spatial.h +630 -630
  195. warp/native/svd.h +562 -562
  196. warp/native/temp_buffer.h +30 -30
  197. warp/native/vec.h +1177 -1133
  198. warp/native/volume.cpp +529 -297
  199. warp/native/volume.cu +58 -32
  200. warp/native/volume.h +960 -538
  201. warp/native/volume_builder.cu +446 -425
  202. warp/native/volume_builder.h +34 -19
  203. warp/native/volume_impl.h +61 -0
  204. warp/native/warp.cpp +1057 -1052
  205. warp/native/warp.cu +2949 -2828
  206. warp/native/warp.h +321 -305
  207. warp/optim/__init__.py +9 -9
  208. warp/optim/adam.py +120 -120
  209. warp/optim/linear.py +1104 -939
  210. warp/optim/sgd.py +104 -92
  211. warp/render/__init__.py +10 -10
  212. warp/render/render_opengl.py +3356 -3204
  213. warp/render/render_usd.py +768 -749
  214. warp/render/utils.py +152 -150
  215. warp/sim/__init__.py +52 -59
  216. warp/sim/articulation.py +685 -685
  217. warp/sim/collide.py +1594 -1590
  218. warp/sim/import_mjcf.py +489 -481
  219. warp/sim/import_snu.py +220 -221
  220. warp/sim/import_urdf.py +536 -516
  221. warp/sim/import_usd.py +887 -881
  222. warp/sim/inertia.py +316 -317
  223. warp/sim/integrator.py +234 -233
  224. warp/sim/integrator_euler.py +1956 -1956
  225. warp/sim/integrator_featherstone.py +1917 -1991
  226. warp/sim/integrator_xpbd.py +3288 -3312
  227. warp/sim/model.py +4473 -4314
  228. warp/sim/particles.py +113 -112
  229. warp/sim/render.py +417 -403
  230. warp/sim/utils.py +413 -410
  231. warp/sparse.py +1289 -1227
  232. warp/stubs.py +2192 -2469
  233. warp/tape.py +1162 -225
  234. warp/tests/__init__.py +1 -1
  235. warp/tests/__main__.py +4 -4
  236. warp/tests/assets/test_index_grid.nvdb +0 -0
  237. warp/tests/assets/torus.usda +105 -105
  238. warp/tests/aux_test_class_kernel.py +26 -26
  239. warp/tests/aux_test_compile_consts_dummy.py +10 -10
  240. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -21
  241. warp/tests/aux_test_dependent.py +20 -22
  242. warp/tests/aux_test_grad_customs.py +21 -23
  243. warp/tests/aux_test_reference.py +9 -11
  244. warp/tests/aux_test_reference_reference.py +8 -10
  245. warp/tests/aux_test_square.py +15 -17
  246. warp/tests/aux_test_unresolved_func.py +14 -14
  247. warp/tests/aux_test_unresolved_symbol.py +14 -14
  248. warp/tests/disabled_kinematics.py +237 -239
  249. warp/tests/run_coverage_serial.py +31 -31
  250. warp/tests/test_adam.py +155 -157
  251. warp/tests/test_arithmetic.py +1088 -1124
  252. warp/tests/test_array.py +2415 -2326
  253. warp/tests/test_array_reduce.py +148 -150
  254. warp/tests/test_async.py +666 -656
  255. warp/tests/test_atomic.py +139 -141
  256. warp/tests/test_bool.py +212 -149
  257. warp/tests/test_builtins_resolution.py +1290 -1292
  258. warp/tests/test_bvh.py +162 -171
  259. warp/tests/test_closest_point_edge_edge.py +227 -228
  260. warp/tests/test_codegen.py +562 -553
  261. warp/tests/test_compile_consts.py +217 -101
  262. warp/tests/test_conditional.py +244 -246
  263. warp/tests/test_copy.py +230 -215
  264. warp/tests/test_ctypes.py +630 -632
  265. warp/tests/test_dense.py +65 -67
  266. warp/tests/test_devices.py +89 -98
  267. warp/tests/test_dlpack.py +528 -529
  268. warp/tests/test_examples.py +403 -378
  269. warp/tests/test_fabricarray.py +952 -955
  270. warp/tests/test_fast_math.py +60 -54
  271. warp/tests/test_fem.py +1298 -1278
  272. warp/tests/test_fp16.py +128 -130
  273. warp/tests/test_func.py +336 -337
  274. warp/tests/test_generics.py +596 -571
  275. warp/tests/test_grad.py +885 -640
  276. warp/tests/test_grad_customs.py +331 -336
  277. warp/tests/test_hash_grid.py +208 -164
  278. warp/tests/test_import.py +37 -39
  279. warp/tests/test_indexedarray.py +1132 -1134
  280. warp/tests/test_intersect.py +65 -67
  281. warp/tests/test_jax.py +305 -307
  282. warp/tests/test_large.py +169 -164
  283. warp/tests/test_launch.py +352 -354
  284. warp/tests/test_lerp.py +217 -261
  285. warp/tests/test_linear_solvers.py +189 -171
  286. warp/tests/test_lvalue.py +419 -493
  287. warp/tests/test_marching_cubes.py +63 -65
  288. warp/tests/test_mat.py +1799 -1827
  289. warp/tests/test_mat_lite.py +113 -115
  290. warp/tests/test_mat_scalar_ops.py +2905 -2889
  291. warp/tests/test_math.py +124 -193
  292. warp/tests/test_matmul.py +498 -499
  293. warp/tests/test_matmul_lite.py +408 -410
  294. warp/tests/test_mempool.py +186 -190
  295. warp/tests/test_mesh.py +281 -324
  296. warp/tests/test_mesh_query_aabb.py +226 -241
  297. warp/tests/test_mesh_query_point.py +690 -702
  298. warp/tests/test_mesh_query_ray.py +290 -303
  299. warp/tests/test_mlp.py +274 -276
  300. warp/tests/test_model.py +108 -110
  301. warp/tests/test_module_hashing.py +111 -0
  302. warp/tests/test_modules_lite.py +36 -39
  303. warp/tests/test_multigpu.py +161 -163
  304. warp/tests/test_noise.py +244 -248
  305. warp/tests/test_operators.py +248 -250
  306. warp/tests/test_options.py +121 -125
  307. warp/tests/test_peer.py +131 -137
  308. warp/tests/test_pinned.py +76 -78
  309. warp/tests/test_print.py +52 -54
  310. warp/tests/test_quat.py +2084 -2086
  311. warp/tests/test_rand.py +324 -288
  312. warp/tests/test_reload.py +207 -217
  313. warp/tests/test_rounding.py +177 -179
  314. warp/tests/test_runlength_encode.py +188 -190
  315. warp/tests/test_sim_grad.py +241 -0
  316. warp/tests/test_sim_kinematics.py +89 -97
  317. warp/tests/test_smoothstep.py +166 -168
  318. warp/tests/test_snippet.py +303 -266
  319. warp/tests/test_sparse.py +466 -460
  320. warp/tests/test_spatial.py +2146 -2148
  321. warp/tests/test_special_values.py +362 -0
  322. warp/tests/test_streams.py +484 -473
  323. warp/tests/test_struct.py +708 -675
  324. warp/tests/test_tape.py +171 -148
  325. warp/tests/test_torch.py +741 -743
  326. warp/tests/test_transient_module.py +85 -87
  327. warp/tests/test_types.py +554 -659
  328. warp/tests/test_utils.py +488 -499
  329. warp/tests/test_vec.py +1262 -1268
  330. warp/tests/test_vec_lite.py +71 -73
  331. warp/tests/test_vec_scalar_ops.py +2097 -2099
  332. warp/tests/test_verify_fp.py +92 -94
  333. warp/tests/test_volume.py +961 -736
  334. warp/tests/test_volume_write.py +338 -265
  335. warp/tests/unittest_serial.py +38 -37
  336. warp/tests/unittest_suites.py +367 -359
  337. warp/tests/unittest_utils.py +434 -578
  338. warp/tests/unused_test_misc.py +69 -71
  339. warp/tests/walkthrough_debug.py +85 -85
  340. warp/thirdparty/appdirs.py +598 -598
  341. warp/thirdparty/dlpack.py +143 -143
  342. warp/thirdparty/unittest_parallel.py +563 -561
  343. warp/torch.py +321 -295
  344. warp/types.py +4941 -4450
  345. warp/utils.py +1008 -821
  346. {warp_lang-1.0.2.dist-info → warp_lang-1.2.0.dist-info}/LICENSE.md +126 -126
  347. {warp_lang-1.0.2.dist-info → warp_lang-1.2.0.dist-info}/METADATA +365 -400
  348. warp_lang-1.2.0.dist-info/RECORD +359 -0
  349. warp/examples/assets/cube.usda +0 -42
  350. warp/examples/assets/sphere.usda +0 -56
  351. warp/examples/assets/torus.usda +0 -105
  352. warp/examples/fem/example_convection_diffusion_dg0.py +0 -194
  353. warp/native/nanovdb/PNanoVDBWrite.h +0 -295
  354. warp_lang-1.0.2.dist-info/RECORD +0 -352
  355. {warp_lang-1.0.2.dist-info → warp_lang-1.2.0.dist-info}/WHEEL +0 -0
  356. {warp_lang-1.0.2.dist-info → warp_lang-1.2.0.dist-info}/top_level.txt +0 -0
warp/native/builtin.h CHANGED
@@ -1,1560 +1,1603 @@
1
- /** Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
2
- * NVIDIA CORPORATION and its licensors retain all intellectual property
3
- * and proprietary rights in and to this software, related documentation
4
- * and any modifications thereto. Any use, reproduction, disclosure or
5
- * distribution of this software and related documentation without an express
6
- * license agreement from NVIDIA CORPORATION is strictly prohibited.
7
- */
8
-
9
- #pragma once
10
-
11
- // All built-in types and functions. To be compatible with runtime NVRTC compilation
12
- // this header must be independently compilable (i.e.: without external SDK headers)
13
- // to achieve this we redefine a subset of CRT functions (printf, pow, sin, cos, etc)
14
-
15
- #include "crt.h"
16
-
17
- #ifdef _WIN32
18
- #define __restrict__ __restrict
19
- #endif
20
-
21
- #if !defined(__CUDACC__)
22
- #define CUDA_CALLABLE
23
- #define CUDA_CALLABLE_DEVICE
24
- #else
25
- #define CUDA_CALLABLE __host__ __device__
26
- #define CUDA_CALLABLE_DEVICE __device__
27
- #endif
28
-
29
- #ifdef WP_VERIFY_FP
30
- #define FP_CHECK 1
31
- #define DO_IF_FPCHECK(X) {X}
32
- #define DO_IF_NO_FPCHECK(X)
33
- #else
34
- #define FP_CHECK 0
35
- #define DO_IF_FPCHECK(X)
36
- #define DO_IF_NO_FPCHECK(X) {X}
37
- #endif
38
-
39
- #define RAD_TO_DEG 57.29577951308232087679
40
- #define DEG_TO_RAD 0.01745329251994329577
41
-
42
- #if defined(__CUDACC__) && !defined(_MSC_VER)
43
- __device__ void __debugbreak() {}
44
- #endif
45
-
46
- namespace wp
47
- {
48
-
49
- // numeric types (used from generated kernels)
50
- typedef float float32;
51
- typedef double float64;
52
-
53
- typedef int8_t int8;
54
- typedef uint8_t uint8;
55
-
56
- typedef int16_t int16;
57
- typedef uint16_t uint16;
58
-
59
- typedef int32_t int32;
60
- typedef uint32_t uint32;
61
-
62
- typedef int64_t int64;
63
- typedef uint64_t uint64;
64
-
65
-
66
- // matches Python string type for constant strings
67
- typedef const char* str;
68
-
69
-
70
-
71
- struct half;
72
-
73
- CUDA_CALLABLE half float_to_half(float x);
74
- CUDA_CALLABLE float half_to_float(half x);
75
-
76
- struct half
77
- {
78
- CUDA_CALLABLE inline half() : u(0) {}
79
-
80
- CUDA_CALLABLE inline half(float f)
81
- {
82
- *this = float_to_half(f);
83
- }
84
-
85
- unsigned short u;
86
-
87
- CUDA_CALLABLE inline bool operator==(const half& h) const { return u == h.u; }
88
- CUDA_CALLABLE inline bool operator!=(const half& h) const { return u != h.u; }
89
- CUDA_CALLABLE inline bool operator>(const half& h) const { return half_to_float(*this) > half_to_float(h); }
90
- CUDA_CALLABLE inline bool operator>=(const half& h) const { return half_to_float(*this) >= half_to_float(h); }
91
- CUDA_CALLABLE inline bool operator<(const half& h) const { return half_to_float(*this) < half_to_float(h); }
92
- CUDA_CALLABLE inline bool operator<=(const half& h) const { return half_to_float(*this) <= half_to_float(h); }
93
-
94
- CUDA_CALLABLE inline bool operator!() const
95
- {
96
- return float32(*this) == 0;
97
- }
98
-
99
- CUDA_CALLABLE inline half operator*=(const half& h)
100
- {
101
- half prod = half(float32(*this) * float32(h));
102
- this->u = prod.u;
103
- return *this;
104
- }
105
-
106
- CUDA_CALLABLE inline half operator/=(const half& h)
107
- {
108
- half quot = half(float32(*this) / float32(h));
109
- this->u = quot.u;
110
- return *this;
111
- }
112
-
113
- CUDA_CALLABLE inline half operator+=(const half& h)
114
- {
115
- half sum = half(float32(*this) + float32(h));
116
- this->u = sum.u;
117
- return *this;
118
- }
119
-
120
- CUDA_CALLABLE inline half operator-=(const half& h)
121
- {
122
- half diff = half(float32(*this) - float32(h));
123
- this->u = diff.u;
124
- return *this;
125
- }
126
-
127
- CUDA_CALLABLE inline operator float32() const { return float32(half_to_float(*this)); }
128
- CUDA_CALLABLE inline operator float64() const { return float64(half_to_float(*this)); }
129
- CUDA_CALLABLE inline operator int8() const { return int8(half_to_float(*this)); }
130
- CUDA_CALLABLE inline operator uint8() const { return uint8(half_to_float(*this)); }
131
- CUDA_CALLABLE inline operator int16() const { return int16(half_to_float(*this)); }
132
- CUDA_CALLABLE inline operator uint16() const { return uint16(half_to_float(*this)); }
133
- CUDA_CALLABLE inline operator int32() const { return int32(half_to_float(*this)); }
134
- CUDA_CALLABLE inline operator uint32() const { return uint32(half_to_float(*this)); }
135
- CUDA_CALLABLE inline operator int64() const { return int64(half_to_float(*this)); }
136
- CUDA_CALLABLE inline operator uint64() const { return uint64(half_to_float(*this)); }
137
- };
138
-
139
- static_assert(sizeof(half) == 2, "Size of half / float16 type must be 2-bytes");
140
-
141
- typedef half float16;
142
-
143
- #if defined(__CUDA_ARCH__)
144
-
145
- CUDA_CALLABLE inline half float_to_half(float x)
146
- {
147
- half h;
148
- asm("{ cvt.rn.f16.f32 %0, %1;}\n" : "=h"(h.u) : "f"(x));
149
- return h;
150
- }
151
-
152
- CUDA_CALLABLE inline float half_to_float(half x)
153
- {
154
- float val;
155
- asm("{ cvt.f32.f16 %0, %1;}\n" : "=f"(val) : "h"(x.u));
156
- return val;
157
- }
158
-
159
- #elif defined(__clang__)
160
-
161
- // _Float16 is Clang's native half-precision floating-point type
162
- inline half float_to_half(float x)
163
- {
164
-
165
- _Float16 f16 = static_cast<_Float16>(x);
166
- return *reinterpret_cast<half*>(&f16);
167
- }
168
-
169
- inline float half_to_float(half h)
170
- {
171
- _Float16 f16 = *reinterpret_cast<_Float16*>(&h);
172
- return static_cast<float>(f16);
173
- }
174
-
175
- #else // Native C++ for Warp builtins outside of kernels
176
-
177
- extern "C" WP_API uint16_t float_to_half_bits(float x);
178
- extern "C" WP_API float half_bits_to_float(uint16_t u);
179
-
180
- inline half float_to_half(float x)
181
- {
182
- half h;
183
- h.u = float_to_half_bits(x);
184
- return h;
185
- }
186
-
187
- inline float half_to_float(half h)
188
- {
189
- return half_bits_to_float(h.u);
190
- }
191
-
192
- #endif
193
-
194
-
195
- // BAD operator implementations for fp16 arithmetic...
196
-
197
- // negation:
198
- inline CUDA_CALLABLE half operator - (half a)
199
- {
200
- return float_to_half( -half_to_float(a) );
201
- }
202
-
203
- inline CUDA_CALLABLE half operator + (half a,half b)
204
- {
205
- return float_to_half( half_to_float(a) + half_to_float(b) );
206
- }
207
-
208
- inline CUDA_CALLABLE half operator - (half a,half b)
209
- {
210
- return float_to_half( half_to_float(a) - half_to_float(b) );
211
- }
212
-
213
- inline CUDA_CALLABLE half operator * (half a,half b)
214
- {
215
- return float_to_half( half_to_float(a) * half_to_float(b) );
216
- }
217
-
218
- inline CUDA_CALLABLE half operator * (half a,double b)
219
- {
220
- return float_to_half( half_to_float(a) * b );
221
- }
222
-
223
- inline CUDA_CALLABLE half operator * (double a,half b)
224
- {
225
- return float_to_half( a * half_to_float(b) );
226
- }
227
-
228
- inline CUDA_CALLABLE half operator / (half a,half b)
229
- {
230
- return float_to_half( half_to_float(a) / half_to_float(b) );
231
- }
232
-
233
-
234
-
235
-
236
-
237
- template <typename T>
238
- CUDA_CALLABLE float cast_float(T x) { return (float)(x); }
239
-
240
- template <typename T>
241
- CUDA_CALLABLE int cast_int(T x) { return (int)(x); }
242
-
243
- template <typename T>
244
- CUDA_CALLABLE void adj_cast_float(T x, T& adj_x, float adj_ret) { adj_x += T(adj_ret); }
245
-
246
- template <typename T>
247
- CUDA_CALLABLE void adj_cast_int(T x, T& adj_x, int adj_ret) { adj_x += adj_ret; }
248
-
249
- template <typename T>
250
- CUDA_CALLABLE inline void adj_int8(T, T&, int8) {}
251
- template <typename T>
252
- CUDA_CALLABLE inline void adj_uint8(T, T&, uint8) {}
253
- template <typename T>
254
- CUDA_CALLABLE inline void adj_int16(T, T&, int16) {}
255
- template <typename T>
256
- CUDA_CALLABLE inline void adj_uint16(T, T&, uint16) {}
257
- template <typename T>
258
- CUDA_CALLABLE inline void adj_int32(T, T&, int32) {}
259
- template <typename T>
260
- CUDA_CALLABLE inline void adj_uint32(T, T&, uint32) {}
261
- template <typename T>
262
- CUDA_CALLABLE inline void adj_int64(T, T&, int64) {}
263
- template <typename T>
264
- CUDA_CALLABLE inline void adj_uint64(T, T&, uint64) {}
265
-
266
-
267
- template <typename T>
268
- CUDA_CALLABLE inline void adj_float16(T x, T& adj_x, float16 adj_ret) { adj_x += T(adj_ret); }
269
- template <typename T>
270
- CUDA_CALLABLE inline void adj_float32(T x, T& adj_x, float32 adj_ret) { adj_x += T(adj_ret); }
271
- template <typename T>
272
- CUDA_CALLABLE inline void adj_float64(T x, T& adj_x, float64 adj_ret) { adj_x += T(adj_ret); }
273
-
274
-
275
- #define kEps 0.0f
276
-
277
- // basic ops for integer types
278
- #define DECLARE_INT_OPS(T) \
279
- inline CUDA_CALLABLE T mul(T a, T b) { return a*b; } \
280
- inline CUDA_CALLABLE T div(T a, T b) { return a/b; } \
281
- inline CUDA_CALLABLE T add(T a, T b) { return a+b; } \
282
- inline CUDA_CALLABLE T sub(T a, T b) { return a-b; } \
283
- inline CUDA_CALLABLE T mod(T a, T b) { return a%b; } \
284
- inline CUDA_CALLABLE T min(T a, T b) { return a<b?a:b; } \
285
- inline CUDA_CALLABLE T max(T a, T b) { return a>b?a:b; } \
286
- inline CUDA_CALLABLE T clamp(T x, T a, T b) { return min(max(a, x), b); } \
287
- inline CUDA_CALLABLE T floordiv(T a, T b) { return a/b; } \
288
- inline CUDA_CALLABLE T nonzero(T x) { return x == T(0) ? T(0) : T(1); } \
289
- inline CUDA_CALLABLE T sqrt(T x) { return 0; } \
290
- inline CUDA_CALLABLE T bit_and(T a, T b) { return a&b; } \
291
- inline CUDA_CALLABLE T bit_or(T a, T b) { return a|b; } \
292
- inline CUDA_CALLABLE T bit_xor(T a, T b) { return a^b; } \
293
- inline CUDA_CALLABLE T lshift(T a, T b) { return a<<b; } \
294
- inline CUDA_CALLABLE T rshift(T a, T b) { return a>>b; } \
295
- inline CUDA_CALLABLE T invert(T x) { return ~x; } \
296
- inline CUDA_CALLABLE bool isfinite(T x) { return true; } \
297
- inline CUDA_CALLABLE void adj_mul(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
298
- inline CUDA_CALLABLE void adj_div(T a, T b, T ret, T& adj_a, T& adj_b, T adj_ret) { } \
299
- inline CUDA_CALLABLE void adj_add(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
300
- inline CUDA_CALLABLE void adj_sub(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
301
- inline CUDA_CALLABLE void adj_mod(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
302
- inline CUDA_CALLABLE void adj_min(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
303
- inline CUDA_CALLABLE void adj_max(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
304
- inline CUDA_CALLABLE void adj_abs(T x, T adj_x, T& adj_ret) { } \
305
- inline CUDA_CALLABLE void adj_sign(T x, T adj_x, T& adj_ret) { } \
306
- inline CUDA_CALLABLE void adj_clamp(T x, T a, T b, T& adj_x, T& adj_a, T& adj_b, T adj_ret) { } \
307
- inline CUDA_CALLABLE void adj_floordiv(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
308
- inline CUDA_CALLABLE void adj_step(T x, T& adj_x, T adj_ret) { } \
309
- inline CUDA_CALLABLE void adj_nonzero(T x, T& adj_x, T adj_ret) { } \
310
- inline CUDA_CALLABLE void adj_sqrt(T x, T adj_x, T& adj_ret) { } \
311
- inline CUDA_CALLABLE void adj_bit_and(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
312
- inline CUDA_CALLABLE void adj_bit_or(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
313
- inline CUDA_CALLABLE void adj_bit_xor(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
314
- inline CUDA_CALLABLE void adj_lshift(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
315
- inline CUDA_CALLABLE void adj_rshift(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
316
- inline CUDA_CALLABLE void adj_invert(T x, T adj_x, T& adj_ret) { }
317
-
318
- inline CUDA_CALLABLE int8 abs(int8 x) { return ::abs(x); }
319
- inline CUDA_CALLABLE int16 abs(int16 x) { return ::abs(x); }
320
- inline CUDA_CALLABLE int32 abs(int32 x) { return ::abs(x); }
321
- inline CUDA_CALLABLE int64 abs(int64 x) { return ::llabs(x); }
322
- inline CUDA_CALLABLE uint8 abs(uint8 x) { return x; }
323
- inline CUDA_CALLABLE uint16 abs(uint16 x) { return x; }
324
- inline CUDA_CALLABLE uint32 abs(uint32 x) { return x; }
325
- inline CUDA_CALLABLE uint64 abs(uint64 x) { return x; }
326
-
327
- DECLARE_INT_OPS(int8)
328
- DECLARE_INT_OPS(int16)
329
- DECLARE_INT_OPS(int32)
330
- DECLARE_INT_OPS(int64)
331
- DECLARE_INT_OPS(uint8)
332
- DECLARE_INT_OPS(uint16)
333
- DECLARE_INT_OPS(uint32)
334
- DECLARE_INT_OPS(uint64)
335
-
336
-
337
- inline CUDA_CALLABLE int8 step(int8 x) { return x < 0 ? 1 : 0; }
338
- inline CUDA_CALLABLE int16 step(int16 x) { return x < 0 ? 1 : 0; }
339
- inline CUDA_CALLABLE int32 step(int32 x) { return x < 0 ? 1 : 0; }
340
- inline CUDA_CALLABLE int64 step(int64 x) { return x < 0 ? 1 : 0; }
341
- inline CUDA_CALLABLE uint8 step(uint8 x) { return 0; }
342
- inline CUDA_CALLABLE uint16 step(uint16 x) { return 0; }
343
- inline CUDA_CALLABLE uint32 step(uint32 x) { return 0; }
344
- inline CUDA_CALLABLE uint64 step(uint64 x) { return 0; }
345
-
346
-
347
- inline CUDA_CALLABLE int8 sign(int8 x) { return x < 0 ? -1 : 1; }
348
- inline CUDA_CALLABLE int8 sign(int16 x) { return x < 0 ? -1 : 1; }
349
- inline CUDA_CALLABLE int8 sign(int32 x) { return x < 0 ? -1 : 1; }
350
- inline CUDA_CALLABLE int8 sign(int64 x) { return x < 0 ? -1 : 1; }
351
- inline CUDA_CALLABLE uint8 sign(uint8 x) { return 1; }
352
- inline CUDA_CALLABLE uint16 sign(uint16 x) { return 1; }
353
- inline CUDA_CALLABLE uint32 sign(uint32 x) { return 1; }
354
- inline CUDA_CALLABLE uint64 sign(uint64 x) { return 1; }
355
-
356
-
357
- // Catch-all for non-float types
358
- template<typename T>
359
- inline bool CUDA_CALLABLE isfinite(const T&)
360
- {
361
- return true;
362
- }
363
-
364
- inline bool CUDA_CALLABLE isfinite(half x)
365
- {
366
- return ::isfinite(float(x));
367
- }
368
- inline bool CUDA_CALLABLE isfinite(float x)
369
- {
370
- return ::isfinite(x);
371
- }
372
- inline bool CUDA_CALLABLE isfinite(double x)
373
- {
374
- return ::isfinite(x);
375
- }
376
-
377
- template<typename T>
378
- inline CUDA_CALLABLE void print(const T&)
379
- {
380
- printf("<type without print implementation>\n");
381
- }
382
-
383
- inline CUDA_CALLABLE void print(float16 f)
384
- {
385
- printf("%g\n", half_to_float(f));
386
- }
387
-
388
- inline CUDA_CALLABLE void print(float f)
389
- {
390
- printf("%g\n", f);
391
- }
392
-
393
- inline CUDA_CALLABLE void print(double f)
394
- {
395
- printf("%g\n", f);
396
- }
397
-
398
-
399
- // basic ops for float types
400
- #define DECLARE_FLOAT_OPS(T) \
401
- inline CUDA_CALLABLE T mul(T a, T b) { return a*b; } \
402
- inline CUDA_CALLABLE T add(T a, T b) { return a+b; } \
403
- inline CUDA_CALLABLE T sub(T a, T b) { return a-b; } \
404
- inline CUDA_CALLABLE T min(T a, T b) { return a<b?a:b; } \
405
- inline CUDA_CALLABLE T max(T a, T b) { return a>b?a:b; } \
406
- inline CUDA_CALLABLE T sign(T x) { return x < T(0) ? -1 : 1; } \
407
- inline CUDA_CALLABLE T step(T x) { return x < T(0) ? T(1) : T(0); }\
408
- inline CUDA_CALLABLE T nonzero(T x) { return x == T(0) ? T(0) : T(1); }\
409
- inline CUDA_CALLABLE T clamp(T x, T a, T b) { return min(max(a, x), b); }\
410
- inline CUDA_CALLABLE void adj_abs(T x, T& adj_x, T adj_ret) \
411
- {\
412
- if (x < T(0))\
413
- adj_x -= adj_ret;\
414
- else\
415
- adj_x += adj_ret;\
416
- }\
417
- inline CUDA_CALLABLE void adj_mul(T a, T b, T& adj_a, T& adj_b, T adj_ret) { adj_a += b*adj_ret; adj_b += a*adj_ret; } \
418
- inline CUDA_CALLABLE void adj_add(T a, T b, T& adj_a, T& adj_b, T adj_ret) { adj_a += adj_ret; adj_b += adj_ret; } \
419
- inline CUDA_CALLABLE void adj_sub(T a, T b, T& adj_a, T& adj_b, T adj_ret) { adj_a += adj_ret; adj_b -= adj_ret; } \
420
- inline CUDA_CALLABLE void adj_min(T a, T b, T& adj_a, T& adj_b, T adj_ret) \
421
- { \
422
- if (a < b) \
423
- adj_a += adj_ret; \
424
- else \
425
- adj_b += adj_ret; \
426
- } \
427
- inline CUDA_CALLABLE void adj_max(T a, T b, T& adj_a, T& adj_b, T adj_ret) \
428
- { \
429
- if (a > b) \
430
- adj_a += adj_ret; \
431
- else \
432
- adj_b += adj_ret; \
433
- } \
434
- inline CUDA_CALLABLE void adj_floordiv(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
435
- inline CUDA_CALLABLE void adj_mod(T a, T b, T& adj_a, T& adj_b, T adj_ret){ adj_a += adj_ret; }\
436
- inline CUDA_CALLABLE void adj_sign(T x, T adj_x, T& adj_ret) { }\
437
- inline CUDA_CALLABLE void adj_step(T x, T& adj_x, T adj_ret) { }\
438
- inline CUDA_CALLABLE void adj_nonzero(T x, T& adj_x, T adj_ret) { }\
439
- inline CUDA_CALLABLE void adj_clamp(T x, T a, T b, T& adj_x, T& adj_a, T& adj_b, T adj_ret)\
440
- {\
441
- if (x < a)\
442
- adj_a += adj_ret;\
443
- else if (x > b)\
444
- adj_b += adj_ret;\
445
- else\
446
- adj_x += adj_ret;\
447
- }\
448
- inline CUDA_CALLABLE T div(T a, T b)\
449
- {\
450
- DO_IF_FPCHECK(\
451
- if (!isfinite(a) || !isfinite(b) || b == T(0))\
452
- {\
453
- printf("%s:%d div(%f, %f)\n", __FILE__, __LINE__, float(a), float(b));\
454
- assert(0);\
455
- })\
456
- return a/b;\
457
- }\
458
- inline CUDA_CALLABLE void adj_div(T a, T b, T ret, T& adj_a, T& adj_b, T adj_ret)\
459
- {\
460
- adj_a += adj_ret/b;\
461
- adj_b -= adj_ret*(ret)/b;\
462
- DO_IF_FPCHECK(\
463
- if (!isfinite(adj_a) || !isfinite(adj_b))\
464
- {\
465
- printf("%s:%d - adj_div(%f, %f, %f, %f, %f)\n", __FILE__, __LINE__, float(a), float(b), float(adj_a), float(adj_b), float(adj_ret));\
466
- assert(0);\
467
- })\
468
- }\
469
-
470
- DECLARE_FLOAT_OPS(float16)
471
- DECLARE_FLOAT_OPS(float32)
472
- DECLARE_FLOAT_OPS(float64)
473
-
474
-
475
-
476
- // basic ops for float types
477
- inline CUDA_CALLABLE float16 mod(float16 a, float16 b)
478
- {
479
- #if FP_CHECK
480
- if (!isfinite(a) || !isfinite(b) || float(b) == 0.0f)
481
- {
482
- printf("%s:%d mod(%f, %f)\n", __FILE__, __LINE__, float(a), float(b));
483
- assert(0);
484
- }
485
- #endif
486
- return fmodf(float(a), float(b));
487
- }
488
-
489
- inline CUDA_CALLABLE float32 mod(float32 a, float32 b)
490
- {
491
- #if FP_CHECK
492
- if (!isfinite(a) || !isfinite(b) || b == 0.0f)
493
- {
494
- printf("%s:%d mod(%f, %f)\n", __FILE__, __LINE__, a, b);
495
- assert(0);
496
- }
497
- #endif
498
- return fmodf(a, b);
499
- }
500
-
501
- inline CUDA_CALLABLE double mod(double a, double b)
502
- {
503
- #if FP_CHECK
504
- if (!isfinite(a) || !isfinite(b) || b == 0.0f)
505
- {
506
- printf("%s:%d mod(%f, %f)\n", __FILE__, __LINE__, a, b);
507
- assert(0);
508
- }
509
- #endif
510
- return fmod(a, b);
511
- }
512
-
513
- inline CUDA_CALLABLE half log(half a)
514
- {
515
- #if FP_CHECK
516
- if (!isfinite(a) || float(a) < 0.0f)
517
- {
518
- printf("%s:%d log(%f)\n", __FILE__, __LINE__, float(a));
519
- assert(0);
520
- }
521
- #endif
522
- return ::logf(a);
523
- }
524
-
525
- inline CUDA_CALLABLE float log(float a)
526
- {
527
- #if FP_CHECK
528
- if (!isfinite(a) || a < 0.0f)
529
- {
530
- printf("%s:%d log(%f)\n", __FILE__, __LINE__, a);
531
- assert(0);
532
- }
533
- #endif
534
- return ::logf(a);
535
- }
536
-
537
- inline CUDA_CALLABLE double log(double a)
538
- {
539
- #if FP_CHECK
540
- if (!isfinite(a) || a < 0.0)
541
- {
542
- printf("%s:%d log(%f)\n", __FILE__, __LINE__, a);
543
- assert(0);
544
- }
545
- #endif
546
- return ::log(a);
547
- }
548
-
549
- inline CUDA_CALLABLE half log2(half a)
550
- {
551
- #if FP_CHECK
552
- if (!isfinite(a) || float(a) < 0.0f)
553
- {
554
- printf("%s:%d log2(%f)\n", __FILE__, __LINE__, float(a));
555
- assert(0);
556
- }
557
- #endif
558
-
559
- return ::log2f(float(a));
560
- }
561
-
562
- inline CUDA_CALLABLE float log2(float a)
563
- {
564
- #if FP_CHECK
565
- if (!isfinite(a) || a < 0.0f)
566
- {
567
- printf("%s:%d log2(%f)\n", __FILE__, __LINE__, a);
568
- assert(0);
569
- }
570
- #endif
571
-
572
- return ::log2f(a);
573
- }
574
-
575
- inline CUDA_CALLABLE double log2(double a)
576
- {
577
- #if FP_CHECK
578
- if (!isfinite(a) || a < 0.0)
579
- {
580
- printf("%s:%d log2(%f)\n", __FILE__, __LINE__, a);
581
- assert(0);
582
- }
583
- #endif
584
-
585
- return ::log2(a);
586
- }
587
-
588
- inline CUDA_CALLABLE half log10(half a)
589
- {
590
- #if FP_CHECK
591
- if (!isfinite(a) || float(a) < 0.0f)
592
- {
593
- printf("%s:%d log10(%f)\n", __FILE__, __LINE__, float(a));
594
- assert(0);
595
- }
596
- #endif
597
-
598
- return ::log10f(float(a));
599
- }
600
-
601
- inline CUDA_CALLABLE float log10(float a)
602
- {
603
- #if FP_CHECK
604
- if (!isfinite(a) || a < 0.0f)
605
- {
606
- printf("%s:%d log10(%f)\n", __FILE__, __LINE__, a);
607
- assert(0);
608
- }
609
- #endif
610
-
611
- return ::log10f(a);
612
- }
613
-
614
- inline CUDA_CALLABLE double log10(double a)
615
- {
616
- #if FP_CHECK
617
- if (!isfinite(a) || a < 0.0)
618
- {
619
- printf("%s:%d log10(%f)\n", __FILE__, __LINE__, a);
620
- assert(0);
621
- }
622
- #endif
623
-
624
- return ::log10(a);
625
- }
626
-
627
- inline CUDA_CALLABLE half exp(half a)
628
- {
629
- half result = ::expf(float(a));
630
- #if FP_CHECK
631
- if (!isfinite(a) || !isfinite(result))
632
- {
633
- printf("%s:%d exp(%f) = %f\n", __FILE__, __LINE__, float(a), float(result));
634
- assert(0);
635
- }
636
- #endif
637
- return result;
638
- }
639
- inline CUDA_CALLABLE float exp(float a)
640
- {
641
- float result = ::expf(a);
642
- #if FP_CHECK
643
- if (!isfinite(a) || !isfinite(result))
644
- {
645
- printf("%s:%d exp(%f) = %f\n", __FILE__, __LINE__, a, result);
646
- assert(0);
647
- }
648
- #endif
649
- return result;
650
- }
651
- inline CUDA_CALLABLE double exp(double a)
652
- {
653
- double result = ::exp(a);
654
- #if FP_CHECK
655
- if (!isfinite(a) || !isfinite(result))
656
- {
657
- printf("%s:%d exp(%f) = %f\n", __FILE__, __LINE__, a, result);
658
- assert(0);
659
- }
660
- #endif
661
- return result;
662
- }
663
-
664
- inline CUDA_CALLABLE half pow(half a, half b)
665
- {
666
- float result = ::powf(float(a), float(b));
667
- #if FP_CHECK
668
- if (!isfinite(float(a)) || !isfinite(float(b)) || !isfinite(result))
669
- {
670
- printf("%s:%d pow(%f, %f) = %f\n", __FILE__, __LINE__, float(a), float(b), result);
671
- assert(0);
672
- }
673
- #endif
674
- return result;
675
- }
676
-
677
- inline CUDA_CALLABLE float pow(float a, float b)
678
- {
679
- float result = ::powf(a, b);
680
- #if FP_CHECK
681
- if (!isfinite(a) || !isfinite(b) || !isfinite(result))
682
- {
683
- printf("%s:%d pow(%f, %f) = %f\n", __FILE__, __LINE__, a, b, result);
684
- assert(0);
685
- }
686
- #endif
687
- return result;
688
- }
689
-
690
- inline CUDA_CALLABLE double pow(double a, double b)
691
- {
692
- double result = ::pow(a, b);
693
- #if FP_CHECK
694
- if (!isfinite(a) || !isfinite(b) || !isfinite(result))
695
- {
696
- printf("%s:%d pow(%f, %f) = %f\n", __FILE__, __LINE__, a, b, result);
697
- assert(0);
698
- }
699
- #endif
700
- return result;
701
- }
702
-
703
- inline CUDA_CALLABLE half floordiv(half a, half b)
704
- {
705
- #if FP_CHECK
706
- if (!isfinite(a) || !isfinite(b) || float(b) == 0.0f)
707
- {
708
- printf("%s:%d mod(%f, %f)\n", __FILE__, __LINE__, float(a), float(b));
709
- assert(0);
710
- }
711
- #endif
712
- return floorf(float(a/b));
713
- }
714
- inline CUDA_CALLABLE float floordiv(float a, float b)
715
- {
716
- #if FP_CHECK
717
- if (!isfinite(a) || !isfinite(b) || b == 0.0f)
718
- {
719
- printf("%s:%d mod(%f, %f)\n", __FILE__, __LINE__, a, b);
720
- assert(0);
721
- }
722
- #endif
723
- return floorf(a/b);
724
- }
725
- inline CUDA_CALLABLE double floordiv(double a, double b)
726
- {
727
- #if FP_CHECK
728
- if (!isfinite(a) || !isfinite(b) || b == 0.0)
729
- {
730
- printf("%s:%d mod(%f, %f)\n", __FILE__, __LINE__, a, b);
731
- assert(0);
732
- }
733
- #endif
734
- return ::floor(a/b);
735
- }
736
-
737
- inline CUDA_CALLABLE float leaky_min(float a, float b, float r) { return min(a, b); }
738
- inline CUDA_CALLABLE float leaky_max(float a, float b, float r) { return max(a, b); }
739
-
740
- inline CUDA_CALLABLE half abs(half x) { return ::fabsf(float(x)); }
741
- inline CUDA_CALLABLE float abs(float x) { return ::fabsf(x); }
742
- inline CUDA_CALLABLE double abs(double x) { return ::fabs(x); }
743
-
744
- inline CUDA_CALLABLE float acos(float x){ return ::acosf(min(max(x, -1.0f), 1.0f)); }
745
- inline CUDA_CALLABLE float asin(float x){ return ::asinf(min(max(x, -1.0f), 1.0f)); }
746
- inline CUDA_CALLABLE float atan(float x) { return ::atanf(x); }
747
- inline CUDA_CALLABLE float atan2(float y, float x) { return ::atan2f(y, x); }
748
- inline CUDA_CALLABLE float sin(float x) { return ::sinf(x); }
749
- inline CUDA_CALLABLE float cos(float x) { return ::cosf(x); }
750
-
751
- inline CUDA_CALLABLE double acos(double x){ return ::acos(min(max(x, -1.0), 1.0)); }
752
- inline CUDA_CALLABLE double asin(double x){ return ::asin(min(max(x, -1.0), 1.0)); }
753
- inline CUDA_CALLABLE double atan(double x) { return ::atan(x); }
754
- inline CUDA_CALLABLE double atan2(double y, double x) { return ::atan2(y, x); }
755
- inline CUDA_CALLABLE double sin(double x) { return ::sin(x); }
756
- inline CUDA_CALLABLE double cos(double x) { return ::cos(x); }
757
-
758
- inline CUDA_CALLABLE half acos(half x){ return ::acosf(min(max(float(x), -1.0f), 1.0f)); }
759
- inline CUDA_CALLABLE half asin(half x){ return ::asinf(min(max(float(x), -1.0f), 1.0f)); }
760
- inline CUDA_CALLABLE half atan(half x) { return ::atanf(float(x)); }
761
- inline CUDA_CALLABLE half atan2(half y, half x) { return ::atan2f(float(y), float(x)); }
762
- inline CUDA_CALLABLE half sin(half x) { return ::sinf(float(x)); }
763
- inline CUDA_CALLABLE half cos(half x) { return ::cosf(float(x)); }
764
-
765
-
766
- inline CUDA_CALLABLE float sqrt(float x)
767
- {
768
- #if FP_CHECK
769
- if (x < 0.0f)
770
- {
771
- printf("%s:%d sqrt(%f)\n", __FILE__, __LINE__, x);
772
- assert(0);
773
- }
774
- #endif
775
- return ::sqrtf(x);
776
- }
777
- inline CUDA_CALLABLE double sqrt(double x)
778
- {
779
- #if FP_CHECK
780
- if (x < 0.0)
781
- {
782
- printf("%s:%d sqrt(%f)\n", __FILE__, __LINE__, x);
783
- assert(0);
784
- }
785
- #endif
786
- return ::sqrt(x);
787
- }
788
- inline CUDA_CALLABLE half sqrt(half x)
789
- {
790
- #if FP_CHECK
791
- if (float(x) < 0.0f)
792
- {
793
- printf("%s:%d sqrt(%f)\n", __FILE__, __LINE__, float(x));
794
- assert(0);
795
- }
796
- #endif
797
- return ::sqrtf(float(x));
798
- }
799
-
800
- inline CUDA_CALLABLE float cbrt(float x) { return ::cbrtf(x); }
801
- inline CUDA_CALLABLE double cbrt(double x) { return ::cbrt(x); }
802
- inline CUDA_CALLABLE half cbrt(half x) { return ::cbrtf(float(x)); }
803
-
804
- inline CUDA_CALLABLE float tan(float x) { return ::tanf(x); }
805
- inline CUDA_CALLABLE float sinh(float x) { return ::sinhf(x);}
806
- inline CUDA_CALLABLE float cosh(float x) { return ::coshf(x);}
807
- inline CUDA_CALLABLE float tanh(float x) { return ::tanhf(x);}
808
- inline CUDA_CALLABLE float degrees(float x) { return x * RAD_TO_DEG;}
809
- inline CUDA_CALLABLE float radians(float x) { return x * DEG_TO_RAD;}
810
-
811
- inline CUDA_CALLABLE double tan(double x) { return ::tan(x); }
812
- inline CUDA_CALLABLE double sinh(double x) { return ::sinh(x);}
813
- inline CUDA_CALLABLE double cosh(double x) { return ::cosh(x);}
814
- inline CUDA_CALLABLE double tanh(double x) { return ::tanh(x);}
815
- inline CUDA_CALLABLE double degrees(double x) { return x * RAD_TO_DEG;}
816
- inline CUDA_CALLABLE double radians(double x) { return x * DEG_TO_RAD;}
817
-
818
- inline CUDA_CALLABLE half tan(half x) { return ::tanf(float(x)); }
819
- inline CUDA_CALLABLE half sinh(half x) { return ::sinhf(float(x));}
820
- inline CUDA_CALLABLE half cosh(half x) { return ::coshf(float(x));}
821
- inline CUDA_CALLABLE half tanh(half x) { return ::tanhf(float(x));}
822
- inline CUDA_CALLABLE half degrees(half x) { return x * RAD_TO_DEG;}
823
- inline CUDA_CALLABLE half radians(half x) { return x * DEG_TO_RAD;}
824
-
825
- inline CUDA_CALLABLE float round(float x) { return ::roundf(x); }
826
- inline CUDA_CALLABLE float rint(float x) { return ::rintf(x); }
827
- inline CUDA_CALLABLE float trunc(float x) { return ::truncf(x); }
828
- inline CUDA_CALLABLE float floor(float x) { return ::floorf(x); }
829
- inline CUDA_CALLABLE float ceil(float x) { return ::ceilf(x); }
830
- inline CUDA_CALLABLE float frac(float x) { return x - trunc(x); }
831
-
832
- inline CUDA_CALLABLE double round(double x) { return ::round(x); }
833
- inline CUDA_CALLABLE double rint(double x) { return ::rint(x); }
834
- inline CUDA_CALLABLE double trunc(double x) { return ::trunc(x); }
835
- inline CUDA_CALLABLE double floor(double x) { return ::floor(x); }
836
- inline CUDA_CALLABLE double ceil(double x) { return ::ceil(x); }
837
- inline CUDA_CALLABLE double frac(double x) { return x - trunc(x); }
838
-
839
- inline CUDA_CALLABLE half round(half x) { return ::roundf(float(x)); }
840
- inline CUDA_CALLABLE half rint(half x) { return ::rintf(float(x)); }
841
- inline CUDA_CALLABLE half trunc(half x) { return ::truncf(float(x)); }
842
- inline CUDA_CALLABLE half floor(half x) { return ::floorf(float(x)); }
843
- inline CUDA_CALLABLE half ceil(half x) { return ::ceilf(float(x)); }
844
- inline CUDA_CALLABLE half frac(half x) { return float(x) - trunc(float(x)); }
845
-
846
- #define DECLARE_ADJOINTS(T)\
847
- inline CUDA_CALLABLE void adj_log(T a, T& adj_a, T adj_ret)\
848
- {\
849
- adj_a += (T(1)/a)*adj_ret;\
850
- DO_IF_FPCHECK(if (!isfinite(adj_a))\
851
- {\
852
- printf("%s:%d - adj_log(%f, %f, %f)\n", __FILE__, __LINE__, float(a), float(adj_a), float(adj_ret));\
853
- assert(0);\
854
- })\
855
- }\
856
- inline CUDA_CALLABLE void adj_log2(T a, T& adj_a, T adj_ret)\
857
- { \
858
- adj_a += (T(1)/a)*(T(1)/log(T(2)))*adj_ret; \
859
- DO_IF_FPCHECK(if (!isfinite(adj_a))\
860
- {\
861
- printf("%s:%d - adj_log2(%f, %f, %f)\n", __FILE__, __LINE__, float(a), float(adj_a), float(adj_ret));\
862
- assert(0);\
863
- }) \
864
- }\
865
- inline CUDA_CALLABLE void adj_log10(T a, T& adj_a, T adj_ret)\
866
- {\
867
- adj_a += (T(1)/a)*(T(1)/log(T(10)))*adj_ret; \
868
- DO_IF_FPCHECK(if (!isfinite(adj_a))\
869
- {\
870
- printf("%s:%d - adj_log10(%f, %f, %f)\n", __FILE__, __LINE__, float(a), float(adj_a), float(adj_ret));\
871
- assert(0);\
872
- })\
873
- }\
874
- inline CUDA_CALLABLE void adj_exp(T a, T ret, T& adj_a, T adj_ret) { adj_a += ret*adj_ret; }\
875
- inline CUDA_CALLABLE void adj_pow(T a, T b, T ret, T& adj_a, T& adj_b, T adj_ret)\
876
- { \
877
- adj_a += b*pow(a, b-T(1))*adj_ret;\
878
- adj_b += log(a)*ret*adj_ret;\
879
- DO_IF_FPCHECK(if (!isfinite(adj_a) || !isfinite(adj_b))\
880
- {\
881
- printf("%s:%d - adj_pow(%f, %f, %f, %f, %f)\n", __FILE__, __LINE__, float(a), float(b), float(adj_a), float(adj_b), float(adj_ret));\
882
- assert(0);\
883
- })\
884
- }\
885
- inline CUDA_CALLABLE void adj_leaky_min(T a, T b, T r, T& adj_a, T& adj_b, T& adj_r, T adj_ret)\
886
- {\
887
- if (a < b)\
888
- adj_a += adj_ret;\
889
- else\
890
- {\
891
- adj_a += r*adj_ret;\
892
- adj_b += adj_ret;\
893
- }\
894
- }\
895
- inline CUDA_CALLABLE void adj_leaky_max(T a, T b, T r, T& adj_a, T& adj_b, T& adj_r, T adj_ret)\
896
- {\
897
- if (a > b)\
898
- adj_a += adj_ret;\
899
- else\
900
- {\
901
- adj_a += r*adj_ret;\
902
- adj_b += adj_ret;\
903
- }\
904
- }\
905
- inline CUDA_CALLABLE void adj_acos(T x, T& adj_x, T adj_ret)\
906
- {\
907
- T d = sqrt(T(1)-x*x);\
908
- DO_IF_FPCHECK(adj_x -= (T(1)/d)*adj_ret;\
909
- if (!isfinite(d) || !isfinite(adj_x))\
910
- {\
911
- printf("%s:%d - adj_acos(%f, %f, %f)\n", __FILE__, __LINE__, float(x), float(adj_x), float(adj_ret)); \
912
- assert(0);\
913
- })\
914
- DO_IF_NO_FPCHECK(if (d > T(0))\
915
- adj_x -= (T(1)/d)*adj_ret;)\
916
- }\
917
- inline CUDA_CALLABLE void adj_asin(T x, T& adj_x, T adj_ret)\
918
- {\
919
- T d = sqrt(T(1)-x*x);\
920
- DO_IF_FPCHECK(adj_x += (T(1)/d)*adj_ret;\
921
- if (!isfinite(d) || !isfinite(adj_x))\
922
- {\
923
- printf("%s:%d - adj_asin(%f, %f, %f)\n", __FILE__, __LINE__, float(x), float(adj_x), float(adj_ret)); \
924
- assert(0);\
925
- })\
926
- DO_IF_NO_FPCHECK(if (d > T(0))\
927
- adj_x += (T(1)/d)*adj_ret;)\
928
- }\
929
- inline CUDA_CALLABLE void adj_tan(T x, T& adj_x, T adj_ret)\
930
- {\
931
- T cos_x = cos(x);\
932
- DO_IF_FPCHECK(adj_x += (T(1)/(cos_x*cos_x))*adj_ret;\
933
- if (!isfinite(adj_x) || cos_x == T(0))\
934
- {\
935
- printf("%s:%d - adj_tan(%f, %f, %f)\n", __FILE__, __LINE__, float(x), float(adj_x), float(adj_ret));\
936
- assert(0);\
937
- })\
938
- DO_IF_NO_FPCHECK(if (cos_x != T(0))\
939
- adj_x += (T(1)/(cos_x*cos_x))*adj_ret;)\
940
- }\
941
- inline CUDA_CALLABLE void adj_atan(T x, T& adj_x, T adj_ret)\
942
- {\
943
- adj_x += adj_ret /(x*x + T(1));\
944
- }\
945
- inline CUDA_CALLABLE void adj_atan2(T y, T x, T& adj_y, T& adj_x, T adj_ret)\
946
- {\
947
- T d = x*x + y*y;\
948
- DO_IF_FPCHECK(adj_x -= y/d*adj_ret;\
949
- adj_y += x/d*adj_ret;\
950
- if (!isfinite(adj_x) || !isfinite(adj_y) || d == T(0))\
951
- {\
952
- printf("%s:%d - adj_atan2(%f, %f, %f, %f, %f)\n", __FILE__, __LINE__, float(y), float(x), float(adj_y), float(adj_x), float(adj_ret));\
953
- assert(0);\
954
- })\
955
- DO_IF_NO_FPCHECK(if (d > T(0))\
956
- {\
957
- adj_x -= (y/d)*adj_ret;\
958
- adj_y += (x/d)*adj_ret;\
959
- })\
960
- }\
961
- inline CUDA_CALLABLE void adj_sin(T x, T& adj_x, T adj_ret)\
962
- {\
963
- adj_x += cos(x)*adj_ret;\
964
- }\
965
- inline CUDA_CALLABLE void adj_cos(T x, T& adj_x, T adj_ret)\
966
- {\
967
- adj_x -= sin(x)*adj_ret;\
968
- }\
969
- inline CUDA_CALLABLE void adj_sinh(T x, T& adj_x, T adj_ret)\
970
- {\
971
- adj_x += cosh(x)*adj_ret;\
972
- }\
973
- inline CUDA_CALLABLE void adj_cosh(T x, T& adj_x, T adj_ret)\
974
- {\
975
- adj_x += sinh(x)*adj_ret;\
976
- }\
977
- inline CUDA_CALLABLE void adj_tanh(T x, T ret, T& adj_x, T adj_ret)\
978
- {\
979
- adj_x += (T(1) - ret*ret)*adj_ret;\
980
- }\
981
- inline CUDA_CALLABLE void adj_sqrt(T x, T ret, T& adj_x, T adj_ret)\
982
- {\
983
- adj_x += T(0.5)*(T(1)/ret)*adj_ret;\
984
- DO_IF_FPCHECK(if (!isfinite(adj_x))\
985
- {\
986
- printf("%s:%d - adj_sqrt(%f, %f, %f)\n", __FILE__, __LINE__, float(x), float(adj_x), float(adj_ret));\
987
- assert(0);\
988
- })\
989
- }\
990
- inline CUDA_CALLABLE void adj_cbrt(T x, T ret, T& adj_x, T adj_ret)\
991
- {\
992
- adj_x += (T(1)/T(3))*(T(1)/(ret*ret))*adj_ret;\
993
- DO_IF_FPCHECK(if (!isfinite(adj_x))\
994
- {\
995
- printf("%s:%d - adj_cbrt(%f, %f, %f)\n", __FILE__, __LINE__, float(x), float(adj_x), float(adj_ret));\
996
- assert(0);\
997
- })\
998
- }\
999
- inline CUDA_CALLABLE void adj_degrees(T x, T& adj_x, T adj_ret)\
1000
- {\
1001
- adj_x += RAD_TO_DEG * adj_ret;\
1002
- }\
1003
- inline CUDA_CALLABLE void adj_radians(T x, T& adj_x, T adj_ret)\
1004
- {\
1005
- adj_x += DEG_TO_RAD * adj_ret;\
1006
- }\
1007
- inline CUDA_CALLABLE void adj_round(T x, T& adj_x, T adj_ret){ }\
1008
- inline CUDA_CALLABLE void adj_rint(T x, T& adj_x, T adj_ret){ }\
1009
- inline CUDA_CALLABLE void adj_trunc(T x, T& adj_x, T adj_ret){ }\
1010
- inline CUDA_CALLABLE void adj_floor(T x, T& adj_x, T adj_ret){ }\
1011
- inline CUDA_CALLABLE void adj_ceil(T x, T& adj_x, T adj_ret){ }\
1012
- inline CUDA_CALLABLE void adj_frac(T x, T& adj_x, T adj_ret){ }
1013
-
1014
- DECLARE_ADJOINTS(float16)
1015
- DECLARE_ADJOINTS(float32)
1016
- DECLARE_ADJOINTS(float64)
1017
-
1018
- template <typename C, typename T>
1019
- CUDA_CALLABLE inline T select(const C& cond, const T& a, const T& b)
1020
- {
1021
- // The double NOT operator !! casts to bool without compiler warnings.
1022
- return (!!cond) ? b : a;
1023
- }
1024
-
1025
- template <typename C, typename T>
1026
- CUDA_CALLABLE inline void adj_select(const C& cond, const T& a, const T& b, C& adj_cond, T& adj_a, T& adj_b, const T& adj_ret)
1027
- {
1028
- // The double NOT operator !! casts to bool without compiler warnings.
1029
- if (!!cond)
1030
- adj_b += adj_ret;
1031
- else
1032
- adj_a += adj_ret;
1033
- }
1034
-
1035
- template <typename T>
1036
- CUDA_CALLABLE inline T copy(const T& src)
1037
- {
1038
- return src;
1039
- }
1040
-
1041
- template <typename T>
1042
- CUDA_CALLABLE inline void adj_copy(const T& src, T& adj_src, T& adj_dest)
1043
- {
1044
- adj_src = adj_dest;
1045
- adj_dest = T{};
1046
- }
1047
-
1048
- template <typename T>
1049
- CUDA_CALLABLE inline void assign(T& dest, const T& src)
1050
- {
1051
- dest = src;
1052
- }
1053
-
1054
- template <typename T>
1055
- CUDA_CALLABLE inline void adj_assign(T& dest, const T& src, T& adj_dest, T& adj_src)
1056
- {
1057
- // this is generally a non-differentiable operation since it violates SSA,
1058
- // except in read-modify-write statements which are reversible through backpropagation
1059
- adj_src = adj_dest;
1060
- adj_dest = T{};
1061
- }
1062
-
1063
-
1064
- // some helpful operator overloads (just for C++ use, these are not adjointed)
1065
-
1066
- template <typename T>
1067
- CUDA_CALLABLE inline T& operator += (T& a, const T& b) { a = add(a, b); return a; }
1068
-
1069
- template <typename T>
1070
- CUDA_CALLABLE inline T& operator -= (T& a, const T& b) { a = sub(a, b); return a; }
1071
-
1072
- template <typename T>
1073
- CUDA_CALLABLE inline T operator+(const T& a, const T& b) { return add(a, b); }
1074
-
1075
- template <typename T>
1076
- CUDA_CALLABLE inline T operator-(const T& a, const T& b) { return sub(a, b); }
1077
-
1078
- template <typename T>
1079
- CUDA_CALLABLE inline T pos(const T& x) { return x; }
1080
- template <typename T>
1081
- CUDA_CALLABLE inline void adj_pos(const T& x, T& adj_x, const T& adj_ret) { adj_x += T(adj_ret); }
1082
-
1083
- // unary negation implemented as negative multiply, not sure the fp implications of this
1084
- // may be better as 0.0 - x?
1085
- template <typename T>
1086
- CUDA_CALLABLE inline T neg(const T& x) { return T(0.0) - x; }
1087
- template <typename T>
1088
- CUDA_CALLABLE inline void adj_neg(const T& x, T& adj_x, const T& adj_ret) { adj_x += T(-adj_ret); }
1089
-
1090
- // unary boolean negation
1091
- template <typename T>
1092
- CUDA_CALLABLE inline bool unot(const T& b) { return !b; }
1093
- template <typename T>
1094
- CUDA_CALLABLE inline void adj_unot(const T& b, T& adj_b, const bool& adj_ret) { }
1095
-
1096
- const int LAUNCH_MAX_DIMS = 4; // should match types.py
1097
-
1098
- struct launch_bounds_t
1099
- {
1100
- int shape[LAUNCH_MAX_DIMS]; // size of each dimension
1101
- int ndim; // number of valid dimension
1102
- size_t size; // total number of threads
1103
- };
1104
-
1105
- #ifndef __CUDACC__
1106
- static size_t s_threadIdx;
1107
- #endif
1108
-
1109
- inline CUDA_CALLABLE size_t grid_index()
1110
- {
1111
- #ifdef __CUDACC__
1112
- // Need to cast at least one of the variables being multiplied so that type promotion happens before the multiplication
1113
- size_t grid_index = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
1114
- return grid_index;
1115
- #else
1116
- return s_threadIdx;
1117
- #endif
1118
- }
1119
-
1120
- inline CUDA_CALLABLE int tid(size_t index)
1121
- {
1122
- // For the 1-D tid() we need to warn the user if we're about to provide a truncated index
1123
- // Only do this in _DEBUG when called from device to avoid excessive register allocation
1124
- #if defined(_DEBUG) || !defined(__CUDA_ARCH__)
1125
- if (index > 2147483647) {
1126
- printf("Warp warning: tid() is returning an overflowed int\n");
1127
- }
1128
- #endif
1129
- return static_cast<int>(index);
1130
- }
1131
-
1132
- inline CUDA_CALLABLE_DEVICE void tid(int& i, int& j, size_t index, const launch_bounds_t& launch_bounds)
1133
- {
1134
- const size_t n = launch_bounds.shape[1];
1135
-
1136
- // convert to work item
1137
- i = index/n;
1138
- j = index%n;
1139
- }
1140
-
1141
- inline CUDA_CALLABLE_DEVICE void tid(int& i, int& j, int& k, size_t index, const launch_bounds_t& launch_bounds)
1142
- {
1143
- const size_t n = launch_bounds.shape[1];
1144
- const size_t o = launch_bounds.shape[2];
1145
-
1146
- // convert to work item
1147
- i = index/(n*o);
1148
- j = index%(n*o)/o;
1149
- k = index%o;
1150
- }
1151
-
1152
- inline CUDA_CALLABLE_DEVICE void tid(int& i, int& j, int& k, int& l, size_t index, const launch_bounds_t& launch_bounds)
1153
- {
1154
- const size_t n = launch_bounds.shape[1];
1155
- const size_t o = launch_bounds.shape[2];
1156
- const size_t p = launch_bounds.shape[3];
1157
-
1158
- // convert to work item
1159
- i = index/(n*o*p);
1160
- j = index%(n*o*p)/(o*p);
1161
- k = index%(o*p)/p;
1162
- l = index%p;
1163
- }
1164
-
1165
- template<typename T>
1166
- inline CUDA_CALLABLE T atomic_add(T* buf, T value)
1167
- {
1168
- #if !defined(__CUDA_ARCH__)
1169
- T old = buf[0];
1170
- buf[0] += value;
1171
- return old;
1172
- #else
1173
- return atomicAdd(buf, value);
1174
- #endif
1175
- }
1176
-
1177
- template<>
1178
- inline CUDA_CALLABLE float16 atomic_add(float16* buf, float16 value)
1179
- {
1180
- #if !defined(__CUDA_ARCH__)
1181
- float16 old = buf[0];
1182
- buf[0] += value;
1183
- return old;
1184
- #elif defined(__clang__) // CUDA compiled by Clang
1185
- __half r = atomicAdd(reinterpret_cast<__half*>(buf), *reinterpret_cast<__half*>(&value));
1186
- return *reinterpret_cast<float16*>(&r);
1187
- #else // CUDA compiled by NVRTC
1188
- //return atomicAdd(buf, value);
1189
-
1190
- /* Define __PTR for atomicAdd prototypes below, undef after done */
1191
- #if (defined(_MSC_VER) && defined(_WIN64)) || defined(__LP64__) || defined(__CUDACC_RTC__)
1192
- #define __PTR "l"
1193
- #else
1194
- #define __PTR "r"
1195
- #endif /*(defined(_MSC_VER) && defined(_WIN64)) || defined(__LP64__) || defined(__CUDACC_RTC__)*/
1196
-
1197
- half r = 0.0;
1198
-
1199
- #if __CUDA_ARCH__ >= 700
1200
-
1201
- asm volatile ("{ atom.add.noftz.f16 %0,[%1],%2; }\n"
1202
- : "=h"(r.u)
1203
- : __PTR(buf), "h"(value.u)
1204
- : "memory");
1205
- #endif
1206
-
1207
- return r;
1208
-
1209
- #undef __PTR
1210
-
1211
- #endif // CUDA compiled by NVRTC
1212
-
1213
- }
1214
-
1215
- // emulate atomic float max
1216
- inline CUDA_CALLABLE float atomic_max(float* address, float val)
1217
- {
1218
- #if defined(__CUDA_ARCH__)
1219
- int *address_as_int = (int*)address;
1220
- int old = *address_as_int, assumed;
1221
-
1222
- while (val > __int_as_float(old))
1223
- {
1224
- assumed = old;
1225
- old = atomicCAS(address_as_int, assumed,
1226
- __float_as_int(val));
1227
- }
1228
-
1229
- return __int_as_float(old);
1230
-
1231
- #else
1232
- float old = *address;
1233
- *address = max(old, val);
1234
- return old;
1235
- #endif
1236
- }
1237
-
1238
- // emulate atomic float min/max with atomicCAS()
1239
- inline CUDA_CALLABLE float atomic_min(float* address, float val)
1240
- {
1241
- #if defined(__CUDA_ARCH__)
1242
- int *address_as_int = (int*)address;
1243
- int old = *address_as_int, assumed;
1244
-
1245
- while (val < __int_as_float(old))
1246
- {
1247
- assumed = old;
1248
- old = atomicCAS(address_as_int, assumed,
1249
- __float_as_int(val));
1250
- }
1251
-
1252
- return __int_as_float(old);
1253
-
1254
- #else
1255
- float old = *address;
1256
- *address = min(old, val);
1257
- return old;
1258
- #endif
1259
- }
1260
-
1261
- inline CUDA_CALLABLE int atomic_max(int* address, int val)
1262
- {
1263
- #if defined(__CUDA_ARCH__)
1264
- return atomicMax(address, val);
1265
-
1266
- #else
1267
- int old = *address;
1268
- *address = max(old, val);
1269
- return old;
1270
- #endif
1271
- }
1272
-
1273
- // atomic int min
1274
- inline CUDA_CALLABLE int atomic_min(int* address, int val)
1275
- {
1276
- #if defined(__CUDA_ARCH__)
1277
- return atomicMin(address, val);
1278
-
1279
- #else
1280
- int old = *address;
1281
- *address = min(old, val);
1282
- return old;
1283
- #endif
1284
- }
1285
-
1286
- // default behavior for adjoint of atomic min/max operation that accumulates gradients for all elements matching the min/max value
1287
- template <typename T>
1288
- CUDA_CALLABLE inline void adj_atomic_minmax(T *addr, T *adj_addr, const T &value, T &adj_value)
1289
- {
1290
- if (value == *addr)
1291
- adj_value += *adj_addr;
1292
- }
1293
-
1294
- // for integral types we do not accumulate gradients
1295
- CUDA_CALLABLE inline void adj_atomic_minmax(int8* buf, int8* adj_buf, const int8 &value, int8 &adj_value) { }
1296
- CUDA_CALLABLE inline void adj_atomic_minmax(uint8* buf, uint8* adj_buf, const uint8 &value, uint8 &adj_value) { }
1297
- CUDA_CALLABLE inline void adj_atomic_minmax(int16* buf, int16* adj_buf, const int16 &value, int16 &adj_value) { }
1298
- CUDA_CALLABLE inline void adj_atomic_minmax(uint16* buf, uint16* adj_buf, const uint16 &value, uint16 &adj_value) { }
1299
- CUDA_CALLABLE inline void adj_atomic_minmax(int32* buf, int32* adj_buf, const int32 &value, int32 &adj_value) { }
1300
- CUDA_CALLABLE inline void adj_atomic_minmax(uint32* buf, uint32* adj_buf, const uint32 &value, uint32 &adj_value) { }
1301
- CUDA_CALLABLE inline void adj_atomic_minmax(int64* buf, int64* adj_buf, const int64 &value, int64 &adj_value) { }
1302
- CUDA_CALLABLE inline void adj_atomic_minmax(uint64* buf, uint64* adj_buf, const uint64 &value, uint64 &adj_value) { }
1303
- CUDA_CALLABLE inline void adj_atomic_minmax(bool* buf, bool* adj_buf, const bool &value, bool &adj_value) { }
1304
-
1305
-
1306
- } // namespace wp
1307
-
1308
-
1309
- // bool and printf are defined outside of the wp namespace in crt.h, hence
1310
- // their adjoint counterparts are also defined in the global namespace.
1311
- template <typename T>
1312
- CUDA_CALLABLE inline void adj_bool(T, T&, bool) {}
1313
- inline CUDA_CALLABLE void adj_printf(const char* fmt, ...) {}
1314
-
1315
-
1316
- #include "vec.h"
1317
- #include "mat.h"
1318
- #include "quat.h"
1319
- #include "spatial.h"
1320
- #include "intersect.h"
1321
- #include "intersect_adj.h"
1322
-
1323
- //--------------
1324
- namespace wp
1325
- {
1326
-
1327
-
1328
- // dot for scalar types just to make some templates compile for scalar/vector
1329
- inline CUDA_CALLABLE float dot(float a, float b) { return mul(a, b); }
1330
- inline CUDA_CALLABLE void adj_dot(float a, float b, float& adj_a, float& adj_b, float adj_ret) { adj_mul(a, b, adj_a, adj_b, adj_ret); }
1331
- inline CUDA_CALLABLE float tensordot(float a, float b) { return mul(a, b); }
1332
-
1333
-
1334
- #define DECLARE_INTERP_FUNCS(T) \
1335
- CUDA_CALLABLE inline T smoothstep(T edge0, T edge1, T x)\
1336
- {\
1337
- x = clamp((x - edge0) / (edge1 - edge0), T(0), T(1));\
1338
- return x * x * (T(3) - T(2) * x);\
1339
- }\
1340
- CUDA_CALLABLE inline void adj_smoothstep(T edge0, T edge1, T x, T& adj_edge0, T& adj_edge1, T& adj_x, T adj_ret)\
1341
- {\
1342
- T ab = edge0 - edge1;\
1343
- T ax = edge0 - x;\
1344
- T bx = edge1 - x;\
1345
- T xb = x - edge1;\
1346
- \
1347
- if (bx / ab >= T(0) || ax / ab <= T(0))\
1348
- {\
1349
- return;\
1350
- }\
1351
- \
1352
- T ab3 = ab * ab * ab;\
1353
- T ab4 = ab3 * ab;\
1354
- adj_edge0 += adj_ret * ((T(6) * ax * bx * bx) / ab4);\
1355
- adj_edge1 += adj_ret * ((T(6) * ax * ax * xb) / ab4);\
1356
- adj_x += adj_ret * ((T(6) * ax * bx ) / ab3);\
1357
- }\
1358
- CUDA_CALLABLE inline T lerp(const T& a, const T& b, T t)\
1359
- {\
1360
- return a*(T(1)-t) + b*t;\
1361
- }\
1362
- CUDA_CALLABLE inline void adj_lerp(const T& a, const T& b, T t, T& adj_a, T& adj_b, T& adj_t, const T& adj_ret)\
1363
- {\
1364
- adj_a += adj_ret*(T(1)-t);\
1365
- adj_b += adj_ret*t;\
1366
- adj_t += b*adj_ret - a*adj_ret;\
1367
- }
1368
-
1369
- DECLARE_INTERP_FUNCS(float16)
1370
- DECLARE_INTERP_FUNCS(float32)
1371
- DECLARE_INTERP_FUNCS(float64)
1372
-
1373
- inline CUDA_CALLABLE void print(const str s)
1374
- {
1375
- printf("%s\n", s);
1376
- }
1377
-
1378
- inline CUDA_CALLABLE void print(int i)
1379
- {
1380
- printf("%d\n", i);
1381
- }
1382
-
1383
- inline CUDA_CALLABLE void print(short i)
1384
- {
1385
- printf("%hd\n", i);
1386
- }
1387
-
1388
- inline CUDA_CALLABLE void print(long i)
1389
- {
1390
- printf("%ld\n", i);
1391
- }
1392
-
1393
- inline CUDA_CALLABLE void print(long long i)
1394
- {
1395
- printf("%lld\n", i);
1396
- }
1397
-
1398
- inline CUDA_CALLABLE void print(unsigned i)
1399
- {
1400
- printf("%u\n", i);
1401
- }
1402
-
1403
- inline CUDA_CALLABLE void print(unsigned short i)
1404
- {
1405
- printf("%hu\n", i);
1406
- }
1407
-
1408
- inline CUDA_CALLABLE void print(unsigned long i)
1409
- {
1410
- printf("%lu\n", i);
1411
- }
1412
-
1413
- inline CUDA_CALLABLE void print(unsigned long long i)
1414
- {
1415
- printf("%llu\n", i);
1416
- }
1417
-
1418
- template<unsigned Length, typename Type>
1419
- inline CUDA_CALLABLE void print(vec_t<Length, Type> v)
1420
- {
1421
- for( unsigned i=0; i < Length; ++i )
1422
- {
1423
- printf("%g ", float(v[i]));
1424
- }
1425
- printf("\n");
1426
- }
1427
-
1428
- template<typename Type>
1429
- inline CUDA_CALLABLE void print(quat_t<Type> i)
1430
- {
1431
- printf("%g %g %g %g\n", float(i.x), float(i.y), float(i.z), float(i.w));
1432
- }
1433
-
1434
- template<unsigned Rows,unsigned Cols,typename Type>
1435
- inline CUDA_CALLABLE void print(const mat_t<Rows,Cols,Type> &m)
1436
- {
1437
- for( unsigned i=0; i< Rows; ++i )
1438
- {
1439
- for( unsigned j=0; j< Cols; ++j )
1440
- {
1441
- printf("%g ",float(m.data[i][j]));
1442
- }
1443
- printf("\n");
1444
- }
1445
- }
1446
-
1447
- template<typename Type>
1448
- inline CUDA_CALLABLE void print(transform_t<Type> t)
1449
- {
1450
- printf("(%g %g %g) (%g %g %g %g)\n", float(t.p[0]), float(t.p[1]), float(t.p[2]), float(t.q.x), float(t.q.y), float(t.q.z), float(t.q.w));
1451
- }
1452
-
1453
- inline CUDA_CALLABLE void adj_print(int i, int adj_i) { printf("%d adj: %d\n", i, adj_i); }
1454
- inline CUDA_CALLABLE void adj_print(float f, float adj_f) { printf("%g adj: %g\n", f, adj_f); }
1455
- inline CUDA_CALLABLE void adj_print(short f, short adj_f) { printf("%hd adj: %hd\n", f, adj_f); }
1456
- inline CUDA_CALLABLE void adj_print(long f, long adj_f) { printf("%ld adj: %ld\n", f, adj_f); }
1457
- inline CUDA_CALLABLE void adj_print(long long f, long long adj_f) { printf("%lld adj: %lld\n", f, adj_f); }
1458
- inline CUDA_CALLABLE void adj_print(unsigned f, unsigned adj_f) { printf("%u adj: %u\n", f, adj_f); }
1459
- inline CUDA_CALLABLE void adj_print(unsigned short f, unsigned short adj_f) { printf("%hu adj: %hu\n", f, adj_f); }
1460
- inline CUDA_CALLABLE void adj_print(unsigned long f, unsigned long adj_f) { printf("%lu adj: %lu\n", f, adj_f); }
1461
- inline CUDA_CALLABLE void adj_print(unsigned long long f, unsigned long long adj_f) { printf("%llu adj: %llu\n", f, adj_f); }
1462
- inline CUDA_CALLABLE void adj_print(half h, half adj_h) { printf("%g adj: %g\n", half_to_float(h), half_to_float(adj_h)); }
1463
- inline CUDA_CALLABLE void adj_print(double f, double adj_f) { printf("%g adj: %g\n", f, adj_f); }
1464
-
1465
- template<unsigned Length, typename Type>
1466
- inline CUDA_CALLABLE void adj_print(vec_t<Length, Type> v, vec_t<Length, Type>& adj_v) { printf("%g %g adj: %g %g \n", v[0], v[1], adj_v[0], adj_v[1]); }
1467
-
1468
- template<unsigned Rows, unsigned Cols, typename Type>
1469
- inline CUDA_CALLABLE void adj_print(mat_t<Rows, Cols, Type> m, mat_t<Rows, Cols, Type>& adj_m) { }
1470
-
1471
- template<typename Type>
1472
- inline CUDA_CALLABLE void adj_print(quat_t<Type> q, quat_t<Type>& adj_q) { printf("%g %g %g %g adj: %g %g %g %g\n", q.x, q.y, q.z, q.w, adj_q.x, adj_q.y, adj_q.z, adj_q.w); }
1473
-
1474
- template<typename Type>
1475
- inline CUDA_CALLABLE void adj_print(transform_t<Type> t, transform_t<Type>& adj_t) {}
1476
-
1477
- inline CUDA_CALLABLE void adj_print(str t, str& adj_t) {}
1478
-
1479
-
1480
- template <typename T>
1481
- inline CUDA_CALLABLE void expect_eq(const T& actual, const T& expected)
1482
- {
1483
- if (!(actual == expected))
1484
- {
1485
- printf("Error, expect_eq() failed:\n");
1486
- printf("\t Expected: "); print(expected);
1487
- printf("\t Actual: "); print(actual);
1488
- }
1489
- }
1490
-
1491
- template <typename T>
1492
- inline CUDA_CALLABLE void adj_expect_eq(const T& a, const T& b, T& adj_a, T& adj_b)
1493
- {
1494
- // nop
1495
- }
1496
-
1497
- template <typename T>
1498
- inline CUDA_CALLABLE void expect_neq(const T& actual, const T& expected)
1499
- {
1500
- if (actual == expected)
1501
- {
1502
- printf("Error, expect_neq() failed:\n");
1503
- printf("\t Expected: "); print(expected);
1504
- printf("\t Actual: "); print(actual);
1505
- }
1506
- }
1507
-
1508
- template <typename T>
1509
- inline CUDA_CALLABLE void adj_expect_neq(const T& a, const T& b, T& adj_a, T& adj_b)
1510
- {
1511
- // nop
1512
- }
1513
-
1514
- template <typename T>
1515
- inline CUDA_CALLABLE void expect_near(const T& actual, const T& expected, const T& tolerance)
1516
- {
1517
- if (abs(actual - expected) > tolerance)
1518
- {
1519
- printf("Error, expect_near() failed with tolerance "); print(tolerance);
1520
- printf("\t Expected: "); print(expected);
1521
- printf("\t Actual: "); print(actual);
1522
- }
1523
- }
1524
-
1525
- inline CUDA_CALLABLE void expect_near(const vec3& actual, const vec3& expected, const float& tolerance)
1526
- {
1527
- const float diff = max(max(abs(actual[0] - expected[0]), abs(actual[1] - expected[1])), abs(actual[2] - expected[2]));
1528
- if (diff > tolerance)
1529
- {
1530
- printf("Error, expect_near() failed with tolerance "); print(tolerance);
1531
- printf("\t Expected: "); print(expected);
1532
- printf("\t Actual: "); print(actual);
1533
- }
1534
- }
1535
-
1536
- template <typename T>
1537
- inline CUDA_CALLABLE void adj_expect_near(const T& actual, const T& expected, const T& tolerance, T& adj_actual, T& adj_expected, T& adj_tolerance)
1538
- {
1539
- // nop
1540
- }
1541
-
1542
- inline CUDA_CALLABLE void adj_expect_near(const vec3& actual, const vec3& expected, float tolerance, vec3& adj_actual, vec3& adj_expected, float adj_tolerance)
1543
- {
1544
- // nop
1545
- }
1546
-
1547
-
1548
- } // namespace wp
1549
-
1550
- // include array.h so we have the print, isfinite functions for the inner array types defined
1551
- #include "array.h"
1552
- #include "mesh.h"
1553
- #include "bvh.h"
1554
- #include "svd.h"
1555
- #include "hashgrid.h"
1556
- #include "volume.h"
1557
- #include "range.h"
1558
- #include "rand.h"
1559
- #include "noise.h"
1560
- #include "matnn.h"
1
+ /** Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
2
+ * NVIDIA CORPORATION and its licensors retain all intellectual property
3
+ * and proprietary rights in and to this software, related documentation
4
+ * and any modifications thereto. Any use, reproduction, disclosure or
5
+ * distribution of this software and related documentation without an express
6
+ * license agreement from NVIDIA CORPORATION is strictly prohibited.
7
+ */
8
+
9
+ #pragma once
10
+
11
+ // All built-in types and functions. To be compatible with runtime NVRTC compilation
12
+ // this header must be independently compilable (i.e.: without external SDK headers)
13
+ // to achieve this we redefine a subset of CRT functions (printf, pow, sin, cos, etc)
14
+
15
+ #include "crt.h"
16
+
17
+ #ifdef _WIN32
18
+ #define __restrict__ __restrict
19
+ #endif
20
+
21
+ #if !defined(__CUDACC__)
22
+ #define CUDA_CALLABLE
23
+ #define CUDA_CALLABLE_DEVICE
24
+ #else
25
+ #define CUDA_CALLABLE __host__ __device__
26
+ #define CUDA_CALLABLE_DEVICE __device__
27
+ #endif
28
+
29
+ #ifdef WP_VERIFY_FP
30
+ #define FP_CHECK 1
31
+ #define DO_IF_FPCHECK(X) {X}
32
+ #define DO_IF_NO_FPCHECK(X)
33
+ #else
34
+ #define FP_CHECK 0
35
+ #define DO_IF_FPCHECK(X)
36
+ #define DO_IF_NO_FPCHECK(X) {X}
37
+ #endif
38
+
39
+ #define RAD_TO_DEG 57.29577951308232087679
40
+ #define DEG_TO_RAD 0.01745329251994329577
41
+
42
+ #if defined(__CUDACC__) && !defined(_MSC_VER)
43
+ __device__ void __debugbreak() {}
44
+ #endif
45
+
46
+ namespace wp
47
+ {
48
+
49
+ // numeric types (used from generated kernels)
50
+ typedef float float32;
51
+ typedef double float64;
52
+
53
+ typedef int8_t int8;
54
+ typedef uint8_t uint8;
55
+
56
+ typedef int16_t int16;
57
+ typedef uint16_t uint16;
58
+
59
+ typedef int32_t int32;
60
+ typedef uint32_t uint32;
61
+
62
+ typedef int64_t int64;
63
+ typedef uint64_t uint64;
64
+
65
+
66
+ // matches Python string type for constant strings
67
+ typedef const char* str;
68
+
69
+
70
+
71
+ struct half;
72
+
73
+ CUDA_CALLABLE half float_to_half(float x);
74
+ CUDA_CALLABLE float half_to_float(half x);
75
+
76
+ struct half
77
+ {
78
+ CUDA_CALLABLE inline half() : u(0) {}
79
+
80
+ CUDA_CALLABLE inline half(float f)
81
+ {
82
+ *this = float_to_half(f);
83
+ }
84
+
85
+ unsigned short u;
86
+
87
+ CUDA_CALLABLE inline bool operator==(const half& h) const
88
+ {
89
+ // Use float32 to get IEEE 754 behavior in case of a NaN
90
+ return float32(h) == float32(*this);
91
+ }
92
+
93
+ CUDA_CALLABLE inline bool operator!=(const half& h) const
94
+ {
95
+ // Use float32 to get IEEE 754 behavior in case of a NaN
96
+ return float32(h) != float32(*this);
97
+ }
98
+ CUDA_CALLABLE inline bool operator>(const half& h) const { return half_to_float(*this) > half_to_float(h); }
99
+ CUDA_CALLABLE inline bool operator>=(const half& h) const { return half_to_float(*this) >= half_to_float(h); }
100
+ CUDA_CALLABLE inline bool operator<(const half& h) const { return half_to_float(*this) < half_to_float(h); }
101
+ CUDA_CALLABLE inline bool operator<=(const half& h) const { return half_to_float(*this) <= half_to_float(h); }
102
+
103
+ CUDA_CALLABLE inline bool operator!() const
104
+ {
105
+ return float32(*this) == 0;
106
+ }
107
+
108
+ CUDA_CALLABLE inline half operator*=(const half& h)
109
+ {
110
+ half prod = half(float32(*this) * float32(h));
111
+ this->u = prod.u;
112
+ return *this;
113
+ }
114
+
115
+ CUDA_CALLABLE inline half operator/=(const half& h)
116
+ {
117
+ half quot = half(float32(*this) / float32(h));
118
+ this->u = quot.u;
119
+ return *this;
120
+ }
121
+
122
+ CUDA_CALLABLE inline half operator+=(const half& h)
123
+ {
124
+ half sum = half(float32(*this) + float32(h));
125
+ this->u = sum.u;
126
+ return *this;
127
+ }
128
+
129
+ CUDA_CALLABLE inline half operator-=(const half& h)
130
+ {
131
+ half diff = half(float32(*this) - float32(h));
132
+ this->u = diff.u;
133
+ return *this;
134
+ }
135
+
136
+ CUDA_CALLABLE inline operator float32() const { return float32(half_to_float(*this)); }
137
+ CUDA_CALLABLE inline operator float64() const { return float64(half_to_float(*this)); }
138
+ CUDA_CALLABLE inline operator int8() const { return int8(half_to_float(*this)); }
139
+ CUDA_CALLABLE inline operator uint8() const { return uint8(half_to_float(*this)); }
140
+ CUDA_CALLABLE inline operator int16() const { return int16(half_to_float(*this)); }
141
+ CUDA_CALLABLE inline operator uint16() const { return uint16(half_to_float(*this)); }
142
+ CUDA_CALLABLE inline operator int32() const { return int32(half_to_float(*this)); }
143
+ CUDA_CALLABLE inline operator uint32() const { return uint32(half_to_float(*this)); }
144
+ CUDA_CALLABLE inline operator int64() const { return int64(half_to_float(*this)); }
145
+ CUDA_CALLABLE inline operator uint64() const { return uint64(half_to_float(*this)); }
146
+ };
147
+
148
+ static_assert(sizeof(half) == 2, "Size of half / float16 type must be 2-bytes");
149
+
150
+ typedef half float16;
151
+
152
+ #if defined(__CUDA_ARCH__)
153
+
154
+ CUDA_CALLABLE inline half float_to_half(float x)
155
+ {
156
+ half h;
157
+ asm("{ cvt.rn.f16.f32 %0, %1;}\n" : "=h"(h.u) : "f"(x));
158
+ return h;
159
+ }
160
+
161
+ CUDA_CALLABLE inline float half_to_float(half x)
162
+ {
163
+ float val;
164
+ asm("{ cvt.f32.f16 %0, %1;}\n" : "=f"(val) : "h"(x.u));
165
+ return val;
166
+ }
167
+
168
+ #elif defined(__clang__)
169
+
170
+ // _Float16 is Clang's native half-precision floating-point type
171
+ inline half float_to_half(float x)
172
+ {
173
+
174
+ _Float16 f16 = static_cast<_Float16>(x);
175
+ return *reinterpret_cast<half*>(&f16);
176
+ }
177
+
178
+ inline float half_to_float(half h)
179
+ {
180
+ _Float16 f16 = *reinterpret_cast<_Float16*>(&h);
181
+ return static_cast<float>(f16);
182
+ }
183
+
184
+ #else // Native C++ for Warp builtins outside of kernels
185
+
186
+ extern "C" WP_API uint16_t float_to_half_bits(float x);
187
+ extern "C" WP_API float half_bits_to_float(uint16_t u);
188
+
189
+ inline half float_to_half(float x)
190
+ {
191
+ half h;
192
+ h.u = float_to_half_bits(x);
193
+ return h;
194
+ }
195
+
196
+ inline float half_to_float(half h)
197
+ {
198
+ return half_bits_to_float(h.u);
199
+ }
200
+
201
+ #endif
202
+
203
+
204
+ // BAD operator implementations for fp16 arithmetic...
205
+
206
+ // negation:
207
+ inline CUDA_CALLABLE half operator - (half a)
208
+ {
209
+ return float_to_half( -half_to_float(a) );
210
+ }
211
+
212
+ inline CUDA_CALLABLE half operator + (half a,half b)
213
+ {
214
+ return float_to_half( half_to_float(a) + half_to_float(b) );
215
+ }
216
+
217
+ inline CUDA_CALLABLE half operator - (half a,half b)
218
+ {
219
+ return float_to_half( half_to_float(a) - half_to_float(b) );
220
+ }
221
+
222
+ inline CUDA_CALLABLE half operator * (half a,half b)
223
+ {
224
+ return float_to_half( half_to_float(a) * half_to_float(b) );
225
+ }
226
+
227
+ inline CUDA_CALLABLE half operator * (half a,double b)
228
+ {
229
+ return float_to_half( half_to_float(a) * b );
230
+ }
231
+
232
+ inline CUDA_CALLABLE half operator * (double a,half b)
233
+ {
234
+ return float_to_half( a * half_to_float(b) );
235
+ }
236
+
237
+ inline CUDA_CALLABLE half operator / (half a,half b)
238
+ {
239
+ return float_to_half( half_to_float(a) / half_to_float(b) );
240
+ }
241
+
242
+
243
+
244
+
245
+
246
+ template <typename T>
247
+ CUDA_CALLABLE float cast_float(T x) { return (float)(x); }
248
+
249
+ template <typename T>
250
+ CUDA_CALLABLE int cast_int(T x) { return (int)(x); }
251
+
252
+ template <typename T>
253
+ CUDA_CALLABLE void adj_cast_float(T x, T& adj_x, float adj_ret) { adj_x += T(adj_ret); }
254
+
255
+ template <typename T>
256
+ CUDA_CALLABLE void adj_cast_int(T x, T& adj_x, int adj_ret) { adj_x += adj_ret; }
257
+
258
+ template <typename T>
259
+ CUDA_CALLABLE inline void adj_int8(T, T&, int8) {}
260
+ template <typename T>
261
+ CUDA_CALLABLE inline void adj_uint8(T, T&, uint8) {}
262
+ template <typename T>
263
+ CUDA_CALLABLE inline void adj_int16(T, T&, int16) {}
264
+ template <typename T>
265
+ CUDA_CALLABLE inline void adj_uint16(T, T&, uint16) {}
266
+ template <typename T>
267
+ CUDA_CALLABLE inline void adj_int32(T, T&, int32) {}
268
+ template <typename T>
269
+ CUDA_CALLABLE inline void adj_uint32(T, T&, uint32) {}
270
+ template <typename T>
271
+ CUDA_CALLABLE inline void adj_int64(T, T&, int64) {}
272
+ template <typename T>
273
+ CUDA_CALLABLE inline void adj_uint64(T, T&, uint64) {}
274
+
275
+
276
+ template <typename T>
277
+ CUDA_CALLABLE inline void adj_float16(T x, T& adj_x, float16 adj_ret) { adj_x += T(adj_ret); }
278
+ template <typename T>
279
+ CUDA_CALLABLE inline void adj_float32(T x, T& adj_x, float32 adj_ret) { adj_x += T(adj_ret); }
280
+ template <typename T>
281
+ CUDA_CALLABLE inline void adj_float64(T x, T& adj_x, float64 adj_ret) { adj_x += T(adj_ret); }
282
+
283
+
284
+ #define kEps 0.0f
285
+
286
+ // basic ops for integer types
287
+ #define DECLARE_INT_OPS(T) \
288
+ inline CUDA_CALLABLE T mul(T a, T b) { return a*b; } \
289
+ inline CUDA_CALLABLE T div(T a, T b) { return a/b; } \
290
+ inline CUDA_CALLABLE T add(T a, T b) { return a+b; } \
291
+ inline CUDA_CALLABLE T sub(T a, T b) { return a-b; } \
292
+ inline CUDA_CALLABLE T mod(T a, T b) { return a%b; } \
293
+ inline CUDA_CALLABLE T min(T a, T b) { return a<b?a:b; } \
294
+ inline CUDA_CALLABLE T max(T a, T b) { return a>b?a:b; } \
295
+ inline CUDA_CALLABLE T clamp(T x, T a, T b) { return min(max(a, x), b); } \
296
+ inline CUDA_CALLABLE T floordiv(T a, T b) { return a/b; } \
297
+ inline CUDA_CALLABLE T nonzero(T x) { return x == T(0) ? T(0) : T(1); } \
298
+ inline CUDA_CALLABLE T sqrt(T x) { return 0; } \
299
+ inline CUDA_CALLABLE T bit_and(T a, T b) { return a&b; } \
300
+ inline CUDA_CALLABLE T bit_or(T a, T b) { return a|b; } \
301
+ inline CUDA_CALLABLE T bit_xor(T a, T b) { return a^b; } \
302
+ inline CUDA_CALLABLE T lshift(T a, T b) { return a<<b; } \
303
+ inline CUDA_CALLABLE T rshift(T a, T b) { return a>>b; } \
304
+ inline CUDA_CALLABLE T invert(T x) { return ~x; } \
305
+ inline CUDA_CALLABLE bool isfinite(T x) { return ::isfinite(double(x)); } \
306
+ inline CUDA_CALLABLE bool isnan(T x) { return ::isnan(double(x)); } \
307
+ inline CUDA_CALLABLE bool isinf(T x) { return ::isinf(double(x)); } \
308
+ inline CUDA_CALLABLE void adj_mul(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
309
+ inline CUDA_CALLABLE void adj_div(T a, T b, T ret, T& adj_a, T& adj_b, T adj_ret) { } \
310
+ inline CUDA_CALLABLE void adj_add(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
311
+ inline CUDA_CALLABLE void adj_sub(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
312
+ inline CUDA_CALLABLE void adj_mod(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
313
+ inline CUDA_CALLABLE void adj_min(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
314
+ inline CUDA_CALLABLE void adj_max(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
315
+ inline CUDA_CALLABLE void adj_abs(T x, T adj_x, T& adj_ret) { } \
316
+ inline CUDA_CALLABLE void adj_sign(T x, T adj_x, T& adj_ret) { } \
317
+ inline CUDA_CALLABLE void adj_clamp(T x, T a, T b, T& adj_x, T& adj_a, T& adj_b, T adj_ret) { } \
318
+ inline CUDA_CALLABLE void adj_floordiv(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
319
+ inline CUDA_CALLABLE void adj_step(T x, T& adj_x, T adj_ret) { } \
320
+ inline CUDA_CALLABLE void adj_nonzero(T x, T& adj_x, T adj_ret) { } \
321
+ inline CUDA_CALLABLE void adj_sqrt(T x, T adj_x, T& adj_ret) { } \
322
+ inline CUDA_CALLABLE void adj_bit_and(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
323
+ inline CUDA_CALLABLE void adj_bit_or(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
324
+ inline CUDA_CALLABLE void adj_bit_xor(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
325
+ inline CUDA_CALLABLE void adj_lshift(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
326
+ inline CUDA_CALLABLE void adj_rshift(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
327
+ inline CUDA_CALLABLE void adj_invert(T x, T adj_x, T& adj_ret) { } \
328
+ inline CUDA_CALLABLE void adj_isnan(const T&, T&, bool) { } \
329
+ inline CUDA_CALLABLE void adj_isinf(const T&, T&, bool) { } \
330
+ inline CUDA_CALLABLE void adj_isfinite(const T&, T&, bool) { }
331
+
332
+ inline CUDA_CALLABLE int8 abs(int8 x) { return ::abs(x); }
333
+ inline CUDA_CALLABLE int16 abs(int16 x) { return ::abs(x); }
334
+ inline CUDA_CALLABLE int32 abs(int32 x) { return ::abs(x); }
335
+ inline CUDA_CALLABLE int64 abs(int64 x) { return ::llabs(x); }
336
+ inline CUDA_CALLABLE uint8 abs(uint8 x) { return x; }
337
+ inline CUDA_CALLABLE uint16 abs(uint16 x) { return x; }
338
+ inline CUDA_CALLABLE uint32 abs(uint32 x) { return x; }
339
+ inline CUDA_CALLABLE uint64 abs(uint64 x) { return x; }
340
+
341
+ DECLARE_INT_OPS(int8)
342
+ DECLARE_INT_OPS(int16)
343
+ DECLARE_INT_OPS(int32)
344
+ DECLARE_INT_OPS(int64)
345
+ DECLARE_INT_OPS(uint8)
346
+ DECLARE_INT_OPS(uint16)
347
+ DECLARE_INT_OPS(uint32)
348
+ DECLARE_INT_OPS(uint64)
349
+
350
+
351
+ inline CUDA_CALLABLE int8 step(int8 x) { return x < 0 ? 1 : 0; }
352
+ inline CUDA_CALLABLE int16 step(int16 x) { return x < 0 ? 1 : 0; }
353
+ inline CUDA_CALLABLE int32 step(int32 x) { return x < 0 ? 1 : 0; }
354
+ inline CUDA_CALLABLE int64 step(int64 x) { return x < 0 ? 1 : 0; }
355
+ inline CUDA_CALLABLE uint8 step(uint8 x) { return 0; }
356
+ inline CUDA_CALLABLE uint16 step(uint16 x) { return 0; }
357
+ inline CUDA_CALLABLE uint32 step(uint32 x) { return 0; }
358
+ inline CUDA_CALLABLE uint64 step(uint64 x) { return 0; }
359
+
360
+
361
+ inline CUDA_CALLABLE int8 sign(int8 x) { return x < 0 ? -1 : 1; }
362
+ inline CUDA_CALLABLE int8 sign(int16 x) { return x < 0 ? -1 : 1; }
363
+ inline CUDA_CALLABLE int8 sign(int32 x) { return x < 0 ? -1 : 1; }
364
+ inline CUDA_CALLABLE int8 sign(int64 x) { return x < 0 ? -1 : 1; }
365
+ inline CUDA_CALLABLE uint8 sign(uint8 x) { return 1; }
366
+ inline CUDA_CALLABLE uint16 sign(uint16 x) { return 1; }
367
+ inline CUDA_CALLABLE uint32 sign(uint32 x) { return 1; }
368
+ inline CUDA_CALLABLE uint64 sign(uint64 x) { return 1; }
369
+
370
+
371
+ // Catch-all for non-float, non-integer types
372
+ template<typename T>
373
+ inline bool CUDA_CALLABLE isfinite(const T&)
374
+ {
375
+ return true;
376
+ }
377
+
378
+ inline bool CUDA_CALLABLE isfinite(half x)
379
+ {
380
+ return ::isfinite(float(x));
381
+ }
382
+ inline bool CUDA_CALLABLE isfinite(float x)
383
+ {
384
+ return ::isfinite(x);
385
+ }
386
+ inline bool CUDA_CALLABLE isfinite(double x)
387
+ {
388
+ return ::isfinite(x);
389
+ }
390
+
391
+ inline bool CUDA_CALLABLE isnan(half x)
392
+ {
393
+ return ::isnan(float(x));
394
+ }
395
+ inline bool CUDA_CALLABLE isnan(float x)
396
+ {
397
+ return ::isnan(x);
398
+ }
399
+ inline bool CUDA_CALLABLE isnan(double x)
400
+ {
401
+ return ::isnan(x);
402
+ }
403
+
404
+ inline bool CUDA_CALLABLE isinf(half x)
405
+ {
406
+ return ::isinf(float(x));
407
+ }
408
+ inline bool CUDA_CALLABLE isinf(float x)
409
+ {
410
+ return ::isinf(x);
411
+ }
412
+ inline bool CUDA_CALLABLE isinf(double x)
413
+ {
414
+ return ::isinf(x);
415
+ }
416
+
417
+ template<typename T>
418
+ inline CUDA_CALLABLE void print(const T&)
419
+ {
420
+ printf("<type without print implementation>\n");
421
+ }
422
+
423
+ inline CUDA_CALLABLE void print(float16 f)
424
+ {
425
+ printf("%g\n", half_to_float(f));
426
+ }
427
+
428
+ inline CUDA_CALLABLE void print(float f)
429
+ {
430
+ printf("%g\n", f);
431
+ }
432
+
433
+ inline CUDA_CALLABLE void print(double f)
434
+ {
435
+ printf("%g\n", f);
436
+ }
437
+
438
+
439
+ // basic ops for float types
440
+ #define DECLARE_FLOAT_OPS(T) \
441
+ inline CUDA_CALLABLE T mul(T a, T b) { return a*b; } \
442
+ inline CUDA_CALLABLE T add(T a, T b) { return a+b; } \
443
+ inline CUDA_CALLABLE T sub(T a, T b) { return a-b; } \
444
+ inline CUDA_CALLABLE T min(T a, T b) { return a<b?a:b; } \
445
+ inline CUDA_CALLABLE T max(T a, T b) { return a>b?a:b; } \
446
+ inline CUDA_CALLABLE T sign(T x) { return x < T(0) ? -1 : 1; } \
447
+ inline CUDA_CALLABLE T step(T x) { return x < T(0) ? T(1) : T(0); }\
448
+ inline CUDA_CALLABLE T nonzero(T x) { return x == T(0) ? T(0) : T(1); }\
449
+ inline CUDA_CALLABLE T clamp(T x, T a, T b) { return min(max(a, x), b); }\
450
+ inline CUDA_CALLABLE void adj_abs(T x, T& adj_x, T adj_ret) \
451
+ {\
452
+ if (x < T(0))\
453
+ adj_x -= adj_ret;\
454
+ else\
455
+ adj_x += adj_ret;\
456
+ }\
457
+ inline CUDA_CALLABLE void adj_mul(T a, T b, T& adj_a, T& adj_b, T adj_ret) { adj_a += b*adj_ret; adj_b += a*adj_ret; } \
458
+ inline CUDA_CALLABLE void adj_add(T a, T b, T& adj_a, T& adj_b, T adj_ret) { adj_a += adj_ret; adj_b += adj_ret; } \
459
+ inline CUDA_CALLABLE void adj_sub(T a, T b, T& adj_a, T& adj_b, T adj_ret) { adj_a += adj_ret; adj_b -= adj_ret; } \
460
+ inline CUDA_CALLABLE void adj_min(T a, T b, T& adj_a, T& adj_b, T adj_ret) \
461
+ { \
462
+ if (a < b) \
463
+ adj_a += adj_ret; \
464
+ else \
465
+ adj_b += adj_ret; \
466
+ } \
467
+ inline CUDA_CALLABLE void adj_max(T a, T b, T& adj_a, T& adj_b, T adj_ret) \
468
+ { \
469
+ if (a > b) \
470
+ adj_a += adj_ret; \
471
+ else \
472
+ adj_b += adj_ret; \
473
+ } \
474
+ inline CUDA_CALLABLE void adj_floordiv(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
475
+ inline CUDA_CALLABLE void adj_mod(T a, T b, T& adj_a, T& adj_b, T adj_ret){ adj_a += adj_ret; }\
476
+ inline CUDA_CALLABLE void adj_sign(T x, T adj_x, T& adj_ret) { }\
477
+ inline CUDA_CALLABLE void adj_step(T x, T& adj_x, T adj_ret) { }\
478
+ inline CUDA_CALLABLE void adj_nonzero(T x, T& adj_x, T adj_ret) { }\
479
+ inline CUDA_CALLABLE void adj_clamp(T x, T a, T b, T& adj_x, T& adj_a, T& adj_b, T adj_ret)\
480
+ {\
481
+ if (x < a)\
482
+ adj_a += adj_ret;\
483
+ else if (x > b)\
484
+ adj_b += adj_ret;\
485
+ else\
486
+ adj_x += adj_ret;\
487
+ }\
488
+ inline CUDA_CALLABLE T div(T a, T b)\
489
+ {\
490
+ DO_IF_FPCHECK(\
491
+ if (!isfinite(a) || !isfinite(b) || b == T(0))\
492
+ {\
493
+ printf("%s:%d div(%f, %f)\n", __FILE__, __LINE__, float(a), float(b));\
494
+ assert(0);\
495
+ })\
496
+ return a/b;\
497
+ }\
498
+ inline CUDA_CALLABLE void adj_div(T a, T b, T ret, T& adj_a, T& adj_b, T adj_ret)\
499
+ {\
500
+ adj_a += adj_ret/b;\
501
+ adj_b -= adj_ret*(ret)/b;\
502
+ DO_IF_FPCHECK(\
503
+ if (!isfinite(adj_a) || !isfinite(adj_b))\
504
+ {\
505
+ printf("%s:%d - adj_div(%f, %f, %f, %f, %f)\n", __FILE__, __LINE__, float(a), float(b), float(adj_a), float(adj_b), float(adj_ret));\
506
+ assert(0);\
507
+ })\
508
+ }\
509
+ inline CUDA_CALLABLE void adj_isnan(const T&, T&, bool) { }\
510
+ inline CUDA_CALLABLE void adj_isinf(const T&, T&, bool) { }\
511
+ inline CUDA_CALLABLE void adj_isfinite(const T&, T&, bool) { }
512
+
513
+ DECLARE_FLOAT_OPS(float16)
514
+ DECLARE_FLOAT_OPS(float32)
515
+ DECLARE_FLOAT_OPS(float64)
516
+
517
+
518
+
519
+ // basic ops for float types
520
+ inline CUDA_CALLABLE float16 mod(float16 a, float16 b)
521
+ {
522
+ #if FP_CHECK
523
+ if (!isfinite(a) || !isfinite(b) || float(b) == 0.0f)
524
+ {
525
+ printf("%s:%d mod(%f, %f)\n", __FILE__, __LINE__, float(a), float(b));
526
+ assert(0);
527
+ }
528
+ #endif
529
+ return fmodf(float(a), float(b));
530
+ }
531
+
532
+ inline CUDA_CALLABLE float32 mod(float32 a, float32 b)
533
+ {
534
+ #if FP_CHECK
535
+ if (!isfinite(a) || !isfinite(b) || b == 0.0f)
536
+ {
537
+ printf("%s:%d mod(%f, %f)\n", __FILE__, __LINE__, a, b);
538
+ assert(0);
539
+ }
540
+ #endif
541
+ return fmodf(a, b);
542
+ }
543
+
544
+ inline CUDA_CALLABLE double mod(double a, double b)
545
+ {
546
+ #if FP_CHECK
547
+ if (!isfinite(a) || !isfinite(b) || b == 0.0f)
548
+ {
549
+ printf("%s:%d mod(%f, %f)\n", __FILE__, __LINE__, a, b);
550
+ assert(0);
551
+ }
552
+ #endif
553
+ return fmod(a, b);
554
+ }
555
+
556
+ inline CUDA_CALLABLE half log(half a)
557
+ {
558
+ #if FP_CHECK
559
+ if (!isfinite(a) || float(a) < 0.0f)
560
+ {
561
+ printf("%s:%d log(%f)\n", __FILE__, __LINE__, float(a));
562
+ assert(0);
563
+ }
564
+ #endif
565
+ return ::logf(a);
566
+ }
567
+
568
+ inline CUDA_CALLABLE float log(float a)
569
+ {
570
+ #if FP_CHECK
571
+ if (!isfinite(a) || a < 0.0f)
572
+ {
573
+ printf("%s:%d log(%f)\n", __FILE__, __LINE__, a);
574
+ assert(0);
575
+ }
576
+ #endif
577
+ return ::logf(a);
578
+ }
579
+
580
+ inline CUDA_CALLABLE double log(double a)
581
+ {
582
+ #if FP_CHECK
583
+ if (!isfinite(a) || a < 0.0)
584
+ {
585
+ printf("%s:%d log(%f)\n", __FILE__, __LINE__, a);
586
+ assert(0);
587
+ }
588
+ #endif
589
+ return ::log(a);
590
+ }
591
+
592
+ inline CUDA_CALLABLE half log2(half a)
593
+ {
594
+ #if FP_CHECK
595
+ if (!isfinite(a) || float(a) < 0.0f)
596
+ {
597
+ printf("%s:%d log2(%f)\n", __FILE__, __LINE__, float(a));
598
+ assert(0);
599
+ }
600
+ #endif
601
+
602
+ return ::log2f(float(a));
603
+ }
604
+
605
+ inline CUDA_CALLABLE float log2(float a)
606
+ {
607
+ #if FP_CHECK
608
+ if (!isfinite(a) || a < 0.0f)
609
+ {
610
+ printf("%s:%d log2(%f)\n", __FILE__, __LINE__, a);
611
+ assert(0);
612
+ }
613
+ #endif
614
+
615
+ return ::log2f(a);
616
+ }
617
+
618
+ inline CUDA_CALLABLE double log2(double a)
619
+ {
620
+ #if FP_CHECK
621
+ if (!isfinite(a) || a < 0.0)
622
+ {
623
+ printf("%s:%d log2(%f)\n", __FILE__, __LINE__, a);
624
+ assert(0);
625
+ }
626
+ #endif
627
+
628
+ return ::log2(a);
629
+ }
630
+
631
+ inline CUDA_CALLABLE half log10(half a)
632
+ {
633
+ #if FP_CHECK
634
+ if (!isfinite(a) || float(a) < 0.0f)
635
+ {
636
+ printf("%s:%d log10(%f)\n", __FILE__, __LINE__, float(a));
637
+ assert(0);
638
+ }
639
+ #endif
640
+
641
+ return ::log10f(float(a));
642
+ }
643
+
644
+ inline CUDA_CALLABLE float log10(float a)
645
+ {
646
+ #if FP_CHECK
647
+ if (!isfinite(a) || a < 0.0f)
648
+ {
649
+ printf("%s:%d log10(%f)\n", __FILE__, __LINE__, a);
650
+ assert(0);
651
+ }
652
+ #endif
653
+
654
+ return ::log10f(a);
655
+ }
656
+
657
+ inline CUDA_CALLABLE double log10(double a)
658
+ {
659
+ #if FP_CHECK
660
+ if (!isfinite(a) || a < 0.0)
661
+ {
662
+ printf("%s:%d log10(%f)\n", __FILE__, __LINE__, a);
663
+ assert(0);
664
+ }
665
+ #endif
666
+
667
+ return ::log10(a);
668
+ }
669
+
670
+ inline CUDA_CALLABLE half exp(half a)
671
+ {
672
+ half result = ::expf(float(a));
673
+ #if FP_CHECK
674
+ if (!isfinite(a) || !isfinite(result))
675
+ {
676
+ printf("%s:%d exp(%f) = %f\n", __FILE__, __LINE__, float(a), float(result));
677
+ assert(0);
678
+ }
679
+ #endif
680
+ return result;
681
+ }
682
+ inline CUDA_CALLABLE float exp(float a)
683
+ {
684
+ float result = ::expf(a);
685
+ #if FP_CHECK
686
+ if (!isfinite(a) || !isfinite(result))
687
+ {
688
+ printf("%s:%d exp(%f) = %f\n", __FILE__, __LINE__, a, result);
689
+ assert(0);
690
+ }
691
+ #endif
692
+ return result;
693
+ }
694
+ inline CUDA_CALLABLE double exp(double a)
695
+ {
696
+ double result = ::exp(a);
697
+ #if FP_CHECK
698
+ if (!isfinite(a) || !isfinite(result))
699
+ {
700
+ printf("%s:%d exp(%f) = %f\n", __FILE__, __LINE__, a, result);
701
+ assert(0);
702
+ }
703
+ #endif
704
+ return result;
705
+ }
706
+
707
+ inline CUDA_CALLABLE half pow(half a, half b)
708
+ {
709
+ float result = ::powf(float(a), float(b));
710
+ #if FP_CHECK
711
+ if (!isfinite(float(a)) || !isfinite(float(b)) || !isfinite(result))
712
+ {
713
+ printf("%s:%d pow(%f, %f) = %f\n", __FILE__, __LINE__, float(a), float(b), result);
714
+ assert(0);
715
+ }
716
+ #endif
717
+ return result;
718
+ }
719
+
720
+ inline CUDA_CALLABLE float pow(float a, float b)
721
+ {
722
+ float result = ::powf(a, b);
723
+ #if FP_CHECK
724
+ if (!isfinite(a) || !isfinite(b) || !isfinite(result))
725
+ {
726
+ printf("%s:%d pow(%f, %f) = %f\n", __FILE__, __LINE__, a, b, result);
727
+ assert(0);
728
+ }
729
+ #endif
730
+ return result;
731
+ }
732
+
733
+ inline CUDA_CALLABLE double pow(double a, double b)
734
+ {
735
+ double result = ::pow(a, b);
736
+ #if FP_CHECK
737
+ if (!isfinite(a) || !isfinite(b) || !isfinite(result))
738
+ {
739
+ printf("%s:%d pow(%f, %f) = %f\n", __FILE__, __LINE__, a, b, result);
740
+ assert(0);
741
+ }
742
+ #endif
743
+ return result;
744
+ }
745
+
746
+ inline CUDA_CALLABLE half floordiv(half a, half b)
747
+ {
748
+ #if FP_CHECK
749
+ if (!isfinite(a) || !isfinite(b) || float(b) == 0.0f)
750
+ {
751
+ printf("%s:%d mod(%f, %f)\n", __FILE__, __LINE__, float(a), float(b));
752
+ assert(0);
753
+ }
754
+ #endif
755
+ return floorf(float(a/b));
756
+ }
757
+ inline CUDA_CALLABLE float floordiv(float a, float b)
758
+ {
759
+ #if FP_CHECK
760
+ if (!isfinite(a) || !isfinite(b) || b == 0.0f)
761
+ {
762
+ printf("%s:%d mod(%f, %f)\n", __FILE__, __LINE__, a, b);
763
+ assert(0);
764
+ }
765
+ #endif
766
+ return floorf(a/b);
767
+ }
768
+ inline CUDA_CALLABLE double floordiv(double a, double b)
769
+ {
770
+ #if FP_CHECK
771
+ if (!isfinite(a) || !isfinite(b) || b == 0.0)
772
+ {
773
+ printf("%s:%d mod(%f, %f)\n", __FILE__, __LINE__, a, b);
774
+ assert(0);
775
+ }
776
+ #endif
777
+ return ::floor(a/b);
778
+ }
779
+
780
+ inline CUDA_CALLABLE float leaky_min(float a, float b, float r) { return min(a, b); }
781
+ inline CUDA_CALLABLE float leaky_max(float a, float b, float r) { return max(a, b); }
782
+
783
+ inline CUDA_CALLABLE half abs(half x) { return ::fabsf(float(x)); }
784
+ inline CUDA_CALLABLE float abs(float x) { return ::fabsf(x); }
785
+ inline CUDA_CALLABLE double abs(double x) { return ::fabs(x); }
786
+
787
+ inline CUDA_CALLABLE float acos(float x){ return ::acosf(min(max(x, -1.0f), 1.0f)); }
788
+ inline CUDA_CALLABLE float asin(float x){ return ::asinf(min(max(x, -1.0f), 1.0f)); }
789
+ inline CUDA_CALLABLE float atan(float x) { return ::atanf(x); }
790
+ inline CUDA_CALLABLE float atan2(float y, float x) { return ::atan2f(y, x); }
791
+ inline CUDA_CALLABLE float sin(float x) { return ::sinf(x); }
792
+ inline CUDA_CALLABLE float cos(float x) { return ::cosf(x); }
793
+
794
+ inline CUDA_CALLABLE double acos(double x){ return ::acos(min(max(x, -1.0), 1.0)); }
795
+ inline CUDA_CALLABLE double asin(double x){ return ::asin(min(max(x, -1.0), 1.0)); }
796
+ inline CUDA_CALLABLE double atan(double x) { return ::atan(x); }
797
+ inline CUDA_CALLABLE double atan2(double y, double x) { return ::atan2(y, x); }
798
+ inline CUDA_CALLABLE double sin(double x) { return ::sin(x); }
799
+ inline CUDA_CALLABLE double cos(double x) { return ::cos(x); }
800
+
801
+ inline CUDA_CALLABLE half acos(half x){ return ::acosf(min(max(float(x), -1.0f), 1.0f)); }
802
+ inline CUDA_CALLABLE half asin(half x){ return ::asinf(min(max(float(x), -1.0f), 1.0f)); }
803
+ inline CUDA_CALLABLE half atan(half x) { return ::atanf(float(x)); }
804
+ inline CUDA_CALLABLE half atan2(half y, half x) { return ::atan2f(float(y), float(x)); }
805
+ inline CUDA_CALLABLE half sin(half x) { return ::sinf(float(x)); }
806
+ inline CUDA_CALLABLE half cos(half x) { return ::cosf(float(x)); }
807
+
808
+
809
+ inline CUDA_CALLABLE float sqrt(float x)
810
+ {
811
+ #if FP_CHECK
812
+ if (x < 0.0f)
813
+ {
814
+ printf("%s:%d sqrt(%f)\n", __FILE__, __LINE__, x);
815
+ assert(0);
816
+ }
817
+ #endif
818
+ return ::sqrtf(x);
819
+ }
820
+ inline CUDA_CALLABLE double sqrt(double x)
821
+ {
822
+ #if FP_CHECK
823
+ if (x < 0.0)
824
+ {
825
+ printf("%s:%d sqrt(%f)\n", __FILE__, __LINE__, x);
826
+ assert(0);
827
+ }
828
+ #endif
829
+ return ::sqrt(x);
830
+ }
831
+ inline CUDA_CALLABLE half sqrt(half x)
832
+ {
833
+ #if FP_CHECK
834
+ if (float(x) < 0.0f)
835
+ {
836
+ printf("%s:%d sqrt(%f)\n", __FILE__, __LINE__, float(x));
837
+ assert(0);
838
+ }
839
+ #endif
840
+ return ::sqrtf(float(x));
841
+ }
842
+
843
+ inline CUDA_CALLABLE float cbrt(float x) { return ::cbrtf(x); }
844
+ inline CUDA_CALLABLE double cbrt(double x) { return ::cbrt(x); }
845
+ inline CUDA_CALLABLE half cbrt(half x) { return ::cbrtf(float(x)); }
846
+
847
+ inline CUDA_CALLABLE float tan(float x) { return ::tanf(x); }
848
+ inline CUDA_CALLABLE float sinh(float x) { return ::sinhf(x);}
849
+ inline CUDA_CALLABLE float cosh(float x) { return ::coshf(x);}
850
+ inline CUDA_CALLABLE float tanh(float x) { return ::tanhf(x);}
851
+ inline CUDA_CALLABLE float degrees(float x) { return x * RAD_TO_DEG;}
852
+ inline CUDA_CALLABLE float radians(float x) { return x * DEG_TO_RAD;}
853
+
854
+ inline CUDA_CALLABLE double tan(double x) { return ::tan(x); }
855
+ inline CUDA_CALLABLE double sinh(double x) { return ::sinh(x);}
856
+ inline CUDA_CALLABLE double cosh(double x) { return ::cosh(x);}
857
+ inline CUDA_CALLABLE double tanh(double x) { return ::tanh(x);}
858
+ inline CUDA_CALLABLE double degrees(double x) { return x * RAD_TO_DEG;}
859
+ inline CUDA_CALLABLE double radians(double x) { return x * DEG_TO_RAD;}
860
+
861
+ inline CUDA_CALLABLE half tan(half x) { return ::tanf(float(x)); }
862
+ inline CUDA_CALLABLE half sinh(half x) { return ::sinhf(float(x));}
863
+ inline CUDA_CALLABLE half cosh(half x) { return ::coshf(float(x));}
864
+ inline CUDA_CALLABLE half tanh(half x) { return ::tanhf(float(x));}
865
+ inline CUDA_CALLABLE half degrees(half x) { return x * RAD_TO_DEG;}
866
+ inline CUDA_CALLABLE half radians(half x) { return x * DEG_TO_RAD;}
867
+
868
+ inline CUDA_CALLABLE float round(float x) { return ::roundf(x); }
869
+ inline CUDA_CALLABLE float rint(float x) { return ::rintf(x); }
870
+ inline CUDA_CALLABLE float trunc(float x) { return ::truncf(x); }
871
+ inline CUDA_CALLABLE float floor(float x) { return ::floorf(x); }
872
+ inline CUDA_CALLABLE float ceil(float x) { return ::ceilf(x); }
873
+ inline CUDA_CALLABLE float frac(float x) { return x - trunc(x); }
874
+
875
+ inline CUDA_CALLABLE double round(double x) { return ::round(x); }
876
+ inline CUDA_CALLABLE double rint(double x) { return ::rint(x); }
877
+ inline CUDA_CALLABLE double trunc(double x) { return ::trunc(x); }
878
+ inline CUDA_CALLABLE double floor(double x) { return ::floor(x); }
879
+ inline CUDA_CALLABLE double ceil(double x) { return ::ceil(x); }
880
+ inline CUDA_CALLABLE double frac(double x) { return x - trunc(x); }
881
+
882
+ inline CUDA_CALLABLE half round(half x) { return ::roundf(float(x)); }
883
+ inline CUDA_CALLABLE half rint(half x) { return ::rintf(float(x)); }
884
+ inline CUDA_CALLABLE half trunc(half x) { return ::truncf(float(x)); }
885
+ inline CUDA_CALLABLE half floor(half x) { return ::floorf(float(x)); }
886
+ inline CUDA_CALLABLE half ceil(half x) { return ::ceilf(float(x)); }
887
+ inline CUDA_CALLABLE half frac(half x) { return float(x) - trunc(float(x)); }
888
+
889
+ #define DECLARE_ADJOINTS(T)\
890
+ inline CUDA_CALLABLE void adj_log(T a, T& adj_a, T adj_ret)\
891
+ {\
892
+ adj_a += (T(1)/a)*adj_ret;\
893
+ DO_IF_FPCHECK(if (!isfinite(adj_a))\
894
+ {\
895
+ printf("%s:%d - adj_log(%f, %f, %f)\n", __FILE__, __LINE__, float(a), float(adj_a), float(adj_ret));\
896
+ assert(0);\
897
+ })\
898
+ }\
899
+ inline CUDA_CALLABLE void adj_log2(T a, T& adj_a, T adj_ret)\
900
+ { \
901
+ adj_a += (T(1)/a)*(T(1)/log(T(2)))*adj_ret; \
902
+ DO_IF_FPCHECK(if (!isfinite(adj_a))\
903
+ {\
904
+ printf("%s:%d - adj_log2(%f, %f, %f)\n", __FILE__, __LINE__, float(a), float(adj_a), float(adj_ret));\
905
+ assert(0);\
906
+ }) \
907
+ }\
908
+ inline CUDA_CALLABLE void adj_log10(T a, T& adj_a, T adj_ret)\
909
+ {\
910
+ adj_a += (T(1)/a)*(T(1)/log(T(10)))*adj_ret; \
911
+ DO_IF_FPCHECK(if (!isfinite(adj_a))\
912
+ {\
913
+ printf("%s:%d - adj_log10(%f, %f, %f)\n", __FILE__, __LINE__, float(a), float(adj_a), float(adj_ret));\
914
+ assert(0);\
915
+ })\
916
+ }\
917
+ inline CUDA_CALLABLE void adj_exp(T a, T ret, T& adj_a, T adj_ret) { adj_a += ret*adj_ret; }\
918
+ inline CUDA_CALLABLE void adj_pow(T a, T b, T ret, T& adj_a, T& adj_b, T adj_ret)\
919
+ { \
920
+ adj_a += b*pow(a, b-T(1))*adj_ret;\
921
+ adj_b += log(a)*ret*adj_ret;\
922
+ DO_IF_FPCHECK(if (!isfinite(adj_a) || !isfinite(adj_b))\
923
+ {\
924
+ printf("%s:%d - adj_pow(%f, %f, %f, %f, %f)\n", __FILE__, __LINE__, float(a), float(b), float(adj_a), float(adj_b), float(adj_ret));\
925
+ assert(0);\
926
+ })\
927
+ }\
928
+ inline CUDA_CALLABLE void adj_leaky_min(T a, T b, T r, T& adj_a, T& adj_b, T& adj_r, T adj_ret)\
929
+ {\
930
+ if (a < b)\
931
+ adj_a += adj_ret;\
932
+ else\
933
+ {\
934
+ adj_a += r*adj_ret;\
935
+ adj_b += adj_ret;\
936
+ }\
937
+ }\
938
+ inline CUDA_CALLABLE void adj_leaky_max(T a, T b, T r, T& adj_a, T& adj_b, T& adj_r, T adj_ret)\
939
+ {\
940
+ if (a > b)\
941
+ adj_a += adj_ret;\
942
+ else\
943
+ {\
944
+ adj_a += r*adj_ret;\
945
+ adj_b += adj_ret;\
946
+ }\
947
+ }\
948
+ inline CUDA_CALLABLE void adj_acos(T x, T& adj_x, T adj_ret)\
949
+ {\
950
+ T d = sqrt(T(1)-x*x);\
951
+ DO_IF_FPCHECK(adj_x -= (T(1)/d)*adj_ret;\
952
+ if (!isfinite(d) || !isfinite(adj_x))\
953
+ {\
954
+ printf("%s:%d - adj_acos(%f, %f, %f)\n", __FILE__, __LINE__, float(x), float(adj_x), float(adj_ret)); \
955
+ assert(0);\
956
+ })\
957
+ DO_IF_NO_FPCHECK(if (d > T(0))\
958
+ adj_x -= (T(1)/d)*adj_ret;)\
959
+ }\
960
+ inline CUDA_CALLABLE void adj_asin(T x, T& adj_x, T adj_ret)\
961
+ {\
962
+ T d = sqrt(T(1)-x*x);\
963
+ DO_IF_FPCHECK(adj_x += (T(1)/d)*adj_ret;\
964
+ if (!isfinite(d) || !isfinite(adj_x))\
965
+ {\
966
+ printf("%s:%d - adj_asin(%f, %f, %f)\n", __FILE__, __LINE__, float(x), float(adj_x), float(adj_ret)); \
967
+ assert(0);\
968
+ })\
969
+ DO_IF_NO_FPCHECK(if (d > T(0))\
970
+ adj_x += (T(1)/d)*adj_ret;)\
971
+ }\
972
+ inline CUDA_CALLABLE void adj_tan(T x, T& adj_x, T adj_ret)\
973
+ {\
974
+ T cos_x = cos(x);\
975
+ DO_IF_FPCHECK(adj_x += (T(1)/(cos_x*cos_x))*adj_ret;\
976
+ if (!isfinite(adj_x) || cos_x == T(0))\
977
+ {\
978
+ printf("%s:%d - adj_tan(%f, %f, %f)\n", __FILE__, __LINE__, float(x), float(adj_x), float(adj_ret));\
979
+ assert(0);\
980
+ })\
981
+ DO_IF_NO_FPCHECK(if (cos_x != T(0))\
982
+ adj_x += (T(1)/(cos_x*cos_x))*adj_ret;)\
983
+ }\
984
+ inline CUDA_CALLABLE void adj_atan(T x, T& adj_x, T adj_ret)\
985
+ {\
986
+ adj_x += adj_ret /(x*x + T(1));\
987
+ }\
988
+ inline CUDA_CALLABLE void adj_atan2(T y, T x, T& adj_y, T& adj_x, T adj_ret)\
989
+ {\
990
+ T d = x*x + y*y;\
991
+ DO_IF_FPCHECK(adj_x -= y/d*adj_ret;\
992
+ adj_y += x/d*adj_ret;\
993
+ if (!isfinite(adj_x) || !isfinite(adj_y) || d == T(0))\
994
+ {\
995
+ printf("%s:%d - adj_atan2(%f, %f, %f, %f, %f)\n", __FILE__, __LINE__, float(y), float(x), float(adj_y), float(adj_x), float(adj_ret));\
996
+ assert(0);\
997
+ })\
998
+ DO_IF_NO_FPCHECK(if (d > T(0))\
999
+ {\
1000
+ adj_x -= (y/d)*adj_ret;\
1001
+ adj_y += (x/d)*adj_ret;\
1002
+ })\
1003
+ }\
1004
+ inline CUDA_CALLABLE void adj_sin(T x, T& adj_x, T adj_ret)\
1005
+ {\
1006
+ adj_x += cos(x)*adj_ret;\
1007
+ }\
1008
+ inline CUDA_CALLABLE void adj_cos(T x, T& adj_x, T adj_ret)\
1009
+ {\
1010
+ adj_x -= sin(x)*adj_ret;\
1011
+ }\
1012
+ inline CUDA_CALLABLE void adj_sinh(T x, T& adj_x, T adj_ret)\
1013
+ {\
1014
+ adj_x += cosh(x)*adj_ret;\
1015
+ }\
1016
+ inline CUDA_CALLABLE void adj_cosh(T x, T& adj_x, T adj_ret)\
1017
+ {\
1018
+ adj_x += sinh(x)*adj_ret;\
1019
+ }\
1020
+ inline CUDA_CALLABLE void adj_tanh(T x, T ret, T& adj_x, T adj_ret)\
1021
+ {\
1022
+ adj_x += (T(1) - ret*ret)*adj_ret;\
1023
+ }\
1024
+ inline CUDA_CALLABLE void adj_sqrt(T x, T ret, T& adj_x, T adj_ret)\
1025
+ {\
1026
+ adj_x += T(0.5)*(T(1)/ret)*adj_ret;\
1027
+ DO_IF_FPCHECK(if (!isfinite(adj_x))\
1028
+ {\
1029
+ printf("%s:%d - adj_sqrt(%f, %f, %f)\n", __FILE__, __LINE__, float(x), float(adj_x), float(adj_ret));\
1030
+ assert(0);\
1031
+ })\
1032
+ }\
1033
+ inline CUDA_CALLABLE void adj_cbrt(T x, T ret, T& adj_x, T adj_ret)\
1034
+ {\
1035
+ adj_x += (T(1)/T(3))*(T(1)/(ret*ret))*adj_ret;\
1036
+ DO_IF_FPCHECK(if (!isfinite(adj_x))\
1037
+ {\
1038
+ printf("%s:%d - adj_cbrt(%f, %f, %f)\n", __FILE__, __LINE__, float(x), float(adj_x), float(adj_ret));\
1039
+ assert(0);\
1040
+ })\
1041
+ }\
1042
+ inline CUDA_CALLABLE void adj_degrees(T x, T& adj_x, T adj_ret)\
1043
+ {\
1044
+ adj_x += RAD_TO_DEG * adj_ret;\
1045
+ }\
1046
+ inline CUDA_CALLABLE void adj_radians(T x, T& adj_x, T adj_ret)\
1047
+ {\
1048
+ adj_x += DEG_TO_RAD * adj_ret;\
1049
+ }\
1050
+ inline CUDA_CALLABLE void adj_round(T x, T& adj_x, T adj_ret){ }\
1051
+ inline CUDA_CALLABLE void adj_rint(T x, T& adj_x, T adj_ret){ }\
1052
+ inline CUDA_CALLABLE void adj_trunc(T x, T& adj_x, T adj_ret){ }\
1053
+ inline CUDA_CALLABLE void adj_floor(T x, T& adj_x, T adj_ret){ }\
1054
+ inline CUDA_CALLABLE void adj_ceil(T x, T& adj_x, T adj_ret){ }\
1055
+ inline CUDA_CALLABLE void adj_frac(T x, T& adj_x, T adj_ret){ }
1056
+
1057
+ DECLARE_ADJOINTS(float16)
1058
+ DECLARE_ADJOINTS(float32)
1059
+ DECLARE_ADJOINTS(float64)
1060
+
1061
+ template <typename C, typename T>
1062
+ CUDA_CALLABLE inline T select(const C& cond, const T& a, const T& b)
1063
+ {
1064
+ // The double NOT operator !! casts to bool without compiler warnings.
1065
+ return (!!cond) ? b : a;
1066
+ }
1067
+
1068
+ template <typename C, typename T>
1069
+ CUDA_CALLABLE inline void adj_select(const C& cond, const T& a, const T& b, C& adj_cond, T& adj_a, T& adj_b, const T& adj_ret)
1070
+ {
1071
+ // The double NOT operator !! casts to bool without compiler warnings.
1072
+ if (!!cond)
1073
+ adj_b += adj_ret;
1074
+ else
1075
+ adj_a += adj_ret;
1076
+ }
1077
+
1078
+ template <typename T>
1079
+ CUDA_CALLABLE inline T copy(const T& src)
1080
+ {
1081
+ return src;
1082
+ }
1083
+
1084
+ template <typename T>
1085
+ CUDA_CALLABLE inline void adj_copy(const T& src, T& adj_src, T& adj_dest)
1086
+ {
1087
+ adj_src += adj_dest;
1088
+ adj_dest = T{};
1089
+ }
1090
+
1091
+ template <typename T>
1092
+ CUDA_CALLABLE inline void assign(T& dest, const T& src)
1093
+ {
1094
+ dest = src;
1095
+ }
1096
+
1097
+ template <typename T>
1098
+ CUDA_CALLABLE inline void adj_assign(T& dest, const T& src, T& adj_dest, T& adj_src)
1099
+ {
1100
+ // this is generally a non-differentiable operation since it violates SSA,
1101
+ // except in read-modify-write statements which are reversible through backpropagation
1102
+ adj_src = adj_dest;
1103
+ adj_dest = T{};
1104
+ }
1105
+
1106
+
1107
+ // some helpful operator overloads (just for C++ use, these are not adjointed)
1108
+
1109
+ template <typename T>
1110
+ CUDA_CALLABLE inline T& operator += (T& a, const T& b) { a = add(a, b); return a; }
1111
+
1112
+ template <typename T>
1113
+ CUDA_CALLABLE inline T& operator -= (T& a, const T& b) { a = sub(a, b); return a; }
1114
+
1115
+ template <typename T>
1116
+ CUDA_CALLABLE inline T operator+(const T& a, const T& b) { return add(a, b); }
1117
+
1118
+ template <typename T>
1119
+ CUDA_CALLABLE inline T operator-(const T& a, const T& b) { return sub(a, b); }
1120
+
1121
+ template <typename T>
1122
+ CUDA_CALLABLE inline T pos(const T& x) { return x; }
1123
+ template <typename T>
1124
+ CUDA_CALLABLE inline void adj_pos(const T& x, T& adj_x, const T& adj_ret) { adj_x += T(adj_ret); }
1125
+
1126
+ // unary negation implemented as negative multiply, not sure the fp implications of this
1127
+ // may be better as 0.0 - x?
1128
+ template <typename T>
1129
+ CUDA_CALLABLE inline T neg(const T& x) { return T(0.0) - x; }
1130
+ template <typename T>
1131
+ CUDA_CALLABLE inline void adj_neg(const T& x, T& adj_x, const T& adj_ret) { adj_x += T(-adj_ret); }
1132
+
1133
+ // unary boolean negation
1134
+ template <typename T>
1135
+ CUDA_CALLABLE inline bool unot(const T& b) { return !b; }
1136
+ template <typename T>
1137
+ CUDA_CALLABLE inline void adj_unot(const T& b, T& adj_b, const bool& adj_ret) { }
1138
+
1139
+ const int LAUNCH_MAX_DIMS = 4; // should match types.py
1140
+
1141
+ struct launch_bounds_t
1142
+ {
1143
+ int shape[LAUNCH_MAX_DIMS]; // size of each dimension
1144
+ int ndim; // number of valid dimension
1145
+ size_t size; // total number of threads
1146
+ };
1147
+
1148
+ #ifndef __CUDACC__
1149
+ static size_t s_threadIdx;
1150
+ #endif
1151
+
1152
+ inline CUDA_CALLABLE size_t grid_index()
1153
+ {
1154
+ #ifdef __CUDACC__
1155
+ // Need to cast at least one of the variables being multiplied so that type promotion happens before the multiplication
1156
+ size_t grid_index = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
1157
+ return grid_index;
1158
+ #else
1159
+ return s_threadIdx;
1160
+ #endif
1161
+ }
1162
+
1163
+ inline CUDA_CALLABLE int tid(size_t index)
1164
+ {
1165
+ // For the 1-D tid() we need to warn the user if we're about to provide a truncated index
1166
+ // Only do this in _DEBUG when called from device to avoid excessive register allocation
1167
+ #if defined(_DEBUG) || !defined(__CUDA_ARCH__)
1168
+ if (index > 2147483647) {
1169
+ printf("Warp warning: tid() is returning an overflowed int\n");
1170
+ }
1171
+ #endif
1172
+ return static_cast<int>(index);
1173
+ }
1174
+
1175
+ inline CUDA_CALLABLE_DEVICE void tid(int& i, int& j, size_t index, const launch_bounds_t& launch_bounds)
1176
+ {
1177
+ const size_t n = launch_bounds.shape[1];
1178
+
1179
+ // convert to work item
1180
+ i = index/n;
1181
+ j = index%n;
1182
+ }
1183
+
1184
+ inline CUDA_CALLABLE_DEVICE void tid(int& i, int& j, int& k, size_t index, const launch_bounds_t& launch_bounds)
1185
+ {
1186
+ const size_t n = launch_bounds.shape[1];
1187
+ const size_t o = launch_bounds.shape[2];
1188
+
1189
+ // convert to work item
1190
+ i = index/(n*o);
1191
+ j = index%(n*o)/o;
1192
+ k = index%o;
1193
+ }
1194
+
1195
+ inline CUDA_CALLABLE_DEVICE void tid(int& i, int& j, int& k, int& l, size_t index, const launch_bounds_t& launch_bounds)
1196
+ {
1197
+ const size_t n = launch_bounds.shape[1];
1198
+ const size_t o = launch_bounds.shape[2];
1199
+ const size_t p = launch_bounds.shape[3];
1200
+
1201
+ // convert to work item
1202
+ i = index/(n*o*p);
1203
+ j = index%(n*o*p)/(o*p);
1204
+ k = index%(o*p)/p;
1205
+ l = index%p;
1206
+ }
1207
+
1208
+ template<typename T>
1209
+ inline CUDA_CALLABLE T atomic_add(T* buf, T value)
1210
+ {
1211
+ #if !defined(__CUDA_ARCH__)
1212
+ T old = buf[0];
1213
+ buf[0] += value;
1214
+ return old;
1215
+ #else
1216
+ return atomicAdd(buf, value);
1217
+ #endif
1218
+ }
1219
+
1220
+ template<>
1221
+ inline CUDA_CALLABLE float16 atomic_add(float16* buf, float16 value)
1222
+ {
1223
+ #if !defined(__CUDA_ARCH__)
1224
+ float16 old = buf[0];
1225
+ buf[0] += value;
1226
+ return old;
1227
+ #elif defined(__clang__) // CUDA compiled by Clang
1228
+ __half r = atomicAdd(reinterpret_cast<__half*>(buf), *reinterpret_cast<__half*>(&value));
1229
+ return *reinterpret_cast<float16*>(&r);
1230
+ #else // CUDA compiled by NVRTC
1231
+ //return atomicAdd(buf, value);
1232
+
1233
+ /* Define __PTR for atomicAdd prototypes below, undef after done */
1234
+ #if (defined(_MSC_VER) && defined(_WIN64)) || defined(__LP64__) || defined(__CUDACC_RTC__)
1235
+ #define __PTR "l"
1236
+ #else
1237
+ #define __PTR "r"
1238
+ #endif /*(defined(_MSC_VER) && defined(_WIN64)) || defined(__LP64__) || defined(__CUDACC_RTC__)*/
1239
+
1240
+ half r = 0.0;
1241
+
1242
+ #if __CUDA_ARCH__ >= 700
1243
+
1244
+ asm volatile ("{ atom.add.noftz.f16 %0,[%1],%2; }\n"
1245
+ : "=h"(r.u)
1246
+ : __PTR(buf), "h"(value.u)
1247
+ : "memory");
1248
+ #endif
1249
+
1250
+ return r;
1251
+
1252
+ #undef __PTR
1253
+
1254
+ #endif // CUDA compiled by NVRTC
1255
+
1256
+ }
1257
+
1258
+ // emulate atomic float max
1259
+ inline CUDA_CALLABLE float atomic_max(float* address, float val)
1260
+ {
1261
+ #if defined(__CUDA_ARCH__)
1262
+ int *address_as_int = (int*)address;
1263
+ int old = *address_as_int, assumed;
1264
+
1265
+ while (val > __int_as_float(old))
1266
+ {
1267
+ assumed = old;
1268
+ old = atomicCAS(address_as_int, assumed,
1269
+ __float_as_int(val));
1270
+ }
1271
+
1272
+ return __int_as_float(old);
1273
+
1274
+ #else
1275
+ float old = *address;
1276
+ *address = max(old, val);
1277
+ return old;
1278
+ #endif
1279
+ }
1280
+
1281
+ // emulate atomic float min/max with atomicCAS()
1282
+ inline CUDA_CALLABLE float atomic_min(float* address, float val)
1283
+ {
1284
+ #if defined(__CUDA_ARCH__)
1285
+ int *address_as_int = (int*)address;
1286
+ int old = *address_as_int, assumed;
1287
+
1288
+ while (val < __int_as_float(old))
1289
+ {
1290
+ assumed = old;
1291
+ old = atomicCAS(address_as_int, assumed,
1292
+ __float_as_int(val));
1293
+ }
1294
+
1295
+ return __int_as_float(old);
1296
+
1297
+ #else
1298
+ float old = *address;
1299
+ *address = min(old, val);
1300
+ return old;
1301
+ #endif
1302
+ }
1303
+
1304
+ inline CUDA_CALLABLE int atomic_max(int* address, int val)
1305
+ {
1306
+ #if defined(__CUDA_ARCH__)
1307
+ return atomicMax(address, val);
1308
+
1309
+ #else
1310
+ int old = *address;
1311
+ *address = max(old, val);
1312
+ return old;
1313
+ #endif
1314
+ }
1315
+
1316
+ // atomic int min
1317
+ inline CUDA_CALLABLE int atomic_min(int* address, int val)
1318
+ {
1319
+ #if defined(__CUDA_ARCH__)
1320
+ return atomicMin(address, val);
1321
+
1322
+ #else
1323
+ int old = *address;
1324
+ *address = min(old, val);
1325
+ return old;
1326
+ #endif
1327
+ }
1328
+
1329
+ // default behavior for adjoint of atomic min/max operation that accumulates gradients for all elements matching the min/max value
1330
+ template <typename T>
1331
+ CUDA_CALLABLE inline void adj_atomic_minmax(T *addr, T *adj_addr, const T &value, T &adj_value)
1332
+ {
1333
+ if (value == *addr)
1334
+ adj_value += *adj_addr;
1335
+ }
1336
+
1337
+ // for integral types we do not accumulate gradients
1338
+ CUDA_CALLABLE inline void adj_atomic_minmax(int8* buf, int8* adj_buf, const int8 &value, int8 &adj_value) { }
1339
+ CUDA_CALLABLE inline void adj_atomic_minmax(uint8* buf, uint8* adj_buf, const uint8 &value, uint8 &adj_value) { }
1340
+ CUDA_CALLABLE inline void adj_atomic_minmax(int16* buf, int16* adj_buf, const int16 &value, int16 &adj_value) { }
1341
+ CUDA_CALLABLE inline void adj_atomic_minmax(uint16* buf, uint16* adj_buf, const uint16 &value, uint16 &adj_value) { }
1342
+ CUDA_CALLABLE inline void adj_atomic_minmax(int32* buf, int32* adj_buf, const int32 &value, int32 &adj_value) { }
1343
+ CUDA_CALLABLE inline void adj_atomic_minmax(uint32* buf, uint32* adj_buf, const uint32 &value, uint32 &adj_value) { }
1344
+ CUDA_CALLABLE inline void adj_atomic_minmax(int64* buf, int64* adj_buf, const int64 &value, int64 &adj_value) { }
1345
+ CUDA_CALLABLE inline void adj_atomic_minmax(uint64* buf, uint64* adj_buf, const uint64 &value, uint64 &adj_value) { }
1346
+ CUDA_CALLABLE inline void adj_atomic_minmax(bool* buf, bool* adj_buf, const bool &value, bool &adj_value) { }
1347
+
1348
+
1349
+ } // namespace wp
1350
+
1351
+
1352
+ // bool and printf are defined outside of the wp namespace in crt.h, hence
1353
+ // their adjoint counterparts are also defined in the global namespace.
1354
+ template <typename T>
1355
+ CUDA_CALLABLE inline void adj_bool(T, T&, bool) {}
1356
+ inline CUDA_CALLABLE void adj_printf(const char* fmt, ...) {}
1357
+
1358
+
1359
+ #include "vec.h"
1360
+ #include "mat.h"
1361
+ #include "quat.h"
1362
+ #include "spatial.h"
1363
+ #include "intersect.h"
1364
+ #include "intersect_adj.h"
1365
+
1366
+ //--------------
1367
+ namespace wp
1368
+ {
1369
+
1370
+
1371
+ // dot for scalar types just to make some templates compile for scalar/vector
1372
+ inline CUDA_CALLABLE float dot(float a, float b) { return mul(a, b); }
1373
+ inline CUDA_CALLABLE void adj_dot(float a, float b, float& adj_a, float& adj_b, float adj_ret) { adj_mul(a, b, adj_a, adj_b, adj_ret); }
1374
+ inline CUDA_CALLABLE float tensordot(float a, float b) { return mul(a, b); }
1375
+
1376
+
1377
+ #define DECLARE_INTERP_FUNCS(T) \
1378
+ CUDA_CALLABLE inline T smoothstep(T edge0, T edge1, T x)\
1379
+ {\
1380
+ x = clamp((x - edge0) / (edge1 - edge0), T(0), T(1));\
1381
+ return x * x * (T(3) - T(2) * x);\
1382
+ }\
1383
+ CUDA_CALLABLE inline void adj_smoothstep(T edge0, T edge1, T x, T& adj_edge0, T& adj_edge1, T& adj_x, T adj_ret)\
1384
+ {\
1385
+ T ab = edge0 - edge1;\
1386
+ T ax = edge0 - x;\
1387
+ T bx = edge1 - x;\
1388
+ T xb = x - edge1;\
1389
+ \
1390
+ if (bx / ab >= T(0) || ax / ab <= T(0))\
1391
+ {\
1392
+ return;\
1393
+ }\
1394
+ \
1395
+ T ab3 = ab * ab * ab;\
1396
+ T ab4 = ab3 * ab;\
1397
+ adj_edge0 += adj_ret * ((T(6) * ax * bx * bx) / ab4);\
1398
+ adj_edge1 += adj_ret * ((T(6) * ax * ax * xb) / ab4);\
1399
+ adj_x += adj_ret * ((T(6) * ax * bx ) / ab3);\
1400
+ }\
1401
+ CUDA_CALLABLE inline T lerp(const T& a, const T& b, T t)\
1402
+ {\
1403
+ return a*(T(1)-t) + b*t;\
1404
+ }\
1405
+ CUDA_CALLABLE inline void adj_lerp(const T& a, const T& b, T t, T& adj_a, T& adj_b, T& adj_t, const T& adj_ret)\
1406
+ {\
1407
+ adj_a += adj_ret*(T(1)-t);\
1408
+ adj_b += adj_ret*t;\
1409
+ adj_t += b*adj_ret - a*adj_ret;\
1410
+ }
1411
+
1412
+ DECLARE_INTERP_FUNCS(float16)
1413
+ DECLARE_INTERP_FUNCS(float32)
1414
+ DECLARE_INTERP_FUNCS(float64)
1415
+
1416
+ inline CUDA_CALLABLE void print(const str s)
1417
+ {
1418
+ printf("%s\n", s);
1419
+ }
1420
+
1421
+ inline CUDA_CALLABLE void print(int i)
1422
+ {
1423
+ printf("%d\n", i);
1424
+ }
1425
+
1426
+ inline CUDA_CALLABLE void print(short i)
1427
+ {
1428
+ printf("%hd\n", i);
1429
+ }
1430
+
1431
+ inline CUDA_CALLABLE void print(long i)
1432
+ {
1433
+ printf("%ld\n", i);
1434
+ }
1435
+
1436
+ inline CUDA_CALLABLE void print(long long i)
1437
+ {
1438
+ printf("%lld\n", i);
1439
+ }
1440
+
1441
+ inline CUDA_CALLABLE void print(unsigned i)
1442
+ {
1443
+ printf("%u\n", i);
1444
+ }
1445
+
1446
+ inline CUDA_CALLABLE void print(unsigned short i)
1447
+ {
1448
+ printf("%hu\n", i);
1449
+ }
1450
+
1451
+ inline CUDA_CALLABLE void print(unsigned long i)
1452
+ {
1453
+ printf("%lu\n", i);
1454
+ }
1455
+
1456
+ inline CUDA_CALLABLE void print(unsigned long long i)
1457
+ {
1458
+ printf("%llu\n", i);
1459
+ }
1460
+
1461
+ template<unsigned Length, typename Type>
1462
+ inline CUDA_CALLABLE void print(vec_t<Length, Type> v)
1463
+ {
1464
+ for( unsigned i=0; i < Length; ++i )
1465
+ {
1466
+ printf("%g ", float(v[i]));
1467
+ }
1468
+ printf("\n");
1469
+ }
1470
+
1471
+ template<typename Type>
1472
+ inline CUDA_CALLABLE void print(quat_t<Type> i)
1473
+ {
1474
+ printf("%g %g %g %g\n", float(i.x), float(i.y), float(i.z), float(i.w));
1475
+ }
1476
+
1477
+ template<unsigned Rows,unsigned Cols,typename Type>
1478
+ inline CUDA_CALLABLE void print(const mat_t<Rows,Cols,Type> &m)
1479
+ {
1480
+ for( unsigned i=0; i< Rows; ++i )
1481
+ {
1482
+ for( unsigned j=0; j< Cols; ++j )
1483
+ {
1484
+ printf("%g ",float(m.data[i][j]));
1485
+ }
1486
+ printf("\n");
1487
+ }
1488
+ }
1489
+
1490
+ template<typename Type>
1491
+ inline CUDA_CALLABLE void print(transform_t<Type> t)
1492
+ {
1493
+ printf("(%g %g %g) (%g %g %g %g)\n", float(t.p[0]), float(t.p[1]), float(t.p[2]), float(t.q.x), float(t.q.y), float(t.q.z), float(t.q.w));
1494
+ }
1495
+
1496
+ inline CUDA_CALLABLE void adj_print(int i, int adj_i) { printf("%d adj: %d\n", i, adj_i); }
1497
+ inline CUDA_CALLABLE void adj_print(float f, float adj_f) { printf("%g adj: %g\n", f, adj_f); }
1498
+ inline CUDA_CALLABLE void adj_print(short f, short adj_f) { printf("%hd adj: %hd\n", f, adj_f); }
1499
+ inline CUDA_CALLABLE void adj_print(long f, long adj_f) { printf("%ld adj: %ld\n", f, adj_f); }
1500
+ inline CUDA_CALLABLE void adj_print(long long f, long long adj_f) { printf("%lld adj: %lld\n", f, adj_f); }
1501
+ inline CUDA_CALLABLE void adj_print(unsigned f, unsigned adj_f) { printf("%u adj: %u\n", f, adj_f); }
1502
+ inline CUDA_CALLABLE void adj_print(unsigned short f, unsigned short adj_f) { printf("%hu adj: %hu\n", f, adj_f); }
1503
+ inline CUDA_CALLABLE void adj_print(unsigned long f, unsigned long adj_f) { printf("%lu adj: %lu\n", f, adj_f); }
1504
+ inline CUDA_CALLABLE void adj_print(unsigned long long f, unsigned long long adj_f) { printf("%llu adj: %llu\n", f, adj_f); }
1505
+ inline CUDA_CALLABLE void adj_print(half h, half adj_h) { printf("%g adj: %g\n", half_to_float(h), half_to_float(adj_h)); }
1506
+ inline CUDA_CALLABLE void adj_print(double f, double adj_f) { printf("%g adj: %g\n", f, adj_f); }
1507
+
1508
+ template<unsigned Length, typename Type>
1509
+ inline CUDA_CALLABLE void adj_print(vec_t<Length, Type> v, vec_t<Length, Type>& adj_v) { printf("%g %g adj: %g %g \n", v[0], v[1], adj_v[0], adj_v[1]); }
1510
+
1511
+ template<unsigned Rows, unsigned Cols, typename Type>
1512
+ inline CUDA_CALLABLE void adj_print(mat_t<Rows, Cols, Type> m, mat_t<Rows, Cols, Type>& adj_m) { }
1513
+
1514
+ template<typename Type>
1515
+ inline CUDA_CALLABLE void adj_print(quat_t<Type> q, quat_t<Type>& adj_q) { printf("%g %g %g %g adj: %g %g %g %g\n", q.x, q.y, q.z, q.w, adj_q.x, adj_q.y, adj_q.z, adj_q.w); }
1516
+
1517
+ template<typename Type>
1518
+ inline CUDA_CALLABLE void adj_print(transform_t<Type> t, transform_t<Type>& adj_t) {}
1519
+
1520
+ inline CUDA_CALLABLE void adj_print(str t, str& adj_t) {}
1521
+
1522
+
1523
+ template <typename T>
1524
+ inline CUDA_CALLABLE void expect_eq(const T& actual, const T& expected)
1525
+ {
1526
+ if (!(actual == expected))
1527
+ {
1528
+ printf("Error, expect_eq() failed:\n");
1529
+ printf("\t Expected: "); print(expected);
1530
+ printf("\t Actual: "); print(actual);
1531
+ }
1532
+ }
1533
+
1534
+ template <typename T>
1535
+ inline CUDA_CALLABLE void adj_expect_eq(const T& a, const T& b, T& adj_a, T& adj_b)
1536
+ {
1537
+ // nop
1538
+ }
1539
+
1540
+ template <typename T>
1541
+ inline CUDA_CALLABLE void expect_neq(const T& actual, const T& expected)
1542
+ {
1543
+ if (actual == expected)
1544
+ {
1545
+ printf("Error, expect_neq() failed:\n");
1546
+ printf("\t Expected: "); print(expected);
1547
+ printf("\t Actual: "); print(actual);
1548
+ }
1549
+ }
1550
+
1551
+ template <typename T>
1552
+ inline CUDA_CALLABLE void adj_expect_neq(const T& a, const T& b, T& adj_a, T& adj_b)
1553
+ {
1554
+ // nop
1555
+ }
1556
+
1557
+ template <typename T>
1558
+ inline CUDA_CALLABLE void expect_near(const T& actual, const T& expected, const T& tolerance)
1559
+ {
1560
+ if (abs(actual - expected) > tolerance)
1561
+ {
1562
+ printf("Error, expect_near() failed with tolerance "); print(tolerance);
1563
+ printf("\t Expected: "); print(expected);
1564
+ printf("\t Actual: "); print(actual);
1565
+ }
1566
+ }
1567
+
1568
+ inline CUDA_CALLABLE void expect_near(const vec3& actual, const vec3& expected, const float& tolerance)
1569
+ {
1570
+ const float diff = max(max(abs(actual[0] - expected[0]), abs(actual[1] - expected[1])), abs(actual[2] - expected[2]));
1571
+ if (diff > tolerance)
1572
+ {
1573
+ printf("Error, expect_near() failed with tolerance "); print(tolerance);
1574
+ printf("\t Expected: "); print(expected);
1575
+ printf("\t Actual: "); print(actual);
1576
+ }
1577
+ }
1578
+
1579
+ template <typename T>
1580
+ inline CUDA_CALLABLE void adj_expect_near(const T& actual, const T& expected, const T& tolerance, T& adj_actual, T& adj_expected, T& adj_tolerance)
1581
+ {
1582
+ // nop
1583
+ }
1584
+
1585
+ inline CUDA_CALLABLE void adj_expect_near(const vec3& actual, const vec3& expected, float tolerance, vec3& adj_actual, vec3& adj_expected, float adj_tolerance)
1586
+ {
1587
+ // nop
1588
+ }
1589
+
1590
+
1591
+ } // namespace wp
1592
+
1593
+ // include array.h so we have the print, isfinite functions for the inner array types defined
1594
+ #include "array.h"
1595
+ #include "mesh.h"
1596
+ #include "bvh.h"
1597
+ #include "svd.h"
1598
+ #include "hashgrid.h"
1599
+ #include "volume.h"
1600
+ #include "range.h"
1601
+ #include "rand.h"
1602
+ #include "noise.h"
1603
+ #include "matnn.h"