warp-lang 0.9.0__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (315) hide show
  1. warp/__init__.py +15 -7
  2. warp/__init__.pyi +1 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +22 -443
  6. warp/build_dll.py +384 -0
  7. warp/builtins.py +998 -488
  8. warp/codegen.py +1307 -739
  9. warp/config.py +5 -3
  10. warp/constants.py +6 -0
  11. warp/context.py +1291 -548
  12. warp/dlpack.py +31 -31
  13. warp/fabric.py +326 -0
  14. warp/fem/__init__.py +27 -0
  15. warp/fem/cache.py +389 -0
  16. warp/fem/dirichlet.py +181 -0
  17. warp/fem/domain.py +263 -0
  18. warp/fem/field/__init__.py +101 -0
  19. warp/fem/field/field.py +149 -0
  20. warp/fem/field/nodal_field.py +299 -0
  21. warp/fem/field/restriction.py +21 -0
  22. warp/fem/field/test.py +181 -0
  23. warp/fem/field/trial.py +183 -0
  24. warp/fem/geometry/__init__.py +19 -0
  25. warp/fem/geometry/closest_point.py +70 -0
  26. warp/fem/geometry/deformed_geometry.py +271 -0
  27. warp/fem/geometry/element.py +744 -0
  28. warp/fem/geometry/geometry.py +186 -0
  29. warp/fem/geometry/grid_2d.py +373 -0
  30. warp/fem/geometry/grid_3d.py +435 -0
  31. warp/fem/geometry/hexmesh.py +953 -0
  32. warp/fem/geometry/partition.py +376 -0
  33. warp/fem/geometry/quadmesh_2d.py +532 -0
  34. warp/fem/geometry/tetmesh.py +840 -0
  35. warp/fem/geometry/trimesh_2d.py +577 -0
  36. warp/fem/integrate.py +1616 -0
  37. warp/fem/operator.py +191 -0
  38. warp/fem/polynomial.py +213 -0
  39. warp/fem/quadrature/__init__.py +2 -0
  40. warp/fem/quadrature/pic_quadrature.py +245 -0
  41. warp/fem/quadrature/quadrature.py +294 -0
  42. warp/fem/space/__init__.py +292 -0
  43. warp/fem/space/basis_space.py +489 -0
  44. warp/fem/space/collocated_function_space.py +105 -0
  45. warp/fem/space/dof_mapper.py +236 -0
  46. warp/fem/space/function_space.py +145 -0
  47. warp/fem/space/grid_2d_function_space.py +267 -0
  48. warp/fem/space/grid_3d_function_space.py +306 -0
  49. warp/fem/space/hexmesh_function_space.py +352 -0
  50. warp/fem/space/partition.py +350 -0
  51. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  52. warp/fem/space/restriction.py +160 -0
  53. warp/fem/space/shape/__init__.py +15 -0
  54. warp/fem/space/shape/cube_shape_function.py +738 -0
  55. warp/fem/space/shape/shape_function.py +103 -0
  56. warp/fem/space/shape/square_shape_function.py +611 -0
  57. warp/fem/space/shape/tet_shape_function.py +567 -0
  58. warp/fem/space/shape/triangle_shape_function.py +429 -0
  59. warp/fem/space/tetmesh_function_space.py +292 -0
  60. warp/fem/space/topology.py +295 -0
  61. warp/fem/space/trimesh_2d_function_space.py +221 -0
  62. warp/fem/types.py +77 -0
  63. warp/fem/utils.py +495 -0
  64. warp/native/array.h +164 -55
  65. warp/native/builtin.h +150 -174
  66. warp/native/bvh.cpp +75 -328
  67. warp/native/bvh.cu +406 -23
  68. warp/native/bvh.h +37 -45
  69. warp/native/clang/clang.cpp +136 -24
  70. warp/native/crt.cpp +1 -76
  71. warp/native/crt.h +111 -104
  72. warp/native/cuda_crt.h +1049 -0
  73. warp/native/cuda_util.cpp +15 -3
  74. warp/native/cuda_util.h +3 -1
  75. warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
  76. warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
  77. warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
  78. warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
  79. warp/native/cutlass/tools/library/scripts/library.py +799 -0
  80. warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
  81. warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
  82. warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
  83. warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
  84. warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
  85. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
  86. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
  87. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
  88. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
  89. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
  90. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
  91. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
  92. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
  93. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
  94. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
  95. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
  96. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
  97. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
  98. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
  99. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
  100. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
  101. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
  102. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
  103. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
  104. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
  105. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
  106. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
  107. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
  108. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
  109. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
  110. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
  111. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
  112. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
  113. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
  114. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
  115. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
  116. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
  117. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  118. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  119. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
  120. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
  121. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
  122. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  123. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  124. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
  125. warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
  126. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
  127. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
  128. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
  129. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
  130. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
  131. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
  132. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
  133. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
  134. warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
  135. warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
  136. warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
  137. warp/native/cutlass/tools/library/scripts/rt.py +796 -0
  138. warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
  139. warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
  140. warp/native/cutlass_gemm.cu +5 -3
  141. warp/native/exports.h +1240 -949
  142. warp/native/fabric.h +228 -0
  143. warp/native/hashgrid.cpp +4 -4
  144. warp/native/hashgrid.h +22 -2
  145. warp/native/initializer_array.h +2 -2
  146. warp/native/intersect.h +22 -7
  147. warp/native/intersect_adj.h +8 -8
  148. warp/native/intersect_tri.h +13 -16
  149. warp/native/marching.cu +157 -161
  150. warp/native/mat.h +119 -19
  151. warp/native/matnn.h +2 -2
  152. warp/native/mesh.cpp +108 -83
  153. warp/native/mesh.cu +243 -6
  154. warp/native/mesh.h +1547 -458
  155. warp/native/nanovdb/NanoVDB.h +1 -1
  156. warp/native/noise.h +272 -329
  157. warp/native/quat.h +51 -8
  158. warp/native/rand.h +45 -35
  159. warp/native/range.h +6 -2
  160. warp/native/reduce.cpp +157 -0
  161. warp/native/reduce.cu +348 -0
  162. warp/native/runlength_encode.cpp +62 -0
  163. warp/native/runlength_encode.cu +46 -0
  164. warp/native/scan.cu +11 -13
  165. warp/native/scan.h +1 -0
  166. warp/native/solid_angle.h +442 -0
  167. warp/native/sort.cpp +13 -0
  168. warp/native/sort.cu +9 -1
  169. warp/native/sparse.cpp +338 -0
  170. warp/native/sparse.cu +545 -0
  171. warp/native/spatial.h +2 -2
  172. warp/native/temp_buffer.h +30 -0
  173. warp/native/vec.h +126 -24
  174. warp/native/volume.h +120 -0
  175. warp/native/warp.cpp +658 -53
  176. warp/native/warp.cu +660 -68
  177. warp/native/warp.h +112 -12
  178. warp/optim/__init__.py +1 -0
  179. warp/optim/linear.py +922 -0
  180. warp/optim/sgd.py +92 -0
  181. warp/render/render_opengl.py +392 -152
  182. warp/render/render_usd.py +11 -11
  183. warp/sim/__init__.py +2 -2
  184. warp/sim/articulation.py +385 -185
  185. warp/sim/collide.py +21 -8
  186. warp/sim/import_mjcf.py +297 -106
  187. warp/sim/import_urdf.py +389 -210
  188. warp/sim/import_usd.py +198 -97
  189. warp/sim/inertia.py +17 -18
  190. warp/sim/integrator_euler.py +14 -8
  191. warp/sim/integrator_xpbd.py +161 -19
  192. warp/sim/model.py +795 -291
  193. warp/sim/optimizer.py +2 -6
  194. warp/sim/render.py +65 -3
  195. warp/sim/utils.py +3 -0
  196. warp/sparse.py +1227 -0
  197. warp/stubs.py +665 -223
  198. warp/tape.py +66 -15
  199. warp/tests/__main__.py +3 -6
  200. warp/tests/assets/curlnoise_golden.npy +0 -0
  201. warp/tests/assets/pnoise_golden.npy +0 -0
  202. warp/tests/assets/torus.usda +105 -105
  203. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  204. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  205. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  206. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  207. warp/tests/aux_test_unresolved_func.py +14 -0
  208. warp/tests/aux_test_unresolved_symbol.py +14 -0
  209. warp/tests/disabled_kinematics.py +239 -0
  210. warp/tests/run_coverage_serial.py +31 -0
  211. warp/tests/test_adam.py +103 -106
  212. warp/tests/test_arithmetic.py +128 -74
  213. warp/tests/test_array.py +1497 -211
  214. warp/tests/test_array_reduce.py +150 -0
  215. warp/tests/test_atomic.py +64 -28
  216. warp/tests/test_bool.py +99 -0
  217. warp/tests/test_builtins_resolution.py +1292 -0
  218. warp/tests/test_bvh.py +75 -43
  219. warp/tests/test_closest_point_edge_edge.py +54 -57
  220. warp/tests/test_codegen.py +233 -128
  221. warp/tests/test_compile_consts.py +28 -20
  222. warp/tests/test_conditional.py +108 -24
  223. warp/tests/test_copy.py +10 -12
  224. warp/tests/test_ctypes.py +112 -88
  225. warp/tests/test_dense.py +21 -14
  226. warp/tests/test_devices.py +98 -0
  227. warp/tests/test_dlpack.py +136 -108
  228. warp/tests/test_examples.py +277 -0
  229. warp/tests/test_fabricarray.py +955 -0
  230. warp/tests/test_fast_math.py +15 -11
  231. warp/tests/test_fem.py +1271 -0
  232. warp/tests/test_fp16.py +53 -19
  233. warp/tests/test_func.py +187 -74
  234. warp/tests/test_generics.py +194 -49
  235. warp/tests/test_grad.py +180 -116
  236. warp/tests/test_grad_customs.py +176 -0
  237. warp/tests/test_hash_grid.py +52 -37
  238. warp/tests/test_import.py +10 -23
  239. warp/tests/test_indexedarray.py +577 -24
  240. warp/tests/test_intersect.py +18 -9
  241. warp/tests/test_large.py +141 -0
  242. warp/tests/test_launch.py +251 -15
  243. warp/tests/test_lerp.py +64 -65
  244. warp/tests/test_linear_solvers.py +154 -0
  245. warp/tests/test_lvalue.py +493 -0
  246. warp/tests/test_marching_cubes.py +12 -13
  247. warp/tests/test_mat.py +508 -2778
  248. warp/tests/test_mat_lite.py +115 -0
  249. warp/tests/test_mat_scalar_ops.py +2889 -0
  250. warp/tests/test_math.py +103 -9
  251. warp/tests/test_matmul.py +305 -69
  252. warp/tests/test_matmul_lite.py +410 -0
  253. warp/tests/test_mesh.py +71 -14
  254. warp/tests/test_mesh_query_aabb.py +41 -25
  255. warp/tests/test_mesh_query_point.py +325 -34
  256. warp/tests/test_mesh_query_ray.py +39 -22
  257. warp/tests/test_mlp.py +30 -22
  258. warp/tests/test_model.py +92 -89
  259. warp/tests/test_modules_lite.py +39 -0
  260. warp/tests/test_multigpu.py +88 -114
  261. warp/tests/test_noise.py +12 -11
  262. warp/tests/test_operators.py +16 -20
  263. warp/tests/test_options.py +11 -11
  264. warp/tests/test_pinned.py +17 -18
  265. warp/tests/test_print.py +32 -11
  266. warp/tests/test_quat.py +275 -129
  267. warp/tests/test_rand.py +18 -16
  268. warp/tests/test_reload.py +38 -34
  269. warp/tests/test_rounding.py +50 -43
  270. warp/tests/test_runlength_encode.py +190 -0
  271. warp/tests/test_smoothstep.py +9 -11
  272. warp/tests/test_snippet.py +143 -0
  273. warp/tests/test_sparse.py +460 -0
  274. warp/tests/test_spatial.py +276 -243
  275. warp/tests/test_streams.py +110 -85
  276. warp/tests/test_struct.py +331 -85
  277. warp/tests/test_tape.py +39 -21
  278. warp/tests/test_torch.py +118 -89
  279. warp/tests/test_transient_module.py +12 -13
  280. warp/tests/test_types.py +614 -0
  281. warp/tests/test_utils.py +494 -0
  282. warp/tests/test_vec.py +354 -1987
  283. warp/tests/test_vec_lite.py +73 -0
  284. warp/tests/test_vec_scalar_ops.py +2099 -0
  285. warp/tests/test_volume.py +457 -293
  286. warp/tests/test_volume_write.py +124 -134
  287. warp/tests/unittest_serial.py +35 -0
  288. warp/tests/unittest_suites.py +341 -0
  289. warp/tests/unittest_utils.py +568 -0
  290. warp/tests/unused_test_misc.py +71 -0
  291. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  292. warp/thirdparty/appdirs.py +36 -45
  293. warp/thirdparty/unittest_parallel.py +549 -0
  294. warp/torch.py +72 -30
  295. warp/types.py +1744 -713
  296. warp/utils.py +360 -350
  297. warp_lang-0.11.0.dist-info/LICENSE.md +36 -0
  298. warp_lang-0.11.0.dist-info/METADATA +238 -0
  299. warp_lang-0.11.0.dist-info/RECORD +332 -0
  300. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
  301. warp/bin/warp-clang.exp +0 -0
  302. warp/bin/warp-clang.lib +0 -0
  303. warp/bin/warp.exp +0 -0
  304. warp/bin/warp.lib +0 -0
  305. warp/tests/test_all.py +0 -215
  306. warp/tests/test_array_scan.py +0 -60
  307. warp/tests/test_base.py +0 -208
  308. warp/tests/test_unresolved_func.py +0 -7
  309. warp/tests/test_unresolved_symbol.py +0 -7
  310. warp_lang-0.9.0.dist-info/METADATA +0 -20
  311. warp_lang-0.9.0.dist-info/RECORD +0 -177
  312. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  313. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  314. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  315. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,410 @@
