warp-lang 1.7.0__py3-none-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (429) hide show
  1. warp/__init__.py +139 -0
  2. warp/__init__.pyi +1 -0
  3. warp/autograd.py +1142 -0
  4. warp/bin/warp-clang.so +0 -0
  5. warp/bin/warp.so +0 -0
  6. warp/build.py +557 -0
  7. warp/build_dll.py +405 -0
  8. warp/builtins.py +6855 -0
  9. warp/codegen.py +3969 -0
  10. warp/config.py +158 -0
  11. warp/constants.py +57 -0
  12. warp/context.py +6812 -0
  13. warp/dlpack.py +462 -0
  14. warp/examples/__init__.py +24 -0
  15. warp/examples/assets/bear.usd +0 -0
  16. warp/examples/assets/bunny.usd +0 -0
  17. warp/examples/assets/cartpole.urdf +110 -0
  18. warp/examples/assets/crazyflie.usd +0 -0
  19. warp/examples/assets/cube.usd +0 -0
  20. warp/examples/assets/nonuniform.usd +0 -0
  21. warp/examples/assets/nv_ant.xml +92 -0
  22. warp/examples/assets/nv_humanoid.xml +183 -0
  23. warp/examples/assets/nvidia_logo.png +0 -0
  24. warp/examples/assets/pixel.jpg +0 -0
  25. warp/examples/assets/quadruped.urdf +268 -0
  26. warp/examples/assets/rocks.nvdb +0 -0
  27. warp/examples/assets/rocks.usd +0 -0
  28. warp/examples/assets/sphere.usd +0 -0
  29. warp/examples/assets/square_cloth.usd +0 -0
  30. warp/examples/benchmarks/benchmark_api.py +389 -0
  31. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  32. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  33. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  34. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  35. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  36. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  37. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  38. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  39. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  40. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  41. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  42. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  43. warp/examples/benchmarks/benchmark_launches.py +301 -0
  44. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  45. warp/examples/browse.py +37 -0
  46. warp/examples/core/example_cupy.py +86 -0
  47. warp/examples/core/example_dem.py +241 -0
  48. warp/examples/core/example_fluid.py +299 -0
  49. warp/examples/core/example_graph_capture.py +150 -0
  50. warp/examples/core/example_marching_cubes.py +194 -0
  51. warp/examples/core/example_mesh.py +180 -0
  52. warp/examples/core/example_mesh_intersect.py +211 -0
  53. warp/examples/core/example_nvdb.py +182 -0
  54. warp/examples/core/example_raycast.py +111 -0
  55. warp/examples/core/example_raymarch.py +205 -0
  56. warp/examples/core/example_render_opengl.py +193 -0
  57. warp/examples/core/example_sample_mesh.py +300 -0
  58. warp/examples/core/example_sph.py +411 -0
  59. warp/examples/core/example_torch.py +211 -0
  60. warp/examples/core/example_wave.py +269 -0
  61. warp/examples/fem/example_adaptive_grid.py +286 -0
  62. warp/examples/fem/example_apic_fluid.py +423 -0
  63. warp/examples/fem/example_burgers.py +261 -0
  64. warp/examples/fem/example_convection_diffusion.py +178 -0
  65. warp/examples/fem/example_convection_diffusion_dg.py +204 -0
  66. warp/examples/fem/example_deformed_geometry.py +172 -0
  67. warp/examples/fem/example_diffusion.py +196 -0
  68. warp/examples/fem/example_diffusion_3d.py +225 -0
  69. warp/examples/fem/example_diffusion_mgpu.py +220 -0
  70. warp/examples/fem/example_distortion_energy.py +228 -0
  71. warp/examples/fem/example_magnetostatics.py +240 -0
  72. warp/examples/fem/example_mixed_elasticity.py +291 -0
  73. warp/examples/fem/example_navier_stokes.py +261 -0
  74. warp/examples/fem/example_nonconforming_contact.py +298 -0
  75. warp/examples/fem/example_stokes.py +213 -0
  76. warp/examples/fem/example_stokes_transfer.py +262 -0
  77. warp/examples/fem/example_streamlines.py +352 -0
  78. warp/examples/fem/utils.py +1000 -0
  79. warp/examples/interop/example_jax_callable.py +116 -0
  80. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  81. warp/examples/interop/example_jax_kernel.py +205 -0
  82. warp/examples/optim/example_bounce.py +266 -0
  83. warp/examples/optim/example_cloth_throw.py +228 -0
  84. warp/examples/optim/example_diffray.py +561 -0
  85. warp/examples/optim/example_drone.py +870 -0
  86. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  87. warp/examples/optim/example_inverse_kinematics.py +182 -0
  88. warp/examples/optim/example_inverse_kinematics_torch.py +191 -0
  89. warp/examples/optim/example_softbody_properties.py +400 -0
  90. warp/examples/optim/example_spring_cage.py +245 -0
  91. warp/examples/optim/example_trajectory.py +227 -0
  92. warp/examples/sim/example_cartpole.py +143 -0
  93. warp/examples/sim/example_cloth.py +225 -0
  94. warp/examples/sim/example_cloth_self_contact.py +322 -0
  95. warp/examples/sim/example_granular.py +130 -0
  96. warp/examples/sim/example_granular_collision_sdf.py +202 -0
  97. warp/examples/sim/example_jacobian_ik.py +244 -0
  98. warp/examples/sim/example_particle_chain.py +124 -0
  99. warp/examples/sim/example_quadruped.py +203 -0
  100. warp/examples/sim/example_rigid_chain.py +203 -0
  101. warp/examples/sim/example_rigid_contact.py +195 -0
  102. warp/examples/sim/example_rigid_force.py +133 -0
  103. warp/examples/sim/example_rigid_gyroscopic.py +115 -0
  104. warp/examples/sim/example_rigid_soft_contact.py +140 -0
  105. warp/examples/sim/example_soft_body.py +196 -0
  106. warp/examples/tile/example_tile_cholesky.py +87 -0
  107. warp/examples/tile/example_tile_convolution.py +66 -0
  108. warp/examples/tile/example_tile_fft.py +55 -0
  109. warp/examples/tile/example_tile_filtering.py +113 -0
  110. warp/examples/tile/example_tile_matmul.py +85 -0
  111. warp/examples/tile/example_tile_mlp.py +383 -0
  112. warp/examples/tile/example_tile_nbody.py +199 -0
  113. warp/examples/tile/example_tile_walker.py +327 -0
  114. warp/fabric.py +355 -0
  115. warp/fem/__init__.py +106 -0
  116. warp/fem/adaptivity.py +508 -0
  117. warp/fem/cache.py +572 -0
  118. warp/fem/dirichlet.py +202 -0
  119. warp/fem/domain.py +411 -0
  120. warp/fem/field/__init__.py +125 -0
  121. warp/fem/field/field.py +619 -0
  122. warp/fem/field/nodal_field.py +326 -0
  123. warp/fem/field/restriction.py +37 -0
  124. warp/fem/field/virtual.py +848 -0
  125. warp/fem/geometry/__init__.py +32 -0
  126. warp/fem/geometry/adaptive_nanogrid.py +857 -0
  127. warp/fem/geometry/closest_point.py +84 -0
  128. warp/fem/geometry/deformed_geometry.py +221 -0
  129. warp/fem/geometry/element.py +776 -0
  130. warp/fem/geometry/geometry.py +362 -0
  131. warp/fem/geometry/grid_2d.py +392 -0
  132. warp/fem/geometry/grid_3d.py +452 -0
  133. warp/fem/geometry/hexmesh.py +911 -0
  134. warp/fem/geometry/nanogrid.py +571 -0
  135. warp/fem/geometry/partition.py +389 -0
  136. warp/fem/geometry/quadmesh.py +663 -0
  137. warp/fem/geometry/tetmesh.py +855 -0
  138. warp/fem/geometry/trimesh.py +806 -0
  139. warp/fem/integrate.py +2335 -0
  140. warp/fem/linalg.py +419 -0
  141. warp/fem/operator.py +293 -0
  142. warp/fem/polynomial.py +229 -0
  143. warp/fem/quadrature/__init__.py +17 -0
  144. warp/fem/quadrature/pic_quadrature.py +299 -0
  145. warp/fem/quadrature/quadrature.py +591 -0
  146. warp/fem/space/__init__.py +228 -0
  147. warp/fem/space/basis_function_space.py +468 -0
  148. warp/fem/space/basis_space.py +667 -0
  149. warp/fem/space/dof_mapper.py +251 -0
  150. warp/fem/space/function_space.py +309 -0
  151. warp/fem/space/grid_2d_function_space.py +177 -0
  152. warp/fem/space/grid_3d_function_space.py +227 -0
  153. warp/fem/space/hexmesh_function_space.py +257 -0
  154. warp/fem/space/nanogrid_function_space.py +201 -0
  155. warp/fem/space/partition.py +367 -0
  156. warp/fem/space/quadmesh_function_space.py +223 -0
  157. warp/fem/space/restriction.py +179 -0
  158. warp/fem/space/shape/__init__.py +143 -0
  159. warp/fem/space/shape/cube_shape_function.py +1105 -0
  160. warp/fem/space/shape/shape_function.py +133 -0
  161. warp/fem/space/shape/square_shape_function.py +926 -0
  162. warp/fem/space/shape/tet_shape_function.py +834 -0
  163. warp/fem/space/shape/triangle_shape_function.py +672 -0
  164. warp/fem/space/tetmesh_function_space.py +271 -0
  165. warp/fem/space/topology.py +424 -0
  166. warp/fem/space/trimesh_function_space.py +194 -0
  167. warp/fem/types.py +99 -0
  168. warp/fem/utils.py +420 -0
  169. warp/jax.py +187 -0
  170. warp/jax_experimental/__init__.py +16 -0
  171. warp/jax_experimental/custom_call.py +351 -0
  172. warp/jax_experimental/ffi.py +698 -0
  173. warp/jax_experimental/xla_ffi.py +602 -0
  174. warp/math.py +244 -0
  175. warp/native/array.h +1145 -0
  176. warp/native/builtin.h +1800 -0
  177. warp/native/bvh.cpp +492 -0
  178. warp/native/bvh.cu +791 -0
  179. warp/native/bvh.h +554 -0
  180. warp/native/clang/clang.cpp +536 -0
  181. warp/native/coloring.cpp +613 -0
  182. warp/native/crt.cpp +51 -0
  183. warp/native/crt.h +362 -0
  184. warp/native/cuda_crt.h +1058 -0
  185. warp/native/cuda_util.cpp +646 -0
  186. warp/native/cuda_util.h +307 -0
  187. warp/native/error.cpp +77 -0
  188. warp/native/error.h +36 -0
  189. warp/native/exports.h +1878 -0
  190. warp/native/fabric.h +245 -0
  191. warp/native/hashgrid.cpp +311 -0
  192. warp/native/hashgrid.cu +87 -0
  193. warp/native/hashgrid.h +240 -0
  194. warp/native/initializer_array.h +41 -0
  195. warp/native/intersect.h +1230 -0
  196. warp/native/intersect_adj.h +375 -0
  197. warp/native/intersect_tri.h +339 -0
  198. warp/native/marching.cpp +19 -0
  199. warp/native/marching.cu +514 -0
  200. warp/native/marching.h +19 -0
  201. warp/native/mat.h +2220 -0
  202. warp/native/mathdx.cpp +87 -0
  203. warp/native/matnn.h +343 -0
  204. warp/native/mesh.cpp +266 -0
  205. warp/native/mesh.cu +404 -0
  206. warp/native/mesh.h +1980 -0
  207. warp/native/nanovdb/GridHandle.h +366 -0
  208. warp/native/nanovdb/HostBuffer.h +590 -0
  209. warp/native/nanovdb/NanoVDB.h +6624 -0
  210. warp/native/nanovdb/PNanoVDB.h +3390 -0
  211. warp/native/noise.h +859 -0
  212. warp/native/quat.h +1371 -0
  213. warp/native/rand.h +342 -0
  214. warp/native/range.h +139 -0
  215. warp/native/reduce.cpp +174 -0
  216. warp/native/reduce.cu +364 -0
  217. warp/native/runlength_encode.cpp +79 -0
  218. warp/native/runlength_encode.cu +61 -0
  219. warp/native/scan.cpp +47 -0
  220. warp/native/scan.cu +53 -0
  221. warp/native/scan.h +23 -0
  222. warp/native/solid_angle.h +466 -0
  223. warp/native/sort.cpp +251 -0
  224. warp/native/sort.cu +277 -0
  225. warp/native/sort.h +33 -0
  226. warp/native/sparse.cpp +378 -0
  227. warp/native/sparse.cu +524 -0
  228. warp/native/spatial.h +657 -0
  229. warp/native/svd.h +702 -0
  230. warp/native/temp_buffer.h +46 -0
  231. warp/native/tile.h +2584 -0
  232. warp/native/tile_reduce.h +264 -0
  233. warp/native/vec.h +1426 -0
  234. warp/native/volume.cpp +501 -0
  235. warp/native/volume.cu +67 -0
  236. warp/native/volume.h +969 -0
  237. warp/native/volume_builder.cu +477 -0
  238. warp/native/volume_builder.h +52 -0
  239. warp/native/volume_impl.h +70 -0
  240. warp/native/warp.cpp +1082 -0
  241. warp/native/warp.cu +3636 -0
  242. warp/native/warp.h +381 -0
  243. warp/optim/__init__.py +17 -0
  244. warp/optim/adam.py +163 -0
  245. warp/optim/linear.py +1137 -0
  246. warp/optim/sgd.py +112 -0
  247. warp/paddle.py +407 -0
  248. warp/render/__init__.py +18 -0
  249. warp/render/render_opengl.py +3518 -0
  250. warp/render/render_usd.py +784 -0
  251. warp/render/utils.py +160 -0
  252. warp/sim/__init__.py +65 -0
  253. warp/sim/articulation.py +793 -0
  254. warp/sim/collide.py +2395 -0
  255. warp/sim/graph_coloring.py +300 -0
  256. warp/sim/import_mjcf.py +790 -0
  257. warp/sim/import_snu.py +227 -0
  258. warp/sim/import_urdf.py +579 -0
  259. warp/sim/import_usd.py +894 -0
  260. warp/sim/inertia.py +324 -0
  261. warp/sim/integrator.py +242 -0
  262. warp/sim/integrator_euler.py +1997 -0
  263. warp/sim/integrator_featherstone.py +2101 -0
  264. warp/sim/integrator_vbd.py +2048 -0
  265. warp/sim/integrator_xpbd.py +3292 -0
  266. warp/sim/model.py +4791 -0
  267. warp/sim/particles.py +121 -0
  268. warp/sim/render.py +427 -0
  269. warp/sim/utils.py +428 -0
  270. warp/sparse.py +2057 -0
  271. warp/stubs.py +3333 -0
  272. warp/tape.py +1203 -0
  273. warp/tests/__init__.py +1 -0
  274. warp/tests/__main__.py +4 -0
  275. warp/tests/assets/curlnoise_golden.npy +0 -0
  276. warp/tests/assets/mlp_golden.npy +0 -0
  277. warp/tests/assets/pixel.npy +0 -0
  278. warp/tests/assets/pnoise_golden.npy +0 -0
  279. warp/tests/assets/spiky.usd +0 -0
  280. warp/tests/assets/test_grid.nvdb +0 -0
  281. warp/tests/assets/test_index_grid.nvdb +0 -0
  282. warp/tests/assets/test_int32_grid.nvdb +0 -0
  283. warp/tests/assets/test_vec_grid.nvdb +0 -0
  284. warp/tests/assets/torus.nvdb +0 -0
  285. warp/tests/assets/torus.usda +105 -0
  286. warp/tests/aux_test_class_kernel.py +34 -0
  287. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  288. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  289. warp/tests/aux_test_dependent.py +29 -0
  290. warp/tests/aux_test_grad_customs.py +29 -0
  291. warp/tests/aux_test_instancing_gc.py +26 -0
  292. warp/tests/aux_test_module_unload.py +23 -0
  293. warp/tests/aux_test_name_clash1.py +40 -0
  294. warp/tests/aux_test_name_clash2.py +40 -0
  295. warp/tests/aux_test_reference.py +9 -0
  296. warp/tests/aux_test_reference_reference.py +8 -0
  297. warp/tests/aux_test_square.py +16 -0
  298. warp/tests/aux_test_unresolved_func.py +22 -0
  299. warp/tests/aux_test_unresolved_symbol.py +22 -0
  300. warp/tests/cuda/__init__.py +0 -0
  301. warp/tests/cuda/test_async.py +676 -0
  302. warp/tests/cuda/test_ipc.py +124 -0
  303. warp/tests/cuda/test_mempool.py +233 -0
  304. warp/tests/cuda/test_multigpu.py +169 -0
  305. warp/tests/cuda/test_peer.py +139 -0
  306. warp/tests/cuda/test_pinned.py +84 -0
  307. warp/tests/cuda/test_streams.py +634 -0
  308. warp/tests/geometry/__init__.py +0 -0
  309. warp/tests/geometry/test_bvh.py +200 -0
  310. warp/tests/geometry/test_hash_grid.py +221 -0
  311. warp/tests/geometry/test_marching_cubes.py +74 -0
  312. warp/tests/geometry/test_mesh.py +316 -0
  313. warp/tests/geometry/test_mesh_query_aabb.py +399 -0
  314. warp/tests/geometry/test_mesh_query_point.py +932 -0
  315. warp/tests/geometry/test_mesh_query_ray.py +311 -0
  316. warp/tests/geometry/test_volume.py +1103 -0
  317. warp/tests/geometry/test_volume_write.py +346 -0
  318. warp/tests/interop/__init__.py +0 -0
  319. warp/tests/interop/test_dlpack.py +729 -0
  320. warp/tests/interop/test_jax.py +371 -0
  321. warp/tests/interop/test_paddle.py +800 -0
  322. warp/tests/interop/test_torch.py +1001 -0
  323. warp/tests/run_coverage_serial.py +39 -0
  324. warp/tests/sim/__init__.py +0 -0
  325. warp/tests/sim/disabled_kinematics.py +244 -0
  326. warp/tests/sim/flaky_test_sim_grad.py +290 -0
  327. warp/tests/sim/test_collision.py +604 -0
  328. warp/tests/sim/test_coloring.py +258 -0
  329. warp/tests/sim/test_model.py +224 -0
  330. warp/tests/sim/test_sim_grad_bounce_linear.py +212 -0
  331. warp/tests/sim/test_sim_kinematics.py +98 -0
  332. warp/tests/sim/test_vbd.py +597 -0
  333. warp/tests/test_adam.py +163 -0
  334. warp/tests/test_arithmetic.py +1096 -0
  335. warp/tests/test_array.py +2972 -0
  336. warp/tests/test_array_reduce.py +156 -0
  337. warp/tests/test_assert.py +250 -0
  338. warp/tests/test_atomic.py +153 -0
  339. warp/tests/test_bool.py +220 -0
  340. warp/tests/test_builtins_resolution.py +1298 -0
  341. warp/tests/test_closest_point_edge_edge.py +327 -0
  342. warp/tests/test_codegen.py +810 -0
  343. warp/tests/test_codegen_instancing.py +1495 -0
  344. warp/tests/test_compile_consts.py +215 -0
  345. warp/tests/test_conditional.py +252 -0
  346. warp/tests/test_context.py +42 -0
  347. warp/tests/test_copy.py +238 -0
  348. warp/tests/test_ctypes.py +638 -0
  349. warp/tests/test_dense.py +73 -0
  350. warp/tests/test_devices.py +97 -0
  351. warp/tests/test_examples.py +482 -0
  352. warp/tests/test_fabricarray.py +996 -0
  353. warp/tests/test_fast_math.py +74 -0
  354. warp/tests/test_fem.py +2003 -0
  355. warp/tests/test_fp16.py +136 -0
  356. warp/tests/test_func.py +454 -0
  357. warp/tests/test_future_annotations.py +98 -0
  358. warp/tests/test_generics.py +656 -0
  359. warp/tests/test_grad.py +893 -0
  360. warp/tests/test_grad_customs.py +339 -0
  361. warp/tests/test_grad_debug.py +341 -0
  362. warp/tests/test_implicit_init.py +411 -0
  363. warp/tests/test_import.py +45 -0
  364. warp/tests/test_indexedarray.py +1140 -0
  365. warp/tests/test_intersect.py +73 -0
  366. warp/tests/test_iter.py +76 -0
  367. warp/tests/test_large.py +177 -0
  368. warp/tests/test_launch.py +411 -0
  369. warp/tests/test_lerp.py +151 -0
  370. warp/tests/test_linear_solvers.py +193 -0
  371. warp/tests/test_lvalue.py +427 -0
  372. warp/tests/test_mat.py +2089 -0
  373. warp/tests/test_mat_lite.py +122 -0
  374. warp/tests/test_mat_scalar_ops.py +2913 -0
  375. warp/tests/test_math.py +178 -0
  376. warp/tests/test_mlp.py +282 -0
  377. warp/tests/test_module_hashing.py +258 -0
  378. warp/tests/test_modules_lite.py +44 -0
  379. warp/tests/test_noise.py +252 -0
  380. warp/tests/test_operators.py +299 -0
  381. warp/tests/test_options.py +129 -0
  382. warp/tests/test_overwrite.py +551 -0
  383. warp/tests/test_print.py +339 -0
  384. warp/tests/test_quat.py +2315 -0
  385. warp/tests/test_rand.py +339 -0
  386. warp/tests/test_reload.py +302 -0
  387. warp/tests/test_rounding.py +185 -0
  388. warp/tests/test_runlength_encode.py +196 -0
  389. warp/tests/test_scalar_ops.py +105 -0
  390. warp/tests/test_smoothstep.py +108 -0
  391. warp/tests/test_snippet.py +318 -0
  392. warp/tests/test_sparse.py +582 -0
  393. warp/tests/test_spatial.py +2229 -0
  394. warp/tests/test_special_values.py +361 -0
  395. warp/tests/test_static.py +592 -0
  396. warp/tests/test_struct.py +734 -0
  397. warp/tests/test_tape.py +204 -0
  398. warp/tests/test_transient_module.py +93 -0
  399. warp/tests/test_triangle_closest_point.py +145 -0
  400. warp/tests/test_types.py +562 -0
  401. warp/tests/test_utils.py +588 -0
  402. warp/tests/test_vec.py +1487 -0
  403. warp/tests/test_vec_lite.py +80 -0
  404. warp/tests/test_vec_scalar_ops.py +2327 -0
  405. warp/tests/test_verify_fp.py +100 -0
  406. warp/tests/tile/__init__.py +0 -0
  407. warp/tests/tile/test_tile.py +780 -0
  408. warp/tests/tile/test_tile_load.py +407 -0
  409. warp/tests/tile/test_tile_mathdx.py +208 -0
  410. warp/tests/tile/test_tile_mlp.py +402 -0
  411. warp/tests/tile/test_tile_reduce.py +447 -0
  412. warp/tests/tile/test_tile_shared_memory.py +247 -0
  413. warp/tests/tile/test_tile_view.py +173 -0
  414. warp/tests/unittest_serial.py +47 -0
  415. warp/tests/unittest_suites.py +427 -0
  416. warp/tests/unittest_utils.py +468 -0
  417. warp/tests/walkthrough_debug.py +93 -0
  418. warp/thirdparty/__init__.py +0 -0
  419. warp/thirdparty/appdirs.py +598 -0
  420. warp/thirdparty/dlpack.py +145 -0
  421. warp/thirdparty/unittest_parallel.py +570 -0
  422. warp/torch.py +391 -0
  423. warp/types.py +5230 -0
  424. warp/utils.py +1137 -0
  425. warp_lang-1.7.0.dist-info/METADATA +516 -0
  426. warp_lang-1.7.0.dist-info/RECORD +429 -0
  427. warp_lang-1.7.0.dist-info/WHEEL +5 -0
  428. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  429. warp_lang-1.7.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,185 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import unittest
