warp-lang 1.10.0__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (468) hide show
  1. warp/__init__.py +334 -0
  2. warp/__init__.pyi +5856 -0
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1077 -0
  5. warp/_src/build.py +620 -0
  6. warp/_src/build_dll.py +642 -0
  7. warp/_src/builtins.py +10555 -0
  8. warp/_src/codegen.py +4361 -0
  9. warp/_src/config.py +178 -0
  10. warp/_src/constants.py +59 -0
  11. warp/_src/context.py +8352 -0
  12. warp/_src/dlpack.py +464 -0
  13. warp/_src/fabric.py +362 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +510 -0
  16. warp/_src/fem/cache.py +689 -0
  17. warp/_src/fem/dirichlet.py +190 -0
  18. warp/_src/fem/domain.py +553 -0
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +703 -0
  21. warp/_src/fem/field/nodal_field.py +403 -0
  22. warp/_src/fem/field/restriction.py +39 -0
  23. warp/_src/fem/field/virtual.py +1021 -0
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/_src/fem/geometry/adaptive_nanogrid.py +782 -0
  26. warp/_src/fem/geometry/closest_point.py +99 -0
  27. warp/_src/fem/geometry/deformed_geometry.py +277 -0
  28. warp/_src/fem/geometry/element.py +854 -0
  29. warp/_src/fem/geometry/geometry.py +693 -0
  30. warp/_src/fem/geometry/grid_2d.py +478 -0
  31. warp/_src/fem/geometry/grid_3d.py +539 -0
  32. warp/_src/fem/geometry/hexmesh.py +956 -0
  33. warp/_src/fem/geometry/nanogrid.py +660 -0
  34. warp/_src/fem/geometry/partition.py +483 -0
  35. warp/_src/fem/geometry/quadmesh.py +597 -0
  36. warp/_src/fem/geometry/tetmesh.py +762 -0
  37. warp/_src/fem/geometry/trimesh.py +588 -0
  38. warp/_src/fem/integrate.py +2507 -0
  39. warp/_src/fem/linalg.py +385 -0
  40. warp/_src/fem/operator.py +398 -0
  41. warp/_src/fem/polynomial.py +231 -0
  42. warp/_src/fem/quadrature/__init__.py +17 -0
  43. warp/_src/fem/quadrature/pic_quadrature.py +318 -0
  44. warp/_src/fem/quadrature/quadrature.py +665 -0
  45. warp/_src/fem/space/__init__.py +248 -0
  46. warp/_src/fem/space/basis_function_space.py +499 -0
  47. warp/_src/fem/space/basis_space.py +681 -0
  48. warp/_src/fem/space/dof_mapper.py +253 -0
  49. warp/_src/fem/space/function_space.py +312 -0
  50. warp/_src/fem/space/grid_2d_function_space.py +179 -0
  51. warp/_src/fem/space/grid_3d_function_space.py +229 -0
  52. warp/_src/fem/space/hexmesh_function_space.py +255 -0
  53. warp/_src/fem/space/nanogrid_function_space.py +199 -0
  54. warp/_src/fem/space/partition.py +435 -0
  55. warp/_src/fem/space/quadmesh_function_space.py +222 -0
  56. warp/_src/fem/space/restriction.py +221 -0
  57. warp/_src/fem/space/shape/__init__.py +152 -0
  58. warp/_src/fem/space/shape/cube_shape_function.py +1107 -0
  59. warp/_src/fem/space/shape/shape_function.py +134 -0
  60. warp/_src/fem/space/shape/square_shape_function.py +928 -0
  61. warp/_src/fem/space/shape/tet_shape_function.py +829 -0
  62. warp/_src/fem/space/shape/triangle_shape_function.py +674 -0
  63. warp/_src/fem/space/tetmesh_function_space.py +270 -0
  64. warp/_src/fem/space/topology.py +461 -0
  65. warp/_src/fem/space/trimesh_function_space.py +193 -0
  66. warp/_src/fem/types.py +114 -0
  67. warp/_src/fem/utils.py +488 -0
  68. warp/_src/jax.py +188 -0
  69. warp/_src/jax_experimental/__init__.py +14 -0
  70. warp/_src/jax_experimental/custom_call.py +389 -0
  71. warp/_src/jax_experimental/ffi.py +1286 -0
  72. warp/_src/jax_experimental/xla_ffi.py +658 -0
  73. warp/_src/marching_cubes.py +710 -0
  74. warp/_src/math.py +416 -0
  75. warp/_src/optim/__init__.py +14 -0
  76. warp/_src/optim/adam.py +165 -0
  77. warp/_src/optim/linear.py +1608 -0
  78. warp/_src/optim/sgd.py +114 -0
  79. warp/_src/paddle.py +408 -0
  80. warp/_src/render/__init__.py +14 -0
  81. warp/_src/render/imgui_manager.py +291 -0
  82. warp/_src/render/render_opengl.py +3638 -0
  83. warp/_src/render/render_usd.py +939 -0
  84. warp/_src/render/utils.py +162 -0
  85. warp/_src/sparse.py +2718 -0
  86. warp/_src/tape.py +1208 -0
  87. warp/_src/thirdparty/__init__.py +0 -0
  88. warp/_src/thirdparty/appdirs.py +598 -0
  89. warp/_src/thirdparty/dlpack.py +145 -0
  90. warp/_src/thirdparty/unittest_parallel.py +676 -0
  91. warp/_src/torch.py +393 -0
  92. warp/_src/types.py +5888 -0
  93. warp/_src/utils.py +1695 -0
  94. warp/autograd.py +33 -0
  95. warp/bin/libwarp-clang.dylib +0 -0
  96. warp/bin/libwarp.dylib +0 -0
  97. warp/build.py +29 -0
  98. warp/build_dll.py +24 -0
  99. warp/codegen.py +24 -0
  100. warp/constants.py +24 -0
  101. warp/context.py +33 -0
  102. warp/dlpack.py +24 -0
  103. warp/examples/__init__.py +24 -0
  104. warp/examples/assets/bear.usd +0 -0
  105. warp/examples/assets/bunny.usd +0 -0
  106. warp/examples/assets/cube.usd +0 -0
  107. warp/examples/assets/nonuniform.usd +0 -0
  108. warp/examples/assets/nvidia_logo.png +0 -0
  109. warp/examples/assets/pixel.jpg +0 -0
  110. warp/examples/assets/rocks.nvdb +0 -0
  111. warp/examples/assets/rocks.usd +0 -0
  112. warp/examples/assets/sphere.usd +0 -0
  113. warp/examples/assets/square_cloth.usd +0 -0
  114. warp/examples/benchmarks/benchmark_api.py +389 -0
  115. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  116. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  117. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  118. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  119. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  120. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  121. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  122. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  123. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  124. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  125. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  126. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  127. warp/examples/benchmarks/benchmark_launches.py +301 -0
  128. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  129. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  130. warp/examples/browse.py +37 -0
  131. warp/examples/core/example_cupy.py +86 -0
  132. warp/examples/core/example_dem.py +241 -0
  133. warp/examples/core/example_fluid.py +299 -0
  134. warp/examples/core/example_graph_capture.py +150 -0
  135. warp/examples/core/example_marching_cubes.py +195 -0
  136. warp/examples/core/example_mesh.py +180 -0
  137. warp/examples/core/example_mesh_intersect.py +211 -0
  138. warp/examples/core/example_nvdb.py +182 -0
  139. warp/examples/core/example_raycast.py +111 -0
  140. warp/examples/core/example_raymarch.py +205 -0
  141. warp/examples/core/example_render_opengl.py +290 -0
  142. warp/examples/core/example_sample_mesh.py +300 -0
  143. warp/examples/core/example_sph.py +411 -0
  144. warp/examples/core/example_spin_lock.py +93 -0
  145. warp/examples/core/example_torch.py +211 -0
  146. warp/examples/core/example_wave.py +269 -0
  147. warp/examples/core/example_work_queue.py +118 -0
  148. warp/examples/distributed/example_jacobi_mpi.py +506 -0
  149. warp/examples/fem/example_adaptive_grid.py +286 -0
  150. warp/examples/fem/example_apic_fluid.py +469 -0
  151. warp/examples/fem/example_burgers.py +261 -0
  152. warp/examples/fem/example_convection_diffusion.py +181 -0
  153. warp/examples/fem/example_convection_diffusion_dg.py +225 -0
  154. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  155. warp/examples/fem/example_deformed_geometry.py +172 -0
  156. warp/examples/fem/example_diffusion.py +196 -0
  157. warp/examples/fem/example_diffusion_3d.py +225 -0
  158. warp/examples/fem/example_diffusion_mgpu.py +225 -0
  159. warp/examples/fem/example_distortion_energy.py +228 -0
  160. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  161. warp/examples/fem/example_magnetostatics.py +242 -0
  162. warp/examples/fem/example_mixed_elasticity.py +293 -0
  163. warp/examples/fem/example_navier_stokes.py +263 -0
  164. warp/examples/fem/example_nonconforming_contact.py +300 -0
  165. warp/examples/fem/example_stokes.py +213 -0
  166. warp/examples/fem/example_stokes_transfer.py +262 -0
  167. warp/examples/fem/example_streamlines.py +357 -0
  168. warp/examples/fem/utils.py +1047 -0
  169. warp/examples/interop/example_jax_callable.py +146 -0
  170. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  171. warp/examples/interop/example_jax_kernel.py +232 -0
  172. warp/examples/optim/example_diffray.py +561 -0
  173. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  174. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  175. warp/examples/tile/example_tile_cholesky.py +88 -0
  176. warp/examples/tile/example_tile_convolution.py +66 -0
  177. warp/examples/tile/example_tile_fft.py +55 -0
  178. warp/examples/tile/example_tile_filtering.py +113 -0
  179. warp/examples/tile/example_tile_matmul.py +85 -0
  180. warp/examples/tile/example_tile_mcgp.py +191 -0
  181. warp/examples/tile/example_tile_mlp.py +385 -0
  182. warp/examples/tile/example_tile_nbody.py +199 -0
  183. warp/fabric.py +24 -0
  184. warp/fem/__init__.py +173 -0
  185. warp/fem/adaptivity.py +26 -0
  186. warp/fem/cache.py +30 -0
  187. warp/fem/dirichlet.py +24 -0
  188. warp/fem/field/__init__.py +24 -0
  189. warp/fem/field/field.py +26 -0
  190. warp/fem/geometry/__init__.py +21 -0
  191. warp/fem/geometry/closest_point.py +31 -0
  192. warp/fem/linalg.py +38 -0
  193. warp/fem/operator.py +32 -0
  194. warp/fem/polynomial.py +29 -0
  195. warp/fem/space/__init__.py +22 -0
  196. warp/fem/space/basis_space.py +24 -0
  197. warp/fem/space/shape/__init__.py +68 -0
  198. warp/fem/space/topology.py +24 -0
  199. warp/fem/types.py +24 -0
  200. warp/fem/utils.py +32 -0
  201. warp/jax.py +29 -0
  202. warp/jax_experimental/__init__.py +29 -0
  203. warp/jax_experimental/custom_call.py +29 -0
  204. warp/jax_experimental/ffi.py +39 -0
  205. warp/jax_experimental/xla_ffi.py +24 -0
  206. warp/marching_cubes.py +24 -0
  207. warp/math.py +37 -0
  208. warp/native/array.h +1687 -0
  209. warp/native/builtin.h +2327 -0
  210. warp/native/bvh.cpp +562 -0
  211. warp/native/bvh.cu +826 -0
  212. warp/native/bvh.h +555 -0
  213. warp/native/clang/clang.cpp +541 -0
  214. warp/native/coloring.cpp +622 -0
  215. warp/native/crt.cpp +51 -0
  216. warp/native/crt.h +568 -0
  217. warp/native/cuda_crt.h +1058 -0
  218. warp/native/cuda_util.cpp +677 -0
  219. warp/native/cuda_util.h +313 -0
  220. warp/native/error.cpp +77 -0
  221. warp/native/error.h +36 -0
  222. warp/native/exports.h +2023 -0
  223. warp/native/fabric.h +246 -0
  224. warp/native/hashgrid.cpp +311 -0
  225. warp/native/hashgrid.cu +89 -0
  226. warp/native/hashgrid.h +240 -0
  227. warp/native/initializer_array.h +41 -0
  228. warp/native/intersect.h +1253 -0
  229. warp/native/intersect_adj.h +375 -0
  230. warp/native/intersect_tri.h +348 -0
  231. warp/native/mat.h +5189 -0
  232. warp/native/mathdx.cpp +93 -0
  233. warp/native/matnn.h +221 -0
  234. warp/native/mesh.cpp +266 -0
  235. warp/native/mesh.cu +406 -0
  236. warp/native/mesh.h +2097 -0
  237. warp/native/nanovdb/GridHandle.h +533 -0
  238. warp/native/nanovdb/HostBuffer.h +591 -0
  239. warp/native/nanovdb/NanoVDB.h +6246 -0
  240. warp/native/nanovdb/NodeManager.h +323 -0
  241. warp/native/nanovdb/PNanoVDB.h +3390 -0
  242. warp/native/noise.h +859 -0
  243. warp/native/quat.h +1664 -0
  244. warp/native/rand.h +342 -0
  245. warp/native/range.h +145 -0
  246. warp/native/reduce.cpp +174 -0
  247. warp/native/reduce.cu +363 -0
  248. warp/native/runlength_encode.cpp +79 -0
  249. warp/native/runlength_encode.cu +61 -0
  250. warp/native/scan.cpp +47 -0
  251. warp/native/scan.cu +55 -0
  252. warp/native/scan.h +23 -0
  253. warp/native/solid_angle.h +466 -0
  254. warp/native/sort.cpp +251 -0
  255. warp/native/sort.cu +286 -0
  256. warp/native/sort.h +35 -0
  257. warp/native/sparse.cpp +241 -0
  258. warp/native/sparse.cu +435 -0
  259. warp/native/spatial.h +1306 -0
  260. warp/native/svd.h +727 -0
  261. warp/native/temp_buffer.h +46 -0
  262. warp/native/tile.h +4124 -0
  263. warp/native/tile_radix_sort.h +1112 -0
  264. warp/native/tile_reduce.h +838 -0
  265. warp/native/tile_scan.h +240 -0
  266. warp/native/tuple.h +189 -0
  267. warp/native/vec.h +2199 -0
  268. warp/native/version.h +23 -0
  269. warp/native/volume.cpp +501 -0
  270. warp/native/volume.cu +68 -0
  271. warp/native/volume.h +970 -0
  272. warp/native/volume_builder.cu +483 -0
  273. warp/native/volume_builder.h +52 -0
  274. warp/native/volume_impl.h +70 -0
  275. warp/native/warp.cpp +1143 -0
  276. warp/native/warp.cu +4604 -0
  277. warp/native/warp.h +358 -0
  278. warp/optim/__init__.py +20 -0
  279. warp/optim/adam.py +24 -0
  280. warp/optim/linear.py +35 -0
  281. warp/optim/sgd.py +24 -0
  282. warp/paddle.py +24 -0
  283. warp/py.typed +0 -0
  284. warp/render/__init__.py +22 -0
  285. warp/render/imgui_manager.py +29 -0
  286. warp/render/render_opengl.py +24 -0
  287. warp/render/render_usd.py +24 -0
  288. warp/render/utils.py +24 -0
  289. warp/sparse.py +51 -0
  290. warp/tape.py +24 -0
  291. warp/tests/__init__.py +1 -0
  292. warp/tests/__main__.py +4 -0
  293. warp/tests/assets/curlnoise_golden.npy +0 -0
  294. warp/tests/assets/mlp_golden.npy +0 -0
  295. warp/tests/assets/pixel.npy +0 -0
  296. warp/tests/assets/pnoise_golden.npy +0 -0
  297. warp/tests/assets/spiky.usd +0 -0
  298. warp/tests/assets/test_grid.nvdb +0 -0
  299. warp/tests/assets/test_index_grid.nvdb +0 -0
  300. warp/tests/assets/test_int32_grid.nvdb +0 -0
  301. warp/tests/assets/test_vec_grid.nvdb +0 -0
  302. warp/tests/assets/torus.nvdb +0 -0
  303. warp/tests/assets/torus.usda +105 -0
  304. warp/tests/aux_test_class_kernel.py +34 -0
  305. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  306. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  307. warp/tests/aux_test_dependent.py +29 -0
  308. warp/tests/aux_test_grad_customs.py +29 -0
  309. warp/tests/aux_test_instancing_gc.py +26 -0
  310. warp/tests/aux_test_module_aot.py +7 -0
  311. warp/tests/aux_test_module_unload.py +23 -0
  312. warp/tests/aux_test_name_clash1.py +40 -0
  313. warp/tests/aux_test_name_clash2.py +40 -0
  314. warp/tests/aux_test_reference.py +9 -0
  315. warp/tests/aux_test_reference_reference.py +8 -0
  316. warp/tests/aux_test_square.py +16 -0
  317. warp/tests/aux_test_unresolved_func.py +22 -0
  318. warp/tests/aux_test_unresolved_symbol.py +22 -0
  319. warp/tests/cuda/__init__.py +0 -0
  320. warp/tests/cuda/test_async.py +676 -0
  321. warp/tests/cuda/test_conditional_captures.py +1147 -0
  322. warp/tests/cuda/test_ipc.py +124 -0
  323. warp/tests/cuda/test_mempool.py +233 -0
  324. warp/tests/cuda/test_multigpu.py +169 -0
  325. warp/tests/cuda/test_peer.py +139 -0
  326. warp/tests/cuda/test_pinned.py +84 -0
  327. warp/tests/cuda/test_streams.py +691 -0
  328. warp/tests/geometry/__init__.py +0 -0
  329. warp/tests/geometry/test_bvh.py +335 -0
  330. warp/tests/geometry/test_hash_grid.py +259 -0
  331. warp/tests/geometry/test_marching_cubes.py +294 -0
  332. warp/tests/geometry/test_mesh.py +318 -0
  333. warp/tests/geometry/test_mesh_query_aabb.py +392 -0
  334. warp/tests/geometry/test_mesh_query_point.py +935 -0
  335. warp/tests/geometry/test_mesh_query_ray.py +323 -0
  336. warp/tests/geometry/test_volume.py +1103 -0
  337. warp/tests/geometry/test_volume_write.py +346 -0
  338. warp/tests/interop/__init__.py +0 -0
  339. warp/tests/interop/test_dlpack.py +730 -0
  340. warp/tests/interop/test_jax.py +1673 -0
  341. warp/tests/interop/test_paddle.py +800 -0
  342. warp/tests/interop/test_torch.py +1001 -0
  343. warp/tests/run_coverage_serial.py +39 -0
  344. warp/tests/test_adam.py +162 -0
  345. warp/tests/test_arithmetic.py +1096 -0
  346. warp/tests/test_array.py +3756 -0
  347. warp/tests/test_array_reduce.py +156 -0
  348. warp/tests/test_assert.py +303 -0
  349. warp/tests/test_atomic.py +336 -0
  350. warp/tests/test_atomic_bitwise.py +209 -0
  351. warp/tests/test_atomic_cas.py +312 -0
  352. warp/tests/test_bool.py +220 -0
  353. warp/tests/test_builtins_resolution.py +732 -0
  354. warp/tests/test_closest_point_edge_edge.py +327 -0
  355. warp/tests/test_codegen.py +974 -0
  356. warp/tests/test_codegen_instancing.py +1495 -0
  357. warp/tests/test_compile_consts.py +215 -0
  358. warp/tests/test_conditional.py +298 -0
  359. warp/tests/test_context.py +35 -0
  360. warp/tests/test_copy.py +319 -0
  361. warp/tests/test_ctypes.py +618 -0
  362. warp/tests/test_dense.py +73 -0
  363. warp/tests/test_devices.py +127 -0
  364. warp/tests/test_enum.py +136 -0
  365. warp/tests/test_examples.py +424 -0
  366. warp/tests/test_fabricarray.py +998 -0
  367. warp/tests/test_fast_math.py +72 -0
  368. warp/tests/test_fem.py +2204 -0
  369. warp/tests/test_fixedarray.py +229 -0
  370. warp/tests/test_fp16.py +136 -0
  371. warp/tests/test_func.py +501 -0
  372. warp/tests/test_future_annotations.py +100 -0
  373. warp/tests/test_generics.py +656 -0
  374. warp/tests/test_grad.py +893 -0
  375. warp/tests/test_grad_customs.py +339 -0
  376. warp/tests/test_grad_debug.py +341 -0
  377. warp/tests/test_implicit_init.py +411 -0
  378. warp/tests/test_import.py +45 -0
  379. warp/tests/test_indexedarray.py +1140 -0
  380. warp/tests/test_intersect.py +103 -0
  381. warp/tests/test_iter.py +76 -0
  382. warp/tests/test_large.py +177 -0
  383. warp/tests/test_launch.py +411 -0
  384. warp/tests/test_lerp.py +151 -0
  385. warp/tests/test_linear_solvers.py +223 -0
  386. warp/tests/test_lvalue.py +427 -0
  387. warp/tests/test_map.py +526 -0
  388. warp/tests/test_mat.py +3515 -0
  389. warp/tests/test_mat_assign_copy.py +178 -0
  390. warp/tests/test_mat_constructors.py +573 -0
  391. warp/tests/test_mat_lite.py +122 -0
  392. warp/tests/test_mat_scalar_ops.py +2913 -0
  393. warp/tests/test_math.py +212 -0
  394. warp/tests/test_module_aot.py +287 -0
  395. warp/tests/test_module_hashing.py +258 -0
  396. warp/tests/test_modules_lite.py +70 -0
  397. warp/tests/test_noise.py +252 -0
  398. warp/tests/test_operators.py +299 -0
  399. warp/tests/test_options.py +129 -0
  400. warp/tests/test_overwrite.py +551 -0
  401. warp/tests/test_print.py +408 -0
  402. warp/tests/test_quat.py +2653 -0
  403. warp/tests/test_quat_assign_copy.py +145 -0
  404. warp/tests/test_rand.py +339 -0
  405. warp/tests/test_reload.py +303 -0
  406. warp/tests/test_rounding.py +157 -0
  407. warp/tests/test_runlength_encode.py +196 -0
  408. warp/tests/test_scalar_ops.py +133 -0
  409. warp/tests/test_smoothstep.py +108 -0
  410. warp/tests/test_snippet.py +318 -0
  411. warp/tests/test_sparse.py +845 -0
  412. warp/tests/test_spatial.py +2859 -0
  413. warp/tests/test_spatial_assign_copy.py +160 -0
  414. warp/tests/test_special_values.py +361 -0
  415. warp/tests/test_static.py +640 -0
  416. warp/tests/test_struct.py +901 -0
  417. warp/tests/test_tape.py +242 -0
  418. warp/tests/test_transient_module.py +93 -0
  419. warp/tests/test_triangle_closest_point.py +192 -0
  420. warp/tests/test_tuple.py +361 -0
  421. warp/tests/test_types.py +615 -0
  422. warp/tests/test_utils.py +594 -0
  423. warp/tests/test_vec.py +1408 -0
  424. warp/tests/test_vec_assign_copy.py +143 -0
  425. warp/tests/test_vec_constructors.py +325 -0
  426. warp/tests/test_vec_lite.py +80 -0
  427. warp/tests/test_vec_scalar_ops.py +2327 -0
  428. warp/tests/test_verify_fp.py +100 -0
  429. warp/tests/test_version.py +75 -0
  430. warp/tests/tile/__init__.py +0 -0
  431. warp/tests/tile/test_tile.py +1519 -0
  432. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  433. warp/tests/tile/test_tile_cholesky.py +608 -0
  434. warp/tests/tile/test_tile_load.py +724 -0
  435. warp/tests/tile/test_tile_mathdx.py +156 -0
  436. warp/tests/tile/test_tile_matmul.py +179 -0
  437. warp/tests/tile/test_tile_mlp.py +400 -0
  438. warp/tests/tile/test_tile_reduce.py +950 -0
  439. warp/tests/tile/test_tile_shared_memory.py +376 -0
  440. warp/tests/tile/test_tile_sort.py +121 -0
  441. warp/tests/tile/test_tile_view.py +173 -0
  442. warp/tests/unittest_serial.py +47 -0
  443. warp/tests/unittest_suites.py +430 -0
  444. warp/tests/unittest_utils.py +469 -0
  445. warp/tests/walkthrough_debug.py +95 -0
  446. warp/torch.py +24 -0
  447. warp/types.py +51 -0
  448. warp/utils.py +31 -0
  449. warp_lang-1.10.0.dist-info/METADATA +459 -0
  450. warp_lang-1.10.0.dist-info/RECORD +468 -0
  451. warp_lang-1.10.0.dist-info/WHEEL +5 -0
  452. warp_lang-1.10.0.dist-info/licenses/LICENSE.md +176 -0
  453. warp_lang-1.10.0.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  454. warp_lang-1.10.0.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  455. warp_lang-1.10.0.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  456. warp_lang-1.10.0.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  457. warp_lang-1.10.0.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  458. warp_lang-1.10.0.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  459. warp_lang-1.10.0.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  460. warp_lang-1.10.0.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  461. warp_lang-1.10.0.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  462. warp_lang-1.10.0.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  463. warp_lang-1.10.0.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  464. warp_lang-1.10.0.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  465. warp_lang-1.10.0.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  466. warp_lang-1.10.0.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  467. warp_lang-1.10.0.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  468. warp_lang-1.10.0.dist-info/top_level.txt +1 -0