1
+ # Copyright (c) 2023 NVIDIA CORPORATION. All rights reserved.
2
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
3
+ # and proprietary rights in and to this software, related documentation
4
+ # and any modifications thereto. Any use, reproduction, disclosure or
5
+ # distribution of this software and related documentation without an express
6
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
+
8
+ import unittest
9
+
10
+ import numpy as np
11
+
12
+ import warp as wp
13
+ from warp.tests.unittest_utils import *
14
+
15
+ wp.init()
16
+
17
+ from warp.context import runtime # noqa: E402
18
+
19
+
20
+ class gemm_test_bed_runner:
21
+ def __init__(self, dtype, device):
22
+ self.dtype = dtype
23
+ self.device = device
24
+
25
+ def alloc(self, m, n, k, batch_count):
26
+ rng = np.random.default_rng(42)
27
+ low = -4.5
28
+ high = 3.5
29
+ if batch_count == 1:
30
+ A = wp.array2d(
31
+ np.ceil(rng.uniform(low=low, high=high, size=(m, k))),
32
+ dtype=self.dtype,
33
+ device=self.device,
34
+ requires_grad=True,
35
+ )
36
+ B = wp.array2d(
37
+ np.ceil(rng.uniform(low=low, high=high, size=(k, n))),
38
+ dtype=self.dtype,
39
+ device=self.device,
40
+ requires_grad=True,
41
+ )
42
+ C = wp.array2d(
43
+ np.ceil(rng.uniform(low=low, high=high, size=(m, n))),
44
+ dtype=self.dtype,
45
+ device=self.device,
46
+ requires_grad=True,
47
+ )
48
+ D = wp.array2d(np.zeros((m, n)), dtype=self.dtype, device=self.device, requires_grad=True)
49
+ else:
50
+ A = wp.array3d(
51
+ np.ceil(rng.uniform(low=low, high=high, size=(batch_count, m, k))),
52
+ dtype=self.dtype,
53
+ device=self.device,
54
+ requires_grad=True,
55
+ )
56
+ B = wp.array3d(
57
+ np.ceil(rng.uniform(low=low, high=high, size=(batch_count, k, n))),
58
+ dtype=self.dtype,
59
+ device=self.device,
60
+ requires_grad=True,
61
+ )
62
+ C = wp.array3d(
63
+ np.ceil(rng.uniform(low=low, high=high, size=(batch_count, m, n))),
64
+ dtype=self.dtype,
65
+ device=self.device,
66
+ requires_grad=True,
67
+ )
68
+ D = wp.array3d(np.zeros((batch_count, m, n)), dtype=self.dtype, device=self.device, requires_grad=True)
69
+ return A, B, C, D
70
+
71
+ def run_and_verify(self, m, n, k, batch_count, alpha, beta):
72
+ A, B, C, D = self.alloc(m, n, k, batch_count)
73
+ ones = wp.zeros_like(D)
74
+ ones.fill_(1.0)
75
+
76
+ if batch_count == 1:
77
+ tape = wp.Tape()
78
+ with tape:
79
+ wp.matmul(A, B, C, D, alpha, beta, False, self.device)
80
+ tape.backward(grads={D: ones})
81
+
82
+ D_np = alpha * (A.numpy() @ B.numpy()) + beta * C.numpy()
83
+ assert np.array_equal(D_np, D.numpy())
84
+
85
+ adj_A_np = alpha * np.matmul(ones.numpy(), B.numpy().transpose())
86
+ adj_B_np = alpha * (A.numpy().transpose() @ ones.numpy())
87
+ adj_C_np = beta * ones.numpy()
88
+
89
+ else:
90
+ tape = wp.Tape()
91
+ with tape:
92
+ wp.batched_matmul(A, B, C, D, alpha, beta, False, self.device)
93
+ tape.backward(grads={D: ones})
94
+
95
+ D_np = alpha * np.matmul(A.numpy(), B.numpy()) + beta * C.numpy()
96
+ assert np.array_equal(D_np, D.numpy())
97
+
98
+ adj_A_np = alpha * np.matmul(ones.numpy(), B.numpy().transpose((0, 2, 1)))
99
+ adj_B_np = alpha * np.matmul(A.numpy().transpose((0, 2, 1)), ones.numpy())
100
+ adj_C_np = beta * ones.numpy()
101
+
102
+ assert np.array_equal(adj_A_np, A.grad.numpy())
103
+ assert np.array_equal(adj_B_np, B.grad.numpy())
104
+ assert np.array_equal(adj_C_np, C.grad.numpy())
105
+
106
+ def run(self):
107
+ Ms = [8]
108
+ Ns = [16]
109
+ Ks = [32]
110
+ batch_counts = [1]
111
+ betas = [1.0]
112
+ alpha = 1.0
113
+
114
+ for batch_count in batch_counts:
115
+ for m in Ms:
116
+ for n in Ns:
117
+ for k in Ks:
118
+ for beta in betas:
119
+ self.run_and_verify(m, n, k, batch_count, alpha, beta)
120
+
121
+
122
+ class gemm_test_bed_runner_transpose:
123
+ def __init__(self, dtype, device):
124
+ self.dtype = dtype
125
+ self.device = device
126
+
127
+ def alloc(self, m, n, k, batch_count):
128
+ rng = np.random.default_rng(42)
129
+ low = -4.5
130
+ high = 3.5
131
+ if batch_count == 1:
132
+ A = wp.array2d(
133
+ np.ceil(rng.uniform(low=low, high=high, size=(m, k))),
134
+ dtype=self.dtype,
135
+ device=self.device,
136
+ requires_grad=True,
137
+ )
138
+ B = wp.array2d(
139
+ np.ceil(rng.uniform(low=low, high=high, size=(k, n))),
140
+ dtype=self.dtype,
141
+ device=self.device,
142
+ requires_grad=True,
143
+ )
144
+ C = wp.array2d(
145
+ np.ceil(rng.uniform(low=low, high=high, size=(m, n))),
146
+ dtype=self.dtype,
147
+ device=self.device,
148
+ requires_grad=True,
149
+ )
150
+ D = wp.array2d(np.zeros((m, n)), dtype=self.dtype, device=self.device, requires_grad=True)
151
+ AT = wp.array2d(A.numpy().transpose([1, 0]), dtype=self.dtype, device=self.device, requires_grad=True)
152
+ BT = wp.array2d(B.numpy().transpose([1, 0]), dtype=self.dtype, device=self.device, requires_grad=True)
153
+ else:
154
+ A = wp.array3d(
155
+ np.ceil(rng.uniform(low=low, high=high, size=(batch_count, m, k))),
156
+ dtype=self.dtype,
157
+ device=self.device,
158
+ requires_grad=True,
159
+ )
160
+ B = wp.array3d(
161
+ np.ceil(rng.uniform(low=low, high=high, size=(batch_count, k, n))),
162
+ dtype=self.dtype,
163
+ device=self.device,
164
+ requires_grad=True,
165
+ )
166
+ C = wp.array3d(
167
+ np.ceil(rng.uniform(low=low, high=high, size=(batch_count, m, n))),
168
+ dtype=self.dtype,
169
+ device=self.device,
170
+ requires_grad=True,
171
+ )
172
+ D = wp.array3d(np.zeros((batch_count, m, n)), dtype=self.dtype, device=self.device, requires_grad=True)
173
+ AT = wp.array3d(A.numpy().transpose([0, 2, 1]), dtype=self.dtype, device=self.device, requires_grad=True)
174
+ BT = wp.array3d(B.numpy().transpose([0, 2, 1]), dtype=self.dtype, device=self.device, requires_grad=True)
175
+ return A, B, C, D, AT, BT
176
+
177
+ def run_and_verify(self, m, n, k, batch_count, alpha, beta):
178
+ A, B, C1, D1, AT1, BT1 = self.alloc(m, n, k, batch_count)
179
+ C2 = wp.clone(C1)
180
+ C3 = wp.clone(C1)
181
+ D2 = wp.clone(D1)
182
+ D3 = wp.clone(D1)
183
+ AT2 = wp.clone(AT1)
184
+ BT2 = wp.clone(BT1)
185
+ ones1 = wp.zeros_like(D1)
186
+ ones1.fill_(1.0)
187
+ ones2 = wp.zeros_like(D2)
188
+ ones2.fill_(1.0)
189
+ ones3 = wp.zeros_like(D3)
190
+ ones3.fill_(1.0)
191
+
192
+ if batch_count == 1:
193
+ ATT1 = AT1.transpose([1, 0])
194
+ BTT1 = BT1.transpose([1, 0])
195
+ ATT2 = AT2.transpose([1, 0])
196
+ BTT2 = BT2.transpose([1, 0])
197
+ tape = wp.Tape()
198
+ with tape:
199
+ wp.matmul(A, BTT1, C1, D1, alpha, beta, False, self.device)
200
+ wp.matmul(ATT1, B, C2, D2, alpha, beta, False, self.device)
201
+ wp.matmul(ATT2, BTT2, C3, D3, alpha, beta, False, self.device)
202
+ tape.backward(grads={D1: ones1, D2: ones2, D3: ones3})
203
+
204
+ D_np = alpha * (A.numpy() @ B.numpy()) + beta * C1.numpy()
205
+ assert np.array_equal(D_np, D1.numpy())
206
+ assert np.array_equal(D_np, D2.numpy())
207
+ assert np.array_equal(D_np, D3.numpy())
208
+
209
+ adj_A_np = alpha * (ones1.numpy() @ B.numpy().transpose())
210
+ adj_B_np = alpha * (A.numpy().transpose() @ ones1.numpy())
211
+ adj_C_np = beta * ones1.numpy()
212
+
213
+ else:
214
+ ATT1 = AT1.transpose([0, 2, 1])
215
+ BTT1 = BT1.transpose([0, 2, 1])
216
+ ATT2 = AT2.transpose([0, 2, 1])
217
+ BTT2 = BT2.transpose([0, 2, 1])
218
+ tape = wp.Tape()
219
+ with tape:
220
+ wp.batched_matmul(A, BTT1, C1, D1, alpha, beta, False, self.device)
221
+ wp.batched_matmul(ATT1, B, C2, D2, alpha, beta, False, self.device)
222
+ wp.batched_matmul(ATT2, BTT2, C3, D3, alpha, beta, False, self.device)
223
+ tape.backward(grads={D1: ones1, D2: ones2, D3: ones3})
224
+
225
+ D_np = alpha * np.matmul(A.numpy(), B.numpy()) + beta * C1.numpy()
226
+ assert np.array_equal(D_np, D1.numpy())
227
+ assert np.array_equal(D_np, D2.numpy())
228
+ assert np.array_equal(D_np, D3.numpy())
229
+
230
+ adj_A_np = alpha * np.matmul(ones1.numpy(), B.numpy().transpose((0, 2, 1)))
231
+ adj_B_np = alpha * np.matmul(A.numpy().transpose((0, 2, 1)), ones1.numpy())
232
+ adj_C_np = beta * ones1.numpy()
233
+
234
+ assert np.array_equal(adj_A_np, A.grad.numpy())
235
+ assert np.array_equal(adj_A_np, ATT1.grad.numpy())
236
+ assert np.array_equal(adj_A_np, ATT2.grad.numpy())
237
+ assert np.array_equal(adj_B_np, B.grad.numpy())
238
+ assert np.array_equal(adj_B_np, BTT1.grad.numpy())
239
+ assert np.array_equal(adj_B_np, BTT2.grad.numpy())
240
+ assert np.array_equal(adj_C_np, C1.grad.numpy())
241
+ assert np.array_equal(adj_C_np, C2.grad.numpy())
242
+ assert np.array_equal(adj_C_np, C3.grad.numpy())
243
+
244
+ def run(self):
245
+ m = 8
246
+ n = 16
247
+ k = 32
248
+ batch_counts = [1, 4]
249
+ beta = 1.0
250
+ alpha = 1.0
251
+
252
+ for batch_count in batch_counts:
253
+ self.run_and_verify(m, n, k, batch_count, alpha, beta)
254
+
255
+
256
+ @unittest.skipUnless(runtime.core.is_cutlass_enabled(), "Warp was not built with CUTLASS support")
257
+ def test_f32(test, device):
258
+ gemm_test_bed_runner(wp.float32, device).run()
259
+ gemm_test_bed_runner_transpose(wp.float32, device).run()
260
+
261
+
262
+ @wp.kernel
263
+ def matrix_sum_kernel(arr: wp.array2d(dtype=float), loss: wp.array(dtype=float)):
264
+ i, j = wp.tid()
265
+ wp.atomic_add(loss, 0, arr[i, j])
266
+
267
+
268
+ @unittest.skipUnless(runtime.core.is_cutlass_enabled(), "Warp was not built with CUTLASS support")
269
+ def test_tape(test, device):
270
+ rng = np.random.default_rng(42)
271
+ low = -4.5
272
+ high = 3.5
273
+ m = 8
274
+ n = 16
275
+ k = 32
276
+ A = wp.array2d(
277
+ np.ceil(rng.uniform(low=low, high=high, size=(m, k))), dtype=float, device=device, requires_grad=True
278
+ )
279
+ B = wp.array2d(
280
+ np.ceil(rng.uniform(low=low, high=high, size=(k, n))), dtype=float, device=device, requires_grad=True
281
+ )
282
+ C = wp.array2d(
283
+ np.ceil(rng.uniform(low=low, high=high, size=(m, n))), dtype=float, device=device, requires_grad=True
284
+ )
285
+ D = wp.array2d(np.zeros((m, n)), dtype=float, device=device, requires_grad=True)
286
+ loss = wp.zeros(1, dtype=float, device=device, requires_grad=True)
287
+
288
+ # test tape
289
+ tape = wp.Tape()
290
+ with tape:
291
+ wp.matmul(A, B, C, D, device=device)
292
+ wp.launch(matrix_sum_kernel, dim=(m, n), inputs=[D, loss], device=device)
293
+
294
+ tape.backward(loss=loss)
295
+ A_grad = A.grad.numpy()
296
+ tape.reset()
297
+
298
+ # test adjoint
299
+ D.grad = wp.array2d(np.ones((m, n)), dtype=float, device=device)
300
+ wp.adj_matmul(A, B, C, A.grad, B.grad, C.grad, D.grad, device=device)
301
+ assert_np_equal(A_grad, A.grad.numpy())
302
+
303
+ # test zero
304
+ tape.zero()
305
+ assert_array_equal(A.grad, wp.zeros_like(A))
306
+
307
+
308
+ @unittest.skipUnless(runtime.core.is_cutlass_enabled(), "Warp was not built with CUTLASS support")
309
+ def test_operator(test, device):
310
+ rng = np.random.default_rng(42)
311
+ low = -4.5
312
+ high = 3.5
313
+ m = 8
314
+ n = 16
315
+ k = 32
316
+ A = wp.array2d(
317
+ np.ceil(rng.uniform(low=low, high=high, size=(m, k))), dtype=float, device=device, requires_grad=True
318
+ )
319
+ B = wp.array2d(
320
+ np.ceil(rng.uniform(low=low, high=high, size=(k, n))), dtype=float, device=device, requires_grad=True
321
+ )
322
+ loss = wp.zeros(1, dtype=float, device=device, requires_grad=True)
323
+
324
+ # test tape
325
+ tape = wp.Tape()
326
+ with tape:
327
+ D = A @ B
328
+ wp.launch(matrix_sum_kernel, dim=(m, n), inputs=[D, loss], device=device)
329
+
330
+ tape.backward(loss=loss)
331
+
332
+ # test adjoint
333
+ D.grad = wp.array2d(np.ones((m, n)), dtype=float, device=device)
334
+ B_transpose = wp.array2d(B.transpose().numpy(), dtype=float, device=device)
335
+
336
+ adj_A = D.grad @ B_transpose
337
+ assert_array_equal(adj_A, A.grad)
338
+
339
+ # test zero
340
+ tape.zero()
341
+ assert_array_equal(A.grad, wp.zeros_like(A))
342
+
343
+
344
+ @unittest.skipUnless(runtime.core.is_cutlass_enabled(), "Warp was not built with CUTLASS support")
345
+ def test_large_batch_count(test, device):
346
+ rng = np.random.default_rng(42)
347
+ low = -4.5
348
+ high = 3.5
349
+ m = 2
350
+ n = 3
351
+ k = 4
352
+ batch_count = 65535 * 2 + int(65535 / 2)
353
+ A = wp.array3d(
354
+ np.ceil(rng.uniform(low=low, high=high, size=(batch_count, m, k))),
355
+ dtype=float,
356
+ device=device,
357
+ requires_grad=True,
358
+ )
359
+ B = wp.array3d(
360
+ np.ceil(rng.uniform(low=low, high=high, size=(batch_count, k, n))),
361
+ dtype=float,
362
+ device=device,
363
+ requires_grad=True,
364
+ )
365
+ C = wp.array3d(
366
+ np.ceil(rng.uniform(low=low, high=high, size=(batch_count, m, n))),
367
+ dtype=float,
368
+ device=device,
369
+ requires_grad=True,
370
+ )
371
+ D = wp.array3d(np.zeros((batch_count, m, n)), dtype=float, device=device, requires_grad=True)
372
+ ones = wp.zeros_like(D)
373
+ ones.fill_(1.0)
374
+
375
+ alpha = 1.0
376
+ beta = 1.0
377
+
378
+ tape = wp.Tape()
379
+ with tape:
380
+ wp.batched_matmul(A, B, C, D, alpha=alpha, beta=beta, allow_tf32x3_arith=False, device=device)
381
+ tape.backward(grads={D: ones})
382
+
383
+ D_np = alpha * np.matmul(A.numpy(), B.numpy()) + beta * C.numpy()
384
+ assert np.array_equal(D_np, D.numpy())
385
+
386
+ adj_A_np = alpha * np.matmul(ones.numpy(), B.numpy().transpose((0, 2, 1)))
387
+ adj_B_np = alpha * np.matmul(A.numpy().transpose((0, 2, 1)), ones.numpy())
388
+ adj_C_np = beta * ones.numpy()
389
+
390
+ assert np.array_equal(adj_A_np, A.grad.numpy())
391
+ assert np.array_equal(adj_B_np, B.grad.numpy())
392
+ assert np.array_equal(adj_C_np, C.grad.numpy())
393
+
394
+
395
+ devices = get_test_devices()
396
+
397
+
398
+ class TestMatmulLite(unittest.TestCase):
399
+ pass
400
+
401
+
402
+ add_function_test(TestMatmulLite, "test_f32", test_f32, devices=devices)
403
+ add_function_test(TestMatmulLite, "test_tape", test_tape, devices=devices)
404
+ add_function_test(TestMatmulLite, "test_operator", test_operator, devices=devices)
405
+ add_function_test(TestMatmulLite, "test_large_batch_count", test_large_batch_count, devices=devices)
406
+
407
+
408
+ if __name__ == "__main__":
409
+ wp.build.clear_kernel_cache()
410
+ unittest.main(verbosity=2, failfast=False)
warp/tests/test_mesh.py CHANGED
@@ -10,8 +10,7 @@ import unittest
10
10
  import numpy as np
