warp-lang 1.10.0__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (468) hide show
  1. warp/__init__.py +334 -0
  2. warp/__init__.pyi +5856 -0
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1077 -0
  5. warp/_src/build.py +620 -0
  6. warp/_src/build_dll.py +642 -0
  7. warp/_src/builtins.py +10555 -0
  8. warp/_src/codegen.py +4361 -0
  9. warp/_src/config.py +178 -0
  10. warp/_src/constants.py +59 -0
  11. warp/_src/context.py +8352 -0
  12. warp/_src/dlpack.py +464 -0
  13. warp/_src/fabric.py +362 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +510 -0
  16. warp/_src/fem/cache.py +689 -0
  17. warp/_src/fem/dirichlet.py +190 -0
  18. warp/_src/fem/domain.py +553 -0
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +703 -0
  21. warp/_src/fem/field/nodal_field.py +403 -0
  22. warp/_src/fem/field/restriction.py +39 -0
  23. warp/_src/fem/field/virtual.py +1021 -0
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/_src/fem/geometry/adaptive_nanogrid.py +782 -0
  26. warp/_src/fem/geometry/closest_point.py +99 -0
  27. warp/_src/fem/geometry/deformed_geometry.py +277 -0
  28. warp/_src/fem/geometry/element.py +854 -0
  29. warp/_src/fem/geometry/geometry.py +693 -0
  30. warp/_src/fem/geometry/grid_2d.py +478 -0
  31. warp/_src/fem/geometry/grid_3d.py +539 -0
  32. warp/_src/fem/geometry/hexmesh.py +956 -0
  33. warp/_src/fem/geometry/nanogrid.py +660 -0
  34. warp/_src/fem/geometry/partition.py +483 -0
  35. warp/_src/fem/geometry/quadmesh.py +597 -0
  36. warp/_src/fem/geometry/tetmesh.py +762 -0
  37. warp/_src/fem/geometry/trimesh.py +588 -0
  38. warp/_src/fem/integrate.py +2507 -0
  39. warp/_src/fem/linalg.py +385 -0
  40. warp/_src/fem/operator.py +398 -0
  41. warp/_src/fem/polynomial.py +231 -0
  42. warp/_src/fem/quadrature/__init__.py +17 -0
  43. warp/_src/fem/quadrature/pic_quadrature.py +318 -0
  44. warp/_src/fem/quadrature/quadrature.py +665 -0
  45. warp/_src/fem/space/__init__.py +248 -0
  46. warp/_src/fem/space/basis_function_space.py +499 -0
  47. warp/_src/fem/space/basis_space.py +681 -0
  48. warp/_src/fem/space/dof_mapper.py +253 -0
  49. warp/_src/fem/space/function_space.py +312 -0
  50. warp/_src/fem/space/grid_2d_function_space.py +179 -0
  51. warp/_src/fem/space/grid_3d_function_space.py +229 -0
  52. warp/_src/fem/space/hexmesh_function_space.py +255 -0
  53. warp/_src/fem/space/nanogrid_function_space.py +199 -0
  54. warp/_src/fem/space/partition.py +435 -0
  55. warp/_src/fem/space/quadmesh_function_space.py +222 -0
  56. warp/_src/fem/space/restriction.py +221 -0
  57. warp/_src/fem/space/shape/__init__.py +152 -0
  58. warp/_src/fem/space/shape/cube_shape_function.py +1107 -0
  59. warp/_src/fem/space/shape/shape_function.py +134 -0
  60. warp/_src/fem/space/shape/square_shape_function.py +928 -0
  61. warp/_src/fem/space/shape/tet_shape_function.py +829 -0
  62. warp/_src/fem/space/shape/triangle_shape_function.py +674 -0
  63. warp/_src/fem/space/tetmesh_function_space.py +270 -0
  64. warp/_src/fem/space/topology.py +461 -0
  65. warp/_src/fem/space/trimesh_function_space.py +193 -0
  66. warp/_src/fem/types.py +114 -0
  67. warp/_src/fem/utils.py +488 -0
  68. warp/_src/jax.py +188 -0
  69. warp/_src/jax_experimental/__init__.py +14 -0
  70. warp/_src/jax_experimental/custom_call.py +389 -0
  71. warp/_src/jax_experimental/ffi.py +1286 -0
  72. warp/_src/jax_experimental/xla_ffi.py +658 -0
  73. warp/_src/marching_cubes.py +710 -0
  74. warp/_src/math.py +416 -0
  75. warp/_src/optim/__init__.py +14 -0
  76. warp/_src/optim/adam.py +165 -0
  77. warp/_src/optim/linear.py +1608 -0
  78. warp/_src/optim/sgd.py +114 -0
  79. warp/_src/paddle.py +408 -0
  80. warp/_src/render/__init__.py +14 -0
  81. warp/_src/render/imgui_manager.py +291 -0
  82. warp/_src/render/render_opengl.py +3638 -0
  83. warp/_src/render/render_usd.py +939 -0
  84. warp/_src/render/utils.py +162 -0
  85. warp/_src/sparse.py +2718 -0
  86. warp/_src/tape.py +1208 -0
  87. warp/_src/thirdparty/__init__.py +0 -0
  88. warp/_src/thirdparty/appdirs.py +598 -0
  89. warp/_src/thirdparty/dlpack.py +145 -0
  90. warp/_src/thirdparty/unittest_parallel.py +676 -0
  91. warp/_src/torch.py +393 -0
  92. warp/_src/types.py +5888 -0
  93. warp/_src/utils.py +1695 -0
  94. warp/autograd.py +33 -0
  95. warp/bin/libwarp-clang.dylib +0 -0
  96. warp/bin/libwarp.dylib +0 -0
  97. warp/build.py +29 -0
  98. warp/build_dll.py +24 -0
  99. warp/codegen.py +24 -0
  100. warp/constants.py +24 -0
  101. warp/context.py +33 -0
  102. warp/dlpack.py +24 -0
  103. warp/examples/__init__.py +24 -0
  104. warp/examples/assets/bear.usd +0 -0
  105. warp/examples/assets/bunny.usd +0 -0
  106. warp/examples/assets/cube.usd +0 -0
  107. warp/examples/assets/nonuniform.usd +0 -0
  108. warp/examples/assets/nvidia_logo.png +0 -0
  109. warp/examples/assets/pixel.jpg +0 -0
  110. warp/examples/assets/rocks.nvdb +0 -0
  111. warp/examples/assets/rocks.usd +0 -0
  112. warp/examples/assets/sphere.usd +0 -0
  113. warp/examples/assets/square_cloth.usd +0 -0
  114. warp/examples/benchmarks/benchmark_api.py +389 -0
  115. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  116. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  117. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  118. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  119. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  120. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  121. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  122. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  123. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  124. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  125. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  126. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  127. warp/examples/benchmarks/benchmark_launches.py +301 -0
  128. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  129. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  130. warp/examples/browse.py +37 -0
  131. warp/examples/core/example_cupy.py +86 -0
  132. warp/examples/core/example_dem.py +241 -0
  133. warp/examples/core/example_fluid.py +299 -0
  134. warp/examples/core/example_graph_capture.py +150 -0
  135. warp/examples/core/example_marching_cubes.py +195 -0
  136. warp/examples/core/example_mesh.py +180 -0
  137. warp/examples/core/example_mesh_intersect.py +211 -0
  138. warp/examples/core/example_nvdb.py +182 -0
  139. warp/examples/core/example_raycast.py +111 -0
  140. warp/examples/core/example_raymarch.py +205 -0
  141. warp/examples/core/example_render_opengl.py +290 -0
  142. warp/examples/core/example_sample_mesh.py +300 -0
  143. warp/examples/core/example_sph.py +411 -0
  144. warp/examples/core/example_spin_lock.py +93 -0
  145. warp/examples/core/example_torch.py +211 -0
  146. warp/examples/core/example_wave.py +269 -0
  147. warp/examples/core/example_work_queue.py +118 -0
  148. warp/examples/distributed/example_jacobi_mpi.py +506 -0
  149. warp/examples/fem/example_adaptive_grid.py +286 -0
  150. warp/examples/fem/example_apic_fluid.py +469 -0
  151. warp/examples/fem/example_burgers.py +261 -0
  152. warp/examples/fem/example_convection_diffusion.py +181 -0
  153. warp/examples/fem/example_convection_diffusion_dg.py +225 -0
  154. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  155. warp/examples/fem/example_deformed_geometry.py +172 -0
  156. warp/examples/fem/example_diffusion.py +196 -0
  157. warp/examples/fem/example_diffusion_3d.py +225 -0
  158. warp/examples/fem/example_diffusion_mgpu.py +225 -0
  159. warp/examples/fem/example_distortion_energy.py +228 -0
  160. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  161. warp/examples/fem/example_magnetostatics.py +242 -0
  162. warp/examples/fem/example_mixed_elasticity.py +293 -0
  163. warp/examples/fem/example_navier_stokes.py +263 -0
  164. warp/examples/fem/example_nonconforming_contact.py +300 -0
  165. warp/examples/fem/example_stokes.py +213 -0
  166. warp/examples/fem/example_stokes_transfer.py +262 -0
  167. warp/examples/fem/example_streamlines.py +357 -0
  168. warp/examples/fem/utils.py +1047 -0
  169. warp/examples/interop/example_jax_callable.py +146 -0
  170. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  171. warp/examples/interop/example_jax_kernel.py +232 -0
  172. warp/examples/optim/example_diffray.py +561 -0
  173. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  174. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  175. warp/examples/tile/example_tile_cholesky.py +88 -0
  176. warp/examples/tile/example_tile_convolution.py +66 -0
  177. warp/examples/tile/example_tile_fft.py +55 -0
  178. warp/examples/tile/example_tile_filtering.py +113 -0
  179. warp/examples/tile/example_tile_matmul.py +85 -0
  180. warp/examples/tile/example_tile_mcgp.py +191 -0
  181. warp/examples/tile/example_tile_mlp.py +385 -0
  182. warp/examples/tile/example_tile_nbody.py +199 -0
  183. warp/fabric.py +24 -0
  184. warp/fem/__init__.py +173 -0
  185. warp/fem/adaptivity.py +26 -0
  186. warp/fem/cache.py +30 -0
  187. warp/fem/dirichlet.py +24 -0
  188. warp/fem/field/__init__.py +24 -0
  189. warp/fem/field/field.py +26 -0
  190. warp/fem/geometry/__init__.py +21 -0
  191. warp/fem/geometry/closest_point.py +31 -0
  192. warp/fem/linalg.py +38 -0
  193. warp/fem/operator.py +32 -0
  194. warp/fem/polynomial.py +29 -0
  195. warp/fem/space/__init__.py +22 -0
  196. warp/fem/space/basis_space.py +24 -0
  197. warp/fem/space/shape/__init__.py +68 -0
  198. warp/fem/space/topology.py +24 -0
  199. warp/fem/types.py +24 -0
  200. warp/fem/utils.py +32 -0
  201. warp/jax.py +29 -0
  202. warp/jax_experimental/__init__.py +29 -0
  203. warp/jax_experimental/custom_call.py +29 -0
  204. warp/jax_experimental/ffi.py +39 -0
  205. warp/jax_experimental/xla_ffi.py +24 -0
  206. warp/marching_cubes.py +24 -0
  207. warp/math.py +37 -0
  208. warp/native/array.h +1687 -0
  209. warp/native/builtin.h +2327 -0
  210. warp/native/bvh.cpp +562 -0
  211. warp/native/bvh.cu +826 -0
  212. warp/native/bvh.h +555 -0
  213. warp/native/clang/clang.cpp +541 -0
  214. warp/native/coloring.cpp +622 -0
  215. warp/native/crt.cpp +51 -0
  216. warp/native/crt.h +568 -0
  217. warp/native/cuda_crt.h +1058 -0
  218. warp/native/cuda_util.cpp +677 -0
  219. warp/native/cuda_util.h +313 -0
  220. warp/native/error.cpp +77 -0
  221. warp/native/error.h +36 -0
  222. warp/native/exports.h +2023 -0
  223. warp/native/fabric.h +246 -0
  224. warp/native/hashgrid.cpp +311 -0
  225. warp/native/hashgrid.cu +89 -0
  226. warp/native/hashgrid.h +240 -0
  227. warp/native/initializer_array.h +41 -0
  228. warp/native/intersect.h +1253 -0
  229. warp/native/intersect_adj.h +375 -0
  230. warp/native/intersect_tri.h +348 -0
  231. warp/native/mat.h +5189 -0
  232. warp/native/mathdx.cpp +93 -0
  233. warp/native/matnn.h +221 -0
  234. warp/native/mesh.cpp +266 -0
  235. warp/native/mesh.cu +406 -0
  236. warp/native/mesh.h +2097 -0
  237. warp/native/nanovdb/GridHandle.h +533 -0
  238. warp/native/nanovdb/HostBuffer.h +591 -0
  239. warp/native/nanovdb/NanoVDB.h +6246 -0
  240. warp/native/nanovdb/NodeManager.h +323 -0
  241. warp/native/nanovdb/PNanoVDB.h +3390 -0
  242. warp/native/noise.h +859 -0
  243. warp/native/quat.h +1664 -0
  244. warp/native/rand.h +342 -0
  245. warp/native/range.h +145 -0
  246. warp/native/reduce.cpp +174 -0
  247. warp/native/reduce.cu +363 -0
  248. warp/native/runlength_encode.cpp +79 -0
  249. warp/native/runlength_encode.cu +61 -0
  250. warp/native/scan.cpp +47 -0
  251. warp/native/scan.cu +55 -0
  252. warp/native/scan.h +23 -0
  253. warp/native/solid_angle.h +466 -0
  254. warp/native/sort.cpp +251 -0
  255. warp/native/sort.cu +286 -0
  256. warp/native/sort.h +35 -0
  257. warp/native/sparse.cpp +241 -0
  258. warp/native/sparse.cu +435 -0
  259. warp/native/spatial.h +1306 -0
  260. warp/native/svd.h +727 -0
  261. warp/native/temp_buffer.h +46 -0
  262. warp/native/tile.h +4124 -0
  263. warp/native/tile_radix_sort.h +1112 -0
  264. warp/native/tile_reduce.h +838 -0
  265. warp/native/tile_scan.h +240 -0
  266. warp/native/tuple.h +189 -0
  267. warp/native/vec.h +2199 -0
  268. warp/native/version.h +23 -0
  269. warp/native/volume.cpp +501 -0
  270. warp/native/volume.cu +68 -0
  271. warp/native/volume.h +970 -0
  272. warp/native/volume_builder.cu +483 -0
  273. warp/native/volume_builder.h +52 -0
  274. warp/native/volume_impl.h +70 -0
  275. warp/native/warp.cpp +1143 -0
  276. warp/native/warp.cu +4604 -0
  277. warp/native/warp.h +358 -0
  278. warp/optim/__init__.py +20 -0
  279. warp/optim/adam.py +24 -0
  280. warp/optim/linear.py +35 -0
  281. warp/optim/sgd.py +24 -0
  282. warp/paddle.py +24 -0
  283. warp/py.typed +0 -0
  284. warp/render/__init__.py +22 -0
  285. warp/render/imgui_manager.py +29 -0
  286. warp/render/render_opengl.py +24 -0
  287. warp/render/render_usd.py +24 -0
  288. warp/render/utils.py +24 -0
  289. warp/sparse.py +51 -0
  290. warp/tape.py +24 -0
  291. warp/tests/__init__.py +1 -0
  292. warp/tests/__main__.py +4 -0
  293. warp/tests/assets/curlnoise_golden.npy +0 -0
  294. warp/tests/assets/mlp_golden.npy +0 -0
  295. warp/tests/assets/pixel.npy +0 -0
  296. warp/tests/assets/pnoise_golden.npy +0 -0
  297. warp/tests/assets/spiky.usd +0 -0
  298. warp/tests/assets/test_grid.nvdb +0 -0
  299. warp/tests/assets/test_index_grid.nvdb +0 -0
  300. warp/tests/assets/test_int32_grid.nvdb +0 -0
  301. warp/tests/assets/test_vec_grid.nvdb +0 -0
  302. warp/tests/assets/torus.nvdb +0 -0
  303. warp/tests/assets/torus.usda +105 -0
  304. warp/tests/aux_test_class_kernel.py +34 -0
  305. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  306. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  307. warp/tests/aux_test_dependent.py +29 -0
  308. warp/tests/aux_test_grad_customs.py +29 -0
  309. warp/tests/aux_test_instancing_gc.py +26 -0
  310. warp/tests/aux_test_module_aot.py +7 -0
  311. warp/tests/aux_test_module_unload.py +23 -0
  312. warp/tests/aux_test_name_clash1.py +40 -0
  313. warp/tests/aux_test_name_clash2.py +40 -0
  314. warp/tests/aux_test_reference.py +9 -0
  315. warp/tests/aux_test_reference_reference.py +8 -0
  316. warp/tests/aux_test_square.py +16 -0
  317. warp/tests/aux_test_unresolved_func.py +22 -0
  318. warp/tests/aux_test_unresolved_symbol.py +22 -0
  319. warp/tests/cuda/__init__.py +0 -0
  320. warp/tests/cuda/test_async.py +676 -0
  321. warp/tests/cuda/test_conditional_captures.py +1147 -0
  322. warp/tests/cuda/test_ipc.py +124 -0
  323. warp/tests/cuda/test_mempool.py +233 -0
  324. warp/tests/cuda/test_multigpu.py +169 -0
  325. warp/tests/cuda/test_peer.py +139 -0
  326. warp/tests/cuda/test_pinned.py +84 -0
  327. warp/tests/cuda/test_streams.py +691 -0
  328. warp/tests/geometry/__init__.py +0 -0
  329. warp/tests/geometry/test_bvh.py +335 -0
  330. warp/tests/geometry/test_hash_grid.py +259 -0
  331. warp/tests/geometry/test_marching_cubes.py +294 -0
  332. warp/tests/geometry/test_mesh.py +318 -0
  333. warp/tests/geometry/test_mesh_query_aabb.py +392 -0
  334. warp/tests/geometry/test_mesh_query_point.py +935 -0
  335. warp/tests/geometry/test_mesh_query_ray.py +323 -0
  336. warp/tests/geometry/test_volume.py +1103 -0
  337. warp/tests/geometry/test_volume_write.py +346 -0
  338. warp/tests/interop/__init__.py +0 -0
  339. warp/tests/interop/test_dlpack.py +730 -0
  340. warp/tests/interop/test_jax.py +1673 -0
  341. warp/tests/interop/test_paddle.py +800 -0
  342. warp/tests/interop/test_torch.py +1001 -0
  343. warp/tests/run_coverage_serial.py +39 -0
  344. warp/tests/test_adam.py +162 -0
  345. warp/tests/test_arithmetic.py +1096 -0
  346. warp/tests/test_array.py +3756 -0
  347. warp/tests/test_array_reduce.py +156 -0
  348. warp/tests/test_assert.py +303 -0
  349. warp/tests/test_atomic.py +336 -0
  350. warp/tests/test_atomic_bitwise.py +209 -0
  351. warp/tests/test_atomic_cas.py +312 -0
  352. warp/tests/test_bool.py +220 -0
  353. warp/tests/test_builtins_resolution.py +732 -0
  354. warp/tests/test_closest_point_edge_edge.py +327 -0
  355. warp/tests/test_codegen.py +974 -0
  356. warp/tests/test_codegen_instancing.py +1495 -0
  357. warp/tests/test_compile_consts.py +215 -0
  358. warp/tests/test_conditional.py +298 -0
  359. warp/tests/test_context.py +35 -0
  360. warp/tests/test_copy.py +319 -0
  361. warp/tests/test_ctypes.py +618 -0
  362. warp/tests/test_dense.py +73 -0
  363. warp/tests/test_devices.py +127 -0
  364. warp/tests/test_enum.py +136 -0
  365. warp/tests/test_examples.py +424 -0
  366. warp/tests/test_fabricarray.py +998 -0
  367. warp/tests/test_fast_math.py +72 -0
  368. warp/tests/test_fem.py +2204 -0
  369. warp/tests/test_fixedarray.py +229 -0
  370. warp/tests/test_fp16.py +136 -0
  371. warp/tests/test_func.py +501 -0
  372. warp/tests/test_future_annotations.py +100 -0
  373. warp/tests/test_generics.py +656 -0
  374. warp/tests/test_grad.py +893 -0
  375. warp/tests/test_grad_customs.py +339 -0
  376. warp/tests/test_grad_debug.py +341 -0
  377. warp/tests/test_implicit_init.py +411 -0
  378. warp/tests/test_import.py +45 -0
  379. warp/tests/test_indexedarray.py +1140 -0
  380. warp/tests/test_intersect.py +103 -0
  381. warp/tests/test_iter.py +76 -0
  382. warp/tests/test_large.py +177 -0
  383. warp/tests/test_launch.py +411 -0
  384. warp/tests/test_lerp.py +151 -0
  385. warp/tests/test_linear_solvers.py +223 -0
  386. warp/tests/test_lvalue.py +427 -0
  387. warp/tests/test_map.py +526 -0
  388. warp/tests/test_mat.py +3515 -0
  389. warp/tests/test_mat_assign_copy.py +178 -0
  390. warp/tests/test_mat_constructors.py +573 -0
  391. warp/tests/test_mat_lite.py +122 -0
  392. warp/tests/test_mat_scalar_ops.py +2913 -0
  393. warp/tests/test_math.py +212 -0
  394. warp/tests/test_module_aot.py +287 -0
  395. warp/tests/test_module_hashing.py +258 -0
  396. warp/tests/test_modules_lite.py +70 -0
  397. warp/tests/test_noise.py +252 -0
  398. warp/tests/test_operators.py +299 -0
  399. warp/tests/test_options.py +129 -0
  400. warp/tests/test_overwrite.py +551 -0
  401. warp/tests/test_print.py +408 -0
  402. warp/tests/test_quat.py +2653 -0
  403. warp/tests/test_quat_assign_copy.py +145 -0
  404. warp/tests/test_rand.py +339 -0
  405. warp/tests/test_reload.py +303 -0
  406. warp/tests/test_rounding.py +157 -0
  407. warp/tests/test_runlength_encode.py +196 -0
  408. warp/tests/test_scalar_ops.py +133 -0
  409. warp/tests/test_smoothstep.py +108 -0
  410. warp/tests/test_snippet.py +318 -0
  411. warp/tests/test_sparse.py +845 -0
  412. warp/tests/test_spatial.py +2859 -0
  413. warp/tests/test_spatial_assign_copy.py +160 -0
  414. warp/tests/test_special_values.py +361 -0
  415. warp/tests/test_static.py +640 -0
  416. warp/tests/test_struct.py +901 -0
  417. warp/tests/test_tape.py +242 -0
  418. warp/tests/test_transient_module.py +93 -0
  419. warp/tests/test_triangle_closest_point.py +192 -0
  420. warp/tests/test_tuple.py +361 -0
  421. warp/tests/test_types.py +615 -0
  422. warp/tests/test_utils.py +594 -0
  423. warp/tests/test_vec.py +1408 -0
  424. warp/tests/test_vec_assign_copy.py +143 -0
  425. warp/tests/test_vec_constructors.py +325 -0
  426. warp/tests/test_vec_lite.py +80 -0
  427. warp/tests/test_vec_scalar_ops.py +2327 -0
  428. warp/tests/test_verify_fp.py +100 -0
  429. warp/tests/test_version.py +75 -0
  430. warp/tests/tile/__init__.py +0 -0
  431. warp/tests/tile/test_tile.py +1519 -0
  432. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  433. warp/tests/tile/test_tile_cholesky.py +608 -0
  434. warp/tests/tile/test_tile_load.py +724 -0
  435. warp/tests/tile/test_tile_mathdx.py +156 -0
  436. warp/tests/tile/test_tile_matmul.py +179 -0
  437. warp/tests/tile/test_tile_mlp.py +400 -0
  438. warp/tests/tile/test_tile_reduce.py +950 -0
  439. warp/tests/tile/test_tile_shared_memory.py +376 -0
  440. warp/tests/tile/test_tile_sort.py +121 -0
  441. warp/tests/tile/test_tile_view.py +173 -0
  442. warp/tests/unittest_serial.py +47 -0
  443. warp/tests/unittest_suites.py +430 -0
  444. warp/tests/unittest_utils.py +469 -0
  445. warp/tests/walkthrough_debug.py +95 -0
  446. warp/torch.py +24 -0
  447. warp/types.py +51 -0
  448. warp/utils.py +31 -0
  449. warp_lang-1.10.0.dist-info/METADATA +459 -0
  450. warp_lang-1.10.0.dist-info/RECORD +468 -0
  451. warp_lang-1.10.0.dist-info/WHEEL +5 -0
  452. warp_lang-1.10.0.dist-info/licenses/LICENSE.md +176 -0
  453. warp_lang-1.10.0.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  454. warp_lang-1.10.0.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  455. warp_lang-1.10.0.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  456. warp_lang-1.10.0.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  457. warp_lang-1.10.0.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  458. warp_lang-1.10.0.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  459. warp_lang-1.10.0.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  460. warp_lang-1.10.0.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  461. warp_lang-1.10.0.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  462. warp_lang-1.10.0.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  463. warp_lang-1.10.0.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  464. warp_lang-1.10.0.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  465. warp_lang-1.10.0.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  466. warp_lang-1.10.0.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  467. warp_lang-1.10.0.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  468. warp_lang-1.10.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,146 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ ###########################################################################
