warp-lang 1.7.0__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (429) hide show
  1. warp/__init__.py +139 -0
  2. warp/__init__.pyi +1 -0
  3. warp/autograd.py +1142 -0
  4. warp/bin/warp-clang.so +0 -0
  5. warp/bin/warp.so +0 -0
  6. warp/build.py +557 -0
  7. warp/build_dll.py +405 -0
  8. warp/builtins.py +6855 -0
  9. warp/codegen.py +3969 -0
  10. warp/config.py +158 -0
  11. warp/constants.py +57 -0
  12. warp/context.py +6812 -0
  13. warp/dlpack.py +462 -0
  14. warp/examples/__init__.py +24 -0
  15. warp/examples/assets/bear.usd +0 -0
  16. warp/examples/assets/bunny.usd +0 -0
  17. warp/examples/assets/cartpole.urdf +110 -0
  18. warp/examples/assets/crazyflie.usd +0 -0
  19. warp/examples/assets/cube.usd +0 -0
  20. warp/examples/assets/nonuniform.usd +0 -0
  21. warp/examples/assets/nv_ant.xml +92 -0
  22. warp/examples/assets/nv_humanoid.xml +183 -0
  23. warp/examples/assets/nvidia_logo.png +0 -0
  24. warp/examples/assets/pixel.jpg +0 -0
  25. warp/examples/assets/quadruped.urdf +268 -0
  26. warp/examples/assets/rocks.nvdb +0 -0
  27. warp/examples/assets/rocks.usd +0 -0
  28. warp/examples/assets/sphere.usd +0 -0
  29. warp/examples/assets/square_cloth.usd +0 -0
  30. warp/examples/benchmarks/benchmark_api.py +389 -0
  31. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  32. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  33. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  34. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  35. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  36. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  37. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  38. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  39. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  40. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  41. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  42. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  43. warp/examples/benchmarks/benchmark_launches.py +301 -0
  44. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  45. warp/examples/browse.py +37 -0
  46. warp/examples/core/example_cupy.py +86 -0
  47. warp/examples/core/example_dem.py +241 -0
  48. warp/examples/core/example_fluid.py +299 -0
  49. warp/examples/core/example_graph_capture.py +150 -0
  50. warp/examples/core/example_marching_cubes.py +194 -0
  51. warp/examples/core/example_mesh.py +180 -0
  52. warp/examples/core/example_mesh_intersect.py +211 -0
  53. warp/examples/core/example_nvdb.py +182 -0
  54. warp/examples/core/example_raycast.py +111 -0
  55. warp/examples/core/example_raymarch.py +205 -0
  56. warp/examples/core/example_render_opengl.py +193 -0
  57. warp/examples/core/example_sample_mesh.py +300 -0
  58. warp/examples/core/example_sph.py +411 -0
  59. warp/examples/core/example_torch.py +211 -0
  60. warp/examples/core/example_wave.py +269 -0
  61. warp/examples/fem/example_adaptive_grid.py +286 -0
  62. warp/examples/fem/example_apic_fluid.py +423 -0
  63. warp/examples/fem/example_burgers.py +261 -0
  64. warp/examples/fem/example_convection_diffusion.py +178 -0
  65. warp/examples/fem/example_convection_diffusion_dg.py +204 -0
  66. warp/examples/fem/example_deformed_geometry.py +172 -0
  67. warp/examples/fem/example_diffusion.py +196 -0
  68. warp/examples/fem/example_diffusion_3d.py +225 -0
  69. warp/examples/fem/example_diffusion_mgpu.py +220 -0
  70. warp/examples/fem/example_distortion_energy.py +228 -0
  71. warp/examples/fem/example_magnetostatics.py +240 -0
  72. warp/examples/fem/example_mixed_elasticity.py +291 -0
  73. warp/examples/fem/example_navier_stokes.py +261 -0
  74. warp/examples/fem/example_nonconforming_contact.py +298 -0
  75. warp/examples/fem/example_stokes.py +213 -0
  76. warp/examples/fem/example_stokes_transfer.py +262 -0
  77. warp/examples/fem/example_streamlines.py +352 -0
  78. warp/examples/fem/utils.py +1000 -0
  79. warp/examples/interop/example_jax_callable.py +116 -0
  80. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  81. warp/examples/interop/example_jax_kernel.py +205 -0
  82. warp/examples/optim/example_bounce.py +266 -0
  83. warp/examples/optim/example_cloth_throw.py +228 -0
  84. warp/examples/optim/example_diffray.py +561 -0
  85. warp/examples/optim/example_drone.py +870 -0
  86. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  87. warp/examples/optim/example_inverse_kinematics.py +182 -0
  88. warp/examples/optim/example_inverse_kinematics_torch.py +191 -0
  89. warp/examples/optim/example_softbody_properties.py +400 -0
  90. warp/examples/optim/example_spring_cage.py +245 -0
  91. warp/examples/optim/example_trajectory.py +227 -0
  92. warp/examples/sim/example_cartpole.py +143 -0
  93. warp/examples/sim/example_cloth.py +225 -0
  94. warp/examples/sim/example_cloth_self_contact.py +322 -0
  95. warp/examples/sim/example_granular.py +130 -0
  96. warp/examples/sim/example_granular_collision_sdf.py +202 -0
  97. warp/examples/sim/example_jacobian_ik.py +244 -0
  98. warp/examples/sim/example_particle_chain.py +124 -0
  99. warp/examples/sim/example_quadruped.py +203 -0
  100. warp/examples/sim/example_rigid_chain.py +203 -0
  101. warp/examples/sim/example_rigid_contact.py +195 -0
  102. warp/examples/sim/example_rigid_force.py +133 -0
  103. warp/examples/sim/example_rigid_gyroscopic.py +115 -0
  104. warp/examples/sim/example_rigid_soft_contact.py +140 -0
  105. warp/examples/sim/example_soft_body.py +196 -0
  106. warp/examples/tile/example_tile_cholesky.py +87 -0
  107. warp/examples/tile/example_tile_convolution.py +66 -0
  108. warp/examples/tile/example_tile_fft.py +55 -0
  109. warp/examples/tile/example_tile_filtering.py +113 -0
  110. warp/examples/tile/example_tile_matmul.py +85 -0
  111. warp/examples/tile/example_tile_mlp.py +383 -0
  112. warp/examples/tile/example_tile_nbody.py +199 -0
  113. warp/examples/tile/example_tile_walker.py +327 -0
  114. warp/fabric.py +355 -0
  115. warp/fem/__init__.py +106 -0
  116. warp/fem/adaptivity.py +508 -0
  117. warp/fem/cache.py +572 -0
  118. warp/fem/dirichlet.py +202 -0
  119. warp/fem/domain.py +411 -0
  120. warp/fem/field/__init__.py +125 -0
  121. warp/fem/field/field.py +619 -0
  122. warp/fem/field/nodal_field.py +326 -0
  123. warp/fem/field/restriction.py +37 -0
  124. warp/fem/field/virtual.py +848 -0
  125. warp/fem/geometry/__init__.py +32 -0
  126. warp/fem/geometry/adaptive_nanogrid.py +857 -0
  127. warp/fem/geometry/closest_point.py +84 -0
  128. warp/fem/geometry/deformed_geometry.py +221 -0
  129. warp/fem/geometry/element.py +776 -0
  130. warp/fem/geometry/geometry.py +362 -0
  131. warp/fem/geometry/grid_2d.py +392 -0
  132. warp/fem/geometry/grid_3d.py +452 -0
  133. warp/fem/geometry/hexmesh.py +911 -0
  134. warp/fem/geometry/nanogrid.py +571 -0
  135. warp/fem/geometry/partition.py +389 -0
  136. warp/fem/geometry/quadmesh.py +663 -0
  137. warp/fem/geometry/tetmesh.py +855 -0
  138. warp/fem/geometry/trimesh.py +806 -0
  139. warp/fem/integrate.py +2335 -0
  140. warp/fem/linalg.py +419 -0
  141. warp/fem/operator.py +293 -0
  142. warp/fem/polynomial.py +229 -0
  143. warp/fem/quadrature/__init__.py +17 -0
  144. warp/fem/quadrature/pic_quadrature.py +299 -0
  145. warp/fem/quadrature/quadrature.py +591 -0
  146. warp/fem/space/__init__.py +228 -0
  147. warp/fem/space/basis_function_space.py +468 -0
  148. warp/fem/space/basis_space.py +667 -0
  149. warp/fem/space/dof_mapper.py +251 -0
  150. warp/fem/space/function_space.py +309 -0
  151. warp/fem/space/grid_2d_function_space.py +177 -0
  152. warp/fem/space/grid_3d_function_space.py +227 -0
  153. warp/fem/space/hexmesh_function_space.py +257 -0
  154. warp/fem/space/nanogrid_function_space.py +201 -0
  155. warp/fem/space/partition.py +367 -0
  156. warp/fem/space/quadmesh_function_space.py +223 -0
  157. warp/fem/space/restriction.py +179 -0
  158. warp/fem/space/shape/__init__.py +143 -0
  159. warp/fem/space/shape/cube_shape_function.py +1105 -0
  160. warp/fem/space/shape/shape_function.py +133 -0
  161. warp/fem/space/shape/square_shape_function.py +926 -0
  162. warp/fem/space/shape/tet_shape_function.py +834 -0
  163. warp/fem/space/shape/triangle_shape_function.py +672 -0
  164. warp/fem/space/tetmesh_function_space.py +271 -0
  165. warp/fem/space/topology.py +424 -0
  166. warp/fem/space/trimesh_function_space.py +194 -0
  167. warp/fem/types.py +99 -0
  168. warp/fem/utils.py +420 -0
  169. warp/jax.py +187 -0
  170. warp/jax_experimental/__init__.py +16 -0
  171. warp/jax_experimental/custom_call.py +351 -0
  172. warp/jax_experimental/ffi.py +698 -0
  173. warp/jax_experimental/xla_ffi.py +602 -0
  174. warp/math.py +244 -0
  175. warp/native/array.h +1145 -0
  176. warp/native/builtin.h +1800 -0
  177. warp/native/bvh.cpp +492 -0
  178. warp/native/bvh.cu +791 -0
  179. warp/native/bvh.h +554 -0
  180. warp/native/clang/clang.cpp +536 -0
  181. warp/native/coloring.cpp +613 -0
  182. warp/native/crt.cpp +51 -0
  183. warp/native/crt.h +362 -0
  184. warp/native/cuda_crt.h +1058 -0
  185. warp/native/cuda_util.cpp +646 -0
  186. warp/native/cuda_util.h +307 -0
  187. warp/native/error.cpp +77 -0
  188. warp/native/error.h +36 -0
  189. warp/native/exports.h +1878 -0
  190. warp/native/fabric.h +245 -0
  191. warp/native/hashgrid.cpp +311 -0
  192. warp/native/hashgrid.cu +87 -0
  193. warp/native/hashgrid.h +240 -0
  194. warp/native/initializer_array.h +41 -0
  195. warp/native/intersect.h +1230 -0
  196. warp/native/intersect_adj.h +375 -0
  197. warp/native/intersect_tri.h +339 -0
  198. warp/native/marching.cpp +19 -0
  199. warp/native/marching.cu +514 -0
  200. warp/native/marching.h +19 -0
  201. warp/native/mat.h +2220 -0
  202. warp/native/mathdx.cpp +87 -0
  203. warp/native/matnn.h +343 -0
  204. warp/native/mesh.cpp +266 -0
  205. warp/native/mesh.cu +404 -0
  206. warp/native/mesh.h +1980 -0
  207. warp/native/nanovdb/GridHandle.h +366 -0
  208. warp/native/nanovdb/HostBuffer.h +590 -0
  209. warp/native/nanovdb/NanoVDB.h +6624 -0
  210. warp/native/nanovdb/PNanoVDB.h +3390 -0
  211. warp/native/noise.h +859 -0
  212. warp/native/quat.h +1371 -0
  213. warp/native/rand.h +342 -0
  214. warp/native/range.h +139 -0
  215. warp/native/reduce.cpp +174 -0
  216. warp/native/reduce.cu +364 -0
  217. warp/native/runlength_encode.cpp +79 -0
  218. warp/native/runlength_encode.cu +61 -0
  219. warp/native/scan.cpp +47 -0
  220. warp/native/scan.cu +53 -0
  221. warp/native/scan.h +23 -0
  222. warp/native/solid_angle.h +466 -0
  223. warp/native/sort.cpp +251 -0
  224. warp/native/sort.cu +277 -0
  225. warp/native/sort.h +33 -0
  226. warp/native/sparse.cpp +378 -0
  227. warp/native/sparse.cu +524 -0
  228. warp/native/spatial.h +657 -0
  229. warp/native/svd.h +702 -0
  230. warp/native/temp_buffer.h +46 -0
  231. warp/native/tile.h +2584 -0
  232. warp/native/tile_reduce.h +264 -0
  233. warp/native/vec.h +1426 -0
  234. warp/native/volume.cpp +501 -0
  235. warp/native/volume.cu +67 -0
  236. warp/native/volume.h +969 -0
  237. warp/native/volume_builder.cu +477 -0
  238. warp/native/volume_builder.h +52 -0
  239. warp/native/volume_impl.h +70 -0
  240. warp/native/warp.cpp +1082 -0
  241. warp/native/warp.cu +3636 -0
  242. warp/native/warp.h +381 -0
  243. warp/optim/__init__.py +17 -0
  244. warp/optim/adam.py +163 -0
  245. warp/optim/linear.py +1137 -0
  246. warp/optim/sgd.py +112 -0
  247. warp/paddle.py +407 -0
  248. warp/render/__init__.py +18 -0
  249. warp/render/render_opengl.py +3518 -0
  250. warp/render/render_usd.py +784 -0
  251. warp/render/utils.py +160 -0
  252. warp/sim/__init__.py +65 -0
  253. warp/sim/articulation.py +793 -0
  254. warp/sim/collide.py +2395 -0
  255. warp/sim/graph_coloring.py +300 -0
  256. warp/sim/import_mjcf.py +790 -0
  257. warp/sim/import_snu.py +227 -0
  258. warp/sim/import_urdf.py +579 -0
  259. warp/sim/import_usd.py +894 -0
  260. warp/sim/inertia.py +324 -0
  261. warp/sim/integrator.py +242 -0
  262. warp/sim/integrator_euler.py +1997 -0
  263. warp/sim/integrator_featherstone.py +2101 -0
  264. warp/sim/integrator_vbd.py +2048 -0
  265. warp/sim/integrator_xpbd.py +3292 -0
  266. warp/sim/model.py +4791 -0
  267. warp/sim/particles.py +121 -0
  268. warp/sim/render.py +427 -0
  269. warp/sim/utils.py +428 -0
  270. warp/sparse.py +2057 -0
  271. warp/stubs.py +3333 -0
  272. warp/tape.py +1203 -0
  273. warp/tests/__init__.py +1 -0
  274. warp/tests/__main__.py +4 -0
  275. warp/tests/assets/curlnoise_golden.npy +0 -0
  276. warp/tests/assets/mlp_golden.npy +0 -0
  277. warp/tests/assets/pixel.npy +0 -0
  278. warp/tests/assets/pnoise_golden.npy +0 -0
  279. warp/tests/assets/spiky.usd +0 -0
  280. warp/tests/assets/test_grid.nvdb +0 -0
  281. warp/tests/assets/test_index_grid.nvdb +0 -0
  282. warp/tests/assets/test_int32_grid.nvdb +0 -0
  283. warp/tests/assets/test_vec_grid.nvdb +0 -0
  284. warp/tests/assets/torus.nvdb +0 -0
  285. warp/tests/assets/torus.usda +105 -0
  286. warp/tests/aux_test_class_kernel.py +34 -0
  287. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  288. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  289. warp/tests/aux_test_dependent.py +29 -0
  290. warp/tests/aux_test_grad_customs.py +29 -0
  291. warp/tests/aux_test_instancing_gc.py +26 -0
  292. warp/tests/aux_test_module_unload.py +23 -0
  293. warp/tests/aux_test_name_clash1.py +40 -0
  294. warp/tests/aux_test_name_clash2.py +40 -0
  295. warp/tests/aux_test_reference.py +9 -0
  296. warp/tests/aux_test_reference_reference.py +8 -0
  297. warp/tests/aux_test_square.py +16 -0
  298. warp/tests/aux_test_unresolved_func.py +22 -0
  299. warp/tests/aux_test_unresolved_symbol.py +22 -0
  300. warp/tests/cuda/__init__.py +0 -0
  301. warp/tests/cuda/test_async.py +676 -0
  302. warp/tests/cuda/test_ipc.py +124 -0
  303. warp/tests/cuda/test_mempool.py +233 -0
  304. warp/tests/cuda/test_multigpu.py +169 -0
  305. warp/tests/cuda/test_peer.py +139 -0
  306. warp/tests/cuda/test_pinned.py +84 -0
  307. warp/tests/cuda/test_streams.py +634 -0
  308. warp/tests/geometry/__init__.py +0 -0
  309. warp/tests/geometry/test_bvh.py +200 -0
  310. warp/tests/geometry/test_hash_grid.py +221 -0
  311. warp/tests/geometry/test_marching_cubes.py +74 -0
  312. warp/tests/geometry/test_mesh.py +316 -0
  313. warp/tests/geometry/test_mesh_query_aabb.py +399 -0
  314. warp/tests/geometry/test_mesh_query_point.py +932 -0
  315. warp/tests/geometry/test_mesh_query_ray.py +311 -0
  316. warp/tests/geometry/test_volume.py +1103 -0
  317. warp/tests/geometry/test_volume_write.py +346 -0
  318. warp/tests/interop/__init__.py +0 -0
  319. warp/tests/interop/test_dlpack.py +729 -0
  320. warp/tests/interop/test_jax.py +371 -0
  321. warp/tests/interop/test_paddle.py +800 -0
  322. warp/tests/interop/test_torch.py +1001 -0
  323. warp/tests/run_coverage_serial.py +39 -0
  324. warp/tests/sim/__init__.py +0 -0
  325. warp/tests/sim/disabled_kinematics.py +244 -0
  326. warp/tests/sim/flaky_test_sim_grad.py +290 -0
  327. warp/tests/sim/test_collision.py +604 -0
  328. warp/tests/sim/test_coloring.py +258 -0
  329. warp/tests/sim/test_model.py +224 -0
  330. warp/tests/sim/test_sim_grad_bounce_linear.py +212 -0
  331. warp/tests/sim/test_sim_kinematics.py +98 -0
  332. warp/tests/sim/test_vbd.py +597 -0
  333. warp/tests/test_adam.py +163 -0
  334. warp/tests/test_arithmetic.py +1096 -0
  335. warp/tests/test_array.py +2972 -0
  336. warp/tests/test_array_reduce.py +156 -0
  337. warp/tests/test_assert.py +250 -0
  338. warp/tests/test_atomic.py +153 -0
  339. warp/tests/test_bool.py +220 -0
  340. warp/tests/test_builtins_resolution.py +1298 -0
  341. warp/tests/test_closest_point_edge_edge.py +327 -0
  342. warp/tests/test_codegen.py +810 -0
  343. warp/tests/test_codegen_instancing.py +1495 -0
  344. warp/tests/test_compile_consts.py +215 -0
  345. warp/tests/test_conditional.py +252 -0
  346. warp/tests/test_context.py +42 -0
  347. warp/tests/test_copy.py +238 -0
  348. warp/tests/test_ctypes.py +638 -0
  349. warp/tests/test_dense.py +73 -0
  350. warp/tests/test_devices.py +97 -0
  351. warp/tests/test_examples.py +482 -0
  352. warp/tests/test_fabricarray.py +996 -0
  353. warp/tests/test_fast_math.py +74 -0
  354. warp/tests/test_fem.py +2003 -0
  355. warp/tests/test_fp16.py +136 -0
  356. warp/tests/test_func.py +454 -0
  357. warp/tests/test_future_annotations.py +98 -0
  358. warp/tests/test_generics.py +656 -0
  359. warp/tests/test_grad.py +893 -0
  360. warp/tests/test_grad_customs.py +339 -0
  361. warp/tests/test_grad_debug.py +341 -0
  362. warp/tests/test_implicit_init.py +411 -0
  363. warp/tests/test_import.py +45 -0
  364. warp/tests/test_indexedarray.py +1140 -0
  365. warp/tests/test_intersect.py +73 -0
  366. warp/tests/test_iter.py +76 -0
  367. warp/tests/test_large.py +177 -0
  368. warp/tests/test_launch.py +411 -0
  369. warp/tests/test_lerp.py +151 -0
  370. warp/tests/test_linear_solvers.py +193 -0
  371. warp/tests/test_lvalue.py +427 -0
  372. warp/tests/test_mat.py +2089 -0
  373. warp/tests/test_mat_lite.py +122 -0
  374. warp/tests/test_mat_scalar_ops.py +2913 -0
  375. warp/tests/test_math.py +178 -0
  376. warp/tests/test_mlp.py +282 -0
  377. warp/tests/test_module_hashing.py +258 -0
  378. warp/tests/test_modules_lite.py +44 -0
  379. warp/tests/test_noise.py +252 -0
  380. warp/tests/test_operators.py +299 -0
  381. warp/tests/test_options.py +129 -0
  382. warp/tests/test_overwrite.py +551 -0
  383. warp/tests/test_print.py +339 -0
  384. warp/tests/test_quat.py +2315 -0
  385. warp/tests/test_rand.py +339 -0
  386. warp/tests/test_reload.py +302 -0
  387. warp/tests/test_rounding.py +185 -0
  388. warp/tests/test_runlength_encode.py +196 -0
  389. warp/tests/test_scalar_ops.py +105 -0
  390. warp/tests/test_smoothstep.py +108 -0
  391. warp/tests/test_snippet.py +318 -0
  392. warp/tests/test_sparse.py +582 -0
  393. warp/tests/test_spatial.py +2229 -0
  394. warp/tests/test_special_values.py +361 -0
  395. warp/tests/test_static.py +592 -0
  396. warp/tests/test_struct.py +734 -0
  397. warp/tests/test_tape.py +204 -0
  398. warp/tests/test_transient_module.py +93 -0
  399. warp/tests/test_triangle_closest_point.py +145 -0
  400. warp/tests/test_types.py +562 -0
  401. warp/tests/test_utils.py +588 -0
  402. warp/tests/test_vec.py +1487 -0
  403. warp/tests/test_vec_lite.py +80 -0
  404. warp/tests/test_vec_scalar_ops.py +2327 -0
  405. warp/tests/test_verify_fp.py +100 -0
  406. warp/tests/tile/__init__.py +0 -0
  407. warp/tests/tile/test_tile.py +780 -0
  408. warp/tests/tile/test_tile_load.py +407 -0
  409. warp/tests/tile/test_tile_mathdx.py +208 -0
  410. warp/tests/tile/test_tile_mlp.py +402 -0
  411. warp/tests/tile/test_tile_reduce.py +447 -0
  412. warp/tests/tile/test_tile_shared_memory.py +247 -0
  413. warp/tests/tile/test_tile_view.py +173 -0
  414. warp/tests/unittest_serial.py +47 -0
  415. warp/tests/unittest_suites.py +427 -0
  416. warp/tests/unittest_utils.py +468 -0
  417. warp/tests/walkthrough_debug.py +93 -0
  418. warp/thirdparty/__init__.py +0 -0
  419. warp/thirdparty/appdirs.py +598 -0
  420. warp/thirdparty/dlpack.py +145 -0
  421. warp/thirdparty/unittest_parallel.py +570 -0
  422. warp/torch.py +391 -0
  423. warp/types.py +5230 -0
  424. warp/utils.py +1137 -0
  425. warp_lang-1.7.0.dist-info/METADATA +516 -0
  426. warp_lang-1.7.0.dist-info/RECORD +429 -0
  427. warp_lang-1.7.0.dist-info/WHEEL +5 -0
  428. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  429. warp_lang-1.7.0.dist-info/top_level.txt +1 -0