11
11
 
12
12
  import warp as wp
13
- from warp.tests.test_base import *
14
-
13
+ from warp.tests.unittest_utils import *
15
14
 
16
15
  # fmt: off
17
16
 
@@ -223,9 +222,9 @@ def query_ray_kernel(
223
222
 
224
223
 
225
224
  def test_mesh_query_ray(test, device):
226
- points = wp.array(POINT_POSITIONS, dtype=wp.vec3)
225
+ points = wp.array(POINT_POSITIONS, dtype=wp.vec3, device=device)
227
226
 
228
- indices = wp.array(RIGHT_HANDED_FACE_VERTEX_INDICES, dtype=int)
227
+ indices = wp.array(RIGHT_HANDED_FACE_VERTEX_INDICES, dtype=int, device=device)
229
228
  mesh = wp.Mesh(points=points, indices=indices)
230
229
  expected_sign = -1.0
231
230
  wp.launch(
@@ -235,9 +234,10 @@ def test_mesh_query_ray(test, device):
235
234
  mesh.id,
236
235
  expected_sign,
237
236
  ],
237
+ device=device,
238
238
  )
239
239
 
240
- indices = wp.array(LEFT_HANDED_FACE_VERTEX_INDICES, dtype=int)
240
+ indices = wp.array(LEFT_HANDED_FACE_VERTEX_INDICES, dtype=int, device=device)
241
241
  mesh = wp.Mesh(points=points, indices=indices)
