warp-lang 0.10.1__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (300) hide show
  1. warp/__init__.py +10 -4
  2. warp/__init__.pyi +1 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +5 -3
  6. warp/build_dll.py +29 -9
  7. warp/builtins.py +868 -507
  8. warp/codegen.py +1074 -638
  9. warp/config.py +3 -3
  10. warp/constants.py +6 -0
  11. warp/context.py +715 -222
  12. warp/fabric.py +326 -0
  13. warp/fem/__init__.py +27 -0
  14. warp/fem/cache.py +389 -0
  15. warp/fem/dirichlet.py +181 -0
  16. warp/fem/domain.py +263 -0
  17. warp/fem/field/__init__.py +101 -0
  18. warp/fem/field/field.py +149 -0
  19. warp/fem/field/nodal_field.py +299 -0
  20. warp/fem/field/restriction.py +21 -0
  21. warp/fem/field/test.py +181 -0
  22. warp/fem/field/trial.py +183 -0
  23. warp/fem/geometry/__init__.py +19 -0
  24. warp/fem/geometry/closest_point.py +70 -0
  25. warp/fem/geometry/deformed_geometry.py +271 -0
  26. warp/fem/geometry/element.py +744 -0
  27. warp/fem/geometry/geometry.py +186 -0
  28. warp/fem/geometry/grid_2d.py +373 -0
  29. warp/fem/geometry/grid_3d.py +435 -0
  30. warp/fem/geometry/hexmesh.py +953 -0
  31. warp/fem/geometry/partition.py +376 -0
  32. warp/fem/geometry/quadmesh_2d.py +532 -0
  33. warp/fem/geometry/tetmesh.py +840 -0
  34. warp/fem/geometry/trimesh_2d.py +577 -0
  35. warp/fem/integrate.py +1616 -0
  36. warp/fem/operator.py +191 -0
  37. warp/fem/polynomial.py +213 -0
  38. warp/fem/quadrature/__init__.py +2 -0
  39. warp/fem/quadrature/pic_quadrature.py +245 -0
  40. warp/fem/quadrature/quadrature.py +294 -0
  41. warp/fem/space/__init__.py +292 -0
  42. warp/fem/space/basis_space.py +489 -0
  43. warp/fem/space/collocated_function_space.py +105 -0
  44. warp/fem/space/dof_mapper.py +236 -0
  45. warp/fem/space/function_space.py +145 -0
  46. warp/fem/space/grid_2d_function_space.py +267 -0
  47. warp/fem/space/grid_3d_function_space.py +306 -0
  48. warp/fem/space/hexmesh_function_space.py +352 -0
  49. warp/fem/space/partition.py +350 -0
  50. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  51. warp/fem/space/restriction.py +160 -0
  52. warp/fem/space/shape/__init__.py +15 -0
  53. warp/fem/space/shape/cube_shape_function.py +738 -0
  54. warp/fem/space/shape/shape_function.py +103 -0
  55. warp/fem/space/shape/square_shape_function.py +611 -0
  56. warp/fem/space/shape/tet_shape_function.py +567 -0
  57. warp/fem/space/shape/triangle_shape_function.py +429 -0
  58. warp/fem/space/tetmesh_function_space.py +292 -0
  59. warp/fem/space/topology.py +295 -0
  60. warp/fem/space/trimesh_2d_function_space.py +221 -0
  61. warp/fem/types.py +77 -0
  62. warp/fem/utils.py +495 -0
  63. warp/native/array.h +147 -44
  64. warp/native/builtin.h +122 -149
  65. warp/native/bvh.cpp +73 -325
  66. warp/native/bvh.cu +406 -23
  67. warp/native/bvh.h +34 -43
  68. warp/native/clang/clang.cpp +13 -8
  69. warp/native/crt.h +2 -0
  70. warp/native/cuda_crt.h +5 -0
  71. warp/native/cuda_util.cpp +15 -3
  72. warp/native/cuda_util.h +3 -1
  73. warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
  74. warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
  75. warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
  76. warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
  77. warp/native/cutlass/tools/library/scripts/library.py +799 -0
  78. warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
  79. warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
  80. warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
  81. warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
  82. warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
  83. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
  84. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
  85. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
  86. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
  87. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
  88. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
  89. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
  90. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
  91. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
  92. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
  93. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
  94. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
  95. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
  96. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
  97. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
  98. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
  99. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
  100. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
  101. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
  102. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
  103. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
  104. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
  105. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
  106. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
  107. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
  108. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
  109. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
  110. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
  111. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
  112. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
  113. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
  114. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
  115. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  116. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  117. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
  118. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
  119. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
  120. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  121. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  122. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
  123. warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
  124. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
  125. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
  126. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
  127. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
  128. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
  129. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
  130. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
  131. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
  132. warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
  133. warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
  134. warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
  135. warp/native/cutlass/tools/library/scripts/rt.py +796 -0
  136. warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
  137. warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
  138. warp/native/cutlass_gemm.cu +5 -3
  139. warp/native/exports.h +1240 -952
  140. warp/native/fabric.h +228 -0
  141. warp/native/hashgrid.cpp +4 -4
  142. warp/native/hashgrid.h +22 -2
  143. warp/native/intersect.h +22 -7
  144. warp/native/intersect_adj.h +8 -8
  145. warp/native/intersect_tri.h +1 -1
  146. warp/native/marching.cu +157 -161
  147. warp/native/mat.h +80 -19
  148. warp/native/matnn.h +2 -2
  149. warp/native/mesh.cpp +33 -108
  150. warp/native/mesh.cu +114 -23
  151. warp/native/mesh.h +446 -46
  152. warp/native/noise.h +272 -329
  153. warp/native/quat.h +51 -8
  154. warp/native/rand.h +45 -35
  155. warp/native/range.h +6 -2
  156. warp/native/reduce.cpp +1 -1
  157. warp/native/reduce.cu +10 -12
  158. warp/native/runlength_encode.cu +6 -10
  159. warp/native/scan.cu +8 -11
  160. warp/native/sparse.cpp +4 -4
  161. warp/native/sparse.cu +164 -154
  162. warp/native/spatial.h +2 -2
  163. warp/native/temp_buffer.h +14 -30
  164. warp/native/vec.h +107 -23
  165. warp/native/volume.h +120 -0
  166. warp/native/warp.cpp +560 -30
  167. warp/native/warp.cu +431 -44
  168. warp/native/warp.h +13 -4
  169. warp/optim/__init__.py +1 -0
  170. warp/optim/linear.py +922 -0
  171. warp/optim/sgd.py +92 -0
  172. warp/render/render_opengl.py +335 -119
  173. warp/render/render_usd.py +11 -11
  174. warp/sim/__init__.py +2 -2
  175. warp/sim/articulation.py +385 -185
  176. warp/sim/collide.py +8 -0
  177. warp/sim/import_mjcf.py +297 -106
  178. warp/sim/import_urdf.py +389 -210
  179. warp/sim/import_usd.py +198 -97
  180. warp/sim/inertia.py +17 -18
  181. warp/sim/integrator_euler.py +14 -8
  182. warp/sim/integrator_xpbd.py +158 -16
  183. warp/sim/model.py +795 -291
  184. warp/sim/render.py +3 -3
  185. warp/sim/utils.py +3 -0
  186. warp/sparse.py +640 -150
  187. warp/stubs.py +606 -267
  188. warp/tape.py +61 -10
  189. warp/tests/__main__.py +3 -6
  190. warp/tests/assets/curlnoise_golden.npy +0 -0
  191. warp/tests/assets/pnoise_golden.npy +0 -0
  192. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  193. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  194. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  195. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  196. warp/tests/aux_test_unresolved_func.py +14 -0
  197. warp/tests/aux_test_unresolved_symbol.py +14 -0
  198. warp/tests/disabled_kinematics.py +239 -0
  199. warp/tests/run_coverage_serial.py +31 -0
  200. warp/tests/test_adam.py +103 -106
  201. warp/tests/test_arithmetic.py +128 -74
  202. warp/tests/test_array.py +212 -97
  203. warp/tests/test_array_reduce.py +57 -23
  204. warp/tests/test_atomic.py +64 -28
  205. warp/tests/test_bool.py +99 -0
  206. warp/tests/test_builtins_resolution.py +1292 -0
  207. warp/tests/test_bvh.py +42 -18
  208. warp/tests/test_closest_point_edge_edge.py +54 -57
  209. warp/tests/test_codegen.py +208 -130
  210. warp/tests/test_compile_consts.py +28 -20
  211. warp/tests/test_conditional.py +108 -24
  212. warp/tests/test_copy.py +10 -12
  213. warp/tests/test_ctypes.py +112 -88
  214. warp/tests/test_dense.py +21 -14
  215. warp/tests/test_devices.py +98 -0
  216. warp/tests/test_dlpack.py +75 -75
  217. warp/tests/test_examples.py +277 -0
  218. warp/tests/test_fabricarray.py +955 -0
  219. warp/tests/test_fast_math.py +15 -11
  220. warp/tests/test_fem.py +1271 -0
  221. warp/tests/test_fp16.py +53 -19
  222. warp/tests/test_func.py +187 -86
  223. warp/tests/test_generics.py +194 -49
  224. warp/tests/test_grad.py +178 -109
  225. warp/tests/test_grad_customs.py +176 -0
  226. warp/tests/test_hash_grid.py +52 -37
  227. warp/tests/test_import.py +10 -23
  228. warp/tests/test_indexedarray.py +32 -31
  229. warp/tests/test_intersect.py +18 -9
  230. warp/tests/test_large.py +141 -0
  231. warp/tests/test_launch.py +14 -41
  232. warp/tests/test_lerp.py +64 -65
  233. warp/tests/test_linear_solvers.py +154 -0
  234. warp/tests/test_lvalue.py +493 -0
  235. warp/tests/test_marching_cubes.py +12 -13
  236. warp/tests/test_mat.py +517 -2898
  237. warp/tests/test_mat_lite.py +115 -0
  238. warp/tests/test_mat_scalar_ops.py +2889 -0
  239. warp/tests/test_math.py +103 -9
  240. warp/tests/test_matmul.py +305 -69
  241. warp/tests/test_matmul_lite.py +410 -0
  242. warp/tests/test_mesh.py +71 -14
  243. warp/tests/test_mesh_query_aabb.py +41 -25
  244. warp/tests/test_mesh_query_point.py +140 -22
  245. warp/tests/test_mesh_query_ray.py +39 -22
  246. warp/tests/test_mlp.py +30 -22
  247. warp/tests/test_model.py +92 -89
  248. warp/tests/test_modules_lite.py +39 -0
  249. warp/tests/test_multigpu.py +88 -114
  250. warp/tests/test_noise.py +12 -11
  251. warp/tests/test_operators.py +16 -20
  252. warp/tests/test_options.py +11 -11
  253. warp/tests/test_pinned.py +17 -18
  254. warp/tests/test_print.py +32 -11
  255. warp/tests/test_quat.py +275 -129
  256. warp/tests/test_rand.py +18 -16
  257. warp/tests/test_reload.py +38 -34
  258. warp/tests/test_rounding.py +50 -43
  259. warp/tests/test_runlength_encode.py +168 -20
  260. warp/tests/test_smoothstep.py +9 -11
  261. warp/tests/test_snippet.py +143 -0
  262. warp/tests/test_sparse.py +261 -63
  263. warp/tests/test_spatial.py +276 -243
  264. warp/tests/test_streams.py +110 -85
  265. warp/tests/test_struct.py +268 -63
  266. warp/tests/test_tape.py +39 -21
  267. warp/tests/test_torch.py +118 -89
  268. warp/tests/test_transient_module.py +12 -13
  269. warp/tests/test_types.py +614 -0
  270. warp/tests/test_utils.py +494 -0
  271. warp/tests/test_vec.py +354 -2050
  272. warp/tests/test_vec_lite.py +73 -0
  273. warp/tests/test_vec_scalar_ops.py +2099 -0
  274. warp/tests/test_volume.py +457 -293
  275. warp/tests/test_volume_write.py +124 -134
  276. warp/tests/unittest_serial.py +35 -0
  277. warp/tests/unittest_suites.py +341 -0
  278. warp/tests/unittest_utils.py +568 -0
  279. warp/tests/unused_test_misc.py +71 -0
  280. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  281. warp/thirdparty/appdirs.py +36 -45
  282. warp/thirdparty/unittest_parallel.py +549 -0
  283. warp/torch.py +9 -6
  284. warp/types.py +1089 -366
  285. warp/utils.py +93 -387
  286. warp_lang-0.11.0.dist-info/METADATA +238 -0
  287. warp_lang-0.11.0.dist-info/RECORD +332 -0
  288. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
  289. warp/tests/test_all.py +0 -219
  290. warp/tests/test_array_scan.py +0 -60
  291. warp/tests/test_base.py +0 -208
  292. warp/tests/test_unresolved_func.py +0 -7
  293. warp/tests/test_unresolved_symbol.py +0 -7
  294. warp_lang-0.10.1.dist-info/METADATA +0 -21
  295. warp_lang-0.10.1.dist-info/RECORD +0 -188
  296. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  297. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  298. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  299. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/LICENSE.md +0 -0
  300. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
