warp-lang 1.10.0__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (468) hide show
  1. warp/__init__.py +334 -0
  2. warp/__init__.pyi +5856 -0
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1077 -0
  5. warp/_src/build.py +620 -0
  6. warp/_src/build_dll.py +642 -0
  7. warp/_src/builtins.py +10555 -0
  8. warp/_src/codegen.py +4361 -0
  9. warp/_src/config.py +178 -0
  10. warp/_src/constants.py +59 -0
  11. warp/_src/context.py +8352 -0
  12. warp/_src/dlpack.py +464 -0
  13. warp/_src/fabric.py +362 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +510 -0
  16. warp/_src/fem/cache.py +689 -0
  17. warp/_src/fem/dirichlet.py +190 -0
  18. warp/_src/fem/domain.py +553 -0
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +703 -0
  21. warp/_src/fem/field/nodal_field.py +403 -0
  22. warp/_src/fem/field/restriction.py +39 -0
  23. warp/_src/fem/field/virtual.py +1021 -0
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/_src/fem/geometry/adaptive_nanogrid.py +782 -0
  26. warp/_src/fem/geometry/closest_point.py +99 -0
  27. warp/_src/fem/geometry/deformed_geometry.py +277 -0
  28. warp/_src/fem/geometry/element.py +854 -0
  29. warp/_src/fem/geometry/geometry.py +693 -0
  30. warp/_src/fem/geometry/grid_2d.py +478 -0
  31. warp/_src/fem/geometry/grid_3d.py +539 -0
  32. warp/_src/fem/geometry/hexmesh.py +956 -0
  33. warp/_src/fem/geometry/nanogrid.py +660 -0
  34. warp/_src/fem/geometry/partition.py +483 -0
  35. warp/_src/fem/geometry/quadmesh.py +597 -0
  36. warp/_src/fem/geometry/tetmesh.py +762 -0
  37. warp/_src/fem/geometry/trimesh.py +588 -0
  38. warp/_src/fem/integrate.py +2507 -0
  39. warp/_src/fem/linalg.py +385 -0
  40. warp/_src/fem/operator.py +398 -0
  41. warp/_src/fem/polynomial.py +231 -0
  42. warp/_src/fem/quadrature/__init__.py +17 -0
  43. warp/_src/fem/quadrature/pic_quadrature.py +318 -0
  44. warp/_src/fem/quadrature/quadrature.py +665 -0
  45. warp/_src/fem/space/__init__.py +248 -0
  46. warp/_src/fem/space/basis_function_space.py +499 -0
  47. warp/_src/fem/space/basis_space.py +681 -0
  48. warp/_src/fem/space/dof_mapper.py +253 -0
  49. warp/_src/fem/space/function_space.py +312 -0
  50. warp/_src/fem/space/grid_2d_function_space.py +179 -0
  51. warp/_src/fem/space/grid_3d_function_space.py +229 -0
  52. warp/_src/fem/space/hexmesh_function_space.py +255 -0
  53. warp/_src/fem/space/nanogrid_function_space.py +199 -0
  54. warp/_src/fem/space/partition.py +435 -0
  55. warp/_src/fem/space/quadmesh_function_space.py +222 -0
  56. warp/_src/fem/space/restriction.py +221 -0
  57. warp/_src/fem/space/shape/__init__.py +152 -0
  58. warp/_src/fem/space/shape/cube_shape_function.py +1107 -0
  59. warp/_src/fem/space/shape/shape_function.py +134 -0
  60. warp/_src/fem/space/shape/square_shape_function.py +928 -0
  61. warp/_src/fem/space/shape/tet_shape_function.py +829 -0
  62. warp/_src/fem/space/shape/triangle_shape_function.py +674 -0
  63. warp/_src/fem/space/tetmesh_function_space.py +270 -0
  64. warp/_src/fem/space/topology.py +461 -0
  65. warp/_src/fem/space/trimesh_function_space.py +193 -0
  66. warp/_src/fem/types.py +114 -0
  67. warp/_src/fem/utils.py +488 -0
  68. warp/_src/jax.py +188 -0
  69. warp/_src/jax_experimental/__init__.py +14 -0
  70. warp/_src/jax_experimental/custom_call.py +389 -0
  71. warp/_src/jax_experimental/ffi.py +1286 -0
  72. warp/_src/jax_experimental/xla_ffi.py +658 -0
  73. warp/_src/marching_cubes.py +710 -0
  74. warp/_src/math.py +416 -0
  75. warp/_src/optim/__init__.py +14 -0
  76. warp/_src/optim/adam.py +165 -0
  77. warp/_src/optim/linear.py +1608 -0
  78. warp/_src/optim/sgd.py +114 -0
  79. warp/_src/paddle.py +408 -0
  80. warp/_src/render/__init__.py +14 -0
  81. warp/_src/render/imgui_manager.py +291 -0
  82. warp/_src/render/render_opengl.py +3638 -0
  83. warp/_src/render/render_usd.py +939 -0
  84. warp/_src/render/utils.py +162 -0
  85. warp/_src/sparse.py +2718 -0
  86. warp/_src/tape.py +1208 -0
  87. warp/_src/thirdparty/__init__.py +0 -0
  88. warp/_src/thirdparty/appdirs.py +598 -0
  89. warp/_src/thirdparty/dlpack.py +145 -0
  90. warp/_src/thirdparty/unittest_parallel.py +676 -0
  91. warp/_src/torch.py +393 -0
  92. warp/_src/types.py +5888 -0
  93. warp/_src/utils.py +1695 -0
  94. warp/autograd.py +33 -0
  95. warp/bin/libwarp-clang.dylib +0 -0
  96. warp/bin/libwarp.dylib +0 -0
  97. warp/build.py +29 -0
  98. warp/build_dll.py +24 -0
  99. warp/codegen.py +24 -0
  100. warp/constants.py +24 -0
  101. warp/context.py +33 -0
  102. warp/dlpack.py +24 -0
  103. warp/examples/__init__.py +24 -0
  104. warp/examples/assets/bear.usd +0 -0
  105. warp/examples/assets/bunny.usd +0 -0
  106. warp/examples/assets/cube.usd +0 -0
  107. warp/examples/assets/nonuniform.usd +0 -0
  108. warp/examples/assets/nvidia_logo.png +0 -0
  109. warp/examples/assets/pixel.jpg +0 -0
  110. warp/examples/assets/rocks.nvdb +0 -0
  111. warp/examples/assets/rocks.usd +0 -0
  112. warp/examples/assets/sphere.usd +0 -0
  113. warp/examples/assets/square_cloth.usd +0 -0
  114. warp/examples/benchmarks/benchmark_api.py +389 -0
  115. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  116. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  117. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  118. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  119. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  120. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  121. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  122. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  123. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  124. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  125. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  126. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  127. warp/examples/benchmarks/benchmark_launches.py +301 -0
  128. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  129. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  130. warp/examples/browse.py +37 -0
  131. warp/examples/core/example_cupy.py +86 -0
  132. warp/examples/core/example_dem.py +241 -0
  133. warp/examples/core/example_fluid.py +299 -0
  134. warp/examples/core/example_graph_capture.py +150 -0
  135. warp/examples/core/example_marching_cubes.py +195 -0
  136. warp/examples/core/example_mesh.py +180 -0
  137. warp/examples/core/example_mesh_intersect.py +211 -0
  138. warp/examples/core/example_nvdb.py +182 -0
  139. warp/examples/core/example_raycast.py +111 -0
  140. warp/examples/core/example_raymarch.py +205 -0
  141. warp/examples/core/example_render_opengl.py +290 -0
  142. warp/examples/core/example_sample_mesh.py +300 -0
  143. warp/examples/core/example_sph.py +411 -0
  144. warp/examples/core/example_spin_lock.py +93 -0
  145. warp/examples/core/example_torch.py +211 -0
  146. warp/examples/core/example_wave.py +269 -0
  147. warp/examples/core/example_work_queue.py +118 -0
  148. warp/examples/distributed/example_jacobi_mpi.py +506 -0
  149. warp/examples/fem/example_adaptive_grid.py +286 -0
  150. warp/examples/fem/example_apic_fluid.py +469 -0
  151. warp/examples/fem/example_burgers.py +261 -0
  152. warp/examples/fem/example_convection_diffusion.py +181 -0
  153. warp/examples/fem/example_convection_diffusion_dg.py +225 -0
  154. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  155. warp/examples/fem/example_deformed_geometry.py +172 -0
  156. warp/examples/fem/example_diffusion.py +196 -0
  157. warp/examples/fem/example_diffusion_3d.py +225 -0
  158. warp/examples/fem/example_diffusion_mgpu.py +225 -0
  159. warp/examples/fem/example_distortion_energy.py +228 -0
  160. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  161. warp/examples/fem/example_magnetostatics.py +242 -0
  162. warp/examples/fem/example_mixed_elasticity.py +293 -0
  163. warp/examples/fem/example_navier_stokes.py +263 -0
  164. warp/examples/fem/example_nonconforming_contact.py +300 -0
  165. warp/examples/fem/example_stokes.py +213 -0
  166. warp/examples/fem/example_stokes_transfer.py +262 -0
  167. warp/examples/fem/example_streamlines.py +357 -0
  168. warp/examples/fem/utils.py +1047 -0
  169. warp/examples/interop/example_jax_callable.py +146 -0
  170. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  171. warp/examples/interop/example_jax_kernel.py +232 -0
  172. warp/examples/optim/example_diffray.py +561 -0
  173. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  174. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  175. warp/examples/tile/example_tile_cholesky.py +88 -0
  176. warp/examples/tile/example_tile_convolution.py +66 -0
  177. warp/examples/tile/example_tile_fft.py +55 -0
  178. warp/examples/tile/example_tile_filtering.py +113 -0
  179. warp/examples/tile/example_tile_matmul.py +85 -0
  180. warp/examples/tile/example_tile_mcgp.py +191 -0
  181. warp/examples/tile/example_tile_mlp.py +385 -0
  182. warp/examples/tile/example_tile_nbody.py +199 -0
  183. warp/fabric.py +24 -0
  184. warp/fem/__init__.py +173 -0
  185. warp/fem/adaptivity.py +26 -0
  186. warp/fem/cache.py +30 -0
  187. warp/fem/dirichlet.py +24 -0
  188. warp/fem/field/__init__.py +24 -0
  189. warp/fem/field/field.py +26 -0
  190. warp/fem/geometry/__init__.py +21 -0
  191. warp/fem/geometry/closest_point.py +31 -0
  192. warp/fem/linalg.py +38 -0
  193. warp/fem/operator.py +32 -0
  194. warp/fem/polynomial.py +29 -0
  195. warp/fem/space/__init__.py +22 -0
  196. warp/fem/space/basis_space.py +24 -0
  197. warp/fem/space/shape/__init__.py +68 -0
  198. warp/fem/space/topology.py +24 -0
  199. warp/fem/types.py +24 -0
  200. warp/fem/utils.py +32 -0
  201. warp/jax.py +29 -0
  202. warp/jax_experimental/__init__.py +29 -0
  203. warp/jax_experimental/custom_call.py +29 -0
  204. warp/jax_experimental/ffi.py +39 -0
  205. warp/jax_experimental/xla_ffi.py +24 -0
  206. warp/marching_cubes.py +24 -0
  207. warp/math.py +37 -0
  208. warp/native/array.h +1687 -0
  209. warp/native/builtin.h +2327 -0
  210. warp/native/bvh.cpp +562 -0
  211. warp/native/bvh.cu +826 -0
  212. warp/native/bvh.h +555 -0
  213. warp/native/clang/clang.cpp +541 -0
  214. warp/native/coloring.cpp +622 -0
  215. warp/native/crt.cpp +51 -0
  216. warp/native/crt.h +568 -0
  217. warp/native/cuda_crt.h +1058 -0
  218. warp/native/cuda_util.cpp +677 -0
  219. warp/native/cuda_util.h +313 -0
  220. warp/native/error.cpp +77 -0
  221. warp/native/error.h +36 -0
  222. warp/native/exports.h +2023 -0
  223. warp/native/fabric.h +246 -0
  224. warp/native/hashgrid.cpp +311 -0
  225. warp/native/hashgrid.cu +89 -0
  226. warp/native/hashgrid.h +240 -0
  227. warp/native/initializer_array.h +41 -0
  228. warp/native/intersect.h +1253 -0
  229. warp/native/intersect_adj.h +375 -0
  230. warp/native/intersect_tri.h +348 -0
  231. warp/native/mat.h +5189 -0
  232. warp/native/mathdx.cpp +93 -0
  233. warp/native/matnn.h +221 -0
  234. warp/native/mesh.cpp +266 -0
  235. warp/native/mesh.cu +406 -0
  236. warp/native/mesh.h +2097 -0
  237. warp/native/nanovdb/GridHandle.h +533 -0
  238. warp/native/nanovdb/HostBuffer.h +591 -0
  239. warp/native/nanovdb/NanoVDB.h +6246 -0
  240. warp/native/nanovdb/NodeManager.h +323 -0
  241. warp/native/nanovdb/PNanoVDB.h +3390 -0
  242. warp/native/noise.h +859 -0
  243. warp/native/quat.h +1664 -0
  244. warp/native/rand.h +342 -0
  245. warp/native/range.h +145 -0
  246. warp/native/reduce.cpp +174 -0
  247. warp/native/reduce.cu +363 -0
  248. warp/native/runlength_encode.cpp +79 -0
  249. warp/native/runlength_encode.cu +61 -0
  250. warp/native/scan.cpp +47 -0
  251. warp/native/scan.cu +55 -0
  252. warp/native/scan.h +23 -0
  253. warp/native/solid_angle.h +466 -0
  254. warp/native/sort.cpp +251 -0
  255. warp/native/sort.cu +286 -0
  256. warp/native/sort.h +35 -0
  257. warp/native/sparse.cpp +241 -0
  258. warp/native/sparse.cu +435 -0
  259. warp/native/spatial.h +1306 -0
  260. warp/native/svd.h +727 -0
  261. warp/native/temp_buffer.h +46 -0
  262. warp/native/tile.h +4124 -0
  263. warp/native/tile_radix_sort.h +1112 -0
  264. warp/native/tile_reduce.h +838 -0
  265. warp/native/tile_scan.h +240 -0
  266. warp/native/tuple.h +189 -0
  267. warp/native/vec.h +2199 -0
  268. warp/native/version.h +23 -0
  269. warp/native/volume.cpp +501 -0
  270. warp/native/volume.cu +68 -0
  271. warp/native/volume.h +970 -0
  272. warp/native/volume_builder.cu +483 -0
  273. warp/native/volume_builder.h +52 -0
  274. warp/native/volume_impl.h +70 -0
  275. warp/native/warp.cpp +1143 -0
  276. warp/native/warp.cu +4604 -0
  277. warp/native/warp.h +358 -0
  278. warp/optim/__init__.py +20 -0
  279. warp/optim/adam.py +24 -0
  280. warp/optim/linear.py +35 -0
  281. warp/optim/sgd.py +24 -0
  282. warp/paddle.py +24 -0
  283. warp/py.typed +0 -0
  284. warp/render/__init__.py +22 -0
  285. warp/render/imgui_manager.py +29 -0
  286. warp/render/render_opengl.py +24 -0
  287. warp/render/render_usd.py +24 -0
  288. warp/render/utils.py +24 -0
  289. warp/sparse.py +51 -0
  290. warp/tape.py +24 -0
  291. warp/tests/__init__.py +1 -0
  292. warp/tests/__main__.py +4 -0
  293. warp/tests/assets/curlnoise_golden.npy +0 -0
  294. warp/tests/assets/mlp_golden.npy +0 -0
  295. warp/tests/assets/pixel.npy +0 -0
  296. warp/tests/assets/pnoise_golden.npy +0 -0
  297. warp/tests/assets/spiky.usd +0 -0
  298. warp/tests/assets/test_grid.nvdb +0 -0
  299. warp/tests/assets/test_index_grid.nvdb +0 -0
  300. warp/tests/assets/test_int32_grid.nvdb +0 -0
  301. warp/tests/assets/test_vec_grid.nvdb +0 -0
  302. warp/tests/assets/torus.nvdb +0 -0
  303. warp/tests/assets/torus.usda +105 -0
  304. warp/tests/aux_test_class_kernel.py +34 -0
  305. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  306. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  307. warp/tests/aux_test_dependent.py +29 -0
  308. warp/tests/aux_test_grad_customs.py +29 -0
  309. warp/tests/aux_test_instancing_gc.py +26 -0
  310. warp/tests/aux_test_module_aot.py +7 -0
  311. warp/tests/aux_test_module_unload.py +23 -0
  312. warp/tests/aux_test_name_clash1.py +40 -0
  313. warp/tests/aux_test_name_clash2.py +40 -0
  314. warp/tests/aux_test_reference.py +9 -0
  315. warp/tests/aux_test_reference_reference.py +8 -0
  316. warp/tests/aux_test_square.py +16 -0
  317. warp/tests/aux_test_unresolved_func.py +22 -0
  318. warp/tests/aux_test_unresolved_symbol.py +22 -0
  319. warp/tests/cuda/__init__.py +0 -0
  320. warp/tests/cuda/test_async.py +676 -0
  321. warp/tests/cuda/test_conditional_captures.py +1147 -0
  322. warp/tests/cuda/test_ipc.py +124 -0
  323. warp/tests/cuda/test_mempool.py +233 -0
  324. warp/tests/cuda/test_multigpu.py +169 -0
  325. warp/tests/cuda/test_peer.py +139 -0
  326. warp/tests/cuda/test_pinned.py +84 -0
  327. warp/tests/cuda/test_streams.py +691 -0
  328. warp/tests/geometry/__init__.py +0 -0
  329. warp/tests/geometry/test_bvh.py +335 -0
  330. warp/tests/geometry/test_hash_grid.py +259 -0
  331. warp/tests/geometry/test_marching_cubes.py +294 -0
  332. warp/tests/geometry/test_mesh.py +318 -0
  333. warp/tests/geometry/test_mesh_query_aabb.py +392 -0
  334. warp/tests/geometry/test_mesh_query_point.py +935 -0
  335. warp/tests/geometry/test_mesh_query_ray.py +323 -0
  336. warp/tests/geometry/test_volume.py +1103 -0
  337. warp/tests/geometry/test_volume_write.py +346 -0
  338. warp/tests/interop/__init__.py +0 -0
  339. warp/tests/interop/test_dlpack.py +730 -0
  340. warp/tests/interop/test_jax.py +1673 -0
  341. warp/tests/interop/test_paddle.py +800 -0
  342. warp/tests/interop/test_torch.py +1001 -0
  343. warp/tests/run_coverage_serial.py +39 -0
  344. warp/tests/test_adam.py +162 -0
  345. warp/tests/test_arithmetic.py +1096 -0
  346. warp/tests/test_array.py +3756 -0
  347. warp/tests/test_array_reduce.py +156 -0
  348. warp/tests/test_assert.py +303 -0
  349. warp/tests/test_atomic.py +336 -0
  350. warp/tests/test_atomic_bitwise.py +209 -0
  351. warp/tests/test_atomic_cas.py +312 -0
  352. warp/tests/test_bool.py +220 -0
  353. warp/tests/test_builtins_resolution.py +732 -0
  354. warp/tests/test_closest_point_edge_edge.py +327 -0
  355. warp/tests/test_codegen.py +974 -0
  356. warp/tests/test_codegen_instancing.py +1495 -0
  357. warp/tests/test_compile_consts.py +215 -0
  358. warp/tests/test_conditional.py +298 -0
  359. warp/tests/test_context.py +35 -0
  360. warp/tests/test_copy.py +319 -0
  361. warp/tests/test_ctypes.py +618 -0
  362. warp/tests/test_dense.py +73 -0
  363. warp/tests/test_devices.py +127 -0
  364. warp/tests/test_enum.py +136 -0
  365. warp/tests/test_examples.py +424 -0
  366. warp/tests/test_fabricarray.py +998 -0
  367. warp/tests/test_fast_math.py +72 -0
  368. warp/tests/test_fem.py +2204 -0
  369. warp/tests/test_fixedarray.py +229 -0
  370. warp/tests/test_fp16.py +136 -0
  371. warp/tests/test_func.py +501 -0
  372. warp/tests/test_future_annotations.py +100 -0
  373. warp/tests/test_generics.py +656 -0
  374. warp/tests/test_grad.py +893 -0
  375. warp/tests/test_grad_customs.py +339 -0
  376. warp/tests/test_grad_debug.py +341 -0
  377. warp/tests/test_implicit_init.py +411 -0
  378. warp/tests/test_import.py +45 -0
  379. warp/tests/test_indexedarray.py +1140 -0
  380. warp/tests/test_intersect.py +103 -0
  381. warp/tests/test_iter.py +76 -0
  382. warp/tests/test_large.py +177 -0
  383. warp/tests/test_launch.py +411 -0
  384. warp/tests/test_lerp.py +151 -0
  385. warp/tests/test_linear_solvers.py +223 -0
  386. warp/tests/test_lvalue.py +427 -0
  387. warp/tests/test_map.py +526 -0
  388. warp/tests/test_mat.py +3515 -0
  389. warp/tests/test_mat_assign_copy.py +178 -0
  390. warp/tests/test_mat_constructors.py +573 -0
  391. warp/tests/test_mat_lite.py +122 -0
  392. warp/tests/test_mat_scalar_ops.py +2913 -0
  393. warp/tests/test_math.py +212 -0
  394. warp/tests/test_module_aot.py +287 -0
  395. warp/tests/test_module_hashing.py +258 -0
  396. warp/tests/test_modules_lite.py +70 -0
  397. warp/tests/test_noise.py +252 -0
  398. warp/tests/test_operators.py +299 -0
  399. warp/tests/test_options.py +129 -0
  400. warp/tests/test_overwrite.py +551 -0
  401. warp/tests/test_print.py +408 -0
  402. warp/tests/test_quat.py +2653 -0
  403. warp/tests/test_quat_assign_copy.py +145 -0
  404. warp/tests/test_rand.py +339 -0
  405. warp/tests/test_reload.py +303 -0
  406. warp/tests/test_rounding.py +157 -0
  407. warp/tests/test_runlength_encode.py +196 -0
  408. warp/tests/test_scalar_ops.py +133 -0
  409. warp/tests/test_smoothstep.py +108 -0
  410. warp/tests/test_snippet.py +318 -0
  411. warp/tests/test_sparse.py +845 -0
  412. warp/tests/test_spatial.py +2859 -0
  413. warp/tests/test_spatial_assign_copy.py +160 -0
  414. warp/tests/test_special_values.py +361 -0
  415. warp/tests/test_static.py +640 -0
  416. warp/tests/test_struct.py +901 -0
  417. warp/tests/test_tape.py +242 -0
  418. warp/tests/test_transient_module.py +93 -0
  419. warp/tests/test_triangle_closest_point.py +192 -0
  420. warp/tests/test_tuple.py +361 -0
  421. warp/tests/test_types.py +615 -0
  422. warp/tests/test_utils.py +594 -0
  423. warp/tests/test_vec.py +1408 -0
  424. warp/tests/test_vec_assign_copy.py +143 -0
  425. warp/tests/test_vec_constructors.py +325 -0
  426. warp/tests/test_vec_lite.py +80 -0
  427. warp/tests/test_vec_scalar_ops.py +2327 -0
  428. warp/tests/test_verify_fp.py +100 -0
  429. warp/tests/test_version.py +75 -0
  430. warp/tests/tile/__init__.py +0 -0
  431. warp/tests/tile/test_tile.py +1519 -0
  432. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  433. warp/tests/tile/test_tile_cholesky.py +608 -0
  434. warp/tests/tile/test_tile_load.py +724 -0
  435. warp/tests/tile/test_tile_mathdx.py +156 -0
  436. warp/tests/tile/test_tile_matmul.py +179 -0
  437. warp/tests/tile/test_tile_mlp.py +400 -0
  438. warp/tests/tile/test_tile_reduce.py +950 -0
  439. warp/tests/tile/test_tile_shared_memory.py +376 -0
  440. warp/tests/tile/test_tile_sort.py +121 -0
  441. warp/tests/tile/test_tile_view.py +173 -0
  442. warp/tests/unittest_serial.py +47 -0
  443. warp/tests/unittest_suites.py +430 -0
  444. warp/tests/unittest_utils.py +469 -0
  445. warp/tests/walkthrough_debug.py +95 -0
  446. warp/torch.py +24 -0
  447. warp/types.py +51 -0
  448. warp/utils.py +31 -0
  449. warp_lang-1.10.0.dist-info/METADATA +459 -0
  450. warp_lang-1.10.0.dist-info/RECORD +468 -0
  451. warp_lang-1.10.0.dist-info/WHEEL +5 -0
  452. warp_lang-1.10.0.dist-info/licenses/LICENSE.md +176 -0
  453. warp_lang-1.10.0.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  454. warp_lang-1.10.0.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  455. warp_lang-1.10.0.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  456. warp_lang-1.10.0.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  457. warp_lang-1.10.0.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  458. warp_lang-1.10.0.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  459. warp_lang-1.10.0.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  460. warp_lang-1.10.0.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  461. warp_lang-1.10.0.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  462. warp_lang-1.10.0.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  463. warp_lang-1.10.0.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  464. warp_lang-1.10.0.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  465. warp_lang-1.10.0.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  466. warp_lang-1.10.0.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  467. warp_lang-1.10.0.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  468. warp_lang-1.10.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1096 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import unittest