242
242
  expected_sign = 1.0
243
243
  wp.launch(
@@ -247,21 +247,78 @@ def test_mesh_query_ray(test, device):
247
247
  mesh.id,
248
248
  expected_sign,
249
249
  ],
250
+ device=device,
250
251
  )
251
252
 
252
253
 
253
- def register(parent):
254
- devices = get_test_devices()
254
+ def test_mesh_refit_graph(test, device):
255
+ points = wp.array(POINT_POSITIONS, dtype=wp.vec3, device=device)
256
+
257
+ indices = wp.array(RIGHT_HANDED_FACE_VERTEX_INDICES, dtype=int, device=device)
258
+ mesh = wp.Mesh(points=points, indices=indices)
259
+
260
+ wp.capture_begin(device, force_module_load=False)
261
+ try:
262
+ mesh.refit()
263
+ finally:
264
+ graph = wp.capture_end(device)
265
+
266
+ # replay
267
+ num_iters = 10
268
+ for _ in range(num_iters):
269
+ wp.capture_launch(graph)
270
+
271
+ wp.synchronize_device(device)
272
+
273
+
274
+ def test_mesh_exceptions(test, device):
275
+ # points and indices must be on same device
276
+ with test.assertRaises(RuntimeError):
277
+ points = wp.array(POINT_POSITIONS, dtype=wp.vec3, device="cpu")
278
+ indices = wp.array(RIGHT_HANDED_FACE_VERTEX_INDICES, dtype=int, device=device)
279
+ wp.Mesh(points=points, indices=indices)
280
+
281
+ # points must be vec3
282
+ with test.assertRaises(RuntimeError):
283
+ points = wp.array(POINT_POSITIONS, dtype=wp.vec3d, device=device)
284
+ indices = wp.array(RIGHT_HANDED_FACE_VERTEX_INDICES, dtype=int, device=device)
285
+ wp.Mesh(points=points, indices=indices)
286
+
287
+ # velocities must be vec3
288
+ with test.assertRaises(RuntimeError):
289
+ points = wp.array(POINT_POSITIONS, dtype=wp.vec3, device=device)
290
+ velocities = wp.zeros(points.shape, dtype=wp.vec3d, device=device)
291
+ indices = wp.array(RIGHT_HANDED_FACE_VERTEX_INDICES, dtype=int, device=device)
292
+ wp.Mesh(points=points, indices=indices, velocities=velocities)
293
+
294
+ # indices must be int32
295
+ with test.assertRaises(RuntimeError):
296
+ points = wp.array(POINT_POSITIONS, dtype=wp.vec3, device=device)
297
+ indices = wp.array(RIGHT_HANDED_FACE_VERTEX_INDICES, dtype=wp.int64, device=device)
298
+ wp.Mesh(points=points, indices=indices)
299
+
300
+ # indices must be 1d
301
+ with test.assertRaises(RuntimeError):
302
+ points = wp.array(POINT_POSITIONS, dtype=wp.vec3, device=device)
303
+ indices = wp.array(RIGHT_HANDED_FACE_VERTEX_INDICES, dtype=int, device=device)
304
+ indices = indices.reshape((3, -1))
305
+ wp.Mesh(points=points, indices=indices)
306
+
307
+
308
+ devices = get_test_devices()
309
+
310
+
311
+ class TestMesh(unittest.TestCase):
312
+ pass
255
313
 