warp/tests/test_grad.py CHANGED
@@ -5,9 +5,13 @@
5
5
  # distribution of this software and related documentation without an express
6
6
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
7
 
8
+ import unittest
9
+ from typing import Any
10
+
8
11
  import numpy as np
12
+
9
13
  import warp as wp
10
- from warp.tests.test_base import *
14
+ from warp.tests.unittest_utils import *
11
15
 
12
16
  wp.init()
13
17
 
@@ -63,26 +67,26 @@ def test_for_loop_grad(test, device):
63
67
 
64
68
 
65
69
  def test_for_loop_graph_grad(test, device):
70
+ wp.load_module(device=device)
71
+
66
72
  n = 32
67
73
  val = np.ones(n, dtype=np.float32)
68
74
 
69
75
  x = wp.array(val, device=device, requires_grad=True)
70
76
  sum = wp.zeros(1, dtype=wp.float32, device=device, requires_grad=True)
71
77
 
72
- wp.force_load()
73
-
74
- wp.capture_begin()
78
+ wp.capture_begin(device, force_module_load=False)
79
+ try:
80
+ tape = wp.Tape()
81
+ with tape:
82
+ wp.launch(for_loop_grad, dim=1, inputs=[n, x, sum], device=device)
75
83
 
76
- tape = wp.Tape()
77
- with tape:
78
- wp.launch(for_loop_grad, dim=1, inputs=[n, x, sum], device=device)
79
-
80
- tape.backward(loss=sum)
81
-
82
- graph = wp.capture_end()
84
+ tape.backward(loss=sum)
85
+ finally:
86
+ graph = wp.capture_end(device)
83
87
 