17
+
18
+ import numpy as np
19
+
20
+ import warp as wp
21
+ from warp.tests.unittest_utils import *
22
+
23
+ np_signed_int_types = [
24
+ np.int8,
25
+ np.int16,
26
+ np.int32,
27
+ np.int64,
28
+ np.byte,
29
+ ]
30
+
31
+ np_unsigned_int_types = [
32
+ np.uint8,
33
+ np.uint16,
34
+ np.uint32,
35
+ np.uint64,
36
+ np.ubyte,
37
+ ]
38
+
39
+ np_int_types = np_signed_int_types + np_unsigned_int_types
40
+
41
+ np_float_types = [np.float16, np.float32, np.float64]
42
+
43
+ np_scalar_types = np_int_types + np_float_types
44
+
45
+
46
+ def randvals(rng, shape, dtype):
47
+ if dtype in np_float_types:
48
+ return rng.standard_normal(size=shape).astype(dtype)
49
+ elif dtype in [np.int8, np.uint8, np.byte, np.ubyte]:
50
+ return rng.integers(1, high=3, size=shape, dtype=dtype)
51
+ return rng.integers(1, high=5, size=shape, dtype=dtype)
52
+
53
+
54
+ kernel_cache = {}
55
+
56
+
57
+ def getkernel(func, suffix=""):
58
+ key = func.__name__ + "_" + suffix
59
+ if key not in kernel_cache:
60
+ kernel_cache[key] = wp.Kernel(func=func, key=key)
61
+ return kernel_cache[key]
62
+
63
+
64
+ def get_select_kernel(dtype):
65
+ def output_select_kernel_fn(
66
+ input: wp.array(dtype=dtype),
67
+ index: int,
68
+ out: wp.array(dtype=dtype),
69
+ ):
70
+ out[0] = input[index]
71
+
72
+ return getkernel(output_select_kernel_fn, suffix=dtype.__name__)
73
+
74
+
75
+ def get_select_kernel2(dtype):
76
+ def output_select_kernel2_fn(
77
+ input: wp.array(dtype=dtype, ndim=2),
78
+ index0: int,
79
+ index1: int,
80
+ out: wp.array(dtype=dtype),
81
+ ):
82
+ out[0] = input[index0, index1]
83
+
84
+ return getkernel(output_select_kernel2_fn, suffix=dtype.__name__)
85
+
86
+
87
+ def test_arrays(test, device, dtype):
88
+ rng = np.random.default_rng(123)
89
+
90
+ tol = {
91
+ np.float16: 1.0e-3,
92
+ np.float32: 1.0e-6,
93
+ np.float64: 1.0e-8,
94
+ }.get(dtype, 0)
95
+
96
+ wptype = wp._src.types.np_dtype_to_warp_type[np.dtype(dtype)]
97
+ arr_np = randvals(rng, (10, 5), dtype)
98
+ arr = wp.array(arr_np, dtype=wptype, requires_grad=True, device=device)
99
+
100
+ assert_np_equal(arr.numpy(), arr_np, tol=tol)
101
+
102
+
103
+ def test_unary_ops(test, device, dtype, register_kernels=False):
104
+ rng = np.random.default_rng(123)
105
+
106
+ tol = {
107
+ np.float16: 5.0e-3,
108
+ np.float32: 1.0e-6,
109
+ np.float64: 1.0e-8,
110
+ }.get(dtype, 0)
111
+
112
+ wptype = wp._src.types.np_dtype_to_warp_type[np.dtype(dtype)]
113
+
114
+ def check_unary(
115
+ inputs: wp.array(dtype=wptype, ndim=2),
116
+ outputs: wp.array(dtype=wptype, ndim=2),
117
+ ):
118
+ for i in range(10):
119
+ i0 = inputs[0, i]
120
+ i1 = inputs[1, i]
121
+ i2 = inputs[2, i]
122
+ i3 = inputs[3, i]
123
+ i4 = inputs[4, i]
124
+
125
+ # multiply outputs by 2 so we've got something to backpropagate:
126
+ outputs[0, i] = wptype(2.0) * (+i0)
127
+ outputs[1, i] = wptype(2.0) * (-i1)
128
+ outputs[2, i] = wptype(2.0) * wp.sign(i2)
129
+ outputs[3, i] = wptype(2.0) * wp.abs(i3)
130
+ outputs[4, i] = wptype(2.0) * wp.step(i4)
131
+
132
+ kernel = getkernel(check_unary, suffix=dtype.__name__)
133
+ output_select_kernel = get_select_kernel2(wptype)
134
+
135
+ if register_kernels:
136
+ return
137
+
138
+ if dtype in np_float_types:
139
+ inputs = wp.array(
140
+ rng.standard_normal(size=(5, 10)).astype(dtype), dtype=wptype, requires_grad=True, device=device
141
+ )
142
+ else:
143
+ inputs = wp.array(
144
+ rng.integers(-2, high=3, size=(5, 10), dtype=dtype), dtype=wptype, requires_grad=True, device=device
145
+ )
146
+ outputs = wp.zeros_like(inputs)
147
+
148
+ wp.launch(kernel, dim=1, inputs=[inputs], outputs=[outputs], device=device)
149
+ assert_np_equal(outputs.numpy()[0], 2 * inputs.numpy()[0], tol=tol)
150
+ assert_np_equal(outputs.numpy()[1], -2 * inputs.numpy()[1], tol=tol)
151
+ expected = 2 * np.sign(inputs.numpy()[2])
152
+ expected[expected == 0] = 2
153
+ assert_np_equal(outputs.numpy()[2], expected, tol=tol)
154
+ assert_np_equal(outputs.numpy()[3], 2 * np.abs(inputs.numpy()[3]), tol=tol)
155
+ assert_np_equal(outputs.numpy()[4], 2 * (1 - np.heaviside(inputs.numpy()[4], 1)), tol=tol)
156
+
157
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
158
+ if dtype in np_float_types:
159
+ for i in range(10):
160
+ # grad of 2x:
161
+ tape = wp.Tape()
162
+ with tape:
163
+ wp.launch(kernel, dim=1, inputs=[inputs], outputs=[outputs], device=device)
164
+ wp.launch(output_select_kernel, dim=1, inputs=[outputs, 0, i], outputs=[out], device=device)
165
+
166
+ tape.backward(loss=out)
167
+ expected_grads = np.zeros_like(inputs.numpy())
168
+ expected_grads[0, i] = 2
169
+ assert_np_equal(tape.gradients[inputs].numpy(), expected_grads, tol=tol)
170
+ tape.zero()
171
+
172
+ # grad of -2x:
173
+ tape = wp.Tape()
174
+ with tape:
175
+ wp.launch(kernel, dim=1, inputs=[inputs], outputs=[outputs], device=device)
176
+ wp.launch(output_select_kernel, dim=1, inputs=[outputs, 1, i], outputs=[out], device=device)
177
+
178
+ tape.backward(loss=out)
179
+ expected_grads = np.zeros_like(inputs.numpy())
180
+ expected_grads[1, i] = -2
181
+ assert_np_equal(tape.gradients[inputs].numpy(), expected_grads, tol=tol)
182
+ tape.zero()
183
+
184
+ # grad of 2 * sign(x):
185
+ tape = wp.Tape()
186
+ with tape:
187
+ wp.launch(kernel, dim=1, inputs=[inputs], outputs=[outputs], device=device)
188
+ wp.launch(output_select_kernel, dim=1, inputs=[outputs, 2, i], outputs=[out], device=device)
189
+
190
+ tape.backward(loss=out)
191
+ expected_grads = np.zeros_like(inputs.numpy())
192
+ assert_np_equal(tape.gradients[inputs].numpy(), expected_grads, tol=tol)
193
+ tape.zero()
194
+
195
+ # grad of 2 * abs(x):
196
+ tape = wp.Tape()
197
+ with tape:
198
+ wp.launch(kernel, dim=1, inputs=[inputs], outputs=[outputs], device=device)
199
+ wp.launch(output_select_kernel, dim=1, inputs=[outputs, 3, i], outputs=[out], device=device)
200
+
201
+ tape.backward(loss=out)
202
+ expected_grads = np.zeros_like(inputs.numpy())
203
+ expected_grads[3, i] = 2 * np.sign(inputs.numpy()[3, i])
204
+ assert_np_equal(tape.gradients[inputs].numpy(), expected_grads, tol=tol)
205
+ tape.zero()
206
+
207
+ # grad of 2 * step(x):
208
+ tape = wp.Tape()
209
+ with tape:
210
+ wp.launch(kernel, dim=1, inputs=[inputs], outputs=[outputs], device=device)
211
+ wp.launch(output_select_kernel, dim=1, inputs=[outputs, 4, i], outputs=[out], device=device)
212
+
213
+ tape.backward(loss=out)
214
+ expected_grads = np.zeros_like(inputs.numpy())
215
+ assert_np_equal(tape.gradients[inputs].numpy(), expected_grads, tol=tol)
216
+ tape.zero()
217
+
218
+
219
+ def test_nonzero(test, device, dtype, register_kernels=False):
220
+ rng = np.random.default_rng(123)
221
+
222
+ tol = {
223
+ np.float16: 5.0e-3,
224
+ np.float32: 1.0e-6,
225
+ np.float64: 1.0e-8,
226
+ }.get(dtype, 0)
227
+
228
+ wptype = wp._src.types.np_dtype_to_warp_type[np.dtype(dtype)]
229
+
230
+ def check_nonzero(
231
+ inputs: wp.array(dtype=wptype),
232
+ outputs: wp.array(dtype=wptype),
233
+ ):
234
+ for i in range(10):
235
+ i0 = inputs[i]
236
+ outputs[i] = wp.nonzero(i0)
237
+
238
+ kernel = getkernel(check_nonzero, suffix=dtype.__name__)
239
+ output_select_kernel = get_select_kernel(wptype)
240
+
241
+ if register_kernels:
242
+ return
243
+
244
+ inputs = wp.array(rng.integers(-2, high=3, size=10).astype(dtype), dtype=wptype, requires_grad=True, device=device)
245
+ outputs = wp.zeros_like(inputs)
246
+
247
+ wp.launch(kernel, dim=1, inputs=[inputs], outputs=[outputs], device=device)
248
+ assert_np_equal(outputs.numpy(), (inputs.numpy() != 0))
249
+
250
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
251
+ if dtype in np_float_types:
252
+ for i in range(10):
253
+ # grad should just be zero:
254
+ tape = wp.Tape()
255
+ with tape:
256
+ wp.launch(kernel, dim=1, inputs=[inputs], outputs=[outputs], device=device)
257
+ wp.launch(output_select_kernel, dim=1, inputs=[outputs, i], outputs=[out], device=device)
258
+
259
+ tape.backward(loss=out)
260
+ expected_grads = np.zeros_like(inputs.numpy())
261
+ assert_np_equal(tape.gradients[inputs].numpy(), expected_grads, tol=tol)
262
+ tape.zero()
263
+
264
+
265
+ def test_binary_ops(test, device, dtype, register_kernels=False):
266
+ rng = np.random.default_rng(123)
267
+
268
+ tol = {
269
+ np.float16: 5.0e-2,
270
+ np.float32: 1.0e-6,
271
+ np.float64: 1.0e-8,
272
+ }.get(dtype, 0)
273
+
274
+ wptype = wp._src.types.np_dtype_to_warp_type[np.dtype(dtype)]
275
+
276
+ def check_binary_ops(
277
+ in1: wp.array(dtype=wptype, ndim=2),
278
+ in2: wp.array(dtype=wptype, ndim=2),
279
+ outputs: wp.array(dtype=wptype, ndim=2),
280
+ ):
281
+ for i in range(10):
282
+ i0 = in1[0, i]
283
+ i1 = in1[1, i]
284
+ i2 = in1[2, i]
285
+ i3 = in1[3, i]
286
+ i4 = in1[4, i]
287
+ i5 = in1[5, i]
288
+ i6 = in1[6, i]
289
+ i7 = in1[7, i]
290
+
291
+ j0 = in2[0, i]
292
+ j1 = in2[1, i]
293
+ j2 = in2[2, i]
294
+ j3 = in2[3, i]
295
+ j4 = in2[4, i]
296
+ j5 = in2[5, i]
297
+ j6 = in2[6, i]
298
+ j7 = in2[7, i]
299
+
300
+ outputs[0, i] = wptype(2) * wp.mul(i0, j0)
301
+ outputs[1, i] = wptype(2) * wp.div(i1, j1)
302
+ outputs[2, i] = wptype(2) * wp.add(i2, j2)
303
+ outputs[3, i] = wptype(2) * wp.sub(i3, j3)
304
+ outputs[4, i] = wptype(2) * wp.mod(i4, j4)
305
+ outputs[5, i] = wptype(2) * wp.min(i5, j5)
306
+ outputs[6, i] = wptype(2) * wp.max(i6, j6)
307
+ outputs[7, i] = wptype(2) * wp.floordiv(i7, j7)
308
+
309
+ kernel = getkernel(check_binary_ops, suffix=dtype.__name__)
310
+ output_select_kernel = get_select_kernel2(wptype)
311
+
312
+ if register_kernels:
313
+ return
314
+
315
+ vals1 = randvals(rng, [8, 10], dtype)
316
+ if dtype in [np_unsigned_int_types]:
317
+ vals2 = vals1 + randvals(rng, [8, 10], dtype)
318
+ else:
319
+ vals2 = np.abs(randvals(rng, [8, 10], dtype))
320
+
321
+ in1 = wp.array(vals1, dtype=wptype, requires_grad=True, device=device)
322
+ in2 = wp.array(vals2, dtype=wptype, requires_grad=True, device=device)
323
+
324
+ outputs = wp.zeros_like(in1)
325
+
326
+ wp.launch(kernel, dim=1, inputs=[in1, in2], outputs=[outputs], device=device)
327
+
328
+ assert_np_equal(outputs.numpy()[0], 2 * in1.numpy()[0] * in2.numpy()[0], tol=tol)
329
+ if dtype in np_float_types:
330
+ assert_np_equal(outputs.numpy()[1], 2 * in1.numpy()[1] / (in2.numpy()[1]), tol=tol)
331
+ else:
332
+ assert_np_equal(outputs.numpy()[1], 2 * (in1.numpy()[1] // (in2.numpy()[1])), tol=tol)
333
+ assert_np_equal(outputs.numpy()[2], 2 * (in1.numpy()[2] + (in2.numpy()[2])), tol=tol)
334
+ assert_np_equal(outputs.numpy()[3], 2 * (in1.numpy()[3] - (in2.numpy()[3])), tol=tol)
335
+
336
+ # ...so this is actually the desired behaviour right? Looks like wp.mod doesn't behave like
337
+ # python's % operator or np.mod()...
338
+ assert_np_equal(
339
+ outputs.numpy()[4],
340
+ 2
341
+ * (
342
+ (in1.numpy()[4])
343
+ - (in2.numpy()[4]) * np.sign(in1.numpy()[4]) * np.floor(np.abs(in1.numpy()[4]) / (in2.numpy()[4]))
344
+ ),
345
+ tol=tol,
346
+ )
347
+
348
+ assert_np_equal(outputs.numpy()[5], 2 * np.minimum(in1.numpy()[5], in2.numpy()[5]), tol=tol)
349
+ assert_np_equal(outputs.numpy()[6], 2 * np.maximum(in1.numpy()[6], in2.numpy()[6]), tol=tol)
350
+ assert_np_equal(outputs.numpy()[7], 2 * np.floor_divide(in1.numpy()[7], in2.numpy()[7]), tol=tol)
351
+
352
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
353
+ if dtype in np_float_types:
354
+ for i in range(10):
355
+ # multiplication:
356
+ tape = wp.Tape()
357
+ with tape:
358
+ wp.launch(kernel, dim=1, inputs=[in1, in2], outputs=[outputs], device=device)
359
+ wp.launch(output_select_kernel, dim=1, inputs=[outputs, 0, i], outputs=[out], device=device)
360
+
361
+ tape.backward(loss=out)
362
+ expected = np.zeros_like(in1.numpy())
363
+ expected[0, i] = 2.0 * in2.numpy()[0, i]
364
+ assert_np_equal(tape.gradients[in1].numpy(), expected, tol=tol)
365
+ expected[0, i] = 2.0 * in1.numpy()[0, i]
366
+ assert_np_equal(tape.gradients[in2].numpy(), expected, tol=tol)
367
+ tape.zero()
368
+
369
+ # division:
370
+ tape = wp.Tape()
371
+ with tape:
372
+ wp.launch(kernel, dim=1, inputs=[in1, in2], outputs=[outputs], device=device)
373
+ wp.launch(output_select_kernel, dim=1, inputs=[outputs, 1, i], outputs=[out], device=device)
374
+
375
+ tape.backward(loss=out)
376
+ expected = np.zeros_like(in1.numpy())
377
+ expected[1, i] = 2.0 / (in2.numpy()[1, i])
378
+ assert_np_equal(tape.gradients[in1].numpy(), expected, tol=tol)
379
+ # y = x1/x2
380
+ # dy/dx2 = -x1/x2^2
381
+ expected[1, i] = (-2.0) * (in1.numpy()[1, i] / (in2.numpy()[1, i] ** 2))
382
+ assert_np_equal(tape.gradients[in2].numpy(), expected, tol=tol)
383
+ tape.zero()
384
+
385
+ # addition:
386
+ tape = wp.Tape()
387
+ with tape:
388
+ wp.launch(kernel, dim=1, inputs=[in1, in2], outputs=[outputs], device=device)
389
+ wp.launch(output_select_kernel, dim=1, inputs=[outputs, 2, i], outputs=[out], device=device)
390
+
391
+ tape.backward(loss=out)
392
+ expected = np.zeros_like(in1.numpy())
393
+ expected[2, i] = 2.0
394
+ assert_np_equal(tape.gradients[in1].numpy(), expected, tol=tol)
395
+ expected[2, i] = 2.0
396
+ assert_np_equal(tape.gradients[in2].numpy(), expected, tol=tol)
397
+ tape.zero()
398
+
399
+ # subtraction:
400
+ tape = wp.Tape()
401
+ with tape:
402
+ wp.launch(kernel, dim=1, inputs=[in1, in2], outputs=[outputs], device=device)
403
+ wp.launch(output_select_kernel, dim=1, inputs=[outputs, 3, i], outputs=[out], device=device)
404
+
405
+ tape.backward(loss=out)
406
+ expected = np.zeros_like(in1.numpy())
407
+ expected[3, i] = 2.0
408
+ assert_np_equal(tape.gradients[in1].numpy(), expected, tol=tol)
409
+ expected[3, i] = -2.0
410
+ assert_np_equal(tape.gradients[in2].numpy(), expected, tol=tol)
411
+ tape.zero()
412
+
413
+ # modulus. unless at discontinuities,
414
+ # d/dx1( x1 % x2 ) == 1
415
+ # d/dx2( x1 % x2 ) == 0
416
+ tape = wp.Tape()
417
+ with tape:
418
+ wp.launch(kernel, dim=1, inputs=[in1, in2], outputs=[outputs], device=device)
419
+ wp.launch(output_select_kernel, dim=1, inputs=[outputs, 4, i], outputs=[out], device=device)
420
+
421
+ tape.backward(loss=out)
422
+ expected = np.zeros_like(in1.numpy())
423
+ expected[4, i] = 2.0
424
+ assert_np_equal(tape.gradients[in1].numpy(), expected, tol=tol)
425
+ expected[4, i] = 0.0
426
+ assert_np_equal(tape.gradients[in2].numpy(), expected, tol=tol)
427
+ tape.zero()
428
+
429
+ # min
430
+ tape = wp.Tape()
431
+ with tape:
432
+ wp.launch(kernel, dim=1, inputs=[in1, in2], outputs=[outputs], device=device)
433
+ wp.launch(output_select_kernel, dim=1, inputs=[outputs, 5, i], outputs=[out], device=device)
434
+
435
+ tape.backward(loss=out)
436
+ expected = np.zeros_like(in1.numpy())
437
+ expected[5, i] = 2.0 if (in1.numpy()[5, i] < in2.numpy()[5, i]) else 0.0
438
+ assert_np_equal(tape.gradients[in1].numpy(), expected, tol=tol)
439
+ expected[5, i] = 2.0 if (in2.numpy()[5, i] < in1.numpy()[5, i]) else 0.0
440
+ assert_np_equal(tape.gradients[in2].numpy(), expected, tol=tol)
441
+ tape.zero()
442
+
443
+ # max
444
+ tape = wp.Tape()
445
+ with tape:
446
+ wp.launch(kernel, dim=1, inputs=[in1, in2], outputs=[outputs], device=device)
447
+ wp.launch(output_select_kernel, dim=1, inputs=[outputs, 6, i], outputs=[out], device=device)
448
+
449
+ tape.backward(loss=out)
450
+ expected = np.zeros_like(in1.numpy())
451
+ expected[6, i] = 2.0 if (in1.numpy()[6, i] > in2.numpy()[6, i]) else 0.0
452
+ assert_np_equal(tape.gradients[in1].numpy(), expected, tol=tol)
453
+ expected[6, i] = 2.0 if (in2.numpy()[6, i] > in1.numpy()[6, i]) else 0.0
454
+ assert_np_equal(tape.gradients[in2].numpy(), expected, tol=tol)
455
+ tape.zero()
456
+
457
+ # floor_divide. Returns integers so gradient is zero
458
+ tape = wp.Tape()
459
+ with tape:
460
+ wp.launch(kernel, dim=1, inputs=[in1, in2], outputs=[outputs], device=device)
461
+ wp.launch(output_select_kernel, dim=1, inputs=[outputs, 7, i], outputs=[out], device=device)
462
+
463
+ tape.backward(loss=out)
464
+ expected = np.zeros_like(in1.numpy())
465
+ assert_np_equal(tape.gradients[in1].numpy(), expected, tol=tol)
466
+ assert_np_equal(tape.gradients[in2].numpy(), expected, tol=tol)
467
+ tape.zero()
468
+
469
+
470
+ def test_special_funcs(test, device, dtype, register_kernels=False):
471
+ rng = np.random.default_rng(123)
472
+
473
+ tol = {
474
+ np.float16: 1.0e-2,
475
+ np.float32: 1.0e-6,
476
+ np.float64: 1.0e-8,
477
+ }.get(dtype, 0)
478
+
479
+ wptype = wp._src.types.np_dtype_to_warp_type[np.dtype(dtype)]
480
+
481
+ def check_special_funcs(
482
+ inputs: wp.array(dtype=wptype, ndim=2),
483
+ outputs: wp.array(dtype=wptype, ndim=2),
484
+ ):
485
+ # multiply outputs by 2 so we've got something to backpropagate:
486
+ for i in range(10):
487
+ outputs[0, i] = wptype(2) * wp.log(inputs[0, i])
488
+ outputs[1, i] = wptype(2) * wp.log2(inputs[1, i])
489
+ outputs[2, i] = wptype(2) * wp.log10(inputs[2, i])
490
+ outputs[3, i] = wptype(2) * wp.exp(inputs[3, i])
491
+ outputs[4, i] = wptype(2) * wp.atan(inputs[4, i])
492
+ outputs[5, i] = wptype(2) * wp.sin(inputs[5, i])
493
+ outputs[6, i] = wptype(2) * wp.cos(inputs[6, i])
494
+ outputs[7, i] = wptype(2) * wp.sqrt(inputs[7, i])
495
+ outputs[8, i] = wptype(2) * wp.tan(inputs[8, i])
496
+ outputs[9, i] = wptype(2) * wp.sinh(inputs[9, i])
497
+ outputs[10, i] = wptype(2) * wp.cosh(inputs[10, i])
498
+ outputs[11, i] = wptype(2) * wp.tanh(inputs[11, i])
499
+ outputs[12, i] = wptype(2) * wp.acos(inputs[12, i])
500
+ outputs[13, i] = wptype(2) * wp.asin(inputs[13, i])
501
+ outputs[14, i] = wptype(2) * wp.cbrt(inputs[14, i])
502
+
503
+ kernel = getkernel(check_special_funcs, suffix=dtype.__name__)
504
+ output_select_kernel = get_select_kernel2(wptype)
505
+
506
+ if register_kernels:
507
+ return
508
+
509
+ invals = rng.normal(size=(15, 10)).astype(dtype)
510
+ invals[[0, 1, 2, 7, 14]] = 0.1 + np.abs(invals[[0, 1, 2, 7, 14]])
511
+ invals[12] = np.clip(invals[12], -0.9, 0.9)
512
+ invals[13] = np.clip(invals[13], -0.9, 0.9)
513
+ inputs = wp.array(invals, dtype=wptype, requires_grad=True, device=device)
514
+ outputs = wp.zeros_like(inputs)
515
+
516
+ wp.launch(kernel, dim=1, inputs=[inputs], outputs=[outputs], device=device)
517
+
518
+ assert_np_equal(outputs.numpy()[0], 2 * np.log(inputs.numpy()[0]), tol=tol)
519
+ assert_np_equal(outputs.numpy()[1], 2 * np.log2(inputs.numpy()[1]), tol=tol)
520
+ assert_np_equal(outputs.numpy()[2], 2 * np.log10(inputs.numpy()[2]), tol=tol)
521
+ assert_np_equal(outputs.numpy()[3], 2 * np.exp(inputs.numpy()[3]), tol=tol)
522
+ assert_np_equal(outputs.numpy()[4], 2 * np.arctan(inputs.numpy()[4]), tol=tol)
523
+ assert_np_equal(outputs.numpy()[5], 2 * np.sin(inputs.numpy()[5]), tol=tol)
524
+ assert_np_equal(outputs.numpy()[6], 2 * np.cos(inputs.numpy()[6]), tol=tol)
525
+ assert_np_equal(outputs.numpy()[7], 2 * np.sqrt(inputs.numpy()[7]), tol=tol)
526
+ assert_np_equal(outputs.numpy()[8], 2 * np.tan(inputs.numpy()[8]), tol=tol)
527
+ assert_np_equal(outputs.numpy()[9], 2 * np.sinh(inputs.numpy()[9]), tol=tol)
528
+ assert_np_equal(outputs.numpy()[10], 2 * np.cosh(inputs.numpy()[10]), tol=tol)
529
+ assert_np_equal(outputs.numpy()[11], 2 * np.tanh(inputs.numpy()[11]), tol=tol)
530
+ assert_np_equal(outputs.numpy()[12], 2 * np.arccos(inputs.numpy()[12]), tol=tol)
531
+ assert_np_equal(outputs.numpy()[13], 2 * np.arcsin(inputs.numpy()[13]), tol=tol)
532
+ assert_np_equal(outputs.numpy()[14], 2 * np.cbrt(inputs.numpy()[14]), tol=tol)
533
+
534
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
535
+ if dtype in np_float_types:
536
+ for i in range(10):
537
+ # log:
538
+ tape = wp.Tape()
539
+ with tape:
540
+ wp.launch(kernel, dim=1, inputs=[inputs], outputs=[outputs], device=device)
541
+ wp.launch(output_select_kernel, dim=1, inputs=[outputs, 0, i], outputs=[out], device=device)
542
+
543
+ tape.backward(loss=out)
544
+ expected = np.zeros_like(inputs.numpy())
545
+ expected[0, i] = 2.0 / inputs.numpy()[0, i]
546
+ assert_np_equal(tape.gradients[inputs].numpy(), expected, tol=tol)
547
+ tape.zero()
548
+
549
+ # log2:
550
+ tape = wp.Tape()
551
+ with tape:
552
+ wp.launch(kernel, dim=1, inputs=[inputs], outputs=[outputs], device=device)
553
+ wp.launch(output_select_kernel, dim=1, inputs=[outputs, 1, i], outputs=[out], device=device)
554
+
555
+ tape.backward(loss=out)
556
+ expected = np.zeros_like(inputs.numpy())
557
+ expected[1, i] = 2.0 / (inputs.numpy()[1, i] * np.log(2.0))
558
+ assert_np_equal(tape.gradients[inputs].numpy(), expected, tol=tol)
559
+ tape.zero()
560
+
561
+ # log10:
562
+ tape = wp.Tape()
563
+ with tape:
564
+ wp.launch(kernel, dim=1, inputs=[inputs], outputs=[outputs], device=device)
565
+ wp.launch(output_select_kernel, dim=1, inputs=[outputs, 2, i], outputs=[out], device=device)
566
+
567
+ tape.backward(loss=out)
568
+ expected = np.zeros_like(inputs.numpy())
569
+ expected[2, i] = 2.0 / (inputs.numpy()[2, i] * np.log(10.0))
570
+ assert_np_equal(tape.gradients[inputs].numpy(), expected, tol=tol)
571
+ tape.zero()
572
+
573
+ # exp:
574
+ tape = wp.Tape()
575
+ with tape:
576
+ wp.launch(kernel, dim=1, inputs=[inputs], outputs=[outputs], device=device)
577
+ wp.launch(output_select_kernel, dim=1, inputs=[outputs, 3, i], outputs=[out], device=device)
578
+
579
+ tape.backward(loss=out)
580
+ expected = np.zeros_like(inputs.numpy())
581
+ expected[3, i] = outputs.numpy()[3, i]
582
+ assert_np_equal(tape.gradients[inputs].numpy(), expected, tol=tol)
583
+ tape.zero()
584
+
585
+ # arctan:
586
+ # looks like the autodiff formula in warp was wrong? Was (1 + x^2) rather than
587
+ # 1/(1 + x^2)
588
+ tape = wp.Tape()
589
+ with tape:
590
+ wp.launch(kernel, dim=1, inputs=[inputs], outputs=[outputs], device=device)
591
+ wp.launch(output_select_kernel, dim=1, inputs=[outputs, 4, i], outputs=[out], device=device)
592
+
593
+ tape.backward(loss=out)
594
+ expected = np.zeros_like(inputs.numpy())
595
+ expected[4, i] = 2.0 / (inputs.numpy()[4, i] ** 2 + 1)
596
+ assert_np_equal(tape.gradients[inputs].numpy(), expected, tol=tol)
597
+ tape.zero()
598
+
599
+ # sin:
600
+ tape = wp.Tape()
601
+ with tape:
602
+ wp.launch(kernel, dim=1, inputs=[inputs], outputs=[outputs], device=device)
603
+ wp.launch(output_select_kernel, dim=1, inputs=[outputs, 5, i], outputs=[out], device=device)
604
+
605
+ tape.backward(loss=out)
606
+ expected = np.zeros_like(inputs.numpy())
607
+ expected[5, i] = np.cos(inputs.numpy()[5, i]) * 2
608
+ assert_np_equal(tape.gradients[inputs].numpy(), expected, tol=tol)
609
+ tape.zero()
610
+
611
+ # cos:
612
+ tape = wp.Tape()
613
+ with tape:
614
+ wp.launch(kernel, dim=1, inputs=[inputs], outputs=[outputs], device=device)
615
+ wp.launch(output_select_kernel, dim=1, inputs=[outputs, 6, i], outputs=[out], device=device)
616
+
617
+ tape.backward(loss=out)
618
+ expected = np.zeros_like(inputs.numpy())
619
+ expected[6, i] = -np.sin(inputs.numpy()[6, i]) * 2.0
620
+ assert_np_equal(tape.gradients[inputs].numpy(), expected, tol=tol)
621
+ tape.zero()
622
+
623
+ # sqrt:
624
+ tape = wp.Tape()
625
+ with tape:
626
+ wp.launch(kernel, dim=1, inputs=[inputs], outputs=[outputs], device=device)
627
+ wp.launch(output_select_kernel, dim=1, inputs=[outputs, 7, i], outputs=[out], device=device)
628
+
629
+ tape.backward(loss=out)
630
+ expected = np.zeros_like(inputs.numpy())
631
+ expected[7, i] = 1.0 / (np.sqrt(inputs.numpy()[7, i]))
632
+ assert_np_equal(tape.gradients[inputs].numpy(), expected, tol=tol)
633
+ tape.zero()
634
+
635
+ # tan:
636
+ # looks like there was a bug in autodiff formula here too - gradient was zero if cos(x) > 0
637
+ # (should have been "if(cosx != 0)")
638
+ tape = wp.Tape()
639
+ with tape:
640
+ wp.launch(kernel, dim=1, inputs=[inputs], outputs=[outputs], device=device)
641
+ wp.launch(output_select_kernel, dim=1, inputs=[outputs, 8, i], outputs=[out], device=device)
642
+
643
+ tape.backward(loss=out)
644
+ expected = np.zeros_like(inputs.numpy())
645
+ expected[8, i] = 2.0 / (np.cos(inputs.numpy()[8, i]) ** 2)
646
+ assert_np_equal(tape.gradients[inputs].numpy(), expected, tol=200 * tol)
647
+ tape.zero()
648
+
649
+ # sinh:
650
+ tape = wp.Tape()
651
+ with tape:
652
+ wp.launch(kernel, dim=1, inputs=[inputs], outputs=[outputs], device=device)
653
+ wp.launch(output_select_kernel, dim=1, inputs=[outputs, 9, i], outputs=[out], device=device)
654
+
655
+ tape.backward(loss=out)
656
+ expected = np.zeros_like(inputs.numpy())
657
+ expected[9, i] = 2.0 * np.cosh(inputs.numpy()[9, i])
658
+ assert_np_equal(tape.gradients[inputs].numpy(), expected, tol=tol)
659
+ tape.zero()
660
+
661
+ # cosh:
662
+ tape = wp.Tape()
663
+ with tape:
664
+ wp.launch(kernel, dim=1, inputs=[inputs], outputs=[outputs], device=device)
665
+ wp.launch(output_select_kernel, dim=1, inputs=[outputs, 10, i], outputs=[out], device=device)
666
+
667
+ tape.backward(loss=out)
668
+ expected = np.zeros_like(inputs.numpy())
669
+ expected[10, i] = 2.0 * np.sinh(inputs.numpy()[10, i])
670
+ assert_np_equal(tape.gradients[inputs].numpy(), expected, tol=tol)
671
+ tape.zero()
672
+
673
+ # tanh:
674
+ tape = wp.Tape()
675
+ with tape:
676
+ wp.launch(kernel, dim=1, inputs=[inputs], outputs=[outputs], device=device)
677
+ wp.launch(output_select_kernel, dim=1, inputs=[outputs, 11, i], outputs=[out], device=device)
678
+
679
+ tape.backward(loss=out)
680
+ expected = np.zeros_like(inputs.numpy())
681
+ expected[11, i] = 2.0 / (np.cosh(inputs.numpy()[11, i]) ** 2)
682
+ assert_np_equal(tape.gradients[inputs].numpy(), expected, tol=tol)
683
+ tape.zero()
684
+
685
+ # arccos:
686
+ tape = wp.Tape()
687
+ with tape:
688
+ wp.launch(kernel, dim=1, inputs=[inputs], outputs=[outputs], device=device)
689
+ wp.launch(output_select_kernel, dim=1, inputs=[outputs, 12, i], outputs=[out], device=device)
690
+
691
+ tape.backward(loss=out)
692
+ expected = np.zeros_like(inputs.numpy())
693
+ expected[12, i] = -2.0 / np.sqrt(1 - inputs.numpy()[12, i] ** 2)
694
+ assert_np_equal(tape.gradients[inputs].numpy(), expected, tol=tol)
695
+ tape.zero()
696
+
697
+ # arcsin:
698
+ tape = wp.Tape()
699
+ with tape:
700
+ wp.launch(kernel, dim=1, inputs=[inputs], outputs=[outputs], device=device)
701
+ wp.launch(output_select_kernel, dim=1, inputs=[outputs, 13, i], outputs=[out], device=device)
702
+
703
+ tape.backward(loss=out)
704
+ expected = np.zeros_like(inputs.numpy())
705
+ expected[13, i] = 2.0 / np.sqrt(1 - inputs.numpy()[13, i] ** 2)
706
+ assert_np_equal(tape.gradients[inputs].numpy(), expected, tol=6 * tol)
707
+ tape.zero()
708
+
709
+ # cbrt:
710
+ tape = wp.Tape()
711
+ with tape:
712
+ wp.launch(kernel, dim=1, inputs=[inputs], outputs=[outputs], device=device)
713
+ wp.launch(output_select_kernel, dim=1, inputs=[outputs, 14, i], outputs=[out], device=device)
714
+
715
+ tape.backward(loss=out)
716
+ expected = np.zeros_like(inputs.numpy())
717
+ cbrt = np.cbrt(inputs.numpy()[14, i], dtype=np.dtype(dtype))
718
+ expected[14, i] = (2.0 / 3.0) * (1.0 / (cbrt * cbrt))
719
+ assert_np_equal(tape.gradients[inputs].numpy(), expected, tol=tol)
720
+ tape.zero()
721
+
722
+
723
+ def test_special_funcs_2arg(test, device, dtype, register_kernels=False):
724
+ rng = np.random.default_rng(123)
725
+
726
+ tol = {
727
+ np.float16: 1.0e-2,
728
+ np.float32: 1.0e-6,
729
+ np.float64: 1.0e-8,
730
+ }.get(dtype, 0)
731
+
732
+ wptype = wp._src.types.np_dtype_to_warp_type[np.dtype(dtype)]
733
+
734
+ def check_special_funcs_2arg(
735
+ in1: wp.array(dtype=wptype, ndim=2),
736
+ in2: wp.array(dtype=wptype, ndim=2),
737
+ outputs: wp.array(dtype=wptype, ndim=2),
738
+ ):
739
+ # multiply outputs by 2 so we've got something to backpropagate:
740
+ for i in range(10):
741
+ outputs[0, i] = wptype(2) * wp.pow(in1[0, i], in2[0, i])
742
+ outputs[1, i] = wptype(2) * wp.atan2(in1[1, i], in2[1, i])
743
+
744
+ kernel = getkernel(check_special_funcs_2arg, suffix=dtype.__name__)
745
+ output_select_kernel = get_select_kernel2(wptype)
746
+
747
+ if register_kernels:
748
+ return
749
+
750
+ in1 = wp.array(np.abs(randvals(rng, [2, 10], dtype)), dtype=wptype, requires_grad=True, device=device)
751
+ in2 = wp.array(randvals(rng, [2, 10], dtype), dtype=wptype, requires_grad=True, device=device)
752
+ outputs = wp.zeros_like(in1)
753
+
754
+ wp.launch(kernel, dim=1, inputs=[in1, in2], outputs=[outputs], device=device)
755
+
756
+ assert_np_equal(outputs.numpy()[0], 2.0 * np.power(in1.numpy()[0], in2.numpy()[0]), tol=tol)
757
+ assert_np_equal(outputs.numpy()[1], 2.0 * np.arctan2(in1.numpy()[1], in2.numpy()[1]), tol=tol)
758
+
759
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
760
+ if dtype in np_float_types:
761
+ for i in range(10):
762
+ # pow:
763
+ tape = wp.Tape()
764
+ with tape:
765
+ wp.launch(kernel, dim=1, inputs=[in1, in2], outputs=[outputs], device=device)
766
+ wp.launch(output_select_kernel, dim=1, inputs=[outputs, 0, i], outputs=[out], device=device)
767
+ tape.backward(loss=out)
768
+ expected = np.zeros_like(in1.numpy())
769
+ expected[0, i] = 2.0 * in2.numpy()[0, i] * np.power(in1.numpy()[0, i], in2.numpy()[0, i] - 1)
770
+ assert_np_equal(tape.gradients[in1].numpy(), expected, tol=5 * tol)
771
+ expected[0, i] = 2.0 * np.power(in1.numpy()[0, i], in2.numpy()[0, i]) * np.log(in1.numpy()[0, i])
772
+ assert_np_equal(tape.gradients[in2].numpy(), expected, tol=tol)
773
+ tape.zero()
774
+
775
+ # atan2:
776
+ tape = wp.Tape()
777
+ with tape:
778
+ wp.launch(kernel, dim=1, inputs=[in1, in2], outputs=[outputs], device=device)
779
+ wp.launch(output_select_kernel, dim=1, inputs=[outputs, 1, i], outputs=[out], device=device)
780
+
781
+ tape.backward(loss=out)
782
+ expected = np.zeros_like(in1.numpy())
783
+ expected[1, i] = 2.0 * in2.numpy()[1, i] / (in1.numpy()[1, i] ** 2 + in2.numpy()[1, i] ** 2)
784
+ assert_np_equal(tape.gradients[in1].numpy(), expected, tol=tol)
785
+ expected[1, i] = -2.0 * in1.numpy()[1, i] / (in1.numpy()[1, i] ** 2 + in2.numpy()[1, i] ** 2)
786
+ assert_np_equal(tape.gradients[in2].numpy(), expected, tol=tol)
787
+ tape.zero()
788
+
789
+
790
+ def test_float_to_int(test, device, dtype, register_kernels=False):
791
+ rng = np.random.default_rng(123)
792
+
793
+ tol = {
794
+ np.float16: 5.0e-3,
795
+ np.float32: 1.0e-6,
796
+ np.float64: 1.0e-8,
797
+ }.get(dtype, 0)
798
+
799
+ wptype = wp._src.types.np_dtype_to_warp_type[np.dtype(dtype)]
800
+
801
+ def check_float_to_int(
802
+ inputs: wp.array(dtype=wptype, ndim=2),
803
+ outputs: wp.array(dtype=wptype, ndim=2),
804
+ ):
805
+ for i in range(10):
806
+ outputs[0, i] = wp.round(inputs[0, i])
807
+ outputs[1, i] = wp.rint(inputs[1, i])
808
+ outputs[2, i] = wp.trunc(inputs[2, i])
809
+ outputs[3, i] = wp.floor(inputs[3, i])
810
+ outputs[4, i] = wp.ceil(inputs[4, i])
811
+ outputs[5, i] = wp.frac(inputs[5, i])
812
+
813
+ kernel = getkernel(check_float_to_int, suffix=dtype.__name__)
814
+ output_select_kernel = get_select_kernel2(wptype)
815
+
816
+ if register_kernels:
817
+ return
818
+
819
+ inputs = wp.array(rng.standard_normal(size=(6, 10)).astype(dtype), dtype=wptype, requires_grad=True, device=device)
820
+ outputs = wp.zeros_like(inputs)
821
+
822
+ wp.launch(kernel, dim=1, inputs=[inputs], outputs=[outputs], device=device)
823
+
824
+ assert_np_equal(outputs.numpy()[0], np.round(inputs.numpy()[0]))
825
+ assert_np_equal(outputs.numpy()[1], np.rint(inputs.numpy()[1]))
826
+ assert_np_equal(outputs.numpy()[2], np.trunc(inputs.numpy()[2]))
827
+ assert_np_equal(outputs.numpy()[3], np.floor(inputs.numpy()[3]))
828
+ assert_np_equal(outputs.numpy()[4], np.ceil(inputs.numpy()[4]))
829
+ assert_np_equal(outputs.numpy()[5], np.modf(inputs.numpy()[5])[0])
830
+
831
+ # all the gradients should be zero as these functions are piecewise constant:
832
+
833
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
834
+ for i in range(10):
835
+ for j in range(5):
836
+ tape = wp.Tape()
837
+ with tape:
838
+ wp.launch(kernel, dim=1, inputs=[inputs], outputs=[outputs], device=device)
839
+ wp.launch(output_select_kernel, dim=1, inputs=[outputs, j, i], outputs=[out], device=device)
840
+
841
+ tape.backward(loss=out)
842
+ assert_np_equal(tape.gradients[inputs].numpy(), np.zeros_like(inputs.numpy()), tol=tol)
843
+ tape.zero()
844
+
845
+
846
+ def test_interp(test, device, dtype, register_kernels=False):
847
+ rng = np.random.default_rng(123)
848
+
849
+ tol = {
850
+ np.float16: 1.0e-2,
851
+ np.float32: 5.0e-6,
852
+ np.float64: 1.0e-8,
853
+ }.get(dtype, 0)
854
+
855
+ wptype = wp._src.types.np_dtype_to_warp_type[np.dtype(dtype)]
856
+
857
+ def check_interp(
858
+ in1: wp.array(dtype=wptype, ndim=2),
859
+ in2: wp.array(dtype=wptype, ndim=2),
860
+ in3: wp.array(dtype=wptype, ndim=2),
861
+ outputs: wp.array(dtype=wptype, ndim=2),
862
+ ):
863
+ # multiply outputs by 2 so we've got something to backpropagate:
864
+ for i in range(10):
865
+ outputs[0, i] = wptype(2) * wp.smoothstep(in1[0, i], in2[0, i], in3[0, i])
866
+ outputs[1, i] = wptype(2) * wp.lerp(in1[1, i], in2[1, i], in3[1, i])
867
+
868
+ kernel = getkernel(check_interp, suffix=dtype.__name__)
869
+ output_select_kernel = get_select_kernel2(wptype)
870
+
871
+ if register_kernels:
872
+ return
873
+
874
+ e0 = randvals(rng, [2, 10], dtype)
875
+ e1 = e0 + randvals(rng, [2, 10], dtype) + 0.1
876
+ in1 = wp.array(e0, dtype=wptype, requires_grad=True, device=device)
877
+ in2 = wp.array(e1, dtype=wptype, requires_grad=True, device=device)
878
+ in3 = wp.array(randvals(rng, [2, 10], dtype), dtype=wptype, requires_grad=True, device=device)
879
+
880
+ outputs = wp.zeros_like(in1)
881
+
882
+ wp.launch(kernel, dim=1, inputs=[in1, in2, in3], outputs=[outputs], device=device)
883
+
884
+ edge0 = in1.numpy()[0]
885
+ edge1 = in2.numpy()[0]
886
+ t_smoothstep = in3.numpy()[0]
887
+ x = np.clip((t_smoothstep - edge0) / (edge1 - edge0), 0, 1)
888
+ smoothstep_expected = 2.0 * x * x * (3 - 2 * x)
889
+
890
+ assert_np_equal(outputs.numpy()[0], smoothstep_expected, tol=tol)
891
+
892
+ a = in1.numpy()[1]
893
+ b = in2.numpy()[1]
894
+ t = in3.numpy()[1]
895
+ assert_np_equal(outputs.numpy()[1], 2.0 * (a * (1 - t) + b * t), tol=tol)
896
+
897
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
898
+ if dtype in np_float_types:
899
+ for i in range(10):
900
+ tape = wp.Tape()
901
+ with tape:
902
+ wp.launch(kernel, dim=1, inputs=[in1, in2, in3], outputs=[outputs], device=device)
903
+ wp.launch(output_select_kernel, dim=1, inputs=[outputs, 0, i], outputs=[out], device=device)
904
+ tape.backward(loss=out)
905
+
906
+ # e0 = in1
907
+ # e1 = in2
908
+ # t = in3
909
+
910
+ # x = clamp((t - e0) / (e1 - e0), 0,1)
911
+ # dx/dt = 1 / (e1 - e0) if e0 < t < e1 else 0
912
+
913
+ # y = x * x * (3 - 2 * x)
914
+
915
+ # y = 3 * x * x - 2 * x * x * x
916
+ # dy/dx = 6 * ( x - x^2 )
917
+ dydx = 6 * x * (1 - x)
918
+
919
+ # dy/in1 = dy/dx dx/de0 de0/din1
920
+ dxde0 = (t_smoothstep - edge1) / ((edge1 - edge0) ** 2)
921
+ dxde0[x == 0] = 0
922
+ dxde0[x == 1] = 0
923
+
924
+ expected_grads = np.zeros_like(in1.numpy())
925
+ expected_grads[0, i] = 2.0 * dydx[i] * dxde0[i]
926
+ assert_np_equal(tape.gradients[in1].numpy(), expected_grads, tol=tol)
927
+
928
+ # dy/in2 = dy/dx dx/de1 de1/din2
929
+ dxde1 = (edge0 - t_smoothstep) / ((edge1 - edge0) ** 2)
930
+ dxde1[x == 0] = 0
931
+ dxde1[x == 1] = 0
932
+
933
+ expected_grads = np.zeros_like(in1.numpy())
934
+ expected_grads[0, i] = 2.0 * dydx[i] * dxde1[i]
935
+ assert_np_equal(tape.gradients[in2].numpy(), expected_grads, tol=tol)
936
+
937
+ # dy/in3 = dy/dx dx/dt dt/din3
938
+ dxdt = 1.0 / (edge1 - edge0)
939
+ dxdt[x == 0] = 0
940
+ dxdt[x == 1] = 0
941
+
942
+ expected_grads = np.zeros_like(in1.numpy())
943
+ expected_grads[0, i] = 2.0 * dydx[i] * dxdt[i]
944
+ assert_np_equal(tape.gradients[in3].numpy(), expected_grads, tol=tol)
945
+ tape.zero()
946
+
947
+ tape = wp.Tape()
948
+ with tape:
949
+ wp.launch(kernel, dim=1, inputs=[in1, in2, in3], outputs=[outputs], device=device)
950
+ wp.launch(output_select_kernel, dim=1, inputs=[outputs, 1, i], outputs=[out], device=device)
951
+ tape.backward(loss=out)
952
+
953
+ # y = a*(1-t) + b*t
954
+ # a = in1
955
+ # b = in2
956
+ # t = in3
957
+
958
+ # y = in1*( 1 - in3 ) + in2*in3
959
+
960
+ # dy/din1 = (1-in3)
961
+ expected_grads = np.zeros_like(in1.numpy())
962
+ expected_grads[1, i] = 2.0 * (1 - in3.numpy()[1, i])
963
+ assert_np_equal(tape.gradients[in1].numpy(), expected_grads, tol=tol)
964
+
965
+ # dy/din2 = in3
966
+ expected_grads = np.zeros_like(in1.numpy())
967
+ expected_grads[1, i] = 2.0 * in3.numpy()[1, i]
968
+ assert_np_equal(tape.gradients[in2].numpy(), expected_grads, tol=tol)
969
+
970
+ # dy/din3 = 8*in2 - 1.5*4*in1
971
+ expected_grads = np.zeros_like(in1.numpy())
972
+ expected_grads[1, i] = 2.0 * (in2.numpy()[1, i] - in1.numpy()[1, i])
973
+ assert_np_equal(tape.gradients[in3].numpy(), expected_grads, tol=tol)
974
+ tape.zero()
975
+
976
+
977
+ def test_clamp(test, device, dtype, register_kernels=False):
978
+ rng = np.random.default_rng(123)
979
+
980
+ tol = {
981
+ np.float16: 5.0e-3,
982
+ np.float32: 1.0e-6,
983
+ np.float64: 1.0e-6,
984
+ }.get(dtype, 0)
985
+
986
+ wptype = wp._src.types.np_dtype_to_warp_type[np.dtype(dtype)]
987
+
988
+ def check_clamp(
989
+ in1: wp.array(dtype=wptype),
990
+ in2: wp.array(dtype=wptype),
991
+ in3: wp.array(dtype=wptype),
992
+ outputs: wp.array(dtype=wptype),
993
+ ):
994
+ for i in range(100):
995
+ # multiply output by 2 so we've got something to backpropagate:
996
+ outputs[i] = wptype(2) * wp.clamp(in1[i], in2[i], in3[i])
997
+
998
+ kernel = getkernel(check_clamp, suffix=dtype.__name__)
999
+ output_select_kernel = get_select_kernel(wptype)
1000
+
1001
+ if register_kernels:
1002
+ return
1003
+
1004
+ in1 = wp.array(randvals(rng, [100], dtype), dtype=wptype, requires_grad=True, device=device)
1005
+ starts = randvals(rng, [100], dtype)
1006
+ diffs = np.abs(randvals(rng, [100], dtype))
1007
+ in2 = wp.array(starts, dtype=wptype, requires_grad=True, device=device)
1008
+ in3 = wp.array(starts + diffs, dtype=wptype, requires_grad=True, device=device)
1009
+ outputs = wp.zeros_like(in1)
1010
+
1011
+ wp.launch(kernel, dim=1, inputs=[in1, in2, in3], outputs=[outputs], device=device)
1012
+
1013
+ assert_np_equal(2 * np.clip(in1.numpy(), in2.numpy(), in3.numpy()), outputs.numpy(), tol=tol)
1014
+
1015
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1016
+ if dtype in np_float_types:
1017
+ for i in range(100):
1018
+ tape = wp.Tape()
1019
+ with tape:
1020
+ wp.launch(kernel, dim=1, inputs=[in1, in2, in3], outputs=[outputs], device=device)
1021
+ wp.launch(output_select_kernel, dim=1, inputs=[outputs, i], outputs=[out], device=device)
1022
+
1023
+ tape.backward(loss=out)
1024
+ t = in1.numpy()[i]
1025
+ lower = in2.numpy()[i]
1026
+ upper = in3.numpy()[i]
1027
+ expected = np.zeros_like(in1.numpy())
1028
+ if t < lower:
1029
+ expected[i] = 2.0
1030
+ assert_np_equal(tape.gradients[in2].numpy(), expected, tol=tol)
1031
+ expected[i] = 0.0
1032
+ assert_np_equal(tape.gradients[in1].numpy(), expected, tol=tol)
1033
+ assert_np_equal(tape.gradients[in3].numpy(), expected, tol=tol)
1034
+ elif t > upper:
1035
+ expected[i] = 2.0
1036
+ assert_np_equal(tape.gradients[in3].numpy(), expected, tol=tol)
1037
+ expected[i] = 0.0
1038
+ assert_np_equal(tape.gradients[in1].numpy(), expected, tol=tol)
1039
+ assert_np_equal(tape.gradients[in2].numpy(), expected, tol=tol)
1040
+ else:
1041
+ expected[i] = 2.0
1042
+ assert_np_equal(tape.gradients[in1].numpy(), expected, tol=tol)
1043
+ expected[i] = 0.0
1044
+ assert_np_equal(tape.gradients[in2].numpy(), expected, tol=tol)
1045
+ assert_np_equal(tape.gradients[in3].numpy(), expected, tol=tol)
1046
+
1047
+ tape.zero()
1048
+
1049
+
1050
+ devices = get_test_devices()
1051
+
1052
+
1053
+ class TestArithmetic(unittest.TestCase):
1054
+ pass
1055
+
1056
+
1057
+ # these unary ops only make sense for signed values:
1058
+ for dtype in np_signed_int_types + np_float_types:
1059
+ add_function_test_register_kernel(
1060
+ TestArithmetic, f"test_unary_ops_{dtype.__name__}", test_unary_ops, devices=devices, dtype=dtype
1061
+ )
1062
+
1063
+ for dtype in np_float_types:
1064
+ add_function_test_register_kernel(
1065
+ TestArithmetic, f"test_special_funcs_{dtype.__name__}", test_special_funcs, devices=devices, dtype=dtype
1066
+ )
1067
+ add_function_test_register_kernel(
1068
+ TestArithmetic,
1069
+ f"test_special_funcs_2arg_{dtype.__name__}",
1070
+ test_special_funcs_2arg,
1071
+ devices=devices,
1072
+ dtype=dtype,
1073
+ )
1074
+ add_function_test_register_kernel(
1075
+ TestArithmetic, f"test_interp_{dtype.__name__}", test_interp, devices=devices, dtype=dtype
1076
+ )
1077
+ add_function_test_register_kernel(
1078
+ TestArithmetic, f"test_float_to_int_{dtype.__name__}", test_float_to_int, devices=devices, dtype=dtype
1079
+ )
1080
+
1081
+ for dtype in np_scalar_types:
1082
+ add_function_test_register_kernel(
1083
+ TestArithmetic, f"test_clamp_{dtype.__name__}", test_clamp, devices=devices, dtype=dtype
1084
+ )
1085
+ add_function_test_register_kernel(
1086
+ TestArithmetic, f"test_nonzero_{dtype.__name__}", test_nonzero, devices=devices, dtype=dtype
1087
+ )
1088
+ add_function_test(TestArithmetic, f"test_arrays_{dtype.__name__}", test_arrays, devices=devices, dtype=dtype)
1089
+ add_function_test_register_kernel(
1090
+ TestArithmetic, f"test_binary_ops_{dtype.__name__}", test_binary_ops, devices=devices, dtype=dtype
1091
+ )
1092
+
1093
+
1094
+ if __name__ == "__main__":
1095
+ wp.clear_kernel_cache()
1096
+ unittest.main(verbosity=2, failfast=False)