warp/native/builtin.h ADDED
@@ -0,0 +1,1800 @@
1
+ /*
2
+ * SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: Apache-2.0
4
+ *
5
+ * Licensed under the Apache License, Version 2.0 (the "License");
6
+ * you may not use this file except in compliance with the License.
7
+ * You may obtain a copy of the License at
8
+ *
9
+ * http://www.apache.org/licenses/LICENSE-2.0
10
+ *
11
+ * Unless required by applicable law or agreed to in writing, software
12
+ * distributed under the License is distributed on an "AS IS" BASIS,
13
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ * See the License for the specific language governing permissions and
15
+ * limitations under the License.
16
+ */
17
+
18
+ #pragma once
19
+
20
+ // All built-in types and functions. To be compatible with runtime NVRTC compilation
21
+ // this header must be independently compilable (i.e.: without external SDK headers)
22
+ // to achieve this we redefine a subset of CRT functions (printf, pow, sin, cos, etc)
23
+
24
+ #include "crt.h"
25
+
26
+ #ifdef _WIN32
27
+ #define __restrict__ __restrict
28
+ #endif
29
+
30
+ #if !defined(__CUDACC__)
31
+ #define CUDA_CALLABLE
32
+ #define CUDA_CALLABLE_DEVICE
33
+ #else
34
+ #define CUDA_CALLABLE __host__ __device__
35
+ #define CUDA_CALLABLE_DEVICE __device__
36
+ #endif
37
+
38
+ #ifdef WP_VERIFY_FP
39
+ #define FP_CHECK 1
40
+ #define DO_IF_FPCHECK(X) {X}
41
+ #define DO_IF_NO_FPCHECK(X)
42
+ #else
43
+ #define FP_CHECK 0
44
+ #define DO_IF_FPCHECK(X)
45
+ #define DO_IF_NO_FPCHECK(X) {X}
46
+ #endif
47
+
48
+ #define RAD_TO_DEG 57.29577951308232087679
49
+ #define DEG_TO_RAD 0.01745329251994329577
50
+
51
+ #if defined(__CUDACC__) && !defined(_MSC_VER)
52
+ __device__ void __debugbreak() {}
53
+ #endif
54
+
55
+ namespace wp
56
+ {
57
+
58
+ // numeric types (used from generated kernels)
59
+ typedef float float32;
60
+ typedef double float64;
61
+
62
+ typedef int8_t int8;
63
+ typedef uint8_t uint8;
64
+
65
+ typedef int16_t int16;
66
+ typedef uint16_t uint16;
67
+
68
+ typedef int32_t int32;
69
+ typedef uint32_t uint32;
70
+
71
+ typedef int64_t int64;
72
+ typedef uint64_t uint64;
73
+
74
+
75
+ // matches Python string type for constant strings
76
+ typedef const char* str;
77
+
78
+
79
+
80
+ struct half;
81
+
82
+ CUDA_CALLABLE half float_to_half(float x);
83
+ CUDA_CALLABLE float half_to_float(half x);
84
+
85
+ struct half
86
+ {
87
+ CUDA_CALLABLE inline half() : u(0) {}
88
+
89
+ CUDA_CALLABLE inline half(float f)
90
+ {
91
+ *this = float_to_half(f);
92
+ }
93
+
94
+ unsigned short u;
95
+
96
+ CUDA_CALLABLE inline bool operator==(const half& h) const
97
+ {
98
+ // Use float32 to get IEEE 754 behavior in case of a NaN
99
+ return float32(h) == float32(*this);
100
+ }
101
+
102
+ CUDA_CALLABLE inline bool operator!=(const half& h) const
103
+ {
104
+ // Use float32 to get IEEE 754 behavior in case of a NaN
105
+ return float32(h) != float32(*this);
106
+ }
107
+ CUDA_CALLABLE inline bool operator>(const half& h) const { return half_to_float(*this) > half_to_float(h); }
108
+ CUDA_CALLABLE inline bool operator>=(const half& h) const { return half_to_float(*this) >= half_to_float(h); }
109
+ CUDA_CALLABLE inline bool operator<(const half& h) const { return half_to_float(*this) < half_to_float(h); }
110
+ CUDA_CALLABLE inline bool operator<=(const half& h) const { return half_to_float(*this) <= half_to_float(h); }
111
+
112
+ CUDA_CALLABLE inline bool operator!() const
113
+ {
114
+ return float32(*this) == 0;
115
+ }
116
+
117
+ CUDA_CALLABLE inline half operator*=(const half& h)
118
+ {
119
+ half prod = half(float32(*this) * float32(h));
120
+ this->u = prod.u;
121
+ return *this;
122
+ }
123
+
124
+ CUDA_CALLABLE inline half operator/=(const half& h)
125
+ {
126
+ half quot = half(float32(*this) / float32(h));
127
+ this->u = quot.u;
128
+ return *this;
129
+ }
130
+
131
+ CUDA_CALLABLE inline half operator+=(const half& h)
132
+ {
133
+ half sum = half(float32(*this) + float32(h));
134
+ this->u = sum.u;
135
+ return *this;
136
+ }
137
+
138
+ CUDA_CALLABLE inline half operator-=(const half& h)
139
+ {
140
+ half diff = half(float32(*this) - float32(h));
141
+ this->u = diff.u;
142
+ return *this;
143
+ }
144
+
145
+ CUDA_CALLABLE inline operator float32() const { return float32(half_to_float(*this)); }
146
+ CUDA_CALLABLE inline operator float64() const { return float64(half_to_float(*this)); }
147
+ CUDA_CALLABLE inline operator int8() const { return int8(half_to_float(*this)); }
148
+ CUDA_CALLABLE inline operator uint8() const { return uint8(half_to_float(*this)); }
149
+ CUDA_CALLABLE inline operator int16() const { return int16(half_to_float(*this)); }
150
+ CUDA_CALLABLE inline operator uint16() const { return uint16(half_to_float(*this)); }
151
+ CUDA_CALLABLE inline operator int32() const { return int32(half_to_float(*this)); }
152
+ CUDA_CALLABLE inline operator uint32() const { return uint32(half_to_float(*this)); }
153
+ CUDA_CALLABLE inline operator int64() const { return int64(half_to_float(*this)); }
154
+ CUDA_CALLABLE inline operator uint64() const { return uint64(half_to_float(*this)); }
155
+ };
156
+
157
+ static_assert(sizeof(half) == 2, "Size of half / float16 type must be 2-bytes");
158
+
159
+ typedef half float16;
160
+
161
+ #if defined(__CUDA_ARCH__)
162
+
163
+ CUDA_CALLABLE inline half float_to_half(float x)
164
+ {
165
+ half h;
166
+ asm("{ cvt.rn.f16.f32 %0, %1;}\n" : "=h"(h.u) : "f"(x));
167
+ return h;
168
+ }
169
+
170
+ CUDA_CALLABLE inline float half_to_float(half x)
171
+ {
172
+ float val;
173
+ asm("{ cvt.f32.f16 %0, %1;}\n" : "=f"(val) : "h"(x.u));
174
+ return val;
175
+ }
176
+
177
+ #elif defined(__clang__)
178
+
179
+ // _Float16 is Clang's native half-precision floating-point type
180
+ inline half float_to_half(float x)
181
+ {
182
+
183
+ _Float16 f16 = static_cast<_Float16>(x);
184
+ return *reinterpret_cast<half*>(&f16);
185
+ }
186
+
187
+ inline float half_to_float(half h)
188
+ {
189
+ _Float16 f16 = *reinterpret_cast<_Float16*>(&h);
190
+ return static_cast<float>(f16);
191
+ }
192
+
193
+ #else // Native C++ for Warp builtins outside of kernels
194
+
195
+ extern "C" WP_API uint16_t float_to_half_bits(float x);
196
+ extern "C" WP_API float half_bits_to_float(uint16_t u);
197
+
198
+ inline half float_to_half(float x)
199
+ {
200
+ half h;
201
+ h.u = float_to_half_bits(x);
202
+ return h;
203
+ }
204
+
205
+ inline float half_to_float(half h)
206
+ {
207
+ return half_bits_to_float(h.u);
208
+ }
209
+
210
+ #endif
211
+
212
+
213
+ // BAD operator implementations for fp16 arithmetic...
214
+
215
+ // negation:
216
+ inline CUDA_CALLABLE half operator - (half a)
217
+ {
218
+ return float_to_half( -half_to_float(a) );
219
+ }
220
+
221
+ inline CUDA_CALLABLE half operator + (half a,half b)
222
+ {
223
+ return float_to_half( half_to_float(a) + half_to_float(b) );
224
+ }
225
+
226
+ inline CUDA_CALLABLE half operator - (half a,half b)
227
+ {
228
+ return float_to_half( half_to_float(a) - half_to_float(b) );
229
+ }
230
+
231
+ inline CUDA_CALLABLE half operator * (half a,half b)
232
+ {
233
+ return float_to_half( half_to_float(a) * half_to_float(b) );
234
+ }
235
+
236
+ inline CUDA_CALLABLE half operator * (half a,float b)
237
+ {
238
+ return float_to_half( half_to_float(a) * b );
239
+ }
240
+
241
+ inline CUDA_CALLABLE half operator * (float a,half b)
242
+ {
243
+ return float_to_half( a * half_to_float(b) );
244
+ }
245
+
246
+ inline CUDA_CALLABLE half operator * (half a,double b)
247
+ {
248
+ return float_to_half( half_to_float(a) * b );
249
+ }
250
+
251
+ inline CUDA_CALLABLE half operator * (double a,half b)
252
+ {
253
+ return float_to_half( a * half_to_float(b) );
254
+ }
255
+
256
+ inline CUDA_CALLABLE half operator / (half a,half b)
257
+ {
258
+ return float_to_half( half_to_float(a) / half_to_float(b) );
259
+ }
260
+
261
+
262
+
263
+
264
+
265
+ template <typename T>
266
+ CUDA_CALLABLE float cast_float(T x) { return (float)(x); }
267
+
268
+ template <typename T>
269
+ CUDA_CALLABLE int cast_int(T x) { return (int)(x); }
270
+
271
+ template <typename T>
272
+ CUDA_CALLABLE void adj_cast_float(T x, T& adj_x, float adj_ret) { adj_x += T(adj_ret); }
273
+
274
+ template <typename T>
275
+ CUDA_CALLABLE void adj_cast_int(T x, T& adj_x, int adj_ret) { adj_x += adj_ret; }
276
+
277
+ template <typename T>
278
+ CUDA_CALLABLE inline void adj_int8(T, T&, int8) {}
279
+ template <typename T>
280
+ CUDA_CALLABLE inline void adj_uint8(T, T&, uint8) {}
281
+ template <typename T>
282
+ CUDA_CALLABLE inline void adj_int16(T, T&, int16) {}
283
+ template <typename T>
284
+ CUDA_CALLABLE inline void adj_uint16(T, T&, uint16) {}
285
+ template <typename T>
286
+ CUDA_CALLABLE inline void adj_int32(T, T&, int32) {}
287
+ template <typename T>
288
+ CUDA_CALLABLE inline void adj_uint32(T, T&, uint32) {}
289
+ template <typename T>
290
+ CUDA_CALLABLE inline void adj_int64(T, T&, int64) {}
291
+ template <typename T>
292
+ CUDA_CALLABLE inline void adj_uint64(T, T&, uint64) {}
293
+
294
+
295
+ template <typename T>
296
+ CUDA_CALLABLE inline void adj_float16(T x, T& adj_x, float16 adj_ret) { adj_x += T(adj_ret); }
297
+ template <typename T>
298
+ CUDA_CALLABLE inline void adj_float32(T x, T& adj_x, float32 adj_ret) { adj_x += T(adj_ret); }
299
+ template <typename T>
300
+ CUDA_CALLABLE inline void adj_float64(T x, T& adj_x, float64 adj_ret) { adj_x += T(adj_ret); }
301
+
302
+
303
+ #define kEps 0.0f
304
+
305
+ // basic ops for integer types
306
+ #define DECLARE_INT_OPS(T) \
307
+ inline CUDA_CALLABLE T mul(T a, T b) { return a*b; } \
308
+ inline CUDA_CALLABLE T div(T a, T b) { return a/b; } \
309
+ inline CUDA_CALLABLE T add(T a, T b) { return a+b; } \
310
+ inline CUDA_CALLABLE T sub(T a, T b) { return a-b; } \
311
+ inline CUDA_CALLABLE T mod(T a, T b) { return a%b; } \
312
+ inline CUDA_CALLABLE T min(T a, T b) { return a<b?a:b; } \
313
+ inline CUDA_CALLABLE T max(T a, T b) { return a>b?a:b; } \
314
+ inline CUDA_CALLABLE T clamp(T x, T a, T b) { return min(max(a, x), b); } \
315
+ inline CUDA_CALLABLE T floordiv(T a, T b) { return a/b; } \
316
+ inline CUDA_CALLABLE T nonzero(T x) { return x == T(0) ? T(0) : T(1); } \
317
+ inline CUDA_CALLABLE T sqrt(T x) { return 0; } \
318
+ inline CUDA_CALLABLE T bit_and(T a, T b) { return a&b; } \
319
+ inline CUDA_CALLABLE T bit_or(T a, T b) { return a|b; } \
320
+ inline CUDA_CALLABLE T bit_xor(T a, T b) { return a^b; } \
321
+ inline CUDA_CALLABLE T lshift(T a, T b) { return a<<b; } \
322
+ inline CUDA_CALLABLE T rshift(T a, T b) { return a>>b; } \
323
+ inline CUDA_CALLABLE T invert(T x) { return ~x; } \
324
+ inline CUDA_CALLABLE bool isfinite(T x) { return ::isfinite(double(x)); } \
325
+ inline CUDA_CALLABLE bool isnan(T x) { return ::isnan(double(x)); } \
326
+ inline CUDA_CALLABLE bool isinf(T x) { return ::isinf(double(x)); } \
327
+ inline CUDA_CALLABLE void adj_mul(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
328
+ inline CUDA_CALLABLE void adj_div(T a, T b, T ret, T& adj_a, T& adj_b, T adj_ret) { } \
329
+ inline CUDA_CALLABLE void adj_add(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
330
+ inline CUDA_CALLABLE void adj_sub(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
331
+ inline CUDA_CALLABLE void adj_mod(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
332
+ inline CUDA_CALLABLE void adj_min(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
333
+ inline CUDA_CALLABLE void adj_max(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
334
+ inline CUDA_CALLABLE void adj_abs(T x, T adj_x, T& adj_ret) { } \
335
+ inline CUDA_CALLABLE void adj_sign(T x, T adj_x, T& adj_ret) { } \
336
+ inline CUDA_CALLABLE void adj_clamp(T x, T a, T b, T& adj_x, T& adj_a, T& adj_b, T adj_ret) { } \
337
+ inline CUDA_CALLABLE void adj_floordiv(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
338
+ inline CUDA_CALLABLE void adj_step(T x, T& adj_x, T adj_ret) { } \
339
+ inline CUDA_CALLABLE void adj_nonzero(T x, T& adj_x, T adj_ret) { } \
340
+ inline CUDA_CALLABLE void adj_sqrt(T x, T adj_x, T& adj_ret) { } \
341
+ inline CUDA_CALLABLE void adj_bit_and(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
342
+ inline CUDA_CALLABLE void adj_bit_or(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
343
+ inline CUDA_CALLABLE void adj_bit_xor(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
344
+ inline CUDA_CALLABLE void adj_lshift(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
345
+ inline CUDA_CALLABLE void adj_rshift(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
346
+ inline CUDA_CALLABLE void adj_invert(T x, T adj_x, T& adj_ret) { } \
347
+ inline CUDA_CALLABLE void adj_isnan(const T&, T&, bool) { } \
348
+ inline CUDA_CALLABLE void adj_isinf(const T&, T&, bool) { } \
349
+ inline CUDA_CALLABLE void adj_isfinite(const T&, T&, bool) { }
350
+
351
+ inline CUDA_CALLABLE int8 abs(int8 x) { return ::abs(x); }
352
+ inline CUDA_CALLABLE int16 abs(int16 x) { return ::abs(x); }
353
+ inline CUDA_CALLABLE int32 abs(int32 x) { return ::abs(x); }
354
+ inline CUDA_CALLABLE int64 abs(int64 x) { return ::llabs(x); }
355
+ inline CUDA_CALLABLE uint8 abs(uint8 x) { return x; }
356
+ inline CUDA_CALLABLE uint16 abs(uint16 x) { return x; }
357
+ inline CUDA_CALLABLE uint32 abs(uint32 x) { return x; }
358
+ inline CUDA_CALLABLE uint64 abs(uint64 x) { return x; }
359
+
360
+ DECLARE_INT_OPS(int8)
361
+ DECLARE_INT_OPS(int16)
362
+ DECLARE_INT_OPS(int32)
363
+ DECLARE_INT_OPS(int64)
364
+ DECLARE_INT_OPS(uint8)
365
+ DECLARE_INT_OPS(uint16)
366
+ DECLARE_INT_OPS(uint32)
367
+ DECLARE_INT_OPS(uint64)
368
+
369
+
370
+ inline CUDA_CALLABLE int8 step(int8 x) { return x < 0 ? 1 : 0; }
371
+ inline CUDA_CALLABLE int16 step(int16 x) { return x < 0 ? 1 : 0; }
372
+ inline CUDA_CALLABLE int32 step(int32 x) { return x < 0 ? 1 : 0; }
373
+ inline CUDA_CALLABLE int64 step(int64 x) { return x < 0 ? 1 : 0; }
374
+ inline CUDA_CALLABLE uint8 step(uint8 x) { return 0; }
375
+ inline CUDA_CALLABLE uint16 step(uint16 x) { return 0; }
376
+ inline CUDA_CALLABLE uint32 step(uint32 x) { return 0; }
377
+ inline CUDA_CALLABLE uint64 step(uint64 x) { return 0; }
378
+
379
+
380
+ inline CUDA_CALLABLE int8 sign(int8 x) { return x < 0 ? -1 : 1; }
381
+ inline CUDA_CALLABLE int8 sign(int16 x) { return x < 0 ? -1 : 1; }
382
+ inline CUDA_CALLABLE int8 sign(int32 x) { return x < 0 ? -1 : 1; }
383
+ inline CUDA_CALLABLE int8 sign(int64 x) { return x < 0 ? -1 : 1; }
384
+ inline CUDA_CALLABLE uint8 sign(uint8 x) { return 1; }
385
+ inline CUDA_CALLABLE uint16 sign(uint16 x) { return 1; }
386
+ inline CUDA_CALLABLE uint32 sign(uint32 x) { return 1; }
387
+ inline CUDA_CALLABLE uint64 sign(uint64 x) { return 1; }
388
+
389
+
390
+ // Catch-all for non-float, non-integer types
391
+ template<typename T>
392
+ inline bool CUDA_CALLABLE isfinite(const T&)
393
+ {
394
+ return true;
395
+ }
396
+
397
+ inline bool CUDA_CALLABLE isfinite(half x)
398
+ {
399
+ return ::isfinite(float(x));
400
+ }
401
+ inline bool CUDA_CALLABLE isfinite(float x)
402
+ {
403
+ return ::isfinite(x);
404
+ }
405
+ inline bool CUDA_CALLABLE isfinite(double x)
406
+ {
407
+ return ::isfinite(x);
408
+ }
409
+
410
+ inline bool CUDA_CALLABLE isnan(half x)
411
+ {
412
+ return ::isnan(float(x));
413
+ }
414
+ inline bool CUDA_CALLABLE isnan(float x)
415
+ {
416
+ return ::isnan(x);
417
+ }
418
+ inline bool CUDA_CALLABLE isnan(double x)
419
+ {
420
+ return ::isnan(x);
421
+ }
422
+
423
+ inline bool CUDA_CALLABLE isinf(half x)
424
+ {
425
+ return ::isinf(float(x));
426
+ }
427
+ inline bool CUDA_CALLABLE isinf(float x)
428
+ {
429
+ return ::isinf(x);
430
+ }
431
+ inline bool CUDA_CALLABLE isinf(double x)
432
+ {
433
+ return ::isinf(x);
434
+ }
435
+
436
+ template<typename T>
437
+ inline CUDA_CALLABLE void print(const T&)
438
+ {
439
+ printf("<type without print implementation>\n");
440
+ }
441
+
442
+ inline CUDA_CALLABLE void print(float16 f)
443
+ {
444
+ printf("%g\n", half_to_float(f));
445
+ }
446
+
447
+ inline CUDA_CALLABLE void print(float f)
448
+ {
449
+ printf("%g\n", f);
450
+ }
451
+
452
+ inline CUDA_CALLABLE void print(double f)
453
+ {
454
+ printf("%g\n", f);
455
+ }
456
+
457
+
458
+ // basic ops for float types
459
+ #define DECLARE_FLOAT_OPS(T) \
460
+ inline CUDA_CALLABLE T mul(T a, T b) { return a*b; } \
461
+ inline CUDA_CALLABLE T add(T a, T b) { return a+b; } \
462
+ inline CUDA_CALLABLE T sub(T a, T b) { return a-b; } \
463
+ inline CUDA_CALLABLE T min(T a, T b) { return a<b?a:b; } \
464
+ inline CUDA_CALLABLE T max(T a, T b) { return a>b?a:b; } \
465
+ inline CUDA_CALLABLE T sign(T x) { return x < T(0) ? -1 : 1; } \
466
+ inline CUDA_CALLABLE T step(T x) { return x < T(0) ? T(1) : T(0); }\
467
+ inline CUDA_CALLABLE T nonzero(T x) { return x == T(0) ? T(0) : T(1); }\
468
+ inline CUDA_CALLABLE T clamp(T x, T a, T b) { return min(max(a, x), b); }\
469
+ inline CUDA_CALLABLE void adj_abs(T x, T& adj_x, T adj_ret) \
470
+ {\
471
+ if (x < T(0))\
472
+ adj_x -= adj_ret;\
473
+ else\
474
+ adj_x += adj_ret;\
475
+ }\
476
+ inline CUDA_CALLABLE void adj_mul(T a, T b, T& adj_a, T& adj_b, T adj_ret) { adj_a += b*adj_ret; adj_b += a*adj_ret; } \
477
+ inline CUDA_CALLABLE void adj_add(T a, T b, T& adj_a, T& adj_b, T adj_ret) { adj_a += adj_ret; adj_b += adj_ret; } \
478
+ inline CUDA_CALLABLE void adj_sub(T a, T b, T& adj_a, T& adj_b, T adj_ret) { adj_a += adj_ret; adj_b -= adj_ret; } \
479
+ inline CUDA_CALLABLE void adj_min(T a, T b, T& adj_a, T& adj_b, T adj_ret) \
480
+ { \
481
+ if (a < b) \
482
+ adj_a += adj_ret; \
483
+ else \
484
+ adj_b += adj_ret; \
485
+ } \
486
+ inline CUDA_CALLABLE void adj_max(T a, T b, T& adj_a, T& adj_b, T adj_ret) \
487
+ { \
488
+ if (a > b) \
489
+ adj_a += adj_ret; \
490
+ else \
491
+ adj_b += adj_ret; \
492
+ } \
493
+ inline CUDA_CALLABLE void adj_floordiv(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
494
+ inline CUDA_CALLABLE void adj_mod(T a, T b, T& adj_a, T& adj_b, T adj_ret){ adj_a += adj_ret; }\
495
+ inline CUDA_CALLABLE void adj_sign(T x, T adj_x, T& adj_ret) { }\
496
+ inline CUDA_CALLABLE void adj_step(T x, T& adj_x, T adj_ret) { }\
497
+ inline CUDA_CALLABLE void adj_nonzero(T x, T& adj_x, T adj_ret) { }\
498
+ inline CUDA_CALLABLE void adj_clamp(T x, T a, T b, T& adj_x, T& adj_a, T& adj_b, T adj_ret)\
499
+ {\
500
+ if (x < a)\
501
+ adj_a += adj_ret;\
502
+ else if (x > b)\
503
+ adj_b += adj_ret;\
504
+ else\
505
+ adj_x += adj_ret;\
506
+ }\
507
+ inline CUDA_CALLABLE T div(T a, T b)\
508
+ {\
509
+ DO_IF_FPCHECK(\
510
+ if (!isfinite(a) || !isfinite(b) || b == T(0))\
511
+ {\
512
+ printf("%s:%d div(%f, %f)\n", __FILE__, __LINE__, float(a), float(b));\
513
+ assert(0);\
514
+ })\
515
+ return a/b;\
516
+ }\
517
+ inline CUDA_CALLABLE void adj_div(T a, T b, T ret, T& adj_a, T& adj_b, T adj_ret)\
518
+ {\
519
+ adj_a += adj_ret/b;\
520
+ adj_b -= adj_ret*(ret)/b;\
521
+ DO_IF_FPCHECK(\
522
+ if (!isfinite(adj_a) || !isfinite(adj_b))\
523
+ {\
524
+ printf("%s:%d - adj_div(%f, %f, %f, %f, %f)\n", __FILE__, __LINE__, float(a), float(b), float(adj_a), float(adj_b), float(adj_ret));\
525
+ assert(0);\
526
+ })\
527
+ }\
528
+ inline CUDA_CALLABLE void adj_isnan(const T&, T&, bool) { }\
529
+ inline CUDA_CALLABLE void adj_isinf(const T&, T&, bool) { }\
530
+ inline CUDA_CALLABLE void adj_isfinite(const T&, T&, bool) { }
531
+
532
+ DECLARE_FLOAT_OPS(float16)
533
+ DECLARE_FLOAT_OPS(float32)
534
+ DECLARE_FLOAT_OPS(float64)
535
+
536
+
537
+
538
+ // basic ops for float types
539
+ inline CUDA_CALLABLE float16 mod(float16 a, float16 b)
540
+ {
541
+ #if FP_CHECK
542
+ if (!isfinite(a) || !isfinite(b) || float(b) == 0.0f)
543
+ {
544
+ printf("%s:%d mod(%f, %f)\n", __FILE__, __LINE__, float(a), float(b));
545
+ assert(0);
546
+ }
547
+ #endif
548
+ return fmodf(float(a), float(b));
549
+ }
550
+
551
+ inline CUDA_CALLABLE float32 mod(float32 a, float32 b)
552
+ {
553
+ #if FP_CHECK
554
+ if (!isfinite(a) || !isfinite(b) || b == 0.0f)
555
+ {
556
+ printf("%s:%d mod(%f, %f)\n", __FILE__, __LINE__, a, b);
557
+ assert(0);
558
+ }
559
+ #endif
560
+ return fmodf(a, b);
561
+ }
562
+
563
+ inline CUDA_CALLABLE double mod(double a, double b)
564
+ {
565
+ #if FP_CHECK
566
+ if (!isfinite(a) || !isfinite(b) || b == 0.0f)
567
+ {
568
+ printf("%s:%d mod(%f, %f)\n", __FILE__, __LINE__, a, b);
569
+ assert(0);
570
+ }
571
+ #endif
572
+ return fmod(a, b);
573
+ }
574
+
575
+ inline CUDA_CALLABLE half log(half a)
576
+ {
577
+ #if FP_CHECK
578
+ if (!isfinite(a) || float(a) < 0.0f)
579
+ {
580
+ printf("%s:%d log(%f)\n", __FILE__, __LINE__, float(a));
581
+ assert(0);
582
+ }
583
+ #endif
584
+ return ::logf(a);
585
+ }
586
+
587
+ inline CUDA_CALLABLE float log(float a)
588
+ {
589
+ #if FP_CHECK
590
+ if (!isfinite(a) || a < 0.0f)
591
+ {
592
+ printf("%s:%d log(%f)\n", __FILE__, __LINE__, a);
593
+ assert(0);
594
+ }
595
+ #endif
596
+ return ::logf(a);
597
+ }
598
+
599
+ inline CUDA_CALLABLE double log(double a)
600
+ {
601
+ #if FP_CHECK
602
+ if (!isfinite(a) || a < 0.0)
603
+ {
604
+ printf("%s:%d log(%f)\n", __FILE__, __LINE__, a);
605
+ assert(0);
606
+ }
607
+ #endif
608
+ return ::log(a);
609
+ }
610
+
611
+ inline CUDA_CALLABLE half log2(half a)
612
+ {
613
+ #if FP_CHECK
614
+ if (!isfinite(a) || float(a) < 0.0f)
615
+ {
616
+ printf("%s:%d log2(%f)\n", __FILE__, __LINE__, float(a));
617
+ assert(0);
618
+ }
619
+ #endif
620
+
621
+ return ::log2f(float(a));
622
+ }
623
+
624
+ inline CUDA_CALLABLE float log2(float a)
625
+ {
626
+ #if FP_CHECK
627
+ if (!isfinite(a) || a < 0.0f)
628
+ {
629
+ printf("%s:%d log2(%f)\n", __FILE__, __LINE__, a);
630
+ assert(0);
631
+ }
632
+ #endif
633
+
634
+ return ::log2f(a);
635
+ }
636
+
637
+ inline CUDA_CALLABLE double log2(double a)
638
+ {
639
+ #if FP_CHECK
640
+ if (!isfinite(a) || a < 0.0)
641
+ {
642
+ printf("%s:%d log2(%f)\n", __FILE__, __LINE__, a);
643
+ assert(0);
644
+ }
645
+ #endif
646
+
647
+ return ::log2(a);
648
+ }
649
+
650
+ inline CUDA_CALLABLE half log10(half a)
651
+ {
652
+ #if FP_CHECK
653
+ if (!isfinite(a) || float(a) < 0.0f)
654
+ {
655
+ printf("%s:%d log10(%f)\n", __FILE__, __LINE__, float(a));
656
+ assert(0);
657
+ }
658
+ #endif
659
+
660
+ return ::log10f(float(a));
661
+ }
662
+
663
+ inline CUDA_CALLABLE float log10(float a)
664
+ {
665
+ #if FP_CHECK
666
+ if (!isfinite(a) || a < 0.0f)
667
+ {
668
+ printf("%s:%d log10(%f)\n", __FILE__, __LINE__, a);
669
+ assert(0);
670
+ }
671
+ #endif
672
+
673
+ return ::log10f(a);
674
+ }
675
+
676
+ inline CUDA_CALLABLE double log10(double a)
677
+ {
678
+ #if FP_CHECK
679
+ if (!isfinite(a) || a < 0.0)
680
+ {
681
+ printf("%s:%d log10(%f)\n", __FILE__, __LINE__, a);
682
+ assert(0);
683
+ }
684
+ #endif
685
+
686
+ return ::log10(a);
687
+ }
688
+
689
+ inline CUDA_CALLABLE half exp(half a)
690
+ {
691
+ half result = ::expf(float(a));
692
+ #if FP_CHECK
693
+ if (!isfinite(a) || !isfinite(result))
694
+ {
695
+ printf("%s:%d exp(%f) = %f\n", __FILE__, __LINE__, float(a), float(result));
696
+ assert(0);
697
+ }
698
+ #endif
699
+ return result;
700
+ }
701
+ inline CUDA_CALLABLE float exp(float a)
702
+ {
703
+ float result = ::expf(a);
704
+ #if FP_CHECK
705
+ if (!isfinite(a) || !isfinite(result))
706
+ {
707
+ printf("%s:%d exp(%f) = %f\n", __FILE__, __LINE__, a, result);
708
+ assert(0);
709
+ }
710
+ #endif
711
+ return result;
712
+ }
713
+ inline CUDA_CALLABLE double exp(double a)
714
+ {
715
+ double result = ::exp(a);
716
+ #if FP_CHECK
717
+ if (!isfinite(a) || !isfinite(result))
718
+ {
719
+ printf("%s:%d exp(%f) = %f\n", __FILE__, __LINE__, a, result);
720
+ assert(0);
721
+ }
722
+ #endif
723
+ return result;
724
+ }
725
+
726
+ inline CUDA_CALLABLE half pow(half a, half b)
727
+ {
728
+ float result = ::powf(float(a), float(b));
729
+ #if FP_CHECK
730
+ if (!isfinite(float(a)) || !isfinite(float(b)) || !isfinite(result))
731
+ {
732
+ printf("%s:%d pow(%f, %f) = %f\n", __FILE__, __LINE__, float(a), float(b), result);
733
+ assert(0);
734
+ }
735
+ #endif
736
+ return result;
737
+ }
738
+
739
+ inline CUDA_CALLABLE float pow(float a, float b)
740
+ {
741
+ float result = ::powf(a, b);
742
+ #if FP_CHECK
743
+ if (!isfinite(a) || !isfinite(b) || !isfinite(result))
744
+ {
745
+ printf("%s:%d pow(%f, %f) = %f\n", __FILE__, __LINE__, a, b, result);
746
+ assert(0);
747
+ }
748
+ #endif
749
+ return result;
750
+ }
751
+
752
+ inline CUDA_CALLABLE double pow(double a, double b)
753
+ {
754
+ double result = ::pow(a, b);
755
+ #if FP_CHECK
756
+ if (!isfinite(a) || !isfinite(b) || !isfinite(result))
757
+ {
758
+ printf("%s:%d pow(%f, %f) = %f\n", __FILE__, __LINE__, a, b, result);
759
+ assert(0);
760
+ }
761
+ #endif
762
+ return result;
763
+ }
764
+
765
+ inline CUDA_CALLABLE half floordiv(half a, half b)
766
+ {
767
+ #if FP_CHECK
768
+ if (!isfinite(a) || !isfinite(b) || float(b) == 0.0f)
769
+ {
770
+ printf("%s:%d floordiv(%f, %f)\n", __FILE__, __LINE__, float(a), float(b));
771
+ assert(0);
772
+ }
773
+ #endif
774
+ return floorf(float(a/b));
775
+ }
776
+ inline CUDA_CALLABLE float floordiv(float a, float b)
777
+ {
778
+ #if FP_CHECK
779
+ if (!isfinite(a) || !isfinite(b) || b == 0.0f)
780
+ {
781
+ printf("%s:%d floordiv(%f, %f)\n", __FILE__, __LINE__, a, b);
782
+ assert(0);
783
+ }
784
+ #endif
785
+ return floorf(a/b);
786
+ }
787
+ inline CUDA_CALLABLE double floordiv(double a, double b)
788
+ {
789
+ #if FP_CHECK
790
+ if (!isfinite(a) || !isfinite(b) || b == 0.0)
791
+ {
792
+ printf("%s:%d floordiv(%f, %f)\n", __FILE__, __LINE__, a, b);
793
+ assert(0);
794
+ }
795
+ #endif
796
+ return ::floor(a/b);
797
+ }
798
+
799
+ inline CUDA_CALLABLE float leaky_min(float a, float b, float r) { return min(a, b); }
800
+ inline CUDA_CALLABLE float leaky_max(float a, float b, float r) { return max(a, b); }
801
+
802
+ inline CUDA_CALLABLE half abs(half x) { return ::fabsf(float(x)); }
803
+ inline CUDA_CALLABLE float abs(float x) { return ::fabsf(x); }
804
+ inline CUDA_CALLABLE double abs(double x) { return ::fabs(x); }
805
+
806
+ inline CUDA_CALLABLE float acos(float x){ return ::acosf(min(max(x, -1.0f), 1.0f)); }
807
+ inline CUDA_CALLABLE float asin(float x){ return ::asinf(min(max(x, -1.0f), 1.0f)); }
808
+ inline CUDA_CALLABLE float atan(float x) { return ::atanf(x); }
809
+ inline CUDA_CALLABLE float atan2(float y, float x) { return ::atan2f(y, x); }
810
+ inline CUDA_CALLABLE float sin(float x) { return ::sinf(x); }
811
+ inline CUDA_CALLABLE float cos(float x) { return ::cosf(x); }
812
+
813
+ inline CUDA_CALLABLE double acos(double x){ return ::acos(min(max(x, -1.0), 1.0)); }
814
+ inline CUDA_CALLABLE double asin(double x){ return ::asin(min(max(x, -1.0), 1.0)); }
815
+ inline CUDA_CALLABLE double atan(double x) { return ::atan(x); }
816
+ inline CUDA_CALLABLE double atan2(double y, double x) { return ::atan2(y, x); }
817
+ inline CUDA_CALLABLE double sin(double x) { return ::sin(x); }
818
+ inline CUDA_CALLABLE double cos(double x) { return ::cos(x); }
819
+
820
+ inline CUDA_CALLABLE half acos(half x){ return ::acosf(min(max(float(x), -1.0f), 1.0f)); }
821
+ inline CUDA_CALLABLE half asin(half x){ return ::asinf(min(max(float(x), -1.0f), 1.0f)); }
822
+ inline CUDA_CALLABLE half atan(half x) { return ::atanf(float(x)); }
823
+ inline CUDA_CALLABLE half atan2(half y, half x) { return ::atan2f(float(y), float(x)); }
824
+ inline CUDA_CALLABLE half sin(half x) { return ::sinf(float(x)); }
825
+ inline CUDA_CALLABLE half cos(half x) { return ::cosf(float(x)); }
826
+
827
+
828
+ inline CUDA_CALLABLE float sqrt(float x)
829
+ {
830
+ #if FP_CHECK
831
+ if (x < 0.0f)
832
+ {
833
+ printf("%s:%d sqrt(%f)\n", __FILE__, __LINE__, x);
834
+ assert(0);
835
+ }
836
+ #endif
837
+ return ::sqrtf(x);
838
+ }
839
+ inline CUDA_CALLABLE double sqrt(double x)
840
+ {
841
+ #if FP_CHECK
842
+ if (x < 0.0)
843
+ {
844
+ printf("%s:%d sqrt(%f)\n", __FILE__, __LINE__, x);
845
+ assert(0);
846
+ }
847
+ #endif
848
+ return ::sqrt(x);
849
+ }
850
+ inline CUDA_CALLABLE half sqrt(half x)
851
+ {
852
+ #if FP_CHECK
853
+ if (float(x) < 0.0f)
854
+ {
855
+ printf("%s:%d sqrt(%f)\n", __FILE__, __LINE__, float(x));
856
+ assert(0);
857
+ }
858
+ #endif
859
+ return ::sqrtf(float(x));
860
+ }
861
+
862
+ inline CUDA_CALLABLE float cbrt(float x) { return ::cbrtf(x); }
863
+ inline CUDA_CALLABLE double cbrt(double x) { return ::cbrt(x); }
864
+ inline CUDA_CALLABLE half cbrt(half x) { return ::cbrtf(float(x)); }
865
+
866
+ inline CUDA_CALLABLE float tan(float x) { return ::tanf(x); }
867
+ inline CUDA_CALLABLE float sinh(float x) { return ::sinhf(x);}
868
+ inline CUDA_CALLABLE float cosh(float x) { return ::coshf(x);}
869
+ inline CUDA_CALLABLE float tanh(float x) { return ::tanhf(x);}
870
+ inline CUDA_CALLABLE float degrees(float x) { return x * RAD_TO_DEG;}
871
+ inline CUDA_CALLABLE float radians(float x) { return x * DEG_TO_RAD;}
872
+
873
+ inline CUDA_CALLABLE double tan(double x) { return ::tan(x); }
874
+ inline CUDA_CALLABLE double sinh(double x) { return ::sinh(x);}
875
+ inline CUDA_CALLABLE double cosh(double x) { return ::cosh(x);}
876
+ inline CUDA_CALLABLE double tanh(double x) { return ::tanh(x);}
877
+ inline CUDA_CALLABLE double degrees(double x) { return x * RAD_TO_DEG;}
878
+ inline CUDA_CALLABLE double radians(double x) { return x * DEG_TO_RAD;}
879
+
880
+ inline CUDA_CALLABLE half tan(half x) { return ::tanf(float(x)); }
881
+ inline CUDA_CALLABLE half sinh(half x) { return ::sinhf(float(x));}
882
+ inline CUDA_CALLABLE half cosh(half x) { return ::coshf(float(x));}
883
+ inline CUDA_CALLABLE half tanh(half x) { return ::tanhf(float(x));}
884
+ inline CUDA_CALLABLE half degrees(half x) { return x * RAD_TO_DEG;}
885
+ inline CUDA_CALLABLE half radians(half x) { return x * DEG_TO_RAD;}
886
+
887
+ inline CUDA_CALLABLE float round(float x) { return ::roundf(x); }
888
+ inline CUDA_CALLABLE float rint(float x) { return ::rintf(x); }
889
+ inline CUDA_CALLABLE float trunc(float x) { return ::truncf(x); }
890
+ inline CUDA_CALLABLE float floor(float x) { return ::floorf(x); }
891
+ inline CUDA_CALLABLE float ceil(float x) { return ::ceilf(x); }
892
+ inline CUDA_CALLABLE float frac(float x) { return x - trunc(x); }
893
+
894
+ inline CUDA_CALLABLE double round(double x) { return ::round(x); }
895
+ inline CUDA_CALLABLE double rint(double x) { return ::rint(x); }
896
+ inline CUDA_CALLABLE double trunc(double x) { return ::trunc(x); }
897
+ inline CUDA_CALLABLE double floor(double x) { return ::floor(x); }
898
+ inline CUDA_CALLABLE double ceil(double x) { return ::ceil(x); }
899
+ inline CUDA_CALLABLE double frac(double x) { return x - trunc(x); }
900
+
901
+ inline CUDA_CALLABLE half round(half x) { return ::roundf(float(x)); }
902
+ inline CUDA_CALLABLE half rint(half x) { return ::rintf(float(x)); }
903
+ inline CUDA_CALLABLE half trunc(half x) { return ::truncf(float(x)); }
904
+ inline CUDA_CALLABLE half floor(half x) { return ::floorf(float(x)); }
905
+ inline CUDA_CALLABLE half ceil(half x) { return ::ceilf(float(x)); }
906
+ inline CUDA_CALLABLE half frac(half x) { return float(x) - trunc(float(x)); }
907
+
908
+ #define DECLARE_ADJOINTS(T)\
909
+ inline CUDA_CALLABLE void adj_log(T a, T& adj_a, T adj_ret)\
910
+ {\
911
+ adj_a += (T(1)/a)*adj_ret;\
912
+ DO_IF_FPCHECK(if (!isfinite(adj_a))\
913
+ {\
914
+ printf("%s:%d - adj_log(%f, %f, %f)\n", __FILE__, __LINE__, float(a), float(adj_a), float(adj_ret));\
915
+ assert(0);\
916
+ })\
917
+ }\
918
+ inline CUDA_CALLABLE void adj_log2(T a, T& adj_a, T adj_ret)\
919
+ { \
920
+ adj_a += (T(1)/a)*(T(1)/log(T(2)))*adj_ret; \
921
+ DO_IF_FPCHECK(if (!isfinite(adj_a))\
922
+ {\
923
+ printf("%s:%d - adj_log2(%f, %f, %f)\n", __FILE__, __LINE__, float(a), float(adj_a), float(adj_ret));\
924
+ assert(0);\
925
+ }) \
926
+ }\
927
+ inline CUDA_CALLABLE void adj_log10(T a, T& adj_a, T adj_ret)\
928
+ {\
929
+ adj_a += (T(1)/a)*(T(1)/log(T(10)))*adj_ret; \
930
+ DO_IF_FPCHECK(if (!isfinite(adj_a))\
931
+ {\
932
+ printf("%s:%d - adj_log10(%f, %f, %f)\n", __FILE__, __LINE__, float(a), float(adj_a), float(adj_ret));\
933
+ assert(0);\
934
+ })\
935
+ }\
936
+ inline CUDA_CALLABLE void adj_exp(T a, T ret, T& adj_a, T adj_ret) { adj_a += ret*adj_ret; }\
937
+ inline CUDA_CALLABLE void adj_pow(T a, T b, T ret, T& adj_a, T& adj_b, T adj_ret)\
938
+ { \
939
+ adj_a += b*pow(a, b-T(1))*adj_ret;\
940
+ adj_b += log(a)*ret*adj_ret;\
941
+ DO_IF_FPCHECK(if (!isfinite(adj_a) || !isfinite(adj_b))\
942
+ {\
943
+ printf("%s:%d - adj_pow(%f, %f, %f, %f, %f)\n", __FILE__, __LINE__, float(a), float(b), float(adj_a), float(adj_b), float(adj_ret));\
944
+ assert(0);\
945
+ })\
946
+ }\
947
+ inline CUDA_CALLABLE void adj_leaky_min(T a, T b, T r, T& adj_a, T& adj_b, T& adj_r, T adj_ret)\
948
+ {\
949
+ if (a < b)\
950
+ adj_a += adj_ret;\
951
+ else\
952
+ {\
953
+ adj_a += r*adj_ret;\
954
+ adj_b += adj_ret;\
955
+ }\
956
+ }\
957
+ inline CUDA_CALLABLE void adj_leaky_max(T a, T b, T r, T& adj_a, T& adj_b, T& adj_r, T adj_ret)\
958
+ {\
959
+ if (a > b)\
960
+ adj_a += adj_ret;\
961
+ else\
962
+ {\
963
+ adj_a += r*adj_ret;\
964
+ adj_b += adj_ret;\
965
+ }\
966
+ }\
967
+ inline CUDA_CALLABLE void adj_acos(T x, T& adj_x, T adj_ret)\
968
+ {\
969
+ T d = sqrt(T(1)-x*x);\
970
+ DO_IF_FPCHECK(adj_x -= (T(1)/d)*adj_ret;\
971
+ if (!isfinite(d) || !isfinite(adj_x))\
972
+ {\
973
+ printf("%s:%d - adj_acos(%f, %f, %f)\n", __FILE__, __LINE__, float(x), float(adj_x), float(adj_ret)); \
974
+ assert(0);\
975
+ })\
976
+ DO_IF_NO_FPCHECK(if (d > T(0))\
977
+ adj_x -= (T(1)/d)*adj_ret;)\
978
+ }\
979
+ inline CUDA_CALLABLE void adj_asin(T x, T& adj_x, T adj_ret)\
980
+ {\
981
+ T d = sqrt(T(1)-x*x);\
982
+ DO_IF_FPCHECK(adj_x += (T(1)/d)*adj_ret;\
983
+ if (!isfinite(d) || !isfinite(adj_x))\
984
+ {\
985
+ printf("%s:%d - adj_asin(%f, %f, %f)\n", __FILE__, __LINE__, float(x), float(adj_x), float(adj_ret)); \
986
+ assert(0);\
987
+ })\
988
+ DO_IF_NO_FPCHECK(if (d > T(0))\
989
+ adj_x += (T(1)/d)*adj_ret;)\
990
+ }\
991
+ inline CUDA_CALLABLE void adj_tan(T x, T& adj_x, T adj_ret)\
992
+ {\
993
+ T cos_x = cos(x);\
994
+ DO_IF_FPCHECK(adj_x += (T(1)/(cos_x*cos_x))*adj_ret;\
995
+ if (!isfinite(adj_x) || cos_x == T(0))\
996
+ {\
997
+ printf("%s:%d - adj_tan(%f, %f, %f)\n", __FILE__, __LINE__, float(x), float(adj_x), float(adj_ret));\
998
+ assert(0);\
999
+ })\
1000
+ DO_IF_NO_FPCHECK(if (cos_x != T(0))\
1001
+ adj_x += (T(1)/(cos_x*cos_x))*adj_ret;)\
1002
+ }\
1003
+ inline CUDA_CALLABLE void adj_atan(T x, T& adj_x, T adj_ret)\
1004
+ {\
1005
+ adj_x += adj_ret /(x*x + T(1));\
1006
+ }\
1007
+ inline CUDA_CALLABLE void adj_atan2(T y, T x, T& adj_y, T& adj_x, T adj_ret)\
1008
+ {\
1009
+ T d = x*x + y*y;\
1010
+ DO_IF_FPCHECK(adj_x -= y/d*adj_ret;\
1011
+ adj_y += x/d*adj_ret;\
1012
+ if (!isfinite(adj_x) || !isfinite(adj_y) || d == T(0))\
1013
+ {\
1014
+ printf("%s:%d - adj_atan2(%f, %f, %f, %f, %f)\n", __FILE__, __LINE__, float(y), float(x), float(adj_y), float(adj_x), float(adj_ret));\
1015
+ assert(0);\
1016
+ })\
1017
+ DO_IF_NO_FPCHECK(if (d > T(0))\
1018
+ {\
1019
+ adj_x -= (y/d)*adj_ret;\
1020
+ adj_y += (x/d)*adj_ret;\
1021
+ })\
1022
+ }\
1023
+ inline CUDA_CALLABLE void adj_sin(T x, T& adj_x, T adj_ret)\
1024
+ {\
1025
+ adj_x += cos(x)*adj_ret;\
1026
+ }\
1027
+ inline CUDA_CALLABLE void adj_cos(T x, T& adj_x, T adj_ret)\
1028
+ {\
1029
+ adj_x -= sin(x)*adj_ret;\
1030
+ }\
1031
+ inline CUDA_CALLABLE void adj_sinh(T x, T& adj_x, T adj_ret)\
1032
+ {\
1033
+ adj_x += cosh(x)*adj_ret;\
1034
+ }\
1035
+ inline CUDA_CALLABLE void adj_cosh(T x, T& adj_x, T adj_ret)\
1036
+ {\
1037
+ adj_x += sinh(x)*adj_ret;\
1038
+ }\
1039
+ inline CUDA_CALLABLE void adj_tanh(T x, T ret, T& adj_x, T adj_ret)\
1040
+ {\
1041
+ adj_x += (T(1) - ret*ret)*adj_ret;\
1042
+ }\
1043
+ inline CUDA_CALLABLE void adj_sqrt(T x, T ret, T& adj_x, T adj_ret)\
1044
+ {\
1045
+ adj_x += T(0.5)*(T(1)/ret)*adj_ret;\
1046
+ DO_IF_FPCHECK(if (!isfinite(adj_x))\
1047
+ {\
1048
+ printf("%s:%d - adj_sqrt(%f, %f, %f)\n", __FILE__, __LINE__, float(x), float(adj_x), float(adj_ret));\
1049
+ assert(0);\
1050
+ })\
1051
+ }\
1052
+ inline CUDA_CALLABLE void adj_cbrt(T x, T ret, T& adj_x, T adj_ret)\
1053
+ {\
1054
+ adj_x += (T(1)/T(3))*(T(1)/(ret*ret))*adj_ret;\
1055
+ DO_IF_FPCHECK(if (!isfinite(adj_x))\
1056
+ {\
1057
+ printf("%s:%d - adj_cbrt(%f, %f, %f)\n", __FILE__, __LINE__, float(x), float(adj_x), float(adj_ret));\
1058
+ assert(0);\
1059
+ })\
1060
+ }\
1061
+ inline CUDA_CALLABLE void adj_degrees(T x, T& adj_x, T adj_ret)\
1062
+ {\
1063
+ adj_x += RAD_TO_DEG * adj_ret;\
1064
+ }\
1065
+ inline CUDA_CALLABLE void adj_radians(T x, T& adj_x, T adj_ret)\
1066
+ {\
1067
+ adj_x += DEG_TO_RAD * adj_ret;\
1068
+ }\
1069
+ inline CUDA_CALLABLE void adj_round(T x, T& adj_x, T adj_ret){ }\
1070
+ inline CUDA_CALLABLE void adj_rint(T x, T& adj_x, T adj_ret){ }\
1071
+ inline CUDA_CALLABLE void adj_trunc(T x, T& adj_x, T adj_ret){ }\
1072
+ inline CUDA_CALLABLE void adj_floor(T x, T& adj_x, T adj_ret){ }\
1073
+ inline CUDA_CALLABLE void adj_ceil(T x, T& adj_x, T adj_ret){ }\
1074
+ inline CUDA_CALLABLE void adj_frac(T x, T& adj_x, T adj_ret){ }
1075
+
1076
+ DECLARE_ADJOINTS(float16)
1077
+ DECLARE_ADJOINTS(float32)
1078
+ DECLARE_ADJOINTS(float64)
1079
+
1080
+ template <typename C, typename T>
1081
+ CUDA_CALLABLE inline T select(const C& cond, const T& a, const T& b)
1082
+ {
1083
+ // The double NOT operator !! casts to bool without compiler warnings.
1084
+ return (!!cond) ? b : a;
1085
+ }
1086
+
1087
+ template <typename C, typename T>
1088
+ CUDA_CALLABLE inline void adj_select(const C& cond, const T& a, const T& b, C& adj_cond, T& adj_a, T& adj_b, const T& adj_ret)
1089
+ {
1090
+ // The double NOT operator !! casts to bool without compiler warnings.
1091
+ if (!!cond)
1092
+ adj_b += adj_ret;
1093
+ else
1094
+ adj_a += adj_ret;
1095
+ }
1096
+
1097
+ template <typename C, typename T>
1098
+ CUDA_CALLABLE inline T where(const C& cond, const T& a, const T& b)
1099
+ {
1100
+ // The double NOT operator !! casts to bool without compiler warnings.
1101
+ return (!!cond) ? a : b;
1102
+ }
1103
+
1104
+ template <typename C, typename T>
1105
+ CUDA_CALLABLE inline void adj_where(const C& cond, const T& a, const T& b, C& adj_cond, T& adj_a, T& adj_b, const T& adj_ret)
1106
+ {
1107
+ // The double NOT operator !! casts to bool without compiler warnings.
1108
+ if (!!cond)
1109
+ adj_a += adj_ret;
1110
+ else
1111
+ adj_b += adj_ret;
1112
+ }
1113
+
1114
+ template <typename T>
1115
+ CUDA_CALLABLE inline T copy(const T& src)
1116
+ {
1117
+ return src;
1118
+ }
1119
+
1120
+ template <typename T>
1121
+ CUDA_CALLABLE inline void adj_copy(const T& src, T& adj_src, T& adj_dest)
1122
+ {
1123
+ adj_src += adj_dest;
1124
+ adj_dest = T{};
1125
+ }
1126
+
1127
+ template <typename T>
1128
+ CUDA_CALLABLE inline void assign(T& dest, const T& src)
1129
+ {
1130
+ dest = src;
1131
+ }
1132
+
1133
+ template <typename T>
1134
+ CUDA_CALLABLE inline void adj_assign(T& dest, const T& src, T& adj_dest, T& adj_src)
1135
+ {
1136
+ // this is generally a non-differentiable operation since it violates SSA,
1137
+ // except in read-modify-write statements which are reversible through backpropagation
1138
+ adj_src = adj_dest;
1139
+ adj_dest = T{};
1140
+ }
1141
+
1142
+
1143
+ // some helpful operator overloads (just for C++ use, these are not adjointed)
1144
+
1145
+ template <typename T>
1146
+ CUDA_CALLABLE inline T& operator += (T& a, const T& b) { a = add(a, b); return a; }
1147
+
1148
+ template <typename T>
1149
+ CUDA_CALLABLE inline T& operator -= (T& a, const T& b) { a = sub(a, b); return a; }
1150
+
1151
+ template <typename T>
1152
+ CUDA_CALLABLE inline T operator+(const T& a, const T& b) { return add(a, b); }
1153
+
1154
+ template <typename T>
1155
+ CUDA_CALLABLE inline T operator-(const T& a, const T& b) { return sub(a, b); }
1156
+
1157
+ template <typename T>
1158
+ CUDA_CALLABLE inline T pos(const T& x) { return x; }
1159
+ template <typename T>
1160
+ CUDA_CALLABLE inline void adj_pos(const T& x, T& adj_x, const T& adj_ret) { adj_x += T(adj_ret); }
1161
+
1162
+ // unary negation implemented as negative multiply, not sure the fp implications of this
1163
+ // may be better as 0.0 - x?
1164
+ template <typename T>
1165
+ CUDA_CALLABLE inline T neg(const T& x) { return T(0.0) - x; }
1166
+ template <typename T>
1167
+ CUDA_CALLABLE inline void adj_neg(const T& x, T& adj_x, const T& adj_ret) { adj_x += T(-adj_ret); }
1168
+
1169
+ // unary boolean negation
1170
+ template <typename T>
1171
+ CUDA_CALLABLE inline bool unot(const T& b) { return !b; }
1172
+ template <typename T>
1173
+ CUDA_CALLABLE inline void adj_unot(const T& b, T& adj_b, const bool& adj_ret) { }
1174
+
1175
+ const int LAUNCH_MAX_DIMS = 4; // should match types.py
1176
+
1177
+ struct launch_bounds_t
1178
+ {
1179
+ int shape[LAUNCH_MAX_DIMS]; // size of each dimension
1180
+ int ndim; // number of valid dimension
1181
+ size_t size; // total number of threads
1182
+ };
1183
+
1184
+ // represents coordinate in the launch grid
1185
+ struct launch_coord_t
1186
+ {
1187
+ int i;
1188
+ int j;
1189
+ int k;
1190
+ int l;
1191
+ };
1192
+
1193
+ // unravels a linear thread index to the corresponding launch grid coord (up to 4d)
1194
+ inline CUDA_CALLABLE launch_coord_t launch_coord(size_t linear, const launch_bounds_t& bounds)
1195
+ {
1196
+ launch_coord_t coord = {0, 0, 0, 0};
1197
+
1198
+ if (bounds.ndim > 3)
1199
+ {
1200
+ coord.l = linear%bounds.shape[3];
1201
+ linear /= bounds.shape[3];
1202
+ }
1203
+
1204
+ if (bounds.ndim > 2)
1205
+ {
1206
+ coord.k = linear%bounds.shape[2];
1207
+ linear /= bounds.shape[2];
1208
+ }
1209
+
1210
+ if (bounds.ndim > 1)
1211
+ {
1212
+ coord.j = linear%bounds.shape[1];
1213
+ linear /= bounds.shape[1];
1214
+ }
1215
+
1216
+ if (bounds.ndim > 0)
1217
+ {
1218
+ coord.i = linear;
1219
+ }
1220
+
1221
+ return coord;
1222
+ }
1223
+
1224
+ inline CUDA_CALLABLE int tid(size_t index, const launch_bounds_t& bounds)
1225
+ {
1226
+ // For the 1-D tid() we need to warn the user if we're about to provide a truncated index
1227
+ // Only do this in _DEBUG when called from device to avoid excessive register allocation
1228
+ #if defined(_DEBUG) || !defined(__CUDA_ARCH__)
1229
+ if (index > 2147483647) {
1230
+ printf("Warp warning: tid() is returning an overflowed int\n");
1231
+ }
1232
+ #endif
1233
+
1234
+ launch_coord_t c = launch_coord(index, bounds);
1235
+ return static_cast<int>(c.i);
1236
+ }
1237
+
1238
+ inline CUDA_CALLABLE_DEVICE void tid(int& i, int& j, size_t index, const launch_bounds_t& bounds)
1239
+ {
1240
+ launch_coord_t c = launch_coord(index, bounds);
1241
+ i = c.i;
1242
+ j = c.j;
1243
+ }
1244
+
1245
+ inline CUDA_CALLABLE_DEVICE void tid(int& i, int& j, int& k, size_t index, const launch_bounds_t& bounds)
1246
+ {
1247
+ launch_coord_t c = launch_coord(index, bounds);
1248
+ i = c.i;
1249
+ j = c.j;
1250
+ k = c.k;
1251
+ }
1252
+
1253
+ inline CUDA_CALLABLE_DEVICE void tid(int& i, int& j, int& k, int& l, size_t index, const launch_bounds_t& bounds)
1254
+ {
1255
+ launch_coord_t c = launch_coord(index, bounds);
1256
+ i = c.i;
1257
+ j = c.j;
1258
+ k = c.k;
1259
+ l = c.l;
1260
+ }
1261
+
1262
+ template<typename T>
1263
+ inline CUDA_CALLABLE T atomic_add(T* buf, T value)
1264
+ {
1265
+ #if !defined(__CUDA_ARCH__)
1266
+ T old = buf[0];
1267
+ buf[0] += value;
1268
+ return old;
1269
+ #else
1270
+ return atomicAdd(buf, value);
1271
+ #endif
1272
+ }
1273
+
1274
+ template<>
1275
+ inline CUDA_CALLABLE float16 atomic_add(float16* buf, float16 value)
1276
+ {
1277
+ #if !defined(__CUDA_ARCH__)
1278
+ float16 old = buf[0];
1279
+ buf[0] += value;
1280
+ return old;
1281
+ #elif defined(__clang__) // CUDA compiled by Clang
1282
+ __half r = atomicAdd(reinterpret_cast<__half*>(buf), *reinterpret_cast<__half*>(&value));
1283
+ return *reinterpret_cast<float16*>(&r);
1284
+ #else // CUDA compiled by NVRTC
1285
+ //return atomicAdd(buf, value);
1286
+
1287
+ /* Define __PTR for atomicAdd prototypes below, undef after done */
1288
+ #if (defined(_MSC_VER) && defined(_WIN64)) || defined(__LP64__) || defined(__CUDACC_RTC__)
1289
+ #define __PTR "l"
1290
+ #else
1291
+ #define __PTR "r"
1292
+ #endif /*(defined(_MSC_VER) && defined(_WIN64)) || defined(__LP64__) || defined(__CUDACC_RTC__)*/
1293
+
1294
+ half r = 0.0;
1295
+
1296
+ #if __CUDA_ARCH__ >= 700
1297
+
1298
+ asm volatile ("{ atom.add.noftz.f16 %0,[%1],%2; }\n"
1299
+ : "=h"(r.u)
1300
+ : __PTR(buf), "h"(value.u)
1301
+ : "memory");
1302
+ #endif
1303
+
1304
+ return r;
1305
+
1306
+ #undef __PTR
1307
+
1308
+ #endif // CUDA compiled by NVRTC
1309
+
1310
+ }
1311
+
1312
+ // emulate atomic float max with atomicCAS()
1313
+ inline CUDA_CALLABLE float atomic_max(float* address, float val)
1314
+ {
1315
+ #if defined(__CUDA_ARCH__)
1316
+ int *address_as_int = (int*)address;
1317
+ int old = *address_as_int, assumed;
1318
+
1319
+ while (val > __int_as_float(old))
1320
+ {
1321
+ assumed = old;
1322
+ old = atomicCAS(address_as_int, assumed,
1323
+ __float_as_int(val));
1324
+ }
1325
+
1326
+ return __int_as_float(old);
1327
+
1328
+ #else
1329
+ float old = *address;
1330
+ *address = max(old, val);
1331
+ return old;
1332
+ #endif
1333
+ }
1334
+
1335
+ // emulate atomic float min with atomicCAS()
1336
+ inline CUDA_CALLABLE float atomic_min(float* address, float val)
1337
+ {
1338
+ #if defined(__CUDA_ARCH__)
1339
+ int *address_as_int = (int*)address;
1340
+ int old = *address_as_int, assumed;
1341
+
1342
+ while (val < __int_as_float(old))
1343
+ {
1344
+ assumed = old;
1345
+ old = atomicCAS(address_as_int, assumed,
1346
+ __float_as_int(val));
1347
+ }
1348
+
1349
+ return __int_as_float(old);
1350
+
1351
+ #else
1352
+ float old = *address;
1353
+ *address = min(old, val);
1354
+ return old;
1355
+ #endif
1356
+ }
1357
+
1358
+ template<>
1359
+ inline CUDA_CALLABLE float64 atomic_add(float64* buf, float64 value)
1360
+ {
1361
+ #if !defined(__CUDA_ARCH__)
1362
+ float64 old = buf[0];
1363
+ buf[0] += value;
1364
+ return old;
1365
+ #elif defined(__clang__) // CUDA compiled by Clang
1366
+ return atomicAdd(buf, value);
1367
+ #else // CUDA compiled by NVRTC
1368
+
1369
+ /* Define __PTR for atomicAdd prototypes below, undef after done */
1370
+ #if (defined(_MSC_VER) && defined(_WIN64)) || defined(__LP64__) || defined(__CUDACC_RTC__)
1371
+ #define __PTR "l"
1372
+ #else
1373
+ #define __PTR "r"
1374
+ #endif /*(defined(_MSC_VER) && defined(_WIN64)) || defined(__LP64__) || defined(__CUDACC_RTC__)*/
1375
+
1376
+ double r = 0.0;
1377
+
1378
+ #if __CUDA_ARCH__ >= 600
1379
+
1380
+ asm volatile ("{ atom.add.f64 %0,[%1],%2; }\n"
1381
+ : "=d"(r)
1382
+ : __PTR(buf), "d"(value)
1383
+ : "memory");
1384
+ #endif
1385
+
1386
+ return r;
1387
+
1388
+ #undef __PTR
1389
+
1390
+ #endif // CUDA compiled by NVRTC
1391
+
1392
+ }
1393
+
1394
+ // emulate atomic double max with atomicCAS()
1395
+ inline CUDA_CALLABLE double atomic_max(double* address, double val)
1396
+ {
1397
+ #if defined(__CUDA_ARCH__)
1398
+ unsigned long long int *address_as_ull = (unsigned long long int*)address;
1399
+ unsigned long long int old = *address_as_ull, assumed;
1400
+
1401
+ while (val > __longlong_as_double(old))
1402
+ {
1403
+ assumed = old;
1404
+ old = atomicCAS(address_as_ull, assumed,
1405
+ __double_as_longlong(val));
1406
+ }
1407
+
1408
+ return __longlong_as_double(old);
1409
+
1410
+ #else
1411
+ double old = *address;
1412
+ *address = max(old, val);
1413
+ return old;
1414
+ #endif
1415
+ }
1416
+
1417
+ // emulate atomic double min with atomicCAS()
1418
+ inline CUDA_CALLABLE double atomic_min(double* address, double val)
1419
+ {
1420
+ #if defined(__CUDA_ARCH__)
1421
+ unsigned long long int *address_as_ull = (unsigned long long int*)address;
1422
+ unsigned long long int old = *address_as_ull, assumed;
1423
+
1424
+ while (val < __longlong_as_double(old))
1425
+ {
1426
+ assumed = old;
1427
+ old = atomicCAS(address_as_ull, assumed,
1428
+ __double_as_longlong(val));
1429
+ }
1430
+
1431
+ return __longlong_as_double(old);
1432
+
1433
+ #else
1434
+ double old = *address;
1435
+ *address = min(old, val);
1436
+ return old;
1437
+ #endif
1438
+ }
1439
+
1440
+ inline CUDA_CALLABLE int atomic_max(int* address, int val)
1441
+ {
1442
+ #if defined(__CUDA_ARCH__)
1443
+ return atomicMax(address, val);
1444
+
1445
+ #else
1446
+ int old = *address;
1447
+ *address = max(old, val);
1448
+ return old;
1449
+ #endif
1450
+ }
1451
+
1452
+ // atomic int min
1453
+ inline CUDA_CALLABLE int atomic_min(int* address, int val)
1454
+ {
1455
+ #if defined(__CUDA_ARCH__)
1456
+ return atomicMin(address, val);
1457
+
1458
+ #else
1459
+ int old = *address;
1460
+ *address = min(old, val);
1461
+ return old;
1462
+ #endif
1463
+ }
1464
+
1465
+ // default behavior for adjoint of atomic min/max operation that accumulates gradients for all elements matching the min/max value
1466
+ template <typename T>
1467
+ CUDA_CALLABLE inline void adj_atomic_minmax(T *addr, T *adj_addr, const T &value, T &adj_value)
1468
+ {
1469
+ if (value == *addr)
1470
+ adj_value += *adj_addr;
1471
+ }
1472
+
1473
+ // for integral types we do not accumulate gradients
1474
+ CUDA_CALLABLE inline void adj_atomic_minmax(int8* buf, int8* adj_buf, const int8 &value, int8 &adj_value) { }
1475
+ CUDA_CALLABLE inline void adj_atomic_minmax(uint8* buf, uint8* adj_buf, const uint8 &value, uint8 &adj_value) { }
1476
+ CUDA_CALLABLE inline void adj_atomic_minmax(int16* buf, int16* adj_buf, const int16 &value, int16 &adj_value) { }
1477
+ CUDA_CALLABLE inline void adj_atomic_minmax(uint16* buf, uint16* adj_buf, const uint16 &value, uint16 &adj_value) { }
1478
+ CUDA_CALLABLE inline void adj_atomic_minmax(int32* buf, int32* adj_buf, const int32 &value, int32 &adj_value) { }
1479
+ CUDA_CALLABLE inline void adj_atomic_minmax(uint32* buf, uint32* adj_buf, const uint32 &value, uint32 &adj_value) { }
1480
+ CUDA_CALLABLE inline void adj_atomic_minmax(int64* buf, int64* adj_buf, const int64 &value, int64 &adj_value) { }
1481
+ CUDA_CALLABLE inline void adj_atomic_minmax(uint64* buf, uint64* adj_buf, const uint64 &value, uint64 &adj_value) { }
1482
+ CUDA_CALLABLE inline void adj_atomic_minmax(bool* buf, bool* adj_buf, const bool &value, bool &adj_value) { }
1483
+
1484
+
1485
+ } // namespace wp
1486
+
1487
+
1488
+ // bool and printf are defined outside of the wp namespace in crt.h, hence
1489
+ // their adjoint counterparts are also defined in the global namespace.
1490
+ template <typename T>
1491
+ CUDA_CALLABLE inline void adj_bool(T, T&, bool) {}
1492
+ inline CUDA_CALLABLE void adj_printf(const char* fmt, ...) {}
1493
+
1494
+
1495
+ #include "vec.h"
1496
+ #include "mat.h"
1497
+ #include "quat.h"
1498
+ #include "spatial.h"
1499
+ #include "intersect.h"
1500
+ #include "intersect_adj.h"
1501
+
1502
+ //--------------
1503
+ namespace wp
1504
+ {
1505
+
1506
+
1507
+ // dot for scalar types just to make some templates compile for scalar/vector
1508
+ inline CUDA_CALLABLE float dot(float a, float b) { return mul(a, b); }
1509
+ inline CUDA_CALLABLE void adj_dot(float a, float b, float& adj_a, float& adj_b, float adj_ret) { adj_mul(a, b, adj_a, adj_b, adj_ret); }
1510
+ inline CUDA_CALLABLE float tensordot(float a, float b) { return mul(a, b); }
1511
+
1512
+
1513
+ #define DECLARE_INTERP_FUNCS(T) \
1514
+ CUDA_CALLABLE inline T smoothstep(T edge0, T edge1, T x)\
1515
+ {\
1516
+ x = clamp((x - edge0) / (edge1 - edge0), T(0), T(1));\
1517
+ return x * x * (T(3) - T(2) * x);\
1518
+ }\
1519
+ CUDA_CALLABLE inline void adj_smoothstep(T edge0, T edge1, T x, T& adj_edge0, T& adj_edge1, T& adj_x, T adj_ret)\
1520
+ {\
1521
+ T ab = edge0 - edge1;\
1522
+ T ax = edge0 - x;\
1523
+ T bx = edge1 - x;\
1524
+ T xb = x - edge1;\
1525
+ \
1526
+ if (bx / ab >= T(0) || ax / ab <= T(0))\
1527
+ {\
1528
+ return;\
1529
+ }\
1530
+ \
1531
+ T ab3 = ab * ab * ab;\
1532
+ T ab4 = ab3 * ab;\
1533
+ adj_edge0 += adj_ret * ((T(6) * ax * bx * bx) / ab4);\
1534
+ adj_edge1 += adj_ret * ((T(6) * ax * ax * xb) / ab4);\
1535
+ adj_x += adj_ret * ((T(6) * ax * bx ) / ab3);\
1536
+ }\
1537
+ CUDA_CALLABLE inline T lerp(const T& a, const T& b, T t)\
1538
+ {\
1539
+ return a*(T(1)-t) + b*t;\
1540
+ }\
1541
+ CUDA_CALLABLE inline void adj_lerp(const T& a, const T& b, T t, T& adj_a, T& adj_b, T& adj_t, const T& adj_ret)\
1542
+ {\
1543
+ adj_a += adj_ret*(T(1)-t);\
1544
+ adj_b += adj_ret*t;\
1545
+ adj_t += b*adj_ret - a*adj_ret;\
1546
+ }
1547
+
1548
+ DECLARE_INTERP_FUNCS(float16)
1549
+ DECLARE_INTERP_FUNCS(float32)
1550
+ DECLARE_INTERP_FUNCS(float64)
1551
+
1552
+ inline CUDA_CALLABLE void print(const str s)
1553
+ {
1554
+ printf("%s\n", s);
1555
+ }
1556
+
1557
+ inline CUDA_CALLABLE void print(signed char i)
1558
+ {
1559
+ printf("%d\n", i);
1560
+ }
1561
+
1562
+ inline CUDA_CALLABLE void print(short i)
1563
+ {
1564
+ printf("%d\n", i);
1565
+ }
1566
+
1567
+ inline CUDA_CALLABLE void print(int i)
1568
+ {
1569
+ printf("%d\n", i);
1570
+ }
1571
+
1572
+ inline CUDA_CALLABLE void print(long i)
1573
+ {
1574
+ printf("%ld\n", i);
1575
+ }
1576
+
1577
+ inline CUDA_CALLABLE void print(long long i)
1578
+ {
1579
+ printf("%lld\n", i);
1580
+ }
1581
+
1582
+ inline CUDA_CALLABLE void print(unsigned char i)
1583
+ {
1584
+ printf("%u\n", i);
1585
+ }
1586
+
1587
+ inline CUDA_CALLABLE void print(unsigned short i)
1588
+ {
1589
+ printf("%u\n", i);
1590
+ }
1591
+
1592
+ inline CUDA_CALLABLE void print(unsigned int i)
1593
+ {
1594
+ printf("%u\n", i);
1595
+ }
1596
+
1597
+ inline CUDA_CALLABLE void print(unsigned long i)
1598
+ {
1599
+ printf("%lu\n", i);
1600
+ }
1601
+
1602
+ inline CUDA_CALLABLE void print(unsigned long long i)
1603
+ {
1604
+ printf("%llu\n", i);
1605
+ }
1606
+
1607
+ inline CUDA_CALLABLE void print(bool b)
1608
+ {
1609
+ printf(b ? "True\n" : "False\n");
1610
+ }
1611
+
1612
+ template<unsigned Length, typename Type>
1613
+ inline CUDA_CALLABLE void print(vec_t<Length, Type> v)
1614
+ {
1615
+ for( unsigned i=0; i < Length; ++i )
1616
+ {
1617
+ printf("%g ", float(v[i]));
1618
+ }
1619
+ printf("\n");
1620
+ }
1621
+
1622
+ template<typename Type>
1623
+ inline CUDA_CALLABLE void print(quat_t<Type> i)
1624
+ {
1625
+ printf("%g %g %g %g\n", float(i.x), float(i.y), float(i.z), float(i.w));
1626
+ }
1627
+
1628
+ template<unsigned Rows,unsigned Cols,typename Type>
1629
+ inline CUDA_CALLABLE void print(const mat_t<Rows,Cols,Type> &m)
1630
+ {
1631
+ for( unsigned i=0; i< Rows; ++i )
1632
+ {
1633
+ for( unsigned j=0; j< Cols; ++j )
1634
+ {
1635
+ printf("%g ",float(m.data[i][j]));
1636
+ }
1637
+ printf("\n");
1638
+ }
1639
+ }
1640
+
1641
+ template<typename Type>
1642
+ inline CUDA_CALLABLE void print(transform_t<Type> t)
1643
+ {
1644
+ printf("(%g %g %g) (%g %g %g %g)\n", float(t.p[0]), float(t.p[1]), float(t.p[2]), float(t.q.x), float(t.q.y), float(t.q.z), float(t.q.w));
1645
+ }
1646
+
1647
+ template<typename T>
1648
+ inline CUDA_CALLABLE void adj_print(const T& x, const T& adj_x)
1649
+ {
1650
+ printf("adj: <type without print implementation>\n");
1651
+ }
1652
+
1653
+ // note: adj_print() only prints the adjoint value, since the value itself gets printed in replay print()
1654
+ inline CUDA_CALLABLE void adj_print(half x, half adj_x) { printf("adj: %g\n", half_to_float(adj_x)); }
1655
+ inline CUDA_CALLABLE void adj_print(float x, float adj_x) { printf("adj: %g\n", adj_x); }
1656
+ inline CUDA_CALLABLE void adj_print(double x, double adj_x) { printf("adj: %g\n", adj_x); }
1657
+
1658
+ inline CUDA_CALLABLE void adj_print(signed char x, signed char adj_x) { printf("adj: %d\n", adj_x); }
1659
+ inline CUDA_CALLABLE void adj_print(short x, short adj_x) { printf("adj: %d\n", adj_x); }
1660
+ inline CUDA_CALLABLE void adj_print(int x, int adj_x) { printf("adj: %d\n", adj_x); }
1661
+ inline CUDA_CALLABLE void adj_print(long x, long adj_x) { printf("adj: %ld\n", adj_x); }
1662
+ inline CUDA_CALLABLE void adj_print(long long x, long long adj_x) { printf("adj: %lld\n", adj_x); }
1663
+
1664
+ inline CUDA_CALLABLE void adj_print(unsigned char x, unsigned char adj_x) { printf("adj: %u\n", adj_x); }
1665
+ inline CUDA_CALLABLE void adj_print(unsigned short x, unsigned short adj_x) { printf("adj: %u\n", adj_x); }
1666
+ inline CUDA_CALLABLE void adj_print(unsigned x, unsigned adj_x) { printf("adj: %u\n", adj_x); }
1667
+ inline CUDA_CALLABLE void adj_print(unsigned long x, unsigned long adj_x) { printf("adj: %lu\n", adj_x); }
1668
+ inline CUDA_CALLABLE void adj_print(unsigned long long x, unsigned long long adj_x) { printf("adj: %llu\n", adj_x); }
1669
+
1670
+ inline CUDA_CALLABLE void adj_print(bool x, bool adj_x) { printf("adj: %s\n", (adj_x ? "True" : "False")); }
1671
+
1672
+ template<unsigned Length, typename Type>
1673
+ inline CUDA_CALLABLE void adj_print(const vec_t<Length, Type>& v, const vec_t<Length, Type>& adj_v)
1674
+ {
1675
+ printf("adj:");
1676
+ for (unsigned i = 0; i < Length; i++)
1677
+ printf(" %g", float(adj_v[i]));
1678
+ printf("\n");
1679
+ }
1680
+
1681
+ template<unsigned Rows, unsigned Cols, typename Type>
1682
+ inline CUDA_CALLABLE void adj_print(const mat_t<Rows, Cols, Type>& m, const mat_t<Rows, Cols, Type>& adj_m)
1683
+ {
1684
+ for (unsigned i = 0; i < Rows; i++)
1685
+ {
1686
+ if (i == 0)
1687
+ printf("adj:");
1688
+ else
1689
+ printf(" ");
1690
+ for (unsigned j = 0; j < Cols; j++)
1691
+ printf(" %g", float(adj_m.data[i][j]));
1692
+ printf("\n");
1693
+ }
1694
+ }
1695
+
1696
+ template<typename Type>
1697
+ inline CUDA_CALLABLE void adj_print(const quat_t<Type>& q, const quat_t<Type>& adj_q)
1698
+ {
1699
+ printf("adj: %g %g %g %g\n", float(adj_q.x), float(adj_q.y), float(adj_q.z), float(adj_q.w));
1700
+ }
1701
+
1702
+ template<typename Type>
1703
+ inline CUDA_CALLABLE void adj_print(const transform_t<Type>& t, const transform_t<Type>& adj_t)
1704
+ {
1705
+ printf("adj: (%g %g %g) (%g %g %g %g)\n",
1706
+ float(adj_t.p[0]), float(adj_t.p[1]), float(adj_t.p[2]),
1707
+ float(adj_t.q.x), float(adj_t.q.y), float(adj_t.q.z), float(adj_t.q.w));
1708
+ }
1709
+
1710
+ inline CUDA_CALLABLE void adj_print(str t, str& adj_t)
1711
+ {
1712
+ printf("adj: %s\n", t);
1713
+ }
1714
+
1715
+ template <typename T>
1716
+ inline CUDA_CALLABLE void expect_eq(const T& actual, const T& expected)
1717
+ {
1718
+ if (!(actual == expected))
1719
+ {
1720
+ printf("Error, expect_eq() failed:\n");
1721
+ printf("\t Expected: "); print(expected);
1722
+ printf("\t Actual: "); print(actual);
1723
+ }
1724
+ }
1725
+
1726
+ template <typename T>
1727
+ inline CUDA_CALLABLE void adj_expect_eq(const T& a, const T& b, T& adj_a, T& adj_b)
1728
+ {
1729
+ // nop
1730
+ }
1731
+
1732
+ template <typename T>
1733
+ inline CUDA_CALLABLE void expect_neq(const T& actual, const T& expected)
1734
+ {
1735
+ if (actual == expected)
1736
+ {
1737
+ printf("Error, expect_neq() failed:\n");
1738
+ printf("\t Expected: "); print(expected);
1739
+ printf("\t Actual: "); print(actual);
1740
+ }
1741
+ }
1742
+
1743
+ template <typename T>
1744
+ inline CUDA_CALLABLE void adj_expect_neq(const T& a, const T& b, T& adj_a, T& adj_b)
1745
+ {
1746
+ // nop
1747
+ }
1748
+
1749
+ template <typename T>
1750
+ inline CUDA_CALLABLE void expect_near(const T& actual, const T& expected, const T& tolerance)
1751
+ {
1752
+ if (abs(actual - expected) > tolerance)
1753
+ {
1754
+ printf("Error, expect_near() failed with tolerance "); print(tolerance);
1755
+ printf("\t Expected: "); print(expected);
1756
+ printf("\t Actual: "); print(actual);
1757
+ }
1758
+ }
1759
+
1760
+ inline CUDA_CALLABLE void expect_near(const vec3& actual, const vec3& expected, const float& tolerance)
1761
+ {
1762
+ const float diff = max(max(abs(actual[0] - expected[0]), abs(actual[1] - expected[1])), abs(actual[2] - expected[2]));
1763
+ if (diff > tolerance)
1764
+ {
1765
+ printf("Error, expect_near() failed with tolerance "); print(tolerance);
1766
+ printf("\t Expected: "); print(expected);
1767
+ printf("\t Actual: "); print(actual);
1768
+ }
1769
+ }
1770
+
1771
+ template <typename T>
1772
+ inline CUDA_CALLABLE void adj_expect_near(const T& actual, const T& expected, const T& tolerance, T& adj_actual, T& adj_expected, T& adj_tolerance)
1773
+ {
1774
+ // nop
1775
+ }
1776
+
1777
+ inline CUDA_CALLABLE void adj_expect_near(const vec3& actual, const vec3& expected, float tolerance, vec3& adj_actual, vec3& adj_expected, float adj_tolerance)
1778
+ {
1779
+ // nop
1780
+ }
1781
+
1782
+
1783
+ } // namespace wp
1784
+
1785
+ // include array.h so we have the print, isfinite functions for the inner array types defined
1786
+ #include "array.h"
1787
+ #include "mesh.h"
1788
+ #include "bvh.h"
1789
+ #include "svd.h"
1790
+ #include "hashgrid.h"
1791
+ #include "volume.h"
1792
+ #include "range.h"
1793
+ #include "rand.h"
1794
+ #include "noise.h"
1795
+ #include "matnn.h"
1796
+
1797
+ #if !defined(WP_ENABLE_CUDA) // only include in kernels for now
1798
+ #include "tile.h"
1799
+ #include "tile_reduce.h"
1800
+ #endif //!defined(WP_ENABLE_CUDA)