256
- class TestMesh(parent):
257
- pass
258
314
 
259
- add_function_test(TestMesh, "test_mesh_read_properties", test_mesh_read_properties, devices=devices)
260
- add_function_test(TestMesh, "test_mesh_query_point", test_mesh_query_point, devices=devices)
261
- add_function_test(TestMesh, "test_mesh_query_ray", test_mesh_query_ray, devices=devices)
262
- return TestMesh
315
+ add_function_test(TestMesh, "test_mesh_read_properties", test_mesh_read_properties, devices=devices)
316
+ add_function_test(TestMesh, "test_mesh_query_point", test_mesh_query_point, devices=devices)
317
+ add_function_test(TestMesh, "test_mesh_query_ray", test_mesh_query_ray, devices=devices)
318
+ add_function_test(TestMesh, "test_mesh_refit_graph", test_mesh_refit_graph, devices=get_unique_cuda_test_devices())
319
+ add_function_test(TestMesh, "test_mesh_exceptions", test_mesh_exceptions, devices=get_unique_cuda_test_devices())
263
320
 
264
321
 
265
322
  if __name__ == "__main__":
266
- _ = register(unittest.TestCase)
323
+ wp.build.clear_kernel_cache()
267
324
  unittest.main(verbosity=2)
@@ -5,10 +5,12 @@
5
5
  # distribution of this software and related documentation without an express
