warp-lang 1.10.0__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (468) hide show
  1. warp/__init__.py +334 -0
  2. warp/__init__.pyi +5856 -0
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1077 -0
  5. warp/_src/build.py +620 -0
  6. warp/_src/build_dll.py +642 -0
  7. warp/_src/builtins.py +10555 -0
  8. warp/_src/codegen.py +4361 -0
  9. warp/_src/config.py +178 -0
  10. warp/_src/constants.py +59 -0
  11. warp/_src/context.py +8352 -0
  12. warp/_src/dlpack.py +464 -0
  13. warp/_src/fabric.py +362 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +510 -0
  16. warp/_src/fem/cache.py +689 -0
  17. warp/_src/fem/dirichlet.py +190 -0
  18. warp/_src/fem/domain.py +553 -0
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +703 -0
  21. warp/_src/fem/field/nodal_field.py +403 -0
  22. warp/_src/fem/field/restriction.py +39 -0
  23. warp/_src/fem/field/virtual.py +1021 -0
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/_src/fem/geometry/adaptive_nanogrid.py +782 -0
  26. warp/_src/fem/geometry/closest_point.py +99 -0
  27. warp/_src/fem/geometry/deformed_geometry.py +277 -0
  28. warp/_src/fem/geometry/element.py +854 -0
  29. warp/_src/fem/geometry/geometry.py +693 -0
  30. warp/_src/fem/geometry/grid_2d.py +478 -0
  31. warp/_src/fem/geometry/grid_3d.py +539 -0
  32. warp/_src/fem/geometry/hexmesh.py +956 -0
  33. warp/_src/fem/geometry/nanogrid.py +660 -0
  34. warp/_src/fem/geometry/partition.py +483 -0
  35. warp/_src/fem/geometry/quadmesh.py +597 -0
  36. warp/_src/fem/geometry/tetmesh.py +762 -0
  37. warp/_src/fem/geometry/trimesh.py +588 -0
  38. warp/_src/fem/integrate.py +2507 -0
  39. warp/_src/fem/linalg.py +385 -0
  40. warp/_src/fem/operator.py +398 -0
  41. warp/_src/fem/polynomial.py +231 -0
  42. warp/_src/fem/quadrature/__init__.py +17 -0
  43. warp/_src/fem/quadrature/pic_quadrature.py +318 -0
  44. warp/_src/fem/quadrature/quadrature.py +665 -0
  45. warp/_src/fem/space/__init__.py +248 -0
  46. warp/_src/fem/space/basis_function_space.py +499 -0
  47. warp/_src/fem/space/basis_space.py +681 -0
  48. warp/_src/fem/space/dof_mapper.py +253 -0
  49. warp/_src/fem/space/function_space.py +312 -0
  50. warp/_src/fem/space/grid_2d_function_space.py +179 -0
  51. warp/_src/fem/space/grid_3d_function_space.py +229 -0
  52. warp/_src/fem/space/hexmesh_function_space.py +255 -0
  53. warp/_src/fem/space/nanogrid_function_space.py +199 -0
  54. warp/_src/fem/space/partition.py +435 -0
  55. warp/_src/fem/space/quadmesh_function_space.py +222 -0
  56. warp/_src/fem/space/restriction.py +221 -0
  57. warp/_src/fem/space/shape/__init__.py +152 -0
  58. warp/_src/fem/space/shape/cube_shape_function.py +1107 -0
  59. warp/_src/fem/space/shape/shape_function.py +134 -0
  60. warp/_src/fem/space/shape/square_shape_function.py +928 -0
  61. warp/_src/fem/space/shape/tet_shape_function.py +829 -0
  62. warp/_src/fem/space/shape/triangle_shape_function.py +674 -0
  63. warp/_src/fem/space/tetmesh_function_space.py +270 -0
  64. warp/_src/fem/space/topology.py +461 -0
  65. warp/_src/fem/space/trimesh_function_space.py +193 -0
  66. warp/_src/fem/types.py +114 -0
  67. warp/_src/fem/utils.py +488 -0
  68. warp/_src/jax.py +188 -0
  69. warp/_src/jax_experimental/__init__.py +14 -0
  70. warp/_src/jax_experimental/custom_call.py +389 -0
  71. warp/_src/jax_experimental/ffi.py +1286 -0
  72. warp/_src/jax_experimental/xla_ffi.py +658 -0
  73. warp/_src/marching_cubes.py +710 -0
  74. warp/_src/math.py +416 -0
  75. warp/_src/optim/__init__.py +14 -0
  76. warp/_src/optim/adam.py +165 -0
  77. warp/_src/optim/linear.py +1608 -0
  78. warp/_src/optim/sgd.py +114 -0
  79. warp/_src/paddle.py +408 -0
  80. warp/_src/render/__init__.py +14 -0
  81. warp/_src/render/imgui_manager.py +291 -0
  82. warp/_src/render/render_opengl.py +3638 -0
  83. warp/_src/render/render_usd.py +939 -0
  84. warp/_src/render/utils.py +162 -0
  85. warp/_src/sparse.py +2718 -0
  86. warp/_src/tape.py +1208 -0
  87. warp/_src/thirdparty/__init__.py +0 -0
  88. warp/_src/thirdparty/appdirs.py +598 -0
  89. warp/_src/thirdparty/dlpack.py +145 -0
  90. warp/_src/thirdparty/unittest_parallel.py +676 -0
  91. warp/_src/torch.py +393 -0
  92. warp/_src/types.py +5888 -0
  93. warp/_src/utils.py +1695 -0
  94. warp/autograd.py +33 -0
  95. warp/bin/libwarp-clang.dylib +0 -0
  96. warp/bin/libwarp.dylib +0 -0
  97. warp/build.py +29 -0
  98. warp/build_dll.py +24 -0
  99. warp/codegen.py +24 -0
  100. warp/constants.py +24 -0
  101. warp/context.py +33 -0
  102. warp/dlpack.py +24 -0
  103. warp/examples/__init__.py +24 -0
  104. warp/examples/assets/bear.usd +0 -0
  105. warp/examples/assets/bunny.usd +0 -0
  106. warp/examples/assets/cube.usd +0 -0
  107. warp/examples/assets/nonuniform.usd +0 -0
  108. warp/examples/assets/nvidia_logo.png +0 -0
  109. warp/examples/assets/pixel.jpg +0 -0
  110. warp/examples/assets/rocks.nvdb +0 -0
  111. warp/examples/assets/rocks.usd +0 -0
  112. warp/examples/assets/sphere.usd +0 -0
  113. warp/examples/assets/square_cloth.usd +0 -0
  114. warp/examples/benchmarks/benchmark_api.py +389 -0
  115. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  116. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  117. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  118. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  119. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  120. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  121. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  122. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  123. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  124. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  125. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  126. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  127. warp/examples/benchmarks/benchmark_launches.py +301 -0
  128. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  129. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  130. warp/examples/browse.py +37 -0
  131. warp/examples/core/example_cupy.py +86 -0
  132. warp/examples/core/example_dem.py +241 -0
  133. warp/examples/core/example_fluid.py +299 -0
  134. warp/examples/core/example_graph_capture.py +150 -0
  135. warp/examples/core/example_marching_cubes.py +195 -0
  136. warp/examples/core/example_mesh.py +180 -0
  137. warp/examples/core/example_mesh_intersect.py +211 -0
  138. warp/examples/core/example_nvdb.py +182 -0
  139. warp/examples/core/example_raycast.py +111 -0
  140. warp/examples/core/example_raymarch.py +205 -0
  141. warp/examples/core/example_render_opengl.py +290 -0
  142. warp/examples/core/example_sample_mesh.py +300 -0
  143. warp/examples/core/example_sph.py +411 -0
  144. warp/examples/core/example_spin_lock.py +93 -0
  145. warp/examples/core/example_torch.py +211 -0
  146. warp/examples/core/example_wave.py +269 -0
  147. warp/examples/core/example_work_queue.py +118 -0
  148. warp/examples/distributed/example_jacobi_mpi.py +506 -0
  149. warp/examples/fem/example_adaptive_grid.py +286 -0
  150. warp/examples/fem/example_apic_fluid.py +469 -0
  151. warp/examples/fem/example_burgers.py +261 -0
  152. warp/examples/fem/example_convection_diffusion.py +181 -0
  153. warp/examples/fem/example_convection_diffusion_dg.py +225 -0
  154. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  155. warp/examples/fem/example_deformed_geometry.py +172 -0
  156. warp/examples/fem/example_diffusion.py +196 -0
  157. warp/examples/fem/example_diffusion_3d.py +225 -0
  158. warp/examples/fem/example_diffusion_mgpu.py +225 -0
  159. warp/examples/fem/example_distortion_energy.py +228 -0
  160. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  161. warp/examples/fem/example_magnetostatics.py +242 -0
  162. warp/examples/fem/example_mixed_elasticity.py +293 -0
  163. warp/examples/fem/example_navier_stokes.py +263 -0
  164. warp/examples/fem/example_nonconforming_contact.py +300 -0
  165. warp/examples/fem/example_stokes.py +213 -0
  166. warp/examples/fem/example_stokes_transfer.py +262 -0
  167. warp/examples/fem/example_streamlines.py +357 -0
  168. warp/examples/fem/utils.py +1047 -0
  169. warp/examples/interop/example_jax_callable.py +146 -0
  170. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  171. warp/examples/interop/example_jax_kernel.py +232 -0
  172. warp/examples/optim/example_diffray.py +561 -0
  173. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  174. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  175. warp/examples/tile/example_tile_cholesky.py +88 -0
  176. warp/examples/tile/example_tile_convolution.py +66 -0
  177. warp/examples/tile/example_tile_fft.py +55 -0
  178. warp/examples/tile/example_tile_filtering.py +113 -0
  179. warp/examples/tile/example_tile_matmul.py +85 -0
  180. warp/examples/tile/example_tile_mcgp.py +191 -0
  181. warp/examples/tile/example_tile_mlp.py +385 -0
  182. warp/examples/tile/example_tile_nbody.py +199 -0
  183. warp/fabric.py +24 -0
  184. warp/fem/__init__.py +173 -0
  185. warp/fem/adaptivity.py +26 -0
  186. warp/fem/cache.py +30 -0
  187. warp/fem/dirichlet.py +24 -0
  188. warp/fem/field/__init__.py +24 -0
  189. warp/fem/field/field.py +26 -0
  190. warp/fem/geometry/__init__.py +21 -0
  191. warp/fem/geometry/closest_point.py +31 -0
  192. warp/fem/linalg.py +38 -0
  193. warp/fem/operator.py +32 -0
  194. warp/fem/polynomial.py +29 -0
  195. warp/fem/space/__init__.py +22 -0
  196. warp/fem/space/basis_space.py +24 -0
  197. warp/fem/space/shape/__init__.py +68 -0
  198. warp/fem/space/topology.py +24 -0
  199. warp/fem/types.py +24 -0
  200. warp/fem/utils.py +32 -0
  201. warp/jax.py +29 -0
  202. warp/jax_experimental/__init__.py +29 -0
  203. warp/jax_experimental/custom_call.py +29 -0
  204. warp/jax_experimental/ffi.py +39 -0
  205. warp/jax_experimental/xla_ffi.py +24 -0
  206. warp/marching_cubes.py +24 -0
  207. warp/math.py +37 -0
  208. warp/native/array.h +1687 -0
  209. warp/native/builtin.h +2327 -0
  210. warp/native/bvh.cpp +562 -0
  211. warp/native/bvh.cu +826 -0
  212. warp/native/bvh.h +555 -0
  213. warp/native/clang/clang.cpp +541 -0
  214. warp/native/coloring.cpp +622 -0
  215. warp/native/crt.cpp +51 -0
  216. warp/native/crt.h +568 -0
  217. warp/native/cuda_crt.h +1058 -0
  218. warp/native/cuda_util.cpp +677 -0
  219. warp/native/cuda_util.h +313 -0
  220. warp/native/error.cpp +77 -0
  221. warp/native/error.h +36 -0
  222. warp/native/exports.h +2023 -0
  223. warp/native/fabric.h +246 -0
  224. warp/native/hashgrid.cpp +311 -0
  225. warp/native/hashgrid.cu +89 -0
  226. warp/native/hashgrid.h +240 -0
  227. warp/native/initializer_array.h +41 -0
  228. warp/native/intersect.h +1253 -0
  229. warp/native/intersect_adj.h +375 -0
  230. warp/native/intersect_tri.h +348 -0
  231. warp/native/mat.h +5189 -0
  232. warp/native/mathdx.cpp +93 -0
  233. warp/native/matnn.h +221 -0
  234. warp/native/mesh.cpp +266 -0
  235. warp/native/mesh.cu +406 -0
  236. warp/native/mesh.h +2097 -0
  237. warp/native/nanovdb/GridHandle.h +533 -0
  238. warp/native/nanovdb/HostBuffer.h +591 -0
  239. warp/native/nanovdb/NanoVDB.h +6246 -0
  240. warp/native/nanovdb/NodeManager.h +323 -0
  241. warp/native/nanovdb/PNanoVDB.h +3390 -0
  242. warp/native/noise.h +859 -0
  243. warp/native/quat.h +1664 -0
  244. warp/native/rand.h +342 -0
  245. warp/native/range.h +145 -0
  246. warp/native/reduce.cpp +174 -0
  247. warp/native/reduce.cu +363 -0
  248. warp/native/runlength_encode.cpp +79 -0
  249. warp/native/runlength_encode.cu +61 -0
  250. warp/native/scan.cpp +47 -0
  251. warp/native/scan.cu +55 -0
  252. warp/native/scan.h +23 -0
  253. warp/native/solid_angle.h +466 -0
  254. warp/native/sort.cpp +251 -0
  255. warp/native/sort.cu +286 -0
  256. warp/native/sort.h +35 -0
  257. warp/native/sparse.cpp +241 -0
  258. warp/native/sparse.cu +435 -0
  259. warp/native/spatial.h +1306 -0
  260. warp/native/svd.h +727 -0
  261. warp/native/temp_buffer.h +46 -0
  262. warp/native/tile.h +4124 -0
  263. warp/native/tile_radix_sort.h +1112 -0
  264. warp/native/tile_reduce.h +838 -0
  265. warp/native/tile_scan.h +240 -0
  266. warp/native/tuple.h +189 -0
  267. warp/native/vec.h +2199 -0
  268. warp/native/version.h +23 -0
  269. warp/native/volume.cpp +501 -0
  270. warp/native/volume.cu +68 -0
  271. warp/native/volume.h +970 -0
  272. warp/native/volume_builder.cu +483 -0
  273. warp/native/volume_builder.h +52 -0
  274. warp/native/volume_impl.h +70 -0
  275. warp/native/warp.cpp +1143 -0
  276. warp/native/warp.cu +4604 -0
  277. warp/native/warp.h +358 -0
  278. warp/optim/__init__.py +20 -0
  279. warp/optim/adam.py +24 -0
  280. warp/optim/linear.py +35 -0
  281. warp/optim/sgd.py +24 -0
  282. warp/paddle.py +24 -0
  283. warp/py.typed +0 -0
  284. warp/render/__init__.py +22 -0
  285. warp/render/imgui_manager.py +29 -0
  286. warp/render/render_opengl.py +24 -0
  287. warp/render/render_usd.py +24 -0
  288. warp/render/utils.py +24 -0
  289. warp/sparse.py +51 -0
  290. warp/tape.py +24 -0
  291. warp/tests/__init__.py +1 -0
  292. warp/tests/__main__.py +4 -0
  293. warp/tests/assets/curlnoise_golden.npy +0 -0
  294. warp/tests/assets/mlp_golden.npy +0 -0
  295. warp/tests/assets/pixel.npy +0 -0
  296. warp/tests/assets/pnoise_golden.npy +0 -0
  297. warp/tests/assets/spiky.usd +0 -0
  298. warp/tests/assets/test_grid.nvdb +0 -0
  299. warp/tests/assets/test_index_grid.nvdb +0 -0
  300. warp/tests/assets/test_int32_grid.nvdb +0 -0
  301. warp/tests/assets/test_vec_grid.nvdb +0 -0
  302. warp/tests/assets/torus.nvdb +0 -0
  303. warp/tests/assets/torus.usda +105 -0
  304. warp/tests/aux_test_class_kernel.py +34 -0
  305. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  306. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  307. warp/tests/aux_test_dependent.py +29 -0
  308. warp/tests/aux_test_grad_customs.py +29 -0
  309. warp/tests/aux_test_instancing_gc.py +26 -0
  310. warp/tests/aux_test_module_aot.py +7 -0
  311. warp/tests/aux_test_module_unload.py +23 -0
  312. warp/tests/aux_test_name_clash1.py +40 -0
  313. warp/tests/aux_test_name_clash2.py +40 -0
  314. warp/tests/aux_test_reference.py +9 -0
  315. warp/tests/aux_test_reference_reference.py +8 -0
  316. warp/tests/aux_test_square.py +16 -0
  317. warp/tests/aux_test_unresolved_func.py +22 -0
  318. warp/tests/aux_test_unresolved_symbol.py +22 -0
  319. warp/tests/cuda/__init__.py +0 -0
  320. warp/tests/cuda/test_async.py +676 -0
  321. warp/tests/cuda/test_conditional_captures.py +1147 -0
  322. warp/tests/cuda/test_ipc.py +124 -0
  323. warp/tests/cuda/test_mempool.py +233 -0
  324. warp/tests/cuda/test_multigpu.py +169 -0
  325. warp/tests/cuda/test_peer.py +139 -0
  326. warp/tests/cuda/test_pinned.py +84 -0
  327. warp/tests/cuda/test_streams.py +691 -0
  328. warp/tests/geometry/__init__.py +0 -0
  329. warp/tests/geometry/test_bvh.py +335 -0
  330. warp/tests/geometry/test_hash_grid.py +259 -0
  331. warp/tests/geometry/test_marching_cubes.py +294 -0
  332. warp/tests/geometry/test_mesh.py +318 -0
  333. warp/tests/geometry/test_mesh_query_aabb.py +392 -0
  334. warp/tests/geometry/test_mesh_query_point.py +935 -0
  335. warp/tests/geometry/test_mesh_query_ray.py +323 -0
  336. warp/tests/geometry/test_volume.py +1103 -0
  337. warp/tests/geometry/test_volume_write.py +346 -0
  338. warp/tests/interop/__init__.py +0 -0
  339. warp/tests/interop/test_dlpack.py +730 -0
  340. warp/tests/interop/test_jax.py +1673 -0
  341. warp/tests/interop/test_paddle.py +800 -0
  342. warp/tests/interop/test_torch.py +1001 -0
  343. warp/tests/run_coverage_serial.py +39 -0
  344. warp/tests/test_adam.py +162 -0
  345. warp/tests/test_arithmetic.py +1096 -0
  346. warp/tests/test_array.py +3756 -0
  347. warp/tests/test_array_reduce.py +156 -0
  348. warp/tests/test_assert.py +303 -0
  349. warp/tests/test_atomic.py +336 -0
  350. warp/tests/test_atomic_bitwise.py +209 -0
  351. warp/tests/test_atomic_cas.py +312 -0
  352. warp/tests/test_bool.py +220 -0
  353. warp/tests/test_builtins_resolution.py +732 -0
  354. warp/tests/test_closest_point_edge_edge.py +327 -0
  355. warp/tests/test_codegen.py +974 -0
  356. warp/tests/test_codegen_instancing.py +1495 -0
  357. warp/tests/test_compile_consts.py +215 -0
  358. warp/tests/test_conditional.py +298 -0
  359. warp/tests/test_context.py +35 -0
  360. warp/tests/test_copy.py +319 -0
  361. warp/tests/test_ctypes.py +618 -0
  362. warp/tests/test_dense.py +73 -0
  363. warp/tests/test_devices.py +127 -0
  364. warp/tests/test_enum.py +136 -0
  365. warp/tests/test_examples.py +424 -0
  366. warp/tests/test_fabricarray.py +998 -0
  367. warp/tests/test_fast_math.py +72 -0
  368. warp/tests/test_fem.py +2204 -0
  369. warp/tests/test_fixedarray.py +229 -0
  370. warp/tests/test_fp16.py +136 -0
  371. warp/tests/test_func.py +501 -0
  372. warp/tests/test_future_annotations.py +100 -0
  373. warp/tests/test_generics.py +656 -0
  374. warp/tests/test_grad.py +893 -0
  375. warp/tests/test_grad_customs.py +339 -0
  376. warp/tests/test_grad_debug.py +341 -0
  377. warp/tests/test_implicit_init.py +411 -0
  378. warp/tests/test_import.py +45 -0
  379. warp/tests/test_indexedarray.py +1140 -0
  380. warp/tests/test_intersect.py +103 -0
  381. warp/tests/test_iter.py +76 -0
  382. warp/tests/test_large.py +177 -0
  383. warp/tests/test_launch.py +411 -0
  384. warp/tests/test_lerp.py +151 -0
  385. warp/tests/test_linear_solvers.py +223 -0
  386. warp/tests/test_lvalue.py +427 -0
  387. warp/tests/test_map.py +526 -0
  388. warp/tests/test_mat.py +3515 -0
  389. warp/tests/test_mat_assign_copy.py +178 -0
  390. warp/tests/test_mat_constructors.py +573 -0
  391. warp/tests/test_mat_lite.py +122 -0
  392. warp/tests/test_mat_scalar_ops.py +2913 -0
  393. warp/tests/test_math.py +212 -0
  394. warp/tests/test_module_aot.py +287 -0
  395. warp/tests/test_module_hashing.py +258 -0
  396. warp/tests/test_modules_lite.py +70 -0
  397. warp/tests/test_noise.py +252 -0
  398. warp/tests/test_operators.py +299 -0
  399. warp/tests/test_options.py +129 -0
  400. warp/tests/test_overwrite.py +551 -0
  401. warp/tests/test_print.py +408 -0
  402. warp/tests/test_quat.py +2653 -0
  403. warp/tests/test_quat_assign_copy.py +145 -0
  404. warp/tests/test_rand.py +339 -0
  405. warp/tests/test_reload.py +303 -0
  406. warp/tests/test_rounding.py +157 -0
  407. warp/tests/test_runlength_encode.py +196 -0
  408. warp/tests/test_scalar_ops.py +133 -0
  409. warp/tests/test_smoothstep.py +108 -0
  410. warp/tests/test_snippet.py +318 -0
  411. warp/tests/test_sparse.py +845 -0
  412. warp/tests/test_spatial.py +2859 -0
  413. warp/tests/test_spatial_assign_copy.py +160 -0
  414. warp/tests/test_special_values.py +361 -0
  415. warp/tests/test_static.py +640 -0
  416. warp/tests/test_struct.py +901 -0
  417. warp/tests/test_tape.py +242 -0
  418. warp/tests/test_transient_module.py +93 -0
  419. warp/tests/test_triangle_closest_point.py +192 -0
  420. warp/tests/test_tuple.py +361 -0
  421. warp/tests/test_types.py +615 -0
  422. warp/tests/test_utils.py +594 -0
  423. warp/tests/test_vec.py +1408 -0
  424. warp/tests/test_vec_assign_copy.py +143 -0
  425. warp/tests/test_vec_constructors.py +325 -0
  426. warp/tests/test_vec_lite.py +80 -0
  427. warp/tests/test_vec_scalar_ops.py +2327 -0
  428. warp/tests/test_verify_fp.py +100 -0
  429. warp/tests/test_version.py +75 -0
  430. warp/tests/tile/__init__.py +0 -0
  431. warp/tests/tile/test_tile.py +1519 -0
  432. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  433. warp/tests/tile/test_tile_cholesky.py +608 -0
  434. warp/tests/tile/test_tile_load.py +724 -0
  435. warp/tests/tile/test_tile_mathdx.py +156 -0
  436. warp/tests/tile/test_tile_matmul.py +179 -0
  437. warp/tests/tile/test_tile_mlp.py +400 -0
  438. warp/tests/tile/test_tile_reduce.py +950 -0
  439. warp/tests/tile/test_tile_shared_memory.py +376 -0
  440. warp/tests/tile/test_tile_sort.py +121 -0
  441. warp/tests/tile/test_tile_view.py +173 -0
  442. warp/tests/unittest_serial.py +47 -0
  443. warp/tests/unittest_suites.py +430 -0
  444. warp/tests/unittest_utils.py +469 -0
  445. warp/tests/walkthrough_debug.py +95 -0
  446. warp/torch.py +24 -0
  447. warp/types.py +51 -0
  448. warp/utils.py +31 -0
  449. warp_lang-1.10.0.dist-info/METADATA +459 -0
  450. warp_lang-1.10.0.dist-info/RECORD +468 -0
  451. warp_lang-1.10.0.dist-info/WHEEL +5 -0
  452. warp_lang-1.10.0.dist-info/licenses/LICENSE.md +176 -0
  453. warp_lang-1.10.0.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  454. warp_lang-1.10.0.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  455. warp_lang-1.10.0.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  456. warp_lang-1.10.0.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  457. warp_lang-1.10.0.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  458. warp_lang-1.10.0.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  459. warp_lang-1.10.0.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  460. warp_lang-1.10.0.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  461. warp_lang-1.10.0.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  462. warp_lang-1.10.0.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  463. warp_lang-1.10.0.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  464. warp_lang-1.10.0.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  465. warp_lang-1.10.0.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  466. warp_lang-1.10.0.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  467. warp_lang-1.10.0.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  468. warp_lang-1.10.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,551 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import contextlib