17
+
18
+ import numpy as np
19
+
20
+ import warp as wp
21
+ from warp.tests.unittest_utils import *
22
+
23
+ compare_to_numpy = False
24
+ print_results = False
25
+
26
+
27
+ @wp.kernel
28
+ def test_kernel(
29
+ x: wp.array(dtype=float),
30
+ x_round: wp.array(dtype=float),
31
+ x_rint: wp.array(dtype=float),
32
+ x_trunc: wp.array(dtype=float),
33
+ x_cast: wp.array(dtype=float),
34
+ x_floor: wp.array(dtype=float),
35
+ x_ceil: wp.array(dtype=float),
36
+ x_frac: wp.array(dtype=float),
37
+ ):
38
+ tid = wp.tid()
39
+
40
+ x_round[tid] = wp.round(x[tid])
41
+ x_rint[tid] = wp.rint(x[tid])
42
+ x_trunc[tid] = wp.trunc(x[tid])
43
+ x_cast[tid] = float(int(x[tid]))
44
+ x_floor[tid] = wp.floor(x[tid])
45
+ x_ceil[tid] = wp.ceil(x[tid])
46
+ x_frac[tid] = wp.frac(x[tid])
47
+
48
+
49
+ def test_rounding(test, device):
50
+ nx = np.array(
51
+ [
52
+ 4.9,
53
+ 4.5,
54
+ 4.1,
55
+ 3.9,
56
+ 3.5,
57
+ 3.1,
58
+ 2.9,
59
+ 2.5,
60
+ 2.1,
61
+ 1.9,
62
+ 1.5,
63
+ 1.1,
64
+ 0.9,
65
+ 0.5,
66
+ 0.1,
67
+ -0.1,
68
+ -0.5,
69
+ -0.9,
70
+ -1.1,
71
+ -1.5,
72
+ -1.9,
73
+ -2.1,
74
+ -2.5,
75
+ -2.9,
76
+ -3.1,
77
+ -3.5,
78
+ -3.9,
79
+ -4.1,
80
+ -4.5,
81
+ -4.9,
82
+ ],
83
+ dtype=np.float32,
84
+ )
85
+
86
+ x = wp.array(nx, device=device)
87
+ N = len(x)
88
+
89
+ x_round = wp.empty(N, dtype=float, device=device)
90
+ x_rint = wp.empty(N, dtype=float, device=device)
91
+ x_trunc = wp.empty(N, dtype=float, device=device)
92
+ x_cast = wp.empty(N, dtype=float, device=device)
93
+ x_floor = wp.empty(N, dtype=float, device=device)
94
+ x_ceil = wp.empty(N, dtype=float, device=device)
95
+ x_frac = wp.empty(N, dtype=float, device=device)
96
+
97
+ wp.launch(
98
+ kernel=test_kernel, dim=N, inputs=[x, x_round, x_rint, x_trunc, x_cast, x_floor, x_ceil, x_frac], device=device
99
+ )
100
+
101
+ wp.synchronize()
102
+
103
+ nx_round = x_round.numpy().reshape(N)
104
+ nx_rint = x_rint.numpy().reshape(N)
105
+ nx_trunc = x_trunc.numpy().reshape(N)
106
+ nx_cast = x_cast.numpy().reshape(N)
107
+ nx_floor = x_floor.numpy().reshape(N)
108
+ nx_ceil = x_ceil.numpy().reshape(N)
109
+ nx_frac = x_frac.numpy().reshape(N)
110
+
111
+ tab = np.stack([nx, nx_round, nx_rint, nx_trunc, nx_cast, nx_floor, nx_ceil, nx_frac], axis=1)
112
+
113
+ golden = np.array(
114
+ [
115
+ [4.9, 5.0, 5.0, 4.0, 4.0, 4.0, 5.0, 0.9],
116
+ [4.5, 5.0, 4.0, 4.0, 4.0, 4.0, 5.0, 0.5],
117
+ [4.1, 4.0, 4.0, 4.0, 4.0, 4.0, 5.0, 0.1],
118
+ [3.9, 4.0, 4.0, 3.0, 3.0, 3.0, 4.0, 0.9],
119
+ [3.5, 4.0, 4.0, 3.0, 3.0, 3.0, 4.0, 0.5],
120
+ [3.1, 3.0, 3.0, 3.0, 3.0, 3.0, 4.0, 0.1],
121
+ [2.9, 3.0, 3.0, 2.0, 2.0, 2.0, 3.0, 0.9],
122
+ [2.5, 3.0, 2.0, 2.0, 2.0, 2.0, 3.0, 0.5],
123
+ [2.1, 2.0, 2.0, 2.0, 2.0, 2.0, 3.0, 0.1],
124
+ [1.9, 2.0, 2.0, 1.0, 1.0, 1.0, 2.0, 0.9],
125
+ [1.5, 2.0, 2.0, 1.0, 1.0, 1.0, 2.0, 0.5],
126
+ [1.1, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 0.1],
127
+ [0.9, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.9],
128
+ [0.5, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.5],
129
+ [0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.1],
130
+ [-0.1, -0.0, -0.0, -0.0, 0.0, -1.0, -0.0, -0.1],
131
+ [-0.5, -1.0, -0.0, -0.0, 0.0, -1.0, -0.0, -0.5],
132
+ [-0.9, -1.0, -1.0, -0.0, 0.0, -1.0, -0.0, -0.9],
133
+ [-1.1, -1.0, -1.0, -1.0, -1.0, -2.0, -1.0, -0.1],
134
+ [-1.5, -2.0, -2.0, -1.0, -1.0, -2.0, -1.0, -0.5],
135
+ [-1.9, -2.0, -2.0, -1.0, -1.0, -2.0, -1.0, -0.9],
136
+ [-2.1, -2.0, -2.0, -2.0, -2.0, -3.0, -2.0, -0.1],
137
+ [-2.5, -3.0, -2.0, -2.0, -2.0, -3.0, -2.0, -0.5],
138
+ [-2.9, -3.0, -3.0, -2.0, -2.0, -3.0, -2.0, -0.9],
139
+ [-3.1, -3.0, -3.0, -3.0, -3.0, -4.0, -3.0, -0.1],
140
+ [-3.5, -4.0, -4.0, -3.0, -3.0, -4.0, -3.0, -0.5],
141
+ [-3.9, -4.0, -4.0, -3.0, -3.0, -4.0, -3.0, -0.9],
142
+ [-4.1, -4.0, -4.0, -4.0, -4.0, -5.0, -4.0, -0.1],
143
+ [-4.5, -5.0, -4.0, -4.0, -4.0, -5.0, -4.0, -0.5],
144
+ [-4.9, -5.0, -5.0, -4.0, -4.0, -5.0, -4.0, -0.9],
145
+ ],
146
+ dtype=np.float32,
147
+ )
148
+
149
+ assert_np_equal(tab, golden, tol=1e-6)
150
+
151
+ if print_results:
152
+ np.set_printoptions(formatter={"float": lambda x: "{:6.1f}".format(x).replace(".0", ".")})
153
+
154
+ print("----------------------------------------------")
155
+ print(" %5s %5s %5s %5s %5s %5s %5s" % ("x ", "round", "rint", "trunc", "cast", "floor", "ceil"))
156
+ print(tab)
157
+ print("----------------------------------------------")
158
+
159
+ if compare_to_numpy:
160
+ nx_round = np.round(nx)
161
+ nx_rint = np.rint(nx)
162
+ nx_trunc = np.trunc(nx)
163
+ nx_fix = np.fix(nx)
164
+ nx_floor = np.floor(nx)
165
+ nx_ceil = np.ceil(nx)
166
+ nx_frac = np.modf(nx)[0]
167
+
168
+ tab = np.stack([nx, nx_round, nx_rint, nx_trunc, nx_fix, nx_floor, nx_ceil, nx_frac], axis=1)
169
+ print(" %5s %5s %5s %5s %5s %5s %5s" % ("x ", "round", "rint", "trunc", "fix", "floor", "ceil"))
170
+ print(tab)
171
+ print("----------------------------------------------")
172
+
173
+
174
+ class TestRounding(unittest.TestCase):
175
+ pass
176
+
177
+
178
+ devices = get_test_devices()
179
+
180
+ add_function_test(TestRounding, "test_rounding", test_rounding, devices=devices)
181
+
182
+
183
+ if __name__ == "__main__":
184
+ wp.clear_kernel_cache()
185
+ unittest.main(verbosity=2)
@@ -0,0 +1,196 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import unittest
17
+ from functools import partial
18
+
19
+ import numpy as np
20
+
21
+ import warp as wp
22
+ from warp.tests.unittest_utils import *
23
+ from warp.utils import runlength_encode
24
+
25
+
26
+ def test_runlength_encode_int(test, device, n):
27
+ rng = np.random.default_rng(123)
28
+
29
+ values_np = np.sort(rng.integers(-10, high=10, size=n, dtype=int))
30
+
31
+ unique_values_np, unique_counts_np = np.unique(values_np, return_counts=True)
32
+
33
+ values = wp.array(values_np, device=device, dtype=int)
34
+
35
+ unique_values = wp.empty_like(values)
36
+ unique_counts = wp.empty_like(values)
37
+
38
+ run_count = runlength_encode(values, unique_values, unique_counts)
39
+
40
+ test.assertEqual(run_count, len(unique_values_np))
41
+ assert_np_equal(unique_values.numpy()[:run_count], unique_values_np[:run_count])
42
+ assert_np_equal(unique_counts.numpy()[:run_count], unique_counts_np[:run_count])
43
+
44
+
45
+ def test_runlength_encode_error_insufficient_storage(test, device):
46
+ values = wp.zeros(123, dtype=int, device=device)
47
+ run_values = wp.empty(1, dtype=int, device=device)
48
+ run_lengths = wp.empty(123, dtype=int, device=device)
49
+ with test.assertRaisesRegex(
50
+ RuntimeError,
51
+ r"Output array storage sizes must be at least equal to value_count$",
52
+ ):
53
+ runlength_encode(values, run_values, run_lengths)
54
+
55
+ values = wp.zeros(123, dtype=int, device="cpu")
56
+ run_values = wp.empty(123, dtype=int, device="cpu")
57
+ run_lengths = wp.empty(1, dtype=int, device="cpu")
58
+ with test.assertRaisesRegex(
59
+ RuntimeError,
60
+ r"Output array storage sizes must be at least equal to value_count$",
61
+ ):
62
+ runlength_encode(values, run_values, run_lengths)
63
+
64
+
65
+ def test_runlength_encode_error_dtypes_mismatch(test, device):
66
+ values = wp.zeros(123, dtype=int, device=device)
67
+ run_values = wp.empty(123, dtype=float, device=device)
68
+ run_lengths = wp.empty_like(values, device=device)
69
+ with test.assertRaisesRegex(
70
+ RuntimeError,
71
+ r"values and run_values data types do not match$",
72
+ ):
73
+ runlength_encode(values, run_values, run_lengths)
74
+
75
+
76
+ def test_runlength_encode_error_run_length_unsupported_dtype(test, device):
77
+ values = wp.zeros(123, dtype=int, device=device)
78
+ run_values = wp.empty(123, dtype=int, device=device)
79
+ run_lengths = wp.empty(123, dtype=float, device=device)
80
+ with test.assertRaisesRegex(
81
+ RuntimeError,
82
+ r"run_lengths array must be of type int32$",
83
+ ):
84
+ runlength_encode(values, run_values, run_lengths)
85
+
86
+
87
+ def test_runlength_encode_error_run_count_unsupported_dtype(test, device):
88
+ values = wp.zeros(123, dtype=int, device=device)
89
+ run_values = wp.empty_like(values, device=device)
90
+ run_lengths = wp.empty_like(values, device=device)
91
+ run_count = wp.empty(shape=(1,), dtype=float, device=device)
92
+ with test.assertRaisesRegex(
93
+ RuntimeError,
94
+ r"run_count array must be of type int32$",
95
+ ):
96
+ runlength_encode(values, run_values, run_lengths, run_count=run_count)
97
+
98
+
99
+ def test_runlength_encode_error_unsupported_dtype(test, device):
100
+ values = wp.zeros(123, dtype=float, device=device)
101
+ run_values = wp.empty(123, dtype=float, device=device)
102
+ run_lengths = wp.empty(123, dtype=int, device=device)
103
+ with test.assertRaisesRegex(
104
+ RuntimeError,
105
+ r"Unsupported data type$",
106
+ ):
107
+ runlength_encode(values, run_values, run_lengths)
108
+
109
+
110
+ devices = get_test_devices()
111
+
112
+
113
+ class TestRunlengthEncode(unittest.TestCase):
114
+ @unittest.skipUnless(wp.is_cuda_available(), "Requires CUDA")
115
+ def test_runlength_encode_error_devices_mismatch(self):
116
+ values = wp.zeros(123, dtype=int, device="cpu")
117
+ run_values = wp.empty_like(values, device="cuda:0")
118
+ run_lengths = wp.empty_like(values, device="cuda:0")
119
+ with self.assertRaisesRegex(
120
+ RuntimeError,
121
+ r"Array storage devices do not match$",
122
+ ):
123
+ runlength_encode(values, run_values, run_lengths)
124
+
125
+ values = wp.zeros(123, dtype=int, device="cpu")
126
+ run_values = wp.empty_like(values, device="cpu")
127
+ run_lengths = wp.empty_like(values, device="cuda:0")
128
+ with self.assertRaisesRegex(
129
+ RuntimeError,
130
+ r"Array storage devices do not match$",
131
+ ):
132
+ runlength_encode(values, run_values, run_lengths)
133
+
134
+ values = wp.zeros(123, dtype=int, device="cpu")
135
+ run_values = wp.empty_like(values, device="cuda:0")
136
+ run_lengths = wp.empty_like(values, device="cpu")
137
+ with self.assertRaisesRegex(
138
+ RuntimeError,
139
+ r"Array storage devices do not match$",
140
+ ):
141
+ runlength_encode(values, run_values, run_lengths)
142
+
143
+ @unittest.skipUnless(wp.is_cuda_available(), "Requires CUDA")
144
+ def test_runlength_encode_error_run_count_device_mismatch(self):
145
+ values = wp.zeros(123, dtype=int, device="cpu")
146
+ run_values = wp.empty_like(values, device="cpu")
147
+ run_lengths = wp.empty_like(values, device="cpu")
148
+ run_count = wp.empty(shape=(1,), dtype=int, device="cuda:0")
149
+ with self.assertRaisesRegex(
150
+ RuntimeError,
151
+ r"run_count storage device does not match other arrays$",
152
+ ):
153
+ runlength_encode(values, run_values, run_lengths, run_count=run_count)
154
+
155
+
156
+ add_function_test(
157
+ TestRunlengthEncode, "test_runlength_encode_int", partial(test_runlength_encode_int, n=100), devices=devices
158
+ )
159
+ add_function_test(
160
+ TestRunlengthEncode, "test_runlength_encode_empty", partial(test_runlength_encode_int, n=0), devices=devices
161
+ )
162
+ add_function_test(
163
+ TestRunlengthEncode,
164
+ "test_runlength_encode_error_insufficient_storage",
165
+ test_runlength_encode_error_insufficient_storage,
166
+ devices=devices,
167
+ )
168
+ add_function_test(
169
+ TestRunlengthEncode,
170
+ "test_runlength_encode_error_dtypes_mismatch",
171
+ test_runlength_encode_error_dtypes_mismatch,
172
+ devices=devices,
173
+ )
174
+ add_function_test(
175
+ TestRunlengthEncode,
176
+ "test_runlength_encode_error_run_length_unsupported_dtype",
177
+ test_runlength_encode_error_run_length_unsupported_dtype,
178
+ devices=devices,
179
+ )
180
+ add_function_test(
181
+ TestRunlengthEncode,
182
+ "test_runlength_encode_error_run_count_unsupported_dtype",
183
+ test_runlength_encode_error_run_count_unsupported_dtype,
184
+ devices=devices,
185
+ )
186
+ add_function_test(
187
+ TestRunlengthEncode,
188
+ "test_runlength_encode_error_unsupported_dtype",
189
+ test_runlength_encode_error_unsupported_dtype,
190
+ devices=devices,
191
+ )
192
+
193
+
194
+ if __name__ == "__main__":
195
+ wp.clear_kernel_cache()
196
+ unittest.main(verbosity=2)
@@ -0,0 +1,105 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import unittest
17
+
18
+ import numpy as np
19
+
20
+ import warp as wp
21
+ from warp.tests.unittest_utils import *
22
+
23
+ np_signed_int_types = [
24
+ np.int8,
25
+ np.int16,
26
+ np.int32,
27
+ np.int64,
28
+ np.byte,
29
+ ]
30
+
31
+ np_unsigned_int_types = [
32
+ np.uint8,
33
+ np.uint16,
34
+ np.uint32,
35
+ np.uint64,
36
+ np.ubyte,
37
+ ]
38
+
39
+ np_int_types = np_signed_int_types + np_unsigned_int_types
40
+
41
+ np_float_types = [np.float16, np.float32, np.float64]
42
+
43
+ np_scalar_types = np_int_types + np_float_types
44
+
45
+
46
+ def test_py_arithmetic_ops(test, device, dtype):
47
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
48
+
49
+ def make_scalar(value):
50
+ if wptype in wp.types.int_types:
51
+ # Cast to the correct integer type to simulate wrapping.
52
+ return wptype._type_(value).value
53
+
54
+ return value
55
+
56
+ a = wptype(1)
57
+ test.assertAlmostEqual(+a, make_scalar(1))
58
+ test.assertAlmostEqual(-a, make_scalar(-1))
59
+ test.assertAlmostEqual(a + wptype(5), make_scalar(6))
60
+ test.assertAlmostEqual(a - wptype(5), make_scalar(-4))
61
+ test.assertAlmostEqual(a % wptype(2), make_scalar(1))
62
+
63
+ a = wptype(2)
64
+ test.assertAlmostEqual(a * wptype(2), make_scalar(4))
65
+ test.assertAlmostEqual(wptype(2) * a, make_scalar(4))
66
+ test.assertAlmostEqual(a / wptype(2), make_scalar(1))
67
+ test.assertAlmostEqual(wptype(24) / a, make_scalar(12))
68
+ test.assertAlmostEqual(a % wptype(2), make_scalar(0))
69
+
70
+
71
+ def test_py_math_ops(test, device, dtype):
72
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
73
+
74
+ def make_scalar(value):
75
+ if wptype in wp.types.int_types:
76
+ # Cast to the correct integer type to simulate wrapping.
77
+ return wptype._type_(value).value
78
+
79
+ return value
80
+
81
+ a = wptype(1)
82
+ test.assertAlmostEqual(wp.abs(a), 1)
83
+
84
+ if dtype in np_float_types:
85
+ test.assertAlmostEqual(wp.sin(a), 0.84147098480789650488, places=3)
86
+ test.assertAlmostEqual(wp.radians(a), 0.01745329251994329577, places=5)
87
+
88
+
89
+ devices = get_test_devices()
90
+
91
+
92
+ class TestScalarOps(unittest.TestCase):
93
+ pass
94
+
95
+
96
+ for dtype in np_scalar_types:
97
+ add_function_test(
98
+ TestScalarOps, f"test_py_arithmetic_ops_{dtype.__name__}", test_py_arithmetic_ops, devices=None, dtype=dtype
99
+ )
100
+ add_function_test(TestScalarOps, f"test_py_math_ops_{dtype.__name__}", test_py_math_ops, devices=None, dtype=dtype)
101
+
102
+
103
+ if __name__ == "__main__":
104
+ wp.clear_kernel_cache()
105
+ unittest.main(verbosity=2, failfast=True)
@@ -0,0 +1,108 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import unittest
17
+ from dataclasses import dataclass
18
+ from typing import Any
19
+
20
+ import numpy as np
21
+
22
+ import warp as wp
23
+ from warp.tests.unittest_utils import *
24
+
25
+
26
+ @dataclass
27
+ class TestData:
28
+ a: Any
29
+ b: Any
30
+ t: float
31
+ expected: Any
32
+ expected_adj_a: Any = None
33
+ expected_adj_b: Any = None
34
+ expected_adj_t: float = None
35
+
36
+ def check_backwards(self):
37
+ return self.expected_adj_a is not None and self.expected_adj_b is not None and self.expected_adj_t is not None
38
+
39
+
40
+ TEST_DATA = {
41
+ wp.float32: (
42
+ TestData(a=1.0, b=2.0, t=1.5, expected=0.5, expected_adj_a=-0.75, expected_adj_b=-0.75, expected_adj_t=1.5),
43
+ TestData(
44
+ a=-1.0,
45
+ b=2.0,
46
+ t=-0.25,
47
+ expected=0.15625,
48
+ expected_adj_a=-0.28125,
49
+ expected_adj_b=-0.09375,
50
+ expected_adj_t=0.375,
51
+ ),
52
+ TestData(a=0.0, b=1.0, t=9.9, expected=1.0, expected_adj_a=0.0, expected_adj_b=0.0, expected_adj_t=0.0),
53
+ TestData(a=0.0, b=1.0, t=-9.9, expected=0.0, expected_adj_a=0.0, expected_adj_b=0.0, expected_adj_t=0.0),
54
+ ),
55
+ }
56
+
57
+
58
+ def test_smoothstep(test, device):
59
+ def make_kernel_fn(data_type):
60
+ def fn(
61
+ a: wp.array(dtype=data_type),
62
+ b: wp.array(dtype=data_type),
63
+ t: wp.array(dtype=float),
64
+ out: wp.array(dtype=data_type),
65
+ ):
66
+ out[0] = wp.smoothstep(a[0], b[0], t[0])
67
+
68
+ return fn
69
+
70
+ for data_type, test_data_set in TEST_DATA.items():
71
+ kernel_fn = make_kernel_fn(data_type)
72
+ kernel = wp.Kernel(
73
+ func=kernel_fn,
74
+ key=f"test_smoothstep{data_type.__name__}_kernel",
75
+ )
76
+
77
+ for test_data in test_data_set:
78
+ a = wp.array([test_data.a], dtype=data_type, device=device, requires_grad=True)
79
+ b = wp.array([test_data.b], dtype=data_type, device=device, requires_grad=True)
80
+ t = wp.array([test_data.t], dtype=float, device=device, requires_grad=True)
81
+ out = wp.array([0] * wp.types.type_length(data_type), dtype=data_type, device=device, requires_grad=True)
82
+
83
+ with wp.Tape() as tape:
84
+ wp.launch(kernel, dim=1, inputs=[a, b, t, out], device=device)
85
+
86
+ assert_np_equal(out.numpy(), np.array([test_data.expected]), tol=1e-6)
87
+
88
+ if test_data.check_backwards():
89
+ tape.backward(out)
90
+
91
+ assert_np_equal(tape.gradients[a].numpy(), np.array([test_data.expected_adj_a]), tol=1e-6)
92
+ assert_np_equal(tape.gradients[b].numpy(), np.array([test_data.expected_adj_b]), tol=1e-6)
93
+ assert_np_equal(tape.gradients[t].numpy(), np.array([test_data.expected_adj_t]), tol=1e-6)
94
+
95
+
96
+ devices = get_test_devices()
97
+
98
+
99
+ class TestSmoothstep(unittest.TestCase):
100
+ pass
101
+
102
+
103
+ add_function_test(TestSmoothstep, "test_smoothstep", test_smoothstep, devices=devices)
104
+
105
+
106
+ if __name__ == "__main__":
107
+ wp.clear_kernel_cache()
108
+ unittest.main(verbosity=2)