6
6
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
7
 
8
+ import unittest
9
+
8
10
  import numpy as np
9
11
 
10
12
  import warp as wp
11
- from warp.tests.test_base import *
13
+ from warp.tests.unittest_utils import *
12
14
 
13
15
  wp.init()
14
16
 
@@ -96,7 +98,6 @@ def test_compute_bounds(test, device):
96
98
 
97
99
  lower_view = lowers.numpy()
98
100
  upper_view = uppers.numpy()
99
- wp.synchronize()
100
101
 
101
102
  # Confirm the bounds of each triangle are correct.
102
103
  test.assertTrue(lower_view[0][0] == 0)
@@ -148,8 +149,6 @@ def test_mesh_query_aabb_count_overlap(test, device):
148
149
  device=device,
149
150
  )
150
151
 
151
- wp.synchronize()
152
-
153
152
  view = counts.numpy()
154
153
 
155
154
  # 2 triangles that share a vertex having overlapping AABBs.
@@ -188,8 +187,6 @@ def test_mesh_query_aabb_count_nonoverlap(test, device):
188
187
  device=device,
189
188
  )
190
189
 
191
- wp.synchronize()
192
-
193
190
  view = counts.numpy()
194
191
 
195
192
  # AABB query only returns one triangle at a time, the triangles are not close enough to overlap.