17
+ import io
18
+ import unittest
19
+ from typing import Any
20
+
21
+ import numpy as np
22
+
23
+ import warp as wp
24
+ from warp.tests.unittest_utils import *
25
+
26
+ # kernels are defined in the global scope, to ensure wp.Kernel objects are not GC'ed in the MGPU case
27
+ # kernel args are assigned array modes during codegen, so wp.Kernel objects generated during codegen
28
+ # must be preserved for overwrite tracking to function
29
+
30
+
31
+ @wp.kernel
32
+ def square_kernel(x: wp.array(dtype=float), y: wp.array(dtype=float)):
33
+ tid = wp.tid()
34
+ y[tid] = x[tid] * x[tid]
35
+
36
+
37
+ @wp.kernel
38
+ def overwrite_kernel_a(z: wp.array(dtype=float), x: wp.array(dtype=float)):
39
+ tid = wp.tid()
40
+ x[tid] = z[tid]
41
+
42
+
43
+ # (kernel READ) -> (kernel WRITE) failure case
44
+ def test_kernel_read_kernel_write(test, device):
45
+ saved_verify_autograd_array_access_setting = wp.config.verify_autograd_array_access
46
+ try:
47
+ wp.config.verify_autograd_array_access = True
48
+
49
+ a = wp.array(np.array([1.0, 2.0, 3.0]), dtype=float, requires_grad=True, device=device)
50
+ b = wp.zeros_like(a)
51
+ c = wp.array(np.array([-1.0, -2.0, -3.0]), dtype=float, requires_grad=True, device=device)
52
+
53
+ tape = wp.Tape()
54
+
55
+ with contextlib.redirect_stdout(io.StringIO()) as f:
56
+ with tape:
57
+ wp.launch(square_kernel, a.shape, inputs=[a], outputs=[b], device=device)
58
+ wp.launch(overwrite_kernel_a, c.shape, inputs=[c], outputs=[a], device=device)
59
+
60
+ expected = "is being written to but has already been read from in a previous launch. This may corrupt gradient computation in the backward pass."
61
+ test.assertIn(expected, f.getvalue())
62
+
63
+ finally:
64
+ wp.config.verify_autograd_array_access = saved_verify_autograd_array_access_setting
65
+
66
+
67
+ @wp.kernel
68
+ def double_kernel(x: wp.array(dtype=float), y: wp.array(dtype=float)):
69
+ tid = wp.tid()
70
+ y[tid] = 2.0 * x[tid]
71
+
72
+
73
+ @wp.kernel
74
+ def triple_kernel(y: wp.array(dtype=float), z: wp.array(dtype=float)):
75
+ tid = wp.tid()
76
+ z[tid] = 3.0 * y[tid]
77
+
78
+
79
+ @wp.kernel
80
+ def overwrite_kernel_b(w: wp.array(dtype=float), y: wp.array(dtype=float)):
81
+ tid = wp.tid()
82
+ y[tid] = 1.0 * w[tid]
83
+
84
+
85
+ # (kernel WRITE) -> (kernel READ) -> (kernel WRITE) failure case
86
+ def test_kernel_write_kernel_read_kernel_write(test, device):
87
+ saved_verify_autograd_array_access_setting = wp.config.verify_autograd_array_access
88
+ try:
89
+ wp.config.verify_autograd_array_access = True
90
+
91
+ tape = wp.Tape()
92
+
93
+ a = wp.array(np.array([1.0, 2.0, 3.0]), dtype=float, requires_grad=True, device=device)
94
+ b = wp.zeros_like(a)
95
+ c = wp.zeros_like(a)
96
+ d = wp.zeros_like(a)
97
+
98
+ with contextlib.redirect_stdout(io.StringIO()) as f:
99
+ with tape:
100
+ wp.launch(double_kernel, a.shape, inputs=[a], outputs=[b], device=device)
101
+ wp.launch(triple_kernel, b.shape, inputs=[b], outputs=[c], device=device)
102
+ wp.launch(overwrite_kernel_b, d.shape, inputs=[d], outputs=[b], device=device)
103
+
104
+ expected = "is being written to but has already been read from in a previous launch. This may corrupt gradient computation in the backward pass."
105
+ test.assertIn(expected, f.getvalue())
106
+
107
+ finally:
108
+ wp.config.verify_autograd_array_access = saved_verify_autograd_array_access_setting
109
+
110
+
111
+ @wp.kernel
112
+ def read_kernel(a: wp.array(dtype=float), b: wp.array(dtype=float)):
113
+ tid = wp.tid()
114
+ b[tid] = a[tid]
115
+
116
+
117
+ @wp.kernel
118
+ def writeread_kernel(a: wp.array(dtype=float), b: wp.array(dtype=float), c: wp.array(dtype=float)):
119
+ tid = wp.tid()
120
+ a[tid] = c[tid] * c[tid]
121
+ b[tid] = a[tid]
122
+
123
+
124
+ # (kernel READ) -> (kernel WRITE -> READ) failure case
125
+ def test_kernel_read_kernel_writeread(test, device):
126
+ saved_verify_autograd_array_access_setting = wp.config.verify_autograd_array_access
127
+ try:
128
+ wp.config.verify_autograd_array_access = True
129
+
130
+ a = wp.array(np.arange(5), dtype=float, requires_grad=True, device=device)
131
+ b = wp.zeros_like(a)
132
+ c = wp.zeros_like(a)
133
+ d = wp.zeros_like(a)
134
+
135
+ tape = wp.Tape()
136
+
137
+ with contextlib.redirect_stdout(io.StringIO()) as f:
138
+ with tape:
139
+ wp.launch(read_kernel, dim=5, inputs=[a, b], device=device)
140
+ wp.launch(writeread_kernel, dim=5, inputs=[a, d, c], device=device)
141
+
142
+ expected = "is being written to but has already been read from in a previous launch. This may corrupt gradient computation in the backward pass."
143
+ test.assertIn(expected, f.getvalue())
144
+
145
+ finally:
146
+ wp.config.verify_autograd_array_access = saved_verify_autograd_array_access_setting
147
+
148
+
149
+ @wp.kernel
150
+ def write_kernel(a: wp.array(dtype=float), d: wp.array(dtype=float)):
151
+ tid = wp.tid()
152
+ a[tid] = d[tid]
153
+
154
+
155
+ # (kernel WRITE -> READ) -> (kernel WRITE) failure case
156
+ def test_kernel_writeread_kernel_write(test, device):
157
+ saved_verify_autograd_array_access_setting = wp.config.verify_autograd_array_access
158
+ try:
159
+ wp.config.verify_autograd_array_access = True
160
+
161
+ c = wp.array(np.arange(5), dtype=float, requires_grad=True, device=device)
162
+ b = wp.zeros_like(c)
163
+ a = wp.zeros_like(c)
164
+ d = wp.zeros_like(c)
165
+
166
+ tape = wp.Tape()
167
+
168
+ with contextlib.redirect_stdout(io.StringIO()) as f:
169
+ with tape:
170
+ wp.launch(writeread_kernel, dim=5, inputs=[a, b, c], device=device)
171
+ wp.launch(write_kernel, dim=5, inputs=[a, d], device=device)
172
+
173
+ expected = "is being written to but has already been read from in a previous launch. This may corrupt gradient computation in the backward pass."
174
+ test.assertIn(expected, f.getvalue())
175
+
176
+ finally:
177
+ wp.config.verify_autograd_array_access = saved_verify_autograd_array_access_setting
178
+
179
+
180
+ @wp.func
181
+ def read_func(a: wp.array(dtype=Any), idx: int):
182
+ x = a[idx]
183
+ return x
184
+
185
+
186
+ @wp.func
187
+ def read_return_func(b: wp.array(dtype=Any), idx: int):
188
+ return 1.0, b[idx]
189
+
190
+
191
+ @wp.func
192
+ def write_func(c: wp.array(dtype=Any), idx: int):
193
+ c[idx] = 1.0
194
+
195
+
196
+ @wp.func
197
+ def main_func(a: wp.array(dtype=float), b: wp.array(dtype=float), c: wp.array(dtype=float), idx: int):
198
+ x = read_func(a, idx)
199
+ y, z = read_return_func(b, idx)
200
+ write_func(c, idx)
201
+ return x + y + z
202
+
203
+
204
+ @wp.kernel
205
+ def func_kernel(a: wp.array(dtype=float), b: wp.array(dtype=float), c: wp.array(dtype=float), d: wp.array(dtype=float)):
206
+ tid = wp.tid()
207
+ d[tid] = main_func(a, b, c, tid)
208
+
209
+
210
+ # test various ways one might write to or read from an array inside warp functions
211
+ def test_nested_function_read_write(test, device):
212
+ saved_verify_autograd_array_access_setting = wp.config.verify_autograd_array_access
213
+ try:
214
+ wp.config.verify_autograd_array_access = True
215
+
216
+ a = wp.zeros(5, dtype=float, requires_grad=True, device=device)
217
+ b = wp.zeros_like(a)
218
+ c = wp.zeros_like(a)
219
+ d = wp.zeros_like(a)
220
+
221
+ tape = wp.Tape()
222
+
223
+ with tape:
224
+ wp.launch(func_kernel, dim=5, inputs=[a, b, c, d], device=device)
225
+
226
+ test.assertEqual(a._is_read, True)
227
+ test.assertEqual(b._is_read, True)
228
+ test.assertEqual(c._is_read, False)
229
+ test.assertEqual(d._is_read, False)
230
+
231
+ finally:
232
+ wp.config.verify_autograd_array_access = saved_verify_autograd_array_access_setting
233
+
234
+
235
+ @wp.kernel
236
+ def slice_kernel(x: wp.array3d(dtype=float), y: wp.array3d(dtype=float)):
237
+ i, j, k = wp.tid()
238
+ x_slice = x[i, j]
239
+ val = x_slice[k]
240
+
241
+ y_slice = y[i, j]
242
+ y_slice[k] = val
243
+
244
+
245
+ # test updating array r/w mode after indexing
246
+ def test_multidimensional_indexing(test, device):
247
+ saved_verify_autograd_array_access_setting = wp.config.verify_autograd_array_access
248
+ try:
249
+ wp.config.verify_autograd_array_access = True
250
+
251
+ a = np.arange(3, dtype=float)
252
+ b = np.tile(a, (3, 3, 1))
253
+ x = wp.array3d(b, dtype=float, requires_grad=True, device=device)
254
+ y = wp.zeros_like(x)
255
+
256
+ tape = wp.Tape()
257
+
258
+ with tape:
259
+ wp.launch(slice_kernel, dim=(3, 3, 3), inputs=[x, y], device=device)
260
+
261
+ test.assertEqual(x._is_read, True)
262
+ test.assertEqual(y._is_read, False)
263
+
264
+ finally:
265
+ wp.config.verify_autograd_array_access = saved_verify_autograd_array_access_setting
266
+
267
+
268
+ @wp.kernel
269
+ def inplace_a(x: wp.array(dtype=float)):
270
+ tid = wp.tid()
271
+ x[tid] += 1.0
272
+
273
+
274
+ @wp.kernel
275
+ def inplace_b(x: wp.array(dtype=float), y: wp.array(dtype=float)):
276
+ tid = wp.tid()
277
+ x[tid] += y[tid]
278
+
279
+
280
+ # in-place operators are treated as write
281
+ def test_in_place_operators(test, device):
282
+ saved_verify_autograd_array_access_setting = wp.config.verify_autograd_array_access
283
+ try:
284
+ wp.config.verify_autograd_array_access = True
285
+
286
+ a = wp.zeros(3, dtype=float, requires_grad=True, device=device)
287
+ b = wp.zeros_like(a)
288
+
289
+ tape = wp.Tape()
290
+
291
+ with tape:
292
+ wp.launch(inplace_a, dim=3, inputs=[a], device=device)
293
+
294
+ test.assertEqual(a._is_read, False)
295
+
296
+ tape.reset()
297
+ a.zero_()
298
+
299
+ with tape:
300
+ wp.launch(inplace_b, dim=3, inputs=[a, b], device=device)
301
+
302
+ test.assertEqual(a._is_read, False)
303
+ test.assertEqual(b._is_read, True)
304
+
305
+ finally:
306
+ wp.config.verify_autograd_array_access = saved_verify_autograd_array_access_setting
307
+
308
+
309
+ def test_views(test, device):
310
+ saved_verify_autograd_array_access_setting = wp.config.verify_autograd_array_access
311
+ try:
312
+ wp.config.verify_autograd_array_access = True
313
+
314
+ a = wp.zeros((3, 3), dtype=float, requires_grad=True, device=device)
315
+ test.assertEqual(a._is_read, False)
316
+
317
+ a.mark_write()
318
+
319
+ b = a.view(dtype=int)
320
+ test.assertEqual(b._is_read, False)
321
+
322
+ c = b.flatten()
323
+ test.assertEqual(c._is_read, False)
324
+
325
+ c.mark_read()
326
+ test.assertEqual(a._is_read, True)
327
+
328
+ finally:
329
+ wp.config.verify_autograd_array_access = saved_verify_autograd_array_access_setting
330
+
331
+
332
+ def test_reset(test, device):
333
+ saved_verify_autograd_array_access_setting = wp.config.verify_autograd_array_access
334
+ try:
335
+ wp.config.verify_autograd_array_access = True
336
+
337
+ a = wp.array(np.array([1.0, 2.0, 3.0]), dtype=float, requires_grad=True, device=device)
338
+ b = wp.zeros_like(a)
339
+
340
+ tape = wp.Tape()
341
+ with tape:
342
+ wp.launch(kernel=write_kernel, dim=3, inputs=[b, a], device=device)
343
+
344
+ tape.backward(grads={b: wp.ones(3, dtype=float, device=device)})
345
+
346
+ test.assertEqual(a._is_read, True)
347
+ test.assertEqual(b._is_read, False)
348
+
349
+ tape.reset()
350
+
351
+ test.assertEqual(a._is_read, False)
352
+ test.assertEqual(b._is_read, False)
353
+
354
+ finally:
355
+ wp.config.verify_autograd_array_access = saved_verify_autograd_array_access_setting
356
+
357
+
358
+ # wp.copy uses wp.record_func. Ensure array modes are propagated correctly.
359
+ def test_copy(test, device):
360
+ saved_verify_autograd_array_access_setting = wp.config.verify_autograd_array_access
361
+ try:
362
+ wp.config.verify_autograd_array_access = True
363
+
364
+ a = wp.array(np.array([1.0, 2.0, 3.0]), dtype=float, requires_grad=True, device=device)
365
+ b = wp.zeros_like(a)
366
+
367
+ tape = wp.Tape()
368
+
369
+ with tape:
370
+ wp.copy(b, a)
371
+
372
+ test.assertEqual(a._is_read, True)
373
+ test.assertEqual(b._is_read, False)
374
+
375
+ finally:
376
+ wp.config.verify_autograd_array_access = saved_verify_autograd_array_access_setting
377
+
378
+
379
+ # write after read warning with in-place operators within a kernel
380
+ def test_in_place_operators_warning(test, device):
381
+ saved_verify_autograd_array_access_setting = wp.config.verify_autograd_array_access
382
+ try:
383
+ wp.config.verify_autograd_array_access = True
384
+
385
+ with contextlib.redirect_stdout(io.StringIO()) as f:
386
+
387
+ @wp.kernel
388
+ def inplace_c(x: wp.array(dtype=float)):
389
+ tid = wp.tid()
390
+ x[tid] = 1.0
391
+ a = x[tid]
392
+ x[tid] += a
393
+
394
+ a = wp.zeros(3, dtype=float, requires_grad=True, device=device)
395
+
396
+ tape = wp.Tape()
397
+ with tape:
398
+ wp.launch(inplace_c, dim=3, inputs=[a], device=device)
399
+
400
+ expected = "is being written to after it has been read from within the same kernel. This may corrupt gradient computation in the backward pass."
401
+ test.assertIn(expected, f.getvalue())
402
+
403
+ finally:
404
+ wp.config.verify_autograd_array_access = saved_verify_autograd_array_access_setting
405
+
406
+
407
+ # (kernel READ -> WRITE) failure case
408
+ def test_kernel_readwrite(test, device):
409
+ saved_verify_autograd_array_access_setting = wp.config.verify_autograd_array_access
410
+ try:
411
+ wp.config.verify_autograd_array_access = True
412
+
413
+ with contextlib.redirect_stdout(io.StringIO()) as f:
414
+
415
+ @wp.kernel
416
+ def readwrite_kernel(a: wp.array(dtype=float), b: wp.array(dtype=float)):
417
+ tid = wp.tid()
418
+ b[tid] = a[tid] * a[tid]
419
+ a[tid] = 1.0
420
+
421
+ a = wp.array(np.arange(5), dtype=float, requires_grad=True, device=device)
422
+ b = wp.zeros_like(a)
423
+
424
+ tape = wp.Tape()
425
+ with tape:
426
+ wp.launch(readwrite_kernel, dim=5, inputs=[a, b], device=device)
427
+
428
+ expected = "is being written to after it has been read from within the same kernel. This may corrupt gradient computation in the backward pass."
429
+ test.assertIn(expected, f.getvalue())
430
+
431
+ finally:
432
+ wp.config.verify_autograd_array_access = saved_verify_autograd_array_access_setting
433
+
434
+
435
+ # (kernel READ -> func WRITE) codegen failure case
436
+ def test_kernel_read_func_write(test, device):
437
+ saved_verify_autograd_array_access_setting = wp.config.verify_autograd_array_access
438
+ try:
439
+ wp.config.verify_autograd_array_access = True
440
+
441
+ with contextlib.redirect_stdout(io.StringIO()) as f:
442
+
443
+ @wp.func
444
+ def write_func_2(x: wp.array(dtype=float), idx: int):
445
+ x[idx] = 2.0
446
+
447
+ @wp.kernel
448
+ def read_kernel_func_write(x: wp.array(dtype=float), y: wp.array(dtype=float)):
449
+ tid = wp.tid()
450
+ a = x[tid]
451
+ write_func_2(x, tid)
452
+ y[tid] = a
453
+
454
+ a = wp.array(np.array([1.0, 2.0, 3.0]), dtype=float, requires_grad=True, device=device)
455
+ b = wp.zeros_like(a)
456
+
457
+ tape = wp.Tape()
458
+ with tape:
459
+ wp.launch(kernel=read_kernel_func_write, dim=3, inputs=[a, b], device=device)
460
+
461
+ expected = "written to after it has been read from within the same kernel. This may corrupt gradient computation in the backward pass."
462
+ test.assertIn(expected, f.getvalue())
463
+
464
+ finally:
465
+ wp.config.verify_autograd_array_access = saved_verify_autograd_array_access_setting
466
+
467
+
468
+ @wp.func
469
+ def atomic_func(
470
+ a: wp.array(dtype=wp.int32),
471
+ b: wp.array(dtype=wp.int32),
472
+ c: wp.array(dtype=wp.int32),
473
+ d: wp.array(dtype=wp.int32),
474
+ i: int,
475
+ ):
476
+ wp.atomic_add(a, i, 1)
477
+ wp.atomic_sub(b, i, 1)
478
+ wp.atomic_min(c, i, 1)
479
+ wp.atomic_max(d, i, 3)
480
+
481
+
482
+ @wp.kernel(enable_backward=False)
483
+ def atomic_kernel(
484
+ a: wp.array(dtype=wp.int32), b: wp.array(dtype=wp.int32), c: wp.array(dtype=wp.int32), d: wp.array(dtype=wp.int32)
485
+ ):
486
+ i = wp.tid()
487
+ atomic_func(a, b, c, d, i)
488
+
489
+
490
+ # atomic operations should mark arrays as WRITE
491
+ def test_atomic_operations(test, device):
492
+ saved_verify_autograd_array_access_setting = wp.config.verify_autograd_array_access
493
+ try:
494
+ wp.config.verify_autograd_array_access = True
495
+
496
+ a = wp.array((1, 2, 3), dtype=wp.int32, device=device)
497
+ b = wp.array((1, 2, 3), dtype=wp.int32, device=device)
498
+ c = wp.array((1, 2, 3), dtype=wp.int32, device=device)
499
+ d = wp.array((1, 2, 3), dtype=wp.int32, device=device)
500
+
501
+ wp.launch(atomic_kernel, dim=a.shape, inputs=(a, b, c, d), device=device)
502
+
503
+ test.assertEqual(atomic_kernel.adj.args[0].is_write, True)
504
+ test.assertEqual(atomic_kernel.adj.args[1].is_write, True)
505
+ test.assertEqual(atomic_kernel.adj.args[2].is_write, True)
506
+ test.assertEqual(atomic_kernel.adj.args[3].is_write, True)
507
+
508
+ finally:
509
+ wp.config.verify_autograd_array_access = saved_verify_autograd_array_access_setting
510
+
511
+
512
+ class TestOverwrite(unittest.TestCase):
513
+ pass
514
+
515
+
516
+ devices = get_test_devices()
517
+
518
+ add_function_test(TestOverwrite, "test_kernel_read_kernel_write", test_kernel_read_kernel_write, devices=devices)
519
+ add_function_test(
520
+ TestOverwrite,
521
+ "test_kernel_write_kernel_read_kernel_write",
522
+ test_kernel_write_kernel_read_kernel_write,
523
+ devices=devices,
524
+ )
525
+ add_function_test(
526
+ TestOverwrite, "test_kernel_read_kernel_writeread", test_kernel_read_kernel_writeread, devices=devices
527
+ )
528
+ add_function_test(
529
+ TestOverwrite, "test_kernel_writeread_kernel_write", test_kernel_writeread_kernel_write, devices=devices
530
+ )
531
+ add_function_test(TestOverwrite, "test_nested_function_read_write", test_nested_function_read_write, devices=devices)
532
+ add_function_test(TestOverwrite, "test_multidimensional_indexing", test_multidimensional_indexing, devices=devices)
533
+ add_function_test(TestOverwrite, "test_in_place_operators", test_in_place_operators, devices=devices)
534
+ add_function_test(TestOverwrite, "test_views", test_views, devices=devices)
535
+ add_function_test(TestOverwrite, "test_reset", test_reset, devices=devices)
536
+
537
+ add_function_test(TestOverwrite, "test_copy", test_copy, devices=devices)
538
+ add_function_test(TestOverwrite, "test_atomic_operations", test_atomic_operations, devices=devices)
539
+
540
+ # Some warning are only issued during codegen, and codegen only runs on cuda_0 in the MGPU case.
541
+ cuda_device = get_cuda_test_devices(mode="basic")
542
+
543
+ add_function_test(
544
+ TestOverwrite, "test_in_place_operators_warning", test_in_place_operators_warning, devices=cuda_device
545
+ )
546
+ add_function_test(TestOverwrite, "test_kernel_readwrite", test_kernel_readwrite, devices=cuda_device)
547
+ add_function_test(TestOverwrite, "test_kernel_read_func_write", test_kernel_read_func_write, devices=cuda_device)
548
+
549
+ if __name__ == "__main__":
550
+ wp.build.clear_kernel_cache()
551
+ unittest.main(verbosity=2)