warp-lang 1.0.1__py3-none-manylinux2014_x86_64.whl → 1.1.0__py3-none-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (346) hide show
  1. warp/__init__.py +108 -97
  2. warp/__init__.pyi +1 -1
  3. warp/bin/warp-clang.so +0 -0
  4. warp/bin/warp.so +0 -0
  5. warp/build.py +115 -113
  6. warp/build_dll.py +383 -375
  7. warp/builtins.py +3425 -3354
  8. warp/codegen.py +2878 -2792
  9. warp/config.py +40 -36
  10. warp/constants.py +45 -45
  11. warp/context.py +5194 -5102
  12. warp/dlpack.py +442 -442
  13. warp/examples/__init__.py +16 -16
  14. warp/examples/assets/bear.usd +0 -0
  15. warp/examples/assets/bunny.usd +0 -0
  16. warp/examples/assets/cartpole.urdf +110 -110
  17. warp/examples/assets/crazyflie.usd +0 -0
  18. warp/examples/assets/cube.usd +0 -0
  19. warp/examples/assets/nv_ant.xml +92 -92
  20. warp/examples/assets/nv_humanoid.xml +183 -183
  21. warp/examples/assets/quadruped.urdf +267 -267
  22. warp/examples/assets/rocks.nvdb +0 -0
  23. warp/examples/assets/rocks.usd +0 -0
  24. warp/examples/assets/sphere.usd +0 -0
  25. warp/examples/benchmarks/benchmark_api.py +383 -383
  26. warp/examples/benchmarks/benchmark_cloth.py +278 -279
  27. warp/examples/benchmarks/benchmark_cloth_cupy.py +88 -88
  28. warp/examples/benchmarks/benchmark_cloth_jax.py +97 -100
  29. warp/examples/benchmarks/benchmark_cloth_numba.py +146 -142
  30. warp/examples/benchmarks/benchmark_cloth_numpy.py +77 -77
  31. warp/examples/benchmarks/benchmark_cloth_pytorch.py +86 -86
  32. warp/examples/benchmarks/benchmark_cloth_taichi.py +112 -112
  33. warp/examples/benchmarks/benchmark_cloth_warp.py +146 -146
  34. warp/examples/benchmarks/benchmark_launches.py +295 -295
  35. warp/examples/browse.py +29 -28
  36. warp/examples/core/example_dem.py +234 -221
  37. warp/examples/core/example_fluid.py +293 -267
  38. warp/examples/core/example_graph_capture.py +144 -129
  39. warp/examples/core/example_marching_cubes.py +188 -176
  40. warp/examples/core/example_mesh.py +174 -154
  41. warp/examples/core/example_mesh_intersect.py +205 -193
  42. warp/examples/core/example_nvdb.py +176 -169
  43. warp/examples/core/example_raycast.py +105 -89
  44. warp/examples/core/example_raymarch.py +199 -178
  45. warp/examples/core/example_render_opengl.py +185 -141
  46. warp/examples/core/example_sph.py +405 -389
  47. warp/examples/core/example_torch.py +222 -181
  48. warp/examples/core/example_wave.py +263 -249
  49. warp/examples/fem/bsr_utils.py +378 -380
  50. warp/examples/fem/example_apic_fluid.py +407 -391
  51. warp/examples/fem/example_convection_diffusion.py +182 -168
  52. warp/examples/fem/example_convection_diffusion_dg.py +219 -209
  53. warp/examples/fem/example_convection_diffusion_dg0.py +204 -194
  54. warp/examples/fem/example_deformed_geometry.py +177 -159
  55. warp/examples/fem/example_diffusion.py +201 -173
  56. warp/examples/fem/example_diffusion_3d.py +177 -152
  57. warp/examples/fem/example_diffusion_mgpu.py +221 -214
  58. warp/examples/fem/example_mixed_elasticity.py +244 -222
  59. warp/examples/fem/example_navier_stokes.py +259 -243
  60. warp/examples/fem/example_stokes.py +220 -192
  61. warp/examples/fem/example_stokes_transfer.py +265 -249
  62. warp/examples/fem/mesh_utils.py +133 -109
  63. warp/examples/fem/plot_utils.py +292 -287
  64. warp/examples/optim/example_bounce.py +260 -248
  65. warp/examples/optim/example_cloth_throw.py +222 -210
  66. warp/examples/optim/example_diffray.py +566 -535
  67. warp/examples/optim/example_drone.py +864 -835
  68. warp/examples/optim/example_inverse_kinematics.py +176 -169
  69. warp/examples/optim/example_inverse_kinematics_torch.py +185 -170
  70. warp/examples/optim/example_spring_cage.py +239 -234
  71. warp/examples/optim/example_trajectory.py +223 -201
  72. warp/examples/optim/example_walker.py +306 -292
  73. warp/examples/sim/example_cartpole.py +139 -128
  74. warp/examples/sim/example_cloth.py +196 -184
  75. warp/examples/sim/example_granular.py +124 -113
  76. warp/examples/sim/example_granular_collision_sdf.py +197 -185
  77. warp/examples/sim/example_jacobian_ik.py +236 -213
  78. warp/examples/sim/example_particle_chain.py +118 -106
  79. warp/examples/sim/example_quadruped.py +193 -179
  80. warp/examples/sim/example_rigid_chain.py +197 -189
  81. warp/examples/sim/example_rigid_contact.py +189 -176
  82. warp/examples/sim/example_rigid_force.py +127 -126
  83. warp/examples/sim/example_rigid_gyroscopic.py +109 -97
  84. warp/examples/sim/example_rigid_soft_contact.py +134 -124
  85. warp/examples/sim/example_soft_body.py +190 -178
  86. warp/fabric.py +337 -335
  87. warp/fem/__init__.py +60 -27
  88. warp/fem/cache.py +401 -388
  89. warp/fem/dirichlet.py +178 -179
  90. warp/fem/domain.py +262 -263
  91. warp/fem/field/__init__.py +100 -101
  92. warp/fem/field/field.py +148 -149
  93. warp/fem/field/nodal_field.py +298 -299
  94. warp/fem/field/restriction.py +22 -21
  95. warp/fem/field/test.py +180 -181
  96. warp/fem/field/trial.py +183 -183
  97. warp/fem/geometry/__init__.py +15 -19
  98. warp/fem/geometry/closest_point.py +69 -70
  99. warp/fem/geometry/deformed_geometry.py +270 -271
  100. warp/fem/geometry/element.py +744 -744
  101. warp/fem/geometry/geometry.py +184 -186
  102. warp/fem/geometry/grid_2d.py +380 -373
  103. warp/fem/geometry/grid_3d.py +441 -435
  104. warp/fem/geometry/hexmesh.py +953 -953
  105. warp/fem/geometry/partition.py +374 -376
  106. warp/fem/geometry/quadmesh_2d.py +532 -532
  107. warp/fem/geometry/tetmesh.py +840 -840
  108. warp/fem/geometry/trimesh_2d.py +577 -577
  109. warp/fem/integrate.py +1630 -1615
  110. warp/fem/operator.py +190 -191
  111. warp/fem/polynomial.py +214 -213
  112. warp/fem/quadrature/__init__.py +2 -2
  113. warp/fem/quadrature/pic_quadrature.py +243 -245
  114. warp/fem/quadrature/quadrature.py +295 -294
  115. warp/fem/space/__init__.py +294 -292
  116. warp/fem/space/basis_space.py +488 -489
  117. warp/fem/space/collocated_function_space.py +100 -105
  118. warp/fem/space/dof_mapper.py +236 -236
  119. warp/fem/space/function_space.py +148 -145
  120. warp/fem/space/grid_2d_function_space.py +267 -267
  121. warp/fem/space/grid_3d_function_space.py +305 -306
  122. warp/fem/space/hexmesh_function_space.py +350 -352
  123. warp/fem/space/partition.py +350 -350
  124. warp/fem/space/quadmesh_2d_function_space.py +368 -369
  125. warp/fem/space/restriction.py +158 -160
  126. warp/fem/space/shape/__init__.py +13 -15
  127. warp/fem/space/shape/cube_shape_function.py +738 -738
  128. warp/fem/space/shape/shape_function.py +102 -103
  129. warp/fem/space/shape/square_shape_function.py +611 -611
  130. warp/fem/space/shape/tet_shape_function.py +565 -567
  131. warp/fem/space/shape/triangle_shape_function.py +429 -429
  132. warp/fem/space/tetmesh_function_space.py +294 -292
  133. warp/fem/space/topology.py +297 -295
  134. warp/fem/space/trimesh_2d_function_space.py +223 -221
  135. warp/fem/types.py +77 -77
  136. warp/fem/utils.py +495 -495
  137. warp/jax.py +166 -141
  138. warp/jax_experimental.py +341 -339
  139. warp/native/array.h +1072 -1025
  140. warp/native/builtin.h +1560 -1560
  141. warp/native/bvh.cpp +398 -398
  142. warp/native/bvh.cu +525 -525
  143. warp/native/bvh.h +429 -429
  144. warp/native/clang/clang.cpp +495 -464
  145. warp/native/crt.cpp +31 -31
  146. warp/native/crt.h +334 -334
  147. warp/native/cuda_crt.h +1049 -1049
  148. warp/native/cuda_util.cpp +549 -540
  149. warp/native/cuda_util.h +288 -203
  150. warp/native/cutlass_gemm.cpp +34 -34
  151. warp/native/cutlass_gemm.cu +372 -372
  152. warp/native/error.cpp +66 -66
  153. warp/native/error.h +27 -27
  154. warp/native/fabric.h +228 -228
  155. warp/native/hashgrid.cpp +301 -278
  156. warp/native/hashgrid.cu +78 -77
  157. warp/native/hashgrid.h +227 -227
  158. warp/native/initializer_array.h +32 -32
  159. warp/native/intersect.h +1204 -1204
  160. warp/native/intersect_adj.h +365 -365
  161. warp/native/intersect_tri.h +322 -322
  162. warp/native/marching.cpp +2 -2
  163. warp/native/marching.cu +497 -497
  164. warp/native/marching.h +2 -2
  165. warp/native/mat.h +1498 -1498
  166. warp/native/matnn.h +333 -333
  167. warp/native/mesh.cpp +203 -203
  168. warp/native/mesh.cu +293 -293
  169. warp/native/mesh.h +1887 -1887
  170. warp/native/nanovdb/NanoVDB.h +4782 -4782
  171. warp/native/nanovdb/PNanoVDB.h +2553 -2553
  172. warp/native/nanovdb/PNanoVDBWrite.h +294 -294
  173. warp/native/noise.h +850 -850
  174. warp/native/quat.h +1084 -1084
  175. warp/native/rand.h +299 -299
  176. warp/native/range.h +108 -108
  177. warp/native/reduce.cpp +156 -156
  178. warp/native/reduce.cu +348 -348
  179. warp/native/runlength_encode.cpp +61 -61
  180. warp/native/runlength_encode.cu +46 -46
  181. warp/native/scan.cpp +30 -30
  182. warp/native/scan.cu +36 -36
  183. warp/native/scan.h +7 -7
  184. warp/native/solid_angle.h +442 -442
  185. warp/native/sort.cpp +94 -94
  186. warp/native/sort.cu +97 -97
  187. warp/native/sort.h +14 -14
  188. warp/native/sparse.cpp +337 -337
  189. warp/native/sparse.cu +544 -544
  190. warp/native/spatial.h +630 -630
  191. warp/native/svd.h +562 -562
  192. warp/native/temp_buffer.h +30 -30
  193. warp/native/vec.h +1132 -1132
  194. warp/native/volume.cpp +297 -297
  195. warp/native/volume.cu +32 -32
  196. warp/native/volume.h +538 -538
  197. warp/native/volume_builder.cu +425 -425
  198. warp/native/volume_builder.h +19 -19
  199. warp/native/warp.cpp +1057 -1052
  200. warp/native/warp.cu +2943 -2828
  201. warp/native/warp.h +313 -305
  202. warp/optim/__init__.py +9 -9
  203. warp/optim/adam.py +120 -120
  204. warp/optim/linear.py +1104 -939
  205. warp/optim/sgd.py +104 -92
  206. warp/render/__init__.py +10 -10
  207. warp/render/render_opengl.py +3217 -3204
  208. warp/render/render_usd.py +768 -749
  209. warp/render/utils.py +152 -150
  210. warp/sim/__init__.py +52 -59
  211. warp/sim/articulation.py +685 -685
  212. warp/sim/collide.py +1594 -1590
  213. warp/sim/import_mjcf.py +489 -481
  214. warp/sim/import_snu.py +220 -221
  215. warp/sim/import_urdf.py +536 -516
  216. warp/sim/import_usd.py +887 -881
  217. warp/sim/inertia.py +316 -317
  218. warp/sim/integrator.py +234 -233
  219. warp/sim/integrator_euler.py +1956 -1956
  220. warp/sim/integrator_featherstone.py +1910 -1991
  221. warp/sim/integrator_xpbd.py +3294 -3312
  222. warp/sim/model.py +4473 -4314
  223. warp/sim/particles.py +113 -112
  224. warp/sim/render.py +417 -403
  225. warp/sim/utils.py +413 -410
  226. warp/sparse.py +1227 -1227
  227. warp/stubs.py +2109 -2469
  228. warp/tape.py +1162 -225
  229. warp/tests/__init__.py +1 -1
  230. warp/tests/__main__.py +4 -4
  231. warp/tests/assets/torus.usda +105 -105
  232. warp/tests/aux_test_class_kernel.py +26 -26
  233. warp/tests/aux_test_compile_consts_dummy.py +10 -10
  234. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -21
  235. warp/tests/aux_test_dependent.py +22 -22
  236. warp/tests/aux_test_grad_customs.py +23 -23
  237. warp/tests/aux_test_reference.py +11 -11
  238. warp/tests/aux_test_reference_reference.py +10 -10
  239. warp/tests/aux_test_square.py +17 -17
  240. warp/tests/aux_test_unresolved_func.py +14 -14
  241. warp/tests/aux_test_unresolved_symbol.py +14 -14
  242. warp/tests/disabled_kinematics.py +239 -239
  243. warp/tests/run_coverage_serial.py +31 -31
  244. warp/tests/test_adam.py +157 -157
  245. warp/tests/test_arithmetic.py +1124 -1124
  246. warp/tests/test_array.py +2417 -2326
  247. warp/tests/test_array_reduce.py +150 -150
  248. warp/tests/test_async.py +668 -656
  249. warp/tests/test_atomic.py +141 -141
  250. warp/tests/test_bool.py +204 -149
  251. warp/tests/test_builtins_resolution.py +1292 -1292
  252. warp/tests/test_bvh.py +164 -171
  253. warp/tests/test_closest_point_edge_edge.py +228 -228
  254. warp/tests/test_codegen.py +566 -553
  255. warp/tests/test_compile_consts.py +97 -101
  256. warp/tests/test_conditional.py +246 -246
  257. warp/tests/test_copy.py +232 -215
  258. warp/tests/test_ctypes.py +632 -632
  259. warp/tests/test_dense.py +67 -67
  260. warp/tests/test_devices.py +91 -98
  261. warp/tests/test_dlpack.py +530 -529
  262. warp/tests/test_examples.py +400 -378
  263. warp/tests/test_fabricarray.py +955 -955
  264. warp/tests/test_fast_math.py +62 -54
  265. warp/tests/test_fem.py +1277 -1278
  266. warp/tests/test_fp16.py +130 -130
  267. warp/tests/test_func.py +338 -337
  268. warp/tests/test_generics.py +571 -571
  269. warp/tests/test_grad.py +746 -640
  270. warp/tests/test_grad_customs.py +333 -336
  271. warp/tests/test_hash_grid.py +210 -164
  272. warp/tests/test_import.py +39 -39
  273. warp/tests/test_indexedarray.py +1134 -1134
  274. warp/tests/test_intersect.py +67 -67
  275. warp/tests/test_jax.py +307 -307
  276. warp/tests/test_large.py +167 -164
  277. warp/tests/test_launch.py +354 -354
  278. warp/tests/test_lerp.py +261 -261
  279. warp/tests/test_linear_solvers.py +191 -171
  280. warp/tests/test_lvalue.py +421 -493
  281. warp/tests/test_marching_cubes.py +65 -65
  282. warp/tests/test_mat.py +1801 -1827
  283. warp/tests/test_mat_lite.py +115 -115
  284. warp/tests/test_mat_scalar_ops.py +2907 -2889
  285. warp/tests/test_math.py +126 -193
  286. warp/tests/test_matmul.py +500 -499
  287. warp/tests/test_matmul_lite.py +410 -410
  288. warp/tests/test_mempool.py +188 -190
  289. warp/tests/test_mesh.py +284 -324
  290. warp/tests/test_mesh_query_aabb.py +228 -241
  291. warp/tests/test_mesh_query_point.py +692 -702
  292. warp/tests/test_mesh_query_ray.py +292 -303
  293. warp/tests/test_mlp.py +276 -276
  294. warp/tests/test_model.py +110 -110
  295. warp/tests/test_modules_lite.py +39 -39
  296. warp/tests/test_multigpu.py +163 -163
  297. warp/tests/test_noise.py +248 -248
  298. warp/tests/test_operators.py +250 -250
  299. warp/tests/test_options.py +123 -125
  300. warp/tests/test_peer.py +133 -137
  301. warp/tests/test_pinned.py +78 -78
  302. warp/tests/test_print.py +54 -54
  303. warp/tests/test_quat.py +2086 -2086
  304. warp/tests/test_rand.py +288 -288
  305. warp/tests/test_reload.py +217 -217
  306. warp/tests/test_rounding.py +179 -179
  307. warp/tests/test_runlength_encode.py +190 -190
  308. warp/tests/test_sim_grad.py +243 -0
  309. warp/tests/test_sim_kinematics.py +91 -97
  310. warp/tests/test_smoothstep.py +168 -168
  311. warp/tests/test_snippet.py +305 -266
  312. warp/tests/test_sparse.py +468 -460
  313. warp/tests/test_spatial.py +2148 -2148
  314. warp/tests/test_streams.py +486 -473
  315. warp/tests/test_struct.py +710 -675
  316. warp/tests/test_tape.py +173 -148
  317. warp/tests/test_torch.py +743 -743
  318. warp/tests/test_transient_module.py +87 -87
  319. warp/tests/test_types.py +556 -659
  320. warp/tests/test_utils.py +490 -499
  321. warp/tests/test_vec.py +1264 -1268
  322. warp/tests/test_vec_lite.py +73 -73
  323. warp/tests/test_vec_scalar_ops.py +2099 -2099
  324. warp/tests/test_verify_fp.py +94 -94
  325. warp/tests/test_volume.py +737 -736
  326. warp/tests/test_volume_write.py +255 -265
  327. warp/tests/unittest_serial.py +37 -37
  328. warp/tests/unittest_suites.py +363 -359
  329. warp/tests/unittest_utils.py +603 -578
  330. warp/tests/unused_test_misc.py +71 -71
  331. warp/tests/walkthrough_debug.py +85 -85
  332. warp/thirdparty/appdirs.py +598 -598
  333. warp/thirdparty/dlpack.py +143 -143
  334. warp/thirdparty/unittest_parallel.py +566 -561
  335. warp/torch.py +321 -295
  336. warp/types.py +4504 -4450
  337. warp/utils.py +1008 -821
  338. {warp_lang-1.0.1.dist-info → warp_lang-1.1.0.dist-info}/LICENSE.md +126 -126
  339. {warp_lang-1.0.1.dist-info → warp_lang-1.1.0.dist-info}/METADATA +338 -400
  340. warp_lang-1.1.0.dist-info/RECORD +352 -0
  341. warp/examples/assets/cube.usda +0 -42
  342. warp/examples/assets/sphere.usda +0 -56
  343. warp/examples/assets/torus.usda +0 -105
  344. warp_lang-1.0.1.dist-info/RECORD +0 -352
  345. {warp_lang-1.0.1.dist-info → warp_lang-1.1.0.dist-info}/WHEEL +0 -0
  346. {warp_lang-1.0.1.dist-info → warp_lang-1.1.0.dist-info}/top_level.txt +0 -0