@@ -197,29 +194,48 @@ def test_mesh_query_aabb_count_nonoverlap(test, device):
197
194
  test.assertTrue(c == 1)
198
195
 
199
196
 
200
- def register(parent):
201
- devices = get_test_devices()
197
+ def test_mesh_query_aabb_codegen_adjoints_with_select(test, device):
198
+ def kernel_fn(
199
+ mesh: wp.uint64,
200
+ ):
201
+ v = wp.vec3(0.0, 0.0, 0.0)
202
202
 
203
- class TestMeshQueryAABBMethods(parent):
204
- pass
203
+ if True:
204
+ query = wp.mesh_query_aabb(mesh, v, v)
205
+ else:
206
+ query = wp.mesh_query_aabb(mesh, v, v)
207
+
208
+ wp.Kernel(func=kernel_fn)
209
+
210
+
211
+ devices = get_test_devices()
212
+
213
+
214
+ class TestMeshQueryAABBMethods(unittest.TestCase):
215
+ pass
205
216
 
206
- add_function_test(TestMeshQueryAABBMethods, "test_compute_bounds", test_compute_bounds, devices=devices)
207
- add_function_test(
208
- TestMeshQueryAABBMethods,
209
- "test_mesh_query_aabb_count_overlap",
210
- test_mesh_query_aabb_count_overlap,
211
- devices=devices,
212
- )
213
- add_function_test(
214
- TestMeshQueryAABBMethods,
215
- "test_mesh_query_aabb_count_nonoverlap",
216
- test_mesh_query_aabb_count_nonoverlap,
217
- devices=devices,
218
- )
219
217
 
220
- return TestMeshQueryAABBMethods
218
+ add_function_test(TestMeshQueryAABBMethods, "test_compute_bounds", test_compute_bounds, devices=devices)
219
+ add_function_test(
220
+ TestMeshQueryAABBMethods,
221
+ "test_mesh_query_aabb_count_overlap",
222
+ test_mesh_query_aabb_count_overlap,
223
+ devices=devices,
224
+ )
225
+ add_function_test(
226
+ TestMeshQueryAABBMethods,
227
+ "test_mesh_query_aabb_count_nonoverlap",
228
+ test_mesh_query_aabb_count_nonoverlap,
229
+ devices=devices,
230
+ )
231
+ add_function_test(
232
+ TestMeshQueryAABBMethods,
233
+ "test_mesh_query_aabb_codegen_adjoints_with_select",
234
+ test_mesh_query_aabb_codegen_adjoints_with_select,
235
+ devices=devices,
236
+ )
221
237
 
222
238
 
223
239
  if __name__ == "__main__":
224
- c = register(unittest.TestCase)
240
+ wp.build.clear_kernel_cache()
225
241
  unittest.main(verbosity=2)