warp/native/builtin.h ADDED
@@ -0,0 +1,2327 @@
1
+ /*
2
+ * SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: Apache-2.0
4
+ *
5
+ * Licensed under the Apache License, Version 2.0 (the "License");
6
+ * you may not use this file except in compliance with the License.
7
+ * You may obtain a copy of the License at
8
+ *
9
+ * http://www.apache.org/licenses/LICENSE-2.0
10
+ *
11
+ * Unless required by applicable law or agreed to in writing, software
12
+ * distributed under the License is distributed on an "AS IS" BASIS,
13
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ * See the License for the specific language governing permissions and
15
+ * limitations under the License.
16
+ */
17
+
18
+ #pragma once
19
+
20
+ // All built-in types and functions. To be compatible with runtime NVRTC compilation
21
+ // this header must be independently compilable (i.e.: without external SDK headers)
22
+ // to achieve this we redefine a subset of CRT functions (printf, pow, sin, cos, etc)
23
+
24
+ #include "crt.h"
25
+
26
+ #ifdef _WIN32
27
+ #define __restrict__ __restrict
28
+ #endif
29
+
30
+ #if !defined(__CUDACC__)
31
+ #define CUDA_CALLABLE
32
+ #define CUDA_CALLABLE_DEVICE
33
+ #else
34
+ #define CUDA_CALLABLE __host__ __device__
35
+ #define CUDA_CALLABLE_DEVICE __device__
36
+ #endif
37
+
38
+ // Tile block dimension used while building the warp core library
39
+ #ifndef WP_TILE_BLOCK_DIM
40
+ #define WP_TILE_BLOCK_DIM 256
41
+ #endif
42
+
43
+ #ifdef WP_VERIFY_FP
44
+ #define FP_CHECK 1
45
+ #define DO_IF_FPCHECK(X) {X}
46
+ #define DO_IF_NO_FPCHECK(X)
47
+ #else
48
+ #define FP_CHECK 0
49
+ #define DO_IF_FPCHECK(X)
50
+ #define DO_IF_NO_FPCHECK(X) {X}
51
+ #endif
52
+
53
+ #define RAD_TO_DEG 57.29577951308232087679
54
+ #define DEG_TO_RAD 0.01745329251994329577
55
+
56
+ #ifndef M_PI_F
57
+ #define M_PI_F 3.14159265358979323846f
58
+ #endif
59
+
60
+ #ifndef M_2_SQRT_PI_F
61
+ #define M_2_SQRT_PI_F 1.1283791670955125739f // 2/sqrt(pi)
62
+ #endif
63
+
64
+ #ifndef M_2_SQRT_PI
65
+ #define M_2_SQRT_PI 1.1283791670955125738961589031215 // 2/sqrt(pi)
66
+ #endif
67
+
68
+ #ifndef M_SQRT_PI_F_2
69
+ #define M_SQRT_PI_F_2 0.88622692545275801364f // sqrt(pi)/2
70
+ #endif
71
+
72
+ #ifndef M_SQRT_PI_2
73
+ #define M_SQRT_PI_2 0.88622692545275801364908374167057 // sqrt(pi)/2
74
+ #endif
75
+
76
+ #if defined(__CUDACC__) && !defined(_MSC_VER)
77
+ __device__ inline void __debugbreak() { __brkpt(); }
78
+ #endif
79
+
80
+ #if defined(__clang__) && defined(__CUDA__) && defined(__CUDA_ARCH__)
81
+ // clang compiling CUDA code, device mode (NOTE: Used when building core library with Clang)
82
+ #include <cuda_fp16.h>
83
+ #endif
84
+
85
+ namespace wp
86
+ {
87
+
88
+ // numeric types (used from generated kernels)
89
+ typedef float float32;
90
+ typedef double float64;
91
+
92
+ typedef int8_t int8;
93
+ typedef uint8_t uint8;
94
+
95
+ typedef int16_t int16;
96
+ typedef uint16_t uint16;
97
+
98
+ typedef int32_t int32;
99
+ typedef uint32_t uint32;
100
+
101
+ typedef int64_t int64;
102
+ typedef uint64_t uint64;
103
+
104
+
105
+ // matches Python string type for constant strings
106
+ typedef const char* str;
107
+
108
+
109
+
110
+ struct half;
111
+
112
+ CUDA_CALLABLE half float_to_half(float x);
113
+ CUDA_CALLABLE float half_to_float(half x);
114
+
115
+ struct half
116
+ {
117
+ CUDA_CALLABLE inline half() : u(0) {}
118
+
119
+ CUDA_CALLABLE inline half(float f)
120
+ {
121
+ *this = float_to_half(f);
122
+ }
123
+
124
+ unsigned short u;
125
+
126
+ CUDA_CALLABLE inline bool operator==(const half& h) const
127
+ {
128
+ // Use float32 to get IEEE 754 behavior in case of a NaN
129
+ return float32(h) == float32(*this);
130
+ }
131
+
132
+ CUDA_CALLABLE inline bool operator!=(const half& h) const
133
+ {
134
+ // Use float32 to get IEEE 754 behavior in case of a NaN
135
+ return float32(h) != float32(*this);
136
+ }
137
+ CUDA_CALLABLE inline bool operator>(const half& h) const { return half_to_float(*this) > half_to_float(h); }
138
+ CUDA_CALLABLE inline bool operator>=(const half& h) const { return half_to_float(*this) >= half_to_float(h); }
139
+ CUDA_CALLABLE inline bool operator<(const half& h) const { return half_to_float(*this) < half_to_float(h); }
140
+ CUDA_CALLABLE inline bool operator<=(const half& h) const { return half_to_float(*this) <= half_to_float(h); }
141
+
142
+ CUDA_CALLABLE inline bool operator!() const
143
+ {
144
+ return float32(*this) == 0;
145
+ }
146
+
147
+ CUDA_CALLABLE inline half operator*=(const half& h)
148
+ {
149
+ half prod = half(float32(*this) * float32(h));
150
+ this->u = prod.u;
151
+ return *this;
152
+ }
153
+
154
+ CUDA_CALLABLE inline half operator/=(const half& h)
155
+ {
156
+ half quot = half(float32(*this) / float32(h));
157
+ this->u = quot.u;
158
+ return *this;
159
+ }
160
+
161
+ CUDA_CALLABLE inline half operator+=(const half& h)
162
+ {
163
+ half sum = half(float32(*this) + float32(h));
164
+ this->u = sum.u;
165
+ return *this;
166
+ }
167
+
168
+ CUDA_CALLABLE inline half operator-=(const half& h)
169
+ {
170
+ half diff = half(float32(*this) - float32(h));
171
+ this->u = diff.u;
172
+ return *this;
173
+ }
174
+
175
+ CUDA_CALLABLE inline operator float32() const { return float32(half_to_float(*this)); }
176
+ CUDA_CALLABLE inline operator float64() const { return float64(half_to_float(*this)); }
177
+ CUDA_CALLABLE inline operator int8() const { return int8(half_to_float(*this)); }
178
+ CUDA_CALLABLE inline operator uint8() const { return uint8(half_to_float(*this)); }
179
+ CUDA_CALLABLE inline operator int16() const { return int16(half_to_float(*this)); }
180
+ CUDA_CALLABLE inline operator uint16() const { return uint16(half_to_float(*this)); }
181
+ CUDA_CALLABLE inline operator int32() const { return int32(half_to_float(*this)); }
182
+ CUDA_CALLABLE inline operator uint32() const { return uint32(half_to_float(*this)); }
183
+ CUDA_CALLABLE inline operator int64() const { return int64(half_to_float(*this)); }
184
+ CUDA_CALLABLE inline operator uint64() const { return uint64(half_to_float(*this)); }
185
+ };
186
+
187
+ static_assert(sizeof(half) == 2, "Size of half / float16 type must be 2-bytes");
188
+
189
+ typedef half float16;
190
+
191
+ #if defined(__CUDA_ARCH__)
192
+
193
+ CUDA_CALLABLE inline half float_to_half(float x)
194
+ {
195
+ half h;
196
+ asm("{ cvt.rn.f16.f32 %0, %1;}\n" : "=h"(h.u) : "f"(x));
197
+ return h;
198
+ }
199
+
200
+ CUDA_CALLABLE inline float half_to_float(half x)
201
+ {
202
+ float val;
203
+ asm("{ cvt.f32.f16 %0, %1;}\n" : "=f"(val) : "h"(x.u));
204
+ return val;
205
+ }
206
+
207
+ #elif defined(__clang__)
208
+
209
+ // _Float16 is Clang's native half-precision floating-point type
210
+ CUDA_CALLABLE inline half float_to_half(float x)
211
+ {
212
+
213
+ _Float16 f16 = static_cast<_Float16>(x);
214
+ return *reinterpret_cast<half*>(&f16);
215
+ }
216
+
217
+ CUDA_CALLABLE inline float half_to_float(half h)
218
+ {
219
+ _Float16 f16 = *reinterpret_cast<_Float16*>(&h);
220
+ return static_cast<float>(f16);
221
+ }
222
+
223
+ #else // Native C++ for Warp builtins outside of kernels
224
+
225
+ extern "C" WP_API uint16_t wp_float_to_half_bits(float x);
226
+ extern "C" WP_API float wp_half_bits_to_float(uint16_t u);
227
+
228
+ inline half float_to_half(float x)
229
+ {
230
+ half h;
231
+ h.u = wp_float_to_half_bits(x);
232
+ return h;
233
+ }
234
+
235
+ inline float half_to_float(half h)
236
+ {
237
+ return wp_half_bits_to_float(h.u);
238
+ }
239
+
240
+ #endif
241
+
242
+
243
+ // BAD operator implementations for fp16 arithmetic...
244
+
245
+ // negation:
246
+ inline CUDA_CALLABLE half operator - (half a)
247
+ {
248
+ return float_to_half( -half_to_float(a) );
249
+ }
250
+
251
+ inline CUDA_CALLABLE half operator + (half a,half b)
252
+ {
253
+ return float_to_half( half_to_float(a) + half_to_float(b) );
254
+ }
255
+
256
+ inline CUDA_CALLABLE half operator - (half a,half b)
257
+ {
258
+ return float_to_half( half_to_float(a) - half_to_float(b) );
259
+ }
260
+
261
+ inline CUDA_CALLABLE half operator * (half a,half b)
262
+ {
263
+ return float_to_half( half_to_float(a) * half_to_float(b) );
264
+ }
265
+
266
+ inline CUDA_CALLABLE half operator * (half a,float b)
267
+ {
268
+ return float_to_half( half_to_float(a) * b );
269
+ }
270
+
271
+ inline CUDA_CALLABLE half operator * (float a,half b)
272
+ {
273
+ return float_to_half( a * half_to_float(b) );
274
+ }
275
+
276
+ inline CUDA_CALLABLE half operator * (half a,double b)
277
+ {
278
+ return float_to_half( half_to_float(a) * b );
279
+ }
280
+
281
+ inline CUDA_CALLABLE half operator * (double a,half b)
282
+ {
283
+ return float_to_half( a * half_to_float(b) );
284
+ }
285
+
286
+ inline CUDA_CALLABLE half operator / (half a,half b)
287
+ {
288
+ return float_to_half( half_to_float(a) / half_to_float(b) );
289
+ }
290
+
291
+
292
+
293
+
294
+
295
+ template<typename TRet, typename T>
296
+ inline CUDA_CALLABLE TRet cast(T a)
297
+ {
298
+ static_assert(sizeof(TRet) == sizeof(T), "source and destination must have the same size");
299
+ return *reinterpret_cast<TRet*>(&a);
300
+ }
301
+
302
+ template <typename T>
303
+ CUDA_CALLABLE inline float cast_float(T x) { return (float)(x); }
304
+
305
+ template <typename T>
306
+ CUDA_CALLABLE inline int cast_int(T x) { return (int)(x); }
307
+
308
+ template <typename T>
309
+ CUDA_CALLABLE inline void adj_cast_float(T x, T& adj_x, float adj_ret) {}
310
+
311
+ CUDA_CALLABLE inline void adj_cast_float(float16 x, float16& adj_x, float adj_ret) { adj_x += float16(adj_ret); }
312
+ CUDA_CALLABLE inline void adj_cast_float(float32 x, float32& adj_x, float adj_ret) { adj_x += float32(adj_ret); }
313
+ CUDA_CALLABLE inline void adj_cast_float(float64 x, float64& adj_x, float adj_ret) { adj_x += float64(adj_ret); }
314
+
315
+ template <typename T>
316
+ CUDA_CALLABLE inline void adj_cast_int(T x, T& adj_x, int adj_ret) {}
317
+
318
+ template <typename T>
319
+ CUDA_CALLABLE inline void adj_int8(T, T&, int8) {}
320
+ template <typename T>
321
+ CUDA_CALLABLE inline void adj_uint8(T, T&, uint8) {}
322
+ template <typename T>
323
+ CUDA_CALLABLE inline void adj_int16(T, T&, int16) {}
324
+ template <typename T>
325
+ CUDA_CALLABLE inline void adj_uint16(T, T&, uint16) {}
326
+ template <typename T>
327
+ CUDA_CALLABLE inline void adj_int32(T, T&, int32) {}
328
+ template <typename T>
329
+ CUDA_CALLABLE inline void adj_uint32(T, T&, uint32) {}
330
+ template <typename T>
331
+ CUDA_CALLABLE inline void adj_int64(T, T&, int64) {}
332
+ template <typename T>
333
+ CUDA_CALLABLE inline void adj_uint64(T, T&, uint64) {}
334
+
335
+
336
+ template <typename T>
337
+ CUDA_CALLABLE inline void adj_float16(T x, T& adj_x, float16 adj_ret) { adj_x += T(adj_ret); }
338
+ template <typename T>
339
+ CUDA_CALLABLE inline void adj_float32(T x, T& adj_x, float32 adj_ret) { adj_x += T(adj_ret); }
340
+ template <typename T>
341
+ CUDA_CALLABLE inline void adj_float64(T x, T& adj_x, float64 adj_ret) { adj_x += T(adj_ret); }
342
+
343
+
344
+ #define kEps 0.0f
345
+
346
+ // basic ops for integer types
347
+ #define DECLARE_INT_OPS(T) \
348
+ inline CUDA_CALLABLE T mul(T a, T b) { return a*b; } \
349
+ inline CUDA_CALLABLE T div(T a, T b) { return a/b; } \
350
+ inline CUDA_CALLABLE T add(T a, T b) { return a+b; } \
351
+ inline CUDA_CALLABLE T sub(T a, T b) { return a-b; } \
352
+ inline CUDA_CALLABLE T mod(T a, T b) { return a%b; } \
353
+ inline CUDA_CALLABLE T min(T a, T b) { return a<b?a:b; } \
354
+ inline CUDA_CALLABLE T max(T a, T b) { return a>b?a:b; } \
355
+ inline CUDA_CALLABLE T clamp(T x, T a, T b) { return min(max(a, x), b); } \
356
+ inline CUDA_CALLABLE T floordiv(T a, T b) { return a/b; } \
357
+ inline CUDA_CALLABLE T nonzero(T x) { return x == T(0) ? T(0) : T(1); } \
358
+ inline CUDA_CALLABLE T sqrt(T x) { return 0; } \
359
+ inline CUDA_CALLABLE T bit_and(T a, T b) { return a&b; } \
360
+ inline CUDA_CALLABLE T bit_or(T a, T b) { return a|b; } \
361
+ inline CUDA_CALLABLE T bit_xor(T a, T b) { return a^b; } \
362
+ inline CUDA_CALLABLE T lshift(T a, T b) { return a<<b; } \
363
+ inline CUDA_CALLABLE T rshift(T a, T b) { return a>>b; } \
364
+ inline CUDA_CALLABLE T invert(T x) { return ~x; } \
365
+ inline CUDA_CALLABLE bool isfinite(T x) { return ::isfinite(double(x)); } \
366
+ inline CUDA_CALLABLE bool isnan(T x) { return ::isnan(double(x)); } \
367
+ inline CUDA_CALLABLE bool isinf(T x) { return ::isinf(double(x)); } \
368
+ inline CUDA_CALLABLE void adj_mul(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
369
+ inline CUDA_CALLABLE void adj_div(T a, T b, T ret, T& adj_a, T& adj_b, T adj_ret) { } \
370
+ inline CUDA_CALLABLE void adj_add(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
371
+ inline CUDA_CALLABLE void adj_sub(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
372
+ inline CUDA_CALLABLE void adj_mod(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
373
+ inline CUDA_CALLABLE void adj_min(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
374
+ inline CUDA_CALLABLE void adj_max(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
375
+ inline CUDA_CALLABLE void adj_abs(T x, T adj_x, T& adj_ret) { } \
376
+ inline CUDA_CALLABLE void adj_sign(T x, T adj_x, T& adj_ret) { } \
377
+ inline CUDA_CALLABLE void adj_clamp(T x, T a, T b, T& adj_x, T& adj_a, T& adj_b, T adj_ret) { } \
378
+ inline CUDA_CALLABLE void adj_floordiv(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
379
+ inline CUDA_CALLABLE void adj_step(T x, T& adj_x, T adj_ret) { } \
380
+ inline CUDA_CALLABLE void adj_nonzero(T x, T& adj_x, T adj_ret) { } \
381
+ inline CUDA_CALLABLE void adj_sqrt(T x, T adj_x, T& adj_ret) { } \
382
+ inline CUDA_CALLABLE void adj_bit_and(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
383
+ inline CUDA_CALLABLE void adj_bit_or(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
384
+ inline CUDA_CALLABLE void adj_bit_xor(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
385
+ inline CUDA_CALLABLE void adj_lshift(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
386
+ inline CUDA_CALLABLE void adj_rshift(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
387
+ inline CUDA_CALLABLE void adj_invert(T x, T adj_x, T& adj_ret) { } \
388
+ inline CUDA_CALLABLE void adj_isnan(const T&, T&, bool) { } \
389
+ inline CUDA_CALLABLE void adj_isinf(const T&, T&, bool) { } \
390
+ inline CUDA_CALLABLE void adj_isfinite(const T&, T&, bool) { }
391
+
392
+ inline CUDA_CALLABLE int8 abs(int8 x) { return ::abs(x); }
393
+ inline CUDA_CALLABLE int16 abs(int16 x) { return ::abs(x); }
394
+ inline CUDA_CALLABLE int32 abs(int32 x) { return ::abs(x); }
395
+ inline CUDA_CALLABLE int64 abs(int64 x) { return ::llabs(x); }
396
+ inline CUDA_CALLABLE uint8 abs(uint8 x) { return x; }
397
+ inline CUDA_CALLABLE uint16 abs(uint16 x) { return x; }
398
+ inline CUDA_CALLABLE uint32 abs(uint32 x) { return x; }
399
+ inline CUDA_CALLABLE uint64 abs(uint64 x) { return x; }
400
+
401
+ DECLARE_INT_OPS(int8)
402
+ DECLARE_INT_OPS(int16)
403
+ DECLARE_INT_OPS(int32)
404
+ DECLARE_INT_OPS(int64)
405
+ DECLARE_INT_OPS(uint8)
406
+ DECLARE_INT_OPS(uint16)
407
+ DECLARE_INT_OPS(uint32)
408
+ DECLARE_INT_OPS(uint64)
409
+
410
+
411
+ inline CUDA_CALLABLE int8 step(int8 x) { return x < 0 ? 1 : 0; }
412
+ inline CUDA_CALLABLE int16 step(int16 x) { return x < 0 ? 1 : 0; }
413
+ inline CUDA_CALLABLE int32 step(int32 x) { return x < 0 ? 1 : 0; }
414
+ inline CUDA_CALLABLE int64 step(int64 x) { return x < 0 ? 1 : 0; }
415
+ inline CUDA_CALLABLE uint8 step(uint8 x) { return 0; }
416
+ inline CUDA_CALLABLE uint16 step(uint16 x) { return 0; }
417
+ inline CUDA_CALLABLE uint32 step(uint32 x) { return 0; }
418
+ inline CUDA_CALLABLE uint64 step(uint64 x) { return 0; }
419
+
420
+
421
+ inline CUDA_CALLABLE int8 sign(int8 x) { return x < 0 ? -1 : 1; }
422
+ inline CUDA_CALLABLE int8 sign(int16 x) { return x < 0 ? -1 : 1; }
423
+ inline CUDA_CALLABLE int8 sign(int32 x) { return x < 0 ? -1 : 1; }
424
+ inline CUDA_CALLABLE int8 sign(int64 x) { return x < 0 ? -1 : 1; }
425
+ inline CUDA_CALLABLE uint8 sign(uint8 x) { return 1; }
426
+ inline CUDA_CALLABLE uint16 sign(uint16 x) { return 1; }
427
+ inline CUDA_CALLABLE uint32 sign(uint32 x) { return 1; }
428
+ inline CUDA_CALLABLE uint64 sign(uint64 x) { return 1; }
429
+
430
+
431
+ // Catch-all for non-float, non-integer types
432
+ template<typename T>
433
+ inline bool CUDA_CALLABLE isfinite(const T&)
434
+ {
435
+ return true;
436
+ }
437
+
438
+ inline bool CUDA_CALLABLE isfinite(half x)
439
+ {
440
+ return ::isfinite(float(x));
441
+ }
442
+ inline bool CUDA_CALLABLE isfinite(float x)
443
+ {
444
+ return ::isfinite(x);
445
+ }
446
+ inline bool CUDA_CALLABLE isfinite(double x)
447
+ {
448
+ return ::isfinite(x);
449
+ }
450
+
451
+ inline bool CUDA_CALLABLE isnan(half x)
452
+ {
453
+ return ::isnan(float(x));
454
+ }
455
+ inline bool CUDA_CALLABLE isnan(float x)
456
+ {
457
+ return ::isnan(x);
458
+ }
459
+ inline bool CUDA_CALLABLE isnan(double x)
460
+ {
461
+ return ::isnan(x);
462
+ }
463
+
464
+ inline bool CUDA_CALLABLE isinf(half x)
465
+ {
466
+ return ::isinf(float(x));
467
+ }
468
+ inline bool CUDA_CALLABLE isinf(float x)
469
+ {
470
+ return ::isinf(x);
471
+ }
472
+ inline bool CUDA_CALLABLE isinf(double x)
473
+ {
474
+ return ::isinf(x);
475
+ }
476
+
477
+ template<typename T>
478
+ inline CUDA_CALLABLE void print(const T&)
479
+ {
480
+ printf("<type without print implementation>\n");
481
+ }
482
+
483
+ inline CUDA_CALLABLE void print(float16 f)
484
+ {
485
+ printf("%g\n", half_to_float(f));
486
+ }
487
+
488
+ inline CUDA_CALLABLE void print(float f)
489
+ {
490
+ printf("%g\n", f);
491
+ }
492
+
493
+ inline CUDA_CALLABLE void print(double f)
494
+ {
495
+ printf("%g\n", f);
496
+ }
497
+
498
+
499
+ // basic ops for float types
500
+ #define DECLARE_FLOAT_OPS(T) \
501
+ inline CUDA_CALLABLE T mul(T a, T b) { return a*b; } \
502
+ inline CUDA_CALLABLE T add(T a, T b) { return a+b; } \
503
+ inline CUDA_CALLABLE T sub(T a, T b) { return a-b; } \
504
+ inline CUDA_CALLABLE T min(T a, T b) { return a<b?a:b; } \
505
+ inline CUDA_CALLABLE T max(T a, T b) { return a>b?a:b; } \
506
+ inline CUDA_CALLABLE T sign(T x) { return x < T(0) ? -1 : 1; } \
507
+ inline CUDA_CALLABLE T step(T x) { return x < T(0) ? T(1) : T(0); }\
508
+ inline CUDA_CALLABLE T nonzero(T x) { return x == T(0) ? T(0) : T(1); }\
509
+ inline CUDA_CALLABLE T clamp(T x, T a, T b) { return min(max(a, x), b); }\
510
+ inline CUDA_CALLABLE void adj_abs(T x, T& adj_x, T adj_ret) \
511
+ {\
512
+ if (x < T(0))\
513
+ adj_x -= adj_ret;\
514
+ else\
515
+ adj_x += adj_ret;\
516
+ }\
517
+ inline CUDA_CALLABLE void adj_mul(T a, T b, T& adj_a, T& adj_b, T adj_ret) { adj_a += b*adj_ret; adj_b += a*adj_ret; } \
518
+ inline CUDA_CALLABLE void adj_add(T a, T b, T& adj_a, T& adj_b, T adj_ret) { adj_a += adj_ret; adj_b += adj_ret; } \
519
+ inline CUDA_CALLABLE void adj_sub(T a, T b, T& adj_a, T& adj_b, T adj_ret) { adj_a += adj_ret; adj_b -= adj_ret; } \
520
+ inline CUDA_CALLABLE void adj_min(T a, T b, T& adj_a, T& adj_b, T adj_ret) \
521
+ { \
522
+ if (a < b) \
523
+ adj_a += adj_ret; \
524
+ else \
525
+ adj_b += adj_ret; \
526
+ } \
527
+ inline CUDA_CALLABLE void adj_max(T a, T b, T& adj_a, T& adj_b, T adj_ret) \
528
+ { \
529
+ if (a > b) \
530
+ adj_a += adj_ret; \
531
+ else \
532
+ adj_b += adj_ret; \
533
+ } \
534
+ inline CUDA_CALLABLE void adj_floordiv(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
535
+ inline CUDA_CALLABLE void adj_mod(T a, T b, T& adj_a, T& adj_b, T adj_ret){ adj_a += adj_ret; }\
536
+ inline CUDA_CALLABLE void adj_sign(T x, T adj_x, T& adj_ret) { }\
537
+ inline CUDA_CALLABLE void adj_step(T x, T& adj_x, T adj_ret) { }\
538
+ inline CUDA_CALLABLE void adj_nonzero(T x, T& adj_x, T adj_ret) { }\
539
+ inline CUDA_CALLABLE void adj_clamp(T x, T a, T b, T& adj_x, T& adj_a, T& adj_b, T adj_ret)\
540
+ {\
541
+ if (x < a)\
542
+ adj_a += adj_ret;\
543
+ else if (x > b)\
544
+ adj_b += adj_ret;\
545
+ else\
546
+ adj_x += adj_ret;\
547
+ }\
548
+ inline CUDA_CALLABLE T div(T a, T b)\
549
+ {\
550
+ DO_IF_FPCHECK(\
551
+ if (!isfinite(a) || !isfinite(b) || b == T(0))\
552
+ {\
553
+ printf("%s:%d div(%f, %f)\n", __FILE__, __LINE__, float(a), float(b));\
554
+ assert(0);\
555
+ })\
556
+ return a/b;\
557
+ }\
558
+ inline CUDA_CALLABLE void adj_div(T a, T b, T ret, T& adj_a, T& adj_b, T adj_ret)\
559
+ {\
560
+ adj_a += adj_ret/b;\
561
+ adj_b -= adj_ret*(ret)/b;\
562
+ DO_IF_FPCHECK(\
563
+ if (!isfinite(adj_a) || !isfinite(adj_b))\
564
+ {\
565
+ printf("%s:%d - adj_div(%f, %f, %f, %f, %f)\n", __FILE__, __LINE__, float(a), float(b), float(adj_a), float(adj_b), float(adj_ret));\
566
+ assert(0);\
567
+ })\
568
+ }\
569
+ inline CUDA_CALLABLE void adj_isnan(const T&, T&, bool) { }\
570
+ inline CUDA_CALLABLE void adj_isinf(const T&, T&, bool) { }\
571
+ inline CUDA_CALLABLE void adj_isfinite(const T&, T&, bool) { }
572
+
573
+ DECLARE_FLOAT_OPS(float16)
574
+ DECLARE_FLOAT_OPS(float32)
575
+ DECLARE_FLOAT_OPS(float64)
576
+
577
+
578
+
579
+ // basic ops for float types
580
+ inline CUDA_CALLABLE float16 mod(float16 a, float16 b)
581
+ {
582
+ #if FP_CHECK
583
+ if (!isfinite(a) || !isfinite(b) || float(b) == 0.0f)
584
+ {
585
+ printf("%s:%d mod(%f, %f)\n", __FILE__, __LINE__, float(a), float(b));
586
+ assert(0);
587
+ }
588
+ #endif
589
+ return fmodf(float(a), float(b));
590
+ }
591
+
592
+ inline CUDA_CALLABLE float32 mod(float32 a, float32 b)
593
+ {
594
+ #if FP_CHECK
595
+ if (!isfinite(a) || !isfinite(b) || b == 0.0f)
596
+ {
597
+ printf("%s:%d mod(%f, %f)\n", __FILE__, __LINE__, a, b);
598
+ assert(0);
599
+ }
600
+ #endif
601
+ return fmodf(a, b);
602
+ }
603
+
604
+ inline CUDA_CALLABLE double mod(double a, double b)
605
+ {
606
+ #if FP_CHECK
607
+ if (!isfinite(a) || !isfinite(b) || b == 0.0f)
608
+ {
609
+ printf("%s:%d mod(%f, %f)\n", __FILE__, __LINE__, a, b);
610
+ assert(0);
611
+ }
612
+ #endif
613
+ return fmod(a, b);
614
+ }
615
+
616
+ inline CUDA_CALLABLE half log(half a)
617
+ {
618
+ #if FP_CHECK
619
+ if (!isfinite(a) || float(a) < 0.0f)
620
+ {
621
+ printf("%s:%d log(%f)\n", __FILE__, __LINE__, float(a));
622
+ assert(0);
623
+ }
624
+ #endif
625
+ return ::logf(a);
626
+ }
627
+
628
+ inline CUDA_CALLABLE float log(float a)
629
+ {
630
+ #if FP_CHECK
631
+ if (!isfinite(a) || a < 0.0f)
632
+ {
633
+ printf("%s:%d log(%f)\n", __FILE__, __LINE__, a);
634
+ assert(0);
635
+ }
636
+ #endif
637
+ return ::logf(a);
638
+ }
639
+
640
+ inline CUDA_CALLABLE double log(double a)
641
+ {
642
+ #if FP_CHECK
643
+ if (!isfinite(a) || a < 0.0)
644
+ {
645
+ printf("%s:%d log(%f)\n", __FILE__, __LINE__, a);
646
+ assert(0);
647
+ }
648
+ #endif
649
+ return ::log(a);
650
+ }
651
+
652
+ inline CUDA_CALLABLE half log2(half a)
653
+ {
654
+ #if FP_CHECK
655
+ if (!isfinite(a) || float(a) < 0.0f)
656
+ {
657
+ printf("%s:%d log2(%f)\n", __FILE__, __LINE__, float(a));
658
+ assert(0);
659
+ }
660
+ #endif
661
+
662
+ return ::log2f(float(a));
663
+ }
664
+
665
+ inline CUDA_CALLABLE float log2(float a)
666
+ {
667
+ #if FP_CHECK
668
+ if (!isfinite(a) || a < 0.0f)
669
+ {
670
+ printf("%s:%d log2(%f)\n", __FILE__, __LINE__, a);
671
+ assert(0);
672
+ }
673
+ #endif
674
+
675
+ return ::log2f(a);
676
+ }
677
+
678
+ inline CUDA_CALLABLE double log2(double a)
679
+ {
680
+ #if FP_CHECK
681
+ if (!isfinite(a) || a < 0.0)
682
+ {
683
+ printf("%s:%d log2(%f)\n", __FILE__, __LINE__, a);
684
+ assert(0);
685
+ }
686
+ #endif
687
+
688
+ return ::log2(a);
689
+ }
690
+
691
+ inline CUDA_CALLABLE half log10(half a)
692
+ {
693
+ #if FP_CHECK
694
+ if (!isfinite(a) || float(a) < 0.0f)
695
+ {
696
+ printf("%s:%d log10(%f)\n", __FILE__, __LINE__, float(a));
697
+ assert(0);
698
+ }
699
+ #endif
700
+
701
+ return ::log10f(float(a));
702
+ }
703
+
704
+ inline CUDA_CALLABLE float log10(float a)
705
+ {
706
+ #if FP_CHECK
707
+ if (!isfinite(a) || a < 0.0f)
708
+ {
709
+ printf("%s:%d log10(%f)\n", __FILE__, __LINE__, a);
710
+ assert(0);
711
+ }
712
+ #endif
713
+
714
+ return ::log10f(a);
715
+ }
716
+
717
+ inline CUDA_CALLABLE double log10(double a)
718
+ {
719
+ #if FP_CHECK
720
+ if (!isfinite(a) || a < 0.0)
721
+ {
722
+ printf("%s:%d log10(%f)\n", __FILE__, __LINE__, a);
723
+ assert(0);
724
+ }
725
+ #endif
726
+
727
+ return ::log10(a);
728
+ }
729
+
730
+ inline CUDA_CALLABLE half exp(half a)
731
+ {
732
+ half result = ::expf(float(a));
733
+ #if FP_CHECK
734
+ if (!isfinite(a) || !isfinite(result))
735
+ {
736
+ printf("%s:%d exp(%f) = %f\n", __FILE__, __LINE__, float(a), float(result));
737
+ assert(0);
738
+ }
739
+ #endif
740
+ return result;
741
+ }
742
+ inline CUDA_CALLABLE float exp(float a)
743
+ {
744
+ float result = ::expf(a);
745
+ #if FP_CHECK
746
+ if (!isfinite(a) || !isfinite(result))
747
+ {
748
+ printf("%s:%d exp(%f) = %f\n", __FILE__, __LINE__, a, result);
749
+ assert(0);
750
+ }
751
+ #endif
752
+ return result;
753
+ }
754
+ inline CUDA_CALLABLE double exp(double a)
755
+ {
756
+ double result = ::exp(a);
757
+ #if FP_CHECK
758
+ if (!isfinite(a) || !isfinite(result))
759
+ {
760
+ printf("%s:%d exp(%f) = %f\n", __FILE__, __LINE__, a, result);
761
+ assert(0);
762
+ }
763
+ #endif
764
+ return result;
765
+ }
766
+
767
+ inline CUDA_CALLABLE half pow(half a, half b)
768
+ {
769
+ float result = ::powf(float(a), float(b));
770
+ #if FP_CHECK
771
+ if (!isfinite(float(a)) || !isfinite(float(b)) || !isfinite(result))
772
+ {
773
+ printf("%s:%d pow(%f, %f) = %f\n", __FILE__, __LINE__, float(a), float(b), result);
774
+ assert(0);
775
+ }
776
+ #endif
777
+ return result;
778
+ }
779
+
780
+ inline CUDA_CALLABLE float pow(float a, float b)
781
+ {
782
+ float result = ::powf(a, b);
783
+ #if FP_CHECK
784
+ if (!isfinite(a) || !isfinite(b) || !isfinite(result))
785
+ {
786
+ printf("%s:%d pow(%f, %f) = %f\n", __FILE__, __LINE__, a, b, result);
787
+ assert(0);
788
+ }
789
+ #endif
790
+ return result;
791
+ }
792
+
793
+ inline CUDA_CALLABLE double pow(double a, double b)
794
+ {
795
+ double result = ::pow(a, b);
796
+ #if FP_CHECK
797
+ if (!isfinite(a) || !isfinite(b) || !isfinite(result))
798
+ {
799
+ printf("%s:%d pow(%f, %f) = %f\n", __FILE__, __LINE__, a, b, result);
800
+ assert(0);
801
+ }
802
+ #endif
803
+ return result;
804
+ }
805
+
806
+ inline CUDA_CALLABLE half floordiv(half a, half b)
807
+ {
808
+ #if FP_CHECK
809
+ if (!isfinite(a) || !isfinite(b) || float(b) == 0.0f)
810
+ {
811
+ printf("%s:%d floordiv(%f, %f)\n", __FILE__, __LINE__, float(a), float(b));
812
+ assert(0);
813
+ }
814
+ #endif
815
+ return floorf(float(a/b));
816
+ }
817
+
818
+ inline CUDA_CALLABLE float floordiv(float a, float b)
819
+ {
820
+ #if FP_CHECK
821
+ if (!isfinite(a) || !isfinite(b) || b == 0.0f)
822
+ {
823
+ printf("%s:%d floordiv(%f, %f)\n", __FILE__, __LINE__, a, b);
824
+ assert(0);
825
+ }
826
+ #endif
827
+ return floorf(a/b);
828
+ }
829
+
830
+ inline CUDA_CALLABLE double floordiv(double a, double b)
831
+ {
832
+ #if FP_CHECK
833
+ if (!isfinite(a) || !isfinite(b) || b == 0.0)
834
+ {
835
+ printf("%s:%d floordiv(%f, %f)\n", __FILE__, __LINE__, a, b);
836
+ assert(0);
837
+ }
838
+ #endif
839
+ return ::floor(a/b);
840
+ }
841
+
842
+ inline CUDA_CALLABLE half erf(half a)
843
+ {
844
+ return erff(float(a));
845
+ }
846
+
847
+ inline CUDA_CALLABLE float erf(float a)
848
+ {
849
+ return erff(a);
850
+ }
851
+
852
+ inline CUDA_CALLABLE double erf(double a)
853
+ {
854
+ return ::erf(a);
855
+ }
856
+
857
+ inline CUDA_CALLABLE half erfc(half a)
858
+ {
859
+ return erfcf(float(a));
860
+ }
861
+
862
+ inline CUDA_CALLABLE float erfc(float a)
863
+ {
864
+ return erfcf(a);
865
+ }
866
+
867
+ inline CUDA_CALLABLE double erfc(double a)
868
+ {
869
+ return ::erfc(a);
870
+ }
871
+
872
+ inline CUDA_CALLABLE half erfinv(half a)
873
+ {
874
+ #if FP_CHECK
875
+ if (float(a) < -1.0f || float(a) > 1.0f)
876
+ {
877
+ printf("%s:%d erfinv(%f)\n", __FILE__, __LINE__, float(a));
878
+ assert(0);
879
+ }
880
+ #endif
881
+ return ::erfinvf(float(a));
882
+ }
883
+
884
+ inline CUDA_CALLABLE float erfinv(float a)
885
+ {
886
+ #if FP_CHECK
887
+ if (a < -1.0f || a > 1.0f)
888
+ {
889
+ printf("%s:%d erfinv(%f)\n", __FILE__, __LINE__, a);
890
+ assert(0);
891
+ }
892
+ #endif
893
+ return ::erfinvf(a);
894
+ }
895
+
896
+ inline CUDA_CALLABLE double erfinv(double a)
897
+ {
898
+ #if FP_CHECK
899
+ if (a < -1.0 || a > 1.0)
900
+ {
901
+ printf("%s:%d erfinv(%f)\n", __FILE__, __LINE__, a);
902
+ assert(0);
903
+ }
904
+ #endif
905
+ return ::erfinv(a);
906
+ }
907
+
908
+ inline CUDA_CALLABLE half erfcinv(half a)
909
+ {
910
+ #if FP_CHECK
911
+ if (float(a) < 0.0f || float(a) > 2.0f)
912
+ {
913
+ printf("%s:%d erfcinv(%f)\n", __FILE__, __LINE__, float(a));
914
+ assert(0);
915
+ }
916
+ #endif
917
+ return ::erfcinvf(float(a));
918
+ }
919
+
920
+ inline CUDA_CALLABLE float erfcinv(float a)
921
+ {
922
+ #if FP_CHECK
923
+ if (a < 0.0f || a > 2.0f)
924
+ {
925
+ printf("%s:%d erfcinv(%f)\n", __FILE__, __LINE__, a);
926
+ assert(0);
927
+ }
928
+ #endif
929
+ return ::erfcinvf(a);
930
+ }
931
+
932
+ inline CUDA_CALLABLE double erfcinv(double a)
933
+ {
934
+ #if FP_CHECK
935
+ if (a < 0.0 || a > 2.0)
936
+ {
937
+ printf("%s:%d erfcinv(%f)\n", __FILE__, __LINE__, a);
938
+ assert(0);
939
+ }
940
+ #endif
941
+ return ::erfcinv(a);
942
+ }
943
+
944
+ inline CUDA_CALLABLE void adj_erf(half a, half& adj_a, half adj_ret)
945
+ {
946
+ adj_a += half(M_2_SQRT_PI_F * ::expf(-float(a)*float(a))) * adj_ret;
947
+ }
948
+
949
+ inline CUDA_CALLABLE void adj_erf(float a, float& adj_a, float adj_ret)
950
+ {
951
+ adj_a += M_2_SQRT_PI_F * ::expf(-a*a) * adj_ret;
952
+ }
953
+
954
+ inline CUDA_CALLABLE void adj_erf(double a, double& adj_a, double adj_ret)
955
+ {
956
+ adj_a += M_2_SQRT_PI * ::exp(-a*a) * adj_ret;
957
+ }
958
+
959
+ inline CUDA_CALLABLE void adj_erfc(half a, half& adj_a, half adj_ret)
960
+ {
961
+ adj_a -= half(M_2_SQRT_PI_F * ::expf(-float(a)*float(a))) * adj_ret;
962
+ }
963
+
964
+ inline CUDA_CALLABLE void adj_erfc(float a, float& adj_a, float adj_ret)
965
+ {
966
+ adj_a -= M_2_SQRT_PI_F * ::expf(-a*a) * adj_ret;
967
+ }
968
+
969
+ inline CUDA_CALLABLE void adj_erfc(double a, double& adj_a, double adj_ret)
970
+ {
971
+ adj_a -= M_2_SQRT_PI * ::exp(-a*a) * adj_ret;
972
+ }
973
+
974
+ inline CUDA_CALLABLE void adj_erfinv(half a, half ret, half& adj_a, half adj_ret)
975
+ {
976
+ #if FP_CHECK
977
+ if (float(a) < -1.0f || float(a) > 1.0f)
978
+ {
979
+ printf("%s:%d adj_erfinv(%f)\n", __FILE__, __LINE__, float(a));
980
+ assert(0);
981
+ }
982
+ #endif
983
+ adj_a += half(M_SQRT_PI_F_2 * ::expf(float(ret)*float(ret))) * adj_ret;
984
+ }
985
+
986
+ inline CUDA_CALLABLE void adj_erfinv(float a, float ret, float& adj_a, float adj_ret)
987
+ {
988
+ #if FP_CHECK
989
+ if (a < -1.0f || a > 1.0f)
990
+ {
991
+ printf("%s:%d adj_erfinv(%f)\n", __FILE__, __LINE__, a);
992
+ assert(0);
993
+ }
994
+ #endif
995
+ adj_a += M_SQRT_PI_F_2 * ::expf(ret*ret) * adj_ret;
996
+ }
997
+
998
+ inline CUDA_CALLABLE void adj_erfinv(double a, double ret, double& adj_a, double adj_ret)
999
+ {
1000
+ #if FP_CHECK
1001
+ if (a < -1.0 || a > 1.0)
1002
+ {
1003
+ printf("%s:%d adj_erfinv(%f)\n", __FILE__, __LINE__, a);
1004
+ assert(0);
1005
+ }
1006
+ #endif
1007
+ adj_a += M_SQRT_PI_2 * ::exp(ret*ret) * adj_ret;
1008
+ }
1009
+
1010
+ inline CUDA_CALLABLE void adj_erfcinv(half a, half ret, half& adj_a, half adj_ret)
1011
+ {
1012
+ #if FP_CHECK
1013
+ if (float(a) < 0.0f || float(a) > 2.0f)
1014
+ {
1015
+ printf("%s:%d adj_erfcinv(%f)\n", __FILE__, __LINE__, float(a));
1016
+ assert(0);
1017
+ }
1018
+ #endif
1019
+ adj_a -= half(M_SQRT_PI_F_2 * ::expf(float(ret)*float(ret))) * adj_ret;
1020
+ }
1021
+
1022
+ inline CUDA_CALLABLE void adj_erfcinv(float a, float ret, float& adj_a, float adj_ret)
1023
+ {
1024
+ #if FP_CHECK
1025
+ if (a < 0.0f || a > 2.0f)
1026
+ {
1027
+ printf("%s:%d adj_erfcinv(%f)\n", __FILE__, __LINE__, a);
1028
+ assert(0);
1029
+ }
1030
+ #endif
1031
+ adj_a -= M_SQRT_PI_F_2 * ::expf(ret*ret) * adj_ret;
1032
+ }
1033
+
1034
+ inline CUDA_CALLABLE void adj_erfcinv(double a, double ret, double& adj_a, double adj_ret)
1035
+ {
1036
+ #if FP_CHECK
1037
+ if (a < 0.0 || a > 2.0)
1038
+ {
1039
+ printf("%s:%d adj_erfcinv(%f)\n", __FILE__, __LINE__, a);
1040
+ assert(0);
1041
+ }
1042
+ #endif
1043
+ adj_a -= M_SQRT_PI_2 * ::exp(ret*ret) * adj_ret;
1044
+ }
1045
+
1046
+ inline CUDA_CALLABLE float leaky_min(float a, float b, float r) { return min(a, b); }
1047
+ inline CUDA_CALLABLE float leaky_max(float a, float b, float r) { return max(a, b); }
1048
+
1049
+ inline CUDA_CALLABLE half abs(half x) { return ::fabsf(float(x)); }
1050
+ inline CUDA_CALLABLE float abs(float x) { return ::fabsf(x); }
1051
+ inline CUDA_CALLABLE double abs(double x) { return ::fabs(x); }
1052
+
1053
+ inline CUDA_CALLABLE float acos(float x){ return ::acosf(min(max(x, -1.0f), 1.0f)); }
1054
+ inline CUDA_CALLABLE float asin(float x){ return ::asinf(min(max(x, -1.0f), 1.0f)); }
1055
+ inline CUDA_CALLABLE float atan(float x) { return ::atanf(x); }
1056
+ inline CUDA_CALLABLE float atan2(float y, float x) { return ::atan2f(y, x); }
1057
+ inline CUDA_CALLABLE float sin(float x) { return ::sinf(x); }
1058
+ inline CUDA_CALLABLE float cos(float x) { return ::cosf(x); }
1059
+
1060
+ inline CUDA_CALLABLE double acos(double x){ return ::acos(min(max(x, -1.0), 1.0)); }
1061
+ inline CUDA_CALLABLE double asin(double x){ return ::asin(min(max(x, -1.0), 1.0)); }
1062
+ inline CUDA_CALLABLE double atan(double x) { return ::atan(x); }
1063
+ inline CUDA_CALLABLE double atan2(double y, double x) { return ::atan2(y, x); }
1064
+ inline CUDA_CALLABLE double sin(double x) { return ::sin(x); }
1065
+ inline CUDA_CALLABLE double cos(double x) { return ::cos(x); }
1066
+
1067
+ inline CUDA_CALLABLE half acos(half x){ return ::acosf(min(max(float(x), -1.0f), 1.0f)); }
1068
+ inline CUDA_CALLABLE half asin(half x){ return ::asinf(min(max(float(x), -1.0f), 1.0f)); }
1069
+ inline CUDA_CALLABLE half atan(half x) { return ::atanf(float(x)); }
1070
+ inline CUDA_CALLABLE half atan2(half y, half x) { return ::atan2f(float(y), float(x)); }
1071
+ inline CUDA_CALLABLE half sin(half x) { return ::sinf(float(x)); }
1072
+ inline CUDA_CALLABLE half cos(half x) { return ::cosf(float(x)); }
1073
+
1074
+
1075
+ inline CUDA_CALLABLE float sqrt(float x)
1076
+ {
1077
+ #if FP_CHECK
1078
+ if (x < 0.0f)
1079
+ {
1080
+ printf("%s:%d sqrt(%f)\n", __FILE__, __LINE__, x);
1081
+ assert(0);
1082
+ }
1083
+ #endif
1084
+ return ::sqrtf(x);
1085
+ }
1086
+ inline CUDA_CALLABLE double sqrt(double x)
1087
+ {
1088
+ #if FP_CHECK
1089
+ if (x < 0.0)
1090
+ {
1091
+ printf("%s:%d sqrt(%f)\n", __FILE__, __LINE__, x);
1092
+ assert(0);
1093
+ }
1094
+ #endif
1095
+ return ::sqrt(x);
1096
+ }
1097
+ inline CUDA_CALLABLE half sqrt(half x)
1098
+ {
1099
+ #if FP_CHECK
1100
+ if (float(x) < 0.0f)
1101
+ {
1102
+ printf("%s:%d sqrt(%f)\n", __FILE__, __LINE__, float(x));
1103
+ assert(0);
1104
+ }
1105
+ #endif
1106
+ return ::sqrtf(float(x));
1107
+ }
1108
+
1109
+ inline CUDA_CALLABLE float cbrt(float x) { return ::cbrtf(x); }
1110
+ inline CUDA_CALLABLE double cbrt(double x) { return ::cbrt(x); }
1111
+ inline CUDA_CALLABLE half cbrt(half x) { return ::cbrtf(float(x)); }
1112
+
1113
+ inline CUDA_CALLABLE float tan(float x) { return ::tanf(x); }
1114
+ inline CUDA_CALLABLE float sinh(float x) { return ::sinhf(x);}
1115
+ inline CUDA_CALLABLE float cosh(float x) { return ::coshf(x);}
1116
+ inline CUDA_CALLABLE float tanh(float x) { return ::tanhf(x);}
1117
+ inline CUDA_CALLABLE float degrees(float x) { return x * RAD_TO_DEG;}
1118
+ inline CUDA_CALLABLE float radians(float x) { return x * DEG_TO_RAD;}
1119
+
1120
+ inline CUDA_CALLABLE double tan(double x) { return ::tan(x); }
1121
+ inline CUDA_CALLABLE double sinh(double x) { return ::sinh(x);}
1122
+ inline CUDA_CALLABLE double cosh(double x) { return ::cosh(x);}
1123
+ inline CUDA_CALLABLE double tanh(double x) { return ::tanh(x);}
1124
+ inline CUDA_CALLABLE double degrees(double x) { return x * RAD_TO_DEG;}
1125
+ inline CUDA_CALLABLE double radians(double x) { return x * DEG_TO_RAD;}
1126
+
1127
+ inline CUDA_CALLABLE half tan(half x) { return ::tanf(float(x)); }
1128
+ inline CUDA_CALLABLE half sinh(half x) { return ::sinhf(float(x));}
1129
+ inline CUDA_CALLABLE half cosh(half x) { return ::coshf(float(x));}
1130
+ inline CUDA_CALLABLE half tanh(half x) { return ::tanhf(float(x));}
1131
+ inline CUDA_CALLABLE half degrees(half x) { return x * RAD_TO_DEG;}
1132
+ inline CUDA_CALLABLE half radians(half x) { return x * DEG_TO_RAD;}
1133
+
1134
+ inline CUDA_CALLABLE float round(float x) { return ::roundf(x); }
1135
+ inline CUDA_CALLABLE float rint(float x) { return ::rintf(x); }
1136
+ inline CUDA_CALLABLE float trunc(float x) { return ::truncf(x); }
1137
+ inline CUDA_CALLABLE float floor(float x) { return ::floorf(x); }
1138
+ inline CUDA_CALLABLE float ceil(float x) { return ::ceilf(x); }
1139
+ inline CUDA_CALLABLE float frac(float x) { return x - trunc(x); }
1140
+
1141
+ inline CUDA_CALLABLE double round(double x) { return ::round(x); }
1142
+ inline CUDA_CALLABLE double rint(double x) { return ::rint(x); }
1143
+ inline CUDA_CALLABLE double trunc(double x) { return ::trunc(x); }
1144
+ inline CUDA_CALLABLE double floor(double x) { return ::floor(x); }
1145
+ inline CUDA_CALLABLE double ceil(double x) { return ::ceil(x); }
1146
+ inline CUDA_CALLABLE double frac(double x) { return x - trunc(x); }
1147
+
1148
+ inline CUDA_CALLABLE half round(half x) { return ::roundf(float(x)); }
1149
+ inline CUDA_CALLABLE half rint(half x) { return ::rintf(float(x)); }
1150
+ inline CUDA_CALLABLE half trunc(half x) { return ::truncf(float(x)); }
1151
+ inline CUDA_CALLABLE half floor(half x) { return ::floorf(float(x)); }
1152
+ inline CUDA_CALLABLE half ceil(half x) { return ::ceilf(float(x)); }
1153
+ inline CUDA_CALLABLE half frac(half x) { return float(x) - trunc(float(x)); }
1154
+
1155
+ #define DECLARE_ADJOINTS(T)\
1156
+ inline CUDA_CALLABLE void adj_log(T a, T& adj_a, T adj_ret)\
1157
+ {\
1158
+ adj_a += (T(1)/a)*adj_ret;\
1159
+ DO_IF_FPCHECK(if (!isfinite(adj_a))\
1160
+ {\
1161
+ printf("%s:%d - adj_log(%f, %f, %f)\n", __FILE__, __LINE__, float(a), float(adj_a), float(adj_ret));\
1162
+ assert(0);\
1163
+ })\
1164
+ }\
1165
+ inline CUDA_CALLABLE void adj_log2(T a, T& adj_a, T adj_ret)\
1166
+ { \
1167
+ adj_a += (T(1)/a)*(T(1)/log(T(2)))*adj_ret; \
1168
+ DO_IF_FPCHECK(if (!isfinite(adj_a))\
1169
+ {\
1170
+ printf("%s:%d - adj_log2(%f, %f, %f)\n", __FILE__, __LINE__, float(a), float(adj_a), float(adj_ret));\
1171
+ assert(0);\
1172
+ }) \
1173
+ }\
1174
+ inline CUDA_CALLABLE void adj_log10(T a, T& adj_a, T adj_ret)\
1175
+ {\
1176
+ adj_a += (T(1)/a)*(T(1)/log(T(10)))*adj_ret; \
1177
+ DO_IF_FPCHECK(if (!isfinite(adj_a))\
1178
+ {\
1179
+ printf("%s:%d - adj_log10(%f, %f, %f)\n", __FILE__, __LINE__, float(a), float(adj_a), float(adj_ret));\
1180
+ assert(0);\
1181
+ })\
1182
+ }\
1183
+ inline CUDA_CALLABLE void adj_exp(T a, T ret, T& adj_a, T adj_ret) { adj_a += ret*adj_ret; }\
1184
+ inline CUDA_CALLABLE void adj_pow(T a, T b, T ret, T& adj_a, T& adj_b, T adj_ret)\
1185
+ { \
1186
+ adj_a += b*pow(a, b-T(1))*adj_ret;\
1187
+ adj_b += log(a)*ret*adj_ret;\
1188
+ DO_IF_FPCHECK(if (!isfinite(adj_a) || !isfinite(adj_b))\
1189
+ {\
1190
+ printf("%s:%d - adj_pow(%f, %f, %f, %f, %f)\n", __FILE__, __LINE__, float(a), float(b), float(adj_a), float(adj_b), float(adj_ret));\
1191
+ assert(0);\
1192
+ })\
1193
+ }\
1194
+ inline CUDA_CALLABLE void adj_leaky_min(T a, T b, T r, T& adj_a, T& adj_b, T& adj_r, T adj_ret)\
1195
+ {\
1196
+ if (a < b)\
1197
+ adj_a += adj_ret;\
1198
+ else\
1199
+ {\
1200
+ adj_a += r*adj_ret;\
1201
+ adj_b += adj_ret;\
1202
+ }\
1203
+ }\
1204
+ inline CUDA_CALLABLE void adj_leaky_max(T a, T b, T r, T& adj_a, T& adj_b, T& adj_r, T adj_ret)\
1205
+ {\
1206
+ if (a > b)\
1207
+ adj_a += adj_ret;\
1208
+ else\
1209
+ {\
1210
+ adj_a += r*adj_ret;\
1211
+ adj_b += adj_ret;\
1212
+ }\
1213
+ }\
1214
+ inline CUDA_CALLABLE void adj_acos(T x, T& adj_x, T adj_ret)\
1215
+ {\
1216
+ T d = sqrt(T(1)-x*x);\
1217
+ DO_IF_FPCHECK(adj_x -= (T(1)/d)*adj_ret;\
1218
+ if (!isfinite(d) || !isfinite(adj_x))\
1219
+ {\
1220
+ printf("%s:%d - adj_acos(%f, %f, %f)\n", __FILE__, __LINE__, float(x), float(adj_x), float(adj_ret)); \
1221
+ assert(0);\
1222
+ })\
1223
+ DO_IF_NO_FPCHECK(if (d > T(0))\
1224
+ adj_x -= (T(1)/d)*adj_ret;)\
1225
+ }\
1226
+ inline CUDA_CALLABLE void adj_asin(T x, T& adj_x, T adj_ret)\
1227
+ {\
1228
+ T d = sqrt(T(1)-x*x);\
1229
+ DO_IF_FPCHECK(adj_x += (T(1)/d)*adj_ret;\
1230
+ if (!isfinite(d) || !isfinite(adj_x))\
1231
+ {\
1232
+ printf("%s:%d - adj_asin(%f, %f, %f)\n", __FILE__, __LINE__, float(x), float(adj_x), float(adj_ret)); \
1233
+ assert(0);\
1234
+ })\
1235
+ DO_IF_NO_FPCHECK(if (d > T(0))\
1236
+ adj_x += (T(1)/d)*adj_ret;)\
1237
+ }\
1238
+ inline CUDA_CALLABLE void adj_tan(T x, T& adj_x, T adj_ret)\
1239
+ {\
1240
+ T cos_x = cos(x);\
1241
+ DO_IF_FPCHECK(adj_x += (T(1)/(cos_x*cos_x))*adj_ret;\
1242
+ if (!isfinite(adj_x) || cos_x == T(0))\
1243
+ {\
1244
+ printf("%s:%d - adj_tan(%f, %f, %f)\n", __FILE__, __LINE__, float(x), float(adj_x), float(adj_ret));\
1245
+ assert(0);\
1246
+ })\
1247
+ DO_IF_NO_FPCHECK(if (cos_x != T(0))\
1248
+ adj_x += (T(1)/(cos_x*cos_x))*adj_ret;)\
1249
+ }\
1250
+ inline CUDA_CALLABLE void adj_atan(T x, T& adj_x, T adj_ret)\
1251
+ {\
1252
+ adj_x += adj_ret /(x*x + T(1));\
1253
+ }\
1254
+ inline CUDA_CALLABLE void adj_atan2(T y, T x, T& adj_y, T& adj_x, T adj_ret)\
1255
+ {\
1256
+ T d = x*x + y*y;\
1257
+ DO_IF_FPCHECK(adj_x -= y/d*adj_ret;\
1258
+ adj_y += x/d*adj_ret;\
1259
+ if (!isfinite(adj_x) || !isfinite(adj_y) || d == T(0))\
1260
+ {\
1261
+ printf("%s:%d - adj_atan2(%f, %f, %f, %f, %f)\n", __FILE__, __LINE__, float(y), float(x), float(adj_y), float(adj_x), float(adj_ret));\
1262
+ assert(0);\
1263
+ })\
1264
+ DO_IF_NO_FPCHECK(if (d > T(0))\
1265
+ {\
1266
+ adj_x -= (y/d)*adj_ret;\
1267
+ adj_y += (x/d)*adj_ret;\
1268
+ })\
1269
+ }\
1270
+ inline CUDA_CALLABLE void adj_sin(T x, T& adj_x, T adj_ret)\
1271
+ {\
1272
+ adj_x += cos(x)*adj_ret;\
1273
+ }\
1274
+ inline CUDA_CALLABLE void adj_cos(T x, T& adj_x, T adj_ret)\
1275
+ {\
1276
+ adj_x -= sin(x)*adj_ret;\
1277
+ }\
1278
+ inline CUDA_CALLABLE void adj_sinh(T x, T& adj_x, T adj_ret)\
1279
+ {\
1280
+ adj_x += cosh(x)*adj_ret;\
1281
+ }\
1282
+ inline CUDA_CALLABLE void adj_cosh(T x, T& adj_x, T adj_ret)\
1283
+ {\
1284
+ adj_x += sinh(x)*adj_ret;\
1285
+ }\
1286
+ inline CUDA_CALLABLE void adj_tanh(T x, T ret, T& adj_x, T adj_ret)\
1287
+ {\
1288
+ adj_x += (T(1) - ret*ret)*adj_ret;\
1289
+ }\
1290
+ inline CUDA_CALLABLE void adj_sqrt(T x, T ret, T& adj_x, T adj_ret)\
1291
+ {\
1292
+ adj_x += T(0.5)*(T(1)/ret)*adj_ret;\
1293
+ DO_IF_FPCHECK(if (!isfinite(adj_x))\
1294
+ {\
1295
+ printf("%s:%d - adj_sqrt(%f, %f, %f)\n", __FILE__, __LINE__, float(x), float(adj_x), float(adj_ret));\
1296
+ assert(0);\
1297
+ })\
1298
+ }\
1299
+ inline CUDA_CALLABLE void adj_cbrt(T x, T ret, T& adj_x, T adj_ret)\
1300
+ {\
1301
+ adj_x += (T(1)/T(3))*(T(1)/(ret*ret))*adj_ret;\
1302
+ DO_IF_FPCHECK(if (!isfinite(adj_x))\
1303
+ {\
1304
+ printf("%s:%d - adj_cbrt(%f, %f, %f)\n", __FILE__, __LINE__, float(x), float(adj_x), float(adj_ret));\
1305
+ assert(0);\
1306
+ })\
1307
+ }\
1308
+ inline CUDA_CALLABLE void adj_degrees(T x, T& adj_x, T adj_ret)\
1309
+ {\
1310
+ adj_x += RAD_TO_DEG * adj_ret;\
1311
+ }\
1312
+ inline CUDA_CALLABLE void adj_radians(T x, T& adj_x, T adj_ret)\
1313
+ {\
1314
+ adj_x += DEG_TO_RAD * adj_ret;\
1315
+ }\
1316
+ inline CUDA_CALLABLE void adj_round(T x, T& adj_x, T adj_ret){ }\
1317
+ inline CUDA_CALLABLE void adj_rint(T x, T& adj_x, T adj_ret){ }\
1318
+ inline CUDA_CALLABLE void adj_trunc(T x, T& adj_x, T adj_ret){ }\
1319
+ inline CUDA_CALLABLE void adj_floor(T x, T& adj_x, T adj_ret){ }\
1320
+ inline CUDA_CALLABLE void adj_ceil(T x, T& adj_x, T adj_ret){ }\
1321
+ inline CUDA_CALLABLE void adj_frac(T x, T& adj_x, T adj_ret){ }
1322
+
1323
+ DECLARE_ADJOINTS(float16)
1324
+ DECLARE_ADJOINTS(float32)
1325
+ DECLARE_ADJOINTS(float64)
1326
+
1327
+ template <typename C, typename T>
1328
+ CUDA_CALLABLE inline T where(const C& cond, const T& a, const T& b)
1329
+ {
1330
+ // The double NOT operator !! casts to bool without compiler warnings.
1331
+ return (!!cond) ? a : b;
1332
+ }
1333
+
1334
+ template <typename C, typename TA, typename TB, typename TRet>
1335
+ CUDA_CALLABLE inline void adj_where(const C& cond, const TA& a, const TB& b, C& adj_cond, TA& adj_a, TB& adj_b, const TRet& adj_ret)
1336
+ {
1337
+ // The double NOT operator !! casts to bool without compiler warnings.
1338
+ if (!!cond)
1339
+ adj_a += adj_ret;
1340
+ else
1341
+ adj_b += adj_ret;
1342
+ }
1343
+
1344
+ template <typename T>
1345
+ CUDA_CALLABLE inline T copy(const T& src)
1346
+ {
1347
+ return src;
1348
+ }
1349
+
1350
+ template <typename T>
1351
+ CUDA_CALLABLE inline void adj_copy(const T& src, T& adj_src, T& adj_dest)
1352
+ {
1353
+ adj_src += adj_dest;
1354
+ adj_dest = T{};
1355
+ }
1356
+
1357
+ template <typename T>
1358
+ CUDA_CALLABLE inline void assign(T& dest, const T& src)
1359
+ {
1360
+ dest = src;
1361
+ }
1362
+
1363
+ template <typename T>
1364
+ CUDA_CALLABLE inline void adj_assign(T& dest, const T& src, T& adj_dest, T& adj_src)
1365
+ {
1366
+ // this is generally a non-differentiable operation since it violates SSA,
1367
+ // except in read-modify-write statements which are reversible through backpropagation
1368
+ adj_src = adj_dest;
1369
+ adj_dest = T{};
1370
+ }
1371
+
1372
+
1373
+ // some helpful operator overloads (just for C++ use, these are not adjointed)
1374
+
1375
+ template <typename T>
1376
+ CUDA_CALLABLE inline T& operator += (T& a, const T& b) { a = add(a, b); return a; }
1377
+
1378
+ template <typename T>
1379
+ CUDA_CALLABLE inline T& operator -= (T& a, const T& b) { a = sub(a, b); return a; }
1380
+
1381
+ template <typename T>
1382
+ CUDA_CALLABLE inline T& operator &= (T& a, const T& b) { a = bit_and(a, b); return a; }
1383
+
1384
+ template <typename T>
1385
+ CUDA_CALLABLE inline T& operator |= (T& a, const T& b) { a = bit_or(a, b); return a; }
1386
+
1387
+ template <typename T>
1388
+ CUDA_CALLABLE inline T& operator ^= (T& a, const T& b) { a = bit_xor(a, b); return a; }
1389
+
1390
+ template <typename T>
1391
+ CUDA_CALLABLE inline T operator+(const T& a, const T& b) { return add(a, b); }
1392
+
1393
+ template <typename T>
1394
+ CUDA_CALLABLE inline T operator-(const T& a, const T& b) { return sub(a, b); }
1395
+
1396
+ template <typename T>
1397
+ CUDA_CALLABLE inline T operator&(const T& a, const T& b) { return bit_and(a, b); }
1398
+
1399
+ template <typename T>
1400
+ CUDA_CALLABLE inline T operator|(const T& a, const T& b) { return bit_or(a, b); }
1401
+
1402
+ template <typename T>
1403
+ CUDA_CALLABLE inline T operator^(const T& a, const T& b) { return bit_xor(a, b); }
1404
+
1405
+ template <typename T>
1406
+ CUDA_CALLABLE inline T pos(const T& x) { return x; }
1407
+ template <typename T>
1408
+ CUDA_CALLABLE inline void adj_pos(const T& x, T& adj_x, const T& adj_ret) { adj_x += T(adj_ret); }
1409
+
1410
+ // unary negation implemented as negative multiply, not sure the fp implications of this
1411
+ // may be better as 0.0 - x?
1412
+ template <typename T>
1413
+ CUDA_CALLABLE inline T neg(const T& x) { return T(0.0) - x; }
1414
+ template <typename T>
1415
+ CUDA_CALLABLE inline void adj_neg(const T& x, T& adj_x, const T& adj_ret) { adj_x += T(-adj_ret); }
1416
+
1417
+ // unary boolean negation
1418
+ template <typename T>
1419
+ CUDA_CALLABLE inline bool unot(const T& b) { return !b; }
1420
+ template <typename T>
1421
+ CUDA_CALLABLE inline void adj_unot(const T& b, T& adj_b, const bool& adj_ret) { }
1422
+
1423
+ const int LAUNCH_MAX_DIMS = 4; // should match types.py
1424
+
1425
+ struct launch_bounds_t
1426
+ {
1427
+ int shape[LAUNCH_MAX_DIMS]; // size of each dimension
1428
+ int ndim; // number of valid dimension
1429
+ size_t size; // total number of threads
1430
+ };
1431
+
1432
+ // represents coordinate in the launch grid
1433
+ struct launch_coord_t
1434
+ {
1435
+ int i;
1436
+ int j;
1437
+ int k;
1438
+ int l;
1439
+ };
1440
+
1441
+ // unravels a linear thread index to the corresponding launch grid coord (up to 4d)
1442
+ inline CUDA_CALLABLE launch_coord_t launch_coord(size_t linear, const launch_bounds_t& bounds)
1443
+ {
1444
+ launch_coord_t coord = {0, 0, 0, 0};
1445
+
1446
+ if (bounds.ndim > 3)
1447
+ {
1448
+ coord.l = linear%bounds.shape[3];
1449
+ linear /= bounds.shape[3];
1450
+ }
1451
+
1452
+ if (bounds.ndim > 2)
1453
+ {
1454
+ coord.k = linear%bounds.shape[2];
1455
+ linear /= bounds.shape[2];
1456
+ }
1457
+
1458
+ if (bounds.ndim > 1)
1459
+ {
1460
+ coord.j = linear%bounds.shape[1];
1461
+ linear /= bounds.shape[1];
1462
+ }
1463
+
1464
+ if (bounds.ndim > 0)
1465
+ {
1466
+ coord.i = linear;
1467
+ }
1468
+
1469
+ return coord;
1470
+ }
1471
+
1472
+ inline CUDA_CALLABLE int block_dim()
1473
+ {
1474
+ #if defined(__CUDA_ARCH__)
1475
+ return blockDim.x;
1476
+ #else
1477
+ return 1;
1478
+ #endif
1479
+ }
1480
+
1481
+ inline CUDA_CALLABLE int tid(size_t index, const launch_bounds_t& bounds)
1482
+ {
1483
+ // For the 1-D tid() we need to warn the user if we're about to provide a truncated index
1484
+ // Only do this in _DEBUG when called from device to avoid excessive register allocation
1485
+ #if defined(_DEBUG) || !defined(__CUDA_ARCH__)
1486
+ if (index > 2147483647) {
1487
+ printf("Warp warning: tid() is returning an overflowed int\n");
1488
+ }
1489
+ #endif
1490
+
1491
+ launch_coord_t c = launch_coord(index, bounds);
1492
+ return static_cast<int>(c.i);
1493
+ }
1494
+
1495
+ inline CUDA_CALLABLE_DEVICE void tid(int& i, int& j, size_t index, const launch_bounds_t& bounds)
1496
+ {
1497
+ launch_coord_t c = launch_coord(index, bounds);
1498
+ i = c.i;
1499
+ j = c.j;
1500
+ }
1501
+
1502
+ inline CUDA_CALLABLE_DEVICE void tid(int& i, int& j, int& k, size_t index, const launch_bounds_t& bounds)
1503
+ {
1504
+ launch_coord_t c = launch_coord(index, bounds);
1505
+ i = c.i;
1506
+ j = c.j;
1507
+ k = c.k;
1508
+ }
1509
+
1510
+ inline CUDA_CALLABLE_DEVICE void tid(int& i, int& j, int& k, int& l, size_t index, const launch_bounds_t& bounds)
1511
+ {
1512
+ launch_coord_t c = launch_coord(index, bounds);
1513
+ i = c.i;
1514
+ j = c.j;
1515
+ k = c.k;
1516
+ l = c.l;
1517
+ }
1518
+
1519
+ // should match types.py
1520
+ static const int SLICE_BEGIN = (1LL << (sizeof(int) * 8 - 1)) - 1; // std::numeric_limits<int>::max()
1521
+ static const int SLICE_END = -(1LL << (sizeof(int) * 8 - 1)); // std::numeric_limits<int>::min()
1522
+
1523
+ struct slice_t
1524
+ {
1525
+ int start;
1526
+ int stop;
1527
+ int step;
1528
+
1529
+ CUDA_CALLABLE inline slice_t()
1530
+ : start(SLICE_BEGIN), stop(SLICE_END), step(1)
1531
+ {}
1532
+
1533
+ CUDA_CALLABLE inline slice_t(int start, int stop, int step)
1534
+ : start(start), stop(stop), step(step)
1535
+ {}
1536
+ };
1537
+
1538
+ CUDA_CALLABLE inline slice_t slice_adjust_indices(const slice_t& slice, int length)
1539
+ {
1540
+ #ifndef NDEBUG
1541
+ if (slice.step == 0)
1542
+ {
1543
+ printf("%s:%d slice step cannot be 0\n", __FILE__, __LINE__);
1544
+ assert(0);
1545
+ }
1546
+ #endif
1547
+
1548
+ int start, stop;
1549
+
1550
+ if (slice.start == SLICE_BEGIN)
1551
+ {
1552
+ start = slice.step < 0 ? length - 1 : 0;
1553
+ }
1554
+ else
1555
+ {
1556
+ start = min(max(slice.start, -length), length);
1557
+ start = start < 0 ? start + length : start;
1558
+ }
1559
+
1560
+ if (slice.stop == SLICE_END)
1561
+ {
1562
+ stop = slice.step < 0 ? -1 : length;
1563
+ }
1564
+ else
1565
+ {
1566
+ stop = min(max(slice.stop, -length), length);
1567
+ stop = stop < 0 ? stop + length : stop;
1568
+ }
1569
+
1570
+ return {start, stop, slice.step};
1571
+ }
1572
+
1573
+ CUDA_CALLABLE inline int slice_get_length(const slice_t& slice)
1574
+ {
1575
+ #ifndef NDEBUG
1576
+ if (slice.step == 0)
1577
+ {
1578
+ printf("%s:%d slice step cannot be 0\n", __FILE__, __LINE__);
1579
+ assert(0);
1580
+ }
1581
+ #endif
1582
+
1583
+ if (slice.step > 0 && slice.start < slice.stop)
1584
+ {
1585
+ return 1 + (slice.stop - slice.start - 1) / slice.step;
1586
+ }
1587
+
1588
+ if (slice.step < 0 && slice.start > slice.stop)
1589
+ {
1590
+ return 1 + (slice.start - slice.stop - 1) / (-slice.step);
1591
+ }
1592
+
1593
+ return 0;
1594
+ }
1595
+
1596
+ template<typename T>
1597
+ inline CUDA_CALLABLE T atomic_add(T* buf, T value)
1598
+ {
1599
+ #if !defined(__CUDA_ARCH__)
1600
+ T old = buf[0];
1601
+ buf[0] += value;
1602
+ return old;
1603
+ #else
1604
+ return atomicAdd(buf, value);
1605
+ #endif
1606
+ }
1607
+
1608
+ template <>
1609
+ inline CUDA_CALLABLE int64 atomic_add(int64* buf, int64 value)
1610
+ {
1611
+ #if !defined(__CUDA_ARCH__)
1612
+ int64 old = buf[0];
1613
+ buf[0] += value;
1614
+ return old;
1615
+ #else // CUDA compiled by NVRTC
1616
+ unsigned long long int *buf_as_ull = (unsigned long long int*)buf;
1617
+ unsigned long long int unsigned_value = static_cast<unsigned long long int>(value);
1618
+ unsigned long long int result = atomicAdd(buf_as_ull, unsigned_value);
1619
+ return static_cast<int64>(result);
1620
+ #endif
1621
+ }
1622
+
1623
+ template<>
1624
+ inline CUDA_CALLABLE float16 atomic_add(float16* buf, float16 value)
1625
+ {
1626
+ #if !defined(__CUDA_ARCH__)
1627
+ float16 old = buf[0];
1628
+ buf[0] += value;
1629
+ return old;
1630
+ #else // CUDA compiled by NVRTC
1631
+ #if __CUDA_ARCH__ >= 700
1632
+ #if defined(__clang__) // CUDA compiled by Clang
1633
+ __half r = atomicAdd(reinterpret_cast<__half*>(buf), *reinterpret_cast<__half*>(&value));
1634
+ return *reinterpret_cast<float16*>(&r);
1635
+ #else // CUDA compiled by NVRTC
1636
+ /* Define __PTR for atomicAdd prototypes below, undef after done */
1637
+ #if (defined(_MSC_VER) && defined(_WIN64)) || defined(__LP64__) || defined(__CUDACC_RTC__)
1638
+ #define __PTR "l"
1639
+ #else
1640
+ #define __PTR "r"
1641
+ #endif /*(defined(_MSC_VER) && defined(_WIN64)) || defined(__LP64__) || defined(__CUDACC_RTC__)*/
1642
+
1643
+ half r = 0.0;
1644
+
1645
+ asm volatile ("{ atom.add.noftz.f16 %0,[%1],%2; }\n"
1646
+ : "=h"(r.u)
1647
+ : __PTR(buf), "h"(value.u)
1648
+ : "memory");
1649
+
1650
+ return r;
1651
+
1652
+ #undef __PTR
1653
+ #endif
1654
+ #else
1655
+ // No native __half atomic support on compute capability < 7.0
1656
+ return float16(0.0f);
1657
+ #endif
1658
+ #endif
1659
+ }
1660
+
1661
+ template<>
1662
+ inline CUDA_CALLABLE float64 atomic_add(float64* buf, float64 value)
1663
+ {
1664
+ #if !defined(__CUDA_ARCH__)
1665
+ float64 old = buf[0];
1666
+ buf[0] += value;
1667
+ return old;
1668
+ #elif defined(__clang__) // CUDA compiled by Clang
1669
+ return atomicAdd(buf, value);
1670
+ #else // CUDA compiled by NVRTC
1671
+
1672
+ /* Define __PTR for atomicAdd prototypes below, undef after done */
1673
+ #if (defined(_MSC_VER) && defined(_WIN64)) || defined(__LP64__) || defined(__CUDACC_RTC__)
1674
+ #define __PTR "l"
1675
+ #else
1676
+ #define __PTR "r"
1677
+ #endif /*(defined(_MSC_VER) && defined(_WIN64)) || defined(__LP64__) || defined(__CUDACC_RTC__)*/
1678
+
1679
+ double r = 0.0;
1680
+
1681
+ #if __CUDA_ARCH__ >= 600
1682
+
1683
+ asm volatile ("{ atom.add.f64 %0,[%1],%2; }\n"
1684
+ : "=d"(r)
1685
+ : __PTR(buf), "d"(value)
1686
+ : "memory");
1687
+ #endif
1688
+
1689
+ return r;
1690
+
1691
+ #undef __PTR
1692
+
1693
+ #endif // CUDA compiled by NVRTC
1694
+ }
1695
+
1696
+ template <typename T>
1697
+ inline CUDA_CALLABLE T atomic_min(T* address, T val)
1698
+ {
1699
+ #if defined(__CUDA_ARCH__)
1700
+ return atomicMin(address, val);
1701
+
1702
+ #else
1703
+ T old = *address;
1704
+ *address = min(old, val);
1705
+ return old;
1706
+ #endif
1707
+ }
1708
+
1709
+ // emulate atomic float min with atomicCAS()
1710
+ template <>
1711
+ inline CUDA_CALLABLE float atomic_min(float* address, float val)
1712
+ {
1713
+ #if defined(__CUDA_ARCH__)
1714
+ int *address_as_int = (int*)address;
1715
+ int old = *address_as_int, assumed;
1716
+
1717
+ while (val < __int_as_float(old))
1718
+ {
1719
+ assumed = old;
1720
+ old = atomicCAS(address_as_int, assumed,
1721
+ __float_as_int(val));
1722
+ }
1723
+
1724
+ return __int_as_float(old);
1725
+
1726
+ #else
1727
+ float old = *address;
1728
+ *address = min(old, val);
1729
+ return old;
1730
+ #endif
1731
+ }
1732
+
1733
+ // emulate atomic double min with atomicCAS()
1734
+ template <>
1735
+ inline CUDA_CALLABLE double atomic_min(double* address, double val)
1736
+ {
1737
+ #if defined(__CUDA_ARCH__)
1738
+ unsigned long long int *address_as_ull = (unsigned long long int*)address;
1739
+ unsigned long long int old = *address_as_ull, assumed;
1740
+
1741
+ while (val < __longlong_as_double(old))
1742
+ {
1743
+ assumed = old;
1744
+ old = atomicCAS(address_as_ull, assumed,
1745
+ __double_as_longlong(val));
1746
+ }
1747
+
1748
+ return __longlong_as_double(old);
1749
+
1750
+ #else
1751
+ double old = *address;
1752
+ *address = min(old, val);
1753
+ return old;
1754
+ #endif
1755
+ }
1756
+
1757
+ template <typename T>
1758
+ inline CUDA_CALLABLE T atomic_max(T* address, T val)
1759
+ {
1760
+ #if defined(__CUDA_ARCH__)
1761
+ return atomicMax(address, val);
1762
+
1763
+ #else
1764
+ T old = *address;
1765
+ *address = max(old, val);
1766
+ return old;
1767
+ #endif
1768
+ }
1769
+
1770
+ // emulate atomic float max with atomicCAS()
1771
+ template<>
1772
+ inline CUDA_CALLABLE float atomic_max(float* address, float val)
1773
+ {
1774
+ #if defined(__CUDA_ARCH__)
1775
+ int *address_as_int = (int*)address;
1776
+ int old = *address_as_int, assumed;
1777
+
1778
+ while (val > __int_as_float(old))
1779
+ {
1780
+ assumed = old;
1781
+ old = atomicCAS(address_as_int, assumed,
1782
+ __float_as_int(val));
1783
+ }
1784
+
1785
+ return __int_as_float(old);
1786
+
1787
+ #else
1788
+ float old = *address;
1789
+ *address = max(old, val);
1790
+ return old;
1791
+ #endif
1792
+ }
1793
+
1794
+ // emulate atomic double max with atomicCAS()
1795
+ template<>
1796
+ inline CUDA_CALLABLE double atomic_max(double* address, double val)
1797
+ {
1798
+ #if defined(__CUDA_ARCH__)
1799
+ unsigned long long int *address_as_ull = (unsigned long long int*)address;
1800
+ unsigned long long int old = *address_as_ull, assumed;
1801
+
1802
+ while (val > __longlong_as_double(old))
1803
+ {
1804
+ assumed = old;
1805
+ old = atomicCAS(address_as_ull, assumed,
1806
+ __double_as_longlong(val));
1807
+ }
1808
+
1809
+ return __longlong_as_double(old);
1810
+
1811
+ #else
1812
+ double old = *address;
1813
+ *address = max(old, val);
1814
+ return old;
1815
+ #endif
1816
+ }
1817
+
1818
+ // default behavior for adjoint of atomic min/max operation that accumulates gradients for all elements matching the min/max value
1819
+ template <typename T>
1820
+ CUDA_CALLABLE inline void adj_atomic_minmax(T *addr, T *adj_addr, const T &value, T &adj_value)
1821
+ {
1822
+ if (value == *addr)
1823
+ adj_value += *adj_addr;
1824
+ }
1825
+
1826
+ // for integral types we do not accumulate gradients
1827
+ CUDA_CALLABLE inline void adj_atomic_minmax(int8* buf, int8* adj_buf, const int8 &value, int8 &adj_value) { }
1828
+ CUDA_CALLABLE inline void adj_atomic_minmax(uint8* buf, uint8* adj_buf, const uint8 &value, uint8 &adj_value) { }
1829
+ CUDA_CALLABLE inline void adj_atomic_minmax(int16* buf, int16* adj_buf, const int16 &value, int16 &adj_value) { }
1830
+ CUDA_CALLABLE inline void adj_atomic_minmax(uint16* buf, uint16* adj_buf, const uint16 &value, uint16 &adj_value) { }
1831
+ CUDA_CALLABLE inline void adj_atomic_minmax(int32* buf, int32* adj_buf, const int32 &value, int32 &adj_value) { }
1832
+ CUDA_CALLABLE inline void adj_atomic_minmax(uint32* buf, uint32* adj_buf, const uint32 &value, uint32 &adj_value) { }
1833
+ CUDA_CALLABLE inline void adj_atomic_minmax(int64* buf, int64* adj_buf, const int64 &value, int64 &adj_value) { }
1834
+ CUDA_CALLABLE inline void adj_atomic_minmax(uint64* buf, uint64* adj_buf, const uint64 &value, uint64 &adj_value) { }
1835
+ CUDA_CALLABLE inline void adj_atomic_minmax(bool* buf, bool* adj_buf, const bool &value, bool &adj_value) { }
1836
+
1837
+
1838
+ template<typename T>
1839
+ inline CUDA_CALLABLE T atomic_cas(T* address, T compare, T val)
1840
+ {
1841
+ #if defined(__CUDA_ARCH__)
1842
+ return atomicCAS(address, compare, val);
1843
+ #else
1844
+ T old = *address;
1845
+ if (old == compare)
1846
+ {
1847
+ *address = val;
1848
+ }
1849
+ return old;
1850
+ #endif
1851
+ }
1852
+
1853
+ template<>
1854
+ inline CUDA_CALLABLE float atomic_cas(float* address, float compare, float val)
1855
+ {
1856
+ #if defined(__CUDA_ARCH__)
1857
+ auto result = atomicCAS(reinterpret_cast<unsigned int*>(address),
1858
+ reinterpret_cast<unsigned int&>(compare),
1859
+ reinterpret_cast<unsigned int&>(val));
1860
+ return reinterpret_cast<float&>(result);
1861
+ #else
1862
+ float old = *address;
1863
+ if (old == compare)
1864
+ {
1865
+ *address = val;
1866
+ }
1867
+ return old;
1868
+ #endif
1869
+ }
1870
+
1871
+ template<>
1872
+ inline CUDA_CALLABLE double atomic_cas(double* address, double compare, double val)
1873
+ {
1874
+ #if defined(__CUDA_ARCH__)
1875
+ auto result = atomicCAS(reinterpret_cast<unsigned long long int *>(address),
1876
+ reinterpret_cast<unsigned long long int &>(compare),
1877
+ reinterpret_cast<unsigned long long int &>(val));
1878
+ return reinterpret_cast<double&>(result);
1879
+ #else
1880
+ double old = *address;
1881
+ if (old == compare)
1882
+ {
1883
+ *address = val;
1884
+ }
1885
+ return old;
1886
+ #endif
1887
+ }
1888
+
1889
+ template<>
1890
+ inline CUDA_CALLABLE int64 atomic_cas(int64* address, int64 compare, int64 val)
1891
+ {
1892
+ #if defined(__CUDA_ARCH__)
1893
+ auto result = atomicCAS(reinterpret_cast<unsigned long long int *>(address),
1894
+ reinterpret_cast<unsigned long long int &>(compare),
1895
+ reinterpret_cast<unsigned long long int &>(val));
1896
+ return reinterpret_cast<int64&>(result);
1897
+ #else
1898
+ int64 old = *address;
1899
+ if (old == compare)
1900
+ {
1901
+ *address = val;
1902
+ }
1903
+ return old;
1904
+ #endif
1905
+ }
1906
+
1907
+ template<typename T>
1908
+ inline CUDA_CALLABLE T atomic_exch(T* address, T val)
1909
+ {
1910
+ #if defined(__CUDA_ARCH__)
1911
+ return atomicExch(address, val);
1912
+ #else
1913
+ T old = *address;
1914
+ *address = val;
1915
+ return old;
1916
+ #endif
1917
+ }
1918
+
1919
+ template<>
1920
+ inline CUDA_CALLABLE double atomic_exch(double* address, double val)
1921
+ {
1922
+ #if defined(__CUDA_ARCH__)
1923
+ auto result = atomicExch(reinterpret_cast<unsigned long long int*>(address),
1924
+ reinterpret_cast<unsigned long long int&>(val));
1925
+ return reinterpret_cast<double&>(result);
1926
+ #else
1927
+ double old = *address;
1928
+ *address = val;
1929
+ return old;
1930
+ #endif
1931
+ }
1932
+
1933
+ template<>
1934
+ inline CUDA_CALLABLE int64 atomic_exch(int64* address, int64 val)
1935
+ {
1936
+ #if defined(__CUDA_ARCH__)
1937
+ auto result = atomicExch(reinterpret_cast<unsigned long long int*>(address),
1938
+ reinterpret_cast<unsigned long long int&>(val));
1939
+ return reinterpret_cast<int64&>(result);
1940
+ #else
1941
+ int64 old = *address;
1942
+ *address = val;
1943
+ return old;
1944
+ #endif
1945
+ }
1946
+
1947
+
1948
+ template<typename T>
1949
+ CUDA_CALLABLE inline void adj_atomic_cas(T* address, T compare, T val, T* adj_address, T& adj_compare, T& adj_val, T adj_ret)
1950
+ {
1951
+ // Not implemented
1952
+ }
1953
+
1954
+ template<typename T>
1955
+ CUDA_CALLABLE inline void adj_atomic_exch(T* address, T val, T* adj_address, T& adj_val, T adj_ret)
1956
+ {
1957
+ // Not implemented
1958
+ }
1959
+
1960
+
1961
+ template<typename T>
1962
+ inline CUDA_CALLABLE T atomic_and(T* buf, T value)
1963
+ {
1964
+ #if defined(__CUDA_ARCH__)
1965
+ return atomicAnd(buf, value);
1966
+ #else
1967
+ T old = buf[0];
1968
+ buf[0] &= value;
1969
+ return old;
1970
+ #endif
1971
+ }
1972
+
1973
+ template<typename T>
1974
+ inline CUDA_CALLABLE T atomic_or(T* buf, T value)
1975
+ {
1976
+ #if defined(__CUDA_ARCH__)
1977
+ return atomicOr(buf, value);
1978
+ #else
1979
+ T old = buf[0];
1980
+ buf[0] |= value;
1981
+ return old;
1982
+ #endif
1983
+ }
1984
+
1985
+ template<typename T>
1986
+ inline CUDA_CALLABLE T atomic_xor(T* buf, T value)
1987
+ {
1988
+ #if defined(__CUDA_ARCH__)
1989
+ return atomicXor(buf, value);
1990
+ #else
1991
+ T old = buf[0];
1992
+ buf[0] ^= value;
1993
+ return old;
1994
+ #endif
1995
+ }
1996
+
1997
+
1998
+ // for bitwise operations we do not accumulate gradients
1999
+ template<typename T>
2000
+ CUDA_CALLABLE inline void adj_atomic_and(T* buf, T* adj_buf, T &value, T &adj_value) { }
2001
+ template<typename T>
2002
+ CUDA_CALLABLE inline void adj_atomic_or(T* buf, T* adj_buf, T &value, T &adj_value) { }
2003
+ template<typename T>
2004
+ CUDA_CALLABLE inline void adj_atomic_xor(T* buf, T* adj_buf, T &value, T &adj_value) { }
2005
+
2006
+
2007
+ } // namespace wp
2008
+
2009
+
2010
+ // bool and printf are defined outside of the wp namespace in crt.h, hence
2011
+ // their adjoint counterparts are also defined in the global namespace.
2012
+ template <typename T>
2013
+ CUDA_CALLABLE inline void adj_bool(T, T&, bool) {}
2014
+ inline CUDA_CALLABLE void adj_printf(const char* fmt, ...) {}
2015
+
2016
+
2017
+ #include "vec.h"
2018
+ #include "mat.h"
2019
+ #include "quat.h"
2020
+ #include "spatial.h"
2021
+ #include "intersect.h"
2022
+ #include "intersect_adj.h"
2023
+
2024
+ //--------------
2025
+ namespace wp
2026
+ {
2027
+
2028
+
2029
+ // dot for scalar types just to make some templates compile for scalar/vector
2030
+ inline CUDA_CALLABLE float dot(float a, float b) { return mul(a, b); }
2031
+ inline CUDA_CALLABLE void adj_dot(float a, float b, float& adj_a, float& adj_b, float adj_ret) { adj_mul(a, b, adj_a, adj_b, adj_ret); }
2032
+ inline CUDA_CALLABLE float tensordot(float a, float b) { return mul(a, b); }
2033
+
2034
+
2035
+ #define DECLARE_INTERP_FUNCS(T) \
2036
+ CUDA_CALLABLE inline T smoothstep(T edge0, T edge1, T x)\
2037
+ {\
2038
+ x = clamp((x - edge0) / (edge1 - edge0), T(0), T(1));\
2039
+ return x * x * (T(3) - T(2) * x);\
2040
+ }\
2041
+ CUDA_CALLABLE inline void adj_smoothstep(T edge0, T edge1, T x, T& adj_edge0, T& adj_edge1, T& adj_x, T adj_ret)\
2042
+ {\
2043
+ T ab = edge0 - edge1;\
2044
+ T ax = edge0 - x;\
2045
+ T bx = edge1 - x;\
2046
+ T xb = x - edge1;\
2047
+ \
2048
+ if (bx / ab >= T(0) || ax / ab <= T(0))\
2049
+ {\
2050
+ return;\
2051
+ }\
2052
+ \
2053
+ T ab3 = ab * ab * ab;\
2054
+ T ab4 = ab3 * ab;\
2055
+ adj_edge0 += adj_ret * ((T(6) * ax * bx * bx) / ab4);\
2056
+ adj_edge1 += adj_ret * ((T(6) * ax * ax * xb) / ab4);\
2057
+ adj_x += adj_ret * ((T(6) * ax * bx ) / ab3);\
2058
+ }\
2059
+ CUDA_CALLABLE inline T lerp(const T& a, const T& b, T t)\
2060
+ {\
2061
+ return a*(T(1)-t) + b*t;\
2062
+ }\
2063
+ CUDA_CALLABLE inline void adj_lerp(const T& a, const T& b, T t, T& adj_a, T& adj_b, T& adj_t, const T& adj_ret)\
2064
+ {\
2065
+ adj_a += adj_ret*(T(1)-t);\
2066
+ adj_b += adj_ret*t;\
2067
+ adj_t += b*adj_ret - a*adj_ret;\
2068
+ }
2069
+
2070
+ DECLARE_INTERP_FUNCS(float16)
2071
+ DECLARE_INTERP_FUNCS(float32)
2072
+ DECLARE_INTERP_FUNCS(float64)
2073
+
2074
+ inline CUDA_CALLABLE void print(const str s)
2075
+ {
2076
+ printf("%s\n", s);
2077
+ }
2078
+
2079
+ inline CUDA_CALLABLE void print(signed char i)
2080
+ {
2081
+ printf("%d\n", i);
2082
+ }
2083
+
2084
+ inline CUDA_CALLABLE void print(short i)
2085
+ {
2086
+ printf("%d\n", i);
2087
+ }
2088
+
2089
+ inline CUDA_CALLABLE void print(int i)
2090
+ {
2091
+ printf("%d\n", i);
2092
+ }
2093
+
2094
+ inline CUDA_CALLABLE void print(long i)
2095
+ {
2096
+ printf("%ld\n", i);
2097
+ }
2098
+
2099
+ inline CUDA_CALLABLE void print(long long i)
2100
+ {
2101
+ printf("%lld\n", i);
2102
+ }
2103
+
2104
+ inline CUDA_CALLABLE void print(unsigned char i)
2105
+ {
2106
+ printf("%u\n", i);
2107
+ }
2108
+
2109
+ inline CUDA_CALLABLE void print(unsigned short i)
2110
+ {
2111
+ printf("%u\n", i);
2112
+ }
2113
+
2114
+ inline CUDA_CALLABLE void print(unsigned int i)
2115
+ {
2116
+ printf("%u\n", i);
2117
+ }
2118
+
2119
+ inline CUDA_CALLABLE void print(unsigned long i)
2120
+ {
2121
+ printf("%lu\n", i);
2122
+ }
2123
+
2124
+ inline CUDA_CALLABLE void print(unsigned long long i)
2125
+ {
2126
+ printf("%llu\n", i);
2127
+ }
2128
+
2129
+ inline CUDA_CALLABLE void print(bool b)
2130
+ {
2131
+ printf(b ? "True\n" : "False\n");
2132
+ }
2133
+
2134
+ template<unsigned Length, typename Type>
2135
+ inline CUDA_CALLABLE void print(vec_t<Length, Type> v)
2136
+ {
2137
+ for( unsigned i=0; i < Length; ++i )
2138
+ {
2139
+ printf("%g ", float(v[i]));
2140
+ }
2141
+ printf("\n");
2142
+ }
2143
+
2144
+ template<typename Type>
2145
+ inline CUDA_CALLABLE void print(quat_t<Type> i)
2146
+ {
2147
+ printf("%g %g %g %g\n", float(i.x), float(i.y), float(i.z), float(i.w));
2148
+ }
2149
+
2150
+ template<unsigned Rows,unsigned Cols,typename Type>
2151
+ inline CUDA_CALLABLE void print(const mat_t<Rows,Cols,Type> &m)
2152
+ {
2153
+ for( unsigned i=0; i< Rows; ++i )
2154
+ {
2155
+ for( unsigned j=0; j< Cols; ++j )
2156
+ {
2157
+ printf("%g ",float(m.data[i][j]));
2158
+ }
2159
+ printf("\n");
2160
+ }
2161
+ }
2162
+
2163
+ template<typename Type>
2164
+ inline CUDA_CALLABLE void print(transform_t<Type> t)
2165
+ {
2166
+ printf("(%g %g %g) (%g %g %g %g)\n", float(t.p[0]), float(t.p[1]), float(t.p[2]), float(t.q.x), float(t.q.y), float(t.q.z), float(t.q.w));
2167
+ }
2168
+
2169
+ template<typename T>
2170
+ inline CUDA_CALLABLE void adj_print(const T& x, const T& adj_x)
2171
+ {
2172
+ printf("adj: <type without print implementation>\n");
2173
+ }
2174
+
2175
+ // note: adj_print() only prints the adjoint value, since the value itself gets printed in replay print()
2176
+ inline CUDA_CALLABLE void adj_print(half x, half adj_x) { printf("adj: %g\n", half_to_float(adj_x)); }
2177
+ inline CUDA_CALLABLE void adj_print(float x, float adj_x) { printf("adj: %g\n", adj_x); }
2178
+ inline CUDA_CALLABLE void adj_print(double x, double adj_x) { printf("adj: %g\n", adj_x); }
2179
+
2180
+ inline CUDA_CALLABLE void adj_print(signed char x, signed char adj_x) { printf("adj: %d\n", adj_x); }
2181
+ inline CUDA_CALLABLE void adj_print(short x, short adj_x) { printf("adj: %d\n", adj_x); }
2182
+ inline CUDA_CALLABLE void adj_print(int x, int adj_x) { printf("adj: %d\n", adj_x); }
2183
+ inline CUDA_CALLABLE void adj_print(long x, long adj_x) { printf("adj: %ld\n", adj_x); }
2184
+ inline CUDA_CALLABLE void adj_print(long long x, long long adj_x) { printf("adj: %lld\n", adj_x); }
2185
+
2186
+ inline CUDA_CALLABLE void adj_print(unsigned char x, unsigned char adj_x) { printf("adj: %u\n", adj_x); }
2187
+ inline CUDA_CALLABLE void adj_print(unsigned short x, unsigned short adj_x) { printf("adj: %u\n", adj_x); }
2188
+ inline CUDA_CALLABLE void adj_print(unsigned x, unsigned adj_x) { printf("adj: %u\n", adj_x); }
2189
+ inline CUDA_CALLABLE void adj_print(unsigned long x, unsigned long adj_x) { printf("adj: %lu\n", adj_x); }
2190
+ inline CUDA_CALLABLE void adj_print(unsigned long long x, unsigned long long adj_x) { printf("adj: %llu\n", adj_x); }
2191
+
2192
+ inline CUDA_CALLABLE void adj_print(bool x, bool adj_x) { printf("adj: %s\n", (adj_x ? "True" : "False")); }
2193
+
2194
+ template<unsigned Length, typename Type>
2195
+ inline CUDA_CALLABLE void adj_print(const vec_t<Length, Type>& v, const vec_t<Length, Type>& adj_v)
2196
+ {
2197
+ printf("adj:");
2198
+ for (unsigned i = 0; i < Length; i++)
2199
+ printf(" %g", float(adj_v[i]));
2200
+ printf("\n");
2201
+ }
2202
+
2203
+ template<unsigned Rows, unsigned Cols, typename Type>
2204
+ inline CUDA_CALLABLE void adj_print(const mat_t<Rows, Cols, Type>& m, const mat_t<Rows, Cols, Type>& adj_m)
2205
+ {
2206
+ for (unsigned i = 0; i < Rows; i++)
2207
+ {
2208
+ if (i == 0)
2209
+ printf("adj:");
2210
+ else
2211
+ printf(" ");
2212
+ for (unsigned j = 0; j < Cols; j++)
2213
+ printf(" %g", float(adj_m.data[i][j]));
2214
+ printf("\n");
2215
+ }
2216
+ }
2217
+
2218
+ template<typename Type>
2219
+ inline CUDA_CALLABLE void adj_print(const quat_t<Type>& q, const quat_t<Type>& adj_q)
2220
+ {
2221
+ printf("adj: %g %g %g %g\n", float(adj_q.x), float(adj_q.y), float(adj_q.z), float(adj_q.w));
2222
+ }
2223
+
2224
+ template<typename Type>
2225
+ inline CUDA_CALLABLE void adj_print(const transform_t<Type>& t, const transform_t<Type>& adj_t)
2226
+ {
2227
+ printf("adj: (%g %g %g) (%g %g %g %g)\n",
2228
+ float(adj_t.p[0]), float(adj_t.p[1]), float(adj_t.p[2]),
2229
+ float(adj_t.q.x), float(adj_t.q.y), float(adj_t.q.z), float(adj_t.q.w));
2230
+ }
2231
+
2232
+ inline CUDA_CALLABLE void adj_print(str t, str& adj_t)
2233
+ {
2234
+ printf("adj: %s\n", t);
2235
+ }
2236
+
2237
+ template <typename T>
2238
+ inline CUDA_CALLABLE void expect_eq(const T& actual, const T& expected)
2239
+ {
2240
+ if (!(actual == expected))
2241
+ {
2242
+ printf("Error, expect_eq() failed:\n");
2243
+ printf("\t Expected: "); print(expected);
2244
+ printf("\t Actual: "); print(actual);
2245
+ }
2246
+ }
2247
+
2248
+ template <typename T>
2249
+ inline CUDA_CALLABLE void adj_expect_eq(const T& a, const T& b, T& adj_a, T& adj_b)
2250
+ {
2251
+ // nop
2252
+ }
2253
+
2254
+ template <typename T>
2255
+ inline CUDA_CALLABLE void expect_neq(const T& actual, const T& expected)
2256
+ {
2257
+ if (actual == expected)
2258
+ {
2259
+ printf("Error, expect_neq() failed:\n");
2260
+ printf("\t Expected: "); print(expected);
2261
+ printf("\t Actual: "); print(actual);
2262
+ }
2263
+ }
2264
+
2265
+ template <typename T>
2266
+ inline CUDA_CALLABLE void adj_expect_neq(const T& a, const T& b, T& adj_a, T& adj_b)
2267
+ {
2268
+ // nop
2269
+ }
2270
+
2271
+ template <typename T>
2272
+ inline CUDA_CALLABLE void expect_near(const T& actual, const T& expected, const T& tolerance)
2273
+ {
2274
+ if (abs(actual - expected) > tolerance)
2275
+ {
2276
+ printf("Error, expect_near() failed with tolerance "); print(tolerance);
2277
+ printf(" Expected: "); print(expected);
2278
+ printf(" Actual: "); print(actual);
2279
+ printf(" Absolute difference: "); print(abs(actual - expected));
2280
+ }
2281
+ }
2282
+
2283
+ inline CUDA_CALLABLE void expect_near(const vec3& actual, const vec3& expected, const float& tolerance)
2284
+ {
2285
+ const float diff = max(max(abs(actual[0] - expected[0]), abs(actual[1] - expected[1])), abs(actual[2] - expected[2]));
2286
+ if (diff > tolerance)
2287
+ {
2288
+ printf("Error, expect_near() failed with tolerance "); print(tolerance);
2289
+ printf(" Expected: "); print(expected);
2290
+ printf(" Actual: "); print(actual);
2291
+ printf(" Max absolute difference: "); print(diff);
2292
+ }
2293
+ }
2294
+
2295
+ template <typename T>
2296
+ inline CUDA_CALLABLE void adj_expect_near(const T& actual, const T& expected, const T& tolerance, T& adj_actual, T& adj_expected, T& adj_tolerance)
2297
+ {
2298
+ // nop
2299
+ }
2300
+
2301
+ inline CUDA_CALLABLE void adj_expect_near(const vec3& actual, const vec3& expected, float tolerance, vec3& adj_actual, vec3& adj_expected, float adj_tolerance)
2302
+ {
2303
+ // nop
2304
+ }
2305
+
2306
+
2307
+ } // namespace wp
2308
+
2309
+ // include array.h so we have the print, isfinite functions for the inner array types defined
2310
+ #include "array.h"
2311
+ #include "tuple.h"
2312
+ #include "mesh.h"
2313
+ #include "bvh.h"
2314
+ #include "svd.h"
2315
+ #include "hashgrid.h"
2316
+ #include "volume.h"
2317
+ #include "range.h"
2318
+ #include "rand.h"
2319
+ #include "noise.h"
2320
+ #include "matnn.h"
2321
+
2322
+ #if !defined(WP_ENABLE_CUDA) // only include in kernels for now
2323
+ #include "tile.h"
2324
+ #include "tile_reduce.h"
2325
+ #include "tile_scan.h"
2326
+ #include "tile_radix_sort.h"
2327
+ #endif //!defined(WP_ENABLE_CUDA)