warp-lang 1.10.0__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (468) hide show
  1. warp/__init__.py +334 -0
  2. warp/__init__.pyi +5856 -0
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1077 -0
  5. warp/_src/build.py +620 -0
  6. warp/_src/build_dll.py +642 -0
  7. warp/_src/builtins.py +10555 -0
  8. warp/_src/codegen.py +4361 -0
  9. warp/_src/config.py +178 -0
  10. warp/_src/constants.py +59 -0
  11. warp/_src/context.py +8352 -0
  12. warp/_src/dlpack.py +464 -0
  13. warp/_src/fabric.py +362 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +510 -0
  16. warp/_src/fem/cache.py +689 -0
  17. warp/_src/fem/dirichlet.py +190 -0
  18. warp/_src/fem/domain.py +553 -0
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +703 -0
  21. warp/_src/fem/field/nodal_field.py +403 -0
  22. warp/_src/fem/field/restriction.py +39 -0
  23. warp/_src/fem/field/virtual.py +1021 -0
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/_src/fem/geometry/adaptive_nanogrid.py +782 -0
  26. warp/_src/fem/geometry/closest_point.py +99 -0
  27. warp/_src/fem/geometry/deformed_geometry.py +277 -0
  28. warp/_src/fem/geometry/element.py +854 -0
  29. warp/_src/fem/geometry/geometry.py +693 -0
  30. warp/_src/fem/geometry/grid_2d.py +478 -0
  31. warp/_src/fem/geometry/grid_3d.py +539 -0
  32. warp/_src/fem/geometry/hexmesh.py +956 -0
  33. warp/_src/fem/geometry/nanogrid.py +660 -0
  34. warp/_src/fem/geometry/partition.py +483 -0
  35. warp/_src/fem/geometry/quadmesh.py +597 -0
  36. warp/_src/fem/geometry/tetmesh.py +762 -0
  37. warp/_src/fem/geometry/trimesh.py +588 -0
  38. warp/_src/fem/integrate.py +2507 -0
  39. warp/_src/fem/linalg.py +385 -0
  40. warp/_src/fem/operator.py +398 -0
  41. warp/_src/fem/polynomial.py +231 -0
  42. warp/_src/fem/quadrature/__init__.py +17 -0
  43. warp/_src/fem/quadrature/pic_quadrature.py +318 -0
  44. warp/_src/fem/quadrature/quadrature.py +665 -0
  45. warp/_src/fem/space/__init__.py +248 -0
  46. warp/_src/fem/space/basis_function_space.py +499 -0
  47. warp/_src/fem/space/basis_space.py +681 -0
  48. warp/_src/fem/space/dof_mapper.py +253 -0
  49. warp/_src/fem/space/function_space.py +312 -0
  50. warp/_src/fem/space/grid_2d_function_space.py +179 -0
  51. warp/_src/fem/space/grid_3d_function_space.py +229 -0
  52. warp/_src/fem/space/hexmesh_function_space.py +255 -0
  53. warp/_src/fem/space/nanogrid_function_space.py +199 -0
  54. warp/_src/fem/space/partition.py +435 -0
  55. warp/_src/fem/space/quadmesh_function_space.py +222 -0
  56. warp/_src/fem/space/restriction.py +221 -0
  57. warp/_src/fem/space/shape/__init__.py +152 -0
  58. warp/_src/fem/space/shape/cube_shape_function.py +1107 -0
  59. warp/_src/fem/space/shape/shape_function.py +134 -0
  60. warp/_src/fem/space/shape/square_shape_function.py +928 -0
  61. warp/_src/fem/space/shape/tet_shape_function.py +829 -0
  62. warp/_src/fem/space/shape/triangle_shape_function.py +674 -0
  63. warp/_src/fem/space/tetmesh_function_space.py +270 -0
  64. warp/_src/fem/space/topology.py +461 -0
  65. warp/_src/fem/space/trimesh_function_space.py +193 -0
  66. warp/_src/fem/types.py +114 -0
  67. warp/_src/fem/utils.py +488 -0
  68. warp/_src/jax.py +188 -0
  69. warp/_src/jax_experimental/__init__.py +14 -0
  70. warp/_src/jax_experimental/custom_call.py +389 -0
  71. warp/_src/jax_experimental/ffi.py +1286 -0
  72. warp/_src/jax_experimental/xla_ffi.py +658 -0
  73. warp/_src/marching_cubes.py +710 -0
  74. warp/_src/math.py +416 -0
  75. warp/_src/optim/__init__.py +14 -0
  76. warp/_src/optim/adam.py +165 -0
  77. warp/_src/optim/linear.py +1608 -0
  78. warp/_src/optim/sgd.py +114 -0
  79. warp/_src/paddle.py +408 -0
  80. warp/_src/render/__init__.py +14 -0
  81. warp/_src/render/imgui_manager.py +291 -0
  82. warp/_src/render/render_opengl.py +3638 -0
  83. warp/_src/render/render_usd.py +939 -0
  84. warp/_src/render/utils.py +162 -0
  85. warp/_src/sparse.py +2718 -0
  86. warp/_src/tape.py +1208 -0
  87. warp/_src/thirdparty/__init__.py +0 -0
  88. warp/_src/thirdparty/appdirs.py +598 -0
  89. warp/_src/thirdparty/dlpack.py +145 -0
  90. warp/_src/thirdparty/unittest_parallel.py +676 -0
  91. warp/_src/torch.py +393 -0
  92. warp/_src/types.py +5888 -0
  93. warp/_src/utils.py +1695 -0
  94. warp/autograd.py +33 -0
  95. warp/bin/libwarp-clang.dylib +0 -0
  96. warp/bin/libwarp.dylib +0 -0
  97. warp/build.py +29 -0
  98. warp/build_dll.py +24 -0
  99. warp/codegen.py +24 -0
  100. warp/constants.py +24 -0
  101. warp/context.py +33 -0
  102. warp/dlpack.py +24 -0
  103. warp/examples/__init__.py +24 -0
  104. warp/examples/assets/bear.usd +0 -0
  105. warp/examples/assets/bunny.usd +0 -0
  106. warp/examples/assets/cube.usd +0 -0
  107. warp/examples/assets/nonuniform.usd +0 -0
  108. warp/examples/assets/nvidia_logo.png +0 -0
  109. warp/examples/assets/pixel.jpg +0 -0
  110. warp/examples/assets/rocks.nvdb +0 -0
  111. warp/examples/assets/rocks.usd +0 -0
  112. warp/examples/assets/sphere.usd +0 -0
  113. warp/examples/assets/square_cloth.usd +0 -0
  114. warp/examples/benchmarks/benchmark_api.py +389 -0
  115. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  116. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  117. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  118. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  119. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  120. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  121. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  122. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  123. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  124. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  125. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  126. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  127. warp/examples/benchmarks/benchmark_launches.py +301 -0
  128. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  129. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  130. warp/examples/browse.py +37 -0
  131. warp/examples/core/example_cupy.py +86 -0
  132. warp/examples/core/example_dem.py +241 -0
  133. warp/examples/core/example_fluid.py +299 -0
  134. warp/examples/core/example_graph_capture.py +150 -0
  135. warp/examples/core/example_marching_cubes.py +195 -0
  136. warp/examples/core/example_mesh.py +180 -0
  137. warp/examples/core/example_mesh_intersect.py +211 -0
  138. warp/examples/core/example_nvdb.py +182 -0
  139. warp/examples/core/example_raycast.py +111 -0
  140. warp/examples/core/example_raymarch.py +205 -0
  141. warp/examples/core/example_render_opengl.py +290 -0
  142. warp/examples/core/example_sample_mesh.py +300 -0
  143. warp/examples/core/example_sph.py +411 -0
  144. warp/examples/core/example_spin_lock.py +93 -0
  145. warp/examples/core/example_torch.py +211 -0
  146. warp/examples/core/example_wave.py +269 -0
  147. warp/examples/core/example_work_queue.py +118 -0
  148. warp/examples/distributed/example_jacobi_mpi.py +506 -0
  149. warp/examples/fem/example_adaptive_grid.py +286 -0
  150. warp/examples/fem/example_apic_fluid.py +469 -0
  151. warp/examples/fem/example_burgers.py +261 -0
  152. warp/examples/fem/example_convection_diffusion.py +181 -0
  153. warp/examples/fem/example_convection_diffusion_dg.py +225 -0
  154. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  155. warp/examples/fem/example_deformed_geometry.py +172 -0
  156. warp/examples/fem/example_diffusion.py +196 -0
  157. warp/examples/fem/example_diffusion_3d.py +225 -0
  158. warp/examples/fem/example_diffusion_mgpu.py +225 -0
  159. warp/examples/fem/example_distortion_energy.py +228 -0
  160. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  161. warp/examples/fem/example_magnetostatics.py +242 -0
  162. warp/examples/fem/example_mixed_elasticity.py +293 -0
  163. warp/examples/fem/example_navier_stokes.py +263 -0
  164. warp/examples/fem/example_nonconforming_contact.py +300 -0
  165. warp/examples/fem/example_stokes.py +213 -0
  166. warp/examples/fem/example_stokes_transfer.py +262 -0
  167. warp/examples/fem/example_streamlines.py +357 -0
  168. warp/examples/fem/utils.py +1047 -0
  169. warp/examples/interop/example_jax_callable.py +146 -0
  170. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  171. warp/examples/interop/example_jax_kernel.py +232 -0
  172. warp/examples/optim/example_diffray.py +561 -0
  173. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  174. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  175. warp/examples/tile/example_tile_cholesky.py +88 -0
  176. warp/examples/tile/example_tile_convolution.py +66 -0
  177. warp/examples/tile/example_tile_fft.py +55 -0
  178. warp/examples/tile/example_tile_filtering.py +113 -0
  179. warp/examples/tile/example_tile_matmul.py +85 -0
  180. warp/examples/tile/example_tile_mcgp.py +191 -0
  181. warp/examples/tile/example_tile_mlp.py +385 -0
  182. warp/examples/tile/example_tile_nbody.py +199 -0
  183. warp/fabric.py +24 -0
  184. warp/fem/__init__.py +173 -0
  185. warp/fem/adaptivity.py +26 -0
  186. warp/fem/cache.py +30 -0
  187. warp/fem/dirichlet.py +24 -0
  188. warp/fem/field/__init__.py +24 -0
  189. warp/fem/field/field.py +26 -0
  190. warp/fem/geometry/__init__.py +21 -0
  191. warp/fem/geometry/closest_point.py +31 -0
  192. warp/fem/linalg.py +38 -0
  193. warp/fem/operator.py +32 -0
  194. warp/fem/polynomial.py +29 -0
  195. warp/fem/space/__init__.py +22 -0
  196. warp/fem/space/basis_space.py +24 -0
  197. warp/fem/space/shape/__init__.py +68 -0
  198. warp/fem/space/topology.py +24 -0
  199. warp/fem/types.py +24 -0
  200. warp/fem/utils.py +32 -0
  201. warp/jax.py +29 -0
  202. warp/jax_experimental/__init__.py +29 -0
  203. warp/jax_experimental/custom_call.py +29 -0
  204. warp/jax_experimental/ffi.py +39 -0
  205. warp/jax_experimental/xla_ffi.py +24 -0
  206. warp/marching_cubes.py +24 -0
  207. warp/math.py +37 -0
  208. warp/native/array.h +1687 -0
  209. warp/native/builtin.h +2327 -0
  210. warp/native/bvh.cpp +562 -0
  211. warp/native/bvh.cu +826 -0
  212. warp/native/bvh.h +555 -0
  213. warp/native/clang/clang.cpp +541 -0
  214. warp/native/coloring.cpp +622 -0
  215. warp/native/crt.cpp +51 -0
  216. warp/native/crt.h +568 -0
  217. warp/native/cuda_crt.h +1058 -0
  218. warp/native/cuda_util.cpp +677 -0
  219. warp/native/cuda_util.h +313 -0
  220. warp/native/error.cpp +77 -0
  221. warp/native/error.h +36 -0
  222. warp/native/exports.h +2023 -0
  223. warp/native/fabric.h +246 -0
  224. warp/native/hashgrid.cpp +311 -0
  225. warp/native/hashgrid.cu +89 -0
  226. warp/native/hashgrid.h +240 -0
  227. warp/native/initializer_array.h +41 -0
  228. warp/native/intersect.h +1253 -0
  229. warp/native/intersect_adj.h +375 -0
  230. warp/native/intersect_tri.h +348 -0
  231. warp/native/mat.h +5189 -0
  232. warp/native/mathdx.cpp +93 -0
  233. warp/native/matnn.h +221 -0
  234. warp/native/mesh.cpp +266 -0
  235. warp/native/mesh.cu +406 -0
  236. warp/native/mesh.h +2097 -0
  237. warp/native/nanovdb/GridHandle.h +533 -0
  238. warp/native/nanovdb/HostBuffer.h +591 -0
  239. warp/native/nanovdb/NanoVDB.h +6246 -0
  240. warp/native/nanovdb/NodeManager.h +323 -0
  241. warp/native/nanovdb/PNanoVDB.h +3390 -0
  242. warp/native/noise.h +859 -0
  243. warp/native/quat.h +1664 -0
  244. warp/native/rand.h +342 -0
  245. warp/native/range.h +145 -0
  246. warp/native/reduce.cpp +174 -0
  247. warp/native/reduce.cu +363 -0
  248. warp/native/runlength_encode.cpp +79 -0
  249. warp/native/runlength_encode.cu +61 -0
  250. warp/native/scan.cpp +47 -0
  251. warp/native/scan.cu +55 -0
  252. warp/native/scan.h +23 -0
  253. warp/native/solid_angle.h +466 -0
  254. warp/native/sort.cpp +251 -0
  255. warp/native/sort.cu +286 -0
  256. warp/native/sort.h +35 -0
  257. warp/native/sparse.cpp +241 -0
  258. warp/native/sparse.cu +435 -0
  259. warp/native/spatial.h +1306 -0
  260. warp/native/svd.h +727 -0
  261. warp/native/temp_buffer.h +46 -0
  262. warp/native/tile.h +4124 -0
  263. warp/native/tile_radix_sort.h +1112 -0
  264. warp/native/tile_reduce.h +838 -0
  265. warp/native/tile_scan.h +240 -0
  266. warp/native/tuple.h +189 -0
  267. warp/native/vec.h +2199 -0
  268. warp/native/version.h +23 -0
  269. warp/native/volume.cpp +501 -0
  270. warp/native/volume.cu +68 -0
  271. warp/native/volume.h +970 -0
  272. warp/native/volume_builder.cu +483 -0
  273. warp/native/volume_builder.h +52 -0
  274. warp/native/volume_impl.h +70 -0
  275. warp/native/warp.cpp +1143 -0
  276. warp/native/warp.cu +4604 -0
  277. warp/native/warp.h +358 -0
  278. warp/optim/__init__.py +20 -0
  279. warp/optim/adam.py +24 -0
  280. warp/optim/linear.py +35 -0
  281. warp/optim/sgd.py +24 -0
  282. warp/paddle.py +24 -0
  283. warp/py.typed +0 -0
  284. warp/render/__init__.py +22 -0
  285. warp/render/imgui_manager.py +29 -0
  286. warp/render/render_opengl.py +24 -0
  287. warp/render/render_usd.py +24 -0
  288. warp/render/utils.py +24 -0
  289. warp/sparse.py +51 -0
  290. warp/tape.py +24 -0
  291. warp/tests/__init__.py +1 -0
  292. warp/tests/__main__.py +4 -0
  293. warp/tests/assets/curlnoise_golden.npy +0 -0
  294. warp/tests/assets/mlp_golden.npy +0 -0
  295. warp/tests/assets/pixel.npy +0 -0
  296. warp/tests/assets/pnoise_golden.npy +0 -0
  297. warp/tests/assets/spiky.usd +0 -0
  298. warp/tests/assets/test_grid.nvdb +0 -0
  299. warp/tests/assets/test_index_grid.nvdb +0 -0
  300. warp/tests/assets/test_int32_grid.nvdb +0 -0
  301. warp/tests/assets/test_vec_grid.nvdb +0 -0
  302. warp/tests/assets/torus.nvdb +0 -0
  303. warp/tests/assets/torus.usda +105 -0
  304. warp/tests/aux_test_class_kernel.py +34 -0
  305. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  306. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  307. warp/tests/aux_test_dependent.py +29 -0
  308. warp/tests/aux_test_grad_customs.py +29 -0
  309. warp/tests/aux_test_instancing_gc.py +26 -0
  310. warp/tests/aux_test_module_aot.py +7 -0
  311. warp/tests/aux_test_module_unload.py +23 -0
  312. warp/tests/aux_test_name_clash1.py +40 -0
  313. warp/tests/aux_test_name_clash2.py +40 -0
  314. warp/tests/aux_test_reference.py +9 -0
  315. warp/tests/aux_test_reference_reference.py +8 -0
  316. warp/tests/aux_test_square.py +16 -0
  317. warp/tests/aux_test_unresolved_func.py +22 -0
  318. warp/tests/aux_test_unresolved_symbol.py +22 -0
  319. warp/tests/cuda/__init__.py +0 -0
  320. warp/tests/cuda/test_async.py +676 -0
  321. warp/tests/cuda/test_conditional_captures.py +1147 -0
  322. warp/tests/cuda/test_ipc.py +124 -0
  323. warp/tests/cuda/test_mempool.py +233 -0
  324. warp/tests/cuda/test_multigpu.py +169 -0
  325. warp/tests/cuda/test_peer.py +139 -0
  326. warp/tests/cuda/test_pinned.py +84 -0
  327. warp/tests/cuda/test_streams.py +691 -0
  328. warp/tests/geometry/__init__.py +0 -0
  329. warp/tests/geometry/test_bvh.py +335 -0
  330. warp/tests/geometry/test_hash_grid.py +259 -0
  331. warp/tests/geometry/test_marching_cubes.py +294 -0
  332. warp/tests/geometry/test_mesh.py +318 -0
  333. warp/tests/geometry/test_mesh_query_aabb.py +392 -0
  334. warp/tests/geometry/test_mesh_query_point.py +935 -0
  335. warp/tests/geometry/test_mesh_query_ray.py +323 -0
  336. warp/tests/geometry/test_volume.py +1103 -0
  337. warp/tests/geometry/test_volume_write.py +346 -0
  338. warp/tests/interop/__init__.py +0 -0
  339. warp/tests/interop/test_dlpack.py +730 -0
  340. warp/tests/interop/test_jax.py +1673 -0
  341. warp/tests/interop/test_paddle.py +800 -0
  342. warp/tests/interop/test_torch.py +1001 -0
  343. warp/tests/run_coverage_serial.py +39 -0
  344. warp/tests/test_adam.py +162 -0
  345. warp/tests/test_arithmetic.py +1096 -0
  346. warp/tests/test_array.py +3756 -0
  347. warp/tests/test_array_reduce.py +156 -0
  348. warp/tests/test_assert.py +303 -0
  349. warp/tests/test_atomic.py +336 -0
  350. warp/tests/test_atomic_bitwise.py +209 -0
  351. warp/tests/test_atomic_cas.py +312 -0
  352. warp/tests/test_bool.py +220 -0
  353. warp/tests/test_builtins_resolution.py +732 -0
  354. warp/tests/test_closest_point_edge_edge.py +327 -0
  355. warp/tests/test_codegen.py +974 -0
  356. warp/tests/test_codegen_instancing.py +1495 -0
  357. warp/tests/test_compile_consts.py +215 -0
  358. warp/tests/test_conditional.py +298 -0
  359. warp/tests/test_context.py +35 -0
  360. warp/tests/test_copy.py +319 -0
  361. warp/tests/test_ctypes.py +618 -0
  362. warp/tests/test_dense.py +73 -0
  363. warp/tests/test_devices.py +127 -0
  364. warp/tests/test_enum.py +136 -0
  365. warp/tests/test_examples.py +424 -0
  366. warp/tests/test_fabricarray.py +998 -0
  367. warp/tests/test_fast_math.py +72 -0
  368. warp/tests/test_fem.py +2204 -0
  369. warp/tests/test_fixedarray.py +229 -0
  370. warp/tests/test_fp16.py +136 -0
  371. warp/tests/test_func.py +501 -0
  372. warp/tests/test_future_annotations.py +100 -0
  373. warp/tests/test_generics.py +656 -0
  374. warp/tests/test_grad.py +893 -0
  375. warp/tests/test_grad_customs.py +339 -0
  376. warp/tests/test_grad_debug.py +341 -0
  377. warp/tests/test_implicit_init.py +411 -0
  378. warp/tests/test_import.py +45 -0
  379. warp/tests/test_indexedarray.py +1140 -0
  380. warp/tests/test_intersect.py +103 -0
  381. warp/tests/test_iter.py +76 -0
  382. warp/tests/test_large.py +177 -0
  383. warp/tests/test_launch.py +411 -0
  384. warp/tests/test_lerp.py +151 -0
  385. warp/tests/test_linear_solvers.py +223 -0
  386. warp/tests/test_lvalue.py +427 -0
  387. warp/tests/test_map.py +526 -0
  388. warp/tests/test_mat.py +3515 -0
  389. warp/tests/test_mat_assign_copy.py +178 -0
  390. warp/tests/test_mat_constructors.py +573 -0
  391. warp/tests/test_mat_lite.py +122 -0
  392. warp/tests/test_mat_scalar_ops.py +2913 -0
  393. warp/tests/test_math.py +212 -0
  394. warp/tests/test_module_aot.py +287 -0
  395. warp/tests/test_module_hashing.py +258 -0
  396. warp/tests/test_modules_lite.py +70 -0
  397. warp/tests/test_noise.py +252 -0
  398. warp/tests/test_operators.py +299 -0
  399. warp/tests/test_options.py +129 -0
  400. warp/tests/test_overwrite.py +551 -0
  401. warp/tests/test_print.py +408 -0
  402. warp/tests/test_quat.py +2653 -0
  403. warp/tests/test_quat_assign_copy.py +145 -0
  404. warp/tests/test_rand.py +339 -0
  405. warp/tests/test_reload.py +303 -0
  406. warp/tests/test_rounding.py +157 -0
  407. warp/tests/test_runlength_encode.py +196 -0
  408. warp/tests/test_scalar_ops.py +133 -0
  409. warp/tests/test_smoothstep.py +108 -0
  410. warp/tests/test_snippet.py +318 -0
  411. warp/tests/test_sparse.py +845 -0
  412. warp/tests/test_spatial.py +2859 -0
  413. warp/tests/test_spatial_assign_copy.py +160 -0
  414. warp/tests/test_special_values.py +361 -0
  415. warp/tests/test_static.py +640 -0
  416. warp/tests/test_struct.py +901 -0
  417. warp/tests/test_tape.py +242 -0
  418. warp/tests/test_transient_module.py +93 -0
  419. warp/tests/test_triangle_closest_point.py +192 -0
  420. warp/tests/test_tuple.py +361 -0
  421. warp/tests/test_types.py +615 -0
  422. warp/tests/test_utils.py +594 -0
  423. warp/tests/test_vec.py +1408 -0
  424. warp/tests/test_vec_assign_copy.py +143 -0
  425. warp/tests/test_vec_constructors.py +325 -0
  426. warp/tests/test_vec_lite.py +80 -0
  427. warp/tests/test_vec_scalar_ops.py +2327 -0
  428. warp/tests/test_verify_fp.py +100 -0
  429. warp/tests/test_version.py +75 -0
  430. warp/tests/tile/__init__.py +0 -0
  431. warp/tests/tile/test_tile.py +1519 -0
  432. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  433. warp/tests/tile/test_tile_cholesky.py +608 -0
  434. warp/tests/tile/test_tile_load.py +724 -0
  435. warp/tests/tile/test_tile_mathdx.py +156 -0
  436. warp/tests/tile/test_tile_matmul.py +179 -0
  437. warp/tests/tile/test_tile_mlp.py +400 -0
  438. warp/tests/tile/test_tile_reduce.py +950 -0
  439. warp/tests/tile/test_tile_shared_memory.py +376 -0
  440. warp/tests/tile/test_tile_sort.py +121 -0
  441. warp/tests/tile/test_tile_view.py +173 -0
  442. warp/tests/unittest_serial.py +47 -0
  443. warp/tests/unittest_suites.py +430 -0
  444. warp/tests/unittest_utils.py +469 -0
  445. warp/tests/walkthrough_debug.py +95 -0
  446. warp/torch.py +24 -0
  447. warp/types.py +51 -0
  448. warp/utils.py +31 -0
  449. warp_lang-1.10.0.dist-info/METADATA +459 -0
  450. warp_lang-1.10.0.dist-info/RECORD +468 -0
  451. warp_lang-1.10.0.dist-info/WHEEL +5 -0
  452. warp_lang-1.10.0.dist-info/licenses/LICENSE.md +176 -0
  453. warp_lang-1.10.0.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  454. warp_lang-1.10.0.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  455. warp_lang-1.10.0.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  456. warp_lang-1.10.0.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  457. warp_lang-1.10.0.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  458. warp_lang-1.10.0.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  459. warp_lang-1.10.0.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  460. warp_lang-1.10.0.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  461. warp_lang-1.10.0.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  462. warp_lang-1.10.0.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  463. warp_lang-1.10.0.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  464. warp_lang-1.10.0.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  465. warp_lang-1.10.0.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  466. warp_lang-1.10.0.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  467. warp_lang-1.10.0.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  468. warp_lang-1.10.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,901 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import gc # Added for garbage collection tests