warp/native/builtin.h CHANGED
@@ -1,1560 +1,1560 @@
1
- /** Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
2
- * NVIDIA CORPORATION and its licensors retain all intellectual property
3
- * and proprietary rights in and to this software, related documentation
4
- * and any modifications thereto. Any use, reproduction, disclosure or
5
- * distribution of this software and related documentation without an express
6
- * license agreement from NVIDIA CORPORATION is strictly prohibited.
7
- */
8
-
9
- #pragma once
10
-
11
- // All built-in types and functions. To be compatible with runtime NVRTC compilation
12
- // this header must be independently compilable (i.e.: without external SDK headers)
13
- // to achieve this we redefine a subset of CRT functions (printf, pow, sin, cos, etc)
14
-
15
- #include "crt.h"
16
-
17
- #ifdef _WIN32
18
- #define __restrict__ __restrict
19
- #endif
20
-
21
- #if !defined(__CUDACC__)
22
- #define CUDA_CALLABLE
23
- #define CUDA_CALLABLE_DEVICE
24
- #else
25
- #define CUDA_CALLABLE __host__ __device__
26
- #define CUDA_CALLABLE_DEVICE __device__
27
- #endif
28
-
29
- #ifdef WP_VERIFY_FP
30
- #define FP_CHECK 1
31
- #define DO_IF_FPCHECK(X) {X}
32
- #define DO_IF_NO_FPCHECK(X)
33
- #else
34
- #define FP_CHECK 0
35
- #define DO_IF_FPCHECK(X)
36
- #define DO_IF_NO_FPCHECK(X) {X}
37
- #endif
38
-
39
- #define RAD_TO_DEG 57.29577951308232087679
40
- #define DEG_TO_RAD 0.01745329251994329577
41
-
42
- #if defined(__CUDACC__) && !defined(_MSC_VER)
43
- __device__ void __debugbreak() {}
44
- #endif
45
-
46
- namespace wp
47
- {
48
-
49
- // numeric types (used from generated kernels)
50
- typedef float float32;
51
- typedef double float64;
52
-
53
- typedef int8_t int8;
54
- typedef uint8_t uint8;
55
-
56
- typedef int16_t int16;
57
- typedef uint16_t uint16;
58
-
59
- typedef int32_t int32;
60
- typedef uint32_t uint32;
61
-
62
- typedef int64_t int64;
63
- typedef uint64_t uint64;
64
-
65
-
66
- // matches Python string type for constant strings
67
- typedef const char* str;
68
-
69
-
70
-
71
- struct half;
72
-
73
- CUDA_CALLABLE half float_to_half(float x);
74
- CUDA_CALLABLE float half_to_float(half x);
75
-
76
- struct half
77
- {
78
- CUDA_CALLABLE inline half() : u(0) {}
79
-
80
- CUDA_CALLABLE inline half(float f)
81
- {
82
- *this = float_to_half(f);
83
- }
84
-
85
- unsigned short u;
86
-
87
- CUDA_CALLABLE inline bool operator==(const half& h) const { return u == h.u; }
88
- CUDA_CALLABLE inline bool operator!=(const half& h) const { return u != h.u; }
89
- CUDA_CALLABLE inline bool operator>(const half& h) const { return half_to_float(*this) > half_to_float(h); }
90
- CUDA_CALLABLE inline bool operator>=(const half& h) const { return half_to_float(*this) >= half_to_float(h); }
91
- CUDA_CALLABLE inline bool operator<(const half& h) const { return half_to_float(*this) < half_to_float(h); }
92
- CUDA_CALLABLE inline bool operator<=(const half& h) const { return half_to_float(*this) <= half_to_float(h); }
93
-
94
- CUDA_CALLABLE inline bool operator!() const
95
- {
96
- return float32(*this) == 0;
97
- }
98
-
99
- CUDA_CALLABLE inline half operator*=(const half& h)
100
- {
101
- half prod = half(float32(*this) * float32(h));
102
- this->u = prod.u;
103
- return *this;
104
- }
105
-
106
- CUDA_CALLABLE inline half operator/=(const half& h)
107
- {
108
- half quot = half(float32(*this) / float32(h));
109
- this->u = quot.u;
110
- return *this;
111
- }
112
-
113
- CUDA_CALLABLE inline half operator+=(const half& h)
114
- {
115
- half sum = half(float32(*this) + float32(h));
116
- this->u = sum.u;
117
- return *this;
118
- }
119
-
120
- CUDA_CALLABLE inline half operator-=(const half& h)
121
- {
122
- half diff = half(float32(*this) - float32(h));
123
- this->u = diff.u;
124
- return *this;
125
- }
126
-
127
- CUDA_CALLABLE inline operator float32() const { return float32(half_to_float(*this)); }
128
- CUDA_CALLABLE inline operator float64() const { return float64(half_to_float(*this)); }
129
- CUDA_CALLABLE inline operator int8() const { return int8(half_to_float(*this)); }
130
- CUDA_CALLABLE inline operator uint8() const { return uint8(half_to_float(*this)); }
131
- CUDA_CALLABLE inline operator int16() const { return int16(half_to_float(*this)); }
132
- CUDA_CALLABLE inline operator uint16() const { return uint16(half_to_float(*this)); }
133
- CUDA_CALLABLE inline operator int32() const { return int32(half_to_float(*this)); }
134
- CUDA_CALLABLE inline operator uint32() const { return uint32(half_to_float(*this)); }
135
- CUDA_CALLABLE inline operator int64() const { return int64(half_to_float(*this)); }
136
- CUDA_CALLABLE inline operator uint64() const { return uint64(half_to_float(*this)); }
137
- };
138
-
139
- static_assert(sizeof(half) == 2, "Size of half / float16 type must be 2-bytes");
140
-
141
- typedef half float16;
142
-
143
- #if defined(__CUDA_ARCH__)
144
-
145
- CUDA_CALLABLE inline half float_to_half(float x)
146
- {
147
- half h;
148
- asm("{ cvt.rn.f16.f32 %0, %1;}\n" : "=h"(h.u) : "f"(x));
149
- return h;
150
- }
151
-
152
- CUDA_CALLABLE inline float half_to_float(half x)
153
- {
154
- float val;
155
- asm("{ cvt.f32.f16 %0, %1;}\n" : "=f"(val) : "h"(x.u));
156
- return val;
157
- }
158
-
159
- #elif defined(__clang__)
160
-
161
- // _Float16 is Clang's native half-precision floating-point type
162
- inline half float_to_half(float x)
163
- {
164
-
165
- _Float16 f16 = static_cast<_Float16>(x);
166
- return *reinterpret_cast<half*>(&f16);
167
- }
168
-
169
- inline float half_to_float(half h)
170
- {
171
- _Float16 f16 = *reinterpret_cast<_Float16*>(&h);
172
- return static_cast<float>(f16);
173
- }
174
-
175
- #else // Native C++ for Warp builtins outside of kernels
176
-
177
- extern "C" WP_API uint16_t float_to_half_bits(float x);
178
- extern "C" WP_API float half_bits_to_float(uint16_t u);
179
-
180
- inline half float_to_half(float x)
181
- {
182
- half h;
183
- h.u = float_to_half_bits(x);
184
- return h;
185
- }
186
-
187
- inline float half_to_float(half h)
188
- {
189
- return half_bits_to_float(h.u);
190
- }
191
-
192
- #endif
193
-
194
-
195
- // BAD operator implementations for fp16 arithmetic...
196
-
197
- // negation:
198
- inline CUDA_CALLABLE half operator - (half a)
199
- {
200
- return float_to_half( -half_to_float(a) );
201
- }
202
-
203
- inline CUDA_CALLABLE half operator + (half a,half b)
204
- {
205
- return float_to_half( half_to_float(a) + half_to_float(b) );
206
- }
207
-
208
- inline CUDA_CALLABLE half operator - (half a,half b)
209
- {
210
- return float_to_half( half_to_float(a) - half_to_float(b) );
211
- }
212
-
213
- inline CUDA_CALLABLE half operator * (half a,half b)
214
- {
215
- return float_to_half( half_to_float(a) * half_to_float(b) );
216
- }
217
-
218
- inline CUDA_CALLABLE half operator * (half a,double b)
219
- {
220
- return float_to_half( half_to_float(a) * b );
221
- }
222
-
223
- inline CUDA_CALLABLE half operator * (double a,half b)
224
- {
225
- return float_to_half( a * half_to_float(b) );
226
- }
227
-
228
- inline CUDA_CALLABLE half operator / (half a,half b)
229
- {
230
- return float_to_half( half_to_float(a) / half_to_float(b) );
231
- }
232
-
233
-
234
-
235
-
236
-
237
- template <typename T>
238
- CUDA_CALLABLE float cast_float(T x) { return (float)(x); }
239
-
240
- template <typename T>
241
- CUDA_CALLABLE int cast_int(T x) { return (int)(x); }
242
-
243
- template <typename T>
244
- CUDA_CALLABLE void adj_cast_float(T x, T& adj_x, float adj_ret) { adj_x += T(adj_ret); }
245
-
246
- template <typename T>
247
- CUDA_CALLABLE void adj_cast_int(T x, T& adj_x, int adj_ret) { adj_x += adj_ret; }
248
-
249
- template <typename T>
250
- CUDA_CALLABLE inline void adj_int8(T, T&, int8) {}
251
- template <typename T>
252
- CUDA_CALLABLE inline void adj_uint8(T, T&, uint8) {}
253
- template <typename T>
254
- CUDA_CALLABLE inline void adj_int16(T, T&, int16) {}
255
- template <typename T>
256
- CUDA_CALLABLE inline void adj_uint16(T, T&, uint16) {}
257
- template <typename T>
258
- CUDA_CALLABLE inline void adj_int32(T, T&, int32) {}
259
- template <typename T>
260
- CUDA_CALLABLE inline void adj_uint32(T, T&, uint32) {}
261
- template <typename T>
262
- CUDA_CALLABLE inline void adj_int64(T, T&, int64) {}
263
- template <typename T>
264
- CUDA_CALLABLE inline void adj_uint64(T, T&, uint64) {}
265
-
266
-
267
- template <typename T>
268
- CUDA_CALLABLE inline void adj_float16(T x, T& adj_x, float16 adj_ret) { adj_x += T(adj_ret); }
269
- template <typename T>
270
- CUDA_CALLABLE inline void adj_float32(T x, T& adj_x, float32 adj_ret) { adj_x += T(adj_ret); }
271
- template <typename T>
272
- CUDA_CALLABLE inline void adj_float64(T x, T& adj_x, float64 adj_ret) { adj_x += T(adj_ret); }
273
-
274
-
275
- #define kEps 0.0f
276
-
277
- // basic ops for integer types
278
- #define DECLARE_INT_OPS(T) \
279
- inline CUDA_CALLABLE T mul(T a, T b) { return a*b; } \
280
- inline CUDA_CALLABLE T div(T a, T b) { return a/b; } \
281
- inline CUDA_CALLABLE T add(T a, T b) { return a+b; } \
282
- inline CUDA_CALLABLE T sub(T a, T b) { return a-b; } \
283
- inline CUDA_CALLABLE T mod(T a, T b) { return a%b; } \
284
- inline CUDA_CALLABLE T min(T a, T b) { return a<b?a:b; } \
285
- inline CUDA_CALLABLE T max(T a, T b) { return a>b?a:b; } \
286
- inline CUDA_CALLABLE T clamp(T x, T a, T b) { return min(max(a, x), b); } \
287
- inline CUDA_CALLABLE T floordiv(T a, T b) { return a/b; } \
288
- inline CUDA_CALLABLE T nonzero(T x) { return x == T(0) ? T(0) : T(1); } \
289
- inline CUDA_CALLABLE T sqrt(T x) { return 0; } \
290
- inline CUDA_CALLABLE T bit_and(T a, T b) { return a&b; } \
291
- inline CUDA_CALLABLE T bit_or(T a, T b) { return a|b; } \
292
- inline CUDA_CALLABLE T bit_xor(T a, T b) { return a^b; } \
293
- inline CUDA_CALLABLE T lshift(T a, T b) { return a<<b; } \
294
- inline CUDA_CALLABLE T rshift(T a, T b) { return a>>b; } \
295
- inline CUDA_CALLABLE T invert(T x) { return ~x; } \
296
- inline CUDA_CALLABLE bool isfinite(T x) { return true; } \
297
- inline CUDA_CALLABLE void adj_mul(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
298
- inline CUDA_CALLABLE void adj_div(T a, T b, T ret, T& adj_a, T& adj_b, T adj_ret) { } \
299
- inline CUDA_CALLABLE void adj_add(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
300
- inline CUDA_CALLABLE void adj_sub(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
301
- inline CUDA_CALLABLE void adj_mod(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
302
- inline CUDA_CALLABLE void adj_min(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
303
- inline CUDA_CALLABLE void adj_max(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
304
- inline CUDA_CALLABLE void adj_abs(T x, T adj_x, T& adj_ret) { } \
305
- inline CUDA_CALLABLE void adj_sign(T x, T adj_x, T& adj_ret) { } \
306
- inline CUDA_CALLABLE void adj_clamp(T x, T a, T b, T& adj_x, T& adj_a, T& adj_b, T adj_ret) { } \
307
- inline CUDA_CALLABLE void adj_floordiv(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
308
- inline CUDA_CALLABLE void adj_step(T x, T& adj_x, T adj_ret) { } \
309
- inline CUDA_CALLABLE void adj_nonzero(T x, T& adj_x, T adj_ret) { } \
310
- inline CUDA_CALLABLE void adj_sqrt(T x, T adj_x, T& adj_ret) { } \
311
- inline CUDA_CALLABLE void adj_bit_and(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
312
- inline CUDA_CALLABLE void adj_bit_or(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
313
- inline CUDA_CALLABLE void adj_bit_xor(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
314
- inline CUDA_CALLABLE void adj_lshift(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
315
- inline CUDA_CALLABLE void adj_rshift(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
316
- inline CUDA_CALLABLE void adj_invert(T x, T adj_x, T& adj_ret) { }
317
-
318
- inline CUDA_CALLABLE int8 abs(int8 x) { return ::abs(x); }
319
- inline CUDA_CALLABLE int16 abs(int16 x) { return ::abs(x); }
320
- inline CUDA_CALLABLE int32 abs(int32 x) { return ::abs(x); }
321
- inline CUDA_CALLABLE int64 abs(int64 x) { return ::llabs(x); }
322
- inline CUDA_CALLABLE uint8 abs(uint8 x) { return x; }
323
- inline CUDA_CALLABLE uint16 abs(uint16 x) { return x; }
324
- inline CUDA_CALLABLE uint32 abs(uint32 x) { return x; }
325
- inline CUDA_CALLABLE uint64 abs(uint64 x) { return x; }
326
-
327
- DECLARE_INT_OPS(int8)
328
- DECLARE_INT_OPS(int16)
329
- DECLARE_INT_OPS(int32)
330
- DECLARE_INT_OPS(int64)
331
- DECLARE_INT_OPS(uint8)
332
- DECLARE_INT_OPS(uint16)
333
- DECLARE_INT_OPS(uint32)
334
- DECLARE_INT_OPS(uint64)
335
-
336
-
337
- inline CUDA_CALLABLE int8 step(int8 x) { return x < 0 ? 1 : 0; }
338
- inline CUDA_CALLABLE int16 step(int16 x) { return x < 0 ? 1 : 0; }
339
- inline CUDA_CALLABLE int32 step(int32 x) { return x < 0 ? 1 : 0; }
340
- inline CUDA_CALLABLE int64 step(int64 x) { return x < 0 ? 1 : 0; }
341
- inline CUDA_CALLABLE uint8 step(uint8 x) { return 0; }
342
- inline CUDA_CALLABLE uint16 step(uint16 x) { return 0; }
343
- inline CUDA_CALLABLE uint32 step(uint32 x) { return 0; }
344
- inline CUDA_CALLABLE uint64 step(uint64 x) { return 0; }
345
-
346
-
347
- inline CUDA_CALLABLE int8 sign(int8 x) { return x < 0 ? -1 : 1; }
348
- inline CUDA_CALLABLE int8 sign(int16 x) { return x < 0 ? -1 : 1; }
349
- inline CUDA_CALLABLE int8 sign(int32 x) { return x < 0 ? -1 : 1; }
350
- inline CUDA_CALLABLE int8 sign(int64 x) { return x < 0 ? -1 : 1; }
351
- inline CUDA_CALLABLE uint8 sign(uint8 x) { return 1; }
352
- inline CUDA_CALLABLE uint16 sign(uint16 x) { return 1; }
353
- inline CUDA_CALLABLE uint32 sign(uint32 x) { return 1; }
354
- inline CUDA_CALLABLE uint64 sign(uint64 x) { return 1; }
355
-
356
-
357
- // Catch-all for non-float types
358
- template<typename T>
359
- inline bool CUDA_CALLABLE isfinite(const T&)
360
- {
361
- return true;
362
- }
363
-
364
- inline bool CUDA_CALLABLE isfinite(half x)
365
- {
366
- return ::isfinite(float(x));
367
- }
368
- inline bool CUDA_CALLABLE isfinite(float x)
369
- {
370
- return ::isfinite(x);
371
- }
372
- inline bool CUDA_CALLABLE isfinite(double x)
373
- {
374
- return ::isfinite(x);
375
- }
376
-
377
- template<typename T>
378
- inline CUDA_CALLABLE void print(const T&)
379
- {
380
- printf("<type without print implementation>\n");
381
- }
382
-
383
- inline CUDA_CALLABLE void print(float16 f)
384
- {
385
- printf("%g\n", half_to_float(f));
386
- }
387
-
388
- inline CUDA_CALLABLE void print(float f)
389
- {
390
- printf("%g\n", f);
391
- }
392
-
393
- inline CUDA_CALLABLE void print(double f)
394
- {
395
- printf("%g\n", f);
396
- }
397
-
398
-
399
- // basic ops for float types
400
- #define DECLARE_FLOAT_OPS(T) \
401
- inline CUDA_CALLABLE T mul(T a, T b) { return a*b; } \
402
- inline CUDA_CALLABLE T add(T a, T b) { return a+b; } \
403
- inline CUDA_CALLABLE T sub(T a, T b) { return a-b; } \
404
- inline CUDA_CALLABLE T min(T a, T b) { return a<b?a:b; } \
405
- inline CUDA_CALLABLE T max(T a, T b) { return a>b?a:b; } \
406
- inline CUDA_CALLABLE T sign(T x) { return x < T(0) ? -1 : 1; } \
407
- inline CUDA_CALLABLE T step(T x) { return x < T(0) ? T(1) : T(0); }\
408
- inline CUDA_CALLABLE T nonzero(T x) { return x == T(0) ? T(0) : T(1); }\
409
- inline CUDA_CALLABLE T clamp(T x, T a, T b) { return min(max(a, x), b); }\
410
- inline CUDA_CALLABLE void adj_abs(T x, T& adj_x, T adj_ret) \
411
- {\
412
- if (x < T(0))\
413
- adj_x -= adj_ret;\
414
- else\
415
- adj_x += adj_ret;\
416
- }\
417
- inline CUDA_CALLABLE void adj_mul(T a, T b, T& adj_a, T& adj_b, T adj_ret) { adj_a += b*adj_ret; adj_b += a*adj_ret; } \
418
- inline CUDA_CALLABLE void adj_add(T a, T b, T& adj_a, T& adj_b, T adj_ret) { adj_a += adj_ret; adj_b += adj_ret; } \
419
- inline CUDA_CALLABLE void adj_sub(T a, T b, T& adj_a, T& adj_b, T adj_ret) { adj_a += adj_ret; adj_b -= adj_ret; } \
420
- inline CUDA_CALLABLE void adj_min(T a, T b, T& adj_a, T& adj_b, T adj_ret) \
421
- { \
422
- if (a < b) \
423
- adj_a += adj_ret; \
424
- else \
425
- adj_b += adj_ret; \
426
- } \
427
- inline CUDA_CALLABLE void adj_max(T a, T b, T& adj_a, T& adj_b, T adj_ret) \
428
- { \
429
- if (a > b) \
430
- adj_a += adj_ret; \
431
- else \
432
- adj_b += adj_ret; \
433
- } \
434
- inline CUDA_CALLABLE void adj_floordiv(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
435
- inline CUDA_CALLABLE void adj_mod(T a, T b, T& adj_a, T& adj_b, T adj_ret){ adj_a += adj_ret; }\
436
- inline CUDA_CALLABLE void adj_sign(T x, T adj_x, T& adj_ret) { }\
437
- inline CUDA_CALLABLE void adj_step(T x, T& adj_x, T adj_ret) { }\
438
- inline CUDA_CALLABLE void adj_nonzero(T x, T& adj_x, T adj_ret) { }\
439
- inline CUDA_CALLABLE void adj_clamp(T x, T a, T b, T& adj_x, T& adj_a, T& adj_b, T adj_ret)\
440
- {\
441
- if (x < a)\
442
- adj_a += adj_ret;\
443
- else if (x > b)\
444
- adj_b += adj_ret;\
445
- else\
446
- adj_x += adj_ret;\
447
- }\
448
- inline CUDA_CALLABLE T div(T a, T b)\
449
- {\
450
- DO_IF_FPCHECK(\
451
- if (!isfinite(a) || !isfinite(b) || b == T(0))\
452
- {\
453
- printf("%s:%d div(%f, %f)\n", __FILE__, __LINE__, float(a), float(b));\
454
- assert(0);\
455
- })\
456
- return a/b;\
457
- }\
458
- inline CUDA_CALLABLE void adj_div(T a, T b, T ret, T& adj_a, T& adj_b, T adj_ret)\
459
- {\
460
- adj_a += adj_ret/b;\
461
- adj_b -= adj_ret*(ret)/b;\
462
- DO_IF_FPCHECK(\
463
- if (!isfinite(adj_a) || !isfinite(adj_b))\
464
- {\
465
- printf("%s:%d - adj_div(%f, %f, %f, %f, %f)\n", __FILE__, __LINE__, float(a), float(b), float(adj_a), float(adj_b), float(adj_ret));\
466
- assert(0);\
467
- })\
468
- }\
469
-
470
- DECLARE_FLOAT_OPS(float16)
471
- DECLARE_FLOAT_OPS(float32)
472
- DECLARE_FLOAT_OPS(float64)
473
-
474
-
475
-
476
- // basic ops for float types
477
- inline CUDA_CALLABLE float16 mod(float16 a, float16 b)
478
- {
479
- #if FP_CHECK
480
- if (!isfinite(a) || !isfinite(b) || float(b) == 0.0f)
481
- {
482
- printf("%s:%d mod(%f, %f)\n", __FILE__, __LINE__, float(a), float(b));
483
- assert(0);
484
- }
485
- #endif
486
- return fmodf(float(a), float(b));
487
- }
488
-
489
- inline CUDA_CALLABLE float32 mod(float32 a, float32 b)
490
- {
491
- #if FP_CHECK
492
- if (!isfinite(a) || !isfinite(b) || b == 0.0f)
493
- {
494
- printf("%s:%d mod(%f, %f)\n", __FILE__, __LINE__, a, b);
495
- assert(0);
496
- }
497
- #endif
498
- return fmodf(a, b);
499
- }
500
-
501
- inline CUDA_CALLABLE double mod(double a, double b)
502
- {
503
- #if FP_CHECK
504
- if (!isfinite(a) || !isfinite(b) || b == 0.0f)
505
- {
506
- printf("%s:%d mod(%f, %f)\n", __FILE__, __LINE__, a, b);
507
- assert(0);
508
- }
509
- #endif
510
- return fmod(a, b);
511
- }
512
-
513
- inline CUDA_CALLABLE half log(half a)
514
- {
515
- #if FP_CHECK
516
- if (!isfinite(a) || float(a) < 0.0f)
517
- {
518
- printf("%s:%d log(%f)\n", __FILE__, __LINE__, float(a));
519
- assert(0);
520
- }
521
- #endif
522
- return ::logf(a);
523
- }
524
-
525
- inline CUDA_CALLABLE float log(float a)
526
- {
527
- #if FP_CHECK
528
- if (!isfinite(a) || a < 0.0f)
529
- {
530
- printf("%s:%d log(%f)\n", __FILE__, __LINE__, a);
531
- assert(0);
532
- }
533
- #endif
534
- return ::logf(a);
535
- }
536
-
537
- inline CUDA_CALLABLE double log(double a)
538
- {
539
- #if FP_CHECK
540
- if (!isfinite(a) || a < 0.0)
541
- {
542
- printf("%s:%d log(%f)\n", __FILE__, __LINE__, a);
543
- assert(0);
544
- }
545
- #endif
546
- return ::log(a);
547
- }
548
-
549
- inline CUDA_CALLABLE half log2(half a)
550
- {
551
- #if FP_CHECK
552
- if (!isfinite(a) || float(a) < 0.0f)
553
- {
554
- printf("%s:%d log2(%f)\n", __FILE__, __LINE__, float(a));
555
- assert(0);
556
- }
557
- #endif
558
-
559
- return ::log2f(float(a));
560
- }
561
-
562
- inline CUDA_CALLABLE float log2(float a)
563
- {
564
- #if FP_CHECK
565
- if (!isfinite(a) || a < 0.0f)
566
- {
567
- printf("%s:%d log2(%f)\n", __FILE__, __LINE__, a);
568
- assert(0);
569
- }
570
- #endif
571
-
572
- return ::log2f(a);
573
- }
574
-
575
- inline CUDA_CALLABLE double log2(double a)
576
- {
577
- #if FP_CHECK
578
- if (!isfinite(a) || a < 0.0)
579
- {
580
- printf("%s:%d log2(%f)\n", __FILE__, __LINE__, a);
581
- assert(0);
582
- }
583
- #endif
584
-
585
- return ::log2(a);
586
- }
587
-
588
- inline CUDA_CALLABLE half log10(half a)
589
- {
590
- #if FP_CHECK
591
- if (!isfinite(a) || float(a) < 0.0f)
592
- {
593
- printf("%s:%d log10(%f)\n", __FILE__, __LINE__, float(a));
594
- assert(0);
595
- }
596
- #endif
597
-
598
- return ::log10f(float(a));
599
- }
600
-
601
- inline CUDA_CALLABLE float log10(float a)
602
- {
603
- #if FP_CHECK
604
- if (!isfinite(a) || a < 0.0f)
605
- {
606
- printf("%s:%d log10(%f)\n", __FILE__, __LINE__, a);
607
- assert(0);
608
- }
609
- #endif
610
-
611
- return ::log10f(a);
612
- }
613
-
614
- inline CUDA_CALLABLE double log10(double a)
615
- {
616
- #if FP_CHECK
617
- if (!isfinite(a) || a < 0.0)
618
- {
619
- printf("%s:%d log10(%f)\n", __FILE__, __LINE__, a);
620
- assert(0);
621
- }
622
- #endif
623
-
624
- return ::log10(a);
625
- }
626
-
627
- inline CUDA_CALLABLE half exp(half a)
628
- {
629
- half result = ::expf(float(a));
630
- #if FP_CHECK
631
- if (!isfinite(a) || !isfinite(result))
632
- {
633
- printf("%s:%d exp(%f) = %f\n", __FILE__, __LINE__, float(a), float(result));
634
- assert(0);
635
- }
636
- #endif
637
- return result;
638
- }
639
- inline CUDA_CALLABLE float exp(float a)
640
- {
641
- float result = ::expf(a);
642
- #if FP_CHECK
643
- if (!isfinite(a) || !isfinite(result))
644
- {
645
- printf("%s:%d exp(%f) = %f\n", __FILE__, __LINE__, a, result);
646
- assert(0);
647
- }
648
- #endif
649
- return result;
650
- }
651
- inline CUDA_CALLABLE double exp(double a)
652
- {
653
- double result = ::exp(a);
654
- #if FP_CHECK
655
- if (!isfinite(a) || !isfinite(result))
656
- {
657
- printf("%s:%d exp(%f) = %f\n", __FILE__, __LINE__, a, result);
658
- assert(0);
659
- }
660
- #endif
661
- return result;
662
- }
663
-
664
- inline CUDA_CALLABLE half pow(half a, half b)
665
- {
666
- float result = ::powf(float(a), float(b));
667
- #if FP_CHECK
668
- if (!isfinite(float(a)) || !isfinite(float(b)) || !isfinite(result))
669
- {
670
- printf("%s:%d pow(%f, %f) = %f\n", __FILE__, __LINE__, float(a), float(b), result);
671
- assert(0);
672
- }
673
- #endif
674
- return result;
675
- }
676
-
677
- inline CUDA_CALLABLE float pow(float a, float b)
678
- {
679
- float result = ::powf(a, b);
680
- #if FP_CHECK
681
- if (!isfinite(a) || !isfinite(b) || !isfinite(result))
682
- {
683
- printf("%s:%d pow(%f, %f) = %f\n", __FILE__, __LINE__, a, b, result);
684
- assert(0);
685
- }
686
- #endif
687
- return result;
688
- }
689
-
690
- inline CUDA_CALLABLE double pow(double a, double b)
691
- {
692
- double result = ::pow(a, b);
693
- #if FP_CHECK
694
- if (!isfinite(a) || !isfinite(b) || !isfinite(result))
695
- {
696
- printf("%s:%d pow(%f, %f) = %f\n", __FILE__, __LINE__, a, b, result);
697
- assert(0);
698
- }
699
- #endif
700
- return result;
701
- }
702
-
703
- inline CUDA_CALLABLE half floordiv(half a, half b)
704
- {
705
- #if FP_CHECK
706
- if (!isfinite(a) || !isfinite(b) || float(b) == 0.0f)
707
- {
708
- printf("%s:%d mod(%f, %f)\n", __FILE__, __LINE__, float(a), float(b));
709
- assert(0);
710
- }
711
- #endif
712
- return floorf(float(a/b));
713
- }
714
- inline CUDA_CALLABLE float floordiv(float a, float b)
715
- {
716
- #if FP_CHECK
717
- if (!isfinite(a) || !isfinite(b) || b == 0.0f)
718
- {
719
- printf("%s:%d mod(%f, %f)\n", __FILE__, __LINE__, a, b);
720
- assert(0);
721
- }
722
- #endif
723
- return floorf(a/b);
724
- }
725
- inline CUDA_CALLABLE double floordiv(double a, double b)
726
- {
727
- #if FP_CHECK
728
- if (!isfinite(a) || !isfinite(b) || b == 0.0)
729
- {
730
- printf("%s:%d mod(%f, %f)\n", __FILE__, __LINE__, a, b);
731
- assert(0);
732
- }
733
- #endif
734
- return ::floor(a/b);
735
- }
736
-
737
- inline CUDA_CALLABLE float leaky_min(float a, float b, float r) { return min(a, b); }
738
- inline CUDA_CALLABLE float leaky_max(float a, float b, float r) { return max(a, b); }
739
-
740
- inline CUDA_CALLABLE half abs(half x) { return ::fabsf(float(x)); }
741
- inline CUDA_CALLABLE float abs(float x) { return ::fabsf(x); }
742
- inline CUDA_CALLABLE double abs(double x) { return ::fabs(x); }
743
-
744
- inline CUDA_CALLABLE float acos(float x){ return ::acosf(min(max(x, -1.0f), 1.0f)); }
745
- inline CUDA_CALLABLE float asin(float x){ return ::asinf(min(max(x, -1.0f), 1.0f)); }
746
- inline CUDA_CALLABLE float atan(float x) { return ::atanf(x); }
747
- inline CUDA_CALLABLE float atan2(float y, float x) { return ::atan2f(y, x); }
748
- inline CUDA_CALLABLE float sin(float x) { return ::sinf(x); }
749
- inline CUDA_CALLABLE float cos(float x) { return ::cosf(x); }
750
-
751
- inline CUDA_CALLABLE double acos(double x){ return ::acos(min(max(x, -1.0), 1.0)); }
752
- inline CUDA_CALLABLE double asin(double x){ return ::asin(min(max(x, -1.0), 1.0)); }
753
- inline CUDA_CALLABLE double atan(double x) { return ::atan(x); }
754
- inline CUDA_CALLABLE double atan2(double y, double x) { return ::atan2(y, x); }
755
- inline CUDA_CALLABLE double sin(double x) { return ::sin(x); }
756
- inline CUDA_CALLABLE double cos(double x) { return ::cos(x); }
757
-
758
- inline CUDA_CALLABLE half acos(half x){ return ::acosf(min(max(float(x), -1.0f), 1.0f)); }
759
- inline CUDA_CALLABLE half asin(half x){ return ::asinf(min(max(float(x), -1.0f), 1.0f)); }
760
- inline CUDA_CALLABLE half atan(half x) { return ::atanf(float(x)); }
761
- inline CUDA_CALLABLE half atan2(half y, half x) { return ::atan2f(float(y), float(x)); }
762
- inline CUDA_CALLABLE half sin(half x) { return ::sinf(float(x)); }
763
- inline CUDA_CALLABLE half cos(half x) { return ::cosf(float(x)); }
764
-
765
-
766
- inline CUDA_CALLABLE float sqrt(float x)
767
- {
768
- #if FP_CHECK
769
- if (x < 0.0f)
770
- {
771
- printf("%s:%d sqrt(%f)\n", __FILE__, __LINE__, x);
772
- assert(0);
773
- }
774
- #endif
775
- return ::sqrtf(x);
776
- }
777
- inline CUDA_CALLABLE double sqrt(double x)
778
- {
779
- #if FP_CHECK
780
- if (x < 0.0)
781
- {
782
- printf("%s:%d sqrt(%f)\n", __FILE__, __LINE__, x);
783
- assert(0);
784
- }
785
- #endif
786
- return ::sqrt(x);
787
- }
788
- inline CUDA_CALLABLE half sqrt(half x)
789
- {
790
- #if FP_CHECK
791
- if (float(x) < 0.0f)
792
- {
793
- printf("%s:%d sqrt(%f)\n", __FILE__, __LINE__, float(x));
794
- assert(0);
795
- }
796
- #endif
797
- return ::sqrtf(float(x));
798
- }
799
-
800
- inline CUDA_CALLABLE float cbrt(float x) { return ::cbrtf(x); }
801
- inline CUDA_CALLABLE double cbrt(double x) { return ::cbrt(x); }
802
- inline CUDA_CALLABLE half cbrt(half x) { return ::cbrtf(float(x)); }
803
-
804
- inline CUDA_CALLABLE float tan(float x) { return ::tanf(x); }
805
- inline CUDA_CALLABLE float sinh(float x) { return ::sinhf(x);}
806
- inline CUDA_CALLABLE float cosh(float x) { return ::coshf(x);}
807
- inline CUDA_CALLABLE float tanh(float x) { return ::tanhf(x);}
808
- inline CUDA_CALLABLE float degrees(float x) { return x * RAD_TO_DEG;}
809
- inline CUDA_CALLABLE float radians(float x) { return x * DEG_TO_RAD;}
810
-
811
- inline CUDA_CALLABLE double tan(double x) { return ::tan(x); }
812
- inline CUDA_CALLABLE double sinh(double x) { return ::sinh(x);}
813
- inline CUDA_CALLABLE double cosh(double x) { return ::cosh(x);}
814
- inline CUDA_CALLABLE double tanh(double x) { return ::tanh(x);}
815
- inline CUDA_CALLABLE double degrees(double x) { return x * RAD_TO_DEG;}
816
- inline CUDA_CALLABLE double radians(double x) { return x * DEG_TO_RAD;}
817
-
818
- inline CUDA_CALLABLE half tan(half x) { return ::tanf(float(x)); }
819
- inline CUDA_CALLABLE half sinh(half x) { return ::sinhf(float(x));}
820
- inline CUDA_CALLABLE half cosh(half x) { return ::coshf(float(x));}
821
- inline CUDA_CALLABLE half tanh(half x) { return ::tanhf(float(x));}
822
- inline CUDA_CALLABLE half degrees(half x) { return x * RAD_TO_DEG;}
823
- inline CUDA_CALLABLE half radians(half x) { return x * DEG_TO_RAD;}
824
-
825
- inline CUDA_CALLABLE float round(float x) { return ::roundf(x); }
826
- inline CUDA_CALLABLE float rint(float x) { return ::rintf(x); }
827
- inline CUDA_CALLABLE float trunc(float x) { return ::truncf(x); }
828
- inline CUDA_CALLABLE float floor(float x) { return ::floorf(x); }
829
- inline CUDA_CALLABLE float ceil(float x) { return ::ceilf(x); }
830
- inline CUDA_CALLABLE float frac(float x) { return x - trunc(x); }
831
-
832
- inline CUDA_CALLABLE double round(double x) { return ::round(x); }
833
- inline CUDA_CALLABLE double rint(double x) { return ::rint(x); }
834
- inline CUDA_CALLABLE double trunc(double x) { return ::trunc(x); }
835
- inline CUDA_CALLABLE double floor(double x) { return ::floor(x); }
836
- inline CUDA_CALLABLE double ceil(double x) { return ::ceil(x); }
837
- inline CUDA_CALLABLE double frac(double x) { return x - trunc(x); }
838
-
839
- inline CUDA_CALLABLE half round(half x) { return ::roundf(float(x)); }
840
- inline CUDA_CALLABLE half rint(half x) { return ::rintf(float(x)); }
841
- inline CUDA_CALLABLE half trunc(half x) { return ::truncf(float(x)); }
842
- inline CUDA_CALLABLE half floor(half x) { return ::floorf(float(x)); }
843
- inline CUDA_CALLABLE half ceil(half x) { return ::ceilf(float(x)); }
844
- inline CUDA_CALLABLE half frac(half x) { return float(x) - trunc(float(x)); }
845
-
846
- #define DECLARE_ADJOINTS(T)\
847
- inline CUDA_CALLABLE void adj_log(T a, T& adj_a, T adj_ret)\
848
- {\
849
- adj_a += (T(1)/a)*adj_ret;\
850
- DO_IF_FPCHECK(if (!isfinite(adj_a))\
851
- {\
852
- printf("%s:%d - adj_log(%f, %f, %f)\n", __FILE__, __LINE__, float(a), float(adj_a), float(adj_ret));\
853
- assert(0);\
854
- })\
855
- }\
856
- inline CUDA_CALLABLE void adj_log2(T a, T& adj_a, T adj_ret)\
857
- { \
858
- adj_a += (T(1)/a)*(T(1)/log(T(2)))*adj_ret; \
859
- DO_IF_FPCHECK(if (!isfinite(adj_a))\
860
- {\
861
- printf("%s:%d - adj_log2(%f, %f, %f)\n", __FILE__, __LINE__, float(a), float(adj_a), float(adj_ret));\
862
- assert(0);\
863
- }) \
864
- }\
865
- inline CUDA_CALLABLE void adj_log10(T a, T& adj_a, T adj_ret)\
866
- {\
867
- adj_a += (T(1)/a)*(T(1)/log(T(10)))*adj_ret; \
868
- DO_IF_FPCHECK(if (!isfinite(adj_a))\
869
- {\
870
- printf("%s:%d - adj_log10(%f, %f, %f)\n", __FILE__, __LINE__, float(a), float(adj_a), float(adj_ret));\
871
- assert(0);\
872
- })\
873
- }\
874
- inline CUDA_CALLABLE void adj_exp(T a, T ret, T& adj_a, T adj_ret) { adj_a += ret*adj_ret; }\
875
- inline CUDA_CALLABLE void adj_pow(T a, T b, T ret, T& adj_a, T& adj_b, T adj_ret)\
876
- { \
877
- adj_a += b*pow(a, b-T(1))*adj_ret;\
878
- adj_b += log(a)*ret*adj_ret;\
879
- DO_IF_FPCHECK(if (!isfinite(adj_a) || !isfinite(adj_b))\
880
- {\
881
- printf("%s:%d - adj_pow(%f, %f, %f, %f, %f)\n", __FILE__, __LINE__, float(a), float(b), float(adj_a), float(adj_b), float(adj_ret));\
882
- assert(0);\
883
- })\
884
- }\
885
- inline CUDA_CALLABLE void adj_leaky_min(T a, T b, T r, T& adj_a, T& adj_b, T& adj_r, T adj_ret)\
886
- {\
887
- if (a < b)\
888
- adj_a += adj_ret;\
889
- else\
890
- {\
891
- adj_a += r*adj_ret;\
892
- adj_b += adj_ret;\
893
- }\
894
- }\
895
- inline CUDA_CALLABLE void adj_leaky_max(T a, T b, T r, T& adj_a, T& adj_b, T& adj_r, T adj_ret)\
896
- {\
897
- if (a > b)\
898
- adj_a += adj_ret;\
899
- else\
900
- {\
901
- adj_a += r*adj_ret;\
902
- adj_b += adj_ret;\
903
- }\
904
- }\
905
- inline CUDA_CALLABLE void adj_acos(T x, T& adj_x, T adj_ret)\
906
- {\
907
- T d = sqrt(T(1)-x*x);\
908
- DO_IF_FPCHECK(adj_x -= (T(1)/d)*adj_ret;\
909
- if (!isfinite(d) || !isfinite(adj_x))\
910
- {\
911
- printf("%s:%d - adj_acos(%f, %f, %f)\n", __FILE__, __LINE__, float(x), float(adj_x), float(adj_ret)); \
912
- assert(0);\
913
- })\
914
- DO_IF_NO_FPCHECK(if (d > T(0))\
915
- adj_x -= (T(1)/d)*adj_ret;)\
916
- }\
917
- inline CUDA_CALLABLE void adj_asin(T x, T& adj_x, T adj_ret)\
918
- {\
919
- T d = sqrt(T(1)-x*x);\
920
- DO_IF_FPCHECK(adj_x += (T(1)/d)*adj_ret;\
921
- if (!isfinite(d) || !isfinite(adj_x))\
922
- {\
923
- printf("%s:%d - adj_asin(%f, %f, %f)\n", __FILE__, __LINE__, float(x), float(adj_x), float(adj_ret)); \
924
- assert(0);\
925
- })\
926
- DO_IF_NO_FPCHECK(if (d > T(0))\
927
- adj_x += (T(1)/d)*adj_ret;)\
928
- }\
929
- inline CUDA_CALLABLE void adj_tan(T x, T& adj_x, T adj_ret)\
930
- {\
931
- T cos_x = cos(x);\
932
- DO_IF_FPCHECK(adj_x += (T(1)/(cos_x*cos_x))*adj_ret;\
933
- if (!isfinite(adj_x) || cos_x == T(0))\
934
- {\
935
- printf("%s:%d - adj_tan(%f, %f, %f)\n", __FILE__, __LINE__, float(x), float(adj_x), float(adj_ret));\
936
- assert(0);\
937
- })\
938
- DO_IF_NO_FPCHECK(if (cos_x != T(0))\
939
- adj_x += (T(1)/(cos_x*cos_x))*adj_ret;)\
940
- }\
941
- inline CUDA_CALLABLE void adj_atan(T x, T& adj_x, T adj_ret)\
942
- {\
943
- adj_x += adj_ret /(x*x + T(1));\
944
- }\
945
- inline CUDA_CALLABLE void adj_atan2(T y, T x, T& adj_y, T& adj_x, T adj_ret)\
946
- {\
947
- T d = x*x + y*y;\
948
- DO_IF_FPCHECK(adj_x -= y/d*adj_ret;\
949
- adj_y += x/d*adj_ret;\
950
- if (!isfinite(adj_x) || !isfinite(adj_y) || d == T(0))\
951
- {\
952
- printf("%s:%d - adj_atan2(%f, %f, %f, %f, %f)\n", __FILE__, __LINE__, float(y), float(x), float(adj_y), float(adj_x), float(adj_ret));\
953
- assert(0);\
954
- })\
955
- DO_IF_NO_FPCHECK(if (d > T(0))\
956
- {\
957
- adj_x -= (y/d)*adj_ret;\
958
- adj_y += (x/d)*adj_ret;\
959
- })\
960
- }\
961
- inline CUDA_CALLABLE void adj_sin(T x, T& adj_x, T adj_ret)\
962
- {\
963
- adj_x += cos(x)*adj_ret;\
964
- }\
965
- inline CUDA_CALLABLE void adj_cos(T x, T& adj_x, T adj_ret)\
966
- {\
967
- adj_x -= sin(x)*adj_ret;\
968
- }\
969
- inline CUDA_CALLABLE void adj_sinh(T x, T& adj_x, T adj_ret)\
970
- {\
971
- adj_x += cosh(x)*adj_ret;\
972
- }\
973
- inline CUDA_CALLABLE void adj_cosh(T x, T& adj_x, T adj_ret)\
974
- {\
975
- adj_x += sinh(x)*adj_ret;\
976
- }\
977
- inline CUDA_CALLABLE void adj_tanh(T x, T ret, T& adj_x, T adj_ret)\
978
- {\
979
- adj_x += (T(1) - ret*ret)*adj_ret;\
980
- }\
981
- inline CUDA_CALLABLE void adj_sqrt(T x, T ret, T& adj_x, T adj_ret)\
982
- {\
983
- adj_x += T(0.5)*(T(1)/ret)*adj_ret;\
984
- DO_IF_FPCHECK(if (!isfinite(adj_x))\
985
- {\
986
- printf("%s:%d - adj_sqrt(%f, %f, %f)\n", __FILE__, __LINE__, float(x), float(adj_x), float(adj_ret));\
987
- assert(0);\
988
- })\
989
- }\
990
- inline CUDA_CALLABLE void adj_cbrt(T x, T ret, T& adj_x, T adj_ret)\
991
- {\
992
- adj_x += (T(1)/T(3))*(T(1)/(ret*ret))*adj_ret;\
993
- DO_IF_FPCHECK(if (!isfinite(adj_x))\
994
- {\
995
- printf("%s:%d - adj_cbrt(%f, %f, %f)\n", __FILE__, __LINE__, float(x), float(adj_x), float(adj_ret));\
996
- assert(0);\
997
- })\
998
- }\
999
- inline CUDA_CALLABLE void adj_degrees(T x, T& adj_x, T adj_ret)\
1000
- {\
1001
- adj_x += RAD_TO_DEG * adj_ret;\
1002
- }\
1003
- inline CUDA_CALLABLE void adj_radians(T x, T& adj_x, T adj_ret)\
1004
- {\
1005
- adj_x += DEG_TO_RAD * adj_ret;\
1006
- }\
1007
- inline CUDA_CALLABLE void adj_round(T x, T& adj_x, T adj_ret){ }\
1008
- inline CUDA_CALLABLE void adj_rint(T x, T& adj_x, T adj_ret){ }\
1009
- inline CUDA_CALLABLE void adj_trunc(T x, T& adj_x, T adj_ret){ }\
1010
- inline CUDA_CALLABLE void adj_floor(T x, T& adj_x, T adj_ret){ }\
1011
- inline CUDA_CALLABLE void adj_ceil(T x, T& adj_x, T adj_ret){ }\
1012
- inline CUDA_CALLABLE void adj_frac(T x, T& adj_x, T adj_ret){ }
1013
-
1014
- DECLARE_ADJOINTS(float16)
1015
- DECLARE_ADJOINTS(float32)
1016
- DECLARE_ADJOINTS(float64)
1017
-
1018
- template <typename C, typename T>
1019
- CUDA_CALLABLE inline T select(const C& cond, const T& a, const T& b)
1020
- {
1021
- // The double NOT operator !! casts to bool without compiler warnings.
1022
- return (!!cond) ? b : a;
1023
- }
1024
-
1025
- template <typename C, typename T>
1026
- CUDA_CALLABLE inline void adj_select(const C& cond, const T& a, const T& b, C& adj_cond, T& adj_a, T& adj_b, const T& adj_ret)
1027
- {
1028
- // The double NOT operator !! casts to bool without compiler warnings.
1029
- if (!!cond)
1030
- adj_b += adj_ret;
1031
- else
1032
- adj_a += adj_ret;
1033
- }
1034
-
1035
- template <typename T>
1036
- CUDA_CALLABLE inline T copy(const T& src)
1037
- {
1038
- return src;
1039
- }
1040
-
1041
- template <typename T>
1042
- CUDA_CALLABLE inline void adj_copy(const T& src, T& adj_src, T& adj_dest)
1043
- {
1044
- adj_src = adj_dest;
1045
- adj_dest = T{};
1046
- }
1047
-
1048
- template <typename T>
1049
- CUDA_CALLABLE inline void assign(T& dest, const T& src)
1050
- {
1051
- dest = src;
1052
- }
1053
-
1054
- template <typename T>
1055
- CUDA_CALLABLE inline void adj_assign(T& dest, const T& src, T& adj_dest, T& adj_src)
1056
- {
1057
- // this is generally a non-differentiable operation since it violates SSA,
1058
- // except in read-modify-write statements which are reversible through backpropagation
1059
- adj_src = adj_dest;
1060
- adj_dest = T{};
1061
- }
1062
-
1063
-
1064
- // some helpful operator overloads (just for C++ use, these are not adjointed)
1065
-
1066
- template <typename T>
1067
- CUDA_CALLABLE inline T& operator += (T& a, const T& b) { a = add(a, b); return a; }
1068
-
1069
- template <typename T>
1070
- CUDA_CALLABLE inline T& operator -= (T& a, const T& b) { a = sub(a, b); return a; }
1071
-
1072
- template <typename T>
1073
- CUDA_CALLABLE inline T operator+(const T& a, const T& b) { return add(a, b); }
1074
-
1075
- template <typename T>
1076
- CUDA_CALLABLE inline T operator-(const T& a, const T& b) { return sub(a, b); }
1077
-
1078
- template <typename T>
1079
- CUDA_CALLABLE inline T pos(const T& x) { return x; }
1080
- template <typename T>
1081
- CUDA_CALLABLE inline void adj_pos(const T& x, T& adj_x, const T& adj_ret) { adj_x += T(adj_ret); }
1082
-
1083
- // unary negation implemented as negative multiply, not sure the fp implications of this
1084
- // may be better as 0.0 - x?
1085
- template <typename T>
1086
- CUDA_CALLABLE inline T neg(const T& x) { return T(0.0) - x; }
1087
- template <typename T>
1088
- CUDA_CALLABLE inline void adj_neg(const T& x, T& adj_x, const T& adj_ret) { adj_x += T(-adj_ret); }
1089
-
1090
- // unary boolean negation
1091
- template <typename T>
1092
- CUDA_CALLABLE inline bool unot(const T& b) { return !b; }
1093
- template <typename T>
1094
- CUDA_CALLABLE inline void adj_unot(const T& b, T& adj_b, const bool& adj_ret) { }
1095
-
1096
- const int LAUNCH_MAX_DIMS = 4; // should match types.py
1097
-
1098
- struct launch_bounds_t
1099
- {
1100
- int shape[LAUNCH_MAX_DIMS]; // size of each dimension
1101
- int ndim; // number of valid dimension
1102
- size_t size; // total number of threads
1103
- };
1104
-
1105
- #ifndef __CUDACC__
1106
- static size_t s_threadIdx;
1107
- #endif
1108
-
1109
- inline CUDA_CALLABLE size_t grid_index()
1110
- {
1111
- #ifdef __CUDACC__
1112
- // Need to cast at least one of the variables being multiplied so that type promotion happens before the multiplication
1113
- size_t grid_index = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
1114
- return grid_index;
1115
- #else
1116
- return s_threadIdx;
1117
- #endif
1118
- }
1119
-
1120
- inline CUDA_CALLABLE int tid(size_t index)
1121
- {
1122
- // For the 1-D tid() we need to warn the user if we're about to provide a truncated index
1123
- // Only do this in _DEBUG when called from device to avoid excessive register allocation
1124
- #if defined(_DEBUG) || !defined(__CUDA_ARCH__)
1125
- if (index > 2147483647) {
1126
- printf("Warp warning: tid() is returning an overflowed int\n");
1127
- }
1128
- #endif
1129
- return static_cast<int>(index);
1130
- }
1131
-
1132
- inline CUDA_CALLABLE_DEVICE void tid(int& i, int& j, size_t index, const launch_bounds_t& launch_bounds)
1133
- {
1134
- const size_t n = launch_bounds.shape[1];
1135
-
1136
- // convert to work item
1137
- i = index/n;
1138
- j = index%n;
1139
- }
1140
-
1141
- inline CUDA_CALLABLE_DEVICE void tid(int& i, int& j, int& k, size_t index, const launch_bounds_t& launch_bounds)
1142
- {
1143
- const size_t n = launch_bounds.shape[1];
1144
- const size_t o = launch_bounds.shape[2];
1145
-
1146
- // convert to work item
1147
- i = index/(n*o);
1148
- j = index%(n*o)/o;
1149
- k = index%o;
1150
- }
1151
-
1152
- inline CUDA_CALLABLE_DEVICE void tid(int& i, int& j, int& k, int& l, size_t index, const launch_bounds_t& launch_bounds)
1153
- {
1154
- const size_t n = launch_bounds.shape[1];
1155
- const size_t o = launch_bounds.shape[2];
1156
- const size_t p = launch_bounds.shape[3];
1157
-
1158
- // convert to work item
1159
- i = index/(n*o*p);
1160
- j = index%(n*o*p)/(o*p);
1161
- k = index%(o*p)/p;
1162
- l = index%p;
1163
- }
1164
-
1165
- template<typename T>
1166
- inline CUDA_CALLABLE T atomic_add(T* buf, T value)
1167
- {
1168
- #if !defined(__CUDA_ARCH__)
1169
- T old = buf[0];
1170
- buf[0] += value;
1171
- return old;
1172
- #else
1173
- return atomicAdd(buf, value);
1174
- #endif
1175
- }
1176
-
1177
- template<>
1178
- inline CUDA_CALLABLE float16 atomic_add(float16* buf, float16 value)
1179
- {
1180
- #if !defined(__CUDA_ARCH__)
1181
- float16 old = buf[0];
1182
- buf[0] += value;
1183
- return old;
1184
- #elif defined(__clang__) // CUDA compiled by Clang
1185
- __half r = atomicAdd(reinterpret_cast<__half*>(buf), *reinterpret_cast<__half*>(&value));
1186
- return *reinterpret_cast<float16*>(&r);
1187
- #else // CUDA compiled by NVRTC
1188
- //return atomicAdd(buf, value);
1189
-
1190
- /* Define __PTR for atomicAdd prototypes below, undef after done */
1191
- #if (defined(_MSC_VER) && defined(_WIN64)) || defined(__LP64__) || defined(__CUDACC_RTC__)
1192
- #define __PTR "l"
1193
- #else
1194
- #define __PTR "r"
1195
- #endif /*(defined(_MSC_VER) && defined(_WIN64)) || defined(__LP64__) || defined(__CUDACC_RTC__)*/
1196
-
1197
- half r = 0.0;
1198
-
1199
- #if __CUDA_ARCH__ >= 700
1200
-
1201
- asm volatile ("{ atom.add.noftz.f16 %0,[%1],%2; }\n"
1202
- : "=h"(r.u)
1203
- : __PTR(buf), "h"(value.u)
1204
- : "memory");
1205
- #endif
1206
-
1207
- return r;
1208
-
1209
- #undef __PTR
1210
-
1211
- #endif // CUDA compiled by NVRTC
1212
-
1213
- }
1214
-
1215
- // emulate atomic float max
1216
- inline CUDA_CALLABLE float atomic_max(float* address, float val)
1217
- {
1218
- #if defined(__CUDA_ARCH__)
1219
- int *address_as_int = (int*)address;
1220
- int old = *address_as_int, assumed;
1221
-
1222
- while (val > __int_as_float(old))
1223
- {
1224
- assumed = old;
1225
- old = atomicCAS(address_as_int, assumed,
1226
- __float_as_int(val));
1227
- }
1228
-
1229
- return __int_as_float(old);
1230
-
1231
- #else
1232
- float old = *address;
1233
- *address = max(old, val);
1234
- return old;
1235
- #endif
1236
- }
1237
-
1238
- // emulate atomic float min/max with atomicCAS()
1239
- inline CUDA_CALLABLE float atomic_min(float* address, float val)
1240
- {
1241
- #if defined(__CUDA_ARCH__)
1242
- int *address_as_int = (int*)address;
1243
- int old = *address_as_int, assumed;
1244
-
1245
- while (val < __int_as_float(old))
1246
- {
1247
- assumed = old;
1248
- old = atomicCAS(address_as_int, assumed,
1249
- __float_as_int(val));
1250
- }
1251
-
1252
- return __int_as_float(old);
1253
-
1254
- #else
1255
- float old = *address;
1256
- *address = min(old, val);
1257
- return old;
1258
- #endif
1259
- }
1260
-
1261
- inline CUDA_CALLABLE int atomic_max(int* address, int val)
1262
- {
1263
- #if defined(__CUDA_ARCH__)
1264
- return atomicMax(address, val);
1265
-
1266
- #else
1267
- int old = *address;
1268
- *address = max(old, val);
1269
- return old;
1270
- #endif
1271
- }
1272
-
1273
- // atomic int min
1274
- inline CUDA_CALLABLE int atomic_min(int* address, int val)
1275
- {
1276
- #if defined(__CUDA_ARCH__)
1277
- return atomicMin(address, val);
1278
-
1279
- #else
1280
- int old = *address;
1281
- *address = min(old, val);
1282
- return old;
1283
- #endif
1284
- }
1285
-
1286
- // default behavior for adjoint of atomic min/max operation that accumulates gradients for all elements matching the min/max value
1287
- template <typename T>
1288
- CUDA_CALLABLE inline void adj_atomic_minmax(T *addr, T *adj_addr, const T &value, T &adj_value)
1289
- {
1290
- if (value == *addr)
1291
- adj_value += *adj_addr;
1292
- }
1293
-
1294
- // for integral types we do not accumulate gradients
1295
- CUDA_CALLABLE inline void adj_atomic_minmax(int8* buf, int8* adj_buf, const int8 &value, int8 &adj_value) { }
1296
- CUDA_CALLABLE inline void adj_atomic_minmax(uint8* buf, uint8* adj_buf, const uint8 &value, uint8 &adj_value) { }
1297
- CUDA_CALLABLE inline void adj_atomic_minmax(int16* buf, int16* adj_buf, const int16 &value, int16 &adj_value) { }
1298
- CUDA_CALLABLE inline void adj_atomic_minmax(uint16* buf, uint16* adj_buf, const uint16 &value, uint16 &adj_value) { }
1299
- CUDA_CALLABLE inline void adj_atomic_minmax(int32* buf, int32* adj_buf, const int32 &value, int32 &adj_value) { }
1300
- CUDA_CALLABLE inline void adj_atomic_minmax(uint32* buf, uint32* adj_buf, const uint32 &value, uint32 &adj_value) { }
1301
- CUDA_CALLABLE inline void adj_atomic_minmax(int64* buf, int64* adj_buf, const int64 &value, int64 &adj_value) { }
1302
- CUDA_CALLABLE inline void adj_atomic_minmax(uint64* buf, uint64* adj_buf, const uint64 &value, uint64 &adj_value) { }
1303
- CUDA_CALLABLE inline void adj_atomic_minmax(bool* buf, bool* adj_buf, const bool &value, bool &adj_value) { }
1304
-
1305
-
1306
- } // namespace wp
1307
-
1308
-
1309
- // bool and printf are defined outside of the wp namespace in crt.h, hence
1310
- // their adjoint counterparts are also defined in the global namespace.
1311
- template <typename T>
1312
- CUDA_CALLABLE inline void adj_bool(T, T&, bool) {}
1313
- inline CUDA_CALLABLE void adj_printf(const char* fmt, ...) {}
1314
-
1315
-
1316
- #include "vec.h"
1317
- #include "mat.h"
1318
- #include "quat.h"
1319
- #include "spatial.h"
1320
- #include "intersect.h"
1321
- #include "intersect_adj.h"
1322
-
1323
- //--------------
1324
- namespace wp
1325
- {
1326
-
1327
-
1328
- // dot for scalar types just to make some templates compile for scalar/vector
1329
- inline CUDA_CALLABLE float dot(float a, float b) { return mul(a, b); }
1330
- inline CUDA_CALLABLE void adj_dot(float a, float b, float& adj_a, float& adj_b, float adj_ret) { adj_mul(a, b, adj_a, adj_b, adj_ret); }
1331
- inline CUDA_CALLABLE float tensordot(float a, float b) { return mul(a, b); }
1332
-
1333
-
1334
- #define DECLARE_INTERP_FUNCS(T) \
1335
- CUDA_CALLABLE inline T smoothstep(T edge0, T edge1, T x)\
1336
- {\
1337
- x = clamp((x - edge0) / (edge1 - edge0), T(0), T(1));\
1338
- return x * x * (T(3) - T(2) * x);\
1339
- }\
1340
- CUDA_CALLABLE inline void adj_smoothstep(T edge0, T edge1, T x, T& adj_edge0, T& adj_edge1, T& adj_x, T adj_ret)\
1341
- {\
1342
- T ab = edge0 - edge1;\
1343
- T ax = edge0 - x;\
1344
- T bx = edge1 - x;\
1345
- T xb = x - edge1;\
1346
- \
1347
- if (bx / ab >= T(0) || ax / ab <= T(0))\
1348
- {\
1349
- return;\
1350
- }\
1351
- \
1352
- T ab3 = ab * ab * ab;\
1353
- T ab4 = ab3 * ab;\
1354
- adj_edge0 += adj_ret * ((T(6) * ax * bx * bx) / ab4);\
1355
- adj_edge1 += adj_ret * ((T(6) * ax * ax * xb) / ab4);\
1356
- adj_x += adj_ret * ((T(6) * ax * bx ) / ab3);\
1357
- }\
1358
- CUDA_CALLABLE inline T lerp(const T& a, const T& b, T t)\
1359
- {\
1360
- return a*(T(1)-t) + b*t;\
1361
- }\
1362
- CUDA_CALLABLE inline void adj_lerp(const T& a, const T& b, T t, T& adj_a, T& adj_b, T& adj_t, const T& adj_ret)\
1363
- {\
1364
- adj_a += adj_ret*(T(1)-t);\
1365
- adj_b += adj_ret*t;\
1366
- adj_t += b*adj_ret - a*adj_ret;\
1367
- }
1368
-
1369
- DECLARE_INTERP_FUNCS(float16)
1370
- DECLARE_INTERP_FUNCS(float32)
1371
- DECLARE_INTERP_FUNCS(float64)
1372
-
1373
- inline CUDA_CALLABLE void print(const str s)
1374
- {
1375
- printf("%s\n", s);
1376
- }
1377
-
1378
- inline CUDA_CALLABLE void print(int i)
1379
- {
1380
- printf("%d\n", i);
1381
- }
1382
-
1383
- inline CUDA_CALLABLE void print(short i)
1384
- {
1385
- printf("%hd\n", i);
1386
- }
1387
-
1388
- inline CUDA_CALLABLE void print(long i)
1389
- {
1390
- printf("%ld\n", i);
1391
- }
1392
-
1393
- inline CUDA_CALLABLE void print(long long i)
1394
- {
1395
- printf("%lld\n", i);
1396
- }
1397
-
1398
- inline CUDA_CALLABLE void print(unsigned i)
1399
- {
1400
- printf("%u\n", i);
1401
- }
1402
-
1403
- inline CUDA_CALLABLE void print(unsigned short i)
1404
- {
1405
- printf("%hu\n", i);
1406
- }
1407
-
1408
- inline CUDA_CALLABLE void print(unsigned long i)
1409
- {
1410
- printf("%lu\n", i);
1411
- }
1412
-
1413
- inline CUDA_CALLABLE void print(unsigned long long i)
1414
- {
1415
- printf("%llu\n", i);
1416
- }
1417
-
1418
- template<unsigned Length, typename Type>
1419
- inline CUDA_CALLABLE void print(vec_t<Length, Type> v)
1420
- {
1421
- for( unsigned i=0; i < Length; ++i )
1422
- {
1423
- printf("%g ", float(v[i]));
1424
- }
1425
- printf("\n");
1426
- }
1427
-
1428
- template<typename Type>
1429
- inline CUDA_CALLABLE void print(quat_t<Type> i)
1430
- {
1431
- printf("%g %g %g %g\n", float(i.x), float(i.y), float(i.z), float(i.w));
1432
- }
1433
-
1434
- template<unsigned Rows,unsigned Cols,typename Type>
1435
- inline CUDA_CALLABLE void print(const mat_t<Rows,Cols,Type> &m)
1436
- {
1437
- for( unsigned i=0; i< Rows; ++i )
1438
- {
1439
- for( unsigned j=0; j< Cols; ++j )
1440
- {
1441
- printf("%g ",float(m.data[i][j]));
1442
- }
1443
- printf("\n");
1444
- }
1445
- }
1446
-
1447
- template<typename Type>
1448
- inline CUDA_CALLABLE void print(transform_t<Type> t)
1449
- {
1450
- printf("(%g %g %g) (%g %g %g %g)\n", float(t.p[0]), float(t.p[1]), float(t.p[2]), float(t.q.x), float(t.q.y), float(t.q.z), float(t.q.w));
1451
- }
1452
-
1453
- inline CUDA_CALLABLE void adj_print(int i, int adj_i) { printf("%d adj: %d\n", i, adj_i); }
1454
- inline CUDA_CALLABLE void adj_print(float f, float adj_f) { printf("%g adj: %g\n", f, adj_f); }
1455
- inline CUDA_CALLABLE void adj_print(short f, short adj_f) { printf("%hd adj: %hd\n", f, adj_f); }
1456
- inline CUDA_CALLABLE void adj_print(long f, long adj_f) { printf("%ld adj: %ld\n", f, adj_f); }
1457
- inline CUDA_CALLABLE void adj_print(long long f, long long adj_f) { printf("%lld adj: %lld\n", f, adj_f); }
1458
- inline CUDA_CALLABLE void adj_print(unsigned f, unsigned adj_f) { printf("%u adj: %u\n", f, adj_f); }
1459
- inline CUDA_CALLABLE void adj_print(unsigned short f, unsigned short adj_f) { printf("%hu adj: %hu\n", f, adj_f); }
1460
- inline CUDA_CALLABLE void adj_print(unsigned long f, unsigned long adj_f) { printf("%lu adj: %lu\n", f, adj_f); }
1461
- inline CUDA_CALLABLE void adj_print(unsigned long long f, unsigned long long adj_f) { printf("%llu adj: %llu\n", f, adj_f); }
1462
- inline CUDA_CALLABLE void adj_print(half h, half adj_h) { printf("%g adj: %g\n", half_to_float(h), half_to_float(adj_h)); }
1463
- inline CUDA_CALLABLE void adj_print(double f, double adj_f) { printf("%g adj: %g\n", f, adj_f); }
1464
-
1465
- template<unsigned Length, typename Type>
1466
- inline CUDA_CALLABLE void adj_print(vec_t<Length, Type> v, vec_t<Length, Type>& adj_v) { printf("%g %g adj: %g %g \n", v[0], v[1], adj_v[0], adj_v[1]); }
1467
-
1468
- template<unsigned Rows, unsigned Cols, typename Type>
1469
- inline CUDA_CALLABLE void adj_print(mat_t<Rows, Cols, Type> m, mat_t<Rows, Cols, Type>& adj_m) { }
1470
-
1471
- template<typename Type>
1472
- inline CUDA_CALLABLE void adj_print(quat_t<Type> q, quat_t<Type>& adj_q) { printf("%g %g %g %g adj: %g %g %g %g\n", q.x, q.y, q.z, q.w, adj_q.x, adj_q.y, adj_q.z, adj_q.w); }
1473
-
1474
- template<typename Type>
1475
- inline CUDA_CALLABLE void adj_print(transform_t<Type> t, transform_t<Type>& adj_t) {}
1476
-
1477
- inline CUDA_CALLABLE void adj_print(str t, str& adj_t) {}
1478
-
1479
-
1480
- template <typename T>
1481
- inline CUDA_CALLABLE void expect_eq(const T& actual, const T& expected)
1482
- {
1483
- if (!(actual == expected))
1484
- {
1485
- printf("Error, expect_eq() failed:\n");
1486
- printf("\t Expected: "); print(expected);
1487
- printf("\t Actual: "); print(actual);
1488
- }
1489
- }
1490
-
1491
- template <typename T>
1492
- inline CUDA_CALLABLE void adj_expect_eq(const T& a, const T& b, T& adj_a, T& adj_b)
1493
- {
1494
- // nop
1495
- }
1496
-
1497
- template <typename T>
1498
- inline CUDA_CALLABLE void expect_neq(const T& actual, const T& expected)
1499
- {
1500
- if (actual == expected)
1501
- {
1502
- printf("Error, expect_neq() failed:\n");
1503
- printf("\t Expected: "); print(expected);
1504
- printf("\t Actual: "); print(actual);
1505
- }
1506
- }
1507
-
1508
- template <typename T>
1509
- inline CUDA_CALLABLE void adj_expect_neq(const T& a, const T& b, T& adj_a, T& adj_b)
1510
- {
1511
- // nop
1512
- }
1513
-
1514
- template <typename T>
1515
- inline CUDA_CALLABLE void expect_near(const T& actual, const T& expected, const T& tolerance)
1516
- {
1517
- if (abs(actual - expected) > tolerance)
1518
- {
1519
- printf("Error, expect_near() failed with tolerance "); print(tolerance);
1520
- printf("\t Expected: "); print(expected);
1521
- printf("\t Actual: "); print(actual);
1522
- }
1523
- }
1524
-
1525
- inline CUDA_CALLABLE void expect_near(const vec3& actual, const vec3& expected, const float& tolerance)
1526
- {
1527
- const float diff = max(max(abs(actual[0] - expected[0]), abs(actual[1] - expected[1])), abs(actual[2] - expected[2]));
1528
- if (diff > tolerance)
1529
- {
1530
- printf("Error, expect_near() failed with tolerance "); print(tolerance);
1531
- printf("\t Expected: "); print(expected);
1532
- printf("\t Actual: "); print(actual);
1533
- }
1534
- }
1535
-
1536
- template <typename T>
1537
- inline CUDA_CALLABLE void adj_expect_near(const T& actual, const T& expected, const T& tolerance, T& adj_actual, T& adj_expected, T& adj_tolerance)
1538
- {
1539
- // nop
1540
- }
1541
-
1542
- inline CUDA_CALLABLE void adj_expect_near(const vec3& actual, const vec3& expected, float tolerance, vec3& adj_actual, vec3& adj_expected, float adj_tolerance)
1543
- {
1544
- // nop
1545
- }
1546
-
1547
-
1548
- } // namespace wp
1549
-
1550
- // include array.h so we have the print, isfinite functions for the inner array types defined
1551
- #include "array.h"
1552
- #include "mesh.h"
1553
- #include "bvh.h"
1554
- #include "svd.h"
1555
- #include "hashgrid.h"
1556
- #include "volume.h"
1557
- #include "range.h"
1558
- #include "rand.h"
1559
- #include "noise.h"
1560
- #include "matnn.h"
1
+ /** Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
2
+ * NVIDIA CORPORATION and its licensors retain all intellectual property
3
+ * and proprietary rights in and to this software, related documentation
4
+ * and any modifications thereto. Any use, reproduction, disclosure or
5
+ * distribution of this software and related documentation without an express
6
+ * license agreement from NVIDIA CORPORATION is strictly prohibited.
7
+ */
8
+
9
+ #pragma once
10
+
11
+ // All built-in types and functions. To be compatible with runtime NVRTC compilation
12
+ // this header must be independently compilable (i.e.: without external SDK headers)
13
+ // to achieve this we redefine a subset of CRT functions (printf, pow, sin, cos, etc)
14
+
15
+ #include "crt.h"
16
+
17
+ #ifdef _WIN32
18
+ #define __restrict__ __restrict
19
+ #endif
20
+
21
+ #if !defined(__CUDACC__)
22
+ #define CUDA_CALLABLE
23
+ #define CUDA_CALLABLE_DEVICE
24
+ #else
25
+ #define CUDA_CALLABLE __host__ __device__
26
+ #define CUDA_CALLABLE_DEVICE __device__
27
+ #endif
28
+
29
+ #ifdef WP_VERIFY_FP
30
+ #define FP_CHECK 1
31
+ #define DO_IF_FPCHECK(X) {X}
32
+ #define DO_IF_NO_FPCHECK(X)
33
+ #else
34
+ #define FP_CHECK 0
35
+ #define DO_IF_FPCHECK(X)
36
+ #define DO_IF_NO_FPCHECK(X) {X}
37
+ #endif
38
+
39
+ #define RAD_TO_DEG 57.29577951308232087679
40
+ #define DEG_TO_RAD 0.01745329251994329577
41
+
42
+ #if defined(__CUDACC__) && !defined(_MSC_VER)
43
+ __device__ void __debugbreak() {}
44
+ #endif
45
+
46
+ namespace wp
47
+ {
48
+
49
+ // numeric types (used from generated kernels)
50
+ typedef float float32;
51
+ typedef double float64;
52
+
53
+ typedef int8_t int8;
54
+ typedef uint8_t uint8;
55
+
56
+ typedef int16_t int16;
57
+ typedef uint16_t uint16;
58
+
59
+ typedef int32_t int32;
60
+ typedef uint32_t uint32;
61
+
62
+ typedef int64_t int64;
63
+ typedef uint64_t uint64;
64
+
65
+
66
+ // matches Python string type for constant strings
67
+ typedef const char* str;
68
+
69
+
70
+
71
+ struct half;
72
+
73
+ CUDA_CALLABLE half float_to_half(float x);
74
+ CUDA_CALLABLE float half_to_float(half x);
75
+
76
+ struct half
77
+ {
78
+ CUDA_CALLABLE inline half() : u(0) {}
79
+
80
+ CUDA_CALLABLE inline half(float f)
81
+ {
82
+ *this = float_to_half(f);
83
+ }
84
+
85
+ unsigned short u;
86
+
87
+ CUDA_CALLABLE inline bool operator==(const half& h) const { return u == h.u; }
88
+ CUDA_CALLABLE inline bool operator!=(const half& h) const { return u != h.u; }
89
+ CUDA_CALLABLE inline bool operator>(const half& h) const { return half_to_float(*this) > half_to_float(h); }
90
+ CUDA_CALLABLE inline bool operator>=(const half& h) const { return half_to_float(*this) >= half_to_float(h); }
91
+ CUDA_CALLABLE inline bool operator<(const half& h) const { return half_to_float(*this) < half_to_float(h); }
92
+ CUDA_CALLABLE inline bool operator<=(const half& h) const { return half_to_float(*this) <= half_to_float(h); }
93
+
94
+ CUDA_CALLABLE inline bool operator!() const
95
+ {
96
+ return float32(*this) == 0;
97
+ }
98
+
99
+ CUDA_CALLABLE inline half operator*=(const half& h)
100
+ {
101
+ half prod = half(float32(*this) * float32(h));
102
+ this->u = prod.u;
103
+ return *this;
104
+ }
105
+
106
+ CUDA_CALLABLE inline half operator/=(const half& h)
107
+ {
108
+ half quot = half(float32(*this) / float32(h));
109
+ this->u = quot.u;
110
+ return *this;
111
+ }
112
+
113
+ CUDA_CALLABLE inline half operator+=(const half& h)
114
+ {
115
+ half sum = half(float32(*this) + float32(h));
116
+ this->u = sum.u;
117
+ return *this;
118
+ }
119
+
120
+ CUDA_CALLABLE inline half operator-=(const half& h)
121
+ {
122
+ half diff = half(float32(*this) - float32(h));
123
+ this->u = diff.u;
124
+ return *this;
125
+ }
126
+
127
+ CUDA_CALLABLE inline operator float32() const { return float32(half_to_float(*this)); }
128
+ CUDA_CALLABLE inline operator float64() const { return float64(half_to_float(*this)); }
129
+ CUDA_CALLABLE inline operator int8() const { return int8(half_to_float(*this)); }
130
+ CUDA_CALLABLE inline operator uint8() const { return uint8(half_to_float(*this)); }
131
+ CUDA_CALLABLE inline operator int16() const { return int16(half_to_float(*this)); }
132
+ CUDA_CALLABLE inline operator uint16() const { return uint16(half_to_float(*this)); }
133
+ CUDA_CALLABLE inline operator int32() const { return int32(half_to_float(*this)); }
134
+ CUDA_CALLABLE inline operator uint32() const { return uint32(half_to_float(*this)); }
135
+ CUDA_CALLABLE inline operator int64() const { return int64(half_to_float(*this)); }
136
+ CUDA_CALLABLE inline operator uint64() const { return uint64(half_to_float(*this)); }
137
+ };
138
+
139
+ static_assert(sizeof(half) == 2, "Size of half / float16 type must be 2-bytes");
140
+
141
+ typedef half float16;
142
+
143
+ #if defined(__CUDA_ARCH__)
144
+
145
+ CUDA_CALLABLE inline half float_to_half(float x)
146
+ {
147
+ half h;
148
+ asm("{ cvt.rn.f16.f32 %0, %1;}\n" : "=h"(h.u) : "f"(x));
149
+ return h;
150
+ }
151
+
152
+ CUDA_CALLABLE inline float half_to_float(half x)
153
+ {
154
+ float val;
155
+ asm("{ cvt.f32.f16 %0, %1;}\n" : "=f"(val) : "h"(x.u));
156
+ return val;
157
+ }
158
+
159
+ #elif defined(__clang__)
160
+
161
+ // _Float16 is Clang's native half-precision floating-point type
162
+ inline half float_to_half(float x)
163
+ {
164
+
165
+ _Float16 f16 = static_cast<_Float16>(x);
166
+ return *reinterpret_cast<half*>(&f16);
167
+ }
168
+
169
+ inline float half_to_float(half h)
170
+ {
171
+ _Float16 f16 = *reinterpret_cast<_Float16*>(&h);
172
+ return static_cast<float>(f16);
173
+ }
174
+
175
+ #else // Native C++ for Warp builtins outside of kernels
176
+
177
+ extern "C" WP_API uint16_t float_to_half_bits(float x);
178
+ extern "C" WP_API float half_bits_to_float(uint16_t u);
179
+
180
+ inline half float_to_half(float x)
181
+ {
182
+ half h;
183
+ h.u = float_to_half_bits(x);
184
+ return h;
185
+ }
186
+
187
+ inline float half_to_float(half h)
188
+ {
189
+ return half_bits_to_float(h.u);
190
+ }
191
+
192
+ #endif
193
+
194
+
195
+ // BAD operator implementations for fp16 arithmetic...
196
+
197
+ // negation:
198
+ inline CUDA_CALLABLE half operator - (half a)
199
+ {
200
+ return float_to_half( -half_to_float(a) );
201
+ }
202
+
203
+ inline CUDA_CALLABLE half operator + (half a,half b)
204
+ {
205
+ return float_to_half( half_to_float(a) + half_to_float(b) );
206
+ }
207
+
208
+ inline CUDA_CALLABLE half operator - (half a,half b)
209
+ {
210
+ return float_to_half( half_to_float(a) - half_to_float(b) );
211
+ }
212
+
213
+ inline CUDA_CALLABLE half operator * (half a,half b)
214
+ {
215
+ return float_to_half( half_to_float(a) * half_to_float(b) );
216
+ }
217
+
218
+ inline CUDA_CALLABLE half operator * (half a,double b)
219
+ {
220
+ return float_to_half( half_to_float(a) * b );
221
+ }
222
+
223
+ inline CUDA_CALLABLE half operator * (double a,half b)
224
+ {
225
+ return float_to_half( a * half_to_float(b) );
226
+ }
227
+
228
+ inline CUDA_CALLABLE half operator / (half a,half b)
229
+ {
230
+ return float_to_half( half_to_float(a) / half_to_float(b) );
231
+ }
232
+
233
+
234
+
235
+
236
+
237
+ template <typename T>
238
+ CUDA_CALLABLE float cast_float(T x) { return (float)(x); }
239
+
240
+ template <typename T>
241
+ CUDA_CALLABLE int cast_int(T x) { return (int)(x); }
242
+
243
+ template <typename T>
244
+ CUDA_CALLABLE void adj_cast_float(T x, T& adj_x, float adj_ret) { adj_x += T(adj_ret); }
245
+
246
+ template <typename T>
247
+ CUDA_CALLABLE void adj_cast_int(T x, T& adj_x, int adj_ret) { adj_x += adj_ret; }
248
+
249
+ template <typename T>
250
+ CUDA_CALLABLE inline void adj_int8(T, T&, int8) {}
251
+ template <typename T>
252
+ CUDA_CALLABLE inline void adj_uint8(T, T&, uint8) {}
253
+ template <typename T>
254
+ CUDA_CALLABLE inline void adj_int16(T, T&, int16) {}
255
+ template <typename T>
256
+ CUDA_CALLABLE inline void adj_uint16(T, T&, uint16) {}
257
+ template <typename T>
258
+ CUDA_CALLABLE inline void adj_int32(T, T&, int32) {}
259
+ template <typename T>
260
+ CUDA_CALLABLE inline void adj_uint32(T, T&, uint32) {}
261
+ template <typename T>
262
+ CUDA_CALLABLE inline void adj_int64(T, T&, int64) {}
263
+ template <typename T>
264
+ CUDA_CALLABLE inline void adj_uint64(T, T&, uint64) {}
265
+
266
+
267
+ template <typename T>
268
+ CUDA_CALLABLE inline void adj_float16(T x, T& adj_x, float16 adj_ret) { adj_x += T(adj_ret); }
269
+ template <typename T>
270
+ CUDA_CALLABLE inline void adj_float32(T x, T& adj_x, float32 adj_ret) { adj_x += T(adj_ret); }
271
+ template <typename T>
272
+ CUDA_CALLABLE inline void adj_float64(T x, T& adj_x, float64 adj_ret) { adj_x += T(adj_ret); }
273
+
274
+
275
+ #define kEps 0.0f
276
+
277
+ // basic ops for integer types
278
+ #define DECLARE_INT_OPS(T) \
279
+ inline CUDA_CALLABLE T mul(T a, T b) { return a*b; } \
280
+ inline CUDA_CALLABLE T div(T a, T b) { return a/b; } \
281
+ inline CUDA_CALLABLE T add(T a, T b) { return a+b; } \
282
+ inline CUDA_CALLABLE T sub(T a, T b) { return a-b; } \
283
+ inline CUDA_CALLABLE T mod(T a, T b) { return a%b; } \
284
+ inline CUDA_CALLABLE T min(T a, T b) { return a<b?a:b; } \
285
+ inline CUDA_CALLABLE T max(T a, T b) { return a>b?a:b; } \
286
+ inline CUDA_CALLABLE T clamp(T x, T a, T b) { return min(max(a, x), b); } \
287
+ inline CUDA_CALLABLE T floordiv(T a, T b) { return a/b; } \
288
+ inline CUDA_CALLABLE T nonzero(T x) { return x == T(0) ? T(0) : T(1); } \
289
+ inline CUDA_CALLABLE T sqrt(T x) { return 0; } \
290
+ inline CUDA_CALLABLE T bit_and(T a, T b) { return a&b; } \
291
+ inline CUDA_CALLABLE T bit_or(T a, T b) { return a|b; } \
292
+ inline CUDA_CALLABLE T bit_xor(T a, T b) { return a^b; } \
293
+ inline CUDA_CALLABLE T lshift(T a, T b) { return a<<b; } \
294
+ inline CUDA_CALLABLE T rshift(T a, T b) { return a>>b; } \
295
+ inline CUDA_CALLABLE T invert(T x) { return ~x; } \
296
+ inline CUDA_CALLABLE bool isfinite(T x) { return true; } \
297
+ inline CUDA_CALLABLE void adj_mul(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
298
+ inline CUDA_CALLABLE void adj_div(T a, T b, T ret, T& adj_a, T& adj_b, T adj_ret) { } \
299
+ inline CUDA_CALLABLE void adj_add(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
300
+ inline CUDA_CALLABLE void adj_sub(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
301
+ inline CUDA_CALLABLE void adj_mod(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
302
+ inline CUDA_CALLABLE void adj_min(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
303
+ inline CUDA_CALLABLE void adj_max(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
304
+ inline CUDA_CALLABLE void adj_abs(T x, T adj_x, T& adj_ret) { } \
305
+ inline CUDA_CALLABLE void adj_sign(T x, T adj_x, T& adj_ret) { } \
306
+ inline CUDA_CALLABLE void adj_clamp(T x, T a, T b, T& adj_x, T& adj_a, T& adj_b, T adj_ret) { } \
307
+ inline CUDA_CALLABLE void adj_floordiv(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
308
+ inline CUDA_CALLABLE void adj_step(T x, T& adj_x, T adj_ret) { } \
309
+ inline CUDA_CALLABLE void adj_nonzero(T x, T& adj_x, T adj_ret) { } \
310
+ inline CUDA_CALLABLE void adj_sqrt(T x, T adj_x, T& adj_ret) { } \
311
+ inline CUDA_CALLABLE void adj_bit_and(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
312
+ inline CUDA_CALLABLE void adj_bit_or(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
313
+ inline CUDA_CALLABLE void adj_bit_xor(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
314
+ inline CUDA_CALLABLE void adj_lshift(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
315
+ inline CUDA_CALLABLE void adj_rshift(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
316
+ inline CUDA_CALLABLE void adj_invert(T x, T adj_x, T& adj_ret) { }
317
+
318
+ inline CUDA_CALLABLE int8 abs(int8 x) { return ::abs(x); }
319
+ inline CUDA_CALLABLE int16 abs(int16 x) { return ::abs(x); }
320
+ inline CUDA_CALLABLE int32 abs(int32 x) { return ::abs(x); }
321
+ inline CUDA_CALLABLE int64 abs(int64 x) { return ::llabs(x); }
322
+ inline CUDA_CALLABLE uint8 abs(uint8 x) { return x; }
323
+ inline CUDA_CALLABLE uint16 abs(uint16 x) { return x; }
324
+ inline CUDA_CALLABLE uint32 abs(uint32 x) { return x; }
325
+ inline CUDA_CALLABLE uint64 abs(uint64 x) { return x; }
326
+
327
+ DECLARE_INT_OPS(int8)
328
+ DECLARE_INT_OPS(int16)
329
+ DECLARE_INT_OPS(int32)
330
+ DECLARE_INT_OPS(int64)
331
+ DECLARE_INT_OPS(uint8)
332
+ DECLARE_INT_OPS(uint16)
333
+ DECLARE_INT_OPS(uint32)
334
+ DECLARE_INT_OPS(uint64)
335
+
336
+
337
+ inline CUDA_CALLABLE int8 step(int8 x) { return x < 0 ? 1 : 0; }
338
+ inline CUDA_CALLABLE int16 step(int16 x) { return x < 0 ? 1 : 0; }
339
+ inline CUDA_CALLABLE int32 step(int32 x) { return x < 0 ? 1 : 0; }
340
+ inline CUDA_CALLABLE int64 step(int64 x) { return x < 0 ? 1 : 0; }
341
+ inline CUDA_CALLABLE uint8 step(uint8 x) { return 0; }
342
+ inline CUDA_CALLABLE uint16 step(uint16 x) { return 0; }
343
+ inline CUDA_CALLABLE uint32 step(uint32 x) { return 0; }
344
+ inline CUDA_CALLABLE uint64 step(uint64 x) { return 0; }
345
+
346
+
347
+ inline CUDA_CALLABLE int8 sign(int8 x) { return x < 0 ? -1 : 1; }
348
+ inline CUDA_CALLABLE int8 sign(int16 x) { return x < 0 ? -1 : 1; }
349
+ inline CUDA_CALLABLE int8 sign(int32 x) { return x < 0 ? -1 : 1; }
350
+ inline CUDA_CALLABLE int8 sign(int64 x) { return x < 0 ? -1 : 1; }
351
+ inline CUDA_CALLABLE uint8 sign(uint8 x) { return 1; }
352
+ inline CUDA_CALLABLE uint16 sign(uint16 x) { return 1; }
353
+ inline CUDA_CALLABLE uint32 sign(uint32 x) { return 1; }
354
+ inline CUDA_CALLABLE uint64 sign(uint64 x) { return 1; }
355
+
356
+
357
+ // Catch-all for non-float types
358
+ template<typename T>
359
+ inline bool CUDA_CALLABLE isfinite(const T&)
360
+ {
361
+ return true;
362
+ }
363
+
364
+ inline bool CUDA_CALLABLE isfinite(half x)
365
+ {
366
+ return ::isfinite(float(x));
367
+ }
368
+ inline bool CUDA_CALLABLE isfinite(float x)
369
+ {
370
+ return ::isfinite(x);
371
+ }
372
+ inline bool CUDA_CALLABLE isfinite(double x)
373
+ {
374
+ return ::isfinite(x);
375
+ }
376
+
377
+ template<typename T>
378
+ inline CUDA_CALLABLE void print(const T&)
379
+ {
380
+ printf("<type without print implementation>\n");
381
+ }
382
+
383
+ inline CUDA_CALLABLE void print(float16 f)
384
+ {
385
+ printf("%g\n", half_to_float(f));
386
+ }
387
+
388
+ inline CUDA_CALLABLE void print(float f)
389
+ {
390
+ printf("%g\n", f);
391
+ }
392
+
393
+ inline CUDA_CALLABLE void print(double f)
394
+ {
395
+ printf("%g\n", f);
396
+ }
397
+
398
+
399
+ // basic ops for float types
400
+ #define DECLARE_FLOAT_OPS(T) \
401
+ inline CUDA_CALLABLE T mul(T a, T b) { return a*b; } \
402
+ inline CUDA_CALLABLE T add(T a, T b) { return a+b; } \
403
+ inline CUDA_CALLABLE T sub(T a, T b) { return a-b; } \
404
+ inline CUDA_CALLABLE T min(T a, T b) { return a<b?a:b; } \
405
+ inline CUDA_CALLABLE T max(T a, T b) { return a>b?a:b; } \
406
+ inline CUDA_CALLABLE T sign(T x) { return x < T(0) ? -1 : 1; } \
407
+ inline CUDA_CALLABLE T step(T x) { return x < T(0) ? T(1) : T(0); }\
408
+ inline CUDA_CALLABLE T nonzero(T x) { return x == T(0) ? T(0) : T(1); }\
409
+ inline CUDA_CALLABLE T clamp(T x, T a, T b) { return min(max(a, x), b); }\
410
+ inline CUDA_CALLABLE void adj_abs(T x, T& adj_x, T adj_ret) \
411
+ {\
412
+ if (x < T(0))\
413
+ adj_x -= adj_ret;\
414
+ else\
415
+ adj_x += adj_ret;\
416
+ }\
417
+ inline CUDA_CALLABLE void adj_mul(T a, T b, T& adj_a, T& adj_b, T adj_ret) { adj_a += b*adj_ret; adj_b += a*adj_ret; } \
418
+ inline CUDA_CALLABLE void adj_add(T a, T b, T& adj_a, T& adj_b, T adj_ret) { adj_a += adj_ret; adj_b += adj_ret; } \
419
+ inline CUDA_CALLABLE void adj_sub(T a, T b, T& adj_a, T& adj_b, T adj_ret) { adj_a += adj_ret; adj_b -= adj_ret; } \
420
+ inline CUDA_CALLABLE void adj_min(T a, T b, T& adj_a, T& adj_b, T adj_ret) \
421
+ { \
422
+ if (a < b) \
423
+ adj_a += adj_ret; \
424
+ else \
425
+ adj_b += adj_ret; \
426
+ } \
427
+ inline CUDA_CALLABLE void adj_max(T a, T b, T& adj_a, T& adj_b, T adj_ret) \
428
+ { \
429
+ if (a > b) \
430
+ adj_a += adj_ret; \
431
+ else \
432
+ adj_b += adj_ret; \
433
+ } \
434
+ inline CUDA_CALLABLE void adj_floordiv(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
435
+ inline CUDA_CALLABLE void adj_mod(T a, T b, T& adj_a, T& adj_b, T adj_ret){ adj_a += adj_ret; }\
436
+ inline CUDA_CALLABLE void adj_sign(T x, T adj_x, T& adj_ret) { }\
437
+ inline CUDA_CALLABLE void adj_step(T x, T& adj_x, T adj_ret) { }\
438
+ inline CUDA_CALLABLE void adj_nonzero(T x, T& adj_x, T adj_ret) { }\
439
+ inline CUDA_CALLABLE void adj_clamp(T x, T a, T b, T& adj_x, T& adj_a, T& adj_b, T adj_ret)\
440
+ {\
441
+ if (x < a)\
442
+ adj_a += adj_ret;\
443
+ else if (x > b)\
444
+ adj_b += adj_ret;\
445
+ else\
446
+ adj_x += adj_ret;\
447
+ }\
448
+ inline CUDA_CALLABLE T div(T a, T b)\
449
+ {\
450
+ DO_IF_FPCHECK(\
451
+ if (!isfinite(a) || !isfinite(b) || b == T(0))\
452
+ {\
453
+ printf("%s:%d div(%f, %f)\n", __FILE__, __LINE__, float(a), float(b));\
454
+ assert(0);\
455
+ })\
456
+ return a/b;\
457
+ }\
458
+ inline CUDA_CALLABLE void adj_div(T a, T b, T ret, T& adj_a, T& adj_b, T adj_ret)\
459
+ {\
460
+ adj_a += adj_ret/b;\
461
+ adj_b -= adj_ret*(ret)/b;\
462
+ DO_IF_FPCHECK(\
463
+ if (!isfinite(adj_a) || !isfinite(adj_b))\
464
+ {\
465
+ printf("%s:%d - adj_div(%f, %f, %f, %f, %f)\n", __FILE__, __LINE__, float(a), float(b), float(adj_a), float(adj_b), float(adj_ret));\
466
+ assert(0);\
467
+ })\
468
+ }\
469
+
470
+ DECLARE_FLOAT_OPS(float16)
471
+ DECLARE_FLOAT_OPS(float32)
472
+ DECLARE_FLOAT_OPS(float64)
473
+
474
+
475
+
476
+ // basic ops for float types
477
+ inline CUDA_CALLABLE float16 mod(float16 a, float16 b)
478
+ {
479
+ #if FP_CHECK
480
+ if (!isfinite(a) || !isfinite(b) || float(b) == 0.0f)
481
+ {
482
+ printf("%s:%d mod(%f, %f)\n", __FILE__, __LINE__, float(a), float(b));
483
+ assert(0);
484
+ }
485
+ #endif
486
+ return fmodf(float(a), float(b));
487
+ }
488
+
489
+ inline CUDA_CALLABLE float32 mod(float32 a, float32 b)
490
+ {
491
+ #if FP_CHECK
492
+ if (!isfinite(a) || !isfinite(b) || b == 0.0f)
493
+ {
494
+ printf("%s:%d mod(%f, %f)\n", __FILE__, __LINE__, a, b);
495
+ assert(0);
496
+ }
497
+ #endif
498
+ return fmodf(a, b);
499
+ }
500
+
501
+ inline CUDA_CALLABLE double mod(double a, double b)
502
+ {
503
+ #if FP_CHECK
504
+ if (!isfinite(a) || !isfinite(b) || b == 0.0f)
505
+ {
506
+ printf("%s:%d mod(%f, %f)\n", __FILE__, __LINE__, a, b);
507
+ assert(0);
508
+ }
509
+ #endif
510
+ return fmod(a, b);
511
+ }
512
+
513
+ inline CUDA_CALLABLE half log(half a)
514
+ {
515
+ #if FP_CHECK
516
+ if (!isfinite(a) || float(a) < 0.0f)
517
+ {
518
+ printf("%s:%d log(%f)\n", __FILE__, __LINE__, float(a));
519
+ assert(0);
520
+ }
521
+ #endif
522
+ return ::logf(a);
523
+ }
524
+
525
+ inline CUDA_CALLABLE float log(float a)
526
+ {
527
+ #if FP_CHECK
528
+ if (!isfinite(a) || a < 0.0f)
529
+ {
530
+ printf("%s:%d log(%f)\n", __FILE__, __LINE__, a);
531
+ assert(0);
532
+ }
533
+ #endif
534
+ return ::logf(a);
535
+ }
536
+
537
+ inline CUDA_CALLABLE double log(double a)
538
+ {
539
+ #if FP_CHECK
540
+ if (!isfinite(a) || a < 0.0)
541
+ {
542
+ printf("%s:%d log(%f)\n", __FILE__, __LINE__, a);
543
+ assert(0);
544
+ }
545
+ #endif
546
+ return ::log(a);
547
+ }
548
+
549
+ inline CUDA_CALLABLE half log2(half a)
550
+ {
551
+ #if FP_CHECK
552
+ if (!isfinite(a) || float(a) < 0.0f)
553
+ {
554
+ printf("%s:%d log2(%f)\n", __FILE__, __LINE__, float(a));
555
+ assert(0);
556
+ }
557
+ #endif
558
+
559
+ return ::log2f(float(a));
560
+ }
561
+
562
+ inline CUDA_CALLABLE float log2(float a)
563
+ {
564
+ #if FP_CHECK
565
+ if (!isfinite(a) || a < 0.0f)
566
+ {
567
+ printf("%s:%d log2(%f)\n", __FILE__, __LINE__, a);
568
+ assert(0);
569
+ }
570
+ #endif
571
+
572
+ return ::log2f(a);
573
+ }
574
+
575
+ inline CUDA_CALLABLE double log2(double a)
576
+ {
577
+ #if FP_CHECK
578
+ if (!isfinite(a) || a < 0.0)
579
+ {
580
+ printf("%s:%d log2(%f)\n", __FILE__, __LINE__, a);
581
+ assert(0);
582
+ }
583
+ #endif
584
+
585
+ return ::log2(a);
586
+ }
587
+
588
+ inline CUDA_CALLABLE half log10(half a)
589
+ {
590
+ #if FP_CHECK
591
+ if (!isfinite(a) || float(a) < 0.0f)
592
+ {
593
+ printf("%s:%d log10(%f)\n", __FILE__, __LINE__, float(a));
594
+ assert(0);
595
+ }
596
+ #endif
597
+
598
+ return ::log10f(float(a));
599
+ }
600
+
601
+ inline CUDA_CALLABLE float log10(float a)
602
+ {
603
+ #if FP_CHECK
604
+ if (!isfinite(a) || a < 0.0f)
605
+ {
606
+ printf("%s:%d log10(%f)\n", __FILE__, __LINE__, a);
607
+ assert(0);
608
+ }
609
+ #endif
610
+
611
+ return ::log10f(a);
612
+ }
613
+
614
+ inline CUDA_CALLABLE double log10(double a)
615
+ {
616
+ #if FP_CHECK
617
+ if (!isfinite(a) || a < 0.0)
618
+ {
619
+ printf("%s:%d log10(%f)\n", __FILE__, __LINE__, a);
620
+ assert(0);
621
+ }
622
+ #endif
623
+
624
+ return ::log10(a);
625
+ }
626
+
627
+ inline CUDA_CALLABLE half exp(half a)
628
+ {
629
+ half result = ::expf(float(a));
630
+ #if FP_CHECK
631
+ if (!isfinite(a) || !isfinite(result))
632
+ {
633
+ printf("%s:%d exp(%f) = %f\n", __FILE__, __LINE__, float(a), float(result));
634
+ assert(0);
635
+ }
636
+ #endif
637
+ return result;
638
+ }
639
+ inline CUDA_CALLABLE float exp(float a)
640
+ {
641
+ float result = ::expf(a);
642
+ #if FP_CHECK
643
+ if (!isfinite(a) || !isfinite(result))
644
+ {
645
+ printf("%s:%d exp(%f) = %f\n", __FILE__, __LINE__, a, result);
646
+ assert(0);
647
+ }
648
+ #endif
649
+ return result;
650
+ }
651
+ inline CUDA_CALLABLE double exp(double a)
652
+ {
653
+ double result = ::exp(a);
654
+ #if FP_CHECK
655
+ if (!isfinite(a) || !isfinite(result))
656
+ {
657
+ printf("%s:%d exp(%f) = %f\n", __FILE__, __LINE__, a, result);
658
+ assert(0);
659
+ }
660
+ #endif
661
+ return result;
662
+ }
663
+
664
+ inline CUDA_CALLABLE half pow(half a, half b)
665
+ {
666
+ float result = ::powf(float(a), float(b));
667
+ #if FP_CHECK
668
+ if (!isfinite(float(a)) || !isfinite(float(b)) || !isfinite(result))
669
+ {
670
+ printf("%s:%d pow(%f, %f) = %f\n", __FILE__, __LINE__, float(a), float(b), result);
671
+ assert(0);
672
+ }
673
+ #endif
674
+ return result;
675
+ }
676
+
677
+ inline CUDA_CALLABLE float pow(float a, float b)
678
+ {
679
+ float result = ::powf(a, b);
680
+ #if FP_CHECK
681
+ if (!isfinite(a) || !isfinite(b) || !isfinite(result))
682
+ {
683
+ printf("%s:%d pow(%f, %f) = %f\n", __FILE__, __LINE__, a, b, result);
684
+ assert(0);
685
+ }
686
+ #endif
687
+ return result;
688
+ }
689
+
690
+ inline CUDA_CALLABLE double pow(double a, double b)
691
+ {
692
+ double result = ::pow(a, b);
693
+ #if FP_CHECK
694
+ if (!isfinite(a) || !isfinite(b) || !isfinite(result))
695
+ {
696
+ printf("%s:%d pow(%f, %f) = %f\n", __FILE__, __LINE__, a, b, result);
697
+ assert(0);
698
+ }
699
+ #endif
700
+ return result;
701
+ }
702
+
703
+ inline CUDA_CALLABLE half floordiv(half a, half b)
704
+ {
705
+ #if FP_CHECK
706
+ if (!isfinite(a) || !isfinite(b) || float(b) == 0.0f)
707
+ {
708
+ printf("%s:%d mod(%f, %f)\n", __FILE__, __LINE__, float(a), float(b));
709
+ assert(0);
710
+ }
711
+ #endif
712
+ return floorf(float(a/b));
713
+ }
714
+ inline CUDA_CALLABLE float floordiv(float a, float b)
715
+ {
716
+ #if FP_CHECK
717
+ if (!isfinite(a) || !isfinite(b) || b == 0.0f)
718
+ {
719
+ printf("%s:%d mod(%f, %f)\n", __FILE__, __LINE__, a, b);
720
+ assert(0);
721
+ }
722
+ #endif
723
+ return floorf(a/b);
724
+ }
725
+ inline CUDA_CALLABLE double floordiv(double a, double b)
726
+ {
727
+ #if FP_CHECK
728
+ if (!isfinite(a) || !isfinite(b) || b == 0.0)
729
+ {
730
+ printf("%s:%d mod(%f, %f)\n", __FILE__, __LINE__, a, b);
731
+ assert(0);
732
+ }
733
+ #endif
734
+ return ::floor(a/b);
735
+ }
736
+
737
+ inline CUDA_CALLABLE float leaky_min(float a, float b, float r) { return min(a, b); }
738
+ inline CUDA_CALLABLE float leaky_max(float a, float b, float r) { return max(a, b); }
739
+
740
+ inline CUDA_CALLABLE half abs(half x) { return ::fabsf(float(x)); }
741
+ inline CUDA_CALLABLE float abs(float x) { return ::fabsf(x); }
742
+ inline CUDA_CALLABLE double abs(double x) { return ::fabs(x); }
743
+
744
+ inline CUDA_CALLABLE float acos(float x){ return ::acosf(min(max(x, -1.0f), 1.0f)); }
745
+ inline CUDA_CALLABLE float asin(float x){ return ::asinf(min(max(x, -1.0f), 1.0f)); }
746
+ inline CUDA_CALLABLE float atan(float x) { return ::atanf(x); }
747
+ inline CUDA_CALLABLE float atan2(float y, float x) { return ::atan2f(y, x); }
748
+ inline CUDA_CALLABLE float sin(float x) { return ::sinf(x); }
749
+ inline CUDA_CALLABLE float cos(float x) { return ::cosf(x); }
750
+
751
+ inline CUDA_CALLABLE double acos(double x){ return ::acos(min(max(x, -1.0), 1.0)); }
752
+ inline CUDA_CALLABLE double asin(double x){ return ::asin(min(max(x, -1.0), 1.0)); }
753
+ inline CUDA_CALLABLE double atan(double x) { return ::atan(x); }
754
+ inline CUDA_CALLABLE double atan2(double y, double x) { return ::atan2(y, x); }
755
+ inline CUDA_CALLABLE double sin(double x) { return ::sin(x); }
756
+ inline CUDA_CALLABLE double cos(double x) { return ::cos(x); }
757
+
758
+ inline CUDA_CALLABLE half acos(half x){ return ::acosf(min(max(float(x), -1.0f), 1.0f)); }
759
+ inline CUDA_CALLABLE half asin(half x){ return ::asinf(min(max(float(x), -1.0f), 1.0f)); }
760
+ inline CUDA_CALLABLE half atan(half x) { return ::atanf(float(x)); }
761
+ inline CUDA_CALLABLE half atan2(half y, half x) { return ::atan2f(float(y), float(x)); }
762
+ inline CUDA_CALLABLE half sin(half x) { return ::sinf(float(x)); }
763
+ inline CUDA_CALLABLE half cos(half x) { return ::cosf(float(x)); }
764
+
765
+
766
+ inline CUDA_CALLABLE float sqrt(float x)
767
+ {
768
+ #if FP_CHECK
769
+ if (x < 0.0f)
770
+ {
771
+ printf("%s:%d sqrt(%f)\n", __FILE__, __LINE__, x);
772
+ assert(0);
773
+ }
774
+ #endif
775
+ return ::sqrtf(x);
776
+ }
777
+ inline CUDA_CALLABLE double sqrt(double x)
778
+ {
779
+ #if FP_CHECK
780
+ if (x < 0.0)
781
+ {
782
+ printf("%s:%d sqrt(%f)\n", __FILE__, __LINE__, x);
783
+ assert(0);
784
+ }
785
+ #endif
786
+ return ::sqrt(x);
787
+ }
788
+ inline CUDA_CALLABLE half sqrt(half x)
789
+ {
790
+ #if FP_CHECK
791
+ if (float(x) < 0.0f)
792
+ {
793
+ printf("%s:%d sqrt(%f)\n", __FILE__, __LINE__, float(x));
794
+ assert(0);
795
+ }
796
+ #endif
797
+ return ::sqrtf(float(x));
798
+ }
799
+
800
+ inline CUDA_CALLABLE float cbrt(float x) { return ::cbrtf(x); }
801
+ inline CUDA_CALLABLE double cbrt(double x) { return ::cbrt(x); }
802
+ inline CUDA_CALLABLE half cbrt(half x) { return ::cbrtf(float(x)); }
803
+
804
+ inline CUDA_CALLABLE float tan(float x) { return ::tanf(x); }
805
+ inline CUDA_CALLABLE float sinh(float x) { return ::sinhf(x);}
806
+ inline CUDA_CALLABLE float cosh(float x) { return ::coshf(x);}
807
+ inline CUDA_CALLABLE float tanh(float x) { return ::tanhf(x);}
808
+ inline CUDA_CALLABLE float degrees(float x) { return x * RAD_TO_DEG;}
809
+ inline CUDA_CALLABLE float radians(float x) { return x * DEG_TO_RAD;}
810
+
811
+ inline CUDA_CALLABLE double tan(double x) { return ::tan(x); }
812
+ inline CUDA_CALLABLE double sinh(double x) { return ::sinh(x);}
813
+ inline CUDA_CALLABLE double cosh(double x) { return ::cosh(x);}
814
+ inline CUDA_CALLABLE double tanh(double x) { return ::tanh(x);}
815
+ inline CUDA_CALLABLE double degrees(double x) { return x * RAD_TO_DEG;}
816
+ inline CUDA_CALLABLE double radians(double x) { return x * DEG_TO_RAD;}
817
+
818
+ inline CUDA_CALLABLE half tan(half x) { return ::tanf(float(x)); }
819
+ inline CUDA_CALLABLE half sinh(half x) { return ::sinhf(float(x));}
820
+ inline CUDA_CALLABLE half cosh(half x) { return ::coshf(float(x));}
821
+ inline CUDA_CALLABLE half tanh(half x) { return ::tanhf(float(x));}
822
+ inline CUDA_CALLABLE half degrees(half x) { return x * RAD_TO_DEG;}
823
+ inline CUDA_CALLABLE half radians(half x) { return x * DEG_TO_RAD;}
824
+
825
+ inline CUDA_CALLABLE float round(float x) { return ::roundf(x); }
826
+ inline CUDA_CALLABLE float rint(float x) { return ::rintf(x); }
827
+ inline CUDA_CALLABLE float trunc(float x) { return ::truncf(x); }
828
+ inline CUDA_CALLABLE float floor(float x) { return ::floorf(x); }
829
+ inline CUDA_CALLABLE float ceil(float x) { return ::ceilf(x); }
830
+ inline CUDA_CALLABLE float frac(float x) { return x - trunc(x); }
831
+
832
+ inline CUDA_CALLABLE double round(double x) { return ::round(x); }
833
+ inline CUDA_CALLABLE double rint(double x) { return ::rint(x); }
834
+ inline CUDA_CALLABLE double trunc(double x) { return ::trunc(x); }
835
+ inline CUDA_CALLABLE double floor(double x) { return ::floor(x); }
836
+ inline CUDA_CALLABLE double ceil(double x) { return ::ceil(x); }
837
+ inline CUDA_CALLABLE double frac(double x) { return x - trunc(x); }
838
+
839
+ inline CUDA_CALLABLE half round(half x) { return ::roundf(float(x)); }
840
+ inline CUDA_CALLABLE half rint(half x) { return ::rintf(float(x)); }
841
+ inline CUDA_CALLABLE half trunc(half x) { return ::truncf(float(x)); }
842
+ inline CUDA_CALLABLE half floor(half x) { return ::floorf(float(x)); }
843
+ inline CUDA_CALLABLE half ceil(half x) { return ::ceilf(float(x)); }
844
+ inline CUDA_CALLABLE half frac(half x) { return float(x) - trunc(float(x)); }
845
+
846
+ #define DECLARE_ADJOINTS(T)\
847
+ inline CUDA_CALLABLE void adj_log(T a, T& adj_a, T adj_ret)\
848
+ {\
849
+ adj_a += (T(1)/a)*adj_ret;\
850
+ DO_IF_FPCHECK(if (!isfinite(adj_a))\
851
+ {\
852
+ printf("%s:%d - adj_log(%f, %f, %f)\n", __FILE__, __LINE__, float(a), float(adj_a), float(adj_ret));\
853
+ assert(0);\
854
+ })\
855
+ }\
856
+ inline CUDA_CALLABLE void adj_log2(T a, T& adj_a, T adj_ret)\
857
+ { \
858
+ adj_a += (T(1)/a)*(T(1)/log(T(2)))*adj_ret; \
859
+ DO_IF_FPCHECK(if (!isfinite(adj_a))\
860
+ {\
861
+ printf("%s:%d - adj_log2(%f, %f, %f)\n", __FILE__, __LINE__, float(a), float(adj_a), float(adj_ret));\
862
+ assert(0);\
863
+ }) \
864
+ }\
865
+ inline CUDA_CALLABLE void adj_log10(T a, T& adj_a, T adj_ret)\
866
+ {\
867
+ adj_a += (T(1)/a)*(T(1)/log(T(10)))*adj_ret; \
868
+ DO_IF_FPCHECK(if (!isfinite(adj_a))\
869
+ {\
870
+ printf("%s:%d - adj_log10(%f, %f, %f)\n", __FILE__, __LINE__, float(a), float(adj_a), float(adj_ret));\
871
+ assert(0);\
872
+ })\
873
+ }\
874
+ inline CUDA_CALLABLE void adj_exp(T a, T ret, T& adj_a, T adj_ret) { adj_a += ret*adj_ret; }\
875
+ inline CUDA_CALLABLE void adj_pow(T a, T b, T ret, T& adj_a, T& adj_b, T adj_ret)\
876
+ { \
877
+ adj_a += b*pow(a, b-T(1))*adj_ret;\
878
+ adj_b += log(a)*ret*adj_ret;\
879
+ DO_IF_FPCHECK(if (!isfinite(adj_a) || !isfinite(adj_b))\
880
+ {\
881
+ printf("%s:%d - adj_pow(%f, %f, %f, %f, %f)\n", __FILE__, __LINE__, float(a), float(b), float(adj_a), float(adj_b), float(adj_ret));\
882
+ assert(0);\
883
+ })\
884
+ }\
885
+ inline CUDA_CALLABLE void adj_leaky_min(T a, T b, T r, T& adj_a, T& adj_b, T& adj_r, T adj_ret)\
886
+ {\
887
+ if (a < b)\
888
+ adj_a += adj_ret;\
889
+ else\
890
+ {\
891
+ adj_a += r*adj_ret;\
892
+ adj_b += adj_ret;\
893
+ }\
894
+ }\
895
+ inline CUDA_CALLABLE void adj_leaky_max(T a, T b, T r, T& adj_a, T& adj_b, T& adj_r, T adj_ret)\
896
+ {\
897
+ if (a > b)\
898
+ adj_a += adj_ret;\
899
+ else\
900
+ {\
901
+ adj_a += r*adj_ret;\
902
+ adj_b += adj_ret;\
903
+ }\
904
+ }\
905
+ inline CUDA_CALLABLE void adj_acos(T x, T& adj_x, T adj_ret)\
906
+ {\
907
+ T d = sqrt(T(1)-x*x);\
908
+ DO_IF_FPCHECK(adj_x -= (T(1)/d)*adj_ret;\
909
+ if (!isfinite(d) || !isfinite(adj_x))\
910
+ {\
911
+ printf("%s:%d - adj_acos(%f, %f, %f)\n", __FILE__, __LINE__, float(x), float(adj_x), float(adj_ret)); \
912
+ assert(0);\
913
+ })\
914
+ DO_IF_NO_FPCHECK(if (d > T(0))\
915
+ adj_x -= (T(1)/d)*adj_ret;)\
916
+ }\
917
+ inline CUDA_CALLABLE void adj_asin(T x, T& adj_x, T adj_ret)\
918
+ {\
919
+ T d = sqrt(T(1)-x*x);\
920
+ DO_IF_FPCHECK(adj_x += (T(1)/d)*adj_ret;\
921
+ if (!isfinite(d) || !isfinite(adj_x))\
922
+ {\
923
+ printf("%s:%d - adj_asin(%f, %f, %f)\n", __FILE__, __LINE__, float(x), float(adj_x), float(adj_ret)); \
924
+ assert(0);\
925
+ })\
926
+ DO_IF_NO_FPCHECK(if (d > T(0))\
927
+ adj_x += (T(1)/d)*adj_ret;)\
928
+ }\
929
+ inline CUDA_CALLABLE void adj_tan(T x, T& adj_x, T adj_ret)\
930
+ {\
931
+ T cos_x = cos(x);\
932
+ DO_IF_FPCHECK(adj_x += (T(1)/(cos_x*cos_x))*adj_ret;\
933
+ if (!isfinite(adj_x) || cos_x == T(0))\
934
+ {\
935
+ printf("%s:%d - adj_tan(%f, %f, %f)\n", __FILE__, __LINE__, float(x), float(adj_x), float(adj_ret));\
936
+ assert(0);\
937
+ })\
938
+ DO_IF_NO_FPCHECK(if (cos_x != T(0))\
939
+ adj_x += (T(1)/(cos_x*cos_x))*adj_ret;)\
940
+ }\
941
+ inline CUDA_CALLABLE void adj_atan(T x, T& adj_x, T adj_ret)\
942
+ {\
943
+ adj_x += adj_ret /(x*x + T(1));\
944
+ }\
945
+ inline CUDA_CALLABLE void adj_atan2(T y, T x, T& adj_y, T& adj_x, T adj_ret)\
946
+ {\
947
+ T d = x*x + y*y;\
948
+ DO_IF_FPCHECK(adj_x -= y/d*adj_ret;\
949
+ adj_y += x/d*adj_ret;\
950
+ if (!isfinite(adj_x) || !isfinite(adj_y) || d == T(0))\
951
+ {\
952
+ printf("%s:%d - adj_atan2(%f, %f, %f, %f, %f)\n", __FILE__, __LINE__, float(y), float(x), float(adj_y), float(adj_x), float(adj_ret));\
953
+ assert(0);\
954
+ })\
955
+ DO_IF_NO_FPCHECK(if (d > T(0))\
956
+ {\
957
+ adj_x -= (y/d)*adj_ret;\
958
+ adj_y += (x/d)*adj_ret;\
959
+ })\
960
+ }\
961
+ inline CUDA_CALLABLE void adj_sin(T x, T& adj_x, T adj_ret)\
962
+ {\
963
+ adj_x += cos(x)*adj_ret;\
964
+ }\
965
+ inline CUDA_CALLABLE void adj_cos(T x, T& adj_x, T adj_ret)\
966
+ {\
967
+ adj_x -= sin(x)*adj_ret;\
968
+ }\
969
+ inline CUDA_CALLABLE void adj_sinh(T x, T& adj_x, T adj_ret)\
970
+ {\
971
+ adj_x += cosh(x)*adj_ret;\
972
+ }\
973
+ inline CUDA_CALLABLE void adj_cosh(T x, T& adj_x, T adj_ret)\
974
+ {\
975
+ adj_x += sinh(x)*adj_ret;\
976
+ }\
977
+ inline CUDA_CALLABLE void adj_tanh(T x, T ret, T& adj_x, T adj_ret)\
978
+ {\
979
+ adj_x += (T(1) - ret*ret)*adj_ret;\
980
+ }\
981
+ inline CUDA_CALLABLE void adj_sqrt(T x, T ret, T& adj_x, T adj_ret)\
982
+ {\
983
+ adj_x += T(0.5)*(T(1)/ret)*adj_ret;\
984
+ DO_IF_FPCHECK(if (!isfinite(adj_x))\
985
+ {\
986
+ printf("%s:%d - adj_sqrt(%f, %f, %f)\n", __FILE__, __LINE__, float(x), float(adj_x), float(adj_ret));\
987
+ assert(0);\
988
+ })\
989
+ }\
990
+ inline CUDA_CALLABLE void adj_cbrt(T x, T ret, T& adj_x, T adj_ret)\
991
+ {\
992
+ adj_x += (T(1)/T(3))*(T(1)/(ret*ret))*adj_ret;\
993
+ DO_IF_FPCHECK(if (!isfinite(adj_x))\
994
+ {\
995
+ printf("%s:%d - adj_cbrt(%f, %f, %f)\n", __FILE__, __LINE__, float(x), float(adj_x), float(adj_ret));\
996
+ assert(0);\
997
+ })\
998
+ }\
999
+ inline CUDA_CALLABLE void adj_degrees(T x, T& adj_x, T adj_ret)\
1000
+ {\
1001
+ adj_x += RAD_TO_DEG * adj_ret;\
1002
+ }\
1003
+ inline CUDA_CALLABLE void adj_radians(T x, T& adj_x, T adj_ret)\
1004
+ {\
1005
+ adj_x += DEG_TO_RAD * adj_ret;\
1006
+ }\
1007
+ inline CUDA_CALLABLE void adj_round(T x, T& adj_x, T adj_ret){ }\
1008
+ inline CUDA_CALLABLE void adj_rint(T x, T& adj_x, T adj_ret){ }\
1009
+ inline CUDA_CALLABLE void adj_trunc(T x, T& adj_x, T adj_ret){ }\
1010
+ inline CUDA_CALLABLE void adj_floor(T x, T& adj_x, T adj_ret){ }\
1011
+ inline CUDA_CALLABLE void adj_ceil(T x, T& adj_x, T adj_ret){ }\
1012
+ inline CUDA_CALLABLE void adj_frac(T x, T& adj_x, T adj_ret){ }
1013
+
1014
+ DECLARE_ADJOINTS(float16)
1015
+ DECLARE_ADJOINTS(float32)
1016
+ DECLARE_ADJOINTS(float64)
1017
+
1018
+ template <typename C, typename T>
1019
+ CUDA_CALLABLE inline T select(const C& cond, const T& a, const T& b)
1020
+ {
1021
+ // The double NOT operator !! casts to bool without compiler warnings.
1022
+ return (!!cond) ? b : a;
1023
+ }
1024
+
1025
+ template <typename C, typename T>
1026
+ CUDA_CALLABLE inline void adj_select(const C& cond, const T& a, const T& b, C& adj_cond, T& adj_a, T& adj_b, const T& adj_ret)
1027
+ {
1028
+ // The double NOT operator !! casts to bool without compiler warnings.
1029
+ if (!!cond)
1030
+ adj_b += adj_ret;
1031
+ else
1032
+ adj_a += adj_ret;
1033
+ }
1034
+
1035
+ template <typename T>
1036
+ CUDA_CALLABLE inline T copy(const T& src)
1037
+ {
1038
+ return src;
1039
+ }
1040
+
1041
+ template <typename T>
1042
+ CUDA_CALLABLE inline void adj_copy(const T& src, T& adj_src, T& adj_dest)
1043
+ {
1044
+ adj_src += adj_dest;
1045
+ adj_dest = T{};
1046
+ }
1047
+
1048
+ template <typename T>
1049
+ CUDA_CALLABLE inline void assign(T& dest, const T& src)
1050
+ {
1051
+ dest = src;
1052
+ }
1053
+
1054
+ template <typename T>
1055
+ CUDA_CALLABLE inline void adj_assign(T& dest, const T& src, T& adj_dest, T& adj_src)
1056
+ {
1057
+ // this is generally a non-differentiable operation since it violates SSA,
1058
+ // except in read-modify-write statements which are reversible through backpropagation
1059
+ adj_src = adj_dest;
1060
+ adj_dest = T{};
1061
+ }
1062
+
1063
+
1064
+ // some helpful operator overloads (just for C++ use, these are not adjointed)
1065
+
1066
+ template <typename T>
1067
+ CUDA_CALLABLE inline T& operator += (T& a, const T& b) { a = add(a, b); return a; }
1068
+
1069
+ template <typename T>
1070
+ CUDA_CALLABLE inline T& operator -= (T& a, const T& b) { a = sub(a, b); return a; }
1071
+
1072
+ template <typename T>
1073
+ CUDA_CALLABLE inline T operator+(const T& a, const T& b) { return add(a, b); }
1074
+
1075
+ template <typename T>
1076
+ CUDA_CALLABLE inline T operator-(const T& a, const T& b) { return sub(a, b); }
1077
+
1078
+ template <typename T>
1079
+ CUDA_CALLABLE inline T pos(const T& x) { return x; }
1080
+ template <typename T>
1081
+ CUDA_CALLABLE inline void adj_pos(const T& x, T& adj_x, const T& adj_ret) { adj_x += T(adj_ret); }
1082
+
1083
+ // unary negation implemented as negative multiply, not sure the fp implications of this
1084
+ // may be better as 0.0 - x?
1085
+ template <typename T>
1086
+ CUDA_CALLABLE inline T neg(const T& x) { return T(0.0) - x; }
1087
+ template <typename T>
1088
+ CUDA_CALLABLE inline void adj_neg(const T& x, T& adj_x, const T& adj_ret) { adj_x += T(-adj_ret); }
1089
+
1090
+ // unary boolean negation
1091
+ template <typename T>
1092
+ CUDA_CALLABLE inline bool unot(const T& b) { return !b; }
1093
+ template <typename T>
1094
+ CUDA_CALLABLE inline void adj_unot(const T& b, T& adj_b, const bool& adj_ret) { }
1095
+
1096
+ const int LAUNCH_MAX_DIMS = 4; // should match types.py
1097
+
1098
+ struct launch_bounds_t
1099
+ {
1100
+ int shape[LAUNCH_MAX_DIMS]; // size of each dimension
1101
+ int ndim; // number of valid dimension
1102
+ size_t size; // total number of threads
1103
+ };
1104
+
1105
+ #ifndef __CUDACC__
1106
+ static size_t s_threadIdx;
1107
+ #endif
1108
+
1109
+ inline CUDA_CALLABLE size_t grid_index()
1110
+ {
1111
+ #ifdef __CUDACC__
1112
+ // Need to cast at least one of the variables being multiplied so that type promotion happens before the multiplication
1113
+ size_t grid_index = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
1114
+ return grid_index;
1115
+ #else
1116
+ return s_threadIdx;
1117
+ #endif
1118
+ }
1119
+
1120
+ inline CUDA_CALLABLE int tid(size_t index)
1121
+ {
1122
+ // For the 1-D tid() we need to warn the user if we're about to provide a truncated index
1123
+ // Only do this in _DEBUG when called from device to avoid excessive register allocation
1124
+ #if defined(_DEBUG) || !defined(__CUDA_ARCH__)
1125
+ if (index > 2147483647) {
1126
+ printf("Warp warning: tid() is returning an overflowed int\n");
1127
+ }
1128
+ #endif
1129
+ return static_cast<int>(index);
1130
+ }
1131
+
1132
+ inline CUDA_CALLABLE_DEVICE void tid(int& i, int& j, size_t index, const launch_bounds_t& launch_bounds)
1133
+ {
1134
+ const size_t n = launch_bounds.shape[1];
1135
+
1136
+ // convert to work item
1137
+ i = index/n;
1138
+ j = index%n;
1139
+ }
1140
+
1141
+ inline CUDA_CALLABLE_DEVICE void tid(int& i, int& j, int& k, size_t index, const launch_bounds_t& launch_bounds)
1142
+ {
1143
+ const size_t n = launch_bounds.shape[1];
1144
+ const size_t o = launch_bounds.shape[2];
1145
+
1146
+ // convert to work item
1147
+ i = index/(n*o);
1148
+ j = index%(n*o)/o;
1149
+ k = index%o;
1150
+ }
1151
+
1152
+ inline CUDA_CALLABLE_DEVICE void tid(int& i, int& j, int& k, int& l, size_t index, const launch_bounds_t& launch_bounds)
1153
+ {
1154
+ const size_t n = launch_bounds.shape[1];
1155
+ const size_t o = launch_bounds.shape[2];
1156
+ const size_t p = launch_bounds.shape[3];
1157
+
1158
+ // convert to work item
1159
+ i = index/(n*o*p);
1160
+ j = index%(n*o*p)/(o*p);
1161
+ k = index%(o*p)/p;
1162
+ l = index%p;
1163
+ }
1164
+
1165
+ template<typename T>
1166
+ inline CUDA_CALLABLE T atomic_add(T* buf, T value)
1167
+ {
1168
+ #if !defined(__CUDA_ARCH__)
1169
+ T old = buf[0];
1170
+ buf[0] += value;
1171
+ return old;
1172
+ #else
1173
+ return atomicAdd(buf, value);
1174
+ #endif
1175
+ }
1176
+
1177
+ template<>
1178
+ inline CUDA_CALLABLE float16 atomic_add(float16* buf, float16 value)
1179
+ {
1180
+ #if !defined(__CUDA_ARCH__)
1181
+ float16 old = buf[0];
1182
+ buf[0] += value;
1183
+ return old;
1184
+ #elif defined(__clang__) // CUDA compiled by Clang
1185
+ __half r = atomicAdd(reinterpret_cast<__half*>(buf), *reinterpret_cast<__half*>(&value));
1186
+ return *reinterpret_cast<float16*>(&r);
1187
+ #else // CUDA compiled by NVRTC
1188
+ //return atomicAdd(buf, value);
1189
+
1190
+ /* Define __PTR for atomicAdd prototypes below, undef after done */
1191
+ #if (defined(_MSC_VER) && defined(_WIN64)) || defined(__LP64__) || defined(__CUDACC_RTC__)
1192
+ #define __PTR "l"
1193
+ #else
1194
+ #define __PTR "r"
1195
+ #endif /*(defined(_MSC_VER) && defined(_WIN64)) || defined(__LP64__) || defined(__CUDACC_RTC__)*/
1196
+
1197
+ half r = 0.0;
1198
+
1199
+ #if __CUDA_ARCH__ >= 700
1200
+
1201
+ asm volatile ("{ atom.add.noftz.f16 %0,[%1],%2; }\n"
1202
+ : "=h"(r.u)
1203
+ : __PTR(buf), "h"(value.u)
1204
+ : "memory");
1205
+ #endif
1206
+
1207
+ return r;
1208
+
1209
+ #undef __PTR
1210
+
1211
+ #endif // CUDA compiled by NVRTC
1212
+
1213
+ }
1214
+
1215
+ // emulate atomic float max
1216
+ inline CUDA_CALLABLE float atomic_max(float* address, float val)
1217
+ {
1218
+ #if defined(__CUDA_ARCH__)
1219
+ int *address_as_int = (int*)address;
1220
+ int old = *address_as_int, assumed;
1221
+
1222
+ while (val > __int_as_float(old))
1223
+ {
1224
+ assumed = old;
1225
+ old = atomicCAS(address_as_int, assumed,
1226
+ __float_as_int(val));
1227
+ }
1228
+
1229
+ return __int_as_float(old);
1230
+
1231
+ #else
1232
+ float old = *address;
1233
+ *address = max(old, val);
1234
+ return old;
1235
+ #endif
1236
+ }
1237
+
1238
+ // emulate atomic float min/max with atomicCAS()
1239
+ inline CUDA_CALLABLE float atomic_min(float* address, float val)
1240
+ {
1241
+ #if defined(__CUDA_ARCH__)
1242
+ int *address_as_int = (int*)address;
1243
+ int old = *address_as_int, assumed;
1244
+
1245
+ while (val < __int_as_float(old))
1246
+ {
1247
+ assumed = old;
1248
+ old = atomicCAS(address_as_int, assumed,
1249
+ __float_as_int(val));
1250
+ }
1251
+
1252
+ return __int_as_float(old);
1253
+
1254
+ #else
1255
+ float old = *address;
1256
+ *address = min(old, val);
1257
+ return old;
1258
+ #endif
1259
+ }
1260
+
1261
+ inline CUDA_CALLABLE int atomic_max(int* address, int val)
1262
+ {
1263
+ #if defined(__CUDA_ARCH__)
1264
+ return atomicMax(address, val);
1265
+
1266
+ #else
1267
+ int old = *address;
1268
+ *address = max(old, val);
1269
+ return old;
1270
+ #endif
1271
+ }
1272
+
1273
+ // atomic int min
1274
+ inline CUDA_CALLABLE int atomic_min(int* address, int val)
1275
+ {
1276
+ #if defined(__CUDA_ARCH__)
1277
+ return atomicMin(address, val);
1278
+
1279
+ #else
1280
+ int old = *address;
1281
+ *address = min(old, val);
1282
+ return old;
1283
+ #endif
1284
+ }
1285
+
1286
+ // default behavior for adjoint of atomic min/max operation that accumulates gradients for all elements matching the min/max value
1287
+ template <typename T>
1288
+ CUDA_CALLABLE inline void adj_atomic_minmax(T *addr, T *adj_addr, const T &value, T &adj_value)
1289
+ {
1290
+ if (value == *addr)
1291
+ adj_value += *adj_addr;
1292
+ }
1293
+
1294
+ // for integral types we do not accumulate gradients
1295
+ CUDA_CALLABLE inline void adj_atomic_minmax(int8* buf, int8* adj_buf, const int8 &value, int8 &adj_value) { }
1296
+ CUDA_CALLABLE inline void adj_atomic_minmax(uint8* buf, uint8* adj_buf, const uint8 &value, uint8 &adj_value) { }
1297
+ CUDA_CALLABLE inline void adj_atomic_minmax(int16* buf, int16* adj_buf, const int16 &value, int16 &adj_value) { }
1298
+ CUDA_CALLABLE inline void adj_atomic_minmax(uint16* buf, uint16* adj_buf, const uint16 &value, uint16 &adj_value) { }
1299
+ CUDA_CALLABLE inline void adj_atomic_minmax(int32* buf, int32* adj_buf, const int32 &value, int32 &adj_value) { }
1300
+ CUDA_CALLABLE inline void adj_atomic_minmax(uint32* buf, uint32* adj_buf, const uint32 &value, uint32 &adj_value) { }
1301
+ CUDA_CALLABLE inline void adj_atomic_minmax(int64* buf, int64* adj_buf, const int64 &value, int64 &adj_value) { }
1302
+ CUDA_CALLABLE inline void adj_atomic_minmax(uint64* buf, uint64* adj_buf, const uint64 &value, uint64 &adj_value) { }
1303
+ CUDA_CALLABLE inline void adj_atomic_minmax(bool* buf, bool* adj_buf, const bool &value, bool &adj_value) { }
1304
+
1305
+
1306
+ } // namespace wp
1307
+
1308
+
1309
+ // bool and printf are defined outside of the wp namespace in crt.h, hence
1310
+ // their adjoint counterparts are also defined in the global namespace.
1311
+ template <typename T>
1312
+ CUDA_CALLABLE inline void adj_bool(T, T&, bool) {}
1313
+ inline CUDA_CALLABLE void adj_printf(const char* fmt, ...) {}
1314
+
1315
+
1316
+ #include "vec.h"
1317
+ #include "mat.h"
1318
+ #include "quat.h"
1319
+ #include "spatial.h"
1320
+ #include "intersect.h"
1321
+ #include "intersect_adj.h"
1322
+
1323
+ //--------------
1324
+ namespace wp
1325
+ {
1326
+
1327
+
1328
+ // dot for scalar types just to make some templates compile for scalar/vector
1329
+ inline CUDA_CALLABLE float dot(float a, float b) { return mul(a, b); }
1330
+ inline CUDA_CALLABLE void adj_dot(float a, float b, float& adj_a, float& adj_b, float adj_ret) { adj_mul(a, b, adj_a, adj_b, adj_ret); }
1331
+ inline CUDA_CALLABLE float tensordot(float a, float b) { return mul(a, b); }
1332
+
1333
+
1334
+ #define DECLARE_INTERP_FUNCS(T) \
1335
+ CUDA_CALLABLE inline T smoothstep(T edge0, T edge1, T x)\
1336
+ {\
1337
+ x = clamp((x - edge0) / (edge1 - edge0), T(0), T(1));\
1338
+ return x * x * (T(3) - T(2) * x);\
1339
+ }\
1340
+ CUDA_CALLABLE inline void adj_smoothstep(T edge0, T edge1, T x, T& adj_edge0, T& adj_edge1, T& adj_x, T adj_ret)\
1341
+ {\
1342
+ T ab = edge0 - edge1;\
1343
+ T ax = edge0 - x;\
1344
+ T bx = edge1 - x;\
1345
+ T xb = x - edge1;\
1346
+ \
1347
+ if (bx / ab >= T(0) || ax / ab <= T(0))\
1348
+ {\
1349
+ return;\
1350
+ }\
1351
+ \
1352
+ T ab3 = ab * ab * ab;\
1353
+ T ab4 = ab3 * ab;\
1354
+ adj_edge0 += adj_ret * ((T(6) * ax * bx * bx) / ab4);\
1355
+ adj_edge1 += adj_ret * ((T(6) * ax * ax * xb) / ab4);\
1356
+ adj_x += adj_ret * ((T(6) * ax * bx ) / ab3);\
1357
+ }\
1358
+ CUDA_CALLABLE inline T lerp(const T& a, const T& b, T t)\
1359
+ {\
1360
+ return a*(T(1)-t) + b*t;\
1361
+ }\
1362
+ CUDA_CALLABLE inline void adj_lerp(const T& a, const T& b, T t, T& adj_a, T& adj_b, T& adj_t, const T& adj_ret)\
1363
+ {\
1364
+ adj_a += adj_ret*(T(1)-t);\
1365
+ adj_b += adj_ret*t;\
1366
+ adj_t += b*adj_ret - a*adj_ret;\
1367
+ }
1368
+
1369
+ DECLARE_INTERP_FUNCS(float16)
1370
+ DECLARE_INTERP_FUNCS(float32)
1371
+ DECLARE_INTERP_FUNCS(float64)
1372
+
1373
+ inline CUDA_CALLABLE void print(const str s)
1374
+ {
1375
+ printf("%s\n", s);
1376
+ }
1377
+
1378
+ inline CUDA_CALLABLE void print(int i)
1379
+ {
1380
+ printf("%d\n", i);
1381
+ }
1382
+
1383
+ inline CUDA_CALLABLE void print(short i)
1384
+ {
1385
+ printf("%hd\n", i);
1386
+ }
1387
+
1388
+ inline CUDA_CALLABLE void print(long i)
1389
+ {
1390
+ printf("%ld\n", i);
1391
+ }
1392
+
1393
+ inline CUDA_CALLABLE void print(long long i)
1394
+ {
1395
+ printf("%lld\n", i);
1396
+ }
1397
+
1398
+ inline CUDA_CALLABLE void print(unsigned i)
1399
+ {
1400
+ printf("%u\n", i);
1401
+ }
1402
+
1403
+ inline CUDA_CALLABLE void print(unsigned short i)
1404
+ {
1405
+ printf("%hu\n", i);
1406
+ }
1407
+
1408
+ inline CUDA_CALLABLE void print(unsigned long i)
1409
+ {
1410
+ printf("%lu\n", i);
1411
+ }
1412
+
1413
+ inline CUDA_CALLABLE void print(unsigned long long i)
1414
+ {
1415
+ printf("%llu\n", i);
1416
+ }
1417
+
1418
+ template<unsigned Length, typename Type>
1419
+ inline CUDA_CALLABLE void print(vec_t<Length, Type> v)
1420
+ {
1421
+ for( unsigned i=0; i < Length; ++i )
1422
+ {
1423
+ printf("%g ", float(v[i]));
1424
+ }
1425
+ printf("\n");
1426
+ }
1427
+
1428
+ template<typename Type>
1429
+ inline CUDA_CALLABLE void print(quat_t<Type> i)
1430
+ {
1431
+ printf("%g %g %g %g\n", float(i.x), float(i.y), float(i.z), float(i.w));
1432
+ }
1433
+
1434
+ template<unsigned Rows,unsigned Cols,typename Type>
1435
+ inline CUDA_CALLABLE void print(const mat_t<Rows,Cols,Type> &m)
1436
+ {
1437
+ for( unsigned i=0; i< Rows; ++i )
1438
+ {
1439
+ for( unsigned j=0; j< Cols; ++j )
1440
+ {
1441
+ printf("%g ",float(m.data[i][j]));
1442
+ }
1443
+ printf("\n");
1444
+ }
1445
+ }
1446
+
1447
+ template<typename Type>
1448
+ inline CUDA_CALLABLE void print(transform_t<Type> t)
1449
+ {
1450
+ printf("(%g %g %g) (%g %g %g %g)\n", float(t.p[0]), float(t.p[1]), float(t.p[2]), float(t.q.x), float(t.q.y), float(t.q.z), float(t.q.w));
1451
+ }
1452
+
1453
+ inline CUDA_CALLABLE void adj_print(int i, int adj_i) { printf("%d adj: %d\n", i, adj_i); }
1454
+ inline CUDA_CALLABLE void adj_print(float f, float adj_f) { printf("%g adj: %g\n", f, adj_f); }
1455
+ inline CUDA_CALLABLE void adj_print(short f, short adj_f) { printf("%hd adj: %hd\n", f, adj_f); }
1456
+ inline CUDA_CALLABLE void adj_print(long f, long adj_f) { printf("%ld adj: %ld\n", f, adj_f); }
1457
+ inline CUDA_CALLABLE void adj_print(long long f, long long adj_f) { printf("%lld adj: %lld\n", f, adj_f); }
1458
+ inline CUDA_CALLABLE void adj_print(unsigned f, unsigned adj_f) { printf("%u adj: %u\n", f, adj_f); }
1459
+ inline CUDA_CALLABLE void adj_print(unsigned short f, unsigned short adj_f) { printf("%hu adj: %hu\n", f, adj_f); }
1460
+ inline CUDA_CALLABLE void adj_print(unsigned long f, unsigned long adj_f) { printf("%lu adj: %lu\n", f, adj_f); }
1461
+ inline CUDA_CALLABLE void adj_print(unsigned long long f, unsigned long long adj_f) { printf("%llu adj: %llu\n", f, adj_f); }
1462
+ inline CUDA_CALLABLE void adj_print(half h, half adj_h) { printf("%g adj: %g\n", half_to_float(h), half_to_float(adj_h)); }
1463
+ inline CUDA_CALLABLE void adj_print(double f, double adj_f) { printf("%g adj: %g\n", f, adj_f); }
1464
+
1465
+ template<unsigned Length, typename Type>
1466
+ inline CUDA_CALLABLE void adj_print(vec_t<Length, Type> v, vec_t<Length, Type>& adj_v) { printf("%g %g adj: %g %g \n", v[0], v[1], adj_v[0], adj_v[1]); }
1467
+
1468
+ template<unsigned Rows, unsigned Cols, typename Type>
1469
+ inline CUDA_CALLABLE void adj_print(mat_t<Rows, Cols, Type> m, mat_t<Rows, Cols, Type>& adj_m) { }
1470
+
1471
+ template<typename Type>
1472
+ inline CUDA_CALLABLE void adj_print(quat_t<Type> q, quat_t<Type>& adj_q) { printf("%g %g %g %g adj: %g %g %g %g\n", q.x, q.y, q.z, q.w, adj_q.x, adj_q.y, adj_q.z, adj_q.w); }
1473
+
1474
+ template<typename Type>
1475
+ inline CUDA_CALLABLE void adj_print(transform_t<Type> t, transform_t<Type>& adj_t) {}
1476
+
1477
+ inline CUDA_CALLABLE void adj_print(str t, str& adj_t) {}
1478
+
1479
+
1480
+ template <typename T>
1481
+ inline CUDA_CALLABLE void expect_eq(const T& actual, const T& expected)
1482
+ {
1483
+ if (!(actual == expected))
1484
+ {
1485
+ printf("Error, expect_eq() failed:\n");
1486
+ printf("\t Expected: "); print(expected);
1487
+ printf("\t Actual: "); print(actual);
1488
+ }
1489
+ }
1490
+
1491
+ template <typename T>
1492
+ inline CUDA_CALLABLE void adj_expect_eq(const T& a, const T& b, T& adj_a, T& adj_b)
1493
+ {
1494
+ // nop
1495
+ }
1496
+
1497
+ template <typename T>
1498
+ inline CUDA_CALLABLE void expect_neq(const T& actual, const T& expected)
1499
+ {
1500
+ if (actual == expected)
1501
+ {
1502
+ printf("Error, expect_neq() failed:\n");
1503
+ printf("\t Expected: "); print(expected);
1504
+ printf("\t Actual: "); print(actual);
1505
+ }
1506
+ }
1507
+
1508
+ template <typename T>
1509
+ inline CUDA_CALLABLE void adj_expect_neq(const T& a, const T& b, T& adj_a, T& adj_b)
1510
+ {
1511
+ // nop
1512
+ }
1513
+
1514
+ template <typename T>
1515
+ inline CUDA_CALLABLE void expect_near(const T& actual, const T& expected, const T& tolerance)
1516
+ {
1517
+ if (abs(actual - expected) > tolerance)
1518
+ {
1519
+ printf("Error, expect_near() failed with tolerance "); print(tolerance);
1520
+ printf("\t Expected: "); print(expected);
1521
+ printf("\t Actual: "); print(actual);
1522
+ }
1523
+ }
1524
+
1525
+ inline CUDA_CALLABLE void expect_near(const vec3& actual, const vec3& expected, const float& tolerance)
1526
+ {
1527
+ const float diff = max(max(abs(actual[0] - expected[0]), abs(actual[1] - expected[1])), abs(actual[2] - expected[2]));
1528
+ if (diff > tolerance)
1529
+ {
1530
+ printf("Error, expect_near() failed with tolerance "); print(tolerance);
1531
+ printf("\t Expected: "); print(expected);
1532
+ printf("\t Actual: "); print(actual);
1533
+ }
1534
+ }
1535
+
1536
+ template <typename T>
1537
+ inline CUDA_CALLABLE void adj_expect_near(const T& actual, const T& expected, const T& tolerance, T& adj_actual, T& adj_expected, T& adj_tolerance)
1538
+ {
1539
+ // nop
1540
+ }
1541
+
1542
+ inline CUDA_CALLABLE void adj_expect_near(const vec3& actual, const vec3& expected, float tolerance, vec3& adj_actual, vec3& adj_expected, float adj_tolerance)
1543
+ {
1544
+ // nop
1545
+ }
1546
+
1547
+
1548
+ } // namespace wp
1549
+
1550
+ // include array.h so we have the print, isfinite functions for the inner array types defined
1551
+ #include "array.h"
1552
+ #include "mesh.h"
1553
+ #include "bvh.h"
1554
+ #include "svd.h"
1555
+ #include "hashgrid.h"
1556
+ #include "volume.h"
1557
+ #include "range.h"
1558
+ #include "rand.h"
1559
+ #include "noise.h"
1560
+ #include "matnn.h"