17
+ # Example jax_callable()
18
+ #
19
+ # Examples of calling annotated Python functions from JAX.
20
+ ###########################################################################
21
+
22
+ from functools import partial
23
+
24
+ import jax
25
+ import jax.numpy as jnp
26
+
27
+ import warp as wp
28
+ from warp.jax_experimental import jax_callable
29
+
30
+
31
+ @wp.kernel
32
+ def scale_kernel(a: wp.array(dtype=float), s: float, output: wp.array(dtype=float)):
33
+ tid = wp.tid()
34
+ output[tid] = a[tid] * s
35
+
36
+
37
+ @wp.kernel
38
+ def scale_vec_kernel(a: wp.array(dtype=wp.vec2), s: float, output: wp.array(dtype=wp.vec2)):
39
+ tid = wp.tid()
40
+ output[tid] = a[tid] * s
41
+
42
+
43
+ # The Python function to call.
44
+ # Note the argument annotations, just like Warp kernels.
45
+ def scale_func(
46
+ # inputs
47
+ a: wp.array(dtype=float),
48
+ b: wp.array(dtype=wp.vec2),
49
+ s: float,
50
+ # outputs
51
+ c: wp.array(dtype=float),
52
+ d: wp.array(dtype=wp.vec2),
53
+ ):
54
+ wp.launch(scale_kernel, dim=a.shape, inputs=[a, s], outputs=[c])
55
+ wp.launch(scale_vec_kernel, dim=b.shape, inputs=[b, s], outputs=[d])
56
+
57
+
58
+ @wp.kernel
59
+ def accum_kernel(a: wp.array(dtype=float), b: wp.array(dtype=float)):
60
+ tid = wp.tid()
61
+ b[tid] += a[tid]
62
+
63
+
64
+ def in_out_func(
65
+ a: wp.array(dtype=float), # input only
66
+ b: wp.array(dtype=float), # input and output
67
+ c: wp.array(dtype=float), # output only
68
+ ):
69
+ wp.launch(scale_kernel, dim=a.size, inputs=[a, 2.0], outputs=[c])
70
+ wp.launch(accum_kernel, dim=a.size, inputs=[a, b]) # modifies `b`
71
+
72
+
73
+ def example1():
74
+ jax_func = jax_callable(scale_func, num_outputs=2)
75
+
76
+ @jax.jit
77
+ def f():
78
+ # inputs
79
+ a = jnp.arange(10, dtype=jnp.float32)
80
+ b = jnp.arange(10, dtype=jnp.float32).reshape((5, 2)) # wp.vec2
81
+ s = 2.0
82
+
83
+ # output shapes
84
+ output_dims = {"c": a.shape, "d": b.shape}
85
+
86
+ c, d = jax_func(a, b, s, output_dims=output_dims)
87
+
88
+ return c, d
89
+
90
+ r1, r2 = f()
91
+ print(r1)
92
+ print(r2)
93
+
94
+
95
+ def example2():
96
+ jax_func = jax_callable(scale_func, num_outputs=2)
97
+
98
+ # NOTE: scalar arguments must be static compile-time constants
99
+ @partial(jax.jit, static_argnames=["s"])
100
+ def f(a, b, s):
101
+ # output shapes
102
+ output_dims = {"c": a.shape, "d": b.shape}
103
+
104
+ c, d = jax_func(a, b, s, output_dims=output_dims)
105
+
106
+ return c, d
107
+
108
+ # inputs
109
+ a = jnp.arange(10, dtype=jnp.float32)
110
+ b = jnp.arange(10, dtype=jnp.float32).reshape((5, 2)) # wp.vec2
111
+ s = 3.0
112
+
113
+ r1, r2 = f(a, b, s)
114
+ print(r1)
115
+ print(r2)
116
+
117
+
118
+ def example3():
119
+ # Using input-output arguments
120
+
121
+ jax_func = jax_callable(in_out_func, num_outputs=2, in_out_argnames=["b"])
122
+
123
+ f = jax.jit(jax_func)
124
+
125
+ a = jnp.ones(10, dtype=jnp.float32)
126
+ b = jnp.arange(10, dtype=jnp.float32)
127
+
128
+ b, c = f(a, b)
129
+ print(b)
130
+ print(c)
131
+
132
+
133
+ def main():
134
+ wp.init()
135
+ wp.load_module(device=wp.get_device())
136
+
137
+ examples = [example1, example2, example3]
138
+
139
+ for example in examples:
140
+ print("\n===========================================================================")
141
+ print(f"{example.__name__}:")
142
+ example()
143
+
144
+
145
+ if __name__ == "__main__":
146
+ main()
@@ -0,0 +1,132 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ ###########################################################################
17
+ # Example register_ffi_callback()
18
+ #
19
+ # Examples of calling Python functions from JAX.
20
+ # Target functions must have the form func(inputs, outputs, attrs, ctx).
21
+ ###########################################################################
22
+
23
+ import jax
24
+ import jax.numpy as jnp
25
+ import numpy as np
26
+
27
+ import warp as wp
28
+ from warp.jax import get_jax_device
29
+ from warp.jax_experimental import register_ffi_callback
30
+
31
+
32
+ @wp.kernel
33
+ def scale_kernel(a: wp.array(dtype=float), s: float, output: wp.array(dtype=float)):
34
+ tid = wp.tid()
35
+ output[tid] = a[tid] * s
36
+
37
+
38
+ @wp.kernel
39
+ def scale_vec_kernel(a: wp.array(dtype=wp.vec2), s: float, output: wp.array(dtype=wp.vec2)):
40
+ tid = wp.tid()
41
+ output[tid] = a[tid] * s
42
+
43
+
44
+ def example1():
45
+ # the Python function to call
46
+ def print_args(inputs, outputs, attrs, ctx):
47
+ def buffer_to_string(b):
48
+ return f"{b.dtype}{list(b.shape)} @{b.data:x}"
49
+
50
+ print("Inputs: ", ", ".join([buffer_to_string(b) for b in inputs]))
51
+ print("Outputs: ", ", ".join([buffer_to_string(b) for b in outputs]))
52
+ print("Attributes: ", "".join([f"\n {k}: {str(v)}" for k, v in attrs.items()])) # noqa: RUF010
53
+
54
+ # register callback
55
+ register_ffi_callback("print_args", print_args)
56
+
57
+ # set up call
58
+ call = jax.ffi.ffi_call("print_args", jax.ShapeDtypeStruct((1, 2, 3), jnp.int8))
59
+
60
+ # call it
61
+ call(
62
+ jnp.arange(16),
63
+ jnp.arange(32.0).reshape((4, 8)),
64
+ str_attr="hi",
65
+ f32_attr=np.float32(4.2),
66
+ dict_attr={"a": 1, "b": 6.4},
67
+ )
68
+
69
+
70
+ def example2():
71
+ # the Python function to call
72
+ def warp_func(inputs, outputs, attrs, ctx):
73
+ # input arrays
74
+ a = inputs[0]
75
+ b = inputs[1]
76
+
77
+ # scalar attributes
78
+ s = attrs["scale"]
79
+
80
+ # output arrays
81
+ c = outputs[0]
82
+ d = outputs[1]
83
+
84
+ device = wp.device_from_jax(get_jax_device())
85
+ stream = wp.Stream(device, cuda_stream=ctx.stream)
86
+
87
+ with wp.ScopedStream(stream):
88
+ # launch with arrays of scalars
89
+ wp.launch(scale_kernel, dim=a.shape, inputs=[a, s], outputs=[c])
90
+
91
+ # launch with arrays of vec2
92
+ # NOTE: the input shapes are from JAX arrays, we need to strip the inner dimension for vec2 arrays
93
+ wp.launch(scale_vec_kernel, dim=b.shape[0], inputs=[b, s], outputs=[d])
94
+
95
+ # register callback
96
+ register_ffi_callback("warp_func", warp_func)
97
+
98
+ n = 10
99
+
100
+ # inputs
101
+ a = jnp.arange(n, dtype=jnp.float32)
102
+ b = jnp.arange(n, dtype=jnp.float32).reshape((n // 2, 2)) # array of wp.vec2
103
+ s = 2.0
104
+
105
+ # set up call
106
+ out_types = [
107
+ jax.ShapeDtypeStruct(a.shape, jnp.float32),
108
+ jax.ShapeDtypeStruct(b.shape, jnp.float32), # array of wp.vec2
109
+ ]
110
+ call = jax.ffi.ffi_call("warp_func", out_types)
111
+
112
+ # call it
113
+ c, d = call(a, b, scale=s)
114
+
115
+ print(c)
116
+ print(d)
117
+
118
+
119
+ def main():
120
+ wp.init()
121
+ wp.load_module(device=wp.get_device())
122
+
123
+ examples = [example1, example2]
124
+
125
+ for example in examples:
126
+ print("\n===========================================================================")
127
+ print(f"{example.__name__}:")
128
+ example()
129
+
130
+
131
+ if __name__ == "__main__":
132
+ main()
@@ -0,0 +1,232 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ ###########################################################################
17
+ # Example jax_kernel()
18
+ #
19
+ # Examples of calling a Warp kernel from JAX.
20
+ ###########################################################################
21
+
22
+ import math
23
+ from functools import partial
24
+
25
+ import jax
26
+ import jax.numpy as jnp
27
+
28
+ import warp as wp
29
+ from warp.jax_experimental import jax_kernel
30
+
31
+
32
+ @wp.kernel
33
+ def add_kernel(a: wp.array(dtype=int), b: wp.array(dtype=int), output: wp.array(dtype=int)):
34
+ tid = wp.tid()
35
+ output[tid] = a[tid] + b[tid]
36
+
37
+
38
+ @wp.kernel
39
+ def sincos_kernel(angle: wp.array(dtype=float), sin_out: wp.array(dtype=float), cos_out: wp.array(dtype=float)):
40
+ tid = wp.tid()
41
+ sin_out[tid] = wp.sin(angle[tid])
42
+ cos_out[tid] = wp.cos(angle[tid])
43
+
44
+
45
+ @wp.kernel
46
+ def diagonal_kernel(output: wp.array(dtype=wp.mat33)):
47
+ tid = wp.tid()
48
+ d = float(tid + 1)
49
+ output[tid] = wp.mat33(d, 0.0, 0.0, 0.0, d * 2.0, 0.0, 0.0, 0.0, d * 3.0)
50
+
51
+
52
+ @wp.kernel
53
+ def matmul_kernel(
54
+ a: wp.array2d(dtype=float), # NxK
55
+ b: wp.array2d(dtype=float), # KxM
56
+ c: wp.array2d(dtype=float), # NxM
57
+ ):
58
+ # launch dims should be (N, M)
59
+ i, j = wp.tid()
60
+ N = a.shape[0]
61
+ K = a.shape[1]
62
+ M = b.shape[1]
63
+ if i < N and j < M:
64
+ s = wp.float32(0)
65
+ for k in range(K):
66
+ s += a[i, k] * b[k, j]
67
+ c[i, j] = s
68
+
69
+
70
+ @wp.kernel
71
+ def scale_vec_kernel(a: wp.array(dtype=wp.vec2), s: float, output: wp.array(dtype=wp.vec2)):
72
+ tid = wp.tid()
73
+ output[tid] = a[tid] * s
74
+
75
+
76
+ @wp.kernel
77
+ def in_out_kernel(
78
+ a: wp.array(dtype=float), # input only
79
+ b: wp.array(dtype=float), # input and output
80
+ c: wp.array(dtype=float), # output only
81
+ ):
82
+ tid = wp.tid()
83
+ b[tid] += a[tid]
84
+ c[tid] = 2.0 * a[tid]
85
+
86
+
87
+ def example1():
88
+ # two inputs and one output
89
+ jax_add = jax_kernel(add_kernel)
90
+
91
+ @jax.jit
92
+ def f():
93
+ n = 10
94
+ a = jnp.arange(n, dtype=jnp.int32)
95
+ b = jnp.ones(n, dtype=jnp.int32)
96
+ return jax_add(a, b)
97
+
98
+ print(f())
99
+
100
+
101
+ def example2():
102
+ # one input and two outputs
103
+ jax_sincos = jax_kernel(sincos_kernel, num_outputs=2)
104
+
105
+ @jax.jit
106
+ def f():
107
+ n = 32
108
+ a = jnp.linspace(0, 2 * math.pi, n)
109
+ return jax_sincos(a)
110
+
111
+ s, c = f()
112
+ print(s)
113
+ print()
114
+ print(c)
115
+
116
+
117
+ def example3():
118
+ # multiply vectors by scalar
119
+ jax_scale_vec = jax_kernel(scale_vec_kernel)
120
+
121
+ @jax.jit
122
+ def f():
123
+ a = jnp.arange(10, dtype=jnp.float32).reshape((5, 2)) # array of vec2
124
+ s = 2.0
125
+ return jax_scale_vec(a, s)
126
+
127
+ b = f()
128
+ print(b)
129
+
130
+
131
+ def example4():
132
+ # multiply vectors by scalar (static arg)
133
+ jax_scale_vec = jax_kernel(scale_vec_kernel)
134
+
135
+ # NOTE: scalar arguments must be static compile-time constants
136
+ @partial(jax.jit, static_argnames=["s"])
137
+ def f(a, s):
138
+ return jax_scale_vec(a, s)
139
+
140
+ a = jnp.arange(10, dtype=jnp.float32).reshape((5, 2)) # array of vec2
141
+ s = 3.0
142
+
143
+ b = f(a, s)
144
+ print(b)
145
+
146
+
147
+ def example5():
148
+ N, M, K = 3, 4, 2
149
+
150
+ # specify default launch dims
151
+ jax_matmul = jax_kernel(matmul_kernel, launch_dims=(N, M))
152
+
153
+ @jax.jit
154
+ def f():
155
+ a = jnp.full((N, K), 2, dtype=jnp.float32)
156
+ b = jnp.full((K, M), 3, dtype=jnp.float32)
157
+
158
+ # use default launch dims
159
+ return jax_matmul(a, b)
160
+
161
+ print(f())
162
+
163
+
164
+ def example6():
165
+ # don't specify default launch dims
166
+ jax_matmul = jax_kernel(matmul_kernel)
167
+
168
+ @jax.jit
169
+ def f():
170
+ N1, M1, K1 = 3, 4, 2
171
+ a1 = jnp.full((N1, K1), 2, dtype=jnp.float32)
172
+ b1 = jnp.full((K1, M1), 3, dtype=jnp.float32)
173
+
174
+ # use custom launch dims
175
+ result1 = jax_matmul(a1, b1, launch_dims=(N1, M1))
176
+
177
+ N2, M2, K2 = 4, 3, 2
178
+ a2 = jnp.full((N2, K2), 2, dtype=jnp.float32)
179
+ b2 = jnp.full((K2, M2), 3, dtype=jnp.float32)
180
+
181
+ # use custom launch dims
182
+ result2 = jax_matmul(a2, b2, launch_dims=(N2, M2))
183
+
184
+ return result1, result2
185
+
186
+ r1, r2 = f()
187
+ print(r1)
188
+ print()
189
+ print(r2)
190
+
191
+
192
+ def example7():
193
+ # no inputs and one output
194
+ jax_diagonal = jax_kernel(diagonal_kernel)
195
+
196
+ @jax.jit
197
+ def f():
198
+ # launch dimensions determine output size
199
+ return jax_diagonal(launch_dims=4)
200
+
201
+ print(f())
202
+
203
+
204
+ def example8():
205
+ # Using input-output arguments
206
+
207
+ jax_func = jax_kernel(in_out_kernel, num_outputs=2, in_out_argnames=["b"])
208
+
209
+ f = jax.jit(jax_func)
210
+
211
+ a = jnp.ones(10, dtype=jnp.float32)
212
+ b = jnp.arange(10, dtype=jnp.float32)
213
+
214
+ b, c = f(a, b)
215
+ print(b)
216
+ print(c)
217
+
218
+
219
+ def main():
220
+ wp.init()
221
+ wp.load_module(device=wp.get_device())
222
+
223
+ examples = [example1, example2, example3, example4, example5, example6, example7, example8]
224
+
225
+ for example in examples:
226
+ print("\n===========================================================================")
227
+ print(f"{example.__name__}:")
228
+ example()
229
+
230
+
231
+ if __name__ == "__main__":
232
+ main()