84
88
  wp.capture_launch(graph)
85
- wp.synchronize()
89
+ wp.synchronize_device(device)
86
90
 
87
91
  # ensure forward pass outputs persist
88
92
  assert_np_equal(sum.numpy(), 2.0 * np.sum(x.numpy()))
@@ -90,7 +94,7 @@ def test_for_loop_graph_grad(test, device):
90
94
  assert_np_equal(x.grad.numpy(), 2.0 * val)
91
95
 
92
96
  wp.capture_launch(graph)
93
- wp.synchronize()
97
+ wp.synchronize_device(device)
94
98
 
95
99
 
96
100
  @wp.kernel
@@ -115,75 +119,20 @@ def for_loop_nested_if_grad(n: int, x: wp.array(dtype=float), s: wp.array(dtype=
115
119
  def test_for_loop_nested_if_grad(test, device):
116
120
  n = 32
117
121
  val = np.ones(n, dtype=np.float32)
118
-
122
+ # fmt: off
119
123
  expected_val = [
120
- 2.0,
121
- 2.0,
122
- 2.0,
123
- 2.0,
124
- 2.0,
125
- 2.0,
126
- 2.0,
127
- 2.0,
128
- 4.0,
129
- 4.0,
130
- 4.0,
131
- 4.0,
132
- 4.0,
133
- 4.0,
134
- 4.0,
135
- 4.0,
136
- 6.0,
137
- 6.0,
138
- 6.0,
139
- 6.0,
140
- 6.0,
141
- 6.0,
142
- 6.0,
143
- 6.0,
144
- 8.0,
145
- 8.0,
146
- 8.0,
147
- 8.0,
148
- 8.0,
149
- 8.0,
150
- 8.0,
151
- 8.0,
124
+ 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0,
125
+ 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0,
126
+ 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0,
127
+ 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0,
152
128
  ]
153
129
  expected_grad = [
154
- 2.0,
155
- 2.0,
156
- 2.0,
157
- 2.0,
158
- 2.0,
159
- 2.0,
160
- 2.0,
161
- 2.0,
162
- 4.0,
163
- 4.0,
164
- 4.0,
165
- 4.0,
166
- 4.0,
167
- 4.0,
168
- 4.0,
169
- 4.0,
170
- 6.0,
171
- 6.0,
172
- 6.0,
173
- 6.0,
174
- 6.0,
175
- 6.0,
176
- 6.0,
177
- 6.0,
178
- 8.0,
179
- 8.0,
180
- 8.0,
181
- 8.0,
182
- 8.0,
183
- 8.0,
184
- 8.0,
185
- 8.0,
130
+ 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0,
131
+ 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0,
132
+ 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0,
133
+ 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0,
186
134
  ]
135
+ # fmt: on
187
136
 
188
137
  x = wp.array(val, device=device, requires_grad=True)
189
138
  sum = wp.zeros(1, dtype=wp.float32, device=device, requires_grad=True)
@@ -327,8 +276,7 @@ def gradcheck(func, func_name, inputs, device, eps=1e-4, tol=1e-2):
327
276
  numerical gradient computed using finite differences.
328
277
  """
329
278
 
330
- module = wp.get_module(func.__module__)
331
- kernel = wp.Kernel(func=func, key=func_name, module=module)
279
+ kernel = wp.Kernel(func=func, key=func_name)
332
280
 
333
281
  def f(xs):
334
282
  # call the kernel without taping for finite differences
@@ -371,7 +319,7 @@ def gradcheck(func, func_name, inputs, device, eps=1e-4, tol=1e-2):
371
319
 
372
320
 
373
321
  def test_vector_math_grad(test, device):
374
- np.random.seed(123)
322
+ rng = np.random.default_rng(123)
375
323
 
376
324
  # test unary operations
377
325
  for dim, vec_type in [(2, wp.vec2), (3, wp.vec3), (4, wp.vec4), (4, wp.quat)]:
@@ -387,14 +335,14 @@ def test_vector_math_grad(test, device):
387
335
 
388
336
  # run the tests with 5 different random inputs
389
337
  for _ in range(5):
390
- x = wp.array(np.random.randn(1, dim).astype(np.float32), dtype=vec_type, device=device)
338
+ x = wp.array(rng.random(size=(1, dim), dtype=np.float32), dtype=vec_type, device=device)
391
339
  gradcheck(check_length, f"check_length_{vec_type.__name__}", [x], device)
392
340
  gradcheck(check_length_sq, f"check_length_sq_{vec_type.__name__}", [x], device)
393
341
  gradcheck(check_normalize, f"check_normalize_{vec_type.__name__}", [x], device)
394
342
 
395
343
 
396
344
  def test_matrix_math_grad(test, device):
397
- np.random.seed(123)
345
+ rng = np.random.default_rng(123)
398
346
 
399
347
  # test unary operations
400
348
  for dim, mat_type in [(2, wp.mat22), (3, wp.mat33), (4, wp.mat44)]:
@@ -407,13 +355,13 @@ def test_matrix_math_grad(test, device):
407
355
 
408
356
  # run the tests with 5 different random inputs
409
357
  for _ in range(5):
410
- x = wp.array(np.random.randn(1, dim, dim).astype(np.float32), ndim=1, dtype=mat_type, device=device)
358
+ x = wp.array(rng.random(size=(1, dim, dim), dtype=np.float32), ndim=1, dtype=mat_type, device=device)
411
359
  gradcheck(check_determinant, f"check_length_{mat_type.__name__}", [x], device)
412
360
  gradcheck(check_trace, f"check_length_sq_{mat_type.__name__}", [x], device)
413
361
 
414
362
 
415
363
  def test_3d_math_grad(test, device):
416
- np.random.seed(123)
364
+ rng = np.random.default_rng(123)
417
365
 
418
366
  # test binary operations
419
367
  def check_cross(vs: wp.array(dtype=wp.vec3), out: wp.array(dtype=float)):
@@ -463,7 +411,9 @@ def test_3d_math_grad(test, device):
463
411
 
464
412
  # run the tests with 5 different random inputs
465
413
  for _ in range(5):
466
- x = wp.array(np.random.randn(2, 3).astype(np.float32), dtype=wp.vec3, device=device, requires_grad=True)
414
+ x = wp.array(
415
+ rng.standard_normal(size=(2, 3), dtype=np.float32), dtype=wp.vec3, device=device, requires_grad=True
416
+ )
467
417
  gradcheck(check_cross, "check_cross_3d", [x], device)
468
418
  gradcheck(check_dot, "check_dot_3d", [x], device)
469
419
  gradcheck(check_mat33, "check_mat33_3d", [x], device, eps=2e-2)
@@ -473,6 +423,28 @@ def test_3d_math_grad(test, device):
473
423
  gradcheck(check_rot_quat_inv, "check_rot_quat_inv_3d", [x], device)
474
424
 
475
425
 
426
+ def test_multi_valued_function_grad(test, device):
427
+ rng = np.random.default_rng(123)
428
+
429
+ @wp.func
430
+ def multi_valued(x: float, y: float, z: float):
431
+ return wp.sin(x), wp.cos(y) * z, wp.sqrt(z) / wp.abs(x)
432
+
433
+ # test multi-valued functions
434
+ def check_multi_valued(vs: wp.array(dtype=wp.vec3), out: wp.array(dtype=float)):
435
+ tid = wp.tid()
436
+ v = vs[tid]
437
+ a, b, c = multi_valued(v[0], v[1], v[2])
438
+ out[tid] = a + b + c
439
+
440
+ # run the tests with 5 different random inputs
441
+ for _ in range(5):
442
+ x = wp.array(
443
+ rng.standard_normal(size=(2, 3), dtype=np.float32), dtype=wp.vec3, device=device, requires_grad=True
444
+ )
445
+ gradcheck(check_multi_valued, "check_multi_valued_3d", [x], device)
446
+
447
+
476
448
  def test_mesh_grad(test, device):
477
449
  pos = wp.array(
478
450
  [
@@ -502,19 +474,17 @@ def test_mesh_grad(test, device):
502
474
  c = mesh.points[k]
503
475
  return wp.length(wp.cross(b - a, c - a)) * 0.5
504
476
 
477
+ @wp.kernel
505
478
  def compute_area(mesh_id: wp.uint64, out: wp.array(dtype=wp.float32)):
506
479
  wp.atomic_add(out, 0, compute_triangle_area(mesh_id, wp.tid()))
507
480
 
508
- module = wp.get_module(compute_area.__module__)
509
- kernel = wp.Kernel(func=compute_area, key="compute_area", module=module)
510
-
511
481
  num_tris = int(len(indices) / 3)
512
482
 
513
483
  # compute analytical gradient
514
484
  tape = wp.Tape()
515
485
  output = wp.zeros(1, dtype=wp.float32, device=device, requires_grad=True)
516
486
  with tape:
517
- wp.launch(kernel, dim=num_tris, inputs=[mesh.id], outputs=[output], device=device)
487
+ wp.launch(compute_area, dim=num_tris, inputs=[mesh.id], outputs=[output], device=device)
518
488
 
519
489
  tape.backward(loss=output)
520
490
 
@@ -531,13 +501,13 @@ def test_mesh_grad(test, device):
531
501
  pos = wp.array(pos_np, dtype=wp.vec3, device=device)
532
502
  mesh = wp.Mesh(points=pos, indices=indices)
533
503
  output.zero_()
534
- wp.launch(kernel, dim=num_tris, inputs=[mesh.id], outputs=[output], device=device)
504
+ wp.launch(compute_area, dim=num_tris, inputs=[mesh.id], outputs=[output], device=device)
535
505
  f1 = output.numpy()[0]
536
506
  pos_np[i, j] -= 2 * eps
537
507
  pos = wp.array(pos_np, dtype=wp.vec3, device=device)
538
508
  mesh = wp.Mesh(points=pos, indices=indices)
539
509
  output.zero_()
540
- wp.launch(kernel, dim=num_tris, inputs=[mesh.id], outputs=[output], device=device)
510
+ wp.launch(compute_area, dim=num_tris, inputs=[mesh.id], outputs=[output], device=device)
541
511
  f2 = output.numpy()[0]
542
512
  pos_np[i, j] += eps
543
513
  fd_grad[i, j] = (f1 - f2) / (2 * eps)
@@ -545,27 +515,126 @@ def test_mesh_grad(test, device):
545
515
  assert np.allclose(ad_grad, fd_grad, atol=1e-3)
546
516
 
547
517
 
548
- def register(parent):
549
- devices = get_test_devices()
518
+ @wp.func
519
+ def name_clash(a: float, b: float) -> float:
520
+ return a + b
521
+
522
+
523
+ @wp.func_grad(name_clash)
524
+ def adj_name_clash(a: float, b: float, adj_ret: float):
525
+ # names `adj_a` and `adj_b` must not clash with function args of generated function
526
+ adj_a = 0.0
527
+ adj_b = 0.0
528
+ if a < 0.0:
529
+ adj_a = adj_ret
530
+ if b > 0.0:
531
+ adj_b = adj_ret
532
+
533
+ wp.adjoint[a] += adj_a
534
+ wp.adjoint[b] += adj_b
535
+
536
+
537
+ @wp.kernel
538
+ def name_clash_kernel(
539
+ input_a: wp.array(dtype=float),
540
+ input_b: wp.array(dtype=float),
541
+ output: wp.array(dtype=float),
542
+ ):
543
+ tid = wp.tid()
544
+ output[tid] = name_clash(input_a[tid], input_b[tid])
545
+
546
+
547
+ def test_name_clash(test, device):
548
+ # tests that no name clashes occur when variable names such as `adj_a` are used in custom gradient code
549
+ with wp.ScopedDevice(device):
550
+ input_a = wp.array([1.0, -2.0, 3.0], dtype=wp.float32, requires_grad=True)
551
+ input_b = wp.array([4.0, 5.0, -6.0], dtype=wp.float32, requires_grad=True)
552
+ output = wp.zeros(3, dtype=wp.float32, requires_grad=True)
553
+
554
+ tape = wp.Tape()
555
+ with tape:
556
+ wp.launch(name_clash_kernel, dim=len(input_a), inputs=[input_a, input_b], outputs=[output])
557
+
558
+ tape.backward(grads={output: wp.array(np.ones(len(input_a), dtype=np.float32))})
559
+
560
+ assert_np_equal(input_a.grad.numpy(), np.array([0.0, 1.0, 0.0]))
561
+ assert_np_equal(input_b.grad.numpy(), np.array([1.0, 1.0, 0.0]))
562
+
563
+
564
+ @wp.struct
565
+ class NestedStruct:
566
+ v: wp.vec2
567
+
568
+
569
+ @wp.struct
570
+ class ParentStruct:
571
+ a: float
572
+ n: NestedStruct
573
+
574
+
575
+ @wp.func
576
+ def noop(a: Any):
577
+ pass
578
+
579
+
580
+ @wp.func
581
+ def sum2(v: wp.vec2):
582
+ return v[0] + v[1]
583
+
584
+
585
+ @wp.kernel
586
+ def test_struct_attribute_gradient_kernel(src: wp.array(dtype=float), res: wp.array(dtype=float)):
587
+ tid = wp.tid()
588
+
589
+ p = ParentStruct(src[tid], NestedStruct(wp.vec2(2.0 * src[tid])))
590
+
591
+ # test that we are not losing gradients when accessing attributes
592
+ noop(p.a)
593
+ noop(p.n)
594
+ noop(p.n.v)
595
+
596
+ res[tid] = p.a + sum2(p.n.v)
597
+
598
+
599
+ def test_struct_attribute_gradient(test_case, device):
600
+ src = wp.array([1], dtype=float, requires_grad=True)
601
+ res = wp.empty_like(src)
602
+
603
+ tape = wp.Tape()
604
+ with tape:
605
+ wp.launch(test_struct_attribute_gradient_kernel, dim=1, inputs=[src, res])
606
+
607
+ res.grad.fill_(1.0)
608
+ tape.backward()
609
+
610
+ test_case.assertEqual(src.grad.numpy()[0], 5.0)
611
+
612
+
613
+ devices = get_test_devices()
614
+
550
615
 
551
- class TestGrad(parent):
552
- pass
616
+ class TestGrad(unittest.TestCase):
617
+ pass
553
618
 
554
- # add_function_test(TestGrad, "test_while_loop_grad", test_while_loop_grad, devices=devices)
555
- add_function_test(TestGrad, "test_for_loop_nested_for_grad", test_for_loop_nested_for_grad, devices=devices)
556
- add_function_test(TestGrad, "test_scalar_grad", test_scalar_grad, devices=devices)
557
- add_function_test(TestGrad, "test_for_loop_grad", test_for_loop_grad, devices=devices)
558
- add_function_test(TestGrad, "test_for_loop_graph_grad", test_for_loop_graph_grad, devices=wp.get_cuda_devices())
559
- add_function_test(TestGrad, "test_for_loop_nested_if_grad", test_for_loop_nested_if_grad, devices=devices)
560
- add_function_test(TestGrad, "test_preserve_outputs_grad", test_preserve_outputs_grad, devices=devices)
561
- add_function_test(TestGrad, "test_vector_math_grad", test_vector_math_grad, devices=devices)
562
- add_function_test(TestGrad, "test_matrix_math_grad", test_matrix_math_grad, devices=devices)
563
- add_function_test(TestGrad, "test_3d_math_grad", test_3d_math_grad, devices=devices)
564
- add_function_test(TestGrad, "test_mesh_grad", test_mesh_grad, devices=devices)
565
619
 
566
- return TestGrad
620
+ # add_function_test(TestGrad, "test_while_loop_grad", test_while_loop_grad, devices=devices)
621
+ add_function_test(TestGrad, "test_for_loop_nested_for_grad", test_for_loop_nested_for_grad, devices=devices)
622
+ add_function_test(TestGrad, "test_scalar_grad", test_scalar_grad, devices=devices)
623
+ add_function_test(TestGrad, "test_for_loop_grad", test_for_loop_grad, devices=devices)
624
+ add_function_test(
625
+ TestGrad, "test_for_loop_graph_grad", test_for_loop_graph_grad, devices=get_unique_cuda_test_devices()
626
+ )
627
+ add_function_test(TestGrad, "test_for_loop_nested_if_grad", test_for_loop_nested_if_grad, devices=devices)
628
+ add_function_test(TestGrad, "test_preserve_outputs_grad", test_preserve_outputs_grad, devices=devices)
629
+ add_function_test(TestGrad, "test_vector_math_grad", test_vector_math_grad, devices=devices)
630
+ add_function_test(TestGrad, "test_matrix_math_grad", test_matrix_math_grad, devices=devices)
631
+ add_function_test(TestGrad, "test_3d_math_grad", test_3d_math_grad, devices=devices)
632
+ add_function_test(TestGrad, "test_multi_valued_function_grad", test_multi_valued_function_grad, devices=devices)
633
+ add_function_test(TestGrad, "test_mesh_grad", test_mesh_grad, devices=devices)
634
+ add_function_test(TestGrad, "test_name_clash", test_name_clash, devices=devices)
635
+ add_function_test(TestGrad, "test_struct_attribute_gradient", test_struct_attribute_gradient, devices=devices)
567
636
 
568
637
 
569
638
  if __name__ == "__main__":
570
- c = register(unittest.TestCase)
639
+ wp.build.clear_kernel_cache()
571
640
  unittest.main(verbosity=2, failfast=False)
@@ -0,0 +1,176 @@
1
+ # Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
2
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
3
+ # and proprietary rights in and to this software, related documentation
4
+ # and any modifications thereto. Any use, reproduction, disclosure or
5
+ # distribution of this software and related documentation without an express
6
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
+
8
+ import unittest
9
+
10
+ import numpy as np
11
+
12
+ import warp as wp
13
+ from warp.tests.unittest_utils import *
14
+
15
+ wp.init()
16
+
17
+
18
+ # atomic add function that memorizes which thread incremented the counter
19
+ # so that the correct counter value per thread can be used in the replay
20
+ # phase of the backward pass
21
+ @wp.func
22
+ def reversible_increment(
23
+ counter: wp.array(dtype=int), counter_index: int, value: int, thread_values: wp.array(dtype=int), tid: int
24
+ ):
25
+ next_index = wp.atomic_add(counter, counter_index, value)
26
+ thread_values[tid] = next_index
27
+ return next_index
28
+
29
+
30
+ @wp.func_replay(reversible_increment)
31
+ def replay_reversible_increment(
32
+ counter: wp.array(dtype=int), counter_index: int, value: int, thread_values: wp.array(dtype=int), tid: int
33
+ ):
34
+ return thread_values[tid]
35
+
36
+
37
+ def test_custom_replay_grad(test, device):
38
+ num_threads = 128
39
+ counter = wp.zeros(1, dtype=wp.int32, device=device)
40
+ thread_ids = wp.zeros(num_threads, dtype=wp.int32, device=device)
41
+ inputs = wp.array(np.arange(num_threads, dtype=np.float32), device=device, requires_grad=True)
42
+ outputs = wp.zeros_like(inputs)
43
+
44
+ @wp.kernel
45
+ def run_atomic_add(
46
+ input: wp.array(dtype=float),
47
+ counter: wp.array(dtype=int),
48
+ thread_values: wp.array(dtype=int),
49
+ output: wp.array(dtype=float),
50
+ ):
51
+ tid = wp.tid()
52
+ idx = reversible_increment(counter, 0, 1, thread_values, tid)
53
+ output[idx] = input[idx] ** 2.0
54
+
55
+ tape = wp.Tape()
56
+ with tape:
57
+ wp.launch(
58
+ run_atomic_add, dim=num_threads, inputs=[inputs, counter, thread_ids], outputs=[outputs], device=device
59
+ )
60
+
61
+ tape.backward(grads={outputs: wp.array(np.ones(num_threads, dtype=np.float32), device=device)})
62
+ assert_np_equal(inputs.grad.numpy(), 2.0 * inputs.numpy(), tol=1e-4)
63
+
64
+
65
+ @wp.func
66
+ def overload_fn(x: float, y: float):
67
+ return x * 3.0 + y / 3.0, y**2.5
68
+
69
+
70
+ @wp.func_grad(overload_fn)
71
+ def overload_fn_grad(x: float, y: float, adj_ret0: float, adj_ret1: float):
72
+ wp.adjoint[x] += x * adj_ret0 * 42.0 + y * adj_ret1 * 10.0
73
+ wp.adjoint[y] += y * adj_ret1 * 3.0
74
+
75
+
76
+ @wp.struct
77
+ class MyStruct:
78
+ scalar: float
79
+ vec: wp.vec3
80
+
81
+
82
+ @wp.func
83
+ def overload_fn(x: MyStruct):
84
+ return x.vec[0] * x.vec[1] * x.vec[2] * 4.0, wp.length(x.vec), x.scalar**0.5
85
+
86
+
87
+ @wp.func_grad(overload_fn)
88
+ def overload_fn_grad(x: MyStruct, adj_ret0: float, adj_ret1: float, adj_ret2: float):
89
+ wp.adjoint[x.scalar] += x.scalar * adj_ret0 * 10.0
90
+ wp.adjoint[x.vec][0] += adj_ret0 * x.vec[1] * x.vec[2] * 20.0
91
+ wp.adjoint[x.vec][1] += adj_ret1 * x.vec[0] * x.vec[2] * 30.0
92
+ wp.adjoint[x.vec][2] += adj_ret2 * x.vec[0] * x.vec[1] * 40.0
93
+
94
+
95
+ @wp.kernel
96
+ def run_overload_float_fn(
97
+ xs: wp.array(dtype=float), ys: wp.array(dtype=float), output0: wp.array(dtype=float), output1: wp.array(dtype=float)
98
+ ):
99
+ i = wp.tid()
100
+ out0, out1 = overload_fn(xs[i], ys[i])
101
+ output0[i] = out0
102
+ output1[i] = out1
103
+
104
+
105
+ @wp.kernel
106
+ def run_overload_struct_fn(xs: wp.array(dtype=MyStruct), output: wp.array(dtype=float)):
107
+ i = wp.tid()
108
+ out0, out1, out2 = overload_fn(xs[i])
109
+ output[i] = out0 + out1 + out2
110
+
111
+
112
+ def test_custom_overload_grad(test, device):
113
+ dim = 3
114
+ xs_float = wp.array(np.arange(1.0, dim + 1.0), dtype=wp.float32, requires_grad=True)
115
+ ys_float = wp.array(np.arange(10.0, dim + 10.0), dtype=wp.float32, requires_grad=True)
116
+ out0_float = wp.zeros(dim)
117
+ out1_float = wp.zeros(dim)
118
+ tape = wp.Tape()
119
+ with tape:
120
+ wp.launch(run_overload_float_fn, dim=dim, inputs=[xs_float, ys_float], outputs=[out0_float, out1_float])
121
+ tape.backward(
122
+ grads={
123
+ out0_float: wp.array(np.ones(dim), dtype=wp.float32),
124
+ out1_float: wp.array(np.ones(dim), dtype=wp.float32),
125
+ }
126
+ )
127
+ assert_np_equal(xs_float.grad.numpy(), xs_float.numpy() * 42.0 + ys_float.numpy() * 10.0)
128
+ assert_np_equal(ys_float.grad.numpy(), ys_float.numpy() * 3.0)
129
+
130
+ x0 = MyStruct()
131
+ x0.vec = wp.vec3(1.0, 2.0, 3.0)
132
+ x0.scalar = 4.0
133
+ x1 = MyStruct()
134
+ x1.vec = wp.vec3(5.0, 6.0, 7.0)
135
+ x1.scalar = -1.0
136
+ x2 = MyStruct()
137
+ x2.vec = wp.vec3(8.0, 9.0, 10.0)
138
+ x2.scalar = 19.0
139
+ xs_struct = wp.array([x0, x1, x2], dtype=MyStruct, requires_grad=True)
140
+ out_struct = wp.zeros(dim)
141
+ tape = wp.Tape()
142
+ with tape:
143
+ wp.launch(run_overload_struct_fn, dim=dim, inputs=[xs_struct], outputs=[out_struct])
144
+ tape.backward(grads={out_struct: wp.array(np.ones(dim), dtype=wp.float32)})
145
+ xs_struct_np = xs_struct.numpy()
146
+ struct_grads = xs_struct.grad.numpy()
147
+ # fmt: off
148
+ assert_np_equal(
149
+ np.array([g[0] for g in struct_grads]),
150
+ np.array([g[0] * 10.0 for g in xs_struct_np]))
151
+ assert_np_equal(
152
+ np.array([g[1][0] for g in struct_grads]),
153
+ np.array([g[1][1] * g[1][2] * 20.0 for g in xs_struct_np]))
154
+ assert_np_equal(
155
+ np.array([g[1][1] for g in struct_grads]),
156
+ np.array([g[1][0] * g[1][2] * 30.0 for g in xs_struct_np]))
157
+ assert_np_equal(
158
+ np.array([g[1][2] for g in struct_grads]),
159
+ np.array([g[1][0] * g[1][1] * 40.0 for g in xs_struct_np]))
160
+ # fmt: on
161
+
162
+
163
+ devices = get_test_devices()
164
+
165
+
166
+ class TestGradCustoms(unittest.TestCase):
167
+ pass
168
+
169
+
170
+ add_function_test(TestGradCustoms, "test_custom_replay_grad", test_custom_replay_grad, devices=devices)
171
+ add_function_test(TestGradCustoms, "test_custom_overload_grad", test_custom_overload_grad, devices=devices)
172
+
173
+
174
+ if __name__ == "__main__":
175
+ wp.build.clear_kernel_cache()
176
+ unittest.main(verbosity=2, failfast=False)