17
+ import unittest
18
+ from typing import Any
19
+
20
+ import numpy as np
21
+
22
+ import warp as wp
23
+ from warp.fem import Sample as StructFromAnotherModule
24
+ from warp.tests.unittest_utils import *
25
+
26
+
27
+ @wp.struct
28
+ class Model:
29
+ dt: float
30
+ gravity: wp.vec3
31
+ m: wp.array(dtype=float)
32
+
33
+
34
+ @wp.struct
35
+ class State:
36
+ x: wp.array(dtype=wp.vec3)
37
+ v: wp.array(dtype=wp.vec3)
38
+
39
+
40
+ @wp.kernel
41
+ def kernel_step(state_in: State, state_out: State, model: Model):
42
+ i = wp.tid()
43
+
44
+ state_out.v[i] = state_in.v[i] + model.gravity / model.m[i] * model.dt
45
+ state_out.x[i] = state_in.x[i] + state_out.v[i] * model.dt
46
+
47
+
48
+ @wp.kernel
49
+ def kernel_step_with_copy(state_in: State, state_out: State, model: Model):
50
+ i = wp.tid()
51
+
52
+ model_rescaled = Model(1.0, model.gravity / model.m[i] * model.dt, model.m)
53
+
54
+ state_out_copy = State(state_out.x, state_out.v)
55
+ state_out_copy.v[i] = state_in.v[i] + model_rescaled.gravity
56
+ state_out_copy.x[i] = state_in.x[i] + state_out_copy.v[i] * model.dt
57
+
58
+
59
+ def test_step(test, device):
60
+ rng = np.random.default_rng(123)
61
+
62
+ dim = 5
63
+
64
+ dt = 0.01
65
+ gravity = np.array([0, 0, -9.81])
66
+
67
+ m = np.ones(dim)
68
+
69
+ m_model = wp.array(m, dtype=float, device=device)
70
+
71
+ model = Model()
72
+ model.m = m_model
73
+ model.dt = dt
74
+ model.gravity = wp.vec3(0, 0, -9.81)
75
+
76
+ x = rng.normal(size=(dim, 3))
77
+ v = rng.normal(size=(dim, 3))
78
+
79
+ x_expected = x + (v + gravity / m[:, None] * dt) * dt
80
+
81
+ x_in = wp.array(x, dtype=wp.vec3, device=device)
82
+ v_in = wp.array(v, dtype=wp.vec3, device=device)
83
+
84
+ state_in = State()
85
+ state_in.x = x_in
86
+ state_in.v = v_in
87
+
88
+ state_out = State()
89
+ state_out.x = wp.empty_like(x_in)
90
+ state_out.v = wp.empty_like(v_in)
91
+
92
+ for step_kernel in [kernel_step, kernel_step_with_copy]:
93
+ with CheckOutput(test):
94
+ wp.launch(step_kernel, dim=dim, inputs=[state_in, state_out, model], device=device)
95
+
96
+ assert_np_equal(state_out.x.numpy(), x_expected, tol=1e-6)
97
+
98
+
99
+ @wp.kernel
100
+ def kernel_loss(x: wp.array(dtype=wp.vec3), loss: wp.array(dtype=float)):
101
+ i = wp.tid()
102
+ wp.atomic_add(loss, 0, x[i][0] * x[i][0] + x[i][1] * x[i][1] + x[i][2] * x[i][2])
103
+
104
+
105
+ def test_step_grad(test, device):
106
+ rng = np.random.default_rng(123)
107
+
108
+ dim = 5
109
+
110
+ dt = 0.01
111
+ gravity = np.array([0, 0, -9.81])
112
+
113
+ m = rng.random(size=dim) + 0.1
114
+
115
+ m_model = wp.array(m, dtype=float, device=device, requires_grad=True)
116
+
117
+ model = Model()
118
+ model.m = m_model
119
+ model.dt = dt
120
+ model.gravity = wp.vec3(0, 0, -9.81)
121
+
122
+ x = rng.normal(size=(dim, 3))
123
+ v = rng.normal(size=(dim, 3))
124
+
125
+ x_in = wp.array(x, dtype=wp.vec3, device=device, requires_grad=True)
126
+ v_in = wp.array(v, dtype=wp.vec3, device=device, requires_grad=True)
127
+
128
+ state_in = State()
129
+ state_in.x = x_in
130
+ state_in.v = v_in
131
+
132
+ state_out = State()
133
+ state_out.x = wp.empty_like(x_in, requires_grad=True)
134
+ state_out.v = wp.empty_like(v_in, requires_grad=True)
135
+
136
+ loss = wp.empty(1, dtype=float, device=device, requires_grad=True)
137
+
138
+ for step_kernel in [kernel_step, kernel_step_with_copy]:
139
+ tape = wp.Tape()
140
+
141
+ with tape:
142
+ wp.launch(step_kernel, dim=dim, inputs=[state_in, state_out, model], device=device)
143
+ wp.launch(kernel_loss, dim=dim, inputs=[state_out.x, loss], device=device)
144
+
145
+ tape.backward(loss)
146
+
147
+ dl_dx = 2 * state_out.x.numpy()
148
+ dl_dv = dl_dx * dt
149
+
150
+ dv_dm = -gravity * dt / m[:, None] ** 2
151
+ dl_dm = (dl_dv * dv_dm).sum(-1)
152
+
153
+ assert_np_equal(state_out.x.grad.numpy(), dl_dx, tol=1e-6)
154
+ assert_np_equal(state_in.x.grad.numpy(), dl_dx, tol=1e-6)
155
+ assert_np_equal(state_out.v.grad.numpy(), dl_dv, tol=1e-6)
156
+ assert_np_equal(state_in.v.grad.numpy(), dl_dv, tol=1e-6)
157
+ assert_np_equal(model.m.grad.numpy(), dl_dm, tol=1e-6)
158
+
159
+ tape.zero()
160
+
161
+ assert state_out.x.grad.numpy().sum() == 0.0
162
+ assert state_in.x.grad.numpy().sum() == 0.0
163
+ assert state_out.v.grad.numpy().sum() == 0.0
164
+ assert state_in.v.grad.numpy().sum() == 0.0
165
+ assert model.m.grad.numpy().sum() == 0.0
166
+
167
+
168
+ @wp.struct
169
+ class Empty:
170
+ pass
171
+
172
+
173
+ @wp.kernel
174
+ def test_empty(input: Empty):
175
+ tid = wp.tid()
176
+
177
+
178
+ @wp.struct
179
+ class Uninitialized:
180
+ data: wp.array(dtype=int)
181
+
182
+
183
+ @wp.kernel
184
+ def test_uninitialized(input: Uninitialized):
185
+ tid = wp.tid()
186
+
187
+
188
+ @wp.struct
189
+ class Baz:
190
+ data: wp.array(dtype=int)
191
+ z: wp.vec3
192
+
193
+
194
+ @wp.struct
195
+ class Bar:
196
+ baz: Baz
197
+ y: float
198
+
199
+
200
+ @wp.struct
201
+ class Foo:
202
+ bar: Bar
203
+ x: int
204
+
205
+
206
+ @wp.kernel
207
+ def kernel_nested_struct(foo: Foo):
208
+ tid = wp.tid()
209
+ foo.bar.baz.data[tid] = (
210
+ foo.bar.baz.data[tid] + foo.x + int(foo.bar.y * 100.0) + int(wp.length_sq(foo.bar.baz.z)) + tid * 2
211
+ )
212
+
213
+
214
+ def test_nested_struct(test, device):
215
+ dim = 3
216
+
217
+ foo = Foo()
218
+ foo.bar = Bar()
219
+ foo.bar.baz = Baz()
220
+ foo.bar.baz.data = wp.zeros(dim, dtype=int, device=device)
221
+ foo.bar.baz.z = wp.vec3(1, 2, 3)
222
+ foo.bar.y = 1.23
223
+ foo.x = 123
224
+
225
+ # verify that struct attributes are instances of their original class
226
+ assert isinstance(foo, Foo.cls)
227
+ assert isinstance(foo.bar, Bar.cls)
228
+ assert isinstance(foo.bar.baz, Baz.cls)
229
+
230
+ wp.launch(kernel_nested_struct, dim=dim, inputs=[foo], device=device)
231
+
232
+ assert_array_equal(
233
+ foo.bar.baz.data,
234
+ wp.array((260, 262, 264), dtype=int, device=device),
235
+ )
236
+
237
+
238
+ @wp.struct
239
+ class MatStruct:
240
+ m: wp.mat44
241
+
242
+
243
+ @wp.kernel
244
+ def kernel_nested_mat(out: wp.array(dtype=wp.mat44)):
245
+ s = MatStruct()
246
+
247
+ s.m[0, 0] = 2.0
248
+ s.m[1, 2] = 3.0
249
+ s.m[2][1] = 5.0
250
+
251
+ out[0] = s.m
252
+
253
+ out[0][2, 2] = 6.0
254
+ out[0][1][1] = 7.0
255
+
256
+ out[0][3, 3] = out[0][0][0]
257
+
258
+
259
+ def test_nested_mat(test, device):
260
+ m = wp.array([wp.mat44()], dtype=wp.mat44, device=device)
261
+ wp.launch(kernel_nested_mat, dim=1, outputs=[m], device=device)
262
+ wp.synchronize()
263
+
264
+ out = m.numpy()
265
+ assert_np_equal(out[0][0, 0], 2.0)
266
+ assert_np_equal(out[0][1, 2], 3.0)
267
+ assert_np_equal(out[0][2][1], 5.0)
268
+ assert_np_equal(out[0][2, 2], 6.0)
269
+ assert_np_equal(out[0][1][1], 7.0)
270
+ assert_np_equal(out[0][3, 3], 2.0)
271
+
272
+
273
+ def test_assign_view(test, device):
274
+ @wp.kernel
275
+ def kernel_assign_view(out: wp.array2d(dtype=wp.mat44)):
276
+ out[0][2, 2] = 6.0
277
+
278
+ m = wp.array([[wp.mat44()]], dtype=wp.mat44, device=device)
279
+
280
+ with test.assertRaisesRegex(
281
+ wp._src.codegen.WarpCodegenError,
282
+ r"Incorrect number of indices specified for array indexing",
283
+ ):
284
+ wp.launch(kernel_assign_view, dim=[1, 1], outputs=[m], device=device)
285
+
286
+
287
+ def test_struct_attribute_error(test, device):
288
+ @wp.kernel
289
+ def kernel(foo: Foo):
290
+ _ = foo.nonexisting
291
+
292
+ with test.assertRaisesRegex(AttributeError, r"`nonexisting` is not an attribute of 'foo' \([\w.]+\.Foo\)$"):
293
+ wp.launch(
294
+ kernel,
295
+ dim=1,
296
+ inputs=[Foo()],
297
+ device=device,
298
+ )
299
+
300
+
301
+ def test_struct_inheritance_error(test, device):
302
+ with test.assertRaisesRegex(RuntimeError, r"Warp structs must be defined as base classes$"):
303
+
304
+ @wp.struct
305
+ class Parent:
306
+ x: int
307
+
308
+ @wp.struct
309
+ class Child(Parent):
310
+ y: int
311
+
312
+
313
+ @wp.kernel
314
+ def test_struct_instantiate(data: wp.array(dtype=int)):
315
+ baz = Baz(data, wp.vec3(0.0, 0.0, 26.0))
316
+ bar = Bar(baz, 25.0)
317
+ foo = Foo(bar, 24)
318
+
319
+ wp.expect_eq(foo.x, 24)
320
+ wp.expect_eq(foo.bar.y, 25.0)
321
+ wp.expect_eq(foo.bar.baz.z[2], 26.0)
322
+ wp.expect_eq(foo.bar.baz.data[0], 1)
323
+
324
+
325
+ @wp.struct
326
+ class MathThings:
327
+ v1: wp.vec3
328
+ v2: wp.vec3
329
+ v3: wp.vec3
330
+ m1: wp.mat22
331
+ m2: wp.mat22
332
+ m3: wp.mat22
333
+ m4: wp.mat22
334
+ m5: wp.mat22
335
+ m6: wp.mat22
336
+
337
+
338
+ @wp.kernel
339
+ def check_math_conversions(s: MathThings):
340
+ wp.expect_eq(s.v1, wp.vec3(1.0, 2.0, 3.0))
341
+ wp.expect_eq(s.v2, wp.vec3(10.0, 20.0, 30.0))
342
+ wp.expect_eq(s.v3, wp.vec3(100.0, 200.0, 300.0))
343
+ wp.expect_eq(s.m1, wp.mat22(1.0, 2.0, 3.0, 4.0))
344
+ wp.expect_eq(s.m2, wp.mat22(10.0, 20.0, 30.0, 40.0))
345
+ wp.expect_eq(s.m3, wp.mat22(100.0, 200.0, 300.0, 400.0))
346
+ wp.expect_eq(s.m4, wp.mat22(1.0, 2.0, 3.0, 4.0))
347
+ wp.expect_eq(s.m5, wp.mat22(10.0, 20.0, 30.0, 40.0))
348
+ wp.expect_eq(s.m6, wp.mat22(100.0, 200.0, 300.0, 400.0))
349
+
350
+
351
+ def test_struct_math_conversions(test, device):
352
+ s = MathThings()
353
+
354
+ # test assigning various containers to vector and matrix attributes
355
+ s.v1 = (1, 2, 3)
356
+ s.v2 = [10, 20, 30]
357
+ s.v3 = np.array([100, 200, 300])
358
+ # 2d containers for matrices
359
+ s.m1 = ((1, 2), (3, 4))
360
+ s.m2 = [[10, 20], [30, 40]]
361
+ s.m3 = np.array([[100, 200], [300, 400]])
362
+ # 1d containers for matrices
363
+ s.m4 = (1, 2, 3, 4)
364
+ s.m5 = [10, 20, 30, 40]
365
+ s.m6 = np.array([100, 200, 300, 400])
366
+
367
+ wp.launch(check_math_conversions, dim=1, inputs=[s], device=device)
368
+
369
+
370
+ @wp.struct
371
+ class TestData:
372
+ value: wp.int32
373
+
374
+
375
+ @wp.func
376
+ def GetTestData(value: wp.int32):
377
+ return TestData(value * 2)
378
+
379
+
380
+ @wp.kernel
381
+ def test_return_struct(data: wp.array(dtype=wp.int32)):
382
+ tid = wp.tid()
383
+ data[tid] = GetTestData(tid).value
384
+
385
+ wp.expect_eq(data[tid], tid * 2)
386
+
387
+
388
+ @wp.struct
389
+ class ReturnStruct:
390
+ a: int
391
+ b: int
392
+
393
+
394
+ @wp.func
395
+ def test_return_func():
396
+ a = ReturnStruct(1, 2)
397
+ return a
398
+
399
+
400
+ @wp.kernel
401
+ def test_return():
402
+ t = test_return_func()
403
+ wp.expect_eq(t.a, 1)
404
+ wp.expect_eq(t.b, 2)
405
+
406
+
407
+ @wp.struct
408
+ class DefaultAttribNested:
409
+ f: float
410
+
411
+
412
+ @wp.struct
413
+ class DefaultAttribStruct:
414
+ i: int
415
+ d: wp.float64
416
+ v: wp.vec3
417
+ m: wp.mat22
418
+ a: wp.array(dtype=wp.int32)
419
+ s: DefaultAttribNested
420
+
421
+
422
+ @wp.func
423
+ def check_default_attributes_func(data: DefaultAttribStruct):
424
+ wp.expect_eq(data.i, wp.int32(0))
425
+ wp.expect_eq(data.d, wp.float64(0))
426
+ wp.expect_eq(data.v, wp.vec3(0.0, 0.0, 0.0))
427
+ wp.expect_eq(data.m, wp.mat22(0.0, 0.0, 0.0, 0.0))
428
+ wp.expect_eq(data.a.shape[0], 0)
429
+ wp.expect_eq(data.s.f, wp.float32(0.0))
430
+
431
+
432
+ @wp.kernel
433
+ def check_default_attributes_kernel(data: DefaultAttribStruct):
434
+ check_default_attributes_func(data)
435
+
436
+
437
+ # check structs default initialized in kernels correctly
438
+ @wp.kernel
439
+ def test_struct_default_attributes_kernel():
440
+ s = DefaultAttribStruct()
441
+
442
+ check_default_attributes_func(s)
443
+
444
+
445
+ @wp.struct
446
+ class MutableStruct:
447
+ param1: int
448
+ param2: float
449
+
450
+
451
+ @wp.kernel
452
+ def test_struct_mutate_attributes_kernel():
453
+ t = MutableStruct()
454
+ t.param1 = 1
455
+ t.param2 = 1.1
456
+
457
+ wp.expect_eq(t.param1, 1)
458
+ wp.expect_eq(t.param2, 1.1)
459
+
460
+
461
+ @wp.struct
462
+ class InnerStruct:
463
+ i: int
464
+
465
+
466
+ @wp.struct
467
+ class ArrayStruct:
468
+ array: wp.array(dtype=InnerStruct)
469
+
470
+
471
+ @wp.kernel
472
+ def struct2_reader(test: ArrayStruct):
473
+ k = wp.tid()
474
+ wp.expect_eq(k + 1, test.array[k].i)
475
+
476
+
477
+ def test_nested_array_struct(test, device):
478
+ var1 = InnerStruct()
479
+ var1.i = 1
480
+
481
+ var2 = InnerStruct()
482
+ var2.i = 2
483
+
484
+ struct = ArrayStruct()
485
+ struct.array = wp.array([var1, var2], dtype=InnerStruct, device=device)
486
+
487
+ wp.launch(struct2_reader, dim=2, inputs=[struct], device=device)
488
+
489
+
490
+ @wp.struct
491
+ class VecStruct:
492
+ value: wp.vec3
493
+
494
+
495
+ @wp.struct
496
+ class Bar2:
497
+ z: wp.array(dtype=float)
498
+
499
+
500
+ @wp.struct
501
+ class Foo2:
502
+ x: wp.array(dtype=float)
503
+ y: Bar2
504
+
505
+
506
+ def test_convert_to_device(test, device):
507
+ foo = Foo2()
508
+ foo.x = wp.array((1.23, 2.34), dtype=float, device=device)
509
+ foo.y = Bar2()
510
+ foo.y.z = wp.array((3.45, 4.56), dtype=float, device=device)
511
+
512
+ if device.is_cpu and wp.is_cuda_available():
513
+ dst_device = "cuda:0"
514
+ elif device.is_cuda and wp.is_cpu_available():
515
+ dst_device = "cpu"
516
+ else:
517
+ return
518
+
519
+ result = foo.to(dst_device)
520
+ assert result.x.device == dst_device
521
+ assert result.y.z.device == dst_device
522
+
523
+
524
+ @wp.struct
525
+ class EmptyNest1:
526
+ a: Empty
527
+ z: int
528
+
529
+
530
+ @wp.struct
531
+ class EmptyNest2:
532
+ a: Empty
533
+ b: Empty
534
+ z: int
535
+
536
+
537
+ @wp.struct
538
+ class EmptyNest3:
539
+ a: Empty
540
+ b: Empty
541
+ c: Empty
542
+ z: int
543
+
544
+
545
+ @wp.struct
546
+ class EmptyNest4:
547
+ a: Empty
548
+ b: Empty
549
+ c: Empty
550
+ d: Empty
551
+ z: int
552
+
553
+
554
+ @wp.struct
555
+ class EmptyNest5:
556
+ a: Empty
557
+ b: Empty
558
+ c: Empty
559
+ d: Empty
560
+ e: Empty
561
+ z: int
562
+
563
+
564
+ @wp.struct
565
+ class EmptyNest6:
566
+ a: Empty
567
+ b: Empty
568
+ c: Empty
569
+ d: Empty
570
+ e: Empty
571
+ f: Empty
572
+ z: int
573
+
574
+
575
+ @wp.struct
576
+ class EmptyNest7:
577
+ a: Empty
578
+ b: Empty
579
+ c: Empty
580
+ d: Empty
581
+ e: Empty
582
+ f: Empty
583
+ g: Empty
584
+ z: int
585
+
586
+
587
+ @wp.struct
588
+ class EmptyNest8:
589
+ a: Empty
590
+ b: Empty
591
+ c: Empty
592
+ d: Empty
593
+ e: Empty
594
+ f: Empty
595
+ g: Empty
596
+ h: Empty
597
+ z: int
598
+
599
+
600
+ @wp.kernel
601
+ def empty_nest_kernel(s: Any):
602
+ wp.expect_eq(s.z, 42)
603
+
604
+
605
+ wp.overload(empty_nest_kernel, [EmptyNest1])
606
+ wp.overload(empty_nest_kernel, [EmptyNest2])
607
+ wp.overload(empty_nest_kernel, [EmptyNest3])
608
+ wp.overload(empty_nest_kernel, [EmptyNest4])
609
+ wp.overload(empty_nest_kernel, [EmptyNest5])
610
+ wp.overload(empty_nest_kernel, [EmptyNest6])
611
+ wp.overload(empty_nest_kernel, [EmptyNest7])
612
+ wp.overload(empty_nest_kernel, [EmptyNest8])
613
+
614
+
615
+ def test_nested_empty_struct(test, device):
616
+ with wp.ScopedDevice(device):
617
+ e1 = EmptyNest1()
618
+ e1.z = 42
619
+ e2 = EmptyNest2()
620
+ e2.z = 42
621
+ e3 = EmptyNest3()
622
+ e3.z = 42
623
+ e4 = EmptyNest4()
624
+ e4.z = 42
625
+ e5 = EmptyNest5()
626
+ e5.z = 42
627
+ e6 = EmptyNest6()
628
+ e6.z = 42
629
+ e7 = EmptyNest7()
630
+ e7.z = 42
631
+ e8 = EmptyNest8()
632
+ e8.z = 42
633
+
634
+ wp.launch(empty_nest_kernel, dim=1, inputs=[e1])
635
+ wp.launch(empty_nest_kernel, dim=1, inputs=[e2])
636
+ wp.launch(empty_nest_kernel, dim=1, inputs=[e3])
637
+ wp.launch(empty_nest_kernel, dim=1, inputs=[e4])
638
+ wp.launch(empty_nest_kernel, dim=1, inputs=[e5])
639
+ wp.launch(empty_nest_kernel, dim=1, inputs=[e6])
640
+ wp.launch(empty_nest_kernel, dim=1, inputs=[e7])
641
+ wp.launch(empty_nest_kernel, dim=1, inputs=[e8])
642
+
643
+ wp.synchronize_device()
644
+
645
+
646
+ @wp.struct
647
+ class DependentModuleImport_A:
648
+ s: StructFromAnotherModule
649
+
650
+
651
+ @wp.struct
652
+ class DependentModuleImport_B:
653
+ s: StructFromAnotherModule
654
+
655
+
656
+ @wp.struct
657
+ class DependentModuleImport_C:
658
+ a: DependentModuleImport_A
659
+ b: DependentModuleImport_B
660
+
661
+
662
+ @wp.kernel
663
+ def test_dependent_module_import(c: DependentModuleImport_C):
664
+ wp.tid() # nop, we're just testing codegen
665
+
666
+
667
+ def test_struct_array_hash(test, device):
668
+ # Ensure that the memory address of the struct does not affect the content hash
669
+
670
+ @wp.struct
671
+ class ContentHashStruct:
672
+ i: int
673
+
674
+ @wp.kernel
675
+ def dummy_kernel(a: wp.array(dtype=ContentHashStruct)):
676
+ i = wp.tid()
677
+
678
+ module_hash_0 = wp.get_module(dummy_kernel.__module__).hash_module()
679
+
680
+ # Redefine ContentHashStruct to have the same members as before but a new memory address
681
+ @wp.struct
682
+ class ContentHashStruct:
683
+ i: int
684
+
685
+ @wp.kernel
686
+ def dummy_kernel(a: wp.array(dtype=ContentHashStruct)):
687
+ i = wp.tid()
688
+
689
+ module_hash_1 = wp.get_module(dummy_kernel.__module__).hash_module()
690
+
691
+ test.assertEqual(
692
+ module_hash_1,
693
+ module_hash_0,
694
+ "Module hash should be unchanged when ContentHashStruct is redefined but unchanged.",
695
+ )
696
+
697
+ # Redefine ContentHashStruct to have different members. This time we should get a new hash.
698
+ @wp.struct
699
+ class ContentHashStruct:
700
+ i: float
701
+
702
+ @wp.kernel
703
+ def dummy_kernel(a: wp.array(dtype=ContentHashStruct)):
704
+ i = wp.tid()
705
+
706
+ module_hash_2 = wp.get_module(dummy_kernel.__module__).hash_module()
707
+
708
+ test.assertNotEqual(
709
+ module_hash_2, module_hash_0, "Module hash should be different when ContentHashStruct redefined and changed."
710
+ )
711
+
712
+
713
+ # Tests for garbage collection behavior with arrays in structs
714
+ @wp.struct
715
+ class StructWithArray:
716
+ data: wp.array(dtype=float)
717
+ some_value: int
718
+
719
+
720
+ @wp.kernel
721
+ def access_array_kernel(s: StructWithArray, out: wp.array(dtype=float)):
722
+ # This kernel is used to verify data integrity by reading the first element.
723
+ # Assumes s.data has at least 1 element for this test.
724
+ out[0] = s.data[0]
725
+
726
+
727
+ @wp.kernel
728
+ def compute_loss_from_struct_array_kernel(s_in: StructWithArray, loss_val: wp.array(dtype=float)):
729
+ # Compute a simple scalar loss from the array elements for grad testing.
730
+ # Assumes s_in.data has at least 2 elements for this test.
731
+ res = 0.0
732
+ res += s_in.data[0] * 2.0 # Example weight
733
+ res += s_in.data[1] * 3.0 # Example weight
734
+ loss_val[0] = res
735
+
736
+
737
+ def test_struct_array_gc_direct_assignment(test, device):
738
+ """
739
+ Tests that an array assigned to a struct (with no other direct Python
740
+ references) is not garbage collected prematurely.
741
+ """
742
+ wp.init()
743
+
744
+ s = StructWithArray()
745
+ s.some_value = 20
746
+
747
+ # Create an array, then assign it to the struct.
748
+ # After this assignment, 's.data' is the primary way to access it from
749
+ # Python's perspective, though Warp's context should also hold a reference.
750
+ local_array = wp.array([4.0, 5.0, 6.0], dtype=float, device=device)
751
+ s.data = local_array
752
+ del local_array # Remove the direct Python reference
753
+
754
+ # Force garbage collection
755
+ gc.collect()
756
+
757
+ # Attempt to access the array in a kernel
758
+ out_wp = wp.zeros(1, dtype=float, device=device)
759
+ try:
760
+ wp.launch(kernel=access_array_kernel, dim=1, inputs=[s, out_wp], device=device)
761
+
762
+ # We expect to read 4.0 if the array is still valid
763
+ assert out_wp.numpy()[0] == 4.0, "Array data was not accessible or incorrect after GC with direct assignment."
764
+ except Exception as e:
765
+ test.fail(f"Kernel execution failed after GC with direct assignment: {e}")
766
+
767
+
768
+ def test_struct_array_gc_requires_grad_toggle(test, device):
769
+ """
770
+ Tests that an array within a struct is not garbage collected prematurely
771
+ when its requires_grad flag is toggled, and that backward pass works.
772
+ """
773
+ wp.init()
774
+
775
+ s = StructWithArray()
776
+ s.some_value = 10
777
+ # Initialize array with requires_grad=True. Content: [1.0, 2.0, 3.0]
778
+ s.data = wp.array([1.0, 2.0, 3.0], dtype=float, device=device, requires_grad=True)
779
+
780
+ loss_wp = wp.zeros(1, dtype=float, device=device, requires_grad=True)
781
+
782
+ tape = wp.Tape()
783
+ with tape:
784
+ # Launch kernel that uses s.data to compute a loss
785
+ wp.launch(
786
+ kernel=compute_loss_from_struct_array_kernel,
787
+ dim=1,
788
+ inputs=[s, loss_wp],
789
+ device=device,
790
+ )
791
+
792
+ # Expected loss = 1.0*2.0 + 2.0*3.0 = 2.0 + 6.0 = 8.0
793
+
794
+ # After the forward pass is recorded, toggle requires_grad and run GC
795
+ s.data.requires_grad = False
796
+ gc.collect()
797
+
798
+ # will cause a memory access violation if grad array has been garbage collected
799
+ # or struct is not updated correctly
800
+ tape.backward(loss=loss_wp)
801
+
802
+
803
+ class TestStruct(unittest.TestCase):
804
+ # check structs default initialized in Python correctly
805
+ def test_struct_default_attributes_python(self):
806
+ s = DefaultAttribStruct()
807
+
808
+ wp.launch(check_default_attributes_kernel, dim=1, inputs=[s])
809
+
810
+ def test_nested_vec_assignment(self):
811
+ v = VecStruct()
812
+ v.value[0] = 1.0
813
+ v.value[1] = 2.0
814
+ v.value[2] = 3.0
815
+
816
+ arr = wp.array([v], dtype=VecStruct)
817
+ expected = np.array([[[1.0, 2.0, 3.0]]])
818
+ np.testing.assert_equal(arr.numpy().tolist(), expected)
819
+
820
+
821
+ devices = get_test_devices()
822
+
823
+
824
+ add_function_test(TestStruct, "test_step", test_step, devices=devices)
825
+ add_function_test(TestStruct, "test_step_grad", test_step_grad, devices=devices)
826
+ add_kernel_test(TestStruct, kernel=test_empty, name="test_empty", dim=1, inputs=[Empty()], devices=devices)
827
+ add_kernel_test(
828
+ TestStruct,
829
+ kernel=test_uninitialized,
830
+ name="test_uninitialized",
831
+ dim=1,
832
+ inputs=[Uninitialized()],
833
+ devices=devices,
834
+ )
835
+ add_kernel_test(TestStruct, kernel=test_return, name="test_return", dim=1, inputs=[], devices=devices)
836
+ add_function_test(TestStruct, "test_nested_struct", test_nested_struct, devices=devices)
837
+ add_function_test(TestStruct, "test_nested_mat", test_nested_mat, devices=devices)
838
+ add_function_test(TestStruct, "test_assign_view", test_assign_view, devices=devices)
839
+ add_function_test(TestStruct, "test_struct_attribute_error", test_struct_attribute_error, devices=devices)
840
+ add_function_test(TestStruct, "test_struct_inheritance_error", test_struct_inheritance_error, devices=devices)
841
+ add_function_test(TestStruct, "test_nested_array_struct", test_nested_array_struct, devices=devices)
842
+ add_function_test(TestStruct, "test_convert_to_device", test_convert_to_device, devices=devices)
843
+ add_function_test(TestStruct, "test_nested_empty_struct", test_nested_empty_struct, devices=devices)
844
+ add_function_test(TestStruct, "test_struct_math_conversions", test_struct_math_conversions, devices=devices)
845
+ add_kernel_test(
846
+ TestStruct,
847
+ name="test_struct_default_attributes",
848
+ kernel=test_struct_default_attributes_kernel,
849
+ dim=1,
850
+ inputs=[],
851
+ devices=devices,
852
+ )
853
+
854
+ add_kernel_test(
855
+ TestStruct,
856
+ name="test_struct_mutate_attributes",
857
+ kernel=test_struct_mutate_attributes_kernel,
858
+ dim=1,
859
+ inputs=[],
860
+ devices=devices,
861
+ )
862
+
863
+ for device in devices:
864
+ add_kernel_test(
865
+ TestStruct,
866
+ kernel=test_struct_instantiate,
867
+ name="test_struct_instantiate",
868
+ dim=1,
869
+ inputs=[wp.array([1], dtype=int, device=device)],
870
+ devices=[device],
871
+ )
872
+ add_kernel_test(
873
+ TestStruct,
874
+ kernel=test_return_struct,
875
+ name="test_return_struct",
876
+ dim=1,
877
+ inputs=[wp.zeros(10, dtype=int, device=device)],
878
+ devices=[device],
879
+ )
880
+
881
+ add_kernel_test(
882
+ TestStruct,
883
+ kernel=test_dependent_module_import,
884
+ name="test_dependent_module_import",
885
+ dim=1,
886
+ inputs=[DependentModuleImport_C()],
887
+ devices=devices,
888
+ )
889
+
890
+ add_function_test(TestStruct, "test_struct_array_hash", test_struct_array_hash, devices=None)
891
+ add_function_test(
892
+ TestStruct, "test_struct_array_gc_requires_grad_toggle", test_struct_array_gc_requires_grad_toggle, devices=devices
893
+ )
894
+ add_function_test(
895
+ TestStruct, "test_struct_array_gc_direct_assignment", test_struct_array_gc_direct_assignment, devices=devices
896
+ )
897
+
898
+
899
+ if __name__ == "__main__":
900
+ wp.clear_kernel_cache()
901
+ unittest.main(verbosity=2)